[
  {
    "path": ".gitignore",
    "content": "# Byte-compiled / optimized / DLL files\n__pycache__/\n*.py[cod]\n*$py.class\n\n# C extensions\n*.so\n\n# Distribution / packaging\n.Python\nbuild/\ndevelop-eggs/\ndist/\ndownloads/\neggs/\n.eggs/\nlib/\nlib64/\nparts/\nsdist/\nvar/\nwheels/\nshare/python-wheels/\n*.egg-info/\n.installed.cfg\n*.egg\nMANIFEST\n\n# PyInstaller\n#  Usually these files are written by a python script from a template\n#  before PyInstaller builds the exe, so as to inject date/other infos into it.\n*.manifest\n*.spec\n\n# Installer logs\npip-log.txt\npip-delete-this-directory.txt\n\n# Unit test / coverage reports\nhtmlcov/\n.tox/\n.nox/\n.coverage\n.coverage.*\n.cache\nnosetests.xml\ncoverage.xml\n*.cover\n*.py,cover\n.hypothesis/\n.pytest_cache/\ncover/\n\n# Translations\n*.mo\n*.pot\n\n# Django stuff:\n*.log\nlocal_settings.py\ndb.sqlite3\ndb.sqlite3-journal\n\n# Flask stuff:\ninstance/\n.webassets-cache\n\n# Scrapy stuff:\n.scrapy\n\n# Sphinx documentation\ndocs/_build/\n\n# PyBuilder\n.pybuilder/\ntarget/\n\n# Jupyter Notebook\n.ipynb_checkpoints\n\n# IPython\nprofile_default/\nipython_config.py\n\n# pyenv\n#   For a library or package, you might want to ignore these files since the code is\n#   intended to run in multiple environments; otherwise, check them in:\n# .python-version\n\n# pipenv\n#   According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.\n#   However, in case of collaboration, if having platform-specific dependencies or dependencies\n#   having no cross-platform support, pipenv may install dependencies that don't work, or not\n#   install all needed dependencies.\n#Pipfile.lock\n\n# poetry\n#   Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.\n#   This is especially recommended for binary packages to ensure reproducibility, and is more\n#   commonly ignored for libraries.\n#   https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control\n#poetry.lock\n\n# pdm\n#   Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.\n#pdm.lock\n#   pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it\n#   in version control.\n#   https://pdm.fming.dev/#use-with-ide\n.pdm.toml\n\n# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm\n__pypackages__/\n\n# Celery stuff\ncelerybeat-schedule\ncelerybeat.pid\n\n# SageMath parsed files\n*.sage.py\n\n# Environments\n.env\n.venv\nenv/\nvenv/\nENV/\nenv.bak/\nvenv.bak/\n\n# Spyder project settings\n.spyderproject\n.spyproject\n\n# Rope project settings\n.ropeproject\n\n# mkdocs documentation\n/site\n\n# mypy\n.mypy_cache/\n.dmypy.json\ndmypy.json\n\n# Pyre type checker\n.pyre/\n\n# pytype static type analyzer\n.pytype/\n\n# Cython debug symbols\ncython_debug/\n\n# PyCharm\n#  JetBrains specific template is maintained in a separate JetBrains.gitignore that can\n#  be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore\n#  and can be added to the global gitignore or merged into this file.  For a more nuclear\n#  option (not recommended) you can uncomment the following to ignore the entire idea folder.\n.idea/\n*.pyc\n*.iml\n\n# Specific stuff\nwandb/\ncache_dir/\nraw_datasets/\nraw_data/\n\nfinal_outputs/\noutputs/\nvalidation_logs/\napex/\n*.ckpt\n.vscode/\n"
  },
  {
    "path": "README.md",
    "content": "# [CVPR'24 Spotlight] State Space Models for Event Cameras\n<p align=\"center\">\n <a href=\"https://www.youtube.com/watch?v=WRZZJn6Me9M\">\n  <img src=\"https://github.com/uzh-rpg/ssms_event_cameras/blob/master/scripts/zubic_cvpr2024_youtube.png\" alt=\"youtube_video\"/>\n </a>\n</p>\n\nThis is the official PyTorch implementation of the CVPR 2024 paper [State Space Models for Event Cameras](https://arxiv.org/abs/2402.15584).\n\n### 🖼️ Check Out Our Poster! 🖼️ [here](https://download.ifi.uzh.ch/rpg/CVPR24_Zubic/Zubic_CVPR24_poster.pdf)\n\n## :white_check_mark: Updates\n* **` June. 14th, 2024`**: Everything is updated! Poster released! Check it above.\n* **` June. 6st, 2024`**: Video released! To watch our video, simply click on the YouTube play button above.\n* **` June. 1st, 2024`**: Our CVPR conference paper has also been accepted as a Spotlight presentation at \"The 3rd Workshop on Transformers for Vision (T4V).\"\n* **` April. 19th, 2024`**: The code along with the best checkpoints is released! The poster and video will be released shortly before CVPR 2024.\n\n## Citation\nIf you find this work and/or code useful, please cite our paper:\n\n```bibtex\n@InProceedings{Zubic_2024_CVPR,\n    author    = {Zubic, Nikola and Gehrig, Mathias and Scaramuzza, Davide},\n    title     = {State Space Models for Event Cameras},\n    booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},\n    month     = {June},\n    year      = {2024},\n    pages     = {5819-5828}\n}\n```\n\n## SSM-ViT\n- S5 model used in our SSM-ViT pipeline can be seen [here](https://github.com/uzh-rpg/ssms_event_cameras/tree/master/RVT/models/layers/s5).\n- In particular, S5 is used instead of RNN in a 4-stage hierarchical ViT backbone, and its forward function is exposed [here](https://github.com/uzh-rpg/ssms_event_cameras/blob/master/RVT/models/detection/recurrent_backbone/maxvit_rnn.py#L245). What is nice about this approach is that we do not need a 'for' loop over sequence dimension, but instead we employ a parallel scanning algorithm. This model assumes that a hidden state is being carried over.\n- For a model that is standalone, and can be used for any sequence modeling problem, one does not use by default this formulation where we carry on the hidden state. The implementation is the same as the original JAX implementation and can be downloaded in zip format from [ssms_event_cameras/RVT/models/s5.zip](https://github.com/uzh-rpg/ssms_event_cameras/raw/master/RVT/models/s5.zip).\n\n## Installation\n### Conda\nWe highly recommend using [Mambaforge](https://github.com/conda-forge/miniforge#mambaforge) to reduce the installation time.\n```Bash\nconda create -y -n events_signals python=3.11\nconda activate events_signals\nconda install pytorch==2.2.1 torchvision==0.17.1 torchaudio==2.2.1 pytorch-cuda=11.8 -c pytorch -c nvidia\npip install lightning wandb pandas plotly opencv-python tabulate pycocotools bbox-visualizer StrEnum hydra-core einops torchdata tqdm numba h5py hdf5plugin lovely-tensors tensorboardX pykeops scikit-learn          \n```\n\n## Required Data\nTo evaluate or train the S5-ViT model, you will need to download the required preprocessed datasets:\n\n<table><tbody>\n<th valign=\"bottom\"></th>\n<th valign=\"bottom\">1 Mpx</th>\n<th valign=\"bottom\">Gen1</th>\n<tr><td align=\"left\">pre-processed dataset</td>\n<td align=\"center\"><a href=\"https://download.ifi.uzh.ch/rpg/RVT/datasets/preprocessed/gen4.tar\">download</a></td>\n<td align=\"center\"><a href=\"https://download.ifi.uzh.ch/rpg/RVT/datasets/preprocessed/gen1.tar\">download</a></td>\n</tr>\n<tr><td align=\"left\">crc32</td>\n<td align=\"center\"><tt>c5ec7c38</tt></td>\n<td align=\"center\"><tt>5acab6f3</tt></td>\n</tr>\n</tbody></table>\n\nYou may also pre-process the dataset yourself by following the [instructions](https://github.com/NikolaZubic/ssms_event_cameras/blob/master/RVT/scripts/genx/README.md).\n\n## Pre-trained Checkpoints\n### 1 Mpx\n<table><tbody>\n<th valign=\"bottom\"></th>\n<th valign=\"bottom\">S5-ViT-Base</th>\n<th valign=\"bottom\">S5-ViT-Small</th>\n<tr><td align=\"left\">pre-trained checkpoint</td>\n<td align=\"center\"><a href=\"https://download.ifi.uzh.ch/rpg/CVPR24_Zubic/gen4_base.ckpt\">download</a></td>\n<td align=\"center\"><a href=\"https://download.ifi.uzh.ch/rpg/CVPR24_Zubic/gen4_small.ckpt\">download</a></td>\n</tr>\n</tbody></table>\n\n### Gen1\n<table><tbody>\n<th valign=\"bottom\"></th>\n<th valign=\"bottom\">S5-ViT-Base</th>\n<th valign=\"bottom\">S5-ViT-Small</th>\n<tr><td align=\"left\">pre-trained checkpoint</td>\n<td align=\"center\"><a href=\"https://download.ifi.uzh.ch/rpg/CVPR24_Zubic/gen1_base.ckpt\">download</a></td>\n<td align=\"center\"><a href=\"https://download.ifi.uzh.ch/rpg/CVPR24_Zubic/gen1_small.ckpt\">download</a></td>\n</tr>\n</tbody></table>\n\n## Evaluation\n- Evaluation scripts with concrete parameters that we trained our models can be seen [here](https://github.com/uzh-rpg/ssms_event_cameras/tree/master/scripts).\n- Set `DATA_DIR` as the path to either the 1 Mpx or Gen1 dataset directory\n- Set `CKPT_PATH` to the path of the *correct* checkpoint matching the choice of the model and dataset\n- Set\n  - `MDL_CFG=base` or\n  - `MDL_CFG=small`\n      \n  to load either the base or small model configuration.\n- Set `GPU_ID` to the PCI BUS ID of the GPU that you want to use. e.g. `GPU_ID=0`.\n  Only a single GPU is supported for evaluation\n### 1 Mpx\n```Bash\npython RVT/validation.py dataset=gen4 dataset.path=${DATA_DIR} checkpoint=${CKPT_PATH} \\\nuse_test_set=1 hardware.gpus=${GPU_ID} +experiment/gen4=\"${MDL_CFG}.yaml\" \\\nbatch_size.eval=12 model.postprocess.confidence_threshold=0.001\n```\n### Gen1\n```Bash\npython RVT/validation.py dataset=gen1 dataset.path=${DATA_DIR} checkpoint=${CKPT_PATH} \\\nuse_test_set=1 hardware.gpus=${GPU_ID} +experiment/gen1=\"${MDL_CFG}.yaml\" \\\nbatch_size.eval=8 model.postprocess.confidence_threshold=0.001\n```\nWe set the same batch size for the evaluation and training: 12 for the 1 Mpx dataset, and 8 for the Gen1 dataset.\n\n## Evaluation results\nEvaluation should give the same results as shown below:\n- 47.7 and 47.8 mAP on Gen1 and 1 Mpx datasets for the base model, and\n- 46.6 and 46.5 mAP on Gen1 and 1 Mpx datasets for the small model.\n<p align=\"center\">\n  <img src=\"https://github.com/uzh-rpg/ssms_event_cameras/blob/master/scripts/checkpoints.png\">\n</p>\n\n## Training\n- Set `DATA_DIR` as the path to either the 1 Mpx or Gen1 dataset directory\n- Set\n    - `MDL_CFG=base` or\n    - `MDL_CFG=small`\n  \n  to load either the base or the small configuration.\n- Set `GPU_IDS` to the PCI BUS IDs of the GPUs that you want to use. e.g. `GPU_IDS=[0,1]` for using GPU 0 and 1.\n  **Using a list of IDS will enable single-node multi-GPU training.**\n  Pay attention to the batch size which is defined per GPU.\n- Set `BATCH_SIZE_PER_GPU` such that the effective batch size is matching the parameters below.\n  The **effective batch size** is (batch size per GPU)*(number of GPUs).\n- If you would like to change the effective batch size, we found the following learning rate scaling to work well for \nall models on both datasets:\n  \n  `lr = 2e-4 * sqrt(effective_batch_size/8)`.\n- The training code uses [W&B](https://wandb.ai/) for logging during the training.\nHence, we assume that you have a W&B account. \n  - The training script below will create a new project called `ssms_event_cameras`. Adapt the project name and group name if necessary.\n \n### 1 Mpx\n- The effective batch size for the 1 Mpx training is 12.\n- For training the model on 1 Mpx dataset, we need 2x A100 80 GB GPUs and we use 12 workers per GPU for training and 4 workers per GPU for evaluation:\n```Bash\nGPU_IDS=[0,1]\nBATCH_SIZE_PER_GPU=6\nTRAIN_WORKERS_PER_GPU=12\nEVAL_WORKERS_PER_GPU=4\npython RVT/train.py model=rnndet dataset=gen4 dataset.path=${DATA_DIR} wandb.project_name=ssms_event_cameras \\\nwandb.group_name=1mpx +experiment/gen4=\"${MDL_CFG}.yaml\" hardware.gpus=${GPU_IDS} \\\nbatch_size.train=${BATCH_SIZE_PER_GPU} batch_size.eval=${BATCH_SIZE_PER_GPU} \\\nhardware.num_workers.train=${TRAIN_WORKERS_PER_GPU} hardware.num_workers.eval=${EVAL_WORKERS_PER_GPU}\n```\nIf you for example want to execute the training on 4 GPUs simply adapt `GPU_IDS` and `BATCH_SIZE_PER_GPU` accordingly:\n```Bash\nGPU_IDS=[0,1,2,3]\nBATCH_SIZE_PER_GPU=3\n```\n### Gen1\n- The effective batch size for the Gen1 training is 8.\n- For training the model on the Gen1 dataset, we need 1x A100 80 GPU using 24 workers for training and 8 workers for evaluation:\n```Bash\nGPU_IDS=0\nBATCH_SIZE_PER_GPU=8\nTRAIN_WORKERS_PER_GPU=24\nEVAL_WORKERS_PER_GPU=8\npython RVT/train.py model=rnndet dataset=gen1 dataset.path=${DATA_DIR} wandb.project_name=ssms_event_cameras \\\nwandb.group_name=gen1 +experiment/gen1=\"${MDL_CFG}.yaml\" hardware.gpus=${GPU_IDS} \\\nbatch_size.train=${BATCH_SIZE_PER_GPU} batch_size.eval=${BATCH_SIZE_PER_GPU} \\\nhardware.num_workers.train=${TRAIN_WORKERS_PER_GPU} hardware.num_workers.eval=${EVAL_WORKERS_PER_GPU}\n```\n\n## Code Acknowledgments\nThis project has used code from the following projects:\n- [RVT](https://github.com/uzh-rpg/RVT) - Recurrent Vision Transformers for Object Detection with Event Cameras in PyTorch\n- [S4](https://github.com/state-spaces/s4) - Structured State Spaces for Sequence Modeling, in particular S4 and S4D models in PyTorch\n- [S5](https://github.com/lindermanlab/S5) - Simplified State Space Layers for Sequence Modeling in JAX\n- [S5 PyTorch](https://github.com/i404788/s5-pytorch) - S5 model in PyTorch\n"
  },
  {
    "path": "RVT/.gitignore",
    "content": "# Byte-compiled / optimized / DLL files\n__pycache__/\n*.py[cod]\n*$py.class\n\n# C extensions\n*.so\n\n# Distribution / packaging\n.Python\nbuild/\ndevelop-eggs/\ndist/\ndownloads/\neggs/\n.eggs/\nlib/\nlib64/\nparts/\nsdist/\nvar/\nwheels/\npip-wheel-metadata/\nshare/python-wheels/\n*.egg-info/\n.installed.cfg\n*.egg\nMANIFEST\n\n# PyInstaller\n#  Usually these files are written by a python script from a template\n#  before PyInstaller builds the exe, so as to inject date/other infos into it.\n*.manifest\n*.spec\n\n# Installer logs\npip-log.txt\npip-delete-this-directory.txt\n\n# Unit test / coverage reports\nhtmlcov/\n.tox/\n.nox/\n.coverage\n.coverage.*\n.cache\nnosetests.xml\ncoverage.xml\n*.cover\n*.py,cover\n.hypothesis/\n.pytest_cache/\n\n# Translations\n*.mo\n*.pot\n\n# Django stuff:\n*.log\nlocal_settings.py\ndb.sqlite3\ndb.sqlite3-journal\n\n# Flask stuff:\ninstance/\n.webassets-cache\n\n# Scrapy stuff:\n.scrapy\n\n# Sphinx documentation\ndocs/_build/\n\n# PyBuilder\ntarget/\n\n# Jupyter Notebook\n.ipynb_checkpoints\n\n# IPython\nprofile_default/\nipython_config.py\n\n# pyenv\n.python-version\n\n# pipenv\n#   According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.\n#   However, in case of collaboration, if having platform-specific dependencies or dependencies\n#   having no cross-platform support, pipenv may install dependencies that don't work, or not\n#   install all needed dependencies.\n#Pipfile.lock\n\n# PEP 582; used by e.g. github.com/David-OConnor/pyflow\n__pypackages__/\n\n# Celery stuff\ncelerybeat-schedule\ncelerybeat.pid\n\n# SageMath parsed files\n*.sage.py\n\n# Environments\n.env\n.venv\nenv/\nvenv/\nENV/\nenv.bak/\nvenv.bak/\n\n# Spyder project settings\n.spyderproject\n.spyproject\n\n# Rope project settings\n.ropeproject\n\n# mkdocs documentation\n/site\n\n# mypy\n.mypy_cache/\n.dmypy.json\ndmypy.json\n\n# Pyre type checker\n.pyre/\n"
  },
  {
    "path": "RVT/LICENSE",
    "content": "MIT License\n\nCopyright (c) 2023 Mathias Gehrig\n\nPermission is hereby granted, free of charge, to any person obtaining a copy\nof this software and associated documentation files (the \"Software\"), to deal\nin the Software without restriction, including without limitation the rights\nto use, copy, modify, merge, publish, distribute, sublicense, and/or sell\ncopies of the Software, and to permit persons to whom the Software is\nfurnished to do so, subject to the following conditions:\n\nThe above copyright notice and this permission notice shall be included in all\ncopies or substantial portions of the Software.\n\nTHE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\nIMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\nFITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\nAUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\nLIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\nOUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\nSOFTWARE.\n"
  },
  {
    "path": "RVT/README.md",
    "content": "# RVT: Recurrent Vision Transformers for Object Detection with Event Cameras\n<p align=\"center\">\n  <img src=\"https://rpg.ifi.uzh.ch/img/papers/arxiv22_detection_mgehrig/combo.png\" width=\"750\">\n</p>\n\nThis is the official Pytorch implementation of the CVPR 2023 paper [Recurrent Vision Transformers for Object Detection with Event Cameras](https://arxiv.org/abs/2212.05598).\n\nWatch the [**video**](https://youtu.be/xZ-pNwHxHgY) for a quick overview.\n\n```bibtex\n@InProceedings{Gehrig_2023_CVPR,\n  author  = {Mathias Gehrig and Davide Scaramuzza},\n  title   = {Recurrent Vision Transformers for Object Detection with Event Cameras},\n  booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},\n  year    = {2023},\n}\n```\n\n## Conda Installation\nWe highly recommend to use [Mambaforge](https://github.com/conda-forge/miniforge#mambaforge) to reduce the installation time.\n```Bash\nconda create -y -n rvt python=3.9 pip\nconda activate rvt\nconda config --set channel_priority flexible\n\nCUDA_VERSION=11.8\n\nconda install -y h5py=3.8.0 blosc-hdf5-plugin=1.0.0 \\\nhydra-core=1.3.2 einops=0.6.0 torchdata=0.6.0 tqdm numba \\\npytorch=2.0.0 torchvision=0.15.0 pytorch-cuda=$CUDA_VERSION \\\n-c pytorch -c nvidia -c conda-forge\n\npython -m pip install pytorch-lightning==1.8.6 wandb==0.14.0 \\\npandas==1.5.3 plotly==5.13.1 opencv-python==4.6.0.66 tabulate==0.9.0 \\\npycocotools==2.0.6 bbox-visualizer==0.1.0 StrEnum==0.4.10\npython -m pip install 'git+https://github.com/facebookresearch/detectron2.git'\n```\nDetectron2 is not strictly required but speeds up the evaluation.\n\n## Required Data\nTo evaluate or train RVT you will need to download the required preprocessed datasets:\n\n<table><tbody>\n<th valign=\"bottom\"></th>\n<th valign=\"bottom\">1 Mpx</th>\n<th valign=\"bottom\">Gen1</th>\n<tr><td align=\"left\">pre-processed dataset</td>\n<td align=\"center\"><a href=\"https://download.ifi.uzh.ch/rpg/RVT/datasets/preprocessed/gen4.tar\">download</a></td>\n<td align=\"center\"><a href=\"https://download.ifi.uzh.ch/rpg/RVT/datasets/preprocessed/gen1.tar\">download</a></td>\n</tr>\n<tr><td align=\"left\">crc32</td>\n<td align=\"center\"><tt>c5ec7c38</tt></td>\n<td align=\"center\"><tt>5acab6f3</tt></td>\n</tr>\n</tbody></table>\n\nYou may also pre-process the dataset yourself by following the [instructions](scripts/genx/README.md).\n\n## Pre-trained Checkpoints\n### 1 Mpx\n<table><tbody>\n<th valign=\"bottom\"></th>\n<th valign=\"bottom\">RVT-Base</th>\n<th valign=\"bottom\">RVT-Small</th>\n<th valign=\"bottom\">RVT-Tiny</th>\n<tr><td align=\"left\">pre-trained checkpoint</td>\n<td align=\"center\"><a href=\"https://download.ifi.uzh.ch/rpg/RVT/checkpoints/1mpx/rvt-b.ckpt\">download</a></td>\n<td align=\"center\"><a href=\"https://download.ifi.uzh.ch/rpg/RVT/checkpoints/1mpx/rvt-s.ckpt\">download</a></td>\n<td align=\"center\"><a href=\"https://download.ifi.uzh.ch/rpg/RVT/checkpoints/1mpx/rvt-t.ckpt\">download</a></td>\n</tr>\n<tr><td align=\"left\">md5</td>\n<td align=\"center\"><tt>72923a</tt></td>\n<td align=\"center\"><tt>a94207</tt></td>\n<td align=\"center\"><tt>5a3c78</tt></td>\n</tr>\n</tbody></table>\n\n### Gen1\n<table><tbody>\n<th valign=\"bottom\"></th>\n<th valign=\"bottom\">RVT-Base</th>\n<th valign=\"bottom\">RVT-Small</th>\n<th valign=\"bottom\">RVT-Tiny</th>\n<tr><td align=\"left\">pre-trained checkpoint</td>\n<td align=\"center\"><a href=\"https://download.ifi.uzh.ch/rpg/RVT/checkpoints/gen1/rvt-b.ckpt\">download</a></td>\n<td align=\"center\"><a href=\"https://download.ifi.uzh.ch/rpg/RVT/checkpoints/gen1/rvt-s.ckpt\">download</a></td>\n<td align=\"center\"><a href=\"https://download.ifi.uzh.ch/rpg/RVT/checkpoints/gen1/rvt-t.ckpt\">download</a></td>\n</tr>\n<tr><td align=\"left\">md5</td>\n<td align=\"center\"><tt>839317</tt></td>\n<td align=\"center\"><tt>840f2b</tt></td>\n<td align=\"center\"><tt>a770b9</tt></td>\n</tr>\n</tbody></table>\n\n## Evaluation\n- Set `DATA_DIR` as the path to either the 1 Mpx or Gen1 dataset directory\n- Set `CKPT_PATH` to the path of the *correct* checkpoint matching the choice of the model and dataset.\n- Set\n  - `MDL_CFG=base`, or\n  - `MDL_CFG=small`, or\n  - `MDL_CFG=tiny`\n  \n  to load either the base, small, or tiny model configuration\n- Set\n  - `USE_TEST=1` to evaluate on the test set, or\n  - `USE_TEST=0` to evaluate on the validation set\n- Set `GPU_ID` to the PCI BUS ID of the GPU that you want to use. e.g. `GPU_ID=0`.\n  Only a single GPU is supported for evaluation\n### 1 Mpx\n```Bash\npython validation.py dataset=gen4 dataset.path=${DATA_DIR} checkpoint=${CKPT_PATH} \\\nuse_test_set=${USE_TEST} hardware.gpus=${GPU_ID} +experiment/gen4=\"${MDL_CFG}.yaml\" \\\nbatch_size.eval=8 model.postprocess.confidence_threshold=0.001\n```\n### Gen1\n```Bash\npython validation.py dataset=gen1 dataset.path=${DATA_DIR} checkpoint=${CKPT_PATH} \\\nuse_test_set=${USE_TEST} hardware.gpus=${GPU_ID} +experiment/gen1=\"${MDL_CFG}.yaml\" \\\nbatch_size.eval=8 model.postprocess.confidence_threshold=0.001\n```\n\n## Training\n- Set `DATA_DIR` as the path to either the 1 Mpx or Gen1 dataset directory\n- Set\n    - `MDL_CFG=base`, or\n    - `MDL_CFG=small`, or\n    - `MDL_CFG=tiny`\n\n  to load either the base, small, or tiny model configuration\n- Set `GPU_IDS` to the PCI BUS IDs of the GPUs that you want to use. e.g. `GPU_IDS=[0,1]` for using GPU 0 and 1.\n  **Using a list of IDS will enable single-node multi-GPU training.**\n  Pay attention to the batch size which is defined per GPU:\n- Set `BATCH_SIZE_PER_GPU` such that the effective batch size is matching the parameters below.\n  The **effective batch size** is (batch size per gpu)*(number of GPUs).\n- If you would like to change the effective batch size, we found the following learning rate scaling to work well for \nall models on both datasets:\n  \n  `lr = 2e-4 * sqrt(effective_batch_size/8)`.\n- The training code uses [W&B](https://wandb.ai/) for logging during the training.\nHence, we assume that you have a W&B account. \n  - The training script below will create a new project called `RVT`. Adapt the project name and group name if necessary.\n \n### 1 Mpx\n- The effective batch size for the 1 Mpx training is 24.\n- To train on 2 GPUs using 6 workers per GPU for training and 2 workers per GPU for evaluation:\n```Bash\nGPU_IDS=[0,1]\nBATCH_SIZE_PER_GPU=12\nTRAIN_WORKERS_PER_GPU=6\nEVAL_WORKERS_PER_GPU=2\npython train.py model=rnndet dataset=gen4 dataset.path=${DATA_DIR} wandb.project_name=RVT \\\nwandb.group_name=1mpx +experiment/gen4=\"${MDL_CFG}.yaml\" hardware.gpus=${GPU_IDS} \\\nbatch_size.train=${BATCH_SIZE_PER_GPU} batch_size.eval=${BATCH_SIZE_PER_GPU} \\\nhardware.num_workers.train=${TRAIN_WORKERS_PER_GPU} hardware.num_workers.eval=${EVAL_WORKERS_PER_GPU}\n```\nIf you instead want to execute the training on 4 GPUs simply adapt `GPU_IDS` and `BATCH_SIZE_PER_GPU` accordingly:\n```Bash\nGPU_IDS=[0,1,2,3]\nBATCH_SIZE_PER_GPU=6\n```\n### Gen1\n- The effective batch size for the Gen1 training is 8.\n- To train on 1 GPU using 6 workers for training and 2 workers for evaluation:\n```Bash\nGPU_IDS=0\nBATCH_SIZE_PER_GPU=8\nTRAIN_WORKERS_PER_GPU=6\nEVAL_WORKERS_PER_GPU=2\npython train.py model=rnndet dataset=gen1 dataset.path=${DATA_DIR} wandb.project_name=RVT \\\nwandb.group_name=gen1 +experiment/gen1=\"${MDL_CFG}.yaml\" hardware.gpus=${GPU_IDS} \\\nbatch_size.train=${BATCH_SIZE_PER_GPU} batch_size.eval=${BATCH_SIZE_PER_GPU} \\\nhardware.num_workers.train=${TRAIN_WORKERS_PER_GPU} hardware.num_workers.eval=${EVAL_WORKERS_PER_GPU}\n```\n\n## Code Acknowledgments\nThis project has used code from the following projects:\n- [timm](https://github.com/huggingface/pytorch-image-models) for the MaxViT layer implementation in Pytorch\n- [YOLOX](https://github.com/Megvii-BaseDetection/YOLOX) for the detection PAFPN/head\n"
  },
  {
    "path": "RVT/callbacks/custom.py",
    "content": "from omegaconf import DictConfig\nfrom lightning.pytorch.callbacks import Callback\nfrom lightning.pytorch.callbacks import ModelCheckpoint\n\nfrom callbacks.detection import DetectionVizCallback\n\n\ndef get_ckpt_callback(config: DictConfig) -> ModelCheckpoint:\n    model_name = config.model.name\n\n    prefix = \"val\"\n    if model_name == \"rnndet\":\n        metric = \"AP\"\n        mode = \"max\"\n    else:\n        raise NotImplementedError\n    ckpt_callback_monitor = prefix + \"/\" + metric\n    filename_monitor_str = prefix + \"_\" + metric\n\n    ckpt_filename = (\n        \"epoch={epoch:03d}-step={step}-\"\n        + filename_monitor_str\n        + \"={\"\n        + ckpt_callback_monitor\n        + \":.2f}\"\n    )\n    cktp_callback = ModelCheckpoint(\n        monitor=ckpt_callback_monitor,\n        filename=ckpt_filename,\n        auto_insert_metric_name=False,  # because backslash would create a directory\n        save_top_k=1,\n        mode=mode,\n        every_n_epochs=config.logging.ckpt_every_n_epochs,\n        save_last=True,\n        verbose=True,\n    )\n    cktp_callback.CHECKPOINT_NAME_LAST = \"last_epoch={epoch:03d}-step={step}\"\n    return cktp_callback\n\n\ndef get_viz_callback(config: DictConfig) -> Callback:\n    model_name = config.model.name\n\n    if model_name == \"rnndet\":\n        return DetectionVizCallback(config=config)\n    raise NotImplementedError\n"
  },
  {
    "path": "RVT/callbacks/detection.py",
    "content": "from enum import Enum, auto\nfrom typing import Any\n\nimport torch\nfrom einops import rearrange\nfrom omegaconf import DictConfig\n\nfrom data.utils.types import ObjDetOutput\nfrom loggers.wandb_logger import WandbLogger\nfrom utils.evaluation.prophesee.visualize.vis_utils import (\n    LABELMAP_GEN1,\n    LABELMAP_GEN4_SHORT,\n    draw_bboxes,\n)\nfrom .viz_base import VizCallbackBase\n\n\nclass DetectionVizEnum(Enum):\n    EV_IMG = auto()\n    LABEL_IMG_PROPH = auto()\n    PRED_IMG_PROPH = auto()\n\n\nclass DetectionVizCallback(VizCallbackBase):\n    def __init__(self, config: DictConfig):\n        super().__init__(config=config, buffer_entries=DetectionVizEnum)\n\n        dataset_name = config.dataset.name\n        if dataset_name == \"gen1\":\n            self.label_map = LABELMAP_GEN1\n        elif dataset_name == \"gen4\":\n            self.label_map = LABELMAP_GEN4_SHORT\n        else:\n            raise NotImplementedError\n\n    def on_train_batch_end_custom(\n        self,\n        logger: WandbLogger,\n        outputs: Any,\n        batch: Any,\n        log_n_samples: int,\n        global_step: int,\n    ) -> None:\n        if outputs is None:\n            # If we tried to skip the training step (not supported in DDP in PL, atm)\n            return\n        ev_tensors = outputs[ObjDetOutput.EV_REPR]\n        num_samples = len(ev_tensors)\n        assert num_samples > 0\n        log_n_samples = min(num_samples, log_n_samples)\n\n        merged_img = []\n        captions = []\n        start_idx = num_samples - 1\n        end_idx = start_idx - log_n_samples\n        # for sample_idx in range(log_n_samples):\n        for sample_idx in range(start_idx, end_idx, -1):\n            ev_img = self.ev_repr_to_img(ev_tensors[sample_idx].cpu().numpy())\n\n            predictions_proph = outputs[ObjDetOutput.PRED_PROPH][sample_idx]\n            prediction_img = ev_img.copy()\n            draw_bboxes(prediction_img, predictions_proph, labelmap=self.label_map)\n\n            labels_proph = outputs[ObjDetOutput.LABELS_PROPH][sample_idx]\n            label_img = ev_img.copy()\n            draw_bboxes(label_img, labels_proph, labelmap=self.label_map)\n\n            merged_img.append(\n                rearrange(\n                    [prediction_img, label_img], \"pl H W C -> (pl H) W C\", pl=2, C=3\n                )\n            )\n            captions.append(f\"sample_{sample_idx}\")\n\n        logger.log_images(\n            key=\"train/predictions\",\n            images=merged_img,\n            caption=captions,\n            step=global_step,\n        )\n\n    def on_validation_batch_end_custom(self, batch: Any, outputs: Any):\n        if outputs[ObjDetOutput.SKIP_VIZ]:\n            return\n        ev_tensor = outputs[ObjDetOutput.EV_REPR]\n        assert isinstance(ev_tensor, torch.Tensor)\n\n        ev_img = self.ev_repr_to_img(ev_tensor.cpu().numpy())\n\n        predictions_proph = outputs[ObjDetOutput.PRED_PROPH]\n        prediction_img = ev_img.copy()\n        draw_bboxes(prediction_img, predictions_proph, labelmap=self.label_map)\n        self.add_to_buffer(DetectionVizEnum.PRED_IMG_PROPH, prediction_img)\n\n        labels_proph = outputs[ObjDetOutput.LABELS_PROPH]\n        label_img = ev_img.copy()\n        draw_bboxes(label_img, labels_proph, labelmap=self.label_map)\n        self.add_to_buffer(DetectionVizEnum.LABEL_IMG_PROPH, label_img)\n\n    def on_validation_epoch_end_custom(self, logger: WandbLogger):\n        pred_imgs = self.get_from_buffer(DetectionVizEnum.PRED_IMG_PROPH)\n        label_imgs = self.get_from_buffer(DetectionVizEnum.LABEL_IMG_PROPH)\n        assert len(pred_imgs) == len(label_imgs)\n        merged_img = []\n        captions = []\n        for idx, (pred_img, label_img) in enumerate(zip(pred_imgs, label_imgs)):\n            merged_img.append(\n                rearrange([pred_img, label_img], \"pl H W C -> (pl H) W C\", pl=2, C=3)\n            )\n            captions.append(f\"sample_{idx}\")\n\n        logger.log_images(key=\"val/predictions\", images=merged_img, caption=captions)\n"
  },
  {
    "path": "RVT/callbacks/gradflow.py",
    "content": "from typing import Any\n\nimport lightning.pytorch as pl\nfrom lightning.pytorch.callbacks import Callback\nfrom lightning.pytorch.utilities.rank_zero import rank_zero_only\n\nfrom callbacks.utils.visualization import get_grad_flow_figure\n\n\nclass GradFlowLogCallback(Callback):\n    def __init__(self, log_every_n_train_steps: int):\n        super().__init__()\n        assert log_every_n_train_steps > 0\n        self.log_every_n_train_steps = log_every_n_train_steps\n\n    @rank_zero_only\n    def on_before_zero_grad(\n        self, trainer: pl.Trainer, pl_module: pl.LightningModule, optimizer: Any\n    ) -> None:\n        # NOTE: before we had this in the on_after_backward callback.\n        # This was fine for fp32 but showed unscaled gradients for fp16.\n        # That is why we move it to on_before_zero_grad where gradients are scaled.\n        global_step = trainer.global_step\n        if global_step % self.log_every_n_train_steps != 0:\n            return\n        named_parameters = pl_module.named_parameters()\n        figure = get_grad_flow_figure(named_parameters)\n        trainer.logger.log_metrics({\"train/gradients\": figure}, step=global_step)\n"
  },
  {
    "path": "RVT/callbacks/utils/visualization.py",
    "content": "import pandas as pd\nimport plotly.express as px\n\n\ndef get_grad_flow_figure(named_params):\n    \"\"\"Creates figure to visualize gradients flowing through different layers in the net during training.\n    Can be used for checking for possible gradient vanishing / exploding problems.\n    Usage: Use this function after loss.backwards()\n    \"\"\"\n    data_dict = {\n        \"name\": list(),\n        \"grad_abs\": list(),\n    }\n    for name, param in named_params:\n        if param.requires_grad and param.grad is not None:\n            grad_abs = param.grad.abs()\n            data_dict[\"name\"].append(name)\n            data_dict[\"grad_abs\"].append(grad_abs.mean().cpu().item())\n\n    data_frame = pd.DataFrame.from_dict(data_dict)\n\n    fig = px.bar(data_frame, x=\"name\", y=\"grad_abs\")\n    return fig\n"
  },
  {
    "path": "RVT/callbacks/viz_base.py",
    "content": "import random\nfrom enum import Enum\nfrom typing import Any, List, Optional, Type, Union\n\nimport numpy as np\nimport pytorch_lightning as pl\nimport torch as th\nfrom einops import rearrange, reduce\nfrom omegaconf import DictConfig\nfrom lightning.pytorch.callbacks import Callback\nfrom lightning.pytorch.utilities.rank_zero import rank_zero_only\n\nfrom loggers.wandb_logger import WandbLogger\n\n\nclass VizCallbackBase(Callback):\n    def __init__(self, config: DictConfig, buffer_entries: Type[Enum]):\n        super().__init__()\n\n        self.log_config = config.logging\n\n        self._training_has_started = False\n        self._selected_val_batches = False\n\n        self.buffer_entries = buffer_entries\n        self._val_batch_indices = list()\n        self._buffer = None\n        self._reset_buffer()\n\n    def _reset_buffer(self):\n        self._buffer = {entry: [] for entry in self.buffer_entries}\n\n    # Functions to be USED in the base class ---------------------------------------------------------------------------\n\n    def add_to_buffer(self, key: Enum, value: Union[np.ndarray, th.Tensor]):\n        if isinstance(value, th.Tensor):\n            assert not value.requires_grad\n            value = value.cpu()\n        else:\n            assert isinstance(value, np.ndarray)\n        assert type(key) == self.buffer_entries\n        assert key in self._buffer\n        self._buffer[key].append(value)\n\n    def get_from_buffer(self, key: Enum) -> List[th.Tensor]:\n        assert type(key) == self.buffer_entries\n        return self._buffer[key]\n\n    # Functions to be IMPLEMENTED in the base class --------------------------------------------------------------------\n\n    def on_train_batch_end_custom(\n        self,\n        logger: WandbLogger,\n        outputs: Any,\n        batch: Any,\n        log_n_samples: int,\n        global_step: int,\n    ) -> None:\n        raise NotImplementedError\n\n    def on_validation_batch_end_custom(self, batch: Any, outputs: Any) -> None:\n        raise NotImplementedError\n\n    def on_validation_epoch_end_custom(self, logger: WandbLogger) -> None:\n        raise NotImplementedError\n\n    # ------------------------------------------------------------------------------------------------------------------\n\n    def on_train_batch_end(\n        self,\n        trainer: pl.Trainer,\n        pl_module: pl.LightningModule,\n        outputs: Any,\n        batch: Any,\n        batch_idx: int,\n        unused: int = 0,\n    ) -> None:\n        log_train_hd = self.log_config.train.high_dim\n        if not log_train_hd.enable:\n            return\n\n        step = trainer.global_step\n        assert log_train_hd.every_n_steps > 0\n        if step % log_train_hd.every_n_steps != 0:\n            return\n\n        n_samples = log_train_hd.n_samples\n\n        logger: Optional[WandbLogger] = trainer.logger\n        assert isinstance(logger, WandbLogger)\n\n        global_step = trainer.global_step\n\n        self.on_train_batch_end_custom(\n            logger=logger,\n            outputs=outputs,\n            batch=batch,\n            log_n_samples=n_samples,\n            global_step=global_step,\n        )\n\n    @rank_zero_only\n    def on_validation_batch_end(\n        self,\n        trainer: pl.Trainer,\n        pl_module: pl.LightningModule,\n        outputs: Optional[Any],\n        batch: Any,\n        batch_idx: int,\n        dataloader_idx: int = 0,\n    ) -> None:\n        log_val_hd = self.log_config.validation.high_dim\n        log_freq_val_epochs = log_val_hd.every_n_epochs\n        if not log_val_hd.enable:\n            return\n        if dataloader_idx > 0:\n            raise NotImplementedError\n        if not self._training_has_started:\n            # PL has a short sanity check for validation. Hence, we have to make sure that one training run is done.\n            return\n        if not self._selected_val_batches:\n            # We only want to add validation batch indices during the first true validation run.\n            self._val_batch_indices.append(batch_idx)\n            return\n        assert len(self._val_batch_indices) > 0\n        if batch_idx not in self._val_batch_indices:\n            return\n        if trainer.current_epoch % log_freq_val_epochs != 0:\n            return\n\n        self.on_validation_batch_end_custom(batch, outputs)\n\n    def on_validation_epoch_start(\n        self, trainer: pl.Trainer, pl_module: pl.LightningModule\n    ) -> None:\n        self._reset_buffer()\n\n    @rank_zero_only\n    def on_validation_epoch_end(\n        self, trainer: pl.Trainer, pl_module: pl.LightningModule\n    ) -> None:\n        log_val_hd = self.log_config.validation.high_dim\n        log_n_samples = log_val_hd.n_samples\n        log_freq_val_epochs = log_val_hd.every_n_epochs\n        if len(self._val_batch_indices) == 0:\n            return\n        if not self._selected_val_batches:\n            random.seed(0)\n            num_samples = min(len(self._val_batch_indices), log_n_samples)\n            # draw without replacement\n            sampled_indices = random.sample(self._val_batch_indices, num_samples)\n            self._val_batch_indices = sampled_indices\n            self._selected_val_batches = True\n            return\n        if trainer.current_epoch % log_freq_val_epochs != 0:\n            return\n\n        logger: Optional[WandbLogger] = trainer.logger\n        assert isinstance(logger, WandbLogger)\n        self.on_validation_epoch_end_custom(logger)\n\n    def on_train_batch_start(\n        self,\n        trainer: \"pl.Trainer\",\n        pl_module: \"pl.LightningModule\",\n        batch: Any,\n        batch_idx: int,\n    ) -> None:\n        self._training_has_started = True\n\n    @staticmethod\n    def ev_repr_to_img(x: np.ndarray):\n        ch, ht, wd = x.shape[-3:]\n        assert ch > 1 and ch % 2 == 0\n        ev_repr_reshaped = rearrange(x, \"(posneg C) H W -> posneg C H W\", posneg=2)\n        img_neg = np.asarray(\n            reduce(ev_repr_reshaped[0], \"C H W -> H W\", \"sum\"), dtype=\"int32\"\n        )\n        img_pos = np.asarray(\n            reduce(ev_repr_reshaped[1], \"C H W -> H W\", \"sum\"), dtype=\"int32\"\n        )\n        img_diff = img_pos - img_neg\n        img = 127 * np.ones((ht, wd, 3), dtype=np.uint8)\n        img[img_diff > 0] = 255\n        img[img_diff < 0] = 0\n        return img\n"
  },
  {
    "path": "RVT/config/dataset/base.yaml",
    "content": "name: ???\npath: ???\ntrain:\n  sampling: 'mixed' # ('random', 'stream', 'mixed')\n  random:\n    weighted_sampling: False\n  mixed:\n    w_stream: 1\n    w_random: 1\neval:\n  sampling: 'stream'\ndata_augmentation:\n  random:\n    prob_hflip: 0.5\n    rotate:\n      prob: 0\n      min_angle_deg: 2\n      max_angle_deg: 6\n    zoom:\n      prob: 0.8\n      zoom_in:\n        weight: 8\n        factor:\n          min: 1\n          max: 1.5\n      zoom_out:\n        weight: 2\n        factor:\n          min: 1\n          max: 1.2\n  stream:\n    prob_hflip: 0.5\n    rotate:\n      prob: 0\n      min_angle_deg: 2\n      max_angle_deg: 6\n    zoom:\n      prob: 0.5\n      zoom_out:\n        factor:\n          min: 1\n          max: 1.2"
  },
  {
    "path": "RVT/config/dataset/gen1.yaml",
    "content": "defaults:\n  - base\n\nname: gen1\nev_repr_name: 'stacked_histogram_dt=50_nbins=10'\nsequence_length: 21\nresolution_hw: [240, 304]\ndownsample_by_factor_2: False\nonly_load_end_labels: False"
  },
  {
    "path": "RVT/config/dataset/gen4.yaml",
    "content": "defaults:\n  - base\n\nname: gen4\nev_repr_name: 'stacked_histogram_dt=50_nbins=10'\nsequence_length: 10\nresolution_hw: [720, 1280]\ndownsample_by_factor_2: True\nonly_load_end_labels: False"
  },
  {
    "path": "RVT/config/experiment/gen1/base.yaml",
    "content": "# @package _global_\ndefaults:\n  - default\n\nmodel:\n  backbone:\n    embed_dim: 64\n  fpn:\n    depth: 0.67\n"
  },
  {
    "path": "RVT/config/experiment/gen1/default.yaml",
    "content": "# @package _global_\ndefaults:\n  - /model/maxvit_yolox: default\n\ntraining:\n  precision: 32\n  max_epochs: 10000\n  max_steps: 400000\n  learning_rate: 0.0002\n  lr_scheduler:\n    use: True\n    total_steps: ${..max_steps}\n    pct_start: 0.005\n    div_factor: 20\n    final_div_factor: 10000\nvalidation:\n  val_check_interval: 10000\n  check_val_every_n_epoch: null\nbatch_size:\n  train: 8\n  eval: 8\nhardware:\n  num_workers:\n    train: 6\n    eval: 2\ndataset:\n  train:\n    sampling: 'mixed'\n    random:\n      weighted_sampling: False\n    mixed:\n      w_stream: 1\n      w_random: 1\n  eval:\n    sampling: 'stream'\n  ev_repr_name: 'stacked_histogram_dt=50_nbins=10'\n  sequence_length: 21\n  downsample_by_factor_2: False\n  only_load_end_labels: False\nmodel:\n  backbone:\n    partition_split_32: 1"
  },
  {
    "path": "RVT/config/experiment/gen1/small.yaml",
    "content": "# @package _global_\ndefaults:\n  - default\n\nmodel:\n  backbone:\n    embed_dim: 48\n    stage:\n      attention:\n        dim_head: 24\n  fpn:\n    depth: 0.33"
  },
  {
    "path": "RVT/config/experiment/gen4/base.yaml",
    "content": "# @package _global_\ndefaults:\n  - default\n\nmodel:\n  backbone:\n    embed_dim: 64\n  fpn:\n    depth: 0.67\n"
  },
  {
    "path": "RVT/config/experiment/gen4/default.yaml",
    "content": "# @package _global_\ndefaults:\n  - /model/maxvit_yolox: default\n\ntraining:\n  precision: 32\n  max_epochs: 10000\n  max_steps: 400000\n  learning_rate: 0.0002449489742783178 # 2e-4 * sqrt(effective_batch_size/8) = 2e-4 * sqrt(12/8)\n  lr_scheduler:\n    use: True\n    total_steps: ${..max_steps}\n    pct_start: 0.005\n    div_factor: 20\n    final_div_factor: 10000\nvalidation:\n  val_check_interval: 10000\n  check_val_every_n_epoch: null\nbatch_size:\n  train: 12\n  eval: 12\nhardware:\n  num_workers:\n    train: 6\n    eval: 2\ndataset:\n  train:\n    sampling: 'mixed'\n    random:\n      weighted_sampling: False\n    mixed:\n      w_stream: 1\n      w_random: 1\n  eval:\n    sampling: 'stream'\n  ev_repr_name: 'stacked_histogram_dt=50_nbins=10'\n  sequence_length: 10\n  downsample_by_factor_2: True\n  only_load_end_labels: False"
  },
  {
    "path": "RVT/config/experiment/gen4/small.yaml",
    "content": "# @package _global_\ndefaults:\n  - default\n\nmodel:\n  backbone:\n    embed_dim: 48\n    stage:\n      attention:\n        dim_head: 24\n  fpn:\n    depth: 0.33"
  },
  {
    "path": "RVT/config/general.yaml",
    "content": "reproduce:\n  seed_everything: null # Union[int, null]\n  deterministic_flag: False # Must be true for fully deterministic behaviour (slows down training)\n  benchmark: False # Should be set to false for fully deterministic behaviour. Could potentially speed up training.\ntraining:\n  precision: 16\n  max_epochs: 10000\n  max_steps: 400000\n  learning_rate: 0.0002\n  weight_decay: 0\n  gradient_clip_val: 1.0\n  limit_train_batches: 1.0\n  lr_scheduler:\n    use: True\n    total_steps: ${..max_steps}\n    pct_start: 0.005\n    div_factor: 25 # init_lr = max_lr / div_factor\n    final_div_factor: 10000 # final_lr = max_lr / final_div_factor (this is different from Pytorch' OneCycleLR param)\nvalidation:\n  limit_val_batches: 1.0\n  val_check_interval: null # Optional[int]\n  check_val_every_n_epoch: 1 # Optional[int]\nbatch_size:\n  train: 8\n  eval: 8\nhardware:\n  num_workers:\n    train: 6\n    eval: 2\n  gpus: 0 # Either a single integer (e.g. 3) or a list of integers (e.g. [3,5,6])\n  dist_backend: \"nccl\"\nlogging:\n  ckpt_every_n_epochs: 1\n  train:\n    metrics:\n      compute: false\n      detection_metrics_every_n_steps: null # Optional[int] -> null: every train epoch, int: every N steps\n    log_model_every_n_steps: 5000\n    log_every_n_steps: 500\n    high_dim:\n      enable: True\n      every_n_steps: 5000\n      n_samples: 4\n  validation:\n    high_dim:\n      enable: True\n      every_n_epochs: 1\n      n_samples: 8\nwandb:\n  #   How to use:\n  #   1) resume existing wandb run:                                 set artifact_name & wandb_runpath\n  #   2) resume full training state in new wandb run:               set artifact_name\n  #   3) resume only model weights of checkpoint in new wandb run:  set artifact_name & resume_only_weights=True\n  #\n  #   In addition: you can specify artifact_local_file to load the checkpoint from disk.\n  #   This is for example required for resuming training with DDP.\n  wandb_runpath: null # WandB run path. E.g. USERNAME/PROJECTNAME/1grv5kg6\n  artifact_name: null # Name of checkpoint/artifact. Required for resuming. E.g. USERNAME/PROJECTNAME/checkpoint-1grv5kg6-last:v15\n  artifact_local_file: null # If specified, will use the provided local filepath instead of downloading it. Required if resuming with DDP.\n  resume_only_weights: False\n  group_name: ??? # Specify group name of the run\n  project_name: RVT"
  },
  {
    "path": "RVT/config/model/base.yaml",
    "content": "name: ???"
  },
  {
    "path": "RVT/config/model/maxvit_yolox/default.yaml",
    "content": "# @package _global_\ndefaults:\n  - override /model: rnndet\n\nmodel:\n  backbone:\n    name: MaxViTRNN\n    compile:\n      enable: False\n      args:\n        mode: reduce-overhead\n    input_channels: 20\n    enable_masking: False\n    partition_split_32: 2\n    embed_dim: 64\n    dim_multiplier: [1, 2, 4, 8]\n    num_blocks: [1, 1, 1, 1]\n    T_max_chrono_init: [4, 8, 16, 32]\n    stem:\n      patch_size: 4\n    stage:\n      downsample:\n        type: patch\n        overlap: True\n        norm_affine: True\n      attention:\n        use_torch_mha: False\n        partition_size: ???\n        dim_head: 32\n        attention_bias: True\n        mlp_activation: gelu\n        mlp_gated: False\n        mlp_bias: True\n        mlp_ratio: 4\n        drop_mlp: 0\n        drop_path: 0\n        ls_init_value: 1e-5\n      lstm:\n        dws_conv: False\n        dws_conv_only_hidden: True\n        dws_conv_kernel_size: 3\n        drop_cell_update: 0\n      s5:\n        dim: 80\n        state_dim: 80\n      s4:\n        dim: 80\n        state_dim: 80 \n  fpn:\n    name: PAFPN\n    compile:\n      enable: False\n      args:\n        mode: reduce-overhead\n    depth: 0.67 # round(depth * 3) == num bottleneck blocks\n    # stage 1 is the first and len(num_layers) is the last\n    in_stages: [2, 3, 4]\n    depthwise: False\n    act: \"silu\"\n  head:\n    name: YoloX\n    compile:\n      enable: False\n      args:\n        mode: reduce-overhead\n    depthwise: False\n    act: \"silu\"\n  postprocess:\n    confidence_threshold: 0.1\n    nms_threshold: 0.45\n"
  },
  {
    "path": "RVT/config/model/rnndet.yaml",
    "content": "defaults:\n  - base\n\nname: rnndet\nbackbone:\n  name: ???\nfpn:\n  name: ???\nhead:\n  name: ???\npostprocess:\n  confidence_threshold: 0.1\n  nms_threshold: 0.45"
  },
  {
    "path": "RVT/config/modifier.py",
    "content": "import os\nfrom typing import Tuple\n\nimport math\nfrom omegaconf import DictConfig, open_dict\n\nfrom data.utils.spatial import get_dataloading_hw\n\n\ndef dynamically_modify_train_config(config: DictConfig):\n    with open_dict(config):\n        slurm_job_id = os.environ.get(\"SLURM_JOB_ID\")\n        if slurm_job_id and slurm_job_id != \"\":\n            config.slurm_job_id = int(slurm_job_id)\n\n        dataset_cfg = config.dataset\n\n        dataset_name = dataset_cfg.name\n        assert dataset_name in {\"gen1\", \"gen4\"}\n        dataset_hw = get_dataloading_hw(dataset_config=dataset_cfg)\n\n        mdl_cfg = config.model\n        mdl_name = mdl_cfg.name\n        if mdl_name == \"rnndet\":\n            backbone_cfg = mdl_cfg.backbone\n            backbone_name = backbone_cfg.name\n            if backbone_name == \"MaxViTRNN\":\n                partition_split_32 = backbone_cfg.partition_split_32\n                assert partition_split_32 in (1, 2, 4)\n\n                multiple_of = 32 * partition_split_32\n                mdl_hw = _get_modified_hw_multiple_of(\n                    hw=dataset_hw, multiple_of=multiple_of\n                )\n                print(f\"Set {backbone_name} backbone (height, width) to {mdl_hw}\")\n                backbone_cfg.in_res_hw = mdl_hw\n\n                attention_cfg = backbone_cfg.stage.attention\n                partition_size = tuple(x // (32 * partition_split_32) for x in mdl_hw)\n                assert (mdl_hw[0] // 32) % partition_size[\n                    0\n                ] == 0, f\"{mdl_hw[0]=}, {partition_size[0]=}\"\n                assert (mdl_hw[1] // 32) % partition_size[\n                    1\n                ] == 0, f\"{mdl_hw[1]=}, {partition_size[1]=}\"\n                print(f\"Set partition sizes: {partition_size}\")\n                attention_cfg.partition_size = partition_size\n            else:\n                print(f\"{backbone_name=} not available\")\n                raise NotImplementedError\n            num_classes = 2 if dataset_name == \"gen1\" else 3\n            mdl_cfg.head.num_classes = num_classes\n            print(f\"Set {num_classes=} for detection head\")\n        else:\n            print(f\"{mdl_name=} not available\")\n            raise NotImplementedError\n\n\ndef _get_modified_hw_multiple_of(\n    hw: Tuple[int, int], multiple_of: int\n) -> Tuple[int, ...]:\n    assert isinstance(hw, tuple), f\"{type(hw)=}, {hw=}\"\n    assert len(hw) == 2\n    assert isinstance(multiple_of, int)\n    assert multiple_of >= 1\n    if multiple_of == 1:\n        return hw\n    new_hw = tuple(math.ceil(x / multiple_of) * multiple_of for x in hw)\n    return new_hw\n"
  },
  {
    "path": "RVT/config/train.yaml",
    "content": "defaults:\n  - general\n  - dataset: ???\n  - model: rnndet\n  - optional model/dataset: ${model}_${dataset}"
  },
  {
    "path": "RVT/config/val.yaml",
    "content": "defaults:\n  - dataset: ???\n  - model: rnndet\n  - _self_\n\ncheckpoint: ???\nuse_test_set: False\nhardware:\n  num_workers:\n    eval: 4\n  gpus: 0 # GPU idx (multi-gpu not supported for validation)\nbatch_size:\n  eval: 8\ntraining:\n  precision: 16\n"
  },
  {
    "path": "RVT/data/genx_utils/collate.py",
    "content": "from copy import deepcopy\nfrom typing import Any, Callable, Dict, Optional, Type, Tuple, Union\n\nimport torch\n\nfrom data.genx_utils.collate_from_pytorch import collate, default_collate_fn_map\nfrom data.genx_utils.labels import ObjectLabels, SparselyBatchedObjectLabels\n\n\ndef collate_object_labels(\n    batch,\n    *,\n    collate_fn_map: Optional[Dict[Union[Type, Tuple[Type, ...]], Callable]] = None\n):\n    return batch\n\n\ndef collate_sparsely_batched_object_labels(\n    batch,\n    *,\n    collate_fn_map: Optional[Dict[Union[Type, Tuple[Type, ...]], Callable]] = None\n):\n    return SparselyBatchedObjectLabels.transpose_list(batch)\n\n\ncustom_collate_fn_map = deepcopy(default_collate_fn_map)\ncustom_collate_fn_map[ObjectLabels] = collate_object_labels\ncustom_collate_fn_map[SparselyBatchedObjectLabels] = (\n    collate_sparsely_batched_object_labels\n)\n\n\ndef custom_collate(batch: Any):\n    return collate(batch, collate_fn_map=custom_collate_fn_map)\n\n\ndef custom_collate_rnd(batch: Any):\n    samples = batch\n    # NOTE: We do not really need the worker id for map style datasets (rnd) but we still provide the id for consistency\n    worker_info = torch.utils.data.get_worker_info()\n    local_worker_id = 0 if worker_info is None else worker_info.id\n    return {\n        \"data\": custom_collate(samples),\n        \"worker_id\": local_worker_id,\n    }\n\n\ndef custom_collate_streaming(batch: Any):\n    \"\"\"We assume that we receive a batch collected by a worker of our streaming datapipe\"\"\"\n    samples = batch[0]\n    worker_id = batch[1]\n    assert isinstance(worker_id, int)\n    return {\n        \"data\": custom_collate(samples),\n        \"worker_id\": worker_id,\n    }\n"
  },
  {
    "path": "RVT/data/genx_utils/collate_from_pytorch.py",
    "content": "import collections\nimport contextlib\nimport re\n\nimport torch\n\ntorch_is_version_1 = int(torch.__version__.split(\".\")[0]) == 1\n\nfrom typing import Callable, Dict, Optional, Tuple, Type, Union\n\nnp_str_obj_array_pattern = re.compile(r\"[SaUO]\")\n\ndefault_collate_err_msg_format = (\n    \"default_collate: batch must contain tensors, numpy arrays, numbers, \"\n    \"dicts or lists; found {}\"\n)\n\n\ndef collate(\n    batch,\n    *,\n    collate_fn_map: Optional[Dict[Union[Type, Tuple[Type, ...]], Callable]] = None\n):\n    r\"\"\"\n    General collate function that handles collection type of element within each batch\n    and opens function registry to deal with specific element types. `default_collate_fn_map`\n    provides default collate functions for tensors, numpy arrays, numbers and strings.\n\n    Args:\n        batch: a single batch to be collated\n        collate_fn_map: Optional dictionary mapping from element type to the corresponding collate function.\n          If the element type isn't present in this dictionary,\n          this function will go through each key of the dictionary in the insertion order to\n          invoke the corresponding collate function if the element type is a subclass of the key.\n\n    Examples:\n        >>> # Extend this function to handle batch of tensors\n        >>> def collate_tensor_fn(batch, *, collate_fn_map):\n        ...     return torch.stack(batch, 0)\n        >>> def custom_collate(batch):\n        ...     collate_map = {torch.Tensor: collate_tensor_fn}\n        ...     return collate(batch, collate_fn_map=collate_map)\n        >>> # Extend `default_collate` by in-place modifying `default_collate_fn_map`\n        >>> default_collate_fn_map.update({torch.Tensor: collate_tensor_fn})\n\n    Note:\n        Each collate function requires a positional argument for batch and a keyword argument\n        for the dictionary of collate functions as `collate_fn_map`.\n    \"\"\"\n    elem = batch[0]\n    elem_type = type(elem)\n\n    if collate_fn_map is not None:\n        if elem_type in collate_fn_map:\n            return collate_fn_map[elem_type](batch, collate_fn_map=collate_fn_map)\n\n        for collate_type in collate_fn_map:\n            if isinstance(elem, collate_type):\n                return collate_fn_map[collate_type](\n                    batch, collate_fn_map=collate_fn_map\n                )\n\n    if isinstance(elem, collections.abc.Mapping):\n        try:\n            return elem_type(\n                {\n                    key: collate([d[key] for d in batch], collate_fn_map=collate_fn_map)\n                    for key in elem\n                }\n            )\n        except TypeError:\n            # The mapping type may not support `__init__(iterable)`.\n            return {\n                key: collate([d[key] for d in batch], collate_fn_map=collate_fn_map)\n                for key in elem\n            }\n    elif isinstance(elem, tuple) and hasattr(elem, \"_fields\"):  # namedtuple\n        return elem_type(\n            *(\n                collate(samples, collate_fn_map=collate_fn_map)\n                for samples in zip(*batch)\n            )\n        )\n    elif isinstance(elem, collections.abc.Sequence):\n        # check to make sure that the elements in batch have consistent size\n        it = iter(batch)\n        elem_size = len(next(it))\n        if not all(len(elem) == elem_size for elem in it):\n            raise RuntimeError(\"each element in list of batch should be of equal size\")\n        transposed = list(zip(*batch))  # It may be accessed twice, so we use a list.\n\n        if isinstance(elem, tuple):\n            return [\n                collate(samples, collate_fn_map=collate_fn_map)\n                for samples in transposed\n            ]  # Backwards compatibility.\n        else:\n            try:\n                return elem_type(\n                    [\n                        collate(samples, collate_fn_map=collate_fn_map)\n                        for samples in transposed\n                    ]\n                )\n            except TypeError:\n                # The sequence type may not support `__init__(iterable)` (e.g., `range`).\n                return [\n                    collate(samples, collate_fn_map=collate_fn_map)\n                    for samples in transposed\n                ]\n\n    raise TypeError(default_collate_err_msg_format.format(elem_type))\n\n\nif torch_is_version_1:\n\n    def collate_tensor_fn(\n        batch,\n        *,\n        collate_fn_map: Optional[Dict[Union[Type, Tuple[Type, ...]], Callable]] = None\n    ):\n        elem = batch[0]\n        out = None\n        if torch.utils.data.get_worker_info() is not None:\n            # If we're in a background process, concatenate directly into a\n            # shared memory tensor to avoid an extra copy\n            numel = sum(x.numel() for x in batch)\n            storage = elem.storage()._new_shared(numel, device=elem.device)\n            out = elem.new(storage).resize_(len(batch), *list(elem.size()))\n        return torch.stack(batch, 0, out=out)\n\nelse:\n\n    def collate_tensor_fn(\n        batch,\n        *,\n        collate_fn_map: Optional[Dict[Union[Type, Tuple[Type, ...]], Callable]] = None\n    ):\n        elem = batch[0]\n        out = None\n        if torch.utils.data.get_worker_info() is not None:\n            # If we're in a background process, concatenate directly into a\n            # shared memory tensor to avoid an extra copy\n            numel = sum(x.numel() for x in batch)\n            storage = elem._typed_storage()._new_shared(numel, device=elem.device)\n            out = elem.new(storage).resize_(len(batch), *list(elem.size()))\n        return torch.stack(batch, 0, out=out)\n\n\ndef collate_numpy_array_fn(\n    batch,\n    *,\n    collate_fn_map: Optional[Dict[Union[Type, Tuple[Type, ...]], Callable]] = None\n):\n    elem = batch[0]\n    # array of string classes and object\n    if np_str_obj_array_pattern.search(elem.dtype.str) is not None:\n        raise TypeError(default_collate_err_msg_format.format(elem.dtype))\n\n    return collate([torch.as_tensor(b) for b in batch], collate_fn_map=collate_fn_map)\n\n\ndef collate_numpy_scalar_fn(\n    batch,\n    *,\n    collate_fn_map: Optional[Dict[Union[Type, Tuple[Type, ...]], Callable]] = None\n):\n    return torch.as_tensor(batch)\n\n\ndef collate_float_fn(\n    batch,\n    *,\n    collate_fn_map: Optional[Dict[Union[Type, Tuple[Type, ...]], Callable]] = None\n):\n    return torch.tensor(batch, dtype=torch.float64)\n\n\ndef collate_int_fn(\n    batch,\n    *,\n    collate_fn_map: Optional[Dict[Union[Type, Tuple[Type, ...]], Callable]] = None\n):\n    return torch.tensor(batch)\n\n\ndef collate_str_fn(\n    batch,\n    *,\n    collate_fn_map: Optional[Dict[Union[Type, Tuple[Type, ...]], Callable]] = None\n):\n    return batch\n\n\ndefault_collate_fn_map: Dict[Union[Type, Tuple[Type, ...]], Callable] = {\n    torch.Tensor: collate_tensor_fn\n}\nwith contextlib.suppress(ImportError):\n    import numpy as np\n\n    # For both ndarray and memmap (subclass of ndarray)\n    default_collate_fn_map[np.ndarray] = collate_numpy_array_fn\n    # See scalars hierarchy: https://numpy.org/doc/stable/reference/arrays.scalars.html\n    # Skip string scalars\n    default_collate_fn_map[(np.bool_, np.number, np.object_)] = collate_numpy_scalar_fn\ndefault_collate_fn_map[float] = collate_float_fn\ndefault_collate_fn_map[int] = collate_int_fn\ndefault_collate_fn_map[str] = collate_str_fn\n"
  },
  {
    "path": "RVT/data/genx_utils/dataset_rnd.py",
    "content": "from collections import namedtuple\nfrom collections.abc import Iterable\nfrom pathlib import Path\nfrom typing import List\n\nimport numpy as np\nfrom omegaconf import DictConfig\nfrom torch.utils.data import ConcatDataset, Dataset\nfrom torch.utils.data.sampler import WeightedRandomSampler\nfrom tqdm import tqdm\n\nfrom data.genx_utils.labels import SparselyBatchedObjectLabels\nfrom data.genx_utils.sequence_rnd import SequenceForRandomAccess\nfrom data.utils.augmentor import RandomSpatialAugmentorGenX\nfrom data.utils.types import DatasetMode, LoaderDataDictGenX, DatasetType, DataType\n\n\nclass SequenceDataset(Dataset):\n    def __init__(\n        self, path: Path, dataset_mode: DatasetMode, dataset_config: DictConfig\n    ):\n        assert path.is_dir()\n\n        ### extract settings from config ###\n        sequence_length = dataset_config.sequence_length\n        assert isinstance(sequence_length, int)\n        assert sequence_length > 0\n        self.output_seq_len = sequence_length\n\n        ev_representation_name = dataset_config.ev_repr_name\n        downsample_by_factor_2 = dataset_config.downsample_by_factor_2\n        only_load_end_labels = dataset_config.only_load_end_labels\n\n        augm_config = dataset_config.data_augmentation\n\n        ####################################\n        if dataset_config.name == \"gen1\":\n            dataset_type = DatasetType.GEN1\n        elif dataset_config.name == \"gen4\":\n            dataset_type = DatasetType.GEN4\n        else:\n            raise NotImplementedError\n        self.sequence = SequenceForRandomAccess(\n            path=path,\n            ev_representation_name=ev_representation_name,\n            sequence_length=sequence_length,\n            dataset_type=dataset_type,\n            downsample_by_factor_2=downsample_by_factor_2,\n            only_load_end_labels=only_load_end_labels,\n        )\n\n        self.spatial_augmentor = None\n        if dataset_mode == DatasetMode.TRAIN:\n            resolution_hw = tuple(dataset_config.resolution_hw)\n            assert len(resolution_hw) == 2\n            ds_by_factor_2 = dataset_config.downsample_by_factor_2\n            if ds_by_factor_2:\n                resolution_hw = tuple(x // 2 for x in resolution_hw)\n            self.spatial_augmentor = RandomSpatialAugmentorGenX(\n                dataset_hw=resolution_hw,\n                automatic_randomization=True,\n                augm_config=augm_config.random,\n            )\n\n    def only_load_labels(self):\n        self.sequence.only_load_labels()\n\n    def load_everything(self):\n        self.sequence.load_everything()\n\n    def __len__(self):\n        return len(self.sequence)\n\n    def __getitem__(self, index: int) -> LoaderDataDictGenX:\n        item = self.sequence[index]\n\n        if (\n            self.spatial_augmentor is not None\n            and not self.sequence.is_only_loading_labels()\n        ):\n            item = self.spatial_augmentor(item)\n\n        return item\n\n\nclass CustomConcatDataset(ConcatDataset):\n    datasets: List[SequenceDataset]\n\n    def __init__(self, datasets: Iterable[SequenceDataset]):\n        super().__init__(datasets=datasets)\n\n    def only_load_labels(self):\n        for idx, dataset in enumerate(self.datasets):\n            self.datasets[idx].only_load_labels()\n\n    def load_everything(self):\n        for idx, dataset in enumerate(self.datasets):\n            self.datasets[idx].load_everything()\n\n\ndef build_random_access_dataset(\n    dataset_mode: DatasetMode, dataset_config: DictConfig\n) -> CustomConcatDataset:\n    dataset_path = Path(dataset_config.path)\n    assert dataset_path.is_dir(), f\"{str(dataset_path)}\"\n\n    mode2str = {\n        DatasetMode.TRAIN: \"train\",\n        DatasetMode.VALIDATION: \"val\",\n        DatasetMode.TESTING: \"test\",\n    }\n\n    split_path = dataset_path / mode2str[dataset_mode]\n    assert split_path.is_dir()\n\n    seq_datasets = list()\n    for entry in tqdm(\n        split_path.iterdir(),\n        desc=f\"creating rnd access {mode2str[dataset_mode]} datasets\",\n    ):\n        seq_datasets.append(\n            SequenceDataset(\n                path=entry, dataset_mode=dataset_mode, dataset_config=dataset_config\n            )\n        )\n\n    return CustomConcatDataset(seq_datasets)\n\n\ndef get_weighted_random_sampler(dataset: CustomConcatDataset) -> WeightedRandomSampler:\n    class2count = dict()\n    ClassAndCount = namedtuple(\"ClassAndCount\", [\"class_ids\", \"counts\"])\n    classandcount_list = list()\n    print(\"--- START generating weighted random sampler ---\")\n    dataset.only_load_labels()\n    for idx, data in enumerate(tqdm(dataset, desc=\"iterate through dataset\")):\n        labels: SparselyBatchedObjectLabels = data[DataType.OBJLABELS_SEQ]\n        label_list, valid_batch_indices = labels.get_valid_labels_and_batch_indices()\n        class_ids_seq = list()\n        for label in label_list:\n            class_ids_numpy = np.asarray(label.class_id.numpy(), dtype=\"int32\")\n            class_ids_seq.append(class_ids_numpy)\n        class_ids_seq, counts_seq = np.unique(\n            np.concatenate(class_ids_seq), return_counts=True\n        )\n        for class_id, count in zip(class_ids_seq, counts_seq):\n            class2count[class_id] = class2count.get(class_id, 0) + count\n        classandcount_list.append(\n            ClassAndCount(class_ids=class_ids_seq, counts=counts_seq)\n        )\n    dataset.load_everything()\n\n    class2weight = {}\n    for class_id, count in class2count.items():\n        count = max(count, 1)\n        class2weight[class_id] = 1 / count\n\n    weights = []\n    for classandcount in classandcount_list:\n        weight = 0\n        for class_id, count in zip(classandcount.class_ids, classandcount.counts):\n            # Not only weight depending on class but also depending on number of occurrences.\n            # This will bias towards sampling \"frames\" with more bounding boxes.\n            weight += class2weight[class_id] * count\n        weights.append(weight)\n\n    print(\"--- DONE generating weighted random sampler ---\")\n    return WeightedRandomSampler(\n        weights=weights, num_samples=len(weights), replacement=True\n    )\n"
  },
  {
    "path": "RVT/data/genx_utils/dataset_streaming.py",
    "content": "from functools import partialmethod\nfrom pathlib import Path\nfrom typing import List, Union\n\nfrom omegaconf import DictConfig\nfrom torchdata.datapipes.map import MapDataPipe\nfrom tqdm import tqdm\n\nfrom data.genx_utils.sequence_for_streaming import (\n    SequenceForIter,\n    RandAugmentIterDataPipe,\n)\nfrom data.utils.stream_concat_datapipe import ConcatStreamingDataPipe\nfrom data.utils.stream_sharded_datapipe import ShardedStreamingDataPipe\nfrom data.utils.types import DatasetMode, DatasetType\n\n\ndef build_streaming_dataset(\n    dataset_mode: DatasetMode,\n    dataset_config: DictConfig,\n    batch_size: int,\n    num_workers: int,\n) -> Union[ConcatStreamingDataPipe, ShardedStreamingDataPipe]:\n    dataset_path = Path(dataset_config.path)\n    assert dataset_path.is_dir(), f\"{str(dataset_path)}\"\n\n    mode2str = {\n        DatasetMode.TRAIN: \"train\",\n        DatasetMode.VALIDATION: \"val\",\n        DatasetMode.TESTING: \"test\",\n    }\n\n    split_path = dataset_path / mode2str[dataset_mode]\n    assert split_path.is_dir()\n    datapipes = list()\n    num_full_sequences = 0\n    num_splits = 0\n    num_split_sequences = 0\n    guarantee_labels = dataset_mode == DatasetMode.TRAIN\n    for entry in tqdm(\n        split_path.iterdir(),\n        desc=f\"creating streaming {mode2str[dataset_mode]} datasets\",\n    ):\n        new_datapipes = get_sequences(\n            path=entry, dataset_config=dataset_config, guarantee_labels=guarantee_labels\n        )\n        if len(new_datapipes) == 1:\n            num_full_sequences += 1\n        else:\n            num_splits += 1\n            num_split_sequences += len(new_datapipes)\n        datapipes.extend(new_datapipes)\n    print(f\"{num_full_sequences=}\\n{num_splits=}\\n{num_split_sequences=}\")\n\n    if dataset_mode == DatasetMode.TRAIN:\n        return build_streaming_train_dataset(\n            datapipes=datapipes,\n            dataset_config=dataset_config,\n            batch_size=batch_size,\n            num_workers=num_workers,\n        )\n    elif dataset_mode in (DatasetMode.VALIDATION, DatasetMode.TESTING):\n        return build_streaming_evaluation_dataset(\n            datapipes=datapipes, batch_size=batch_size\n        )\n    else:\n        raise NotImplementedError\n\n\ndef get_sequences(\n    path: Path, dataset_config: DictConfig, guarantee_labels: bool\n) -> List[SequenceForIter]:\n    assert path.is_dir()\n\n    ### extract settings from config ###\n    sequence_length = dataset_config.sequence_length\n    ev_representation_name = dataset_config.ev_repr_name\n    downsample_by_factor_2 = dataset_config.downsample_by_factor_2\n    if dataset_config.name == \"gen1\":\n        dataset_type = DatasetType.GEN1\n    elif dataset_config.name == \"gen4\":\n        dataset_type = DatasetType.GEN4\n    else:\n        raise NotImplementedError\n    ####################################\n    if guarantee_labels:\n        return SequenceForIter.get_sequences_with_guaranteed_labels(\n            path=path,\n            ev_representation_name=ev_representation_name,\n            sequence_length=sequence_length,\n            dataset_type=dataset_type,\n            downsample_by_factor_2=downsample_by_factor_2,\n        )\n    return [\n        SequenceForIter(\n            path=path,\n            ev_representation_name=ev_representation_name,\n            sequence_length=sequence_length,\n            dataset_type=dataset_type,\n            downsample_by_factor_2=downsample_by_factor_2,\n        )\n    ]\n\n\ndef partialclass(cls, *args, **kwargs):\n    class NewCls(cls):\n        __init__ = partialmethod(cls.__init__, *args, **kwargs)\n\n    return NewCls\n\n\ndef build_streaming_train_dataset(\n    datapipes: List[MapDataPipe],\n    dataset_config: DictConfig,\n    batch_size: int,\n    num_workers: int,\n) -> ConcatStreamingDataPipe:\n    assert len(datapipes) > 0\n    augmentation_datapipe_type = partialclass(\n        RandAugmentIterDataPipe, dataset_config=dataset_config\n    )\n    streaming_dataset = ConcatStreamingDataPipe(\n        datapipe_list=datapipes,\n        batch_size=batch_size,\n        num_workers=num_workers,\n        augmentation_pipeline=augmentation_datapipe_type,\n        print_seed_debug=False,\n    )\n    return streaming_dataset\n\n\ndef build_streaming_evaluation_dataset(\n    datapipes: List[MapDataPipe], batch_size: int\n) -> ShardedStreamingDataPipe:\n    assert len(datapipes) > 0\n    fill_value = datapipes[0].get_fully_padded_sample()\n    streaming_dataset = ShardedStreamingDataPipe(\n        datapipe_list=datapipes, batch_size=batch_size, fill_value=fill_value\n    )\n    return streaming_dataset\n"
  },
  {
    "path": "RVT/data/genx_utils/labels.py",
    "content": "from __future__ import annotations\n\nfrom typing import List, Tuple, Union, Optional\n\nimport math\nimport numpy as np\nimport torch as th\nfrom einops import rearrange\nfrom torch.nn.functional import pad\n\n\nclass ObjectLabelBase:\n    _str2idx = {\n        \"t\": 0,\n        \"x\": 1,\n        \"y\": 2,\n        \"w\": 3,\n        \"h\": 4,\n        \"class_id\": 5,\n        \"class_confidence\": 6,\n    }\n\n    def __init__(self, object_labels: th.Tensor, input_size_hw: Tuple[int, int]):\n        assert isinstance(object_labels, th.Tensor)\n        assert object_labels.dtype in {th.float32, th.float64}\n        assert object_labels.ndim == 2\n        assert object_labels.shape[-1] == len(self._str2idx)\n        assert isinstance(input_size_hw, tuple)\n        assert len(input_size_hw) == 2\n\n        self.object_labels = object_labels\n        self._input_size_hw = input_size_hw\n        self._is_numpy = False\n\n    def clamp_to_frame_(self):\n        ht, wd = self.input_size_hw\n        x0 = th.clamp(self.x, min=0, max=wd - 1)\n        y0 = th.clamp(self.y, min=0, max=ht - 1)\n        x1 = th.clamp(self.x + self.w, min=0, max=wd - 1)\n        y1 = th.clamp(self.y + self.h, min=0, max=ht - 1)\n        w = x1 - x0\n        h = y1 - y0\n        assert th.all(w > 0)\n        assert th.all(h > 0)\n        self.x = x0\n        self.y = y0\n        self.w = w\n        self.h = h\n\n    def remove_flat_labels_(self):\n        keep = (self.w > 0) & (self.h > 0)\n        self.object_labels = self.object_labels[keep]\n\n    @classmethod\n    def create_empty(cls):\n        # This is useful to represent cases where no labels are available.\n        return ObjectLabelBase(\n            object_labels=th.empty((0, len(cls._str2idx))), input_size_hw=(0, 0)\n        )\n\n    def _assert_not_numpy(self):\n        assert (\n            not self._is_numpy\n        ), \"Labels have been converted numpy. \\\n        Numpy is not supported for the intended operations.\"\n\n    def to(self, *args, **kwargs):\n        # This function executes torch.to on self tensors and returns self.\n        self._assert_not_numpy()\n        # This will be used by Pytorch Lightning to transfer to the relevant device\n        self.object_labels = self.object_labels.to(*args, **kwargs)\n        return self\n\n    def numpy_(self) -> None:\n        \"\"\"\n        In place conversion to numpy (detach + to cpu + to numpy).\n        Cannot be undone.\n        \"\"\"\n        self._is_numpy = True\n        self.object_labels = self.object_labels.detach().cpu().numpy()\n\n    @property\n    def input_size_hw(self) -> Tuple[int, int]:\n        return self._input_size_hw\n\n    @input_size_hw.setter\n    def input_size_hw(self, height_width: Tuple[int, int]):\n        assert isinstance(height_width, tuple)\n        assert len(height_width) == 2\n        assert height_width[0] > 0\n        assert height_width[1] > 0\n        self._input_size_hw = height_width\n\n    def get(self, request: str):\n        assert request in self._str2idx\n        return self.object_labels[:, self._str2idx[request]]\n\n    @property\n    def t(self):\n        return self.object_labels[:, self._str2idx[\"t\"]]\n\n    @property\n    def x(self):\n        return self.object_labels[:, self._str2idx[\"x\"]]\n\n    @x.setter\n    def x(self, value: Union[th.Tensor, np.ndarray]):\n        self.object_labels[:, self._str2idx[\"x\"]] = value\n\n    @property\n    def y(self):\n        return self.object_labels[:, self._str2idx[\"y\"]]\n\n    @y.setter\n    def y(self, value: Union[th.Tensor, np.ndarray]):\n        self.object_labels[:, self._str2idx[\"y\"]] = value\n\n    @property\n    def w(self):\n        return self.object_labels[:, self._str2idx[\"w\"]]\n\n    @w.setter\n    def w(self, value: Union[th.Tensor, np.ndarray]):\n        self.object_labels[:, self._str2idx[\"w\"]] = value\n\n    @property\n    def h(self):\n        return self.object_labels[:, self._str2idx[\"h\"]]\n\n    @h.setter\n    def h(self, value: Union[th.Tensor, np.ndarray]):\n        self.object_labels[:, self._str2idx[\"h\"]] = value\n\n    @property\n    def class_id(self):\n        return self.object_labels[:, self._str2idx[\"class_id\"]]\n\n    @property\n    def class_confidence(self):\n        return self.object_labels[:, self._str2idx[\"class_confidence\"]]\n\n    @property\n    def dtype(self):\n        return self.object_labels.dtype\n\n    @property\n    def device(self):\n        return self.object_labels.device\n\n\nclass ObjectLabelFactory(ObjectLabelBase):\n    def __init__(\n        self,\n        object_labels: th.Tensor,\n        objframe_idx_2_label_idx: th.Tensor,\n        input_size_hw: Tuple[int, int],\n        downsample_factor: Optional[float] = None,\n    ):\n        super().__init__(object_labels=object_labels, input_size_hw=input_size_hw)\n        assert objframe_idx_2_label_idx.dtype == th.int64\n        assert objframe_idx_2_label_idx.dim() == 1\n\n        self.objframe_idx_2_label_idx = objframe_idx_2_label_idx\n        self.downsample_factor = downsample_factor\n        if self.downsample_factor is not None:\n            assert self.downsample_factor > 1\n        self.clamp_to_frame_()\n\n    @staticmethod\n    def from_structured_array(\n        object_labels: np.ndarray,\n        objframe_idx_2_label_idx: np.ndarray,\n        input_size_hw: Tuple[int, int],\n        downsample_factor: Optional[float] = None,\n    ) -> ObjectLabelFactory:\n        np_labels = [\n            object_labels[key].astype(\"float32\") for key in ObjectLabels._str2idx.keys()\n        ]\n        np_labels = rearrange(np_labels, \"fields L -> L fields\")\n        torch_labels = th.from_numpy(np_labels)\n        objframe_idx_2_label_idx = th.from_numpy(\n            objframe_idx_2_label_idx.astype(\"int64\")\n        )\n        assert objframe_idx_2_label_idx.numel() == np.unique(object_labels[\"t\"]).size\n        return ObjectLabelFactory(\n            object_labels=torch_labels,\n            objframe_idx_2_label_idx=objframe_idx_2_label_idx,\n            input_size_hw=input_size_hw,\n            downsample_factor=downsample_factor,\n        )\n\n    def __len__(self):\n        return len(self.objframe_idx_2_label_idx)\n\n    def __getitem__(self, item: int) -> ObjectLabels:\n        assert item >= 0\n        length = len(self)\n        assert length > 0\n        assert item < length\n        is_last_item = item == length - 1\n\n        from_idx = self.objframe_idx_2_label_idx[item]\n        to_idx = (\n            self.object_labels.shape[0]\n            if is_last_item\n            else self.objframe_idx_2_label_idx[item + 1]\n        )\n        assert to_idx > from_idx\n        object_labels = ObjectLabels(\n            object_labels=self.object_labels[from_idx:to_idx].clone(),\n            input_size_hw=self.input_size_hw,\n        )\n        if self.downsample_factor is not None:\n            object_labels.scale_(scaling_multiplier=1 / self.downsample_factor)\n        return object_labels\n\n\nclass ObjectLabels(ObjectLabelBase):\n    def __init__(self, object_labels: th.Tensor, input_size_hw: Tuple[int, int]):\n        super().__init__(object_labels=object_labels, input_size_hw=input_size_hw)\n\n    def __len__(self) -> int:\n        return self.object_labels.shape[0]\n\n    def rotate_(self, angle_deg: float):\n        if len(self) == 0:\n            return\n        # (x0,y0)---(x1,y0)   p00---p10\n        #  |             |    |       |\n        #  |             |    |       |\n        # (x0,y1)---(x1,y1)   p01---p11\n        p00 = th.stack((self.x, self.y), dim=1)\n        p10 = th.stack((self.x + self.w, self.y), dim=1)\n        p01 = th.stack((self.x, self.y + self.h), dim=1)\n        p11 = th.stack((self.x + self.w, self.y + self.h), dim=1)\n        # points: 4 x N x 2\n        points = th.stack((p00, p10, p01, p11), dim=0)\n\n        cx = self._input_size_hw[1] // 2\n        cy = self._input_size_hw[0] // 2\n        center = th.tensor([cx, cy], device=self.device)\n\n        angle_rad = angle_deg / 180 * math.pi\n        # counter-clockwise rotation\n        rot_matrix = th.tensor(\n            [\n                [math.cos(angle_rad), math.sin(angle_rad)],\n                [-math.sin(angle_rad), math.cos(angle_rad)],\n            ],\n            device=self.device,\n        )\n\n        points = points - center\n        points = th.einsum(\"ij,pnj->pni\", rot_matrix, points)\n        points = points + center\n\n        height, width = self.input_size_hw\n        x0 = th.clamp(th.min(points[..., 0], dim=0)[0], min=0, max=width - 1)\n        y0 = th.clamp(th.min(points[..., 1], dim=0)[0], min=0, max=height - 1)\n        x1 = th.clamp(th.max(points[..., 0], dim=0)[0], min=0, max=width - 1)\n        y1 = th.clamp(th.max(points[..., 1], dim=0)[0], min=0, max=height - 1)\n\n        self.x = x0\n        self.y = y0\n        self.w = x1 - x0\n        self.h = y1 - y0\n\n        self.remove_flat_labels_()\n\n        assert th.all(self.x >= 0)\n        assert th.all(self.y >= 0)\n        assert th.all(self.x + self.w <= self.input_size_hw[1] - 1)\n        assert th.all(self.y + self.h <= self.input_size_hw[0] - 1)\n\n    def zoom_in_and_rescale_(\n        self, zoom_coordinates_x0y0: Tuple[int, int], zoom_in_factor: float\n    ):\n        \"\"\"\n        1) Computes a new smaller canvas size: original canvas scaled by a factor of 1/zoom_in_factor (downscaling)\n        2) Places the smaller canvas inside the original canvas at the top-left coordinates zoom_coordinates_x0y0\n        3) Extract the smaller canvas and rescale it back to the original resolution\n        \"\"\"\n        if len(self) == 0:\n            return\n        assert len(zoom_coordinates_x0y0) == 2\n        assert zoom_in_factor >= 1\n        if zoom_in_factor == 1:\n            return\n        z_x0, z_y0 = zoom_coordinates_x0y0\n        h_orig, w_orig = self.input_size_hw\n        assert 0 <= z_x0 <= w_orig - 1\n        assert 0 <= z_y0 <= h_orig - 1\n        zoom_window_h, zoom_window_w = tuple(\n            x / zoom_in_factor for x in self.input_size_hw\n        )\n        z_x1 = min(z_x0 + zoom_window_w, w_orig - 1)\n        assert z_x1 <= w_orig - 1, f\"{z_x1=} is larger than {w_orig-1=}\"\n        z_y1 = min(z_y0 + zoom_window_h, h_orig - 1)\n        assert z_y1 <= h_orig - 1, f\"{z_y1=} is larger than {h_orig-1=}\"\n\n        x0 = th.clamp(self.x, min=z_x0, max=z_x1 - 1)\n        y0 = th.clamp(self.y, min=z_y0, max=z_y1 - 1)\n\n        x1 = th.clamp(self.x + self.w, min=z_x0, max=z_x1 - 1)\n        y1 = th.clamp(self.y + self.h, min=z_y0, max=z_y1 - 1)\n\n        self.x = x0 - z_x0\n        self.y = y0 - z_y0\n        self.w = x1 - x0\n        self.h = y1 - y0\n        self.input_size_hw = (zoom_window_h, zoom_window_w)\n\n        self.remove_flat_labels_()\n\n        self.scale_(scaling_multiplier=zoom_in_factor)\n\n    def zoom_out_and_rescale_(\n        self, zoom_coordinates_x0y0: Tuple[int, int], zoom_out_factor: float\n    ):\n        \"\"\"\n        1) Scales the input by a factor of 1/zoom_out_factor (i.e. reduces the canvas size)\n        2) Places the downscaled canvas into the original canvas at the top-left coordinates zoom_coordinates_x0y0\n        \"\"\"\n        if len(self) == 0:\n            return\n        assert len(zoom_coordinates_x0y0) == 2\n        assert zoom_out_factor >= 1\n        if zoom_out_factor == 1:\n            return\n\n        h_orig, w_orig = self.input_size_hw\n        self.scale_(scaling_multiplier=1 / zoom_out_factor)\n\n        self.input_size_hw = (h_orig, w_orig)\n        z_x0, z_y0 = zoom_coordinates_x0y0\n        assert 0 <= z_x0 <= w_orig - 1\n        assert 0 <= z_y0 <= h_orig - 1\n\n        self.x = self.x + z_x0\n        self.y = self.y + z_y0\n\n    def scale_(self, scaling_multiplier: float):\n        if len(self) == 0:\n            return\n        assert scaling_multiplier > 0\n        if scaling_multiplier == 1:\n            return\n        img_ht, img_wd = self.input_size_hw\n        new_img_ht = scaling_multiplier * img_ht\n        new_img_wd = scaling_multiplier * img_wd\n        self.input_size_hw = (new_img_ht, new_img_wd)\n        x1 = th.clamp((self.x + self.w) * scaling_multiplier, max=new_img_wd - 1)\n        y1 = th.clamp((self.y + self.h) * scaling_multiplier, max=new_img_ht - 1)\n        self.x = self.x * scaling_multiplier\n        self.y = self.y * scaling_multiplier\n\n        self.w = x1 - self.x\n        self.h = y1 - self.y\n\n        self.remove_flat_labels_()\n\n    def flip_lr_(self) -> None:\n        if len(self) == 0:\n            return\n        self.x = self.input_size_hw[1] - 1 - self.x - self.w\n\n    def get_labels_as_tensors(self, format_: str = \"yolox\") -> th.Tensor:\n        self._assert_not_numpy()\n\n        if format_ == \"yolox\":\n            out = th.zeros((len(self), 5), dtype=th.float32, device=self.device)\n            if len(self) == 0:\n                return out\n            out[:, 0] = self.class_id\n            out[:, 1] = self.x + 0.5 * self.w\n            out[:, 2] = self.y + 0.5 * self.h\n            out[:, 3] = self.w\n            out[:, 4] = self.h\n            return out\n        else:\n            raise NotImplementedError\n\n    @staticmethod\n    def get_labels_as_batched_tensor(\n        obj_label_list: List[ObjectLabels], format_: str = \"yolox\"\n    ) -> th.Tensor:\n        num_object_frames = len(obj_label_list)\n        assert num_object_frames > 0\n        max_num_labels_per_object_frame = max([len(x) for x in obj_label_list])\n        assert max_num_labels_per_object_frame > 0\n\n        if format_ == \"yolox\":\n            tensor_labels = []\n            for labels in obj_label_list:\n                obj_labels_tensor = labels.get_labels_as_tensors(format_=format_)\n                num_to_pad = max_num_labels_per_object_frame - len(labels)\n                padded_labels = pad(\n                    obj_labels_tensor, (0, 0, 0, num_to_pad), mode=\"constant\", value=0\n                )\n                tensor_labels.append(padded_labels)\n            tensor_labels = th.stack(tensors=tensor_labels, dim=0)\n            return tensor_labels\n        else:\n            raise NotImplementedError\n\n\nclass SparselyBatchedObjectLabels:\n    def __init__(self, sparse_object_labels_batch: List[Optional[ObjectLabels]]):\n        # Can contain None elements that indicate missing labels.\n        for entry in sparse_object_labels_batch:\n            assert isinstance(entry, ObjectLabels) or entry is None\n        self.sparse_object_labels_batch = sparse_object_labels_batch\n        self.set_empty_labels_to_none_()\n\n    def __len__(self) -> int:\n        return len(self.sparse_object_labels_batch)\n\n    def __iter__(self):\n        return iter(self.sparse_object_labels_batch)\n\n    def __getitem__(self, item: int) -> Optional[ObjectLabels]:\n        if item < 0 or item >= len(self):\n            raise IndexError(f\"Index ({item}) out of range (0, {len(self) - 1})\")\n        return self.sparse_object_labels_batch[item]\n\n    def __add__(self, other: SparselyBatchedObjectLabels):\n        sparse_object_labels_batch = (\n            self.sparse_object_labels_batch + other.sparse_object_labels_batch\n        )\n        return SparselyBatchedObjectLabels(\n            sparse_object_labels_batch=sparse_object_labels_batch\n        )\n\n    def set_empty_labels_to_none_(self):\n        for idx, obj_label in enumerate(self.sparse_object_labels_batch):\n            if obj_label is not None and len(obj_label) == 0:\n                self.sparse_object_labels_batch[idx] = None\n\n    @property\n    def input_size_hw(self) -> Optional[Union[Tuple[int, int], Tuple[float, float]]]:\n        for obj_labels in self.sparse_object_labels_batch:\n            if obj_labels is not None:\n                return obj_labels.input_size_hw\n        return None\n\n    def zoom_in_and_rescale_(self, *args, **kwargs):\n        for idx, entry in enumerate(self.sparse_object_labels_batch):\n            if entry is not None:\n                self.sparse_object_labels_batch[idx].zoom_in_and_rescale_(\n                    *args, **kwargs\n                )\n        # We may have deleted labels. If no labels are left, set the object to None\n        self.set_empty_labels_to_none_()\n\n    def zoom_out_and_rescale_(self, *args, **kwargs):\n        for idx, entry in enumerate(self.sparse_object_labels_batch):\n            if entry is not None:\n                self.sparse_object_labels_batch[idx].zoom_out_and_rescale_(\n                    *args, **kwargs\n                )\n\n    def rotate_(self, *args, **kwargs):\n        for idx, entry in enumerate(self.sparse_object_labels_batch):\n            if entry is not None:\n                self.sparse_object_labels_batch[idx].rotate_(*args, **kwargs)\n\n    def scale_(self, *args, **kwargs):\n        for idx, entry in enumerate(self.sparse_object_labels_batch):\n            if entry is not None:\n                self.sparse_object_labels_batch[idx].scale_(*args, **kwargs)\n        # We may have deleted labels. If no labels are left, set the object to None\n        self.set_empty_labels_to_none_()\n\n    def flip_lr_(self):\n        for idx, entry in enumerate(self.sparse_object_labels_batch):\n            if entry is not None:\n                self.sparse_object_labels_batch[idx].flip_lr_()\n\n    def to(self, *args, **kwargs):\n        for idx, entry in enumerate(self.sparse_object_labels_batch):\n            if entry is not None:\n                self.sparse_object_labels_batch[idx].to(*args, **kwargs)\n        return self\n\n    def get_valid_labels_and_batch_indices(\n        self,\n    ) -> Tuple[List[ObjectLabels], List[int]]:\n        out = list()\n        valid_indices = list()\n        for idx, label in enumerate(self.sparse_object_labels_batch):\n            if label is not None:\n                out.append(label)\n                valid_indices.append(idx)\n        return out, valid_indices\n\n    @staticmethod\n    def transpose_list(\n        list_of_sparsely_batched_object_labels: List[SparselyBatchedObjectLabels],\n    ) -> List[SparselyBatchedObjectLabels]:\n        return [\n            SparselyBatchedObjectLabels(list(labels_as_tuple))\n            for labels_as_tuple in zip(*list_of_sparsely_batched_object_labels)\n        ]\n"
  },
  {
    "path": "RVT/data/genx_utils/sequence_base.py",
    "content": "from pathlib import Path\nfrom typing import Any, List, Optional\n\nimport h5py\nimport numpy as np\nimport torch\nfrom torchdata.datapipes.map import MapDataPipe\n\nfrom data.genx_utils.labels import ObjectLabelFactory, ObjectLabels\nfrom data.utils.spatial import get_original_hw\nfrom data.utils.types import DatasetType\nfrom utils.timers import TimerDummy as Timer\n\n\ndef get_event_representation_dir(path: Path, ev_representation_name: str) -> Path:\n    ev_repr_dir = path / \"event_representations_v2\" / ev_representation_name\n    assert ev_repr_dir.is_dir(), f\"{ev_repr_dir}\"\n    return ev_repr_dir\n\n\ndef get_objframe_idx_2_repr_idx(path: Path, ev_representation_name: str) -> np.ndarray:\n    ev_repr_dir = get_event_representation_dir(\n        path=path, ev_representation_name=ev_representation_name\n    )\n    objframe_idx_2_repr_idx = np.load(str(ev_repr_dir / \"objframe_idx_2_repr_idx.npy\"))\n    return objframe_idx_2_repr_idx\n\n\nclass SequenceBase(MapDataPipe):\n    \"\"\"\n    Structure example of a sequence:\n    .\n    ├── event_representations_v2\n    │ └── ev_representation_name\n    │     ├── event_representations.h5\n    │     ├── objframe_idx_2_repr_idx.npy\n    │     └── timestamps_us.npy\n    └── labels_v2\n        ├── labels.npz\n        └── timestamps_us.npy\n    \"\"\"\n\n    def __init__(\n        self,\n        path: Path,\n        ev_representation_name: str,\n        sequence_length: int,\n        dataset_type: DatasetType,\n        downsample_by_factor_2: bool,\n        only_load_end_labels: bool,\n    ):\n        assert sequence_length >= 1\n        assert path.is_dir()\n        assert dataset_type in {\n            DatasetType.GEN1,\n            DatasetType.GEN4,\n        }, f\"{dataset_type} not implemented\"\n\n        self.only_load_end_labels = only_load_end_labels\n\n        ev_repr_dir = get_event_representation_dir(\n            path=path, ev_representation_name=ev_representation_name\n        )\n\n        labels_dir = path / \"labels_v2\"\n        assert labels_dir.is_dir()\n\n        height, width = get_original_hw(dataset_type)\n        self.seq_len = sequence_length\n\n        ds_factor_str = \"_ds2_nearest\" if downsample_by_factor_2 else \"\"\n        self.ev_repr_file = ev_repr_dir / f\"event_representations{ds_factor_str}.h5\"\n        assert self.ev_repr_file.exists(), f\"{str(self.ev_repr_file)=}\"\n\n        with Timer(timer_name=\"prepare labels\"):\n            label_data = np.load(str(labels_dir / \"labels.npz\"))\n            objframe_idx_2_label_idx = label_data[\"objframe_idx_2_label_idx\"]\n            labels = label_data[\"labels\"]\n            label_factory = ObjectLabelFactory.from_structured_array(\n                object_labels=labels,\n                objframe_idx_2_label_idx=objframe_idx_2_label_idx,\n                input_size_hw=(height, width),\n                downsample_factor=2 if downsample_by_factor_2 else None,\n            )\n            self.label_factory = label_factory\n\n        with Timer(timer_name=\"load objframe_idx_2_repr_idx\"):\n            self.objframe_idx_2_repr_idx = get_objframe_idx_2_repr_idx(\n                path=path, ev_representation_name=ev_representation_name\n            )\n        with Timer(timer_name=\"construct repr_idx_2_objframe_idx\"):\n            self.repr_idx_2_objframe_idx = dict(\n                zip(\n                    self.objframe_idx_2_repr_idx,\n                    range(len(self.objframe_idx_2_repr_idx)),\n                )\n            )\n\n    def _get_labels_from_repr_idx(self, repr_idx: int) -> Optional[ObjectLabels]:\n        objframe_idx = self.repr_idx_2_objframe_idx.get(repr_idx, None)\n        return None if objframe_idx is None else self.label_factory[objframe_idx]\n\n    def _get_event_repr_torch(self, start_idx: int, end_idx: int) -> List[torch.Tensor]:\n        assert end_idx > start_idx\n        with h5py.File(str(self.ev_repr_file), \"r\") as h5f:\n            ev_repr = h5f[\"data\"][start_idx:end_idx]\n        ev_repr = torch.from_numpy(ev_repr)\n        if ev_repr.dtype != torch.uint8:\n            ev_repr = torch.asarray(ev_repr, dtype=torch.float32)\n        ev_repr = torch.split(ev_repr, 1, dim=0)\n        # remove first dim that is always 1 due to how torch.split works\n        ev_repr = [x[0] for x in ev_repr]\n        return ev_repr\n\n    def __len__(self) -> int:\n        raise NotImplementedError\n\n    def __getitem__(self, index: int) -> Any:\n        raise NotImplementedError\n"
  },
  {
    "path": "RVT/data/genx_utils/sequence_for_streaming.py",
    "content": "from pathlib import Path\nfrom typing import List, Optional, Union, Tuple\n\nimport h5py\nimport numpy as np\nimport torch\nfrom omegaconf import DictConfig\nfrom torchdata.datapipes.iter import IterDataPipe\n\nfrom data.genx_utils.labels import SparselyBatchedObjectLabels\nfrom data.genx_utils.sequence_base import SequenceBase, get_objframe_idx_2_repr_idx\nfrom data.utils.augmentor import RandomSpatialAugmentorGenX\nfrom data.utils.types import DataType, DatasetType, LoaderDataDictGenX\nfrom utils.timers import TimerDummy as Timer\n\n\ndef _scalar_as_1d_array(scalar: Union[int, float]):\n    return np.atleast_1d(scalar)\n\n\ndef _get_ev_repr_range_indices(\n    indices: np.ndarray, max_len: int\n) -> List[Tuple[int, int]]:\n    \"\"\"\n    Computes a list of index ranges based on the input array of indices and a maximum length.\n    The index ranges are computed such that the difference between consecutive indices\n    should not exceed the maximum length (max_len).\n\n    Parameters:\n    -----------\n    indices : np.ndarray\n        A NumPy array of indices, where the indices are sorted in ascending order.\n    max_len : int\n        The maximum allowed length between consecutive indices.\n\n    Returns:\n    --------\n    out : List[Tuple[int, int]]\n        A list of tuples, where each tuple contains two integers representing the start and\n        stop indices of the range.\n    \"\"\"\n    meta_indices_stop = np.flatnonzero(np.diff(indices) > max_len)\n\n    meta_indices_start = np.concatenate((np.atleast_1d(0), meta_indices_stop + 1))\n    meta_indices_stop = np.concatenate(\n        (meta_indices_stop, np.atleast_1d(len(indices) - 1))\n    )\n\n    out = list()\n    for meta_idx_start, meta_idx_stop in zip(meta_indices_start, meta_indices_stop):\n        idx_start = max(indices[meta_idx_start] - max_len + 1, 0)\n        idx_stop = indices[meta_idx_stop] + 1\n        out.append((idx_start, idx_stop))\n    return out\n\n\nclass SequenceForIter(SequenceBase):\n    def __init__(\n        self,\n        path: Path,\n        ev_representation_name: str,\n        sequence_length: int,\n        dataset_type: DatasetType,\n        downsample_by_factor_2: bool,\n        range_indices: Optional[Tuple[int, int]] = None,\n    ):\n        super().__init__(\n            path=path,\n            ev_representation_name=ev_representation_name,\n            sequence_length=sequence_length,\n            dataset_type=dataset_type,\n            downsample_by_factor_2=downsample_by_factor_2,\n            only_load_end_labels=False,\n        )\n\n        with h5py.File(str(self.ev_repr_file), \"r\") as h5f:\n            num_ev_repr = h5f[\"data\"].shape[0]\n        if range_indices is None:\n            repr_idx_start = max(\n                self.objframe_idx_2_repr_idx[0] - sequence_length + 1, 0\n            )\n            repr_idx_stop = num_ev_repr\n        else:\n            repr_idx_start, repr_idx_stop = range_indices\n        # Set start idx such that the first label is no further than the last timestamp of the first sample sub-sequence\n        min_start_repr_idx = max(\n            self.objframe_idx_2_repr_idx[0] - sequence_length + 1, 0\n        )\n        assert (\n            0 <= min_start_repr_idx <= repr_idx_start < repr_idx_stop <= num_ev_repr\n        ), f\"{min_start_repr_idx=}, {repr_idx_start=}, {repr_idx_stop=}, {num_ev_repr=}, {path=}\"\n\n        self.start_indices = list(range(repr_idx_start, repr_idx_stop, sequence_length))\n        self.stop_indices = self.start_indices[1:] + [repr_idx_stop]\n        self.length = len(self.start_indices)\n\n        self._padding_representation = None\n\n    @staticmethod\n    def get_sequences_with_guaranteed_labels(\n        path: Path,\n        ev_representation_name: str,\n        sequence_length: int,\n        dataset_type: DatasetType,\n        downsample_by_factor_2: bool,\n    ) -> List[\"SequenceForIter\"]:\n        \"\"\"Generate sequences such that we do always have labels within each sample of the sequence\n        This is required for training such that we are guaranteed to always have labels in the training step.\n        However, for validation we don't require this if we catch the special case.\n        \"\"\"\n        objframe_idx_2_repr_idx = get_objframe_idx_2_repr_idx(\n            path=path, ev_representation_name=ev_representation_name\n        )\n        # max diff for repr idx is sequence length\n        range_indices_list = _get_ev_repr_range_indices(\n            indices=objframe_idx_2_repr_idx, max_len=sequence_length\n        )\n        sequence_list = list()\n        for range_indices in range_indices_list:\n            sequence_list.append(\n                SequenceForIter(\n                    path=path,\n                    ev_representation_name=ev_representation_name,\n                    sequence_length=sequence_length,\n                    dataset_type=dataset_type,\n                    downsample_by_factor_2=downsample_by_factor_2,\n                    range_indices=range_indices,\n                )\n            )\n        return sequence_list\n\n    @property\n    def padding_representation(self) -> torch.Tensor:\n        if self._padding_representation is None:\n            ev_repr = self._get_event_repr_torch(start_idx=0, end_idx=1)[0]\n            self._padding_representation = torch.zeros_like(ev_repr)\n        return self._padding_representation\n\n    def get_fully_padded_sample(self) -> LoaderDataDictGenX:\n        is_first_sample = False\n        is_padded_mask = [True] * self.seq_len\n        ev_repr = [self.padding_representation] * self.seq_len\n        labels = [None] * self.seq_len\n        sparse_labels = SparselyBatchedObjectLabels(sparse_object_labels_batch=labels)\n        out = {\n            DataType.EV_REPR: ev_repr,\n            DataType.OBJLABELS_SEQ: sparse_labels,\n            DataType.IS_FIRST_SAMPLE: is_first_sample,\n            DataType.IS_PADDED_MASK: is_padded_mask,\n        }\n        return out\n\n    def __len__(self):\n        return self.length\n\n    def __getitem__(self, index: int) -> LoaderDataDictGenX:\n        start_idx = self.start_indices[index]\n        end_idx = self.stop_indices[index]\n\n        # sequence info ###\n        sample_len = end_idx - start_idx\n        assert self.seq_len >= sample_len > 0, (\n            f\"{self.seq_len=}, {sample_len=}, {start_idx=}, {end_idx=}, \"\n            f\"\\n{self.start_indices=}\\n{self.stop_indices=}\"\n        )\n\n        is_first_sample = True if index == 0 else False\n        is_padded_mask = [False] * sample_len\n        ###################\n\n        # event representations ###\n        with Timer(timer_name=\"read ev reprs\"):\n            ev_repr = self._get_event_repr_torch(start_idx=start_idx, end_idx=end_idx)\n        assert len(ev_repr) == sample_len\n        ###########################\n\n        # labels ###\n        labels = list()\n        for repr_idx in range(start_idx, end_idx):\n            labels.append(self._get_labels_from_repr_idx(repr_idx))\n        assert len(labels) == len(ev_repr)\n        ############\n\n        # apply padding (if necessary) ###\n        if sample_len < self.seq_len:\n            padding_len = self.seq_len - sample_len\n\n            is_padded_mask.extend([True] * padding_len)\n            ev_repr.extend([self.padding_representation] * padding_len)\n            labels.extend([None] * padding_len)\n        ##################################\n\n        # convert labels to sparse labels for datapipes and dataloader\n        sparse_labels = SparselyBatchedObjectLabels(sparse_object_labels_batch=labels)\n\n        out = {\n            DataType.EV_REPR: ev_repr,\n            DataType.OBJLABELS_SEQ: sparse_labels,\n            DataType.IS_FIRST_SAMPLE: is_first_sample,\n            DataType.IS_PADDED_MASK: is_padded_mask,\n        }\n        return out\n\n\nclass RandAugmentIterDataPipe(IterDataPipe):\n    def __init__(self, source_dp: IterDataPipe, dataset_config: DictConfig):\n        super().__init__()\n        self.source_dp = source_dp\n\n        resolution_hw = tuple(dataset_config.resolution_hw)\n        assert len(resolution_hw) == 2\n        ds_by_factor_2 = dataset_config.downsample_by_factor_2\n        if ds_by_factor_2:\n            resolution_hw = tuple(x // 2 for x in resolution_hw)\n\n        augm_config = dataset_config.data_augmentation\n        self.spatial_augmentor = RandomSpatialAugmentorGenX(\n            dataset_hw=resolution_hw,\n            automatic_randomization=False,\n            augm_config=augm_config.stream,\n        )\n\n    def __iter__(self):\n        self.spatial_augmentor.randomize_augmentation()\n        for x in self.source_dp:\n            yield self.spatial_augmentor(x)\n"
  },
  {
    "path": "RVT/data/genx_utils/sequence_rnd.py",
    "content": "from pathlib import Path\n\nfrom data.genx_utils.labels import SparselyBatchedObjectLabels\nfrom data.genx_utils.sequence_base import SequenceBase\nfrom data.utils.types import DataType, DatasetType, LoaderDataDictGenX\nfrom utils.timers import TimerDummy as Timer\n\n\nclass SequenceForRandomAccess(SequenceBase):\n    def __init__(\n        self,\n        path: Path,\n        ev_representation_name: str,\n        sequence_length: int,\n        dataset_type: DatasetType,\n        downsample_by_factor_2: bool,\n        only_load_end_labels: bool,\n    ):\n        super().__init__(\n            path=path,\n            ev_representation_name=ev_representation_name,\n            sequence_length=sequence_length,\n            dataset_type=dataset_type,\n            downsample_by_factor_2=downsample_by_factor_2,\n            only_load_end_labels=only_load_end_labels,\n        )\n\n        self.start_idx_offset = None\n        for objframe_idx, repr_idx in enumerate(self.objframe_idx_2_repr_idx):\n            if repr_idx - self.seq_len + 1 >= 0:\n                # We can fit the sequence length to the label\n                self.start_idx_offset = objframe_idx\n                break\n        if self.start_idx_offset is None:\n            # This leads to actual length of 0:\n            self.start_idx_offset = len(self.label_factory)\n\n        self.length = len(self.label_factory) - self.start_idx_offset\n        assert len(self.label_factory) == len(self.objframe_idx_2_repr_idx)\n\n        # Useful for weighted sampler that is based on label statistics:\n        self._only_load_labels = False\n\n    def __len__(self):\n        return self.length\n\n    def __getitem__(self, index: int) -> LoaderDataDictGenX:\n        corrected_idx = index + self.start_idx_offset\n        labels_repr_idx = self.objframe_idx_2_repr_idx[corrected_idx]\n\n        end_idx = labels_repr_idx + 1\n        start_idx = end_idx - self.seq_len\n        assert_msg = (\n            f\"{self.ev_repr_file=}, {self.start_idx_offset=}, {start_idx=}, {end_idx=}\"\n        )\n        assert start_idx >= 0, assert_msg\n\n        labels = list()\n        for repr_idx in range(start_idx, end_idx):\n            if self.only_load_end_labels and repr_idx < end_idx - 1:\n                labels.append(None)\n            else:\n                labels.append(self._get_labels_from_repr_idx(repr_idx))\n        sparse_labels = SparselyBatchedObjectLabels(sparse_object_labels_batch=labels)\n        if self._only_load_labels:\n            return {DataType.OBJLABELS_SEQ: sparse_labels}\n\n        with Timer(timer_name=\"read ev reprs\"):\n            ev_repr = self._get_event_repr_torch(start_idx=start_idx, end_idx=end_idx)\n        assert len(sparse_labels) == len(ev_repr)\n\n        is_first_sample = True  # Due to random loading\n        is_padded_mask = [False] * len(ev_repr)\n\n        out = {\n            DataType.EV_REPR: ev_repr,\n            DataType.OBJLABELS_SEQ: sparse_labels,\n            DataType.IS_FIRST_SAMPLE: is_first_sample,\n            DataType.IS_PADDED_MASK: is_padded_mask,\n        }\n        return out\n\n    def is_only_loading_labels(self) -> bool:\n        return self._only_load_labels\n\n    def only_load_labels(self):\n        self._only_load_labels = True\n\n    def load_everything(self):\n        self._only_load_labels = False\n"
  },
  {
    "path": "RVT/data/utils/augmentor.py",
    "content": "import collections.abc as abc\nfrom dataclasses import dataclass\nfrom typing import Any, Optional, Tuple, Union\nfrom warnings import filterwarnings, warn\n\nimport torch as th\nimport torch.distributions.categorical\nfrom omegaconf import DictConfig\nfrom torch.nn.functional import interpolate\nfrom torchvision.transforms import InterpolationMode\nfrom torchvision.transforms.functional import rotate\n\nfrom data.genx_utils.labels import ObjectLabels, SparselyBatchedObjectLabels\nfrom data.utils.types import DataType, LoaderDataDictGenX\nfrom utils.helpers import torch_uniform_sample_scalar\n\nNO_LABEL_WARN_MSG = (\n    \"No Labels found. This can lead to a crash and should not happen often.\"\n)\nfilterwarnings(\"always\", message=NO_LABEL_WARN_MSG)\n\n\n@dataclass\nclass ZoomOutState:\n    active: bool\n    x0: int\n    y0: int\n    zoom_out_factor: float\n\n\n@dataclass\nclass RotationState:\n    active: bool\n    angle_deg: float\n\n\n@dataclass\nclass AugmentationState:\n    apply_h_flip: bool\n    rotation: RotationState\n    apply_zoom_in: bool\n    zoom_out: ZoomOutState\n\n\nclass RandomSpatialAugmentorGenX:\n    def __init__(\n        self,\n        dataset_hw: Tuple[int, int],\n        automatic_randomization: bool,\n        augm_config: DictConfig,\n    ):\n        assert isinstance(dataset_hw, tuple)\n        assert len(dataset_hw) == 2\n        assert all(x > 0 for x in dataset_hw)\n        assert isinstance(automatic_randomization, bool)\n\n        self.hw_tuple = dataset_hw\n        self.automatic_randomization = automatic_randomization\n        self.h_flip_prob = augm_config.prob_hflip\n        self.rot_prob = augm_config.rotate.prob\n        self.rot_min_angle_deg = augm_config.rotate.get(\"min_angle_deg\", 0)\n        self.rot_max_angle_deg = augm_config.rotate.max_angle_deg\n        self.zoom_prob = augm_config.zoom.prob\n        zoom_out_weight = augm_config.zoom.zoom_out.get(\"weight\", 1)\n        self.min_zoom_out_factor = augm_config.zoom.zoom_out.factor.min\n        self.max_zoom_out_factor = augm_config.zoom.zoom_out.factor.max\n        has_zoom_in = \"zoom_in\" in augm_config.zoom\n        zoom_in_weight = augm_config.zoom.zoom_in.weight if has_zoom_in else 0\n        self.min_zoom_in_factor = (\n            augm_config.zoom.zoom_in.factor.min if has_zoom_in else 1\n        )\n        self.max_zoom_in_factor = (\n            augm_config.zoom.zoom_in.factor.max if has_zoom_in else 1\n        )\n\n        assert 0 <= self.h_flip_prob <= 1\n        assert 0 <= self.rot_prob <= 1\n        assert 0 <= self.rot_min_angle_deg <= self.rot_max_angle_deg\n        assert 0 <= self.zoom_prob <= 1\n        assert 0 <= zoom_in_weight\n        assert self.max_zoom_in_factor >= self.min_zoom_in_factor >= 1\n        assert 0 <= zoom_out_weight\n        assert self.max_zoom_out_factor >= self.min_zoom_out_factor >= 1\n        if not automatic_randomization:\n            # We are probably applying augmentation to a streaming dataset for which zoom in augm is not supported.\n            assert zoom_in_weight == 0, f\"{zoom_in_weight=}\"\n\n        self.zoom_in_or_out_distribution = torch.distributions.categorical.Categorical(\n            probs=th.tensor([zoom_in_weight, zoom_out_weight])\n        )\n\n        self.augm_state = AugmentationState(\n            apply_h_flip=False,\n            rotation=RotationState(active=False, angle_deg=0.0),\n            apply_zoom_in=False,\n            zoom_out=ZoomOutState(active=False, x0=0, y0=0, zoom_out_factor=1.0),\n        )\n\n    def randomize_augmentation(self):\n        \"\"\"Sample new augmentation parameters that will be consistently applied among the items.\n\n        This function only works with augmentations that are input-independent.\n        E.g. The zoom-in augmentation parameters depend on the labels and cannot be sampled in this function.\n        For the same reason, it is not a very reasonable augmentation for the streaming scenario.\n        \"\"\"\n        self.augm_state.apply_h_flip = self.h_flip_prob > th.rand(1).item()\n\n        self.augm_state.rotation.active = self.rot_prob > th.rand(1).item()\n        if self.augm_state.rotation.active:\n            sign = 1 if th.randn(1).item() >= 0 else -1\n            self.augm_state.rotation.angle_deg = sign * torch_uniform_sample_scalar(\n                min_value=self.rot_min_angle_deg, max_value=self.rot_max_angle_deg\n            )\n\n        # Zoom in and zoom out is mutually exclusive.\n        do_zoom = self.zoom_prob > th.rand(1).item()\n        do_zoom_in = self.zoom_in_or_out_distribution.sample().item() == 0\n        do_zoom_out = not do_zoom_in\n        do_zoom_in &= do_zoom\n        do_zoom_out &= do_zoom\n        self.augm_state.apply_zoom_in = do_zoom_in\n        self.augm_state.zoom_out.active = do_zoom_out\n        if do_zoom_out:\n            rand_zoom_out_factor = torch_uniform_sample_scalar(\n                min_value=self.min_zoom_out_factor, max_value=self.max_zoom_out_factor\n            )\n            height, width = self.hw_tuple\n            zoom_window_h, zoom_window_w = int(height / rand_zoom_out_factor), int(\n                width / rand_zoom_out_factor\n            )\n            x0_sampled = int(\n                torch_uniform_sample_scalar(\n                    min_value=0, max_value=width - zoom_window_w\n                )\n            )\n            y0_sampled = int(\n                torch_uniform_sample_scalar(\n                    min_value=0, max_value=height - zoom_window_h\n                )\n            )\n            self.augm_state.zoom_out.x0 = x0_sampled\n            self.augm_state.zoom_out.y0 = y0_sampled\n            self.augm_state.zoom_out.zoom_out_factor = rand_zoom_out_factor\n\n    def _zoom_out_and_rescale(\n        self, data_dict: LoaderDataDictGenX\n    ) -> LoaderDataDictGenX:\n        zoom_out_state = self.augm_state.zoom_out\n\n        zoom_out_factor = zoom_out_state.zoom_out_factor\n        if zoom_out_factor == 1:\n            return data_dict\n        return {\n            k: RandomSpatialAugmentorGenX._zoom_out_and_rescale_recursive(\n                v,\n                zoom_coordinates_x0y0=(zoom_out_state.x0, zoom_out_state.y0),\n                zoom_out_factor=zoom_out_factor,\n                datatype=k,\n            )\n            for k, v in data_dict.items()\n        }\n\n    @staticmethod\n    def _zoom_out_and_rescale_tensor(\n        input_: th.Tensor,\n        zoom_coordinates_x0y0: Tuple[int, int],\n        zoom_out_factor: float,\n        datatype: DataType,\n    ) -> th.Tensor:\n        assert len(zoom_coordinates_x0y0) == 2\n        assert isinstance(input_, th.Tensor)\n\n        if datatype == DataType.IMAGE or datatype == DataType.EV_REPR:\n            assert input_.ndim == 3, f\"{input_.shape=}\"\n            height, width = input_.shape[-2:]\n            zoom_window_h, zoom_window_w = int(height / zoom_out_factor), int(\n                width / zoom_out_factor\n            )\n            zoom_window = interpolate(\n                input_.unsqueeze(0),\n                size=(zoom_window_h, zoom_window_w),\n                mode=\"nearest-exact\",\n            )[0]\n            output = th.zeros_like(input_)\n\n            x0, y0 = zoom_coordinates_x0y0\n            assert x0 >= 0\n            assert y0 >= 0\n            output[:, y0 : y0 + zoom_window_h, x0 : x0 + zoom_window_w] = zoom_window\n            return output\n        raise NotImplementedError\n\n    @classmethod\n    def _zoom_out_and_rescale_recursive(\n        cls,\n        input_: Any,\n        zoom_coordinates_x0y0: Tuple[int, int],\n        zoom_out_factor: float,\n        datatype: DataType,\n    ):\n        if datatype in (DataType.IS_PADDED_MASK, DataType.IS_FIRST_SAMPLE):\n            return input_\n        if isinstance(input_, th.Tensor):\n            return cls._zoom_out_and_rescale_tensor(\n                input_=input_,\n                zoom_coordinates_x0y0=zoom_coordinates_x0y0,\n                zoom_out_factor=zoom_out_factor,\n                datatype=datatype,\n            )\n        if isinstance(input_, ObjectLabels) or isinstance(\n            input_, SparselyBatchedObjectLabels\n        ):\n            assert datatype == DataType.OBJLABELS or datatype == DataType.OBJLABELS_SEQ\n            input_.zoom_out_and_rescale_(\n                zoom_coordinates_x0y0=zoom_coordinates_x0y0,\n                zoom_out_factor=zoom_out_factor,\n            )\n            return input_\n        if isinstance(input_, abc.Sequence):\n            return [\n                RandomSpatialAugmentorGenX._zoom_out_and_rescale_recursive(\n                    x,\n                    zoom_coordinates_x0y0=zoom_coordinates_x0y0,\n                    zoom_out_factor=zoom_out_factor,\n                    datatype=datatype,\n                )\n                for x in input_\n            ]\n        if isinstance(input_, abc.Mapping):\n            return {\n                key: RandomSpatialAugmentorGenX._zoom_out_and_rescale_recursive(\n                    value,\n                    zoom_coordinates_x0y0=zoom_coordinates_x0y0,\n                    zoom_out_factor=zoom_out_factor,\n                    datatype=datatype,\n                )\n                for key, value in input_.items()\n            }\n        raise NotImplementedError\n\n    def _zoom_in_and_rescale(self, data_dict: LoaderDataDictGenX) -> LoaderDataDictGenX:\n        rand_zoom_in_factor = torch_uniform_sample_scalar(\n            min_value=self.min_zoom_in_factor, max_value=self.max_zoom_in_factor\n        )\n        if rand_zoom_in_factor == 1:\n            return data_dict\n\n        height, width = RandomSpatialAugmentorGenX._hw_from_data(data_dict=data_dict)\n        assert (height, width) == self.hw_tuple\n        zoom_window_h, zoom_window_w = int(height / rand_zoom_in_factor), int(\n            width / rand_zoom_in_factor\n        )\n        latest_objframe = get_most_recent_objframe(\n            data_dict=data_dict, check_if_nonempty=True\n        )\n        if latest_objframe is None:\n            warn(message=NO_LABEL_WARN_MSG, category=UserWarning, stacklevel=2)\n            return data_dict\n        x0_sampled, y0_sampled = randomly_sample_zoom_window_from_objframe(\n            objframe=latest_objframe,\n            zoom_window_height=zoom_window_h,\n            zoom_window_width=zoom_window_w,\n        )\n\n        return {\n            k: RandomSpatialAugmentorGenX._zoom_in_and_rescale_recursive(\n                v,\n                zoom_coordinates_x0y0=(x0_sampled, y0_sampled),\n                zoom_in_factor=rand_zoom_in_factor,\n                datatype=k,\n            )\n            for k, v in data_dict.items()\n        }\n\n    @staticmethod\n    def _zoom_in_and_rescale_tensor(\n        input_: th.Tensor,\n        zoom_coordinates_x0y0: Tuple[int, int],\n        zoom_in_factor: float,\n        datatype: DataType,\n    ) -> th.Tensor:\n        assert len(zoom_coordinates_x0y0) == 2\n        assert isinstance(input_, th.Tensor)\n\n        if datatype == DataType.IMAGE or datatype == DataType.EV_REPR:\n            assert input_.ndim == 3, f\"{input_.shape=}\"\n            height, width = input_.shape[-2:]\n            zoom_window_h, zoom_window_w = int(height / zoom_in_factor), int(\n                width / zoom_in_factor\n            )\n\n            x0, y0 = zoom_coordinates_x0y0\n            assert x0 >= 0\n            assert y0 >= 0\n            zoom_canvas = input_[\n                ..., y0 : y0 + zoom_window_h, x0 : x0 + zoom_window_w\n            ].unsqueeze(0)\n            output = interpolate(\n                zoom_canvas, size=(height, width), mode=\"nearest-exact\"\n            )\n            output = output[0]\n            return output\n        raise NotImplementedError\n\n    @classmethod\n    def _zoom_in_and_rescale_recursive(\n        cls,\n        input_: Any,\n        zoom_coordinates_x0y0: Tuple[int, int],\n        zoom_in_factor: float,\n        datatype: DataType,\n    ):\n        if datatype in (DataType.IS_PADDED_MASK, DataType.IS_FIRST_SAMPLE):\n            return input_\n        if isinstance(input_, th.Tensor):\n            return cls._zoom_in_and_rescale_tensor(\n                input_=input_,\n                zoom_coordinates_x0y0=zoom_coordinates_x0y0,\n                zoom_in_factor=zoom_in_factor,\n                datatype=datatype,\n            )\n        if isinstance(input_, ObjectLabels) or isinstance(\n            input_, SparselyBatchedObjectLabels\n        ):\n            assert datatype == DataType.OBJLABELS or datatype == DataType.OBJLABELS_SEQ\n            input_.zoom_in_and_rescale_(\n                zoom_coordinates_x0y0=zoom_coordinates_x0y0,\n                zoom_in_factor=zoom_in_factor,\n            )\n            return input_\n        if isinstance(input_, abc.Sequence):\n            return [\n                RandomSpatialAugmentorGenX._zoom_in_and_rescale_recursive(\n                    x,\n                    zoom_coordinates_x0y0=zoom_coordinates_x0y0,\n                    zoom_in_factor=zoom_in_factor,\n                    datatype=datatype,\n                )\n                for x in input_\n            ]\n        if isinstance(input_, abc.Mapping):\n            return {\n                key: RandomSpatialAugmentorGenX._zoom_in_and_rescale_recursive(\n                    value,\n                    zoom_coordinates_x0y0=zoom_coordinates_x0y0,\n                    zoom_in_factor=zoom_in_factor,\n                    datatype=datatype,\n                )\n                for key, value in input_.items()\n            }\n        raise NotImplementedError\n\n    def _rotate(self, data_dict: LoaderDataDictGenX) -> LoaderDataDictGenX:\n        angle_deg = self.augm_state.rotation.angle_deg\n        return {\n            k: RandomSpatialAugmentorGenX._rotate_recursive(\n                v, angle_deg=angle_deg, datatype=k\n            )\n            for k, v in data_dict.items()\n        }\n\n    @staticmethod\n    def _rotate_tensor(input_: Any, angle_deg: float, datatype: DataType):\n        assert isinstance(input_, th.Tensor)\n        if datatype == DataType.IMAGE or datatype == DataType.EV_REPR:\n            return rotate(\n                input_, angle=angle_deg, interpolation=InterpolationMode.NEAREST\n            )\n        raise NotImplementedError\n\n    @classmethod\n    def _rotate_recursive(cls, input_: Any, angle_deg: float, datatype: DataType):\n        if datatype in (DataType.IS_PADDED_MASK, DataType.IS_FIRST_SAMPLE):\n            return input_\n        if isinstance(input_, th.Tensor):\n            return cls._rotate_tensor(\n                input_=input_, angle_deg=angle_deg, datatype=datatype\n            )\n        if isinstance(input_, ObjectLabels) or isinstance(\n            input_, SparselyBatchedObjectLabels\n        ):\n            assert datatype == DataType.OBJLABELS or datatype == DataType.OBJLABELS_SEQ\n            input_.rotate_(angle_deg=angle_deg)\n            return input_\n        if isinstance(input_, abc.Sequence):\n            return [\n                RandomSpatialAugmentorGenX._rotate_recursive(\n                    x, angle_deg=angle_deg, datatype=datatype\n                )\n                for x in input_\n            ]\n        if isinstance(input_, abc.Mapping):\n            return {\n                key: RandomSpatialAugmentorGenX._rotate_recursive(\n                    value, angle_deg=angle_deg, datatype=datatype\n                )\n                for key, value in input_.items()\n            }\n        raise NotImplementedError\n\n    @staticmethod\n    def _flip(data_dict: LoaderDataDictGenX, type_: str) -> LoaderDataDictGenX:\n        assert type_ in {\"h\", \"v\"}\n        return {\n            k: RandomSpatialAugmentorGenX._flip_recursive(\n                v, flip_type=type_, datatype=k\n            )\n            for k, v in data_dict.items()\n        }\n\n    @staticmethod\n    def _flip_tensor(input_: Any, flip_type: str, datatype: DataType):\n        assert isinstance(input_, th.Tensor)\n        flip_axis = -1 if flip_type == \"h\" else -2\n        if datatype == DataType.IMAGE or datatype == DataType.EV_REPR:\n            return th.flip(input_, dims=[flip_axis])\n        if datatype == DataType.FLOW:\n            assert input_.shape[-3] == 2\n            flow_idx = 0 if flip_type == \"h\" else 1\n            input_ = th.flip(input_, dims=[flip_axis])\n            # Also flip the sign of the x (horizontal) or y (vertical) component of the flow.\n            input_[..., flow_idx, :, :] = -1 * input_[..., flow_idx, :, :]\n            return input_\n        raise NotImplementedError\n\n    @classmethod\n    def _flip_recursive(cls, input_: Any, flip_type: str, datatype: DataType):\n        if datatype in (DataType.IS_PADDED_MASK, DataType.IS_FIRST_SAMPLE):\n            return input_\n        if isinstance(input_, th.Tensor):\n            return cls._flip_tensor(\n                input_=input_, flip_type=flip_type, datatype=datatype\n            )\n        if isinstance(input_, ObjectLabels) or isinstance(\n            input_, SparselyBatchedObjectLabels\n        ):\n            assert datatype == DataType.OBJLABELS or datatype == DataType.OBJLABELS_SEQ\n            if flip_type == \"h\":\n                # in-place modification\n                input_.flip_lr_()\n                return input_\n            else:\n                raise NotImplementedError\n        if isinstance(input_, abc.Sequence):\n            return [\n                RandomSpatialAugmentorGenX._flip_recursive(\n                    x, flip_type=flip_type, datatype=datatype\n                )\n                for x in input_\n            ]\n        if isinstance(input_, abc.Mapping):\n            return {\n                key: RandomSpatialAugmentorGenX._flip_recursive(\n                    value, flip_type=flip_type, datatype=datatype\n                )\n                for key, value in input_.items()\n            }\n        raise NotImplementedError\n\n    @staticmethod\n    def _hw_from_data(data_dict: LoaderDataDictGenX) -> Tuple[int, int]:\n        height = None\n        width = None\n        for k, v in data_dict.items():\n            _hw = None\n            if k == DataType.OBJLABELS or k == DataType.OBJLABELS_SEQ:\n                hw = v.input_size_hw\n                if hw is not None:\n                    _hw = v.input_size_hw\n            elif k in (DataType.IMAGE, DataType.FLOW, DataType.EV_REPR):\n                _hw = v[0].shape[-2:]\n            if _hw is not None:\n                _height, _width = _hw\n                if height is None:\n                    assert width is None\n                    height, width = _height, _width\n                else:\n                    assert height == _height and width == _width\n        assert height is not None\n        assert width is not None\n        return height, width\n\n    def __call__(self, data_dict: LoaderDataDictGenX):\n        \"\"\"\n        :param data_dict: LoaderDataDictGenX type, image-based tensors must have (*, h, w) shape.\n        :return: map with same keys but spatially augmented values.\n        \"\"\"\n        if self.automatic_randomization:\n            self.randomize_augmentation()\n\n        if self.augm_state.apply_h_flip:\n            data_dict = self._flip(data_dict, type_=\"h\")\n        if self.augm_state.rotation.active:\n            data_dict = self._rotate(data_dict)\n        if self.augm_state.apply_zoom_in:\n            data_dict = self._zoom_in_and_rescale(data_dict=data_dict)\n        if self.augm_state.zoom_out.active:\n            assert not self.augm_state.apply_zoom_in\n            data_dict = self._zoom_out_and_rescale(data_dict=data_dict)\n        return data_dict\n\n\ndef get_most_recent_objframe(\n    data_dict: LoaderDataDictGenX, check_if_nonempty: bool = True\n) -> Optional[ObjectLabels]:\n    assert (\n        DataType.OBJLABELS_SEQ in data_dict\n    ), f\"Requires datatype {DataType.OBJLABELS_SEQ} to be present\"\n    sparse_obj_labels = data_dict[DataType.OBJLABELS_SEQ]\n    sparse_obj_labels: SparselyBatchedObjectLabels\n\n    for obj_label in reversed(sparse_obj_labels):\n        if obj_label is not None:\n            return_label = True if not check_if_nonempty else len(obj_label) > 0\n            if return_label:\n                return obj_label\n    # no labels found\n    return None\n\n\ndef randomly_sample_zoom_window_from_objframe(\n    objframe: ObjectLabels,\n    zoom_window_height: Union[int, float],\n    zoom_window_width: Union[int, float],\n) -> Tuple[int, int]:\n    input_height, input_width = objframe.input_size_hw\n    possible_samples = []\n    for idx in range(len(objframe)):\n        label_xywh = (\n            objframe.x[idx],\n            objframe.y[idx],\n            objframe.w[idx],\n            objframe.h[idx],\n        )\n        possible_samples.append(\n            randomly_sample_zoom_window_from_label_rectangle(\n                label_xywh=label_xywh,\n                input_height=input_height,\n                input_width=input_width,\n                zoom_window_height=zoom_window_height,\n                zoom_window_width=zoom_window_width,\n            )\n        )\n    assert len(possible_samples) > 0\n    # Using torch to sample, to avoid potential problems with multiprocessing.\n    sample_idx = (\n        0\n        if len(possible_samples) == 1\n        else th.randint(low=0, high=len(possible_samples) - 1, size=(1,)).item()\n    )\n    x0_sample, y0_sample = possible_samples[sample_idx]\n    assert input_width > x0_sample >= 0, f\"{x0_sample=}\"\n    assert input_height > y0_sample >= 0, f\"{y0_sample=}\"\n    return x0_sample, y0_sample\n\n\ndef randomly_sample_zoom_window_from_label_rectangle(\n    label_xywh: Tuple[Union[int, float, th.Tensor], ...],\n    input_height: Union[int, float],\n    input_width: Union[int, float],\n    zoom_window_height: Union[int, float],\n    zoom_window_width: Union[int, float],\n) -> Tuple[int, int]:\n    \"\"\"Computes a set of top-left coordinates from which the top-left corner of the zoom window\n    can be sampled such that the zoom window is guaranteed to contain the whole (rectangular) label.\n    Return a random sample from this set.\n\n    Notation:\n    (x0,y0)---(x1,y0)\n     |             |\n     |             |\n    (x0,y1)---(x1,y1)\n    \"\"\"\n    assert input_height >= zoom_window_height\n    assert input_width >= zoom_window_width\n    label_xywh = tuple(x.item() if isinstance(x, th.Tensor) else x for x in label_xywh)\n    x0_l, y0_l, w_l, h_l = label_xywh\n    x1_l = x0_l + w_l\n    y1_l = y0_l + h_l\n    assert x0_l >= 0\n    assert y0_l >= 0\n    assert w_l > 0\n    assert h_l > 0\n    assert x1_l <= input_width + 1e-2 - 1\n    assert y1_l <= input_height + 1e-2 - 1\n\n    x0_valid_region = max(x1_l - max(zoom_window_width, w_l), 0)\n    y0_valid_region = max(y1_l - max(zoom_window_height, h_l), 0)\n    x1_valid_region = min(x0_l + max(zoom_window_width, w_l), input_width - 1)\n    y1_valid_region = min(y0_l + max(zoom_window_height, h_l), input_height - 1)\n\n    x1_valid_region = max(x1_valid_region - zoom_window_width, x0_valid_region)\n    y1_valid_region = max(y1_valid_region - zoom_window_height, y0_valid_region)\n\n    x_topleft_sample = int(\n        torch_uniform_sample_scalar(\n            min_value=x0_valid_region, max_value=x1_valid_region\n        )\n    )\n    assert 0 <= x_topleft_sample < input_width\n    y_topleft_sample = int(\n        torch_uniform_sample_scalar(\n            min_value=y0_valid_region, max_value=y1_valid_region\n        )\n    )\n    assert 0 <= y_topleft_sample < input_height\n    return x_topleft_sample, y_topleft_sample\n"
  },
  {
    "path": "RVT/data/utils/representations.py",
    "content": "from abc import ABC, abstractmethod\nfrom typing import Optional, Tuple\n\nimport math\nimport numpy as np\nimport torch as th\n\n\nclass RepresentationBase(ABC):\n    @abstractmethod\n    def construct(\n        self, x: th.Tensor, y: th.Tensor, pol: th.Tensor, time: th.Tensor\n    ) -> th.Tensor: ...\n\n    @abstractmethod\n    def get_shape(self) -> Tuple[int, int, int]: ...\n\n    @staticmethod\n    @abstractmethod\n    def get_numpy_dtype() -> np.dtype: ...\n\n    @staticmethod\n    @abstractmethod\n    def get_torch_dtype() -> th.dtype: ...\n\n    @property\n    def dtype(self) -> th.dtype:\n        return self.get_torch_dtype()\n\n    @staticmethod\n    def _is_int_tensor(tensor: th.Tensor) -> bool:\n        return not th.is_floating_point(tensor) and not th.is_complex(tensor)\n\n\nclass StackedHistogram(RepresentationBase):\n    def __init__(\n        self,\n        bins: int,\n        height: int,\n        width: int,\n        count_cutoff: Optional[int] = None,\n        fastmode: bool = True,\n    ):\n        \"\"\"\n        In case of fastmode == True: use uint8 to construct the representation, but could lead to overflow.\n        In case of fastmode == False: use int16 to construct the representation, and convert to uint8 after clipping.\n\n        Note: Overflow should not be a big problem because it happens only for hot pixels. In case of overflow,\n        the value will just start accumulating from 0 again.\n        \"\"\"\n        assert bins >= 1\n        self.bins = bins\n        assert height >= 1\n        self.height = height\n        assert width >= 1\n        self.width = width\n        self.count_cutoff = count_cutoff\n        if self.count_cutoff is None:\n            self.count_cutoff = 255\n        else:\n            assert count_cutoff >= 1\n            self.count_cutoff = min(count_cutoff, 255)\n        self.fastmode = fastmode\n        self.channels = 2\n\n    @staticmethod\n    def get_numpy_dtype() -> np.dtype:\n        return np.dtype(\"uint8\")\n\n    @staticmethod\n    def get_torch_dtype() -> th.dtype:\n        return th.uint8\n\n    def merge_channel_and_bins(self, representation: th.Tensor):\n        assert representation.dim() == 4\n        return th.reshape(representation, (-1, self.height, self.width))\n\n    def get_shape(self) -> Tuple[int, int, int]:\n        return 2 * self.bins, self.height, self.width\n\n    def construct(\n        self, x: th.Tensor, y: th.Tensor, pol: th.Tensor, time: th.Tensor\n    ) -> th.Tensor:\n        device = x.device\n        assert y.device == pol.device == time.device == device\n        assert self._is_int_tensor(x)\n        assert self._is_int_tensor(y)\n        assert self._is_int_tensor(pol)\n        assert self._is_int_tensor(time)\n\n        dtype = th.uint8 if self.fastmode else th.int16\n\n        representation = th.zeros(\n            (self.channels, self.bins, self.height, self.width),\n            dtype=dtype,\n            device=device,\n            requires_grad=False,\n        )\n\n        if x.numel() == 0:\n            assert y.numel() == 0\n            assert pol.numel() == 0\n            assert time.numel() == 0\n            return self.merge_channel_and_bins(representation.to(th.uint8))\n        assert x.numel() == y.numel() == pol.numel() == time.numel()\n\n        assert pol.min() >= 0\n        assert pol.max() <= 1\n\n        bn, ch, ht, wd = self.bins, self.channels, self.height, self.width\n\n        # NOTE: assume sorted time\n        t0_int = time[0]\n        t1_int = time[-1]\n        assert t1_int >= t0_int\n        t_norm = time - t0_int\n        t_norm = t_norm / max((t1_int - t0_int), 1)\n        t_norm = t_norm * bn\n        t_idx = t_norm.floor()\n        t_idx = th.clamp(t_idx, max=bn - 1)\n\n        indices = (\n            x.long()\n            + wd * y.long()\n            + ht * wd * t_idx.long()\n            + bn * ht * wd * pol.long()\n        )\n        values = th.ones_like(indices, dtype=dtype, device=device)\n        representation.put_(indices, values, accumulate=True)\n        representation = th.clamp(representation, min=0, max=self.count_cutoff)\n        if not self.fastmode:\n            representation = representation.to(th.uint8)\n\n        return self.merge_channel_and_bins(representation)\n\n\ndef cumsum_channel(x: th.Tensor, num_channels: int):\n    for i in reversed(range(num_channels)):\n        x[i] = th.sum(input=x[: i + 1], dim=0)\n    return x\n\n\nclass MixedDensityEventStack(RepresentationBase):\n    def __init__(\n        self,\n        bins: int,\n        height: int,\n        width: int,\n        count_cutoff: Optional[int] = None,\n        allow_compilation: bool = False,\n    ):\n        assert bins >= 1\n        self.bins = bins\n        assert height >= 1\n        self.height = height\n        assert width >= 1\n        self.width = width\n        self.count_cutoff = count_cutoff\n        if self.count_cutoff is not None:\n            assert isinstance(count_cutoff, int)\n            assert 0 <= self.count_cutoff <= 2**7 - 1\n\n        self.cumsum_ch_opt = cumsum_channel\n\n        if allow_compilation:\n            # Will most likely not work with multiprocessing.\n            try:\n                self.cumsum_ch_opt = th.compile(cumsum_channel)\n            except AttributeError:\n                ...\n\n    @staticmethod\n    def get_numpy_dtype() -> np.dtype:\n        return np.dtype(\"int8\")\n\n    @staticmethod\n    def get_torch_dtype() -> th.dtype:\n        return th.int8\n\n    def get_shape(self) -> Tuple[int, int, int]:\n        return self.bins, self.height, self.width\n\n    def construct(\n        self, x: th.Tensor, y: th.Tensor, pol: th.Tensor, time: th.Tensor\n    ) -> th.Tensor:\n        device = x.device\n        assert y.device == pol.device == time.device == device\n        assert self._is_int_tensor(x)\n        assert self._is_int_tensor(y)\n        assert self._is_int_tensor(pol)\n        assert self._is_int_tensor(time)\n\n        dtype = th.int8\n\n        representation = th.zeros(\n            (self.bins, self.height, self.width),\n            dtype=dtype,\n            device=device,\n            requires_grad=False,\n        )\n\n        if x.numel() == 0:\n            assert y.numel() == 0\n            assert pol.numel() == 0\n            assert time.numel() == 0\n            return representation\n        assert x.numel() == y.numel() == pol.numel() == time.numel()\n\n        assert pol.min() >= 0  # maybe remove because too costly\n        assert pol.max() <= 1  # maybe remove because too costly\n        pol = pol * 2 - 1\n\n        bn, ht, wd = self.bins, self.height, self.width\n\n        # NOTE: assume sorted time\n        t0_int = time[0]\n        t1_int = time[-1]\n        assert t1_int >= t0_int\n        t_norm = (time - t0_int) / max((t1_int - t0_int), 1)\n        t_norm = th.clamp(t_norm, min=1e-6, max=1 - 1e-6)\n        # Let N be the number of bins. I.e. bin \\in [0, N):\n        # Let f(bin) = t_norm, model the relationship between bin and normalized time \\in [0, 1]\n        # f(bin=N) = 1\n        # f(bin=N-1) = 1/2\n        # f(bin=N-2) = 1/2*1/2\n        # -> f(bin=N-i) = (1/2)^i\n        # Also: f(bin) = t_norm\n        #\n        # Hence, (1/2)^(N-bin) = t_norm\n        # And, bin = N - log(t_norm, base=1/2) = N - log(t_norm)/log(1/2)\n        bin_float = self.bins - th.log(t_norm) / math.log(1 / 2)\n        # Can go below 0 for t_norm close to 0 -> clamp to 0\n        bin_float = th.clamp(bin_float, min=0)\n        t_idx = bin_float.floor()\n\n        indices = x.long() + wd * y.long() + ht * wd * t_idx.long()\n        values = th.asarray(pol, dtype=dtype, device=device)\n        representation.put_(indices, values, accumulate=True)\n        representation = self.cumsum_ch_opt(representation, num_channels=self.bins)\n        if self.count_cutoff is not None:\n            representation = th.clamp(\n                representation, min=-self.count_cutoff, max=self.count_cutoff\n            )\n        return representation\n"
  },
  {
    "path": "RVT/data/utils/spatial.py",
    "content": "from omegaconf import DictConfig\n\nfrom data.utils.types import DatasetType\n\n_type_2_hw = {\n    DatasetType.GEN1: (240, 304),\n    DatasetType.GEN4: (720, 1280),\n}\n\n_str_2_type = {\n    \"gen1\": DatasetType.GEN1,\n    \"gen4\": DatasetType.GEN4,\n}\n\n\ndef get_original_hw(dataset_type: DatasetType):\n    return _type_2_hw[dataset_type]\n\n\ndef get_dataloading_hw(dataset_config: DictConfig):\n    dataset_name = dataset_config.name\n    hw = get_original_hw(dataset_type=_str_2_type[dataset_name])\n    downsample_by_factor_2 = dataset_config.downsample_by_factor_2\n    if downsample_by_factor_2:\n        hw = tuple(x // 2 for x in hw)\n    return hw\n"
  },
  {
    "path": "RVT/data/utils/stream_concat_datapipe.py",
    "content": "from typing import Any, Iterator, List, Optional, Type\n\nimport torch as th\nimport torch.distributed as dist\nfrom torch.utils.data import DataLoader\nfrom torchdata.datapipes.iter import (\n    Concater,\n    IterableWrapper,\n    IterDataPipe,\n    Zipper,\n)\nfrom torchdata.datapipes.map import MapDataPipe\n\n\nclass DummyIterDataPipe(IterDataPipe):\n    def __init__(self, source_dp: IterDataPipe):\n        super().__init__()\n        assert isinstance(source_dp, IterDataPipe)\n        self.source_dp = source_dp\n\n    def __iter__(self):\n        yield from self.source_dp\n\n\nclass ConcatStreamingDataPipe(IterDataPipe):\n    \"\"\"This Dataset avoids the sharding problem by instantiating randomized stream concatenation at the batch and\n    worker level.\n    Pros:\n    - Every single batch has valid samples. Consequently, the batch size is always constant.\n    Cons:\n    - There might be repeated samples in a batch. Although they should be different because of data augmentation.\n    - Cannot be used for validation or testing because we repeat the dataset multiple times in an epoch.\n\n    TLDR: preferred approach for training but not useful for validation or testing.\n    \"\"\"\n\n    def __init__(\n        self,\n        datapipe_list: List[MapDataPipe],\n        batch_size: int,\n        num_workers: int,\n        augmentation_pipeline: Optional[Type[IterDataPipe]] = None,\n        print_seed_debug: bool = False,\n    ):\n        super().__init__()\n        assert batch_size > 0\n\n        if augmentation_pipeline is not None:\n            self.augmentation_dp = augmentation_pipeline\n        else:\n            self.augmentation_dp = DummyIterDataPipe\n\n        # We require MapDataPipes instead of IterDataPipes because IterDataPipes must be deepcopied in each worker.\n        # Instead, MapDataPipes can be converted to IterDataPipes in each worker without requiring a deepcopy.\n        self.datapipe_list = datapipe_list\n        self.batch_size = batch_size\n\n        self.print_seed_debug = print_seed_debug\n\n    @staticmethod\n    def random_torch_shuffle_list(data: List[Any]) -> Iterator[Any]:\n        assert isinstance(data, List)\n        return (data[idx] for idx in th.randperm(len(data)).tolist())\n\n    def _get_zipped_streams(self, datapipe_list: List[MapDataPipe], batch_size: int):\n        \"\"\"Use it only in the iter function of this class!!!\n        Reason: randomized shuffling must happen within each worker. Otherwise, the same random order will be used\n        for all workers.\n        \"\"\"\n        assert isinstance(datapipe_list, List)\n        assert batch_size > 0\n        streams = Zipper(\n            *(\n                Concater(\n                    *(\n                        self.augmentation_dp(x.to_iter_datapipe())\n                        for x in self.random_torch_shuffle_list(datapipe_list)\n                    )\n                )\n                for _ in range(batch_size)\n            )\n        )\n        return streams\n\n    def _print_seed_debug_info(self):\n        worker_info = th.utils.data.get_worker_info()\n        local_worker_id = 0 if worker_info is None else worker_info.id\n\n        worker_torch_seed = worker_info.seed\n        local_num_workers = 1 if worker_info is None else worker_info.num_workers\n        if dist.is_available() and dist.is_initialized():\n            global_rank = dist.get_rank()\n        else:\n            global_rank = 0\n        global_worker_id = global_rank * local_num_workers + local_worker_id\n\n        rnd_number = th.randn(1)\n        print(\n            f\"{worker_torch_seed=},\\t{global_worker_id=},\\t{global_rank=},\\t{local_worker_id=},\\t{rnd_number=}\",\n            flush=True,\n        )\n\n    def _get_zipped_streams_with_worker_id(self):\n        \"\"\"Use it only in the iter function of this class!!!\"\"\"\n        worker_info = th.utils.data.get_worker_info()\n        local_worker_id = 0 if worker_info is None else worker_info.id\n        worker_id_stream = IterableWrapper([local_worker_id]).cycle(count=None)\n        zipped_stream = self._get_zipped_streams(\n            datapipe_list=self.datapipe_list, batch_size=self.batch_size\n        )\n        return zipped_stream.zip(worker_id_stream)\n\n    def __iter__(self):\n        if self.print_seed_debug:\n            self._print_seed_debug_info()\n        return iter(self._get_zipped_streams_with_worker_id())\n"
  },
  {
    "path": "RVT/data/utils/stream_sharded_datapipe.py",
    "content": "from typing import Any, List, Optional\n\nimport torch\nimport torch.distributed as dist\nfrom torch.utils.data import DataLoader\nfrom torchdata.datapipes.iter import (\n    Concater,\n    IterableWrapper,\n    IterDataPipe,\n    ZipperLongest,\n)\nfrom torchdata.datapipes.map import MapDataPipe\n\n\nclass ShardedStreamingDataPipe(IterDataPipe):\n    def __init__(\n        self,\n        datapipe_list: List[MapDataPipe],\n        batch_size: int,\n        fill_value: Optional[Any] = None,\n    ):\n        super().__init__()\n        assert batch_size > 0\n\n        # We require MapDataPipes instead of IterDataPipes because IterDataPipes must be deepcopied in each worker.\n        # Instead, MapDataPipes can be converted to IterDataPipes in each worker without requiring a deepcopy.\n        # Note: Sorting is a heuristic to get potentially better distribution of workloads than taking the data as is.\n        # Sort iterators from long to short.\n        self.datapipe_list = sorted(datapipe_list, key=lambda x: len(x), reverse=True)\n        self.batch_size = batch_size\n        self.fill_value = fill_value\n\n    @staticmethod\n    def yield_pyramid_indices(start_idx: int, end_idx: int):\n        while True:\n            for idx in range(start_idx, end_idx):\n                yield idx\n            for idx in range(end_idx - 1, start_idx - 1, -1):\n                yield idx\n\n    @classmethod\n    def assign_datapipes_to_worker(\n        cls,\n        sorted_datapipe_list: List[MapDataPipe],\n        total_num_workers: int,\n        global_worker_id: int,\n    ) -> List[MapDataPipe]:\n        num_datapipes = len(sorted_datapipe_list)\n        assert (\n            num_datapipes >= total_num_workers > global_worker_id\n        ), f\"{num_datapipes=}, {total_num_workers=}, {global_worker_id=}\"\n        datapipes = []\n        # Assumes sorted datapipes from long to short.\n        global_worker_id_generator = cls.yield_pyramid_indices(\n            start_idx=0, end_idx=total_num_workers\n        )\n        for idx, dp in enumerate(sorted_datapipe_list):\n            generated_global_worker_id = next(global_worker_id_generator)\n            if generated_global_worker_id == global_worker_id:\n                datapipes.append(dp)\n        assert len(sorted_datapipe_list) > 0\n        return datapipes\n\n    def get_zipped_stream_from_worker_datapipes(\n        self, datapipe_list: List[MapDataPipe], batch_size: int\n    ) -> ZipperLongest:\n        num_datapipes = len(datapipe_list)\n        assert num_datapipes > 0\n        assert batch_size > 0\n        assert num_datapipes >= batch_size, (\n            \"Each worker must at least get 'batch_size' number of datapipes. \"\n            \"Otherwise, we would have to support dynamic batch sizes. \"\n            \"As a workaround, decrease the number of workers.\"\n        )\n        # Sort datapipe_list from long to short.\n        datapipe_list = sorted(datapipe_list, key=lambda x: len(x), reverse=True)\n        zipped_streams = [[] for _ in range(batch_size)]\n        batch_id_generator = self.yield_pyramid_indices(start_idx=0, end_idx=batch_size)\n        for datapipe in datapipe_list:\n            batch_idx = next(batch_id_generator)\n            zipped_streams[batch_idx].append(datapipe)\n        for idx, streams in enumerate(zipped_streams):\n            zipped_streams[idx] = Concater(\n                *(stream.to_iter_datapipe() for stream in streams)\n            )\n        zipped_streams = ZipperLongest(*zipped_streams, fill_value=self.fill_value)\n        return zipped_streams\n\n    def __iter__(self):\n        worker_info = torch.utils.data.get_worker_info()\n        local_worker_id = 0 if worker_info is None else worker_info.id\n        local_num_workers = 1 if worker_info is None else worker_info.num_workers\n        if dist.is_available() and dist.is_initialized():\n            world_size = dist.get_world_size()\n            global_rank = dist.get_rank()\n        else:\n            world_size = 1\n            global_rank = 0\n        total_num_workers = local_num_workers * world_size\n        global_worker_id = global_rank * local_num_workers + local_worker_id\n\n        local_datapipes = self.assign_datapipes_to_worker(\n            sorted_datapipe_list=self.datapipe_list,\n            total_num_workers=total_num_workers,\n            global_worker_id=global_worker_id,\n        )\n        zipped_stream = self.get_zipped_stream_from_worker_datapipes(\n            datapipe_list=local_datapipes, batch_size=self.batch_size\n        )\n        # We also stream the local worker id for the use-case where we have a recurrent neural network that saves\n        # its state based on the local worker id. We don't need the global worker id for that because the states\n        # are saved in each DDP process (per GPU) separately and do not to communicate with each other.\n\n        worker_id_stream = IterableWrapper([local_worker_id]).cycle(count=None)\n        zipped_stream = zipped_stream.zip(worker_id_stream)\n\n        return iter(zipped_stream)\n"
  },
  {
    "path": "RVT/data/utils/types.py",
    "content": "from enum import auto, Enum\n\ntry:\n    from enum import StrEnum\nexcept ImportError:\n    from strenum import StrEnum\nfrom typing import Dict, List, Optional, Tuple, Union\n\nimport torch as th\n\nfrom data.genx_utils.labels import ObjectLabels, SparselyBatchedObjectLabels\n\n\nclass DataType(Enum):\n    EV_REPR = auto()\n    FLOW = auto()\n    IMAGE = auto()\n    OBJLABELS = auto()\n    OBJLABELS_SEQ = auto()\n    IS_PADDED_MASK = auto()\n    IS_FIRST_SAMPLE = auto()\n    TOKEN_MASK = auto()\n\n\nclass DatasetType(Enum):\n    GEN1 = auto()\n    GEN4 = auto()\n\n\nclass DatasetMode(Enum):\n    TRAIN = auto()\n    VALIDATION = auto()\n    TESTING = auto()\n\n\nclass DatasetSamplingMode(StrEnum):\n    RANDOM = \"random\"\n    STREAM = \"stream\"\n    MIXED = \"mixed\"\n\n\nclass ObjDetOutput(Enum):\n    LABELS_PROPH = auto()\n    PRED_PROPH = auto()\n    EV_REPR = auto()\n    SKIP_VIZ = auto()\n\n\nLoaderDataDictGenX = Dict[\n    DataType,\n    Union[List[th.Tensor], ObjectLabels, SparselyBatchedObjectLabels, List[bool]],\n]\n\nLstmState = Optional[Tuple[th.Tensor]]\nLstmStates = List[LstmState]\n\nFeatureMap = th.Tensor\nBackboneFeatures = Dict[int, th.Tensor]\n"
  },
  {
    "path": "RVT/loggers/utils.py",
    "content": "from pathlib import Path\nfrom typing import Union\n\nimport wandb\nfrom omegaconf import DictConfig, OmegaConf\n\nfrom loggers.wandb_logger import WandbLogger\n\n\ndef get_wandb_logger(full_config: DictConfig) -> WandbLogger:\n    wandb_config = full_config.wandb\n    wandb_runpath = wandb_config.wandb_runpath\n\n    if wandb_runpath is None:\n        wandb_id = wandb.util.generate_id()\n        print(f\"new run: generating id {wandb_id}\")\n    else:\n        wandb_id = Path(wandb_runpath).name\n        print(f\"using provided id {wandb_id}\")\n\n    full_config_dict = OmegaConf.to_container(\n        full_config, resolve=True, throw_on_missing=True\n    )\n    logger = WandbLogger(\n        project=wandb_config.project_name,\n        group=wandb_config.group_name,\n        wandb_id=wandb_id,\n        log_model=True,\n        save_last_only_final=False,\n        save_code=True,\n        config_args=full_config_dict,\n    )\n\n    return logger\n\n\ndef get_ckpt_path(logger: WandbLogger, wandb_config: DictConfig) -> Union[Path, None]:\n    cfg = wandb_config\n    artifact_name = cfg.artifact_name\n    assert (\n        artifact_name is not None\n    ), \"Artifact name is required to resume from checkpoint.\"\n    print(f\"resuming checkpoint from artifact {artifact_name}\")\n    artifact_local_file = cfg.artifact_local_file\n    if artifact_local_file is not None:\n        artifact_local_file = Path(artifact_local_file)\n    if isinstance(logger, WandbLogger):\n        resume_path = logger.get_checkpoint(\n            artifact_name=artifact_name, artifact_filepath=artifact_local_file\n        )\n    else:\n        resume_path = artifact_local_file\n    assert resume_path.exists()\n    assert resume_path.suffix == \".ckpt\", resume_path.suffix\n    return resume_path\n"
  },
  {
    "path": "RVT/loggers/wandb_logger.py",
    "content": "\"\"\"\nThis is a modified version of the Pytorch Lightning logger\n\"\"\"\n\nimport time\nfrom argparse import Namespace\nfrom pathlib import Path\nfrom typing import Any, Dict, List, Optional, Union\nfrom weakref import ReferenceType\n\nimport numpy as np\nimport lightning.pytorch as pl\nimport torch\nimport torch.nn as nn\n\npl_is_ge_1_6 = float(pl.__version__[:3]) >= 1.6\nassert pl_is_ge_1_6\n\nfrom lightning.pytorch.callbacks.model_checkpoint import ModelCheckpoint\nfrom lightning.pytorch.loggers.logger import rank_zero_experiment, Logger\nfrom lightning.pytorch.utilities.rank_zero import rank_zero_only, rank_zero_warn\nfrom lightning.fabric.utilities.logger import (\n    _add_prefix,\n    _convert_params,\n    _flatten_dict,\n    _sanitize_callable_params,\n)\n\nimport wandb\nfrom wandb.sdk.lib import RunDisabled\nfrom wandb.wandb_run import Run\n\n\nclass WandbLogger(Logger):\n    LOGGER_JOIN_CHAR = \"-\"\n    STEP_METRIC = \"trainer/global_step\"\n\n    def __init__(\n        self,\n        name: Optional[str] = None,\n        project: Optional[str] = None,\n        group: Optional[str] = None,\n        wandb_id: Optional[str] = None,\n        prefix: Optional[str] = \"\",\n        log_model: Optional[bool] = True,\n        save_last_only_final: Optional[bool] = False,\n        config_args: Optional[Dict[str, Any]] = None,\n        **kwargs,\n    ):\n        super().__init__()\n        self._experiment = None\n        self._log_model = log_model\n        self._prefix = prefix\n        self._logged_model_time = {}\n        self._checkpoint_callback = None\n        # Save last is determined by the checkpoint callback argument\n        self._save_last = None\n        # Whether to save the last checkpoint continuously (more storage) or only when the run is aborted\n        self._save_last_only_final = save_last_only_final\n        # Save the configuration args (e.g. parsed arguments) and log it in wandb\n        self._config_args = config_args\n        # set wandb init arguments\n        self._wandb_init = dict(\n            name=name,\n            project=project,\n            group=group,\n            id=wandb_id,\n            resume=\"allow\",\n            save_code=True,\n        )\n        self._wandb_init.update(**kwargs)\n        # extract parameters\n        self._name = self._wandb_init.get(\"name\")\n        self._id = self._wandb_init.get(\"id\")\n        # for save_top_k\n        self._public_run = None\n\n        # start wandb run (to create an attach_id for distributed modes)\n        wandb.require(\"service\")\n        _ = self.experiment\n\n    def get_checkpoint(\n        self, artifact_name: str, artifact_filepath: Optional[Path] = None\n    ) -> Path:\n        artifact = self.experiment.use_artifact(artifact_name)\n        if artifact_filepath is None:\n            assert artifact is not None, (\n                \"You are probably using DDP, \"\n                \"in which case you should provide an artifact filepath.\"\n            )\n            # TODO: specify download directory\n            artifact_dir = artifact.download()\n            artifact_filepath = next(Path(artifact_dir).iterdir())\n        assert artifact_filepath.exists()\n        assert artifact_filepath.suffix == \".ckpt\"\n        return artifact_filepath\n\n    def __getstate__(self) -> Dict[str, Any]:\n        state = self.__dict__.copy()\n        # args needed to reload correct experiment\n        if self._experiment is not None:\n            state[\"_id\"] = getattr(self._experiment, \"id\", None)\n            state[\"_attach_id\"] = getattr(self._experiment, \"_attach_id\", None)\n            state[\"_name\"] = self._experiment.name\n\n        # cannot be pickled\n        state[\"_experiment\"] = None\n        return state\n\n    @property\n    @rank_zero_experiment\n    def experiment(self) -> Run:\n        if self._experiment is None:\n            attach_id = getattr(self, \"_attach_id\", None)\n            if wandb.run is not None:\n                # wandb process already created in this instance\n                rank_zero_warn(\n                    \"There is a wandb run already in progress and newly created instances of `WandbLogger` will reuse\"\n                    \" this run. If this is not desired, call `wandb.finish()` before instantiating `WandbLogger`.\"\n                )\n                self._experiment = wandb.run\n            elif attach_id is not None and hasattr(wandb, \"_attach\"):\n                # attach to wandb process referenced\n                self._experiment = wandb._attach(attach_id)\n            else:\n                # create new wandb process\n                self._experiment = wandb.init(**self._wandb_init)\n                if self._config_args is not None:\n                    self._experiment.config.update(\n                        self._config_args, allow_val_change=True\n                    )\n\n                # define default x-axis\n                if isinstance(self._experiment, (Run, RunDisabled)) and getattr(\n                    self._experiment, \"define_metric\", None\n                ):\n                    self._experiment.define_metric(self.STEP_METRIC)\n                    self._experiment.define_metric(\n                        \"*\", step_metric=self.STEP_METRIC, step_sync=True\n                    )\n\n        assert isinstance(self._experiment, (Run, RunDisabled))\n        return self._experiment\n\n    def watch(\n        self,\n        model: nn.Module,\n        log: str = \"all\",\n        log_freq: int = 100,\n        log_graph: bool = True,\n    ):\n        self.experiment.watch(model, log=log, log_freq=log_freq, log_graph=log_graph)\n\n    def add_step_metric(self, input_dict: dict, step: int) -> None:\n        input_dict.update({self.STEP_METRIC: step})\n\n    @rank_zero_only\n    def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None:\n        params = _convert_params(params)\n        params = _flatten_dict(params)\n        params = _sanitize_callable_params(params)\n        self.experiment.config.update(params, allow_val_change=True)\n\n    @rank_zero_only\n    def log_metrics(self, metrics: Dict[str, Any], step: Optional[int] = None) -> None:\n        assert rank_zero_only.rank == 0, \"experiment tried to log from global_rank != 0\"\n\n        metrics = _add_prefix(metrics, self._prefix, self.LOGGER_JOIN_CHAR)\n        if step is not None:\n            self.add_step_metric(metrics, step)\n            self.experiment.log({**metrics}, step=step)\n        else:\n            self.experiment.log(metrics)\n\n    @rank_zero_only\n    def log_images(\n        self, key: str, images: List[Any], step: Optional[int] = None, **kwargs: str\n    ) -> None:\n        \"\"\"Log images (tensors, numpy arrays, PIL Images or file paths).\n        Optional kwargs are lists passed to each image (ex: caption, masks, boxes).\n\n        How to use: https://pytorch-lightning.readthedocs.io/en/latest/api/pytorch_lightning.loggers.wandb.html#weights-and-biases-logger\n        Taken from: https://github.com/PyTorchLightning/pytorch-lightning/blob/11e289ad9f95f5fe23af147fa4edcc9794f9b9a7/pytorch_lightning/loggers/wandb.py#L420\n        \"\"\"\n        if not isinstance(images, list):\n            raise TypeError(f'Expected a list as \"images\", found {type(images)}')\n        n = len(images)\n        for k, v in kwargs.items():\n            if len(v) != n:\n                raise ValueError(f\"Expected {n} items but only found {len(v)} for {k}\")\n        kwarg_list = [{k: kwargs[k][i] for k in kwargs.keys()} for i in range(n)]\n        metrics = {\n            key: [wandb.Image(img, **kwarg) for img, kwarg in zip(images, kwarg_list)]\n        }\n        self.log_metrics(metrics, step)\n\n    @rank_zero_only\n    def log_videos(\n        self,\n        key: str,\n        videos: List[Union[np.ndarray, str]],\n        step: Optional[int] = None,\n        captions: Optional[List[str]] = None,\n        fps: int = 4,\n        format_: Optional[str] = None,\n    ):\n        \"\"\"\n        :param video: List[(T,C,H,W)] or List[(N,T,C,H,W)]\n        :param captions: List[str] or None\n\n        More info: https://docs.wandb.ai/ref/python/data-types/video and\n        https://docs.wandb.ai/guides/track/log/media#other-media\n        \"\"\"\n        assert isinstance(videos, list)\n        if captions is not None:\n            assert isinstance(captions, list)\n            assert len(captions) == len(videos)\n        wandb_videos = list()\n        for idx, video in enumerate(videos):\n            caption = captions[idx] if captions is not None else None\n            wandb_videos.append(\n                wandb.Video(\n                    data_or_path=video, caption=caption, fps=fps, format=format_\n                )\n            )\n        self.log_metrics(metrics={key: wandb_videos}, step=step)\n\n    @property\n    def name(self) -> Optional[str]:\n        # This function seems to be only relevant if LoggerCollection is used.\n        # don't create an experiment if we don't have one\n        return self._experiment.project_name() if self._experiment else self._name\n\n    @property\n    def version(self) -> Optional[str]:\n        # This function seems to be only relevant if LoggerCollection is used.\n        # don't create an experiment if we don't have one\n        return self._experiment.id if self._experiment else self._id\n\n    @rank_zero_only\n    def after_save_checkpoint(\n        self, checkpoint_callback: \"ReferenceType[ModelCheckpoint]\"\n    ) -> None:\n        # log checkpoints as artifacts\n        if self._checkpoint_callback is None:\n            self._checkpoint_callback = checkpoint_callback\n            self._save_last = checkpoint_callback.save_last\n        if self._log_model:\n            self._scan_and_log_checkpoints(\n                checkpoint_callback, self._save_last and not self._save_last_only_final\n            )\n\n    @rank_zero_only\n    def finalize(self, status: str) -> None:\n        # log checkpoints as artifacts\n        if self._checkpoint_callback and self._log_model:\n            self._scan_and_log_checkpoints(self._checkpoint_callback, self._save_last)\n\n    def _get_public_run(self):\n        if self._public_run is None:\n            experiment = self.experiment\n            runpath = (\n                experiment._entity\n                + \"/\"\n                + experiment._project\n                + \"/\"\n                + experiment._run_id\n            )\n            api = wandb.Api()\n            self._public_run = api.run(path=runpath)\n        return self._public_run\n\n    def _num_logged_artifact(self):\n        public_run = self._get_public_run()\n        return len(public_run.logged_artifacts())\n\n    def _scan_and_log_checkpoints(\n        self, checkpoint_callback: \"ReferenceType[ModelCheckpoint]\", save_last: bool\n    ) -> None:\n        assert self._log_model\n        if self._checkpoint_callback is None:\n            self._checkpoint_callback = checkpoint_callback\n            self._save_last = checkpoint_callback.save_last\n\n        checkpoints = {\n            checkpoint_callback.best_model_path: checkpoint_callback.best_model_score,\n            **checkpoint_callback.best_k_models,\n        }\n        assert len(checkpoints) <= max(checkpoint_callback.save_top_k, 0)\n\n        if save_last:\n            last_model_path = Path(checkpoint_callback.last_model_path)\n            if last_model_path.exists():\n                checkpoints.update(\n                    {\n                        checkpoint_callback.last_model_path: checkpoint_callback.current_score\n                    }\n                )\n            else:\n                print(\n                    f\"last model checkpoint not found at {checkpoint_callback.last_model_path}\"\n                )\n\n        checkpoints = sorted(\n            (\n                (Path(path).stat().st_mtime, path, score)\n                for path, score in checkpoints.items()\n                if Path(path).is_file()\n            ),\n            key=lambda x: x[0],\n        )\n        # Retain only checkpoints that we have not logged before with one exception:\n        # If the name is the same (e.g. last checkpoint which should be overwritten),\n        # make sure that they are newer than the previously saved checkpoint by checking their modification time\n        checkpoints = [\n            ckpt\n            for ckpt in checkpoints\n            if ckpt[1] not in self._logged_model_time.keys()\n            or self._logged_model_time[ckpt[1]] < ckpt[0]\n        ]\n        # remove checkpoints with undefined (None) score\n        checkpoints = [x for x in checkpoints if x[2] is not None]\n\n        num_ckpt_logged_before = self._num_logged_artifact()\n        num_new_cktps = len(checkpoints)\n\n        if num_new_cktps == 0:\n            return\n\n        # log iteratively all new checkpoints\n        for time_, path, score in checkpoints:\n            score = score.item() if isinstance(score, torch.Tensor) else score\n            is_best = path == checkpoint_callback.best_model_path\n            is_last = path == checkpoint_callback.last_model_path\n            metadata = {\n                \"score\": score,\n                \"original_filename\": Path(path).name,\n                \"ModelCheckpoint\": {\n                    k: getattr(checkpoint_callback, k)\n                    for k in [\n                        \"monitor\",\n                        \"mode\",\n                        \"save_last\",\n                        \"save_top_k\",\n                        \"save_weights_only\",\n                    ]\n                    # ensure it does not break if `ModelCheckpoint` args change\n                    if hasattr(checkpoint_callback, k)\n                },\n            }\n            aliases = []\n            if is_best:\n                aliases.append(\"best\")\n            if is_last:\n                aliases.append(\"last\")\n            artifact_name = f\"checkpoint-{self.experiment.id}-\" + (\n                \"last\" if is_last else \"topK\"\n            )\n            artifact = wandb.Artifact(\n                name=artifact_name, type=\"model\", metadata=metadata\n            )\n            assert Path(path).exists()\n            artifact.add_file(path, name=f\"{self.experiment.id}.ckpt\")\n            self.experiment.log_artifact(artifact, aliases=aliases)\n            # remember logged models - timestamp needed in case filename didn't change (last.ckpt or custom name)\n            self._logged_model_time[path] = time_\n\n        timeout = 20\n        time_spent = 0\n        while self._num_logged_artifact() < num_ckpt_logged_before + num_new_cktps:\n            time.sleep(1)\n            time_spent += 1\n            if time_spent >= timeout:\n                rank_zero_warn(\n                    \"Timeout: Num logged artifacts never reached expected value.\"\n                )\n                print(f\"self._num_logged_artifact() = {self._num_logged_artifact()}\")\n                print(f\"num_ckpt_logged_before = {num_ckpt_logged_before}\")\n                print(f\"num_new_cktps = {num_new_cktps}\")\n                break\n\n        try:\n            self._rm_but_top_k(checkpoint_callback.save_top_k)\n        except KeyError:\n            pass\n\n    def _rm_but_top_k(self, top_k: int):\n        # top_k == -1: save all models\n        # top_k == 0: no models saved at all. The checkpoint callback does not return checkpoints.\n        # top_k > 0: keep only top k models (last and best will not be deleted)\n        def is_last(artifact):\n            return \"last\" in artifact.aliases\n\n        def is_best(artifact):\n            return \"best\" in artifact.aliases\n\n        def try_delete(artifact):\n            try:\n                artifact.delete(delete_aliases=True)\n            except wandb.errors.CommError:\n                print(\n                    f\"Failed to delete artifact {artifact.name} due to wandb.errors.CommError\"\n                )\n\n        public_run = self._get_public_run()\n\n        score2art = list()\n        for artifact in public_run.logged_artifacts():\n            score = artifact.metadata[\"score\"]\n            original_filename = artifact.metadata[\"original_filename\"]\n            if score == \"Infinity\":\n                print(\n                    f\"removing INF artifact (name, score, original_filename): ({artifact.name}, {score}, {original_filename})\"\n                )\n                try_delete(artifact)\n                continue\n            if score is None:\n                print(\n                    f\"removing None artifact (name, score, original_filename): ({artifact.name}, {score}, {original_filename})\"\n                )\n                try_delete(artifact)\n                continue\n            score2art.append((score, artifact))\n\n        # From high score to low score\n        score2art.sort(key=lambda x: x[0], reverse=True)\n\n        count = 0\n        for score, artifact in score2art:\n            original_filename = artifact.metadata[\"original_filename\"]\n            if \"last\" in original_filename and not is_last(artifact):\n                try_delete(artifact)\n                continue\n            if is_last(artifact):\n                continue\n            count += 1\n            if is_best(artifact):\n                continue\n            # if top_k == -1, we do not delete anything\n            if 0 <= top_k < count:\n                try_delete(artifact)\n"
  },
  {
    "path": "RVT/models/detection/__init_.py",
    "content": ""
  },
  {
    "path": "RVT/models/detection/recurrent_backbone/__init__.py",
    "content": "from omegaconf import DictConfig\n\nfrom .maxvit_rnn import RNNDetector as MaxViTRNNDetector\n\n\ndef build_recurrent_backbone(backbone_cfg: DictConfig):\n    name = backbone_cfg.name\n    if name == \"MaxViTRNN\":\n        return MaxViTRNNDetector(backbone_cfg)\n    else:\n        raise NotImplementedError\n"
  },
  {
    "path": "RVT/models/detection/recurrent_backbone/base.py",
    "content": "from typing import Tuple\n\nimport torch.nn as nn\n\n\nclass BaseDetector(nn.Module):\n    def get_stage_dims(self, stages: Tuple[int, ...]) -> Tuple[int, ...]:\n        raise NotImplementedError\n\n    def get_strides(self, stages: Tuple[int, ...]) -> Tuple[int, ...]:\n        raise NotImplementedError\n"
  },
  {
    "path": "RVT/models/detection/recurrent_backbone/maxvit_rnn.py",
    "content": "from typing import Dict, Optional, Tuple\nimport torch as th\nimport torch.nn as nn\nfrom omegaconf import DictConfig, OmegaConf\nfrom einops import rearrange\n\ntry:\n    from torch import compile as th_compile\nexcept ImportError:\n    th_compile = None\n\nfrom data.utils.types import FeatureMap, BackboneFeatures, LstmState, LstmStates\n\n# from models.layers.rnn import DWSConvLSTM2d\nfrom models.layers.s5.s5_model import S5Block\n\nfrom models.layers.maxvit.maxvit import (\n    PartitionAttentionCl,\n    nhwC_2_nChw,\n    get_downsample_layer_Cf2Cl,\n    PartitionType,\n)\n\nfrom .base import BaseDetector\n\n\nclass RNNDetector(BaseDetector):\n    def __init__(self, mdl_config: DictConfig):\n        super().__init__()\n\n        ###### Config ######\n        in_channels = mdl_config.input_channels\n        embed_dim = mdl_config.embed_dim\n        dim_multiplier_per_stage = tuple(mdl_config.dim_multiplier)\n        num_blocks_per_stage = tuple(mdl_config.num_blocks)\n        T_max_chrono_init_per_stage = tuple(mdl_config.T_max_chrono_init)\n        enable_masking = mdl_config.enable_masking\n\n        num_stages = len(num_blocks_per_stage)\n        assert num_stages == 4\n\n        assert isinstance(embed_dim, int)\n        assert num_stages == len(dim_multiplier_per_stage)\n        assert num_stages == len(num_blocks_per_stage)\n        assert num_stages == len(T_max_chrono_init_per_stage)\n\n        ###### Compile if requested ######\n        compile_cfg = mdl_config.get(\"compile\", None)\n        if compile_cfg is not None:\n            compile_mdl = compile_cfg.enable\n            if compile_mdl and th_compile is not None:\n                compile_args = OmegaConf.to_container(\n                    compile_cfg.args, resolve=True, throw_on_missing=True\n                )\n                self.forward = th_compile(self.forward, **compile_args)\n            elif compile_mdl:\n                print(\n                    \"Could not compile backbone because torch.compile is not available\"\n                )\n        ##################################\n\n        input_dim = in_channels\n        patch_size = mdl_config.stem.patch_size\n        stride = 1\n        self.stage_dims = [embed_dim * x for x in dim_multiplier_per_stage]\n\n        self.stages = nn.ModuleList()\n        self.strides = []\n        for stage_idx, (num_blocks, T_max_chrono_init_stage) in enumerate(\n            zip(num_blocks_per_stage, T_max_chrono_init_per_stage)\n        ):\n            spatial_downsample_factor = patch_size if stage_idx == 0 else 2\n            stage_dim = self.stage_dims[stage_idx]\n            enable_masking_in_stage = enable_masking and stage_idx == 0\n            stage = RNNDetectorStage(\n                dim_in=input_dim,\n                stage_dim=stage_dim,\n                spatial_downsample_factor=spatial_downsample_factor,\n                num_blocks=num_blocks,\n                enable_token_masking=enable_masking_in_stage,\n                T_max_chrono_init=T_max_chrono_init_stage,\n                stage_cfg=mdl_config.stage,\n            )\n            stride = stride * spatial_downsample_factor\n            self.strides.append(stride)\n\n            input_dim = stage_dim\n            self.stages.append(stage)\n\n        self.num_stages = num_stages\n\n    def get_stage_dims(self, stages: Tuple[int, ...]) -> Tuple[int, ...]:\n        stage_indices = [x - 1 for x in stages]\n        assert min(stage_indices) >= 0, stage_indices\n        assert max(stage_indices) < len(self.stages), stage_indices\n        return tuple(self.stage_dims[stage_idx] for stage_idx in stage_indices)\n\n    def get_strides(self, stages: Tuple[int, ...]) -> Tuple[int, ...]:\n        stage_indices = [x - 1 for x in stages]\n        assert min(stage_indices) >= 0, stage_indices\n        assert max(stage_indices) < len(self.stages), stage_indices\n        return tuple(self.strides[stage_idx] for stage_idx in stage_indices)\n\n    def forward(\n        self,\n        x: th.Tensor,\n        prev_states: Optional[LstmStates] = None,\n        token_mask: Optional[th.Tensor] = None,\n        train_step: bool = True,\n    ) -> Tuple[BackboneFeatures, LstmStates]:\n        if prev_states is None:\n            prev_states = [None] * self.num_stages\n        assert len(prev_states) == self.num_stages\n        states: LstmStates = list()\n        output: Dict[int, FeatureMap] = {}\n        for stage_idx, stage in enumerate(self.stages):\n            x, state = stage(\n                x,\n                prev_states[stage_idx],\n                token_mask if stage_idx == 0 else None,\n                train_step,\n            )\n            states.append(state)\n            stage_number = stage_idx + 1\n            output[stage_number] = x\n        return output, states\n\n\nclass MaxVitAttentionPairCl(nn.Module):\n    def __init__(self, dim: int, skip_first_norm: bool, attention_cfg: DictConfig):\n        super().__init__()\n\n        self.att_window = PartitionAttentionCl(\n            dim=dim,\n            partition_type=PartitionType.WINDOW,\n            attention_cfg=attention_cfg,\n            skip_first_norm=skip_first_norm,\n        )\n        self.att_grid = PartitionAttentionCl(\n            dim=dim,\n            partition_type=PartitionType.GRID,\n            attention_cfg=attention_cfg,\n            skip_first_norm=False,\n        )\n\n    def forward(self, x):\n        x = self.att_window(x)\n        x = self.att_grid(x)\n        return x\n\n\nclass RNNDetectorStage(nn.Module):\n    \"\"\"Operates with NCHW [channel-first] format as input and output.\"\"\"\n\n    def __init__(\n        self,\n        dim_in: int,\n        stage_dim: int,\n        spatial_downsample_factor: int,\n        num_blocks: int,\n        enable_token_masking: bool,\n        T_max_chrono_init: Optional[int],\n        stage_cfg: DictConfig,\n    ):\n        super().__init__()\n        assert isinstance(num_blocks, int) and num_blocks > 0\n        downsample_cfg = stage_cfg.downsample\n        lstm_cfg = stage_cfg.lstm\n        attention_cfg = stage_cfg.attention\n\n        self.downsample_cf2cl = get_downsample_layer_Cf2Cl(\n            dim_in=dim_in,\n            dim_out=stage_dim,\n            downsample_factor=spatial_downsample_factor,\n            downsample_cfg=downsample_cfg,\n        )\n        blocks = [\n            MaxVitAttentionPairCl(\n                dim=stage_dim,\n                skip_first_norm=i == 0 and self.downsample_cf2cl.output_is_normed(),\n                attention_cfg=attention_cfg,\n            )\n            for i in range(num_blocks)\n        ]\n        self.att_blocks = nn.ModuleList(blocks)\n\n        self.s5_block = S5Block(\n            dim=stage_dim, state_dim=stage_dim, bidir=False, bandlimit=0.5\n        )\n\n        \"\"\"\n        self.lstm = DWSConvLSTM2d(\n            dim=stage_dim,\n            dws_conv=lstm_cfg.dws_conv,\n            dws_conv_only_hidden=lstm_cfg.dws_conv_only_hidden,\n            dws_conv_kernel_size=lstm_cfg.dws_conv_kernel_size,\n            cell_update_dropout=lstm_cfg.get(\"drop_cell_update\", 0),\n        )\n        \"\"\"\n\n        ###### Mask Token ################\n        self.mask_token = (\n            nn.Parameter(th.zeros(1, 1, 1, stage_dim), requires_grad=True)\n            if enable_token_masking\n            else None\n        )\n\n        if self.mask_token is not None:\n            th.nn.init.normal_(self.mask_token, std=0.02)\n        ##################################\n\n    def forward(\n        self,\n        x: th.Tensor,\n        states: Optional[LstmState] = None,\n        token_mask: Optional[th.Tensor] = None,\n        train_step: bool = True,\n    ) -> Tuple[FeatureMap, LstmState]:\n        sequence_length = x.shape[0]\n        batch_size = x.shape[1]\n        x = rearrange(\n            x, \"L B C H W -> (L B) C H W\"\n        )  # where B' = (L B) is the new batch size\n        x = self.downsample_cf2cl(x)  # B' C H W -> B' H W C\n\n        if token_mask is not None:\n            assert self.mask_token is not None, \"No mask token present in this stage\"\n            x[token_mask] = self.mask_token\n\n        for blk in self.att_blocks:\n            x = blk(x)\n        x = nhwC_2_nChw(x)  # B' H W C -> B' C H W\n\n        new_h, new_w = x.shape[2], x.shape[3]\n\n        x = rearrange(x, \"(L B) C H W -> (B H W) L C\", L=sequence_length)\n\n        if states is None:\n            states = self.s5_block.s5.initial_state(\n                batch_size=batch_size * new_h * new_w\n            ).to(x.device)\n        else:\n            states = rearrange(states, \"B C H W -> (B H W) C\")\n\n        x, states = self.s5_block(x, states)\n\n        x = rearrange(\n            x, \"(B H W) L C -> L B C H W\", B=batch_size, H=int(new_h), W=int(new_w)\n        )\n\n        states = rearrange(states, \"(B H W) C -> B C H W\", H=new_h, W=new_w)\n\n        return x, states\n"
  },
  {
    "path": "RVT/models/detection/yolox/models/__init__.py",
    "content": ""
  },
  {
    "path": "RVT/models/detection/yolox/models/losses.py",
    "content": "#!/usr/bin/env python\n# -*- encoding: utf-8 -*-\n# Copyright (c) Megvii Inc. All rights reserved.\n\nimport torch\nimport torch.nn as nn\n\n\nclass IOUloss(nn.Module):\n    def __init__(self, reduction=\"none\", loss_type=\"iou\"):\n        super(IOUloss, self).__init__()\n        self.reduction = reduction\n        self.loss_type = loss_type\n\n    def forward(self, pred, target):\n        assert pred.shape[0] == target.shape[0]\n\n        pred = pred.view(-1, 4)\n        target = target.view(-1, 4)\n        tl = torch.max(\n            (pred[:, :2] - pred[:, 2:] / 2), (target[:, :2] - target[:, 2:] / 2)\n        )\n        br = torch.min(\n            (pred[:, :2] + pred[:, 2:] / 2), (target[:, :2] + target[:, 2:] / 2)\n        )\n\n        area_p = torch.prod(pred[:, 2:], 1)\n        area_g = torch.prod(target[:, 2:], 1)\n\n        en = (tl < br).type(tl.type()).prod(dim=1)\n        area_i = torch.prod(br - tl, 1) * en\n        area_u = area_p + area_g - area_i\n        iou = (area_i) / (area_u + 1e-16)\n\n        if self.loss_type == \"iou\":\n            loss = 1 - iou**2\n        elif self.loss_type == \"giou\":\n            c_tl = torch.min(\n                (pred[:, :2] - pred[:, 2:] / 2), (target[:, :2] - target[:, 2:] / 2)\n            )\n            c_br = torch.max(\n                (pred[:, :2] + pred[:, 2:] / 2), (target[:, :2] + target[:, 2:] / 2)\n            )\n            area_c = torch.prod(c_br - c_tl, 1)\n            giou = iou - (area_c - area_u) / area_c.clamp(1e-16)\n            loss = 1 - giou.clamp(min=-1.0, max=1.0)\n        else:\n            raise NotImplementedError\n\n        if self.reduction == \"mean\":\n            loss = loss.mean()\n        elif self.reduction == \"sum\":\n            loss = loss.sum()\n\n        return loss\n"
  },
  {
    "path": "RVT/models/detection/yolox/models/network_blocks.py",
    "content": "#!/usr/bin/env python\n# -*- encoding: utf-8 -*-\n# Copyright (c) Megvii Inc. All rights reserved.\n\nimport torch\nimport torch.nn as nn\n\n\nclass SiLU(nn.Module):\n    \"\"\"export-friendly version of nn.SiLU()\"\"\"\n\n    @staticmethod\n    def forward(x):\n        return x * torch.sigmoid(x)\n\n\ndef get_activation(name=\"silu\", inplace=True):\n    if name == \"silu\":\n        module = nn.SiLU(inplace=inplace)\n    elif name == \"relu\":\n        module = nn.ReLU(inplace=inplace)\n    elif name == \"lrelu\":\n        module = nn.LeakyReLU(0.1, inplace=inplace)\n    else:\n        raise AttributeError(\"Unsupported act type: {}\".format(name))\n    return module\n\n\nclass BaseConv(nn.Module):\n    \"\"\"A Conv2d -> Batchnorm -> silu/leaky relu block\"\"\"\n\n    def __init__(\n        self, in_channels, out_channels, ksize, stride, groups=1, bias=False, act=\"silu\"\n    ):\n        super().__init__()\n        # same padding\n        pad = (ksize - 1) // 2\n        self.conv = nn.Conv2d(\n            in_channels,\n            out_channels,\n            kernel_size=ksize,\n            stride=stride,\n            padding=pad,\n            groups=groups,\n            bias=bias,\n        )\n        self.bn = nn.BatchNorm2d(out_channels)\n        self.act = get_activation(act, inplace=True)\n\n    def forward(self, x):\n        return self.act(self.bn(self.conv(x)))\n\n    def fuseforward(self, x):\n        return self.act(self.conv(x))\n\n\nclass DWConv(nn.Module):\n    \"\"\"Depthwise Conv + Conv\"\"\"\n\n    def __init__(self, in_channels, out_channels, ksize, stride=1, act=\"silu\"):\n        super().__init__()\n        self.dconv = BaseConv(\n            in_channels,\n            in_channels,\n            ksize=ksize,\n            stride=stride,\n            groups=in_channels,\n            act=act,\n        )\n        self.pconv = BaseConv(\n            in_channels, out_channels, ksize=1, stride=1, groups=1, act=act\n        )\n\n    def forward(self, x):\n        x = self.dconv(x)\n        return self.pconv(x)\n\n\nclass Bottleneck(nn.Module):\n    # Standard bottleneck\n    def __init__(\n        self,\n        in_channels,\n        out_channels,\n        shortcut=True,\n        expansion=0.5,\n        depthwise=False,\n        act=\"silu\",\n    ):\n        super().__init__()\n        hidden_channels = int(out_channels * expansion)\n        Conv = DWConv if depthwise else BaseConv\n        self.conv1 = BaseConv(in_channels, hidden_channels, 1, stride=1, act=act)\n        self.conv2 = Conv(hidden_channels, out_channels, 3, stride=1, act=act)\n        self.use_add = shortcut and in_channels == out_channels\n\n    def forward(self, x):\n        y = self.conv2(self.conv1(x))\n        if self.use_add:\n            y = y + x\n        return y\n\n\nclass CSPLayer(nn.Module):\n    \"\"\"C3 in yolov5, CSP Bottleneck with 3 convolutions\"\"\"\n\n    def __init__(\n        self,\n        in_channels,\n        out_channels,\n        n=1,\n        shortcut=True,\n        expansion=0.5,\n        depthwise=False,\n        act=\"silu\",\n    ):\n        \"\"\"\n        Args:\n            in_channels (int): input channels.\n            out_channels (int): output channels.\n            n (int): number of Bottlenecks. Default value: 1.\n        \"\"\"\n        # ch_in, ch_out, number, shortcut, groups, expansion\n        super().__init__()\n        hidden_channels = int(out_channels * expansion)  # hidden channels\n        self.conv1 = BaseConv(in_channels, hidden_channels, 1, stride=1, act=act)\n        self.conv2 = BaseConv(in_channels, hidden_channels, 1, stride=1, act=act)\n        self.conv3 = BaseConv(2 * hidden_channels, out_channels, 1, stride=1, act=act)\n        module_list = [\n            Bottleneck(\n                hidden_channels, hidden_channels, shortcut, 1.0, depthwise, act=act\n            )\n            for _ in range(n)\n        ]\n        self.m = nn.Sequential(*module_list)\n\n    def forward(self, x):\n        x_1 = self.conv1(x)\n        x_2 = self.conv2(x)\n        x_1 = self.m(x_1)\n        x = torch.cat((x_1, x_2), dim=1)\n        return self.conv3(x)\n"
  },
  {
    "path": "RVT/models/detection/yolox/models/yolo_head.py",
    "content": "\"\"\"\nOriginal Yolox Head code with slight modifications\n\"\"\"\n\nimport math\nfrom typing import Dict, Optional\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\ntry:\n    from torch import compile as th_compile\nexcept ImportError:\n    th_compile = None\n\nfrom models.detection.yolox.utils import bboxes_iou\n\nfrom .losses import IOUloss\nfrom .network_blocks import BaseConv, DWConv\n\n\nclass YOLOXHead(nn.Module):\n    def __init__(\n        self,\n        num_classes=80,\n        strides=(8, 16, 32),\n        in_channels=(256, 512, 1024),\n        act=\"silu\",\n        depthwise=False,\n        compile_cfg: Optional[Dict] = None,\n    ):\n        super().__init__()\n\n        self.num_classes = num_classes\n        self.decode_in_inference = True  # for deploy, set to False\n\n        self.cls_convs = nn.ModuleList()\n        self.reg_convs = nn.ModuleList()\n        self.cls_preds = nn.ModuleList()\n        self.reg_preds = nn.ModuleList()\n        self.obj_preds = nn.ModuleList()\n        self.stems = nn.ModuleList()\n        Conv = DWConv if depthwise else BaseConv\n\n        self.output_strides = None\n        self.output_grids = None\n\n        # Automatic width scaling according to original YoloX channel dims.\n        # in[-1]/out = 4/1\n        # out = in[-1]/4 = 256 * width\n        # -> width = in[-1]/1024\n        largest_base_dim_yolox = 1024\n        largest_base_dim_from_input = in_channels[-1]\n        width = largest_base_dim_from_input / largest_base_dim_yolox\n\n        hidden_dim = int(256 * width)\n\n        for i in range(len(in_channels)):\n            self.stems.append(\n                BaseConv(\n                    in_channels=in_channels[i],\n                    out_channels=hidden_dim,\n                    ksize=1,\n                    stride=1,\n                    act=act,\n                )\n            )\n            self.cls_convs.append(\n                nn.Sequential(\n                    *[\n                        Conv(\n                            in_channels=hidden_dim,\n                            out_channels=hidden_dim,\n                            ksize=3,\n                            stride=1,\n                            act=act,\n                        ),\n                        Conv(\n                            in_channels=hidden_dim,\n                            out_channels=hidden_dim,\n                            ksize=3,\n                            stride=1,\n                            act=act,\n                        ),\n                    ]\n                )\n            )\n            self.reg_convs.append(\n                nn.Sequential(\n                    *[\n                        Conv(\n                            in_channels=hidden_dim,\n                            out_channels=hidden_dim,\n                            ksize=3,\n                            stride=1,\n                            act=act,\n                        ),\n                        Conv(\n                            in_channels=hidden_dim,\n                            out_channels=hidden_dim,\n                            ksize=3,\n                            stride=1,\n                            act=act,\n                        ),\n                    ]\n                )\n            )\n            self.cls_preds.append(\n                nn.Conv2d(\n                    in_channels=hidden_dim,\n                    out_channels=self.num_classes,\n                    kernel_size=1,\n                    stride=1,\n                    padding=0,\n                )\n            )\n            self.reg_preds.append(\n                nn.Conv2d(\n                    in_channels=hidden_dim,\n                    out_channels=4,\n                    kernel_size=1,\n                    stride=1,\n                    padding=0,\n                )\n            )\n            self.obj_preds.append(\n                nn.Conv2d(\n                    in_channels=hidden_dim,\n                    out_channels=1,\n                    kernel_size=1,\n                    stride=1,\n                    padding=0,\n                )\n            )\n\n        self.use_l1 = False\n        self.l1_loss = nn.L1Loss(reduction=\"none\")\n        self.bcewithlog_loss = nn.BCEWithLogitsLoss(reduction=\"none\")\n        self.iou_loss = IOUloss(reduction=\"none\")\n        self.strides = strides\n        self.grids = [torch.zeros(1)] * len(in_channels)\n\n        # According to Focal Loss paper:\n        self.initialize_biases(prior_prob=0.01)\n\n        ###### Compile if requested ######\n        if compile_cfg is not None:\n            compile_mdl = compile_cfg[\"enable\"]\n            if compile_mdl and th_compile is not None:\n                self.forward = th_compile(self.forward, **compile_cfg[\"args\"])\n            elif compile_mdl:\n                print(\n                    \"Could not compile YOLOXHead because torch.compile is not available\"\n                )\n        ##################################\n\n    def initialize_biases(self, prior_prob):\n        for conv in self.cls_preds:\n            b = conv.bias.view(1, -1)\n            b.data.fill_(-math.log((1 - prior_prob) / prior_prob))\n            conv.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)\n\n        for conv in self.obj_preds:\n            b = conv.bias.view(1, -1)\n            b.data.fill_(-math.log((1 - prior_prob) / prior_prob))\n            conv.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)\n\n    def forward(self, xin, labels=None):\n        train_outputs = []\n        inference_outputs = []\n        origin_preds = []\n        x_shifts = []\n        y_shifts = []\n        expanded_strides = []\n\n        for k, (cls_conv, reg_conv, stride_this_level, x) in enumerate(\n            zip(self.cls_convs, self.reg_convs, self.strides, xin)\n        ):\n            x = self.stems[k](x)\n            cls_x = x\n            reg_x = x\n\n            cls_feat = cls_conv(cls_x)\n            cls_output = self.cls_preds[k](cls_feat)\n\n            reg_feat = reg_conv(reg_x)\n            reg_output = self.reg_preds[k](reg_feat)\n            obj_output = self.obj_preds[k](reg_feat)\n\n            if self.training:\n                output = torch.cat([reg_output, obj_output, cls_output], 1)\n                output, grid = self.get_output_and_grid(\n                    output, k, stride_this_level, xin[0].type()\n                )\n                x_shifts.append(grid[:, :, 0])\n                y_shifts.append(grid[:, :, 1])\n                expanded_strides.append(\n                    torch.zeros(1, grid.shape[1])\n                    .fill_(stride_this_level)\n                    .type_as(xin[0])\n                )\n                if self.use_l1:\n                    batch_size = reg_output.shape[0]\n                    hsize, wsize = reg_output.shape[-2:]\n                    reg_output = reg_output.view(batch_size, 1, 4, hsize, wsize)\n                    reg_output = reg_output.permute(0, 1, 3, 4, 2).reshape(\n                        batch_size, -1, 4\n                    )\n                    origin_preds.append(reg_output.clone())\n                train_outputs.append(output)\n            inference_output = torch.cat(\n                [reg_output, obj_output.sigmoid(), cls_output.sigmoid()], 1\n            )\n            inference_outputs.append(inference_output)\n\n        # --------------------------------------------------------\n        # Modification: return decoded output also during training\n        # --------------------------------------------------------\n        losses = None\n        if self.training:\n            losses = self.get_losses(\n                x_shifts,\n                y_shifts,\n                expanded_strides,\n                labels,\n                torch.cat(train_outputs, 1),\n                origin_preds,\n                dtype=xin[0].dtype,\n            )\n            assert len(losses) == 6\n            losses = {\n                \"loss\": losses[0],\n                \"iou_loss\": losses[1],\n                \"conf_loss\": losses[2],  # object-ness\n                \"cls_loss\": losses[3],  # predicted class\n                \"l1_loss\": losses[4],\n                \"num_fg\": losses[5],\n            }\n        self.hw = [x.shape[-2:] for x in inference_outputs]\n        # [batch, n_anchors_all, 85]\n        outputs = torch.cat(\n            [x.flatten(start_dim=2) for x in inference_outputs], dim=2\n        ).permute(0, 2, 1)\n        if self.decode_in_inference:\n            return self.decode_outputs(outputs), losses\n        else:\n            return outputs, losses\n\n    def get_output_and_grid(self, output, k, stride, dtype):\n        grid = self.grids[k]\n\n        batch_size = output.shape[0]\n        n_ch = 5 + self.num_classes\n        hsize, wsize = output.shape[-2:]\n        if grid.shape[2:4] != output.shape[2:4]:\n            yv, xv = torch.meshgrid([torch.arange(hsize), torch.arange(wsize)])\n            grid = torch.stack((xv, yv), 2).view(1, 1, hsize, wsize, 2).type(dtype)\n            self.grids[k] = grid\n\n        output = output.view(batch_size, 1, n_ch, hsize, wsize)\n        output = output.permute(0, 1, 3, 4, 2).reshape(batch_size, hsize * wsize, -1)\n        grid = grid.view(1, -1, 2)\n        output[..., :2] = (output[..., :2] + grid) * stride\n        output[..., 2:4] = torch.exp(output[..., 2:4]) * stride\n        return output, grid\n\n    def decode_outputs(self, outputs):\n        if self.output_grids is None:\n            assert self.output_strides is None\n            dtype = outputs.dtype\n            device = outputs.device\n            grids = []\n            strides = []\n            for (hsize, wsize), stride in zip(self.hw, self.strides):\n                yv, xv = torch.meshgrid(\n                    [\n                        torch.arange(hsize, device=device, dtype=dtype),\n                        torch.arange(wsize, device=device, dtype=dtype),\n                    ]\n                )\n                grid = torch.stack((xv, yv), 2).view(1, -1, 2)\n                grids.append(grid)\n                shape = grid.shape[:2]\n                strides.append(\n                    torch.full((*shape, 1), stride, device=device, dtype=dtype)\n                )\n            self.output_grids = torch.cat(grids, dim=1)\n            self.output_strides = torch.cat(strides, dim=1)\n        outputs = torch.cat(\n            [\n                (outputs[..., 0:2] + self.output_grids) * self.output_strides,\n                torch.exp(outputs[..., 2:4]) * self.output_strides,\n                outputs[..., 4:],\n            ],\n            dim=-1,\n        )\n        return outputs\n\n    def get_losses(\n        self,\n        x_shifts,\n        y_shifts,\n        expanded_strides,\n        labels,\n        outputs,\n        origin_preds,\n        dtype,\n    ):\n        bbox_preds = outputs[:, :, :4]  # [batch, n_anchors_all, 4]\n        obj_preds = outputs[:, :, 4:5]  # [batch, n_anchors_all, 1]\n        cls_preds = outputs[:, :, 5:]  # [batch, n_anchors_all, n_cls]\n\n        # calculate targets\n        nlabel = (labels.sum(dim=2) > 0).sum(dim=1)  # number of objects\n\n        total_num_anchors = outputs.shape[1]\n        x_shifts = torch.cat(x_shifts, 1)  # [1, n_anchors_all]\n        y_shifts = torch.cat(y_shifts, 1)  # [1, n_anchors_all]\n        expanded_strides = torch.cat(expanded_strides, 1)\n        if self.use_l1:\n            origin_preds = torch.cat(origin_preds, 1)\n\n        cls_targets = []\n        reg_targets = []\n        l1_targets = []\n        obj_targets = []\n        fg_masks = []\n\n        num_fg = 0.0\n        num_gts = 0.0\n\n        for batch_idx in range(outputs.shape[0]):\n            num_gt = int(nlabel[batch_idx])\n            num_gts += num_gt\n            if num_gt == 0:\n                cls_target = outputs.new_zeros((0, self.num_classes))\n                reg_target = outputs.new_zeros((0, 4))\n                l1_target = outputs.new_zeros((0, 4))\n                obj_target = outputs.new_zeros((total_num_anchors, 1))\n                fg_mask = outputs.new_zeros(total_num_anchors).bool()\n            else:\n                gt_bboxes_per_image = labels[batch_idx, :num_gt, 1:5]\n                gt_classes = labels[batch_idx, :num_gt, 0]\n                bboxes_preds_per_image = bbox_preds[batch_idx]\n\n                try:\n                    (\n                        gt_matched_classes,\n                        fg_mask,\n                        pred_ious_this_matching,\n                        matched_gt_inds,\n                        num_fg_img,\n                    ) = self.get_assignments(  # noqa\n                        batch_idx,\n                        num_gt,\n                        gt_bboxes_per_image,\n                        gt_classes,\n                        bboxes_preds_per_image,\n                        expanded_strides,\n                        x_shifts,\n                        y_shifts,\n                        cls_preds,\n                        obj_preds,\n                    )\n                except RuntimeError as e:\n                    # TODO: the string might change, consider a better way\n                    if \"CUDA out of memory. \" not in str(e):\n                        raise\n\n                    torch.cuda.empty_cache()\n                    (\n                        gt_matched_classes,\n                        fg_mask,\n                        pred_ious_this_matching,\n                        matched_gt_inds,\n                        num_fg_img,\n                    ) = self.get_assignments(  # noqa\n                        batch_idx,\n                        num_gt,\n                        gt_bboxes_per_image,\n                        gt_classes,\n                        bboxes_preds_per_image,\n                        expanded_strides,\n                        x_shifts,\n                        y_shifts,\n                        cls_preds,\n                        obj_preds,\n                        \"cpu\",\n                    )\n\n                torch.cuda.empty_cache()\n                num_fg += num_fg_img\n\n                cls_target = F.one_hot(\n                    gt_matched_classes.to(torch.int64), self.num_classes\n                ) * pred_ious_this_matching.unsqueeze(-1)\n                obj_target = fg_mask.unsqueeze(-1)\n                reg_target = gt_bboxes_per_image[matched_gt_inds]\n                if self.use_l1:\n                    l1_target = self.get_l1_target(\n                        outputs.new_zeros((num_fg_img, 4)),\n                        gt_bboxes_per_image[matched_gt_inds],\n                        expanded_strides[0][fg_mask],\n                        x_shifts=x_shifts[0][fg_mask],\n                        y_shifts=y_shifts[0][fg_mask],\n                    )\n\n            cls_targets.append(cls_target)\n            reg_targets.append(reg_target)\n            obj_targets.append(obj_target.to(dtype))\n            fg_masks.append(fg_mask)\n            if self.use_l1:\n                l1_targets.append(l1_target)\n\n        cls_targets = torch.cat(cls_targets, 0)\n        reg_targets = torch.cat(reg_targets, 0)\n        obj_targets = torch.cat(obj_targets, 0)\n        fg_masks = torch.cat(fg_masks, 0)\n        if self.use_l1:\n            l1_targets = torch.cat(l1_targets, 0)\n\n        num_fg = max(num_fg, 1)\n        loss_iou = (\n            self.iou_loss(bbox_preds.view(-1, 4)[fg_masks], reg_targets)\n        ).sum() / num_fg\n        loss_obj = (\n            self.bcewithlog_loss(obj_preds.view(-1, 1), obj_targets)\n        ).sum() / num_fg\n        loss_cls = (\n            self.bcewithlog_loss(\n                cls_preds.view(-1, self.num_classes)[fg_masks], cls_targets\n            )\n        ).sum() / num_fg\n        if self.use_l1:\n            loss_l1 = (\n                self.l1_loss(origin_preds.view(-1, 4)[fg_masks], l1_targets)\n            ).sum() / num_fg\n        else:\n            loss_l1 = 0.0\n\n        reg_weight = 5.0\n        loss = reg_weight * loss_iou + loss_obj + loss_cls + loss_l1\n\n        return (\n            loss,\n            reg_weight * loss_iou,\n            loss_obj,\n            loss_cls,\n            loss_l1,\n            num_fg / max(num_gts, 1),\n        )\n\n    def get_l1_target(self, l1_target, gt, stride, x_shifts, y_shifts, eps=1e-8):\n        l1_target[:, 0] = gt[:, 0] / stride - x_shifts\n        l1_target[:, 1] = gt[:, 1] / stride - y_shifts\n        l1_target[:, 2] = torch.log(gt[:, 2] / stride + eps)\n        l1_target[:, 3] = torch.log(gt[:, 3] / stride + eps)\n        return l1_target\n\n    @torch.no_grad()\n    def get_assignments(\n        self,\n        batch_idx,\n        num_gt,\n        gt_bboxes_per_image,\n        gt_classes,\n        bboxes_preds_per_image,\n        expanded_strides,\n        x_shifts,\n        y_shifts,\n        cls_preds,\n        obj_preds,\n        mode=\"gpu\",\n    ):\n        if mode == \"cpu\":\n            print(\"-----------Using CPU for the Current Batch-------------\")\n            gt_bboxes_per_image = gt_bboxes_per_image.cpu().float()\n            bboxes_preds_per_image = bboxes_preds_per_image.cpu().float()\n            gt_classes = gt_classes.cpu().float()\n            expanded_strides = expanded_strides.cpu().float()\n            x_shifts = x_shifts.cpu()\n            y_shifts = y_shifts.cpu()\n\n        fg_mask, geometry_relation = self.get_geometry_constraint(\n            gt_bboxes_per_image,\n            expanded_strides,\n            x_shifts,\n            y_shifts,\n        )\n\n        bboxes_preds_per_image = bboxes_preds_per_image[fg_mask]\n        cls_preds_ = cls_preds[batch_idx][fg_mask]\n        obj_preds_ = obj_preds[batch_idx][fg_mask]\n        num_in_boxes_anchor = bboxes_preds_per_image.shape[0]\n\n        if mode == \"cpu\":\n            gt_bboxes_per_image = gt_bboxes_per_image.cpu()\n            bboxes_preds_per_image = bboxes_preds_per_image.cpu()\n\n        pair_wise_ious = bboxes_iou(gt_bboxes_per_image, bboxes_preds_per_image, False)\n\n        gt_cls_per_image = F.one_hot(\n            gt_classes.to(torch.int64), self.num_classes\n        ).float()\n        pair_wise_ious_loss = -torch.log(pair_wise_ious + 1e-8)\n\n        if mode == \"cpu\":\n            cls_preds_, obj_preds_ = cls_preds_.cpu(), obj_preds_.cpu()\n\n        with torch.cuda.amp.autocast(enabled=False):\n            cls_preds_ = (\n                cls_preds_.float().sigmoid_() * obj_preds_.float().sigmoid_()\n            ).sqrt()\n            pair_wise_cls_loss = F.binary_cross_entropy(\n                cls_preds_.unsqueeze(0).repeat(num_gt, 1, 1),\n                gt_cls_per_image.unsqueeze(1).repeat(1, num_in_boxes_anchor, 1),\n                reduction=\"none\",\n            ).sum(-1)\n        del cls_preds_\n\n        cost = (\n            pair_wise_cls_loss\n            + 3.0 * pair_wise_ious_loss\n            + float(1e6) * (~geometry_relation)\n        )\n\n        (\n            num_fg,\n            gt_matched_classes,\n            pred_ious_this_matching,\n            matched_gt_inds,\n        ) = self.simota_matching(cost, pair_wise_ious, gt_classes, num_gt, fg_mask)\n        del pair_wise_cls_loss, cost, pair_wise_ious, pair_wise_ious_loss\n\n        if mode == \"cpu\":\n            gt_matched_classes = gt_matched_classes.cuda()\n            fg_mask = fg_mask.cuda()\n            pred_ious_this_matching = pred_ious_this_matching.cuda()\n            matched_gt_inds = matched_gt_inds.cuda()\n\n        return (\n            gt_matched_classes,\n            fg_mask,\n            pred_ious_this_matching,\n            matched_gt_inds,\n            num_fg,\n        )\n\n    def get_geometry_constraint(\n        self,\n        gt_bboxes_per_image,\n        expanded_strides,\n        x_shifts,\n        y_shifts,\n    ):\n        \"\"\"\n        Calculate whether the center of an object is located in a fixed range of\n        an anchor. This is used to avert inappropriate matching. It can also reduce\n        the number of candidate anchors so that the GPU memory is saved.\n        \"\"\"\n        expanded_strides_per_image = expanded_strides[0]\n        x_centers_per_image = (\n            (x_shifts[0] + 0.5) * expanded_strides_per_image\n        ).unsqueeze(0)\n        y_centers_per_image = (\n            (y_shifts[0] + 0.5) * expanded_strides_per_image\n        ).unsqueeze(0)\n\n        # in fixed center\n        center_radius = 1.5\n        center_dist = expanded_strides_per_image.unsqueeze(0) * center_radius\n        gt_bboxes_per_image_l = (gt_bboxes_per_image[:, 0:1]) - center_dist\n        gt_bboxes_per_image_r = (gt_bboxes_per_image[:, 0:1]) + center_dist\n        gt_bboxes_per_image_t = (gt_bboxes_per_image[:, 1:2]) - center_dist\n        gt_bboxes_per_image_b = (gt_bboxes_per_image[:, 1:2]) + center_dist\n\n        c_l = x_centers_per_image - gt_bboxes_per_image_l\n        c_r = gt_bboxes_per_image_r - x_centers_per_image\n        c_t = y_centers_per_image - gt_bboxes_per_image_t\n        c_b = gt_bboxes_per_image_b - y_centers_per_image\n        center_deltas = torch.stack([c_l, c_t, c_r, c_b], 2)\n        is_in_centers = center_deltas.min(dim=-1).values > 0.0\n        anchor_filter = is_in_centers.sum(dim=0) > 0\n        geometry_relation = is_in_centers[:, anchor_filter]\n\n        return anchor_filter, geometry_relation\n\n    def simota_matching(self, cost, pair_wise_ious, gt_classes, num_gt, fg_mask):\n        matching_matrix = torch.zeros_like(cost, dtype=torch.uint8)\n\n        n_candidate_k = min(10, pair_wise_ious.size(1))\n        topk_ious, _ = torch.topk(pair_wise_ious, n_candidate_k, dim=1)\n        dynamic_ks = torch.clamp(topk_ious.sum(1).int(), min=1)\n        for gt_idx in range(num_gt):\n            _, pos_idx = torch.topk(cost[gt_idx], k=dynamic_ks[gt_idx], largest=False)\n            matching_matrix[gt_idx][pos_idx] = 1\n\n        del topk_ious, dynamic_ks, pos_idx\n\n        anchor_matching_gt = matching_matrix.sum(0)\n        # deal with the case that one anchor matches multiple ground-truths\n        if anchor_matching_gt.max() > 1:\n            multiple_match_mask = anchor_matching_gt > 1\n            _, cost_argmin = torch.min(cost[:, multiple_match_mask], dim=0)\n            matching_matrix[:, multiple_match_mask] *= 0\n            matching_matrix[cost_argmin, multiple_match_mask] = 1\n        fg_mask_inboxes = anchor_matching_gt > 0\n        num_fg = fg_mask_inboxes.sum().item()\n\n        fg_mask[fg_mask.clone()] = fg_mask_inboxes\n\n        matched_gt_inds = matching_matrix[:, fg_mask_inboxes].argmax(0)\n        gt_matched_classes = gt_classes[matched_gt_inds]\n\n        pred_ious_this_matching = (matching_matrix * pair_wise_ious).sum(0)[\n            fg_mask_inboxes\n        ]\n        return num_fg, gt_matched_classes, pred_ious_this_matching, matched_gt_inds\n"
  },
  {
    "path": "RVT/models/detection/yolox/utils/__init__.py",
    "content": "#!/usr/bin/env python3\n# -*- coding:utf-8 -*-\n# Copyright (c) Megvii Inc. All rights reserved.\n\nfrom .boxes import *\nfrom .compat import meshgrid\n"
  },
  {
    "path": "RVT/models/detection/yolox/utils/boxes.py",
    "content": "#!/usr/bin/env python3\n# -*- coding:utf-8 -*-\n# Copyright (c) Megvii Inc. All rights reserved.\n\nimport numpy as np\n\nimport torch\nimport torchvision\n\n__all__ = [\n    \"filter_box\",\n    \"postprocess\",\n    \"bboxes_iou\",\n    \"matrix_iou\",\n    \"adjust_box_anns\",\n    \"xyxy2xywh\",\n    \"xyxy2cxcywh\",\n]\n\n\ndef filter_box(output, scale_range):\n    \"\"\"\n    output: (N, 5+class) shape\n    \"\"\"\n    min_scale, max_scale = scale_range\n    w = output[:, 2] - output[:, 0]\n    h = output[:, 3] - output[:, 1]\n    keep = (w * h > min_scale * min_scale) & (w * h < max_scale * max_scale)\n    return output[keep]\n\n\ndef postprocess(\n    prediction, num_classes, conf_thre=0.7, nms_thre=0.45, class_agnostic=False\n):\n    box_corner = prediction.new(prediction.shape)\n    box_corner[:, :, 0] = prediction[:, :, 0] - prediction[:, :, 2] / 2\n    box_corner[:, :, 1] = prediction[:, :, 1] - prediction[:, :, 3] / 2\n    box_corner[:, :, 2] = prediction[:, :, 0] + prediction[:, :, 2] / 2\n    box_corner[:, :, 3] = prediction[:, :, 1] + prediction[:, :, 3] / 2\n    prediction[:, :, :4] = box_corner[:, :, :4]\n\n    output = [None for _ in range(len(prediction))]\n    for i, image_pred in enumerate(prediction):\n        # If none are remaining => process next image\n        if not image_pred.size(0):\n            continue\n        # Get score and class with highest confidence\n        class_conf, class_pred = torch.max(\n            image_pred[:, 5 : 5 + num_classes], 1, keepdim=True\n        )\n\n        conf_mask = (image_pred[:, 4] * class_conf.squeeze() >= conf_thre).squeeze()\n        # Detections ordered as (x1, y1, x2, y2, obj_conf, class_conf, class_pred)\n        detections = torch.cat((image_pred[:, :5], class_conf, class_pred.float()), 1)\n        detections = detections[conf_mask]\n        if not detections.size(0):\n            continue\n\n        if class_agnostic:\n            nms_out_index = torchvision.ops.nms(\n                detections[:, :4],\n                detections[:, 4] * detections[:, 5],\n                nms_thre,\n            )\n        else:\n            nms_out_index = torchvision.ops.batched_nms(\n                detections[:, :4],\n                detections[:, 4] * detections[:, 5],\n                detections[:, 6],\n                nms_thre,\n            )\n\n        detections = detections[nms_out_index]\n        if output[i] is None:\n            output[i] = detections\n        else:\n            output[i] = torch.cat((output[i], detections))\n\n    return output\n\n\ndef bboxes_iou(bboxes_a, bboxes_b, xyxy=True):\n    if bboxes_a.shape[1] != 4 or bboxes_b.shape[1] != 4:\n        raise IndexError\n\n    if xyxy:\n        tl = torch.max(bboxes_a[:, None, :2], bboxes_b[:, :2])\n        br = torch.min(bboxes_a[:, None, 2:], bboxes_b[:, 2:])\n        area_a = torch.prod(bboxes_a[:, 2:] - bboxes_a[:, :2], 1)\n        area_b = torch.prod(bboxes_b[:, 2:] - bboxes_b[:, :2], 1)\n    else:\n        tl = torch.max(\n            (bboxes_a[:, None, :2] - bboxes_a[:, None, 2:] / 2),\n            (bboxes_b[:, :2] - bboxes_b[:, 2:] / 2),\n        )\n        br = torch.min(\n            (bboxes_a[:, None, :2] + bboxes_a[:, None, 2:] / 2),\n            (bboxes_b[:, :2] + bboxes_b[:, 2:] / 2),\n        )\n\n        area_a = torch.prod(bboxes_a[:, 2:], 1)\n        area_b = torch.prod(bboxes_b[:, 2:], 1)\n    en = (tl < br).type(tl.type()).prod(dim=2)\n    area_i = torch.prod(br - tl, 2) * en  # * ((tl < br).all())\n    return area_i / (area_a[:, None] + area_b - area_i)\n\n\ndef matrix_iou(a, b):\n    \"\"\"\n    return iou of a and b, numpy version for data augenmentation\n    \"\"\"\n    lt = np.maximum(a[:, np.newaxis, :2], b[:, :2])\n    rb = np.minimum(a[:, np.newaxis, 2:], b[:, 2:])\n\n    area_i = np.prod(rb - lt, axis=2) * (lt < rb).all(axis=2)\n    area_a = np.prod(a[:, 2:] - a[:, :2], axis=1)\n    area_b = np.prod(b[:, 2:] - b[:, :2], axis=1)\n    return area_i / (area_a[:, np.newaxis] + area_b - area_i + 1e-12)\n\n\ndef adjust_box_anns(bbox, scale_ratio, padw, padh, w_max, h_max):\n    bbox[:, 0::2] = np.clip(bbox[:, 0::2] * scale_ratio + padw, 0, w_max)\n    bbox[:, 1::2] = np.clip(bbox[:, 1::2] * scale_ratio + padh, 0, h_max)\n    return bbox\n\n\ndef xyxy2xywh(bboxes):\n    bboxes[:, 2] = bboxes[:, 2] - bboxes[:, 0]\n    bboxes[:, 3] = bboxes[:, 3] - bboxes[:, 1]\n    return bboxes\n\n\ndef xyxy2cxcywh(bboxes):\n    bboxes[:, 2] = bboxes[:, 2] - bboxes[:, 0]\n    bboxes[:, 3] = bboxes[:, 3] - bboxes[:, 1]\n    bboxes[:, 0] = bboxes[:, 0] + bboxes[:, 2] * 0.5\n    bboxes[:, 1] = bboxes[:, 1] + bboxes[:, 3] * 0.5\n    return bboxes\n"
  },
  {
    "path": "RVT/models/detection/yolox/utils/compat.py",
    "content": "#!/usr/bin/env python3\n# -*- coding:utf-8 -*-\n\nimport torch\n\n_TORCH_VER = [int(x) for x in torch.__version__.split(\".\")[:2]]\n\n__all__ = [\"meshgrid\"]\n\n\ndef meshgrid(*tensors):\n    if _TORCH_VER >= [1, 10]:\n        return torch.meshgrid(*tensors, indexing=\"ij\")\n    else:\n        return torch.meshgrid(*tensors)\n"
  },
  {
    "path": "RVT/models/detection/yolox_extension/models/__init__.py",
    "content": ""
  },
  {
    "path": "RVT/models/detection/yolox_extension/models/build.py",
    "content": "from typing import Tuple\n\nfrom omegaconf import OmegaConf, DictConfig\n\nfrom .yolo_pafpn import YOLOPAFPN\nfrom ...yolox.models.yolo_head import YOLOXHead\n\n\ndef build_yolox_head(\n    head_cfg: DictConfig, in_channels: Tuple[int, ...], strides: Tuple[int, ...]\n):\n    head_cfg_dict = OmegaConf.to_container(\n        head_cfg, resolve=True, throw_on_missing=True\n    )\n    head_cfg_dict.pop(\"name\")\n    head_cfg_dict.pop(\"version\", None)\n    head_cfg_dict.update({\"in_channels\": in_channels})\n    head_cfg_dict.update({\"strides\": strides})\n    compile_cfg = head_cfg_dict.pop(\"compile\", None)\n    head_cfg_dict.update({\"compile_cfg\": compile_cfg})\n    return YOLOXHead(**head_cfg_dict)\n\n\ndef build_yolox_fpn(fpn_cfg: DictConfig, in_channels: Tuple[int, ...]):\n    fpn_cfg_dict = OmegaConf.to_container(fpn_cfg, resolve=True, throw_on_missing=True)\n    fpn_name = fpn_cfg_dict.pop(\"name\")\n    fpn_cfg_dict.update({\"in_channels\": in_channels})\n    if fpn_name in {\"PAFPN\", \"pafpn\"}:\n        compile_cfg = fpn_cfg_dict.pop(\"compile\", None)\n        fpn_cfg_dict.update({\"compile_cfg\": compile_cfg})\n        return YOLOPAFPN(**fpn_cfg_dict)\n    raise NotImplementedError\n"
  },
  {
    "path": "RVT/models/detection/yolox_extension/models/detector.py",
    "content": "from typing import Dict, Optional, Tuple, Union\n\nimport torch as th\nfrom omegaconf import DictConfig\n\ntry:\n    from torch import compile as th_compile\nexcept ImportError:\n    th_compile = None\n\nfrom ...recurrent_backbone import build_recurrent_backbone\nfrom .build import build_yolox_fpn, build_yolox_head\nfrom utils.timers import TimerDummy as CudaTimer\n\nfrom data.utils.types import BackboneFeatures, LstmStates\n\n\nclass YoloXDetector(th.nn.Module):\n    def __init__(self, model_cfg: DictConfig):\n        super().__init__()\n        backbone_cfg = model_cfg.backbone\n        fpn_cfg = model_cfg.fpn\n        head_cfg = model_cfg.head\n\n        self.backbone = build_recurrent_backbone(backbone_cfg)\n\n        in_channels = self.backbone.get_stage_dims(fpn_cfg.in_stages)\n        self.fpn = build_yolox_fpn(fpn_cfg, in_channels=in_channels)\n\n        strides = self.backbone.get_strides(fpn_cfg.in_stages)\n        self.yolox_head = build_yolox_head(\n            head_cfg, in_channels=in_channels, strides=strides\n        )\n\n    def forward_backbone(\n        self,\n        x: th.Tensor,\n        previous_states: Optional[LstmStates] = None,\n        token_mask: Optional[th.Tensor] = None,\n        train_step: bool = True,\n    ) -> Tuple[BackboneFeatures, LstmStates]:\n        with CudaTimer(device=x.device, timer_name=\"Backbone\"):\n            backbone_features, states = self.backbone(\n                x, previous_states, token_mask, train_step\n            )\n        return backbone_features, states\n\n    def forward_detect(\n        self, backbone_features: BackboneFeatures, targets: Optional[th.Tensor] = None\n    ) -> Tuple[th.Tensor, Union[Dict[str, th.Tensor], None]]:\n        device = next(iter(backbone_features.values())).device\n        with CudaTimer(device=device, timer_name=\"FPN\"):\n            fpn_features = self.fpn(backbone_features)\n        if self.training:\n            assert targets is not None\n            with CudaTimer(device=device, timer_name=\"HEAD + Loss\"):\n                outputs, losses = self.yolox_head(fpn_features, targets)\n            return outputs, losses\n        with CudaTimer(device=device, timer_name=\"HEAD\"):\n            outputs, losses = self.yolox_head(fpn_features)\n        assert losses is None\n        return outputs, losses\n\n    def forward(\n        self,\n        x: th.Tensor,\n        previous_states: Optional[LstmStates] = None,\n        retrieve_detections: bool = True,\n        targets: Optional[th.Tensor] = None,\n    ) -> Tuple[Union[th.Tensor, None], Union[Dict[str, th.Tensor], None], LstmStates]:\n        backbone_features, states = self.forward_backbone(x, previous_states)\n        outputs, losses = None, None\n        if not retrieve_detections:\n            assert targets is None\n            return outputs, losses, states\n        outputs, losses = self.forward_detect(\n            backbone_features=backbone_features, targets=targets\n        )\n        return outputs, losses, states\n"
  },
  {
    "path": "RVT/models/detection/yolox_extension/models/yolo_pafpn.py",
    "content": "\"\"\"\nOriginal Yolox PAFPN code with slight modifications\n\"\"\"\n\nfrom typing import Dict, Optional, Tuple\n\nimport torch as th\nimport torch.nn as nn\n\ntry:\n    from torch import compile as th_compile\nexcept ImportError:\n    th_compile = None\n\nfrom ...yolox.models.network_blocks import BaseConv, CSPLayer, DWConv\nfrom data.utils.types import BackboneFeatures\n\n\nclass YOLOPAFPN(nn.Module):\n    \"\"\"\n    Removed the direct dependency on the backbone.\n    \"\"\"\n\n    def __init__(\n        self,\n        depth: float = 1.0,\n        in_stages: Tuple[int, ...] = (2, 3, 4),\n        in_channels: Tuple[int, ...] = (256, 512, 1024),\n        depthwise: bool = False,\n        act: str = \"silu\",\n        compile_cfg: Optional[Dict] = None,\n    ):\n        super().__init__()\n        assert len(in_stages) == len(in_channels)\n        assert len(in_channels) == 3, \"Current implementation only for 3 feature maps\"\n        self.in_features = in_stages\n        self.in_channels = in_channels\n        Conv = DWConv if depthwise else BaseConv\n\n        ###### Compile if requested ######\n        if compile_cfg is not None:\n            compile_mdl = compile_cfg[\"enable\"]\n            if compile_mdl and th_compile is not None:\n                self.forward = th_compile(self.forward, **compile_cfg[\"args\"])\n            elif compile_mdl:\n                print(\"Could not compile PAFPN because torch.compile is not available\")\n\n        ##################################\n\n        self.upsample = lambda x: nn.functional.interpolate(\n            x, scale_factor=2, mode=\"nearest-exact\"\n        )\n        self.lateral_conv0 = BaseConv(in_channels[2], in_channels[1], 1, 1, act=act)\n        self.C3_p4 = CSPLayer(\n            2 * in_channels[1],\n            in_channels[1],\n            round(3 * depth),\n            False,\n            depthwise=depthwise,\n            act=act,\n        )  # cat\n\n        self.reduce_conv1 = BaseConv(in_channels[1], in_channels[0], 1, 1, act=act)\n        self.C3_p3 = CSPLayer(\n            2 * in_channels[0],\n            in_channels[0],\n            round(3 * depth),\n            False,\n            depthwise=depthwise,\n            act=act,\n        )\n\n        # bottom-up conv\n        self.bu_conv2 = Conv(in_channels[0], in_channels[0], 3, 2, act=act)\n        self.C3_n3 = CSPLayer(\n            2 * in_channels[0],\n            in_channels[1],\n            round(3 * depth),\n            False,\n            depthwise=depthwise,\n            act=act,\n        )\n\n        # bottom-up conv\n        self.bu_conv1 = Conv(in_channels[1], in_channels[1], 3, 2, act=act)\n        self.C3_n4 = CSPLayer(\n            2 * in_channels[1],\n            in_channels[2],\n            round(3 * depth),\n            False,\n            depthwise=depthwise,\n            act=act,\n        )\n\n        ###### Compile if requested ######\n        if compile_cfg is not None:\n            compile_mdl = compile_cfg[\"enable\"]\n            if compile_mdl and th_compile is not None:\n                self.forward = th_compile(self.forward, **compile_cfg[\"args\"])\n            elif compile_mdl:\n                print(\"Could not compile PAFPN because torch.compile is not available\")\n        ##################################\n\n    def forward(self, input: BackboneFeatures):\n        \"\"\"\n        Args:\n            inputs: Feature maps from backbone\n\n        Returns:\n            Tuple[Tensor]: FPN feature.\n        \"\"\"\n        features = [input[f] for f in self.in_features]\n        x2, x1, x0 = features\n\n        fpn_out0 = self.lateral_conv0(x0)  # 1024->512/32\n        f_out0 = self.upsample(fpn_out0)  # 512/16\n        f_out0 = th.cat([f_out0, x1], 1)  # 512->1024/16\n        f_out0 = self.C3_p4(f_out0)  # 1024->512/16\n\n        fpn_out1 = self.reduce_conv1(f_out0)  # 512->256/16\n        f_out1 = self.upsample(fpn_out1)  # 256/8\n        f_out1 = th.cat([f_out1, x2], 1)  # 256->512/8\n        pan_out2 = self.C3_p3(f_out1)  # 512->256/8\n\n        p_out1 = self.bu_conv2(pan_out2)  # 256->256/16\n        p_out1 = th.cat([p_out1, fpn_out1], 1)  # 256->512/16\n        pan_out1 = self.C3_n3(p_out1)  # 512->512/16\n\n        p_out0 = self.bu_conv1(pan_out1)  # 512->512/32\n        p_out0 = th.cat([p_out0, fpn_out0], 1)  # 512->1024/32\n        pan_out0 = self.C3_n4(p_out0)  # 1024->1024/32\n\n        outputs = (pan_out2, pan_out1, pan_out0)\n        return outputs\n"
  },
  {
    "path": "RVT/models/layers/maxvit/__init__.py",
    "content": ""
  },
  {
    "path": "RVT/models/layers/maxvit/layers/__init__.py",
    "content": "from .activations import *\nfrom .adaptive_avgmax_pool import (\n    adaptive_avgmax_pool2d,\n    select_adaptive_pool2d,\n    AdaptiveAvgMaxPool2d,\n    SelectAdaptivePool2d,\n)\nfrom .blur_pool import BlurPool2d\nfrom .classifier import ClassifierHead, create_classifier\nfrom .cond_conv2d import CondConv2d, get_condconv_initializer\nfrom .config import (\n    is_exportable,\n    is_scriptable,\n    is_no_jit,\n    set_exportable,\n    set_scriptable,\n    set_no_jit,\n    set_layer_config,\n)\nfrom .conv2d_same import Conv2dSame, conv2d_same\nfrom .conv_bn_act import ConvNormAct, ConvNormActAa, ConvBnAct\nfrom .create_act import create_act_layer, get_act_layer, get_act_fn\nfrom .create_attn import get_attn, create_attn\nfrom .create_conv2d import create_conv2d\nfrom .create_norm import get_norm_layer, create_norm_layer\nfrom .create_norm_act import (\n    get_norm_act_layer,\n    create_norm_act_layer,\n    get_norm_act_layer,\n)\nfrom .drop import DropBlock2d, DropPath, drop_block_2d, drop_path\nfrom .eca import (\n    EcaModule,\n    CecaModule,\n    EfficientChannelAttn,\n    CircularEfficientChannelAttn,\n)\nfrom .evo_norm import (\n    EvoNorm2dB0,\n    EvoNorm2dB1,\n    EvoNorm2dB2,\n    EvoNorm2dS0,\n    EvoNorm2dS0a,\n    EvoNorm2dS1,\n    EvoNorm2dS1a,\n    EvoNorm2dS2,\n    EvoNorm2dS2a,\n)\nfrom .fast_norm import is_fast_norm, set_fast_norm, fast_group_norm, fast_layer_norm\nfrom .filter_response_norm import FilterResponseNormTlu2d, FilterResponseNormAct2d\nfrom .gather_excite import GatherExcite\nfrom .global_context import GlobalContext\nfrom .helpers import (\n    to_ntuple,\n    to_2tuple,\n    to_3tuple,\n    to_4tuple,\n    make_divisible,\n    extend_tuple,\n)\nfrom .inplace_abn import InplaceAbn\nfrom .linear import Linear\nfrom .mixed_conv2d import MixedConv2d\nfrom .mlp import Mlp, GluMlp, GatedMlp, ConvMlp\nfrom .non_local_attn import NonLocalAttn, BatNonLocalAttn\nfrom .norm import GroupNorm, GroupNorm1, LayerNorm, LayerNorm2d\nfrom .norm_act import BatchNormAct2d, GroupNormAct, convert_sync_batchnorm\nfrom .padding import get_padding, get_same_padding, pad_same\nfrom .patch_embed import PatchEmbed\nfrom .pool2d_same import AvgPool2dSame, create_pool2d\nfrom .squeeze_excite import (\n    SEModule,\n    SqueezeExcite,\n    EffectiveSEModule,\n    EffectiveSqueezeExcite,\n)\nfrom .selective_kernel import SelectiveKernel\nfrom .separable_conv import SeparableConv2d, SeparableConvNormAct\nfrom .space_to_depth import SpaceToDepthModule\nfrom .split_attn import SplitAttn\nfrom .split_batchnorm import SplitBatchNorm2d, convert_splitbn_model\nfrom .std_conv import StdConv2d, StdConv2dSame, ScaledStdConv2d, ScaledStdConv2dSame\nfrom .test_time_pool import TestTimePoolHead, apply_test_time_pool\nfrom .trace_utils import _assert, _float_to_int\nfrom .weight_init import (\n    trunc_normal_,\n    trunc_normal_tf_,\n    variance_scaling_,\n    lecun_normal_,\n)\n"
  },
  {
    "path": "RVT/models/layers/maxvit/layers/activations.py",
    "content": "\"\"\" Activations\n\nA collection of activations fn and modules with a common interface so that they can\neasily be swapped. All have an `inplace` arg even if not used.\n\nHacked together by / Copyright 2020 Ross Wightman\n\"\"\"\n\nimport torch\nfrom torch import nn as nn\nfrom torch.nn import functional as F\n\n\ndef swish(x, inplace: bool = False):\n    \"\"\"Swish - Described in: https://arxiv.org/abs/1710.05941\"\"\"\n    return x.mul_(x.sigmoid()) if inplace else x.mul(x.sigmoid())\n\n\nclass Swish(nn.Module):\n    def __init__(self, inplace: bool = False):\n        super(Swish, self).__init__()\n        self.inplace = inplace\n\n    def forward(self, x):\n        return swish(x, self.inplace)\n\n\ndef mish(x, inplace: bool = False):\n    \"\"\"Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681\n    NOTE: I don't have a working inplace variant\n    \"\"\"\n    return x.mul(F.softplus(x).tanh())\n\n\nclass Mish(nn.Module):\n    \"\"\"Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681\"\"\"\n\n    def __init__(self, inplace: bool = False):\n        super(Mish, self).__init__()\n\n    def forward(self, x):\n        return mish(x)\n\n\ndef sigmoid(x, inplace: bool = False):\n    return x.sigmoid_() if inplace else x.sigmoid()\n\n\n# PyTorch has this, but not with a consistent inplace argmument interface\nclass Sigmoid(nn.Module):\n    def __init__(self, inplace: bool = False):\n        super(Sigmoid, self).__init__()\n        self.inplace = inplace\n\n    def forward(self, x):\n        return x.sigmoid_() if self.inplace else x.sigmoid()\n\n\ndef tanh(x, inplace: bool = False):\n    return x.tanh_() if inplace else x.tanh()\n\n\n# PyTorch has this, but not with a consistent inplace argmument interface\nclass Tanh(nn.Module):\n    def __init__(self, inplace: bool = False):\n        super(Tanh, self).__init__()\n        self.inplace = inplace\n\n    def forward(self, x):\n        return x.tanh_() if self.inplace else x.tanh()\n\n\ndef hard_swish(x, inplace: bool = False):\n    inner = F.relu6(x + 3.0).div_(6.0)\n    return x.mul_(inner) if inplace else x.mul(inner)\n\n\nclass HardSwish(nn.Module):\n    def __init__(self, inplace: bool = False):\n        super(HardSwish, self).__init__()\n        self.inplace = inplace\n\n    def forward(self, x):\n        return hard_swish(x, self.inplace)\n\n\ndef hard_sigmoid(x, inplace: bool = False):\n    if inplace:\n        return x.add_(3.0).clamp_(0.0, 6.0).div_(6.0)\n    else:\n        return F.relu6(x + 3.0) / 6.0\n\n\nclass HardSigmoid(nn.Module):\n    def __init__(self, inplace: bool = False):\n        super(HardSigmoid, self).__init__()\n        self.inplace = inplace\n\n    def forward(self, x):\n        return hard_sigmoid(x, self.inplace)\n\n\ndef hard_mish(x, inplace: bool = False):\n    \"\"\"Hard Mish\n    Experimental, based on notes by Mish author Diganta Misra at\n      https://github.com/digantamisra98/H-Mish/blob/0da20d4bc58e696b6803f2523c58d3c8a82782d0/README.md\n    \"\"\"\n    if inplace:\n        return x.mul_(0.5 * (x + 2).clamp(min=0, max=2))\n    else:\n        return 0.5 * x * (x + 2).clamp(min=0, max=2)\n\n\nclass HardMish(nn.Module):\n    def __init__(self, inplace: bool = False):\n        super(HardMish, self).__init__()\n        self.inplace = inplace\n\n    def forward(self, x):\n        return hard_mish(x, self.inplace)\n\n\nclass PReLU(nn.PReLU):\n    \"\"\"Applies PReLU (w/ dummy inplace arg)\"\"\"\n\n    def __init__(\n        self, num_parameters: int = 1, init: float = 0.25, inplace: bool = False\n    ) -> None:\n        super(PReLU, self).__init__(num_parameters=num_parameters, init=init)\n\n    def forward(self, input: torch.Tensor) -> torch.Tensor:\n        return F.prelu(input, self.weight)\n\n\ndef gelu(x: torch.Tensor, inplace: bool = False) -> torch.Tensor:\n    return F.gelu(x)\n\n\nclass GELU(nn.Module):\n    \"\"\"Applies the Gaussian Error Linear Units function (w/ dummy inplace arg)\"\"\"\n\n    def __init__(self, inplace: bool = False):\n        super(GELU, self).__init__()\n\n    def forward(self, input: torch.Tensor) -> torch.Tensor:\n        return F.gelu(input)\n"
  },
  {
    "path": "RVT/models/layers/maxvit/layers/activations_jit.py",
    "content": "\"\"\" Activations\n\nA collection of jit-scripted activations fn and modules with a common interface so that they can\neasily be swapped. All have an `inplace` arg even if not used.\n\nAll jit scripted activations are lacking in-place variations on purpose, scripted kernel fusion does not\ncurrently work across in-place op boundaries, thus performance is equal to or less than the non-scripted\nversions if they contain in-place ops.\n\nHacked together by / Copyright 2020 Ross Wightman\n\"\"\"\n\nimport torch\nfrom torch import nn as nn\nfrom torch.nn import functional as F\n\n\n@torch.jit.script\ndef swish_jit(x, inplace: bool = False):\n    \"\"\"Swish - Described in: https://arxiv.org/abs/1710.05941\"\"\"\n    return x.mul(x.sigmoid())\n\n\n@torch.jit.script\ndef mish_jit(x, _inplace: bool = False):\n    \"\"\"Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681\"\"\"\n    return x.mul(F.softplus(x).tanh())\n\n\nclass SwishJit(nn.Module):\n    def __init__(self, inplace: bool = False):\n        super(SwishJit, self).__init__()\n\n    def forward(self, x):\n        return swish_jit(x)\n\n\nclass MishJit(nn.Module):\n    def __init__(self, inplace: bool = False):\n        super(MishJit, self).__init__()\n\n    def forward(self, x):\n        return mish_jit(x)\n\n\n@torch.jit.script\ndef hard_sigmoid_jit(x, inplace: bool = False):\n    # return F.relu6(x + 3.) / 6.\n    return (x + 3).clamp(min=0, max=6).div(6.0)  # clamp seems ever so slightly faster?\n\n\nclass HardSigmoidJit(nn.Module):\n    def __init__(self, inplace: bool = False):\n        super(HardSigmoidJit, self).__init__()\n\n    def forward(self, x):\n        return hard_sigmoid_jit(x)\n\n\n@torch.jit.script\ndef hard_swish_jit(x, inplace: bool = False):\n    # return x * (F.relu6(x + 3.) / 6)\n    return x * (x + 3).clamp(min=0, max=6).div(\n        6.0\n    )  # clamp seems ever so slightly faster?\n\n\nclass HardSwishJit(nn.Module):\n    def __init__(self, inplace: bool = False):\n        super(HardSwishJit, self).__init__()\n\n    def forward(self, x):\n        return hard_swish_jit(x)\n\n\n@torch.jit.script\ndef hard_mish_jit(x, inplace: bool = False):\n    \"\"\"Hard Mish\n    Experimental, based on notes by Mish author Diganta Misra at\n      https://github.com/digantamisra98/H-Mish/blob/0da20d4bc58e696b6803f2523c58d3c8a82782d0/README.md\n    \"\"\"\n    return 0.5 * x * (x + 2).clamp(min=0, max=2)\n\n\nclass HardMishJit(nn.Module):\n    def __init__(self, inplace: bool = False):\n        super(HardMishJit, self).__init__()\n\n    def forward(self, x):\n        return hard_mish_jit(x)\n"
  },
  {
    "path": "RVT/models/layers/maxvit/layers/activations_me.py",
    "content": "\"\"\" Activations (memory-efficient w/ custom autograd)\n\nA collection of activations fn and modules with a common interface so that they can\neasily be swapped. All have an `inplace` arg even if not used.\n\nThese activations are not compatible with jit scripting or ONNX export of the model, please use either\nthe JIT or basic versions of the activations.\n\nHacked together by / Copyright 2020 Ross Wightman\n\"\"\"\n\nimport torch\nfrom torch import nn as nn\nfrom torch.nn import functional as F\n\n\n@torch.jit.script\ndef swish_jit_fwd(x):\n    return x.mul(torch.sigmoid(x))\n\n\n@torch.jit.script\ndef swish_jit_bwd(x, grad_output):\n    x_sigmoid = torch.sigmoid(x)\n    return grad_output * (x_sigmoid * (1 + x * (1 - x_sigmoid)))\n\n\nclass SwishJitAutoFn(torch.autograd.Function):\n    \"\"\"torch.jit.script optimised Swish w/ memory-efficient checkpoint\n    Inspired by conversation btw Jeremy Howard & Adam Pazske\n    https://twitter.com/jeremyphoward/status/1188251041835315200\n    \"\"\"\n\n    @staticmethod\n    def symbolic(g, x):\n        return g.op(\"Mul\", x, g.op(\"Sigmoid\", x))\n\n    @staticmethod\n    def forward(ctx, x):\n        ctx.save_for_backward(x)\n        return swish_jit_fwd(x)\n\n    @staticmethod\n    def backward(ctx, grad_output):\n        x = ctx.saved_tensors[0]\n        return swish_jit_bwd(x, grad_output)\n\n\ndef swish_me(x, inplace=False):\n    return SwishJitAutoFn.apply(x)\n\n\nclass SwishMe(nn.Module):\n    def __init__(self, inplace: bool = False):\n        super(SwishMe, self).__init__()\n\n    def forward(self, x):\n        return SwishJitAutoFn.apply(x)\n\n\n@torch.jit.script\ndef mish_jit_fwd(x):\n    return x.mul(torch.tanh(F.softplus(x)))\n\n\n@torch.jit.script\ndef mish_jit_bwd(x, grad_output):\n    x_sigmoid = torch.sigmoid(x)\n    x_tanh_sp = F.softplus(x).tanh()\n    return grad_output.mul(x_tanh_sp + x * x_sigmoid * (1 - x_tanh_sp * x_tanh_sp))\n\n\nclass MishJitAutoFn(torch.autograd.Function):\n    \"\"\"Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681\n    A memory efficient, jit scripted variant of Mish\n    \"\"\"\n\n    @staticmethod\n    def forward(ctx, x):\n        ctx.save_for_backward(x)\n        return mish_jit_fwd(x)\n\n    @staticmethod\n    def backward(ctx, grad_output):\n        x = ctx.saved_tensors[0]\n        return mish_jit_bwd(x, grad_output)\n\n\ndef mish_me(x, inplace=False):\n    return MishJitAutoFn.apply(x)\n\n\nclass MishMe(nn.Module):\n    def __init__(self, inplace: bool = False):\n        super(MishMe, self).__init__()\n\n    def forward(self, x):\n        return MishJitAutoFn.apply(x)\n\n\n@torch.jit.script\ndef hard_sigmoid_jit_fwd(x, inplace: bool = False):\n    return (x + 3).clamp(min=0, max=6).div(6.0)\n\n\n@torch.jit.script\ndef hard_sigmoid_jit_bwd(x, grad_output):\n    m = torch.ones_like(x) * ((x >= -3.0) & (x <= 3.0)) / 6.0\n    return grad_output * m\n\n\nclass HardSigmoidJitAutoFn(torch.autograd.Function):\n    @staticmethod\n    def forward(ctx, x):\n        ctx.save_for_backward(x)\n        return hard_sigmoid_jit_fwd(x)\n\n    @staticmethod\n    def backward(ctx, grad_output):\n        x = ctx.saved_tensors[0]\n        return hard_sigmoid_jit_bwd(x, grad_output)\n\n\ndef hard_sigmoid_me(x, inplace: bool = False):\n    return HardSigmoidJitAutoFn.apply(x)\n\n\nclass HardSigmoidMe(nn.Module):\n    def __init__(self, inplace: bool = False):\n        super(HardSigmoidMe, self).__init__()\n\n    def forward(self, x):\n        return HardSigmoidJitAutoFn.apply(x)\n\n\n@torch.jit.script\ndef hard_swish_jit_fwd(x):\n    return x * (x + 3).clamp(min=0, max=6).div(6.0)\n\n\n@torch.jit.script\ndef hard_swish_jit_bwd(x, grad_output):\n    m = torch.ones_like(x) * (x >= 3.0)\n    m = torch.where((x >= -3.0) & (x <= 3.0), x / 3.0 + 0.5, m)\n    return grad_output * m\n\n\nclass HardSwishJitAutoFn(torch.autograd.Function):\n    \"\"\"A memory efficient, jit-scripted HardSwish activation\"\"\"\n\n    @staticmethod\n    def forward(ctx, x):\n        ctx.save_for_backward(x)\n        return hard_swish_jit_fwd(x)\n\n    @staticmethod\n    def backward(ctx, grad_output):\n        x = ctx.saved_tensors[0]\n        return hard_swish_jit_bwd(x, grad_output)\n\n    @staticmethod\n    def symbolic(g, self):\n        input = g.op(\n            \"Add\", self, g.op(\"Constant\", value_t=torch.tensor(3, dtype=torch.float))\n        )\n        hardtanh_ = g.op(\n            \"Clip\",\n            input,\n            g.op(\"Constant\", value_t=torch.tensor(0, dtype=torch.float)),\n            g.op(\"Constant\", value_t=torch.tensor(6, dtype=torch.float)),\n        )\n        hardtanh_ = g.op(\n            \"Div\",\n            hardtanh_,\n            g.op(\"Constant\", value_t=torch.tensor(6, dtype=torch.float)),\n        )\n        return g.op(\"Mul\", self, hardtanh_)\n\n\ndef hard_swish_me(x, inplace=False):\n    return HardSwishJitAutoFn.apply(x)\n\n\nclass HardSwishMe(nn.Module):\n    def __init__(self, inplace: bool = False):\n        super(HardSwishMe, self).__init__()\n\n    def forward(self, x):\n        return HardSwishJitAutoFn.apply(x)\n\n\n@torch.jit.script\ndef hard_mish_jit_fwd(x):\n    return 0.5 * x * (x + 2).clamp(min=0, max=2)\n\n\n@torch.jit.script\ndef hard_mish_jit_bwd(x, grad_output):\n    m = torch.ones_like(x) * (x >= -2.0)\n    m = torch.where((x >= -2.0) & (x <= 0.0), x + 1.0, m)\n    return grad_output * m\n\n\nclass HardMishJitAutoFn(torch.autograd.Function):\n    \"\"\"A memory efficient, jit scripted variant of Hard Mish\n    Experimental, based on notes by Mish author Diganta Misra at\n      https://github.com/digantamisra98/H-Mish/blob/0da20d4bc58e696b6803f2523c58d3c8a82782d0/README.md\n    \"\"\"\n\n    @staticmethod\n    def forward(ctx, x):\n        ctx.save_for_backward(x)\n        return hard_mish_jit_fwd(x)\n\n    @staticmethod\n    def backward(ctx, grad_output):\n        x = ctx.saved_tensors[0]\n        return hard_mish_jit_bwd(x, grad_output)\n\n\ndef hard_mish_me(x, inplace: bool = False):\n    return HardMishJitAutoFn.apply(x)\n\n\nclass HardMishMe(nn.Module):\n    def __init__(self, inplace: bool = False):\n        super(HardMishMe, self).__init__()\n\n    def forward(self, x):\n        return HardMishJitAutoFn.apply(x)\n"
  },
  {
    "path": "RVT/models/layers/maxvit/layers/adaptive_avgmax_pool.py",
    "content": "\"\"\" PyTorch selectable adaptive pooling\nAdaptive pooling with the ability to select the type of pooling from:\n    * 'avg' - Average pooling\n    * 'max' - Max pooling\n    * 'avgmax' - Sum of average and max pooling re-scaled by 0.5\n    * 'avgmaxc' - Concatenation of average and max pooling along feature dim, doubles feature dim\n\nBoth a functional and a nn.Module version of the pooling is provided.\n\nHacked together by / Copyright 2020 Ross Wightman\n\"\"\"\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\n\ndef adaptive_pool_feat_mult(pool_type=\"avg\"):\n    if pool_type == \"catavgmax\":\n        return 2\n    else:\n        return 1\n\n\ndef adaptive_avgmax_pool2d(x, output_size=1):\n    x_avg = F.adaptive_avg_pool2d(x, output_size)\n    x_max = F.adaptive_max_pool2d(x, output_size)\n    return 0.5 * (x_avg + x_max)\n\n\ndef adaptive_catavgmax_pool2d(x, output_size=1):\n    x_avg = F.adaptive_avg_pool2d(x, output_size)\n    x_max = F.adaptive_max_pool2d(x, output_size)\n    return torch.cat((x_avg, x_max), 1)\n\n\ndef select_adaptive_pool2d(x, pool_type=\"avg\", output_size=1):\n    \"\"\"Selectable global pooling function with dynamic input kernel size\"\"\"\n    if pool_type == \"avg\":\n        x = F.adaptive_avg_pool2d(x, output_size)\n    elif pool_type == \"avgmax\":\n        x = adaptive_avgmax_pool2d(x, output_size)\n    elif pool_type == \"catavgmax\":\n        x = adaptive_catavgmax_pool2d(x, output_size)\n    elif pool_type == \"max\":\n        x = F.adaptive_max_pool2d(x, output_size)\n    else:\n        assert False, \"Invalid pool type: %s\" % pool_type\n    return x\n\n\nclass FastAdaptiveAvgPool2d(nn.Module):\n    def __init__(self, flatten=False):\n        super(FastAdaptiveAvgPool2d, self).__init__()\n        self.flatten = flatten\n\n    def forward(self, x):\n        return x.mean((2, 3), keepdim=not self.flatten)\n\n\nclass AdaptiveAvgMaxPool2d(nn.Module):\n    def __init__(self, output_size=1):\n        super(AdaptiveAvgMaxPool2d, self).__init__()\n        self.output_size = output_size\n\n    def forward(self, x):\n        return adaptive_avgmax_pool2d(x, self.output_size)\n\n\nclass AdaptiveCatAvgMaxPool2d(nn.Module):\n    def __init__(self, output_size=1):\n        super(AdaptiveCatAvgMaxPool2d, self).__init__()\n        self.output_size = output_size\n\n    def forward(self, x):\n        return adaptive_catavgmax_pool2d(x, self.output_size)\n\n\nclass SelectAdaptivePool2d(nn.Module):\n    \"\"\"Selectable global pooling layer with dynamic input kernel size\"\"\"\n\n    def __init__(self, output_size=1, pool_type=\"fast\", flatten=False):\n        super(SelectAdaptivePool2d, self).__init__()\n        self.pool_type = (\n            pool_type or \"\"\n        )  # convert other falsy values to empty string for consistent TS typing\n        self.flatten = nn.Flatten(1) if flatten else nn.Identity()\n        if pool_type == \"\":\n            self.pool = nn.Identity()  # pass through\n        elif pool_type == \"fast\":\n            assert output_size == 1\n            self.pool = FastAdaptiveAvgPool2d(flatten)\n            self.flatten = nn.Identity()\n        elif pool_type == \"avg\":\n            self.pool = nn.AdaptiveAvgPool2d(output_size)\n        elif pool_type == \"avgmax\":\n            self.pool = AdaptiveAvgMaxPool2d(output_size)\n        elif pool_type == \"catavgmax\":\n            self.pool = AdaptiveCatAvgMaxPool2d(output_size)\n        elif pool_type == \"max\":\n            self.pool = nn.AdaptiveMaxPool2d(output_size)\n        else:\n            assert False, \"Invalid pool type: %s\" % pool_type\n\n    def is_identity(self):\n        return not self.pool_type\n\n    def forward(self, x):\n        x = self.pool(x)\n        x = self.flatten(x)\n        return x\n\n    def feat_mult(self):\n        return adaptive_pool_feat_mult(self.pool_type)\n\n    def __repr__(self):\n        return (\n            self.__class__.__name__\n            + \" (\"\n            + \"pool_type=\"\n            + self.pool_type\n            + \", flatten=\"\n            + str(self.flatten)\n            + \")\"\n        )\n"
  },
  {
    "path": "RVT/models/layers/maxvit/layers/attention_pool2d.py",
    "content": "\"\"\" Attention Pool 2D\n\nImplementations of 2D spatial feature pooling using multi-head attention instead of average pool.\n\nBased on idea in CLIP by OpenAI, licensed Apache 2.0\nhttps://github.com/openai/CLIP/blob/3b473b0e682c091a9e53623eebc1ca1657385717/clip/model.py\n\nHacked together by / Copyright 2021 Ross Wightman\n\"\"\"\n\nfrom typing import Union, Tuple\n\nimport torch\nimport torch.nn as nn\n\nfrom .helpers import to_2tuple\nfrom .pos_embed import apply_rot_embed, RotaryEmbedding\nfrom .weight_init import trunc_normal_\n\n\nclass RotAttentionPool2d(nn.Module):\n    \"\"\"Attention based 2D feature pooling w/ rotary (relative) pos embedding.\n    This is a multi-head attention based replacement for (spatial) average pooling in NN architectures.\n\n    Adapted from the AttentionPool2d in CLIP w/ rotary embedding instead of learned embed.\n    https://github.com/openai/CLIP/blob/3b473b0e682c091a9e53623eebc1ca1657385717/clip/model.py\n\n    NOTE: While this impl does not require a fixed feature size, performance at differeing resolutions from\n    train varies widely and falls off dramatically. I'm not sure if there is a way around this... -RW\n    \"\"\"\n\n    def __init__(\n        self,\n        in_features: int,\n        out_features: int = None,\n        embed_dim: int = None,\n        num_heads: int = 4,\n        qkv_bias: bool = True,\n    ):\n        super().__init__()\n        embed_dim = embed_dim or in_features\n        out_features = out_features or in_features\n        self.qkv = nn.Linear(in_features, embed_dim * 3, bias=qkv_bias)\n        self.proj = nn.Linear(embed_dim, out_features)\n        self.num_heads = num_heads\n        assert embed_dim % num_heads == 0\n        self.head_dim = embed_dim // num_heads\n        self.scale = self.head_dim**-0.5\n        self.pos_embed = RotaryEmbedding(self.head_dim)\n\n        trunc_normal_(self.qkv.weight, std=in_features**-0.5)\n        nn.init.zeros_(self.qkv.bias)\n\n    def forward(self, x):\n        B, _, H, W = x.shape\n        N = H * W\n        x = x.reshape(B, -1, N).permute(0, 2, 1)\n\n        x = torch.cat([x.mean(1, keepdim=True), x], dim=1)\n\n        x = (\n            self.qkv(x)\n            .reshape(B, N + 1, 3, self.num_heads, self.head_dim)\n            .permute(2, 0, 3, 1, 4)\n        )\n        q, k, v = x[0], x[1], x[2]\n\n        qc, q = q[:, :, :1], q[:, :, 1:]\n        sin_emb, cos_emb = self.pos_embed.get_embed((H, W))\n        q = apply_rot_embed(q, sin_emb, cos_emb)\n        q = torch.cat([qc, q], dim=2)\n\n        kc, k = k[:, :, :1], k[:, :, 1:]\n        k = apply_rot_embed(k, sin_emb, cos_emb)\n        k = torch.cat([kc, k], dim=2)\n\n        attn = (q @ k.transpose(-2, -1)) * self.scale\n        attn = attn.softmax(dim=-1)\n\n        x = (attn @ v).transpose(1, 2).reshape(B, N + 1, -1)\n        x = self.proj(x)\n        return x[:, 0]\n\n\nclass AttentionPool2d(nn.Module):\n    \"\"\"Attention based 2D feature pooling w/ learned (absolute) pos embedding.\n    This is a multi-head attention based replacement for (spatial) average pooling in NN architectures.\n\n    It was based on impl in CLIP by OpenAI\n    https://github.com/openai/CLIP/blob/3b473b0e682c091a9e53623eebc1ca1657385717/clip/model.py\n\n    NOTE: This requires feature size upon construction and well prevent adaptive sizing of the network.\n    \"\"\"\n\n    def __init__(\n        self,\n        in_features: int,\n        feat_size: Union[int, Tuple[int, int]],\n        out_features: int = None,\n        embed_dim: int = None,\n        num_heads: int = 4,\n        qkv_bias: bool = True,\n    ):\n        super().__init__()\n\n        embed_dim = embed_dim or in_features\n        out_features = out_features or in_features\n        assert embed_dim % num_heads == 0\n        self.feat_size = to_2tuple(feat_size)\n        self.qkv = nn.Linear(in_features, embed_dim * 3, bias=qkv_bias)\n        self.proj = nn.Linear(embed_dim, out_features)\n        self.num_heads = num_heads\n        self.head_dim = embed_dim // num_heads\n        self.scale = self.head_dim**-0.5\n\n        spatial_dim = self.feat_size[0] * self.feat_size[1]\n        self.pos_embed = nn.Parameter(torch.zeros(spatial_dim + 1, in_features))\n        trunc_normal_(self.pos_embed, std=in_features**-0.5)\n        trunc_normal_(self.qkv.weight, std=in_features**-0.5)\n        nn.init.zeros_(self.qkv.bias)\n\n    def forward(self, x):\n        B, _, H, W = x.shape\n        N = H * W\n        assert self.feat_size[0] == H\n        assert self.feat_size[1] == W\n        x = x.reshape(B, -1, N).permute(0, 2, 1)\n        x = torch.cat([x.mean(1, keepdim=True), x], dim=1)\n        x = x + self.pos_embed.unsqueeze(0).to(x.dtype)\n\n        x = (\n            self.qkv(x)\n            .reshape(B, N + 1, 3, self.num_heads, self.head_dim)\n            .permute(2, 0, 3, 1, 4)\n        )\n        q, k, v = x[0], x[1], x[2]\n        attn = (q @ k.transpose(-2, -1)) * self.scale\n        attn = attn.softmax(dim=-1)\n\n        x = (attn @ v).transpose(1, 2).reshape(B, N + 1, -1)\n        x = self.proj(x)\n        return x[:, 0]\n"
  },
  {
    "path": "RVT/models/layers/maxvit/layers/blur_pool.py",
    "content": "\"\"\"\nBlurPool layer inspired by\n - Kornia's Max_BlurPool2d\n - Making Convolutional Networks Shift-Invariant Again :cite:`zhang2019shiftinvar`\n\nHacked together by Chris Ha and Ross Wightman\n\"\"\"\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport numpy as np\nfrom .padding import get_padding\n\n\nclass BlurPool2d(nn.Module):\n    r\"\"\"Creates a module that computes blurs and downsample a given feature map.\n    See :cite:`zhang2019shiftinvar` for more details.\n    Corresponds to the Downsample class, which does blurring and subsampling\n\n    Args:\n        channels = Number of input channels\n        filt_size (int): binomial filter size for blurring. currently supports 3 (default) and 5.\n        stride (int): downsampling filter stride\n\n    Returns:\n        torch.Tensor: the transformed tensor.\n    \"\"\"\n\n    def __init__(self, channels, filt_size=3, stride=2) -> None:\n        super(BlurPool2d, self).__init__()\n        assert filt_size > 1\n        self.channels = channels\n        self.filt_size = filt_size\n        self.stride = stride\n        self.padding = [get_padding(filt_size, stride, dilation=1)] * 4\n        coeffs = torch.tensor(\n            (np.poly1d((0.5, 0.5)) ** (self.filt_size - 1)).coeffs.astype(np.float32)\n        )\n        blur_filter = (coeffs[:, None] * coeffs[None, :])[None, None, :, :].repeat(\n            self.channels, 1, 1, 1\n        )\n        self.register_buffer(\"filt\", blur_filter, persistent=False)\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        x = F.pad(x, self.padding, \"reflect\")\n        return F.conv2d(x, self.filt, stride=self.stride, groups=self.channels)\n"
  },
  {
    "path": "RVT/models/layers/maxvit/layers/bottleneck_attn.py",
    "content": "\"\"\" Bottleneck Self Attention (Bottleneck Transformers)\n\nPaper: `Bottleneck Transformers for Visual Recognition` - https://arxiv.org/abs/2101.11605\n\n@misc{2101.11605,\nAuthor = {Aravind Srinivas and Tsung-Yi Lin and Niki Parmar and Jonathon Shlens and Pieter Abbeel and Ashish Vaswani},\nTitle = {Bottleneck Transformers for Visual Recognition},\nYear = {2021},\n}\n\nBased on ref gist at: https://gist.github.com/aravindsrinivas/56359b79f0ce4449bcb04ab4b56a57a2\n\nThis impl is a WIP but given that it is based on the ref gist likely not too far off.\n\nHacked together by / Copyright 2021 Ross Wightman\n\"\"\"\n\nfrom typing import List\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom .helpers import to_2tuple, make_divisible\nfrom .weight_init import trunc_normal_\nfrom .trace_utils import _assert\n\n\ndef rel_logits_1d(q, rel_k, permute_mask: List[int]):\n    \"\"\"Compute relative logits along one dimension\n\n    As per: https://gist.github.com/aravindsrinivas/56359b79f0ce4449bcb04ab4b56a57a2\n    Originally from: `Attention Augmented Convolutional Networks` - https://arxiv.org/abs/1904.09925\n\n    Args:\n        q: (batch, heads, height, width, dim)\n        rel_k: (2 * width - 1, dim)\n        permute_mask: permute output dim according to this\n    \"\"\"\n    B, H, W, dim = q.shape\n    x = q @ rel_k.transpose(-1, -2)\n    x = x.reshape(-1, W, 2 * W - 1)\n\n    # pad to shift from relative to absolute indexing\n    x_pad = F.pad(x, [0, 1]).flatten(1)\n    x_pad = F.pad(x_pad, [0, W - 1])\n\n    # reshape and slice out the padded elements\n    x_pad = x_pad.reshape(-1, W + 1, 2 * W - 1)\n    x = x_pad[:, :W, W - 1 :]\n\n    # reshape and tile\n    x = x.reshape(B, H, 1, W, W).expand(-1, -1, H, -1, -1)\n    return x.permute(permute_mask)\n\n\nclass PosEmbedRel(nn.Module):\n    \"\"\"Relative Position Embedding\n    As per: https://gist.github.com/aravindsrinivas/56359b79f0ce4449bcb04ab4b56a57a2\n    Originally from: `Attention Augmented Convolutional Networks` - https://arxiv.org/abs/1904.09925\n    \"\"\"\n\n    def __init__(self, feat_size, dim_head, scale):\n        super().__init__()\n        self.height, self.width = to_2tuple(feat_size)\n        self.dim_head = dim_head\n        self.height_rel = nn.Parameter(\n            torch.randn(self.height * 2 - 1, dim_head) * scale\n        )\n        self.width_rel = nn.Parameter(torch.randn(self.width * 2 - 1, dim_head) * scale)\n\n    def forward(self, q):\n        B, HW, _ = q.shape\n\n        # relative logits in width dimension.\n        q = q.reshape(B, self.height, self.width, -1)\n        rel_logits_w = rel_logits_1d(q, self.width_rel, permute_mask=(0, 1, 3, 2, 4))\n\n        # relative logits in height dimension.\n        q = q.transpose(1, 2)\n        rel_logits_h = rel_logits_1d(q, self.height_rel, permute_mask=(0, 3, 1, 4, 2))\n\n        rel_logits = rel_logits_h + rel_logits_w\n        rel_logits = rel_logits.reshape(B, HW, HW)\n        return rel_logits\n\n\nclass BottleneckAttn(nn.Module):\n    \"\"\"Bottleneck Attention\n    Paper: `Bottleneck Transformers for Visual Recognition` - https://arxiv.org/abs/2101.11605\n\n    The internal dimensions of the attention module are controlled by the interaction of several arguments.\n      * the output dimension of the module is specified by dim_out, which falls back to input dim if not set\n      * the value (v) dimension is set to dim_out // num_heads, the v projection determines the output dim\n      * the query and key (qk) dimensions are determined by\n        * num_heads * dim_head if dim_head is not None\n        * num_heads * (dim_out * attn_ratio // num_heads) if dim_head is None\n      * as seen above, attn_ratio determines the ratio of q and k relative to the output if dim_head not used\n\n    Args:\n        dim (int): input dimension to the module\n        dim_out (int): output dimension of the module, same as dim if not set\n        stride (int): output stride of the module, avg pool used if stride == 2 (default: 1).\n        num_heads (int): parallel attention heads (default: 4)\n        dim_head (int): dimension of query and key heads, calculated from dim_out * attn_ratio // num_heads if not set\n        qk_ratio (float): ratio of q and k dimensions to output dimension when dim_head not set. (default: 1.0)\n        qkv_bias (bool): add bias to q, k, and v projections\n        scale_pos_embed (bool): scale the position embedding as well as Q @ K\n    \"\"\"\n\n    def __init__(\n        self,\n        dim,\n        dim_out=None,\n        feat_size=None,\n        stride=1,\n        num_heads=4,\n        dim_head=None,\n        qk_ratio=1.0,\n        qkv_bias=False,\n        scale_pos_embed=False,\n    ):\n        super().__init__()\n        assert (\n            feat_size is not None\n        ), \"A concrete feature size matching expected input (H, W) is required\"\n        dim_out = dim_out or dim\n        assert dim_out % num_heads == 0\n        self.num_heads = num_heads\n        self.dim_head_qk = (\n            dim_head or make_divisible(dim_out * qk_ratio, divisor=8) // num_heads\n        )\n        self.dim_head_v = dim_out // self.num_heads\n        self.dim_out_qk = num_heads * self.dim_head_qk\n        self.dim_out_v = num_heads * self.dim_head_v\n        self.scale = self.dim_head_qk**-0.5\n        self.scale_pos_embed = scale_pos_embed\n\n        self.qkv = nn.Conv2d(\n            dim, self.dim_out_qk * 2 + self.dim_out_v, 1, bias=qkv_bias\n        )\n\n        # NOTE I'm only supporting relative pos embedding for now\n        self.pos_embed = PosEmbedRel(\n            feat_size, dim_head=self.dim_head_qk, scale=self.scale\n        )\n\n        self.pool = nn.AvgPool2d(2, 2) if stride == 2 else nn.Identity()\n\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        trunc_normal_(self.qkv.weight, std=self.qkv.weight.shape[1] ** -0.5)  # fan-in\n        trunc_normal_(self.pos_embed.height_rel, std=self.scale)\n        trunc_normal_(self.pos_embed.width_rel, std=self.scale)\n\n    def forward(self, x):\n        B, C, H, W = x.shape\n        _assert(H == self.pos_embed.height, \"\")\n        _assert(W == self.pos_embed.width, \"\")\n\n        x = self.qkv(x)  # B, (2 * dim_head_qk + dim_head_v) * num_heads, H, W\n\n        # NOTE head vs channel split ordering in qkv projection was decided before I allowed qk to differ from v\n        # So, this is more verbose than if heads were before qkv splits, but throughput is not impacted.\n        q, k, v = torch.split(\n            x, [self.dim_out_qk, self.dim_out_qk, self.dim_out_v], dim=1\n        )\n        q = q.reshape(B * self.num_heads, self.dim_head_qk, -1).transpose(-1, -2)\n        k = k.reshape(\n            B * self.num_heads, self.dim_head_qk, -1\n        )  # no transpose, for q @ k\n        v = v.reshape(B * self.num_heads, self.dim_head_v, -1).transpose(-1, -2)\n\n        if self.scale_pos_embed:\n            attn = (\n                q @ k + self.pos_embed(q)\n            ) * self.scale  # B * num_heads, H * W, H * W\n        else:\n            attn = (q @ k) * self.scale + self.pos_embed(q)\n        attn = attn.softmax(dim=-1)\n\n        out = (\n            (attn @ v).transpose(-1, -2).reshape(B, self.dim_out_v, H, W)\n        )  # B, dim_out, H, W\n        out = self.pool(out)\n        return out\n"
  },
  {
    "path": "RVT/models/layers/maxvit/layers/cbam.py",
    "content": "\"\"\" CBAM (sort-of) Attention\n\nExperimental impl of CBAM: Convolutional Block Attention Module: https://arxiv.org/abs/1807.06521\n\nWARNING: Results with these attention layers have been mixed. They can significantly reduce performance on\nsome tasks, especially fine-grained it seems. I may end up removing this impl.\n\nHacked together by / Copyright 2020 Ross Wightman\n\"\"\"\n\nimport torch\nfrom torch import nn as nn\nimport torch.nn.functional as F\n\nfrom .conv_bn_act import ConvNormAct\nfrom .create_act import create_act_layer, get_act_layer\nfrom .helpers import make_divisible\n\n\nclass ChannelAttn(nn.Module):\n    \"\"\"Original CBAM channel attention module, currently avg + max pool variant only.\"\"\"\n\n    def __init__(\n        self,\n        channels,\n        rd_ratio=1.0 / 16,\n        rd_channels=None,\n        rd_divisor=1,\n        act_layer=nn.ReLU,\n        gate_layer=\"sigmoid\",\n        mlp_bias=False,\n    ):\n        super(ChannelAttn, self).__init__()\n        if not rd_channels:\n            rd_channels = make_divisible(\n                channels * rd_ratio, rd_divisor, round_limit=0.0\n            )\n        self.fc1 = nn.Conv2d(channels, rd_channels, 1, bias=mlp_bias)\n        self.act = act_layer(inplace=True)\n        self.fc2 = nn.Conv2d(rd_channels, channels, 1, bias=mlp_bias)\n        self.gate = create_act_layer(gate_layer)\n\n    def forward(self, x):\n        x_avg = self.fc2(self.act(self.fc1(x.mean((2, 3), keepdim=True))))\n        x_max = self.fc2(self.act(self.fc1(x.amax((2, 3), keepdim=True))))\n        return x * self.gate(x_avg + x_max)\n\n\nclass LightChannelAttn(ChannelAttn):\n    \"\"\"An experimental 'lightweight' that sums avg + max pool first\"\"\"\n\n    def __init__(\n        self,\n        channels,\n        rd_ratio=1.0 / 16,\n        rd_channels=None,\n        rd_divisor=1,\n        act_layer=nn.ReLU,\n        gate_layer=\"sigmoid\",\n        mlp_bias=False,\n    ):\n        super(LightChannelAttn, self).__init__(\n            channels, rd_ratio, rd_channels, rd_divisor, act_layer, gate_layer, mlp_bias\n        )\n\n    def forward(self, x):\n        x_pool = 0.5 * x.mean((2, 3), keepdim=True) + 0.5 * x.amax((2, 3), keepdim=True)\n        x_attn = self.fc2(self.act(self.fc1(x_pool)))\n        return x * F.sigmoid(x_attn)\n\n\nclass SpatialAttn(nn.Module):\n    \"\"\"Original CBAM spatial attention module\"\"\"\n\n    def __init__(self, kernel_size=7, gate_layer=\"sigmoid\"):\n        super(SpatialAttn, self).__init__()\n        self.conv = ConvNormAct(2, 1, kernel_size, apply_act=False)\n        self.gate = create_act_layer(gate_layer)\n\n    def forward(self, x):\n        x_attn = torch.cat(\n            [x.mean(dim=1, keepdim=True), x.amax(dim=1, keepdim=True)], dim=1\n        )\n        x_attn = self.conv(x_attn)\n        return x * self.gate(x_attn)\n\n\nclass LightSpatialAttn(nn.Module):\n    \"\"\"An experimental 'lightweight' variant that sums avg_pool and max_pool results.\"\"\"\n\n    def __init__(self, kernel_size=7, gate_layer=\"sigmoid\"):\n        super(LightSpatialAttn, self).__init__()\n        self.conv = ConvNormAct(1, 1, kernel_size, apply_act=False)\n        self.gate = create_act_layer(gate_layer)\n\n    def forward(self, x):\n        x_attn = 0.5 * x.mean(dim=1, keepdim=True) + 0.5 * x.amax(dim=1, keepdim=True)\n        x_attn = self.conv(x_attn)\n        return x * self.gate(x_attn)\n\n\nclass CbamModule(nn.Module):\n    def __init__(\n        self,\n        channels,\n        rd_ratio=1.0 / 16,\n        rd_channels=None,\n        rd_divisor=1,\n        spatial_kernel_size=7,\n        act_layer=nn.ReLU,\n        gate_layer=\"sigmoid\",\n        mlp_bias=False,\n    ):\n        super(CbamModule, self).__init__()\n        self.channel = ChannelAttn(\n            channels,\n            rd_ratio=rd_ratio,\n            rd_channels=rd_channels,\n            rd_divisor=rd_divisor,\n            act_layer=act_layer,\n            gate_layer=gate_layer,\n            mlp_bias=mlp_bias,\n        )\n        self.spatial = SpatialAttn(spatial_kernel_size, gate_layer=gate_layer)\n\n    def forward(self, x):\n        x = self.channel(x)\n        x = self.spatial(x)\n        return x\n\n\nclass LightCbamModule(nn.Module):\n    def __init__(\n        self,\n        channels,\n        rd_ratio=1.0 / 16,\n        rd_channels=None,\n        rd_divisor=1,\n        spatial_kernel_size=7,\n        act_layer=nn.ReLU,\n        gate_layer=\"sigmoid\",\n        mlp_bias=False,\n    ):\n        super(LightCbamModule, self).__init__()\n        self.channel = LightChannelAttn(\n            channels,\n            rd_ratio=rd_ratio,\n            rd_channels=rd_channels,\n            rd_divisor=rd_divisor,\n            act_layer=act_layer,\n            gate_layer=gate_layer,\n            mlp_bias=mlp_bias,\n        )\n        self.spatial = LightSpatialAttn(spatial_kernel_size)\n\n    def forward(self, x):\n        x = self.channel(x)\n        x = self.spatial(x)\n        return x\n"
  },
  {
    "path": "RVT/models/layers/maxvit/layers/classifier.py",
    "content": "\"\"\" Classifier head and layer factory\n\nHacked together by / Copyright 2020 Ross Wightman\n\"\"\"\n\nfrom torch import nn as nn\nfrom torch.nn import functional as F\n\nfrom .adaptive_avgmax_pool import SelectAdaptivePool2d\n\n\ndef _create_pool(num_features, num_classes, pool_type=\"avg\", use_conv=False):\n    flatten_in_pool = not use_conv  # flatten when we use a Linear layer after pooling\n    if not pool_type:\n        assert (\n            num_classes == 0 or use_conv\n        ), \"Pooling can only be disabled if classifier is also removed or conv classifier is used\"\n        flatten_in_pool = (\n            False  # disable flattening if pooling is pass-through (no pooling)\n        )\n    global_pool = SelectAdaptivePool2d(pool_type=pool_type, flatten=flatten_in_pool)\n    num_pooled_features = num_features * global_pool.feat_mult()\n    return global_pool, num_pooled_features\n\n\ndef _create_fc(num_features, num_classes, use_conv=False):\n    if num_classes <= 0:\n        fc = nn.Identity()  # pass-through (no classifier)\n    elif use_conv:\n        fc = nn.Conv2d(num_features, num_classes, 1, bias=True)\n    else:\n        fc = nn.Linear(num_features, num_classes, bias=True)\n    return fc\n\n\ndef create_classifier(num_features, num_classes, pool_type=\"avg\", use_conv=False):\n    global_pool, num_pooled_features = _create_pool(\n        num_features, num_classes, pool_type, use_conv=use_conv\n    )\n    fc = _create_fc(num_pooled_features, num_classes, use_conv=use_conv)\n    return global_pool, fc\n\n\nclass ClassifierHead(nn.Module):\n    \"\"\"Classifier head w/ configurable global pooling and dropout.\"\"\"\n\n    def __init__(\n        self, in_chs, num_classes, pool_type=\"avg\", drop_rate=0.0, use_conv=False\n    ):\n        super(ClassifierHead, self).__init__()\n        self.drop_rate = drop_rate\n        self.global_pool, num_pooled_features = _create_pool(\n            in_chs, num_classes, pool_type, use_conv=use_conv\n        )\n        self.fc = _create_fc(num_pooled_features, num_classes, use_conv=use_conv)\n        self.flatten = nn.Flatten(1) if use_conv and pool_type else nn.Identity()\n\n    def forward(self, x, pre_logits: bool = False):\n        x = self.global_pool(x)\n        if self.drop_rate:\n            x = F.dropout(x, p=float(self.drop_rate), training=self.training)\n        if pre_logits:\n            return x.flatten(1)\n        else:\n            x = self.fc(x)\n            return self.flatten(x)\n"
  },
  {
    "path": "RVT/models/layers/maxvit/layers/cond_conv2d.py",
    "content": "\"\"\" PyTorch Conditionally Parameterized Convolution (CondConv)\n\nPaper: CondConv: Conditionally Parameterized Convolutions for Efficient Inference\n(https://arxiv.org/abs/1904.04971)\n\nHacked together by / Copyright 2020 Ross Wightman\n\"\"\"\n\nimport math\nfrom functools import partial\nimport numpy as np\nimport torch\nfrom torch import nn as nn\nfrom torch.nn import functional as F\n\nfrom .helpers import to_2tuple\nfrom .conv2d_same import conv2d_same\nfrom .padding import get_padding_value\n\n\ndef get_condconv_initializer(initializer, num_experts, expert_shape):\n    def condconv_initializer(weight):\n        \"\"\"CondConv initializer function.\"\"\"\n        num_params = np.prod(expert_shape)\n        if (\n            len(weight.shape) != 2\n            or weight.shape[0] != num_experts\n            or weight.shape[1] != num_params\n        ):\n            raise (\n                ValueError(\n                    \"CondConv variables must have shape [num_experts, num_params]\"\n                )\n            )\n        for i in range(num_experts):\n            initializer(weight[i].view(expert_shape))\n\n    return condconv_initializer\n\n\nclass CondConv2d(nn.Module):\n    \"\"\"Conditionally Parameterized Convolution\n    Inspired by: https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/condconv/condconv_layers.py\n\n    Grouped convolution hackery for parallel execution of the per-sample kernel filters inspired by this discussion:\n    https://github.com/pytorch/pytorch/issues/17983\n    \"\"\"\n\n    __constants__ = [\"in_channels\", \"out_channels\", \"dynamic_padding\"]\n\n    def __init__(\n        self,\n        in_channels,\n        out_channels,\n        kernel_size=3,\n        stride=1,\n        padding=\"\",\n        dilation=1,\n        groups=1,\n        bias=False,\n        num_experts=4,\n    ):\n        super(CondConv2d, self).__init__()\n\n        self.in_channels = in_channels\n        self.out_channels = out_channels\n        self.kernel_size = to_2tuple(kernel_size)\n        self.stride = to_2tuple(stride)\n        padding_val, is_padding_dynamic = get_padding_value(\n            padding, kernel_size, stride=stride, dilation=dilation\n        )\n        self.dynamic_padding = (\n            is_padding_dynamic  # if in forward to work with torchscript\n        )\n        self.padding = to_2tuple(padding_val)\n        self.dilation = to_2tuple(dilation)\n        self.groups = groups\n        self.num_experts = num_experts\n\n        self.weight_shape = (\n            self.out_channels,\n            self.in_channels // self.groups,\n        ) + self.kernel_size\n        weight_num_param = 1\n        for wd in self.weight_shape:\n            weight_num_param *= wd\n        self.weight = torch.nn.Parameter(\n            torch.Tensor(self.num_experts, weight_num_param)\n        )\n\n        if bias:\n            self.bias_shape = (self.out_channels,)\n            self.bias = torch.nn.Parameter(\n                torch.Tensor(self.num_experts, self.out_channels)\n            )\n        else:\n            self.register_parameter(\"bias\", None)\n\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        init_weight = get_condconv_initializer(\n            partial(nn.init.kaiming_uniform_, a=math.sqrt(5)),\n            self.num_experts,\n            self.weight_shape,\n        )\n        init_weight(self.weight)\n        if self.bias is not None:\n            fan_in = np.prod(self.weight_shape[1:])\n            bound = 1 / math.sqrt(fan_in)\n            init_bias = get_condconv_initializer(\n                partial(nn.init.uniform_, a=-bound, b=bound),\n                self.num_experts,\n                self.bias_shape,\n            )\n            init_bias(self.bias)\n\n    def forward(self, x, routing_weights):\n        B, C, H, W = x.shape\n        weight = torch.matmul(routing_weights, self.weight)\n        new_weight_shape = (\n            B * self.out_channels,\n            self.in_channels // self.groups,\n        ) + self.kernel_size\n        weight = weight.view(new_weight_shape)\n        bias = None\n        if self.bias is not None:\n            bias = torch.matmul(routing_weights, self.bias)\n            bias = bias.view(B * self.out_channels)\n        # move batch elements with channels so each batch element can be efficiently convolved with separate kernel\n        # reshape instead of view to work with channels_last input\n        x = x.reshape(1, B * C, H, W)\n        if self.dynamic_padding:\n            out = conv2d_same(\n                x,\n                weight,\n                bias,\n                stride=self.stride,\n                padding=self.padding,\n                dilation=self.dilation,\n                groups=self.groups * B,\n            )\n        else:\n            out = F.conv2d(\n                x,\n                weight,\n                bias,\n                stride=self.stride,\n                padding=self.padding,\n                dilation=self.dilation,\n                groups=self.groups * B,\n            )\n        out = out.permute([1, 0, 2, 3]).view(\n            B, self.out_channels, out.shape[-2], out.shape[-1]\n        )\n\n        # Literal port (from TF definition)\n        # x = torch.split(x, 1, 0)\n        # weight = torch.split(weight, 1, 0)\n        # if self.bias is not None:\n        #     bias = torch.matmul(routing_weights, self.bias)\n        #     bias = torch.split(bias, 1, 0)\n        # else:\n        #     bias = [None] * B\n        # out = []\n        # for xi, wi, bi in zip(x, weight, bias):\n        #     wi = wi.view(*self.weight_shape)\n        #     if bi is not None:\n        #         bi = bi.view(*self.bias_shape)\n        #     out.append(self.conv_fn(\n        #         xi, wi, bi, stride=self.stride, padding=self.padding,\n        #         dilation=self.dilation, groups=self.groups))\n        # out = torch.cat(out, 0)\n        return out\n"
  },
  {
    "path": "RVT/models/layers/maxvit/layers/config.py",
    "content": "\"\"\" Model / Layer Config singleton state\n\"\"\"\n\nfrom typing import Any, Optional\n\n__all__ = [\n    \"is_exportable\",\n    \"is_scriptable\",\n    \"is_no_jit\",\n    \"set_exportable\",\n    \"set_scriptable\",\n    \"set_no_jit\",\n    \"set_layer_config\",\n]\n\n# Set to True if prefer to have layers with no jit optimization (includes activations)\n_NO_JIT = False\n\n# Set to True if prefer to have activation layers with no jit optimization\n# NOTE not currently used as no difference between no_jit and no_activation jit as only layers obeying\n# the jit flags so far are activations. This will change as more layers are updated and/or added.\n_NO_ACTIVATION_JIT = False\n\n# Set to True if exporting a model with Same padding via ONNX\n_EXPORTABLE = False\n\n# Set to True if wanting to use torch.jit.script on a model\n_SCRIPTABLE = False\n\n\ndef is_no_jit():\n    return _NO_JIT\n\n\nclass set_no_jit:\n    def __init__(self, mode: bool) -> None:\n        global _NO_JIT\n        self.prev = _NO_JIT\n        _NO_JIT = mode\n\n    def __enter__(self) -> None:\n        pass\n\n    def __exit__(self, *args: Any) -> bool:\n        global _NO_JIT\n        _NO_JIT = self.prev\n        return False\n\n\ndef is_exportable():\n    return _EXPORTABLE\n\n\nclass set_exportable:\n    def __init__(self, mode: bool) -> None:\n        global _EXPORTABLE\n        self.prev = _EXPORTABLE\n        _EXPORTABLE = mode\n\n    def __enter__(self) -> None:\n        pass\n\n    def __exit__(self, *args: Any) -> bool:\n        global _EXPORTABLE\n        _EXPORTABLE = self.prev\n        return False\n\n\ndef is_scriptable():\n    return _SCRIPTABLE\n\n\nclass set_scriptable:\n    def __init__(self, mode: bool) -> None:\n        global _SCRIPTABLE\n        self.prev = _SCRIPTABLE\n        _SCRIPTABLE = mode\n\n    def __enter__(self) -> None:\n        pass\n\n    def __exit__(self, *args: Any) -> bool:\n        global _SCRIPTABLE\n        _SCRIPTABLE = self.prev\n        return False\n\n\nclass set_layer_config:\n    \"\"\"Layer config context manager that allows setting all layer config flags at once.\n    If a flag arg is None, it will not change the current value.\n    \"\"\"\n\n    def __init__(\n        self,\n        scriptable: Optional[bool] = None,\n        exportable: Optional[bool] = None,\n        no_jit: Optional[bool] = None,\n        no_activation_jit: Optional[bool] = None,\n    ):\n        global _SCRIPTABLE\n        global _EXPORTABLE\n        global _NO_JIT\n        global _NO_ACTIVATION_JIT\n        self.prev = _SCRIPTABLE, _EXPORTABLE, _NO_JIT, _NO_ACTIVATION_JIT\n        if scriptable is not None:\n            _SCRIPTABLE = scriptable\n        if exportable is not None:\n            _EXPORTABLE = exportable\n        if no_jit is not None:\n            _NO_JIT = no_jit\n        if no_activation_jit is not None:\n            _NO_ACTIVATION_JIT = no_activation_jit\n\n    def __enter__(self) -> None:\n        pass\n\n    def __exit__(self, *args: Any) -> bool:\n        global _SCRIPTABLE\n        global _EXPORTABLE\n        global _NO_JIT\n        global _NO_ACTIVATION_JIT\n        _SCRIPTABLE, _EXPORTABLE, _NO_JIT, _NO_ACTIVATION_JIT = self.prev\n        return False\n"
  },
  {
    "path": "RVT/models/layers/maxvit/layers/conv2d_same.py",
    "content": "\"\"\" Conv2d w/ Same Padding\n\nHacked together by / Copyright 2020 Ross Wightman\n\"\"\"\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom typing import Tuple, Optional\n\nfrom .padding import pad_same, get_padding_value\n\n\ndef conv2d_same(\n    x,\n    weight: torch.Tensor,\n    bias: Optional[torch.Tensor] = None,\n    stride: Tuple[int, int] = (1, 1),\n    padding: Tuple[int, int] = (0, 0),\n    dilation: Tuple[int, int] = (1, 1),\n    groups: int = 1,\n):\n    x = pad_same(x, weight.shape[-2:], stride, dilation)\n    return F.conv2d(x, weight, bias, stride, (0, 0), dilation, groups)\n\n\nclass Conv2dSame(nn.Conv2d):\n    \"\"\"Tensorflow like 'SAME' convolution wrapper for 2D convolutions\"\"\"\n\n    def __init__(\n        self,\n        in_channels,\n        out_channels,\n        kernel_size,\n        stride=1,\n        padding=0,\n        dilation=1,\n        groups=1,\n        bias=True,\n    ):\n        super(Conv2dSame, self).__init__(\n            in_channels, out_channels, kernel_size, stride, 0, dilation, groups, bias\n        )\n\n    def forward(self, x):\n        return conv2d_same(\n            x,\n            self.weight,\n            self.bias,\n            self.stride,\n            self.padding,\n            self.dilation,\n            self.groups,\n        )\n\n\ndef create_conv2d_pad(in_chs, out_chs, kernel_size, **kwargs):\n    padding = kwargs.pop(\"padding\", \"\")\n    kwargs.setdefault(\"bias\", False)\n    padding, is_dynamic = get_padding_value(padding, kernel_size, **kwargs)\n    if is_dynamic:\n        return Conv2dSame(in_chs, out_chs, kernel_size, **kwargs)\n    else:\n        return nn.Conv2d(in_chs, out_chs, kernel_size, padding=padding, **kwargs)\n"
  },
  {
    "path": "RVT/models/layers/maxvit/layers/conv_bn_act.py",
    "content": "\"\"\" Conv2d + BN + Act\n\nHacked together by / Copyright 2020 Ross Wightman\n\"\"\"\n\nimport functools\nfrom torch import nn as nn\n\nfrom .create_conv2d import create_conv2d\nfrom .create_norm_act import get_norm_act_layer\n\n\nclass ConvNormAct(nn.Module):\n    def __init__(\n        self,\n        in_channels,\n        out_channels,\n        kernel_size=1,\n        stride=1,\n        padding=\"\",\n        dilation=1,\n        groups=1,\n        bias=False,\n        apply_act=True,\n        norm_layer=nn.BatchNorm2d,\n        act_layer=nn.ReLU,\n        drop_layer=None,\n    ):\n        super(ConvNormAct, self).__init__()\n        self.conv = create_conv2d(\n            in_channels,\n            out_channels,\n            kernel_size,\n            stride=stride,\n            padding=padding,\n            dilation=dilation,\n            groups=groups,\n            bias=bias,\n        )\n\n        # NOTE for backwards compatibility with models that use separate norm and act layer definitions\n        norm_act_layer = get_norm_act_layer(norm_layer, act_layer)\n        # NOTE for backwards (weight) compatibility, norm layer name remains `.bn`\n        norm_kwargs = dict(drop_layer=drop_layer) if drop_layer is not None else {}\n        self.bn = norm_act_layer(out_channels, apply_act=apply_act, **norm_kwargs)\n\n    @property\n    def in_channels(self):\n        return self.conv.in_channels\n\n    @property\n    def out_channels(self):\n        return self.conv.out_channels\n\n    def forward(self, x):\n        x = self.conv(x)\n        x = self.bn(x)\n        return x\n\n\nConvBnAct = ConvNormAct\n\n\ndef create_aa(aa_layer, channels, stride=2, enable=True):\n    if not aa_layer or not enable:\n        return nn.Identity()\n    if isinstance(aa_layer, functools.partial):\n        if issubclass(aa_layer.func, nn.AvgPool2d):\n            return aa_layer()\n        else:\n            return aa_layer(channels)\n    elif issubclass(aa_layer, nn.AvgPool2d):\n        return aa_layer(stride)\n    else:\n        return aa_layer(channels=channels, stride=stride)\n\n\nclass ConvNormActAa(nn.Module):\n    def __init__(\n        self,\n        in_channels,\n        out_channels,\n        kernel_size=1,\n        stride=1,\n        padding=\"\",\n        dilation=1,\n        groups=1,\n        bias=False,\n        apply_act=True,\n        norm_layer=nn.BatchNorm2d,\n        act_layer=nn.ReLU,\n        aa_layer=None,\n        drop_layer=None,\n    ):\n        super(ConvNormActAa, self).__init__()\n        use_aa = aa_layer is not None and stride == 2\n\n        self.conv = create_conv2d(\n            in_channels,\n            out_channels,\n            kernel_size,\n            stride=1 if use_aa else stride,\n            padding=padding,\n            dilation=dilation,\n            groups=groups,\n            bias=bias,\n        )\n\n        # NOTE for backwards compatibility with models that use separate norm and act layer definitions\n        norm_act_layer = get_norm_act_layer(norm_layer, act_layer)\n        # NOTE for backwards (weight) compatibility, norm layer name remains `.bn`\n        norm_kwargs = dict(drop_layer=drop_layer) if drop_layer is not None else {}\n        self.bn = norm_act_layer(out_channels, apply_act=apply_act, **norm_kwargs)\n        self.aa = create_aa(aa_layer, out_channels, stride=stride, enable=use_aa)\n\n    @property\n    def in_channels(self):\n        return self.conv.in_channels\n\n    @property\n    def out_channels(self):\n        return self.conv.out_channels\n\n    def forward(self, x):\n        x = self.conv(x)\n        x = self.bn(x)\n        x = self.aa(x)\n        return x\n"
  },
  {
    "path": "RVT/models/layers/maxvit/layers/create_act.py",
    "content": "\"\"\" Activation Factory\nHacked together by / Copyright 2020 Ross Wightman\n\"\"\"\n\nfrom typing import Union, Callable, Type\n\nfrom .activations import *\nfrom .activations_jit import *\nfrom .activations_me import *\nfrom .config import is_exportable, is_scriptable, is_no_jit\n\n# PyTorch has an optimized, native 'silu' (aka 'swish') operator as of PyTorch 1.7.\n# Also hardsigmoid, hardswish, and soon mish. This code will use native version if present.\n# Eventually, the custom SiLU, Mish, Hard*, layers will be removed and only native variants will be used.\n_has_silu = \"silu\" in dir(torch.nn.functional)\n_has_hardswish = \"hardswish\" in dir(torch.nn.functional)\n_has_hardsigmoid = \"hardsigmoid\" in dir(torch.nn.functional)\n_has_mish = \"mish\" in dir(torch.nn.functional)\n\n\n_ACT_FN_DEFAULT = dict(\n    silu=F.silu if _has_silu else swish,\n    swish=F.silu if _has_silu else swish,\n    mish=F.mish if _has_mish else mish,\n    relu=F.relu,\n    relu6=F.relu6,\n    leaky_relu=F.leaky_relu,\n    elu=F.elu,\n    celu=F.celu,\n    selu=F.selu,\n    gelu=gelu,\n    sigmoid=sigmoid,\n    tanh=tanh,\n    hard_sigmoid=F.hardsigmoid if _has_hardsigmoid else hard_sigmoid,\n    hard_swish=F.hardswish if _has_hardswish else hard_swish,\n    hard_mish=hard_mish,\n)\n\n_ACT_FN_JIT = dict(\n    silu=F.silu if _has_silu else swish_jit,\n    swish=F.silu if _has_silu else swish_jit,\n    mish=F.mish if _has_mish else mish_jit,\n    hard_sigmoid=F.hardsigmoid if _has_hardsigmoid else hard_sigmoid_jit,\n    hard_swish=F.hardswish if _has_hardswish else hard_swish_jit,\n    hard_mish=hard_mish_jit,\n)\n\n_ACT_FN_ME = dict(\n    silu=F.silu if _has_silu else swish_me,\n    swish=F.silu if _has_silu else swish_me,\n    mish=F.mish if _has_mish else mish_me,\n    hard_sigmoid=F.hardsigmoid if _has_hardsigmoid else hard_sigmoid_me,\n    hard_swish=F.hardswish if _has_hardswish else hard_swish_me,\n    hard_mish=hard_mish_me,\n)\n\n_ACT_FNS = (_ACT_FN_ME, _ACT_FN_JIT, _ACT_FN_DEFAULT)\nfor a in _ACT_FNS:\n    a.setdefault(\"hardsigmoid\", a.get(\"hard_sigmoid\"))\n    a.setdefault(\"hardswish\", a.get(\"hard_swish\"))\n\n\n_ACT_LAYER_DEFAULT = dict(\n    silu=nn.SiLU if _has_silu else Swish,\n    swish=nn.SiLU if _has_silu else Swish,\n    mish=nn.Mish if _has_mish else Mish,\n    relu=nn.ReLU,\n    relu6=nn.ReLU6,\n    leaky_relu=nn.LeakyReLU,\n    elu=nn.ELU,\n    prelu=PReLU,\n    celu=nn.CELU,\n    selu=nn.SELU,\n    gelu=GELU,\n    sigmoid=Sigmoid,\n    tanh=Tanh,\n    hard_sigmoid=nn.Hardsigmoid if _has_hardsigmoid else HardSigmoid,\n    hard_swish=nn.Hardswish if _has_hardswish else HardSwish,\n    hard_mish=HardMish,\n)\n\n_ACT_LAYER_JIT = dict(\n    silu=nn.SiLU if _has_silu else SwishJit,\n    swish=nn.SiLU if _has_silu else SwishJit,\n    mish=nn.Mish if _has_mish else MishJit,\n    hard_sigmoid=nn.Hardsigmoid if _has_hardsigmoid else HardSigmoidJit,\n    hard_swish=nn.Hardswish if _has_hardswish else HardSwishJit,\n    hard_mish=HardMishJit,\n)\n\n_ACT_LAYER_ME = dict(\n    silu=nn.SiLU if _has_silu else SwishMe,\n    swish=nn.SiLU if _has_silu else SwishMe,\n    mish=nn.Mish if _has_mish else MishMe,\n    hard_sigmoid=nn.Hardsigmoid if _has_hardsigmoid else HardSigmoidMe,\n    hard_swish=nn.Hardswish if _has_hardswish else HardSwishMe,\n    hard_mish=HardMishMe,\n)\n\n_ACT_LAYERS = (_ACT_LAYER_ME, _ACT_LAYER_JIT, _ACT_LAYER_DEFAULT)\nfor a in _ACT_LAYERS:\n    a.setdefault(\"hardsigmoid\", a.get(\"hard_sigmoid\"))\n    a.setdefault(\"hardswish\", a.get(\"hard_swish\"))\n\n\ndef get_act_fn(name: Union[Callable, str] = \"relu\"):\n    \"\"\"Activation Function Factory\n    Fetching activation fns by name with this function allows export or torch script friendly\n    functions to be returned dynamically based on current config.\n    \"\"\"\n    if not name:\n        return None\n    if isinstance(name, Callable):\n        return name\n    if not (is_no_jit() or is_exportable() or is_scriptable()):\n        # If not exporting or scripting the model, first look for a memory-efficient version with\n        # custom autograd, then fallback\n        if name in _ACT_FN_ME:\n            return _ACT_FN_ME[name]\n    if not (is_no_jit() or is_exportable()):\n        if name in _ACT_FN_JIT:\n            return _ACT_FN_JIT[name]\n    return _ACT_FN_DEFAULT[name]\n\n\ndef get_act_layer(name: Union[Type[nn.Module], str] = \"relu\"):\n    \"\"\"Activation Layer Factory\n    Fetching activation layers by name with this function allows export or torch script friendly\n    functions to be returned dynamically based on current config.\n    \"\"\"\n    if not name:\n        return None\n    if not isinstance(name, str):\n        # callable, module, etc\n        return name\n    if not (is_no_jit() or is_exportable() or is_scriptable()):\n        if name in _ACT_LAYER_ME:\n            return _ACT_LAYER_ME[name]\n    if not (is_no_jit() or is_exportable()):\n        if name in _ACT_LAYER_JIT:\n            return _ACT_LAYER_JIT[name]\n    return _ACT_LAYER_DEFAULT[name]\n\n\ndef create_act_layer(name: Union[nn.Module, str], inplace=None, **kwargs):\n    act_layer = get_act_layer(name)\n    if act_layer is None:\n        return None\n    if inplace is None:\n        return act_layer(**kwargs)\n    try:\n        return act_layer(inplace=inplace, **kwargs)\n    except TypeError:\n        # recover if act layer doesn't have inplace arg\n        return act_layer(**kwargs)\n"
  },
  {
    "path": "RVT/models/layers/maxvit/layers/create_attn.py",
    "content": "\"\"\" Attention Factory\n\nHacked together by / Copyright 2021 Ross Wightman\n\"\"\"\n\nimport torch\nfrom functools import partial\n\nfrom .bottleneck_attn import BottleneckAttn\nfrom .cbam import CbamModule, LightCbamModule\nfrom .eca import EcaModule, CecaModule\nfrom .gather_excite import GatherExcite\nfrom .global_context import GlobalContext\nfrom .halo_attn import HaloAttn\nfrom .lambda_layer import LambdaLayer\nfrom .non_local_attn import NonLocalAttn, BatNonLocalAttn\nfrom .selective_kernel import SelectiveKernel\nfrom .split_attn import SplitAttn\nfrom .squeeze_excite import SEModule, EffectiveSEModule\n\n\ndef get_attn(attn_type):\n    if isinstance(attn_type, torch.nn.Module):\n        return attn_type\n    module_cls = None\n    if attn_type:\n        if isinstance(attn_type, str):\n            attn_type = attn_type.lower()\n            # Lightweight attention modules (channel and/or coarse spatial).\n            # Typically added to existing network architecture blocks in addition to existing convolutions.\n            if attn_type == \"se\":\n                module_cls = SEModule\n            elif attn_type == \"ese\":\n                module_cls = EffectiveSEModule\n            elif attn_type == \"eca\":\n                module_cls = EcaModule\n            elif attn_type == \"ecam\":\n                module_cls = partial(EcaModule, use_mlp=True)\n            elif attn_type == \"ceca\":\n                module_cls = CecaModule\n            elif attn_type == \"ge\":\n                module_cls = GatherExcite\n            elif attn_type == \"gc\":\n                module_cls = GlobalContext\n            elif attn_type == \"gca\":\n                module_cls = partial(GlobalContext, fuse_add=True, fuse_scale=False)\n            elif attn_type == \"cbam\":\n                module_cls = CbamModule\n            elif attn_type == \"lcbam\":\n                module_cls = LightCbamModule\n\n            # Attention / attention-like modules w/ significant params\n            # Typically replace some of the existing workhorse convs in a network architecture.\n            # All of these accept a stride argument and can spatially downsample the input.\n            elif attn_type == \"sk\":\n                module_cls = SelectiveKernel\n            elif attn_type == \"splat\":\n                module_cls = SplitAttn\n\n            # Self-attention / attention-like modules w/ significant compute and/or params\n            # Typically replace some of the existing workhorse convs in a network architecture.\n            # All of these accept a stride argument and can spatially downsample the input.\n            elif attn_type == \"lambda\":\n                return LambdaLayer\n            elif attn_type == \"bottleneck\":\n                return BottleneckAttn\n            elif attn_type == \"halo\":\n                return HaloAttn\n            elif attn_type == \"nl\":\n                module_cls = NonLocalAttn\n            elif attn_type == \"bat\":\n                module_cls = BatNonLocalAttn\n\n            # Woops!\n            else:\n                assert False, \"Invalid attn module (%s)\" % attn_type\n        elif isinstance(attn_type, bool):\n            if attn_type:\n                module_cls = SEModule\n        else:\n            module_cls = attn_type\n    return module_cls\n\n\ndef create_attn(attn_type, channels, **kwargs):\n    module_cls = get_attn(attn_type)\n    if module_cls is not None:\n        # NOTE: it's expected the first (positional) argument of all attention layers is the # input channels\n        return module_cls(channels, **kwargs)\n    return None\n"
  },
  {
    "path": "RVT/models/layers/maxvit/layers/create_conv2d.py",
    "content": "\"\"\" Create Conv2d Factory Method\n\nHacked together by / Copyright 2020 Ross Wightman\n\"\"\"\n\nfrom .mixed_conv2d import MixedConv2d\nfrom .cond_conv2d import CondConv2d\nfrom .conv2d_same import create_conv2d_pad\n\n\ndef create_conv2d(in_channels, out_channels, kernel_size, **kwargs):\n    \"\"\"Select a 2d convolution implementation based on arguments\n    Creates and returns one of torch.nn.Conv2d, Conv2dSame, MixedConv2d, or CondConv2d.\n\n    Used extensively by EfficientNet, MobileNetv3 and related networks.\n    \"\"\"\n    if isinstance(kernel_size, list):\n        assert (\n            \"num_experts\" not in kwargs\n        )  # MixNet + CondConv combo not supported currently\n        if \"groups\" in kwargs:\n            groups = kwargs.pop(\"groups\")\n            if groups == in_channels:\n                kwargs[\"depthwise\"] = True\n            else:\n                assert groups == 1\n        # We're going to use only lists for defining the MixedConv2d kernel groups,\n        # ints, tuples, other iterables will continue to pass to normal conv and specify h, w.\n        m = MixedConv2d(in_channels, out_channels, kernel_size, **kwargs)\n    else:\n        depthwise = kwargs.pop(\"depthwise\", False)\n        # for DW out_channels must be multiple of in_channels as must have out_channels % groups == 0\n        groups = in_channels if depthwise else kwargs.pop(\"groups\", 1)\n        if \"num_experts\" in kwargs and kwargs[\"num_experts\"] > 0:\n            m = CondConv2d(\n                in_channels, out_channels, kernel_size, groups=groups, **kwargs\n            )\n        else:\n            m = create_conv2d_pad(\n                in_channels, out_channels, kernel_size, groups=groups, **kwargs\n            )\n    return m\n"
  },
  {
    "path": "RVT/models/layers/maxvit/layers/create_norm.py",
    "content": "\"\"\" Norm Layer Factory\n\nCreate norm modules by string (to mirror create_act and creat_norm-act fns)\n\nCopyright 2022 Ross Wightman\n\"\"\"\n\nimport types\nimport functools\n\nimport torch.nn as nn\n\nfrom .norm import GroupNorm, GroupNorm1, LayerNorm, LayerNorm2d\n\n_NORM_MAP = dict(\n    batchnorm=nn.BatchNorm2d,\n    batchnorm2d=nn.BatchNorm2d,\n    batchnorm1d=nn.BatchNorm1d,\n    groupnorm=GroupNorm,\n    groupnorm1=GroupNorm1,\n    layernorm=LayerNorm,\n    layernorm2d=LayerNorm2d,\n)\n_NORM_TYPES = {m for n, m in _NORM_MAP.items()}\n\n\ndef create_norm_layer(\n    layer_name, num_features, act_layer=None, apply_act=True, **kwargs\n):\n    layer = get_norm_layer(layer_name, act_layer=act_layer)\n    layer_instance = layer(num_features, apply_act=apply_act, **kwargs)\n    return layer_instance\n\n\ndef get_norm_layer(norm_layer):\n    assert isinstance(norm_layer, (type, str, types.FunctionType, functools.partial))\n    norm_kwargs = {}\n\n    # unbind partial fn, so args can be rebound later\n    if isinstance(norm_layer, functools.partial):\n        norm_kwargs.update(norm_layer.keywords)\n        norm_layer = norm_layer.func\n\n    if isinstance(norm_layer, str):\n        layer_name = norm_layer.replace(\"_\", \"\")\n        norm_layer = _NORM_MAP.get(layer_name, None)\n    elif norm_layer in _NORM_TYPES:\n        norm_layer = norm_layer\n    elif isinstance(norm_layer, types.FunctionType):\n        # if function type, assume it is a lambda/fn that creates a norm layer\n        norm_layer = norm_layer\n    else:\n        type_name = norm_layer.__name__.lower().replace(\"_\", \"\")\n        norm_layer = _NORM_MAP.get(type_name, None)\n        assert norm_layer is not None, f\"No equivalent norm layer for {type_name}\"\n\n    if norm_kwargs:\n        norm_layer = functools.partial(norm_layer, **norm_kwargs)  # bind/rebind args\n    return norm_layer\n"
  },
  {
    "path": "RVT/models/layers/maxvit/layers/create_norm_act.py",
    "content": "\"\"\" NormAct (Normalizaiton + Activation Layer) Factory\n\nCreate norm + act combo modules that attempt to be backwards compatible with separate norm + act\nisntances in models. Where these are used it will be possible to swap separate BN + act layers with\ncombined modules like IABN or EvoNorms.\n\nHacked together by / Copyright 2020 Ross Wightman\n\"\"\"\n\nimport types\nimport functools\n\nfrom .evo_norm import *\nfrom .filter_response_norm import FilterResponseNormAct2d, FilterResponseNormTlu2d\nfrom .norm_act import BatchNormAct2d, GroupNormAct, LayerNormAct, LayerNormAct2d\nfrom .inplace_abn import InplaceAbn\n\n_NORM_ACT_MAP = dict(\n    batchnorm=BatchNormAct2d,\n    batchnorm2d=BatchNormAct2d,\n    groupnorm=GroupNormAct,\n    groupnorm1=functools.partial(GroupNormAct, num_groups=1),\n    layernorm=LayerNormAct,\n    layernorm2d=LayerNormAct2d,\n    evonormb0=EvoNorm2dB0,\n    evonormb1=EvoNorm2dB1,\n    evonormb2=EvoNorm2dB2,\n    evonorms0=EvoNorm2dS0,\n    evonorms0a=EvoNorm2dS0a,\n    evonorms1=EvoNorm2dS1,\n    evonorms1a=EvoNorm2dS1a,\n    evonorms2=EvoNorm2dS2,\n    evonorms2a=EvoNorm2dS2a,\n    frn=FilterResponseNormAct2d,\n    frntlu=FilterResponseNormTlu2d,\n    inplaceabn=InplaceAbn,\n    iabn=InplaceAbn,\n)\n_NORM_ACT_TYPES = {m for n, m in _NORM_ACT_MAP.items()}\n# has act_layer arg to define act type\n_NORM_ACT_REQUIRES_ARG = {\n    BatchNormAct2d,\n    GroupNormAct,\n    LayerNormAct,\n    LayerNormAct2d,\n    FilterResponseNormAct2d,\n    InplaceAbn,\n}\n\n\ndef create_norm_act_layer(\n    layer_name, num_features, act_layer=None, apply_act=True, jit=False, **kwargs\n):\n    layer = get_norm_act_layer(layer_name, act_layer=act_layer)\n    layer_instance = layer(num_features, apply_act=apply_act, **kwargs)\n    if jit:\n        layer_instance = torch.jit.script(layer_instance)\n    return layer_instance\n\n\ndef get_norm_act_layer(norm_layer, act_layer=None):\n    assert isinstance(norm_layer, (type, str, types.FunctionType, functools.partial))\n    assert act_layer is None or isinstance(\n        act_layer, (type, str, types.FunctionType, functools.partial)\n    )\n    norm_act_kwargs = {}\n\n    # unbind partial fn, so args can be rebound later\n    if isinstance(norm_layer, functools.partial):\n        norm_act_kwargs.update(norm_layer.keywords)\n        norm_layer = norm_layer.func\n\n    if isinstance(norm_layer, str):\n        layer_name = norm_layer.replace(\"_\", \"\").lower().split(\"-\")[0]\n        norm_act_layer = _NORM_ACT_MAP.get(layer_name, None)\n    elif norm_layer in _NORM_ACT_TYPES:\n        norm_act_layer = norm_layer\n    elif isinstance(norm_layer, types.FunctionType):\n        # if function type, must be a lambda/fn that creates a norm_act layer\n        norm_act_layer = norm_layer\n    else:\n        type_name = norm_layer.__name__.lower()\n        if type_name.startswith(\"batchnorm\"):\n            norm_act_layer = BatchNormAct2d\n        elif type_name.startswith(\"groupnorm\"):\n            norm_act_layer = GroupNormAct\n        elif type_name.startswith(\"groupnorm1\"):\n            norm_act_layer = functools.partial(GroupNormAct, num_groups=1)\n        elif type_name.startswith(\"layernorm2d\"):\n            norm_act_layer = LayerNormAct2d\n        elif type_name.startswith(\"layernorm\"):\n            norm_act_layer = LayerNormAct\n        else:\n            assert False, f\"No equivalent norm_act layer for {type_name}\"\n\n    if norm_act_layer in _NORM_ACT_REQUIRES_ARG:\n        # pass `act_layer` through for backwards compat where `act_layer=None` implies no activation.\n        # In the future, may force use of `apply_act` with `act_layer` arg bound to relevant NormAct types\n        norm_act_kwargs.setdefault(\"act_layer\", act_layer)\n    if norm_act_kwargs:\n        norm_act_layer = functools.partial(\n            norm_act_layer, **norm_act_kwargs\n        )  # bind/rebind args\n    return norm_act_layer\n"
  },
  {
    "path": "RVT/models/layers/maxvit/layers/drop.py",
    "content": "\"\"\" DropBlock, DropPath\n\nPyTorch implementations of DropBlock and DropPath (Stochastic Depth) regularization layers.\n\nPapers:\nDropBlock: A regularization method for convolutional networks (https://arxiv.org/abs/1810.12890)\n\nDeep Networks with Stochastic Depth (https://arxiv.org/abs/1603.09382)\n\nCode:\nDropBlock impl inspired by two Tensorflow impl that I liked:\n - https://github.com/tensorflow/tpu/blob/master/models/official/resnet/resnet_model.py#L74\n - https://github.com/clovaai/assembled-cnn/blob/master/nets/blocks.py\n\nHacked together by / Copyright 2020 Ross Wightman\n\"\"\"\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\n\ndef drop_block_2d(\n    x,\n    drop_prob: float = 0.1,\n    block_size: int = 7,\n    gamma_scale: float = 1.0,\n    with_noise: bool = False,\n    inplace: bool = False,\n    batchwise: bool = False,\n):\n    \"\"\"DropBlock. See https://arxiv.org/pdf/1810.12890.pdf\n\n    DropBlock with an experimental gaussian noise option. This layer has been tested on a few training\n    runs with success, but needs further validation and possibly optimization for lower runtime impact.\n    \"\"\"\n    B, C, H, W = x.shape\n    total_size = W * H\n    clipped_block_size = min(block_size, min(W, H))\n    # seed_drop_rate, the gamma parameter\n    gamma = (\n        gamma_scale\n        * drop_prob\n        * total_size\n        / clipped_block_size**2\n        / ((W - block_size + 1) * (H - block_size + 1))\n    )\n\n    # Forces the block to be inside the feature map.\n    w_i, h_i = torch.meshgrid(\n        torch.arange(W).to(x.device), torch.arange(H).to(x.device)\n    )\n    valid_block = (\n        (w_i >= clipped_block_size // 2) & (w_i < W - (clipped_block_size - 1) // 2)\n    ) & ((h_i >= clipped_block_size // 2) & (h_i < H - (clipped_block_size - 1) // 2))\n    valid_block = torch.reshape(valid_block, (1, 1, H, W)).to(dtype=x.dtype)\n\n    if batchwise:\n        # one mask for whole batch, quite a bit faster\n        uniform_noise = torch.rand((1, C, H, W), dtype=x.dtype, device=x.device)\n    else:\n        uniform_noise = torch.rand_like(x)\n    block_mask = ((2 - gamma - valid_block + uniform_noise) >= 1).to(dtype=x.dtype)\n    block_mask = -F.max_pool2d(\n        -block_mask,\n        kernel_size=clipped_block_size,  # block_size,\n        stride=1,\n        padding=clipped_block_size // 2,\n    )\n\n    if with_noise:\n        normal_noise = (\n            torch.randn((1, C, H, W), dtype=x.dtype, device=x.device)\n            if batchwise\n            else torch.randn_like(x)\n        )\n        if inplace:\n            x.mul_(block_mask).add_(normal_noise * (1 - block_mask))\n        else:\n            x = x * block_mask + normal_noise * (1 - block_mask)\n    else:\n        normalize_scale = (\n            block_mask.numel() / block_mask.to(dtype=torch.float32).sum().add(1e-7)\n        ).to(x.dtype)\n        if inplace:\n            x.mul_(block_mask * normalize_scale)\n        else:\n            x = x * block_mask * normalize_scale\n    return x\n\n\ndef drop_block_fast_2d(\n    x: torch.Tensor,\n    drop_prob: float = 0.1,\n    block_size: int = 7,\n    gamma_scale: float = 1.0,\n    with_noise: bool = False,\n    inplace: bool = False,\n):\n    \"\"\"DropBlock. See https://arxiv.org/pdf/1810.12890.pdf\n\n    DropBlock with an experimental gaussian noise option. Simplied from above without concern for valid\n    block mask at edges.\n    \"\"\"\n    B, C, H, W = x.shape\n    total_size = W * H\n    clipped_block_size = min(block_size, min(W, H))\n    gamma = (\n        gamma_scale\n        * drop_prob\n        * total_size\n        / clipped_block_size**2\n        / ((W - block_size + 1) * (H - block_size + 1))\n    )\n\n    block_mask = torch.empty_like(x).bernoulli_(gamma)\n    block_mask = F.max_pool2d(\n        block_mask.to(x.dtype),\n        kernel_size=clipped_block_size,\n        stride=1,\n        padding=clipped_block_size // 2,\n    )\n\n    if with_noise:\n        normal_noise = torch.empty_like(x).normal_()\n        if inplace:\n            x.mul_(1.0 - block_mask).add_(normal_noise * block_mask)\n        else:\n            x = x * (1.0 - block_mask) + normal_noise * block_mask\n    else:\n        block_mask = 1 - block_mask\n        normalize_scale = (\n            block_mask.numel() / block_mask.to(dtype=torch.float32).sum().add(1e-6)\n        ).to(dtype=x.dtype)\n        if inplace:\n            x.mul_(block_mask * normalize_scale)\n        else:\n            x = x * block_mask * normalize_scale\n    return x\n\n\nclass DropBlock2d(nn.Module):\n    \"\"\"DropBlock. See https://arxiv.org/pdf/1810.12890.pdf\"\"\"\n\n    def __init__(\n        self,\n        drop_prob: float = 0.1,\n        block_size: int = 7,\n        gamma_scale: float = 1.0,\n        with_noise: bool = False,\n        inplace: bool = False,\n        batchwise: bool = False,\n        fast: bool = True,\n    ):\n        super(DropBlock2d, self).__init__()\n        self.drop_prob = drop_prob\n        self.gamma_scale = gamma_scale\n        self.block_size = block_size\n        self.with_noise = with_noise\n        self.inplace = inplace\n        self.batchwise = batchwise\n        self.fast = fast  # FIXME finish comparisons of fast vs not\n\n    def forward(self, x):\n        if not self.training or not self.drop_prob:\n            return x\n        if self.fast:\n            return drop_block_fast_2d(\n                x,\n                self.drop_prob,\n                self.block_size,\n                self.gamma_scale,\n                self.with_noise,\n                self.inplace,\n            )\n        else:\n            return drop_block_2d(\n                x,\n                self.drop_prob,\n                self.block_size,\n                self.gamma_scale,\n                self.with_noise,\n                self.inplace,\n                self.batchwise,\n            )\n\n\ndef drop_path(\n    x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True\n):\n    \"\"\"Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).\n\n    This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,\n    the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...\n    See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for\n    changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use\n    'survival rate' as the argument.\n\n    \"\"\"\n    if drop_prob == 0.0 or not training:\n        return x\n    keep_prob = 1 - drop_prob\n    shape = (x.shape[0],) + (1,) * (\n        x.ndim - 1\n    )  # work with diff dim tensors, not just 2D ConvNets\n    random_tensor = x.new_empty(shape).bernoulli_(keep_prob)\n    if keep_prob > 0.0 and scale_by_keep:\n        random_tensor.div_(keep_prob)\n    return x * random_tensor\n\n\nclass DropPath(nn.Module):\n    \"\"\"Drop paths (Stochastic Depth) per sample  (when applied in main path of residual blocks).\"\"\"\n\n    def __init__(self, drop_prob: float = 0.0, scale_by_keep: bool = True):\n        super(DropPath, self).__init__()\n        self.drop_prob = drop_prob\n        self.scale_by_keep = scale_by_keep\n\n    def forward(self, x):\n        return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)\n\n    def extra_repr(self):\n        return f\"drop_prob={round(self.drop_prob,3):0.3f}\"\n"
  },
  {
    "path": "RVT/models/layers/maxvit/layers/eca.py",
    "content": "\"\"\"\nECA module from ECAnet\n\npaper: ECA-Net: Efficient Channel Attention for Deep Convolutional Neural Networks\nhttps://arxiv.org/abs/1910.03151\n\nOriginal ECA model borrowed from https://github.com/BangguWu/ECANet\n\nModified circular ECA implementation and adaption for use in timm package\nby Chris Ha https://github.com/VRandme\n\nOriginal License:\n\nMIT License\n\nCopyright (c) 2019 BangguWu, Qilong Wang\n\nPermission is hereby granted, free of charge, to any person obtaining a copy\nof this software and associated documentation files (the \"Software\"), to deal\nin the Software without restriction, including without limitation the rights\nto use, copy, modify, merge, publish, distribute, sublicense, and/or sell\ncopies of the Software, and to permit persons to whom the Software is\nfurnished to do so, subject to the following conditions:\n\nThe above copyright notice and this permission notice shall be included in all\ncopies or substantial portions of the Software.\n\nTHE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\nIMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\nFITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\nAUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\nLIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\nOUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\nSOFTWARE.\n\"\"\"\n\nimport math\nfrom torch import nn\nimport torch.nn.functional as F\n\n\nfrom .create_act import create_act_layer\nfrom .helpers import make_divisible\n\n\nclass EcaModule(nn.Module):\n    \"\"\"Constructs an ECA module.\n\n    Args:\n        channels: Number of channels of the input feature map for use in adaptive kernel sizes\n            for actual calculations according to channel.\n            gamma, beta: when channel is given parameters of mapping function\n            refer to original paper https://arxiv.org/pdf/1910.03151.pdf\n            (default=None. if channel size not given, use k_size given for kernel size.)\n        kernel_size: Adaptive selection of kernel size (default=3)\n        gamm: used in kernel_size calc, see above\n        beta: used in kernel_size calc, see above\n        act_layer: optional non-linearity after conv, enables conv bias, this is an experiment\n        gate_layer: gating non-linearity to use\n    \"\"\"\n\n    def __init__(\n        self,\n        channels=None,\n        kernel_size=3,\n        gamma=2,\n        beta=1,\n        act_layer=None,\n        gate_layer=\"sigmoid\",\n        rd_ratio=1 / 8,\n        rd_channels=None,\n        rd_divisor=8,\n        use_mlp=False,\n    ):\n        super(EcaModule, self).__init__()\n        if channels is not None:\n            t = int(abs(math.log(channels, 2) + beta) / gamma)\n            kernel_size = max(t if t % 2 else t + 1, 3)\n        assert kernel_size % 2 == 1\n        padding = (kernel_size - 1) // 2\n        if use_mlp:\n            # NOTE 'mlp' mode is a timm experiment, not in paper\n            assert channels is not None\n            if rd_channels is None:\n                rd_channels = make_divisible(channels * rd_ratio, divisor=rd_divisor)\n            act_layer = act_layer or nn.ReLU\n            self.conv = nn.Conv1d(1, rd_channels, kernel_size=1, padding=0, bias=True)\n            self.act = create_act_layer(act_layer)\n            self.conv2 = nn.Conv1d(\n                rd_channels, 1, kernel_size=kernel_size, padding=padding, bias=True\n            )\n        else:\n            self.conv = nn.Conv1d(\n                1, 1, kernel_size=kernel_size, padding=padding, bias=False\n            )\n            self.act = None\n            self.conv2 = None\n        self.gate = create_act_layer(gate_layer)\n\n    def forward(self, x):\n        y = x.mean((2, 3)).view(x.shape[0], 1, -1)  # view for 1d conv\n        y = self.conv(y)\n        if self.conv2 is not None:\n            y = self.act(y)\n            y = self.conv2(y)\n        y = self.gate(y).view(x.shape[0], -1, 1, 1)\n        return x * y.expand_as(x)\n\n\nEfficientChannelAttn = EcaModule  # alias\n\n\nclass CecaModule(nn.Module):\n    \"\"\"Constructs a circular ECA module.\n\n    ECA module where the conv uses circular padding rather than zero padding.\n    Unlike the spatial dimension, the channels do not have inherent ordering nor\n    locality. Although this module in essence, applies such an assumption, it is unnecessary\n    to limit the channels on either \"edge\" from being circularly adapted to each other.\n    This will fundamentally increase connectivity and possibly increase performance metrics\n    (accuracy, robustness), without significantly impacting resource metrics\n    (parameter size, throughput,latency, etc)\n\n    Args:\n        channels: Number of channels of the input feature map for use in adaptive kernel sizes\n            for actual calculations according to channel.\n            gamma, beta: when channel is given parameters of mapping function\n            refer to original paper https://arxiv.org/pdf/1910.03151.pdf\n            (default=None. if channel size not given, use k_size given for kernel size.)\n        kernel_size: Adaptive selection of kernel size (default=3)\n        gamm: used in kernel_size calc, see above\n        beta: used in kernel_size calc, see above\n        act_layer: optional non-linearity after conv, enables conv bias, this is an experiment\n        gate_layer: gating non-linearity to use\n    \"\"\"\n\n    def __init__(\n        self,\n        channels=None,\n        kernel_size=3,\n        gamma=2,\n        beta=1,\n        act_layer=None,\n        gate_layer=\"sigmoid\",\n    ):\n        super(CecaModule, self).__init__()\n        if channels is not None:\n            t = int(abs(math.log(channels, 2) + beta) / gamma)\n            kernel_size = max(t if t % 2 else t + 1, 3)\n        has_act = act_layer is not None\n        assert kernel_size % 2 == 1\n\n        # PyTorch circular padding mode is buggy as of pytorch 1.4\n        # see https://github.com/pytorch/pytorch/pull/17240\n        # implement manual circular padding\n        self.padding = (kernel_size - 1) // 2\n        self.conv = nn.Conv1d(1, 1, kernel_size=kernel_size, padding=0, bias=has_act)\n        self.gate = create_act_layer(gate_layer)\n\n    def forward(self, x):\n        y = x.mean((2, 3)).view(x.shape[0], 1, -1)\n        # Manually implement circular padding, F.pad does not seemed to be bugged\n        y = F.pad(y, (self.padding, self.padding), mode=\"circular\")\n        y = self.conv(y)\n        y = self.gate(y).view(x.shape[0], -1, 1, 1)\n        return x * y.expand_as(x)\n\n\nCircularEfficientChannelAttn = CecaModule\n"
  },
  {
    "path": "RVT/models/layers/maxvit/layers/evo_norm.py",
    "content": "\"\"\" EvoNorm in PyTorch\n\nBased on `Evolving Normalization-Activation Layers` - https://arxiv.org/abs/2004.02967\n@inproceedings{NEURIPS2020,\n author = {Liu, Hanxiao and Brock, Andy and Simonyan, Karen and Le, Quoc},\n booktitle = {Advances in Neural Information Processing Systems},\n editor = {H. Larochelle and M. Ranzato and R. Hadsell and M. F. Balcan and H. Lin},\n pages = {13539--13550},\n publisher = {Curran Associates, Inc.},\n title = {Evolving Normalization-Activation Layers},\n url = {https://proceedings.neurips.cc/paper/2020/file/9d4c03631b8b0c85ae08bf05eda37d0f-Paper.pdf},\n volume = {33},\n year = {2020}\n}\n\nAn attempt at getting decent performing EvoNorms running in PyTorch.\nWhile faster than other PyTorch impl, still quite a ways off the built-in BatchNorm\nin terms of memory usage and throughput on GPUs.\n\nI'm testing these modules on TPU w/ PyTorch XLA. Promising start but\ncurrently working around some issues with builtin torch/tensor.var/std. Unlike\nGPU, similar train speeds for EvoNormS variants and BatchNorm.\n\nHacked together by / Copyright 2020 Ross Wightman\n\"\"\"\n\nfrom typing import Sequence, Union\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom .create_act import create_act_layer\nfrom .trace_utils import _assert\n\n\ndef instance_std(x, eps: float = 1e-5):\n    std = (\n        x.float()\n        .var(dim=(2, 3), unbiased=False, keepdim=True)\n        .add(eps)\n        .sqrt()\n        .to(x.dtype)\n    )\n    return std.expand(x.shape)\n\n\ndef instance_std_tpu(x, eps: float = 1e-5):\n    std = manual_var(x, dim=(2, 3)).add(eps).sqrt()\n    return std.expand(x.shape)\n\n\n# instance_std = instance_std_tpu\n\n\ndef instance_rms(x, eps: float = 1e-5):\n    rms = x.float().square().mean(dim=(2, 3), keepdim=True).add(eps).sqrt().to(x.dtype)\n    return rms.expand(x.shape)\n\n\ndef manual_var(x, dim: Union[int, Sequence[int]], diff_sqm: bool = False):\n    xm = x.mean(dim=dim, keepdim=True)\n    if diff_sqm:\n        # difference of squared mean and mean squared, faster on TPU can be less stable\n        var = ((x * x).mean(dim=dim, keepdim=True) - (xm * xm)).clamp(0)\n    else:\n        var = ((x - xm) * (x - xm)).mean(dim=dim, keepdim=True)\n    return var\n\n\ndef group_std(x, groups: int = 32, eps: float = 1e-5, flatten: bool = False):\n    B, C, H, W = x.shape\n    x_dtype = x.dtype\n    _assert(C % groups == 0, \"\")\n    if flatten:\n        x = x.reshape(B, groups, -1)  # FIXME simpler shape causing TPU / XLA issues\n        std = (\n            x.float()\n            .var(dim=2, unbiased=False, keepdim=True)\n            .add(eps)\n            .sqrt()\n            .to(x_dtype)\n        )\n    else:\n        x = x.reshape(B, groups, C // groups, H, W)\n        std = (\n            x.float()\n            .var(dim=(2, 3, 4), unbiased=False, keepdim=True)\n            .add(eps)\n            .sqrt()\n            .to(x_dtype)\n        )\n    return std.expand(x.shape).reshape(B, C, H, W)\n\n\ndef group_std_tpu(\n    x,\n    groups: int = 32,\n    eps: float = 1e-5,\n    diff_sqm: bool = False,\n    flatten: bool = False,\n):\n    # This is a workaround for some stability / odd behaviour of .var and .std\n    # running on PyTorch XLA w/ TPUs. These manual var impl are producing much better results\n    B, C, H, W = x.shape\n    _assert(C % groups == 0, \"\")\n    if flatten:\n        x = x.reshape(B, groups, -1)  # FIXME simpler shape causing TPU / XLA issues\n        var = manual_var(x, dim=-1, diff_sqm=diff_sqm)\n    else:\n        x = x.reshape(B, groups, C // groups, H, W)\n        var = manual_var(x, dim=(2, 3, 4), diff_sqm=diff_sqm)\n    return var.add(eps).sqrt().expand(x.shape).reshape(B, C, H, W)\n\n\n# group_std = group_std_tpu  # FIXME TPU temporary\n\n\ndef group_rms(x, groups: int = 32, eps: float = 1e-5):\n    B, C, H, W = x.shape\n    _assert(C % groups == 0, \"\")\n    x_dtype = x.dtype\n    x = x.reshape(B, groups, C // groups, H, W)\n    rms = (\n        x.float()\n        .square()\n        .mean(dim=(2, 3, 4), keepdim=True)\n        .add(eps)\n        .sqrt_()\n        .to(x_dtype)\n    )\n    return rms.expand(x.shape).reshape(B, C, H, W)\n\n\nclass EvoNorm2dB0(nn.Module):\n    def __init__(self, num_features, apply_act=True, momentum=0.1, eps=1e-3, **_):\n        super().__init__()\n        self.apply_act = apply_act  # apply activation (non-linearity)\n        self.momentum = momentum\n        self.eps = eps\n        self.weight = nn.Parameter(torch.ones(num_features))\n        self.bias = nn.Parameter(torch.zeros(num_features))\n        self.v = nn.Parameter(torch.ones(num_features)) if apply_act else None\n        self.register_buffer(\"running_var\", torch.ones(num_features))\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        nn.init.ones_(self.weight)\n        nn.init.zeros_(self.bias)\n        if self.v is not None:\n            nn.init.ones_(self.v)\n\n    def forward(self, x):\n        _assert(x.dim() == 4, \"expected 4D input\")\n        x_dtype = x.dtype\n        v_shape = (1, -1, 1, 1)\n        if self.v is not None:\n            if self.training:\n                var = x.float().var(dim=(0, 2, 3), unbiased=False)\n                # var = manual_var(x, dim=(0, 2, 3)).squeeze()\n                n = x.numel() / x.shape[1]\n                self.running_var.copy_(\n                    self.running_var * (1 - self.momentum)\n                    + var.detach() * self.momentum * (n / (n - 1))\n                )\n            else:\n                var = self.running_var\n            left = var.add(self.eps).sqrt_().to(x_dtype).view(v_shape).expand_as(x)\n            v = self.v.to(x_dtype).view(v_shape)\n            right = x * v + instance_std(x, self.eps)\n            x = x / left.max(right)\n        return x * self.weight.to(x_dtype).view(v_shape) + self.bias.to(x_dtype).view(\n            v_shape\n        )\n\n\nclass EvoNorm2dB1(nn.Module):\n    def __init__(self, num_features, apply_act=True, momentum=0.1, eps=1e-5, **_):\n        super().__init__()\n        self.apply_act = apply_act  # apply activation (non-linearity)\n        self.momentum = momentum\n        self.eps = eps\n        self.weight = nn.Parameter(torch.ones(num_features))\n        self.bias = nn.Parameter(torch.zeros(num_features))\n        self.register_buffer(\"running_var\", torch.ones(num_features))\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        nn.init.ones_(self.weight)\n        nn.init.zeros_(self.bias)\n\n    def forward(self, x):\n        _assert(x.dim() == 4, \"expected 4D input\")\n        x_dtype = x.dtype\n        v_shape = (1, -1, 1, 1)\n        if self.apply_act:\n            if self.training:\n                var = x.float().var(dim=(0, 2, 3), unbiased=False)\n                n = x.numel() / x.shape[1]\n                self.running_var.copy_(\n                    self.running_var * (1 - self.momentum)\n                    + var.detach().to(self.running_var.dtype)\n                    * self.momentum\n                    * (n / (n - 1))\n                )\n            else:\n                var = self.running_var\n            var = var.to(x_dtype).view(v_shape)\n            left = var.add(self.eps).sqrt_()\n            right = (x + 1) * instance_rms(x, self.eps)\n            x = x / left.max(right)\n        return x * self.weight.view(v_shape).to(x_dtype) + self.bias.view(v_shape).to(\n            x_dtype\n        )\n\n\nclass EvoNorm2dB2(nn.Module):\n    def __init__(self, num_features, apply_act=True, momentum=0.1, eps=1e-5, **_):\n        super().__init__()\n        self.apply_act = apply_act  # apply activation (non-linearity)\n        self.momentum = momentum\n        self.eps = eps\n        self.weight = nn.Parameter(torch.ones(num_features))\n        self.bias = nn.Parameter(torch.zeros(num_features))\n        self.register_buffer(\"running_var\", torch.ones(num_features))\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        nn.init.ones_(self.weight)\n        nn.init.zeros_(self.bias)\n\n    def forward(self, x):\n        _assert(x.dim() == 4, \"expected 4D input\")\n        x_dtype = x.dtype\n        v_shape = (1, -1, 1, 1)\n        if self.apply_act:\n            if self.training:\n                var = x.float().var(dim=(0, 2, 3), unbiased=False)\n                n = x.numel() / x.shape[1]\n                self.running_var.copy_(\n                    self.running_var * (1 - self.momentum)\n                    + var.detach().to(self.running_var.dtype)\n                    * self.momentum\n                    * (n / (n - 1))\n                )\n            else:\n                var = self.running_var\n            var = var.to(x_dtype).view(v_shape)\n            left = var.add(self.eps).sqrt_()\n            right = instance_rms(x, self.eps) - x\n            x = x / left.max(right)\n        return x * self.weight.view(v_shape).to(x_dtype) + self.bias.view(v_shape).to(\n            x_dtype\n        )\n\n\nclass EvoNorm2dS0(nn.Module):\n    def __init__(\n        self, num_features, groups=32, group_size=None, apply_act=True, eps=1e-5, **_\n    ):\n        super().__init__()\n        self.apply_act = apply_act  # apply activation (non-linearity)\n        if group_size:\n            assert num_features % group_size == 0\n            self.groups = num_features // group_size\n        else:\n            self.groups = groups\n        self.eps = eps\n        self.weight = nn.Parameter(torch.ones(num_features))\n        self.bias = nn.Parameter(torch.zeros(num_features))\n        self.v = nn.Parameter(torch.ones(num_features)) if apply_act else None\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        nn.init.ones_(self.weight)\n        nn.init.zeros_(self.bias)\n        if self.v is not None:\n            nn.init.ones_(self.v)\n\n    def forward(self, x):\n        _assert(x.dim() == 4, \"expected 4D input\")\n        x_dtype = x.dtype\n        v_shape = (1, -1, 1, 1)\n        if self.v is not None:\n            v = self.v.view(v_shape).to(x_dtype)\n            x = x * (x * v).sigmoid() / group_std(x, self.groups, self.eps)\n        return x * self.weight.view(v_shape).to(x_dtype) + self.bias.view(v_shape).to(\n            x_dtype\n        )\n\n\nclass EvoNorm2dS0a(EvoNorm2dS0):\n    def __init__(\n        self, num_features, groups=32, group_size=None, apply_act=True, eps=1e-3, **_\n    ):\n        super().__init__(\n            num_features,\n            groups=groups,\n            group_size=group_size,\n            apply_act=apply_act,\n            eps=eps,\n        )\n\n    def forward(self, x):\n        _assert(x.dim() == 4, \"expected 4D input\")\n        x_dtype = x.dtype\n        v_shape = (1, -1, 1, 1)\n        d = group_std(x, self.groups, self.eps)\n        if self.v is not None:\n            v = self.v.view(v_shape).to(x_dtype)\n            x = x * (x * v).sigmoid()\n        x = x / d\n        return x * self.weight.view(v_shape).to(x_dtype) + self.bias.view(v_shape).to(\n            x_dtype\n        )\n\n\nclass EvoNorm2dS1(nn.Module):\n    def __init__(\n        self,\n        num_features,\n        groups=32,\n        group_size=None,\n        apply_act=True,\n        act_layer=None,\n        eps=1e-5,\n        **_\n    ):\n        super().__init__()\n        act_layer = act_layer or nn.SiLU\n        self.apply_act = apply_act  # apply activation (non-linearity)\n        if act_layer is not None and apply_act:\n            self.act = create_act_layer(act_layer)\n        else:\n            self.act = nn.Identity()\n        if group_size:\n            assert num_features % group_size == 0\n            self.groups = num_features // group_size\n        else:\n            self.groups = groups\n        self.eps = eps\n        self.pre_act_norm = False\n        self.weight = nn.Parameter(torch.ones(num_features))\n        self.bias = nn.Parameter(torch.zeros(num_features))\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        nn.init.ones_(self.weight)\n        nn.init.zeros_(self.bias)\n\n    def forward(self, x):\n        _assert(x.dim() == 4, \"expected 4D input\")\n        x_dtype = x.dtype\n        v_shape = (1, -1, 1, 1)\n        if self.apply_act:\n            x = self.act(x) / group_std(x, self.groups, self.eps)\n        return x * self.weight.view(v_shape).to(x_dtype) + self.bias.view(v_shape).to(\n            x_dtype\n        )\n\n\nclass EvoNorm2dS1a(EvoNorm2dS1):\n    def __init__(\n        self,\n        num_features,\n        groups=32,\n        group_size=None,\n        apply_act=True,\n        act_layer=None,\n        eps=1e-3,\n        **_\n    ):\n        super().__init__(\n            num_features,\n            groups=groups,\n            group_size=group_size,\n            apply_act=apply_act,\n            act_layer=act_layer,\n            eps=eps,\n        )\n\n    def forward(self, x):\n        _assert(x.dim() == 4, \"expected 4D input\")\n        x_dtype = x.dtype\n        v_shape = (1, -1, 1, 1)\n        x = self.act(x) / group_std(x, self.groups, self.eps)\n        return x * self.weight.view(v_shape).to(x_dtype) + self.bias.view(v_shape).to(\n            x_dtype\n        )\n\n\nclass EvoNorm2dS2(nn.Module):\n    def __init__(\n        self,\n        num_features,\n        groups=32,\n        group_size=None,\n        apply_act=True,\n        act_layer=None,\n        eps=1e-5,\n        **_\n    ):\n        super().__init__()\n        act_layer = act_layer or nn.SiLU\n        self.apply_act = apply_act  # apply activation (non-linearity)\n        if act_layer is not None and apply_act:\n            self.act = create_act_layer(act_layer)\n        else:\n            self.act = nn.Identity()\n        if group_size:\n            assert num_features % group_size == 0\n            self.groups = num_features // group_size\n        else:\n            self.groups = groups\n        self.eps = eps\n        self.weight = nn.Parameter(torch.ones(num_features))\n        self.bias = nn.Parameter(torch.zeros(num_features))\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        nn.init.ones_(self.weight)\n        nn.init.zeros_(self.bias)\n\n    def forward(self, x):\n        _assert(x.dim() == 4, \"expected 4D input\")\n        x_dtype = x.dtype\n        v_shape = (1, -1, 1, 1)\n        if self.apply_act:\n            x = self.act(x) / group_rms(x, self.groups, self.eps)\n        return x * self.weight.view(v_shape).to(x_dtype) + self.bias.view(v_shape).to(\n            x_dtype\n        )\n\n\nclass EvoNorm2dS2a(EvoNorm2dS2):\n    def __init__(\n        self,\n        num_features,\n        groups=32,\n        group_size=None,\n        apply_act=True,\n        act_layer=None,\n        eps=1e-3,\n        **_\n    ):\n        super().__init__(\n            num_features,\n            groups=groups,\n            group_size=group_size,\n            apply_act=apply_act,\n            act_layer=act_layer,\n            eps=eps,\n        )\n\n    def forward(self, x):\n        _assert(x.dim() == 4, \"expected 4D input\")\n        x_dtype = x.dtype\n        v_shape = (1, -1, 1, 1)\n        x = self.act(x) / group_rms(x, self.groups, self.eps)\n        return x * self.weight.view(v_shape).to(x_dtype) + self.bias.view(v_shape).to(\n            x_dtype\n        )\n"
  },
  {
    "path": "RVT/models/layers/maxvit/layers/fast_norm.py",
    "content": "\"\"\" 'Fast' Normalization Functions\n\nFor GroupNorm and LayerNorm these functions bypass typical AMP upcast to float32.\n\nAdditionally, for LayerNorm, the APEX fused LN is used if available (which also does not upcast)\n\nHacked together by / Copyright 2022 Ross Wightman\n\"\"\"\n\nfrom typing import List, Optional\n\nimport torch\nfrom torch.nn import functional as F\n\ntry:\n    from apex.normalization.fused_layer_norm import fused_layer_norm_affine\n\n    has_apex = True\nexcept ImportError:\n    has_apex = False\n\n\n# fast (ie lower precision LN) can be disabled with this flag if issues crop up\n_USE_FAST_NORM = False  # defaulting to False for now\n\n\ndef is_fast_norm():\n    return _USE_FAST_NORM\n\n\ndef set_fast_norm(enable=True):\n    global _USE_FAST_NORM\n    _USE_FAST_NORM = enable\n\n\ndef fast_group_norm(\n    x: torch.Tensor,\n    num_groups: int,\n    weight: Optional[torch.Tensor] = None,\n    bias: Optional[torch.Tensor] = None,\n    eps: float = 1e-5,\n) -> torch.Tensor:\n    if torch.jit.is_scripting():\n        # currently cannot use is_autocast_enabled within torchscript\n        return F.group_norm(x, num_groups, weight, bias, eps)\n\n    if torch.is_autocast_enabled():\n        # normally native AMP casts GN inputs to float32\n        # here we use the low precision autocast dtype\n        # FIXME what to do re CPU autocast?\n        dt = torch.get_autocast_gpu_dtype()\n        x, weight, bias = x.to(dt), weight.to(dt), bias.to(dt)\n\n    with torch.cuda.amp.autocast(enabled=False):\n        return F.group_norm(x, num_groups, weight, bias, eps)\n\n\ndef fast_layer_norm(\n    x: torch.Tensor,\n    normalized_shape: List[int],\n    weight: Optional[torch.Tensor] = None,\n    bias: Optional[torch.Tensor] = None,\n    eps: float = 1e-5,\n) -> torch.Tensor:\n    if torch.jit.is_scripting():\n        # currently cannot use is_autocast_enabled within torchscript\n        return F.layer_norm(x, normalized_shape, weight, bias, eps)\n\n    if has_apex:\n        return fused_layer_norm_affine(x, weight, bias, normalized_shape, eps)\n\n    if torch.is_autocast_enabled():\n        # normally native AMP casts LN inputs to float32\n        # apex LN does not, this is behaving like Apex\n        dt = torch.get_autocast_gpu_dtype()\n        # FIXME what to do re CPU autocast?\n        x, weight, bias = x.to(dt), weight.to(dt), bias.to(dt)\n\n    with torch.cuda.amp.autocast(enabled=False):\n        return F.layer_norm(x, normalized_shape, weight, bias, eps)\n"
  },
  {
    "path": "RVT/models/layers/maxvit/layers/filter_response_norm.py",
    "content": "\"\"\" Filter Response Norm in PyTorch\n\nBased on `Filter Response Normalization Layer` - https://arxiv.org/abs/1911.09737\n\nHacked together by / Copyright 2021 Ross Wightman\n\"\"\"\n\nimport torch\nimport torch.nn as nn\n\nfrom .create_act import create_act_layer\nfrom .trace_utils import _assert\n\n\ndef inv_instance_rms(x, eps: float = 1e-5):\n    rms = x.square().float().mean(dim=(2, 3), keepdim=True).add(eps).rsqrt().to(x.dtype)\n    return rms.expand(x.shape)\n\n\nclass FilterResponseNormTlu2d(nn.Module):\n    def __init__(self, num_features, apply_act=True, eps=1e-5, rms=True, **_):\n        super(FilterResponseNormTlu2d, self).__init__()\n        self.apply_act = apply_act  # apply activation (non-linearity)\n        self.rms = rms\n        self.eps = eps\n        self.weight = nn.Parameter(torch.ones(num_features))\n        self.bias = nn.Parameter(torch.zeros(num_features))\n        self.tau = nn.Parameter(torch.zeros(num_features)) if apply_act else None\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        nn.init.ones_(self.weight)\n        nn.init.zeros_(self.bias)\n        if self.tau is not None:\n            nn.init.zeros_(self.tau)\n\n    def forward(self, x):\n        _assert(x.dim() == 4, \"expected 4D input\")\n        x_dtype = x.dtype\n        v_shape = (1, -1, 1, 1)\n        x = x * inv_instance_rms(x, self.eps)\n        x = x * self.weight.view(v_shape).to(dtype=x_dtype) + self.bias.view(\n            v_shape\n        ).to(dtype=x_dtype)\n        return (\n            torch.maximum(x, self.tau.reshape(v_shape).to(dtype=x_dtype))\n            if self.tau is not None\n            else x\n        )\n\n\nclass FilterResponseNormAct2d(nn.Module):\n    def __init__(\n        self,\n        num_features,\n        apply_act=True,\n        act_layer=nn.ReLU,\n        inplace=None,\n        rms=True,\n        eps=1e-5,\n        **_\n    ):\n        super(FilterResponseNormAct2d, self).__init__()\n        if act_layer is not None and apply_act:\n            self.act = create_act_layer(act_layer, inplace=inplace)\n        else:\n            self.act = nn.Identity()\n        self.rms = rms\n        self.eps = eps\n        self.weight = nn.Parameter(torch.ones(num_features))\n        self.bias = nn.Parameter(torch.zeros(num_features))\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        nn.init.ones_(self.weight)\n        nn.init.zeros_(self.bias)\n\n    def forward(self, x):\n        _assert(x.dim() == 4, \"expected 4D input\")\n        x_dtype = x.dtype\n        v_shape = (1, -1, 1, 1)\n        x = x * inv_instance_rms(x, self.eps)\n        x = x * self.weight.view(v_shape).to(dtype=x_dtype) + self.bias.view(\n            v_shape\n        ).to(dtype=x_dtype)\n        return self.act(x)\n"
  },
  {
    "path": "RVT/models/layers/maxvit/layers/gather_excite.py",
    "content": "\"\"\" Gather-Excite Attention Block\n\nPaper: `Gather-Excite: Exploiting Feature Context in CNNs` - https://arxiv.org/abs/1810.12348\n\nOfficial code here, but it's only partial impl in Caffe: https://github.com/hujie-frank/GENet\n\nI've tried to support all of the extent both w/ and w/o params. I don't believe I've seen another\nimpl that covers all of the cases.\n\nNOTE: extent=0 + extra_params=False is equivalent to Squeeze-and-Excitation\n\nHacked together by / Copyright 2021 Ross Wightman\n\"\"\"\n\nimport math\n\nfrom torch import nn as nn\nimport torch.nn.functional as F\n\nfrom .create_act import create_act_layer, get_act_layer\nfrom .create_conv2d import create_conv2d\nfrom .helpers import make_divisible\nfrom .mlp import ConvMlp\n\n\nclass GatherExcite(nn.Module):\n    \"\"\"Gather-Excite Attention Module\"\"\"\n\n    def __init__(\n        self,\n        channels,\n        feat_size=None,\n        extra_params=False,\n        extent=0,\n        use_mlp=True,\n        rd_ratio=1.0 / 16,\n        rd_channels=None,\n        rd_divisor=1,\n        add_maxpool=False,\n        act_layer=nn.ReLU,\n        norm_layer=nn.BatchNorm2d,\n        gate_layer=\"sigmoid\",\n    ):\n        super(GatherExcite, self).__init__()\n        self.add_maxpool = add_maxpool\n        act_layer = get_act_layer(act_layer)\n        self.extent = extent\n        if extra_params:\n            self.gather = nn.Sequential()\n            if extent == 0:\n                assert (\n                    feat_size is not None\n                ), \"spatial feature size must be specified for global extent w/ params\"\n                self.gather.add_module(\n                    \"conv1\",\n                    create_conv2d(\n                        channels,\n                        channels,\n                        kernel_size=feat_size,\n                        stride=1,\n                        depthwise=True,\n                    ),\n                )\n                if norm_layer:\n                    self.gather.add_module(f\"norm1\", nn.BatchNorm2d(channels))\n            else:\n                assert extent % 2 == 0\n                num_conv = int(math.log2(extent))\n                for i in range(num_conv):\n                    self.gather.add_module(\n                        f\"conv{i + 1}\",\n                        create_conv2d(\n                            channels, channels, kernel_size=3, stride=2, depthwise=True\n                        ),\n                    )\n                    if norm_layer:\n                        self.gather.add_module(f\"norm{i + 1}\", nn.BatchNorm2d(channels))\n                    if i != num_conv - 1:\n                        self.gather.add_module(f\"act{i + 1}\", act_layer(inplace=True))\n        else:\n            self.gather = None\n            if self.extent == 0:\n                self.gk = 0\n                self.gs = 0\n            else:\n                assert extent % 2 == 0\n                self.gk = self.extent * 2 - 1\n                self.gs = self.extent\n\n        if not rd_channels:\n            rd_channels = make_divisible(\n                channels * rd_ratio, rd_divisor, round_limit=0.0\n            )\n        self.mlp = (\n            ConvMlp(channels, rd_channels, act_layer=act_layer)\n            if use_mlp\n            else nn.Identity()\n        )\n        self.gate = create_act_layer(gate_layer)\n\n    def forward(self, x):\n        size = x.shape[-2:]\n        if self.gather is not None:\n            x_ge = self.gather(x)\n        else:\n            if self.extent == 0:\n                # global extent\n                x_ge = x.mean(dim=(2, 3), keepdims=True)\n                if self.add_maxpool:\n                    # experimental codepath, may remove or change\n                    x_ge = 0.5 * x_ge + 0.5 * x.amax((2, 3), keepdim=True)\n            else:\n                x_ge = F.avg_pool2d(\n                    x,\n                    kernel_size=self.gk,\n                    stride=self.gs,\n                    padding=self.gk // 2,\n                    count_include_pad=False,\n                )\n                if self.add_maxpool:\n                    # experimental codepath, may remove or change\n                    x_ge = 0.5 * x_ge + 0.5 * F.max_pool2d(\n                        x, kernel_size=self.gk, stride=self.gs, padding=self.gk // 2\n                    )\n        x_ge = self.mlp(x_ge)\n        if x_ge.shape[-1] != 1 or x_ge.shape[-2] != 1:\n            x_ge = F.interpolate(x_ge, size=size)\n        return x * self.gate(x_ge)\n"
  },
  {
    "path": "RVT/models/layers/maxvit/layers/global_context.py",
    "content": "\"\"\" Global Context Attention Block\n\nPaper: `GCNet: Non-local Networks Meet Squeeze-Excitation Networks and Beyond`\n    - https://arxiv.org/abs/1904.11492\n\nOfficial code consulted as reference: https://github.com/xvjiarui/GCNet\n\nHacked together by / Copyright 2021 Ross Wightman\n\"\"\"\n\nfrom torch import nn as nn\nimport torch.nn.functional as F\n\nfrom .create_act import create_act_layer, get_act_layer\nfrom .helpers import make_divisible\nfrom .mlp import ConvMlp\nfrom .norm import LayerNorm2d\n\n\nclass GlobalContext(nn.Module):\n    def __init__(\n        self,\n        channels,\n        use_attn=True,\n        fuse_add=False,\n        fuse_scale=True,\n        init_last_zero=False,\n        rd_ratio=1.0 / 8,\n        rd_channels=None,\n        rd_divisor=1,\n        act_layer=nn.ReLU,\n        gate_layer=\"sigmoid\",\n    ):\n        super(GlobalContext, self).__init__()\n        act_layer = get_act_layer(act_layer)\n\n        self.conv_attn = (\n            nn.Conv2d(channels, 1, kernel_size=1, bias=True) if use_attn else None\n        )\n\n        if rd_channels is None:\n            rd_channels = make_divisible(\n                channels * rd_ratio, rd_divisor, round_limit=0.0\n            )\n        if fuse_add:\n            self.mlp_add = ConvMlp(\n                channels, rd_channels, act_layer=act_layer, norm_layer=LayerNorm2d\n            )\n        else:\n            self.mlp_add = None\n        if fuse_scale:\n            self.mlp_scale = ConvMlp(\n                channels, rd_channels, act_layer=act_layer, norm_layer=LayerNorm2d\n            )\n        else:\n            self.mlp_scale = None\n\n        self.gate = create_act_layer(gate_layer)\n        self.init_last_zero = init_last_zero\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        if self.conv_attn is not None:\n            nn.init.kaiming_normal_(\n                self.conv_attn.weight, mode=\"fan_in\", nonlinearity=\"relu\"\n            )\n        if self.mlp_add is not None:\n            nn.init.zeros_(self.mlp_add.fc2.weight)\n\n    def forward(self, x):\n        B, C, H, W = x.shape\n\n        if self.conv_attn is not None:\n            attn = self.conv_attn(x).reshape(B, 1, H * W)  # (B, 1, H * W)\n            attn = F.softmax(attn, dim=-1).unsqueeze(3)  # (B, 1, H * W, 1)\n            context = x.reshape(B, C, H * W).unsqueeze(1) @ attn\n            context = context.view(B, C, 1, 1)\n        else:\n            context = x.mean(dim=(2, 3), keepdim=True)\n\n        if self.mlp_scale is not None:\n            mlp_x = self.mlp_scale(context)\n            x = x * self.gate(mlp_x)\n        if self.mlp_add is not None:\n            mlp_x = self.mlp_add(context)\n            x = x + mlp_x\n\n        return x\n"
  },
  {
    "path": "RVT/models/layers/maxvit/layers/halo_attn.py",
    "content": "\"\"\" Halo Self Attention\n\nPaper: `Scaling Local Self-Attention for Parameter Efficient Visual Backbones`\n    - https://arxiv.org/abs/2103.12731\n\n@misc{2103.12731,\nAuthor = {Ashish Vaswani and Prajit Ramachandran and Aravind Srinivas and Niki Parmar and Blake Hechtman and\n    Jonathon Shlens},\nTitle = {Scaling Local Self-Attention for Parameter Efficient Visual Backbones},\nYear = {2021},\n}\n\nStatus:\nThis impl is a WIP, there is no official ref impl and some details in paper weren't clear to me.\nThe attention mechanism works but it's slow as implemented.\n\nHacked together by / Copyright 2021 Ross Wightman\n\"\"\"\n\nfrom typing import List\n\nimport torch\nfrom torch import nn\nimport torch.nn.functional as F\n\nfrom .helpers import make_divisible\nfrom .weight_init import trunc_normal_\nfrom .trace_utils import _assert\n\n\ndef rel_logits_1d(q, rel_k, permute_mask: List[int]):\n    \"\"\"Compute relative logits along one dimension\n\n    As per: https://gist.github.com/aravindsrinivas/56359b79f0ce4449bcb04ab4b56a57a2\n    Originally from: `Attention Augmented Convolutional Networks` - https://arxiv.org/abs/1904.09925\n\n    Args:\n        q: (batch, height, width, dim)\n        rel_k: (2 * window - 1, dim)\n        permute_mask: permute output dim according to this\n    \"\"\"\n    B, H, W, dim = q.shape\n    rel_size = rel_k.shape[0]\n    win_size = (rel_size + 1) // 2\n\n    x = q @ rel_k.transpose(-1, -2)\n    x = x.reshape(-1, W, rel_size)\n\n    # pad to shift from relative to absolute indexing\n    x_pad = F.pad(x, [0, 1]).flatten(1)\n    x_pad = F.pad(x_pad, [0, rel_size - W])\n\n    # reshape and slice out the padded elements\n    x_pad = x_pad.reshape(-1, W + 1, rel_size)\n    x = x_pad[:, :W, win_size - 1 :]\n\n    # reshape and tile\n    x = x.reshape(B, H, 1, W, win_size).expand(-1, -1, win_size, -1, -1)\n    return x.permute(permute_mask)\n\n\nclass PosEmbedRel(nn.Module):\n    \"\"\"Relative Position Embedding\n    As per: https://gist.github.com/aravindsrinivas/56359b79f0ce4449bcb04ab4b56a57a2\n    Originally from: `Attention Augmented Convolutional Networks` - https://arxiv.org/abs/1904.09925\n\n    \"\"\"\n\n    def __init__(self, block_size, win_size, dim_head, scale):\n        \"\"\"\n        Args:\n            block_size (int): block size\n            win_size (int): neighbourhood window size\n            dim_head (int): attention head dim\n            scale (float): scale factor (for init)\n        \"\"\"\n        super().__init__()\n        self.block_size = block_size\n        self.dim_head = dim_head\n        self.height_rel = nn.Parameter(torch.randn(win_size * 2 - 1, dim_head) * scale)\n        self.width_rel = nn.Parameter(torch.randn(win_size * 2 - 1, dim_head) * scale)\n\n    def forward(self, q):\n        B, BB, HW, _ = q.shape\n\n        # relative logits in width dimension.\n        q = q.reshape(-1, self.block_size, self.block_size, self.dim_head)\n        rel_logits_w = rel_logits_1d(q, self.width_rel, permute_mask=(0, 1, 3, 2, 4))\n\n        # relative logits in height dimension.\n        q = q.transpose(1, 2)\n        rel_logits_h = rel_logits_1d(q, self.height_rel, permute_mask=(0, 3, 1, 4, 2))\n\n        rel_logits = rel_logits_h + rel_logits_w\n        rel_logits = rel_logits.reshape(B, BB, HW, -1)\n        return rel_logits\n\n\nclass HaloAttn(nn.Module):\n    \"\"\"Halo Attention\n\n    Paper: `Scaling Local Self-Attention for Parameter Efficient Visual Backbones`\n        - https://arxiv.org/abs/2103.12731\n\n    The internal dimensions of the attention module are controlled by the interaction of several arguments.\n      * the output dimension of the module is specified by dim_out, which falls back to input dim if not set\n      * the value (v) dimension is set to dim_out // num_heads, the v projection determines the output dim\n      * the query and key (qk) dimensions are determined by\n        * num_heads * dim_head if dim_head is not None\n        * num_heads * (dim_out * attn_ratio // num_heads) if dim_head is None\n      * as seen above, attn_ratio determines the ratio of q and k relative to the output if dim_head not used\n\n    Args:\n        dim (int): input dimension to the module\n        dim_out (int): output dimension of the module, same as dim if not set\n        feat_size (Tuple[int, int]): size of input feature_map (not used, for arg compat with bottle/lambda)\n        stride: output stride of the module, query downscaled if > 1 (default: 1).\n        num_heads: parallel attention heads (default: 8).\n        dim_head: dimension of query and key heads, calculated from dim_out * attn_ratio // num_heads if not set\n        block_size (int): size of blocks. (default: 8)\n        halo_size (int): size of halo overlap. (default: 3)\n        qk_ratio (float): ratio of q and k dimensions to output dimension when dim_head not set. (default: 1.0)\n        qkv_bias (bool) : add bias to q, k, and v projections\n        avg_down (bool): use average pool downsample instead of strided query blocks\n        scale_pos_embed (bool): scale the position embedding as well as Q @ K\n    \"\"\"\n\n    def __init__(\n        self,\n        dim,\n        dim_out=None,\n        feat_size=None,\n        stride=1,\n        num_heads=8,\n        dim_head=None,\n        block_size=8,\n        halo_size=3,\n        qk_ratio=1.0,\n        qkv_bias=False,\n        avg_down=False,\n        scale_pos_embed=False,\n    ):\n        super().__init__()\n        dim_out = dim_out or dim\n        assert dim_out % num_heads == 0\n        assert stride in (1, 2)\n        self.num_heads = num_heads\n        self.dim_head_qk = (\n            dim_head or make_divisible(dim_out * qk_ratio, divisor=8) // num_heads\n        )\n        self.dim_head_v = dim_out // self.num_heads\n        self.dim_out_qk = num_heads * self.dim_head_qk\n        self.dim_out_v = num_heads * self.dim_head_v\n        self.scale = self.dim_head_qk**-0.5\n        self.scale_pos_embed = scale_pos_embed\n        self.block_size = self.block_size_ds = block_size\n        self.halo_size = halo_size\n        self.win_size = block_size + halo_size * 2  # neighbourhood window size\n        self.block_stride = 1\n        use_avg_pool = False\n        if stride > 1:\n            use_avg_pool = avg_down or block_size % stride != 0\n            self.block_stride = 1 if use_avg_pool else stride\n            self.block_size_ds = self.block_size // self.block_stride\n\n        # FIXME not clear if this stride behaviour is what the paper intended\n        # Also, the paper mentions using a 3D conv for dealing with the blocking/gather, and leaving\n        # data in unfolded block form. I haven't wrapped my head around how that'd look.\n        self.q = nn.Conv2d(\n            dim, self.dim_out_qk, 1, stride=self.block_stride, bias=qkv_bias\n        )\n        self.kv = nn.Conv2d(dim, self.dim_out_qk + self.dim_out_v, 1, bias=qkv_bias)\n\n        self.pos_embed = PosEmbedRel(\n            block_size=self.block_size_ds,\n            win_size=self.win_size,\n            dim_head=self.dim_head_qk,\n            scale=self.scale,\n        )\n\n        self.pool = nn.AvgPool2d(2, 2) if use_avg_pool else nn.Identity()\n\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        std = self.q.weight.shape[1] ** -0.5  # fan-in\n        trunc_normal_(self.q.weight, std=std)\n        trunc_normal_(self.kv.weight, std=std)\n        trunc_normal_(self.pos_embed.height_rel, std=self.scale)\n        trunc_normal_(self.pos_embed.width_rel, std=self.scale)\n\n    def forward(self, x):\n        B, C, H, W = x.shape\n        _assert(H % self.block_size == 0, \"\")\n        _assert(W % self.block_size == 0, \"\")\n        num_h_blocks = H // self.block_size\n        num_w_blocks = W // self.block_size\n        num_blocks = num_h_blocks * num_w_blocks\n\n        q = self.q(x)\n        # unfold\n        q = q.reshape(\n            -1,\n            self.dim_head_qk,\n            num_h_blocks,\n            self.block_size_ds,\n            num_w_blocks,\n            self.block_size_ds,\n        ).permute(0, 1, 3, 5, 2, 4)\n        # B, num_heads * dim_head * block_size ** 2, num_blocks\n        q = q.reshape(B * self.num_heads, self.dim_head_qk, -1, num_blocks).transpose(\n            1, 3\n        )\n        # B * num_heads, num_blocks, block_size ** 2, dim_head\n\n        kv = self.kv(x)\n        # Generate overlapping windows for kv. This approach is good for GPU and CPU. However, unfold() is not\n        # lowered for PyTorch XLA so it will be very slow. See code at bottom of file for XLA friendly approach.\n        # FIXME figure out how to switch impl between this and conv2d if XLA being used.\n        kv = F.pad(kv, [self.halo_size, self.halo_size, self.halo_size, self.halo_size])\n        kv = (\n            kv.unfold(2, self.win_size, self.block_size)\n            .unfold(3, self.win_size, self.block_size)\n            .reshape(\n                B * self.num_heads, self.dim_head_qk + self.dim_head_v, num_blocks, -1\n            )\n            .permute(0, 2, 3, 1)\n        )\n        k, v = torch.split(kv, [self.dim_head_qk, self.dim_head_v], dim=-1)\n        # B * num_heads, num_blocks, win_size ** 2, dim_head_qk or dim_head_v\n\n        if self.scale_pos_embed:\n            attn = (q @ k.transpose(-1, -2) + self.pos_embed(q)) * self.scale\n        else:\n            attn = (q @ k.transpose(-1, -2)) * self.scale + self.pos_embed(q)\n        # B * num_heads, num_blocks, block_size ** 2, win_size ** 2\n        attn = attn.softmax(dim=-1)\n\n        out = (attn @ v).transpose(\n            1, 3\n        )  # B * num_heads, dim_head_v, block_size ** 2, num_blocks\n        # fold\n        out = out.reshape(\n            -1, self.block_size_ds, self.block_size_ds, num_h_blocks, num_w_blocks\n        )\n        out = (\n            out.permute(0, 3, 1, 4, 2)\n            .contiguous()\n            .view(B, self.dim_out_v, H // self.block_stride, W // self.block_stride)\n        )\n        # B, dim_out, H // block_stride, W // block_stride\n        out = self.pool(out)\n        return out\n\n\n\"\"\" Three alternatives for overlapping windows.\n\n`.unfold().unfold()` is same speed as stride tricks with similar clarity as F.unfold()\n\n    if is_xla:\n        # This code achieves haloing on PyTorch XLA with reasonable runtime trade-off, it is\n        # EXTREMELY slow for backward on a GPU though so I need a way of selecting based on environment.\n        WW = self.win_size ** 2\n        pw = torch.eye(WW, dtype=x.dtype, device=x.device).reshape(WW, 1, self.win_size, self.win_size)\n        kv = F.conv2d(kv.reshape(-1, 1, H, W), pw, stride=self.block_size, padding=self.halo_size)\n    elif self.stride_tricks:\n        kv = F.pad(kv, [self.halo_size, self.halo_size, self.halo_size, self.halo_size]).contiguous()\n        kv = kv.as_strided((\n            B, self.dim_out_qk + self.dim_out_v, self.win_size, self.win_size, num_h_blocks, num_w_blocks),\n            stride=(kv.stride(0), kv.stride(1), kv.shape[-1], 1, self.block_size * kv.shape[-1], self.block_size))\n    else:\n        kv = F.unfold(kv, kernel_size=self.win_size, stride=self.block_size, padding=self.halo_size)\n\n    kv = kv.reshape(\n       B * self.num_heads, self.dim_head_qk + self.dim_head_v, -1, num_blocks).transpose(1, 3)\n\"\"\"\n"
  },
  {
    "path": "RVT/models/layers/maxvit/layers/helpers.py",
    "content": "\"\"\" Layer/Module Helpers\n\nHacked together by / Copyright 2020 Ross Wightman\n\"\"\"\n\nfrom itertools import repeat\nimport collections.abc\n\n\n# From PyTorch internals\ndef _ntuple(n):\n    def parse(x):\n        if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):\n            return x\n        return tuple(repeat(x, n))\n\n    return parse\n\n\nto_1tuple = _ntuple(1)\nto_2tuple = _ntuple(2)\nto_3tuple = _ntuple(3)\nto_4tuple = _ntuple(4)\nto_ntuple = _ntuple\n\n\ndef make_divisible(v, divisor=8, min_value=None, round_limit=0.9):\n    min_value = min_value or divisor\n    new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)\n    # Make sure that round down does not go down by more than 10%.\n    if new_v < round_limit * v:\n        new_v += divisor\n    return new_v\n\n\ndef extend_tuple(x, n):\n    # pdas a tuple to specified n by padding with last value\n    if not isinstance(x, (tuple, list)):\n        x = (x,)\n    else:\n        x = tuple(x)\n    pad_n = n - len(x)\n    if pad_n <= 0:\n        return x[:n]\n    return x + (x[-1],) * pad_n\n"
  },
  {
    "path": "RVT/models/layers/maxvit/layers/inplace_abn.py",
    "content": "import torch\nfrom torch import nn as nn\n\ntry:\n    from inplace_abn.functions import inplace_abn, inplace_abn_sync\n\n    has_iabn = True\nexcept ImportError:\n    has_iabn = False\n\n    def inplace_abn(\n        x,\n        weight,\n        bias,\n        running_mean,\n        running_var,\n        training=True,\n        momentum=0.1,\n        eps=1e-05,\n        activation=\"leaky_relu\",\n        activation_param=0.01,\n    ):\n        raise ImportError(\n            \"Please install InplaceABN:'pip install git+https://github.com/mapillary/inplace_abn.git@v1.0.12'\"\n        )\n\n    def inplace_abn_sync(**kwargs):\n        inplace_abn(**kwargs)\n\n\nclass InplaceAbn(nn.Module):\n    \"\"\"Activated Batch Normalization\n\n    This gathers a BatchNorm and an activation function in a single module\n\n    Parameters\n    ----------\n    num_features : int\n        Number of feature channels in the input and output.\n    eps : float\n        Small constant to prevent numerical issues.\n    momentum : float\n        Momentum factor applied to compute running statistics.\n    affine : bool\n        If `True` apply learned scale and shift transformation after normalization.\n    act_layer : str or nn.Module type\n        Name or type of the activation functions, one of: `leaky_relu`, `elu`\n    act_param : float\n        Negative slope for the `leaky_relu` activation.\n    \"\"\"\n\n    def __init__(\n        self,\n        num_features,\n        eps=1e-5,\n        momentum=0.1,\n        affine=True,\n        apply_act=True,\n        act_layer=\"leaky_relu\",\n        act_param=0.01,\n        drop_layer=None,\n    ):\n        super(InplaceAbn, self).__init__()\n        self.num_features = num_features\n        self.affine = affine\n        self.eps = eps\n        self.momentum = momentum\n        if apply_act:\n            if isinstance(act_layer, str):\n                assert act_layer in (\"leaky_relu\", \"elu\", \"identity\", \"\")\n                self.act_name = act_layer if act_layer else \"identity\"\n            else:\n                # convert act layer passed as type to string\n                if act_layer == nn.ELU:\n                    self.act_name = \"elu\"\n                elif act_layer == nn.LeakyReLU:\n                    self.act_name = \"leaky_relu\"\n                elif act_layer is None or act_layer == nn.Identity:\n                    self.act_name = \"identity\"\n                else:\n                    assert False, f\"Invalid act layer {act_layer.__name__} for IABN\"\n        else:\n            self.act_name = \"identity\"\n        self.act_param = act_param\n        if self.affine:\n            self.weight = nn.Parameter(torch.ones(num_features))\n            self.bias = nn.Parameter(torch.zeros(num_features))\n        else:\n            self.register_parameter(\"weight\", None)\n            self.register_parameter(\"bias\", None)\n        self.register_buffer(\"running_mean\", torch.zeros(num_features))\n        self.register_buffer(\"running_var\", torch.ones(num_features))\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        nn.init.constant_(self.running_mean, 0)\n        nn.init.constant_(self.running_var, 1)\n        if self.affine:\n            nn.init.constant_(self.weight, 1)\n            nn.init.constant_(self.bias, 0)\n\n    def forward(self, x):\n        output = inplace_abn(\n            x,\n            self.weight,\n            self.bias,\n            self.running_mean,\n            self.running_var,\n            self.training,\n            self.momentum,\n            self.eps,\n            self.act_name,\n            self.act_param,\n        )\n        if isinstance(output, tuple):\n            output = output[0]\n        return output\n"
  },
  {
    "path": "RVT/models/layers/maxvit/layers/lambda_layer.py",
    "content": "\"\"\" Lambda Layer\n\nPaper: `LambdaNetworks: Modeling Long-Range Interactions Without Attention`\n    - https://arxiv.org/abs/2102.08602\n\n@misc{2102.08602,\nAuthor = {Irwan Bello},\nTitle = {LambdaNetworks: Modeling Long-Range Interactions Without Attention},\nYear = {2021},\n}\n\nStatus:\nThis impl is a WIP. Code snippets in the paper were used as reference but\ngood chance some details are missing/wrong.\n\nI've only implemented local lambda conv based pos embeddings.\n\nFor a PyTorch impl that includes other embedding options checkout\nhttps://github.com/lucidrains/lambda-networks\n\nHacked together by / Copyright 2021 Ross Wightman\n\"\"\"\n\nimport torch\nfrom torch import nn\nimport torch.nn.functional as F\n\nfrom .helpers import to_2tuple, make_divisible\nfrom .weight_init import trunc_normal_\n\n\ndef rel_pos_indices(size):\n    size = to_2tuple(size)\n    pos = torch.stack(\n        torch.meshgrid(torch.arange(size[0]), torch.arange(size[1]))\n    ).flatten(1)\n    rel_pos = pos[:, None, :] - pos[:, :, None]\n    rel_pos[0] += size[0] - 1\n    rel_pos[1] += size[1] - 1\n    return rel_pos  # 2, H * W, H * W\n\n\nclass LambdaLayer(nn.Module):\n    \"\"\"Lambda Layer\n\n    Paper: `LambdaNetworks: Modeling Long-Range Interactions Without Attention`\n        - https://arxiv.org/abs/2102.08602\n\n    NOTE: intra-depth parameter 'u' is fixed at 1. It did not appear worth the complexity to add.\n\n    The internal dimensions of the lambda module are controlled via the interaction of several arguments.\n      * the output dimension of the module is specified by dim_out, which falls back to input dim if not set\n      * the value (v) dimension is set to dim_out // num_heads, the v projection determines the output dim\n      * the query (q) and key (k) dimension are determined by\n        * dim_head = (dim_out * attn_ratio // num_heads) if dim_head is None\n        * q = num_heads * dim_head, k = dim_head\n      * as seen above, attn_ratio determines the ratio of q and k relative to the output if dim_head not set\n\n    Args:\n        dim (int): input dimension to the module\n        dim_out (int): output dimension of the module, same as dim if not set\n        feat_size (Tuple[int, int]): size of input feature_map for relative pos variant H, W\n        stride (int): output stride of the module, avg pool used if stride == 2\n        num_heads (int): parallel attention heads.\n        dim_head (int): dimension of query and key heads, calculated from dim_out * attn_ratio // num_heads if not set\n        r (int): local lambda convolution radius. Use lambda conv if set, else relative pos if not. (default: 9)\n        qk_ratio (float): ratio of q and k dimensions to output dimension when dim_head not set. (default: 1.0)\n        qkv_bias (bool): add bias to q, k, and v projections\n    \"\"\"\n\n    def __init__(\n        self,\n        dim,\n        dim_out=None,\n        feat_size=None,\n        stride=1,\n        num_heads=4,\n        dim_head=16,\n        r=9,\n        qk_ratio=1.0,\n        qkv_bias=False,\n    ):\n        super().__init__()\n        dim_out = dim_out or dim\n        assert dim_out % num_heads == 0, \" should be divided by num_heads\"\n        self.dim_qk = (\n            dim_head or make_divisible(dim_out * qk_ratio, divisor=8) // num_heads\n        )\n        self.num_heads = num_heads\n        self.dim_v = dim_out // num_heads\n\n        self.qkv = nn.Conv2d(\n            dim,\n            num_heads * self.dim_qk + self.dim_qk + self.dim_v,\n            kernel_size=1,\n            bias=qkv_bias,\n        )\n        self.norm_q = nn.BatchNorm2d(num_heads * self.dim_qk)\n        self.norm_v = nn.BatchNorm2d(self.dim_v)\n\n        if r is not None:\n            # local lambda convolution for pos\n            self.conv_lambda = nn.Conv3d(\n                1, self.dim_qk, (r, r, 1), padding=(r // 2, r // 2, 0)\n            )\n            self.pos_emb = None\n            self.rel_pos_indices = None\n        else:\n            # relative pos embedding\n            assert feat_size is not None\n            feat_size = to_2tuple(feat_size)\n            rel_size = [2 * s - 1 for s in feat_size]\n            self.conv_lambda = None\n            self.pos_emb = nn.Parameter(\n                torch.zeros(rel_size[0], rel_size[1], self.dim_qk)\n            )\n            self.register_buffer(\n                \"rel_pos_indices\", rel_pos_indices(feat_size), persistent=False\n            )\n\n        self.pool = nn.AvgPool2d(2, 2) if stride == 2 else nn.Identity()\n\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        trunc_normal_(self.qkv.weight, std=self.qkv.weight.shape[1] ** -0.5)  # fan-in\n        if self.conv_lambda is not None:\n            trunc_normal_(self.conv_lambda.weight, std=self.dim_qk**-0.5)\n        if self.pos_emb is not None:\n            trunc_normal_(self.pos_emb, std=0.02)\n\n    def forward(self, x):\n        B, C, H, W = x.shape\n        M = H * W\n        qkv = self.qkv(x)\n        q, k, v = torch.split(\n            qkv, [self.num_heads * self.dim_qk, self.dim_qk, self.dim_v], dim=1\n        )\n        q = (\n            self.norm_q(q).reshape(B, self.num_heads, self.dim_qk, M).transpose(-1, -2)\n        )  # B, num_heads, M, K\n        v = self.norm_v(v).reshape(B, self.dim_v, M).transpose(-1, -2)  # B, M, V\n        k = F.softmax(k.reshape(B, self.dim_qk, M), dim=-1)  # B, K, M\n\n        content_lam = k @ v  # B, K, V\n        content_out = q @ content_lam.unsqueeze(1)  # B, num_heads, M, V\n\n        if self.pos_emb is None:\n            position_lam = self.conv_lambda(\n                v.reshape(B, 1, H, W, self.dim_v)\n            )  # B, H, W, V, K\n            position_lam = position_lam.reshape(\n                B, 1, self.dim_qk, H * W, self.dim_v\n            ).transpose(\n                2, 3\n            )  # B, 1, M, K, V\n        else:\n            # FIXME relative pos embedding path not fully verified\n            pos_emb = self.pos_emb[\n                self.rel_pos_indices[0], self.rel_pos_indices[1]\n            ].expand(B, -1, -1, -1)\n            position_lam = (pos_emb.transpose(-1, -2) @ v.unsqueeze(1)).unsqueeze(\n                1\n            )  # B, 1, M, K, V\n        position_out = (q.unsqueeze(-2) @ position_lam).squeeze(\n            -2\n        )  # B, num_heads, M, V\n\n        out = (\n            (content_out + position_out).transpose(-1, -2).reshape(B, C, H, W)\n        )  # B, C (num_heads * V), H, W\n        out = self.pool(out)\n        return out\n"
  },
  {
    "path": "RVT/models/layers/maxvit/layers/linear.py",
    "content": "\"\"\" Linear layer (alternate definition)\n\"\"\"\n\nimport torch\nimport torch.nn.functional as F\nfrom torch import nn as nn\n\n\nclass Linear(nn.Linear):\n    r\"\"\"Applies a linear transformation to the incoming data: :math:`y = xA^T + b`\n\n    Wraps torch.nn.Linear to support AMP + torchscript usage by manually casting\n    weight & bias to input.dtype to work around an issue w/ torch.addmm in this use case.\n    \"\"\"\n\n    def forward(self, input: torch.Tensor) -> torch.Tensor:\n        if torch.jit.is_scripting():\n            bias = self.bias.to(dtype=input.dtype) if self.bias is not None else None\n            return F.linear(input, self.weight.to(dtype=input.dtype), bias=bias)\n        else:\n            return F.linear(input, self.weight, self.bias)\n"
  },
  {
    "path": "RVT/models/layers/maxvit/layers/median_pool.py",
    "content": "\"\"\" Median Pool\nHacked together by / Copyright 2020 Ross Wightman\n\"\"\"\n\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom .helpers import to_2tuple, to_4tuple\n\n\nclass MedianPool2d(nn.Module):\n    \"\"\"Median pool (usable as median filter when stride=1) module.\n\n    Args:\n         kernel_size: size of pooling kernel, int or 2-tuple\n         stride: pool stride, int or 2-tuple\n         padding: pool padding, int or 4-tuple (l, r, t, b) as in pytorch F.pad\n         same: override padding and enforce same padding, boolean\n    \"\"\"\n\n    def __init__(self, kernel_size=3, stride=1, padding=0, same=False):\n        super(MedianPool2d, self).__init__()\n        self.k = to_2tuple(kernel_size)\n        self.stride = to_2tuple(stride)\n        self.padding = to_4tuple(padding)  # convert to l, r, t, b\n        self.same = same\n\n    def _padding(self, x):\n        if self.same:\n            ih, iw = x.size()[2:]\n            if ih % self.stride[0] == 0:\n                ph = max(self.k[0] - self.stride[0], 0)\n            else:\n                ph = max(self.k[0] - (ih % self.stride[0]), 0)\n            if iw % self.stride[1] == 0:\n                pw = max(self.k[1] - self.stride[1], 0)\n            else:\n                pw = max(self.k[1] - (iw % self.stride[1]), 0)\n            pl = pw // 2\n            pr = pw - pl\n            pt = ph // 2\n            pb = ph - pt\n            padding = (pl, pr, pt, pb)\n        else:\n            padding = self.padding\n        return padding\n\n    def forward(self, x):\n        x = F.pad(x, self._padding(x), mode=\"reflect\")\n        x = x.unfold(2, self.k[0], self.stride[0]).unfold(3, self.k[1], self.stride[1])\n        x = x.contiguous().view(x.size()[:4] + (-1,)).median(dim=-1)[0]\n        return x\n"
  },
  {
    "path": "RVT/models/layers/maxvit/layers/mixed_conv2d.py",
    "content": "\"\"\" PyTorch Mixed Convolution\n\nPaper: MixConv: Mixed Depthwise Convolutional Kernels (https://arxiv.org/abs/1907.09595)\n\nHacked together by / Copyright 2020 Ross Wightman\n\"\"\"\n\nimport torch\nfrom torch import nn as nn\n\nfrom .conv2d_same import create_conv2d_pad\n\n\ndef _split_channels(num_chan, num_groups):\n    split = [num_chan // num_groups for _ in range(num_groups)]\n    split[0] += num_chan - sum(split)\n    return split\n\n\nclass MixedConv2d(nn.ModuleDict):\n    \"\"\"Mixed Grouped Convolution\n\n    Based on MDConv and GroupedConv in MixNet impl:\n      https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mixnet/custom_layers.py\n    \"\"\"\n\n    def __init__(\n        self,\n        in_channels,\n        out_channels,\n        kernel_size=3,\n        stride=1,\n        padding=\"\",\n        dilation=1,\n        depthwise=False,\n        **kwargs\n    ):\n        super(MixedConv2d, self).__init__()\n\n        kernel_size = kernel_size if isinstance(kernel_size, list) else [kernel_size]\n        num_groups = len(kernel_size)\n        in_splits = _split_channels(in_channels, num_groups)\n        out_splits = _split_channels(out_channels, num_groups)\n        self.in_channels = sum(in_splits)\n        self.out_channels = sum(out_splits)\n        for idx, (k, in_ch, out_ch) in enumerate(\n            zip(kernel_size, in_splits, out_splits)\n        ):\n            conv_groups = in_ch if depthwise else 1\n            # use add_module to keep key space clean\n            self.add_module(\n                str(idx),\n                create_conv2d_pad(\n                    in_ch,\n                    out_ch,\n                    k,\n                    stride=stride,\n                    padding=padding,\n                    dilation=dilation,\n                    groups=conv_groups,\n                    **kwargs\n                ),\n            )\n        self.splits = in_splits\n\n    def forward(self, x):\n        x_split = torch.split(x, self.splits, 1)\n        x_out = [c(x_split[i]) for i, c in enumerate(self.values())]\n        x = torch.cat(x_out, 1)\n        return x\n"
  },
  {
    "path": "RVT/models/layers/maxvit/layers/ml_decoder.py",
    "content": "from typing import Optional\n\nimport torch\nfrom torch import nn\nfrom torch import nn, Tensor\nfrom torch.nn.modules.transformer import _get_activation_fn\n\n\ndef add_ml_decoder_head(model):\n    if hasattr(model, \"global_pool\") and hasattr(\n        model, \"fc\"\n    ):  # most CNN models, like Resnet50\n        model.global_pool = nn.Identity()\n        del model.fc\n        num_classes = model.num_classes\n        num_features = model.num_features\n        model.fc = MLDecoder(num_classes=num_classes, initial_num_features=num_features)\n    elif hasattr(model, \"global_pool\") and hasattr(model, \"classifier\"):  # EfficientNet\n        model.global_pool = nn.Identity()\n        del model.classifier\n        num_classes = model.num_classes\n        num_features = model.num_features\n        model.classifier = MLDecoder(\n            num_classes=num_classes, initial_num_features=num_features\n        )\n    elif (\n        \"RegNet\" in model._get_name() or \"TResNet\" in model._get_name()\n    ):  # hasattr(model, 'head')\n        del model.head\n        num_classes = model.num_classes\n        num_features = model.num_features\n        model.head = MLDecoder(\n            num_classes=num_classes, initial_num_features=num_features\n        )\n    else:\n        print(\"Model code-writing is not aligned currently with ml-decoder\")\n        exit(-1)\n    if hasattr(model, \"drop_rate\"):  # Ml-Decoder has inner dropout\n        model.drop_rate = 0\n    return model\n\n\nclass TransformerDecoderLayerOptimal(nn.Module):\n    def __init__(\n        self,\n        d_model,\n        nhead=8,\n        dim_feedforward=2048,\n        dropout=0.1,\n        activation=\"relu\",\n        layer_norm_eps=1e-5,\n    ) -> None:\n        super(TransformerDecoderLayerOptimal, self).__init__()\n        self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps)\n        self.dropout = nn.Dropout(dropout)\n        self.dropout1 = nn.Dropout(dropout)\n        self.dropout2 = nn.Dropout(dropout)\n        self.dropout3 = nn.Dropout(dropout)\n\n        self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)\n\n        # Implementation of Feedforward model\n        self.linear1 = nn.Linear(d_model, dim_feedforward)\n        self.linear2 = nn.Linear(dim_feedforward, d_model)\n\n        self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps)\n        self.norm3 = nn.LayerNorm(d_model, eps=layer_norm_eps)\n\n        self.activation = _get_activation_fn(activation)\n\n    def __setstate__(self, state):\n        if \"activation\" not in state:\n            state[\"activation\"] = torch.nn.functional.relu\n        super(TransformerDecoderLayerOptimal, self).__setstate__(state)\n\n    def forward(\n        self,\n        tgt: Tensor,\n        memory: Tensor,\n        tgt_mask: Optional[Tensor] = None,\n        memory_mask: Optional[Tensor] = None,\n        tgt_key_padding_mask: Optional[Tensor] = None,\n        memory_key_padding_mask: Optional[Tensor] = None,\n    ) -> Tensor:\n        tgt = tgt + self.dropout1(tgt)\n        tgt = self.norm1(tgt)\n        tgt2 = self.multihead_attn(tgt, memory, memory)[0]\n        tgt = tgt + self.dropout2(tgt2)\n        tgt = self.norm2(tgt)\n        tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))\n        tgt = tgt + self.dropout3(tgt2)\n        tgt = self.norm3(tgt)\n        return tgt\n\n\n# @torch.jit.script\n# class ExtrapClasses(object):\n#     def __init__(self, num_queries: int, group_size: int):\n#         self.num_queries = num_queries\n#         self.group_size = group_size\n#\n#     def __call__(self, h: torch.Tensor, class_embed_w: torch.Tensor, class_embed_b: torch.Tensor, out_extrap:\n#     torch.Tensor):\n#         # h = h.unsqueeze(-1).expand(-1, -1, -1, self.group_size)\n#         h = h[..., None].repeat(1, 1, 1, self.group_size) # torch.Size([bs, 5, 768, groups])\n#         w = class_embed_w.view((self.num_queries, h.shape[2], self.group_size))\n#         out = (h * w).sum(dim=2) + class_embed_b\n#         out = out.view((h.shape[0], self.group_size * self.num_queries))\n#         return out\n\n\n@torch.jit.script\nclass GroupFC(object):\n    def __init__(self, embed_len_decoder: int):\n        self.embed_len_decoder = embed_len_decoder\n\n    def __call__(\n        self, h: torch.Tensor, duplicate_pooling: torch.Tensor, out_extrap: torch.Tensor\n    ):\n        for i in range(self.embed_len_decoder):\n            h_i = h[:, i, :]\n            w_i = duplicate_pooling[i, :, :]\n            out_extrap[:, i, :] = torch.matmul(h_i, w_i)\n\n\nclass MLDecoder(nn.Module):\n    def __init__(\n        self,\n        num_classes,\n        num_of_groups=-1,\n        decoder_embedding=768,\n        initial_num_features=2048,\n    ):\n        super(MLDecoder, self).__init__()\n        embed_len_decoder = 100 if num_of_groups < 0 else num_of_groups\n        if embed_len_decoder > num_classes:\n            embed_len_decoder = num_classes\n\n        # switching to 768 initial embeddings\n        decoder_embedding = 768 if decoder_embedding < 0 else decoder_embedding\n        self.embed_standart = nn.Linear(initial_num_features, decoder_embedding)\n\n        # decoder\n        decoder_dropout = 0.1\n        num_layers_decoder = 1\n        dim_feedforward = 2048\n        layer_decode = TransformerDecoderLayerOptimal(\n            d_model=decoder_embedding,\n            dim_feedforward=dim_feedforward,\n            dropout=decoder_dropout,\n        )\n        self.decoder = nn.TransformerDecoder(\n            layer_decode, num_layers=num_layers_decoder\n        )\n\n        # non-learnable queries\n        self.query_embed = nn.Embedding(embed_len_decoder, decoder_embedding)\n        self.query_embed.requires_grad_(False)\n\n        # group fully-connected\n        self.num_classes = num_classes\n        self.duplicate_factor = int(num_classes / embed_len_decoder + 0.999)\n        self.duplicate_pooling = torch.nn.Parameter(\n            torch.Tensor(embed_len_decoder, decoder_embedding, self.duplicate_factor)\n        )\n        self.duplicate_pooling_bias = torch.nn.Parameter(torch.Tensor(num_classes))\n        torch.nn.init.xavier_normal_(self.duplicate_pooling)\n        torch.nn.init.constant_(self.duplicate_pooling_bias, 0)\n        self.group_fc = GroupFC(embed_len_decoder)\n\n    def forward(self, x):\n        if len(x.shape) == 4:  # [bs,2048, 7,7]\n            embedding_spatial = x.flatten(2).transpose(1, 2)\n        else:  # [bs, 197,468]\n            embedding_spatial = x\n        embedding_spatial_786 = self.embed_standart(embedding_spatial)\n        embedding_spatial_786 = torch.nn.functional.relu(\n            embedding_spatial_786, inplace=True\n        )\n\n        bs = embedding_spatial_786.shape[0]\n        query_embed = self.query_embed.weight\n        # tgt = query_embed.unsqueeze(1).repeat(1, bs, 1)\n        tgt = query_embed.unsqueeze(1).expand(\n            -1, bs, -1\n        )  # no allocation of memory with expand\n        h = self.decoder(\n            tgt, embedding_spatial_786.transpose(0, 1)\n        )  # [embed_len_decoder, batch, 768]\n        h = h.transpose(0, 1)\n\n        out_extrap = torch.zeros(\n            h.shape[0],\n            h.shape[1],\n            self.duplicate_factor,\n            device=h.device,\n            dtype=h.dtype,\n        )\n        self.group_fc(h, self.duplicate_pooling, out_extrap)\n        h_out = out_extrap.flatten(1)[:, : self.num_classes]\n        h_out += self.duplicate_pooling_bias\n        logits = h_out\n        return logits\n"
  },
  {
    "path": "RVT/models/layers/maxvit/layers/mlp.py",
    "content": "\"\"\" MLP module w/ dropout and configurable activation layer\n\nHacked together by / Copyright 2020 Ross Wightman\n\"\"\"\n\nfrom torch import nn as nn\n\nfrom .helpers import to_2tuple\n\n\nclass Mlp(nn.Module):\n    \"\"\"MLP as used in Vision Transformer, MLP-Mixer and related networks\"\"\"\n\n    def __init__(\n        self,\n        in_features,\n        hidden_features=None,\n        out_features=None,\n        act_layer=nn.GELU,\n        bias=True,\n        drop=0.0,\n    ):\n        super().__init__()\n        out_features = out_features or in_features\n        hidden_features = hidden_features or in_features\n        bias = to_2tuple(bias)\n        drop_probs = to_2tuple(drop)\n\n        self.fc1 = nn.Linear(in_features, hidden_features, bias=bias[0])\n        self.act = act_layer()\n        self.drop1 = nn.Dropout(drop_probs[0])\n        self.fc2 = nn.Linear(hidden_features, out_features, bias=bias[1])\n        self.drop2 = nn.Dropout(drop_probs[1])\n\n    def forward(self, x):\n        x = self.fc1(x)\n        x = self.act(x)\n        x = self.drop1(x)\n        x = self.fc2(x)\n        x = self.drop2(x)\n        return x\n\n\nclass GluMlp(nn.Module):\n    \"\"\"MLP w/ GLU style gating\n    See: https://arxiv.org/abs/1612.08083, https://arxiv.org/abs/2002.05202\n    \"\"\"\n\n    def __init__(\n        self,\n        in_features,\n        hidden_features=None,\n        out_features=None,\n        act_layer=nn.Sigmoid,\n        bias=True,\n        drop=0.0,\n    ):\n        super().__init__()\n        out_features = out_features or in_features\n        hidden_features = hidden_features or in_features\n        assert hidden_features % 2 == 0\n        bias = to_2tuple(bias)\n        drop_probs = to_2tuple(drop)\n\n        self.fc1 = nn.Linear(in_features, hidden_features, bias=bias[0])\n        self.act = act_layer()\n        self.drop1 = nn.Dropout(drop_probs[0])\n        self.fc2 = nn.Linear(hidden_features // 2, out_features, bias=bias[1])\n        self.drop2 = nn.Dropout(drop_probs[1])\n\n    def init_weights(self):\n        # override init of fc1 w/ gate portion set to weight near zero, bias=1\n        fc1_mid = self.fc1.bias.shape[0] // 2\n        nn.init.ones_(self.fc1.bias[fc1_mid:])\n        nn.init.normal_(self.fc1.weight[fc1_mid:], std=1e-6)\n\n    def forward(self, x):\n        x = self.fc1(x)\n        x, gates = x.chunk(2, dim=-1)\n        x = x * self.act(gates)\n        x = self.drop1(x)\n        x = self.fc2(x)\n        x = self.drop2(x)\n        return x\n\n\nclass GatedMlp(nn.Module):\n    \"\"\"MLP as used in gMLP\"\"\"\n\n    def __init__(\n        self,\n        in_features,\n        hidden_features=None,\n        out_features=None,\n        act_layer=nn.GELU,\n        gate_layer=None,\n        bias=True,\n        drop=0.0,\n    ):\n        super().__init__()\n        out_features = out_features or in_features\n        hidden_features = hidden_features or in_features\n        bias = to_2tuple(bias)\n        drop_probs = to_2tuple(drop)\n\n        self.fc1 = nn.Linear(in_features, hidden_features, bias=bias[0])\n        self.act = act_layer()\n        self.drop1 = nn.Dropout(drop_probs[0])\n        if gate_layer is not None:\n            assert hidden_features % 2 == 0\n            self.gate = gate_layer(hidden_features)\n            hidden_features = (\n                hidden_features // 2\n            )  # FIXME base reduction on gate property?\n        else:\n            self.gate = nn.Identity()\n        self.fc2 = nn.Linear(hidden_features, out_features, bias=bias[1])\n        self.drop2 = nn.Dropout(drop_probs[1])\n\n    def forward(self, x):\n        x = self.fc1(x)\n        x = self.act(x)\n        x = self.drop1(x)\n        x = self.gate(x)\n        x = self.fc2(x)\n        x = self.drop2(x)\n        return x\n\n\nclass ConvMlp(nn.Module):\n    \"\"\"MLP using 1x1 convs that keeps spatial dims\"\"\"\n\n    def __init__(\n        self,\n        in_features,\n        hidden_features=None,\n        out_features=None,\n        act_layer=nn.ReLU,\n        norm_layer=None,\n        bias=True,\n        drop=0.0,\n    ):\n        super().__init__()\n        out_features = out_features or in_features\n        hidden_features = hidden_features or in_features\n        bias = to_2tuple(bias)\n\n        self.fc1 = nn.Conv2d(in_features, hidden_features, kernel_size=1, bias=bias[0])\n        self.norm = norm_layer(hidden_features) if norm_layer else nn.Identity()\n        self.act = act_layer()\n        self.drop = nn.Dropout(drop)\n        self.fc2 = nn.Conv2d(hidden_features, out_features, kernel_size=1, bias=bias[1])\n\n    def forward(self, x):\n        x = self.fc1(x)\n        x = self.norm(x)\n        x = self.act(x)\n        x = self.drop(x)\n        x = self.fc2(x)\n        return x\n"
  },
  {
    "path": "RVT/models/layers/maxvit/layers/non_local_attn.py",
    "content": "\"\"\" Bilinear-Attention-Transform and Non-Local Attention\n\nPaper: `Non-Local Neural Networks With Grouped Bilinear Attentional Transforms`\n    - https://openaccess.thecvf.com/content_CVPR_2020/html/Chi_Non-Local_Neural_Networks_With_Grouped_Bilinear_Attentional_Transforms_CVPR_2020_paper.html\nAdapted from original code: https://github.com/BA-Transform/BAT-Image-Classification\n\"\"\"\n\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\nfrom .conv_bn_act import ConvNormAct\nfrom .helpers import make_divisible\nfrom .trace_utils import _assert\n\n\nclass NonLocalAttn(nn.Module):\n    \"\"\"Spatial NL block for image classification.\n\n    This was adapted from https://github.com/BA-Transform/BAT-Image-Classification\n    Their NonLocal impl inspired by https://github.com/facebookresearch/video-nonlocal-net.\n    \"\"\"\n\n    def __init__(\n        self,\n        in_channels,\n        use_scale=True,\n        rd_ratio=1 / 8,\n        rd_channels=None,\n        rd_divisor=8,\n        **kwargs\n    ):\n        super(NonLocalAttn, self).__init__()\n        if rd_channels is None:\n            rd_channels = make_divisible(in_channels * rd_ratio, divisor=rd_divisor)\n        self.scale = in_channels**-0.5 if use_scale else 1.0\n        self.t = nn.Conv2d(in_channels, rd_channels, kernel_size=1, stride=1, bias=True)\n        self.p = nn.Conv2d(in_channels, rd_channels, kernel_size=1, stride=1, bias=True)\n        self.g = nn.Conv2d(in_channels, rd_channels, kernel_size=1, stride=1, bias=True)\n        self.z = nn.Conv2d(rd_channels, in_channels, kernel_size=1, stride=1, bias=True)\n        self.norm = nn.BatchNorm2d(in_channels)\n        self.reset_parameters()\n\n    def forward(self, x):\n        shortcut = x\n\n        t = self.t(x)\n        p = self.p(x)\n        g = self.g(x)\n\n        B, C, H, W = t.size()\n        t = t.view(B, C, -1).permute(0, 2, 1)\n        p = p.view(B, C, -1)\n        g = g.view(B, C, -1).permute(0, 2, 1)\n\n        att = torch.bmm(t, p) * self.scale\n        att = F.softmax(att, dim=2)\n        x = torch.bmm(att, g)\n\n        x = x.permute(0, 2, 1).reshape(B, C, H, W)\n        x = self.z(x)\n        x = self.norm(x) + shortcut\n\n        return x\n\n    def reset_parameters(self):\n        for name, m in self.named_modules():\n            if isinstance(m, nn.Conv2d):\n                nn.init.kaiming_normal_(m.weight, mode=\"fan_out\", nonlinearity=\"relu\")\n                if len(list(m.parameters())) > 1:\n                    nn.init.constant_(m.bias, 0.0)\n            elif isinstance(m, nn.BatchNorm2d):\n                nn.init.constant_(m.weight, 0)\n                nn.init.constant_(m.bias, 0)\n            elif isinstance(m, nn.GroupNorm):\n                nn.init.constant_(m.weight, 0)\n                nn.init.constant_(m.bias, 0)\n\n\nclass BilinearAttnTransform(nn.Module):\n    def __init__(\n        self,\n        in_channels,\n        block_size,\n        groups,\n        act_layer=nn.ReLU,\n        norm_layer=nn.BatchNorm2d,\n    ):\n        super(BilinearAttnTransform, self).__init__()\n\n        self.conv1 = ConvNormAct(\n            in_channels, groups, 1, act_layer=act_layer, norm_layer=norm_layer\n        )\n        self.conv_p = nn.Conv2d(\n            groups, block_size * block_size * groups, kernel_size=(block_size, 1)\n        )\n        self.conv_q = nn.Conv2d(\n            groups, block_size * block_size * groups, kernel_size=(1, block_size)\n        )\n        self.conv2 = ConvNormAct(\n            in_channels, in_channels, 1, act_layer=act_layer, norm_layer=norm_layer\n        )\n        self.block_size = block_size\n        self.groups = groups\n        self.in_channels = in_channels\n\n    def resize_mat(self, x, t: int):\n        B, C, block_size, block_size1 = x.shape\n        _assert(block_size == block_size1, \"\")\n        if t <= 1:\n            return x\n        x = x.view(B * C, -1, 1, 1)\n        x = x * torch.eye(t, t, dtype=x.dtype, device=x.device)\n        x = x.view(B * C, block_size, block_size, t, t)\n        x = torch.cat(torch.split(x, 1, dim=1), dim=3)\n        x = torch.cat(torch.split(x, 1, dim=2), dim=4)\n        x = x.view(B, C, block_size * t, block_size * t)\n        return x\n\n    def forward(self, x):\n        _assert(x.shape[-1] % self.block_size == 0, \"\")\n        _assert(x.shape[-2] % self.block_size == 0, \"\")\n        B, C, H, W = x.shape\n        out = self.conv1(x)\n        rp = F.adaptive_max_pool2d(out, (self.block_size, 1))\n        cp = F.adaptive_max_pool2d(out, (1, self.block_size))\n        p = (\n            self.conv_p(rp)\n            .view(B, self.groups, self.block_size, self.block_size)\n            .sigmoid()\n        )\n        q = (\n            self.conv_q(cp)\n            .view(B, self.groups, self.block_size, self.block_size)\n            .sigmoid()\n        )\n        p = p / p.sum(dim=3, keepdim=True)\n        q = q / q.sum(dim=2, keepdim=True)\n        p = (\n            p.view(B, self.groups, 1, self.block_size, self.block_size)\n            .expand(\n                x.size(0),\n                self.groups,\n                C // self.groups,\n                self.block_size,\n                self.block_size,\n            )\n            .contiguous()\n        )\n        p = p.view(B, C, self.block_size, self.block_size)\n        q = (\n            q.view(B, self.groups, 1, self.block_size, self.block_size)\n            .expand(\n                x.size(0),\n                self.groups,\n                C // self.groups,\n                self.block_size,\n                self.block_size,\n            )\n            .contiguous()\n        )\n        q = q.view(B, C, self.block_size, self.block_size)\n        p = self.resize_mat(p, H // self.block_size)\n        q = self.resize_mat(q, W // self.block_size)\n        y = p.matmul(x)\n        y = y.matmul(q)\n\n        y = self.conv2(y)\n        return y\n\n\nclass BatNonLocalAttn(nn.Module):\n    \"\"\"BAT\n    Adapted from: https://github.com/BA-Transform/BAT-Image-Classification\n    \"\"\"\n\n    def __init__(\n        self,\n        in_channels,\n        block_size=7,\n        groups=2,\n        rd_ratio=0.25,\n        rd_channels=None,\n        rd_divisor=8,\n        drop_rate=0.2,\n        act_layer=nn.ReLU,\n        norm_layer=nn.BatchNorm2d,\n        **_\n    ):\n        super().__init__()\n        if rd_channels is None:\n            rd_channels = make_divisible(in_channels * rd_ratio, divisor=rd_divisor)\n        self.conv1 = ConvNormAct(\n            in_channels, rd_channels, 1, act_layer=act_layer, norm_layer=norm_layer\n        )\n        self.ba = BilinearAttnTransform(\n            rd_channels, block_size, groups, act_layer=act_layer, norm_layer=norm_layer\n        )\n        self.conv2 = ConvNormAct(\n            rd_channels, in_channels, 1, act_layer=act_layer, norm_layer=norm_layer\n        )\n        self.dropout = nn.Dropout2d(p=drop_rate)\n\n    def forward(self, x):\n        xl = self.conv1(x)\n        y = self.ba(xl)\n        y = self.conv2(y)\n        y = self.dropout(y)\n        return y + x\n"
  },
  {
    "path": "RVT/models/layers/maxvit/layers/norm.py",
    "content": "\"\"\" Normalization layers and wrappers\n\nNorm layer definitions that support fast norm and consistent channel arg order (always first arg).\n\nHacked together by / Copyright 2022 Ross Wightman\n\"\"\"\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom .fast_norm import is_fast_norm, fast_group_norm, fast_layer_norm\n\n\nclass GroupNorm(nn.GroupNorm):\n    def __init__(self, num_channels, num_groups=32, eps=1e-5, affine=True):\n        # NOTE num_channels is swapped to first arg for consistency in swapping norm layers with BN\n        super().__init__(num_groups, num_channels, eps=eps, affine=affine)\n        self.fast_norm = (\n            is_fast_norm()\n        )  # can't script unless we have these flags here (no globals)\n\n    def forward(self, x):\n        if self.fast_norm:\n            return fast_group_norm(x, self.num_groups, self.weight, self.bias, self.eps)\n        else:\n            return F.group_norm(x, self.num_groups, self.weight, self.bias, self.eps)\n\n\nclass GroupNorm1(nn.GroupNorm):\n    \"\"\"Group Normalization with 1 group.\n    Input: tensor in shape [B, C, *]\n    \"\"\"\n\n    def __init__(self, num_channels, **kwargs):\n        super().__init__(1, num_channels, **kwargs)\n        self.fast_norm = (\n            is_fast_norm()\n        )  # can't script unless we have these flags here (no globals)\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        if self.fast_norm:\n            return fast_group_norm(x, self.num_groups, self.weight, self.bias, self.eps)\n        else:\n            return F.group_norm(x, self.num_groups, self.weight, self.bias, self.eps)\n\n\nclass LayerNorm(nn.LayerNorm):\n    \"\"\"LayerNorm w/ fast norm option\"\"\"\n\n    def __init__(self, num_channels, eps=1e-6, affine=True):\n        super().__init__(num_channels, eps=eps, elementwise_affine=affine)\n        self._fast_norm = (\n            is_fast_norm()\n        )  # can't script unless we have these flags here (no globals)\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        if self._fast_norm:\n            x = fast_layer_norm(\n                x, self.normalized_shape, self.weight, self.bias, self.eps\n            )\n        else:\n            x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)\n        return x\n\n\nclass LayerNorm2d(nn.LayerNorm):\n    \"\"\"LayerNorm for channels of '2D' spatial NCHW tensors\"\"\"\n\n    def __init__(self, num_channels, eps=1e-6, affine=True):\n        super().__init__(num_channels, eps=eps, elementwise_affine=affine)\n        self._fast_norm = (\n            is_fast_norm()\n        )  # can't script unless we have these flags here (no globals)\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        x = x.permute(0, 2, 3, 1)\n        if self._fast_norm:\n            x = fast_layer_norm(\n                x, self.normalized_shape, self.weight, self.bias, self.eps\n            )\n        else:\n            x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)\n        x = x.permute(0, 3, 1, 2)\n        return x\n\n\ndef _is_contiguous(tensor: torch.Tensor) -> bool:\n    # jit is oh so lovely :/\n    if torch.jit.is_scripting():\n        return tensor.is_contiguous()\n    else:\n        return tensor.is_contiguous(memory_format=torch.contiguous_format)\n\n\n@torch.jit.script\ndef _layer_norm_cf(\n    x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, eps: float\n):\n    s, u = torch.var_mean(x, dim=1, unbiased=False, keepdim=True)\n    x = (x - u) * torch.rsqrt(s + eps)\n    x = x * weight[:, None, None] + bias[:, None, None]\n    return x\n\n\ndef _layer_norm_cf_sqm(\n    x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, eps: float\n):\n    u = x.mean(dim=1, keepdim=True)\n    s = ((x * x).mean(dim=1, keepdim=True) - (u * u)).clamp(0)\n    x = (x - u) * torch.rsqrt(s + eps)\n    x = x * weight.view(1, -1, 1, 1) + bias.view(1, -1, 1, 1)\n    return x\n\n\nclass LayerNormExp2d(nn.LayerNorm):\n    \"\"\"LayerNorm for channels_first tensors with 2d spatial dimensions (ie N, C, H, W).\n\n    Experimental implementation w/ manual norm for tensors non-contiguous tensors.\n\n    This improves throughput in some scenarios (tested on Ampere GPU), esp w/ channels_last\n    layout. However, benefits are not always clear and can perform worse on other GPUs.\n    \"\"\"\n\n    def __init__(self, num_channels, eps=1e-6):\n        super().__init__(num_channels, eps=eps)\n\n    def forward(self, x) -> torch.Tensor:\n        if _is_contiguous(x):\n            x = F.layer_norm(\n                x.permute(0, 2, 3, 1),\n                self.normalized_shape,\n                self.weight,\n                self.bias,\n                self.eps,\n            ).permute(0, 3, 1, 2)\n        else:\n            x = _layer_norm_cf(x, self.weight, self.bias, self.eps)\n        return x\n"
  },
  {
    "path": "RVT/models/layers/maxvit/layers/norm_act.py",
    "content": "\"\"\" Normalization + Activation Layers\n\nProvides Norm+Act fns for standard PyTorch norm layers such as\n* BatchNorm\n* GroupNorm\n* LayerNorm\n\nThis allows swapping with alternative layers that are natively both norm + act such as\n* EvoNorm (evo_norm.py)\n* FilterResponseNorm (filter_response_norm.py)\n* InplaceABN (inplace_abn.py)\n\nHacked together by / Copyright 2022 Ross Wightman\n\"\"\"\n\nfrom typing import Union, List, Optional, Any\n\nimport torch\nfrom torch import nn as nn\nfrom torch.nn import functional as F\n\nfrom .create_act import get_act_layer\nfrom .fast_norm import is_fast_norm, fast_group_norm, fast_layer_norm\nfrom .trace_utils import _assert\n\n\nclass BatchNormAct2d(nn.BatchNorm2d):\n    \"\"\"BatchNorm + Activation\n\n    This module performs BatchNorm + Activation in a manner that will remain backwards\n    compatible with weights trained with separate bn, act. This is why we inherit from BN\n    instead of composing it as a .bn member.\n    \"\"\"\n\n    def __init__(\n        self,\n        num_features,\n        eps=1e-5,\n        momentum=0.1,\n        affine=True,\n        track_running_stats=True,\n        apply_act=True,\n        act_layer=nn.ReLU,\n        inplace=True,\n        drop_layer=None,\n        device=None,\n        dtype=None,\n    ):\n        try:\n            factory_kwargs = {\"device\": device, \"dtype\": dtype}\n            super(BatchNormAct2d, self).__init__(\n                num_features,\n                eps=eps,\n                momentum=momentum,\n                affine=affine,\n                track_running_stats=track_running_stats,\n                **factory_kwargs,\n            )\n        except TypeError:\n            # NOTE for backwards compat with old PyTorch w/o factory device/dtype support\n            super(BatchNormAct2d, self).__init__(\n                num_features,\n                eps=eps,\n                momentum=momentum,\n                affine=affine,\n                track_running_stats=track_running_stats,\n            )\n        self.drop = drop_layer() if drop_layer is not None else nn.Identity()\n        act_layer = get_act_layer(act_layer)  # string -> nn.Module\n        if act_layer is not None and apply_act:\n            act_args = dict(inplace=True) if inplace else {}\n            self.act = act_layer(**act_args)\n        else:\n            self.act = nn.Identity()\n\n    def forward(self, x):\n        # cut & paste of torch.nn.BatchNorm2d.forward impl to avoid issues with torchscript and tracing\n        _assert(x.ndim == 4, f\"expected 4D input (got {x.ndim}D input)\")\n\n        # exponential_average_factor is set to self.momentum\n        # (when it is available) only so that it gets updated\n        # in ONNX graph when this node is exported to ONNX.\n        if self.momentum is None:\n            exponential_average_factor = 0.0\n        else:\n            exponential_average_factor = self.momentum\n\n        if self.training and self.track_running_stats:\n            # TODO: if statement only here to tell the jit to skip emitting this when it is None\n            if self.num_batches_tracked is not None:  # type: ignore[has-type]\n                self.num_batches_tracked = self.num_batches_tracked + 1  # type: ignore[has-type]\n                if self.momentum is None:  # use cumulative moving average\n                    exponential_average_factor = 1.0 / float(self.num_batches_tracked)\n                else:  # use exponential moving average\n                    exponential_average_factor = self.momentum\n\n        r\"\"\"\n        Decide whether the mini-batch stats should be used for normalization rather than the buffers.\n        Mini-batch stats are used in training mode, and in eval mode when buffers are None.\n        \"\"\"\n        if self.training:\n            bn_training = True\n        else:\n            bn_training = (self.running_mean is None) and (self.running_var is None)\n\n        r\"\"\"\n        Buffers are only updated if they are to be tracked and we are in training mode. Thus they only need to be\n        passed when the update should occur (i.e. in training mode when they are tracked), or when buffer stats are\n        used for normalization (i.e. in eval mode when buffers are not None).\n        \"\"\"\n        x = F.batch_norm(\n            x,\n            # If buffers are not to be tracked, ensure that they won't be updated\n            (\n                self.running_mean\n                if not self.training or self.track_running_stats\n                else None\n            ),\n            self.running_var if not self.training or self.track_running_stats else None,\n            self.weight,\n            self.bias,\n            bn_training,\n            exponential_average_factor,\n            self.eps,\n        )\n        x = self.drop(x)\n        x = self.act(x)\n        return x\n\n\nclass SyncBatchNormAct(nn.SyncBatchNorm):\n    # Thanks to Selim Seferbekov (https://github.com/rwightman/pytorch-image-models/issues/1254)\n    # This is a quick workaround to support SyncBatchNorm for timm BatchNormAct2d layers\n    # but ONLY when used in conjunction with the timm conversion function below.\n    # Do not create this module directly or use the PyTorch conversion function.\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        x = super().forward(\n            x\n        )  # SyncBN doesn't work with torchscript anyways, so this is fine\n        if hasattr(self, \"drop\"):\n            x = self.drop(x)\n        if hasattr(self, \"act\"):\n            x = self.act(x)\n        return x\n\n\ndef convert_sync_batchnorm(module, process_group=None):\n    # convert both BatchNorm and BatchNormAct layers to Synchronized variants\n    module_output = module\n    if isinstance(module, torch.nn.modules.batchnorm._BatchNorm):\n        if isinstance(module, BatchNormAct2d):\n            # convert timm norm + act layer\n            module_output = SyncBatchNormAct(\n                module.num_features,\n                module.eps,\n                module.momentum,\n                module.affine,\n                module.track_running_stats,\n                process_group=process_group,\n            )\n            # set act and drop attr from the original module\n            module_output.act = module.act\n            module_output.drop = module.drop\n        else:\n            # convert standard BatchNorm layers\n            module_output = torch.nn.SyncBatchNorm(\n                module.num_features,\n                module.eps,\n                module.momentum,\n                module.affine,\n                module.track_running_stats,\n                process_group,\n            )\n        if module.affine:\n            with torch.no_grad():\n                module_output.weight = module.weight\n                module_output.bias = module.bias\n        module_output.running_mean = module.running_mean\n        module_output.running_var = module.running_var\n        module_output.num_batches_tracked = module.num_batches_tracked\n        if hasattr(module, \"qconfig\"):\n            module_output.qconfig = module.qconfig\n    for name, child in module.named_children():\n        module_output.add_module(name, convert_sync_batchnorm(child, process_group))\n    del module\n    return module_output\n\n\ndef _num_groups(num_channels, num_groups, group_size):\n    if group_size:\n        assert num_channels % group_size == 0\n        return num_channels // group_size\n    return num_groups\n\n\nclass GroupNormAct(nn.GroupNorm):\n    # NOTE num_channel and num_groups order flipped for easier layer swaps / binding of fixed args\n    def __init__(\n        self,\n        num_channels,\n        num_groups=32,\n        eps=1e-5,\n        affine=True,\n        group_size=None,\n        apply_act=True,\n        act_layer=nn.ReLU,\n        inplace=True,\n        drop_layer=None,\n    ):\n        super(GroupNormAct, self).__init__(\n            _num_groups(num_channels, num_groups, group_size),\n            num_channels,\n            eps=eps,\n            affine=affine,\n        )\n        self.drop = drop_layer() if drop_layer is not None else nn.Identity()\n        act_layer = get_act_layer(act_layer)  # string -> nn.Module\n        if act_layer is not None and apply_act:\n            act_args = dict(inplace=True) if inplace else {}\n            self.act = act_layer(**act_args)\n        else:\n            self.act = nn.Identity()\n        self._fast_norm = is_fast_norm()\n\n    def forward(self, x):\n        if self._fast_norm:\n            x = fast_group_norm(x, self.num_groups, self.weight, self.bias, self.eps)\n        else:\n            x = F.group_norm(x, self.num_groups, self.weight, self.bias, self.eps)\n        x = self.drop(x)\n        x = self.act(x)\n        return x\n\n\nclass LayerNormAct(nn.LayerNorm):\n    def __init__(\n        self,\n        normalization_shape: Union[int, List[int], torch.Size],\n        eps=1e-5,\n        affine=True,\n        apply_act=True,\n        act_layer=nn.ReLU,\n        inplace=True,\n        drop_layer=None,\n    ):\n        super(LayerNormAct, self).__init__(\n            normalization_shape, eps=eps, elementwise_affine=affine\n        )\n        self.drop = drop_layer() if drop_layer is not None else nn.Identity()\n        act_layer = get_act_layer(act_layer)  # string -> nn.Module\n        if act_layer is not None and apply_act:\n            act_args = dict(inplace=True) if inplace else {}\n            self.act = act_layer(**act_args)\n        else:\n            self.act = nn.Identity()\n        self._fast_norm = is_fast_norm()\n\n    def forward(self, x):\n        if self._fast_norm:\n            x = fast_layer_norm(\n                x, self.normalized_shape, self.weight, self.bias, self.eps\n            )\n        else:\n            x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)\n        x = self.drop(x)\n        x = self.act(x)\n        return x\n\n\nclass LayerNormAct2d(nn.LayerNorm):\n    def __init__(\n        self,\n        num_channels,\n        eps=1e-5,\n        affine=True,\n        apply_act=True,\n        act_layer=nn.ReLU,\n        inplace=True,\n        drop_layer=None,\n    ):\n        super(LayerNormAct2d, self).__init__(\n            num_channels, eps=eps, elementwise_affine=affine\n        )\n        self.drop = drop_layer() if drop_layer is not None else nn.Identity()\n        act_layer = get_act_layer(act_layer)  # string -> nn.Module\n        if act_layer is not None and apply_act:\n            act_args = dict(inplace=True) if inplace else {}\n            self.act = act_layer(**act_args)\n        else:\n            self.act = nn.Identity()\n        self._fast_norm = is_fast_norm()\n\n    def forward(self, x):\n        x = x.permute(0, 2, 3, 1)\n        if self._fast_norm:\n            x = fast_layer_norm(\n                x, self.normalized_shape, self.weight, self.bias, self.eps\n            )\n        else:\n            x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)\n        x = x.permute(0, 3, 1, 2)\n        x = self.drop(x)\n        x = self.act(x)\n        return x\n"
  },
  {
    "path": "RVT/models/layers/maxvit/layers/padding.py",
    "content": "\"\"\" Padding Helpers\n\nHacked together by / Copyright 2020 Ross Wightman\n\"\"\"\n\nimport math\nfrom typing import List, Tuple\n\nimport torch.nn.functional as F\n\n\n# Calculate symmetric padding for a convolution\ndef get_padding(kernel_size: int, stride: int = 1, dilation: int = 1, **_) -> int:\n    padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2\n    return padding\n\n\n# Calculate asymmetric TensorFlow-like 'SAME' padding for a convolution\ndef get_same_padding(x: int, k: int, s: int, d: int):\n    return max((math.ceil(x / s) - 1) * s + (k - 1) * d + 1 - x, 0)\n\n\n# Can SAME padding for given args be done statically?\ndef is_static_pad(kernel_size: int, stride: int = 1, dilation: int = 1, **_):\n    return stride == 1 and (dilation * (kernel_size - 1)) % 2 == 0\n\n\n# Dynamically pad input x with 'SAME' padding for conv with specified args\ndef pad_same(x, k: List[int], s: List[int], d: List[int] = (1, 1), value: float = 0):\n    ih, iw = x.size()[-2:]\n    pad_h, pad_w = get_same_padding(ih, k[0], s[0], d[0]), get_same_padding(\n        iw, k[1], s[1], d[1]\n    )\n    if pad_h > 0 or pad_w > 0:\n        x = F.pad(\n            x,\n            [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2],\n            value=value,\n        )\n    return x\n\n\ndef get_padding_value(padding, kernel_size, **kwargs) -> Tuple[Tuple, bool]:\n    dynamic = False\n    if isinstance(padding, str):\n        # for any string padding, the padding will be calculated for you, one of three ways\n        padding = padding.lower()\n        if padding == \"same\":\n            # TF compatible 'SAME' padding, has a performance and GPU memory allocation impact\n            if is_static_pad(kernel_size, **kwargs):\n                # static case, no extra overhead\n                padding = get_padding(kernel_size, **kwargs)\n            else:\n                # dynamic 'SAME' padding, has runtime/GPU memory overhead\n                padding = 0\n                dynamic = True\n        elif padding == \"valid\":\n            # 'VALID' padding, same as padding=0\n            padding = 0\n        else:\n            # Default to PyTorch style 'same'-ish symmetric padding\n            padding = get_padding(kernel_size, **kwargs)\n    return padding, dynamic\n"
  },
  {
    "path": "RVT/models/layers/maxvit/layers/patch_embed.py",
    "content": "\"\"\" Image to Patch Embedding using Conv2d\n\nA convolution based approach to patchifying a 2D image w/ embedding projection.\n\nBased on the impl in https://github.com/google-research/vision_transformer\n\nHacked together by / Copyright 2020 Ross Wightman\n\"\"\"\n\nfrom torch import nn as nn\n\nfrom .helpers import to_2tuple\nfrom .trace_utils import _assert\n\n\nclass PatchEmbed(nn.Module):\n    \"\"\"2D Image to Patch Embedding\"\"\"\n\n    def __init__(\n        self,\n        img_size=224,\n        patch_size=16,\n        in_chans=3,\n        embed_dim=768,\n        norm_layer=None,\n        flatten=True,\n    ):\n        super().__init__()\n        img_size = to_2tuple(img_size)\n        patch_size = to_2tuple(patch_size)\n        self.img_size = img_size\n        self.patch_size = patch_size\n        self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])\n        self.num_patches = self.grid_size[0] * self.grid_size[1]\n        self.flatten = flatten\n\n        self.proj = nn.Conv2d(\n            in_chans, embed_dim, kernel_size=patch_size, stride=patch_size\n        )\n        self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()\n\n    def forward(self, x):\n        B, C, H, W = x.shape\n        _assert(\n            H == self.img_size[0],\n            f\"Input image height ({H}) doesn't match model ({self.img_size[0]}).\",\n        )\n        _assert(\n            W == self.img_size[1],\n            f\"Input image width ({W}) doesn't match model ({self.img_size[1]}).\",\n        )\n        x = self.proj(x)\n        if self.flatten:\n            x = x.flatten(2).transpose(1, 2)  # BCHW -> BNC\n        x = self.norm(x)\n        return x\n"
  },
  {
    "path": "RVT/models/layers/maxvit/layers/pool2d_same.py",
    "content": "\"\"\" AvgPool2d w/ Same Padding\n\nHacked together by / Copyright 2020 Ross Wightman\n\"\"\"\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom typing import List, Tuple, Optional\n\nfrom .helpers import to_2tuple\nfrom .padding import pad_same, get_padding_value\n\n\ndef avg_pool2d_same(\n    x,\n    kernel_size: List[int],\n    stride: List[int],\n    padding: List[int] = (0, 0),\n    ceil_mode: bool = False,\n    count_include_pad: bool = True,\n):\n    # FIXME how to deal with count_include_pad vs not for external padding?\n    x = pad_same(x, kernel_size, stride)\n    return F.avg_pool2d(x, kernel_size, stride, (0, 0), ceil_mode, count_include_pad)\n\n\nclass AvgPool2dSame(nn.AvgPool2d):\n    \"\"\"Tensorflow like 'SAME' wrapper for 2D average pooling\"\"\"\n\n    def __init__(\n        self,\n        kernel_size: int,\n        stride=None,\n        padding=0,\n        ceil_mode=False,\n        count_include_pad=True,\n    ):\n        kernel_size = to_2tuple(kernel_size)\n        stride = to_2tuple(stride)\n        super(AvgPool2dSame, self).__init__(\n            kernel_size, stride, (0, 0), ceil_mode, count_include_pad\n        )\n\n    def forward(self, x):\n        x = pad_same(x, self.kernel_size, self.stride)\n        return F.avg_pool2d(\n            x,\n            self.kernel_size,\n            self.stride,\n            self.padding,\n            self.ceil_mode,\n            self.count_include_pad,\n        )\n\n\ndef max_pool2d_same(\n    x,\n    kernel_size: List[int],\n    stride: List[int],\n    padding: List[int] = (0, 0),\n    dilation: List[int] = (1, 1),\n    ceil_mode: bool = False,\n):\n    x = pad_same(x, kernel_size, stride, value=-float(\"inf\"))\n    return F.max_pool2d(x, kernel_size, stride, (0, 0), dilation, ceil_mode)\n\n\nclass MaxPool2dSame(nn.MaxPool2d):\n    \"\"\"Tensorflow like 'SAME' wrapper for 2D max pooling\"\"\"\n\n    def __init__(\n        self, kernel_size: int, stride=None, padding=0, dilation=1, ceil_mode=False\n    ):\n        kernel_size = to_2tuple(kernel_size)\n        stride = to_2tuple(stride)\n        dilation = to_2tuple(dilation)\n        super(MaxPool2dSame, self).__init__(\n            kernel_size, stride, (0, 0), dilation, ceil_mode\n        )\n\n    def forward(self, x):\n        x = pad_same(x, self.kernel_size, self.stride, value=-float(\"inf\"))\n        return F.max_pool2d(\n            x, self.kernel_size, self.stride, (0, 0), self.dilation, self.ceil_mode\n        )\n\n\ndef create_pool2d(pool_type, kernel_size, stride=None, **kwargs):\n    stride = stride or kernel_size\n    padding = kwargs.pop(\"padding\", \"\")\n    padding, is_dynamic = get_padding_value(\n        padding, kernel_size, stride=stride, **kwargs\n    )\n    if is_dynamic:\n        if pool_type == \"avg\":\n            return AvgPool2dSame(kernel_size, stride=stride, **kwargs)\n        elif pool_type == \"max\":\n            return MaxPool2dSame(kernel_size, stride=stride, **kwargs)\n        else:\n            assert False, f\"Unsupported pool type {pool_type}\"\n    else:\n        if pool_type == \"avg\":\n            return nn.AvgPool2d(kernel_size, stride=stride, padding=padding, **kwargs)\n        elif pool_type == \"max\":\n            return nn.MaxPool2d(kernel_size, stride=stride, padding=padding, **kwargs)\n        else:\n            assert False, f\"Unsupported pool type {pool_type}\"\n"
  },
  {
    "path": "RVT/models/layers/maxvit/layers/pos_embed.py",
    "content": "import math\nfrom typing import List, Tuple, Optional, Union\n\nimport torch\nfrom torch import nn as nn\n\n\ndef pixel_freq_bands(\n    num_bands: int,\n    max_freq: float = 224.0,\n    linear_bands: bool = True,\n    dtype: torch.dtype = torch.float32,\n    device: Optional[torch.device] = None,\n):\n    if linear_bands:\n        bands = torch.linspace(1.0, max_freq / 2, num_bands, dtype=dtype, device=device)\n    else:\n        bands = 2 ** torch.linspace(\n            0, math.log(max_freq, 2) - 1, num_bands, dtype=dtype, device=device\n        )\n    return bands * torch.pi\n\n\ndef inv_freq_bands(\n    num_bands: int,\n    temperature: float = 100000.0,\n    step: int = 2,\n    dtype: torch.dtype = torch.float32,\n    device: Optional[torch.device] = None,\n) -> torch.Tensor:\n    inv_freq = 1.0 / (\n        temperature\n        ** (torch.arange(0, num_bands, step, dtype=dtype, device=device) / num_bands)\n    )\n    return inv_freq\n\n\ndef build_sincos2d_pos_embed(\n    feat_shape: List[int],\n    dim: int = 64,\n    temperature: float = 10000.0,\n    reverse_coord: bool = False,\n    interleave_sin_cos: bool = False,\n    dtype: torch.dtype = torch.float32,\n    device: Optional[torch.device] = None,\n) -> torch.Tensor:\n    \"\"\"\n\n    Args:\n        feat_shape:\n        dim:\n        temperature:\n        reverse_coord: stack grid order W, H instead of H, W\n        interleave_sin_cos: sin, cos, sin, cos stack instead of sin, sin, cos, cos\n        dtype:\n        device:\n\n    Returns:\n\n    \"\"\"\n    assert (\n        dim % 4 == 0\n    ), \"Embed dimension must be divisible by 4 for sin-cos 2D position embedding\"\n    pos_dim = dim // 4\n    bands = inv_freq_bands(\n        pos_dim, temperature=temperature, step=1, dtype=dtype, device=device\n    )\n\n    if reverse_coord:\n        feat_shape = feat_shape[::-1]  # stack W, H instead of H, W\n    grid = (\n        torch.stack(\n            torch.meshgrid(\n                [torch.arange(s, device=device, dtype=dtype) for s in feat_shape]\n            )\n        )\n        .flatten(1)\n        .transpose(0, 1)\n    )\n    pos2 = grid.unsqueeze(-1) * bands.unsqueeze(0)\n    # FIXME add support for unflattened spatial dim?\n\n    stack_dim = (\n        2 if interleave_sin_cos else 1\n    )  # stack sin, cos, sin, cos  instead of sin sin cos cos\n    pos_emb = torch.stack([torch.sin(pos2), torch.cos(pos2)], dim=stack_dim).flatten(1)\n    return pos_emb\n\n\ndef build_fourier_pos_embed(\n    feat_shape: List[int],\n    bands: Optional[torch.Tensor] = None,\n    num_bands: int = 64,\n    max_res: int = 224,\n    linear_bands: bool = False,\n    include_grid: bool = False,\n    concat_out: bool = True,\n    in_pixels: bool = True,\n    dtype: torch.dtype = torch.float32,\n    device: Optional[torch.device] = None,\n) -> List[torch.Tensor]:\n    if bands is None:\n        if in_pixels:\n            bands = pixel_freq_bands(\n                num_bands,\n                float(max_res),\n                linear_bands=linear_bands,\n                dtype=dtype,\n                device=device,\n            )\n        else:\n            bands = inv_freq_bands(num_bands, step=1, dtype=dtype, device=device)\n    else:\n        if device is None:\n            device = bands.device\n        if dtype is None:\n            dtype = bands.dtype\n\n    if in_pixels:\n        grid = torch.stack(\n            torch.meshgrid(\n                [\n                    torch.linspace(-1.0, 1.0, steps=s, device=device, dtype=dtype)\n                    for s in feat_shape\n                ]\n            ),\n            dim=-1,\n        )\n    else:\n        grid = torch.stack(\n            torch.meshgrid(\n                [torch.arange(s, device=device, dtype=dtype) for s in feat_shape]\n            ),\n            dim=-1,\n        )\n    grid = grid.unsqueeze(-1)\n    pos = grid * bands\n\n    pos_sin, pos_cos = pos.sin(), pos.cos()\n    out = (grid, pos_sin, pos_cos) if include_grid else (pos_sin, pos_cos)\n    # FIXME torchscript doesn't like multiple return types, probably need to always cat?\n    if concat_out:\n        out = torch.cat(out, dim=-1)\n    return out\n\n\nclass FourierEmbed(nn.Module):\n    def __init__(\n        self,\n        max_res: int = 224,\n        num_bands: int = 64,\n        concat_grid=True,\n        keep_spatial=False,\n    ):\n        super().__init__()\n        self.max_res = max_res\n        self.num_bands = num_bands\n        self.concat_grid = concat_grid\n        self.keep_spatial = keep_spatial\n        self.register_buffer(\n            \"bands\", pixel_freq_bands(max_res, num_bands), persistent=False\n        )\n\n    def forward(self, x):\n        B, C = x.shape[:2]\n        feat_shape = x.shape[2:]\n        emb = build_fourier_pos_embed(\n            feat_shape,\n            self.bands,\n            include_grid=self.concat_grid,\n            dtype=x.dtype,\n            device=x.device,\n        )\n        emb = emb.transpose(-1, -2).flatten(len(feat_shape))\n        batch_expand = (B,) + (-1,) * (x.ndim - 1)\n\n        # FIXME support nD\n        if self.keep_spatial:\n            x = torch.cat(\n                [x, emb.unsqueeze(0).expand(batch_expand).permute(0, 3, 1, 2)], dim=1\n            )\n        else:\n            x = torch.cat(\n                [x.permute(0, 2, 3, 1), emb.unsqueeze(0).expand(batch_expand)], dim=-1\n            )\n            x = x.reshape(B, feat_shape.numel(), -1)\n\n        return x\n\n\ndef rot(x):\n    return torch.stack([-x[..., 1::2], x[..., ::2]], -1).reshape(x.shape)\n\n\ndef apply_rot_embed(x: torch.Tensor, sin_emb, cos_emb):\n    return x * cos_emb + rot(x) * sin_emb\n\n\ndef apply_rot_embed_list(x: List[torch.Tensor], sin_emb, cos_emb):\n    if isinstance(x, torch.Tensor):\n        x = [x]\n    return [t * cos_emb + rot(t) * sin_emb for t in x]\n\n\ndef apply_rot_embed_split(x: torch.Tensor, emb):\n    split = emb.shape[-1] // 2\n    return x * emb[:, :split] + rot(x) * emb[:, split:]\n\n\ndef build_rotary_pos_embed(\n    feat_shape: List[int],\n    bands: Optional[torch.Tensor] = None,\n    dim: int = 64,\n    max_freq: float = 224,\n    linear_bands: bool = False,\n    dtype: torch.dtype = torch.float32,\n    device: Optional[torch.device] = None,\n):\n    \"\"\"\n    NOTE: shape arg should include spatial dim only\n    \"\"\"\n    feat_shape = torch.Size(feat_shape)\n\n    sin_emb, cos_emb = build_fourier_pos_embed(\n        feat_shape,\n        bands=bands,\n        num_bands=dim // 4,\n        max_res=max_freq,\n        linear_bands=linear_bands,\n        concat_out=False,\n        device=device,\n        dtype=dtype,\n    )\n    N = feat_shape.numel()\n    sin_emb = sin_emb.reshape(N, -1).repeat_interleave(2, -1)\n    cos_emb = cos_emb.reshape(N, -1).repeat_interleave(2, -1)\n    return sin_emb, cos_emb\n\n\nclass RotaryEmbedding(nn.Module):\n    \"\"\"Rotary position embedding\n\n    NOTE: This is my initial attempt at impl rotary embedding for spatial use, it has not\n    been well tested, and will likely change. It will be moved to its own file.\n\n    The following impl/resources were referenced for this impl:\n    * https://github.com/lucidrains/vit-pytorch/blob/6f3a5fcf0bca1c5ec33a35ef48d97213709df4ba/vit_pytorch/rvt.py\n    * https://blog.eleuther.ai/rotary-embeddings/\n    \"\"\"\n\n    def __init__(self, dim, max_res=224, linear_bands: bool = False):\n        super().__init__()\n        self.dim = dim\n        self.register_buffer(\n            \"bands\",\n            pixel_freq_bands(dim // 4, max_res, linear_bands=linear_bands),\n            persistent=False,\n        )\n\n    def get_embed(self, shape: List[int]):\n        return build_rotary_pos_embed(shape, self.bands)\n\n    def forward(self, x):\n        # assuming channel-first tensor where spatial dim are >= 2\n        sin_emb, cos_emb = self.get_embed(x.shape[2:])\n        return apply_rot_embed(x, sin_emb, cos_emb)\n"
  },
  {
    "path": "RVT/models/layers/maxvit/layers/selective_kernel.py",
    "content": "\"\"\" Selective Kernel Convolution/Attention\n\nPaper: Selective Kernel Networks (https://arxiv.org/abs/1903.06586)\n\nHacked together by / Copyright 2020 Ross Wightman\n\"\"\"\n\nimport torch\nfrom torch import nn as nn\n\nfrom .conv_bn_act import ConvNormActAa\nfrom .helpers import make_divisible\nfrom .trace_utils import _assert\n\n\ndef _kernel_valid(k):\n    if isinstance(k, (list, tuple)):\n        for ki in k:\n            return _kernel_valid(ki)\n    assert k >= 3 and k % 2\n\n\nclass SelectiveKernelAttn(nn.Module):\n    def __init__(\n        self,\n        channels,\n        num_paths=2,\n        attn_channels=32,\n        act_layer=nn.ReLU,\n        norm_layer=nn.BatchNorm2d,\n    ):\n        \"\"\"Selective Kernel Attention Module\n\n        Selective Kernel attention mechanism factored out into its own module.\n\n        \"\"\"\n        super(SelectiveKernelAttn, self).__init__()\n        self.num_paths = num_paths\n        self.fc_reduce = nn.Conv2d(channels, attn_channels, kernel_size=1, bias=False)\n        self.bn = norm_layer(attn_channels)\n        self.act = act_layer(inplace=True)\n        self.fc_select = nn.Conv2d(\n            attn_channels, channels * num_paths, kernel_size=1, bias=False\n        )\n\n    def forward(self, x):\n        _assert(x.shape[1] == self.num_paths, \"\")\n        x = x.sum(1).mean((2, 3), keepdim=True)\n        x = self.fc_reduce(x)\n        x = self.bn(x)\n        x = self.act(x)\n        x = self.fc_select(x)\n        B, C, H, W = x.shape\n        x = x.view(B, self.num_paths, C // self.num_paths, H, W)\n        x = torch.softmax(x, dim=1)\n        return x\n\n\nclass SelectiveKernel(nn.Module):\n    def __init__(\n        self,\n        in_channels,\n        out_channels=None,\n        kernel_size=None,\n        stride=1,\n        dilation=1,\n        groups=1,\n        rd_ratio=1.0 / 16,\n        rd_channels=None,\n        rd_divisor=8,\n        keep_3x3=True,\n        split_input=True,\n        act_layer=nn.ReLU,\n        norm_layer=nn.BatchNorm2d,\n        aa_layer=None,\n        drop_layer=None,\n    ):\n        \"\"\"Selective Kernel Convolution Module\n\n        As described in Selective Kernel Networks (https://arxiv.org/abs/1903.06586) with some modifications.\n\n        Largest change is the input split, which divides the input channels across each convolution path, this can\n        be viewed as a grouping of sorts, but the output channel counts expand to the module level value. This keeps\n        the parameter count from ballooning when the convolutions themselves don't have groups, but still provides\n        a noteworthy increase in performance over similar param count models without this attention layer. -Ross W\n\n        Args:\n            in_channels (int):  module input (feature) channel count\n            out_channels (int):  module output (feature) channel count\n            kernel_size (int, list): kernel size for each convolution branch\n            stride (int): stride for convolutions\n            dilation (int): dilation for module as a whole, impacts dilation of each branch\n            groups (int): number of groups for each branch\n            rd_ratio (int, float): reduction factor for attention features\n            keep_3x3 (bool): keep all branch convolution kernels as 3x3, changing larger kernels for dilations\n            split_input (bool): split input channels evenly across each convolution branch, keeps param count lower,\n                can be viewed as grouping by path, output expands to module out_channels count\n            act_layer (nn.Module): activation layer to use\n            norm_layer (nn.Module): batchnorm/norm layer to use\n            aa_layer (nn.Module): anti-aliasing module\n            drop_layer (nn.Module): spatial drop module in convs (drop block, etc)\n        \"\"\"\n        super(SelectiveKernel, self).__init__()\n        out_channels = out_channels or in_channels\n        kernel_size = kernel_size or [\n            3,\n            5,\n        ]  # default to one 3x3 and one 5x5 branch. 5x5 -> 3x3 + dilation\n        _kernel_valid(kernel_size)\n        if not isinstance(kernel_size, list):\n            kernel_size = [kernel_size] * 2\n        if keep_3x3:\n            dilation = [dilation * (k - 1) // 2 for k in kernel_size]\n            kernel_size = [3] * len(kernel_size)\n        else:\n            dilation = [dilation] * len(kernel_size)\n        self.num_paths = len(kernel_size)\n        self.in_channels = in_channels\n        self.out_channels = out_channels\n        self.split_input = split_input\n        if self.split_input:\n            assert in_channels % self.num_paths == 0\n            in_channels = in_channels // self.num_paths\n        groups = min(out_channels, groups)\n\n        conv_kwargs = dict(\n            stride=stride,\n            groups=groups,\n            act_layer=act_layer,\n            norm_layer=norm_layer,\n            aa_layer=aa_layer,\n            drop_layer=drop_layer,\n        )\n        self.paths = nn.ModuleList(\n            [\n                ConvNormActAa(\n                    in_channels, out_channels, kernel_size=k, dilation=d, **conv_kwargs\n                )\n                for k, d in zip(kernel_size, dilation)\n            ]\n        )\n\n        attn_channels = rd_channels or make_divisible(\n            out_channels * rd_ratio, divisor=rd_divisor\n        )\n        self.attn = SelectiveKernelAttn(out_channels, self.num_paths, attn_channels)\n\n    def forward(self, x):\n        if self.split_input:\n            x_split = torch.split(x, self.in_channels // self.num_paths, 1)\n            x_paths = [op(x_split[i]) for i, op in enumerate(self.paths)]\n        else:\n            x_paths = [op(x) for op in self.paths]\n        x = torch.stack(x_paths, dim=1)\n        x_attn = self.attn(x)\n        x = x * x_attn\n        x = torch.sum(x, dim=1)\n        return x\n"
  },
  {
    "path": "RVT/models/layers/maxvit/layers/separable_conv.py",
    "content": "\"\"\" Depthwise Separable Conv Modules\n\nBasic DWS convs. Other variations of DWS exist with batch norm or activations between the\nDW and PW convs such as the Depthwise modules in MobileNetV2 / EfficientNet and Xception.\n\nHacked together by / Copyright 2020 Ross Wightman\n\"\"\"\n\nfrom torch import nn as nn\n\nfrom .create_conv2d import create_conv2d\nfrom .create_norm_act import get_norm_act_layer\n\n\nclass SeparableConvNormAct(nn.Module):\n    \"\"\"Separable Conv w/ trailing Norm and Activation\"\"\"\n\n    def __init__(\n        self,\n        in_channels,\n        out_channels,\n        kernel_size=3,\n        stride=1,\n        dilation=1,\n        padding=\"\",\n        bias=False,\n        channel_multiplier=1.0,\n        pw_kernel_size=1,\n        norm_layer=nn.BatchNorm2d,\n        act_layer=nn.ReLU,\n        apply_act=True,\n        drop_layer=None,\n    ):\n        super(SeparableConvNormAct, self).__init__()\n\n        self.conv_dw = create_conv2d(\n            in_channels,\n            int(in_channels * channel_multiplier),\n            kernel_size,\n            stride=stride,\n            dilation=dilation,\n            padding=padding,\n            depthwise=True,\n        )\n\n        self.conv_pw = create_conv2d(\n            int(in_channels * channel_multiplier),\n            out_channels,\n            pw_kernel_size,\n            padding=padding,\n            bias=bias,\n        )\n\n        norm_act_layer = get_norm_act_layer(norm_layer, act_layer)\n        norm_kwargs = dict(drop_layer=drop_layer) if drop_layer is not None else {}\n        self.bn = norm_act_layer(out_channels, apply_act=apply_act, **norm_kwargs)\n\n    @property\n    def in_channels(self):\n        return self.conv_dw.in_channels\n\n    @property\n    def out_channels(self):\n        return self.conv_pw.out_channels\n\n    def forward(self, x):\n        x = self.conv_dw(x)\n        x = self.conv_pw(x)\n        x = self.bn(x)\n        return x\n\n\nSeparableConvBnAct = SeparableConvNormAct\n\n\nclass SeparableConv2d(nn.Module):\n    \"\"\"Separable Conv\"\"\"\n\n    def __init__(\n        self,\n        in_channels,\n        out_channels,\n        kernel_size=3,\n        stride=1,\n        dilation=1,\n        padding=\"\",\n        bias=False,\n        channel_multiplier=1.0,\n        pw_kernel_size=1,\n    ):\n        super(SeparableConv2d, self).__init__()\n\n        self.conv_dw = create_conv2d(\n            in_channels,\n            int(in_channels * channel_multiplier),\n            kernel_size,\n            stride=stride,\n            dilation=dilation,\n            padding=padding,\n            depthwise=True,\n        )\n\n        self.conv_pw = create_conv2d(\n            int(in_channels * channel_multiplier),\n            out_channels,\n            pw_kernel_size,\n            padding=padding,\n            bias=bias,\n        )\n\n    @property\n    def in_channels(self):\n        return self.conv_dw.in_channels\n\n    @property\n    def out_channels(self):\n        return self.conv_pw.out_channels\n\n    def forward(self, x):\n        x = self.conv_dw(x)\n        x = self.conv_pw(x)\n        return x\n"
  },
  {
    "path": "RVT/models/layers/maxvit/layers/space_to_depth.py",
    "content": "import torch\nimport torch.nn as nn\n\n\nclass SpaceToDepth(nn.Module):\n    def __init__(self, block_size=4):\n        super().__init__()\n        assert block_size == 4\n        self.bs = block_size\n\n    def forward(self, x):\n        N, C, H, W = x.size()\n        x = x.view(\n            N, C, H // self.bs, self.bs, W // self.bs, self.bs\n        )  # (N, C, H//bs, bs, W//bs, bs)\n        x = x.permute(0, 3, 5, 1, 2, 4).contiguous()  # (N, bs, bs, C, H//bs, W//bs)\n        x = x.view(\n            N, C * (self.bs**2), H // self.bs, W // self.bs\n        )  # (N, C*bs^2, H//bs, W//bs)\n        return x\n\n\n@torch.jit.script\nclass SpaceToDepthJit(object):\n    def __call__(self, x: torch.Tensor):\n        # assuming hard-coded that block_size==4 for acceleration\n        N, C, H, W = x.size()\n        x = x.view(N, C, H // 4, 4, W // 4, 4)  # (N, C, H//bs, bs, W//bs, bs)\n        x = x.permute(0, 3, 5, 1, 2, 4).contiguous()  # (N, bs, bs, C, H//bs, W//bs)\n        x = x.view(N, C * 16, H // 4, W // 4)  # (N, C*bs^2, H//bs, W//bs)\n        return x\n\n\nclass SpaceToDepthModule(nn.Module):\n    def __init__(self, no_jit=False):\n        super().__init__()\n        if not no_jit:\n            self.op = SpaceToDepthJit()\n        else:\n            self.op = SpaceToDepth()\n\n    def forward(self, x):\n        return self.op(x)\n\n\nclass DepthToSpace(nn.Module):\n    def __init__(self, block_size):\n        super().__init__()\n        self.bs = block_size\n\n    def forward(self, x):\n        N, C, H, W = x.size()\n        x = x.view(\n            N, self.bs, self.bs, C // (self.bs**2), H, W\n        )  # (N, bs, bs, C//bs^2, H, W)\n        x = x.permute(0, 3, 4, 1, 5, 2).contiguous()  # (N, C//bs^2, H, bs, W, bs)\n        x = x.view(\n            N, C // (self.bs**2), H * self.bs, W * self.bs\n        )  # (N, C//bs^2, H * bs, W * bs)\n        return x\n"
  },
  {
    "path": "RVT/models/layers/maxvit/layers/split_attn.py",
    "content": "\"\"\" Split Attention Conv2d (for ResNeSt Models)\n\nPaper: `ResNeSt: Split-Attention Networks` - /https://arxiv.org/abs/2004.08955\n\nAdapted from original PyTorch impl at https://github.com/zhanghang1989/ResNeSt\n\nModified for torchscript compat, performance, and consistency with timm by Ross Wightman\n\"\"\"\n\nimport torch\nimport torch.nn.functional as F\nfrom torch import nn\n\nfrom .helpers import make_divisible\n\n\nclass RadixSoftmax(nn.Module):\n    def __init__(self, radix, cardinality):\n        super(RadixSoftmax, self).__init__()\n        self.radix = radix\n        self.cardinality = cardinality\n\n    def forward(self, x):\n        batch = x.size(0)\n        if self.radix > 1:\n            x = x.view(batch, self.cardinality, self.radix, -1).transpose(1, 2)\n            x = F.softmax(x, dim=1)\n            x = x.reshape(batch, -1)\n        else:\n            x = torch.sigmoid(x)\n        return x\n\n\nclass SplitAttn(nn.Module):\n    \"\"\"Split-Attention (aka Splat)\"\"\"\n\n    def __init__(\n        self,\n        in_channels,\n        out_channels=None,\n        kernel_size=3,\n        stride=1,\n        padding=None,\n        dilation=1,\n        groups=1,\n        bias=False,\n        radix=2,\n        rd_ratio=0.25,\n        rd_channels=None,\n        rd_divisor=8,\n        act_layer=nn.ReLU,\n        norm_layer=None,\n        drop_layer=None,\n        **kwargs\n    ):\n        super(SplitAttn, self).__init__()\n        out_channels = out_channels or in_channels\n        self.radix = radix\n        mid_chs = out_channels * radix\n        if rd_channels is None:\n            attn_chs = make_divisible(\n                in_channels * radix * rd_ratio, min_value=32, divisor=rd_divisor\n            )\n        else:\n            attn_chs = rd_channels * radix\n\n        padding = kernel_size // 2 if padding is None else padding\n        self.conv = nn.Conv2d(\n            in_channels,\n            mid_chs,\n            kernel_size,\n            stride,\n            padding,\n            dilation,\n            groups=groups * radix,\n            bias=bias,\n            **kwargs\n        )\n        self.bn0 = norm_layer(mid_chs) if norm_layer else nn.Identity()\n        self.drop = drop_layer() if drop_layer is not None else nn.Identity()\n        self.act0 = act_layer(inplace=True)\n        self.fc1 = nn.Conv2d(out_channels, attn_chs, 1, groups=groups)\n        self.bn1 = norm_layer(attn_chs) if norm_layer else nn.Identity()\n        self.act1 = act_layer(inplace=True)\n        self.fc2 = nn.Conv2d(attn_chs, mid_chs, 1, groups=groups)\n        self.rsoftmax = RadixSoftmax(radix, groups)\n\n    def forward(self, x):\n        x = self.conv(x)\n        x = self.bn0(x)\n        x = self.drop(x)\n        x = self.act0(x)\n\n        B, RC, H, W = x.shape\n        if self.radix > 1:\n            x = x.reshape((B, self.radix, RC // self.radix, H, W))\n            x_gap = x.sum(dim=1)\n        else:\n            x_gap = x\n        x_gap = x_gap.mean((2, 3), keepdim=True)\n        x_gap = self.fc1(x_gap)\n        x_gap = self.bn1(x_gap)\n        x_gap = self.act1(x_gap)\n        x_attn = self.fc2(x_gap)\n\n        x_attn = self.rsoftmax(x_attn).view(B, -1, 1, 1)\n        if self.radix > 1:\n            out = (x * x_attn.reshape((B, self.radix, RC // self.radix, 1, 1))).sum(\n                dim=1\n            )\n        else:\n            out = x * x_attn\n        return out.contiguous()\n"
  },
  {
    "path": "RVT/models/layers/maxvit/layers/split_batchnorm.py",
    "content": "\"\"\" Split BatchNorm\n\nA PyTorch BatchNorm layer that splits input batch into N equal parts and passes each through\na separate BN layer. The first split is passed through the parent BN layers with weight/bias\nkeys the same as the original BN. All other splits pass through BN sub-layers under the '.aux_bn'\nnamespace.\n\nThis allows easily removing the auxiliary BN layers after training to efficiently\nachieve the 'Auxiliary BatchNorm' as described in the AdvProp Paper, section 4.2,\n'Disentangled Learning via An Auxiliary BN'\n\nHacked together by / Copyright 2020 Ross Wightman\n\"\"\"\n\nimport torch\nimport torch.nn as nn\n\n\nclass SplitBatchNorm2d(torch.nn.BatchNorm2d):\n    def __init__(\n        self,\n        num_features,\n        eps=1e-5,\n        momentum=0.1,\n        affine=True,\n        track_running_stats=True,\n        num_splits=2,\n    ):\n        super().__init__(num_features, eps, momentum, affine, track_running_stats)\n        assert (\n            num_splits > 1\n        ), \"Should have at least one aux BN layer (num_splits at least 2)\"\n        self.num_splits = num_splits\n        self.aux_bn = nn.ModuleList(\n            [\n                nn.BatchNorm2d(num_features, eps, momentum, affine, track_running_stats)\n                for _ in range(num_splits - 1)\n            ]\n        )\n\n    def forward(self, input: torch.Tensor):\n        if self.training:  # aux BN only relevant while training\n            split_size = input.shape[0] // self.num_splits\n            assert (\n                input.shape[0] == split_size * self.num_splits\n            ), \"batch size must be evenly divisible by num_splits\"\n            split_input = input.split(split_size)\n            x = [super().forward(split_input[0])]\n            for i, a in enumerate(self.aux_bn):\n                x.append(a(split_input[i + 1]))\n            return torch.cat(x, dim=0)\n        else:\n            return super().forward(input)\n\n\ndef convert_splitbn_model(module, num_splits=2):\n    \"\"\"\n    Recursively traverse module and its children to replace all instances of\n    ``torch.nn.modules.batchnorm._BatchNorm`` with `SplitBatchnorm2d`.\n    Args:\n        module (torch.nn.Module): input module\n        num_splits: number of separate batchnorm layers to split input across\n    Example::\n        >>> # model is an instance of torch.nn.Module\n        >>> model = timm.models.convert_splitbn_model(model, num_splits=2)\n    \"\"\"\n    mod = module\n    if isinstance(module, torch.nn.modules.instancenorm._InstanceNorm):\n        return module\n    if isinstance(module, torch.nn.modules.batchnorm._BatchNorm):\n        mod = SplitBatchNorm2d(\n            module.num_features,\n            module.eps,\n            module.momentum,\n            module.affine,\n            module.track_running_stats,\n            num_splits=num_splits,\n        )\n        mod.running_mean = module.running_mean\n        mod.running_var = module.running_var\n        mod.num_batches_tracked = module.num_batches_tracked\n        if module.affine:\n            mod.weight.data = module.weight.data.clone().detach()\n            mod.bias.data = module.bias.data.clone().detach()\n        for aux in mod.aux_bn:\n            aux.running_mean = module.running_mean.clone()\n            aux.running_var = module.running_var.clone()\n            aux.num_batches_tracked = module.num_batches_tracked.clone()\n            if module.affine:\n                aux.weight.data = module.weight.data.clone().detach()\n                aux.bias.data = module.bias.data.clone().detach()\n    for name, child in module.named_children():\n        mod.add_module(name, convert_splitbn_model(child, num_splits=num_splits))\n    del module\n    return mod\n"
  },
  {
    "path": "RVT/models/layers/maxvit/layers/squeeze_excite.py",
    "content": "\"\"\" Squeeze-and-Excitation Channel Attention\n\nAn SE implementation originally based on PyTorch SE-Net impl.\nHas since evolved with additional functionality / configuration.\n\nPaper: `Squeeze-and-Excitation Networks` - https://arxiv.org/abs/1709.01507\n\nAlso included is Effective Squeeze-Excitation (ESE).\nPaper: `CenterMask : Real-Time Anchor-Free Instance Segmentation` - https://arxiv.org/abs/1911.06667\n\nHacked together by / Copyright 2021 Ross Wightman\n\"\"\"\n\nfrom torch import nn as nn\n\nfrom .create_act import create_act_layer\nfrom .helpers import make_divisible\n\n\nclass SEModule(nn.Module):\n    \"\"\"SE Module as defined in original SE-Nets with a few additions\n    Additions include:\n        * divisor can be specified to keep channels % div == 0 (default: 8)\n        * reduction channels can be specified directly by arg (if rd_channels is set)\n        * reduction channels can be specified by float rd_ratio (default: 1/16)\n        * global max pooling can be added to the squeeze aggregation\n        * customizable activation, normalization, and gate layer\n    \"\"\"\n\n    def __init__(\n        self,\n        channels,\n        rd_ratio=1.0 / 16,\n        rd_channels=None,\n        rd_divisor=8,\n        add_maxpool=False,\n        bias=True,\n        act_layer=nn.ReLU,\n        norm_layer=None,\n        gate_layer=\"sigmoid\",\n    ):\n        super(SEModule, self).__init__()\n        self.add_maxpool = add_maxpool\n        if not rd_channels:\n            rd_channels = make_divisible(\n                channels * rd_ratio, rd_divisor, round_limit=0.0\n            )\n        self.fc1 = nn.Conv2d(channels, rd_channels, kernel_size=1, bias=bias)\n        self.bn = norm_layer(rd_channels) if norm_layer else nn.Identity()\n        self.act = create_act_layer(act_layer, inplace=True)\n        self.fc2 = nn.Conv2d(rd_channels, channels, kernel_size=1, bias=bias)\n        self.gate = create_act_layer(gate_layer)\n\n    def forward(self, x):\n        x_se = x.mean((2, 3), keepdim=True)\n        if self.add_maxpool:\n            # experimental codepath, may remove or change\n            x_se = 0.5 * x_se + 0.5 * x.amax((2, 3), keepdim=True)\n        x_se = self.fc1(x_se)\n        x_se = self.act(self.bn(x_se))\n        x_se = self.fc2(x_se)\n        return x * self.gate(x_se)\n\n\nSqueezeExcite = SEModule  # alias\n\n\nclass EffectiveSEModule(nn.Module):\n    \"\"\"'Effective Squeeze-Excitation\n    From `CenterMask : Real-Time Anchor-Free Instance Segmentation` - https://arxiv.org/abs/1911.06667\n    \"\"\"\n\n    def __init__(self, channels, add_maxpool=False, gate_layer=\"hard_sigmoid\", **_):\n        super(EffectiveSEModule, self).__init__()\n        self.add_maxpool = add_maxpool\n        self.fc = nn.Conv2d(channels, channels, kernel_size=1, padding=0)\n        self.gate = create_act_layer(gate_layer)\n\n    def forward(self, x):\n        x_se = x.mean((2, 3), keepdim=True)\n        if self.add_maxpool:\n            # experimental codepath, may remove or change\n            x_se = 0.5 * x_se + 0.5 * x.amax((2, 3), keepdim=True)\n        x_se = self.fc(x_se)\n        return x * self.gate(x_se)\n\n\nEffectiveSqueezeExcite = EffectiveSEModule  # alias\n"
  },
  {
    "path": "RVT/models/layers/maxvit/layers/std_conv.py",
    "content": "\"\"\" Convolution with Weight Standardization (StdConv and ScaledStdConv)\n\nStdConv:\n@article{weightstandardization,\n  author    = {Siyuan Qiao and Huiyu Wang and Chenxi Liu and Wei Shen and Alan Yuille},\n  title     = {Weight Standardization},\n  journal   = {arXiv preprint arXiv:1903.10520},\n  year      = {2019},\n}\nCode: https://github.com/joe-siyuan-qiao/WeightStandardization\n\nScaledStdConv:\nPaper: `Characterizing signal propagation to close the performance gap in unnormalized ResNets`\n    - https://arxiv.org/abs/2101.08692\nOfficial Deepmind JAX code: https://github.com/deepmind/deepmind-research/tree/master/nfnets\n\nHacked together by / copyright Ross Wightman, 2021.\n\"\"\"\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom .padding import get_padding, get_padding_value, pad_same\n\n\nclass StdConv2d(nn.Conv2d):\n    \"\"\"Conv2d with Weight Standardization. Used for BiT ResNet-V2 models.\n\n    Paper: `Micro-Batch Training with Batch-Channel Normalization and Weight Standardization` -\n        https://arxiv.org/abs/1903.10520v2\n    \"\"\"\n\n    def __init__(\n        self,\n        in_channel,\n        out_channels,\n        kernel_size,\n        stride=1,\n        padding=None,\n        dilation=1,\n        groups=1,\n        bias=False,\n        eps=1e-6,\n    ):\n        if padding is None:\n            padding = get_padding(kernel_size, stride, dilation)\n        super().__init__(\n            in_channel,\n            out_channels,\n            kernel_size,\n            stride=stride,\n            padding=padding,\n            dilation=dilation,\n            groups=groups,\n            bias=bias,\n        )\n        self.eps = eps\n\n    def forward(self, x):\n        weight = F.batch_norm(\n            self.weight.reshape(1, self.out_channels, -1),\n            None,\n            None,\n            training=True,\n            momentum=0.0,\n            eps=self.eps,\n        ).reshape_as(self.weight)\n        x = F.conv2d(\n            x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups\n        )\n        return x\n\n\nclass StdConv2dSame(nn.Conv2d):\n    \"\"\"Conv2d with Weight Standardization. TF compatible SAME padding. Used for ViT Hybrid model.\n\n    Paper: `Micro-Batch Training with Batch-Channel Normalization and Weight Standardization` -\n        https://arxiv.org/abs/1903.10520v2\n    \"\"\"\n\n    def __init__(\n        self,\n        in_channel,\n        out_channels,\n        kernel_size,\n        stride=1,\n        padding=\"SAME\",\n        dilation=1,\n        groups=1,\n        bias=False,\n        eps=1e-6,\n    ):\n        padding, is_dynamic = get_padding_value(\n            padding, kernel_size, stride=stride, dilation=dilation\n        )\n        super().__init__(\n            in_channel,\n            out_channels,\n            kernel_size,\n            stride=stride,\n            padding=padding,\n            dilation=dilation,\n            groups=groups,\n            bias=bias,\n        )\n        self.same_pad = is_dynamic\n        self.eps = eps\n\n    def forward(self, x):\n        if self.same_pad:\n            x = pad_same(x, self.kernel_size, self.stride, self.dilation)\n        weight = F.batch_norm(\n            self.weight.reshape(1, self.out_channels, -1),\n            None,\n            None,\n            training=True,\n            momentum=0.0,\n            eps=self.eps,\n        ).reshape_as(self.weight)\n        x = F.conv2d(\n            x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups\n        )\n        return x\n\n\nclass ScaledStdConv2d(nn.Conv2d):\n    \"\"\"Conv2d layer with Scaled Weight Standardization.\n\n    Paper: `Characterizing signal propagation to close the performance gap in unnormalized ResNets` -\n        https://arxiv.org/abs/2101.08692\n\n    NOTE: the operations used in this impl differ slightly from the DeepMind Haiku impl. The impact is minor.\n    \"\"\"\n\n    def __init__(\n        self,\n        in_channels,\n        out_channels,\n        kernel_size,\n        stride=1,\n        padding=None,\n        dilation=1,\n        groups=1,\n        bias=True,\n        gamma=1.0,\n        eps=1e-6,\n        gain_init=1.0,\n    ):\n        if padding is None:\n            padding = get_padding(kernel_size, stride, dilation)\n        super().__init__(\n            in_channels,\n            out_channels,\n            kernel_size,\n            stride=stride,\n            padding=padding,\n            dilation=dilation,\n            groups=groups,\n            bias=bias,\n        )\n        self.gain = nn.Parameter(torch.full((self.out_channels, 1, 1, 1), gain_init))\n        self.scale = gamma * self.weight[0].numel() ** -0.5  # gamma * 1 / sqrt(fan-in)\n        self.eps = eps\n\n    def forward(self, x):\n        weight = F.batch_norm(\n            self.weight.reshape(1, self.out_channels, -1),\n            None,\n            None,\n            weight=(self.gain * self.scale).view(-1),\n            training=True,\n            momentum=0.0,\n            eps=self.eps,\n        ).reshape_as(self.weight)\n        return F.conv2d(\n            x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups\n        )\n\n\nclass ScaledStdConv2dSame(nn.Conv2d):\n    \"\"\"Conv2d layer with Scaled Weight Standardization and Tensorflow-like SAME padding support\n\n    Paper: `Characterizing signal propagation to close the performance gap in unnormalized ResNets` -\n        https://arxiv.org/abs/2101.08692\n\n    NOTE: the operations used in this impl differ slightly from the DeepMind Haiku impl. The impact is minor.\n    \"\"\"\n\n    def __init__(\n        self,\n        in_channels,\n        out_channels,\n        kernel_size,\n        stride=1,\n        padding=\"SAME\",\n        dilation=1,\n        groups=1,\n        bias=True,\n        gamma=1.0,\n        eps=1e-6,\n        gain_init=1.0,\n    ):\n        padding, is_dynamic = get_padding_value(\n            padding, kernel_size, stride=stride, dilation=dilation\n        )\n        super().__init__(\n            in_channels,\n            out_channels,\n            kernel_size,\n            stride=stride,\n            padding=padding,\n            dilation=dilation,\n            groups=groups,\n            bias=bias,\n        )\n        self.gain = nn.Parameter(torch.full((self.out_channels, 1, 1, 1), gain_init))\n        self.scale = gamma * self.weight[0].numel() ** -0.5\n        self.same_pad = is_dynamic\n        self.eps = eps\n\n    def forward(self, x):\n        if self.same_pad:\n            x = pad_same(x, self.kernel_size, self.stride, self.dilation)\n        weight = F.batch_norm(\n            self.weight.reshape(1, self.out_channels, -1),\n            None,\n            None,\n            weight=(self.gain * self.scale).view(-1),\n            training=True,\n            momentum=0.0,\n            eps=self.eps,\n        ).reshape_as(self.weight)\n        return F.conv2d(\n            x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups\n        )\n"
  },
  {
    "path": "RVT/models/layers/maxvit/layers/test_time_pool.py",
    "content": "\"\"\" Test Time Pooling (Average-Max Pool)\n\nHacked together by / Copyright 2020 Ross Wightman\n\"\"\"\n\nimport logging\nfrom torch import nn\nimport torch.nn.functional as F\n\nfrom .adaptive_avgmax_pool import adaptive_avgmax_pool2d\n\n\n_logger = logging.getLogger(__name__)\n\n\nclass TestTimePoolHead(nn.Module):\n    def __init__(self, base, original_pool=7):\n        super(TestTimePoolHead, self).__init__()\n        self.base = base\n        self.original_pool = original_pool\n        base_fc = self.base.get_classifier()\n        if isinstance(base_fc, nn.Conv2d):\n            self.fc = base_fc\n        else:\n            self.fc = nn.Conv2d(\n                self.base.num_features, self.base.num_classes, kernel_size=1, bias=True\n            )\n            self.fc.weight.data.copy_(base_fc.weight.data.view(self.fc.weight.size()))\n            self.fc.bias.data.copy_(base_fc.bias.data.view(self.fc.bias.size()))\n        self.base.reset_classifier(0)  # delete original fc layer\n\n    def forward(self, x):\n        x = self.base.forward_features(x)\n        x = F.avg_pool2d(x, kernel_size=self.original_pool, stride=1)\n        x = self.fc(x)\n        x = adaptive_avgmax_pool2d(x, 1)\n        return x.view(x.size(0), -1)\n\n\ndef apply_test_time_pool(model, config, use_test_size=False):\n    test_time_pool = False\n    if not hasattr(model, \"default_cfg\") or not model.default_cfg:\n        return model, False\n    if use_test_size and \"test_input_size\" in model.default_cfg:\n        df_input_size = model.default_cfg[\"test_input_size\"]\n    else:\n        df_input_size = model.default_cfg[\"input_size\"]\n    if (\n        config[\"input_size\"][-1] > df_input_size[-1]\n        and config[\"input_size\"][-2] > df_input_size[-2]\n    ):\n        _logger.info(\n            \"Target input size %s > pretrained default %s, using test time pooling\"\n            % (str(config[\"input_size\"][-2:]), str(df_input_size[-2:]))\n        )\n        model = TestTimePoolHead(model, original_pool=model.default_cfg[\"pool_size\"])\n        test_time_pool = True\n    return model, test_time_pool\n"
  },
  {
    "path": "RVT/models/layers/maxvit/layers/trace_utils.py",
    "content": "try:\n    from torch import _assert\nexcept ImportError:\n\n    def _assert(condition: bool, message: str):\n        assert condition, message\n\n\ndef _float_to_int(x: float) -> int:\n    \"\"\"\n    Symbolic tracing helper to substitute for inbuilt `int`.\n    Hint: Inbuilt `int` can't accept an argument of type `Proxy`\n    \"\"\"\n    return int(x)\n"
  },
  {
    "path": "RVT/models/layers/maxvit/layers/weight_init.py",
    "content": "import torch\nimport math\nimport warnings\n\nfrom torch.nn.init import _calculate_fan_in_and_fan_out\n\n\ndef _trunc_normal_(tensor, mean, std, a, b):\n    # Cut & paste from PyTorch official master until it's in a few official releases - RW\n    # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf\n    def norm_cdf(x):\n        # Computes standard normal cumulative distribution function\n        return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0\n\n    if (mean < a - 2 * std) or (mean > b + 2 * std):\n        warnings.warn(\n            \"mean is more than 2 std from [a, b] in nn.init.trunc_normal_. \"\n            \"The distribution of values may be incorrect.\",\n            stacklevel=2,\n        )\n\n    # Values are generated by using a truncated uniform distribution and\n    # then using the inverse CDF for the normal distribution.\n    # Get upper and lower cdf values\n    l = norm_cdf((a - mean) / std)\n    u = norm_cdf((b - mean) / std)\n\n    # Uniformly fill tensor with values from [l, u], then translate to\n    # [2l-1, 2u-1].\n    tensor.uniform_(2 * l - 1, 2 * u - 1)\n\n    # Use inverse cdf transform for normal distribution to get truncated\n    # standard normal\n    tensor.erfinv_()\n\n    # Transform to proper mean, std\n    tensor.mul_(std * math.sqrt(2.0))\n    tensor.add_(mean)\n\n    # Clamp to ensure it's in the proper range\n    tensor.clamp_(min=a, max=b)\n    return tensor\n\n\ndef trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0):\n    # type: (Tensor, float, float, float, float) -> Tensor\n    r\"\"\"Fills the input Tensor with values drawn from a truncated\n    normal distribution. The values are effectively drawn from the\n    normal distribution :math:`\\mathcal{N}(\\text{mean}, \\text{std}^2)`\n    with values outside :math:`[a, b]` redrawn until they are within\n    the bounds. The method used for generating the random values works\n    best when :math:`a \\leq \\text{mean} \\leq b`.\n\n    NOTE: this impl is similar to the PyTorch trunc_normal_, the bounds [a, b] are\n    applied while sampling the normal with mean/std applied, therefore a, b args\n    should be adjusted to match the range of mean, std args.\n\n    Args:\n        tensor: an n-dimensional `torch.Tensor`\n        mean: the mean of the normal distribution\n        std: the standard deviation of the normal distribution\n        a: the minimum cutoff value\n        b: the maximum cutoff value\n    Examples:\n        >>> w = torch.empty(3, 5)\n        >>> nn.init.trunc_normal_(w)\n    \"\"\"\n    with torch.no_grad():\n        return _trunc_normal_(tensor, mean, std, a, b)\n\n\ndef trunc_normal_tf_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0):\n    # type: (Tensor, float, float, float, float) -> Tensor\n    r\"\"\"Fills the input Tensor with values drawn from a truncated\n    normal distribution. The values are effectively drawn from the\n    normal distribution :math:`\\mathcal{N}(\\text{mean}, \\text{std}^2)`\n    with values outside :math:`[a, b]` redrawn until they are within\n    the bounds. The method used for generating the random values works\n    best when :math:`a \\leq \\text{mean} \\leq b`.\n\n    NOTE: this 'tf' variant behaves closer to Tensorflow / JAX impl where the\n    bounds [a, b] are applied when sampling the normal distribution with mean=0, std=1.0\n    and the result is subsquently scaled and shifted by the mean and std args.\n\n    Args:\n        tensor: an n-dimensional `torch.Tensor`\n        mean: the mean of the normal distribution\n        std: the standard deviation of the normal distribution\n        a: the minimum cutoff value\n        b: the maximum cutoff value\n    Examples:\n        >>> w = torch.empty(3, 5)\n        >>> nn.init.trunc_normal_(w)\n    \"\"\"\n    with torch.no_grad():\n        _trunc_normal_(tensor, 0, 1.0, a, b)\n        tensor.mul_(std).add_(mean)\n    return tensor\n\n\ndef variance_scaling_(tensor, scale=1.0, mode=\"fan_in\", distribution=\"normal\"):\n    fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)\n    if mode == \"fan_in\":\n        denom = fan_in\n    elif mode == \"fan_out\":\n        denom = fan_out\n    elif mode == \"fan_avg\":\n        denom = (fan_in + fan_out) / 2\n\n    variance = scale / denom\n\n    if distribution == \"truncated_normal\":\n        # constant is stddev of standard normal truncated to (-2, 2)\n        trunc_normal_tf_(tensor, std=math.sqrt(variance) / 0.87962566103423978)\n    elif distribution == \"normal\":\n        with torch.no_grad():\n            tensor.normal_(std=math.sqrt(variance))\n    elif distribution == \"uniform\":\n        bound = math.sqrt(3 * variance)\n        with torch.no_grad():\n            tensor.uniform_(-bound, bound)\n    else:\n        raise ValueError(f\"invalid distribution {distribution}\")\n\n\ndef lecun_normal_(tensor):\n    variance_scaling_(tensor, mode=\"fan_in\", distribution=\"truncated_normal\")\n"
  },
  {
    "path": "RVT/models/layers/maxvit/maxvit.py",
    "content": "\"\"\"\nPart of this code stems from rwightman's MaxVit implementation:\nhttps://github.com/huggingface/pytorch-image-models/blob/1885bdc4318cc3be459981ea1a26cd862220864d/timm/models/maxxvit.py\nthat is:\n- LayerScale\n- PartitionAttentionCl\n- window*\n- grid*\n- SelfAttentionCl\n\"\"\"\n\nfrom enum import Enum, auto\nfrom functools import partial\nfrom typing import Optional, Union, Tuple, List, Type\n\nimport math\nimport torch\nfrom omegaconf import DictConfig\nfrom torch import nn\n\nfrom .layers import DropPath, LayerNorm\nfrom .layers import get_act_layer, get_norm_layer\nfrom .layers import to_2tuple, _assert\n\n\nclass PartitionType(Enum):\n    WINDOW = auto()\n    GRID = auto()\n\n\ndef nChw_2_nhwC(x: torch.Tensor):\n    \"\"\"N C H W -> N H W C\"\"\"\n    assert x.ndim == 4\n    return x.permute(0, 2, 3, 1)\n\n\ndef nhwC_2_nChw(x: torch.Tensor):\n    \"\"\"N H W C -> N C H W\"\"\"\n    assert x.ndim == 4\n    return x.permute(0, 3, 1, 2)\n\n\nclass LayerScale(nn.Module):\n    def __init__(self, dim: int, init_values: float = 1e-5, inplace: bool = False):\n        super().__init__()\n        self.inplace = inplace\n        self.gamma = nn.Parameter(init_values * torch.ones(dim))\n\n    def forward(self, x):\n        gamma = self.gamma\n        return x.mul_(gamma) if self.inplace else x * gamma\n\n\nclass GLU(nn.Module):\n    def __init__(\n        self,\n        dim_in: int,\n        dim_out: int,\n        channel_last: bool,\n        act_layer: Type[nn.Module],\n        bias: bool = True,\n    ):\n        super().__init__()\n        # Different activation functions / versions of the gated linear unit:\n        # - ReGLU:  Relu\n        # - SwiGLU: Swish/SiLU\n        # - GeGLU:  GELU\n        # - GLU:    Sigmoid\n        # seem to be the most promising once.\n        # Extensive quantitative eval in table 1: https://arxiv.org/abs/2102.11972\n        # Section 2 for explanation and implementation details: https://arxiv.org/abs/2002.05202\n        # NOTE: Pytorch has a native GLU implementation: https://pytorch.org/docs/stable/generated/torch.nn.GLU.html?highlight=glu#torch.nn.GLU\n        proj_out_dim = dim_out * 2\n        self.proj = (\n            nn.Linear(dim_in, proj_out_dim, bias=bias)\n            if channel_last\n            else nn.Conv2d(dim_in, proj_out_dim, kernel_size=1, stride=1, bias=bias)\n        )\n        self.channel_dim = -1 if channel_last else 1\n\n        self.act_layer = act_layer()\n\n    def forward(self, x: torch.Tensor):\n        x, gate = torch.tensor_split(self.proj(x), 2, dim=self.channel_dim)\n        return x * self.act_layer(gate)\n\n\nclass MLP(nn.Module):\n    def __init__(\n        self,\n        dim: int,\n        channel_last: bool,\n        expansion_ratio: int,\n        act_layer: Type[nn.Module],\n        gated: bool = True,\n        bias: bool = True,\n        drop_prob: float = 0.0,\n    ):\n        super().__init__()\n        inner_dim = int(dim * expansion_ratio)\n        if gated:\n            # To keep the number of parameters (approx) constant regardless of whether glu == True\n            # Section 2 for explanation: https://arxiv.org/abs/2002.05202\n            # inner_dim = round(inner_dim * 2 / 3)\n            # inner_dim = math.ceil(inner_dim * 2 / 3 / 32) * 32 # multiple of 32\n            # inner_dim = round(inner_dim * 2 / 3 / 32) * 32 # multiple of 32\n            inner_dim = math.floor(inner_dim * 2 / 3 / 32) * 32  # multiple of 32\n            proj_in = GLU(\n                dim_in=dim,\n                dim_out=inner_dim,\n                channel_last=channel_last,\n                act_layer=act_layer,\n                bias=bias,\n            )\n        else:\n            proj_in = nn.Sequential(\n                (\n                    nn.Linear(in_features=dim, out_features=inner_dim, bias=bias)\n                    if channel_last\n                    else nn.Conv2d(\n                        in_channels=dim,\n                        out_channels=inner_dim,\n                        kernel_size=1,\n                        stride=1,\n                        bias=bias,\n                    )\n                ),\n                act_layer(),\n            )\n        self.net = nn.Sequential(\n            proj_in,\n            nn.Dropout(p=drop_prob),\n            (\n                nn.Linear(in_features=inner_dim, out_features=dim, bias=bias)\n                if channel_last\n                else nn.Conv2d(\n                    in_channels=inner_dim,\n                    out_channels=dim,\n                    kernel_size=1,\n                    stride=1,\n                    bias=bias,\n                )\n            ),\n        )\n\n    def forward(self, x):\n        return self.net(x)\n\n\nclass DownsampleBase(nn.Module):\n    def __init__(self):\n        super().__init__()\n\n    @staticmethod\n    def output_is_normed():\n        raise NotImplementedError\n\n\ndef get_downsample_layer_Cf2Cl(\n    dim_in: int, dim_out: int, downsample_factor: int, downsample_cfg: DictConfig\n) -> DownsampleBase:\n    type = downsample_cfg.type\n    if type == \"patch\":\n        return ConvDownsampling_Cf2Cl(\n            dim_in=dim_in,\n            dim_out=dim_out,\n            downsample_factor=downsample_factor,\n            downsample_cfg=downsample_cfg,\n        )\n    raise NotImplementedError\n\n\nclass ConvDownsampling_Cf2Cl(DownsampleBase):\n    \"\"\"Downsample with input in NCHW [channel-first] format.\n    Output in NHWC [channel-last] format.\n    \"\"\"\n\n    def __init__(\n        self,\n        dim_in: int,\n        dim_out: int,\n        downsample_factor: int,\n        downsample_cfg: DictConfig,\n    ):\n        super().__init__()\n        assert isinstance(dim_out, int)\n        assert isinstance(dim_in, int)\n        assert downsample_factor in (2, 4, 8)\n\n        norm_affine = downsample_cfg.get(\"norm_affine\", True)\n        overlap = downsample_cfg.get(\"overlap\", True)\n\n        if overlap:\n            kernel_size = (downsample_factor - 1) * 2 + 1\n            padding = kernel_size // 2\n        else:\n            kernel_size = downsample_factor\n            padding = 0\n        self.conv = nn.Conv2d(\n            in_channels=dim_in,\n            out_channels=dim_out,\n            kernel_size=kernel_size,\n            padding=padding,\n            stride=downsample_factor,\n            bias=False,\n        )\n        self.norm = LayerNorm(num_channels=dim_out, eps=1e-5, affine=norm_affine)\n\n    def forward(self, x: torch.Tensor):\n        x = self.conv(x)\n        x = nChw_2_nhwC(x)\n        x = self.norm(x)\n        return x\n\n    @staticmethod\n    def output_is_normed():\n        return True\n\n\nclass PartitionAttentionCl(nn.Module):\n    \"\"\"Grid or Block partition + Attn + FFN.\n    NxC 'channels last' tensor layout.\n\n    According to RW, NHWC attention is a few percent faster on GPUs (but slower on TPUs)\n    https://github.com/rwightman/pytorch-image-models/blob/4f72bae43be26d9764a08d83b88f8bd4ec3dbe43/timm/models/maxxvit.py#L1258\n    \"\"\"\n\n    def __init__(\n        self,\n        dim: int,\n        partition_type: PartitionType,\n        attention_cfg: DictConfig,\n        skip_first_norm: bool = False,\n    ):\n        super().__init__()\n        norm_eps = attention_cfg.get(\"norm_eps\", 1e-5)\n        partition_size = attention_cfg.partition_size\n        use_torch_mha = attention_cfg.use_torch_mha\n        dim_head = attention_cfg.get(\"dim_head\", 32)\n        attention_bias = attention_cfg.get(\"attention_bias\", True)\n        mlp_act_string = attention_cfg.mlp_activation\n        mlp_gated = attention_cfg.mlp_gated\n        mlp_bias = attention_cfg.get(\"mlp_bias\", True)\n        mlp_expand_ratio = attention_cfg.get(\"mlp_ratio\", 4)\n\n        drop_path = attention_cfg.get(\"drop_path\", 0.0)\n        drop_mlp = attention_cfg.get(\"drop_mlp\", 0.0)\n        ls_init_value = attention_cfg.get(\"ls_init_value\", 1e-5)\n\n        assert isinstance(use_torch_mha, bool)\n        assert isinstance(mlp_gated, bool)\n        assert_activation_string(activation_string=mlp_act_string)\n        mlp_act_layer = get_act_layer(mlp_act_string)\n\n        self_attn_module = TorchMHSAWrapperCl if use_torch_mha else SelfAttentionCl\n\n        if isinstance(partition_size, int):\n            partition_size = to_2tuple(partition_size)\n        else:\n            partition_size = tuple(partition_size)\n            assert len(partition_size) == 2\n        self.partition_size = partition_size\n\n        norm_layer = partial(\n            get_norm_layer(\"layernorm\"), eps=norm_eps\n        )  # NOTE this block is channels-last\n\n        assert isinstance(partition_type, PartitionType)\n        self.partition_window = partition_type == PartitionType.WINDOW\n\n        self.norm1 = nn.Identity() if skip_first_norm else norm_layer(dim)\n        self.self_attn = self_attn_module(dim, dim_head=dim_head, bias=attention_bias)\n        self.ls1 = (\n            LayerScale(dim=dim, init_values=ls_init_value)\n            if ls_init_value > 0\n            else nn.Identity()\n        )\n        self.drop_path1 = (\n            DropPath(drop_prob=drop_path) if drop_path > 0 else nn.Identity()\n        )\n\n        self.norm2 = norm_layer(dim)\n        self.mlp = MLP(\n            dim=dim,\n            channel_last=True,\n            expansion_ratio=mlp_expand_ratio,\n            act_layer=mlp_act_layer,\n            gated=mlp_gated,\n            bias=mlp_bias,\n            drop_prob=drop_mlp,\n        )\n        self.ls2 = (\n            LayerScale(dim=dim, init_values=ls_init_value)\n            if ls_init_value > 0\n            else nn.Identity()\n        )\n        self.drop_path2 = (\n            DropPath(drop_prob=drop_path) if drop_path > 0 else nn.Identity()\n        )\n\n    def _partition_attn(self, x):\n        img_size = x.shape[1:3]\n        if self.partition_window:\n            partitioned = window_partition(x, self.partition_size)\n        else:\n            partitioned = grid_partition(x, self.partition_size)\n\n        partitioned = self.self_attn(partitioned)\n\n        if self.partition_window:\n            x = window_reverse(\n                partitioned, self.partition_size, (img_size[0], img_size[1])\n            )\n        else:\n            x = grid_reverse(\n                partitioned, self.partition_size, (img_size[0], img_size[1])\n            )\n        return x\n\n    def forward(self, x):\n        x = x + self.drop_path1(self.ls1(self._partition_attn(self.norm1(x))))\n        x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))\n        return x\n\n\ndef window_partition(x, window_size: Tuple[int, int]):\n    B, H, W, C = x.shape\n    _assert(\n        H % window_size[0] == 0,\n        f\"height ({H}) must be divisible by window ({window_size[0]})\",\n    )\n    _assert(\n        W % window_size[1] == 0,\n        f\"width ({W}) must be divisible by window ({window_size[1]})\",\n    )\n    x = x.view(\n        B, H // window_size[0], window_size[0], W // window_size[1], window_size[1], C\n    )\n    windows = (\n        x.permute(0, 1, 3, 2, 4, 5)\n        .contiguous()\n        .view(-1, window_size[0], window_size[1], C)\n    )\n    return windows\n\n\ndef window_reverse(windows, window_size: Tuple[int, int], img_size: Tuple[int, int]):\n    H, W = img_size\n    C = windows.shape[-1]\n    x = windows.view(\n        -1, H // window_size[0], W // window_size[1], window_size[0], window_size[1], C\n    )\n    x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, H, W, C)\n    return x\n\n\ndef grid_partition(x, grid_size: Tuple[int, int]):\n    B, H, W, C = x.shape\n    _assert(\n        H % grid_size[0] == 0, f\"height {H} must be divisible by grid {grid_size[0]}\"\n    )\n    _assert(\n        W % grid_size[1] == 0, f\"width {W} must be divisible by grid {grid_size[1]}\"\n    )\n    x = x.view(B, grid_size[0], H // grid_size[0], grid_size[1], W // grid_size[1], C)\n    windows = (\n        x.permute(0, 2, 4, 1, 3, 5).contiguous().view(-1, grid_size[0], grid_size[1], C)\n    )\n    return windows\n\n\ndef grid_reverse(windows, grid_size: Tuple[int, int], img_size: Tuple[int, int]):\n    H, W = img_size\n    C = windows.shape[-1]\n    x = windows.view(\n        -1, H // grid_size[0], W // grid_size[1], grid_size[0], grid_size[1], C\n    )\n    x = x.permute(0, 3, 1, 4, 2, 5).contiguous().view(-1, H, W, C)\n    return x\n\n\nclass TorchMHSAWrapperCl(nn.Module):\n    \"\"\"Channels-last multi-head self-attention (B, ..., C)\"\"\"\n\n    def __init__(self, dim: int, dim_head: int = 32, bias: bool = True):\n        super().__init__()\n        assert dim % dim_head == 0\n        num_heads = dim // dim_head\n        self.mha = nn.MultiheadAttention(\n            embed_dim=dim, num_heads=num_heads, bias=bias, batch_first=True\n        )\n\n    def forward(self, x: torch.Tensor):\n        restore_shape = x.shape\n        B, C = restore_shape[0], restore_shape[-1]\n        x = x.view(B, -1, C)\n        attn_output, attn_output_weights = self.mha(query=x, key=x, value=x)\n        attn_output = attn_output.reshape(restore_shape)\n        return attn_output\n\n\nclass SelfAttentionCl(nn.Module):\n    \"\"\"Channels-last multi-head self-attention (B, ..., C)\"\"\"\n\n    def __init__(self, dim: int, dim_head: int = 32, bias: bool = True):\n        super().__init__()\n        self.num_heads = dim // dim_head\n        self.dim_head = dim_head\n        self.scale = dim_head**-0.5\n\n        self.qkv = nn.Linear(dim, dim * 3, bias=bias)\n        self.proj = nn.Linear(dim, dim, bias=bias)\n\n    def forward(self, x: torch.Tensor):\n        B = x.shape[0]\n        restore_shape = x.shape[:-1]\n\n        q, k, v = (\n            self.qkv(x)\n            .view(B, -1, self.num_heads, self.dim_head * 3)\n            .transpose(1, 2)\n            .chunk(3, dim=3)\n        )\n\n        attn = (q @ k.transpose(-2, -1)) * self.scale\n        attn = attn.softmax(dim=-1)\n\n        x = (attn @ v).transpose(1, 2).reshape(restore_shape + (-1,))\n        x = self.proj(x)\n        return x\n\n\ndef assert_activation_string(\n    activation_string: Optional[Union[str, Tuple[str, ...], List[str]]]\n) -> None:\n    # Serves as a hacky documentation and sanity check.\n    # List of possible activation layer strings that are reasonable:\n    # https://github.com/rwightman/pytorch-image-models/blob/a520da9b495422bc773fb5dfe10819acb8bd7c5c/timm/models/layers/create_act.py#L62\n    if activation_string is None:\n        return\n    if isinstance(activation_string, str):\n        assert activation_string in (\n            \"silu\",\n            \"swish\",\n            \"mish\",\n            \"relu\",\n            \"relu6\",\n            \"leaky_relu\",\n            \"elu\",\n            \"prelu\",\n            \"celu\",\n            \"selu\",\n            \"gelu\",\n            \"sigmoid\",\n            \"tanh\",\n            \"hard_sigmoid\",\n            \"hard_swish\",\n            \"hard_mish\",\n        )\n    elif isinstance(activation_string, (tuple, list)):\n        for entry in activation_string:\n            assert_activation_string(activation_string=entry)\n    else:\n        raise NotImplementedError\n\n\ndef assert_norm2d_layer_string(\n    norm_layer: Optional[Union[str, Tuple[str, ...], List[str]]]\n) -> None:\n    # Serves as a hacky documentation and sanity check.\n    # List of possible norm layer strings that are reasonable:\n    # https://github.com/rwightman/pytorch-image-models/blob/4f72bae43be26d9764a08d83b88f8bd4ec3dbe43/timm/models/layers/create_norm.py#L14\n    if norm_layer is None:\n        return\n    if isinstance(norm_layer, str):\n        assert norm_layer in (\"batchnorm\", \"batchnorm2d\", \"groupnorm\", \"layernorm2d\")\n    elif isinstance(norm_layer, (tuple, list)):\n        for entry in norm_layer:\n            assert_norm2d_layer_string(norm_layer=entry)\n    else:\n        raise NotImplementedError\n"
  },
  {
    "path": "RVT/models/layers/rnn.py",
    "content": "from typing import Optional, Tuple\n\nimport torch as th\nimport torch.nn as nn\n\n\nclass DWSConvLSTM2d(nn.Module):\n    \"\"\"LSTM with (depthwise-separable) Conv option in NCHW [channel-first] format.\"\"\"\n\n    def __init__(\n        self,\n        dim: int,\n        dws_conv: bool = True,\n        dws_conv_only_hidden: bool = True,\n        dws_conv_kernel_size: int = 3,\n        cell_update_dropout: float = 0.0,\n    ):\n        super().__init__()\n        assert isinstance(dws_conv, bool)\n        assert isinstance(dws_conv_only_hidden, bool)\n        self.dim = dim\n\n        xh_dim = dim * 2\n        gates_dim = dim * 4\n        conv3x3_dws_dim = dim if dws_conv_only_hidden else xh_dim\n        self.conv3x3_dws = (\n            nn.Conv2d(\n                in_channels=conv3x3_dws_dim,\n                out_channels=conv3x3_dws_dim,\n                kernel_size=dws_conv_kernel_size,\n                padding=dws_conv_kernel_size // 2,\n                groups=conv3x3_dws_dim,\n            )\n            if dws_conv\n            else nn.Identity()\n        )\n        self.conv1x1 = nn.Conv2d(\n            in_channels=xh_dim, out_channels=gates_dim, kernel_size=1\n        )\n        self.conv_only_hidden = dws_conv_only_hidden\n        self.cell_update_dropout = nn.Dropout(p=cell_update_dropout)\n\n    def forward(\n        self,\n        x: th.Tensor,\n        h_and_c_previous: Optional[Tuple[th.Tensor, th.Tensor]] = None,\n    ) -> Tuple[th.Tensor, th.Tensor]:\n        \"\"\"\n        :param x: (N C H W)\n        :param h_and_c_previous: ((N C H W), (N C H W))\n        :return: ((N C H W), (N C H W))\n        \"\"\"\n        if h_and_c_previous is None:\n            # generate zero states\n            hidden = th.zeros_like(x)\n            cell = th.zeros_like(x)\n            h_and_c_previous = (hidden, cell)\n        h_tm1, c_tm1 = h_and_c_previous\n\n        if self.conv_only_hidden:\n            h_tm1 = self.conv3x3_dws(h_tm1)\n        xh = th.cat((x, h_tm1), dim=1)\n        if not self.conv_only_hidden:\n            xh = self.conv3x3_dws(xh)\n        mix = self.conv1x1(xh)\n\n        gates, cell_input = th.tensor_split(mix, [self.dim * 3], dim=1)\n        assert gates.shape[1] == cell_input.shape[1] * 3\n\n        gates = th.sigmoid(gates)\n        forget_gate, input_gate, output_gate = th.tensor_split(gates, 3, dim=1)\n        assert forget_gate.shape == input_gate.shape == output_gate.shape\n\n        cell_input = self.cell_update_dropout(th.tanh(cell_input))\n\n        c_t = forget_gate * c_tm1 + input_gate * cell_input\n        h_t = output_gate * th.tanh(c_t)\n\n        return h_t, c_t\n"
  },
  {
    "path": "RVT/models/layers/s5/__init__.py",
    "content": "from .s5_model import *\n"
  },
  {
    "path": "RVT/models/layers/s5/jax_func.py",
    "content": "import torch\nimport numpy as np\nfrom torch.utils._pytree import tree_flatten, tree_unflatten\nfrom typing import (\n    overload,\n    Callable,\n    Iterable,\n    List,\n    TypeVar,\n    Any,\n    Literal,\n    Sequence,\n    Optional,\n)\nfrom functools import partial\nimport math\n\n\"\"\"\nJax-Pytorch ported functions, mostly interfaces are kept the same but unsupported features are removed:\n* Jax-Keyed RNGs are sampled from global RNG\n* Canonical/Named shapes/dtypes/etc are now regular shapes,dtypes\n\"\"\"\n\nT = TypeVar(\"T\")\nT1 = TypeVar(\"T1\")\nT2 = TypeVar(\"T2\")\nT3 = TypeVar(\"T3\")\n\n\n@overload\ndef safe_map(f: Callable[[T1], T], __arg1: Iterable[T1]) -> List[T]: ...\n\n\n@overload\ndef safe_map(\n    f: Callable[[T1, T2], T], __arg1: Iterable[T1], __arg2: Iterable[T2]\n) -> List[T]: ...\n\n\n@overload\ndef safe_map(\n    f: Callable[[T1, T2, T3], T],\n    __arg1: Iterable[T1],\n    __arg2: Iterable[T2],\n    __arg3: Iterable[T3],\n) -> List[T]: ...\n\n\n@overload\ndef safe_map(\n    f: Callable[..., T],\n    __arg1: Iterable[Any],\n    __arg2: Iterable[Any],\n    __arg3: Iterable[Any],\n    __arg4: Iterable[Any],\n    *args,\n) -> List[T]: ...\n\n\ndef safe_map(f, *args):\n    args = list(map(list, args))\n    n = len(args[0])\n    for arg in args[1:]:\n        assert len(arg) == n, f\"length mismatch: {list(map(len, args))}\"\n    return list(map(f, *args))\n\n\ndef combine(tree, operator, a_flat, b_flat):\n    # Lower `fn` to operate on flattened sequences of elems.\n    a = tree_unflatten(a_flat, tree)\n    b = tree_unflatten(b_flat, tree)\n    c = operator(a, b)\n    c_flat, _ = tree_flatten(c)\n    return c_flat\n\n\ndef _scan(tree, operator, elems, axis: int):\n    \"\"\"Perform scan on `elems`.\"\"\"\n    num_elems = elems[0].shape[axis]\n\n    if num_elems < 2:\n        return elems\n\n    # Combine adjacent pairs of elements.\n    reduced_elems = combine(\n        tree,\n        operator,\n        [torch.ops.aten.slice(elem, axis, 0, -1, 2) for elem in elems],\n        [torch.ops.aten.slice(elem, axis, 1, None, 2) for elem in elems],\n    )\n\n    # Recursively compute scan for partially reduced tensors.\n    odd_elems = _scan(tree, operator, reduced_elems, axis)\n\n    if num_elems % 2 == 0:\n        even_elems = combine(\n            tree,\n            operator,\n            [torch.ops.aten.slice(e, axis, 0, -1) for e in odd_elems],\n            [torch.ops.aten.slice(e, axis, 2, None, 2) for e in elems],\n        )\n    else:\n        even_elems = combine(\n            tree,\n            operator,\n            odd_elems,\n            [torch.ops.aten.slice(e, axis, 2, None, 2) for e in elems],\n        )\n\n    # The first element of a scan is the same as the first element\n    # of the original `elems`.\n    even_elems = [\n        (\n            torch.cat([torch.ops.aten.slice(elem, axis, 0, 1), result], dim=axis)\n            if result.shape.numel() > 0 and elem.shape[axis] > 0\n            else (\n                result\n                if result.shape.numel() > 0\n                else torch.ops.aten.slice(elem, axis, 0, 1)\n            )\n        )  # Jax allows/ignores concat with 0-dim, Pytorch does not\n        for (elem, result) in zip(elems, even_elems)\n    ]\n\n    return list(safe_map(partial(_interleave, axis=axis), even_elems, odd_elems))\n\n\n# Pytorch impl. of jax.lax.associative_scan\ndef associative_scan(operator: Callable, elems, axis: int = 0, reverse: bool = False):\n    # if not callable(operator):\n    #     raise TypeError(\"lax.associative_scan: fn argument should be callable.\")\n    elems_flat, tree = tree_flatten(elems)\n\n    if reverse:\n        elems_flat = [torch.flip(elem, [axis]) for elem in elems_flat]\n\n    assert (\n        axis >= 0 or axis < elems_flat[0].ndim\n    ), \"Axis should be within bounds of input\"\n    num_elems = int(elems_flat[0].shape[axis])\n    if not all(int(elem.shape[axis]) == num_elems for elem in elems_flat[1:]):\n        raise ValueError(\n            \"Array inputs to associative_scan must have the same \"\n            \"first dimension. (saw: {})\".format([elem.shape for elem in elems_flat])\n        )\n\n    scans = _scan(tree, operator, elems_flat, axis)\n\n    if reverse:\n        scans = [torch.flip(scanned, [axis]) for scanned in scans]\n\n    return tree_unflatten(scans, tree)\n\n\ndef test_associative_scan(shape=(1, 24, 24)):\n    import jax.lax\n    import jax\n\n    x = np.random.randn(*shape)\n    jx = jax.numpy.array(x)\n    tx = torch.tensor(x, dtype=torch.float32)\n\n    def nested_func(a, b):\n        a_i, b_i = a\n        a_j, b_j = b\n        return a_j * a_i, a_j * b_i + b_j\n\n    jy1, jy2 = jax.lax.associative_scan(nested_func, (jx, jx))\n    ty1, ty2 = associative_scan(nested_func, (tx, tx))\n    assert (\n        np.isclose(ty1.numpy(), np.array(jy1)).all()\n        and np.isclose(ty2.numpy(), np.array(jy2)).all()\n    ), \"Expected jax & pytorch impl to be close\"\n\n    jy1, jy2 = jax.lax.associative_scan(nested_func, (jx, jx), reverse=True)\n    ty1, ty2 = associative_scan(nested_func, (tx, tx), reverse=True)\n    assert (\n        np.isclose(ty1.numpy(), np.array(jy1)).all()\n        and np.isclose(ty2.numpy(), np.array(jy2)).all()\n    ), \"Expected jax & pytorch reverse impl to be close\"\n\n    print(\"Associative scan working as expected!\")\n\n\ndef _interleave(a, b, axis: int):\n    # https://stackoverflow.com/questions/60869537/how-can-i-interleave-5-pytorch-tensors\n    b_trunc = a.shape[axis] == b.shape[axis] + 1\n    if b_trunc:\n        pad = [0, 0] * b.ndim\n        pad[(b.ndim - axis - 1) * 2 + 1] = (\n            1  # +1=always end of dim, pad-order is reversed so start is at end\n        )\n        b = torch.nn.functional.pad(b, pad)\n\n    stacked = torch.stack([a, b], dim=axis + 1)\n    interleaved = torch.flatten(stacked, start_dim=axis, end_dim=axis + 1)\n    if b_trunc:\n        # TODO: find torch alternative for slice_along axis for torch.jit.script to work\n        interleaved = torch.ops.aten.slice(\n            interleaved, axis, 0, b.shape[axis] + a.shape[axis] - 1\n        )\n    return interleaved\n\n\ndef test_interleave():\n    x, y = torch.randn(1, 32, 32), torch.randn(1, 32, 32)\n    v = _interleave(x, y, axis=1)\n    assert v.shape == (1, 64, 32)\n    assert (v[:, 0] == x[:, 0]).all()\n    assert (v[:, 1] == y[:, 0]).all()\n    assert (v[:, 2] == x[:, 1]).all()\n    assert (v[:, 3] == y[:, 1]).all()\n    assert (v[:, 4] == x[:, 2]).all()\n\n    v = _interleave(x, y, axis=2)\n    assert v.shape == (1, 32, 64)\n    assert (v[..., 0] == x[..., 0]).all()\n    assert (v[..., 1] == y[..., 0]).all()\n    assert (v[..., 2] == x[..., 1]).all()\n    assert (v[..., 3] == y[..., 1]).all()\n    assert (v[..., 4] == x[..., 2]).all()\n\n    x, y = torch.randn(1, 24, 24), torch.randn(1, 24, 24)\n    assert _interleave(x, y, axis=1).shape == (1, 48, 24)\n    assert _interleave(x, y, axis=2).shape == (1, 24, 48)\n\n    x, y = torch.randn(3, 96), torch.randn(2, 96)\n    v = _interleave(x, y, axis=0)\n    assert v.shape == (5, 96)\n    assert (v[0] == x[0]).all()\n    assert (v[1] == y[0]).all()\n    assert (v[2] == x[1]).all()\n    assert (v[3] == y[1]).all()\n    assert (v[4] == x[2]).all()\n    print(\"Interleave working as expected!\")\n\n\ndef _compute_fans(shape, fan_in_axes=None):\n    \"\"\"Computes the number of input and output units for a weight shape.\"\"\"\n    if len(shape) < 1:\n        fan_in = fan_out = 1\n    elif len(shape) == 1:\n        fan_in = fan_out = shape[0]\n    elif len(shape) == 2:\n        fan_in, fan_out = shape\n    else:\n        if fan_in_axes is not None:\n            # Compute fan-in using user-specified fan-in axes.\n            fan_in = np.prod([shape[i] for i in fan_in_axes])\n            fan_out = np.prod([s for i, s in enumerate(shape) if i not in fan_in_axes])\n        else:\n            # If no axes specified, assume convolution kernels (2D, 3D, or more.)\n            # kernel_shape: (..., input_depth, depth)\n            receptive_field_size = np.prod(shape[:-2])\n            fan_in = shape[-2] * receptive_field_size\n            fan_out = shape[-1] * receptive_field_size\n    return fan_in, fan_out\n\n\ndef uniform(shape, dtype=torch.float, minval=0.0, maxval=1.0, device=None):\n    src = torch.rand(shape, dtype=dtype, device=device)\n    if minval == 0 and maxval == 1.0:\n        return src\n    else:\n        return (src * (maxval - minval)) + minval\n\n\ndef _complex_uniform(shape: Sequence[int], dtype, device=None) -> torch.Tensor:\n    \"\"\"\n    Sample uniform random values within a disk on the complex plane,\n    with zero mean and unit variance.\n    \"\"\"\n    r = torch.sqrt(2 * torch.rand(shape, dtype=dtype, device=device))\n    theta = 2 * torch.pi * torch.rand(shape, dtype=dtype, device=device)\n    return r * torch.exp(1j * theta)\n\n\ndef complex_as_float_dtype(dtype):\n    match dtype:\n        case torch.complex32:\n            return torch.float32  # NOTE: complexe32 is not wel supported yet\n        case torch.complex64:\n            return torch.float32\n        case torch.complex128:\n            return torch.float64\n        case _:\n            return dtype\n\n\ndef _complex_truncated_normal(\n    upper: float, shape: Sequence[int], dtype, device=None\n) -> torch.Tensor:\n    \"\"\"\n    Sample random values from a centered normal distribution on the complex plane,\n    whose modulus is truncated to `upper`, and the variance before the truncation\n    is one.\n    \"\"\"\n    real_dtype = torch.tensor(0, dtype=dtype).real.dtype\n    t = (\n        1 - torch.exp(torch.tensor(-(upper**2), dtype=dtype, device=device))\n    ) * torch.rand(shape, dtype=real_dtype, device=device).type(dtype)\n    r = torch.sqrt(-torch.log(1 - t))\n    theta = (\n        2 * torch.pi * torch.rand(shape, dtype=real_dtype, device=device).type(dtype)\n    )\n    return r * torch.exp(1j * theta)\n\n\ndef _truncated_normal(lower, upper, shape, dtype=torch.float):\n    if shape is None:\n        shape = torch.broadcast_shapes(np.shape(lower), np.shape(upper))\n\n    sqrt2 = math.sqrt(2)\n    a = math.erf(lower / sqrt2)\n    b = math.erf(upper / sqrt2)\n\n    # a<u<b\n    u = uniform(shape, dtype, minval=a, maxval=b)\n    out = sqrt2 * torch.erfinv(u)\n    # Clamp the value to the open interval (lower, upper) to make sure that\n    # rounding (or if we chose `a` for `u`) doesn't push us outside of the range.\n    with torch.no_grad():\n        return torch.clip(\n            out,\n            torch.nextafter(torch.tensor(lower), torch.tensor(np.inf, dtype=dtype)),\n            torch.nextafter(torch.tensor(upper), torch.tensor(-np.inf, dtype=dtype)),\n        )\n\n\ndef variance_scaling(\n    scale: float,\n    mode: Literal[\"fan_in\", \"fan_out\", \"fan_avg\"] = \"fan_in\",\n    distribution: Literal[\"truncated_normal\", \"normal\", \"uniform\"] = \"truncated_normal\",\n    fan_in_axes: Optional[Sequence[int]] = None,\n    dtype=torch.float,\n):\n    def init(shape: Sequence[float], dtype=dtype, device=None):\n        fan_in, fan_out = _compute_fans(shape, fan_in_axes)\n        match mode:\n            case \"fan_in\":\n                denom = max(1, fan_in)\n            case \"fan_out\":\n                denom = max(1, fan_out)\n            case \"fan_avg\":\n                denom = max(1, (fan_in + fan_out) / 2)\n            case _:\n                raise ValueError(\n                    f\"invalid mode for variance scaling initializer: {mode}\"\n                )\n\n        variance = scale / denom\n        match distribution:\n            case \"normal\":\n                return torch.normal(\n                    0, np.sqrt(variance), shape, dtype=dtype, device=device\n                )\n            case \"uniform\":\n                if dtype.is_complex:\n                    return _complex_uniform(\n                        shape, dtype=dtype, device=device\n                    ) * np.sqrt(variance)\n                else:\n                    return uniform(\n                        shape, dtype=dtype, device=device, minval=-1, maxval=1.0\n                    ) * np.sqrt(3 * variance)\n            case \"truncated_normal\":\n                if dtype.is_complex:\n                    stddev = np.sqrt(variance) * 0.95311164380491208\n                    return (\n                        _complex_truncated_normal(2, shape, dtype=dtype, device=device)\n                        * stddev\n                    )\n                else:\n                    stddev = np.sqrt(variance) * 0.87962566103423978\n                    return _truncated_normal(-2.0, 2.0, shape, dtype=dtype) * stddev\n            case _:\n                raise ValueError(\n                    f\"invalid distribution for variance scaling initializer: {distribution}\"\n                )\n\n    return init\n\n\ndef lecun_normal(fan_in_axes=None, dtype=torch.float):\n    \"\"\"Builds a Lecun normal initializer.\n\n    A `Lecun normal initializer`_ is a specialization of\n    :func:`jax.nn.initializers.variance_scaling` where ``scale = 1.0``,\n    ``mode=\"fan_in\"``, and ``distribution=\"truncated_normal\"``.\n\n    Args:\n    in_axis: axis or sequence of axes of the input dimension in the weights\n      array.\n    out_axis: axis or sequence of axes of the output dimension in the weights\n      array.\n    batch_axis: axis or sequence of axes in the weight array that should be\n      ignored.\n    dtype: the dtype of the weights.\n\n    Returns:\n    An initializer.\n\n    Example:\n\n    >>> import jax, jax.numpy as jnp\n    >>> initializer = jax.nn.initializers.lecun_normal()\n    >>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32)  # doctest: +SKIP\n    Array([[ 0.46700746,  0.8414632 ,  0.8518669 ],\n         [-0.61677957, -0.67402434,  0.09683388]], dtype=float32)\n\n    .. _Lecun normal initializer: https://arxiv.org/abs/1706.02515\n    \"\"\"\n    return variance_scaling(\n        1.0, \"fan_in\", \"truncated_normal\", fan_in_axes=fan_in_axes, dtype=dtype\n    )\n\n\ndef test_variance_scaling():\n    v = variance_scaling(1.0, distribution=\"normal\")\n    n_f32 = v((1, 10000), dtype=torch.float)\n    assert np.isclose(\n        n_f32.std().item(), 1.0, rtol=0.015, atol=0.015\n    ), f\"std for f32 normal[0,1.0] is {n_f32.std()} != 1.0\"\n    del n_f32\n    # NOTE: this is used in the original as `complex_normal` (but with stddev=0.5**0.5)\n    n_c64 = v((1, 10000), dtype=torch.complex64)\n    assert np.isclose(\n        n_c64.std().item(), 1.0, rtol=0.015, atol=0.015\n    ), f\"std for c64 normal[0,1.0] is {n_c64.std()} != 1.0\"\n    del n_c64\n\n    # Truncated normal\n    v = variance_scaling(1.0, distribution=\"truncated_normal\")\n    tn_f32 = v((1, 10000), dtype=torch.float)\n    assert np.isclose(\n        tn_f32.std().item(), 0.775, rtol=0.015, atol=0.015\n    ), f\"std for f32 truncated normal[0,1.0] is {tn_f32.std()} != 0.775\"\n    del tn_f32\n\n    # NOTE: this is used in the original (both trunc_standard_normal & lecun_normal it seems),\n    # seems that they are using the fan-in/out feature to 'hide the low variance initialization'\n    # The actual std observed is np.sqrt(2/shape[1]/(2*shape[0])); shape[2] has no impact\n    v = variance_scaling(1.0, distribution=\"truncated_normal\")\n    tn_f32 = v((1, 10000, 2), dtype=torch.float)\n    tn_c32 = torch.complex(tn_f32[..., 0], tn_f32[..., 1])\n    expected_std = np.sqrt(2 / tn_f32.shape[1] / (2 * tn_f32.shape[0]))\n    print(tn_c32.shape)\n    assert np.isclose(\n        tn_c32.std().item(), expected_std, rtol=0.015, atol=0.015\n    ), f\"std for f32 truncated normal[0,1.0] is {tn_c32.std()} != {expected_std}\"\n    del tn_f32\n    del tn_c32\n\n    print(\"Variance scaling working as expected!\")\n\n\nif __name__ == \"__main__\":\n    test_variance_scaling()\n    test_interleave()\n    test_associative_scan()\n    test_associative_scan(shape=(2, 256, 24))\n    test_associative_scan(shape=(360, 96))\n"
  },
  {
    "path": "RVT/models/layers/s5/s5_init.py",
    "content": "import torch\nimport numpy as np\nfrom .jax_func import variance_scaling, lecun_normal, uniform\nimport scipy.linalg\n\n# Initialization Functions\n\n\ndef make_HiPPO(N):\n    \"\"\"Create a HiPPO-LegS matrix.\n    From https://github.com/srush/annotated-s4/blob/main/s4/s4.py\n    Args:\n        N (int32): state size\n    Returns:\n        N x N HiPPO LegS matrix\n    \"\"\"\n    P = np.sqrt(1 + 2 * np.arange(N))\n    A = P[:, np.newaxis] * P[np.newaxis, :]\n    A = np.tril(A) - np.diag(np.arange(N))\n    return -A\n\n\ndef make_NPLR_HiPPO(N):\n    \"\"\"\n    Makes components needed for NPLR representation of HiPPO-LegS\n     From https://github.com/srush/annotated-s4/blob/main/s4/s4.py\n    Args:\n        N (int32): state size\n    Returns:\n        N x N HiPPO LegS matrix, low-rank factor P, HiPPO input matrix B\n    \"\"\"\n    # Make -HiPPO\n    hippo = make_HiPPO(N)\n\n    # Add in a rank 1 term. Makes it Normal.\n    P = np.sqrt(np.arange(N) + 0.5)\n\n    # HiPPO also specifies the B matrix\n    B = np.sqrt(2 * np.arange(N) + 1.0)\n    return hippo, P, B\n\n\ndef make_DPLR_HiPPO(N):\n    \"\"\"\n    Makes components needed for DPLR representation of HiPPO-LegS\n     From https://github.com/srush/annotated-s4/blob/main/s4/s4.py\n    Note, we will only use the diagonal part\n    Args:\n        N:\n    Returns:\n        eigenvalues Lambda, low-rank term P, conjugated HiPPO input matrix B,\n        eigenvectors V, HiPPO B pre-conjugation\n    \"\"\"\n    A, P, B = make_NPLR_HiPPO(N)\n\n    S = A + P[:, np.newaxis] * P[np.newaxis, :]\n\n    S_diag = np.diagonal(S)\n    Lambda_real = np.mean(S_diag) * np.ones_like(S_diag)\n\n    # Diagonalize S to V \\Lambda V^*\n    Lambda_imag, V = np.linalg.eigh(S * -1j)\n\n    P = V.conj().T @ P\n    B_orig = B\n    B = V.conj().T @ B\n    return Lambda_real + 1j * Lambda_imag, P, B, V, B_orig\n\n\ndef make_Normal_S(N):\n    nhippo = make_HiPPO(N)\n    # Add in a rank 1 term. Makes it Normal.\n    p = 0.5 * np.sqrt(2 * np.arange(1, N + 1) + 1.0)\n    q = 2 * p\n    S = nhippo + p[:, np.newaxis] * q[np.newaxis, :]\n    return S\n\n\ndef make_Normal_HiPPO(N, B=1):\n    \"\"\"Create a normal approximation to HiPPO-LegS matrix.\n    For HiPPO matrix A, A=S+pqT is normal plus low-rank for\n    a certain normal matrix S and low rank terms p and q.\n    We are going to approximate the HiPPO matrix with the normal matrix S.\n    Note we use original numpy instead of jax.numpy first to use the\n    onp.linalg.eig function. This is because Jax's linalg.eig function does not run\n    on GPU for non-symmetric matrices. This creates tracing issues.\n    So we instead use onp.linalg eig and then cast to a jax array\n    (since we only have to do this once in the beginning to initialize).\n    Args:\n        N (int32): state size\n        B (int32): diagonal blocks\n    Returns:\n        Lambda (complex64): eigenvalues of S (N,)\n        V      (complex64): eigenvectors of S (N,N)\n    \"\"\"\n\n    assert N % B == 0, \"N must divide blocks\"\n    S = (make_Normal_S(N // B),) * B\n    S = scipy.linalg.block_diag(*S)\n\n    # Diagonalize S to V \\Lambda V^*\n    Lambda, V = np.linalg.eig(S)\n\n    # Convert to jax array\n    return torch.tensor(Lambda), torch.tensor(V)\n\n\ndef log_step_initializer(dt_min=0.001, dt_max=0.1):\n    \"\"\"Initialize the learnable timescale Delta by sampling\n    uniformly between dt_min and dt_max.\n    Args:\n        dt_min (float32): minimum value\n        dt_max (float32): maximum value\n    Returns:\n        init function\n    \"\"\"\n\n    def init(shape):\n        \"\"\"Init function\n        Args:\n            key: jax random key\n            shape tuple: desired shape\n        Returns:\n            sampled log_step (float32)\n        \"\"\"\n        return uniform(shape, minval=np.log(dt_min), maxval=np.log(dt_max))\n        # return torch.rand(shape) * (np.log(dt_max) - np.log(dt_min)) + np.log(dt_min)\n\n    return init\n\n\ndef init_log_steps(H, dt_min, dt_max):\n    \"\"\"Initialize an array of learnable timescale parameters\n    Args:\n        key: jax random key\n        input: tuple containing the array shape H and\n               dt_min and dt_max\n    Returns:\n        initialized array of timescales (float32): (H,)\n    \"\"\"\n    log_steps = []\n    for i in range(H):\n        log_step = log_step_initializer(dt_min=dt_min, dt_max=dt_max)(shape=(1,))\n        log_steps.append(log_step)\n\n    return torch.tensor(log_steps)\n\n\ndef init_VinvB(init_fun, Vinv):\n    \"\"\"Initialize B_tilde=V^{-1}B. First samples B. Then compute V^{-1}B.\n    Note we will parameterize this with two different matrices for complex\n    numbers.\n     Args:\n         init_fun:  the initialization function to use, e.g. lecun_normal()\n         shape (tuple): desired shape  (P,H)\n         Vinv: (complex64)     the inverse eigenvectors used for initialization\n     Returns:\n         B_tilde (complex64) of shape (P,H,2)\n    \"\"\"\n\n    def init(shape, dtype):\n        B = init_fun(shape, dtype)\n        VinvB = Vinv @ B.type(Vinv.dtype)\n        VinvB_real = VinvB.real\n        VinvB_imag = VinvB.imag\n        return torch.cat((VinvB_real[..., None], VinvB_imag[..., None]), axis=-1)\n\n    return init\n\n\ndef trunc_standard_normal(shape):\n    \"\"\"Sample C with a truncated normal distribution with standard deviation 1.\n    Args:\n        key: jax random key\n        shape (tuple): desired shape, of length 3, (H,P,_)\n    Returns:\n        sampled C matrix (float32) of shape (H,P,2) (for complex parameterization)\n    \"\"\"\n    H, P, _ = shape\n    Cs = []\n    for i in range(H):\n        C = lecun_normal()(shape=(1, P, 2))\n        Cs.append(C)\n    return torch.tensor(Cs)[:, 0]\n\n\ndef init_CV(init_fun, shape, V) -> torch.Tensor:\n    \"\"\"Initialize C_tilde=CV. First sample C. Then compute CV.\n    Note we will parameterize this with two different matrices for complex\n    numbers.\n     Args:\n         init_fun:  the initialization function to use, e.g. lecun_normal()\n         shape (tuple): desired shape  (H,P)\n         V: (complex64)     the eigenvectors used for initialization\n     Returns:\n         C_tilde (complex64) of shape (H,P,2)\n    \"\"\"\n    C_ = init_fun(shape + (2,))\n    C = C_[..., 0] + 1j * C_[..., 1]\n    CV = C @ V\n    return CV\n\n\ndef init_columnwise_B(shape, dtype):\n    \"\"\"Initialize B matrix in columnwise fashion.\n    We will sample each column of B from a lecun_normal distribution.\n    This gives a different fan-in size then if we sample the entire\n    matrix B at once. We found this approach to be helpful for PathX\n    It appears to be related to the point in\n    https://arxiv.org/abs/2206.12037 regarding the initialization of\n    the C matrix in S4, so potentially more important for the\n    C initialization than for B.\n     Args:\n         key: jax random key\n         shape (tuple): desired shape, either of length 3, (P,H,_), or\n                      of length 2 (N,H) depending on if the function is called\n                      from the low-rank factorization initialization or a dense\n                      initialization\n     Returns:\n         sampled B matrix (float32), either of shape (H,P) or\n          shape (H,P,2) (for complex parameterization)\n    \"\"\"\n    shape = shape[:2] + ((2,) if len(shape) == 3 else ())\n    lecun = variance_scaling(0.5 if len(shape) == 3 else 1.0, fan_in_axes=(0,))\n    return lecun(shape, dtype)\n\n\ndef init_columnwise_VinvB(init_fun, Vinv):\n    \"\"\"Same function as above, but with transpose applied to prevent shape mismatch\n    when using the columnwise initialization. In general this is unnecessary\n    and will be removed in future versions, but is left for now consistency with\n    certain random seeds until we rerun experiments.\"\"\"\n\n    def init(shape, dtype):\n        B = init_fun(shape[:2], dtype)\n        VinvB = Vinv @ B\n        VinvB_real = VinvB.real\n        VinvB_imag = VinvB.imag\n        return torch.cat((VinvB_real[..., None], VinvB_imag[..., None]), axis=-1)\n\n    return init\n\n\ndef init_rowwise_C(shape, dtype):\n    \"\"\"Initialize C matrix in rowwise fashion. Analogous to init_columnwise_B function above.\n    We will sample each row of C from a lecun_normal distribution.\n    This gives a different fan-in size then if we sample the entire\n    matrix B at once. We found this approach to be helpful for PathX.\n    It appears to be related to the point in\n    https://arxiv.org/abs/2206.12037 regarding the initialization of\n    the C matrix in S4.\n     Args:\n         shape (tuple): desired shape, of length 3, (H,P,_)\n     Returns:\n         sampled C matrix (float32) of shape (H,P,2) (for complex parameterization)\n    \"\"\"\n    shape = shape[:2] + ((2,) if len(shape) == 3 else ())\n    lecun = variance_scaling(0.5, fan_in_axes=(0,))\n    return lecun(shape, dtype)\n"
  },
  {
    "path": "RVT/models/layers/s5/s5_model.py",
    "content": "import torch\nimport torch.nn.functional as F\nfrom typing import Literal, Tuple, Optional\nimport os, sys\nimport math\n\nROOT = os.getcwd()\nif str(ROOT) not in sys.path:\n    sys.path.append(str(ROOT))\nsys.path.append(os.path.join(ROOT, \"RVT\"))\n\nfrom models.layers.s5.jax_func import associative_scan\nfrom models.layers.s5.s5_init import *\n\n# Runtime functions\n\n\n@torch.jit.script\ndef binary_operator(\n    q_i: Tuple[torch.Tensor, torch.Tensor], q_j: Tuple[torch.Tensor, torch.Tensor]\n):\n    \"\"\"Binary operator for parallel scan of linear recurrence. Assumes a diagonal matrix A.\n    Args:\n        q_i: tuple containing A_i and Bu_i at position i       (P,), (P,)\n        q_j: tuple containing A_j and Bu_j at position j       (P,), (P,)\n    Returns:\n        new element ( A_out, Bu_out )\n    \"\"\"\n    A_i, b_i = q_i\n    A_j, b_j = q_j\n    # return A_j * A_i, A_j * b_i + b_j\n    return A_j * A_i, torch.addcmul(b_j, A_j, b_i)\n\n\ndef apply_ssm(\n    Lambda_bars: torch.Tensor,\n    B_bars,\n    C_tilde,\n    D,\n    input_sequence,\n    prev_state,\n    bidir: bool = False,\n):\n    B_bars = as_complex(B_bars)\n    C_tilde = as_complex(C_tilde)\n    Lambda_bars = as_complex(Lambda_bars)\n\n    cinput_sequence = input_sequence.type(\n        Lambda_bars.dtype\n    )  # Cast to correct complex type\n\n    if B_bars.ndim == 3:\n        # Dynamic timesteps (significantly more expensive)\n        Bu_elements = torch.vmap(lambda B_bar, u: B_bar @ u)(B_bars, cinput_sequence)\n    else:\n        # Static timesteps\n        Bu_elements = torch.vmap(lambda u: B_bars @ u)(cinput_sequence)\n\n    if Lambda_bars.ndim == 1:  # Repeat for associative_scan\n        Lambda_bars = Lambda_bars.tile(input_sequence.shape[0], 1)\n\n    Lambda_bars[0] = Lambda_bars[0] * prev_state\n\n    _, xs = associative_scan(binary_operator, (Lambda_bars, Bu_elements))\n\n    if bidir:\n        _, xs2 = associative_scan(\n            binary_operator, (Lambda_bars, Bu_elements), reverse=True\n        )\n        xs = torch.cat((xs, xs2), axis=-1)\n\n    Du = torch.vmap(lambda u: D * u)(input_sequence)\n    # TODO: the last element of xs (non-bidir) is the hidden state, allow returning it\n    return torch.vmap(lambda x: (C_tilde @ x).real)(xs) + Du, xs[-1]\n\n\ndef apply_ssm_liquid(\n    Lambda_bars, B_bars, C_tilde, D, input_sequence, bidir: bool = False\n):\n    \"\"\"Liquid time constant SSM \\u00e1 la dynamical systems given in Eq. 8 of\n    https://arxiv.org/abs/2209.12951\"\"\"\n    cinput_sequence = input_sequence.type(\n        Lambda_bars.dtype\n    )  # Cast to correct complex type\n\n    if B_bars.ndim == 3:\n        # Dynamic timesteps (significantly more expensive)\n        Bu_elements = torch.vmap(lambda B_bar, u: B_bar @ u)(B_bars, cinput_sequence)\n    else:\n        # Static timesteps\n        Bu_elements = torch.vmap(lambda u: B_bars @ u)(cinput_sequence)\n\n    if Lambda_bars.ndim == 1:  # Repeat for associative_scan\n        Lambda_bars = Lambda_bars.tile(input_sequence.shape[0], 1)\n\n    _, xs = associative_scan(binary_operator, (Lambda_bars + Bu_elements, Bu_elements))\n\n    if bidir:\n        _, xs2 = associative_scan(\n            binary_operator, (Lambda_bars, Bu_elements), reverse=True\n        )\n        xs = torch.cat((xs, xs2), axis=-1)\n\n    Du = torch.vmap(lambda u: D * u)(input_sequence)\n    return torch.vmap(lambda x: (C_tilde @ x).real)(xs) + Du\n\n\n# Discretization functions\ndef discretize_bilinear(Lambda, B_tilde, Delta):\n    \"\"\"Discretize a diagonalized, continuous-time linear SSM\n    using bilinear transform method.\n    Args:\n        Lambda (complex64): diagonal state matrix              (P,)\n        B_tilde (complex64): input matrix                      (P, H)\n        Delta (float32): discretization step sizes             (P,)\n    Returns:\n        discretized Lambda_bar (complex64), B_bar (complex64)  (P,), (P,H)\n    \"\"\"\n    Lambda = torch.view_as_complex(Lambda)\n\n    Identity = torch.ones(Lambda.shape[0], device=Lambda.device)\n    BL = 1 / (Identity - (Delta / 2.0) * Lambda)\n    Lambda_bar = BL * (Identity + (Delta / 2.0) * Lambda)\n    B_bar = (BL * Delta)[..., None] * B_tilde\n\n    Lambda_bar = torch.view_as_real(Lambda_bar)\n    B_bar = torch.view_as_real(B_bar)\n\n    return Lambda_bar, B_bar\n\n\ndef discretize_zoh(Lambda, B_tilde, Delta):\n    \"\"\"Discretize a diagonalized, continuous-time linear SSM\n    using zero-order hold method.\n    Args:\n        Lambda (complex64): diagonal state matrix              (P,)\n        B_tilde (complex64): input matrix                      (P, H)\n        Delta (float32): discretization step sizes             (P,)\n    Returns:\n        discretized Lambda_bar (complex64), B_bar (complex64)  (P,), (P,H)\n    \"\"\"\n    # Identity = torch.ones(Lambda.shape[0], device=Lambda.device) # (replaced by -1)\n    Lambda_bar = torch.exp(Lambda * Delta)\n    B_bar = (1 / Lambda * (Lambda_bar - 1))[..., None] * B_tilde\n    return Lambda_bar, B_bar\n\n\ndef as_complex(t: torch.Tensor, dtype=torch.complex64):\n    assert t.shape[-1] == 2, \"as_complex can only be done on tensors with shape=(...,2)\"\n    nt = torch.complex(t[..., 0], t[..., 1])\n    if nt.dtype != dtype:\n        nt = nt.type(dtype)\n    return nt\n\n\nInitialization = Literal[\"dense_columns\", \"dense\", \"factorized\"]\n\n\nclass S5SSM(torch.nn.Module):\n    def __init__(\n        self,\n        lambdaInit: torch.Tensor,\n        V: torch.Tensor,\n        Vinv: torch.Tensor,\n        h: int,\n        p: int,\n        dt_min: float,\n        dt_max: float,\n        liquid: bool = False,\n        factor_rank: Optional[int] = None,\n        discretization: Literal[\"zoh\", \"bilinear\"] = \"bilinear\",\n        bcInit: Initialization = \"factorized\",\n        degree: int = 1,\n        bidir: bool = False,\n        step_scale: float = 1.0,\n        bandlimit: Optional[float] = None,\n    ):\n        \"\"\"The S5 SSM\n        Args:\n            lambdaInit  (complex64): Initial diagonal state matrix       (P,)\n            V           (complex64): Eigenvectors used for init          (P,P)\n            Vinv        (complex64): Inverse eigenvectors used for init  (P,P)\n            h           (int32):     Number of features of input seq\n            p           (int32):     state size\n            k           (int32):     rank of low-rank factorization (if used)\n            bcInit      (string):    Specifies How B and C are initialized\n                        Options: [factorized: low-rank factorization,\n                                dense: dense matrix drawn from Lecun_normal]\n                                dense_columns: dense matrix where the columns\n                                of B and the rows of C are each drawn from Lecun_normal\n                                separately (i.e. different fan-in then the dense option).\n                                We found this initialization to be helpful for Pathx.\n            discretization: (string) Specifies discretization method\n                            options: [zoh: zero-order hold method,\n                                    bilinear: bilinear transform]\n            liquid:         (bool): use liquid_ssm from LiquidS4\n            dt_min:      (float32): minimum value to draw timescale values from when\n                                    initializing log_step\n            dt_max:      (float32): maximum value to draw timescale values from when\n                                    initializing log_step\n            step_scale:  (float32): allows for changing the step size, e.g. after training\n                                    on a different resolution for the speech commands benchmark\n        \"\"\"\n        super().__init__()\n        self.Lambda = torch.nn.Parameter(torch.view_as_real(lambdaInit))\n        self.degree = degree\n        self.liquid = liquid\n        self.bcInit = bcInit\n        self.bidir = bidir\n        self.bandlimit = bandlimit\n\n        cp = p\n        if self.bidir:\n            cp *= 2\n\n        match bcInit:\n            case \"complex_normal\":\n                self.C = torch.nn.Parameter(\n                    torch.normal(0, 0.5**0.5, (h, cp), dtype=torch.complex64)\n                )\n                self.B = torch.nn.Parameter(\n                    init_VinvB(lecun_normal(), Vinv)((p, h), torch.float)\n                )\n            case \"dense_columns\" | \"dense\":\n                if bcInit == \"dense_columns\":\n                    B_eigen_init = init_columnwise_VinvB\n                    B_init = init_columnwise_B\n                    C_init = init_rowwise_C\n                elif bcInit == \"dense\":\n                    B_eigen_init = init_VinvB\n                    B_init = C_init = lecun_normal()\n                # TODO: make init_*VinvB all a the same interface\n                self.B = torch.nn.Parameter(\n                    B_eigen_init(B_init, Vinv)((p, h), torch.float)\n                )\n                if self.bidir:\n                    C = torch.cat(\n                        [init_CV(C_init, (h, p), V), init_CV(C_init, (h, p), V)],\n                        axis=-1,\n                    )\n                else:\n                    C = init_CV(C_init, (h, p), V)\n                self.C = torch.nn.Parameter(torch.view_as_real(C))\n            case _:\n                raise NotImplementedError(f\"BC_init method {bcInit} not implemented\")\n\n        # Initialize feedthrough (D) matrix\n        self.D = torch.nn.Parameter(\n            torch.rand(\n                h,\n            )\n        )\n        self.log_step = torch.nn.Parameter(init_log_steps(p, dt_min, dt_max))\n        match discretization:\n            case \"zoh\":\n                self.discretize = discretize_zoh\n            case \"bilinear\":\n                self.discretize = discretize_bilinear\n            case _:\n                raise ValueError(f\"Unknown discretization {discretization}\")\n\n        if self.bandlimit is not None:\n            step = step_scale * torch.exp(self.log_step)\n\n            freqs = step / step_scale * self.Lambda[:, 1].abs() / (2 * math.pi)\n            mask = torch.where(freqs < bandlimit * 0.5, 1, 0)  # (64, )\n            self.C = torch.nn.Parameter(\n                torch.view_as_real(torch.view_as_complex(self.C) * mask)\n            )\n\n    def initial_state(self, batch_size: Optional[int]):\n        batch_shape = (batch_size,) if batch_size is not None else ()\n        _, C_tilde = self.get_BC_tilde()\n\n        return torch.zeros((*batch_shape, C_tilde.shape[-2]))\n\n    def get_BC_tilde(self):\n        match self.bcInit:\n            case \"dense_columns\" | \"dense\" | \"complex_normal\":\n                B_tilde = as_complex(self.B)\n                C_tilde = self.C\n            case \"factorized\":\n                B_tilde = self.BP @ self.BH.T\n                C_tilde = self.CH.T @ self.CP\n        return B_tilde, C_tilde\n\n    def forward_rnn(self, signal, prev_state, step_scale: float | torch.Tensor = 1.0):\n        assert not self.bidir, \"Can't use bidirectional when manually stepping\"\n        B_tilde, C_tilde = self.get_BC_tilde()\n        step = step_scale * torch.exp(self.log_step)\n        Lambda_bar, B_bar = self.discretize(self.Lambda, B_tilde, step)\n        if self.degree != 1:\n            assert (\n                B_bar.shape[-2] == B_bar.shape[-1]\n            ), \"higher-order input operators must be full-rank\"\n            B_bar **= self.degree\n\n        if not torch.is_tensor(step_scale) or step_scale.ndim == 0:\n            step_scale = torch.ones(signal.shape[-2], device=signal.device) * step_scale\n        step = step_scale[:, None] * torch.exp(self.log_step)\n        # https://arxiv.org/abs/2209.12951v1, Eq. 9\n        Bu = B_bar @ signal\n        if self.liquid:\n            Lambda_bar += Bu\n        # https://arxiv.org/abs/2208.04933v2, Eq. 2\n        x = Lambda_bar * prev_state + Bu\n        y = (C_tilde @ x + self.D * signal).real\n        return y, x\n\n    # NOTE: can only be used as RNN OR S5(MIMO) (no mixing)\n    def forward(self, signal, prev_state, step_scale: float | torch.Tensor = 1.0):\n        B_tilde, C_tilde = self.get_BC_tilde()\n        if self.degree != 1:\n            assert (\n                B_bar.shape[-2] == B_bar.shape[-1]\n            ), \"higher-order input operators must be full-rank\"\n            B_bar **= self.degree\n\n        if not torch.is_tensor(step_scale) or step_scale.ndim == 0:\n            # step_scale = torch.ones(signal.shape[-2], device=signal.device) * step_scale\n            step = step_scale * torch.exp(self.log_step)\n        else:\n            # TODO: This is very expensive due to individual steps being multiplied by B_tilde in self.discretize\n            step = step_scale[:, None] * torch.exp(self.log_step)\n\n        Lambda_bars, B_bars = self.discretize(self.Lambda, B_tilde, step)\n        # Lambda_bars, B_bars = torch.vmap(self.discretize, (None, None, 0))(self.Lambda, B_tilde, step)\n        forward = apply_ssm_liquid if self.liquid else apply_ssm\n        return forward(\n            Lambda_bars, B_bars, C_tilde, self.D, signal, prev_state, bidir=self.bidir\n        )\n\n\nclass S5(torch.nn.Module):\n    def __init__(\n        self,\n        width: int,\n        state_width: Optional[int] = None,\n        factor_rank: Optional[int] = None,\n        block_count: int = 1,\n        dt_min: float = 0.001,\n        dt_max: float = 0.1,\n        liquid: bool = False,\n        degree: int = 1,\n        bidir: bool = False,\n        bcInit: Optional[Initialization] = None,\n        bandlimit: Optional[float] = None,\n    ):\n        super().__init__()\n        state_width = state_width or width\n        assert (\n            state_width % block_count == 0\n        ), \"block_count should be a factor of state_width\"\n\n        block_size = state_width // block_count\n        Lambda, _, B, V, B_orig = make_DPLR_HiPPO(block_size)\n        Vinv = V.conj().T\n        Lambda, B, V, B_orig, Vinv = map(\n            lambda v: torch.tensor(v, dtype=torch.complex64),\n            (Lambda, B, V, B_orig, Vinv),\n        )\n        if block_count > 1:\n            Lambda = Lambda[:block_size]\n            V = V[:, :block_size]\n            Lambda = (Lambda * torch.ones((block_count, block_size))).ravel()\n            V = torch.block_diag(*([V] * block_count))\n            Vinv = torch.block_diag(*([Vinv] * block_count))\n\n        assert bool(factor_rank) != bool(\n            bcInit != \"factorized\"\n        ), \"Can't have `bcInit != factorized` and `factor_rank` defined\"\n        bc_init = \"factorized\" if factor_rank is not None else (bcInit or \"dense\")\n        self.width = width\n        self.seq = S5SSM(\n            Lambda,\n            V,\n            Vinv,\n            width,\n            state_width,\n            dt_min,\n            dt_max,\n            factor_rank=factor_rank,\n            bcInit=bc_init,\n            liquid=liquid,\n            degree=degree,\n            bidir=bidir,\n            bandlimit=bandlimit,\n        )\n\n    def initial_state(self, batch_size: Optional[int] = None):\n        return self.seq.initial_state(batch_size)\n\n    def forward(self, signal, prev_state, step_scale: float | torch.Tensor = 1.0):\n        # NOTE: step_scale can be float | Tensor[batch] | Tensor[batch, seq]\n        if not torch.is_tensor(step_scale):\n            # Duplicate across batchdim\n            step_scale = torch.ones(signal.shape[0], device=signal.device) * step_scale\n\n        return torch.vmap(lambda s, ps, ss: self.seq(s, prev_state=ps, step_scale=ss))(\n            signal, prev_state, step_scale\n        )\n\n\nclass GEGLU(torch.nn.Module):\n    def forward(self, x):\n        x, gates = x.chunk(2, dim=-1)\n        return x * F.gelu(gates)\n\n\nclass S5Block(torch.nn.Module):\n    def __init__(\n        self,\n        dim: int,\n        state_dim: int,\n        bidir: bool,\n        block_count: int = 1,\n        liquid: bool = False,\n        degree: int = 1,\n        factor_rank: int | None = None,\n        bcInit: Optional[Initialization] = None,\n        ff_mult: float = 1.0,\n        glu: bool = True,\n        ff_dropout: float = 0.0,\n        attn_dropout: float = 0.0,\n        bandlimit: Optional[float] = None,\n    ):\n        super().__init__()\n        self.s5 = S5(\n            dim,\n            state_width=state_dim,\n            bidir=bidir,\n            block_count=block_count,\n            liquid=liquid,\n            degree=degree,\n            factor_rank=factor_rank,\n            bcInit=bcInit,\n            bandlimit=bandlimit,\n        )\n        self.attn_norm = torch.nn.LayerNorm(dim)\n        self.attn_dropout = torch.nn.Dropout(p=attn_dropout)\n        self.geglu = GEGLU() if glu else None\n        self.ff_enc = torch.nn.Linear(dim, int(dim * ff_mult) * (1 + glu), bias=False)\n        self.ff_dec = torch.nn.Linear(int(dim * ff_mult), dim, bias=False)\n        self.ff_norm = torch.nn.LayerNorm(dim)\n        self.ff_dropout = torch.nn.Dropout(p=ff_dropout)\n\n    def forward(self, x, states):\n        # Standard transfomer-style block with GEGLU/Pre-LayerNorm\n        fx = self.attn_norm(x)\n        res = fx.clone()\n        x, new_state = self.s5(fx, states)\n        x = F.gelu(x) + res\n        x = self.attn_dropout(x)\n\n        fx = self.ff_norm(x)\n        res = fx.clone()\n        x = self.ff_enc(fx)\n        if self.geglu is not None:\n            x = self.geglu(x)\n        x = self.ff_dec(x) + res\n        x = self.ff_dropout(\n            x\n        )  # TODO: test if should be placed inbetween ff or after ff\n        return x, new_state\n\n\nif __name__ == \"__main__\":\n    import lovely_tensors as lt\n\n    lt.monkey_patch()\n\n    def tensor_stats(t: torch.Tensor):  # Clone of lovely_tensors for complex support\n        return f\"tensor[{t.shape}] n={t.shape.numel()}, u={t.mean()}, s={round(t.std().item(), 3)} var={round(t.var().item(), 3)}\\n\"\n\n    x = torch.rand([2, 256, 32]).cuda()\n    model = S5(32, 32, factor_rank=None).cuda()\n    print(\"B\", tensor_stats(model.seq.B.data))\n    print(\"C\", tensor_stats(model.seq.C.data))\n    # print('B', tensor_stats(model.seq.BH.data), tensor_stats(model.seq.BP.data))\n    # print('C', tensor_stats(model.seq.CH.data), tensor_stats(model.seq.CP.data))\n    # FIXME: unstable initialization\n    # state = model.initial_state(256)\n    # res = model(x, prev_state=state)\n    # print(res.shape, res.dtype, res)\n    res = model(x)  # warm-up\n    print(res.shape, res.dtype, res)\n\n    # Example 2: (B, L, H) inputs\n    x = torch.rand([2, 256, 32]).cuda()\n    model = S5Block(32, 32, False).cuda()\n    res = model(x)\n    print(res.shape, res.dtype, res)\n"
  },
  {
    "path": "RVT/models/layers/s5/triton_comparison.py",
    "content": "import torch\nimport numpy as np\nimport time\nimport triton\nimport triton.language as tl\nfrom triton.runtime.jit import TensorWrapper, reinterpret\nfrom jax_func import associative_scan\n\nint_dtypes = [\"int8\", \"int16\", \"int32\", \"int64\"]\nuint_dtypes = [\"uint8\", \"uint16\", \"uint32\", \"uint64\"]\nfloat_dtypes = [\"float16\", \"float32\", \"float64\"]\ndtypes = int_dtypes + uint_dtypes + float_dtypes\ndtypes_with_bfloat16 = dtypes + [\"bfloat16\"]\ntorch_dtypes = [\"bool\"] + int_dtypes + [\"uint8\"] + float_dtypes + [\"bfloat16\"]\n\n\ndef to_triton(x: np.ndarray, device=\"cuda\", dst_type=None):\n    t = x.dtype.name\n    if t in uint_dtypes:\n        signed_type_name = t.lstrip(\"u\")  # e.g. \"uint16\" -> \"int16\"\n        x_signed = x.astype(getattr(np, signed_type_name))\n        return reinterpret(\n            torch.tensor(x_signed, device=device).contiguous(), getattr(tl, t)\n        )\n    else:\n        if dst_type and \"float8\" in dst_type:\n            return reinterpret(\n                torch.tensor(x, device=device).contiguous(), getattr(tl, dst_type)\n            )\n        if t == \"float32\" and dst_type == \"bfloat16\":\n            return torch.tensor(x, device=device).contiguous().bfloat16()\n        return torch.tensor(x, device=device).contiguous()\n\n\ndef to_numpy(x):\n    if isinstance(x, TensorWrapper):\n        # FIXME: torch_dtype_name doesn't exist\n        return x.base.cpu().numpy().astype(getattr(np, torch_dtype_name(x.dtype)))\n    elif isinstance(x, torch.Tensor):\n        if x.dtype is torch.bfloat16:\n            return x.cpu().float().numpy()\n        return x.cpu().numpy()\n    else:\n        raise ValueError(f\"Not a triton-compatible tensor: {x}\")\n\n\nif __name__ == \"__main__\":\n    use_gpu = True\n\n    if use_gpu:\n        device = torch.device(\"cuda:0\")\n    else:\n        device = None\n\n    triton_times = []\n    loop_times = []\n    loop_comp_times = []\n    jax_compat_times = []\n\n    print(\"Initializing\")\n    op = \"cumsum\"\n    num_warps = 16\n\n    dim = 1\n    seq_len = 2048\n    batch = 4\n\n    dtype_str = \"float32\"\n    axis = 0\n    shape = (batch, seq_len, dim)\n    n_timings = 10000\n\n    x = np.random.rand(*shape).astype(dtype=np.float32)\n    inp = torch.tensor(x, device=device, requires_grad=True, dtype=torch.float32)\n    init = torch.zeros(shape[1], 1, device=device, requires_grad=True)\n    inp_scan = inp\n\n    @triton.jit\n    def sum_op(a, b):\n        return a + b\n\n    @triton.jit\n    def kernel(X, Z, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, AXIS: tl.constexpr):\n        range_m = tl.arange(0, BLOCK_M)\n        range_n = tl.arange(0, BLOCK_N)\n        x = tl.load(X + range_m[:, None] * BLOCK_N + range_n[None, :])\n        # tl.device_print(\"z\", x)\n        z = tl.associative_scan(x, 0, sum_op)\n        # tl.device_print(\"z\", z)\n        tl.store(Z + range_m[:, None] * BLOCK_N + range_n[None, :], z)\n\n    print(\"Triton\")\n    z = np.empty_like(x)\n    x_tri = to_triton(x, device=device)\n    numpy_op = np.cumsum\n    z_dtype_str = dtype_str\n    z_ref = numpy_op(x, axis=axis).astype(getattr(np, z_dtype_str))\n    # triton result\n    z_tri = to_triton(z, device=device)\n    val = kernel[(1,)](\n        x_tri, z_tri, BLOCK_M=shape[0], BLOCK_N=shape[1], AXIS=axis, num_warps=num_warps\n    )\n    out_triton = to_numpy(z_tri)\n\n    for _ in range(n_timings):\n        # print('.', end='', flush=True)\n        start = time.monotonic_ns()\n        kernel[(1,)](\n            x_tri,\n            z_tri,\n            BLOCK_M=shape[0],\n            BLOCK_N=shape[1],\n            AXIS=axis,\n            num_warps=num_warps,\n        )\n        stop = time.monotonic_ns()\n        triton_times.append((stop - start) / (10**9))\n\n    print(\"\\nFake scan\")\n\n    def f(carry, x):\n        return carry + x, carry + x\n\n    def _fake_scan(f, init, x):\n        zs = []\n        carry = init\n        for xp in x:\n            carry, out = f(carry, xp)\n            zs.append(out)\n        return carry, torch.stack(zs)\n\n    expected_carry_out, expected_ys = _fake_scan(f, init, inp_scan)\n\n    for _ in range(n_timings):\n        # print('.', end='', flush=True)\n        start = time.monotonic_ns()\n        expected_carry_out, expected_ys = _fake_scan(f, init, inp_scan)\n        stop = time.monotonic_ns()\n        loop_times.append((stop - start) / (10**9))\n\n    # _fake_scan_comp = torch.compile(_fake_scan, mode='reduce-overhead', fullgraph=True, dynamic=False)\n\n    # # Warm-up cycles\n    # print(\"\\nFake scan-compiled\")\n    # for _ in range(5):\n    #     expected_carry_out_comp, expected_ys_comp = _fake_scan_comp(f, init, inp_scan)\n\n    # for _ in range(n_timings):\n    #     print('.', end='', flush=True)\n    #     start = time.monotonic_ns()\n    #     expected_carry_out_comp, expected_ys_comp = _fake_scan_comp(f, init, inp_scan)\n    #     stop = time.monotonic_ns()\n    #     loop_comp_times.append((stop - start) / (10 ** 9))\n\n    def sum_op2(a, b):\n        return a + b, a + b\n\n    # Warm-up\n    print(\"\\njax_compat\")\n    for _ in range(5):\n        expected_ys_comp = associative_scan(sum_op2, inp_scan, axis=-1)\n\n    for _ in range(n_timings):\n        # print('.', end='', flush=True)\n        start = time.monotonic_ns()\n        expected_ys_comp = associative_scan(sum_op2, inp_scan, axis=-1)\n        stop = time.monotonic_ns()\n        jax_compat_times.append((stop - start) / (10**9))\n\n    print()\n    print(\"Times regular loop \" + str(np.array(loop_times).mean()))\n    # print('Times compiled loop ' + str(np.array(loop_comp_times).mean()))\n    print(\"Times triton \" + str(np.array(triton_times).mean()))\n    print(\"Times jax_compat \" + str(np.array(jax_compat_times).mean()))\n    print(\"Script ended\")\n"
  },
  {
    "path": "RVT/modules/__init__.py",
    "content": ""
  },
  {
    "path": "RVT/modules/data/genx.py",
    "content": "from functools import partial\nfrom typing import Any, Dict, Optional, Union\n\nimport math\nimport lightning.pytorch as pl\nfrom omegaconf import DictConfig\nfrom torch.utils.data import DataLoader, Dataset\n\nfrom data.genx_utils.collate import custom_collate_rnd, custom_collate_streaming\nfrom data.genx_utils.dataset_rnd import (\n    build_random_access_dataset,\n    get_weighted_random_sampler,\n    CustomConcatDataset,\n)\nfrom data.genx_utils.dataset_streaming import build_streaming_dataset\nfrom data.utils.spatial import get_dataloading_hw\nfrom data.utils.types import DatasetMode, DatasetSamplingMode\n\n\ndef get_dataloader_kwargs(\n    dataset: Union[Dataset, CustomConcatDataset],\n    sampling_mode: DatasetSamplingMode,\n    dataset_mode: DatasetMode,\n    dataset_config: DictConfig,\n    batch_size: int,\n    num_workers: int,\n) -> Dict[str, Any]:\n    if dataset_mode == DatasetMode.TRAIN:\n        if sampling_mode == DatasetSamplingMode.STREAM:\n            return dict(\n                dataset=dataset,\n                batch_size=None,\n                shuffle=False,  # Done already in the streaming datapipe\n                num_workers=num_workers,\n                pin_memory=False,\n                drop_last=False,  # Cannot be done with streaming datapipes\n                collate_fn=custom_collate_streaming,\n            )\n        if sampling_mode == DatasetSamplingMode.RANDOM:\n            use_weighted_rnd_sampling = dataset_config.train.random.weighted_sampling\n            sampler = (\n                get_weighted_random_sampler(dataset)\n                if use_weighted_rnd_sampling\n                else None\n            )\n            return dict(\n                dataset=dataset,\n                batch_size=batch_size,\n                shuffle=sampler is None,\n                sampler=sampler,\n                num_workers=num_workers,\n                pin_memory=False,\n                drop_last=True,  # Maintain the same batch size for logging\n                collate_fn=custom_collate_rnd,\n            )\n        raise NotImplementedError\n    elif dataset_mode in (DatasetMode.VALIDATION, DatasetMode.TESTING):\n        if sampling_mode == DatasetSamplingMode.STREAM:\n            return dict(\n                dataset=dataset,\n                batch_size=None,\n                shuffle=False,\n                num_workers=num_workers,\n                pin_memory=False,\n                drop_last=False,  # Cannot be done with streaming datapipes\n                collate_fn=custom_collate_streaming,\n            )\n        if sampling_mode == DatasetSamplingMode.RANDOM:\n            return dict(\n                dataset=dataset,\n                batch_size=batch_size,\n                shuffle=False,\n                num_workers=num_workers,\n                pin_memory=False,\n                drop_last=True,  # Maintain the same batch size for logging\n                collate_fn=custom_collate_rnd,\n            )\n        raise NotImplementedError\n    raise NotImplementedError\n\n\nclass DataModule(pl.LightningDataModule):\n    def __init__(\n        self,\n        dataset_config: DictConfig,\n        num_workers_train: int,\n        num_workers_eval: int,\n        batch_size_train: int,\n        batch_size_eval: int,\n    ):\n        super().__init__()\n        assert num_workers_train >= 0\n        assert num_workers_eval >= 0\n        assert batch_size_train >= 1\n        assert batch_size_eval >= 1\n\n        self.dataset_config = dataset_config\n        self.train_sampling_mode = dataset_config.train.sampling\n        self.eval_sampling_mode = dataset_config.eval.sampling\n\n        assert self.train_sampling_mode in iter(DatasetSamplingMode)\n        assert self.eval_sampling_mode in (\n            DatasetSamplingMode.STREAM,\n            DatasetSamplingMode.RANDOM,\n        )\n\n        # In DDP all configs are per process/GPU (num_workers, batch_size, ...).\n        self.overall_batch_size_train = batch_size_train\n        self.overall_batch_size_eval = batch_size_eval\n        self.overall_num_workers_train = num_workers_train\n        self.overall_num_workers_eval = num_workers_eval\n\n        if self.eval_sampling_mode == DatasetSamplingMode.STREAM:\n            self.build_eval_dataset = partial(\n                build_streaming_dataset,\n                batch_size=self.overall_batch_size_eval,\n                num_workers=self.overall_num_workers_eval,\n            )\n        elif self.eval_sampling_mode == DatasetSamplingMode.RANDOM:\n            self.build_eval_dataset = build_random_access_dataset\n        else:\n            raise NotImplementedError\n\n        self.sampling_mode_2_dataset = dict()\n        self.sampling_mode_2_train_workers = dict()\n        self.sampling_mode_2_train_batch_size = dict()\n        self.validation_dataset = None\n        self.test_dataset = None\n\n    def get_dataloading_hw(self):\n        return get_dataloading_hw(dataset_config=self.dataset_config)\n\n    def set_mixed_sampling_mode_variables_for_train(self):\n        assert (\n            self.overall_batch_size_train >= 2\n        ), \"Cannot use mixed mode with batch size smaller than 2\"\n        assert (\n            self.overall_num_workers_train >= 2\n        ), \"Cannot use mixed mode with num workers smaller than 2\"\n        weight_random = self.dataset_config.train.mixed.w_random\n        weight_stream = self.dataset_config.train.mixed.w_stream\n        assert weight_random > 0\n        assert weight_stream > 0\n\n        # Set batch size according to weights.\n        bs_rnd = min(\n            round(\n                self.overall_batch_size_train\n                * weight_random\n                / (weight_stream + weight_random)\n            ),\n            self.overall_batch_size_train - 1,\n        )\n        bs_str = self.overall_batch_size_train - bs_rnd\n        self.sampling_mode_2_train_batch_size[DatasetSamplingMode.RANDOM] = bs_rnd\n        self.sampling_mode_2_train_batch_size[DatasetSamplingMode.STREAM] = bs_str\n\n        # Set num workers according to batch size. Random sampling typically takes longer than stream sampling!\n        workers_rnd = min(\n            math.ceil(\n                self.overall_num_workers_train * bs_rnd / self.overall_batch_size_train\n            ),\n            self.overall_num_workers_train - 1,\n        )\n        workers_str = self.overall_num_workers_train - workers_rnd\n        self.sampling_mode_2_train_workers[DatasetSamplingMode.RANDOM] = workers_rnd\n        self.sampling_mode_2_train_workers[DatasetSamplingMode.STREAM] = workers_str\n\n        print(\n            f\"[Train] Local batch size for:\\nstream sampling:\\t{bs_str}\\nrandom sampling:\\t{bs_rnd}\\n\"\n            f\"[Train] Local num workers for:\\nstream sampling:\\t{workers_str}\\nrandom sampling:\\t{workers_rnd}\"\n        )\n\n    def setup(self, stage: Optional[str] = None) -> None:\n        if stage == \"fit\":\n            if self.train_sampling_mode == DatasetSamplingMode.MIXED:\n                self.set_mixed_sampling_mode_variables_for_train()\n            else:\n                self.sampling_mode_2_train_workers[self.train_sampling_mode] = (\n                    self.overall_num_workers_train\n                )\n                self.sampling_mode_2_train_batch_size[self.train_sampling_mode] = (\n                    self.overall_batch_size_train\n                )\n            # This code is a bit hacky because at this point we not use DatasetSamplingMode.MIXED anymore\n            # because we split it up into random and streaming. DatasetSamplingMode.MIXED was just used to determine\n            # whether we use both or not.\n            if self.train_sampling_mode in (\n                DatasetSamplingMode.RANDOM,\n                DatasetSamplingMode.MIXED,\n            ):\n                self.sampling_mode_2_dataset[DatasetSamplingMode.RANDOM] = (\n                    build_random_access_dataset(\n                        dataset_mode=DatasetMode.TRAIN,\n                        dataset_config=self.dataset_config,\n                    )\n                )\n            if self.train_sampling_mode in (\n                DatasetSamplingMode.STREAM,\n                DatasetSamplingMode.MIXED,\n            ):\n                self.sampling_mode_2_dataset[DatasetSamplingMode.STREAM] = (\n                    build_streaming_dataset(\n                        dataset_mode=DatasetMode.TRAIN,\n                        dataset_config=self.dataset_config,\n                        batch_size=self.sampling_mode_2_train_batch_size[\n                            DatasetSamplingMode.STREAM\n                        ],\n                        num_workers=self.sampling_mode_2_train_workers[\n                            DatasetSamplingMode.STREAM\n                        ],\n                    )\n                )\n\n            self.validation_dataset = self.build_eval_dataset(\n                dataset_mode=DatasetMode.VALIDATION, dataset_config=self.dataset_config\n            )\n        elif stage == \"validate\":\n            self.validation_dataset = self.build_eval_dataset(\n                dataset_mode=DatasetMode.VALIDATION, dataset_config=self.dataset_config\n            )\n        elif stage == \"test\":\n            self.test_dataset = self.build_eval_dataset(\n                dataset_mode=DatasetMode.TESTING, dataset_config=self.dataset_config\n            )\n        else:\n            raise NotImplementedError\n\n    def train_dataloader(self):\n        train_loaders = dict()\n        for sampling_mode, dataset in self.sampling_mode_2_dataset.items():\n            train_loaders[sampling_mode] = DataLoader(\n                **get_dataloader_kwargs(\n                    dataset=dataset,\n                    sampling_mode=sampling_mode,\n                    dataset_mode=DatasetMode.TRAIN,\n                    dataset_config=self.dataset_config,\n                    batch_size=self.sampling_mode_2_train_batch_size[sampling_mode],\n                    num_workers=self.sampling_mode_2_train_workers[sampling_mode],\n                )\n            )\n        if len(train_loaders) == 1:\n            train_loaders = next(iter(train_loaders.values()))\n            # Returns a single dataloader.\n            return train_loaders\n        assert len(train_loaders) == 2\n        # Returns a mapping from dataset sampling modes to dataloader.\n        return train_loaders\n\n    def val_dataloader(self):\n        return DataLoader(\n            **get_dataloader_kwargs(\n                dataset=self.validation_dataset,\n                sampling_mode=self.eval_sampling_mode,\n                dataset_mode=DatasetMode.VALIDATION,\n                dataset_config=self.dataset_config,\n                batch_size=self.overall_batch_size_eval,\n                num_workers=self.overall_num_workers_eval,\n            )\n        )\n\n    def test_dataloader(self):\n        return DataLoader(\n            **get_dataloader_kwargs(\n                dataset=self.test_dataset,\n                sampling_mode=self.eval_sampling_mode,\n                dataset_mode=DatasetMode.TESTING,\n                dataset_config=self.dataset_config,\n                batch_size=self.overall_batch_size_eval,\n                num_workers=self.overall_num_workers_eval,\n            )\n        )\n"
  },
  {
    "path": "RVT/modules/detection.py",
    "content": "from typing import Any, Optional, Tuple, Union, Dict\nfrom warnings import warn\n\nimport numpy as np\nimport lightning.pytorch as pl\nimport torch\nimport torch as th\nimport torch.distributed as dist\nfrom omegaconf import DictConfig\nfrom lightning.pytorch.utilities.types import STEP_OUTPUT\nfrom einops import rearrange\n\nfrom data.genx_utils.labels import ObjectLabels\nfrom data.utils.types import DataType, LstmStates, ObjDetOutput, DatasetSamplingMode\nfrom models.detection.yolox.utils.boxes import postprocess\nfrom models.detection.yolox_extension.models.detector import YoloXDetector\nfrom utils.evaluation.prophesee.evaluator import PropheseeEvaluator\nfrom utils.evaluation.prophesee.io.box_loading import to_prophesee\nfrom utils.padding import InputPadderFromShape\nfrom .utils.detection import (\n    BackboneFeatureSelector,\n    EventReprSelector,\n    RNNStates,\n    Mode,\n    mode_2_string,\n    merge_mixed_batches,\n)\n\n\nclass Module(pl.LightningModule):\n    def __init__(self, full_config: DictConfig):\n        super().__init__()\n\n        self.full_config = full_config\n\n        self.mdl_config = full_config.model\n        in_res_hw = tuple(self.mdl_config.backbone.in_res_hw)\n        self.input_padder = InputPadderFromShape(desired_hw=in_res_hw)\n\n        self.mdl = YoloXDetector(self.mdl_config)\n\n        self.mode_2_rnn_states: Dict[Mode, RNNStates] = {\n            Mode.TRAIN: RNNStates(),\n            Mode.VAL: RNNStates(),\n            Mode.TEST: RNNStates(),\n        }\n\n    def setup(self, stage: Optional[str] = None) -> None:\n        dataset_name = self.full_config.dataset.name\n        self.mode_2_hw: Dict[Mode, Optional[Tuple[int, int]]] = {}\n        self.mode_2_batch_size: Dict[Mode, Optional[int]] = {}\n        self.mode_2_psee_evaluator: Dict[Mode, Optional[PropheseeEvaluator]] = {}\n        self.mode_2_sampling_mode: Dict[Mode, DatasetSamplingMode] = {}\n\n        self.started_training = True\n\n        dataset_train_sampling = self.full_config.dataset.train.sampling\n        dataset_eval_sampling = self.full_config.dataset.eval.sampling\n        assert dataset_train_sampling in iter(DatasetSamplingMode)\n        assert dataset_eval_sampling in (\n            DatasetSamplingMode.STREAM,\n            DatasetSamplingMode.RANDOM,\n        )\n        if stage == \"fit\":  # train + val\n            self.train_config = self.full_config.training\n            self.train_metrics_config = self.full_config.logging.train.metrics\n\n            if self.train_metrics_config.compute:\n                self.mode_2_psee_evaluator[Mode.TRAIN] = PropheseeEvaluator(\n                    dataset=dataset_name,\n                    downsample_by_2=self.full_config.dataset.downsample_by_factor_2,\n                )\n            self.mode_2_psee_evaluator[Mode.VAL] = PropheseeEvaluator(\n                dataset=dataset_name,\n                downsample_by_2=self.full_config.dataset.downsample_by_factor_2,\n            )\n            self.mode_2_sampling_mode[Mode.TRAIN] = dataset_train_sampling\n            self.mode_2_sampling_mode[Mode.VAL] = dataset_eval_sampling\n\n            for mode in (Mode.TRAIN, Mode.VAL):\n                self.mode_2_hw[mode] = None\n                self.mode_2_batch_size[mode] = None\n            self.started_training = False\n        elif stage == \"validate\":\n            mode = Mode.VAL\n            self.mode_2_psee_evaluator[mode] = PropheseeEvaluator(\n                dataset=dataset_name,\n                downsample_by_2=self.full_config.dataset.downsample_by_factor_2,\n            )\n            self.mode_2_sampling_mode[Mode.VAL] = dataset_eval_sampling\n            self.mode_2_hw[mode] = None\n            self.mode_2_batch_size[mode] = None\n        elif stage == \"test\":\n            mode = Mode.TEST\n            self.mode_2_psee_evaluator[mode] = PropheseeEvaluator(\n                dataset=dataset_name,\n                downsample_by_2=self.full_config.dataset.downsample_by_factor_2,\n            )\n            self.mode_2_sampling_mode[Mode.TEST] = dataset_eval_sampling\n            self.mode_2_hw[mode] = None\n            self.mode_2_batch_size[mode] = None\n        else:\n            raise NotImplementedError\n\n    def forward(\n        self,\n        event_tensor: th.Tensor,\n        previous_states: Optional[LstmStates] = None,\n        retrieve_detections: bool = True,\n        targets=None,\n    ) -> Tuple[Union[th.Tensor, None], Union[Dict[str, th.Tensor], None], LstmStates]:\n        return self.mdl(\n            x=event_tensor,\n            previous_states=previous_states,\n            retrieve_detections=retrieve_detections,\n            targets=targets,\n        )\n\n    def get_worker_id_from_batch(self, batch: Any) -> int:\n        return batch[\"worker_id\"]\n\n    def get_data_from_batch(self, batch: Any):\n        return batch[\"data\"]\n\n    def training_step(self, batch: Any, batch_idx: int) -> STEP_OUTPUT:\n        batch = merge_mixed_batches(batch)\n        data = self.get_data_from_batch(batch)\n        worker_id = self.get_worker_id_from_batch(batch)\n\n        mode = Mode.TRAIN\n        self.started_training = True\n        step = self.trainer.global_step\n        ev_tensor_sequence = data[DataType.EV_REPR]\n        sparse_obj_labels = data[DataType.OBJLABELS_SEQ]\n        is_first_sample = data[DataType.IS_FIRST_SAMPLE]\n        token_mask_sequence = data.get(DataType.TOKEN_MASK, None)\n\n        self.mode_2_rnn_states[mode].reset(\n            worker_id=worker_id, indices_or_bool_tensor=is_first_sample\n        )\n\n        sequence_len = len(ev_tensor_sequence)\n        assert sequence_len > 0\n        batch_size = len(sparse_obj_labels[0])\n        if self.mode_2_batch_size[mode] is None:\n            self.mode_2_batch_size[mode] = batch_size\n        else:\n            assert self.mode_2_batch_size[mode] == batch_size\n\n        prev_states = self.mode_2_rnn_states[mode].get_states(worker_id=worker_id)\n        backbone_feature_selector = BackboneFeatureSelector()\n        ev_repr_selector = EventReprSelector()\n        obj_labels = list()\n\n        ev_tensor_sequence = torch.stack(\n            ev_tensor_sequence\n        )  # shape: (sequence_len, batch_size, channels, height, width) = (L, B, C, H, W)\n        ev_tensor_sequence = ev_tensor_sequence.to(dtype=self.dtype)\n        ev_tensor_sequence = self.input_padder.pad_tensor_ev_repr(ev_tensor_sequence)\n        if token_mask_sequence is not None:\n            token_mask_sequence = torch.stack(token_mask_sequence)\n            token_mask_sequence = token_mask_sequence.to(dtype=self.dtype)\n            token_mask_sequence = self.input_padder.pad_token_mask(\n                token_mask=token_mask_sequence\n            )\n        else:\n            token_mask_sequence = None\n\n        if self.mode_2_hw[mode] is None:\n            self.mode_2_hw[mode] = tuple(ev_tensor_sequence.shape[-2:])\n        else:\n            assert self.mode_2_hw[mode] == ev_tensor_sequence.shape[-2:]\n\n        backbone_features, states = self.mdl.forward_backbone(\n            x=ev_tensor_sequence,\n            previous_states=prev_states,\n            token_mask=token_mask_sequence,\n            train_step=True,\n        )\n        prev_states = states\n\n        for tidx, curr_labels in enumerate(sparse_obj_labels):\n            (\n                current_labels,\n                valid_batch_indices,\n            ) = curr_labels.get_valid_labels_and_batch_indices()\n            # Store backbone features that correspond to the available labels.\n            if len(current_labels) > 0:\n                backbone_feature_selector.add_backbone_features(\n                    backbone_features={\n                        k: v[tidx] for k, v in backbone_features.items()\n                    },\n                    selected_indices=valid_batch_indices,\n                )\n                obj_labels.extend(current_labels)\n                ev_repr_selector.add_event_representations(\n                    event_representations=ev_tensor_sequence[tidx],\n                    selected_indices=valid_batch_indices,\n                )\n\n        self.mode_2_rnn_states[mode].save_states_and_detach(\n            worker_id=worker_id, states=prev_states\n        )\n        assert len(obj_labels) > 0\n        # Batch the backbone features and labels to parallelize the detection code.\n        selected_backbone_features = (\n            backbone_feature_selector.get_batched_backbone_features()\n        )\n        labels_yolox = ObjectLabels.get_labels_as_batched_tensor(\n            obj_label_list=obj_labels, format_=\"yolox\"\n        )\n        labels_yolox = labels_yolox.to(dtype=self.dtype)\n\n        predictions, losses = self.mdl.forward_detect(\n            backbone_features=selected_backbone_features, targets=labels_yolox\n        )\n\n        if self.mode_2_sampling_mode[mode] in (\n            DatasetSamplingMode.MIXED,\n            DatasetSamplingMode.RANDOM,\n        ):\n            # We only want to evaluate the last batch_size samples if we use random sampling (or mixed).\n            # This is because otherwise we would mostly evaluate the init phase of the sequence.\n            predictions = predictions[-batch_size:]\n            obj_labels = obj_labels[-batch_size:]\n\n        pred_processed = postprocess(\n            prediction=predictions,\n            num_classes=self.mdl_config.head.num_classes,\n            conf_thre=self.mdl_config.postprocess.confidence_threshold,\n            nms_thre=self.mdl_config.postprocess.nms_threshold,\n        )\n\n        loaded_labels_proph, yolox_preds_proph = to_prophesee(\n            obj_labels, pred_processed\n        )\n\n        assert losses is not None\n        assert \"loss\" in losses\n\n        # For visualization, we only use the last batch_size items.\n        output = {\n            ObjDetOutput.LABELS_PROPH: loaded_labels_proph[-batch_size:],\n            ObjDetOutput.PRED_PROPH: yolox_preds_proph[-batch_size:],\n            ObjDetOutput.EV_REPR: ev_repr_selector.get_event_representations_as_list(\n                start_idx=-batch_size\n            ),\n            ObjDetOutput.SKIP_VIZ: False,\n            \"loss\": losses[\"loss\"],\n        }\n\n        # Logging\n        prefix = f\"{mode_2_string[mode]}/\"\n        log_dict = {f\"{prefix}{k}\": v for k, v in losses.items()}\n        self.log_dict(\n            log_dict, on_step=True, on_epoch=True, batch_size=batch_size, sync_dist=True\n        )\n\n        if mode in self.mode_2_psee_evaluator:\n            self.mode_2_psee_evaluator[mode].add_labels(loaded_labels_proph)\n            self.mode_2_psee_evaluator[mode].add_predictions(yolox_preds_proph)\n            if (\n                self.train_metrics_config.detection_metrics_every_n_steps is not None\n                and step > 0\n                and step % self.train_metrics_config.detection_metrics_every_n_steps\n                == 0\n            ):\n                self.run_psee_evaluator(mode=mode)\n\n        return output\n\n    def _val_test_step_impl(self, batch: Any, mode: Mode) -> Optional[STEP_OUTPUT]:\n        data = self.get_data_from_batch(batch)\n        worker_id = self.get_worker_id_from_batch(batch)\n\n        assert mode in (Mode.VAL, Mode.TEST)\n        ev_tensor_sequence = data[DataType.EV_REPR]\n        sparse_obj_labels = data[DataType.OBJLABELS_SEQ]\n        is_first_sample = data[DataType.IS_FIRST_SAMPLE]\n\n        self.mode_2_rnn_states[mode].reset(\n            worker_id=worker_id, indices_or_bool_tensor=is_first_sample\n        )\n\n        sequence_len = len(ev_tensor_sequence)\n        assert sequence_len > 0\n        batch_size = len(sparse_obj_labels[0])\n        if self.mode_2_batch_size[mode] is None:\n            self.mode_2_batch_size[mode] = batch_size\n        else:\n            assert self.mode_2_batch_size[mode] == batch_size\n\n        prev_states = self.mode_2_rnn_states[mode].get_states(worker_id=worker_id)\n        backbone_feature_selector = BackboneFeatureSelector()\n        ev_repr_selector = EventReprSelector()\n        obj_labels = list()\n\n        ev_tensor_sequence = torch.stack(\n            ev_tensor_sequence\n        )  # shape: (sequence_len, batch_size, channels, height, width) = (L, B, C, H, W)\n        ev_tensor_sequence = ev_tensor_sequence.to(dtype=self.dtype)\n        ev_tensor_sequence = self.input_padder.pad_tensor_ev_repr(ev_tensor_sequence)\n\n        if self.mode_2_hw[mode] is None:\n            self.mode_2_hw[mode] = tuple(ev_tensor_sequence.shape[-2:])\n        else:\n            assert self.mode_2_hw[mode] == ev_tensor_sequence.shape[-2:]\n\n        backbone_features, states = self.mdl.forward_backbone(\n            x=ev_tensor_sequence,\n            previous_states=prev_states,\n            train_step=False,\n        )\n\n        prev_states = states\n\n        for tidx in range(sequence_len):\n            collect_predictions = (tidx == sequence_len - 1) or (\n                self.mode_2_sampling_mode[mode] == DatasetSamplingMode.STREAM\n            )\n\n            if collect_predictions:\n                current_labels, valid_batch_indices = sparse_obj_labels[\n                    tidx\n                ].get_valid_labels_and_batch_indices()\n                # Store backbone features that correspond to the available labels.\n                if len(current_labels) > 0:\n                    backbone_feature_selector.add_backbone_features(\n                        backbone_features={\n                            k: v[tidx] for k, v in backbone_features.items()\n                        },\n                        selected_indices=valid_batch_indices,\n                    )\n\n                    obj_labels.extend(current_labels)\n                    ev_repr_selector.add_event_representations(\n                        event_representations=ev_tensor_sequence[tidx],\n                        selected_indices=valid_batch_indices,\n                    )\n        self.mode_2_rnn_states[mode].save_states_and_detach(\n            worker_id=worker_id, states=prev_states\n        )\n        if len(obj_labels) == 0:\n            return {ObjDetOutput.SKIP_VIZ: True}\n        selected_backbone_features = (\n            backbone_feature_selector.get_batched_backbone_features()\n        )\n        predictions, _ = self.mdl.forward_detect(\n            backbone_features=selected_backbone_features\n        )\n\n        pred_processed = postprocess(\n            prediction=predictions,\n            num_classes=self.mdl_config.head.num_classes,\n            conf_thre=self.mdl_config.postprocess.confidence_threshold,\n            nms_thre=self.mdl_config.postprocess.nms_threshold,\n        )\n\n        loaded_labels_proph, yolox_preds_proph = to_prophesee(\n            obj_labels, pred_processed\n        )\n\n        # For visualization, we only use the last item (per batch).\n        output = {\n            ObjDetOutput.LABELS_PROPH: loaded_labels_proph[-1],\n            ObjDetOutput.PRED_PROPH: yolox_preds_proph[-1],\n            ObjDetOutput.EV_REPR: ev_repr_selector.get_event_representations_as_list(\n                start_idx=-1\n            )[0],\n            ObjDetOutput.SKIP_VIZ: False,\n        }\n\n        if self.started_training:\n            self.mode_2_psee_evaluator[mode].add_labels(loaded_labels_proph)\n            self.mode_2_psee_evaluator[mode].add_predictions(yolox_preds_proph)\n\n        return output\n\n    def validation_step(self, batch: Any, batch_idx: int) -> Optional[STEP_OUTPUT]:\n        return self._val_test_step_impl(batch=batch, mode=Mode.VAL)\n\n    def test_step(self, batch: Any, batch_idx: int) -> Optional[STEP_OUTPUT]:\n        return self._val_test_step_impl(batch=batch, mode=Mode.TEST)\n\n    def run_psee_evaluator(self, mode: Mode):\n        psee_evaluator = self.mode_2_psee_evaluator[mode]\n        batch_size = self.mode_2_batch_size[mode]\n        hw_tuple = self.mode_2_hw[mode]\n        if psee_evaluator is None:\n            warn(f\"psee_evaluator is None in {mode=}\", UserWarning, stacklevel=2)\n            return\n        assert batch_size is not None\n        assert hw_tuple is not None\n        if psee_evaluator.has_data():\n            metrics = psee_evaluator.evaluate_buffer(\n                img_height=hw_tuple[0], img_width=hw_tuple[1]\n            )\n            assert metrics is not None\n\n            prefix = f\"{mode_2_string[mode]}/\"\n            step = self.trainer.global_step\n            log_dict = {}\n            for k, v in metrics.items():\n                if isinstance(v, (int, float)):\n                    value = torch.tensor(v)\n                elif isinstance(v, np.ndarray):\n                    value = torch.from_numpy(v)\n                elif isinstance(v, torch.Tensor):\n                    value = v\n                else:\n                    raise NotImplementedError\n                assert (\n                    value.ndim == 0\n                ), f\"tensor must be a scalar.\\n{v=}\\n{type(v)=}\\n{value=}\\n{type(value)=}\"\n                # put them on the current device to avoid this error: https://github.com/Lightning-AI/lightning/discussions/2529\n                log_dict[f\"{prefix}{k}\"] = value.to(self.device)\n            # Somehow self.log does not work when we eval during the training epoch.\n            self.log_dict(\n                log_dict,\n                on_step=False,\n                on_epoch=True,\n                batch_size=batch_size,\n                sync_dist=True,\n            )\n            if dist.is_available() and dist.is_initialized():\n                # We now have to manually sync (average the metrics) across processes in case of distributed training.\n                # NOTE: This is necessary to ensure that we have the same numbers for the checkpoint metric (metadata)\n                # and wandb metric:\n                # - checkpoint callback is using the self.log function which uses global sync (avg across ranks)\n                # - wandb uses log_metrics that we reduce manually to global rank 0\n                dist.barrier()\n                for k, v in log_dict.items():\n                    dist.reduce(log_dict[k], dst=0, op=dist.ReduceOp.SUM)\n                    if dist.get_rank() == 0:\n                        log_dict[k] /= dist.get_world_size()\n            if self.trainer.is_global_zero:\n                # For some reason we need to increase the step by 2 to enable consistent logging in wandb here.\n                # I might not understand wandb login correctly. This works reasonably well for now.\n                add_hack = 2\n                self.logger.log_metrics(metrics=log_dict, step=step + add_hack)\n\n            psee_evaluator.reset_buffer()\n        else:\n            warn(f\"psee_evaluator has not data in {mode=}\", UserWarning, stacklevel=2)\n\n    def on_train_epoch_end(self) -> None:\n        mode = Mode.TRAIN\n        if (\n            mode in self.mode_2_psee_evaluator\n            and self.train_metrics_config.detection_metrics_every_n_steps is None\n            and self.mode_2_hw[mode] is not None\n        ):\n            # For some reason PL calls this function when resuming.\n            # We don't know yet the value of train_height_width, so we skip this\n            self.run_psee_evaluator(mode=mode)\n\n    def on_validation_epoch_end(self) -> None:\n        mode = Mode.VAL\n        if self.started_training:\n            assert self.mode_2_psee_evaluator[mode].has_data()\n            self.run_psee_evaluator(mode=mode)\n\n    def on_test_epoch_end(self) -> None:\n        mode = Mode.TEST\n        assert self.mode_2_psee_evaluator[mode].has_data()\n        self.run_psee_evaluator(mode=mode)\n\n    def configure_optimizers(self) -> Any:\n        lr = self.train_config.learning_rate\n        weight_decay = self.train_config.weight_decay\n        optimizer = th.optim.AdamW(\n            self.mdl.parameters(), lr=lr, weight_decay=weight_decay\n        )\n\n        scheduler_params = self.train_config.lr_scheduler\n        if not scheduler_params.use:\n            return optimizer\n\n        total_steps = scheduler_params.total_steps\n        assert total_steps is not None\n        assert total_steps > 0\n        # Here we interpret the final lr as max_lr/final_div_factor.\n        # Note that Pytorch OneCycleLR interprets it as initial_lr/final_div_factor:\n        final_div_factor_pytorch = (\n            scheduler_params.final_div_factor / scheduler_params.div_factor\n        )\n        lr_scheduler = torch.optim.lr_scheduler.OneCycleLR(\n            optimizer=optimizer,\n            max_lr=lr,\n            div_factor=scheduler_params.div_factor,\n            final_div_factor=final_div_factor_pytorch,\n            total_steps=total_steps,\n            pct_start=scheduler_params.pct_start,\n            cycle_momentum=False,\n            anneal_strategy=\"linear\",\n        )\n        lr_scheduler_config = {\n            \"scheduler\": lr_scheduler,\n            \"interval\": \"step\",\n            \"frequency\": 1,\n            \"strict\": True,\n            \"name\": \"learning_rate\",\n        }\n\n        return {\"optimizer\": optimizer, \"lr_scheduler\": lr_scheduler_config}\n"
  },
  {
    "path": "RVT/modules/utils/detection.py",
    "content": "from enum import Enum, auto\nfrom typing import List, Optional, Union, Tuple, Dict, Any\n\nimport torch\nimport torch as th\n\nfrom data.genx_utils.labels import SparselyBatchedObjectLabels\nfrom data.utils.types import BackboneFeatures, LstmStates, DatasetSamplingMode\n\n\nclass Mode(Enum):\n    TRAIN = auto()\n    VAL = auto()\n    TEST = auto()\n\n\nmode_2_string = {\n    Mode.TRAIN: \"train\",\n    Mode.VAL: \"val\",\n    Mode.TEST: \"test\",\n}\n\n\nclass BackboneFeatureSelector:\n    def __init__(self):\n        self.features = None\n        self.reset()\n\n    def reset(self):\n        self.features = dict()\n\n    def add_backbone_features(\n        self,\n        backbone_features: BackboneFeatures,\n        selected_indices: Optional[List[int]] = None,\n    ) -> None:\n        if selected_indices is not None:\n            assert len(selected_indices) > 0\n        for k, v in backbone_features.items():\n            if k not in self.features:\n                self.features[k] = (\n                    [v[selected_indices]] if selected_indices is not None else [v]\n                )\n            else:\n                self.features[k].append(\n                    v[selected_indices] if selected_indices is not None else v\n                )\n\n    def get_batched_backbone_features(self) -> Optional[BackboneFeatures]:\n        if len(self.features) == 0:\n            return None\n        return {k: th.cat(v, dim=0) for k, v in self.features.items()}\n\n\nclass EventReprSelector:\n    def __init__(self):\n        self.repr_list = None\n        self.reset()\n\n    def reset(self):\n        self.repr_list = list()\n\n    def __len__(self):\n        return len(self.repr_list)\n\n    def add_event_representations(\n        self,\n        event_representations: th.Tensor,\n        selected_indices: Optional[List[int]] = None,\n    ) -> None:\n        if selected_indices is not None:\n            assert len(selected_indices) > 0\n        self.repr_list.extend(\n            x[0] for x in event_representations[selected_indices].split(1)\n        )\n\n    def get_event_representations_as_list(\n        self, start_idx: int = 0, end_idx: Optional[int] = None\n    ) -> Optional[List[th.Tensor]]:\n        if len(self) == 0:\n            return None\n        if end_idx is None:\n            end_idx = len(self)\n        assert start_idx < end_idx, f\"{start_idx=}, {end_idx=}\"\n        return self.repr_list[start_idx:end_idx]\n\n\nclass RNNStates:\n    def __init__(self):\n        self.states = {}\n\n    def _has_states(self):\n        return len(self.states) > 0\n\n    @classmethod\n    def recursive_detach(cls, inp: Union[th.Tensor, List, Tuple, Dict]):\n        if isinstance(inp, th.Tensor):\n            return inp.detach()\n        if isinstance(inp, list):\n            return [cls.recursive_detach(x) for x in inp]\n        if isinstance(inp, tuple):\n            return tuple(cls.recursive_detach(x) for x in inp)\n        if isinstance(inp, dict):\n            return {k: cls.recursive_detach(v) for k, v in inp.items()}\n        raise NotImplementedError\n\n    @classmethod\n    def recursive_reset(\n        cls,\n        inp: Union[th.Tensor, List, Tuple, Dict],\n        indices_or_bool_tensor: Optional[Union[List[int], torch.Tensor]] = None,\n    ):\n        if isinstance(inp, th.Tensor):\n            assert (\n                inp.requires_grad is False\n            ), \"Not assumed here but should be the case.\"\n            if indices_or_bool_tensor is None:\n                inp[:] = 0\n            else:\n                assert len(indices_or_bool_tensor) > 0\n                inp[indices_or_bool_tensor] = 0\n            return inp\n        if isinstance(inp, list):\n            return [\n                cls.recursive_reset(x, indices_or_bool_tensor=indices_or_bool_tensor)\n                for x in inp\n            ]\n        if isinstance(inp, tuple):\n            return tuple(\n                cls.recursive_reset(x, indices_or_bool_tensor=indices_or_bool_tensor)\n                for x in inp\n            )\n        if isinstance(inp, dict):\n            return {\n                k: cls.recursive_reset(v, indices_or_bool_tensor=indices_or_bool_tensor)\n                for k, v in inp.items()\n            }\n        raise NotImplementedError\n\n    def save_states_and_detach(self, worker_id: int, states: LstmStates) -> None:\n        self.states[worker_id] = self.recursive_detach(states)\n\n    def get_states(self, worker_id: int) -> Optional[LstmStates]:\n        if not self._has_states():\n            return None\n        if worker_id not in self.states:\n            return None\n        return self.states[worker_id]\n\n    def reset(\n        self,\n        worker_id: int,\n        indices_or_bool_tensor: Optional[Union[List[int], torch.Tensor]] = None,\n    ):\n        if not self._has_states():\n            return\n        if worker_id in self.states:\n            self.states[worker_id] = self.recursive_reset(\n                self.states[worker_id], indices_or_bool_tensor=indices_or_bool_tensor\n            )\n\n\ndef mixed_collate_fn(\n    x1: Union[th.Tensor, List[th.Tensor]], x2: Union[th.Tensor, List[th.Tensor]]\n):\n    if isinstance(x1, th.Tensor):\n        assert isinstance(x2, th.Tensor)\n        return th.cat((x1, x2))\n    if isinstance(x1, SparselyBatchedObjectLabels):\n        assert isinstance(x2, SparselyBatchedObjectLabels)\n        return x1 + x2\n    if isinstance(x1, list):\n        assert isinstance(x2, list)\n        assert len(x1) == len(x2)\n        return [mixed_collate_fn(x1=el_1, x2=el_2) for el_1, el_2 in zip(x1, x2)]\n    raise NotImplementedError\n\n\ndef merge_mixed_batches(batch: Dict[str, Any]):\n    if \"data\" in batch:\n        return batch\n    rnd_data = batch[DatasetSamplingMode.RANDOM][\"data\"]\n    stream_batch = batch[DatasetSamplingMode.STREAM]\n    # We only care about the worker id of the streaming dataloader because the states will be anyway reset for the\n    # random dataloader batch.\n    out = {\"worker_id\": stream_batch[\"worker_id\"]}\n    stream_data = stream_batch[\"data\"]\n    assert (\n        rnd_data.keys() == stream_data.keys()\n    ), f\"{rnd_data.keys()=}, {stream_data.keys()=}\"\n    data_out = dict()\n    for key in rnd_data.keys():\n        data_out[key] = mixed_collate_fn(stream_data[key], rnd_data[key])\n    out.update({\"data\": data_out})\n    return out\n"
  },
  {
    "path": "RVT/modules/utils/fetch.py",
    "content": "import lightning.pytorch as pl\nfrom omegaconf import DictConfig\n\nfrom modules.data.genx import DataModule as genx_data_module\nfrom modules.detection import Module as rnn_det_module\n\n\ndef fetch_model_module(config: DictConfig) -> pl.LightningModule:\n    model_str = config.model.name\n    if model_str == \"rnndet\":\n        return rnn_det_module(config)\n    raise NotImplementedError\n\n\ndef fetch_data_module(config: DictConfig) -> pl.LightningDataModule:\n    batch_size_train = config.batch_size.train\n    batch_size_eval = config.batch_size.eval\n    num_workers_generic = config.hardware.get(\"num_workers\", None)\n    num_workers_train = config.hardware.num_workers.get(\"train\", num_workers_generic)\n    num_workers_eval = config.hardware.num_workers.get(\"eval\", num_workers_generic)\n    dataset_str = config.dataset.name\n    if dataset_str in {\"gen1\", \"gen4\"}:\n        return genx_data_module(\n            config.dataset,\n            num_workers_train=num_workers_train,\n            num_workers_eval=num_workers_eval,\n            batch_size_train=batch_size_train,\n            batch_size_eval=batch_size_eval,\n        )\n    raise NotImplementedError\n"
  },
  {
    "path": "RVT/scripts/genx/README.md",
    "content": "# Pre-Processing the Original Dataset\n\n### 1. Download the data\n<table><tbody>\n<th valign=\"bottom\"></th>\n<th valign=\"bottom\">train</th>\n<th valign=\"bottom\">validation</th>\n<th valign=\"bottom\">test</th>\n<tr><td align=\"left\">1 Mpx</td>\n<td align=\"center\"><a href=\"https://download.ifi.uzh.ch/rpg/RVT/datasets/gen4_tar/train.tar\">download</a></td>\n<td align=\"center\"><a href=\"https://download.ifi.uzh.ch/rpg/RVT/datasets/gen4_tar/val.tar\">download</a></td>\n<td align=\"center\"><a href=\"https://download.ifi.uzh.ch/rpg/RVT/datasets/gen4_tar/test.tar\">download</a></td>\n</tr>\n<tr><td align=\"left\">crc32</td>\n<td align=\"center\"><tt>d677488a</tt></td>\n<td align=\"center\"><tt>72f13c3e</tt></td>\n<td align=\"center\"><tt>643e61ef</tt></td>\n</tr>\n<tr><td align=\"left\">Gen1</td>\n<td align=\"center\"><a href=\"https://download.ifi.uzh.ch/rpg/RVT/datasets/gen1_tar/train.tar\">download</a></td>\n<td align=\"center\"><a href=\"https://download.ifi.uzh.ch/rpg/RVT/datasets/gen1_tar/val.tar\">download</a></td>\n<td align=\"center\"><a href=\"https://download.ifi.uzh.ch/rpg/RVT/datasets/gen1_tar/test.tar\">download</a></td>\n</tr>\n<tr><td align=\"left\">crc32</td>\n<td align=\"center\"><tt>3d23bd30</tt></td>\n<td align=\"center\"><tt>cc802022</tt></td>\n<td align=\"center\"><tt>cdd4fd69</tt></td>\n</tr>\n</tbody></table>\n\n### 2. Extract the tar files\nThe following directory structure is assumed:\n\n```\ndata_dir\n├── test\n│   ├── ..._bbox.npy\n│   ├── ..._td.dat.h5\n│   ...\n│\n├── train\n│   ├── ....npy\n│   ├── ..._td.dat.h5\n│   ...\n│\n└── val\n    ├── ..._bbox.npy\n    ├── ..._td.dat.h5\n    ... \n```\n\n### 3. Run the pre-processing script\n`${DATA_DIR}` should point to the directory structure mentioned above.\n`${DEST_DIR}` should point to the directory to which the data will be written.\n\nFor the 1 Mpx dataset:\n```Bash\nNUM_PROCESSES=20  # set to the number of parallel processes to use\npython preprocess_dataset.py ${DATA_DIR} ${DEST_DIR} conf_preprocess/representation/stacked_hist.yaml \\\nconf_preprocess/extraction/const_duration.yaml conf_preprocess/filter_gen4.yaml -ds gen4 -np ${NUM_PROCESSES}\n```\n\nFor the Gen1 dataset:\n```Bash\nNUM_PROCESSES=20  # set to the number of parallel processes to use\npython preprocess_dataset.py ${DATA_DIR} ${DEST_DIR} conf_preprocess/representation/stacked_hist.yaml \\\nconf_preprocess/extraction/const_duration.yaml conf_preprocess/filter_gen1.yaml -ds gen1 -np ${NUM_PROCESSES}\n```\n"
  },
  {
    "path": "RVT/scripts/genx/conf_preprocess/extraction/const_count.yaml",
    "content": "method: COUNT\nvalue: 50000"
  },
  {
    "path": "RVT/scripts/genx/conf_preprocess/extraction/const_duration.yaml",
    "content": "method: DURATION\n# value is in milliseconds!\nvalue: 50\n"
  },
  {
    "path": "RVT/scripts/genx/conf_preprocess/extraction/frequencies/const_duration_100hz.yaml",
    "content": "method: DURATION\n# value is in milliseconds!\nvalue: 10\n"
  },
  {
    "path": "RVT/scripts/genx/conf_preprocess/extraction/frequencies/const_duration_200hz.yaml",
    "content": "method: DURATION\n# value is in milliseconds!\nvalue: 5\n"
  },
  {
    "path": "RVT/scripts/genx/conf_preprocess/extraction/frequencies/const_duration_40hz.yaml",
    "content": "method: DURATION\n# value is in milliseconds!\nvalue: 25\n"
  },
  {
    "path": "RVT/scripts/genx/conf_preprocess/extraction/frequencies/const_duration_80hz.yaml",
    "content": "method: DURATION\n# value is in milliseconds!\nvalue: 12\n"
  },
  {
    "path": "RVT/scripts/genx/conf_preprocess/filter_gen1.yaml",
    "content": "apply_psee_bbox_filter: True\napply_faulty_bbox_filter: True"
  },
  {
    "path": "RVT/scripts/genx/conf_preprocess/filter_gen4.yaml",
    "content": "apply_psee_bbox_filter: False\napply_faulty_bbox_filter: True"
  },
  {
    "path": "RVT/scripts/genx/conf_preprocess/representation/mixeddensity_stack.yaml",
    "content": "name: \"mixeddensity_stack\"\nnbins: 10\ncount_cutoff: 32\n"
  },
  {
    "path": "RVT/scripts/genx/conf_preprocess/representation/stacked_hist.yaml",
    "content": "name: \"stacked_histogram\"\nnbins: 10\ncount_cutoff: 10\n"
  },
  {
    "path": "RVT/scripts/genx/preprocess_dataset.py",
    "content": "import os\n\nos.environ[\"OMP_NUM_THREADS\"] = \"1\"\nos.environ[\"OPENBLAS_NUM_THREADS\"] = \"1\"\nos.environ[\"MKL_NUM_THREADS\"] = \"1\"\nos.environ[\"VECLIB_MAXIMUM_THREADS\"] = \"1\"\nos.environ[\"NUMEXPR_NUM_THREADS\"] = \"1\"\n\nfrom abc import ABC, abstractmethod\nimport argparse\nfrom dataclasses import dataclass, field\nfrom enum import Enum, auto\nfrom functools import partial\nfrom multiprocessing import get_context\nfrom pathlib import Path\nimport shutil\nimport sys\n\nsys.path.append(\"../..\")\nfrom typing import Any, Dict, List, Optional, Tuple, Union\nimport weakref\n\nimport h5py\nimport hdf5plugin\nfrom numba import jit\nimport numpy as np\nfrom omegaconf import OmegaConf, DictConfig, MISSING\nimport torch\nfrom tqdm import tqdm\n\nfrom utils.preprocessing import _blosc_opts\nfrom data.utils.representations import (\n    MixedDensityEventStack,\n    StackedHistogram,\n    RepresentationBase,\n)\n\n\nclass DataKeys(Enum):\n    InNPY = auto()\n    InH5 = auto()\n    OutLabelDir = auto()\n    OutEvReprDir = auto()\n    SplitType = auto()\n\n\nclass SplitType(Enum):\n    TRAIN = auto()\n    VAL = auto()\n    TEST = auto()\n\n\nsplit_name_2_type = {\n    \"train\": SplitType.TRAIN,\n    \"val\": SplitType.VAL,\n    \"test\": SplitType.TEST,\n}\n\ndataset_2_height = {\"gen1\": 240, \"gen4\": 720}\ndataset_2_width = {\"gen1\": 304, \"gen4\": 1280}\n\n# The following sequences would be discarded because all the labels would be removed after filtering:\ndirs_to_ignore = {\n    \"gen1\": (\n        \"17-04-06_09-57-37_6344500000_6404500000\",\n        \"17-04-13_19-17-27_976500000_1036500000\",\n        \"17-04-06_15-14-36_1159500000_1219500000\",\n        \"17-04-11_15-13-23_122500000_182500000\",\n    ),\n    \"gen4\": (),\n}\n\n\nclass NoLabelsException(Exception):\n    # Raised when no labels are present anymore in the sequence after filtering\n    ...\n\n\nclass H5Writer:\n    def __init__(\n        self, outfile: Path, key: str, ev_repr_shape: Tuple, numpy_dtype: np.dtype\n    ):\n        assert len(ev_repr_shape) == 3\n        self.h5f = h5py.File(str(outfile), \"w\")\n        # Sets a finalizer that ensures the file gets closed when the object is garbage collected\n        self._finalizer = weakref.finalize(self, self.close_callback, self.h5f)\n        self.key = key  # The dataset name/key inside the HDF5 file\n        self.numpy_dtype = numpy_dtype\n\n        # create hdf5 datasets\n        maxshape = (None,) + ev_repr_shape\n        chunkshape = (1,) + ev_repr_shape\n        self.maxshape = maxshape\n        self.h5f.create_dataset(\n            key,\n            dtype=self.numpy_dtype.name,\n            shape=chunkshape,\n            chunks=chunkshape,\n            maxshape=maxshape,\n            **_blosc_opts(complevel=1, shuffle=\"byte\"),\n        )\n        self.t_idx = 0\n\n    # enter and exit alllow to use the class as a context manager\n    def __enter__(self):\n        return self\n\n    def __exit__(self, exc_type, exc_val, exc_tb):\n        self._finalizer()\n\n    @staticmethod\n    def close_callback(h5f: h5py.File):\n        h5f.close()\n\n    def close(self):\n        self.h5f.close()\n\n    def get_current_length(self):\n        return self.t_idx\n\n    def add_data(self, data: np.ndarray):\n        # append new data into the already initialized HDF5 dataset\n        assert data.dtype == self.numpy_dtype, f\"{data.dtype=}, {self.numpy_dtype=}\"\n        assert data.shape == self.maxshape[1:]\n        new_size = self.t_idx + 1\n        self.h5f[self.key].resize(new_size, axis=0)\n        self.h5f[self.key][\n            self.t_idx : new_size\n        ] = data  # it writes the new data to the last position of the first dimension\n        self.t_idx = new_size  # It updates the internal index (self.t_idx) to point to the next empty slot in the dataset\n\n\nclass H5Reader:\n    def __init__(self, h5_file: Path, dataset: str = \"gen4\"):\n        assert h5_file.exists()\n        assert h5_file.suffix == \".h5\"\n        assert dataset in {\"gen1\", \"gen4\"}\n\n        self.h5f = h5py.File(str(h5_file), \"r\")\n        self._finalizer = weakref.finalize(self, self._close_callback, self.h5f)\n        self.is_open = True\n\n        try:\n            self.height = self.h5f[\"events\"][\"height\"][()].item()\n            self.width = self.h5f[\"events\"][\"width\"][()].item()\n        except KeyError:\n            self.height = dataset_2_height[dataset]\n            self.width = dataset_2_width[dataset]\n\n        self.all_times = None\n\n    def __enter__(self):\n        return self\n\n    def __exit__(self, exc_type, exc_val, exc_tb):\n        self._finalizer()\n\n    @staticmethod\n    def _close_callback(h5f: h5py.File):\n        h5f.close()\n\n    def close(self):\n        self.h5f.close()\n        self.is_open = False\n\n    def get_height_and_width(self) -> Tuple[int, int]:\n        return self.height, self.width\n\n    @property\n    def time(self) -> np.ndarray:\n        # We need to lazy load time because it is typically not sorted everywhere.\n        # - Set timestamps of events such they are not decreasing.\n        assert self.is_open\n        if self.all_times is None:\n            self.all_times = np.asarray(self.h5f[\"events\"][\"t\"])\n            self._correct_time(self.all_times)\n        return self.all_times\n\n    @staticmethod\n    @jit(nopython=True)\n    def _correct_time(time_array: np.ndarray):\n        assert time_array[0] >= 0\n        time_last = 0\n        for idx, time in enumerate(time_array):\n            if time < time_last:\n                time_array[idx] = time_last\n            else:\n                time_last = time\n\n    def get_event_slice(\n        self, idx_start: int, idx_end: int, convert_2_torch: bool = True\n    ):\n        assert self.is_open\n        assert idx_end >= idx_start\n        ev_data = self.h5f[\"events\"]\n        x_array = np.asarray(ev_data[\"x\"][idx_start:idx_end], dtype=\"int64\")\n        y_array = np.asarray(ev_data[\"y\"][idx_start:idx_end], dtype=\"int64\")\n        p_array = np.asarray(ev_data[\"p\"][idx_start:idx_end], dtype=\"int64\")\n        p_array = np.clip(p_array, a_min=0, a_max=None)\n        t_array = np.asarray(self.time[idx_start:idx_end], dtype=\"int64\")\n        assert np.all(t_array[:-1] <= t_array[1:])\n        ev_data = dict(\n            x=x_array if not convert_2_torch else torch.from_numpy(x_array),\n            y=y_array if not convert_2_torch else torch.from_numpy(y_array),\n            p=p_array if not convert_2_torch else torch.from_numpy(p_array),\n            t=t_array if not convert_2_torch else torch.from_numpy(t_array),\n            height=self.height,\n            width=self.width,\n        )\n        return ev_data\n\n\ndef prophesee_bbox_filter(labels: np.ndarray, dataset_type: str) -> np.ndarray:\n    assert dataset_type in {\"gen1\", \"gen4\"}\n\n    # Default values taken from: https://github.com/prophesee-ai/prophesee-automotive-dataset-toolbox/blob/0393adea2bf22d833893c8cb1d986fcbe4e6f82d/src/psee_evaluator.py#L23-L24\n    min_box_diag = 60 if dataset_type == \"gen4\" else 30\n    # Corrected values from supplementary mat from paper for min_box_side!\n    min_box_side = 20 if dataset_type == \"gen4\" else 10\n\n    w_lbl = labels[\"w\"]\n    h_lbl = labels[\"h\"]\n\n    diag_ok = w_lbl**2 + h_lbl**2 >= min_box_diag**2\n    side_ok = (w_lbl >= min_box_side) & (h_lbl >= min_box_side)\n    keep = diag_ok & side_ok\n    labels = labels[keep]\n    return labels\n\n\ndef conservative_bbox_filter(labels: np.ndarray) -> np.ndarray:\n    w_lbl = labels[\"w\"]\n    h_lbl = labels[\"h\"]\n    min_box_side = 5\n    side_ok = (w_lbl >= min_box_side) & (h_lbl >= min_box_side)\n    labels = labels[side_ok]\n    return labels\n\n\ndef remove_faulty_huge_bbox_filter(labels: np.ndarray, dataset_type: str) -> np.ndarray:\n    \"\"\"There are some labels which span the frame horizontally without actually covering an object.\"\"\"\n    assert dataset_type in {\"gen1\", \"gen4\"}\n    w_lbl = labels[\"w\"]\n    max_width = (9 * dataset_2_width[dataset_type]) // 10\n    side_ok = w_lbl <= max_width\n    labels = labels[side_ok]\n    return labels\n\n\ndef crop_to_fov_filter(labels: np.ndarray, dataset_type: str) -> np.ndarray:\n    assert dataset_type in {\"gen1\", \"gen4\"}, f\"{dataset_type=}\"\n    # In the gen1 and gen4 datasets the bounding box can be partially or completely outside the frame.\n    # We fix this labeling error by cropping to the FOV.\n    frame_height = dataset_2_height[dataset_type]\n    frame_width = dataset_2_width[dataset_type]\n    x_left = labels[\"x\"]\n    y_top = labels[\"y\"]\n    x_right = x_left + labels[\"w\"]\n    y_bottom = y_top + labels[\"h\"]\n    x_left_cropped = np.clip(x_left, a_min=0, a_max=frame_width - 1)\n    y_top_cropped = np.clip(y_top, a_min=0, a_max=frame_height - 1)\n    x_right_cropped = np.clip(x_right, a_min=0, a_max=frame_width - 1)\n    y_bottom_cropped = np.clip(y_bottom, a_min=0, a_max=frame_height - 1)\n\n    w_cropped = x_right_cropped - x_left_cropped\n    assert np.all(w_cropped >= 0)\n    h_cropped = y_bottom_cropped - y_top_cropped\n    assert np.all(h_cropped >= 0)\n\n    labels[\"x\"] = x_left_cropped\n    labels[\"y\"] = y_top_cropped\n    labels[\"w\"] = w_cropped\n    labels[\"h\"] = h_cropped\n\n    # Remove bboxes that have 0 height or width\n    keep = (labels[\"w\"] > 0) & (labels[\"h\"] > 0)\n    labels = labels[keep]\n    return labels\n\n\ndef prophesee_remove_labels_filter_gen4(labels: np.ndarray) -> np.ndarray:\n    # Original gen4 labels: pedestrian, two wheeler, car, truck, bus, traffic sign, traffic light\n    # gen4 labels to keep: pedestrian, two wheeler, car\n    # gen4 labels to remove: truck, bus, traffic sign, traffic light\n    #\n    # class_id in {0, 1, 2, 3, 4, 5, 6} in the order mentioned above\n    keep = labels[\"class_id\"] <= 2\n    labels = labels[keep]\n    return labels\n\n\ndef apply_filters(\n    labels: np.ndarray,\n    split_type: SplitType,\n    filter_cfg: DictConfig,\n    dataset_type: str = \"gen1\",\n) -> np.ndarray:\n    assert isinstance(dataset_type, str)\n    if dataset_type == \"gen4\":\n        labels = prophesee_remove_labels_filter_gen4(labels=labels)\n    labels = crop_to_fov_filter(labels=labels, dataset_type=dataset_type)\n    if filter_cfg.apply_psee_bbox_filter:\n        labels = prophesee_bbox_filter(labels=labels, dataset_type=dataset_type)\n    else:\n        labels = conservative_bbox_filter(labels=labels)\n    if split_type == SplitType.TRAIN and filter_cfg.apply_faulty_bbox_filter:\n        labels = remove_faulty_huge_bbox_filter(\n            labels=labels, dataset_type=dataset_type\n        )\n    return labels\n\n\ndef get_base_delta_ts_for_labels_us(\n    unique_label_ts_us: np.ndarray, dataset_type: str = \"gen1\"\n) -> int:\n    if dataset_type == \"gen1\":\n        delta_t_us_4hz = 250000\n        return delta_t_us_4hz\n    assert dataset_type == \"gen4\"\n    diff_us = np.diff(unique_label_ts_us)\n    median_diff_us = np.median(diff_us)\n\n    hz = int(np.rint(10**6 / median_diff_us))\n    assert hz in {30, 60}, f\"{hz=} but should be either 30 or 60\"\n\n    delta_t_us_approx_10hz = int(6 * median_diff_us if hz == 60 else 3 * median_diff_us)\n    return delta_t_us_approx_10hz\n\n\ndef save_labels(\n    out_labels_dir: Path,\n    labels_per_frame: List[np.ndarray],\n    frame_timestamps_us: np.ndarray,\n    match_if_exists: bool = True,\n) -> None:\n    assert len(labels_per_frame) == len(frame_timestamps_us)\n    assert len(labels_per_frame) > 0\n    labels_v2 = list()\n    objframe_idx_2_label_idx = list()\n    start_idx = 0\n    for labels, timestamp in zip(labels_per_frame, frame_timestamps_us):\n        objframe_idx_2_label_idx.append(start_idx)\n        labels_v2.append(labels)\n        start_idx += len(labels)\n    assert len(labels_v2) == len(objframe_idx_2_label_idx)\n    labels_v2 = np.concatenate(labels_v2)\n\n    outfile_labels = out_labels_dir / \"labels.npz\"\n    if outfile_labels.exists() and match_if_exists:\n        data_existing = np.load(str(outfile_labels))\n        labels_existing = data_existing[\"labels\"]\n        assert np.array_equal(labels_existing, labels_v2)\n        oi_2_li_existing = data_existing[\"objframe_idx_2_label_idx\"]\n        assert np.array_equal(oi_2_li_existing, objframe_idx_2_label_idx)\n    else:\n        np.savez(\n            str(outfile_labels),\n            labels=labels_v2,\n            objframe_idx_2_label_idx=objframe_idx_2_label_idx,\n        )\n\n    out_labels_ts_file = out_labels_dir / \"timestamps_us.npy\"\n    if out_labels_ts_file.exists() and match_if_exists:\n        frame_timestamps_us_existing = np.load(str(out_labels_ts_file))\n        assert np.array_equal(frame_timestamps_us_existing, frame_timestamps_us)\n    else:\n        np.save(str(out_labels_ts_file), frame_timestamps_us)\n\n\ndef labels_and_ev_repr_timestamps(\n    npy_file: Path,\n    split_type: SplitType,\n    filter_cfg: DictConfig,\n    align_t_ms: int,\n    ts_step_ev_repr_ms: int,\n    dataset_type: str,\n):\n    assert npy_file.exists()\n    assert npy_file.suffix == \".npy\"\n    ts_step_frame_ms = 100\n    assert ts_step_frame_ms >= ts_step_ev_repr_ms\n    assert ts_step_frame_ms % ts_step_ev_repr_ms == 0 and ts_step_ev_repr_ms > 0\n\n    align_t_us = align_t_ms * 1000\n    delta_t_us = ts_step_ev_repr_ms * 1000\n\n    sequence_labels = np.load(str(npy_file))\n    assert len(sequence_labels) > 0\n\n    sequence_labels = apply_filters(\n        labels=sequence_labels,\n        split_type=split_type,\n        filter_cfg=filter_cfg,\n        dataset_type=dataset_type,\n    )\n    if sequence_labels.size == 0:\n        raise NoLabelsException\n\n    unique_ts_us = np.unique(np.asarray(sequence_labels[\"t\"], dtype=\"int64\"))\n\n    base_delta_ts_labels_us = get_base_delta_ts_for_labels_us(\n        unique_label_ts_us=unique_ts_us, dataset_type=dataset_type\n    )\n\n    # We extract the first label at or after align_t_us to keep it as the reference for the label extraction.\n    unique_ts_idx_first = np.searchsorted(unique_ts_us, align_t_us, side=\"left\")\n\n    # Extract \"frame\" timestamps from labels and prepare ev repr ts computation\n    num_ev_reprs_between_frame_ts = []\n    frame_timestamps_us = [unique_ts_us[unique_ts_idx_first]]\n    for unique_ts_idx in range(unique_ts_idx_first + 1, len(unique_ts_us)):\n        reference_time = frame_timestamps_us[-1]\n        ts = unique_ts_us[unique_ts_idx]\n        diff_to_ref = ts - reference_time\n        base_delta_count = round(diff_to_ref / base_delta_ts_labels_us)\n        diff_to_ref_rounded = base_delta_count * base_delta_ts_labels_us\n        if np.abs(diff_to_ref - diff_to_ref_rounded) <= 2000:\n            assert base_delta_count > 0\n            # We accept up to 2 millisecond of jitter\n            frame_timestamps_us.append(ts)\n            num_ev_reprs_between_frame_ts.append(\n                base_delta_count * (ts_step_frame_ms // ts_step_ev_repr_ms)\n            )\n    frame_timestamps_us = np.asarray(frame_timestamps_us, dtype=\"int64\")\n    assert len(frame_timestamps_us) > 0, f\"{npy_file=}\"\n\n    start_indices_per_label = np.searchsorted(\n        sequence_labels[\"t\"], frame_timestamps_us, side=\"left\"\n    )\n    end_indices_per_label = np.searchsorted(\n        sequence_labels[\"t\"], frame_timestamps_us, side=\"right\"\n    )\n\n    # Create labels per \"frame\"\n    labels_per_frame = []\n    for idx_start, idx_end in zip(start_indices_per_label, end_indices_per_label):\n        labels = sequence_labels[idx_start:idx_end]\n        label_time_us = labels[\"t\"][0]\n        assert np.all(labels[\"t\"] == label_time_us)\n        labels_per_frame.append(labels)\n\n    if len(frame_timestamps_us) > 1:\n        assert (\n            np.diff(frame_timestamps_us).min() > 98000\n        ), f\"{np.diff(frame_timestamps_us).min()=}\"\n\n    # Event repr timestamps generation\n    ev_repr_timestamps_us_end = list(\n        reversed(range(frame_timestamps_us[0], 0, -delta_t_us))\n    )[1:-1]\n    assert (\n        len(num_ev_reprs_between_frame_ts) == len(frame_timestamps_us) - 1\n    ), f\"{len(num_ev_reprs_between_frame_ts)=}, {len(frame_timestamps_us)=}\"\n    for idx, (num_ev_repr_between, frame_ts_us_start, frame_ts_us_end) in enumerate(\n        zip(\n            num_ev_reprs_between_frame_ts,\n            frame_timestamps_us[:-1],\n            frame_timestamps_us[1:],\n        )\n    ):\n        new_edge_timestamps = np.asarray(\n            np.linspace(frame_ts_us_start, frame_ts_us_end, num_ev_repr_between + 1),\n            dtype=\"int64\",\n        ).tolist()\n        is_last_iter = idx == len(num_ev_reprs_between_frame_ts) - 1\n        if not is_last_iter:\n            new_edge_timestamps = new_edge_timestamps[:-1]\n        ev_repr_timestamps_us_end.extend(new_edge_timestamps)\n    if len(frame_timestamps_us) == 1:\n        # special case not handled in above for loop (no iter in this case)\n        # yes, it's hacky ...\n        ev_repr_timestamps_us_end.append(frame_timestamps_us[0])\n    ev_repr_timestamps_us_end = np.asarray(ev_repr_timestamps_us_end, dtype=\"int64\")\n\n    frameidx_2_repridx = np.searchsorted(\n        ev_repr_timestamps_us_end, frame_timestamps_us, side=\"left\"\n    )\n    assert len(frameidx_2_repridx) == len(frame_timestamps_us)\n\n    # Some sanity checks:\n    assert len(labels_per_frame) == len(frame_timestamps_us)\n    assert len(frame_timestamps_us) == len(frameidx_2_repridx)\n    for label, frame_ts_us, repr_idx in zip(\n        labels_per_frame, frame_timestamps_us, frameidx_2_repridx\n    ):\n        assert label[\"t\"][0] == frame_ts_us\n        assert frame_ts_us == ev_repr_timestamps_us_end[repr_idx]\n\n    return (\n        labels_per_frame,\n        frame_timestamps_us,\n        ev_repr_timestamps_us_end,\n        frameidx_2_repridx,\n    )\n\n\ndef write_event_data(\n    in_h5_file: Path,\n    ev_out_dir: Path,\n    dataset: str,\n    event_representation: RepresentationBase,\n    ev_repr_num_events: Optional[int],\n    ev_repr_delta_ts_ms: Optional[int],\n    ev_repr_timestamps_us: np.ndarray,\n    downsample_by_2: bool,\n    frameidx2repridx: np.ndarray,\n) -> None:\n    frameidx2repridx_file = ev_out_dir / \"objframe_idx_2_repr_idx.npy\"\n    if frameidx2repridx_file.exists():\n        frameidx2repridx_loaded = np.load(str(frameidx2repridx_file))\n        assert np.array_equal(frameidx2repridx_loaded, frameidx2repridx)\n    else:\n        np.save(str(frameidx2repridx_file), frameidx2repridx)\n    timestamps_file = ev_out_dir / \"timestamps_us.npy\"\n    if timestamps_file.exists():\n        timestamps_loaded = np.load(str(timestamps_file))\n        assert np.array_equal(timestamps_loaded, ev_repr_timestamps_us)\n    else:\n        np.save(str(timestamps_file), ev_repr_timestamps_us)\n    write_event_representations(\n        in_h5_file=in_h5_file,\n        ev_out_dir=ev_out_dir,\n        dataset=dataset,\n        event_representation=event_representation,\n        ev_repr_num_events=ev_repr_num_events,\n        ev_repr_delta_ts_ms=ev_repr_delta_ts_ms,\n        ev_repr_timestamps_us=ev_repr_timestamps_us,\n        downsample_by_2=downsample_by_2,\n        overwrite_if_exists=False,\n    )\n\n\ndef downsample_ev_repr(x: torch.Tensor, scale_factor: float):\n    assert 0 < scale_factor < 1\n    orig_dtype = x.dtype\n    if orig_dtype == torch.int8:\n        x = torch.asarray(x, dtype=torch.int16)\n        x = torch.asarray(x + 128, dtype=torch.uint8)\n    x = torch.nn.functional.interpolate(\n        x, scale_factor=scale_factor, mode=\"nearest-exact\"\n    )\n    if orig_dtype == torch.int8:\n        x = torch.asarray(x, dtype=torch.int16)\n        x = torch.asarray(x - 128, dtype=torch.int8)\n    return x\n\n\ndef write_event_representations(\n    in_h5_file: Path,\n    ev_out_dir: Path,\n    dataset: str,\n    event_representation: RepresentationBase,\n    ev_repr_num_events: Optional[int],\n    ev_repr_delta_ts_ms: Optional[int],\n    ev_repr_timestamps_us: np.ndarray,\n    downsample_by_2: bool,\n    overwrite_if_exists: bool = False,\n) -> None:\n    ev_outfile = (\n        ev_out_dir\n        / f\"event_representations{'_ds2_nearest' if downsample_by_2 else ''}.h5\"\n    )\n    if ev_outfile.exists() and not overwrite_if_exists:\n        return\n    ev_outfile_in_progress = ev_outfile.parent / (\n        ev_outfile.stem + \"_in_progress\" + ev_outfile.suffix\n    )\n    if ev_outfile_in_progress.exists():\n        os.remove(ev_outfile_in_progress)\n    ev_repr_shape = tuple(event_representation.get_shape())\n    if downsample_by_2:\n        ev_repr_shape = ev_repr_shape[0], ev_repr_shape[1] // 2, ev_repr_shape[2] // 2\n    ev_repr_dtype = event_representation.get_numpy_dtype()\n    with H5Reader(in_h5_file, dataset=dataset) as h5_reader, H5Writer(\n        ev_outfile_in_progress,\n        key=\"data\",\n        ev_repr_shape=ev_repr_shape,\n        numpy_dtype=ev_repr_dtype,\n    ) as h5_writer:\n        height, width = h5_reader.get_height_and_width()\n        if downsample_by_2:\n            assert (height // 2, width // 2) == ev_repr_shape[-2:]\n        else:\n            assert (height, width) == ev_repr_shape[-2:]\n        ev_ts_us = h5_reader.time\n\n        end_indices = np.searchsorted(ev_ts_us, ev_repr_timestamps_us, side=\"right\")\n        if ev_repr_num_events is not None:\n            start_indices = np.maximum(end_indices - ev_repr_num_events, 0)\n        else:\n            assert ev_repr_delta_ts_ms is not None\n            start_indices = np.searchsorted(\n                ev_ts_us,\n                ev_repr_timestamps_us - ev_repr_delta_ts_ms * 1000,\n                side=\"left\",\n            )\n\n        for idx_start, idx_end in zip(start_indices, end_indices):\n            ev_window = h5_reader.get_event_slice(idx_start=idx_start, idx_end=idx_end)\n\n            ev_repr = event_representation.construct(\n                x=ev_window[\"x\"],\n                y=ev_window[\"y\"],\n                pol=ev_window[\"p\"],\n                time=ev_window[\"t\"],\n            )\n            if downsample_by_2:\n                ev_repr = ev_repr.unsqueeze(0)\n                ev_repr = downsample_ev_repr(x=ev_repr, scale_factor=0.5)\n                ev_repr_numpy = ev_repr.numpy()[0]\n            else:\n                ev_repr_numpy = ev_repr.numpy()\n            h5_writer.add_data(ev_repr_numpy)\n        num_written_ev_repr = h5_writer.get_current_length()\n    assert num_written_ev_repr == len(ev_repr_timestamps_us)\n    os.rename(ev_outfile_in_progress, ev_outfile)\n\n\ndef process_sequence(\n    dataset: str,\n    filter_cfg: DictConfig,\n    event_representation: RepresentationBase,\n    ev_repr_num_events: Optional[int],\n    ev_repr_delta_ts_ms: Optional[int],\n    ts_step_ev_repr_ms: int,\n    downsample_by_2: bool,\n    sequence_data: Dict[DataKeys, Union[Path, SplitType]],\n):\n    in_npy_file = sequence_data[DataKeys.InNPY]\n    in_h5_file = sequence_data[DataKeys.InH5]\n    out_labels_dir = sequence_data[DataKeys.OutLabelDir]\n    out_ev_repr_dir = sequence_data[DataKeys.OutEvReprDir]\n    split_type = sequence_data[DataKeys.SplitType]\n    assert out_labels_dir.is_dir()\n    assert ts_step_ev_repr_ms > 0\n    assert bool(ev_repr_num_events is not None) ^ bool(\n        ev_repr_delta_ts_ms is not None\n    ), f\"{ev_repr_num_events=}, {ev_repr_delta_ts_ms=}\"\n\n    # 1) extract: labels_per_frame, frame_timestamps_us, ev_repr_timestamps_us, frameidx2repridx\n    align_t_ms = 100\n    try:\n        (\n            labels_per_frame,\n            frame_timestamps_us,\n            ev_repr_timestamps_us,\n            frameidx2repridx,\n        ) = labels_and_ev_repr_timestamps(\n            npy_file=in_npy_file,\n            split_type=split_type,\n            filter_cfg=filter_cfg,\n            align_t_ms=align_t_ms,\n            ts_step_ev_repr_ms=ts_step_ev_repr_ms,\n            dataset_type=dataset,\n        )\n    except NoLabelsException:\n        parent_dir = out_labels_dir.parent\n        print(f\"No labels after filtering. Deleting {str(parent_dir)}\")\n        shutil.rmtree(parent_dir)\n        return\n\n    # 2) save: labels_per_frame, frame_timestamps_us\n    save_labels(\n        out_labels_dir=out_labels_dir,\n        labels_per_frame=labels_per_frame,\n        frame_timestamps_us=frame_timestamps_us,\n    )\n\n    # 3) retrieve event data, compute event representations and save them\n    write_event_data(\n        in_h5_file=in_h5_file,\n        ev_out_dir=out_ev_repr_dir,\n        dataset=dataset,\n        event_representation=event_representation,\n        ev_repr_num_events=ev_repr_num_events,\n        ev_repr_delta_ts_ms=ev_repr_delta_ts_ms,\n        ev_repr_timestamps_us=ev_repr_timestamps_us,\n        downsample_by_2=downsample_by_2,\n        frameidx2repridx=frameidx2repridx,\n    )\n\n\nclass AggregationType(Enum):\n    COUNT = auto()\n    DURATION = auto()\n\n\naggregation_2_string = {\n    AggregationType.DURATION: \"dt\",\n    AggregationType.COUNT: \"ne\",\n}\n\n\n@dataclass\nclass FilterConf:\n    apply_psee_bbox_filter: bool = MISSING\n    apply_faulty_bbox_filter: bool = MISSING\n\n\n@dataclass\nclass EventWindowExtractionConf:\n    method: AggregationType = MISSING\n    value: int = MISSING\n\n\n@dataclass\nclass StackedHistogramConf:\n    name: str = MISSING\n    nbins: int = MISSING\n    count_cutoff: Optional[int] = MISSING\n    event_window_extraction: EventWindowExtractionConf = field(\n        default_factory=EventWindowExtractionConf\n    )\n    fastmode: bool = True\n\n\n@dataclass\nclass MixedDensityEventStackConf:\n    name: str = MISSING\n    nbins: int = MISSING\n    count_cutoff: Optional[int] = MISSING\n    event_window_extraction: EventWindowExtractionConf = field(\n        default_factory=EventWindowExtractionConf\n    )\n\n\nname_2_structured_config = {\n    \"stacked_histogram\": StackedHistogramConf,\n    \"mixeddensity_stack\": MixedDensityEventStackConf,\n}\n\n\nclass EventRepresentationFactory(ABC):\n    def __init__(self, config: DictConfig):\n        self.config = config\n\n    @property\n    @abstractmethod\n    def name(self) -> str: ...\n\n    @abstractmethod\n    def create(self, height: int, width: int) -> Any: ...\n\n\nclass StackedHistogramFactory(EventRepresentationFactory):\n    @property\n    def name(self) -> str:\n        extraction = self.config.event_window_extraction\n        return f\"{self.config.name}_{aggregation_2_string[extraction.method]}={extraction.value}_nbins={self.config.nbins}\"\n\n    def create(self, height: int, width: int) -> StackedHistogram:\n        return StackedHistogram(\n            bins=self.config.nbins,\n            height=height,\n            width=width,\n            count_cutoff=self.config.count_cutoff,\n            fastmode=self.config.fastmode,\n        )\n\n\nclass MixedDensityStackFactory(EventRepresentationFactory):\n    @property\n    def name(self) -> str:\n        extraction = self.config.event_window_extraction\n        cutoff_str = (\n            f\"_cutoff={self.config.count_cutoff}\"\n            if self.config.count_cutoff is not None\n            else \"\"\n        )\n        return f\"{self.config.name}_{aggregation_2_string[extraction.method]}={extraction.value}_nbins={self.config.nbins}{cutoff_str}\"\n\n    def create(self, height: int, width: int) -> MixedDensityEventStack:\n        return MixedDensityEventStack(\n            bins=self.config.nbins,\n            height=height,\n            width=width,\n            count_cutoff=self.config.count_cutoff,\n        )\n\n\nname_2_ev_repr_factory = {\n    \"stacked_histogram\": StackedHistogramFactory,\n    \"mixeddensity_stack\": MixedDensityStackFactory,\n}\n\n\ndef get_configuration(\n    ev_repr_yaml_config: Path, extraction_yaml_config: Path\n) -> DictConfig:\n    config = OmegaConf.load(ev_repr_yaml_config)\n    event_window_extraction_config = OmegaConf.load(extraction_yaml_config)\n    event_window_extraction_config = OmegaConf.merge(\n        OmegaConf.structured(EventWindowExtractionConf), event_window_extraction_config\n    )\n    config.event_window_extraction = event_window_extraction_config\n    config_schema = OmegaConf.structured(name_2_structured_config[config.name])\n    config = OmegaConf.merge(config_schema, config)\n    return config\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"input_dir\")\n    parser.add_argument(\"target_dir\")\n    parser.add_argument(\n        \"ev_repr_yaml_config\", help=\"Path to event representation yaml config file\"\n    )\n    parser.add_argument(\n        \"extraction_yaml_config\",\n        help=\"Path to event window extraction yaml config file\",\n    )\n    parser.add_argument(\n        \"bbox_filter_yaml_config\", help=\"Path to bbox filter yaml config file\"\n    )\n    parser.add_argument(\"-ds\", \"--dataset\", default=\"gen1\", help=\"gen1 or gen4\")\n    parser.add_argument(\n        \"-np\",\n        \"--num_processes\",\n        type=int,\n        default=1,\n        help=\"Num proceesses to run in parallel\",\n    )\n    args = parser.parse_args()\n\n    num_processes = args.num_processes\n\n    dataset = args.dataset\n    assert dataset in (\"gen1\", \"gen4\")\n    downsample_by_2 = True if dataset == \"gen4\" else False\n\n    config = get_configuration(\n        ev_repr_yaml_config=Path(args.ev_repr_yaml_config),\n        extraction_yaml_config=Path(args.extraction_yaml_config),\n    )\n\n    bbox_filter_yaml_config = Path(args.bbox_filter_yaml_config)\n    assert bbox_filter_yaml_config.exists()\n    filter_cfg = OmegaConf.load(str(bbox_filter_yaml_config))\n    filter_cfg = OmegaConf.merge(OmegaConf.structured(FilterConf), filter_cfg)\n\n    print(\"\")\n    print(OmegaConf.to_yaml(config))\n\n    ev_repr_factory: EventRepresentationFactory = name_2_ev_repr_factory[config.name](\n        config\n    )\n    height = dataset_2_height[args.dataset]\n    width = dataset_2_width[args.dataset]\n    ev_repr = ev_repr_factory.create(height=height, width=width)\n    ev_repr_string = ev_repr_factory.name\n\n    dataset_input_path = Path(args.input_dir)\n    train_path = dataset_input_path / \"train\"\n    val_path = dataset_input_path / \"val\"\n    test_path = dataset_input_path / \"test\"\n    target_dir = Path(args.target_dir)\n    os.makedirs(target_dir, exist_ok=True)\n\n    assert train_path.exists(), f\"{train_path=}\"\n    assert val_path.exists(), f\"{val_path=}\"\n    assert test_path.exists(), f\"{test_path=}\"\n\n    seq_data_list = list()\n    for split in [train_path, val_path, test_path]:\n        split_out_dir = target_dir / split.name\n        os.makedirs(split_out_dir, exist_ok=True)\n        for npy_file in split.iterdir():\n            if npy_file.suffix != \".npy\":\n                continue\n            h5f_path = npy_file.parent / (\n                npy_file.stem.split(\"bbox\")[0]\n                + f\"td{'.dat' if dataset == 'gen1' else ''}.h5\"\n            )\n            assert h5f_path.exists(), f\"{h5f_path=}\"\n\n            dir_name = npy_file.stem.split(\"_bbox\")[0]\n            if dir_name in dirs_to_ignore[dataset]:\n                continue\n            out_seq_path = split_out_dir / dir_name\n\n            out_labels_path = out_seq_path / \"labels_v2\"\n            os.makedirs(out_labels_path, exist_ok=True)\n\n            out_ev_repr_parent_path = out_seq_path / \"event_representations_v2\"\n            out_ev_repr_path = out_ev_repr_parent_path / ev_repr_string\n            os.makedirs(out_ev_repr_path, exist_ok=True)\n\n            sequence_data = {\n                DataKeys.InNPY: npy_file,\n                DataKeys.InH5: h5f_path,\n                DataKeys.OutLabelDir: out_labels_path,\n                DataKeys.OutEvReprDir: out_ev_repr_path,\n                DataKeys.SplitType: split_name_2_type[split.name],\n            }\n            seq_data_list.append(sequence_data)\n\n    ev_repr_num_events = None\n    ev_repr_delta_ts_ms = None\n    if config.event_window_extraction.method == AggregationType.COUNT:\n        ev_repr_num_events = config.event_window_extraction.value\n    else:\n        assert config.event_window_extraction.method == AggregationType.DURATION\n        ev_repr_delta_ts_ms = config.event_window_extraction.value\n    ts_step_ev_repr_ms = 50  # Could be an argument of the script.\n\n    if num_processes > 1:\n        chunksize = 1\n        func = partial(\n            process_sequence,\n            dataset,\n            filter_cfg,\n            ev_repr,\n            ev_repr_num_events,\n            ev_repr_delta_ts_ms,\n            ts_step_ev_repr_ms,\n            downsample_by_2,\n        )\n        with get_context(\"spawn\").Pool(num_processes) as pool:\n            with tqdm(total=len(seq_data_list), desc=\"sequences\") as pbar:\n                for _ in pool.imap_unordered(\n                    func, iterable=seq_data_list, chunksize=chunksize\n                ):\n                    pbar.update()\n    else:\n        for entry in tqdm(seq_data_list, desc=\"sequences\"):\n            process_sequence(\n                dataset=dataset,\n                filter_cfg=filter_cfg,\n                event_representation=ev_repr,\n                ev_repr_num_events=ev_repr_num_events,\n                ev_repr_delta_ts_ms=ev_repr_delta_ts_ms,\n                ts_step_ev_repr_ms=ts_step_ev_repr_ms,\n                downsample_by_2=downsample_by_2,\n                sequence_data=entry,\n            )\n"
  },
  {
    "path": "RVT/scripts/genx/preprocess_dataset.sh",
    "content": "NUM_PROCESSES=20  # set to the number of parallel processes to use\nDATA_DIR=/data/scratch1/nzubic/datasets/gen1_tar/\nDEST_DIR=/data/scratch1/nzubic/datasets/RVT/gen1_frequencies/gen1_200hz/\nFREQUENCY=conf_preprocess/extraction/frequencies/const_duration_200hz.yaml\n\npython preprocess_dataset.py ${DATA_DIR} ${DEST_DIR} conf_preprocess/representation/stacked_hist.yaml ${FREQUENCY} \\\nconf_preprocess/filter_gen1.yaml -ds gen1 -np ${NUM_PROCESSES}\n"
  },
  {
    "path": "RVT/scripts/viz/viz_gt.py",
    "content": "import os\n\nos.environ[\"OMP_NUM_THREADS\"] = \"1\"  # export OMP_NUM_THREADS=1\nos.environ[\"OPENBLAS_NUM_THREADS\"] = \"1\"  # export OPENBLAS_NUM_THREADS=1\nos.environ[\"MKL_NUM_THREADS\"] = \"1\"  # export MKL_NUM_THREADS=1\nos.environ[\"VECLIB_MAXIMUM_THREADS\"] = \"1\"  # export VECLIB_MAXIMUM_THREADS=1\nos.environ[\"NUMEXPR_NUM_THREADS\"] = \"1\"  # export NUMEXPR_NUM_THREADS=1\n\nfrom pathlib import Path\nimport sys\n\ncurrent_filepath = Path(os.path.realpath(__file__))\nsys.path.insert(0, str(current_filepath.parent.parent.parent))\nfrom typing import Tuple, Optional\n\nimport imageio.v3 as iio\nimport torch as th\nfrom tqdm import tqdm\n\nfrom data.utils.types import DataType, DatasetType\nfrom data.genx_utils.sequence_for_streaming import SequenceForIter\nfrom data.genx_utils.labels import ObjectLabels\nfrom utils.evaluation.prophesee.io.box_loading import loaded_label_to_prophesee\nfrom callbacks.viz_base import VizCallbackBase\nimport cv2\nimport numpy as np\nimport bbox_visualizer as bbv\nimport hdf5plugin\n\nLABELMAP_GEN1 = (\"car\", \"pedestrian\")\nLABELMAP_GEN4_SHORT = (\"pedestrian\", \"two wheeler\", \"car\")\n\n\ndef draw_bboxes_bbv(\n    img, boxes, labelmap=LABELMAP_GEN1, hd_resolution: bool = False\n) -> np.ndarray:\n    \"\"\"\n    draw bboxes in the image img\n    \"\"\"\n    colors = cv2.applyColorMap(np.arange(0, 255).astype(np.uint8), cv2.COLORMAP_HSV)\n    colors = [tuple(*item) for item in colors.tolist()]\n\n    if labelmap == LABELMAP_GEN1:\n        classid2colors = {\n            0: (255, 255, 0),  # car -> yellow (rgb)\n            1: (0, 0, 255),  # ped -> blue (rgb)\n        }\n        scale_multiplier = 4\n    else:\n        assert labelmap == LABELMAP_GEN4_SHORT\n        classid2colors = {\n            0: (0, 0, 255),  # ped -> blue (rgb)\n            1: (0, 255, 255),  # 2-wheeler cyan (rgb)\n            2: (255, 255, 0),  # car -> yellow (rgb)\n        }\n        scale_multiplier = 1 if hd_resolution else 2\n\n    add_score = True\n    ht, wd, ch = img.shape\n    dim_new_wh = (int(wd * scale_multiplier), int(ht * scale_multiplier))\n    if scale_multiplier != 1:\n        img = cv2.resize(img, dim_new_wh, interpolation=cv2.INTER_AREA)\n    for i in range(boxes.shape[0]):\n        pt1 = (int(boxes[\"x\"][i]), int(boxes[\"y\"][i]))\n        size = (int(boxes[\"w\"][i]), int(boxes[\"h\"][i]))\n        pt2 = (pt1[0] + size[0], pt1[1] + size[1])\n        bbox = (pt1[0], pt1[1], pt2[0], pt2[1])\n        bbox = tuple(x * scale_multiplier for x in bbox)\n\n        score = boxes[\"class_confidence\"][i]\n        class_id = boxes[\"class_id\"][i]\n        class_name = labelmap[class_id % len(labelmap)]\n        bbox_txt = class_name\n        if add_score:\n            bbox_txt += f\" {score:.2f}\"\n        color_tuple_rgb = classid2colors[class_id]\n        img = bbv.draw_rectangle(img, bbox, bbox_color=color_tuple_rgb)\n        img = bbv.add_label(\n            img, bbox_txt, bbox, text_bg_color=color_tuple_rgb, top=True\n        )\n\n    return img\n\n\ndef draw_predictions(\n    ev_repr: th.Tensor,\n    predictions_proph,\n    hd_resolution: bool = False,\n    labelmap=LABELMAP_GEN4_SHORT,\n):\n    img = VizCallbackBase.ev_repr_to_img(ev_repr.cpu().numpy())\n    if predictions_proph is not None:\n        img = draw_bboxes_bbv(\n            img, predictions_proph, labelmap=labelmap, hd_resolution=hd_resolution\n        )\n    return img\n\n\ndef gen_gt_generator(\n    seq_path: Path,\n    ev_representation_name: str,\n    downsample_by_factor_2: bool,\n    dataset_type: DatasetType = DatasetType.GEN4,\n) -> Tuple[th.Tensor, Optional[ObjectLabels]]:\n    sequence_length = 5\n\n    if dataset_type == DatasetType.GEN1:\n        map_dataset = SequenceForIter(\n            path=seq_path,\n            ev_representation_name=ev_representation_name,\n            sequence_length=sequence_length,\n            dataset_type=DatasetType.GEN1,\n            downsample_by_factor_2=downsample_by_factor_2,\n        )\n    else:\n        map_dataset = SequenceForIter(\n            path=seq_path,\n            ev_representation_name=ev_representation_name,\n            sequence_length=sequence_length,\n            dataset_type=DatasetType.GEN4,\n            downsample_by_factor_2=downsample_by_factor_2,\n        )\n\n    iter_dataset = map_dataset.to_iter_datapipe()\n\n    for data in iter_dataset:\n        seq_ev_reprs = data[DataType.EV_REPR]\n        seq_labels = data[DataType.OBJLABELS_SEQ]\n\n        for idx, ev_repr in enumerate(seq_ev_reprs):\n            labels = seq_labels[idx]\n            yield ev_repr, labels\n\n\nif __name__ == \"__main__\":\n    SEQUENCE_PATH = \"/data/scratch1/nzubic/datasets/RVT/gen1_frequencies/gen1_40hz/test/17-04-04_11-00-13_cut_15_500000_60500000/\"\n    OUT_DIR_PATH = \"/data/scratch1/nzubic/out_viz/\"\n    DOWNSAMPLE = False\n    EV_REPR_NAME = \"stacked_histogram_dt=25_nbins=10\"  # dt varies depending on different frequencies\n    DATASET_TYPE = DatasetType.GEN1\n\n    seq_path = Path(SEQUENCE_PATH)\n    out_dir = Path(OUT_DIR_PATH)\n    os.makedirs(out_dir, exist_ok=False)\n\n    if DATASET_TYPE == DatasetType.GEN1:\n        labelmap = LABELMAP_GEN1\n    else:\n        labelmap = LABELMAP_GEN4_SHORT\n\n    viz_at_hd_resolution = None\n    prev_img_with_labels = None\n    for idx, (ev_repr, labels) in enumerate(\n        tqdm(\n            gen_gt_generator(\n                seq_path=seq_path,\n                ev_representation_name=EV_REPR_NAME,\n                downsample_by_factor_2=DOWNSAMPLE,\n                dataset_type=DATASET_TYPE,\n            )\n        )\n    ):\n        if viz_at_hd_resolution is None:\n            height, width = ev_repr.shape[-2:]\n            viz_at_hd_resolution = height * width > 9e5\n\n        have_labels = labels is not None\n        labels_proph = loaded_label_to_prophesee(labels) if have_labels else None\n        img = draw_predictions(\n            ev_repr=ev_repr,\n            predictions_proph=labels_proph,\n            hd_resolution=viz_at_hd_resolution,\n            labelmap=labelmap,\n        )\n\n        filename = f\"{idx}\".zfill(6) + \".png\"\n        img_filepath = out_dir / filename\n\n        if have_labels or prev_img_with_labels is None:\n            img_to_write = img\n        else:\n            img_to_write = prev_img_with_labels\n\n        iio.imwrite(str(img_filepath), img_to_write)\n\n        if labels_proph is not None:\n            prev_img_with_labels = img\n"
  },
  {
    "path": "RVT/train.py",
    "content": "import os\n\nos.environ[\"CUDA_DEVICE_ORDER\"] = \"PCI_BUS_ID\"\nos.environ[\"OMP_NUM_THREADS\"] = \"1\"\nos.environ[\"OPENBLAS_NUM_THREADS\"] = \"1\"\nos.environ[\"MKL_NUM_THREADS\"] = \"1\"\nos.environ[\"VECLIB_MAXIMUM_THREADS\"] = \"1\"\nos.environ[\"NUMEXPR_NUM_THREADS\"] = \"1\"\n\nimport torch\n\ntorch.multiprocessing.set_sharing_strategy(\"file_system\")\nfrom torch.backends import cuda, cudnn\n\ncuda.matmul.allow_tf32 = True\ncudnn.allow_tf32 = True\n\nimport hydra\nimport hdf5plugin\nfrom omegaconf import DictConfig, OmegaConf\nimport lightning.pytorch as pl\nfrom lightning.pytorch.callbacks import LearningRateMonitor, ModelSummary\nfrom lightning.pytorch.strategies import DDPStrategy\n\nfrom callbacks.custom import get_ckpt_callback, get_viz_callback\nfrom callbacks.gradflow import GradFlowLogCallback\nfrom config.modifier import dynamically_modify_train_config\nfrom data.utils.types import DatasetSamplingMode\nfrom loggers.utils import get_wandb_logger, get_ckpt_path\nfrom modules.utils.fetch import fetch_data_module, fetch_model_module\nfrom modules.detection import Module\n\n\n@hydra.main(config_path=\"config\", config_name=\"train\", version_base=\"1.2\")\ndef main(config: DictConfig):\n    dynamically_modify_train_config(config)\n    # Just to check whether config can be resolved\n    OmegaConf.to_container(config, resolve=True, throw_on_missing=True)\n\n    print(\"------ Configuration ------\")\n    print(OmegaConf.to_yaml(config))\n    print(\"---------------------------\")\n\n    # ---------------------\n    # Reproducibility\n    # ---------------------\n    dataset_train_sampling = config.dataset.train.sampling\n    assert dataset_train_sampling in iter(DatasetSamplingMode)\n    disable_seed_everything = dataset_train_sampling in (\n        DatasetSamplingMode.STREAM,\n        DatasetSamplingMode.MIXED,\n    )\n    if disable_seed_everything:\n        print(\n            \"Disabling PL seed everything because of unresolved issues with shuffling during training on streaming \"\n            \"datasets\"\n        )\n    seed = config.reproduce.seed_everything\n    if seed is not None and not disable_seed_everything:\n        assert isinstance(seed, int)\n        print(f\"USING pl.seed_everything WITH {seed=}\")\n        pl.seed_everything(seed=seed, workers=True)\n\n    # ---------------------\n    # DDP\n    # ---------------------\n    gpu_config = config.hardware.gpus\n    gpus = (\n        OmegaConf.to_container(gpu_config)\n        if OmegaConf.is_config(gpu_config)\n        else gpu_config\n    )\n    gpus = gpus if isinstance(gpus, list) else [gpus]\n    distributed_backend = config.hardware.dist_backend\n    assert distributed_backend in (\"nccl\", \"gloo\"), f\"{distributed_backend=}\"\n    strategy = (\n        DDPStrategy(\n            process_group_backend=distributed_backend,\n            find_unused_parameters=True,\n            gradient_as_bucket_view=True,\n        )\n        if len(gpus) > 1\n        else \"auto\"\n    )\n\n    # ---------------------\n    # Data\n    # ---------------------\n    data_module = fetch_data_module(config=config)\n\n    # ---------------------\n    # Logging and Checkpoints\n    # ---------------------\n    logger = get_wandb_logger(config)\n    ckpt_path = None\n    if config.wandb.artifact_name is not None:\n        ckpt_path = get_ckpt_path(logger, wandb_config=config.wandb)\n\n    # ---------------------\n    # Model\n    # ---------------------\n    module = fetch_model_module(config=config)\n    if ckpt_path is not None and config.wandb.resume_only_weights:\n        print(\"Resuming only the weights instead of the full training state\")\n        module = Module.load_from_checkpoint(\n            str(ckpt_path), **{\"full_config\": config}, strict=False\n        )\n\n        ckpt_path = None\n\n    # ---------------------\n    # Callbacks and Misc\n    # ---------------------\n    callbacks = list()\n    callbacks.append(get_ckpt_callback(config))\n    callbacks.append(GradFlowLogCallback(config.logging.train.log_model_every_n_steps))\n    if config.training.lr_scheduler.use:\n        callbacks.append(LearningRateMonitor(logging_interval=\"step\"))\n    if (\n        config.logging.train.high_dim.enable\n        or config.logging.validation.high_dim.enable\n    ):\n        viz_callback = get_viz_callback(config=config)\n        callbacks.append(viz_callback)\n    callbacks.append(ModelSummary(max_depth=2))\n\n    logger.watch(\n        model=module,\n        log=\"all\",\n        log_freq=config.logging.train.log_model_every_n_steps,\n        log_graph=True,\n    )\n\n    # ---------------------\n    # Training\n    # ---------------------\n\n    val_check_interval = config.validation.val_check_interval\n    check_val_every_n_epoch = config.validation.check_val_every_n_epoch\n    assert val_check_interval is None or check_val_every_n_epoch is None\n\n    trainer = pl.Trainer(\n        accelerator=\"gpu\",\n        callbacks=callbacks,\n        enable_checkpointing=True,\n        val_check_interval=val_check_interval,\n        check_val_every_n_epoch=check_val_every_n_epoch,\n        default_root_dir=None,\n        devices=gpus,\n        gradient_clip_val=config.training.gradient_clip_val,\n        gradient_clip_algorithm=\"value\",\n        limit_train_batches=config.training.limit_train_batches,\n        limit_val_batches=config.validation.limit_val_batches,\n        logger=logger,\n        log_every_n_steps=config.logging.train.log_every_n_steps,\n        plugins=None,\n        precision=config.training.precision,\n        max_epochs=config.training.max_epochs,\n        max_steps=config.training.max_steps,\n        strategy=strategy,\n        sync_batchnorm=False if strategy == \"auto\" else True,\n        # move_metrics_to_cpu=False,\n        benchmark=config.reproduce.benchmark,\n        deterministic=config.reproduce.deterministic_flag,\n    )\n    trainer.fit(model=module, ckpt_path=ckpt_path, datamodule=data_module)\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "RVT/utils/evaluation/prophesee/__init__.py",
    "content": ""
  },
  {
    "path": "RVT/utils/evaluation/prophesee/evaluation.py",
    "content": "from .io.box_filtering import filter_boxes\nfrom .metrics.coco_eval import evaluate_detection\n\n\ndef evaluate_list(\n    result_boxes_list,\n    gt_boxes_list,\n    height: int,\n    width: int,\n    camera: str = \"gen1\",\n    apply_bbox_filters: bool = True,\n    downsampled_by_2: bool = False,\n    return_aps: bool = True,\n):\n    assert camera in {\"gen1\", \"gen4\"}\n\n    if camera == \"gen1\":\n        classes = (\"car\", \"pedestrian\")\n    elif camera == \"gen4\":\n        classes = (\"pedestrian\", \"two-wheeler\", \"car\")\n    else:\n        raise NotImplementedError\n\n    if apply_bbox_filters:\n        # Default values taken from: https://github.com/prophesee-ai/prophesee-automotive-dataset-toolbox/blob/0393adea2bf22d833893c8cb1d986fcbe4e6f82d/src/psee_evaluator.py#L23-L24\n        min_box_diag = 60 if camera == \"gen4\" else 30\n        # In the supplementary mat, they say that min_box_side is 20 for gen4.\n        min_box_side = 20 if camera == \"gen4\" else 10\n        if downsampled_by_2:\n            assert min_box_diag % 2 == 0\n            min_box_diag //= 2\n            assert min_box_side % 2 == 0\n            min_box_side //= 2\n\n        half_sec_us = int(5e5)\n        filter_boxes_fn = lambda x: filter_boxes(\n            x, half_sec_us, min_box_diag, min_box_side\n        )\n\n        gt_boxes_list = map(filter_boxes_fn, gt_boxes_list)\n        # NOTE: We also filter the prediction to follow the prophesee protocol of evaluation.\n        result_boxes_list = map(filter_boxes_fn, result_boxes_list)\n\n    return evaluate_detection(\n        gt_boxes_list,\n        result_boxes_list,\n        height=height,\n        width=width,\n        classes=classes,\n        return_aps=return_aps,\n    )\n"
  },
  {
    "path": "RVT/utils/evaluation/prophesee/evaluator.py",
    "content": "from typing import Any, List, Optional, Dict\nfrom warnings import warn\n\nimport numpy as np\n\nfrom utils.evaluation.prophesee.evaluation import evaluate_list\n\n\nclass PropheseeEvaluator:\n    LABELS = \"lables\"\n    PREDICTIONS = \"predictions\"\n\n    def __init__(self, dataset: str, downsample_by_2: bool):\n        super().__init__()\n        assert dataset in {\"gen1\", \"gen4\"}\n        self.dataset = dataset\n        self.downsample_by_2 = downsample_by_2\n\n        self._buffer = None\n        self._buffer_empty = True\n        self._reset_buffer()\n\n    def _reset_buffer(self):\n        self._buffer_empty = True\n        self._buffer = {\n            self.LABELS: list(),\n            self.PREDICTIONS: list(),\n        }\n\n    def _add_to_buffer(self, key: str, value: List[np.ndarray]):\n        assert isinstance(value, list)\n        for entry in value:\n            assert isinstance(entry, np.ndarray)\n        self._buffer_empty = False\n        assert self._buffer is not None\n        self._buffer[key].extend(value)\n\n    def _get_from_buffer(self, key: str) -> List[np.ndarray]:\n        assert not self._buffer_empty\n        assert self._buffer is not None\n        return self._buffer[key]\n\n    def add_predictions(self, predictions: List[np.ndarray]):\n        self._add_to_buffer(self.PREDICTIONS, predictions)\n\n    def add_labels(self, labels: List[np.ndarray]):\n        self._add_to_buffer(self.LABELS, labels)\n\n    def reset_buffer(self) -> None:\n        # E.g. call in on_validation_epoch_start\n        self._reset_buffer()\n\n    def has_data(self):\n        return not self._buffer_empty\n\n    def evaluate_buffer(\n        self, img_height: int, img_width: int\n    ) -> Optional[Dict[str, Any]]:\n        # e.g call in on_validation_epoch_end\n        if self._buffer_empty:\n            warn(\n                \"Attempt to use prophesee evaluation buffer, but it is empty\",\n                UserWarning,\n                stacklevel=2,\n            )\n            return\n\n        labels = self._get_from_buffer(self.LABELS)\n        predictions = self._get_from_buffer(self.PREDICTIONS)\n        assert len(labels) == len(predictions)\n        metrics = evaluate_list(\n            result_boxes_list=predictions,\n            gt_boxes_list=labels,\n            height=img_height,\n            width=img_width,\n            apply_bbox_filters=True,\n            downsampled_by_2=self.downsample_by_2,\n            camera=self.dataset,\n        )\n        return metrics\n"
  },
  {
    "path": "RVT/utils/evaluation/prophesee/io/__init__.py",
    "content": ""
  },
  {
    "path": "RVT/utils/evaluation/prophesee/io/box_filtering.py",
    "content": "\"\"\"\nDefine same filtering that we apply in:\n\"Learning to detect objects on a 1 Megapixel Event Camera\" by Etienne Perot et al.\n\nNamely we apply 2 different filters:\n1. skip all boxes before 0.5s (before we assume it is unlikely you have sufficient historic)\n2. filter all boxes whose diagonal <= min_box_diag**2 and whose side <= min_box_side\n\n\n\nCopyright: (c) 2019-2020 Prophesee\n\"\"\"\n\nfrom __future__ import print_function\n\nimport numpy as np\n\n\ndef filter_boxes(boxes, skip_ts=int(5e5), min_box_diag=60, min_box_side=20):\n    \"\"\"Filters boxes according to the paper rule.\n\n    To note: the default represents our threshold when evaluating GEN4 resolution (1280x720)\n    To note: we assume the initial time of the video is always 0\n\n    Args:\n        boxes (np.ndarray): structured box array with fields ['t','x','y','w','h','class_id','track_id','class_confidence']\n        (example BBOX_DTYPE is provided in src/box_loading.py)\n\n    Returns:\n        boxes: filtered boxes\n    \"\"\"\n    ts = boxes[\"t\"]\n    width = boxes[\"w\"]\n    height = boxes[\"h\"]\n    diag_square = width**2 + height**2\n    mask = (\n        (ts > skip_ts)\n        * (diag_square >= min_box_diag**2)\n        * (width >= min_box_side)\n        * (height >= min_box_side)\n    )\n    return boxes[mask]\n"
  },
  {
    "path": "RVT/utils/evaluation/prophesee/io/box_loading.py",
    "content": "\"\"\"\nDefines some tools to handle events.\nIn particular :\n    -> defines events' types\n    -> defines functions to read events from binary .dat files using numpy\n    -> defines functions to write events to binary .dat files using numpy\n\nCopyright: (c) 2019-2020 Prophesee\n\"\"\"\n\nfrom __future__ import print_function\n\nfrom typing import List, Optional, Tuple\n\nimport numpy as np\nimport torch as th\n\nfrom data.genx_utils.labels import ObjectLabels\n\nBBOX_DTYPE = np.dtype(\n    {\n        \"names\": [\"t\", \"x\", \"y\", \"w\", \"h\", \"class_id\", \"track_id\", \"class_confidence\"],\n        \"formats\": [\"<i8\", \"<f4\", \"<f4\", \"<f4\", \"<f4\", \"<u4\", \"<u4\", \"<f4\"],\n        \"offsets\": [0, 8, 12, 16, 20, 24, 28, 32],\n        \"itemsize\": 40,\n    }\n)\n\nYOLOX_PRED_PROCESSED = List[Optional[th.Tensor]]\nLOADED_LABELS = List[ObjectLabels]\n\n\ndef reformat_boxes(boxes):\n    \"\"\"ReFormat boxes according to new rule\n    This allows to be backward-compatible with imerit annotation.\n        't' = 'ts'\n        'class_confidence' = 'confidence'\n    \"\"\"\n    if \"t\" not in boxes.dtype.names or \"class_confidence\" not in boxes.dtype.names:\n        new = np.zeros((len(boxes),), dtype=BBOX_DTYPE)\n        for name in boxes.dtype.names:\n            if name == \"ts\":\n                new[\"t\"] = boxes[name]\n            elif name == \"confidence\":\n                new[\"class_confidence\"] = boxes[name]\n            else:\n                new[name] = boxes[name]\n        return new\n    else:\n        return boxes\n\n\ndef loaded_label_to_prophesee(loaded_labels: ObjectLabels) -> np.ndarray:\n    loaded_labels.numpy_()\n    loaded_label_proph = np.zeros((len(loaded_labels),), dtype=BBOX_DTYPE)\n    for name in BBOX_DTYPE.names:\n        if name == \"track_id\":\n            # We don't have that and don't need it\n            continue\n        loaded_label_proph[name] = np.asarray(\n            loaded_labels.get(name), dtype=BBOX_DTYPE[name]\n        )\n    return loaded_label_proph\n\n\ndef to_prophesee(\n    loaded_label_list: LOADED_LABELS, yolox_pred_list: YOLOX_PRED_PROCESSED\n) -> Tuple[List[np.ndarray], List[np.ndarray]]:\n    assert len(loaded_label_list) == len(yolox_pred_list)\n\n    loaded_label_list_proph = []\n    yolox_pred_list_proph = []\n    for loaded_labels, yolox_preds in zip(loaded_label_list, yolox_pred_list):\n        # TODO: use loaded_label_to_prophesee func here\n        time = None\n        # --- LOADED LABELS ---\n        loaded_labels.numpy_()\n        loaded_label_proph = np.zeros((len(loaded_labels),), dtype=BBOX_DTYPE)\n        for name in BBOX_DTYPE.names:\n            if name == \"track_id\":\n                # We don't have that and don't need it\n                continue\n            loaded_label_proph[name] = np.asarray(\n                loaded_labels.get(name), dtype=BBOX_DTYPE[name]\n            )\n            if name == \"t\":\n                time = np.unique(loaded_labels.get(name))\n                assert time.size == 1\n                time = time.item()\n        loaded_label_list_proph.append(loaded_label_proph)\n\n        # --- YOLOX PREDICTIONS ---\n        # Assumes batch of post-processed predictions from YoloX Head.\n        # See postprocessing: https://github.com/Megvii-BaseDetection/YOLOX/blob/a5bb5ab12a61b8a25a5c3c11ae6f06397eb9b296/yolox/utils/boxes.py#L32\n        # Detections ordered as (x1, y1, x2, y2, obj_conf, class_conf, class_pred)\n        num_pred = 0 if yolox_preds is None else yolox_preds.shape[0]\n        yolox_pred_proph = np.zeros((num_pred,), dtype=BBOX_DTYPE)\n        if num_pred > 0:\n            yolox_preds = yolox_preds.detach().cpu().numpy()\n            assert yolox_preds.shape == (num_pred, 7)\n            yolox_pred_proph[\"t\"] = np.ones((num_pred,), dtype=BBOX_DTYPE[\"t\"]) * time\n            yolox_pred_proph[\"x\"] = np.asarray(yolox_preds[:, 0], dtype=BBOX_DTYPE[\"x\"])\n            yolox_pred_proph[\"y\"] = np.asarray(yolox_preds[:, 1], dtype=BBOX_DTYPE[\"y\"])\n            yolox_pred_proph[\"w\"] = np.asarray(\n                yolox_preds[:, 2] - yolox_preds[:, 0], dtype=BBOX_DTYPE[\"w\"]\n            )\n            yolox_pred_proph[\"h\"] = np.asarray(\n                yolox_preds[:, 3] - yolox_preds[:, 1], dtype=BBOX_DTYPE[\"h\"]\n            )\n            yolox_pred_proph[\"class_id\"] = np.asarray(\n                yolox_preds[:, 6], dtype=BBOX_DTYPE[\"class_id\"]\n            )\n            yolox_pred_proph[\"class_confidence\"] = np.asarray(\n                yolox_preds[:, 5], dtype=BBOX_DTYPE[\"class_confidence\"]\n            )\n        yolox_pred_list_proph.append(yolox_pred_proph)\n\n    return loaded_label_list_proph, yolox_pred_list_proph\n"
  },
  {
    "path": "RVT/utils/evaluation/prophesee/io/dat_events_tools.py",
    "content": "\"\"\"\nDefines some tools to handle events.\nIn particular :\n    -> defines events' types\n    -> defines functions to read events from binary .dat files using numpy\n    -> defines functions to write events to binary .dat files using numpy\n\nCopyright: (c) 2019-2020 Prophesee\n\"\"\"\n\nfrom __future__ import print_function\n\nimport datetime\nimport os\nimport sys\n\nimport numpy as np\n\nEV_TYPE = [(\"t\", \"u4\"), (\"_\", \"i4\")]  # Event2D\n\nEV_STRING = \"Event2D\"\n\n\ndef load_td_data(filename, ev_count=-1, ev_start=0):\n    \"\"\"\n    Loads TD data from files generated by the StreamLogger consumer for Event2D\n    events [ts,x,y,p]. The type ID in the file header must be 0.\n    args :\n        - path to a dat file\n        - number of event (all if set to the default -1)\n        - index of the first event\n\n    return :\n        - dat, a dictionary like structure containing the fields ts, x, y, p\n    \"\"\"\n\n    with open(filename, \"rb\") as f:\n        _, ev_type, ev_size, _ = parse_header(f)\n        if ev_start > 0:\n            f.seek(ev_start * ev_size, 1)\n\n        dtype = EV_TYPE\n        dat = np.fromfile(f, dtype=dtype, count=ev_count)\n        xyp = None\n        if (\"_\", \"i4\") in dtype:\n            x = np.bitwise_and(dat[\"_\"], 16383)\n            y = np.right_shift(np.bitwise_and(dat[\"_\"], 268419072), 14)\n            p = np.right_shift(np.bitwise_and(dat[\"_\"], 268435456), 28)\n            xyp = (x, y, p)\n        return _dat_transfer(dat, dtype, xyp=xyp)\n\n\ndef _dat_transfer(dat, dtype, xyp=None):\n    \"\"\"\n    Transfers the fields present in dtype from an old datastructure to a new datastructure\n    xyp should be passed as a tuple\n    args :\n        - dat vector as directly read from file\n        - dtype _numpy dtype_ as a list of couple of field name/ type eg [('x','i4'), ('y','f2')]\n        - xyp optional tuple containing x,y,p etracted from a field '_'and untangled by bitshift and masking\n    \"\"\"\n    variables = []\n    xyp_index = -1\n    for i, (name, _) in enumerate(dtype):\n        if name == \"_\":\n            xyp_index = i\n            continue\n        variables.append((name, dat[name]))\n    if xyp and xyp_index == -1:\n        print(\"Error dat didn't contain a '_' field !\")\n        return\n    if xyp_index >= 0:\n        dtype = (\n            dtype[:xyp_index]\n            + [(\"x\", \"i2\"), (\"y\", \"i2\"), (\"p\", \"i2\")]\n            + dtype[xyp_index + 1 :]\n        )\n    new_dat = np.empty(dat.shape[0], dtype=dtype)\n    if xyp:\n        new_dat[\"x\"] = xyp[0].astype(np.uint16)\n        new_dat[\"y\"] = xyp[1].astype(np.uint16)\n        new_dat[\"p\"] = xyp[2].astype(np.uint16)\n    for name, arr in variables:\n        new_dat[name] = arr\n    return new_dat\n\n\ndef stream_td_data(file_handle, buffer, dtype, ev_count=-1):\n    \"\"\"\n    Streams data from opened file_handle\n    args :\n        - file_handle: file object\n        - buffer: pre-allocated buffer to fill with events\n        - dtype:  expected fields\n        - ev_count: number of events\n    \"\"\"\n\n    dat = np.fromfile(file_handle, dtype=dtype, count=ev_count)\n    count = len(dat[\"t\"])\n    for name, _ in dtype:\n        if name == \"_\":\n            buffer[\"x\"][:count] = np.bitwise_and(dat[\"_\"], 16383)\n            buffer[\"y\"][:count] = np.right_shift(\n                np.bitwise_and(dat[\"_\"], 268419072), 14\n            )\n            buffer[\"p\"][:count] = np.right_shift(\n                np.bitwise_and(dat[\"_\"], 268435456), 28\n            )\n        else:\n            buffer[name][:count] = dat[name]\n\n\ndef count_events(filename):\n    \"\"\"\n    Returns the number of events in a dat file\n    args :\n        - path to a dat file\n    \"\"\"\n    with open(filename, \"rb\") as f:\n        bod, _, ev_size, _ = parse_header(f)\n        f.seek(0, os.SEEK_END)\n        eod = f.tell()\n        if (eod - bod) % ev_size != 0:\n            raise Exception(\"unexpected format !\")\n        return (eod - bod) // ev_size\n\n\ndef parse_header(f):\n    \"\"\"\n    Parses the header of a dat file\n    Args:\n        - f file handle to a dat file\n    return :\n        - int position of the file cursor after the header\n        - int type of event\n        - int size of event in bytes\n        - size (height, width) tuple of int or None\n    \"\"\"\n    f.seek(0, os.SEEK_SET)\n    bod = None\n    end_of_header = False\n    header = []\n    num_comment_line = 0\n    size = [None, None]\n    # parse header\n    while not end_of_header:\n        bod = f.tell()\n        line = f.readline()\n        if sys.version_info > (3, 0):\n            first_item = line.decode(\"latin-1\")[:2]\n        else:\n            first_item = line[:2]\n\n        if first_item != \"% \":\n            end_of_header = True\n        else:\n            words = line.split()\n            if len(words) > 1:\n                if words[1] == \"Date\":\n                    header += [\"Date\", words[2] + \" \" + words[3]]\n                if (\n                    words[1] == \"Height\" or words[1] == b\"Height\"\n                ):  # compliant with python 3 (and python2)\n                    size[0] = int(words[2])\n                    header += [\"Height\", words[2]]\n                if (\n                    words[1] == \"Width\" or words[1] == b\"Width\"\n                ):  # compliant with python 3 (and python2)\n                    size[1] = int(words[2])\n                    header += [\"Width\", words[2]]\n            else:\n                header += words[1:3]\n            num_comment_line += 1\n    # parse data\n    f.seek(bod, os.SEEK_SET)\n\n    if num_comment_line > 0:  # Ensure compatibility with previous files.\n        # Read event type\n        ev_type = np.frombuffer(f.read(1), dtype=np.uint8)[0]\n        # Read event size\n        ev_size = np.frombuffer(f.read(1), dtype=np.uint8)[0]\n    else:\n        ev_type = 0\n        ev_size = sum([int(n[-1]) for _, n in EV_TYPE])\n\n    bod = f.tell()\n    return bod, ev_type, ev_size, size\n\n\ndef write_header(filename, height=240, width=320, ev_type=0):\n    \"\"\"\n    write header for a dat file\n    \"\"\"\n    if max(height, width) > 2**14 - 1:\n        raise ValueError(\n            \"Coordinates value exceed maximum range in\"\n            \" binary .dat file format max({:d},{:d}) vs 2^14 - 1\".format(height, width)\n        )\n    f = open(filename, \"w\")\n    f.write(\n        \"% Data file containing {:s} events.\\n\"\n        \"% Version 2\\n\".format(EV_STRINGS[ev_type])\n    )\n    now = datetime.datetime.utcnow()\n    f.write(\n        \"% Date {}-{}-{} {}:{}:{}\\n\".format(\n            now.year, now.month, now.day, now.hour, now.minute, now.second\n        )\n    )\n\n    f.write(\"% Height {:d}\\n\" \"% Width {:d}\\n\".format(height, width))\n    # write type and bit size\n    ev_size = sum([int(b[-1]) for _, b in EV_TYPE])\n\n    np.array([ev_type, ev_size], dtype=np.uint8).tofile(f)\n    f.flush()\n    return f\n\n\ndef write_event_buffer(f, buffers):\n    \"\"\"\n    writes events of fields x,y,p,t into the file object f\n    \"\"\"\n    # pack data as events\n    dtype = EV_TYPE\n    data_to_write = np.empty(len(buffers[\"t\"]), dtype=dtype)\n\n    for name, typ in buffers.dtype.fields.items():\n        if name == \"x\":\n            x = buffers[\"x\"].astype(\"i4\")\n        elif name == \"y\":\n            y = np.left_shift(buffers[\"y\"].astype(\"i4\"), 14)\n        elif name == \"p\":\n            buffers[\"p\"] = (buffers[\"p\"] == 1).astype(buffers[\"p\"].dtype)\n            p = np.left_shift(buffers[\"p\"].astype(\"i4\"), 28)\n        else:\n            data_to_write[name] = buffers[name].astype(typ[0])\n\n    data_to_write[\"_\"] = x + y + p\n\n    # write data\n    data_to_write.tofile(f)\n    f.flush()\n"
  },
  {
    "path": "RVT/utils/evaluation/prophesee/io/npy_events_tools.py",
    "content": "#!/usr/bin/env python\n\n\"\"\"\nDefines some tools to handle events, mimicking dat_events_tools.py.\nIn particular :\n    -> defines functions to read events from binary .npy files using numpy\n    -> defines functions to write events to binary .dat files using numpy (TODO later)\n\nCopyright: (c) 2015-2019 Prophesee\n\"\"\"\nfrom __future__ import print_function\n\nimport numpy as np\n\n\ndef stream_td_data(file_handle, buffer, dtype, ev_count=-1):\n    \"\"\"\n    Streams data from opened file_handle\n    args :\n        - file_handle: file object\n        - buffer: pre-allocated buffer to fill with events\n        - dtype:  expected fields\n        - ev_count: number of events\n    \"\"\"\n    dat = np.fromfile(file_handle, dtype=dtype, count=ev_count)\n    count = len(dat[\"t\"])\n    for name, _ in dtype:\n        buffer[name][:count] = dat[name]\n\n\ndef parse_header(fhandle):\n    \"\"\"\n    Parses the header of a .npy file\n    Args:\n        - f file handle to a .npy file\n    return :\n        - int position of the file cursor after the header\n        - int type of event\n        - int size of event in bytes\n        - size (height, width) tuple of int or (None, None)\n    \"\"\"\n    version = np.lib.format.read_magic(fhandle)\n    shape, fortran, dtype = np.lib.format._read_array_header(fhandle, version)\n    assert not fortran, \"Fortran order arrays not supported\"\n    # Get the number of elements in one 'row' by taking\n    # a product over all other dimensions.\n    if len(shape) == 0:\n        count = 1\n    else:\n        count = np.multiply.reduce(shape, dtype=np.int64)\n    ev_size = dtype.itemsize\n    assert ev_size != 0\n    start = fhandle.tell()\n    # turn numpy.dtype into an iterable list\n    ev_type = [(x, str(dtype.fields[x][0])) for x in dtype.names]\n    # filter name to have only t and not ts\n    ev_type = [(name if name != \"ts\" else \"t\", desc) for name, desc in ev_type]\n    ev_type = [\n        (name if name != \"confidence\" else \"class_confidence\", desc)\n        for name, desc in ev_type\n    ]\n    size = (None, None)\n    size = (None, None)\n\n    return start, ev_type, ev_size, size\n"
  },
  {
    "path": "RVT/utils/evaluation/prophesee/io/psee_loader.py",
    "content": "\"\"\"\nThis class loads events from dat or npy files\n\nCopyright: (c) 2019-2020 Prophesee\n\"\"\"\n\nfrom __future__ import print_function\n\nimport os\n\nimport numpy as np\n\nfrom . import dat_events_tools as dat\nfrom . import npy_events_tools as npy_format\n\n\nclass PSEELoader(object):\n    \"\"\"\n    PSEELoader loads a dat or npy file and stream events\n    \"\"\"\n\n    def __init__(self, datfile):\n        \"\"\"\n        ctor\n        :param datfile: binary dat or npy file\n        \"\"\"\n        self._extension = datfile.split(\".\")[-1]\n        assert self._extension in [\"dat\", \"npy\"], \"input file path = {}\".format(datfile)\n        if self._extension == \"dat\":\n            self._binary_format = dat\n        elif self._extension == \"npy\":\n            self._binary_format = npy_format\n        self._file = open(datfile, \"rb\")\n        (\n            self._start,\n            self.ev_type,\n            self._ev_size,\n            self._size,\n        ) = self._binary_format.parse_header(self._file)\n        assert self._ev_size != 0\n        if self._extension == \"dat\":\n            self._dtype = self._binary_format.EV_TYPE\n        elif self._extension == \"npy\":\n            self._dtype = self.ev_type\n        else:\n            assert False, \"unsupported extension\"\n\n        self._decode_dtype = []\n        for dtype in self._dtype:\n            if dtype[0] == \"_\":\n                self._decode_dtype += [(\"x\", \"u2\"), (\"y\", \"u2\"), (\"p\", \"u1\")]\n            else:\n                self._decode_dtype.append(dtype)\n\n        # size\n        self._file.seek(0, os.SEEK_END)\n        self._end = self._file.tell()\n        self._ev_count = (self._end - self._start) // self._ev_size\n        self.done = False\n        self._file.seek(self._start)\n        # If the current time is t, it means that next event that will be loaded has a\n        # timestamp superior or equal to t (event with timestamp exactly t is not loaded yet)\n        self.current_time = 0\n        self.duration_s = self.total_time() * 1e-6\n\n    def reset(self):\n        \"\"\"reset at beginning of file\"\"\"\n        self._file.seek(self._start)\n        self.done = False\n        self.current_time = 0\n\n    def event_count(self):\n        \"\"\"\n        getter on event_count\n        :return:\n        \"\"\"\n        return self._ev_count\n\n    def get_size(self):\n        \"\"\" \"(height, width) of the imager might be (None, None)\"\"\"\n        return self._size\n\n    def __repr__(self):\n        \"\"\"\n        prints properties\n        :return:\n        \"\"\"\n        wrd = \"\"\n        wrd += \"PSEELoader:\" + \"\\n\"\n        wrd += \"-----------\" + \"\\n\"\n        if self._extension == \"dat\":\n            wrd += \"Event Type: \" + str(self._binary_format.EV_STRING) + \"\\n\"\n        elif self._extension == \"npy\":\n            wrd += \"Event Type: numpy array element\\n\"\n        wrd += \"Event Size: \" + str(self._ev_size) + \" bytes\\n\"\n        wrd += \"Event Count: \" + str(self._ev_count) + \"\\n\"\n        wrd += \"Duration: \" + str(self.duration_s) + \" s \\n\"\n        wrd += \"-----------\" + \"\\n\"\n        return wrd\n\n    def load_n_events(self, ev_count):\n        \"\"\"\n        load batch of n events\n        :param ev_count: number of events that will be loaded\n        :return: events\n        Note that current time will be incremented to reach the timestamp of the first event not loaded yet\n        \"\"\"\n        event_buffer = np.empty((ev_count + 1,), dtype=self._decode_dtype)\n\n        pos = self._file.tell()\n        count = (self._end - pos) // self._ev_size\n        if ev_count >= count:\n            self.done = True\n            ev_count = count\n            self._binary_format.stream_td_data(\n                self._file, event_buffer, self._dtype, ev_count\n            )\n            self.current_time = event_buffer[\"t\"][ev_count - 1] + 1\n        else:\n            self._binary_format.stream_td_data(\n                self._file, event_buffer, self._dtype, ev_count + 1\n            )\n            self.current_time = event_buffer[\"t\"][ev_count]\n            self._file.seek(pos + ev_count * self._ev_size)\n\n        return event_buffer[:ev_count]\n\n    def load_delta_t(self, delta_t):\n        \"\"\"\n        loads a slice of time.\n        :param delta_t: (us) slice thickness\n        :return: events\n        Note that current time will be incremented by delta_t.\n        If an event is timestamped at exactly current_time it will not be loaded.\n        \"\"\"\n        if delta_t < 1:\n            raise ValueError(\n                \"load_delta_t(): delta_t must be at least 1 micro-second: {}\".format(\n                    delta_t\n                )\n            )\n\n        if self.done or (self._file.tell() >= self._end):\n            self.done = True\n            return np.empty((0,), dtype=self._decode_dtype)\n\n        final_time = self.current_time + delta_t\n        tmp_time = self.current_time\n        start = self._file.tell()\n        pos = start\n        nevs = 0\n        batch = 100000\n        event_buffer = []\n        # data is read by buffers until enough events are read or until the end of the file\n        while tmp_time < final_time and pos < self._end:\n            count = (min(self._end, pos + batch * self._ev_size) - pos) // self._ev_size\n            buffer = np.empty((count,), dtype=self._decode_dtype)\n            self._binary_format.stream_td_data(self._file, buffer, self._dtype, count)\n            tmp_time = buffer[\"t\"][-1]\n            event_buffer.append(buffer)\n            nevs += count\n            pos = self._file.tell()\n        if tmp_time >= final_time:\n            self.current_time = final_time\n        else:\n            self.current_time = tmp_time + 1\n        assert len(event_buffer) > 0\n        idx = np.searchsorted(event_buffer[-1][\"t\"], final_time)\n        event_buffer[-1] = event_buffer[-1][:idx]\n        event_buffer = np.concatenate(event_buffer)\n        idx = len(event_buffer)\n        self._file.seek(start + idx * self._ev_size)\n        self.done = self._file.tell() >= self._end\n        return event_buffer\n\n    def seek_event(self, ev_count):\n        \"\"\"\n        seek in the file by ev_count events\n        :param ev_count: seek in the file after ev_count events\n        Note that current time will be set to the timestamp of the next event.\n        \"\"\"\n        if ev_count <= 0:\n            self._file.seek(self._start)\n            self.current_time = 0\n        elif ev_count >= self._ev_count:\n            # we put the cursor one event before and read the last event\n            # which puts the file cursor at the right place\n            # current_time is set to the last event timestamp + 1\n            self._file.seek(self._start + (self._ev_count - 1) * self._ev_size)\n            self.current_time = (\n                np.fromfile(self._file, dtype=self._dtype, count=1)[\"t\"][0] + 1\n            )\n        else:\n            # we put the cursor at the *ev_count*nth event\n            self._file.seek(self._start + (ev_count) * self._ev_size)\n            # we read the timestamp of the following event (this change the position in the file)\n            self.current_time = np.fromfile(self._file, dtype=self._dtype, count=1)[\n                \"t\"\n            ][0]\n            # this is why we go back at the right position here\n            self._file.seek(self._start + (ev_count) * self._ev_size)\n        self.done = self._file.tell() >= self._end\n\n    def seek_time(self, final_time, term_criterion=100000):\n        \"\"\"\n        go to the time final_time inside the file. This is implemented using a binary search algorithm\n        :param final_time: expected time\n        :param term_cirterion: (nb event) binary search termination criterion\n        it will load those events in a buffer and do a numpy searchsorted so the result is always exact\n        \"\"\"\n        if final_time > self.total_time():\n            self._file.seek(self._end)\n            self.done = True\n            self.current_time = self.total_time() + 1\n            return\n\n        if final_time <= 0:\n            self.reset()\n            return\n\n        low = 0\n        high = self._ev_count\n\n        # binary search\n        while high - low > term_criterion:\n            middle = (low + high) // 2\n\n            self.seek_event(middle)\n            mid = np.fromfile(self._file, dtype=self._dtype, count=1)[\"t\"][0]\n\n            if mid > final_time:\n                high = middle\n            elif mid < final_time:\n                low = middle + 1\n            else:\n                self.current_time = final_time\n                self.done = self._file.tell() >= self._end\n                return\n        # we now know that it is between low and high\n        self.seek_event(low)\n        final_buffer = np.fromfile(self._file, dtype=self._dtype, count=high - low)[\"t\"]\n        final_index = np.searchsorted(final_buffer, final_time)\n\n        self.seek_event(low + final_index)\n        self.current_time = final_time\n        self.done = self._file.tell() >= self._end\n\n    def total_time(self):\n        \"\"\"\n        get total duration of video in mus, providing there is no overflow\n        :return:\n        \"\"\"\n        if not self._ev_count:\n            return 0\n        # save the state of the class\n        pos = self._file.tell()\n        current_time = self.current_time\n        done = self.done\n        # read the last event's timestamp\n        self.seek_event(self._ev_count - 1)\n        time = np.fromfile(self._file, dtype=self._dtype, count=1)[\"t\"][0]\n        # restore the state\n        self._file.seek(pos)\n        self.current_time = current_time\n        self.done = done\n\n        return time\n\n    def __del__(self):\n        self._file.close()\n"
  },
  {
    "path": "RVT/utils/evaluation/prophesee/metrics/__init__.py",
    "content": ""
  },
  {
    "path": "RVT/utils/evaluation/prophesee/metrics/coco_eval.py",
    "content": "\"\"\"\nCompute the COCO metric on bounding box files by matching timestamps\n\nCopyright: (c) 2019-2020 Prophesee\n\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport contextlib\nimport os\n\nimport numpy as np\nfrom pycocotools.coco import COCO\n\ntry:\n    coco_eval_type = \"cpp-based\"\n    from detectron2.evaluation.fast_eval_api import COCOeval_opt as COCOeval\nexcept ImportError:\n    coco_eval_type = \"python-based\"\n    from pycocotools.cocoeval import COCOeval\nprint(f\"Using {coco_eval_type} detection evaluation\")\n\n\ndef evaluate_detection(\n    gt_boxes_list,\n    dt_boxes_list,\n    classes=(\"car\", \"pedestrian\"),\n    height=240,\n    width=304,\n    time_tol=50000,\n    return_aps: bool = True,\n):\n    \"\"\"\n    Compute detection KPIs on list of boxes in the numpy format, using the COCO python API\n    https://github.com/cocodataset/cocoapi\n    KPIs are only computed on timestamps where there is actual at least one box\n    (fully empty frames are not considered)\n\n    :param gt_boxes_list: list of numpy array for GT boxes (one per file)\n    :param dt_boxes_list: list of numpy array for detected boxes\n    :param classes: iterable of classes names\n    :param height: int for box size statistics\n    :param width: int for box size statistics\n    :param time_tol: int size of the temporal window in micro seconds to look for a detection around a gt box\n    \"\"\"\n    flattened_gt = []\n    flattened_dt = []\n    for gt_boxes, dt_boxes in zip(gt_boxes_list, dt_boxes_list):\n        assert np.all(gt_boxes[\"t\"][1:] >= gt_boxes[\"t\"][:-1])\n        assert np.all(dt_boxes[\"t\"][1:] >= dt_boxes[\"t\"][:-1])\n\n        all_ts = np.unique(gt_boxes[\"t\"])\n        n_steps = len(all_ts)\n\n        gt_win, dt_win = _match_times(all_ts, gt_boxes, dt_boxes, time_tol)\n        flattened_gt = flattened_gt + gt_win\n        flattened_dt = flattened_dt + dt_win\n    return _coco_eval(\n        flattened_gt,\n        flattened_dt,\n        height,\n        width,\n        labelmap=classes,\n        return_aps=return_aps,\n    )\n\n\ndef _match_times(all_ts, gt_boxes, dt_boxes, time_tol):\n    \"\"\"\n    match ground truth boxes and ground truth detections at all timestamps using a specified tolerance\n    return a list of boxes vectors\n    \"\"\"\n    gt_size = len(gt_boxes)\n    dt_size = len(dt_boxes)\n\n    windowed_gt = []\n    windowed_dt = []\n\n    low_gt, high_gt = 0, 0\n    low_dt, high_dt = 0, 0\n    for ts in all_ts:\n        while low_gt < gt_size and gt_boxes[low_gt][\"t\"] < ts:\n            low_gt += 1\n        # the high index is at least as big as the low one\n        high_gt = max(low_gt, high_gt)\n        while high_gt < gt_size and gt_boxes[high_gt][\"t\"] <= ts:\n            high_gt += 1\n\n        # detection are allowed to be inside a window around the right detection timestamp\n        low = ts - time_tol\n        high = ts + time_tol\n        while low_dt < dt_size and dt_boxes[low_dt][\"t\"] < low:\n            low_dt += 1\n        # the high index is at least as big as the low one\n        high_dt = max(low_dt, high_dt)\n        while high_dt < dt_size and dt_boxes[high_dt][\"t\"] <= high:\n            high_dt += 1\n\n        windowed_gt.append(gt_boxes[low_gt:high_gt])\n        windowed_dt.append(dt_boxes[low_dt:high_dt])\n\n    return windowed_gt, windowed_dt\n\n\ndef _coco_eval(\n    gts,\n    detections,\n    height,\n    width,\n    labelmap=(\"car\", \"pedestrian\"),\n    return_aps: bool = True,\n):\n    \"\"\"simple helper function wrapping around COCO's Python API\n    :params:  gts iterable of numpy boxes for the ground truth\n    :params:  detections iterable of numpy boxes for the detections\n    :params:  height int\n    :params:  width int\n    :params:  labelmap iterable of class labels\n    \"\"\"\n    categories = [\n        {\"id\": id + 1, \"name\": class_name, \"supercategory\": \"none\"}\n        for id, class_name in enumerate(labelmap)\n    ]\n\n    num_detections = 0\n    for detection in detections:\n        num_detections += detection.size\n\n    # Meaning: https://cocodataset.org/#detection-eval\n    out_keys = (\"AP\", \"AP_50\", \"AP_75\", \"AP_S\", \"AP_M\", \"AP_L\")\n    out_dict = {k: 0.0 for k in out_keys}\n\n    if num_detections == 0:\n        # Corner case at the very beginning of the training.\n        print(\"no detections for evaluation found.\")\n        return out_dict if return_aps else None\n\n    dataset, results = _to_coco_format(\n        gts, detections, categories, height=height, width=width\n    )\n\n    coco_gt = COCO()\n    coco_gt.dataset = dataset\n    coco_gt.createIndex()\n    coco_pred = coco_gt.loadRes(results)\n\n    coco_eval = COCOeval(coco_gt, coco_pred, \"bbox\")\n    coco_eval.params.imgIds = np.arange(1, len(gts) + 1, dtype=int)\n    coco_eval.evaluate()\n    coco_eval.accumulate()\n    if return_aps:\n        with open(os.devnull, \"w\") as f, contextlib.redirect_stdout(f):\n            # info: https://stackoverflow.com/questions/8391411/how-to-block-calls-to-print\n            coco_eval.summarize()\n        for idx, key in enumerate(out_keys):\n            out_dict[key] = coco_eval.stats[idx]\n        return out_dict\n    # Print the whole summary instead without return\n    coco_eval.summarize()\n\n\ndef coco_eval_return_metrics(coco_eval: COCOeval):\n    pass\n\n\ndef _to_coco_format(gts, detections, categories, height=240, width=304):\n    \"\"\"\n    utilitary function producing our data in a COCO usable format\n    \"\"\"\n    annotations = []\n    results = []\n    images = []\n\n    # to dictionary\n    for image_id, (gt, pred) in enumerate(zip(gts, detections)):\n        im_id = image_id + 1\n\n        images.append(\n            {\n                \"date_captured\": \"2019\",\n                \"file_name\": \"n.a\",\n                \"id\": im_id,\n                \"license\": 1,\n                \"url\": \"\",\n                \"height\": height,\n                \"width\": width,\n            }\n        )\n\n        for bbox in gt:\n            x1, y1 = bbox[\"x\"], bbox[\"y\"]\n            w, h = bbox[\"w\"], bbox[\"h\"]\n            area = w * h\n\n            annotation = {\n                \"area\": float(area),\n                \"iscrowd\": False,\n                \"image_id\": im_id,\n                \"bbox\": [x1, y1, w, h],\n                \"category_id\": int(bbox[\"class_id\"]) + 1,\n                \"id\": len(annotations) + 1,\n            }\n            annotations.append(annotation)\n\n        for bbox in pred:\n            image_result = {\n                \"image_id\": im_id,\n                \"category_id\": int(bbox[\"class_id\"]) + 1,\n                \"score\": float(bbox[\"class_confidence\"]),\n                \"bbox\": [bbox[\"x\"], bbox[\"y\"], bbox[\"w\"], bbox[\"h\"]],\n            }\n            results.append(image_result)\n\n    dataset = {\n        \"info\": {},\n        \"licenses\": [],\n        \"type\": \"instances\",\n        \"images\": images,\n        \"annotations\": annotations,\n        \"categories\": categories,\n    }\n    return dataset, results\n"
  },
  {
    "path": "RVT/utils/evaluation/prophesee/visualize/__init__.py",
    "content": ""
  },
  {
    "path": "RVT/utils/evaluation/prophesee/visualize/vis_utils.py",
    "content": "\"\"\"\nFunctions to display events and boxes\nCopyright: (c) 2019-2020 Prophesee\n\"\"\"\n\nfrom __future__ import print_function\n\nimport bbox_visualizer as bbv\nimport cv2\nimport numpy as np\n\nLABELMAP_GEN1 = (\"car\", \"pedestrian\")\nLABELMAP_GEN4 = (\n    \"pedestrian\",\n    \"two wheeler\",\n    \"car\",\n    \"truck\",\n    \"bus\",\n    \"traffic sign\",\n    \"traffic light\",\n)\nLABELMAP_GEN4_SHORT = (\"pedestrian\", \"two wheeler\", \"car\")\n\n\ndef make_binary_histo(events, img=None, width=304, height=240):\n    \"\"\"\n    simple display function that shows negative events as blacks dots and positive as white one\n    on a gray background\n    args :\n        - events structured numpy array\n        - img (numpy array, height x width x 3) optional array to paint event on.\n        - width int\n        - height int\n    return:\n        - img numpy array, height x width x 3)\n    \"\"\"\n    if img is None:\n        img = 127 * np.ones((height, width, 3), dtype=np.uint8)\n    else:\n        # if an array was already allocated just paint it grey\n        img[...] = 127\n    if events.size:\n        assert events[\"x\"].max() < width, \"out of bound events: x = {}, w = {}\".format(\n            events[\"x\"].max(), width\n        )\n        assert events[\"y\"].max() < height, \"out of bound events: y = {}, h = {}\".format(\n            events[\"y\"].max(), height\n        )\n\n        img[events[\"y\"], events[\"x\"], :] = 255 * events[\"p\"][:, None]\n    return img\n\n\ndef draw_bboxes_bbv(img, boxes, labelmap=LABELMAP_GEN1) -> np.ndarray:\n    \"\"\"\n    draw bboxes in the image img\n    \"\"\"\n    colors = cv2.applyColorMap(np.arange(0, 255).astype(np.uint8), cv2.COLORMAP_HSV)\n    colors = [tuple(*item) for item in colors.tolist()]\n\n    if labelmap == LABELMAP_GEN1:\n        classid2colors = {\n            0: (255, 255, 0),  # car -> yellow (rgb)\n            1: (0, 0, 255),  # ped -> blue (rgb)\n        }\n        scale_multiplier = 4\n    else:\n        assert labelmap == LABELMAP_GEN4_SHORT\n        classid2colors = {\n            0: (0, 0, 255),  # ped -> blue (rgb)\n            1: (0, 255, 255),  # 2-wheeler cyan (rgb)\n            2: (255, 255, 0),  # car -> yellow (rgb)\n        }\n        scale_multiplier = 2\n\n    add_score = True\n    ht, wd, ch = img.shape\n    dim_new_wh = (int(wd * scale_multiplier), int(ht * scale_multiplier))\n    if scale_multiplier != 1:\n        img = cv2.resize(img, dim_new_wh, interpolation=cv2.INTER_AREA)\n    for i in range(boxes.shape[0]):\n        pt1 = (int(boxes[\"x\"][i]), int(boxes[\"y\"][i]))\n        size = (int(boxes[\"w\"][i]), int(boxes[\"h\"][i]))\n        pt2 = (pt1[0] + size[0], pt1[1] + size[1])\n        bbox = (pt1[0], pt1[1], pt2[0], pt2[1])\n        bbox = tuple(x * scale_multiplier for x in bbox)\n\n        score = boxes[\"class_confidence\"][i]\n        class_id = boxes[\"class_id\"][i]\n        class_name = labelmap[class_id % len(labelmap)]\n        bbox_txt = class_name\n        if add_score:\n            bbox_txt += f\" {score:.2f}\"\n        color_tuple_rgb = classid2colors[class_id]\n        img = bbv.draw_rectangle(img, bbox, bbox_color=color_tuple_rgb)\n        img = bbv.add_label(\n            img, bbox_txt, bbox, text_bg_color=color_tuple_rgb, top=True\n        )\n\n    return img\n\n\ndef draw_bboxes(img, boxes, labelmap=LABELMAP_GEN1) -> None:\n    \"\"\"\n    draw bboxes in the image img\n    \"\"\"\n    colors = cv2.applyColorMap(np.arange(0, 255).astype(np.uint8), cv2.COLORMAP_HSV)\n    colors = [tuple(*item) for item in colors.tolist()]\n\n    for i in range(boxes.shape[0]):\n        pt1 = (int(boxes[\"x\"][i]), int(boxes[\"y\"][i]))\n        size = (int(boxes[\"w\"][i]), int(boxes[\"h\"][i]))\n        pt2 = (pt1[0] + size[0], pt1[1] + size[1])\n        score = boxes[\"class_confidence\"][i]\n        class_id = boxes[\"class_id\"][i]\n        class_name = labelmap[class_id % len(labelmap)]\n        color = colors[class_id * 60 % 255]\n        center = ((pt1[0] + pt2[0]) // 2, (pt1[1] + pt2[1]) // 2)\n        cv2.rectangle(img, pt1, pt2, color, 1)\n        cv2.putText(\n            img,\n            class_name,\n            (center[0], pt2[1] - 1),\n            cv2.FONT_HERSHEY_SIMPLEX,\n            0.5,\n            color,\n        )\n        cv2.putText(\n            img,\n            str(score),\n            (center[0], pt1[1] - 1),\n            cv2.FONT_HERSHEY_SIMPLEX,\n            0.5,\n            color,\n        )\n"
  },
  {
    "path": "RVT/utils/helpers.py",
    "content": "from typing import Union\n\nimport torch as th\n\n\ndef torch_uniform_sample_scalar(min_value: float, max_value: float):\n    assert max_value >= min_value, f\"{max_value=} is smaller than {min_value=}\"\n    if max_value == min_value:\n        return min_value\n    return min_value + (max_value - min_value) * th.rand(1).item()\n\n\ndef clamp(\n    value: Union[int, float], smallest: Union[int, float], largest: Union[int, float]\n):\n    return max(smallest, min(value, largest))\n"
  },
  {
    "path": "RVT/utils/padding.py",
    "content": "from typing import Any, List, Tuple\n\nimport torch as th\nimport torch.nn.functional as F\n\n\nclass InputPadderFromShape:\n    def __init__(\n        self,\n        desired_hw: Tuple[int, int],\n        mode: str = \"constant\",\n        value: int = 0,\n        type: str = \"corner\",\n    ):\n        \"\"\"\n        :param desired_hw: Desired height and width\n        :param mode: See torch.nn.functional.pad\n        :param value:  See torch.nn.functional.pad\n        :param type: \"corner\": add zero to bottom and right\n        \"\"\"\n        assert isinstance(desired_hw, tuple)\n        assert len(desired_hw) == 2\n        assert desired_hw[0] % 4 == 0, \"Required for token mask padding\"\n        assert desired_hw[1] % 4 == 0, \"Required for token mask padding\"\n        assert type in {\"corner\"}\n\n        self.desired_hw = desired_hw\n        self.mode = mode\n        self.value = value\n        self.type = type\n        self._pad_ev_repr = None\n        self._pad_token_mask = None\n\n    @staticmethod\n    def _pad_tensor_impl(\n        input_tensor: th.Tensor, desired_hw: Tuple[int, int], mode: str, value: Any\n    ) -> Tuple[th.Tensor, List[int]]:\n        assert isinstance(input_tensor, th.Tensor)\n\n        ht, wd = input_tensor.shape[-2:]\n        ht_des, wd_des = desired_hw\n        assert ht <= ht_des\n        assert wd <= wd_des\n\n        pad_left = 0\n        pad_right = wd_des - wd\n        pad_top = 0\n        pad_bottom = ht_des - ht\n\n        pad = [pad_left, pad_right, pad_top, pad_bottom]\n        return (\n            F.pad(\n                input_tensor,\n                pad=pad,\n                mode=mode,\n                value=value if mode == \"constant\" else None,\n            ),\n            pad,\n        )\n\n    def pad_tensor_ev_repr(self, ev_repr: th.Tensor) -> th.Tensor:\n        padded_ev_repr, pad = self._pad_tensor_impl(\n            input_tensor=ev_repr,\n            desired_hw=self.desired_hw,\n            mode=self.mode,\n            value=self.value,\n        )\n        if self._pad_ev_repr is None:\n            self._pad_ev_repr = pad\n        else:\n            assert self._pad_ev_repr == pad\n        return padded_ev_repr\n\n    def pad_token_mask(self, token_mask: th.Tensor):\n        assert isinstance(token_mask, th.Tensor)\n\n        desired_hw = tuple(x // 4 for x in self.desired_hw)\n        padded_token_mask, pad = self._pad_tensor_impl(\n            input_tensor=token_mask, desired_hw=desired_hw, mode=\"constant\", value=0\n        )\n        if self._pad_token_mask is None:\n            self._pad_token_mask = pad\n        else:\n            assert self._pad_token_mask == pad\n        return padded_token_mask\n"
  },
  {
    "path": "RVT/utils/preprocessing.py",
    "content": "def _blosc_opts(complevel=1, complib=\"blosc:zstd\", shuffle=\"byte\"):\n    shuffle = 2 if shuffle == \"bit\" else 1 if shuffle == \"byte\" else 0\n    compressors = [\"blosclz\", \"lz4\", \"lz4hc\", \"snappy\", \"zlib\", \"zstd\"]\n    complib = [\"blosc:\" + c for c in compressors].index(complib)\n    args = {\n        \"compression\": 32001,\n        \"compression_opts\": (0, 0, 0, 0, complevel, shuffle, complib),\n    }\n    if shuffle > 0:\n        # Do not use h5py shuffle if blosc shuffle is enabled.\n        args[\"shuffle\"] = False\n    return args\n"
  },
  {
    "path": "RVT/utils/timers.py",
    "content": "import atexit\nimport time\nfrom functools import wraps\n\nimport numpy as np\nimport torch\n\ncuda_timers = {}\ntimers = {}\n\n\nclass CudaTimer:\n    def __init__(self, device: torch.device, timer_name: str):\n        assert isinstance(device, torch.device)\n        assert isinstance(timer_name, str)\n        self.timer_name = timer_name\n        if self.timer_name not in cuda_timers:\n            cuda_timers[self.timer_name] = []\n\n        self.device = device\n        self.start = None\n        self.end = None\n\n    def __enter__(self):\n        torch.cuda.synchronize(device=self.device)\n        self.start = time.time()\n        return self\n\n    def __exit__(self, *args):\n        assert self.start is not None\n        torch.cuda.synchronize(device=self.device)\n        end = time.time()\n        cuda_timers[self.timer_name].append(end - self.start)\n\n\ndef cuda_timer_decorator(device: torch.device, timer_name: str):\n    def decorator(func):\n        @wraps(func)\n        def wrapper(*args, **kwargs):\n            with CudaTimer(device=device, timer_name=timer_name):\n                out = func(*args, **kwargs)\n            return out\n\n        return wrapper\n\n    return decorator\n\n\nclass TimerDummy:\n    def __init__(self, *args, **kwargs):\n        pass\n\n    def __enter__(self):\n        pass\n\n    def __exit__(self, *args):\n        pass\n\n\nclass Timer:\n    def __init__(self, timer_name=\"\"):\n        self.timer_name = timer_name\n        if self.timer_name not in timers:\n            timers[self.timer_name] = []\n\n    def __enter__(self):\n        self.start = time.time()\n        return self\n\n    def __exit__(self, *args):\n        end = time.time()\n        time_diff_s = end - self.start  # measured in seconds\n        timers[self.timer_name].append(time_diff_s)\n\n\ndef print_timing_info():\n    print(\"== Timing statistics ==\")\n    skip_warmup = 10\n    for timer_name, timing_values in [*cuda_timers.items(), *timers.items()]:\n        if len(timing_values) <= skip_warmup:\n            continue\n        values = timing_values[skip_warmup:]\n        timing_value_s_mean = np.mean(np.array(values))\n        timing_value_s_median = np.median(np.array(values))\n        timing_value_ms_mean = timing_value_s_mean * 1000\n        timing_value_ms_median = timing_value_s_median * 1000\n        if timing_value_ms_mean > 1000:\n            print(\n                \"{}: mean={:.2f} s, median={:.2f} s\".format(\n                    timer_name, timing_value_s_mean, timing_value_s_median\n                )\n            )\n        else:\n            print(\n                \"{}: mean={:.2f} ms, median={:.2f} ms\".format(\n                    timer_name, timing_value_ms_mean, timing_value_ms_median\n                )\n            )\n\n\n# this will print all the timer values upon termination of any program that imported this file\natexit.register(print_timing_info)\n"
  },
  {
    "path": "RVT/validation.py",
    "content": "import os\n\nos.environ[\"CUDA_DEVICE_ORDER\"] = \"PCI_BUS_ID\"\nos.environ[\"OMP_NUM_THREADS\"] = \"1\"\nos.environ[\"OPENBLAS_NUM_THREADS\"] = \"1\"\nos.environ[\"MKL_NUM_THREADS\"] = \"1\"\nos.environ[\"VECLIB_MAXIMUM_THREADS\"] = \"1\"\nos.environ[\"NUMEXPR_NUM_THREADS\"] = \"1\"\nfrom pathlib import Path\n\nimport torch\nfrom torch.backends import cuda, cudnn\n\ncuda.matmul.allow_tf32 = True\ncudnn.allow_tf32 = True\ntorch.multiprocessing.set_sharing_strategy(\"file_system\")\n\nimport hydra\nimport hdf5plugin\nfrom omegaconf import DictConfig, OmegaConf\nimport lightning.pytorch as pl\nfrom lightning.pytorch.loggers import CSVLogger\nfrom lightning.pytorch.callbacks import ModelSummary\n\nfrom config.modifier import dynamically_modify_train_config\nfrom modules.utils.fetch import fetch_data_module, fetch_model_module\nfrom modules.detection import Module\n\n\n@hydra.main(config_path=\"config\", config_name=\"val\", version_base=\"1.2\")\ndef main(config: DictConfig):\n    dynamically_modify_train_config(config)\n    # Just to check whether config can be resolved\n    OmegaConf.to_container(config, resolve=True, throw_on_missing=True)\n\n    print(\"------ Configuration ------\")\n    print(OmegaConf.to_yaml(config))\n    print(\"---------------------------\")\n\n    # ---------------------\n    # GPU options\n    # ---------------------\n    gpus = config.hardware.gpus\n    assert isinstance(gpus, int), \"no more than 1 GPU supported\"\n    gpus = [gpus]\n\n    # ---------------------\n    # Data\n    # ---------------------\n    data_module = fetch_data_module(config=config)\n\n    # ---------------------\n    # Logging and Checkpoints\n    # ---------------------\n    logger = CSVLogger(save_dir=\"./validation_logs\")\n    ckpt_path = Path(config.checkpoint)\n\n    # ---------------------\n    # Model\n    # ---------------------\n\n    module = fetch_model_module(config=config)\n    module = Module.load_from_checkpoint(str(ckpt_path), **{\"full_config\": config})\n\n    # ---------------------\n    # Callbacks and Misc\n    # ---------------------\n    callbacks = [ModelSummary(max_depth=2)]\n\n    # ---------------------\n    # Validation\n    # ---------------------\n\n    trainer = pl.Trainer(\n        accelerator=\"gpu\",\n        callbacks=callbacks,\n        default_root_dir=None,\n        devices=gpus,\n        logger=logger,\n        log_every_n_steps=100,\n        precision=config.training.precision,\n        # move_metrics_to_cpu=False,\n    )\n    with torch.inference_mode():\n        if config.use_test_set:\n            trainer.test(model=module, datamodule=data_module, ckpt_path=str(ckpt_path))\n        else:\n            trainer.validate(\n                model=module, datamodule=data_module, ckpt_path=str(ckpt_path)\n            )\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "installation_details.txt",
    "content": "conda create -y -n events_signals python=3.11\nconda activate events_signals\nconda install -y pytorch torchvision torchaudio pytorch-cuda=11.8 -c pytorch -c nvidia (Stable (2.2.1))\npip install lightning wandb pandas plotly opencv-python tabulate pycocotools bbox-visualizer StrEnum hydra-core einops torchdata tqdm numba h5py hdf5plugin lovely-tensors tensorboardX pykeops scikit-learn    \n"
  },
  {
    "path": "scripts/1mpx/onempx_base.bash",
    "content": "#!/usr/bin/env bash\n\nsource activate events_signals\n\npython RVT/train.py model=rnndet dataset=gen4 dataset.path=/shares/rpg.ifi.uzh/nzubic/datasets/RVT/gen4_new_no_psee_filter wandb.project_name=ssms_event_cameras \\\nwandb.group_name=1mpx +experiment/gen4=base.yaml hardware.gpus=[0,1] batch_size.train=6 batch_size.eval=6 \\\nhardware.num_workers.train=12 hardware.num_workers.eval=4\n"
  },
  {
    "path": "scripts/1mpx/onempx_base.job",
    "content": "#!/usr/bin/env bash\n#SBATCH --ntasks-per-node=2\n#SBATCH --cpus-per-task=16\n#SBATCH --mem-per-cpu=8G\n#SBATCH --time=86:00:00\n#SBATCH --gres=gpu:2  # The GPU model is optional, you can simply specify 'gpu:1'\n#SBATCH --constraint=GPUMEM80GB  # This constraint is optional if you don't care about VRAM\n#SBATCH --output=final_outputs/onempx_base.txt\n\nmodule load gpu cuda\nsrun onempx_base.bash\n"
  },
  {
    "path": "scripts/1mpx/onempx_small.bash",
    "content": "#!/usr/bin/env bash\n\nsource activate events_signals\n\npython RVT/train.py model=rnndet dataset=gen4 dataset.path=/shares/rpg.ifi.uzh/nzubic/datasets/RVT/gen4_new_no_psee_filter wandb.project_name=ssms_event_cameras \\\nwandb.group_name=1mpx +experiment/gen4=small.yaml hardware.gpus=[0,1] batch_size.train=6 batch_size.eval=6 \\\nhardware.num_workers.train=12 hardware.num_workers.eval=4\n"
  },
  {
    "path": "scripts/1mpx/onempx_small.job",
    "content": "#!/usr/bin/env bash\n#SBATCH --ntasks-per-node=2\n#SBATCH --cpus-per-task=16\n#SBATCH --mem-per-cpu=8G\n#SBATCH --time=78:00:00\n#SBATCH --gres=gpu:2  # The GPU model is optional, you can simply specify 'gpu:1'\n#SBATCH --constraint=GPUMEM80GB  # This constraint is optional if you don't care about VRAM\n#SBATCH --output=final_outputs/onempx_small_2.txt\n\nmodule load gpu cuda\nsrun onempx_small.bash\n"
  },
  {
    "path": "scripts/gen1/base.txt",
    "content": "python RVT/train.py model=rnndet dataset=gen1 dataset.path=/data/scratch1/nzubic/datasets/RVT/gen1 wandb.project_name=ssms_event_cameras \\\nwandb.group_name=gen1 +experiment/gen1=base.yaml hardware.gpus=0 batch_size.train=8 batch_size.eval=8 hardware.num_workers.train=24 \\\nhardware.num_workers.eval=8\n"
  },
  {
    "path": "scripts/gen1/small.txt",
    "content": "python RVT/train.py model=rnndet dataset=gen1 dataset.path=/data/scratch1/nzubic/datasets/RVT/gen1 wandb.project_name=ssms_event_cameras \\\nwandb.group_name=gen1 +experiment/gen1=small.yaml hardware.gpus=0 batch_size.train=8 batch_size.eval=8 hardware.num_workers.train=24 \\\nhardware.num_workers.eval=8\n"
  }
]