Repository: uzh-rpg/ssms_event_cameras Branch: master Commit: 7c871b55a0c5 Files: 162 Total size: 623.9 KB Directory structure: gitextract_9qf3l4ub/ ├── .gitignore ├── README.md ├── RVT/ │ ├── .gitignore │ ├── LICENSE │ ├── README.md │ ├── callbacks/ │ │ ├── custom.py │ │ ├── detection.py │ │ ├── gradflow.py │ │ ├── utils/ │ │ │ └── visualization.py │ │ └── viz_base.py │ ├── config/ │ │ ├── dataset/ │ │ │ ├── base.yaml │ │ │ ├── gen1.yaml │ │ │ └── gen4.yaml │ │ ├── experiment/ │ │ │ ├── gen1/ │ │ │ │ ├── base.yaml │ │ │ │ ├── default.yaml │ │ │ │ └── small.yaml │ │ │ └── gen4/ │ │ │ ├── base.yaml │ │ │ ├── default.yaml │ │ │ └── small.yaml │ │ ├── general.yaml │ │ ├── model/ │ │ │ ├── base.yaml │ │ │ ├── maxvit_yolox/ │ │ │ │ └── default.yaml │ │ │ └── rnndet.yaml │ │ ├── modifier.py │ │ ├── train.yaml │ │ └── val.yaml │ ├── data/ │ │ ├── genx_utils/ │ │ │ ├── collate.py │ │ │ ├── collate_from_pytorch.py │ │ │ ├── dataset_rnd.py │ │ │ ├── dataset_streaming.py │ │ │ ├── labels.py │ │ │ ├── sequence_base.py │ │ │ ├── sequence_for_streaming.py │ │ │ └── sequence_rnd.py │ │ └── utils/ │ │ ├── augmentor.py │ │ ├── representations.py │ │ ├── spatial.py │ │ ├── stream_concat_datapipe.py │ │ ├── stream_sharded_datapipe.py │ │ └── types.py │ ├── loggers/ │ │ ├── utils.py │ │ └── wandb_logger.py │ ├── models/ │ │ ├── detection/ │ │ │ ├── __init_.py │ │ │ ├── recurrent_backbone/ │ │ │ │ ├── __init__.py │ │ │ │ ├── base.py │ │ │ │ └── maxvit_rnn.py │ │ │ ├── yolox/ │ │ │ │ ├── models/ │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── losses.py │ │ │ │ │ ├── network_blocks.py │ │ │ │ │ └── yolo_head.py │ │ │ │ └── utils/ │ │ │ │ ├── __init__.py │ │ │ │ ├── boxes.py │ │ │ │ └── compat.py │ │ │ └── yolox_extension/ │ │ │ └── models/ │ │ │ ├── __init__.py │ │ │ ├── build.py │ │ │ ├── detector.py │ │ │ └── yolo_pafpn.py │ │ └── layers/ │ │ ├── maxvit/ │ │ │ ├── __init__.py │ │ │ ├── layers/ │ │ │ │ ├── __init__.py │ │ │ │ ├── activations.py │ │ │ │ ├── activations_jit.py │ │ │ │ ├── activations_me.py │ │ │ │ ├── adaptive_avgmax_pool.py │ │ │ │ ├── attention_pool2d.py │ │ │ │ ├── blur_pool.py │ │ │ │ ├── bottleneck_attn.py │ │ │ │ ├── cbam.py │ │ │ │ ├── classifier.py │ │ │ │ ├── cond_conv2d.py │ │ │ │ ├── config.py │ │ │ │ ├── conv2d_same.py │ │ │ │ ├── conv_bn_act.py │ │ │ │ ├── create_act.py │ │ │ │ ├── create_attn.py │ │ │ │ ├── create_conv2d.py │ │ │ │ ├── create_norm.py │ │ │ │ ├── create_norm_act.py │ │ │ │ ├── drop.py │ │ │ │ ├── eca.py │ │ │ │ ├── evo_norm.py │ │ │ │ ├── fast_norm.py │ │ │ │ ├── filter_response_norm.py │ │ │ │ ├── gather_excite.py │ │ │ │ ├── global_context.py │ │ │ │ ├── halo_attn.py │ │ │ │ ├── helpers.py │ │ │ │ ├── inplace_abn.py │ │ │ │ ├── lambda_layer.py │ │ │ │ ├── linear.py │ │ │ │ ├── median_pool.py │ │ │ │ ├── mixed_conv2d.py │ │ │ │ ├── ml_decoder.py │ │ │ │ ├── mlp.py │ │ │ │ ├── non_local_attn.py │ │ │ │ ├── norm.py │ │ │ │ ├── norm_act.py │ │ │ │ ├── padding.py │ │ │ │ ├── patch_embed.py │ │ │ │ ├── pool2d_same.py │ │ │ │ ├── pos_embed.py │ │ │ │ ├── selective_kernel.py │ │ │ │ ├── separable_conv.py │ │ │ │ ├── space_to_depth.py │ │ │ │ ├── split_attn.py │ │ │ │ ├── split_batchnorm.py │ │ │ │ ├── squeeze_excite.py │ │ │ │ ├── std_conv.py │ │ │ │ ├── test_time_pool.py │ │ │ │ ├── trace_utils.py │ │ │ │ └── weight_init.py │ │ │ └── maxvit.py │ │ ├── rnn.py │ │ └── s5/ │ │ ├── __init__.py │ │ ├── jax_func.py │ │ ├── s5_init.py │ │ ├── s5_model.py │ │ └── triton_comparison.py │ ├── modules/ │ │ ├── __init__.py │ │ ├── data/ │ │ │ └── genx.py │ │ ├── detection.py │ │ └── utils/ │ │ ├── detection.py │ │ └── fetch.py │ ├── scripts/ │ │ ├── genx/ │ │ │ ├── README.md │ │ │ ├── conf_preprocess/ │ │ │ │ ├── extraction/ │ │ │ │ │ ├── const_count.yaml │ │ │ │ │ ├── const_duration.yaml │ │ │ │ │ └── frequencies/ │ │ │ │ │ ├── const_duration_100hz.yaml │ │ │ │ │ ├── const_duration_200hz.yaml │ │ │ │ │ ├── const_duration_40hz.yaml │ │ │ │ │ └── const_duration_80hz.yaml │ │ │ │ ├── filter_gen1.yaml │ │ │ │ ├── filter_gen4.yaml │ │ │ │ └── representation/ │ │ │ │ ├── mixeddensity_stack.yaml │ │ │ │ └── stacked_hist.yaml │ │ │ ├── preprocess_dataset.py │ │ │ └── preprocess_dataset.sh │ │ └── viz/ │ │ └── viz_gt.py │ ├── train.py │ ├── utils/ │ │ ├── evaluation/ │ │ │ └── prophesee/ │ │ │ ├── __init__.py │ │ │ ├── evaluation.py │ │ │ ├── evaluator.py │ │ │ ├── io/ │ │ │ │ ├── __init__.py │ │ │ │ ├── box_filtering.py │ │ │ │ ├── box_loading.py │ │ │ │ ├── dat_events_tools.py │ │ │ │ ├── npy_events_tools.py │ │ │ │ └── psee_loader.py │ │ │ ├── metrics/ │ │ │ │ ├── __init__.py │ │ │ │ └── coco_eval.py │ │ │ └── visualize/ │ │ │ ├── __init__.py │ │ │ └── vis_utils.py │ │ ├── helpers.py │ │ ├── padding.py │ │ ├── preprocessing.py │ │ └── timers.py │ └── validation.py ├── installation_details.txt └── scripts/ ├── 1mpx/ │ ├── onempx_base.bash │ ├── onempx_base.job │ ├── onempx_small.bash │ └── onempx_small.job └── gen1/ ├── base.txt └── small.txt ================================================ FILE CONTENTS ================================================ ================================================ FILE: .gitignore ================================================ # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] *$py.class # C extensions *.so # Distribution / packaging .Python build/ develop-eggs/ dist/ downloads/ eggs/ .eggs/ lib/ lib64/ parts/ sdist/ var/ wheels/ share/python-wheels/ *.egg-info/ .installed.cfg *.egg MANIFEST # PyInstaller # Usually these files are written by a python script from a template # before PyInstaller builds the exe, so as to inject date/other infos into it. *.manifest *.spec # Installer logs pip-log.txt pip-delete-this-directory.txt # Unit test / coverage reports htmlcov/ .tox/ .nox/ .coverage .coverage.* .cache nosetests.xml coverage.xml *.cover *.py,cover .hypothesis/ .pytest_cache/ cover/ # Translations *.mo *.pot # Django stuff: *.log local_settings.py db.sqlite3 db.sqlite3-journal # Flask stuff: instance/ .webassets-cache # Scrapy stuff: .scrapy # Sphinx documentation docs/_build/ # PyBuilder .pybuilder/ target/ # Jupyter Notebook .ipynb_checkpoints # IPython profile_default/ ipython_config.py # pyenv # For a library or package, you might want to ignore these files since the code is # intended to run in multiple environments; otherwise, check them in: # .python-version # pipenv # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. # However, in case of collaboration, if having platform-specific dependencies or dependencies # having no cross-platform support, pipenv may install dependencies that don't work, or not # install all needed dependencies. #Pipfile.lock # poetry # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. # This is especially recommended for binary packages to ensure reproducibility, and is more # commonly ignored for libraries. # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control #poetry.lock # pdm # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. #pdm.lock # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it # in version control. # https://pdm.fming.dev/#use-with-ide .pdm.toml # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm __pypackages__/ # Celery stuff celerybeat-schedule celerybeat.pid # SageMath parsed files *.sage.py # Environments .env .venv env/ venv/ ENV/ env.bak/ venv.bak/ # Spyder project settings .spyderproject .spyproject # Rope project settings .ropeproject # mkdocs documentation /site # mypy .mypy_cache/ .dmypy.json dmypy.json # Pyre type checker .pyre/ # pytype static type analyzer .pytype/ # Cython debug symbols cython_debug/ # PyCharm # JetBrains specific template is maintained in a separate JetBrains.gitignore that can # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore # and can be added to the global gitignore or merged into this file. For a more nuclear # option (not recommended) you can uncomment the following to ignore the entire idea folder. .idea/ *.pyc *.iml # Specific stuff wandb/ cache_dir/ raw_datasets/ raw_data/ final_outputs/ outputs/ validation_logs/ apex/ *.ckpt .vscode/ ================================================ FILE: README.md ================================================ # [CVPR'24 Spotlight] State Space Models for Event Cameras

youtube_video

This is the official PyTorch implementation of the CVPR 2024 paper [State Space Models for Event Cameras](https://arxiv.org/abs/2402.15584). ### 🖼️ Check Out Our Poster! 🖼️ [here](https://download.ifi.uzh.ch/rpg/CVPR24_Zubic/Zubic_CVPR24_poster.pdf) ## :white_check_mark: Updates * **` June. 14th, 2024`**: Everything is updated! Poster released! Check it above. * **` June. 6st, 2024`**: Video released! To watch our video, simply click on the YouTube play button above. * **` June. 1st, 2024`**: Our CVPR conference paper has also been accepted as a Spotlight presentation at "The 3rd Workshop on Transformers for Vision (T4V)." * **` April. 19th, 2024`**: The code along with the best checkpoints is released! The poster and video will be released shortly before CVPR 2024. ## Citation If you find this work and/or code useful, please cite our paper: ```bibtex @InProceedings{Zubic_2024_CVPR, author = {Zubic, Nikola and Gehrig, Mathias and Scaramuzza, Davide}, title = {State Space Models for Event Cameras}, booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)}, month = {June}, year = {2024}, pages = {5819-5828} } ``` ## SSM-ViT - S5 model used in our SSM-ViT pipeline can be seen [here](https://github.com/uzh-rpg/ssms_event_cameras/tree/master/RVT/models/layers/s5). - In particular, S5 is used instead of RNN in a 4-stage hierarchical ViT backbone, and its forward function is exposed [here](https://github.com/uzh-rpg/ssms_event_cameras/blob/master/RVT/models/detection/recurrent_backbone/maxvit_rnn.py#L245). What is nice about this approach is that we do not need a 'for' loop over sequence dimension, but instead we employ a parallel scanning algorithm. This model assumes that a hidden state is being carried over. - For a model that is standalone, and can be used for any sequence modeling problem, one does not use by default this formulation where we carry on the hidden state. The implementation is the same as the original JAX implementation and can be downloaded in zip format from [ssms_event_cameras/RVT/models/s5.zip](https://github.com/uzh-rpg/ssms_event_cameras/raw/master/RVT/models/s5.zip). ## Installation ### Conda We highly recommend using [Mambaforge](https://github.com/conda-forge/miniforge#mambaforge) to reduce the installation time. ```Bash conda create -y -n events_signals python=3.11 conda activate events_signals conda install pytorch==2.2.1 torchvision==0.17.1 torchaudio==2.2.1 pytorch-cuda=11.8 -c pytorch -c nvidia pip install lightning wandb pandas plotly opencv-python tabulate pycocotools bbox-visualizer StrEnum hydra-core einops torchdata tqdm numba h5py hdf5plugin lovely-tensors tensorboardX pykeops scikit-learn ``` ## Required Data To evaluate or train the S5-ViT model, you will need to download the required preprocessed datasets:
1 Mpx Gen1
pre-processed dataset download download
crc32 c5ec7c38 5acab6f3
You may also pre-process the dataset yourself by following the [instructions](https://github.com/NikolaZubic/ssms_event_cameras/blob/master/RVT/scripts/genx/README.md). ## Pre-trained Checkpoints ### 1 Mpx
S5-ViT-Base S5-ViT-Small
pre-trained checkpoint download download
### Gen1
S5-ViT-Base S5-ViT-Small
pre-trained checkpoint download download
## Evaluation - Evaluation scripts with concrete parameters that we trained our models can be seen [here](https://github.com/uzh-rpg/ssms_event_cameras/tree/master/scripts). - Set `DATA_DIR` as the path to either the 1 Mpx or Gen1 dataset directory - Set `CKPT_PATH` to the path of the *correct* checkpoint matching the choice of the model and dataset - Set - `MDL_CFG=base` or - `MDL_CFG=small` to load either the base or small model configuration. - Set `GPU_ID` to the PCI BUS ID of the GPU that you want to use. e.g. `GPU_ID=0`. Only a single GPU is supported for evaluation ### 1 Mpx ```Bash python RVT/validation.py dataset=gen4 dataset.path=${DATA_DIR} checkpoint=${CKPT_PATH} \ use_test_set=1 hardware.gpus=${GPU_ID} +experiment/gen4="${MDL_CFG}.yaml" \ batch_size.eval=12 model.postprocess.confidence_threshold=0.001 ``` ### Gen1 ```Bash python RVT/validation.py dataset=gen1 dataset.path=${DATA_DIR} checkpoint=${CKPT_PATH} \ use_test_set=1 hardware.gpus=${GPU_ID} +experiment/gen1="${MDL_CFG}.yaml" \ batch_size.eval=8 model.postprocess.confidence_threshold=0.001 ``` We set the same batch size for the evaluation and training: 12 for the 1 Mpx dataset, and 8 for the Gen1 dataset. ## Evaluation results Evaluation should give the same results as shown below: - 47.7 and 47.8 mAP on Gen1 and 1 Mpx datasets for the base model, and - 46.6 and 46.5 mAP on Gen1 and 1 Mpx datasets for the small model.

## Training - Set `DATA_DIR` as the path to either the 1 Mpx or Gen1 dataset directory - Set - `MDL_CFG=base` or - `MDL_CFG=small` to load either the base or the small configuration. - Set `GPU_IDS` to the PCI BUS IDs of the GPUs that you want to use. e.g. `GPU_IDS=[0,1]` for using GPU 0 and 1. **Using a list of IDS will enable single-node multi-GPU training.** Pay attention to the batch size which is defined per GPU. - Set `BATCH_SIZE_PER_GPU` such that the effective batch size is matching the parameters below. The **effective batch size** is (batch size per GPU)*(number of GPUs). - If you would like to change the effective batch size, we found the following learning rate scaling to work well for all models on both datasets: `lr = 2e-4 * sqrt(effective_batch_size/8)`. - The training code uses [W&B](https://wandb.ai/) for logging during the training. Hence, we assume that you have a W&B account. - The training script below will create a new project called `ssms_event_cameras`. Adapt the project name and group name if necessary. ### 1 Mpx - The effective batch size for the 1 Mpx training is 12. - For training the model on 1 Mpx dataset, we need 2x A100 80 GB GPUs and we use 12 workers per GPU for training and 4 workers per GPU for evaluation: ```Bash GPU_IDS=[0,1] BATCH_SIZE_PER_GPU=6 TRAIN_WORKERS_PER_GPU=12 EVAL_WORKERS_PER_GPU=4 python RVT/train.py model=rnndet dataset=gen4 dataset.path=${DATA_DIR} wandb.project_name=ssms_event_cameras \ wandb.group_name=1mpx +experiment/gen4="${MDL_CFG}.yaml" hardware.gpus=${GPU_IDS} \ batch_size.train=${BATCH_SIZE_PER_GPU} batch_size.eval=${BATCH_SIZE_PER_GPU} \ hardware.num_workers.train=${TRAIN_WORKERS_PER_GPU} hardware.num_workers.eval=${EVAL_WORKERS_PER_GPU} ``` If you for example want to execute the training on 4 GPUs simply adapt `GPU_IDS` and `BATCH_SIZE_PER_GPU` accordingly: ```Bash GPU_IDS=[0,1,2,3] BATCH_SIZE_PER_GPU=3 ``` ### Gen1 - The effective batch size for the Gen1 training is 8. - For training the model on the Gen1 dataset, we need 1x A100 80 GPU using 24 workers for training and 8 workers for evaluation: ```Bash GPU_IDS=0 BATCH_SIZE_PER_GPU=8 TRAIN_WORKERS_PER_GPU=24 EVAL_WORKERS_PER_GPU=8 python RVT/train.py model=rnndet dataset=gen1 dataset.path=${DATA_DIR} wandb.project_name=ssms_event_cameras \ wandb.group_name=gen1 +experiment/gen1="${MDL_CFG}.yaml" hardware.gpus=${GPU_IDS} \ batch_size.train=${BATCH_SIZE_PER_GPU} batch_size.eval=${BATCH_SIZE_PER_GPU} \ hardware.num_workers.train=${TRAIN_WORKERS_PER_GPU} hardware.num_workers.eval=${EVAL_WORKERS_PER_GPU} ``` ## Code Acknowledgments This project has used code from the following projects: - [RVT](https://github.com/uzh-rpg/RVT) - Recurrent Vision Transformers for Object Detection with Event Cameras in PyTorch - [S4](https://github.com/state-spaces/s4) - Structured State Spaces for Sequence Modeling, in particular S4 and S4D models in PyTorch - [S5](https://github.com/lindermanlab/S5) - Simplified State Space Layers for Sequence Modeling in JAX - [S5 PyTorch](https://github.com/i404788/s5-pytorch) - S5 model in PyTorch ================================================ FILE: RVT/.gitignore ================================================ # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] *$py.class # C extensions *.so # Distribution / packaging .Python build/ develop-eggs/ dist/ downloads/ eggs/ .eggs/ lib/ lib64/ parts/ sdist/ var/ wheels/ pip-wheel-metadata/ share/python-wheels/ *.egg-info/ .installed.cfg *.egg MANIFEST # PyInstaller # Usually these files are written by a python script from a template # before PyInstaller builds the exe, so as to inject date/other infos into it. *.manifest *.spec # Installer logs pip-log.txt pip-delete-this-directory.txt # Unit test / coverage reports htmlcov/ .tox/ .nox/ .coverage .coverage.* .cache nosetests.xml coverage.xml *.cover *.py,cover .hypothesis/ .pytest_cache/ # Translations *.mo *.pot # Django stuff: *.log local_settings.py db.sqlite3 db.sqlite3-journal # Flask stuff: instance/ .webassets-cache # Scrapy stuff: .scrapy # Sphinx documentation docs/_build/ # PyBuilder target/ # Jupyter Notebook .ipynb_checkpoints # IPython profile_default/ ipython_config.py # pyenv .python-version # pipenv # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. # However, in case of collaboration, if having platform-specific dependencies or dependencies # having no cross-platform support, pipenv may install dependencies that don't work, or not # install all needed dependencies. #Pipfile.lock # PEP 582; used by e.g. github.com/David-OConnor/pyflow __pypackages__/ # Celery stuff celerybeat-schedule celerybeat.pid # SageMath parsed files *.sage.py # Environments .env .venv env/ venv/ ENV/ env.bak/ venv.bak/ # Spyder project settings .spyderproject .spyproject # Rope project settings .ropeproject # mkdocs documentation /site # mypy .mypy_cache/ .dmypy.json dmypy.json # Pyre type checker .pyre/ ================================================ FILE: RVT/LICENSE ================================================ MIT License Copyright (c) 2023 Mathias Gehrig Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ================================================ FILE: RVT/README.md ================================================ # RVT: Recurrent Vision Transformers for Object Detection with Event Cameras

This is the official Pytorch implementation of the CVPR 2023 paper [Recurrent Vision Transformers for Object Detection with Event Cameras](https://arxiv.org/abs/2212.05598). Watch the [**video**](https://youtu.be/xZ-pNwHxHgY) for a quick overview. ```bibtex @InProceedings{Gehrig_2023_CVPR, author = {Mathias Gehrig and Davide Scaramuzza}, title = {Recurrent Vision Transformers for Object Detection with Event Cameras}, booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)}, year = {2023}, } ``` ## Conda Installation We highly recommend to use [Mambaforge](https://github.com/conda-forge/miniforge#mambaforge) to reduce the installation time. ```Bash conda create -y -n rvt python=3.9 pip conda activate rvt conda config --set channel_priority flexible CUDA_VERSION=11.8 conda install -y h5py=3.8.0 blosc-hdf5-plugin=1.0.0 \ hydra-core=1.3.2 einops=0.6.0 torchdata=0.6.0 tqdm numba \ pytorch=2.0.0 torchvision=0.15.0 pytorch-cuda=$CUDA_VERSION \ -c pytorch -c nvidia -c conda-forge python -m pip install pytorch-lightning==1.8.6 wandb==0.14.0 \ pandas==1.5.3 plotly==5.13.1 opencv-python==4.6.0.66 tabulate==0.9.0 \ pycocotools==2.0.6 bbox-visualizer==0.1.0 StrEnum==0.4.10 python -m pip install 'git+https://github.com/facebookresearch/detectron2.git' ``` Detectron2 is not strictly required but speeds up the evaluation. ## Required Data To evaluate or train RVT you will need to download the required preprocessed datasets:
1 Mpx Gen1
pre-processed dataset download download
crc32 c5ec7c38 5acab6f3
You may also pre-process the dataset yourself by following the [instructions](scripts/genx/README.md). ## Pre-trained Checkpoints ### 1 Mpx
RVT-Base RVT-Small RVT-Tiny
pre-trained checkpoint download download download
md5 72923a a94207 5a3c78
### Gen1
RVT-Base RVT-Small RVT-Tiny
pre-trained checkpoint download download download
md5 839317 840f2b a770b9
## Evaluation - Set `DATA_DIR` as the path to either the 1 Mpx or Gen1 dataset directory - Set `CKPT_PATH` to the path of the *correct* checkpoint matching the choice of the model and dataset. - Set - `MDL_CFG=base`, or - `MDL_CFG=small`, or - `MDL_CFG=tiny` to load either the base, small, or tiny model configuration - Set - `USE_TEST=1` to evaluate on the test set, or - `USE_TEST=0` to evaluate on the validation set - Set `GPU_ID` to the PCI BUS ID of the GPU that you want to use. e.g. `GPU_ID=0`. Only a single GPU is supported for evaluation ### 1 Mpx ```Bash python validation.py dataset=gen4 dataset.path=${DATA_DIR} checkpoint=${CKPT_PATH} \ use_test_set=${USE_TEST} hardware.gpus=${GPU_ID} +experiment/gen4="${MDL_CFG}.yaml" \ batch_size.eval=8 model.postprocess.confidence_threshold=0.001 ``` ### Gen1 ```Bash python validation.py dataset=gen1 dataset.path=${DATA_DIR} checkpoint=${CKPT_PATH} \ use_test_set=${USE_TEST} hardware.gpus=${GPU_ID} +experiment/gen1="${MDL_CFG}.yaml" \ batch_size.eval=8 model.postprocess.confidence_threshold=0.001 ``` ## Training - Set `DATA_DIR` as the path to either the 1 Mpx or Gen1 dataset directory - Set - `MDL_CFG=base`, or - `MDL_CFG=small`, or - `MDL_CFG=tiny` to load either the base, small, or tiny model configuration - Set `GPU_IDS` to the PCI BUS IDs of the GPUs that you want to use. e.g. `GPU_IDS=[0,1]` for using GPU 0 and 1. **Using a list of IDS will enable single-node multi-GPU training.** Pay attention to the batch size which is defined per GPU: - Set `BATCH_SIZE_PER_GPU` such that the effective batch size is matching the parameters below. The **effective batch size** is (batch size per gpu)*(number of GPUs). - If you would like to change the effective batch size, we found the following learning rate scaling to work well for all models on both datasets: `lr = 2e-4 * sqrt(effective_batch_size/8)`. - The training code uses [W&B](https://wandb.ai/) for logging during the training. Hence, we assume that you have a W&B account. - The training script below will create a new project called `RVT`. Adapt the project name and group name if necessary. ### 1 Mpx - The effective batch size for the 1 Mpx training is 24. - To train on 2 GPUs using 6 workers per GPU for training and 2 workers per GPU for evaluation: ```Bash GPU_IDS=[0,1] BATCH_SIZE_PER_GPU=12 TRAIN_WORKERS_PER_GPU=6 EVAL_WORKERS_PER_GPU=2 python train.py model=rnndet dataset=gen4 dataset.path=${DATA_DIR} wandb.project_name=RVT \ wandb.group_name=1mpx +experiment/gen4="${MDL_CFG}.yaml" hardware.gpus=${GPU_IDS} \ batch_size.train=${BATCH_SIZE_PER_GPU} batch_size.eval=${BATCH_SIZE_PER_GPU} \ hardware.num_workers.train=${TRAIN_WORKERS_PER_GPU} hardware.num_workers.eval=${EVAL_WORKERS_PER_GPU} ``` If you instead want to execute the training on 4 GPUs simply adapt `GPU_IDS` and `BATCH_SIZE_PER_GPU` accordingly: ```Bash GPU_IDS=[0,1,2,3] BATCH_SIZE_PER_GPU=6 ``` ### Gen1 - The effective batch size for the Gen1 training is 8. - To train on 1 GPU using 6 workers for training and 2 workers for evaluation: ```Bash GPU_IDS=0 BATCH_SIZE_PER_GPU=8 TRAIN_WORKERS_PER_GPU=6 EVAL_WORKERS_PER_GPU=2 python train.py model=rnndet dataset=gen1 dataset.path=${DATA_DIR} wandb.project_name=RVT \ wandb.group_name=gen1 +experiment/gen1="${MDL_CFG}.yaml" hardware.gpus=${GPU_IDS} \ batch_size.train=${BATCH_SIZE_PER_GPU} batch_size.eval=${BATCH_SIZE_PER_GPU} \ hardware.num_workers.train=${TRAIN_WORKERS_PER_GPU} hardware.num_workers.eval=${EVAL_WORKERS_PER_GPU} ``` ## Code Acknowledgments This project has used code from the following projects: - [timm](https://github.com/huggingface/pytorch-image-models) for the MaxViT layer implementation in Pytorch - [YOLOX](https://github.com/Megvii-BaseDetection/YOLOX) for the detection PAFPN/head ================================================ FILE: RVT/callbacks/custom.py ================================================ from omegaconf import DictConfig from lightning.pytorch.callbacks import Callback from lightning.pytorch.callbacks import ModelCheckpoint from callbacks.detection import DetectionVizCallback def get_ckpt_callback(config: DictConfig) -> ModelCheckpoint: model_name = config.model.name prefix = "val" if model_name == "rnndet": metric = "AP" mode = "max" else: raise NotImplementedError ckpt_callback_monitor = prefix + "/" + metric filename_monitor_str = prefix + "_" + metric ckpt_filename = ( "epoch={epoch:03d}-step={step}-" + filename_monitor_str + "={" + ckpt_callback_monitor + ":.2f}" ) cktp_callback = ModelCheckpoint( monitor=ckpt_callback_monitor, filename=ckpt_filename, auto_insert_metric_name=False, # because backslash would create a directory save_top_k=1, mode=mode, every_n_epochs=config.logging.ckpt_every_n_epochs, save_last=True, verbose=True, ) cktp_callback.CHECKPOINT_NAME_LAST = "last_epoch={epoch:03d}-step={step}" return cktp_callback def get_viz_callback(config: DictConfig) -> Callback: model_name = config.model.name if model_name == "rnndet": return DetectionVizCallback(config=config) raise NotImplementedError ================================================ FILE: RVT/callbacks/detection.py ================================================ from enum import Enum, auto from typing import Any import torch from einops import rearrange from omegaconf import DictConfig from data.utils.types import ObjDetOutput from loggers.wandb_logger import WandbLogger from utils.evaluation.prophesee.visualize.vis_utils import ( LABELMAP_GEN1, LABELMAP_GEN4_SHORT, draw_bboxes, ) from .viz_base import VizCallbackBase class DetectionVizEnum(Enum): EV_IMG = auto() LABEL_IMG_PROPH = auto() PRED_IMG_PROPH = auto() class DetectionVizCallback(VizCallbackBase): def __init__(self, config: DictConfig): super().__init__(config=config, buffer_entries=DetectionVizEnum) dataset_name = config.dataset.name if dataset_name == "gen1": self.label_map = LABELMAP_GEN1 elif dataset_name == "gen4": self.label_map = LABELMAP_GEN4_SHORT else: raise NotImplementedError def on_train_batch_end_custom( self, logger: WandbLogger, outputs: Any, batch: Any, log_n_samples: int, global_step: int, ) -> None: if outputs is None: # If we tried to skip the training step (not supported in DDP in PL, atm) return ev_tensors = outputs[ObjDetOutput.EV_REPR] num_samples = len(ev_tensors) assert num_samples > 0 log_n_samples = min(num_samples, log_n_samples) merged_img = [] captions = [] start_idx = num_samples - 1 end_idx = start_idx - log_n_samples # for sample_idx in range(log_n_samples): for sample_idx in range(start_idx, end_idx, -1): ev_img = self.ev_repr_to_img(ev_tensors[sample_idx].cpu().numpy()) predictions_proph = outputs[ObjDetOutput.PRED_PROPH][sample_idx] prediction_img = ev_img.copy() draw_bboxes(prediction_img, predictions_proph, labelmap=self.label_map) labels_proph = outputs[ObjDetOutput.LABELS_PROPH][sample_idx] label_img = ev_img.copy() draw_bboxes(label_img, labels_proph, labelmap=self.label_map) merged_img.append( rearrange( [prediction_img, label_img], "pl H W C -> (pl H) W C", pl=2, C=3 ) ) captions.append(f"sample_{sample_idx}") logger.log_images( key="train/predictions", images=merged_img, caption=captions, step=global_step, ) def on_validation_batch_end_custom(self, batch: Any, outputs: Any): if outputs[ObjDetOutput.SKIP_VIZ]: return ev_tensor = outputs[ObjDetOutput.EV_REPR] assert isinstance(ev_tensor, torch.Tensor) ev_img = self.ev_repr_to_img(ev_tensor.cpu().numpy()) predictions_proph = outputs[ObjDetOutput.PRED_PROPH] prediction_img = ev_img.copy() draw_bboxes(prediction_img, predictions_proph, labelmap=self.label_map) self.add_to_buffer(DetectionVizEnum.PRED_IMG_PROPH, prediction_img) labels_proph = outputs[ObjDetOutput.LABELS_PROPH] label_img = ev_img.copy() draw_bboxes(label_img, labels_proph, labelmap=self.label_map) self.add_to_buffer(DetectionVizEnum.LABEL_IMG_PROPH, label_img) def on_validation_epoch_end_custom(self, logger: WandbLogger): pred_imgs = self.get_from_buffer(DetectionVizEnum.PRED_IMG_PROPH) label_imgs = self.get_from_buffer(DetectionVizEnum.LABEL_IMG_PROPH) assert len(pred_imgs) == len(label_imgs) merged_img = [] captions = [] for idx, (pred_img, label_img) in enumerate(zip(pred_imgs, label_imgs)): merged_img.append( rearrange([pred_img, label_img], "pl H W C -> (pl H) W C", pl=2, C=3) ) captions.append(f"sample_{idx}") logger.log_images(key="val/predictions", images=merged_img, caption=captions) ================================================ FILE: RVT/callbacks/gradflow.py ================================================ from typing import Any import lightning.pytorch as pl from lightning.pytorch.callbacks import Callback from lightning.pytorch.utilities.rank_zero import rank_zero_only from callbacks.utils.visualization import get_grad_flow_figure class GradFlowLogCallback(Callback): def __init__(self, log_every_n_train_steps: int): super().__init__() assert log_every_n_train_steps > 0 self.log_every_n_train_steps = log_every_n_train_steps @rank_zero_only def on_before_zero_grad( self, trainer: pl.Trainer, pl_module: pl.LightningModule, optimizer: Any ) -> None: # NOTE: before we had this in the on_after_backward callback. # This was fine for fp32 but showed unscaled gradients for fp16. # That is why we move it to on_before_zero_grad where gradients are scaled. global_step = trainer.global_step if global_step % self.log_every_n_train_steps != 0: return named_parameters = pl_module.named_parameters() figure = get_grad_flow_figure(named_parameters) trainer.logger.log_metrics({"train/gradients": figure}, step=global_step) ================================================ FILE: RVT/callbacks/utils/visualization.py ================================================ import pandas as pd import plotly.express as px def get_grad_flow_figure(named_params): """Creates figure to visualize gradients flowing through different layers in the net during training. Can be used for checking for possible gradient vanishing / exploding problems. Usage: Use this function after loss.backwards() """ data_dict = { "name": list(), "grad_abs": list(), } for name, param in named_params: if param.requires_grad and param.grad is not None: grad_abs = param.grad.abs() data_dict["name"].append(name) data_dict["grad_abs"].append(grad_abs.mean().cpu().item()) data_frame = pd.DataFrame.from_dict(data_dict) fig = px.bar(data_frame, x="name", y="grad_abs") return fig ================================================ FILE: RVT/callbacks/viz_base.py ================================================ import random from enum import Enum from typing import Any, List, Optional, Type, Union import numpy as np import pytorch_lightning as pl import torch as th from einops import rearrange, reduce from omegaconf import DictConfig from lightning.pytorch.callbacks import Callback from lightning.pytorch.utilities.rank_zero import rank_zero_only from loggers.wandb_logger import WandbLogger class VizCallbackBase(Callback): def __init__(self, config: DictConfig, buffer_entries: Type[Enum]): super().__init__() self.log_config = config.logging self._training_has_started = False self._selected_val_batches = False self.buffer_entries = buffer_entries self._val_batch_indices = list() self._buffer = None self._reset_buffer() def _reset_buffer(self): self._buffer = {entry: [] for entry in self.buffer_entries} # Functions to be USED in the base class --------------------------------------------------------------------------- def add_to_buffer(self, key: Enum, value: Union[np.ndarray, th.Tensor]): if isinstance(value, th.Tensor): assert not value.requires_grad value = value.cpu() else: assert isinstance(value, np.ndarray) assert type(key) == self.buffer_entries assert key in self._buffer self._buffer[key].append(value) def get_from_buffer(self, key: Enum) -> List[th.Tensor]: assert type(key) == self.buffer_entries return self._buffer[key] # Functions to be IMPLEMENTED in the base class -------------------------------------------------------------------- def on_train_batch_end_custom( self, logger: WandbLogger, outputs: Any, batch: Any, log_n_samples: int, global_step: int, ) -> None: raise NotImplementedError def on_validation_batch_end_custom(self, batch: Any, outputs: Any) -> None: raise NotImplementedError def on_validation_epoch_end_custom(self, logger: WandbLogger) -> None: raise NotImplementedError # ------------------------------------------------------------------------------------------------------------------ def on_train_batch_end( self, trainer: pl.Trainer, pl_module: pl.LightningModule, outputs: Any, batch: Any, batch_idx: int, unused: int = 0, ) -> None: log_train_hd = self.log_config.train.high_dim if not log_train_hd.enable: return step = trainer.global_step assert log_train_hd.every_n_steps > 0 if step % log_train_hd.every_n_steps != 0: return n_samples = log_train_hd.n_samples logger: Optional[WandbLogger] = trainer.logger assert isinstance(logger, WandbLogger) global_step = trainer.global_step self.on_train_batch_end_custom( logger=logger, outputs=outputs, batch=batch, log_n_samples=n_samples, global_step=global_step, ) @rank_zero_only def on_validation_batch_end( self, trainer: pl.Trainer, pl_module: pl.LightningModule, outputs: Optional[Any], batch: Any, batch_idx: int, dataloader_idx: int = 0, ) -> None: log_val_hd = self.log_config.validation.high_dim log_freq_val_epochs = log_val_hd.every_n_epochs if not log_val_hd.enable: return if dataloader_idx > 0: raise NotImplementedError if not self._training_has_started: # PL has a short sanity check for validation. Hence, we have to make sure that one training run is done. return if not self._selected_val_batches: # We only want to add validation batch indices during the first true validation run. self._val_batch_indices.append(batch_idx) return assert len(self._val_batch_indices) > 0 if batch_idx not in self._val_batch_indices: return if trainer.current_epoch % log_freq_val_epochs != 0: return self.on_validation_batch_end_custom(batch, outputs) def on_validation_epoch_start( self, trainer: pl.Trainer, pl_module: pl.LightningModule ) -> None: self._reset_buffer() @rank_zero_only def on_validation_epoch_end( self, trainer: pl.Trainer, pl_module: pl.LightningModule ) -> None: log_val_hd = self.log_config.validation.high_dim log_n_samples = log_val_hd.n_samples log_freq_val_epochs = log_val_hd.every_n_epochs if len(self._val_batch_indices) == 0: return if not self._selected_val_batches: random.seed(0) num_samples = min(len(self._val_batch_indices), log_n_samples) # draw without replacement sampled_indices = random.sample(self._val_batch_indices, num_samples) self._val_batch_indices = sampled_indices self._selected_val_batches = True return if trainer.current_epoch % log_freq_val_epochs != 0: return logger: Optional[WandbLogger] = trainer.logger assert isinstance(logger, WandbLogger) self.on_validation_epoch_end_custom(logger) def on_train_batch_start( self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", batch: Any, batch_idx: int, ) -> None: self._training_has_started = True @staticmethod def ev_repr_to_img(x: np.ndarray): ch, ht, wd = x.shape[-3:] assert ch > 1 and ch % 2 == 0 ev_repr_reshaped = rearrange(x, "(posneg C) H W -> posneg C H W", posneg=2) img_neg = np.asarray( reduce(ev_repr_reshaped[0], "C H W -> H W", "sum"), dtype="int32" ) img_pos = np.asarray( reduce(ev_repr_reshaped[1], "C H W -> H W", "sum"), dtype="int32" ) img_diff = img_pos - img_neg img = 127 * np.ones((ht, wd, 3), dtype=np.uint8) img[img_diff > 0] = 255 img[img_diff < 0] = 0 return img ================================================ FILE: RVT/config/dataset/base.yaml ================================================ name: ??? path: ??? train: sampling: 'mixed' # ('random', 'stream', 'mixed') random: weighted_sampling: False mixed: w_stream: 1 w_random: 1 eval: sampling: 'stream' data_augmentation: random: prob_hflip: 0.5 rotate: prob: 0 min_angle_deg: 2 max_angle_deg: 6 zoom: prob: 0.8 zoom_in: weight: 8 factor: min: 1 max: 1.5 zoom_out: weight: 2 factor: min: 1 max: 1.2 stream: prob_hflip: 0.5 rotate: prob: 0 min_angle_deg: 2 max_angle_deg: 6 zoom: prob: 0.5 zoom_out: factor: min: 1 max: 1.2 ================================================ FILE: RVT/config/dataset/gen1.yaml ================================================ defaults: - base name: gen1 ev_repr_name: 'stacked_histogram_dt=50_nbins=10' sequence_length: 21 resolution_hw: [240, 304] downsample_by_factor_2: False only_load_end_labels: False ================================================ FILE: RVT/config/dataset/gen4.yaml ================================================ defaults: - base name: gen4 ev_repr_name: 'stacked_histogram_dt=50_nbins=10' sequence_length: 10 resolution_hw: [720, 1280] downsample_by_factor_2: True only_load_end_labels: False ================================================ FILE: RVT/config/experiment/gen1/base.yaml ================================================ # @package _global_ defaults: - default model: backbone: embed_dim: 64 fpn: depth: 0.67 ================================================ FILE: RVT/config/experiment/gen1/default.yaml ================================================ # @package _global_ defaults: - /model/maxvit_yolox: default training: precision: 32 max_epochs: 10000 max_steps: 400000 learning_rate: 0.0002 lr_scheduler: use: True total_steps: ${..max_steps} pct_start: 0.005 div_factor: 20 final_div_factor: 10000 validation: val_check_interval: 10000 check_val_every_n_epoch: null batch_size: train: 8 eval: 8 hardware: num_workers: train: 6 eval: 2 dataset: train: sampling: 'mixed' random: weighted_sampling: False mixed: w_stream: 1 w_random: 1 eval: sampling: 'stream' ev_repr_name: 'stacked_histogram_dt=50_nbins=10' sequence_length: 21 downsample_by_factor_2: False only_load_end_labels: False model: backbone: partition_split_32: 1 ================================================ FILE: RVT/config/experiment/gen1/small.yaml ================================================ # @package _global_ defaults: - default model: backbone: embed_dim: 48 stage: attention: dim_head: 24 fpn: depth: 0.33 ================================================ FILE: RVT/config/experiment/gen4/base.yaml ================================================ # @package _global_ defaults: - default model: backbone: embed_dim: 64 fpn: depth: 0.67 ================================================ FILE: RVT/config/experiment/gen4/default.yaml ================================================ # @package _global_ defaults: - /model/maxvit_yolox: default training: precision: 32 max_epochs: 10000 max_steps: 400000 learning_rate: 0.0002449489742783178 # 2e-4 * sqrt(effective_batch_size/8) = 2e-4 * sqrt(12/8) lr_scheduler: use: True total_steps: ${..max_steps} pct_start: 0.005 div_factor: 20 final_div_factor: 10000 validation: val_check_interval: 10000 check_val_every_n_epoch: null batch_size: train: 12 eval: 12 hardware: num_workers: train: 6 eval: 2 dataset: train: sampling: 'mixed' random: weighted_sampling: False mixed: w_stream: 1 w_random: 1 eval: sampling: 'stream' ev_repr_name: 'stacked_histogram_dt=50_nbins=10' sequence_length: 10 downsample_by_factor_2: True only_load_end_labels: False ================================================ FILE: RVT/config/experiment/gen4/small.yaml ================================================ # @package _global_ defaults: - default model: backbone: embed_dim: 48 stage: attention: dim_head: 24 fpn: depth: 0.33 ================================================ FILE: RVT/config/general.yaml ================================================ reproduce: seed_everything: null # Union[int, null] deterministic_flag: False # Must be true for fully deterministic behaviour (slows down training) benchmark: False # Should be set to false for fully deterministic behaviour. Could potentially speed up training. training: precision: 16 max_epochs: 10000 max_steps: 400000 learning_rate: 0.0002 weight_decay: 0 gradient_clip_val: 1.0 limit_train_batches: 1.0 lr_scheduler: use: True total_steps: ${..max_steps} pct_start: 0.005 div_factor: 25 # init_lr = max_lr / div_factor final_div_factor: 10000 # final_lr = max_lr / final_div_factor (this is different from Pytorch' OneCycleLR param) validation: limit_val_batches: 1.0 val_check_interval: null # Optional[int] check_val_every_n_epoch: 1 # Optional[int] batch_size: train: 8 eval: 8 hardware: num_workers: train: 6 eval: 2 gpus: 0 # Either a single integer (e.g. 3) or a list of integers (e.g. [3,5,6]) dist_backend: "nccl" logging: ckpt_every_n_epochs: 1 train: metrics: compute: false detection_metrics_every_n_steps: null # Optional[int] -> null: every train epoch, int: every N steps log_model_every_n_steps: 5000 log_every_n_steps: 500 high_dim: enable: True every_n_steps: 5000 n_samples: 4 validation: high_dim: enable: True every_n_epochs: 1 n_samples: 8 wandb: # How to use: # 1) resume existing wandb run: set artifact_name & wandb_runpath # 2) resume full training state in new wandb run: set artifact_name # 3) resume only model weights of checkpoint in new wandb run: set artifact_name & resume_only_weights=True # # In addition: you can specify artifact_local_file to load the checkpoint from disk. # This is for example required for resuming training with DDP. wandb_runpath: null # WandB run path. E.g. USERNAME/PROJECTNAME/1grv5kg6 artifact_name: null # Name of checkpoint/artifact. Required for resuming. E.g. USERNAME/PROJECTNAME/checkpoint-1grv5kg6-last:v15 artifact_local_file: null # If specified, will use the provided local filepath instead of downloading it. Required if resuming with DDP. resume_only_weights: False group_name: ??? # Specify group name of the run project_name: RVT ================================================ FILE: RVT/config/model/base.yaml ================================================ name: ??? ================================================ FILE: RVT/config/model/maxvit_yolox/default.yaml ================================================ # @package _global_ defaults: - override /model: rnndet model: backbone: name: MaxViTRNN compile: enable: False args: mode: reduce-overhead input_channels: 20 enable_masking: False partition_split_32: 2 embed_dim: 64 dim_multiplier: [1, 2, 4, 8] num_blocks: [1, 1, 1, 1] T_max_chrono_init: [4, 8, 16, 32] stem: patch_size: 4 stage: downsample: type: patch overlap: True norm_affine: True attention: use_torch_mha: False partition_size: ??? dim_head: 32 attention_bias: True mlp_activation: gelu mlp_gated: False mlp_bias: True mlp_ratio: 4 drop_mlp: 0 drop_path: 0 ls_init_value: 1e-5 lstm: dws_conv: False dws_conv_only_hidden: True dws_conv_kernel_size: 3 drop_cell_update: 0 s5: dim: 80 state_dim: 80 s4: dim: 80 state_dim: 80 fpn: name: PAFPN compile: enable: False args: mode: reduce-overhead depth: 0.67 # round(depth * 3) == num bottleneck blocks # stage 1 is the first and len(num_layers) is the last in_stages: [2, 3, 4] depthwise: False act: "silu" head: name: YoloX compile: enable: False args: mode: reduce-overhead depthwise: False act: "silu" postprocess: confidence_threshold: 0.1 nms_threshold: 0.45 ================================================ FILE: RVT/config/model/rnndet.yaml ================================================ defaults: - base name: rnndet backbone: name: ??? fpn: name: ??? head: name: ??? postprocess: confidence_threshold: 0.1 nms_threshold: 0.45 ================================================ FILE: RVT/config/modifier.py ================================================ import os from typing import Tuple import math from omegaconf import DictConfig, open_dict from data.utils.spatial import get_dataloading_hw def dynamically_modify_train_config(config: DictConfig): with open_dict(config): slurm_job_id = os.environ.get("SLURM_JOB_ID") if slurm_job_id and slurm_job_id != "": config.slurm_job_id = int(slurm_job_id) dataset_cfg = config.dataset dataset_name = dataset_cfg.name assert dataset_name in {"gen1", "gen4"} dataset_hw = get_dataloading_hw(dataset_config=dataset_cfg) mdl_cfg = config.model mdl_name = mdl_cfg.name if mdl_name == "rnndet": backbone_cfg = mdl_cfg.backbone backbone_name = backbone_cfg.name if backbone_name == "MaxViTRNN": partition_split_32 = backbone_cfg.partition_split_32 assert partition_split_32 in (1, 2, 4) multiple_of = 32 * partition_split_32 mdl_hw = _get_modified_hw_multiple_of( hw=dataset_hw, multiple_of=multiple_of ) print(f"Set {backbone_name} backbone (height, width) to {mdl_hw}") backbone_cfg.in_res_hw = mdl_hw attention_cfg = backbone_cfg.stage.attention partition_size = tuple(x // (32 * partition_split_32) for x in mdl_hw) assert (mdl_hw[0] // 32) % partition_size[ 0 ] == 0, f"{mdl_hw[0]=}, {partition_size[0]=}" assert (mdl_hw[1] // 32) % partition_size[ 1 ] == 0, f"{mdl_hw[1]=}, {partition_size[1]=}" print(f"Set partition sizes: {partition_size}") attention_cfg.partition_size = partition_size else: print(f"{backbone_name=} not available") raise NotImplementedError num_classes = 2 if dataset_name == "gen1" else 3 mdl_cfg.head.num_classes = num_classes print(f"Set {num_classes=} for detection head") else: print(f"{mdl_name=} not available") raise NotImplementedError def _get_modified_hw_multiple_of( hw: Tuple[int, int], multiple_of: int ) -> Tuple[int, ...]: assert isinstance(hw, tuple), f"{type(hw)=}, {hw=}" assert len(hw) == 2 assert isinstance(multiple_of, int) assert multiple_of >= 1 if multiple_of == 1: return hw new_hw = tuple(math.ceil(x / multiple_of) * multiple_of for x in hw) return new_hw ================================================ FILE: RVT/config/train.yaml ================================================ defaults: - general - dataset: ??? - model: rnndet - optional model/dataset: ${model}_${dataset} ================================================ FILE: RVT/config/val.yaml ================================================ defaults: - dataset: ??? - model: rnndet - _self_ checkpoint: ??? use_test_set: False hardware: num_workers: eval: 4 gpus: 0 # GPU idx (multi-gpu not supported for validation) batch_size: eval: 8 training: precision: 16 ================================================ FILE: RVT/data/genx_utils/collate.py ================================================ from copy import deepcopy from typing import Any, Callable, Dict, Optional, Type, Tuple, Union import torch from data.genx_utils.collate_from_pytorch import collate, default_collate_fn_map from data.genx_utils.labels import ObjectLabels, SparselyBatchedObjectLabels def collate_object_labels( batch, *, collate_fn_map: Optional[Dict[Union[Type, Tuple[Type, ...]], Callable]] = None ): return batch def collate_sparsely_batched_object_labels( batch, *, collate_fn_map: Optional[Dict[Union[Type, Tuple[Type, ...]], Callable]] = None ): return SparselyBatchedObjectLabels.transpose_list(batch) custom_collate_fn_map = deepcopy(default_collate_fn_map) custom_collate_fn_map[ObjectLabels] = collate_object_labels custom_collate_fn_map[SparselyBatchedObjectLabels] = ( collate_sparsely_batched_object_labels ) def custom_collate(batch: Any): return collate(batch, collate_fn_map=custom_collate_fn_map) def custom_collate_rnd(batch: Any): samples = batch # NOTE: We do not really need the worker id for map style datasets (rnd) but we still provide the id for consistency worker_info = torch.utils.data.get_worker_info() local_worker_id = 0 if worker_info is None else worker_info.id return { "data": custom_collate(samples), "worker_id": local_worker_id, } def custom_collate_streaming(batch: Any): """We assume that we receive a batch collected by a worker of our streaming datapipe""" samples = batch[0] worker_id = batch[1] assert isinstance(worker_id, int) return { "data": custom_collate(samples), "worker_id": worker_id, } ================================================ FILE: RVT/data/genx_utils/collate_from_pytorch.py ================================================ import collections import contextlib import re import torch torch_is_version_1 = int(torch.__version__.split(".")[0]) == 1 from typing import Callable, Dict, Optional, Tuple, Type, Union np_str_obj_array_pattern = re.compile(r"[SaUO]") default_collate_err_msg_format = ( "default_collate: batch must contain tensors, numpy arrays, numbers, " "dicts or lists; found {}" ) def collate( batch, *, collate_fn_map: Optional[Dict[Union[Type, Tuple[Type, ...]], Callable]] = None ): r""" General collate function that handles collection type of element within each batch and opens function registry to deal with specific element types. `default_collate_fn_map` provides default collate functions for tensors, numpy arrays, numbers and strings. Args: batch: a single batch to be collated collate_fn_map: Optional dictionary mapping from element type to the corresponding collate function. If the element type isn't present in this dictionary, this function will go through each key of the dictionary in the insertion order to invoke the corresponding collate function if the element type is a subclass of the key. Examples: >>> # Extend this function to handle batch of tensors >>> def collate_tensor_fn(batch, *, collate_fn_map): ... return torch.stack(batch, 0) >>> def custom_collate(batch): ... collate_map = {torch.Tensor: collate_tensor_fn} ... return collate(batch, collate_fn_map=collate_map) >>> # Extend `default_collate` by in-place modifying `default_collate_fn_map` >>> default_collate_fn_map.update({torch.Tensor: collate_tensor_fn}) Note: Each collate function requires a positional argument for batch and a keyword argument for the dictionary of collate functions as `collate_fn_map`. """ elem = batch[0] elem_type = type(elem) if collate_fn_map is not None: if elem_type in collate_fn_map: return collate_fn_map[elem_type](batch, collate_fn_map=collate_fn_map) for collate_type in collate_fn_map: if isinstance(elem, collate_type): return collate_fn_map[collate_type]( batch, collate_fn_map=collate_fn_map ) if isinstance(elem, collections.abc.Mapping): try: return elem_type( { key: collate([d[key] for d in batch], collate_fn_map=collate_fn_map) for key in elem } ) except TypeError: # The mapping type may not support `__init__(iterable)`. return { key: collate([d[key] for d in batch], collate_fn_map=collate_fn_map) for key in elem } elif isinstance(elem, tuple) and hasattr(elem, "_fields"): # namedtuple return elem_type( *( collate(samples, collate_fn_map=collate_fn_map) for samples in zip(*batch) ) ) elif isinstance(elem, collections.abc.Sequence): # check to make sure that the elements in batch have consistent size it = iter(batch) elem_size = len(next(it)) if not all(len(elem) == elem_size for elem in it): raise RuntimeError("each element in list of batch should be of equal size") transposed = list(zip(*batch)) # It may be accessed twice, so we use a list. if isinstance(elem, tuple): return [ collate(samples, collate_fn_map=collate_fn_map) for samples in transposed ] # Backwards compatibility. else: try: return elem_type( [ collate(samples, collate_fn_map=collate_fn_map) for samples in transposed ] ) except TypeError: # The sequence type may not support `__init__(iterable)` (e.g., `range`). return [ collate(samples, collate_fn_map=collate_fn_map) for samples in transposed ] raise TypeError(default_collate_err_msg_format.format(elem_type)) if torch_is_version_1: def collate_tensor_fn( batch, *, collate_fn_map: Optional[Dict[Union[Type, Tuple[Type, ...]], Callable]] = None ): elem = batch[0] out = None if torch.utils.data.get_worker_info() is not None: # If we're in a background process, concatenate directly into a # shared memory tensor to avoid an extra copy numel = sum(x.numel() for x in batch) storage = elem.storage()._new_shared(numel, device=elem.device) out = elem.new(storage).resize_(len(batch), *list(elem.size())) return torch.stack(batch, 0, out=out) else: def collate_tensor_fn( batch, *, collate_fn_map: Optional[Dict[Union[Type, Tuple[Type, ...]], Callable]] = None ): elem = batch[0] out = None if torch.utils.data.get_worker_info() is not None: # If we're in a background process, concatenate directly into a # shared memory tensor to avoid an extra copy numel = sum(x.numel() for x in batch) storage = elem._typed_storage()._new_shared(numel, device=elem.device) out = elem.new(storage).resize_(len(batch), *list(elem.size())) return torch.stack(batch, 0, out=out) def collate_numpy_array_fn( batch, *, collate_fn_map: Optional[Dict[Union[Type, Tuple[Type, ...]], Callable]] = None ): elem = batch[0] # array of string classes and object if np_str_obj_array_pattern.search(elem.dtype.str) is not None: raise TypeError(default_collate_err_msg_format.format(elem.dtype)) return collate([torch.as_tensor(b) for b in batch], collate_fn_map=collate_fn_map) def collate_numpy_scalar_fn( batch, *, collate_fn_map: Optional[Dict[Union[Type, Tuple[Type, ...]], Callable]] = None ): return torch.as_tensor(batch) def collate_float_fn( batch, *, collate_fn_map: Optional[Dict[Union[Type, Tuple[Type, ...]], Callable]] = None ): return torch.tensor(batch, dtype=torch.float64) def collate_int_fn( batch, *, collate_fn_map: Optional[Dict[Union[Type, Tuple[Type, ...]], Callable]] = None ): return torch.tensor(batch) def collate_str_fn( batch, *, collate_fn_map: Optional[Dict[Union[Type, Tuple[Type, ...]], Callable]] = None ): return batch default_collate_fn_map: Dict[Union[Type, Tuple[Type, ...]], Callable] = { torch.Tensor: collate_tensor_fn } with contextlib.suppress(ImportError): import numpy as np # For both ndarray and memmap (subclass of ndarray) default_collate_fn_map[np.ndarray] = collate_numpy_array_fn # See scalars hierarchy: https://numpy.org/doc/stable/reference/arrays.scalars.html # Skip string scalars default_collate_fn_map[(np.bool_, np.number, np.object_)] = collate_numpy_scalar_fn default_collate_fn_map[float] = collate_float_fn default_collate_fn_map[int] = collate_int_fn default_collate_fn_map[str] = collate_str_fn ================================================ FILE: RVT/data/genx_utils/dataset_rnd.py ================================================ from collections import namedtuple from collections.abc import Iterable from pathlib import Path from typing import List import numpy as np from omegaconf import DictConfig from torch.utils.data import ConcatDataset, Dataset from torch.utils.data.sampler import WeightedRandomSampler from tqdm import tqdm from data.genx_utils.labels import SparselyBatchedObjectLabels from data.genx_utils.sequence_rnd import SequenceForRandomAccess from data.utils.augmentor import RandomSpatialAugmentorGenX from data.utils.types import DatasetMode, LoaderDataDictGenX, DatasetType, DataType class SequenceDataset(Dataset): def __init__( self, path: Path, dataset_mode: DatasetMode, dataset_config: DictConfig ): assert path.is_dir() ### extract settings from config ### sequence_length = dataset_config.sequence_length assert isinstance(sequence_length, int) assert sequence_length > 0 self.output_seq_len = sequence_length ev_representation_name = dataset_config.ev_repr_name downsample_by_factor_2 = dataset_config.downsample_by_factor_2 only_load_end_labels = dataset_config.only_load_end_labels augm_config = dataset_config.data_augmentation #################################### if dataset_config.name == "gen1": dataset_type = DatasetType.GEN1 elif dataset_config.name == "gen4": dataset_type = DatasetType.GEN4 else: raise NotImplementedError self.sequence = SequenceForRandomAccess( path=path, ev_representation_name=ev_representation_name, sequence_length=sequence_length, dataset_type=dataset_type, downsample_by_factor_2=downsample_by_factor_2, only_load_end_labels=only_load_end_labels, ) self.spatial_augmentor = None if dataset_mode == DatasetMode.TRAIN: resolution_hw = tuple(dataset_config.resolution_hw) assert len(resolution_hw) == 2 ds_by_factor_2 = dataset_config.downsample_by_factor_2 if ds_by_factor_2: resolution_hw = tuple(x // 2 for x in resolution_hw) self.spatial_augmentor = RandomSpatialAugmentorGenX( dataset_hw=resolution_hw, automatic_randomization=True, augm_config=augm_config.random, ) def only_load_labels(self): self.sequence.only_load_labels() def load_everything(self): self.sequence.load_everything() def __len__(self): return len(self.sequence) def __getitem__(self, index: int) -> LoaderDataDictGenX: item = self.sequence[index] if ( self.spatial_augmentor is not None and not self.sequence.is_only_loading_labels() ): item = self.spatial_augmentor(item) return item class CustomConcatDataset(ConcatDataset): datasets: List[SequenceDataset] def __init__(self, datasets: Iterable[SequenceDataset]): super().__init__(datasets=datasets) def only_load_labels(self): for idx, dataset in enumerate(self.datasets): self.datasets[idx].only_load_labels() def load_everything(self): for idx, dataset in enumerate(self.datasets): self.datasets[idx].load_everything() def build_random_access_dataset( dataset_mode: DatasetMode, dataset_config: DictConfig ) -> CustomConcatDataset: dataset_path = Path(dataset_config.path) assert dataset_path.is_dir(), f"{str(dataset_path)}" mode2str = { DatasetMode.TRAIN: "train", DatasetMode.VALIDATION: "val", DatasetMode.TESTING: "test", } split_path = dataset_path / mode2str[dataset_mode] assert split_path.is_dir() seq_datasets = list() for entry in tqdm( split_path.iterdir(), desc=f"creating rnd access {mode2str[dataset_mode]} datasets", ): seq_datasets.append( SequenceDataset( path=entry, dataset_mode=dataset_mode, dataset_config=dataset_config ) ) return CustomConcatDataset(seq_datasets) def get_weighted_random_sampler(dataset: CustomConcatDataset) -> WeightedRandomSampler: class2count = dict() ClassAndCount = namedtuple("ClassAndCount", ["class_ids", "counts"]) classandcount_list = list() print("--- START generating weighted random sampler ---") dataset.only_load_labels() for idx, data in enumerate(tqdm(dataset, desc="iterate through dataset")): labels: SparselyBatchedObjectLabels = data[DataType.OBJLABELS_SEQ] label_list, valid_batch_indices = labels.get_valid_labels_and_batch_indices() class_ids_seq = list() for label in label_list: class_ids_numpy = np.asarray(label.class_id.numpy(), dtype="int32") class_ids_seq.append(class_ids_numpy) class_ids_seq, counts_seq = np.unique( np.concatenate(class_ids_seq), return_counts=True ) for class_id, count in zip(class_ids_seq, counts_seq): class2count[class_id] = class2count.get(class_id, 0) + count classandcount_list.append( ClassAndCount(class_ids=class_ids_seq, counts=counts_seq) ) dataset.load_everything() class2weight = {} for class_id, count in class2count.items(): count = max(count, 1) class2weight[class_id] = 1 / count weights = [] for classandcount in classandcount_list: weight = 0 for class_id, count in zip(classandcount.class_ids, classandcount.counts): # Not only weight depending on class but also depending on number of occurrences. # This will bias towards sampling "frames" with more bounding boxes. weight += class2weight[class_id] * count weights.append(weight) print("--- DONE generating weighted random sampler ---") return WeightedRandomSampler( weights=weights, num_samples=len(weights), replacement=True ) ================================================ FILE: RVT/data/genx_utils/dataset_streaming.py ================================================ from functools import partialmethod from pathlib import Path from typing import List, Union from omegaconf import DictConfig from torchdata.datapipes.map import MapDataPipe from tqdm import tqdm from data.genx_utils.sequence_for_streaming import ( SequenceForIter, RandAugmentIterDataPipe, ) from data.utils.stream_concat_datapipe import ConcatStreamingDataPipe from data.utils.stream_sharded_datapipe import ShardedStreamingDataPipe from data.utils.types import DatasetMode, DatasetType def build_streaming_dataset( dataset_mode: DatasetMode, dataset_config: DictConfig, batch_size: int, num_workers: int, ) -> Union[ConcatStreamingDataPipe, ShardedStreamingDataPipe]: dataset_path = Path(dataset_config.path) assert dataset_path.is_dir(), f"{str(dataset_path)}" mode2str = { DatasetMode.TRAIN: "train", DatasetMode.VALIDATION: "val", DatasetMode.TESTING: "test", } split_path = dataset_path / mode2str[dataset_mode] assert split_path.is_dir() datapipes = list() num_full_sequences = 0 num_splits = 0 num_split_sequences = 0 guarantee_labels = dataset_mode == DatasetMode.TRAIN for entry in tqdm( split_path.iterdir(), desc=f"creating streaming {mode2str[dataset_mode]} datasets", ): new_datapipes = get_sequences( path=entry, dataset_config=dataset_config, guarantee_labels=guarantee_labels ) if len(new_datapipes) == 1: num_full_sequences += 1 else: num_splits += 1 num_split_sequences += len(new_datapipes) datapipes.extend(new_datapipes) print(f"{num_full_sequences=}\n{num_splits=}\n{num_split_sequences=}") if dataset_mode == DatasetMode.TRAIN: return build_streaming_train_dataset( datapipes=datapipes, dataset_config=dataset_config, batch_size=batch_size, num_workers=num_workers, ) elif dataset_mode in (DatasetMode.VALIDATION, DatasetMode.TESTING): return build_streaming_evaluation_dataset( datapipes=datapipes, batch_size=batch_size ) else: raise NotImplementedError def get_sequences( path: Path, dataset_config: DictConfig, guarantee_labels: bool ) -> List[SequenceForIter]: assert path.is_dir() ### extract settings from config ### sequence_length = dataset_config.sequence_length ev_representation_name = dataset_config.ev_repr_name downsample_by_factor_2 = dataset_config.downsample_by_factor_2 if dataset_config.name == "gen1": dataset_type = DatasetType.GEN1 elif dataset_config.name == "gen4": dataset_type = DatasetType.GEN4 else: raise NotImplementedError #################################### if guarantee_labels: return SequenceForIter.get_sequences_with_guaranteed_labels( path=path, ev_representation_name=ev_representation_name, sequence_length=sequence_length, dataset_type=dataset_type, downsample_by_factor_2=downsample_by_factor_2, ) return [ SequenceForIter( path=path, ev_representation_name=ev_representation_name, sequence_length=sequence_length, dataset_type=dataset_type, downsample_by_factor_2=downsample_by_factor_2, ) ] def partialclass(cls, *args, **kwargs): class NewCls(cls): __init__ = partialmethod(cls.__init__, *args, **kwargs) return NewCls def build_streaming_train_dataset( datapipes: List[MapDataPipe], dataset_config: DictConfig, batch_size: int, num_workers: int, ) -> ConcatStreamingDataPipe: assert len(datapipes) > 0 augmentation_datapipe_type = partialclass( RandAugmentIterDataPipe, dataset_config=dataset_config ) streaming_dataset = ConcatStreamingDataPipe( datapipe_list=datapipes, batch_size=batch_size, num_workers=num_workers, augmentation_pipeline=augmentation_datapipe_type, print_seed_debug=False, ) return streaming_dataset def build_streaming_evaluation_dataset( datapipes: List[MapDataPipe], batch_size: int ) -> ShardedStreamingDataPipe: assert len(datapipes) > 0 fill_value = datapipes[0].get_fully_padded_sample() streaming_dataset = ShardedStreamingDataPipe( datapipe_list=datapipes, batch_size=batch_size, fill_value=fill_value ) return streaming_dataset ================================================ FILE: RVT/data/genx_utils/labels.py ================================================ from __future__ import annotations from typing import List, Tuple, Union, Optional import math import numpy as np import torch as th from einops import rearrange from torch.nn.functional import pad class ObjectLabelBase: _str2idx = { "t": 0, "x": 1, "y": 2, "w": 3, "h": 4, "class_id": 5, "class_confidence": 6, } def __init__(self, object_labels: th.Tensor, input_size_hw: Tuple[int, int]): assert isinstance(object_labels, th.Tensor) assert object_labels.dtype in {th.float32, th.float64} assert object_labels.ndim == 2 assert object_labels.shape[-1] == len(self._str2idx) assert isinstance(input_size_hw, tuple) assert len(input_size_hw) == 2 self.object_labels = object_labels self._input_size_hw = input_size_hw self._is_numpy = False def clamp_to_frame_(self): ht, wd = self.input_size_hw x0 = th.clamp(self.x, min=0, max=wd - 1) y0 = th.clamp(self.y, min=0, max=ht - 1) x1 = th.clamp(self.x + self.w, min=0, max=wd - 1) y1 = th.clamp(self.y + self.h, min=0, max=ht - 1) w = x1 - x0 h = y1 - y0 assert th.all(w > 0) assert th.all(h > 0) self.x = x0 self.y = y0 self.w = w self.h = h def remove_flat_labels_(self): keep = (self.w > 0) & (self.h > 0) self.object_labels = self.object_labels[keep] @classmethod def create_empty(cls): # This is useful to represent cases where no labels are available. return ObjectLabelBase( object_labels=th.empty((0, len(cls._str2idx))), input_size_hw=(0, 0) ) def _assert_not_numpy(self): assert ( not self._is_numpy ), "Labels have been converted numpy. \ Numpy is not supported for the intended operations." def to(self, *args, **kwargs): # This function executes torch.to on self tensors and returns self. self._assert_not_numpy() # This will be used by Pytorch Lightning to transfer to the relevant device self.object_labels = self.object_labels.to(*args, **kwargs) return self def numpy_(self) -> None: """ In place conversion to numpy (detach + to cpu + to numpy). Cannot be undone. """ self._is_numpy = True self.object_labels = self.object_labels.detach().cpu().numpy() @property def input_size_hw(self) -> Tuple[int, int]: return self._input_size_hw @input_size_hw.setter def input_size_hw(self, height_width: Tuple[int, int]): assert isinstance(height_width, tuple) assert len(height_width) == 2 assert height_width[0] > 0 assert height_width[1] > 0 self._input_size_hw = height_width def get(self, request: str): assert request in self._str2idx return self.object_labels[:, self._str2idx[request]] @property def t(self): return self.object_labels[:, self._str2idx["t"]] @property def x(self): return self.object_labels[:, self._str2idx["x"]] @x.setter def x(self, value: Union[th.Tensor, np.ndarray]): self.object_labels[:, self._str2idx["x"]] = value @property def y(self): return self.object_labels[:, self._str2idx["y"]] @y.setter def y(self, value: Union[th.Tensor, np.ndarray]): self.object_labels[:, self._str2idx["y"]] = value @property def w(self): return self.object_labels[:, self._str2idx["w"]] @w.setter def w(self, value: Union[th.Tensor, np.ndarray]): self.object_labels[:, self._str2idx["w"]] = value @property def h(self): return self.object_labels[:, self._str2idx["h"]] @h.setter def h(self, value: Union[th.Tensor, np.ndarray]): self.object_labels[:, self._str2idx["h"]] = value @property def class_id(self): return self.object_labels[:, self._str2idx["class_id"]] @property def class_confidence(self): return self.object_labels[:, self._str2idx["class_confidence"]] @property def dtype(self): return self.object_labels.dtype @property def device(self): return self.object_labels.device class ObjectLabelFactory(ObjectLabelBase): def __init__( self, object_labels: th.Tensor, objframe_idx_2_label_idx: th.Tensor, input_size_hw: Tuple[int, int], downsample_factor: Optional[float] = None, ): super().__init__(object_labels=object_labels, input_size_hw=input_size_hw) assert objframe_idx_2_label_idx.dtype == th.int64 assert objframe_idx_2_label_idx.dim() == 1 self.objframe_idx_2_label_idx = objframe_idx_2_label_idx self.downsample_factor = downsample_factor if self.downsample_factor is not None: assert self.downsample_factor > 1 self.clamp_to_frame_() @staticmethod def from_structured_array( object_labels: np.ndarray, objframe_idx_2_label_idx: np.ndarray, input_size_hw: Tuple[int, int], downsample_factor: Optional[float] = None, ) -> ObjectLabelFactory: np_labels = [ object_labels[key].astype("float32") for key in ObjectLabels._str2idx.keys() ] np_labels = rearrange(np_labels, "fields L -> L fields") torch_labels = th.from_numpy(np_labels) objframe_idx_2_label_idx = th.from_numpy( objframe_idx_2_label_idx.astype("int64") ) assert objframe_idx_2_label_idx.numel() == np.unique(object_labels["t"]).size return ObjectLabelFactory( object_labels=torch_labels, objframe_idx_2_label_idx=objframe_idx_2_label_idx, input_size_hw=input_size_hw, downsample_factor=downsample_factor, ) def __len__(self): return len(self.objframe_idx_2_label_idx) def __getitem__(self, item: int) -> ObjectLabels: assert item >= 0 length = len(self) assert length > 0 assert item < length is_last_item = item == length - 1 from_idx = self.objframe_idx_2_label_idx[item] to_idx = ( self.object_labels.shape[0] if is_last_item else self.objframe_idx_2_label_idx[item + 1] ) assert to_idx > from_idx object_labels = ObjectLabels( object_labels=self.object_labels[from_idx:to_idx].clone(), input_size_hw=self.input_size_hw, ) if self.downsample_factor is not None: object_labels.scale_(scaling_multiplier=1 / self.downsample_factor) return object_labels class ObjectLabels(ObjectLabelBase): def __init__(self, object_labels: th.Tensor, input_size_hw: Tuple[int, int]): super().__init__(object_labels=object_labels, input_size_hw=input_size_hw) def __len__(self) -> int: return self.object_labels.shape[0] def rotate_(self, angle_deg: float): if len(self) == 0: return # (x0,y0)---(x1,y0) p00---p10 # | | | | # | | | | # (x0,y1)---(x1,y1) p01---p11 p00 = th.stack((self.x, self.y), dim=1) p10 = th.stack((self.x + self.w, self.y), dim=1) p01 = th.stack((self.x, self.y + self.h), dim=1) p11 = th.stack((self.x + self.w, self.y + self.h), dim=1) # points: 4 x N x 2 points = th.stack((p00, p10, p01, p11), dim=0) cx = self._input_size_hw[1] // 2 cy = self._input_size_hw[0] // 2 center = th.tensor([cx, cy], device=self.device) angle_rad = angle_deg / 180 * math.pi # counter-clockwise rotation rot_matrix = th.tensor( [ [math.cos(angle_rad), math.sin(angle_rad)], [-math.sin(angle_rad), math.cos(angle_rad)], ], device=self.device, ) points = points - center points = th.einsum("ij,pnj->pni", rot_matrix, points) points = points + center height, width = self.input_size_hw x0 = th.clamp(th.min(points[..., 0], dim=0)[0], min=0, max=width - 1) y0 = th.clamp(th.min(points[..., 1], dim=0)[0], min=0, max=height - 1) x1 = th.clamp(th.max(points[..., 0], dim=0)[0], min=0, max=width - 1) y1 = th.clamp(th.max(points[..., 1], dim=0)[0], min=0, max=height - 1) self.x = x0 self.y = y0 self.w = x1 - x0 self.h = y1 - y0 self.remove_flat_labels_() assert th.all(self.x >= 0) assert th.all(self.y >= 0) assert th.all(self.x + self.w <= self.input_size_hw[1] - 1) assert th.all(self.y + self.h <= self.input_size_hw[0] - 1) def zoom_in_and_rescale_( self, zoom_coordinates_x0y0: Tuple[int, int], zoom_in_factor: float ): """ 1) Computes a new smaller canvas size: original canvas scaled by a factor of 1/zoom_in_factor (downscaling) 2) Places the smaller canvas inside the original canvas at the top-left coordinates zoom_coordinates_x0y0 3) Extract the smaller canvas and rescale it back to the original resolution """ if len(self) == 0: return assert len(zoom_coordinates_x0y0) == 2 assert zoom_in_factor >= 1 if zoom_in_factor == 1: return z_x0, z_y0 = zoom_coordinates_x0y0 h_orig, w_orig = self.input_size_hw assert 0 <= z_x0 <= w_orig - 1 assert 0 <= z_y0 <= h_orig - 1 zoom_window_h, zoom_window_w = tuple( x / zoom_in_factor for x in self.input_size_hw ) z_x1 = min(z_x0 + zoom_window_w, w_orig - 1) assert z_x1 <= w_orig - 1, f"{z_x1=} is larger than {w_orig-1=}" z_y1 = min(z_y0 + zoom_window_h, h_orig - 1) assert z_y1 <= h_orig - 1, f"{z_y1=} is larger than {h_orig-1=}" x0 = th.clamp(self.x, min=z_x0, max=z_x1 - 1) y0 = th.clamp(self.y, min=z_y0, max=z_y1 - 1) x1 = th.clamp(self.x + self.w, min=z_x0, max=z_x1 - 1) y1 = th.clamp(self.y + self.h, min=z_y0, max=z_y1 - 1) self.x = x0 - z_x0 self.y = y0 - z_y0 self.w = x1 - x0 self.h = y1 - y0 self.input_size_hw = (zoom_window_h, zoom_window_w) self.remove_flat_labels_() self.scale_(scaling_multiplier=zoom_in_factor) def zoom_out_and_rescale_( self, zoom_coordinates_x0y0: Tuple[int, int], zoom_out_factor: float ): """ 1) Scales the input by a factor of 1/zoom_out_factor (i.e. reduces the canvas size) 2) Places the downscaled canvas into the original canvas at the top-left coordinates zoom_coordinates_x0y0 """ if len(self) == 0: return assert len(zoom_coordinates_x0y0) == 2 assert zoom_out_factor >= 1 if zoom_out_factor == 1: return h_orig, w_orig = self.input_size_hw self.scale_(scaling_multiplier=1 / zoom_out_factor) self.input_size_hw = (h_orig, w_orig) z_x0, z_y0 = zoom_coordinates_x0y0 assert 0 <= z_x0 <= w_orig - 1 assert 0 <= z_y0 <= h_orig - 1 self.x = self.x + z_x0 self.y = self.y + z_y0 def scale_(self, scaling_multiplier: float): if len(self) == 0: return assert scaling_multiplier > 0 if scaling_multiplier == 1: return img_ht, img_wd = self.input_size_hw new_img_ht = scaling_multiplier * img_ht new_img_wd = scaling_multiplier * img_wd self.input_size_hw = (new_img_ht, new_img_wd) x1 = th.clamp((self.x + self.w) * scaling_multiplier, max=new_img_wd - 1) y1 = th.clamp((self.y + self.h) * scaling_multiplier, max=new_img_ht - 1) self.x = self.x * scaling_multiplier self.y = self.y * scaling_multiplier self.w = x1 - self.x self.h = y1 - self.y self.remove_flat_labels_() def flip_lr_(self) -> None: if len(self) == 0: return self.x = self.input_size_hw[1] - 1 - self.x - self.w def get_labels_as_tensors(self, format_: str = "yolox") -> th.Tensor: self._assert_not_numpy() if format_ == "yolox": out = th.zeros((len(self), 5), dtype=th.float32, device=self.device) if len(self) == 0: return out out[:, 0] = self.class_id out[:, 1] = self.x + 0.5 * self.w out[:, 2] = self.y + 0.5 * self.h out[:, 3] = self.w out[:, 4] = self.h return out else: raise NotImplementedError @staticmethod def get_labels_as_batched_tensor( obj_label_list: List[ObjectLabels], format_: str = "yolox" ) -> th.Tensor: num_object_frames = len(obj_label_list) assert num_object_frames > 0 max_num_labels_per_object_frame = max([len(x) for x in obj_label_list]) assert max_num_labels_per_object_frame > 0 if format_ == "yolox": tensor_labels = [] for labels in obj_label_list: obj_labels_tensor = labels.get_labels_as_tensors(format_=format_) num_to_pad = max_num_labels_per_object_frame - len(labels) padded_labels = pad( obj_labels_tensor, (0, 0, 0, num_to_pad), mode="constant", value=0 ) tensor_labels.append(padded_labels) tensor_labels = th.stack(tensors=tensor_labels, dim=0) return tensor_labels else: raise NotImplementedError class SparselyBatchedObjectLabels: def __init__(self, sparse_object_labels_batch: List[Optional[ObjectLabels]]): # Can contain None elements that indicate missing labels. for entry in sparse_object_labels_batch: assert isinstance(entry, ObjectLabels) or entry is None self.sparse_object_labels_batch = sparse_object_labels_batch self.set_empty_labels_to_none_() def __len__(self) -> int: return len(self.sparse_object_labels_batch) def __iter__(self): return iter(self.sparse_object_labels_batch) def __getitem__(self, item: int) -> Optional[ObjectLabels]: if item < 0 or item >= len(self): raise IndexError(f"Index ({item}) out of range (0, {len(self) - 1})") return self.sparse_object_labels_batch[item] def __add__(self, other: SparselyBatchedObjectLabels): sparse_object_labels_batch = ( self.sparse_object_labels_batch + other.sparse_object_labels_batch ) return SparselyBatchedObjectLabels( sparse_object_labels_batch=sparse_object_labels_batch ) def set_empty_labels_to_none_(self): for idx, obj_label in enumerate(self.sparse_object_labels_batch): if obj_label is not None and len(obj_label) == 0: self.sparse_object_labels_batch[idx] = None @property def input_size_hw(self) -> Optional[Union[Tuple[int, int], Tuple[float, float]]]: for obj_labels in self.sparse_object_labels_batch: if obj_labels is not None: return obj_labels.input_size_hw return None def zoom_in_and_rescale_(self, *args, **kwargs): for idx, entry in enumerate(self.sparse_object_labels_batch): if entry is not None: self.sparse_object_labels_batch[idx].zoom_in_and_rescale_( *args, **kwargs ) # We may have deleted labels. If no labels are left, set the object to None self.set_empty_labels_to_none_() def zoom_out_and_rescale_(self, *args, **kwargs): for idx, entry in enumerate(self.sparse_object_labels_batch): if entry is not None: self.sparse_object_labels_batch[idx].zoom_out_and_rescale_( *args, **kwargs ) def rotate_(self, *args, **kwargs): for idx, entry in enumerate(self.sparse_object_labels_batch): if entry is not None: self.sparse_object_labels_batch[idx].rotate_(*args, **kwargs) def scale_(self, *args, **kwargs): for idx, entry in enumerate(self.sparse_object_labels_batch): if entry is not None: self.sparse_object_labels_batch[idx].scale_(*args, **kwargs) # We may have deleted labels. If no labels are left, set the object to None self.set_empty_labels_to_none_() def flip_lr_(self): for idx, entry in enumerate(self.sparse_object_labels_batch): if entry is not None: self.sparse_object_labels_batch[idx].flip_lr_() def to(self, *args, **kwargs): for idx, entry in enumerate(self.sparse_object_labels_batch): if entry is not None: self.sparse_object_labels_batch[idx].to(*args, **kwargs) return self def get_valid_labels_and_batch_indices( self, ) -> Tuple[List[ObjectLabels], List[int]]: out = list() valid_indices = list() for idx, label in enumerate(self.sparse_object_labels_batch): if label is not None: out.append(label) valid_indices.append(idx) return out, valid_indices @staticmethod def transpose_list( list_of_sparsely_batched_object_labels: List[SparselyBatchedObjectLabels], ) -> List[SparselyBatchedObjectLabels]: return [ SparselyBatchedObjectLabels(list(labels_as_tuple)) for labels_as_tuple in zip(*list_of_sparsely_batched_object_labels) ] ================================================ FILE: RVT/data/genx_utils/sequence_base.py ================================================ from pathlib import Path from typing import Any, List, Optional import h5py import numpy as np import torch from torchdata.datapipes.map import MapDataPipe from data.genx_utils.labels import ObjectLabelFactory, ObjectLabels from data.utils.spatial import get_original_hw from data.utils.types import DatasetType from utils.timers import TimerDummy as Timer def get_event_representation_dir(path: Path, ev_representation_name: str) -> Path: ev_repr_dir = path / "event_representations_v2" / ev_representation_name assert ev_repr_dir.is_dir(), f"{ev_repr_dir}" return ev_repr_dir def get_objframe_idx_2_repr_idx(path: Path, ev_representation_name: str) -> np.ndarray: ev_repr_dir = get_event_representation_dir( path=path, ev_representation_name=ev_representation_name ) objframe_idx_2_repr_idx = np.load(str(ev_repr_dir / "objframe_idx_2_repr_idx.npy")) return objframe_idx_2_repr_idx class SequenceBase(MapDataPipe): """ Structure example of a sequence: . ├── event_representations_v2 │ └── ev_representation_name │ ├── event_representations.h5 │ ├── objframe_idx_2_repr_idx.npy │ └── timestamps_us.npy └── labels_v2 ├── labels.npz └── timestamps_us.npy """ def __init__( self, path: Path, ev_representation_name: str, sequence_length: int, dataset_type: DatasetType, downsample_by_factor_2: bool, only_load_end_labels: bool, ): assert sequence_length >= 1 assert path.is_dir() assert dataset_type in { DatasetType.GEN1, DatasetType.GEN4, }, f"{dataset_type} not implemented" self.only_load_end_labels = only_load_end_labels ev_repr_dir = get_event_representation_dir( path=path, ev_representation_name=ev_representation_name ) labels_dir = path / "labels_v2" assert labels_dir.is_dir() height, width = get_original_hw(dataset_type) self.seq_len = sequence_length ds_factor_str = "_ds2_nearest" if downsample_by_factor_2 else "" self.ev_repr_file = ev_repr_dir / f"event_representations{ds_factor_str}.h5" assert self.ev_repr_file.exists(), f"{str(self.ev_repr_file)=}" with Timer(timer_name="prepare labels"): label_data = np.load(str(labels_dir / "labels.npz")) objframe_idx_2_label_idx = label_data["objframe_idx_2_label_idx"] labels = label_data["labels"] label_factory = ObjectLabelFactory.from_structured_array( object_labels=labels, objframe_idx_2_label_idx=objframe_idx_2_label_idx, input_size_hw=(height, width), downsample_factor=2 if downsample_by_factor_2 else None, ) self.label_factory = label_factory with Timer(timer_name="load objframe_idx_2_repr_idx"): self.objframe_idx_2_repr_idx = get_objframe_idx_2_repr_idx( path=path, ev_representation_name=ev_representation_name ) with Timer(timer_name="construct repr_idx_2_objframe_idx"): self.repr_idx_2_objframe_idx = dict( zip( self.objframe_idx_2_repr_idx, range(len(self.objframe_idx_2_repr_idx)), ) ) def _get_labels_from_repr_idx(self, repr_idx: int) -> Optional[ObjectLabels]: objframe_idx = self.repr_idx_2_objframe_idx.get(repr_idx, None) return None if objframe_idx is None else self.label_factory[objframe_idx] def _get_event_repr_torch(self, start_idx: int, end_idx: int) -> List[torch.Tensor]: assert end_idx > start_idx with h5py.File(str(self.ev_repr_file), "r") as h5f: ev_repr = h5f["data"][start_idx:end_idx] ev_repr = torch.from_numpy(ev_repr) if ev_repr.dtype != torch.uint8: ev_repr = torch.asarray(ev_repr, dtype=torch.float32) ev_repr = torch.split(ev_repr, 1, dim=0) # remove first dim that is always 1 due to how torch.split works ev_repr = [x[0] for x in ev_repr] return ev_repr def __len__(self) -> int: raise NotImplementedError def __getitem__(self, index: int) -> Any: raise NotImplementedError ================================================ FILE: RVT/data/genx_utils/sequence_for_streaming.py ================================================ from pathlib import Path from typing import List, Optional, Union, Tuple import h5py import numpy as np import torch from omegaconf import DictConfig from torchdata.datapipes.iter import IterDataPipe from data.genx_utils.labels import SparselyBatchedObjectLabels from data.genx_utils.sequence_base import SequenceBase, get_objframe_idx_2_repr_idx from data.utils.augmentor import RandomSpatialAugmentorGenX from data.utils.types import DataType, DatasetType, LoaderDataDictGenX from utils.timers import TimerDummy as Timer def _scalar_as_1d_array(scalar: Union[int, float]): return np.atleast_1d(scalar) def _get_ev_repr_range_indices( indices: np.ndarray, max_len: int ) -> List[Tuple[int, int]]: """ Computes a list of index ranges based on the input array of indices and a maximum length. The index ranges are computed such that the difference between consecutive indices should not exceed the maximum length (max_len). Parameters: ----------- indices : np.ndarray A NumPy array of indices, where the indices are sorted in ascending order. max_len : int The maximum allowed length between consecutive indices. Returns: -------- out : List[Tuple[int, int]] A list of tuples, where each tuple contains two integers representing the start and stop indices of the range. """ meta_indices_stop = np.flatnonzero(np.diff(indices) > max_len) meta_indices_start = np.concatenate((np.atleast_1d(0), meta_indices_stop + 1)) meta_indices_stop = np.concatenate( (meta_indices_stop, np.atleast_1d(len(indices) - 1)) ) out = list() for meta_idx_start, meta_idx_stop in zip(meta_indices_start, meta_indices_stop): idx_start = max(indices[meta_idx_start] - max_len + 1, 0) idx_stop = indices[meta_idx_stop] + 1 out.append((idx_start, idx_stop)) return out class SequenceForIter(SequenceBase): def __init__( self, path: Path, ev_representation_name: str, sequence_length: int, dataset_type: DatasetType, downsample_by_factor_2: bool, range_indices: Optional[Tuple[int, int]] = None, ): super().__init__( path=path, ev_representation_name=ev_representation_name, sequence_length=sequence_length, dataset_type=dataset_type, downsample_by_factor_2=downsample_by_factor_2, only_load_end_labels=False, ) with h5py.File(str(self.ev_repr_file), "r") as h5f: num_ev_repr = h5f["data"].shape[0] if range_indices is None: repr_idx_start = max( self.objframe_idx_2_repr_idx[0] - sequence_length + 1, 0 ) repr_idx_stop = num_ev_repr else: repr_idx_start, repr_idx_stop = range_indices # Set start idx such that the first label is no further than the last timestamp of the first sample sub-sequence min_start_repr_idx = max( self.objframe_idx_2_repr_idx[0] - sequence_length + 1, 0 ) assert ( 0 <= min_start_repr_idx <= repr_idx_start < repr_idx_stop <= num_ev_repr ), f"{min_start_repr_idx=}, {repr_idx_start=}, {repr_idx_stop=}, {num_ev_repr=}, {path=}" self.start_indices = list(range(repr_idx_start, repr_idx_stop, sequence_length)) self.stop_indices = self.start_indices[1:] + [repr_idx_stop] self.length = len(self.start_indices) self._padding_representation = None @staticmethod def get_sequences_with_guaranteed_labels( path: Path, ev_representation_name: str, sequence_length: int, dataset_type: DatasetType, downsample_by_factor_2: bool, ) -> List["SequenceForIter"]: """Generate sequences such that we do always have labels within each sample of the sequence This is required for training such that we are guaranteed to always have labels in the training step. However, for validation we don't require this if we catch the special case. """ objframe_idx_2_repr_idx = get_objframe_idx_2_repr_idx( path=path, ev_representation_name=ev_representation_name ) # max diff for repr idx is sequence length range_indices_list = _get_ev_repr_range_indices( indices=objframe_idx_2_repr_idx, max_len=sequence_length ) sequence_list = list() for range_indices in range_indices_list: sequence_list.append( SequenceForIter( path=path, ev_representation_name=ev_representation_name, sequence_length=sequence_length, dataset_type=dataset_type, downsample_by_factor_2=downsample_by_factor_2, range_indices=range_indices, ) ) return sequence_list @property def padding_representation(self) -> torch.Tensor: if self._padding_representation is None: ev_repr = self._get_event_repr_torch(start_idx=0, end_idx=1)[0] self._padding_representation = torch.zeros_like(ev_repr) return self._padding_representation def get_fully_padded_sample(self) -> LoaderDataDictGenX: is_first_sample = False is_padded_mask = [True] * self.seq_len ev_repr = [self.padding_representation] * self.seq_len labels = [None] * self.seq_len sparse_labels = SparselyBatchedObjectLabels(sparse_object_labels_batch=labels) out = { DataType.EV_REPR: ev_repr, DataType.OBJLABELS_SEQ: sparse_labels, DataType.IS_FIRST_SAMPLE: is_first_sample, DataType.IS_PADDED_MASK: is_padded_mask, } return out def __len__(self): return self.length def __getitem__(self, index: int) -> LoaderDataDictGenX: start_idx = self.start_indices[index] end_idx = self.stop_indices[index] # sequence info ### sample_len = end_idx - start_idx assert self.seq_len >= sample_len > 0, ( f"{self.seq_len=}, {sample_len=}, {start_idx=}, {end_idx=}, " f"\n{self.start_indices=}\n{self.stop_indices=}" ) is_first_sample = True if index == 0 else False is_padded_mask = [False] * sample_len ################### # event representations ### with Timer(timer_name="read ev reprs"): ev_repr = self._get_event_repr_torch(start_idx=start_idx, end_idx=end_idx) assert len(ev_repr) == sample_len ########################### # labels ### labels = list() for repr_idx in range(start_idx, end_idx): labels.append(self._get_labels_from_repr_idx(repr_idx)) assert len(labels) == len(ev_repr) ############ # apply padding (if necessary) ### if sample_len < self.seq_len: padding_len = self.seq_len - sample_len is_padded_mask.extend([True] * padding_len) ev_repr.extend([self.padding_representation] * padding_len) labels.extend([None] * padding_len) ################################## # convert labels to sparse labels for datapipes and dataloader sparse_labels = SparselyBatchedObjectLabels(sparse_object_labels_batch=labels) out = { DataType.EV_REPR: ev_repr, DataType.OBJLABELS_SEQ: sparse_labels, DataType.IS_FIRST_SAMPLE: is_first_sample, DataType.IS_PADDED_MASK: is_padded_mask, } return out class RandAugmentIterDataPipe(IterDataPipe): def __init__(self, source_dp: IterDataPipe, dataset_config: DictConfig): super().__init__() self.source_dp = source_dp resolution_hw = tuple(dataset_config.resolution_hw) assert len(resolution_hw) == 2 ds_by_factor_2 = dataset_config.downsample_by_factor_2 if ds_by_factor_2: resolution_hw = tuple(x // 2 for x in resolution_hw) augm_config = dataset_config.data_augmentation self.spatial_augmentor = RandomSpatialAugmentorGenX( dataset_hw=resolution_hw, automatic_randomization=False, augm_config=augm_config.stream, ) def __iter__(self): self.spatial_augmentor.randomize_augmentation() for x in self.source_dp: yield self.spatial_augmentor(x) ================================================ FILE: RVT/data/genx_utils/sequence_rnd.py ================================================ from pathlib import Path from data.genx_utils.labels import SparselyBatchedObjectLabels from data.genx_utils.sequence_base import SequenceBase from data.utils.types import DataType, DatasetType, LoaderDataDictGenX from utils.timers import TimerDummy as Timer class SequenceForRandomAccess(SequenceBase): def __init__( self, path: Path, ev_representation_name: str, sequence_length: int, dataset_type: DatasetType, downsample_by_factor_2: bool, only_load_end_labels: bool, ): super().__init__( path=path, ev_representation_name=ev_representation_name, sequence_length=sequence_length, dataset_type=dataset_type, downsample_by_factor_2=downsample_by_factor_2, only_load_end_labels=only_load_end_labels, ) self.start_idx_offset = None for objframe_idx, repr_idx in enumerate(self.objframe_idx_2_repr_idx): if repr_idx - self.seq_len + 1 >= 0: # We can fit the sequence length to the label self.start_idx_offset = objframe_idx break if self.start_idx_offset is None: # This leads to actual length of 0: self.start_idx_offset = len(self.label_factory) self.length = len(self.label_factory) - self.start_idx_offset assert len(self.label_factory) == len(self.objframe_idx_2_repr_idx) # Useful for weighted sampler that is based on label statistics: self._only_load_labels = False def __len__(self): return self.length def __getitem__(self, index: int) -> LoaderDataDictGenX: corrected_idx = index + self.start_idx_offset labels_repr_idx = self.objframe_idx_2_repr_idx[corrected_idx] end_idx = labels_repr_idx + 1 start_idx = end_idx - self.seq_len assert_msg = ( f"{self.ev_repr_file=}, {self.start_idx_offset=}, {start_idx=}, {end_idx=}" ) assert start_idx >= 0, assert_msg labels = list() for repr_idx in range(start_idx, end_idx): if self.only_load_end_labels and repr_idx < end_idx - 1: labels.append(None) else: labels.append(self._get_labels_from_repr_idx(repr_idx)) sparse_labels = SparselyBatchedObjectLabels(sparse_object_labels_batch=labels) if self._only_load_labels: return {DataType.OBJLABELS_SEQ: sparse_labels} with Timer(timer_name="read ev reprs"): ev_repr = self._get_event_repr_torch(start_idx=start_idx, end_idx=end_idx) assert len(sparse_labels) == len(ev_repr) is_first_sample = True # Due to random loading is_padded_mask = [False] * len(ev_repr) out = { DataType.EV_REPR: ev_repr, DataType.OBJLABELS_SEQ: sparse_labels, DataType.IS_FIRST_SAMPLE: is_first_sample, DataType.IS_PADDED_MASK: is_padded_mask, } return out def is_only_loading_labels(self) -> bool: return self._only_load_labels def only_load_labels(self): self._only_load_labels = True def load_everything(self): self._only_load_labels = False ================================================ FILE: RVT/data/utils/augmentor.py ================================================ import collections.abc as abc from dataclasses import dataclass from typing import Any, Optional, Tuple, Union from warnings import filterwarnings, warn import torch as th import torch.distributions.categorical from omegaconf import DictConfig from torch.nn.functional import interpolate from torchvision.transforms import InterpolationMode from torchvision.transforms.functional import rotate from data.genx_utils.labels import ObjectLabels, SparselyBatchedObjectLabels from data.utils.types import DataType, LoaderDataDictGenX from utils.helpers import torch_uniform_sample_scalar NO_LABEL_WARN_MSG = ( "No Labels found. This can lead to a crash and should not happen often." ) filterwarnings("always", message=NO_LABEL_WARN_MSG) @dataclass class ZoomOutState: active: bool x0: int y0: int zoom_out_factor: float @dataclass class RotationState: active: bool angle_deg: float @dataclass class AugmentationState: apply_h_flip: bool rotation: RotationState apply_zoom_in: bool zoom_out: ZoomOutState class RandomSpatialAugmentorGenX: def __init__( self, dataset_hw: Tuple[int, int], automatic_randomization: bool, augm_config: DictConfig, ): assert isinstance(dataset_hw, tuple) assert len(dataset_hw) == 2 assert all(x > 0 for x in dataset_hw) assert isinstance(automatic_randomization, bool) self.hw_tuple = dataset_hw self.automatic_randomization = automatic_randomization self.h_flip_prob = augm_config.prob_hflip self.rot_prob = augm_config.rotate.prob self.rot_min_angle_deg = augm_config.rotate.get("min_angle_deg", 0) self.rot_max_angle_deg = augm_config.rotate.max_angle_deg self.zoom_prob = augm_config.zoom.prob zoom_out_weight = augm_config.zoom.zoom_out.get("weight", 1) self.min_zoom_out_factor = augm_config.zoom.zoom_out.factor.min self.max_zoom_out_factor = augm_config.zoom.zoom_out.factor.max has_zoom_in = "zoom_in" in augm_config.zoom zoom_in_weight = augm_config.zoom.zoom_in.weight if has_zoom_in else 0 self.min_zoom_in_factor = ( augm_config.zoom.zoom_in.factor.min if has_zoom_in else 1 ) self.max_zoom_in_factor = ( augm_config.zoom.zoom_in.factor.max if has_zoom_in else 1 ) assert 0 <= self.h_flip_prob <= 1 assert 0 <= self.rot_prob <= 1 assert 0 <= self.rot_min_angle_deg <= self.rot_max_angle_deg assert 0 <= self.zoom_prob <= 1 assert 0 <= zoom_in_weight assert self.max_zoom_in_factor >= self.min_zoom_in_factor >= 1 assert 0 <= zoom_out_weight assert self.max_zoom_out_factor >= self.min_zoom_out_factor >= 1 if not automatic_randomization: # We are probably applying augmentation to a streaming dataset for which zoom in augm is not supported. assert zoom_in_weight == 0, f"{zoom_in_weight=}" self.zoom_in_or_out_distribution = torch.distributions.categorical.Categorical( probs=th.tensor([zoom_in_weight, zoom_out_weight]) ) self.augm_state = AugmentationState( apply_h_flip=False, rotation=RotationState(active=False, angle_deg=0.0), apply_zoom_in=False, zoom_out=ZoomOutState(active=False, x0=0, y0=0, zoom_out_factor=1.0), ) def randomize_augmentation(self): """Sample new augmentation parameters that will be consistently applied among the items. This function only works with augmentations that are input-independent. E.g. The zoom-in augmentation parameters depend on the labels and cannot be sampled in this function. For the same reason, it is not a very reasonable augmentation for the streaming scenario. """ self.augm_state.apply_h_flip = self.h_flip_prob > th.rand(1).item() self.augm_state.rotation.active = self.rot_prob > th.rand(1).item() if self.augm_state.rotation.active: sign = 1 if th.randn(1).item() >= 0 else -1 self.augm_state.rotation.angle_deg = sign * torch_uniform_sample_scalar( min_value=self.rot_min_angle_deg, max_value=self.rot_max_angle_deg ) # Zoom in and zoom out is mutually exclusive. do_zoom = self.zoom_prob > th.rand(1).item() do_zoom_in = self.zoom_in_or_out_distribution.sample().item() == 0 do_zoom_out = not do_zoom_in do_zoom_in &= do_zoom do_zoom_out &= do_zoom self.augm_state.apply_zoom_in = do_zoom_in self.augm_state.zoom_out.active = do_zoom_out if do_zoom_out: rand_zoom_out_factor = torch_uniform_sample_scalar( min_value=self.min_zoom_out_factor, max_value=self.max_zoom_out_factor ) height, width = self.hw_tuple zoom_window_h, zoom_window_w = int(height / rand_zoom_out_factor), int( width / rand_zoom_out_factor ) x0_sampled = int( torch_uniform_sample_scalar( min_value=0, max_value=width - zoom_window_w ) ) y0_sampled = int( torch_uniform_sample_scalar( min_value=0, max_value=height - zoom_window_h ) ) self.augm_state.zoom_out.x0 = x0_sampled self.augm_state.zoom_out.y0 = y0_sampled self.augm_state.zoom_out.zoom_out_factor = rand_zoom_out_factor def _zoom_out_and_rescale( self, data_dict: LoaderDataDictGenX ) -> LoaderDataDictGenX: zoom_out_state = self.augm_state.zoom_out zoom_out_factor = zoom_out_state.zoom_out_factor if zoom_out_factor == 1: return data_dict return { k: RandomSpatialAugmentorGenX._zoom_out_and_rescale_recursive( v, zoom_coordinates_x0y0=(zoom_out_state.x0, zoom_out_state.y0), zoom_out_factor=zoom_out_factor, datatype=k, ) for k, v in data_dict.items() } @staticmethod def _zoom_out_and_rescale_tensor( input_: th.Tensor, zoom_coordinates_x0y0: Tuple[int, int], zoom_out_factor: float, datatype: DataType, ) -> th.Tensor: assert len(zoom_coordinates_x0y0) == 2 assert isinstance(input_, th.Tensor) if datatype == DataType.IMAGE or datatype == DataType.EV_REPR: assert input_.ndim == 3, f"{input_.shape=}" height, width = input_.shape[-2:] zoom_window_h, zoom_window_w = int(height / zoom_out_factor), int( width / zoom_out_factor ) zoom_window = interpolate( input_.unsqueeze(0), size=(zoom_window_h, zoom_window_w), mode="nearest-exact", )[0] output = th.zeros_like(input_) x0, y0 = zoom_coordinates_x0y0 assert x0 >= 0 assert y0 >= 0 output[:, y0 : y0 + zoom_window_h, x0 : x0 + zoom_window_w] = zoom_window return output raise NotImplementedError @classmethod def _zoom_out_and_rescale_recursive( cls, input_: Any, zoom_coordinates_x0y0: Tuple[int, int], zoom_out_factor: float, datatype: DataType, ): if datatype in (DataType.IS_PADDED_MASK, DataType.IS_FIRST_SAMPLE): return input_ if isinstance(input_, th.Tensor): return cls._zoom_out_and_rescale_tensor( input_=input_, zoom_coordinates_x0y0=zoom_coordinates_x0y0, zoom_out_factor=zoom_out_factor, datatype=datatype, ) if isinstance(input_, ObjectLabels) or isinstance( input_, SparselyBatchedObjectLabels ): assert datatype == DataType.OBJLABELS or datatype == DataType.OBJLABELS_SEQ input_.zoom_out_and_rescale_( zoom_coordinates_x0y0=zoom_coordinates_x0y0, zoom_out_factor=zoom_out_factor, ) return input_ if isinstance(input_, abc.Sequence): return [ RandomSpatialAugmentorGenX._zoom_out_and_rescale_recursive( x, zoom_coordinates_x0y0=zoom_coordinates_x0y0, zoom_out_factor=zoom_out_factor, datatype=datatype, ) for x in input_ ] if isinstance(input_, abc.Mapping): return { key: RandomSpatialAugmentorGenX._zoom_out_and_rescale_recursive( value, zoom_coordinates_x0y0=zoom_coordinates_x0y0, zoom_out_factor=zoom_out_factor, datatype=datatype, ) for key, value in input_.items() } raise NotImplementedError def _zoom_in_and_rescale(self, data_dict: LoaderDataDictGenX) -> LoaderDataDictGenX: rand_zoom_in_factor = torch_uniform_sample_scalar( min_value=self.min_zoom_in_factor, max_value=self.max_zoom_in_factor ) if rand_zoom_in_factor == 1: return data_dict height, width = RandomSpatialAugmentorGenX._hw_from_data(data_dict=data_dict) assert (height, width) == self.hw_tuple zoom_window_h, zoom_window_w = int(height / rand_zoom_in_factor), int( width / rand_zoom_in_factor ) latest_objframe = get_most_recent_objframe( data_dict=data_dict, check_if_nonempty=True ) if latest_objframe is None: warn(message=NO_LABEL_WARN_MSG, category=UserWarning, stacklevel=2) return data_dict x0_sampled, y0_sampled = randomly_sample_zoom_window_from_objframe( objframe=latest_objframe, zoom_window_height=zoom_window_h, zoom_window_width=zoom_window_w, ) return { k: RandomSpatialAugmentorGenX._zoom_in_and_rescale_recursive( v, zoom_coordinates_x0y0=(x0_sampled, y0_sampled), zoom_in_factor=rand_zoom_in_factor, datatype=k, ) for k, v in data_dict.items() } @staticmethod def _zoom_in_and_rescale_tensor( input_: th.Tensor, zoom_coordinates_x0y0: Tuple[int, int], zoom_in_factor: float, datatype: DataType, ) -> th.Tensor: assert len(zoom_coordinates_x0y0) == 2 assert isinstance(input_, th.Tensor) if datatype == DataType.IMAGE or datatype == DataType.EV_REPR: assert input_.ndim == 3, f"{input_.shape=}" height, width = input_.shape[-2:] zoom_window_h, zoom_window_w = int(height / zoom_in_factor), int( width / zoom_in_factor ) x0, y0 = zoom_coordinates_x0y0 assert x0 >= 0 assert y0 >= 0 zoom_canvas = input_[ ..., y0 : y0 + zoom_window_h, x0 : x0 + zoom_window_w ].unsqueeze(0) output = interpolate( zoom_canvas, size=(height, width), mode="nearest-exact" ) output = output[0] return output raise NotImplementedError @classmethod def _zoom_in_and_rescale_recursive( cls, input_: Any, zoom_coordinates_x0y0: Tuple[int, int], zoom_in_factor: float, datatype: DataType, ): if datatype in (DataType.IS_PADDED_MASK, DataType.IS_FIRST_SAMPLE): return input_ if isinstance(input_, th.Tensor): return cls._zoom_in_and_rescale_tensor( input_=input_, zoom_coordinates_x0y0=zoom_coordinates_x0y0, zoom_in_factor=zoom_in_factor, datatype=datatype, ) if isinstance(input_, ObjectLabels) or isinstance( input_, SparselyBatchedObjectLabels ): assert datatype == DataType.OBJLABELS or datatype == DataType.OBJLABELS_SEQ input_.zoom_in_and_rescale_( zoom_coordinates_x0y0=zoom_coordinates_x0y0, zoom_in_factor=zoom_in_factor, ) return input_ if isinstance(input_, abc.Sequence): return [ RandomSpatialAugmentorGenX._zoom_in_and_rescale_recursive( x, zoom_coordinates_x0y0=zoom_coordinates_x0y0, zoom_in_factor=zoom_in_factor, datatype=datatype, ) for x in input_ ] if isinstance(input_, abc.Mapping): return { key: RandomSpatialAugmentorGenX._zoom_in_and_rescale_recursive( value, zoom_coordinates_x0y0=zoom_coordinates_x0y0, zoom_in_factor=zoom_in_factor, datatype=datatype, ) for key, value in input_.items() } raise NotImplementedError def _rotate(self, data_dict: LoaderDataDictGenX) -> LoaderDataDictGenX: angle_deg = self.augm_state.rotation.angle_deg return { k: RandomSpatialAugmentorGenX._rotate_recursive( v, angle_deg=angle_deg, datatype=k ) for k, v in data_dict.items() } @staticmethod def _rotate_tensor(input_: Any, angle_deg: float, datatype: DataType): assert isinstance(input_, th.Tensor) if datatype == DataType.IMAGE or datatype == DataType.EV_REPR: return rotate( input_, angle=angle_deg, interpolation=InterpolationMode.NEAREST ) raise NotImplementedError @classmethod def _rotate_recursive(cls, input_: Any, angle_deg: float, datatype: DataType): if datatype in (DataType.IS_PADDED_MASK, DataType.IS_FIRST_SAMPLE): return input_ if isinstance(input_, th.Tensor): return cls._rotate_tensor( input_=input_, angle_deg=angle_deg, datatype=datatype ) if isinstance(input_, ObjectLabels) or isinstance( input_, SparselyBatchedObjectLabels ): assert datatype == DataType.OBJLABELS or datatype == DataType.OBJLABELS_SEQ input_.rotate_(angle_deg=angle_deg) return input_ if isinstance(input_, abc.Sequence): return [ RandomSpatialAugmentorGenX._rotate_recursive( x, angle_deg=angle_deg, datatype=datatype ) for x in input_ ] if isinstance(input_, abc.Mapping): return { key: RandomSpatialAugmentorGenX._rotate_recursive( value, angle_deg=angle_deg, datatype=datatype ) for key, value in input_.items() } raise NotImplementedError @staticmethod def _flip(data_dict: LoaderDataDictGenX, type_: str) -> LoaderDataDictGenX: assert type_ in {"h", "v"} return { k: RandomSpatialAugmentorGenX._flip_recursive( v, flip_type=type_, datatype=k ) for k, v in data_dict.items() } @staticmethod def _flip_tensor(input_: Any, flip_type: str, datatype: DataType): assert isinstance(input_, th.Tensor) flip_axis = -1 if flip_type == "h" else -2 if datatype == DataType.IMAGE or datatype == DataType.EV_REPR: return th.flip(input_, dims=[flip_axis]) if datatype == DataType.FLOW: assert input_.shape[-3] == 2 flow_idx = 0 if flip_type == "h" else 1 input_ = th.flip(input_, dims=[flip_axis]) # Also flip the sign of the x (horizontal) or y (vertical) component of the flow. input_[..., flow_idx, :, :] = -1 * input_[..., flow_idx, :, :] return input_ raise NotImplementedError @classmethod def _flip_recursive(cls, input_: Any, flip_type: str, datatype: DataType): if datatype in (DataType.IS_PADDED_MASK, DataType.IS_FIRST_SAMPLE): return input_ if isinstance(input_, th.Tensor): return cls._flip_tensor( input_=input_, flip_type=flip_type, datatype=datatype ) if isinstance(input_, ObjectLabels) or isinstance( input_, SparselyBatchedObjectLabels ): assert datatype == DataType.OBJLABELS or datatype == DataType.OBJLABELS_SEQ if flip_type == "h": # in-place modification input_.flip_lr_() return input_ else: raise NotImplementedError if isinstance(input_, abc.Sequence): return [ RandomSpatialAugmentorGenX._flip_recursive( x, flip_type=flip_type, datatype=datatype ) for x in input_ ] if isinstance(input_, abc.Mapping): return { key: RandomSpatialAugmentorGenX._flip_recursive( value, flip_type=flip_type, datatype=datatype ) for key, value in input_.items() } raise NotImplementedError @staticmethod def _hw_from_data(data_dict: LoaderDataDictGenX) -> Tuple[int, int]: height = None width = None for k, v in data_dict.items(): _hw = None if k == DataType.OBJLABELS or k == DataType.OBJLABELS_SEQ: hw = v.input_size_hw if hw is not None: _hw = v.input_size_hw elif k in (DataType.IMAGE, DataType.FLOW, DataType.EV_REPR): _hw = v[0].shape[-2:] if _hw is not None: _height, _width = _hw if height is None: assert width is None height, width = _height, _width else: assert height == _height and width == _width assert height is not None assert width is not None return height, width def __call__(self, data_dict: LoaderDataDictGenX): """ :param data_dict: LoaderDataDictGenX type, image-based tensors must have (*, h, w) shape. :return: map with same keys but spatially augmented values. """ if self.automatic_randomization: self.randomize_augmentation() if self.augm_state.apply_h_flip: data_dict = self._flip(data_dict, type_="h") if self.augm_state.rotation.active: data_dict = self._rotate(data_dict) if self.augm_state.apply_zoom_in: data_dict = self._zoom_in_and_rescale(data_dict=data_dict) if self.augm_state.zoom_out.active: assert not self.augm_state.apply_zoom_in data_dict = self._zoom_out_and_rescale(data_dict=data_dict) return data_dict def get_most_recent_objframe( data_dict: LoaderDataDictGenX, check_if_nonempty: bool = True ) -> Optional[ObjectLabels]: assert ( DataType.OBJLABELS_SEQ in data_dict ), f"Requires datatype {DataType.OBJLABELS_SEQ} to be present" sparse_obj_labels = data_dict[DataType.OBJLABELS_SEQ] sparse_obj_labels: SparselyBatchedObjectLabels for obj_label in reversed(sparse_obj_labels): if obj_label is not None: return_label = True if not check_if_nonempty else len(obj_label) > 0 if return_label: return obj_label # no labels found return None def randomly_sample_zoom_window_from_objframe( objframe: ObjectLabels, zoom_window_height: Union[int, float], zoom_window_width: Union[int, float], ) -> Tuple[int, int]: input_height, input_width = objframe.input_size_hw possible_samples = [] for idx in range(len(objframe)): label_xywh = ( objframe.x[idx], objframe.y[idx], objframe.w[idx], objframe.h[idx], ) possible_samples.append( randomly_sample_zoom_window_from_label_rectangle( label_xywh=label_xywh, input_height=input_height, input_width=input_width, zoom_window_height=zoom_window_height, zoom_window_width=zoom_window_width, ) ) assert len(possible_samples) > 0 # Using torch to sample, to avoid potential problems with multiprocessing. sample_idx = ( 0 if len(possible_samples) == 1 else th.randint(low=0, high=len(possible_samples) - 1, size=(1,)).item() ) x0_sample, y0_sample = possible_samples[sample_idx] assert input_width > x0_sample >= 0, f"{x0_sample=}" assert input_height > y0_sample >= 0, f"{y0_sample=}" return x0_sample, y0_sample def randomly_sample_zoom_window_from_label_rectangle( label_xywh: Tuple[Union[int, float, th.Tensor], ...], input_height: Union[int, float], input_width: Union[int, float], zoom_window_height: Union[int, float], zoom_window_width: Union[int, float], ) -> Tuple[int, int]: """Computes a set of top-left coordinates from which the top-left corner of the zoom window can be sampled such that the zoom window is guaranteed to contain the whole (rectangular) label. Return a random sample from this set. Notation: (x0,y0)---(x1,y0) | | | | (x0,y1)---(x1,y1) """ assert input_height >= zoom_window_height assert input_width >= zoom_window_width label_xywh = tuple(x.item() if isinstance(x, th.Tensor) else x for x in label_xywh) x0_l, y0_l, w_l, h_l = label_xywh x1_l = x0_l + w_l y1_l = y0_l + h_l assert x0_l >= 0 assert y0_l >= 0 assert w_l > 0 assert h_l > 0 assert x1_l <= input_width + 1e-2 - 1 assert y1_l <= input_height + 1e-2 - 1 x0_valid_region = max(x1_l - max(zoom_window_width, w_l), 0) y0_valid_region = max(y1_l - max(zoom_window_height, h_l), 0) x1_valid_region = min(x0_l + max(zoom_window_width, w_l), input_width - 1) y1_valid_region = min(y0_l + max(zoom_window_height, h_l), input_height - 1) x1_valid_region = max(x1_valid_region - zoom_window_width, x0_valid_region) y1_valid_region = max(y1_valid_region - zoom_window_height, y0_valid_region) x_topleft_sample = int( torch_uniform_sample_scalar( min_value=x0_valid_region, max_value=x1_valid_region ) ) assert 0 <= x_topleft_sample < input_width y_topleft_sample = int( torch_uniform_sample_scalar( min_value=y0_valid_region, max_value=y1_valid_region ) ) assert 0 <= y_topleft_sample < input_height return x_topleft_sample, y_topleft_sample ================================================ FILE: RVT/data/utils/representations.py ================================================ from abc import ABC, abstractmethod from typing import Optional, Tuple import math import numpy as np import torch as th class RepresentationBase(ABC): @abstractmethod def construct( self, x: th.Tensor, y: th.Tensor, pol: th.Tensor, time: th.Tensor ) -> th.Tensor: ... @abstractmethod def get_shape(self) -> Tuple[int, int, int]: ... @staticmethod @abstractmethod def get_numpy_dtype() -> np.dtype: ... @staticmethod @abstractmethod def get_torch_dtype() -> th.dtype: ... @property def dtype(self) -> th.dtype: return self.get_torch_dtype() @staticmethod def _is_int_tensor(tensor: th.Tensor) -> bool: return not th.is_floating_point(tensor) and not th.is_complex(tensor) class StackedHistogram(RepresentationBase): def __init__( self, bins: int, height: int, width: int, count_cutoff: Optional[int] = None, fastmode: bool = True, ): """ In case of fastmode == True: use uint8 to construct the representation, but could lead to overflow. In case of fastmode == False: use int16 to construct the representation, and convert to uint8 after clipping. Note: Overflow should not be a big problem because it happens only for hot pixels. In case of overflow, the value will just start accumulating from 0 again. """ assert bins >= 1 self.bins = bins assert height >= 1 self.height = height assert width >= 1 self.width = width self.count_cutoff = count_cutoff if self.count_cutoff is None: self.count_cutoff = 255 else: assert count_cutoff >= 1 self.count_cutoff = min(count_cutoff, 255) self.fastmode = fastmode self.channels = 2 @staticmethod def get_numpy_dtype() -> np.dtype: return np.dtype("uint8") @staticmethod def get_torch_dtype() -> th.dtype: return th.uint8 def merge_channel_and_bins(self, representation: th.Tensor): assert representation.dim() == 4 return th.reshape(representation, (-1, self.height, self.width)) def get_shape(self) -> Tuple[int, int, int]: return 2 * self.bins, self.height, self.width def construct( self, x: th.Tensor, y: th.Tensor, pol: th.Tensor, time: th.Tensor ) -> th.Tensor: device = x.device assert y.device == pol.device == time.device == device assert self._is_int_tensor(x) assert self._is_int_tensor(y) assert self._is_int_tensor(pol) assert self._is_int_tensor(time) dtype = th.uint8 if self.fastmode else th.int16 representation = th.zeros( (self.channels, self.bins, self.height, self.width), dtype=dtype, device=device, requires_grad=False, ) if x.numel() == 0: assert y.numel() == 0 assert pol.numel() == 0 assert time.numel() == 0 return self.merge_channel_and_bins(representation.to(th.uint8)) assert x.numel() == y.numel() == pol.numel() == time.numel() assert pol.min() >= 0 assert pol.max() <= 1 bn, ch, ht, wd = self.bins, self.channels, self.height, self.width # NOTE: assume sorted time t0_int = time[0] t1_int = time[-1] assert t1_int >= t0_int t_norm = time - t0_int t_norm = t_norm / max((t1_int - t0_int), 1) t_norm = t_norm * bn t_idx = t_norm.floor() t_idx = th.clamp(t_idx, max=bn - 1) indices = ( x.long() + wd * y.long() + ht * wd * t_idx.long() + bn * ht * wd * pol.long() ) values = th.ones_like(indices, dtype=dtype, device=device) representation.put_(indices, values, accumulate=True) representation = th.clamp(representation, min=0, max=self.count_cutoff) if not self.fastmode: representation = representation.to(th.uint8) return self.merge_channel_and_bins(representation) def cumsum_channel(x: th.Tensor, num_channels: int): for i in reversed(range(num_channels)): x[i] = th.sum(input=x[: i + 1], dim=0) return x class MixedDensityEventStack(RepresentationBase): def __init__( self, bins: int, height: int, width: int, count_cutoff: Optional[int] = None, allow_compilation: bool = False, ): assert bins >= 1 self.bins = bins assert height >= 1 self.height = height assert width >= 1 self.width = width self.count_cutoff = count_cutoff if self.count_cutoff is not None: assert isinstance(count_cutoff, int) assert 0 <= self.count_cutoff <= 2**7 - 1 self.cumsum_ch_opt = cumsum_channel if allow_compilation: # Will most likely not work with multiprocessing. try: self.cumsum_ch_opt = th.compile(cumsum_channel) except AttributeError: ... @staticmethod def get_numpy_dtype() -> np.dtype: return np.dtype("int8") @staticmethod def get_torch_dtype() -> th.dtype: return th.int8 def get_shape(self) -> Tuple[int, int, int]: return self.bins, self.height, self.width def construct( self, x: th.Tensor, y: th.Tensor, pol: th.Tensor, time: th.Tensor ) -> th.Tensor: device = x.device assert y.device == pol.device == time.device == device assert self._is_int_tensor(x) assert self._is_int_tensor(y) assert self._is_int_tensor(pol) assert self._is_int_tensor(time) dtype = th.int8 representation = th.zeros( (self.bins, self.height, self.width), dtype=dtype, device=device, requires_grad=False, ) if x.numel() == 0: assert y.numel() == 0 assert pol.numel() == 0 assert time.numel() == 0 return representation assert x.numel() == y.numel() == pol.numel() == time.numel() assert pol.min() >= 0 # maybe remove because too costly assert pol.max() <= 1 # maybe remove because too costly pol = pol * 2 - 1 bn, ht, wd = self.bins, self.height, self.width # NOTE: assume sorted time t0_int = time[0] t1_int = time[-1] assert t1_int >= t0_int t_norm = (time - t0_int) / max((t1_int - t0_int), 1) t_norm = th.clamp(t_norm, min=1e-6, max=1 - 1e-6) # Let N be the number of bins. I.e. bin \in [0, N): # Let f(bin) = t_norm, model the relationship between bin and normalized time \in [0, 1] # f(bin=N) = 1 # f(bin=N-1) = 1/2 # f(bin=N-2) = 1/2*1/2 # -> f(bin=N-i) = (1/2)^i # Also: f(bin) = t_norm # # Hence, (1/2)^(N-bin) = t_norm # And, bin = N - log(t_norm, base=1/2) = N - log(t_norm)/log(1/2) bin_float = self.bins - th.log(t_norm) / math.log(1 / 2) # Can go below 0 for t_norm close to 0 -> clamp to 0 bin_float = th.clamp(bin_float, min=0) t_idx = bin_float.floor() indices = x.long() + wd * y.long() + ht * wd * t_idx.long() values = th.asarray(pol, dtype=dtype, device=device) representation.put_(indices, values, accumulate=True) representation = self.cumsum_ch_opt(representation, num_channels=self.bins) if self.count_cutoff is not None: representation = th.clamp( representation, min=-self.count_cutoff, max=self.count_cutoff ) return representation ================================================ FILE: RVT/data/utils/spatial.py ================================================ from omegaconf import DictConfig from data.utils.types import DatasetType _type_2_hw = { DatasetType.GEN1: (240, 304), DatasetType.GEN4: (720, 1280), } _str_2_type = { "gen1": DatasetType.GEN1, "gen4": DatasetType.GEN4, } def get_original_hw(dataset_type: DatasetType): return _type_2_hw[dataset_type] def get_dataloading_hw(dataset_config: DictConfig): dataset_name = dataset_config.name hw = get_original_hw(dataset_type=_str_2_type[dataset_name]) downsample_by_factor_2 = dataset_config.downsample_by_factor_2 if downsample_by_factor_2: hw = tuple(x // 2 for x in hw) return hw ================================================ FILE: RVT/data/utils/stream_concat_datapipe.py ================================================ from typing import Any, Iterator, List, Optional, Type import torch as th import torch.distributed as dist from torch.utils.data import DataLoader from torchdata.datapipes.iter import ( Concater, IterableWrapper, IterDataPipe, Zipper, ) from torchdata.datapipes.map import MapDataPipe class DummyIterDataPipe(IterDataPipe): def __init__(self, source_dp: IterDataPipe): super().__init__() assert isinstance(source_dp, IterDataPipe) self.source_dp = source_dp def __iter__(self): yield from self.source_dp class ConcatStreamingDataPipe(IterDataPipe): """This Dataset avoids the sharding problem by instantiating randomized stream concatenation at the batch and worker level. Pros: - Every single batch has valid samples. Consequently, the batch size is always constant. Cons: - There might be repeated samples in a batch. Although they should be different because of data augmentation. - Cannot be used for validation or testing because we repeat the dataset multiple times in an epoch. TLDR: preferred approach for training but not useful for validation or testing. """ def __init__( self, datapipe_list: List[MapDataPipe], batch_size: int, num_workers: int, augmentation_pipeline: Optional[Type[IterDataPipe]] = None, print_seed_debug: bool = False, ): super().__init__() assert batch_size > 0 if augmentation_pipeline is not None: self.augmentation_dp = augmentation_pipeline else: self.augmentation_dp = DummyIterDataPipe # We require MapDataPipes instead of IterDataPipes because IterDataPipes must be deepcopied in each worker. # Instead, MapDataPipes can be converted to IterDataPipes in each worker without requiring a deepcopy. self.datapipe_list = datapipe_list self.batch_size = batch_size self.print_seed_debug = print_seed_debug @staticmethod def random_torch_shuffle_list(data: List[Any]) -> Iterator[Any]: assert isinstance(data, List) return (data[idx] for idx in th.randperm(len(data)).tolist()) def _get_zipped_streams(self, datapipe_list: List[MapDataPipe], batch_size: int): """Use it only in the iter function of this class!!! Reason: randomized shuffling must happen within each worker. Otherwise, the same random order will be used for all workers. """ assert isinstance(datapipe_list, List) assert batch_size > 0 streams = Zipper( *( Concater( *( self.augmentation_dp(x.to_iter_datapipe()) for x in self.random_torch_shuffle_list(datapipe_list) ) ) for _ in range(batch_size) ) ) return streams def _print_seed_debug_info(self): worker_info = th.utils.data.get_worker_info() local_worker_id = 0 if worker_info is None else worker_info.id worker_torch_seed = worker_info.seed local_num_workers = 1 if worker_info is None else worker_info.num_workers if dist.is_available() and dist.is_initialized(): global_rank = dist.get_rank() else: global_rank = 0 global_worker_id = global_rank * local_num_workers + local_worker_id rnd_number = th.randn(1) print( f"{worker_torch_seed=},\t{global_worker_id=},\t{global_rank=},\t{local_worker_id=},\t{rnd_number=}", flush=True, ) def _get_zipped_streams_with_worker_id(self): """Use it only in the iter function of this class!!!""" worker_info = th.utils.data.get_worker_info() local_worker_id = 0 if worker_info is None else worker_info.id worker_id_stream = IterableWrapper([local_worker_id]).cycle(count=None) zipped_stream = self._get_zipped_streams( datapipe_list=self.datapipe_list, batch_size=self.batch_size ) return zipped_stream.zip(worker_id_stream) def __iter__(self): if self.print_seed_debug: self._print_seed_debug_info() return iter(self._get_zipped_streams_with_worker_id()) ================================================ FILE: RVT/data/utils/stream_sharded_datapipe.py ================================================ from typing import Any, List, Optional import torch import torch.distributed as dist from torch.utils.data import DataLoader from torchdata.datapipes.iter import ( Concater, IterableWrapper, IterDataPipe, ZipperLongest, ) from torchdata.datapipes.map import MapDataPipe class ShardedStreamingDataPipe(IterDataPipe): def __init__( self, datapipe_list: List[MapDataPipe], batch_size: int, fill_value: Optional[Any] = None, ): super().__init__() assert batch_size > 0 # We require MapDataPipes instead of IterDataPipes because IterDataPipes must be deepcopied in each worker. # Instead, MapDataPipes can be converted to IterDataPipes in each worker without requiring a deepcopy. # Note: Sorting is a heuristic to get potentially better distribution of workloads than taking the data as is. # Sort iterators from long to short. self.datapipe_list = sorted(datapipe_list, key=lambda x: len(x), reverse=True) self.batch_size = batch_size self.fill_value = fill_value @staticmethod def yield_pyramid_indices(start_idx: int, end_idx: int): while True: for idx in range(start_idx, end_idx): yield idx for idx in range(end_idx - 1, start_idx - 1, -1): yield idx @classmethod def assign_datapipes_to_worker( cls, sorted_datapipe_list: List[MapDataPipe], total_num_workers: int, global_worker_id: int, ) -> List[MapDataPipe]: num_datapipes = len(sorted_datapipe_list) assert ( num_datapipes >= total_num_workers > global_worker_id ), f"{num_datapipes=}, {total_num_workers=}, {global_worker_id=}" datapipes = [] # Assumes sorted datapipes from long to short. global_worker_id_generator = cls.yield_pyramid_indices( start_idx=0, end_idx=total_num_workers ) for idx, dp in enumerate(sorted_datapipe_list): generated_global_worker_id = next(global_worker_id_generator) if generated_global_worker_id == global_worker_id: datapipes.append(dp) assert len(sorted_datapipe_list) > 0 return datapipes def get_zipped_stream_from_worker_datapipes( self, datapipe_list: List[MapDataPipe], batch_size: int ) -> ZipperLongest: num_datapipes = len(datapipe_list) assert num_datapipes > 0 assert batch_size > 0 assert num_datapipes >= batch_size, ( "Each worker must at least get 'batch_size' number of datapipes. " "Otherwise, we would have to support dynamic batch sizes. " "As a workaround, decrease the number of workers." ) # Sort datapipe_list from long to short. datapipe_list = sorted(datapipe_list, key=lambda x: len(x), reverse=True) zipped_streams = [[] for _ in range(batch_size)] batch_id_generator = self.yield_pyramid_indices(start_idx=0, end_idx=batch_size) for datapipe in datapipe_list: batch_idx = next(batch_id_generator) zipped_streams[batch_idx].append(datapipe) for idx, streams in enumerate(zipped_streams): zipped_streams[idx] = Concater( *(stream.to_iter_datapipe() for stream in streams) ) zipped_streams = ZipperLongest(*zipped_streams, fill_value=self.fill_value) return zipped_streams def __iter__(self): worker_info = torch.utils.data.get_worker_info() local_worker_id = 0 if worker_info is None else worker_info.id local_num_workers = 1 if worker_info is None else worker_info.num_workers if dist.is_available() and dist.is_initialized(): world_size = dist.get_world_size() global_rank = dist.get_rank() else: world_size = 1 global_rank = 0 total_num_workers = local_num_workers * world_size global_worker_id = global_rank * local_num_workers + local_worker_id local_datapipes = self.assign_datapipes_to_worker( sorted_datapipe_list=self.datapipe_list, total_num_workers=total_num_workers, global_worker_id=global_worker_id, ) zipped_stream = self.get_zipped_stream_from_worker_datapipes( datapipe_list=local_datapipes, batch_size=self.batch_size ) # We also stream the local worker id for the use-case where we have a recurrent neural network that saves # its state based on the local worker id. We don't need the global worker id for that because the states # are saved in each DDP process (per GPU) separately and do not to communicate with each other. worker_id_stream = IterableWrapper([local_worker_id]).cycle(count=None) zipped_stream = zipped_stream.zip(worker_id_stream) return iter(zipped_stream) ================================================ FILE: RVT/data/utils/types.py ================================================ from enum import auto, Enum try: from enum import StrEnum except ImportError: from strenum import StrEnum from typing import Dict, List, Optional, Tuple, Union import torch as th from data.genx_utils.labels import ObjectLabels, SparselyBatchedObjectLabels class DataType(Enum): EV_REPR = auto() FLOW = auto() IMAGE = auto() OBJLABELS = auto() OBJLABELS_SEQ = auto() IS_PADDED_MASK = auto() IS_FIRST_SAMPLE = auto() TOKEN_MASK = auto() class DatasetType(Enum): GEN1 = auto() GEN4 = auto() class DatasetMode(Enum): TRAIN = auto() VALIDATION = auto() TESTING = auto() class DatasetSamplingMode(StrEnum): RANDOM = "random" STREAM = "stream" MIXED = "mixed" class ObjDetOutput(Enum): LABELS_PROPH = auto() PRED_PROPH = auto() EV_REPR = auto() SKIP_VIZ = auto() LoaderDataDictGenX = Dict[ DataType, Union[List[th.Tensor], ObjectLabels, SparselyBatchedObjectLabels, List[bool]], ] LstmState = Optional[Tuple[th.Tensor]] LstmStates = List[LstmState] FeatureMap = th.Tensor BackboneFeatures = Dict[int, th.Tensor] ================================================ FILE: RVT/loggers/utils.py ================================================ from pathlib import Path from typing import Union import wandb from omegaconf import DictConfig, OmegaConf from loggers.wandb_logger import WandbLogger def get_wandb_logger(full_config: DictConfig) -> WandbLogger: wandb_config = full_config.wandb wandb_runpath = wandb_config.wandb_runpath if wandb_runpath is None: wandb_id = wandb.util.generate_id() print(f"new run: generating id {wandb_id}") else: wandb_id = Path(wandb_runpath).name print(f"using provided id {wandb_id}") full_config_dict = OmegaConf.to_container( full_config, resolve=True, throw_on_missing=True ) logger = WandbLogger( project=wandb_config.project_name, group=wandb_config.group_name, wandb_id=wandb_id, log_model=True, save_last_only_final=False, save_code=True, config_args=full_config_dict, ) return logger def get_ckpt_path(logger: WandbLogger, wandb_config: DictConfig) -> Union[Path, None]: cfg = wandb_config artifact_name = cfg.artifact_name assert ( artifact_name is not None ), "Artifact name is required to resume from checkpoint." print(f"resuming checkpoint from artifact {artifact_name}") artifact_local_file = cfg.artifact_local_file if artifact_local_file is not None: artifact_local_file = Path(artifact_local_file) if isinstance(logger, WandbLogger): resume_path = logger.get_checkpoint( artifact_name=artifact_name, artifact_filepath=artifact_local_file ) else: resume_path = artifact_local_file assert resume_path.exists() assert resume_path.suffix == ".ckpt", resume_path.suffix return resume_path ================================================ FILE: RVT/loggers/wandb_logger.py ================================================ """ This is a modified version of the Pytorch Lightning logger """ import time from argparse import Namespace from pathlib import Path from typing import Any, Dict, List, Optional, Union from weakref import ReferenceType import numpy as np import lightning.pytorch as pl import torch import torch.nn as nn pl_is_ge_1_6 = float(pl.__version__[:3]) >= 1.6 assert pl_is_ge_1_6 from lightning.pytorch.callbacks.model_checkpoint import ModelCheckpoint from lightning.pytorch.loggers.logger import rank_zero_experiment, Logger from lightning.pytorch.utilities.rank_zero import rank_zero_only, rank_zero_warn from lightning.fabric.utilities.logger import ( _add_prefix, _convert_params, _flatten_dict, _sanitize_callable_params, ) import wandb from wandb.sdk.lib import RunDisabled from wandb.wandb_run import Run class WandbLogger(Logger): LOGGER_JOIN_CHAR = "-" STEP_METRIC = "trainer/global_step" def __init__( self, name: Optional[str] = None, project: Optional[str] = None, group: Optional[str] = None, wandb_id: Optional[str] = None, prefix: Optional[str] = "", log_model: Optional[bool] = True, save_last_only_final: Optional[bool] = False, config_args: Optional[Dict[str, Any]] = None, **kwargs, ): super().__init__() self._experiment = None self._log_model = log_model self._prefix = prefix self._logged_model_time = {} self._checkpoint_callback = None # Save last is determined by the checkpoint callback argument self._save_last = None # Whether to save the last checkpoint continuously (more storage) or only when the run is aborted self._save_last_only_final = save_last_only_final # Save the configuration args (e.g. parsed arguments) and log it in wandb self._config_args = config_args # set wandb init arguments self._wandb_init = dict( name=name, project=project, group=group, id=wandb_id, resume="allow", save_code=True, ) self._wandb_init.update(**kwargs) # extract parameters self._name = self._wandb_init.get("name") self._id = self._wandb_init.get("id") # for save_top_k self._public_run = None # start wandb run (to create an attach_id for distributed modes) wandb.require("service") _ = self.experiment def get_checkpoint( self, artifact_name: str, artifact_filepath: Optional[Path] = None ) -> Path: artifact = self.experiment.use_artifact(artifact_name) if artifact_filepath is None: assert artifact is not None, ( "You are probably using DDP, " "in which case you should provide an artifact filepath." ) # TODO: specify download directory artifact_dir = artifact.download() artifact_filepath = next(Path(artifact_dir).iterdir()) assert artifact_filepath.exists() assert artifact_filepath.suffix == ".ckpt" return artifact_filepath def __getstate__(self) -> Dict[str, Any]: state = self.__dict__.copy() # args needed to reload correct experiment if self._experiment is not None: state["_id"] = getattr(self._experiment, "id", None) state["_attach_id"] = getattr(self._experiment, "_attach_id", None) state["_name"] = self._experiment.name # cannot be pickled state["_experiment"] = None return state @property @rank_zero_experiment def experiment(self) -> Run: if self._experiment is None: attach_id = getattr(self, "_attach_id", None) if wandb.run is not None: # wandb process already created in this instance rank_zero_warn( "There is a wandb run already in progress and newly created instances of `WandbLogger` will reuse" " this run. If this is not desired, call `wandb.finish()` before instantiating `WandbLogger`." ) self._experiment = wandb.run elif attach_id is not None and hasattr(wandb, "_attach"): # attach to wandb process referenced self._experiment = wandb._attach(attach_id) else: # create new wandb process self._experiment = wandb.init(**self._wandb_init) if self._config_args is not None: self._experiment.config.update( self._config_args, allow_val_change=True ) # define default x-axis if isinstance(self._experiment, (Run, RunDisabled)) and getattr( self._experiment, "define_metric", None ): self._experiment.define_metric(self.STEP_METRIC) self._experiment.define_metric( "*", step_metric=self.STEP_METRIC, step_sync=True ) assert isinstance(self._experiment, (Run, RunDisabled)) return self._experiment def watch( self, model: nn.Module, log: str = "all", log_freq: int = 100, log_graph: bool = True, ): self.experiment.watch(model, log=log, log_freq=log_freq, log_graph=log_graph) def add_step_metric(self, input_dict: dict, step: int) -> None: input_dict.update({self.STEP_METRIC: step}) @rank_zero_only def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None: params = _convert_params(params) params = _flatten_dict(params) params = _sanitize_callable_params(params) self.experiment.config.update(params, allow_val_change=True) @rank_zero_only def log_metrics(self, metrics: Dict[str, Any], step: Optional[int] = None) -> None: assert rank_zero_only.rank == 0, "experiment tried to log from global_rank != 0" metrics = _add_prefix(metrics, self._prefix, self.LOGGER_JOIN_CHAR) if step is not None: self.add_step_metric(metrics, step) self.experiment.log({**metrics}, step=step) else: self.experiment.log(metrics) @rank_zero_only def log_images( self, key: str, images: List[Any], step: Optional[int] = None, **kwargs: str ) -> None: """Log images (tensors, numpy arrays, PIL Images or file paths). Optional kwargs are lists passed to each image (ex: caption, masks, boxes). How to use: https://pytorch-lightning.readthedocs.io/en/latest/api/pytorch_lightning.loggers.wandb.html#weights-and-biases-logger Taken from: https://github.com/PyTorchLightning/pytorch-lightning/blob/11e289ad9f95f5fe23af147fa4edcc9794f9b9a7/pytorch_lightning/loggers/wandb.py#L420 """ if not isinstance(images, list): raise TypeError(f'Expected a list as "images", found {type(images)}') n = len(images) for k, v in kwargs.items(): if len(v) != n: raise ValueError(f"Expected {n} items but only found {len(v)} for {k}") kwarg_list = [{k: kwargs[k][i] for k in kwargs.keys()} for i in range(n)] metrics = { key: [wandb.Image(img, **kwarg) for img, kwarg in zip(images, kwarg_list)] } self.log_metrics(metrics, step) @rank_zero_only def log_videos( self, key: str, videos: List[Union[np.ndarray, str]], step: Optional[int] = None, captions: Optional[List[str]] = None, fps: int = 4, format_: Optional[str] = None, ): """ :param video: List[(T,C,H,W)] or List[(N,T,C,H,W)] :param captions: List[str] or None More info: https://docs.wandb.ai/ref/python/data-types/video and https://docs.wandb.ai/guides/track/log/media#other-media """ assert isinstance(videos, list) if captions is not None: assert isinstance(captions, list) assert len(captions) == len(videos) wandb_videos = list() for idx, video in enumerate(videos): caption = captions[idx] if captions is not None else None wandb_videos.append( wandb.Video( data_or_path=video, caption=caption, fps=fps, format=format_ ) ) self.log_metrics(metrics={key: wandb_videos}, step=step) @property def name(self) -> Optional[str]: # This function seems to be only relevant if LoggerCollection is used. # don't create an experiment if we don't have one return self._experiment.project_name() if self._experiment else self._name @property def version(self) -> Optional[str]: # This function seems to be only relevant if LoggerCollection is used. # don't create an experiment if we don't have one return self._experiment.id if self._experiment else self._id @rank_zero_only def after_save_checkpoint( self, checkpoint_callback: "ReferenceType[ModelCheckpoint]" ) -> None: # log checkpoints as artifacts if self._checkpoint_callback is None: self._checkpoint_callback = checkpoint_callback self._save_last = checkpoint_callback.save_last if self._log_model: self._scan_and_log_checkpoints( checkpoint_callback, self._save_last and not self._save_last_only_final ) @rank_zero_only def finalize(self, status: str) -> None: # log checkpoints as artifacts if self._checkpoint_callback and self._log_model: self._scan_and_log_checkpoints(self._checkpoint_callback, self._save_last) def _get_public_run(self): if self._public_run is None: experiment = self.experiment runpath = ( experiment._entity + "/" + experiment._project + "/" + experiment._run_id ) api = wandb.Api() self._public_run = api.run(path=runpath) return self._public_run def _num_logged_artifact(self): public_run = self._get_public_run() return len(public_run.logged_artifacts()) def _scan_and_log_checkpoints( self, checkpoint_callback: "ReferenceType[ModelCheckpoint]", save_last: bool ) -> None: assert self._log_model if self._checkpoint_callback is None: self._checkpoint_callback = checkpoint_callback self._save_last = checkpoint_callback.save_last checkpoints = { checkpoint_callback.best_model_path: checkpoint_callback.best_model_score, **checkpoint_callback.best_k_models, } assert len(checkpoints) <= max(checkpoint_callback.save_top_k, 0) if save_last: last_model_path = Path(checkpoint_callback.last_model_path) if last_model_path.exists(): checkpoints.update( { checkpoint_callback.last_model_path: checkpoint_callback.current_score } ) else: print( f"last model checkpoint not found at {checkpoint_callback.last_model_path}" ) checkpoints = sorted( ( (Path(path).stat().st_mtime, path, score) for path, score in checkpoints.items() if Path(path).is_file() ), key=lambda x: x[0], ) # Retain only checkpoints that we have not logged before with one exception: # If the name is the same (e.g. last checkpoint which should be overwritten), # make sure that they are newer than the previously saved checkpoint by checking their modification time checkpoints = [ ckpt for ckpt in checkpoints if ckpt[1] not in self._logged_model_time.keys() or self._logged_model_time[ckpt[1]] < ckpt[0] ] # remove checkpoints with undefined (None) score checkpoints = [x for x in checkpoints if x[2] is not None] num_ckpt_logged_before = self._num_logged_artifact() num_new_cktps = len(checkpoints) if num_new_cktps == 0: return # log iteratively all new checkpoints for time_, path, score in checkpoints: score = score.item() if isinstance(score, torch.Tensor) else score is_best = path == checkpoint_callback.best_model_path is_last = path == checkpoint_callback.last_model_path metadata = { "score": score, "original_filename": Path(path).name, "ModelCheckpoint": { k: getattr(checkpoint_callback, k) for k in [ "monitor", "mode", "save_last", "save_top_k", "save_weights_only", ] # ensure it does not break if `ModelCheckpoint` args change if hasattr(checkpoint_callback, k) }, } aliases = [] if is_best: aliases.append("best") if is_last: aliases.append("last") artifact_name = f"checkpoint-{self.experiment.id}-" + ( "last" if is_last else "topK" ) artifact = wandb.Artifact( name=artifact_name, type="model", metadata=metadata ) assert Path(path).exists() artifact.add_file(path, name=f"{self.experiment.id}.ckpt") self.experiment.log_artifact(artifact, aliases=aliases) # remember logged models - timestamp needed in case filename didn't change (last.ckpt or custom name) self._logged_model_time[path] = time_ timeout = 20 time_spent = 0 while self._num_logged_artifact() < num_ckpt_logged_before + num_new_cktps: time.sleep(1) time_spent += 1 if time_spent >= timeout: rank_zero_warn( "Timeout: Num logged artifacts never reached expected value." ) print(f"self._num_logged_artifact() = {self._num_logged_artifact()}") print(f"num_ckpt_logged_before = {num_ckpt_logged_before}") print(f"num_new_cktps = {num_new_cktps}") break try: self._rm_but_top_k(checkpoint_callback.save_top_k) except KeyError: pass def _rm_but_top_k(self, top_k: int): # top_k == -1: save all models # top_k == 0: no models saved at all. The checkpoint callback does not return checkpoints. # top_k > 0: keep only top k models (last and best will not be deleted) def is_last(artifact): return "last" in artifact.aliases def is_best(artifact): return "best" in artifact.aliases def try_delete(artifact): try: artifact.delete(delete_aliases=True) except wandb.errors.CommError: print( f"Failed to delete artifact {artifact.name} due to wandb.errors.CommError" ) public_run = self._get_public_run() score2art = list() for artifact in public_run.logged_artifacts(): score = artifact.metadata["score"] original_filename = artifact.metadata["original_filename"] if score == "Infinity": print( f"removing INF artifact (name, score, original_filename): ({artifact.name}, {score}, {original_filename})" ) try_delete(artifact) continue if score is None: print( f"removing None artifact (name, score, original_filename): ({artifact.name}, {score}, {original_filename})" ) try_delete(artifact) continue score2art.append((score, artifact)) # From high score to low score score2art.sort(key=lambda x: x[0], reverse=True) count = 0 for score, artifact in score2art: original_filename = artifact.metadata["original_filename"] if "last" in original_filename and not is_last(artifact): try_delete(artifact) continue if is_last(artifact): continue count += 1 if is_best(artifact): continue # if top_k == -1, we do not delete anything if 0 <= top_k < count: try_delete(artifact) ================================================ FILE: RVT/models/detection/__init_.py ================================================ ================================================ FILE: RVT/models/detection/recurrent_backbone/__init__.py ================================================ from omegaconf import DictConfig from .maxvit_rnn import RNNDetector as MaxViTRNNDetector def build_recurrent_backbone(backbone_cfg: DictConfig): name = backbone_cfg.name if name == "MaxViTRNN": return MaxViTRNNDetector(backbone_cfg) else: raise NotImplementedError ================================================ FILE: RVT/models/detection/recurrent_backbone/base.py ================================================ from typing import Tuple import torch.nn as nn class BaseDetector(nn.Module): def get_stage_dims(self, stages: Tuple[int, ...]) -> Tuple[int, ...]: raise NotImplementedError def get_strides(self, stages: Tuple[int, ...]) -> Tuple[int, ...]: raise NotImplementedError ================================================ FILE: RVT/models/detection/recurrent_backbone/maxvit_rnn.py ================================================ from typing import Dict, Optional, Tuple import torch as th import torch.nn as nn from omegaconf import DictConfig, OmegaConf from einops import rearrange try: from torch import compile as th_compile except ImportError: th_compile = None from data.utils.types import FeatureMap, BackboneFeatures, LstmState, LstmStates # from models.layers.rnn import DWSConvLSTM2d from models.layers.s5.s5_model import S5Block from models.layers.maxvit.maxvit import ( PartitionAttentionCl, nhwC_2_nChw, get_downsample_layer_Cf2Cl, PartitionType, ) from .base import BaseDetector class RNNDetector(BaseDetector): def __init__(self, mdl_config: DictConfig): super().__init__() ###### Config ###### in_channels = mdl_config.input_channels embed_dim = mdl_config.embed_dim dim_multiplier_per_stage = tuple(mdl_config.dim_multiplier) num_blocks_per_stage = tuple(mdl_config.num_blocks) T_max_chrono_init_per_stage = tuple(mdl_config.T_max_chrono_init) enable_masking = mdl_config.enable_masking num_stages = len(num_blocks_per_stage) assert num_stages == 4 assert isinstance(embed_dim, int) assert num_stages == len(dim_multiplier_per_stage) assert num_stages == len(num_blocks_per_stage) assert num_stages == len(T_max_chrono_init_per_stage) ###### Compile if requested ###### compile_cfg = mdl_config.get("compile", None) if compile_cfg is not None: compile_mdl = compile_cfg.enable if compile_mdl and th_compile is not None: compile_args = OmegaConf.to_container( compile_cfg.args, resolve=True, throw_on_missing=True ) self.forward = th_compile(self.forward, **compile_args) elif compile_mdl: print( "Could not compile backbone because torch.compile is not available" ) ################################## input_dim = in_channels patch_size = mdl_config.stem.patch_size stride = 1 self.stage_dims = [embed_dim * x for x in dim_multiplier_per_stage] self.stages = nn.ModuleList() self.strides = [] for stage_idx, (num_blocks, T_max_chrono_init_stage) in enumerate( zip(num_blocks_per_stage, T_max_chrono_init_per_stage) ): spatial_downsample_factor = patch_size if stage_idx == 0 else 2 stage_dim = self.stage_dims[stage_idx] enable_masking_in_stage = enable_masking and stage_idx == 0 stage = RNNDetectorStage( dim_in=input_dim, stage_dim=stage_dim, spatial_downsample_factor=spatial_downsample_factor, num_blocks=num_blocks, enable_token_masking=enable_masking_in_stage, T_max_chrono_init=T_max_chrono_init_stage, stage_cfg=mdl_config.stage, ) stride = stride * spatial_downsample_factor self.strides.append(stride) input_dim = stage_dim self.stages.append(stage) self.num_stages = num_stages def get_stage_dims(self, stages: Tuple[int, ...]) -> Tuple[int, ...]: stage_indices = [x - 1 for x in stages] assert min(stage_indices) >= 0, stage_indices assert max(stage_indices) < len(self.stages), stage_indices return tuple(self.stage_dims[stage_idx] for stage_idx in stage_indices) def get_strides(self, stages: Tuple[int, ...]) -> Tuple[int, ...]: stage_indices = [x - 1 for x in stages] assert min(stage_indices) >= 0, stage_indices assert max(stage_indices) < len(self.stages), stage_indices return tuple(self.strides[stage_idx] for stage_idx in stage_indices) def forward( self, x: th.Tensor, prev_states: Optional[LstmStates] = None, token_mask: Optional[th.Tensor] = None, train_step: bool = True, ) -> Tuple[BackboneFeatures, LstmStates]: if prev_states is None: prev_states = [None] * self.num_stages assert len(prev_states) == self.num_stages states: LstmStates = list() output: Dict[int, FeatureMap] = {} for stage_idx, stage in enumerate(self.stages): x, state = stage( x, prev_states[stage_idx], token_mask if stage_idx == 0 else None, train_step, ) states.append(state) stage_number = stage_idx + 1 output[stage_number] = x return output, states class MaxVitAttentionPairCl(nn.Module): def __init__(self, dim: int, skip_first_norm: bool, attention_cfg: DictConfig): super().__init__() self.att_window = PartitionAttentionCl( dim=dim, partition_type=PartitionType.WINDOW, attention_cfg=attention_cfg, skip_first_norm=skip_first_norm, ) self.att_grid = PartitionAttentionCl( dim=dim, partition_type=PartitionType.GRID, attention_cfg=attention_cfg, skip_first_norm=False, ) def forward(self, x): x = self.att_window(x) x = self.att_grid(x) return x class RNNDetectorStage(nn.Module): """Operates with NCHW [channel-first] format as input and output.""" def __init__( self, dim_in: int, stage_dim: int, spatial_downsample_factor: int, num_blocks: int, enable_token_masking: bool, T_max_chrono_init: Optional[int], stage_cfg: DictConfig, ): super().__init__() assert isinstance(num_blocks, int) and num_blocks > 0 downsample_cfg = stage_cfg.downsample lstm_cfg = stage_cfg.lstm attention_cfg = stage_cfg.attention self.downsample_cf2cl = get_downsample_layer_Cf2Cl( dim_in=dim_in, dim_out=stage_dim, downsample_factor=spatial_downsample_factor, downsample_cfg=downsample_cfg, ) blocks = [ MaxVitAttentionPairCl( dim=stage_dim, skip_first_norm=i == 0 and self.downsample_cf2cl.output_is_normed(), attention_cfg=attention_cfg, ) for i in range(num_blocks) ] self.att_blocks = nn.ModuleList(blocks) self.s5_block = S5Block( dim=stage_dim, state_dim=stage_dim, bidir=False, bandlimit=0.5 ) """ self.lstm = DWSConvLSTM2d( dim=stage_dim, dws_conv=lstm_cfg.dws_conv, dws_conv_only_hidden=lstm_cfg.dws_conv_only_hidden, dws_conv_kernel_size=lstm_cfg.dws_conv_kernel_size, cell_update_dropout=lstm_cfg.get("drop_cell_update", 0), ) """ ###### Mask Token ################ self.mask_token = ( nn.Parameter(th.zeros(1, 1, 1, stage_dim), requires_grad=True) if enable_token_masking else None ) if self.mask_token is not None: th.nn.init.normal_(self.mask_token, std=0.02) ################################## def forward( self, x: th.Tensor, states: Optional[LstmState] = None, token_mask: Optional[th.Tensor] = None, train_step: bool = True, ) -> Tuple[FeatureMap, LstmState]: sequence_length = x.shape[0] batch_size = x.shape[1] x = rearrange( x, "L B C H W -> (L B) C H W" ) # where B' = (L B) is the new batch size x = self.downsample_cf2cl(x) # B' C H W -> B' H W C if token_mask is not None: assert self.mask_token is not None, "No mask token present in this stage" x[token_mask] = self.mask_token for blk in self.att_blocks: x = blk(x) x = nhwC_2_nChw(x) # B' H W C -> B' C H W new_h, new_w = x.shape[2], x.shape[3] x = rearrange(x, "(L B) C H W -> (B H W) L C", L=sequence_length) if states is None: states = self.s5_block.s5.initial_state( batch_size=batch_size * new_h * new_w ).to(x.device) else: states = rearrange(states, "B C H W -> (B H W) C") x, states = self.s5_block(x, states) x = rearrange( x, "(B H W) L C -> L B C H W", B=batch_size, H=int(new_h), W=int(new_w) ) states = rearrange(states, "(B H W) C -> B C H W", H=new_h, W=new_w) return x, states ================================================ FILE: RVT/models/detection/yolox/models/__init__.py ================================================ ================================================ FILE: RVT/models/detection/yolox/models/losses.py ================================================ #!/usr/bin/env python # -*- encoding: utf-8 -*- # Copyright (c) Megvii Inc. All rights reserved. import torch import torch.nn as nn class IOUloss(nn.Module): def __init__(self, reduction="none", loss_type="iou"): super(IOUloss, self).__init__() self.reduction = reduction self.loss_type = loss_type def forward(self, pred, target): assert pred.shape[0] == target.shape[0] pred = pred.view(-1, 4) target = target.view(-1, 4) tl = torch.max( (pred[:, :2] - pred[:, 2:] / 2), (target[:, :2] - target[:, 2:] / 2) ) br = torch.min( (pred[:, :2] + pred[:, 2:] / 2), (target[:, :2] + target[:, 2:] / 2) ) area_p = torch.prod(pred[:, 2:], 1) area_g = torch.prod(target[:, 2:], 1) en = (tl < br).type(tl.type()).prod(dim=1) area_i = torch.prod(br - tl, 1) * en area_u = area_p + area_g - area_i iou = (area_i) / (area_u + 1e-16) if self.loss_type == "iou": loss = 1 - iou**2 elif self.loss_type == "giou": c_tl = torch.min( (pred[:, :2] - pred[:, 2:] / 2), (target[:, :2] - target[:, 2:] / 2) ) c_br = torch.max( (pred[:, :2] + pred[:, 2:] / 2), (target[:, :2] + target[:, 2:] / 2) ) area_c = torch.prod(c_br - c_tl, 1) giou = iou - (area_c - area_u) / area_c.clamp(1e-16) loss = 1 - giou.clamp(min=-1.0, max=1.0) else: raise NotImplementedError if self.reduction == "mean": loss = loss.mean() elif self.reduction == "sum": loss = loss.sum() return loss ================================================ FILE: RVT/models/detection/yolox/models/network_blocks.py ================================================ #!/usr/bin/env python # -*- encoding: utf-8 -*- # Copyright (c) Megvii Inc. All rights reserved. import torch import torch.nn as nn class SiLU(nn.Module): """export-friendly version of nn.SiLU()""" @staticmethod def forward(x): return x * torch.sigmoid(x) def get_activation(name="silu", inplace=True): if name == "silu": module = nn.SiLU(inplace=inplace) elif name == "relu": module = nn.ReLU(inplace=inplace) elif name == "lrelu": module = nn.LeakyReLU(0.1, inplace=inplace) else: raise AttributeError("Unsupported act type: {}".format(name)) return module class BaseConv(nn.Module): """A Conv2d -> Batchnorm -> silu/leaky relu block""" def __init__( self, in_channels, out_channels, ksize, stride, groups=1, bias=False, act="silu" ): super().__init__() # same padding pad = (ksize - 1) // 2 self.conv = nn.Conv2d( in_channels, out_channels, kernel_size=ksize, stride=stride, padding=pad, groups=groups, bias=bias, ) self.bn = nn.BatchNorm2d(out_channels) self.act = get_activation(act, inplace=True) def forward(self, x): return self.act(self.bn(self.conv(x))) def fuseforward(self, x): return self.act(self.conv(x)) class DWConv(nn.Module): """Depthwise Conv + Conv""" def __init__(self, in_channels, out_channels, ksize, stride=1, act="silu"): super().__init__() self.dconv = BaseConv( in_channels, in_channels, ksize=ksize, stride=stride, groups=in_channels, act=act, ) self.pconv = BaseConv( in_channels, out_channels, ksize=1, stride=1, groups=1, act=act ) def forward(self, x): x = self.dconv(x) return self.pconv(x) class Bottleneck(nn.Module): # Standard bottleneck def __init__( self, in_channels, out_channels, shortcut=True, expansion=0.5, depthwise=False, act="silu", ): super().__init__() hidden_channels = int(out_channels * expansion) Conv = DWConv if depthwise else BaseConv self.conv1 = BaseConv(in_channels, hidden_channels, 1, stride=1, act=act) self.conv2 = Conv(hidden_channels, out_channels, 3, stride=1, act=act) self.use_add = shortcut and in_channels == out_channels def forward(self, x): y = self.conv2(self.conv1(x)) if self.use_add: y = y + x return y class CSPLayer(nn.Module): """C3 in yolov5, CSP Bottleneck with 3 convolutions""" def __init__( self, in_channels, out_channels, n=1, shortcut=True, expansion=0.5, depthwise=False, act="silu", ): """ Args: in_channels (int): input channels. out_channels (int): output channels. n (int): number of Bottlenecks. Default value: 1. """ # ch_in, ch_out, number, shortcut, groups, expansion super().__init__() hidden_channels = int(out_channels * expansion) # hidden channels self.conv1 = BaseConv(in_channels, hidden_channels, 1, stride=1, act=act) self.conv2 = BaseConv(in_channels, hidden_channels, 1, stride=1, act=act) self.conv3 = BaseConv(2 * hidden_channels, out_channels, 1, stride=1, act=act) module_list = [ Bottleneck( hidden_channels, hidden_channels, shortcut, 1.0, depthwise, act=act ) for _ in range(n) ] self.m = nn.Sequential(*module_list) def forward(self, x): x_1 = self.conv1(x) x_2 = self.conv2(x) x_1 = self.m(x_1) x = torch.cat((x_1, x_2), dim=1) return self.conv3(x) ================================================ FILE: RVT/models/detection/yolox/models/yolo_head.py ================================================ """ Original Yolox Head code with slight modifications """ import math from typing import Dict, Optional import torch import torch.nn as nn import torch.nn.functional as F try: from torch import compile as th_compile except ImportError: th_compile = None from models.detection.yolox.utils import bboxes_iou from .losses import IOUloss from .network_blocks import BaseConv, DWConv class YOLOXHead(nn.Module): def __init__( self, num_classes=80, strides=(8, 16, 32), in_channels=(256, 512, 1024), act="silu", depthwise=False, compile_cfg: Optional[Dict] = None, ): super().__init__() self.num_classes = num_classes self.decode_in_inference = True # for deploy, set to False self.cls_convs = nn.ModuleList() self.reg_convs = nn.ModuleList() self.cls_preds = nn.ModuleList() self.reg_preds = nn.ModuleList() self.obj_preds = nn.ModuleList() self.stems = nn.ModuleList() Conv = DWConv if depthwise else BaseConv self.output_strides = None self.output_grids = None # Automatic width scaling according to original YoloX channel dims. # in[-1]/out = 4/1 # out = in[-1]/4 = 256 * width # -> width = in[-1]/1024 largest_base_dim_yolox = 1024 largest_base_dim_from_input = in_channels[-1] width = largest_base_dim_from_input / largest_base_dim_yolox hidden_dim = int(256 * width) for i in range(len(in_channels)): self.stems.append( BaseConv( in_channels=in_channels[i], out_channels=hidden_dim, ksize=1, stride=1, act=act, ) ) self.cls_convs.append( nn.Sequential( *[ Conv( in_channels=hidden_dim, out_channels=hidden_dim, ksize=3, stride=1, act=act, ), Conv( in_channels=hidden_dim, out_channels=hidden_dim, ksize=3, stride=1, act=act, ), ] ) ) self.reg_convs.append( nn.Sequential( *[ Conv( in_channels=hidden_dim, out_channels=hidden_dim, ksize=3, stride=1, act=act, ), Conv( in_channels=hidden_dim, out_channels=hidden_dim, ksize=3, stride=1, act=act, ), ] ) ) self.cls_preds.append( nn.Conv2d( in_channels=hidden_dim, out_channels=self.num_classes, kernel_size=1, stride=1, padding=0, ) ) self.reg_preds.append( nn.Conv2d( in_channels=hidden_dim, out_channels=4, kernel_size=1, stride=1, padding=0, ) ) self.obj_preds.append( nn.Conv2d( in_channels=hidden_dim, out_channels=1, kernel_size=1, stride=1, padding=0, ) ) self.use_l1 = False self.l1_loss = nn.L1Loss(reduction="none") self.bcewithlog_loss = nn.BCEWithLogitsLoss(reduction="none") self.iou_loss = IOUloss(reduction="none") self.strides = strides self.grids = [torch.zeros(1)] * len(in_channels) # According to Focal Loss paper: self.initialize_biases(prior_prob=0.01) ###### Compile if requested ###### if compile_cfg is not None: compile_mdl = compile_cfg["enable"] if compile_mdl and th_compile is not None: self.forward = th_compile(self.forward, **compile_cfg["args"]) elif compile_mdl: print( "Could not compile YOLOXHead because torch.compile is not available" ) ################################## def initialize_biases(self, prior_prob): for conv in self.cls_preds: b = conv.bias.view(1, -1) b.data.fill_(-math.log((1 - prior_prob) / prior_prob)) conv.bias = torch.nn.Parameter(b.view(-1), requires_grad=True) for conv in self.obj_preds: b = conv.bias.view(1, -1) b.data.fill_(-math.log((1 - prior_prob) / prior_prob)) conv.bias = torch.nn.Parameter(b.view(-1), requires_grad=True) def forward(self, xin, labels=None): train_outputs = [] inference_outputs = [] origin_preds = [] x_shifts = [] y_shifts = [] expanded_strides = [] for k, (cls_conv, reg_conv, stride_this_level, x) in enumerate( zip(self.cls_convs, self.reg_convs, self.strides, xin) ): x = self.stems[k](x) cls_x = x reg_x = x cls_feat = cls_conv(cls_x) cls_output = self.cls_preds[k](cls_feat) reg_feat = reg_conv(reg_x) reg_output = self.reg_preds[k](reg_feat) obj_output = self.obj_preds[k](reg_feat) if self.training: output = torch.cat([reg_output, obj_output, cls_output], 1) output, grid = self.get_output_and_grid( output, k, stride_this_level, xin[0].type() ) x_shifts.append(grid[:, :, 0]) y_shifts.append(grid[:, :, 1]) expanded_strides.append( torch.zeros(1, grid.shape[1]) .fill_(stride_this_level) .type_as(xin[0]) ) if self.use_l1: batch_size = reg_output.shape[0] hsize, wsize = reg_output.shape[-2:] reg_output = reg_output.view(batch_size, 1, 4, hsize, wsize) reg_output = reg_output.permute(0, 1, 3, 4, 2).reshape( batch_size, -1, 4 ) origin_preds.append(reg_output.clone()) train_outputs.append(output) inference_output = torch.cat( [reg_output, obj_output.sigmoid(), cls_output.sigmoid()], 1 ) inference_outputs.append(inference_output) # -------------------------------------------------------- # Modification: return decoded output also during training # -------------------------------------------------------- losses = None if self.training: losses = self.get_losses( x_shifts, y_shifts, expanded_strides, labels, torch.cat(train_outputs, 1), origin_preds, dtype=xin[0].dtype, ) assert len(losses) == 6 losses = { "loss": losses[0], "iou_loss": losses[1], "conf_loss": losses[2], # object-ness "cls_loss": losses[3], # predicted class "l1_loss": losses[4], "num_fg": losses[5], } self.hw = [x.shape[-2:] for x in inference_outputs] # [batch, n_anchors_all, 85] outputs = torch.cat( [x.flatten(start_dim=2) for x in inference_outputs], dim=2 ).permute(0, 2, 1) if self.decode_in_inference: return self.decode_outputs(outputs), losses else: return outputs, losses def get_output_and_grid(self, output, k, stride, dtype): grid = self.grids[k] batch_size = output.shape[0] n_ch = 5 + self.num_classes hsize, wsize = output.shape[-2:] if grid.shape[2:4] != output.shape[2:4]: yv, xv = torch.meshgrid([torch.arange(hsize), torch.arange(wsize)]) grid = torch.stack((xv, yv), 2).view(1, 1, hsize, wsize, 2).type(dtype) self.grids[k] = grid output = output.view(batch_size, 1, n_ch, hsize, wsize) output = output.permute(0, 1, 3, 4, 2).reshape(batch_size, hsize * wsize, -1) grid = grid.view(1, -1, 2) output[..., :2] = (output[..., :2] + grid) * stride output[..., 2:4] = torch.exp(output[..., 2:4]) * stride return output, grid def decode_outputs(self, outputs): if self.output_grids is None: assert self.output_strides is None dtype = outputs.dtype device = outputs.device grids = [] strides = [] for (hsize, wsize), stride in zip(self.hw, self.strides): yv, xv = torch.meshgrid( [ torch.arange(hsize, device=device, dtype=dtype), torch.arange(wsize, device=device, dtype=dtype), ] ) grid = torch.stack((xv, yv), 2).view(1, -1, 2) grids.append(grid) shape = grid.shape[:2] strides.append( torch.full((*shape, 1), stride, device=device, dtype=dtype) ) self.output_grids = torch.cat(grids, dim=1) self.output_strides = torch.cat(strides, dim=1) outputs = torch.cat( [ (outputs[..., 0:2] + self.output_grids) * self.output_strides, torch.exp(outputs[..., 2:4]) * self.output_strides, outputs[..., 4:], ], dim=-1, ) return outputs def get_losses( self, x_shifts, y_shifts, expanded_strides, labels, outputs, origin_preds, dtype, ): bbox_preds = outputs[:, :, :4] # [batch, n_anchors_all, 4] obj_preds = outputs[:, :, 4:5] # [batch, n_anchors_all, 1] cls_preds = outputs[:, :, 5:] # [batch, n_anchors_all, n_cls] # calculate targets nlabel = (labels.sum(dim=2) > 0).sum(dim=1) # number of objects total_num_anchors = outputs.shape[1] x_shifts = torch.cat(x_shifts, 1) # [1, n_anchors_all] y_shifts = torch.cat(y_shifts, 1) # [1, n_anchors_all] expanded_strides = torch.cat(expanded_strides, 1) if self.use_l1: origin_preds = torch.cat(origin_preds, 1) cls_targets = [] reg_targets = [] l1_targets = [] obj_targets = [] fg_masks = [] num_fg = 0.0 num_gts = 0.0 for batch_idx in range(outputs.shape[0]): num_gt = int(nlabel[batch_idx]) num_gts += num_gt if num_gt == 0: cls_target = outputs.new_zeros((0, self.num_classes)) reg_target = outputs.new_zeros((0, 4)) l1_target = outputs.new_zeros((0, 4)) obj_target = outputs.new_zeros((total_num_anchors, 1)) fg_mask = outputs.new_zeros(total_num_anchors).bool() else: gt_bboxes_per_image = labels[batch_idx, :num_gt, 1:5] gt_classes = labels[batch_idx, :num_gt, 0] bboxes_preds_per_image = bbox_preds[batch_idx] try: ( gt_matched_classes, fg_mask, pred_ious_this_matching, matched_gt_inds, num_fg_img, ) = self.get_assignments( # noqa batch_idx, num_gt, gt_bboxes_per_image, gt_classes, bboxes_preds_per_image, expanded_strides, x_shifts, y_shifts, cls_preds, obj_preds, ) except RuntimeError as e: # TODO: the string might change, consider a better way if "CUDA out of memory. " not in str(e): raise torch.cuda.empty_cache() ( gt_matched_classes, fg_mask, pred_ious_this_matching, matched_gt_inds, num_fg_img, ) = self.get_assignments( # noqa batch_idx, num_gt, gt_bboxes_per_image, gt_classes, bboxes_preds_per_image, expanded_strides, x_shifts, y_shifts, cls_preds, obj_preds, "cpu", ) torch.cuda.empty_cache() num_fg += num_fg_img cls_target = F.one_hot( gt_matched_classes.to(torch.int64), self.num_classes ) * pred_ious_this_matching.unsqueeze(-1) obj_target = fg_mask.unsqueeze(-1) reg_target = gt_bboxes_per_image[matched_gt_inds] if self.use_l1: l1_target = self.get_l1_target( outputs.new_zeros((num_fg_img, 4)), gt_bboxes_per_image[matched_gt_inds], expanded_strides[0][fg_mask], x_shifts=x_shifts[0][fg_mask], y_shifts=y_shifts[0][fg_mask], ) cls_targets.append(cls_target) reg_targets.append(reg_target) obj_targets.append(obj_target.to(dtype)) fg_masks.append(fg_mask) if self.use_l1: l1_targets.append(l1_target) cls_targets = torch.cat(cls_targets, 0) reg_targets = torch.cat(reg_targets, 0) obj_targets = torch.cat(obj_targets, 0) fg_masks = torch.cat(fg_masks, 0) if self.use_l1: l1_targets = torch.cat(l1_targets, 0) num_fg = max(num_fg, 1) loss_iou = ( self.iou_loss(bbox_preds.view(-1, 4)[fg_masks], reg_targets) ).sum() / num_fg loss_obj = ( self.bcewithlog_loss(obj_preds.view(-1, 1), obj_targets) ).sum() / num_fg loss_cls = ( self.bcewithlog_loss( cls_preds.view(-1, self.num_classes)[fg_masks], cls_targets ) ).sum() / num_fg if self.use_l1: loss_l1 = ( self.l1_loss(origin_preds.view(-1, 4)[fg_masks], l1_targets) ).sum() / num_fg else: loss_l1 = 0.0 reg_weight = 5.0 loss = reg_weight * loss_iou + loss_obj + loss_cls + loss_l1 return ( loss, reg_weight * loss_iou, loss_obj, loss_cls, loss_l1, num_fg / max(num_gts, 1), ) def get_l1_target(self, l1_target, gt, stride, x_shifts, y_shifts, eps=1e-8): l1_target[:, 0] = gt[:, 0] / stride - x_shifts l1_target[:, 1] = gt[:, 1] / stride - y_shifts l1_target[:, 2] = torch.log(gt[:, 2] / stride + eps) l1_target[:, 3] = torch.log(gt[:, 3] / stride + eps) return l1_target @torch.no_grad() def get_assignments( self, batch_idx, num_gt, gt_bboxes_per_image, gt_classes, bboxes_preds_per_image, expanded_strides, x_shifts, y_shifts, cls_preds, obj_preds, mode="gpu", ): if mode == "cpu": print("-----------Using CPU for the Current Batch-------------") gt_bboxes_per_image = gt_bboxes_per_image.cpu().float() bboxes_preds_per_image = bboxes_preds_per_image.cpu().float() gt_classes = gt_classes.cpu().float() expanded_strides = expanded_strides.cpu().float() x_shifts = x_shifts.cpu() y_shifts = y_shifts.cpu() fg_mask, geometry_relation = self.get_geometry_constraint( gt_bboxes_per_image, expanded_strides, x_shifts, y_shifts, ) bboxes_preds_per_image = bboxes_preds_per_image[fg_mask] cls_preds_ = cls_preds[batch_idx][fg_mask] obj_preds_ = obj_preds[batch_idx][fg_mask] num_in_boxes_anchor = bboxes_preds_per_image.shape[0] if mode == "cpu": gt_bboxes_per_image = gt_bboxes_per_image.cpu() bboxes_preds_per_image = bboxes_preds_per_image.cpu() pair_wise_ious = bboxes_iou(gt_bboxes_per_image, bboxes_preds_per_image, False) gt_cls_per_image = F.one_hot( gt_classes.to(torch.int64), self.num_classes ).float() pair_wise_ious_loss = -torch.log(pair_wise_ious + 1e-8) if mode == "cpu": cls_preds_, obj_preds_ = cls_preds_.cpu(), obj_preds_.cpu() with torch.cuda.amp.autocast(enabled=False): cls_preds_ = ( cls_preds_.float().sigmoid_() * obj_preds_.float().sigmoid_() ).sqrt() pair_wise_cls_loss = F.binary_cross_entropy( cls_preds_.unsqueeze(0).repeat(num_gt, 1, 1), gt_cls_per_image.unsqueeze(1).repeat(1, num_in_boxes_anchor, 1), reduction="none", ).sum(-1) del cls_preds_ cost = ( pair_wise_cls_loss + 3.0 * pair_wise_ious_loss + float(1e6) * (~geometry_relation) ) ( num_fg, gt_matched_classes, pred_ious_this_matching, matched_gt_inds, ) = self.simota_matching(cost, pair_wise_ious, gt_classes, num_gt, fg_mask) del pair_wise_cls_loss, cost, pair_wise_ious, pair_wise_ious_loss if mode == "cpu": gt_matched_classes = gt_matched_classes.cuda() fg_mask = fg_mask.cuda() pred_ious_this_matching = pred_ious_this_matching.cuda() matched_gt_inds = matched_gt_inds.cuda() return ( gt_matched_classes, fg_mask, pred_ious_this_matching, matched_gt_inds, num_fg, ) def get_geometry_constraint( self, gt_bboxes_per_image, expanded_strides, x_shifts, y_shifts, ): """ Calculate whether the center of an object is located in a fixed range of an anchor. This is used to avert inappropriate matching. It can also reduce the number of candidate anchors so that the GPU memory is saved. """ expanded_strides_per_image = expanded_strides[0] x_centers_per_image = ( (x_shifts[0] + 0.5) * expanded_strides_per_image ).unsqueeze(0) y_centers_per_image = ( (y_shifts[0] + 0.5) * expanded_strides_per_image ).unsqueeze(0) # in fixed center center_radius = 1.5 center_dist = expanded_strides_per_image.unsqueeze(0) * center_radius gt_bboxes_per_image_l = (gt_bboxes_per_image[:, 0:1]) - center_dist gt_bboxes_per_image_r = (gt_bboxes_per_image[:, 0:1]) + center_dist gt_bboxes_per_image_t = (gt_bboxes_per_image[:, 1:2]) - center_dist gt_bboxes_per_image_b = (gt_bboxes_per_image[:, 1:2]) + center_dist c_l = x_centers_per_image - gt_bboxes_per_image_l c_r = gt_bboxes_per_image_r - x_centers_per_image c_t = y_centers_per_image - gt_bboxes_per_image_t c_b = gt_bboxes_per_image_b - y_centers_per_image center_deltas = torch.stack([c_l, c_t, c_r, c_b], 2) is_in_centers = center_deltas.min(dim=-1).values > 0.0 anchor_filter = is_in_centers.sum(dim=0) > 0 geometry_relation = is_in_centers[:, anchor_filter] return anchor_filter, geometry_relation def simota_matching(self, cost, pair_wise_ious, gt_classes, num_gt, fg_mask): matching_matrix = torch.zeros_like(cost, dtype=torch.uint8) n_candidate_k = min(10, pair_wise_ious.size(1)) topk_ious, _ = torch.topk(pair_wise_ious, n_candidate_k, dim=1) dynamic_ks = torch.clamp(topk_ious.sum(1).int(), min=1) for gt_idx in range(num_gt): _, pos_idx = torch.topk(cost[gt_idx], k=dynamic_ks[gt_idx], largest=False) matching_matrix[gt_idx][pos_idx] = 1 del topk_ious, dynamic_ks, pos_idx anchor_matching_gt = matching_matrix.sum(0) # deal with the case that one anchor matches multiple ground-truths if anchor_matching_gt.max() > 1: multiple_match_mask = anchor_matching_gt > 1 _, cost_argmin = torch.min(cost[:, multiple_match_mask], dim=0) matching_matrix[:, multiple_match_mask] *= 0 matching_matrix[cost_argmin, multiple_match_mask] = 1 fg_mask_inboxes = anchor_matching_gt > 0 num_fg = fg_mask_inboxes.sum().item() fg_mask[fg_mask.clone()] = fg_mask_inboxes matched_gt_inds = matching_matrix[:, fg_mask_inboxes].argmax(0) gt_matched_classes = gt_classes[matched_gt_inds] pred_ious_this_matching = (matching_matrix * pair_wise_ious).sum(0)[ fg_mask_inboxes ] return num_fg, gt_matched_classes, pred_ious_this_matching, matched_gt_inds ================================================ FILE: RVT/models/detection/yolox/utils/__init__.py ================================================ #!/usr/bin/env python3 # -*- coding:utf-8 -*- # Copyright (c) Megvii Inc. All rights reserved. from .boxes import * from .compat import meshgrid ================================================ FILE: RVT/models/detection/yolox/utils/boxes.py ================================================ #!/usr/bin/env python3 # -*- coding:utf-8 -*- # Copyright (c) Megvii Inc. All rights reserved. import numpy as np import torch import torchvision __all__ = [ "filter_box", "postprocess", "bboxes_iou", "matrix_iou", "adjust_box_anns", "xyxy2xywh", "xyxy2cxcywh", ] def filter_box(output, scale_range): """ output: (N, 5+class) shape """ min_scale, max_scale = scale_range w = output[:, 2] - output[:, 0] h = output[:, 3] - output[:, 1] keep = (w * h > min_scale * min_scale) & (w * h < max_scale * max_scale) return output[keep] def postprocess( prediction, num_classes, conf_thre=0.7, nms_thre=0.45, class_agnostic=False ): box_corner = prediction.new(prediction.shape) box_corner[:, :, 0] = prediction[:, :, 0] - prediction[:, :, 2] / 2 box_corner[:, :, 1] = prediction[:, :, 1] - prediction[:, :, 3] / 2 box_corner[:, :, 2] = prediction[:, :, 0] + prediction[:, :, 2] / 2 box_corner[:, :, 3] = prediction[:, :, 1] + prediction[:, :, 3] / 2 prediction[:, :, :4] = box_corner[:, :, :4] output = [None for _ in range(len(prediction))] for i, image_pred in enumerate(prediction): # If none are remaining => process next image if not image_pred.size(0): continue # Get score and class with highest confidence class_conf, class_pred = torch.max( image_pred[:, 5 : 5 + num_classes], 1, keepdim=True ) conf_mask = (image_pred[:, 4] * class_conf.squeeze() >= conf_thre).squeeze() # Detections ordered as (x1, y1, x2, y2, obj_conf, class_conf, class_pred) detections = torch.cat((image_pred[:, :5], class_conf, class_pred.float()), 1) detections = detections[conf_mask] if not detections.size(0): continue if class_agnostic: nms_out_index = torchvision.ops.nms( detections[:, :4], detections[:, 4] * detections[:, 5], nms_thre, ) else: nms_out_index = torchvision.ops.batched_nms( detections[:, :4], detections[:, 4] * detections[:, 5], detections[:, 6], nms_thre, ) detections = detections[nms_out_index] if output[i] is None: output[i] = detections else: output[i] = torch.cat((output[i], detections)) return output def bboxes_iou(bboxes_a, bboxes_b, xyxy=True): if bboxes_a.shape[1] != 4 or bboxes_b.shape[1] != 4: raise IndexError if xyxy: tl = torch.max(bboxes_a[:, None, :2], bboxes_b[:, :2]) br = torch.min(bboxes_a[:, None, 2:], bboxes_b[:, 2:]) area_a = torch.prod(bboxes_a[:, 2:] - bboxes_a[:, :2], 1) area_b = torch.prod(bboxes_b[:, 2:] - bboxes_b[:, :2], 1) else: tl = torch.max( (bboxes_a[:, None, :2] - bboxes_a[:, None, 2:] / 2), (bboxes_b[:, :2] - bboxes_b[:, 2:] / 2), ) br = torch.min( (bboxes_a[:, None, :2] + bboxes_a[:, None, 2:] / 2), (bboxes_b[:, :2] + bboxes_b[:, 2:] / 2), ) area_a = torch.prod(bboxes_a[:, 2:], 1) area_b = torch.prod(bboxes_b[:, 2:], 1) en = (tl < br).type(tl.type()).prod(dim=2) area_i = torch.prod(br - tl, 2) * en # * ((tl < br).all()) return area_i / (area_a[:, None] + area_b - area_i) def matrix_iou(a, b): """ return iou of a and b, numpy version for data augenmentation """ lt = np.maximum(a[:, np.newaxis, :2], b[:, :2]) rb = np.minimum(a[:, np.newaxis, 2:], b[:, 2:]) area_i = np.prod(rb - lt, axis=2) * (lt < rb).all(axis=2) area_a = np.prod(a[:, 2:] - a[:, :2], axis=1) area_b = np.prod(b[:, 2:] - b[:, :2], axis=1) return area_i / (area_a[:, np.newaxis] + area_b - area_i + 1e-12) def adjust_box_anns(bbox, scale_ratio, padw, padh, w_max, h_max): bbox[:, 0::2] = np.clip(bbox[:, 0::2] * scale_ratio + padw, 0, w_max) bbox[:, 1::2] = np.clip(bbox[:, 1::2] * scale_ratio + padh, 0, h_max) return bbox def xyxy2xywh(bboxes): bboxes[:, 2] = bboxes[:, 2] - bboxes[:, 0] bboxes[:, 3] = bboxes[:, 3] - bboxes[:, 1] return bboxes def xyxy2cxcywh(bboxes): bboxes[:, 2] = bboxes[:, 2] - bboxes[:, 0] bboxes[:, 3] = bboxes[:, 3] - bboxes[:, 1] bboxes[:, 0] = bboxes[:, 0] + bboxes[:, 2] * 0.5 bboxes[:, 1] = bboxes[:, 1] + bboxes[:, 3] * 0.5 return bboxes ================================================ FILE: RVT/models/detection/yolox/utils/compat.py ================================================ #!/usr/bin/env python3 # -*- coding:utf-8 -*- import torch _TORCH_VER = [int(x) for x in torch.__version__.split(".")[:2]] __all__ = ["meshgrid"] def meshgrid(*tensors): if _TORCH_VER >= [1, 10]: return torch.meshgrid(*tensors, indexing="ij") else: return torch.meshgrid(*tensors) ================================================ FILE: RVT/models/detection/yolox_extension/models/__init__.py ================================================ ================================================ FILE: RVT/models/detection/yolox_extension/models/build.py ================================================ from typing import Tuple from omegaconf import OmegaConf, DictConfig from .yolo_pafpn import YOLOPAFPN from ...yolox.models.yolo_head import YOLOXHead def build_yolox_head( head_cfg: DictConfig, in_channels: Tuple[int, ...], strides: Tuple[int, ...] ): head_cfg_dict = OmegaConf.to_container( head_cfg, resolve=True, throw_on_missing=True ) head_cfg_dict.pop("name") head_cfg_dict.pop("version", None) head_cfg_dict.update({"in_channels": in_channels}) head_cfg_dict.update({"strides": strides}) compile_cfg = head_cfg_dict.pop("compile", None) head_cfg_dict.update({"compile_cfg": compile_cfg}) return YOLOXHead(**head_cfg_dict) def build_yolox_fpn(fpn_cfg: DictConfig, in_channels: Tuple[int, ...]): fpn_cfg_dict = OmegaConf.to_container(fpn_cfg, resolve=True, throw_on_missing=True) fpn_name = fpn_cfg_dict.pop("name") fpn_cfg_dict.update({"in_channels": in_channels}) if fpn_name in {"PAFPN", "pafpn"}: compile_cfg = fpn_cfg_dict.pop("compile", None) fpn_cfg_dict.update({"compile_cfg": compile_cfg}) return YOLOPAFPN(**fpn_cfg_dict) raise NotImplementedError ================================================ FILE: RVT/models/detection/yolox_extension/models/detector.py ================================================ from typing import Dict, Optional, Tuple, Union import torch as th from omegaconf import DictConfig try: from torch import compile as th_compile except ImportError: th_compile = None from ...recurrent_backbone import build_recurrent_backbone from .build import build_yolox_fpn, build_yolox_head from utils.timers import TimerDummy as CudaTimer from data.utils.types import BackboneFeatures, LstmStates class YoloXDetector(th.nn.Module): def __init__(self, model_cfg: DictConfig): super().__init__() backbone_cfg = model_cfg.backbone fpn_cfg = model_cfg.fpn head_cfg = model_cfg.head self.backbone = build_recurrent_backbone(backbone_cfg) in_channels = self.backbone.get_stage_dims(fpn_cfg.in_stages) self.fpn = build_yolox_fpn(fpn_cfg, in_channels=in_channels) strides = self.backbone.get_strides(fpn_cfg.in_stages) self.yolox_head = build_yolox_head( head_cfg, in_channels=in_channels, strides=strides ) def forward_backbone( self, x: th.Tensor, previous_states: Optional[LstmStates] = None, token_mask: Optional[th.Tensor] = None, train_step: bool = True, ) -> Tuple[BackboneFeatures, LstmStates]: with CudaTimer(device=x.device, timer_name="Backbone"): backbone_features, states = self.backbone( x, previous_states, token_mask, train_step ) return backbone_features, states def forward_detect( self, backbone_features: BackboneFeatures, targets: Optional[th.Tensor] = None ) -> Tuple[th.Tensor, Union[Dict[str, th.Tensor], None]]: device = next(iter(backbone_features.values())).device with CudaTimer(device=device, timer_name="FPN"): fpn_features = self.fpn(backbone_features) if self.training: assert targets is not None with CudaTimer(device=device, timer_name="HEAD + Loss"): outputs, losses = self.yolox_head(fpn_features, targets) return outputs, losses with CudaTimer(device=device, timer_name="HEAD"): outputs, losses = self.yolox_head(fpn_features) assert losses is None return outputs, losses def forward( self, x: th.Tensor, previous_states: Optional[LstmStates] = None, retrieve_detections: bool = True, targets: Optional[th.Tensor] = None, ) -> Tuple[Union[th.Tensor, None], Union[Dict[str, th.Tensor], None], LstmStates]: backbone_features, states = self.forward_backbone(x, previous_states) outputs, losses = None, None if not retrieve_detections: assert targets is None return outputs, losses, states outputs, losses = self.forward_detect( backbone_features=backbone_features, targets=targets ) return outputs, losses, states ================================================ FILE: RVT/models/detection/yolox_extension/models/yolo_pafpn.py ================================================ """ Original Yolox PAFPN code with slight modifications """ from typing import Dict, Optional, Tuple import torch as th import torch.nn as nn try: from torch import compile as th_compile except ImportError: th_compile = None from ...yolox.models.network_blocks import BaseConv, CSPLayer, DWConv from data.utils.types import BackboneFeatures class YOLOPAFPN(nn.Module): """ Removed the direct dependency on the backbone. """ def __init__( self, depth: float = 1.0, in_stages: Tuple[int, ...] = (2, 3, 4), in_channels: Tuple[int, ...] = (256, 512, 1024), depthwise: bool = False, act: str = "silu", compile_cfg: Optional[Dict] = None, ): super().__init__() assert len(in_stages) == len(in_channels) assert len(in_channels) == 3, "Current implementation only for 3 feature maps" self.in_features = in_stages self.in_channels = in_channels Conv = DWConv if depthwise else BaseConv ###### Compile if requested ###### if compile_cfg is not None: compile_mdl = compile_cfg["enable"] if compile_mdl and th_compile is not None: self.forward = th_compile(self.forward, **compile_cfg["args"]) elif compile_mdl: print("Could not compile PAFPN because torch.compile is not available") ################################## self.upsample = lambda x: nn.functional.interpolate( x, scale_factor=2, mode="nearest-exact" ) self.lateral_conv0 = BaseConv(in_channels[2], in_channels[1], 1, 1, act=act) self.C3_p4 = CSPLayer( 2 * in_channels[1], in_channels[1], round(3 * depth), False, depthwise=depthwise, act=act, ) # cat self.reduce_conv1 = BaseConv(in_channels[1], in_channels[0], 1, 1, act=act) self.C3_p3 = CSPLayer( 2 * in_channels[0], in_channels[0], round(3 * depth), False, depthwise=depthwise, act=act, ) # bottom-up conv self.bu_conv2 = Conv(in_channels[0], in_channels[0], 3, 2, act=act) self.C3_n3 = CSPLayer( 2 * in_channels[0], in_channels[1], round(3 * depth), False, depthwise=depthwise, act=act, ) # bottom-up conv self.bu_conv1 = Conv(in_channels[1], in_channels[1], 3, 2, act=act) self.C3_n4 = CSPLayer( 2 * in_channels[1], in_channels[2], round(3 * depth), False, depthwise=depthwise, act=act, ) ###### Compile if requested ###### if compile_cfg is not None: compile_mdl = compile_cfg["enable"] if compile_mdl and th_compile is not None: self.forward = th_compile(self.forward, **compile_cfg["args"]) elif compile_mdl: print("Could not compile PAFPN because torch.compile is not available") ################################## def forward(self, input: BackboneFeatures): """ Args: inputs: Feature maps from backbone Returns: Tuple[Tensor]: FPN feature. """ features = [input[f] for f in self.in_features] x2, x1, x0 = features fpn_out0 = self.lateral_conv0(x0) # 1024->512/32 f_out0 = self.upsample(fpn_out0) # 512/16 f_out0 = th.cat([f_out0, x1], 1) # 512->1024/16 f_out0 = self.C3_p4(f_out0) # 1024->512/16 fpn_out1 = self.reduce_conv1(f_out0) # 512->256/16 f_out1 = self.upsample(fpn_out1) # 256/8 f_out1 = th.cat([f_out1, x2], 1) # 256->512/8 pan_out2 = self.C3_p3(f_out1) # 512->256/8 p_out1 = self.bu_conv2(pan_out2) # 256->256/16 p_out1 = th.cat([p_out1, fpn_out1], 1) # 256->512/16 pan_out1 = self.C3_n3(p_out1) # 512->512/16 p_out0 = self.bu_conv1(pan_out1) # 512->512/32 p_out0 = th.cat([p_out0, fpn_out0], 1) # 512->1024/32 pan_out0 = self.C3_n4(p_out0) # 1024->1024/32 outputs = (pan_out2, pan_out1, pan_out0) return outputs ================================================ FILE: RVT/models/layers/maxvit/__init__.py ================================================ ================================================ FILE: RVT/models/layers/maxvit/layers/__init__.py ================================================ from .activations import * from .adaptive_avgmax_pool import ( adaptive_avgmax_pool2d, select_adaptive_pool2d, AdaptiveAvgMaxPool2d, SelectAdaptivePool2d, ) from .blur_pool import BlurPool2d from .classifier import ClassifierHead, create_classifier from .cond_conv2d import CondConv2d, get_condconv_initializer from .config import ( is_exportable, is_scriptable, is_no_jit, set_exportable, set_scriptable, set_no_jit, set_layer_config, ) from .conv2d_same import Conv2dSame, conv2d_same from .conv_bn_act import ConvNormAct, ConvNormActAa, ConvBnAct from .create_act import create_act_layer, get_act_layer, get_act_fn from .create_attn import get_attn, create_attn from .create_conv2d import create_conv2d from .create_norm import get_norm_layer, create_norm_layer from .create_norm_act import ( get_norm_act_layer, create_norm_act_layer, get_norm_act_layer, ) from .drop import DropBlock2d, DropPath, drop_block_2d, drop_path from .eca import ( EcaModule, CecaModule, EfficientChannelAttn, CircularEfficientChannelAttn, ) from .evo_norm import ( EvoNorm2dB0, EvoNorm2dB1, EvoNorm2dB2, EvoNorm2dS0, EvoNorm2dS0a, EvoNorm2dS1, EvoNorm2dS1a, EvoNorm2dS2, EvoNorm2dS2a, ) from .fast_norm import is_fast_norm, set_fast_norm, fast_group_norm, fast_layer_norm from .filter_response_norm import FilterResponseNormTlu2d, FilterResponseNormAct2d from .gather_excite import GatherExcite from .global_context import GlobalContext from .helpers import ( to_ntuple, to_2tuple, to_3tuple, to_4tuple, make_divisible, extend_tuple, ) from .inplace_abn import InplaceAbn from .linear import Linear from .mixed_conv2d import MixedConv2d from .mlp import Mlp, GluMlp, GatedMlp, ConvMlp from .non_local_attn import NonLocalAttn, BatNonLocalAttn from .norm import GroupNorm, GroupNorm1, LayerNorm, LayerNorm2d from .norm_act import BatchNormAct2d, GroupNormAct, convert_sync_batchnorm from .padding import get_padding, get_same_padding, pad_same from .patch_embed import PatchEmbed from .pool2d_same import AvgPool2dSame, create_pool2d from .squeeze_excite import ( SEModule, SqueezeExcite, EffectiveSEModule, EffectiveSqueezeExcite, ) from .selective_kernel import SelectiveKernel from .separable_conv import SeparableConv2d, SeparableConvNormAct from .space_to_depth import SpaceToDepthModule from .split_attn import SplitAttn from .split_batchnorm import SplitBatchNorm2d, convert_splitbn_model from .std_conv import StdConv2d, StdConv2dSame, ScaledStdConv2d, ScaledStdConv2dSame from .test_time_pool import TestTimePoolHead, apply_test_time_pool from .trace_utils import _assert, _float_to_int from .weight_init import ( trunc_normal_, trunc_normal_tf_, variance_scaling_, lecun_normal_, ) ================================================ FILE: RVT/models/layers/maxvit/layers/activations.py ================================================ """ Activations A collection of activations fn and modules with a common interface so that they can easily be swapped. All have an `inplace` arg even if not used. Hacked together by / Copyright 2020 Ross Wightman """ import torch from torch import nn as nn from torch.nn import functional as F def swish(x, inplace: bool = False): """Swish - Described in: https://arxiv.org/abs/1710.05941""" return x.mul_(x.sigmoid()) if inplace else x.mul(x.sigmoid()) class Swish(nn.Module): def __init__(self, inplace: bool = False): super(Swish, self).__init__() self.inplace = inplace def forward(self, x): return swish(x, self.inplace) def mish(x, inplace: bool = False): """Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681 NOTE: I don't have a working inplace variant """ return x.mul(F.softplus(x).tanh()) class Mish(nn.Module): """Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681""" def __init__(self, inplace: bool = False): super(Mish, self).__init__() def forward(self, x): return mish(x) def sigmoid(x, inplace: bool = False): return x.sigmoid_() if inplace else x.sigmoid() # PyTorch has this, but not with a consistent inplace argmument interface class Sigmoid(nn.Module): def __init__(self, inplace: bool = False): super(Sigmoid, self).__init__() self.inplace = inplace def forward(self, x): return x.sigmoid_() if self.inplace else x.sigmoid() def tanh(x, inplace: bool = False): return x.tanh_() if inplace else x.tanh() # PyTorch has this, but not with a consistent inplace argmument interface class Tanh(nn.Module): def __init__(self, inplace: bool = False): super(Tanh, self).__init__() self.inplace = inplace def forward(self, x): return x.tanh_() if self.inplace else x.tanh() def hard_swish(x, inplace: bool = False): inner = F.relu6(x + 3.0).div_(6.0) return x.mul_(inner) if inplace else x.mul(inner) class HardSwish(nn.Module): def __init__(self, inplace: bool = False): super(HardSwish, self).__init__() self.inplace = inplace def forward(self, x): return hard_swish(x, self.inplace) def hard_sigmoid(x, inplace: bool = False): if inplace: return x.add_(3.0).clamp_(0.0, 6.0).div_(6.0) else: return F.relu6(x + 3.0) / 6.0 class HardSigmoid(nn.Module): def __init__(self, inplace: bool = False): super(HardSigmoid, self).__init__() self.inplace = inplace def forward(self, x): return hard_sigmoid(x, self.inplace) def hard_mish(x, inplace: bool = False): """Hard Mish Experimental, based on notes by Mish author Diganta Misra at https://github.com/digantamisra98/H-Mish/blob/0da20d4bc58e696b6803f2523c58d3c8a82782d0/README.md """ if inplace: return x.mul_(0.5 * (x + 2).clamp(min=0, max=2)) else: return 0.5 * x * (x + 2).clamp(min=0, max=2) class HardMish(nn.Module): def __init__(self, inplace: bool = False): super(HardMish, self).__init__() self.inplace = inplace def forward(self, x): return hard_mish(x, self.inplace) class PReLU(nn.PReLU): """Applies PReLU (w/ dummy inplace arg)""" def __init__( self, num_parameters: int = 1, init: float = 0.25, inplace: bool = False ) -> None: super(PReLU, self).__init__(num_parameters=num_parameters, init=init) def forward(self, input: torch.Tensor) -> torch.Tensor: return F.prelu(input, self.weight) def gelu(x: torch.Tensor, inplace: bool = False) -> torch.Tensor: return F.gelu(x) class GELU(nn.Module): """Applies the Gaussian Error Linear Units function (w/ dummy inplace arg)""" def __init__(self, inplace: bool = False): super(GELU, self).__init__() def forward(self, input: torch.Tensor) -> torch.Tensor: return F.gelu(input) ================================================ FILE: RVT/models/layers/maxvit/layers/activations_jit.py ================================================ """ Activations A collection of jit-scripted activations fn and modules with a common interface so that they can easily be swapped. All have an `inplace` arg even if not used. All jit scripted activations are lacking in-place variations on purpose, scripted kernel fusion does not currently work across in-place op boundaries, thus performance is equal to or less than the non-scripted versions if they contain in-place ops. Hacked together by / Copyright 2020 Ross Wightman """ import torch from torch import nn as nn from torch.nn import functional as F @torch.jit.script def swish_jit(x, inplace: bool = False): """Swish - Described in: https://arxiv.org/abs/1710.05941""" return x.mul(x.sigmoid()) @torch.jit.script def mish_jit(x, _inplace: bool = False): """Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681""" return x.mul(F.softplus(x).tanh()) class SwishJit(nn.Module): def __init__(self, inplace: bool = False): super(SwishJit, self).__init__() def forward(self, x): return swish_jit(x) class MishJit(nn.Module): def __init__(self, inplace: bool = False): super(MishJit, self).__init__() def forward(self, x): return mish_jit(x) @torch.jit.script def hard_sigmoid_jit(x, inplace: bool = False): # return F.relu6(x + 3.) / 6. return (x + 3).clamp(min=0, max=6).div(6.0) # clamp seems ever so slightly faster? class HardSigmoidJit(nn.Module): def __init__(self, inplace: bool = False): super(HardSigmoidJit, self).__init__() def forward(self, x): return hard_sigmoid_jit(x) @torch.jit.script def hard_swish_jit(x, inplace: bool = False): # return x * (F.relu6(x + 3.) / 6) return x * (x + 3).clamp(min=0, max=6).div( 6.0 ) # clamp seems ever so slightly faster? class HardSwishJit(nn.Module): def __init__(self, inplace: bool = False): super(HardSwishJit, self).__init__() def forward(self, x): return hard_swish_jit(x) @torch.jit.script def hard_mish_jit(x, inplace: bool = False): """Hard Mish Experimental, based on notes by Mish author Diganta Misra at https://github.com/digantamisra98/H-Mish/blob/0da20d4bc58e696b6803f2523c58d3c8a82782d0/README.md """ return 0.5 * x * (x + 2).clamp(min=0, max=2) class HardMishJit(nn.Module): def __init__(self, inplace: bool = False): super(HardMishJit, self).__init__() def forward(self, x): return hard_mish_jit(x) ================================================ FILE: RVT/models/layers/maxvit/layers/activations_me.py ================================================ """ Activations (memory-efficient w/ custom autograd) A collection of activations fn and modules with a common interface so that they can easily be swapped. All have an `inplace` arg even if not used. These activations are not compatible with jit scripting or ONNX export of the model, please use either the JIT or basic versions of the activations. Hacked together by / Copyright 2020 Ross Wightman """ import torch from torch import nn as nn from torch.nn import functional as F @torch.jit.script def swish_jit_fwd(x): return x.mul(torch.sigmoid(x)) @torch.jit.script def swish_jit_bwd(x, grad_output): x_sigmoid = torch.sigmoid(x) return grad_output * (x_sigmoid * (1 + x * (1 - x_sigmoid))) class SwishJitAutoFn(torch.autograd.Function): """torch.jit.script optimised Swish w/ memory-efficient checkpoint Inspired by conversation btw Jeremy Howard & Adam Pazske https://twitter.com/jeremyphoward/status/1188251041835315200 """ @staticmethod def symbolic(g, x): return g.op("Mul", x, g.op("Sigmoid", x)) @staticmethod def forward(ctx, x): ctx.save_for_backward(x) return swish_jit_fwd(x) @staticmethod def backward(ctx, grad_output): x = ctx.saved_tensors[0] return swish_jit_bwd(x, grad_output) def swish_me(x, inplace=False): return SwishJitAutoFn.apply(x) class SwishMe(nn.Module): def __init__(self, inplace: bool = False): super(SwishMe, self).__init__() def forward(self, x): return SwishJitAutoFn.apply(x) @torch.jit.script def mish_jit_fwd(x): return x.mul(torch.tanh(F.softplus(x))) @torch.jit.script def mish_jit_bwd(x, grad_output): x_sigmoid = torch.sigmoid(x) x_tanh_sp = F.softplus(x).tanh() return grad_output.mul(x_tanh_sp + x * x_sigmoid * (1 - x_tanh_sp * x_tanh_sp)) class MishJitAutoFn(torch.autograd.Function): """Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681 A memory efficient, jit scripted variant of Mish """ @staticmethod def forward(ctx, x): ctx.save_for_backward(x) return mish_jit_fwd(x) @staticmethod def backward(ctx, grad_output): x = ctx.saved_tensors[0] return mish_jit_bwd(x, grad_output) def mish_me(x, inplace=False): return MishJitAutoFn.apply(x) class MishMe(nn.Module): def __init__(self, inplace: bool = False): super(MishMe, self).__init__() def forward(self, x): return MishJitAutoFn.apply(x) @torch.jit.script def hard_sigmoid_jit_fwd(x, inplace: bool = False): return (x + 3).clamp(min=0, max=6).div(6.0) @torch.jit.script def hard_sigmoid_jit_bwd(x, grad_output): m = torch.ones_like(x) * ((x >= -3.0) & (x <= 3.0)) / 6.0 return grad_output * m class HardSigmoidJitAutoFn(torch.autograd.Function): @staticmethod def forward(ctx, x): ctx.save_for_backward(x) return hard_sigmoid_jit_fwd(x) @staticmethod def backward(ctx, grad_output): x = ctx.saved_tensors[0] return hard_sigmoid_jit_bwd(x, grad_output) def hard_sigmoid_me(x, inplace: bool = False): return HardSigmoidJitAutoFn.apply(x) class HardSigmoidMe(nn.Module): def __init__(self, inplace: bool = False): super(HardSigmoidMe, self).__init__() def forward(self, x): return HardSigmoidJitAutoFn.apply(x) @torch.jit.script def hard_swish_jit_fwd(x): return x * (x + 3).clamp(min=0, max=6).div(6.0) @torch.jit.script def hard_swish_jit_bwd(x, grad_output): m = torch.ones_like(x) * (x >= 3.0) m = torch.where((x >= -3.0) & (x <= 3.0), x / 3.0 + 0.5, m) return grad_output * m class HardSwishJitAutoFn(torch.autograd.Function): """A memory efficient, jit-scripted HardSwish activation""" @staticmethod def forward(ctx, x): ctx.save_for_backward(x) return hard_swish_jit_fwd(x) @staticmethod def backward(ctx, grad_output): x = ctx.saved_tensors[0] return hard_swish_jit_bwd(x, grad_output) @staticmethod def symbolic(g, self): input = g.op( "Add", self, g.op("Constant", value_t=torch.tensor(3, dtype=torch.float)) ) hardtanh_ = g.op( "Clip", input, g.op("Constant", value_t=torch.tensor(0, dtype=torch.float)), g.op("Constant", value_t=torch.tensor(6, dtype=torch.float)), ) hardtanh_ = g.op( "Div", hardtanh_, g.op("Constant", value_t=torch.tensor(6, dtype=torch.float)), ) return g.op("Mul", self, hardtanh_) def hard_swish_me(x, inplace=False): return HardSwishJitAutoFn.apply(x) class HardSwishMe(nn.Module): def __init__(self, inplace: bool = False): super(HardSwishMe, self).__init__() def forward(self, x): return HardSwishJitAutoFn.apply(x) @torch.jit.script def hard_mish_jit_fwd(x): return 0.5 * x * (x + 2).clamp(min=0, max=2) @torch.jit.script def hard_mish_jit_bwd(x, grad_output): m = torch.ones_like(x) * (x >= -2.0) m = torch.where((x >= -2.0) & (x <= 0.0), x + 1.0, m) return grad_output * m class HardMishJitAutoFn(torch.autograd.Function): """A memory efficient, jit scripted variant of Hard Mish Experimental, based on notes by Mish author Diganta Misra at https://github.com/digantamisra98/H-Mish/blob/0da20d4bc58e696b6803f2523c58d3c8a82782d0/README.md """ @staticmethod def forward(ctx, x): ctx.save_for_backward(x) return hard_mish_jit_fwd(x) @staticmethod def backward(ctx, grad_output): x = ctx.saved_tensors[0] return hard_mish_jit_bwd(x, grad_output) def hard_mish_me(x, inplace: bool = False): return HardMishJitAutoFn.apply(x) class HardMishMe(nn.Module): def __init__(self, inplace: bool = False): super(HardMishMe, self).__init__() def forward(self, x): return HardMishJitAutoFn.apply(x) ================================================ FILE: RVT/models/layers/maxvit/layers/adaptive_avgmax_pool.py ================================================ """ PyTorch selectable adaptive pooling Adaptive pooling with the ability to select the type of pooling from: * 'avg' - Average pooling * 'max' - Max pooling * 'avgmax' - Sum of average and max pooling re-scaled by 0.5 * 'avgmaxc' - Concatenation of average and max pooling along feature dim, doubles feature dim Both a functional and a nn.Module version of the pooling is provided. Hacked together by / Copyright 2020 Ross Wightman """ import torch import torch.nn as nn import torch.nn.functional as F def adaptive_pool_feat_mult(pool_type="avg"): if pool_type == "catavgmax": return 2 else: return 1 def adaptive_avgmax_pool2d(x, output_size=1): x_avg = F.adaptive_avg_pool2d(x, output_size) x_max = F.adaptive_max_pool2d(x, output_size) return 0.5 * (x_avg + x_max) def adaptive_catavgmax_pool2d(x, output_size=1): x_avg = F.adaptive_avg_pool2d(x, output_size) x_max = F.adaptive_max_pool2d(x, output_size) return torch.cat((x_avg, x_max), 1) def select_adaptive_pool2d(x, pool_type="avg", output_size=1): """Selectable global pooling function with dynamic input kernel size""" if pool_type == "avg": x = F.adaptive_avg_pool2d(x, output_size) elif pool_type == "avgmax": x = adaptive_avgmax_pool2d(x, output_size) elif pool_type == "catavgmax": x = adaptive_catavgmax_pool2d(x, output_size) elif pool_type == "max": x = F.adaptive_max_pool2d(x, output_size) else: assert False, "Invalid pool type: %s" % pool_type return x class FastAdaptiveAvgPool2d(nn.Module): def __init__(self, flatten=False): super(FastAdaptiveAvgPool2d, self).__init__() self.flatten = flatten def forward(self, x): return x.mean((2, 3), keepdim=not self.flatten) class AdaptiveAvgMaxPool2d(nn.Module): def __init__(self, output_size=1): super(AdaptiveAvgMaxPool2d, self).__init__() self.output_size = output_size def forward(self, x): return adaptive_avgmax_pool2d(x, self.output_size) class AdaptiveCatAvgMaxPool2d(nn.Module): def __init__(self, output_size=1): super(AdaptiveCatAvgMaxPool2d, self).__init__() self.output_size = output_size def forward(self, x): return adaptive_catavgmax_pool2d(x, self.output_size) class SelectAdaptivePool2d(nn.Module): """Selectable global pooling layer with dynamic input kernel size""" def __init__(self, output_size=1, pool_type="fast", flatten=False): super(SelectAdaptivePool2d, self).__init__() self.pool_type = ( pool_type or "" ) # convert other falsy values to empty string for consistent TS typing self.flatten = nn.Flatten(1) if flatten else nn.Identity() if pool_type == "": self.pool = nn.Identity() # pass through elif pool_type == "fast": assert output_size == 1 self.pool = FastAdaptiveAvgPool2d(flatten) self.flatten = nn.Identity() elif pool_type == "avg": self.pool = nn.AdaptiveAvgPool2d(output_size) elif pool_type == "avgmax": self.pool = AdaptiveAvgMaxPool2d(output_size) elif pool_type == "catavgmax": self.pool = AdaptiveCatAvgMaxPool2d(output_size) elif pool_type == "max": self.pool = nn.AdaptiveMaxPool2d(output_size) else: assert False, "Invalid pool type: %s" % pool_type def is_identity(self): return not self.pool_type def forward(self, x): x = self.pool(x) x = self.flatten(x) return x def feat_mult(self): return adaptive_pool_feat_mult(self.pool_type) def __repr__(self): return ( self.__class__.__name__ + " (" + "pool_type=" + self.pool_type + ", flatten=" + str(self.flatten) + ")" ) ================================================ FILE: RVT/models/layers/maxvit/layers/attention_pool2d.py ================================================ """ Attention Pool 2D Implementations of 2D spatial feature pooling using multi-head attention instead of average pool. Based on idea in CLIP by OpenAI, licensed Apache 2.0 https://github.com/openai/CLIP/blob/3b473b0e682c091a9e53623eebc1ca1657385717/clip/model.py Hacked together by / Copyright 2021 Ross Wightman """ from typing import Union, Tuple import torch import torch.nn as nn from .helpers import to_2tuple from .pos_embed import apply_rot_embed, RotaryEmbedding from .weight_init import trunc_normal_ class RotAttentionPool2d(nn.Module): """Attention based 2D feature pooling w/ rotary (relative) pos embedding. This is a multi-head attention based replacement for (spatial) average pooling in NN architectures. Adapted from the AttentionPool2d in CLIP w/ rotary embedding instead of learned embed. https://github.com/openai/CLIP/blob/3b473b0e682c091a9e53623eebc1ca1657385717/clip/model.py NOTE: While this impl does not require a fixed feature size, performance at differeing resolutions from train varies widely and falls off dramatically. I'm not sure if there is a way around this... -RW """ def __init__( self, in_features: int, out_features: int = None, embed_dim: int = None, num_heads: int = 4, qkv_bias: bool = True, ): super().__init__() embed_dim = embed_dim or in_features out_features = out_features or in_features self.qkv = nn.Linear(in_features, embed_dim * 3, bias=qkv_bias) self.proj = nn.Linear(embed_dim, out_features) self.num_heads = num_heads assert embed_dim % num_heads == 0 self.head_dim = embed_dim // num_heads self.scale = self.head_dim**-0.5 self.pos_embed = RotaryEmbedding(self.head_dim) trunc_normal_(self.qkv.weight, std=in_features**-0.5) nn.init.zeros_(self.qkv.bias) def forward(self, x): B, _, H, W = x.shape N = H * W x = x.reshape(B, -1, N).permute(0, 2, 1) x = torch.cat([x.mean(1, keepdim=True), x], dim=1) x = ( self.qkv(x) .reshape(B, N + 1, 3, self.num_heads, self.head_dim) .permute(2, 0, 3, 1, 4) ) q, k, v = x[0], x[1], x[2] qc, q = q[:, :, :1], q[:, :, 1:] sin_emb, cos_emb = self.pos_embed.get_embed((H, W)) q = apply_rot_embed(q, sin_emb, cos_emb) q = torch.cat([qc, q], dim=2) kc, k = k[:, :, :1], k[:, :, 1:] k = apply_rot_embed(k, sin_emb, cos_emb) k = torch.cat([kc, k], dim=2) attn = (q @ k.transpose(-2, -1)) * self.scale attn = attn.softmax(dim=-1) x = (attn @ v).transpose(1, 2).reshape(B, N + 1, -1) x = self.proj(x) return x[:, 0] class AttentionPool2d(nn.Module): """Attention based 2D feature pooling w/ learned (absolute) pos embedding. This is a multi-head attention based replacement for (spatial) average pooling in NN architectures. It was based on impl in CLIP by OpenAI https://github.com/openai/CLIP/blob/3b473b0e682c091a9e53623eebc1ca1657385717/clip/model.py NOTE: This requires feature size upon construction and well prevent adaptive sizing of the network. """ def __init__( self, in_features: int, feat_size: Union[int, Tuple[int, int]], out_features: int = None, embed_dim: int = None, num_heads: int = 4, qkv_bias: bool = True, ): super().__init__() embed_dim = embed_dim or in_features out_features = out_features or in_features assert embed_dim % num_heads == 0 self.feat_size = to_2tuple(feat_size) self.qkv = nn.Linear(in_features, embed_dim * 3, bias=qkv_bias) self.proj = nn.Linear(embed_dim, out_features) self.num_heads = num_heads self.head_dim = embed_dim // num_heads self.scale = self.head_dim**-0.5 spatial_dim = self.feat_size[0] * self.feat_size[1] self.pos_embed = nn.Parameter(torch.zeros(spatial_dim + 1, in_features)) trunc_normal_(self.pos_embed, std=in_features**-0.5) trunc_normal_(self.qkv.weight, std=in_features**-0.5) nn.init.zeros_(self.qkv.bias) def forward(self, x): B, _, H, W = x.shape N = H * W assert self.feat_size[0] == H assert self.feat_size[1] == W x = x.reshape(B, -1, N).permute(0, 2, 1) x = torch.cat([x.mean(1, keepdim=True), x], dim=1) x = x + self.pos_embed.unsqueeze(0).to(x.dtype) x = ( self.qkv(x) .reshape(B, N + 1, 3, self.num_heads, self.head_dim) .permute(2, 0, 3, 1, 4) ) q, k, v = x[0], x[1], x[2] attn = (q @ k.transpose(-2, -1)) * self.scale attn = attn.softmax(dim=-1) x = (attn @ v).transpose(1, 2).reshape(B, N + 1, -1) x = self.proj(x) return x[:, 0] ================================================ FILE: RVT/models/layers/maxvit/layers/blur_pool.py ================================================ """ BlurPool layer inspired by - Kornia's Max_BlurPool2d - Making Convolutional Networks Shift-Invariant Again :cite:`zhang2019shiftinvar` Hacked together by Chris Ha and Ross Wightman """ import torch import torch.nn as nn import torch.nn.functional as F import numpy as np from .padding import get_padding class BlurPool2d(nn.Module): r"""Creates a module that computes blurs and downsample a given feature map. See :cite:`zhang2019shiftinvar` for more details. Corresponds to the Downsample class, which does blurring and subsampling Args: channels = Number of input channels filt_size (int): binomial filter size for blurring. currently supports 3 (default) and 5. stride (int): downsampling filter stride Returns: torch.Tensor: the transformed tensor. """ def __init__(self, channels, filt_size=3, stride=2) -> None: super(BlurPool2d, self).__init__() assert filt_size > 1 self.channels = channels self.filt_size = filt_size self.stride = stride self.padding = [get_padding(filt_size, stride, dilation=1)] * 4 coeffs = torch.tensor( (np.poly1d((0.5, 0.5)) ** (self.filt_size - 1)).coeffs.astype(np.float32) ) blur_filter = (coeffs[:, None] * coeffs[None, :])[None, None, :, :].repeat( self.channels, 1, 1, 1 ) self.register_buffer("filt", blur_filter, persistent=False) def forward(self, x: torch.Tensor) -> torch.Tensor: x = F.pad(x, self.padding, "reflect") return F.conv2d(x, self.filt, stride=self.stride, groups=self.channels) ================================================ FILE: RVT/models/layers/maxvit/layers/bottleneck_attn.py ================================================ """ Bottleneck Self Attention (Bottleneck Transformers) Paper: `Bottleneck Transformers for Visual Recognition` - https://arxiv.org/abs/2101.11605 @misc{2101.11605, Author = {Aravind Srinivas and Tsung-Yi Lin and Niki Parmar and Jonathon Shlens and Pieter Abbeel and Ashish Vaswani}, Title = {Bottleneck Transformers for Visual Recognition}, Year = {2021}, } Based on ref gist at: https://gist.github.com/aravindsrinivas/56359b79f0ce4449bcb04ab4b56a57a2 This impl is a WIP but given that it is based on the ref gist likely not too far off. Hacked together by / Copyright 2021 Ross Wightman """ from typing import List import torch import torch.nn as nn import torch.nn.functional as F from .helpers import to_2tuple, make_divisible from .weight_init import trunc_normal_ from .trace_utils import _assert def rel_logits_1d(q, rel_k, permute_mask: List[int]): """Compute relative logits along one dimension As per: https://gist.github.com/aravindsrinivas/56359b79f0ce4449bcb04ab4b56a57a2 Originally from: `Attention Augmented Convolutional Networks` - https://arxiv.org/abs/1904.09925 Args: q: (batch, heads, height, width, dim) rel_k: (2 * width - 1, dim) permute_mask: permute output dim according to this """ B, H, W, dim = q.shape x = q @ rel_k.transpose(-1, -2) x = x.reshape(-1, W, 2 * W - 1) # pad to shift from relative to absolute indexing x_pad = F.pad(x, [0, 1]).flatten(1) x_pad = F.pad(x_pad, [0, W - 1]) # reshape and slice out the padded elements x_pad = x_pad.reshape(-1, W + 1, 2 * W - 1) x = x_pad[:, :W, W - 1 :] # reshape and tile x = x.reshape(B, H, 1, W, W).expand(-1, -1, H, -1, -1) return x.permute(permute_mask) class PosEmbedRel(nn.Module): """Relative Position Embedding As per: https://gist.github.com/aravindsrinivas/56359b79f0ce4449bcb04ab4b56a57a2 Originally from: `Attention Augmented Convolutional Networks` - https://arxiv.org/abs/1904.09925 """ def __init__(self, feat_size, dim_head, scale): super().__init__() self.height, self.width = to_2tuple(feat_size) self.dim_head = dim_head self.height_rel = nn.Parameter( torch.randn(self.height * 2 - 1, dim_head) * scale ) self.width_rel = nn.Parameter(torch.randn(self.width * 2 - 1, dim_head) * scale) def forward(self, q): B, HW, _ = q.shape # relative logits in width dimension. q = q.reshape(B, self.height, self.width, -1) rel_logits_w = rel_logits_1d(q, self.width_rel, permute_mask=(0, 1, 3, 2, 4)) # relative logits in height dimension. q = q.transpose(1, 2) rel_logits_h = rel_logits_1d(q, self.height_rel, permute_mask=(0, 3, 1, 4, 2)) rel_logits = rel_logits_h + rel_logits_w rel_logits = rel_logits.reshape(B, HW, HW) return rel_logits class BottleneckAttn(nn.Module): """Bottleneck Attention Paper: `Bottleneck Transformers for Visual Recognition` - https://arxiv.org/abs/2101.11605 The internal dimensions of the attention module are controlled by the interaction of several arguments. * the output dimension of the module is specified by dim_out, which falls back to input dim if not set * the value (v) dimension is set to dim_out // num_heads, the v projection determines the output dim * the query and key (qk) dimensions are determined by * num_heads * dim_head if dim_head is not None * num_heads * (dim_out * attn_ratio // num_heads) if dim_head is None * as seen above, attn_ratio determines the ratio of q and k relative to the output if dim_head not used Args: dim (int): input dimension to the module dim_out (int): output dimension of the module, same as dim if not set stride (int): output stride of the module, avg pool used if stride == 2 (default: 1). num_heads (int): parallel attention heads (default: 4) dim_head (int): dimension of query and key heads, calculated from dim_out * attn_ratio // num_heads if not set qk_ratio (float): ratio of q and k dimensions to output dimension when dim_head not set. (default: 1.0) qkv_bias (bool): add bias to q, k, and v projections scale_pos_embed (bool): scale the position embedding as well as Q @ K """ def __init__( self, dim, dim_out=None, feat_size=None, stride=1, num_heads=4, dim_head=None, qk_ratio=1.0, qkv_bias=False, scale_pos_embed=False, ): super().__init__() assert ( feat_size is not None ), "A concrete feature size matching expected input (H, W) is required" dim_out = dim_out or dim assert dim_out % num_heads == 0 self.num_heads = num_heads self.dim_head_qk = ( dim_head or make_divisible(dim_out * qk_ratio, divisor=8) // num_heads ) self.dim_head_v = dim_out // self.num_heads self.dim_out_qk = num_heads * self.dim_head_qk self.dim_out_v = num_heads * self.dim_head_v self.scale = self.dim_head_qk**-0.5 self.scale_pos_embed = scale_pos_embed self.qkv = nn.Conv2d( dim, self.dim_out_qk * 2 + self.dim_out_v, 1, bias=qkv_bias ) # NOTE I'm only supporting relative pos embedding for now self.pos_embed = PosEmbedRel( feat_size, dim_head=self.dim_head_qk, scale=self.scale ) self.pool = nn.AvgPool2d(2, 2) if stride == 2 else nn.Identity() self.reset_parameters() def reset_parameters(self): trunc_normal_(self.qkv.weight, std=self.qkv.weight.shape[1] ** -0.5) # fan-in trunc_normal_(self.pos_embed.height_rel, std=self.scale) trunc_normal_(self.pos_embed.width_rel, std=self.scale) def forward(self, x): B, C, H, W = x.shape _assert(H == self.pos_embed.height, "") _assert(W == self.pos_embed.width, "") x = self.qkv(x) # B, (2 * dim_head_qk + dim_head_v) * num_heads, H, W # NOTE head vs channel split ordering in qkv projection was decided before I allowed qk to differ from v # So, this is more verbose than if heads were before qkv splits, but throughput is not impacted. q, k, v = torch.split( x, [self.dim_out_qk, self.dim_out_qk, self.dim_out_v], dim=1 ) q = q.reshape(B * self.num_heads, self.dim_head_qk, -1).transpose(-1, -2) k = k.reshape( B * self.num_heads, self.dim_head_qk, -1 ) # no transpose, for q @ k v = v.reshape(B * self.num_heads, self.dim_head_v, -1).transpose(-1, -2) if self.scale_pos_embed: attn = ( q @ k + self.pos_embed(q) ) * self.scale # B * num_heads, H * W, H * W else: attn = (q @ k) * self.scale + self.pos_embed(q) attn = attn.softmax(dim=-1) out = ( (attn @ v).transpose(-1, -2).reshape(B, self.dim_out_v, H, W) ) # B, dim_out, H, W out = self.pool(out) return out ================================================ FILE: RVT/models/layers/maxvit/layers/cbam.py ================================================ """ CBAM (sort-of) Attention Experimental impl of CBAM: Convolutional Block Attention Module: https://arxiv.org/abs/1807.06521 WARNING: Results with these attention layers have been mixed. They can significantly reduce performance on some tasks, especially fine-grained it seems. I may end up removing this impl. Hacked together by / Copyright 2020 Ross Wightman """ import torch from torch import nn as nn import torch.nn.functional as F from .conv_bn_act import ConvNormAct from .create_act import create_act_layer, get_act_layer from .helpers import make_divisible class ChannelAttn(nn.Module): """Original CBAM channel attention module, currently avg + max pool variant only.""" def __init__( self, channels, rd_ratio=1.0 / 16, rd_channels=None, rd_divisor=1, act_layer=nn.ReLU, gate_layer="sigmoid", mlp_bias=False, ): super(ChannelAttn, self).__init__() if not rd_channels: rd_channels = make_divisible( channels * rd_ratio, rd_divisor, round_limit=0.0 ) self.fc1 = nn.Conv2d(channels, rd_channels, 1, bias=mlp_bias) self.act = act_layer(inplace=True) self.fc2 = nn.Conv2d(rd_channels, channels, 1, bias=mlp_bias) self.gate = create_act_layer(gate_layer) def forward(self, x): x_avg = self.fc2(self.act(self.fc1(x.mean((2, 3), keepdim=True)))) x_max = self.fc2(self.act(self.fc1(x.amax((2, 3), keepdim=True)))) return x * self.gate(x_avg + x_max) class LightChannelAttn(ChannelAttn): """An experimental 'lightweight' that sums avg + max pool first""" def __init__( self, channels, rd_ratio=1.0 / 16, rd_channels=None, rd_divisor=1, act_layer=nn.ReLU, gate_layer="sigmoid", mlp_bias=False, ): super(LightChannelAttn, self).__init__( channels, rd_ratio, rd_channels, rd_divisor, act_layer, gate_layer, mlp_bias ) def forward(self, x): x_pool = 0.5 * x.mean((2, 3), keepdim=True) + 0.5 * x.amax((2, 3), keepdim=True) x_attn = self.fc2(self.act(self.fc1(x_pool))) return x * F.sigmoid(x_attn) class SpatialAttn(nn.Module): """Original CBAM spatial attention module""" def __init__(self, kernel_size=7, gate_layer="sigmoid"): super(SpatialAttn, self).__init__() self.conv = ConvNormAct(2, 1, kernel_size, apply_act=False) self.gate = create_act_layer(gate_layer) def forward(self, x): x_attn = torch.cat( [x.mean(dim=1, keepdim=True), x.amax(dim=1, keepdim=True)], dim=1 ) x_attn = self.conv(x_attn) return x * self.gate(x_attn) class LightSpatialAttn(nn.Module): """An experimental 'lightweight' variant that sums avg_pool and max_pool results.""" def __init__(self, kernel_size=7, gate_layer="sigmoid"): super(LightSpatialAttn, self).__init__() self.conv = ConvNormAct(1, 1, kernel_size, apply_act=False) self.gate = create_act_layer(gate_layer) def forward(self, x): x_attn = 0.5 * x.mean(dim=1, keepdim=True) + 0.5 * x.amax(dim=1, keepdim=True) x_attn = self.conv(x_attn) return x * self.gate(x_attn) class CbamModule(nn.Module): def __init__( self, channels, rd_ratio=1.0 / 16, rd_channels=None, rd_divisor=1, spatial_kernel_size=7, act_layer=nn.ReLU, gate_layer="sigmoid", mlp_bias=False, ): super(CbamModule, self).__init__() self.channel = ChannelAttn( channels, rd_ratio=rd_ratio, rd_channels=rd_channels, rd_divisor=rd_divisor, act_layer=act_layer, gate_layer=gate_layer, mlp_bias=mlp_bias, ) self.spatial = SpatialAttn(spatial_kernel_size, gate_layer=gate_layer) def forward(self, x): x = self.channel(x) x = self.spatial(x) return x class LightCbamModule(nn.Module): def __init__( self, channels, rd_ratio=1.0 / 16, rd_channels=None, rd_divisor=1, spatial_kernel_size=7, act_layer=nn.ReLU, gate_layer="sigmoid", mlp_bias=False, ): super(LightCbamModule, self).__init__() self.channel = LightChannelAttn( channels, rd_ratio=rd_ratio, rd_channels=rd_channels, rd_divisor=rd_divisor, act_layer=act_layer, gate_layer=gate_layer, mlp_bias=mlp_bias, ) self.spatial = LightSpatialAttn(spatial_kernel_size) def forward(self, x): x = self.channel(x) x = self.spatial(x) return x ================================================ FILE: RVT/models/layers/maxvit/layers/classifier.py ================================================ """ Classifier head and layer factory Hacked together by / Copyright 2020 Ross Wightman """ from torch import nn as nn from torch.nn import functional as F from .adaptive_avgmax_pool import SelectAdaptivePool2d def _create_pool(num_features, num_classes, pool_type="avg", use_conv=False): flatten_in_pool = not use_conv # flatten when we use a Linear layer after pooling if not pool_type: assert ( num_classes == 0 or use_conv ), "Pooling can only be disabled if classifier is also removed or conv classifier is used" flatten_in_pool = ( False # disable flattening if pooling is pass-through (no pooling) ) global_pool = SelectAdaptivePool2d(pool_type=pool_type, flatten=flatten_in_pool) num_pooled_features = num_features * global_pool.feat_mult() return global_pool, num_pooled_features def _create_fc(num_features, num_classes, use_conv=False): if num_classes <= 0: fc = nn.Identity() # pass-through (no classifier) elif use_conv: fc = nn.Conv2d(num_features, num_classes, 1, bias=True) else: fc = nn.Linear(num_features, num_classes, bias=True) return fc def create_classifier(num_features, num_classes, pool_type="avg", use_conv=False): global_pool, num_pooled_features = _create_pool( num_features, num_classes, pool_type, use_conv=use_conv ) fc = _create_fc(num_pooled_features, num_classes, use_conv=use_conv) return global_pool, fc class ClassifierHead(nn.Module): """Classifier head w/ configurable global pooling and dropout.""" def __init__( self, in_chs, num_classes, pool_type="avg", drop_rate=0.0, use_conv=False ): super(ClassifierHead, self).__init__() self.drop_rate = drop_rate self.global_pool, num_pooled_features = _create_pool( in_chs, num_classes, pool_type, use_conv=use_conv ) self.fc = _create_fc(num_pooled_features, num_classes, use_conv=use_conv) self.flatten = nn.Flatten(1) if use_conv and pool_type else nn.Identity() def forward(self, x, pre_logits: bool = False): x = self.global_pool(x) if self.drop_rate: x = F.dropout(x, p=float(self.drop_rate), training=self.training) if pre_logits: return x.flatten(1) else: x = self.fc(x) return self.flatten(x) ================================================ FILE: RVT/models/layers/maxvit/layers/cond_conv2d.py ================================================ """ PyTorch Conditionally Parameterized Convolution (CondConv) Paper: CondConv: Conditionally Parameterized Convolutions for Efficient Inference (https://arxiv.org/abs/1904.04971) Hacked together by / Copyright 2020 Ross Wightman """ import math from functools import partial import numpy as np import torch from torch import nn as nn from torch.nn import functional as F from .helpers import to_2tuple from .conv2d_same import conv2d_same from .padding import get_padding_value def get_condconv_initializer(initializer, num_experts, expert_shape): def condconv_initializer(weight): """CondConv initializer function.""" num_params = np.prod(expert_shape) if ( len(weight.shape) != 2 or weight.shape[0] != num_experts or weight.shape[1] != num_params ): raise ( ValueError( "CondConv variables must have shape [num_experts, num_params]" ) ) for i in range(num_experts): initializer(weight[i].view(expert_shape)) return condconv_initializer class CondConv2d(nn.Module): """Conditionally Parameterized Convolution Inspired by: https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/condconv/condconv_layers.py Grouped convolution hackery for parallel execution of the per-sample kernel filters inspired by this discussion: https://github.com/pytorch/pytorch/issues/17983 """ __constants__ = ["in_channels", "out_channels", "dynamic_padding"] def __init__( self, in_channels, out_channels, kernel_size=3, stride=1, padding="", dilation=1, groups=1, bias=False, num_experts=4, ): super(CondConv2d, self).__init__() self.in_channels = in_channels self.out_channels = out_channels self.kernel_size = to_2tuple(kernel_size) self.stride = to_2tuple(stride) padding_val, is_padding_dynamic = get_padding_value( padding, kernel_size, stride=stride, dilation=dilation ) self.dynamic_padding = ( is_padding_dynamic # if in forward to work with torchscript ) self.padding = to_2tuple(padding_val) self.dilation = to_2tuple(dilation) self.groups = groups self.num_experts = num_experts self.weight_shape = ( self.out_channels, self.in_channels // self.groups, ) + self.kernel_size weight_num_param = 1 for wd in self.weight_shape: weight_num_param *= wd self.weight = torch.nn.Parameter( torch.Tensor(self.num_experts, weight_num_param) ) if bias: self.bias_shape = (self.out_channels,) self.bias = torch.nn.Parameter( torch.Tensor(self.num_experts, self.out_channels) ) else: self.register_parameter("bias", None) self.reset_parameters() def reset_parameters(self): init_weight = get_condconv_initializer( partial(nn.init.kaiming_uniform_, a=math.sqrt(5)), self.num_experts, self.weight_shape, ) init_weight(self.weight) if self.bias is not None: fan_in = np.prod(self.weight_shape[1:]) bound = 1 / math.sqrt(fan_in) init_bias = get_condconv_initializer( partial(nn.init.uniform_, a=-bound, b=bound), self.num_experts, self.bias_shape, ) init_bias(self.bias) def forward(self, x, routing_weights): B, C, H, W = x.shape weight = torch.matmul(routing_weights, self.weight) new_weight_shape = ( B * self.out_channels, self.in_channels // self.groups, ) + self.kernel_size weight = weight.view(new_weight_shape) bias = None if self.bias is not None: bias = torch.matmul(routing_weights, self.bias) bias = bias.view(B * self.out_channels) # move batch elements with channels so each batch element can be efficiently convolved with separate kernel # reshape instead of view to work with channels_last input x = x.reshape(1, B * C, H, W) if self.dynamic_padding: out = conv2d_same( x, weight, bias, stride=self.stride, padding=self.padding, dilation=self.dilation, groups=self.groups * B, ) else: out = F.conv2d( x, weight, bias, stride=self.stride, padding=self.padding, dilation=self.dilation, groups=self.groups * B, ) out = out.permute([1, 0, 2, 3]).view( B, self.out_channels, out.shape[-2], out.shape[-1] ) # Literal port (from TF definition) # x = torch.split(x, 1, 0) # weight = torch.split(weight, 1, 0) # if self.bias is not None: # bias = torch.matmul(routing_weights, self.bias) # bias = torch.split(bias, 1, 0) # else: # bias = [None] * B # out = [] # for xi, wi, bi in zip(x, weight, bias): # wi = wi.view(*self.weight_shape) # if bi is not None: # bi = bi.view(*self.bias_shape) # out.append(self.conv_fn( # xi, wi, bi, stride=self.stride, padding=self.padding, # dilation=self.dilation, groups=self.groups)) # out = torch.cat(out, 0) return out ================================================ FILE: RVT/models/layers/maxvit/layers/config.py ================================================ """ Model / Layer Config singleton state """ from typing import Any, Optional __all__ = [ "is_exportable", "is_scriptable", "is_no_jit", "set_exportable", "set_scriptable", "set_no_jit", "set_layer_config", ] # Set to True if prefer to have layers with no jit optimization (includes activations) _NO_JIT = False # Set to True if prefer to have activation layers with no jit optimization # NOTE not currently used as no difference between no_jit and no_activation jit as only layers obeying # the jit flags so far are activations. This will change as more layers are updated and/or added. _NO_ACTIVATION_JIT = False # Set to True if exporting a model with Same padding via ONNX _EXPORTABLE = False # Set to True if wanting to use torch.jit.script on a model _SCRIPTABLE = False def is_no_jit(): return _NO_JIT class set_no_jit: def __init__(self, mode: bool) -> None: global _NO_JIT self.prev = _NO_JIT _NO_JIT = mode def __enter__(self) -> None: pass def __exit__(self, *args: Any) -> bool: global _NO_JIT _NO_JIT = self.prev return False def is_exportable(): return _EXPORTABLE class set_exportable: def __init__(self, mode: bool) -> None: global _EXPORTABLE self.prev = _EXPORTABLE _EXPORTABLE = mode def __enter__(self) -> None: pass def __exit__(self, *args: Any) -> bool: global _EXPORTABLE _EXPORTABLE = self.prev return False def is_scriptable(): return _SCRIPTABLE class set_scriptable: def __init__(self, mode: bool) -> None: global _SCRIPTABLE self.prev = _SCRIPTABLE _SCRIPTABLE = mode def __enter__(self) -> None: pass def __exit__(self, *args: Any) -> bool: global _SCRIPTABLE _SCRIPTABLE = self.prev return False class set_layer_config: """Layer config context manager that allows setting all layer config flags at once. If a flag arg is None, it will not change the current value. """ def __init__( self, scriptable: Optional[bool] = None, exportable: Optional[bool] = None, no_jit: Optional[bool] = None, no_activation_jit: Optional[bool] = None, ): global _SCRIPTABLE global _EXPORTABLE global _NO_JIT global _NO_ACTIVATION_JIT self.prev = _SCRIPTABLE, _EXPORTABLE, _NO_JIT, _NO_ACTIVATION_JIT if scriptable is not None: _SCRIPTABLE = scriptable if exportable is not None: _EXPORTABLE = exportable if no_jit is not None: _NO_JIT = no_jit if no_activation_jit is not None: _NO_ACTIVATION_JIT = no_activation_jit def __enter__(self) -> None: pass def __exit__(self, *args: Any) -> bool: global _SCRIPTABLE global _EXPORTABLE global _NO_JIT global _NO_ACTIVATION_JIT _SCRIPTABLE, _EXPORTABLE, _NO_JIT, _NO_ACTIVATION_JIT = self.prev return False ================================================ FILE: RVT/models/layers/maxvit/layers/conv2d_same.py ================================================ """ Conv2d w/ Same Padding Hacked together by / Copyright 2020 Ross Wightman """ import torch import torch.nn as nn import torch.nn.functional as F from typing import Tuple, Optional from .padding import pad_same, get_padding_value def conv2d_same( x, weight: torch.Tensor, bias: Optional[torch.Tensor] = None, stride: Tuple[int, int] = (1, 1), padding: Tuple[int, int] = (0, 0), dilation: Tuple[int, int] = (1, 1), groups: int = 1, ): x = pad_same(x, weight.shape[-2:], stride, dilation) return F.conv2d(x, weight, bias, stride, (0, 0), dilation, groups) class Conv2dSame(nn.Conv2d): """Tensorflow like 'SAME' convolution wrapper for 2D convolutions""" def __init__( self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, ): super(Conv2dSame, self).__init__( in_channels, out_channels, kernel_size, stride, 0, dilation, groups, bias ) def forward(self, x): return conv2d_same( x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups, ) def create_conv2d_pad(in_chs, out_chs, kernel_size, **kwargs): padding = kwargs.pop("padding", "") kwargs.setdefault("bias", False) padding, is_dynamic = get_padding_value(padding, kernel_size, **kwargs) if is_dynamic: return Conv2dSame(in_chs, out_chs, kernel_size, **kwargs) else: return nn.Conv2d(in_chs, out_chs, kernel_size, padding=padding, **kwargs) ================================================ FILE: RVT/models/layers/maxvit/layers/conv_bn_act.py ================================================ """ Conv2d + BN + Act Hacked together by / Copyright 2020 Ross Wightman """ import functools from torch import nn as nn from .create_conv2d import create_conv2d from .create_norm_act import get_norm_act_layer class ConvNormAct(nn.Module): def __init__( self, in_channels, out_channels, kernel_size=1, stride=1, padding="", dilation=1, groups=1, bias=False, apply_act=True, norm_layer=nn.BatchNorm2d, act_layer=nn.ReLU, drop_layer=None, ): super(ConvNormAct, self).__init__() self.conv = create_conv2d( in_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias, ) # NOTE for backwards compatibility with models that use separate norm and act layer definitions norm_act_layer = get_norm_act_layer(norm_layer, act_layer) # NOTE for backwards (weight) compatibility, norm layer name remains `.bn` norm_kwargs = dict(drop_layer=drop_layer) if drop_layer is not None else {} self.bn = norm_act_layer(out_channels, apply_act=apply_act, **norm_kwargs) @property def in_channels(self): return self.conv.in_channels @property def out_channels(self): return self.conv.out_channels def forward(self, x): x = self.conv(x) x = self.bn(x) return x ConvBnAct = ConvNormAct def create_aa(aa_layer, channels, stride=2, enable=True): if not aa_layer or not enable: return nn.Identity() if isinstance(aa_layer, functools.partial): if issubclass(aa_layer.func, nn.AvgPool2d): return aa_layer() else: return aa_layer(channels) elif issubclass(aa_layer, nn.AvgPool2d): return aa_layer(stride) else: return aa_layer(channels=channels, stride=stride) class ConvNormActAa(nn.Module): def __init__( self, in_channels, out_channels, kernel_size=1, stride=1, padding="", dilation=1, groups=1, bias=False, apply_act=True, norm_layer=nn.BatchNorm2d, act_layer=nn.ReLU, aa_layer=None, drop_layer=None, ): super(ConvNormActAa, self).__init__() use_aa = aa_layer is not None and stride == 2 self.conv = create_conv2d( in_channels, out_channels, kernel_size, stride=1 if use_aa else stride, padding=padding, dilation=dilation, groups=groups, bias=bias, ) # NOTE for backwards compatibility with models that use separate norm and act layer definitions norm_act_layer = get_norm_act_layer(norm_layer, act_layer) # NOTE for backwards (weight) compatibility, norm layer name remains `.bn` norm_kwargs = dict(drop_layer=drop_layer) if drop_layer is not None else {} self.bn = norm_act_layer(out_channels, apply_act=apply_act, **norm_kwargs) self.aa = create_aa(aa_layer, out_channels, stride=stride, enable=use_aa) @property def in_channels(self): return self.conv.in_channels @property def out_channels(self): return self.conv.out_channels def forward(self, x): x = self.conv(x) x = self.bn(x) x = self.aa(x) return x ================================================ FILE: RVT/models/layers/maxvit/layers/create_act.py ================================================ """ Activation Factory Hacked together by / Copyright 2020 Ross Wightman """ from typing import Union, Callable, Type from .activations import * from .activations_jit import * from .activations_me import * from .config import is_exportable, is_scriptable, is_no_jit # PyTorch has an optimized, native 'silu' (aka 'swish') operator as of PyTorch 1.7. # Also hardsigmoid, hardswish, and soon mish. This code will use native version if present. # Eventually, the custom SiLU, Mish, Hard*, layers will be removed and only native variants will be used. _has_silu = "silu" in dir(torch.nn.functional) _has_hardswish = "hardswish" in dir(torch.nn.functional) _has_hardsigmoid = "hardsigmoid" in dir(torch.nn.functional) _has_mish = "mish" in dir(torch.nn.functional) _ACT_FN_DEFAULT = dict( silu=F.silu if _has_silu else swish, swish=F.silu if _has_silu else swish, mish=F.mish if _has_mish else mish, relu=F.relu, relu6=F.relu6, leaky_relu=F.leaky_relu, elu=F.elu, celu=F.celu, selu=F.selu, gelu=gelu, sigmoid=sigmoid, tanh=tanh, hard_sigmoid=F.hardsigmoid if _has_hardsigmoid else hard_sigmoid, hard_swish=F.hardswish if _has_hardswish else hard_swish, hard_mish=hard_mish, ) _ACT_FN_JIT = dict( silu=F.silu if _has_silu else swish_jit, swish=F.silu if _has_silu else swish_jit, mish=F.mish if _has_mish else mish_jit, hard_sigmoid=F.hardsigmoid if _has_hardsigmoid else hard_sigmoid_jit, hard_swish=F.hardswish if _has_hardswish else hard_swish_jit, hard_mish=hard_mish_jit, ) _ACT_FN_ME = dict( silu=F.silu if _has_silu else swish_me, swish=F.silu if _has_silu else swish_me, mish=F.mish if _has_mish else mish_me, hard_sigmoid=F.hardsigmoid if _has_hardsigmoid else hard_sigmoid_me, hard_swish=F.hardswish if _has_hardswish else hard_swish_me, hard_mish=hard_mish_me, ) _ACT_FNS = (_ACT_FN_ME, _ACT_FN_JIT, _ACT_FN_DEFAULT) for a in _ACT_FNS: a.setdefault("hardsigmoid", a.get("hard_sigmoid")) a.setdefault("hardswish", a.get("hard_swish")) _ACT_LAYER_DEFAULT = dict( silu=nn.SiLU if _has_silu else Swish, swish=nn.SiLU if _has_silu else Swish, mish=nn.Mish if _has_mish else Mish, relu=nn.ReLU, relu6=nn.ReLU6, leaky_relu=nn.LeakyReLU, elu=nn.ELU, prelu=PReLU, celu=nn.CELU, selu=nn.SELU, gelu=GELU, sigmoid=Sigmoid, tanh=Tanh, hard_sigmoid=nn.Hardsigmoid if _has_hardsigmoid else HardSigmoid, hard_swish=nn.Hardswish if _has_hardswish else HardSwish, hard_mish=HardMish, ) _ACT_LAYER_JIT = dict( silu=nn.SiLU if _has_silu else SwishJit, swish=nn.SiLU if _has_silu else SwishJit, mish=nn.Mish if _has_mish else MishJit, hard_sigmoid=nn.Hardsigmoid if _has_hardsigmoid else HardSigmoidJit, hard_swish=nn.Hardswish if _has_hardswish else HardSwishJit, hard_mish=HardMishJit, ) _ACT_LAYER_ME = dict( silu=nn.SiLU if _has_silu else SwishMe, swish=nn.SiLU if _has_silu else SwishMe, mish=nn.Mish if _has_mish else MishMe, hard_sigmoid=nn.Hardsigmoid if _has_hardsigmoid else HardSigmoidMe, hard_swish=nn.Hardswish if _has_hardswish else HardSwishMe, hard_mish=HardMishMe, ) _ACT_LAYERS = (_ACT_LAYER_ME, _ACT_LAYER_JIT, _ACT_LAYER_DEFAULT) for a in _ACT_LAYERS: a.setdefault("hardsigmoid", a.get("hard_sigmoid")) a.setdefault("hardswish", a.get("hard_swish")) def get_act_fn(name: Union[Callable, str] = "relu"): """Activation Function Factory Fetching activation fns by name with this function allows export or torch script friendly functions to be returned dynamically based on current config. """ if not name: return None if isinstance(name, Callable): return name if not (is_no_jit() or is_exportable() or is_scriptable()): # If not exporting or scripting the model, first look for a memory-efficient version with # custom autograd, then fallback if name in _ACT_FN_ME: return _ACT_FN_ME[name] if not (is_no_jit() or is_exportable()): if name in _ACT_FN_JIT: return _ACT_FN_JIT[name] return _ACT_FN_DEFAULT[name] def get_act_layer(name: Union[Type[nn.Module], str] = "relu"): """Activation Layer Factory Fetching activation layers by name with this function allows export or torch script friendly functions to be returned dynamically based on current config. """ if not name: return None if not isinstance(name, str): # callable, module, etc return name if not (is_no_jit() or is_exportable() or is_scriptable()): if name in _ACT_LAYER_ME: return _ACT_LAYER_ME[name] if not (is_no_jit() or is_exportable()): if name in _ACT_LAYER_JIT: return _ACT_LAYER_JIT[name] return _ACT_LAYER_DEFAULT[name] def create_act_layer(name: Union[nn.Module, str], inplace=None, **kwargs): act_layer = get_act_layer(name) if act_layer is None: return None if inplace is None: return act_layer(**kwargs) try: return act_layer(inplace=inplace, **kwargs) except TypeError: # recover if act layer doesn't have inplace arg return act_layer(**kwargs) ================================================ FILE: RVT/models/layers/maxvit/layers/create_attn.py ================================================ """ Attention Factory Hacked together by / Copyright 2021 Ross Wightman """ import torch from functools import partial from .bottleneck_attn import BottleneckAttn from .cbam import CbamModule, LightCbamModule from .eca import EcaModule, CecaModule from .gather_excite import GatherExcite from .global_context import GlobalContext from .halo_attn import HaloAttn from .lambda_layer import LambdaLayer from .non_local_attn import NonLocalAttn, BatNonLocalAttn from .selective_kernel import SelectiveKernel from .split_attn import SplitAttn from .squeeze_excite import SEModule, EffectiveSEModule def get_attn(attn_type): if isinstance(attn_type, torch.nn.Module): return attn_type module_cls = None if attn_type: if isinstance(attn_type, str): attn_type = attn_type.lower() # Lightweight attention modules (channel and/or coarse spatial). # Typically added to existing network architecture blocks in addition to existing convolutions. if attn_type == "se": module_cls = SEModule elif attn_type == "ese": module_cls = EffectiveSEModule elif attn_type == "eca": module_cls = EcaModule elif attn_type == "ecam": module_cls = partial(EcaModule, use_mlp=True) elif attn_type == "ceca": module_cls = CecaModule elif attn_type == "ge": module_cls = GatherExcite elif attn_type == "gc": module_cls = GlobalContext elif attn_type == "gca": module_cls = partial(GlobalContext, fuse_add=True, fuse_scale=False) elif attn_type == "cbam": module_cls = CbamModule elif attn_type == "lcbam": module_cls = LightCbamModule # Attention / attention-like modules w/ significant params # Typically replace some of the existing workhorse convs in a network architecture. # All of these accept a stride argument and can spatially downsample the input. elif attn_type == "sk": module_cls = SelectiveKernel elif attn_type == "splat": module_cls = SplitAttn # Self-attention / attention-like modules w/ significant compute and/or params # Typically replace some of the existing workhorse convs in a network architecture. # All of these accept a stride argument and can spatially downsample the input. elif attn_type == "lambda": return LambdaLayer elif attn_type == "bottleneck": return BottleneckAttn elif attn_type == "halo": return HaloAttn elif attn_type == "nl": module_cls = NonLocalAttn elif attn_type == "bat": module_cls = BatNonLocalAttn # Woops! else: assert False, "Invalid attn module (%s)" % attn_type elif isinstance(attn_type, bool): if attn_type: module_cls = SEModule else: module_cls = attn_type return module_cls def create_attn(attn_type, channels, **kwargs): module_cls = get_attn(attn_type) if module_cls is not None: # NOTE: it's expected the first (positional) argument of all attention layers is the # input channels return module_cls(channels, **kwargs) return None ================================================ FILE: RVT/models/layers/maxvit/layers/create_conv2d.py ================================================ """ Create Conv2d Factory Method Hacked together by / Copyright 2020 Ross Wightman """ from .mixed_conv2d import MixedConv2d from .cond_conv2d import CondConv2d from .conv2d_same import create_conv2d_pad def create_conv2d(in_channels, out_channels, kernel_size, **kwargs): """Select a 2d convolution implementation based on arguments Creates and returns one of torch.nn.Conv2d, Conv2dSame, MixedConv2d, or CondConv2d. Used extensively by EfficientNet, MobileNetv3 and related networks. """ if isinstance(kernel_size, list): assert ( "num_experts" not in kwargs ) # MixNet + CondConv combo not supported currently if "groups" in kwargs: groups = kwargs.pop("groups") if groups == in_channels: kwargs["depthwise"] = True else: assert groups == 1 # We're going to use only lists for defining the MixedConv2d kernel groups, # ints, tuples, other iterables will continue to pass to normal conv and specify h, w. m = MixedConv2d(in_channels, out_channels, kernel_size, **kwargs) else: depthwise = kwargs.pop("depthwise", False) # for DW out_channels must be multiple of in_channels as must have out_channels % groups == 0 groups = in_channels if depthwise else kwargs.pop("groups", 1) if "num_experts" in kwargs and kwargs["num_experts"] > 0: m = CondConv2d( in_channels, out_channels, kernel_size, groups=groups, **kwargs ) else: m = create_conv2d_pad( in_channels, out_channels, kernel_size, groups=groups, **kwargs ) return m ================================================ FILE: RVT/models/layers/maxvit/layers/create_norm.py ================================================ """ Norm Layer Factory Create norm modules by string (to mirror create_act and creat_norm-act fns) Copyright 2022 Ross Wightman """ import types import functools import torch.nn as nn from .norm import GroupNorm, GroupNorm1, LayerNorm, LayerNorm2d _NORM_MAP = dict( batchnorm=nn.BatchNorm2d, batchnorm2d=nn.BatchNorm2d, batchnorm1d=nn.BatchNorm1d, groupnorm=GroupNorm, groupnorm1=GroupNorm1, layernorm=LayerNorm, layernorm2d=LayerNorm2d, ) _NORM_TYPES = {m for n, m in _NORM_MAP.items()} def create_norm_layer( layer_name, num_features, act_layer=None, apply_act=True, **kwargs ): layer = get_norm_layer(layer_name, act_layer=act_layer) layer_instance = layer(num_features, apply_act=apply_act, **kwargs) return layer_instance def get_norm_layer(norm_layer): assert isinstance(norm_layer, (type, str, types.FunctionType, functools.partial)) norm_kwargs = {} # unbind partial fn, so args can be rebound later if isinstance(norm_layer, functools.partial): norm_kwargs.update(norm_layer.keywords) norm_layer = norm_layer.func if isinstance(norm_layer, str): layer_name = norm_layer.replace("_", "") norm_layer = _NORM_MAP.get(layer_name, None) elif norm_layer in _NORM_TYPES: norm_layer = norm_layer elif isinstance(norm_layer, types.FunctionType): # if function type, assume it is a lambda/fn that creates a norm layer norm_layer = norm_layer else: type_name = norm_layer.__name__.lower().replace("_", "") norm_layer = _NORM_MAP.get(type_name, None) assert norm_layer is not None, f"No equivalent norm layer for {type_name}" if norm_kwargs: norm_layer = functools.partial(norm_layer, **norm_kwargs) # bind/rebind args return norm_layer ================================================ FILE: RVT/models/layers/maxvit/layers/create_norm_act.py ================================================ """ NormAct (Normalizaiton + Activation Layer) Factory Create norm + act combo modules that attempt to be backwards compatible with separate norm + act isntances in models. Where these are used it will be possible to swap separate BN + act layers with combined modules like IABN or EvoNorms. Hacked together by / Copyright 2020 Ross Wightman """ import types import functools from .evo_norm import * from .filter_response_norm import FilterResponseNormAct2d, FilterResponseNormTlu2d from .norm_act import BatchNormAct2d, GroupNormAct, LayerNormAct, LayerNormAct2d from .inplace_abn import InplaceAbn _NORM_ACT_MAP = dict( batchnorm=BatchNormAct2d, batchnorm2d=BatchNormAct2d, groupnorm=GroupNormAct, groupnorm1=functools.partial(GroupNormAct, num_groups=1), layernorm=LayerNormAct, layernorm2d=LayerNormAct2d, evonormb0=EvoNorm2dB0, evonormb1=EvoNorm2dB1, evonormb2=EvoNorm2dB2, evonorms0=EvoNorm2dS0, evonorms0a=EvoNorm2dS0a, evonorms1=EvoNorm2dS1, evonorms1a=EvoNorm2dS1a, evonorms2=EvoNorm2dS2, evonorms2a=EvoNorm2dS2a, frn=FilterResponseNormAct2d, frntlu=FilterResponseNormTlu2d, inplaceabn=InplaceAbn, iabn=InplaceAbn, ) _NORM_ACT_TYPES = {m for n, m in _NORM_ACT_MAP.items()} # has act_layer arg to define act type _NORM_ACT_REQUIRES_ARG = { BatchNormAct2d, GroupNormAct, LayerNormAct, LayerNormAct2d, FilterResponseNormAct2d, InplaceAbn, } def create_norm_act_layer( layer_name, num_features, act_layer=None, apply_act=True, jit=False, **kwargs ): layer = get_norm_act_layer(layer_name, act_layer=act_layer) layer_instance = layer(num_features, apply_act=apply_act, **kwargs) if jit: layer_instance = torch.jit.script(layer_instance) return layer_instance def get_norm_act_layer(norm_layer, act_layer=None): assert isinstance(norm_layer, (type, str, types.FunctionType, functools.partial)) assert act_layer is None or isinstance( act_layer, (type, str, types.FunctionType, functools.partial) ) norm_act_kwargs = {} # unbind partial fn, so args can be rebound later if isinstance(norm_layer, functools.partial): norm_act_kwargs.update(norm_layer.keywords) norm_layer = norm_layer.func if isinstance(norm_layer, str): layer_name = norm_layer.replace("_", "").lower().split("-")[0] norm_act_layer = _NORM_ACT_MAP.get(layer_name, None) elif norm_layer in _NORM_ACT_TYPES: norm_act_layer = norm_layer elif isinstance(norm_layer, types.FunctionType): # if function type, must be a lambda/fn that creates a norm_act layer norm_act_layer = norm_layer else: type_name = norm_layer.__name__.lower() if type_name.startswith("batchnorm"): norm_act_layer = BatchNormAct2d elif type_name.startswith("groupnorm"): norm_act_layer = GroupNormAct elif type_name.startswith("groupnorm1"): norm_act_layer = functools.partial(GroupNormAct, num_groups=1) elif type_name.startswith("layernorm2d"): norm_act_layer = LayerNormAct2d elif type_name.startswith("layernorm"): norm_act_layer = LayerNormAct else: assert False, f"No equivalent norm_act layer for {type_name}" if norm_act_layer in _NORM_ACT_REQUIRES_ARG: # pass `act_layer` through for backwards compat where `act_layer=None` implies no activation. # In the future, may force use of `apply_act` with `act_layer` arg bound to relevant NormAct types norm_act_kwargs.setdefault("act_layer", act_layer) if norm_act_kwargs: norm_act_layer = functools.partial( norm_act_layer, **norm_act_kwargs ) # bind/rebind args return norm_act_layer ================================================ FILE: RVT/models/layers/maxvit/layers/drop.py ================================================ """ DropBlock, DropPath PyTorch implementations of DropBlock and DropPath (Stochastic Depth) regularization layers. Papers: DropBlock: A regularization method for convolutional networks (https://arxiv.org/abs/1810.12890) Deep Networks with Stochastic Depth (https://arxiv.org/abs/1603.09382) Code: DropBlock impl inspired by two Tensorflow impl that I liked: - https://github.com/tensorflow/tpu/blob/master/models/official/resnet/resnet_model.py#L74 - https://github.com/clovaai/assembled-cnn/blob/master/nets/blocks.py Hacked together by / Copyright 2020 Ross Wightman """ import torch import torch.nn as nn import torch.nn.functional as F def drop_block_2d( x, drop_prob: float = 0.1, block_size: int = 7, gamma_scale: float = 1.0, with_noise: bool = False, inplace: bool = False, batchwise: bool = False, ): """DropBlock. See https://arxiv.org/pdf/1810.12890.pdf DropBlock with an experimental gaussian noise option. This layer has been tested on a few training runs with success, but needs further validation and possibly optimization for lower runtime impact. """ B, C, H, W = x.shape total_size = W * H clipped_block_size = min(block_size, min(W, H)) # seed_drop_rate, the gamma parameter gamma = ( gamma_scale * drop_prob * total_size / clipped_block_size**2 / ((W - block_size + 1) * (H - block_size + 1)) ) # Forces the block to be inside the feature map. w_i, h_i = torch.meshgrid( torch.arange(W).to(x.device), torch.arange(H).to(x.device) ) valid_block = ( (w_i >= clipped_block_size // 2) & (w_i < W - (clipped_block_size - 1) // 2) ) & ((h_i >= clipped_block_size // 2) & (h_i < H - (clipped_block_size - 1) // 2)) valid_block = torch.reshape(valid_block, (1, 1, H, W)).to(dtype=x.dtype) if batchwise: # one mask for whole batch, quite a bit faster uniform_noise = torch.rand((1, C, H, W), dtype=x.dtype, device=x.device) else: uniform_noise = torch.rand_like(x) block_mask = ((2 - gamma - valid_block + uniform_noise) >= 1).to(dtype=x.dtype) block_mask = -F.max_pool2d( -block_mask, kernel_size=clipped_block_size, # block_size, stride=1, padding=clipped_block_size // 2, ) if with_noise: normal_noise = ( torch.randn((1, C, H, W), dtype=x.dtype, device=x.device) if batchwise else torch.randn_like(x) ) if inplace: x.mul_(block_mask).add_(normal_noise * (1 - block_mask)) else: x = x * block_mask + normal_noise * (1 - block_mask) else: normalize_scale = ( block_mask.numel() / block_mask.to(dtype=torch.float32).sum().add(1e-7) ).to(x.dtype) if inplace: x.mul_(block_mask * normalize_scale) else: x = x * block_mask * normalize_scale return x def drop_block_fast_2d( x: torch.Tensor, drop_prob: float = 0.1, block_size: int = 7, gamma_scale: float = 1.0, with_noise: bool = False, inplace: bool = False, ): """DropBlock. See https://arxiv.org/pdf/1810.12890.pdf DropBlock with an experimental gaussian noise option. Simplied from above without concern for valid block mask at edges. """ B, C, H, W = x.shape total_size = W * H clipped_block_size = min(block_size, min(W, H)) gamma = ( gamma_scale * drop_prob * total_size / clipped_block_size**2 / ((W - block_size + 1) * (H - block_size + 1)) ) block_mask = torch.empty_like(x).bernoulli_(gamma) block_mask = F.max_pool2d( block_mask.to(x.dtype), kernel_size=clipped_block_size, stride=1, padding=clipped_block_size // 2, ) if with_noise: normal_noise = torch.empty_like(x).normal_() if inplace: x.mul_(1.0 - block_mask).add_(normal_noise * block_mask) else: x = x * (1.0 - block_mask) + normal_noise * block_mask else: block_mask = 1 - block_mask normalize_scale = ( block_mask.numel() / block_mask.to(dtype=torch.float32).sum().add(1e-6) ).to(dtype=x.dtype) if inplace: x.mul_(block_mask * normalize_scale) else: x = x * block_mask * normalize_scale return x class DropBlock2d(nn.Module): """DropBlock. See https://arxiv.org/pdf/1810.12890.pdf""" def __init__( self, drop_prob: float = 0.1, block_size: int = 7, gamma_scale: float = 1.0, with_noise: bool = False, inplace: bool = False, batchwise: bool = False, fast: bool = True, ): super(DropBlock2d, self).__init__() self.drop_prob = drop_prob self.gamma_scale = gamma_scale self.block_size = block_size self.with_noise = with_noise self.inplace = inplace self.batchwise = batchwise self.fast = fast # FIXME finish comparisons of fast vs not def forward(self, x): if not self.training or not self.drop_prob: return x if self.fast: return drop_block_fast_2d( x, self.drop_prob, self.block_size, self.gamma_scale, self.with_noise, self.inplace, ) else: return drop_block_2d( x, self.drop_prob, self.block_size, self.gamma_scale, self.with_noise, self.inplace, self.batchwise, ) def drop_path( x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True ): """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the argument. """ if drop_prob == 0.0 or not training: return x keep_prob = 1 - drop_prob shape = (x.shape[0],) + (1,) * ( x.ndim - 1 ) # work with diff dim tensors, not just 2D ConvNets random_tensor = x.new_empty(shape).bernoulli_(keep_prob) if keep_prob > 0.0 and scale_by_keep: random_tensor.div_(keep_prob) return x * random_tensor class DropPath(nn.Module): """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" def __init__(self, drop_prob: float = 0.0, scale_by_keep: bool = True): super(DropPath, self).__init__() self.drop_prob = drop_prob self.scale_by_keep = scale_by_keep def forward(self, x): return drop_path(x, self.drop_prob, self.training, self.scale_by_keep) def extra_repr(self): return f"drop_prob={round(self.drop_prob,3):0.3f}" ================================================ FILE: RVT/models/layers/maxvit/layers/eca.py ================================================ """ ECA module from ECAnet paper: ECA-Net: Efficient Channel Attention for Deep Convolutional Neural Networks https://arxiv.org/abs/1910.03151 Original ECA model borrowed from https://github.com/BangguWu/ECANet Modified circular ECA implementation and adaption for use in timm package by Chris Ha https://github.com/VRandme Original License: MIT License Copyright (c) 2019 BangguWu, Qilong Wang Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ import math from torch import nn import torch.nn.functional as F from .create_act import create_act_layer from .helpers import make_divisible class EcaModule(nn.Module): """Constructs an ECA module. Args: channels: Number of channels of the input feature map for use in adaptive kernel sizes for actual calculations according to channel. gamma, beta: when channel is given parameters of mapping function refer to original paper https://arxiv.org/pdf/1910.03151.pdf (default=None. if channel size not given, use k_size given for kernel size.) kernel_size: Adaptive selection of kernel size (default=3) gamm: used in kernel_size calc, see above beta: used in kernel_size calc, see above act_layer: optional non-linearity after conv, enables conv bias, this is an experiment gate_layer: gating non-linearity to use """ def __init__( self, channels=None, kernel_size=3, gamma=2, beta=1, act_layer=None, gate_layer="sigmoid", rd_ratio=1 / 8, rd_channels=None, rd_divisor=8, use_mlp=False, ): super(EcaModule, self).__init__() if channels is not None: t = int(abs(math.log(channels, 2) + beta) / gamma) kernel_size = max(t if t % 2 else t + 1, 3) assert kernel_size % 2 == 1 padding = (kernel_size - 1) // 2 if use_mlp: # NOTE 'mlp' mode is a timm experiment, not in paper assert channels is not None if rd_channels is None: rd_channels = make_divisible(channels * rd_ratio, divisor=rd_divisor) act_layer = act_layer or nn.ReLU self.conv = nn.Conv1d(1, rd_channels, kernel_size=1, padding=0, bias=True) self.act = create_act_layer(act_layer) self.conv2 = nn.Conv1d( rd_channels, 1, kernel_size=kernel_size, padding=padding, bias=True ) else: self.conv = nn.Conv1d( 1, 1, kernel_size=kernel_size, padding=padding, bias=False ) self.act = None self.conv2 = None self.gate = create_act_layer(gate_layer) def forward(self, x): y = x.mean((2, 3)).view(x.shape[0], 1, -1) # view for 1d conv y = self.conv(y) if self.conv2 is not None: y = self.act(y) y = self.conv2(y) y = self.gate(y).view(x.shape[0], -1, 1, 1) return x * y.expand_as(x) EfficientChannelAttn = EcaModule # alias class CecaModule(nn.Module): """Constructs a circular ECA module. ECA module where the conv uses circular padding rather than zero padding. Unlike the spatial dimension, the channels do not have inherent ordering nor locality. Although this module in essence, applies such an assumption, it is unnecessary to limit the channels on either "edge" from being circularly adapted to each other. This will fundamentally increase connectivity and possibly increase performance metrics (accuracy, robustness), without significantly impacting resource metrics (parameter size, throughput,latency, etc) Args: channels: Number of channels of the input feature map for use in adaptive kernel sizes for actual calculations according to channel. gamma, beta: when channel is given parameters of mapping function refer to original paper https://arxiv.org/pdf/1910.03151.pdf (default=None. if channel size not given, use k_size given for kernel size.) kernel_size: Adaptive selection of kernel size (default=3) gamm: used in kernel_size calc, see above beta: used in kernel_size calc, see above act_layer: optional non-linearity after conv, enables conv bias, this is an experiment gate_layer: gating non-linearity to use """ def __init__( self, channels=None, kernel_size=3, gamma=2, beta=1, act_layer=None, gate_layer="sigmoid", ): super(CecaModule, self).__init__() if channels is not None: t = int(abs(math.log(channels, 2) + beta) / gamma) kernel_size = max(t if t % 2 else t + 1, 3) has_act = act_layer is not None assert kernel_size % 2 == 1 # PyTorch circular padding mode is buggy as of pytorch 1.4 # see https://github.com/pytorch/pytorch/pull/17240 # implement manual circular padding self.padding = (kernel_size - 1) // 2 self.conv = nn.Conv1d(1, 1, kernel_size=kernel_size, padding=0, bias=has_act) self.gate = create_act_layer(gate_layer) def forward(self, x): y = x.mean((2, 3)).view(x.shape[0], 1, -1) # Manually implement circular padding, F.pad does not seemed to be bugged y = F.pad(y, (self.padding, self.padding), mode="circular") y = self.conv(y) y = self.gate(y).view(x.shape[0], -1, 1, 1) return x * y.expand_as(x) CircularEfficientChannelAttn = CecaModule ================================================ FILE: RVT/models/layers/maxvit/layers/evo_norm.py ================================================ """ EvoNorm in PyTorch Based on `Evolving Normalization-Activation Layers` - https://arxiv.org/abs/2004.02967 @inproceedings{NEURIPS2020, author = {Liu, Hanxiao and Brock, Andy and Simonyan, Karen and Le, Quoc}, booktitle = {Advances in Neural Information Processing Systems}, editor = {H. Larochelle and M. Ranzato and R. Hadsell and M. F. Balcan and H. Lin}, pages = {13539--13550}, publisher = {Curran Associates, Inc.}, title = {Evolving Normalization-Activation Layers}, url = {https://proceedings.neurips.cc/paper/2020/file/9d4c03631b8b0c85ae08bf05eda37d0f-Paper.pdf}, volume = {33}, year = {2020} } An attempt at getting decent performing EvoNorms running in PyTorch. While faster than other PyTorch impl, still quite a ways off the built-in BatchNorm in terms of memory usage and throughput on GPUs. I'm testing these modules on TPU w/ PyTorch XLA. Promising start but currently working around some issues with builtin torch/tensor.var/std. Unlike GPU, similar train speeds for EvoNormS variants and BatchNorm. Hacked together by / Copyright 2020 Ross Wightman """ from typing import Sequence, Union import torch import torch.nn as nn import torch.nn.functional as F from .create_act import create_act_layer from .trace_utils import _assert def instance_std(x, eps: float = 1e-5): std = ( x.float() .var(dim=(2, 3), unbiased=False, keepdim=True) .add(eps) .sqrt() .to(x.dtype) ) return std.expand(x.shape) def instance_std_tpu(x, eps: float = 1e-5): std = manual_var(x, dim=(2, 3)).add(eps).sqrt() return std.expand(x.shape) # instance_std = instance_std_tpu def instance_rms(x, eps: float = 1e-5): rms = x.float().square().mean(dim=(2, 3), keepdim=True).add(eps).sqrt().to(x.dtype) return rms.expand(x.shape) def manual_var(x, dim: Union[int, Sequence[int]], diff_sqm: bool = False): xm = x.mean(dim=dim, keepdim=True) if diff_sqm: # difference of squared mean and mean squared, faster on TPU can be less stable var = ((x * x).mean(dim=dim, keepdim=True) - (xm * xm)).clamp(0) else: var = ((x - xm) * (x - xm)).mean(dim=dim, keepdim=True) return var def group_std(x, groups: int = 32, eps: float = 1e-5, flatten: bool = False): B, C, H, W = x.shape x_dtype = x.dtype _assert(C % groups == 0, "") if flatten: x = x.reshape(B, groups, -1) # FIXME simpler shape causing TPU / XLA issues std = ( x.float() .var(dim=2, unbiased=False, keepdim=True) .add(eps) .sqrt() .to(x_dtype) ) else: x = x.reshape(B, groups, C // groups, H, W) std = ( x.float() .var(dim=(2, 3, 4), unbiased=False, keepdim=True) .add(eps) .sqrt() .to(x_dtype) ) return std.expand(x.shape).reshape(B, C, H, W) def group_std_tpu( x, groups: int = 32, eps: float = 1e-5, diff_sqm: bool = False, flatten: bool = False, ): # This is a workaround for some stability / odd behaviour of .var and .std # running on PyTorch XLA w/ TPUs. These manual var impl are producing much better results B, C, H, W = x.shape _assert(C % groups == 0, "") if flatten: x = x.reshape(B, groups, -1) # FIXME simpler shape causing TPU / XLA issues var = manual_var(x, dim=-1, diff_sqm=diff_sqm) else: x = x.reshape(B, groups, C // groups, H, W) var = manual_var(x, dim=(2, 3, 4), diff_sqm=diff_sqm) return var.add(eps).sqrt().expand(x.shape).reshape(B, C, H, W) # group_std = group_std_tpu # FIXME TPU temporary def group_rms(x, groups: int = 32, eps: float = 1e-5): B, C, H, W = x.shape _assert(C % groups == 0, "") x_dtype = x.dtype x = x.reshape(B, groups, C // groups, H, W) rms = ( x.float() .square() .mean(dim=(2, 3, 4), keepdim=True) .add(eps) .sqrt_() .to(x_dtype) ) return rms.expand(x.shape).reshape(B, C, H, W) class EvoNorm2dB0(nn.Module): def __init__(self, num_features, apply_act=True, momentum=0.1, eps=1e-3, **_): super().__init__() self.apply_act = apply_act # apply activation (non-linearity) self.momentum = momentum self.eps = eps self.weight = nn.Parameter(torch.ones(num_features)) self.bias = nn.Parameter(torch.zeros(num_features)) self.v = nn.Parameter(torch.ones(num_features)) if apply_act else None self.register_buffer("running_var", torch.ones(num_features)) self.reset_parameters() def reset_parameters(self): nn.init.ones_(self.weight) nn.init.zeros_(self.bias) if self.v is not None: nn.init.ones_(self.v) def forward(self, x): _assert(x.dim() == 4, "expected 4D input") x_dtype = x.dtype v_shape = (1, -1, 1, 1) if self.v is not None: if self.training: var = x.float().var(dim=(0, 2, 3), unbiased=False) # var = manual_var(x, dim=(0, 2, 3)).squeeze() n = x.numel() / x.shape[1] self.running_var.copy_( self.running_var * (1 - self.momentum) + var.detach() * self.momentum * (n / (n - 1)) ) else: var = self.running_var left = var.add(self.eps).sqrt_().to(x_dtype).view(v_shape).expand_as(x) v = self.v.to(x_dtype).view(v_shape) right = x * v + instance_std(x, self.eps) x = x / left.max(right) return x * self.weight.to(x_dtype).view(v_shape) + self.bias.to(x_dtype).view( v_shape ) class EvoNorm2dB1(nn.Module): def __init__(self, num_features, apply_act=True, momentum=0.1, eps=1e-5, **_): super().__init__() self.apply_act = apply_act # apply activation (non-linearity) self.momentum = momentum self.eps = eps self.weight = nn.Parameter(torch.ones(num_features)) self.bias = nn.Parameter(torch.zeros(num_features)) self.register_buffer("running_var", torch.ones(num_features)) self.reset_parameters() def reset_parameters(self): nn.init.ones_(self.weight) nn.init.zeros_(self.bias) def forward(self, x): _assert(x.dim() == 4, "expected 4D input") x_dtype = x.dtype v_shape = (1, -1, 1, 1) if self.apply_act: if self.training: var = x.float().var(dim=(0, 2, 3), unbiased=False) n = x.numel() / x.shape[1] self.running_var.copy_( self.running_var * (1 - self.momentum) + var.detach().to(self.running_var.dtype) * self.momentum * (n / (n - 1)) ) else: var = self.running_var var = var.to(x_dtype).view(v_shape) left = var.add(self.eps).sqrt_() right = (x + 1) * instance_rms(x, self.eps) x = x / left.max(right) return x * self.weight.view(v_shape).to(x_dtype) + self.bias.view(v_shape).to( x_dtype ) class EvoNorm2dB2(nn.Module): def __init__(self, num_features, apply_act=True, momentum=0.1, eps=1e-5, **_): super().__init__() self.apply_act = apply_act # apply activation (non-linearity) self.momentum = momentum self.eps = eps self.weight = nn.Parameter(torch.ones(num_features)) self.bias = nn.Parameter(torch.zeros(num_features)) self.register_buffer("running_var", torch.ones(num_features)) self.reset_parameters() def reset_parameters(self): nn.init.ones_(self.weight) nn.init.zeros_(self.bias) def forward(self, x): _assert(x.dim() == 4, "expected 4D input") x_dtype = x.dtype v_shape = (1, -1, 1, 1) if self.apply_act: if self.training: var = x.float().var(dim=(0, 2, 3), unbiased=False) n = x.numel() / x.shape[1] self.running_var.copy_( self.running_var * (1 - self.momentum) + var.detach().to(self.running_var.dtype) * self.momentum * (n / (n - 1)) ) else: var = self.running_var var = var.to(x_dtype).view(v_shape) left = var.add(self.eps).sqrt_() right = instance_rms(x, self.eps) - x x = x / left.max(right) return x * self.weight.view(v_shape).to(x_dtype) + self.bias.view(v_shape).to( x_dtype ) class EvoNorm2dS0(nn.Module): def __init__( self, num_features, groups=32, group_size=None, apply_act=True, eps=1e-5, **_ ): super().__init__() self.apply_act = apply_act # apply activation (non-linearity) if group_size: assert num_features % group_size == 0 self.groups = num_features // group_size else: self.groups = groups self.eps = eps self.weight = nn.Parameter(torch.ones(num_features)) self.bias = nn.Parameter(torch.zeros(num_features)) self.v = nn.Parameter(torch.ones(num_features)) if apply_act else None self.reset_parameters() def reset_parameters(self): nn.init.ones_(self.weight) nn.init.zeros_(self.bias) if self.v is not None: nn.init.ones_(self.v) def forward(self, x): _assert(x.dim() == 4, "expected 4D input") x_dtype = x.dtype v_shape = (1, -1, 1, 1) if self.v is not None: v = self.v.view(v_shape).to(x_dtype) x = x * (x * v).sigmoid() / group_std(x, self.groups, self.eps) return x * self.weight.view(v_shape).to(x_dtype) + self.bias.view(v_shape).to( x_dtype ) class EvoNorm2dS0a(EvoNorm2dS0): def __init__( self, num_features, groups=32, group_size=None, apply_act=True, eps=1e-3, **_ ): super().__init__( num_features, groups=groups, group_size=group_size, apply_act=apply_act, eps=eps, ) def forward(self, x): _assert(x.dim() == 4, "expected 4D input") x_dtype = x.dtype v_shape = (1, -1, 1, 1) d = group_std(x, self.groups, self.eps) if self.v is not None: v = self.v.view(v_shape).to(x_dtype) x = x * (x * v).sigmoid() x = x / d return x * self.weight.view(v_shape).to(x_dtype) + self.bias.view(v_shape).to( x_dtype ) class EvoNorm2dS1(nn.Module): def __init__( self, num_features, groups=32, group_size=None, apply_act=True, act_layer=None, eps=1e-5, **_ ): super().__init__() act_layer = act_layer or nn.SiLU self.apply_act = apply_act # apply activation (non-linearity) if act_layer is not None and apply_act: self.act = create_act_layer(act_layer) else: self.act = nn.Identity() if group_size: assert num_features % group_size == 0 self.groups = num_features // group_size else: self.groups = groups self.eps = eps self.pre_act_norm = False self.weight = nn.Parameter(torch.ones(num_features)) self.bias = nn.Parameter(torch.zeros(num_features)) self.reset_parameters() def reset_parameters(self): nn.init.ones_(self.weight) nn.init.zeros_(self.bias) def forward(self, x): _assert(x.dim() == 4, "expected 4D input") x_dtype = x.dtype v_shape = (1, -1, 1, 1) if self.apply_act: x = self.act(x) / group_std(x, self.groups, self.eps) return x * self.weight.view(v_shape).to(x_dtype) + self.bias.view(v_shape).to( x_dtype ) class EvoNorm2dS1a(EvoNorm2dS1): def __init__( self, num_features, groups=32, group_size=None, apply_act=True, act_layer=None, eps=1e-3, **_ ): super().__init__( num_features, groups=groups, group_size=group_size, apply_act=apply_act, act_layer=act_layer, eps=eps, ) def forward(self, x): _assert(x.dim() == 4, "expected 4D input") x_dtype = x.dtype v_shape = (1, -1, 1, 1) x = self.act(x) / group_std(x, self.groups, self.eps) return x * self.weight.view(v_shape).to(x_dtype) + self.bias.view(v_shape).to( x_dtype ) class EvoNorm2dS2(nn.Module): def __init__( self, num_features, groups=32, group_size=None, apply_act=True, act_layer=None, eps=1e-5, **_ ): super().__init__() act_layer = act_layer or nn.SiLU self.apply_act = apply_act # apply activation (non-linearity) if act_layer is not None and apply_act: self.act = create_act_layer(act_layer) else: self.act = nn.Identity() if group_size: assert num_features % group_size == 0 self.groups = num_features // group_size else: self.groups = groups self.eps = eps self.weight = nn.Parameter(torch.ones(num_features)) self.bias = nn.Parameter(torch.zeros(num_features)) self.reset_parameters() def reset_parameters(self): nn.init.ones_(self.weight) nn.init.zeros_(self.bias) def forward(self, x): _assert(x.dim() == 4, "expected 4D input") x_dtype = x.dtype v_shape = (1, -1, 1, 1) if self.apply_act: x = self.act(x) / group_rms(x, self.groups, self.eps) return x * self.weight.view(v_shape).to(x_dtype) + self.bias.view(v_shape).to( x_dtype ) class EvoNorm2dS2a(EvoNorm2dS2): def __init__( self, num_features, groups=32, group_size=None, apply_act=True, act_layer=None, eps=1e-3, **_ ): super().__init__( num_features, groups=groups, group_size=group_size, apply_act=apply_act, act_layer=act_layer, eps=eps, ) def forward(self, x): _assert(x.dim() == 4, "expected 4D input") x_dtype = x.dtype v_shape = (1, -1, 1, 1) x = self.act(x) / group_rms(x, self.groups, self.eps) return x * self.weight.view(v_shape).to(x_dtype) + self.bias.view(v_shape).to( x_dtype ) ================================================ FILE: RVT/models/layers/maxvit/layers/fast_norm.py ================================================ """ 'Fast' Normalization Functions For GroupNorm and LayerNorm these functions bypass typical AMP upcast to float32. Additionally, for LayerNorm, the APEX fused LN is used if available (which also does not upcast) Hacked together by / Copyright 2022 Ross Wightman """ from typing import List, Optional import torch from torch.nn import functional as F try: from apex.normalization.fused_layer_norm import fused_layer_norm_affine has_apex = True except ImportError: has_apex = False # fast (ie lower precision LN) can be disabled with this flag if issues crop up _USE_FAST_NORM = False # defaulting to False for now def is_fast_norm(): return _USE_FAST_NORM def set_fast_norm(enable=True): global _USE_FAST_NORM _USE_FAST_NORM = enable def fast_group_norm( x: torch.Tensor, num_groups: int, weight: Optional[torch.Tensor] = None, bias: Optional[torch.Tensor] = None, eps: float = 1e-5, ) -> torch.Tensor: if torch.jit.is_scripting(): # currently cannot use is_autocast_enabled within torchscript return F.group_norm(x, num_groups, weight, bias, eps) if torch.is_autocast_enabled(): # normally native AMP casts GN inputs to float32 # here we use the low precision autocast dtype # FIXME what to do re CPU autocast? dt = torch.get_autocast_gpu_dtype() x, weight, bias = x.to(dt), weight.to(dt), bias.to(dt) with torch.cuda.amp.autocast(enabled=False): return F.group_norm(x, num_groups, weight, bias, eps) def fast_layer_norm( x: torch.Tensor, normalized_shape: List[int], weight: Optional[torch.Tensor] = None, bias: Optional[torch.Tensor] = None, eps: float = 1e-5, ) -> torch.Tensor: if torch.jit.is_scripting(): # currently cannot use is_autocast_enabled within torchscript return F.layer_norm(x, normalized_shape, weight, bias, eps) if has_apex: return fused_layer_norm_affine(x, weight, bias, normalized_shape, eps) if torch.is_autocast_enabled(): # normally native AMP casts LN inputs to float32 # apex LN does not, this is behaving like Apex dt = torch.get_autocast_gpu_dtype() # FIXME what to do re CPU autocast? x, weight, bias = x.to(dt), weight.to(dt), bias.to(dt) with torch.cuda.amp.autocast(enabled=False): return F.layer_norm(x, normalized_shape, weight, bias, eps) ================================================ FILE: RVT/models/layers/maxvit/layers/filter_response_norm.py ================================================ """ Filter Response Norm in PyTorch Based on `Filter Response Normalization Layer` - https://arxiv.org/abs/1911.09737 Hacked together by / Copyright 2021 Ross Wightman """ import torch import torch.nn as nn from .create_act import create_act_layer from .trace_utils import _assert def inv_instance_rms(x, eps: float = 1e-5): rms = x.square().float().mean(dim=(2, 3), keepdim=True).add(eps).rsqrt().to(x.dtype) return rms.expand(x.shape) class FilterResponseNormTlu2d(nn.Module): def __init__(self, num_features, apply_act=True, eps=1e-5, rms=True, **_): super(FilterResponseNormTlu2d, self).__init__() self.apply_act = apply_act # apply activation (non-linearity) self.rms = rms self.eps = eps self.weight = nn.Parameter(torch.ones(num_features)) self.bias = nn.Parameter(torch.zeros(num_features)) self.tau = nn.Parameter(torch.zeros(num_features)) if apply_act else None self.reset_parameters() def reset_parameters(self): nn.init.ones_(self.weight) nn.init.zeros_(self.bias) if self.tau is not None: nn.init.zeros_(self.tau) def forward(self, x): _assert(x.dim() == 4, "expected 4D input") x_dtype = x.dtype v_shape = (1, -1, 1, 1) x = x * inv_instance_rms(x, self.eps) x = x * self.weight.view(v_shape).to(dtype=x_dtype) + self.bias.view( v_shape ).to(dtype=x_dtype) return ( torch.maximum(x, self.tau.reshape(v_shape).to(dtype=x_dtype)) if self.tau is not None else x ) class FilterResponseNormAct2d(nn.Module): def __init__( self, num_features, apply_act=True, act_layer=nn.ReLU, inplace=None, rms=True, eps=1e-5, **_ ): super(FilterResponseNormAct2d, self).__init__() if act_layer is not None and apply_act: self.act = create_act_layer(act_layer, inplace=inplace) else: self.act = nn.Identity() self.rms = rms self.eps = eps self.weight = nn.Parameter(torch.ones(num_features)) self.bias = nn.Parameter(torch.zeros(num_features)) self.reset_parameters() def reset_parameters(self): nn.init.ones_(self.weight) nn.init.zeros_(self.bias) def forward(self, x): _assert(x.dim() == 4, "expected 4D input") x_dtype = x.dtype v_shape = (1, -1, 1, 1) x = x * inv_instance_rms(x, self.eps) x = x * self.weight.view(v_shape).to(dtype=x_dtype) + self.bias.view( v_shape ).to(dtype=x_dtype) return self.act(x) ================================================ FILE: RVT/models/layers/maxvit/layers/gather_excite.py ================================================ """ Gather-Excite Attention Block Paper: `Gather-Excite: Exploiting Feature Context in CNNs` - https://arxiv.org/abs/1810.12348 Official code here, but it's only partial impl in Caffe: https://github.com/hujie-frank/GENet I've tried to support all of the extent both w/ and w/o params. I don't believe I've seen another impl that covers all of the cases. NOTE: extent=0 + extra_params=False is equivalent to Squeeze-and-Excitation Hacked together by / Copyright 2021 Ross Wightman """ import math from torch import nn as nn import torch.nn.functional as F from .create_act import create_act_layer, get_act_layer from .create_conv2d import create_conv2d from .helpers import make_divisible from .mlp import ConvMlp class GatherExcite(nn.Module): """Gather-Excite Attention Module""" def __init__( self, channels, feat_size=None, extra_params=False, extent=0, use_mlp=True, rd_ratio=1.0 / 16, rd_channels=None, rd_divisor=1, add_maxpool=False, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, gate_layer="sigmoid", ): super(GatherExcite, self).__init__() self.add_maxpool = add_maxpool act_layer = get_act_layer(act_layer) self.extent = extent if extra_params: self.gather = nn.Sequential() if extent == 0: assert ( feat_size is not None ), "spatial feature size must be specified for global extent w/ params" self.gather.add_module( "conv1", create_conv2d( channels, channels, kernel_size=feat_size, stride=1, depthwise=True, ), ) if norm_layer: self.gather.add_module(f"norm1", nn.BatchNorm2d(channels)) else: assert extent % 2 == 0 num_conv = int(math.log2(extent)) for i in range(num_conv): self.gather.add_module( f"conv{i + 1}", create_conv2d( channels, channels, kernel_size=3, stride=2, depthwise=True ), ) if norm_layer: self.gather.add_module(f"norm{i + 1}", nn.BatchNorm2d(channels)) if i != num_conv - 1: self.gather.add_module(f"act{i + 1}", act_layer(inplace=True)) else: self.gather = None if self.extent == 0: self.gk = 0 self.gs = 0 else: assert extent % 2 == 0 self.gk = self.extent * 2 - 1 self.gs = self.extent if not rd_channels: rd_channels = make_divisible( channels * rd_ratio, rd_divisor, round_limit=0.0 ) self.mlp = ( ConvMlp(channels, rd_channels, act_layer=act_layer) if use_mlp else nn.Identity() ) self.gate = create_act_layer(gate_layer) def forward(self, x): size = x.shape[-2:] if self.gather is not None: x_ge = self.gather(x) else: if self.extent == 0: # global extent x_ge = x.mean(dim=(2, 3), keepdims=True) if self.add_maxpool: # experimental codepath, may remove or change x_ge = 0.5 * x_ge + 0.5 * x.amax((2, 3), keepdim=True) else: x_ge = F.avg_pool2d( x, kernel_size=self.gk, stride=self.gs, padding=self.gk // 2, count_include_pad=False, ) if self.add_maxpool: # experimental codepath, may remove or change x_ge = 0.5 * x_ge + 0.5 * F.max_pool2d( x, kernel_size=self.gk, stride=self.gs, padding=self.gk // 2 ) x_ge = self.mlp(x_ge) if x_ge.shape[-1] != 1 or x_ge.shape[-2] != 1: x_ge = F.interpolate(x_ge, size=size) return x * self.gate(x_ge) ================================================ FILE: RVT/models/layers/maxvit/layers/global_context.py ================================================ """ Global Context Attention Block Paper: `GCNet: Non-local Networks Meet Squeeze-Excitation Networks and Beyond` - https://arxiv.org/abs/1904.11492 Official code consulted as reference: https://github.com/xvjiarui/GCNet Hacked together by / Copyright 2021 Ross Wightman """ from torch import nn as nn import torch.nn.functional as F from .create_act import create_act_layer, get_act_layer from .helpers import make_divisible from .mlp import ConvMlp from .norm import LayerNorm2d class GlobalContext(nn.Module): def __init__( self, channels, use_attn=True, fuse_add=False, fuse_scale=True, init_last_zero=False, rd_ratio=1.0 / 8, rd_channels=None, rd_divisor=1, act_layer=nn.ReLU, gate_layer="sigmoid", ): super(GlobalContext, self).__init__() act_layer = get_act_layer(act_layer) self.conv_attn = ( nn.Conv2d(channels, 1, kernel_size=1, bias=True) if use_attn else None ) if rd_channels is None: rd_channels = make_divisible( channels * rd_ratio, rd_divisor, round_limit=0.0 ) if fuse_add: self.mlp_add = ConvMlp( channels, rd_channels, act_layer=act_layer, norm_layer=LayerNorm2d ) else: self.mlp_add = None if fuse_scale: self.mlp_scale = ConvMlp( channels, rd_channels, act_layer=act_layer, norm_layer=LayerNorm2d ) else: self.mlp_scale = None self.gate = create_act_layer(gate_layer) self.init_last_zero = init_last_zero self.reset_parameters() def reset_parameters(self): if self.conv_attn is not None: nn.init.kaiming_normal_( self.conv_attn.weight, mode="fan_in", nonlinearity="relu" ) if self.mlp_add is not None: nn.init.zeros_(self.mlp_add.fc2.weight) def forward(self, x): B, C, H, W = x.shape if self.conv_attn is not None: attn = self.conv_attn(x).reshape(B, 1, H * W) # (B, 1, H * W) attn = F.softmax(attn, dim=-1).unsqueeze(3) # (B, 1, H * W, 1) context = x.reshape(B, C, H * W).unsqueeze(1) @ attn context = context.view(B, C, 1, 1) else: context = x.mean(dim=(2, 3), keepdim=True) if self.mlp_scale is not None: mlp_x = self.mlp_scale(context) x = x * self.gate(mlp_x) if self.mlp_add is not None: mlp_x = self.mlp_add(context) x = x + mlp_x return x ================================================ FILE: RVT/models/layers/maxvit/layers/halo_attn.py ================================================ """ Halo Self Attention Paper: `Scaling Local Self-Attention for Parameter Efficient Visual Backbones` - https://arxiv.org/abs/2103.12731 @misc{2103.12731, Author = {Ashish Vaswani and Prajit Ramachandran and Aravind Srinivas and Niki Parmar and Blake Hechtman and Jonathon Shlens}, Title = {Scaling Local Self-Attention for Parameter Efficient Visual Backbones}, Year = {2021}, } Status: This impl is a WIP, there is no official ref impl and some details in paper weren't clear to me. The attention mechanism works but it's slow as implemented. Hacked together by / Copyright 2021 Ross Wightman """ from typing import List import torch from torch import nn import torch.nn.functional as F from .helpers import make_divisible from .weight_init import trunc_normal_ from .trace_utils import _assert def rel_logits_1d(q, rel_k, permute_mask: List[int]): """Compute relative logits along one dimension As per: https://gist.github.com/aravindsrinivas/56359b79f0ce4449bcb04ab4b56a57a2 Originally from: `Attention Augmented Convolutional Networks` - https://arxiv.org/abs/1904.09925 Args: q: (batch, height, width, dim) rel_k: (2 * window - 1, dim) permute_mask: permute output dim according to this """ B, H, W, dim = q.shape rel_size = rel_k.shape[0] win_size = (rel_size + 1) // 2 x = q @ rel_k.transpose(-1, -2) x = x.reshape(-1, W, rel_size) # pad to shift from relative to absolute indexing x_pad = F.pad(x, [0, 1]).flatten(1) x_pad = F.pad(x_pad, [0, rel_size - W]) # reshape and slice out the padded elements x_pad = x_pad.reshape(-1, W + 1, rel_size) x = x_pad[:, :W, win_size - 1 :] # reshape and tile x = x.reshape(B, H, 1, W, win_size).expand(-1, -1, win_size, -1, -1) return x.permute(permute_mask) class PosEmbedRel(nn.Module): """Relative Position Embedding As per: https://gist.github.com/aravindsrinivas/56359b79f0ce4449bcb04ab4b56a57a2 Originally from: `Attention Augmented Convolutional Networks` - https://arxiv.org/abs/1904.09925 """ def __init__(self, block_size, win_size, dim_head, scale): """ Args: block_size (int): block size win_size (int): neighbourhood window size dim_head (int): attention head dim scale (float): scale factor (for init) """ super().__init__() self.block_size = block_size self.dim_head = dim_head self.height_rel = nn.Parameter(torch.randn(win_size * 2 - 1, dim_head) * scale) self.width_rel = nn.Parameter(torch.randn(win_size * 2 - 1, dim_head) * scale) def forward(self, q): B, BB, HW, _ = q.shape # relative logits in width dimension. q = q.reshape(-1, self.block_size, self.block_size, self.dim_head) rel_logits_w = rel_logits_1d(q, self.width_rel, permute_mask=(0, 1, 3, 2, 4)) # relative logits in height dimension. q = q.transpose(1, 2) rel_logits_h = rel_logits_1d(q, self.height_rel, permute_mask=(0, 3, 1, 4, 2)) rel_logits = rel_logits_h + rel_logits_w rel_logits = rel_logits.reshape(B, BB, HW, -1) return rel_logits class HaloAttn(nn.Module): """Halo Attention Paper: `Scaling Local Self-Attention for Parameter Efficient Visual Backbones` - https://arxiv.org/abs/2103.12731 The internal dimensions of the attention module are controlled by the interaction of several arguments. * the output dimension of the module is specified by dim_out, which falls back to input dim if not set * the value (v) dimension is set to dim_out // num_heads, the v projection determines the output dim * the query and key (qk) dimensions are determined by * num_heads * dim_head if dim_head is not None * num_heads * (dim_out * attn_ratio // num_heads) if dim_head is None * as seen above, attn_ratio determines the ratio of q and k relative to the output if dim_head not used Args: dim (int): input dimension to the module dim_out (int): output dimension of the module, same as dim if not set feat_size (Tuple[int, int]): size of input feature_map (not used, for arg compat with bottle/lambda) stride: output stride of the module, query downscaled if > 1 (default: 1). num_heads: parallel attention heads (default: 8). dim_head: dimension of query and key heads, calculated from dim_out * attn_ratio // num_heads if not set block_size (int): size of blocks. (default: 8) halo_size (int): size of halo overlap. (default: 3) qk_ratio (float): ratio of q and k dimensions to output dimension when dim_head not set. (default: 1.0) qkv_bias (bool) : add bias to q, k, and v projections avg_down (bool): use average pool downsample instead of strided query blocks scale_pos_embed (bool): scale the position embedding as well as Q @ K """ def __init__( self, dim, dim_out=None, feat_size=None, stride=1, num_heads=8, dim_head=None, block_size=8, halo_size=3, qk_ratio=1.0, qkv_bias=False, avg_down=False, scale_pos_embed=False, ): super().__init__() dim_out = dim_out or dim assert dim_out % num_heads == 0 assert stride in (1, 2) self.num_heads = num_heads self.dim_head_qk = ( dim_head or make_divisible(dim_out * qk_ratio, divisor=8) // num_heads ) self.dim_head_v = dim_out // self.num_heads self.dim_out_qk = num_heads * self.dim_head_qk self.dim_out_v = num_heads * self.dim_head_v self.scale = self.dim_head_qk**-0.5 self.scale_pos_embed = scale_pos_embed self.block_size = self.block_size_ds = block_size self.halo_size = halo_size self.win_size = block_size + halo_size * 2 # neighbourhood window size self.block_stride = 1 use_avg_pool = False if stride > 1: use_avg_pool = avg_down or block_size % stride != 0 self.block_stride = 1 if use_avg_pool else stride self.block_size_ds = self.block_size // self.block_stride # FIXME not clear if this stride behaviour is what the paper intended # Also, the paper mentions using a 3D conv for dealing with the blocking/gather, and leaving # data in unfolded block form. I haven't wrapped my head around how that'd look. self.q = nn.Conv2d( dim, self.dim_out_qk, 1, stride=self.block_stride, bias=qkv_bias ) self.kv = nn.Conv2d(dim, self.dim_out_qk + self.dim_out_v, 1, bias=qkv_bias) self.pos_embed = PosEmbedRel( block_size=self.block_size_ds, win_size=self.win_size, dim_head=self.dim_head_qk, scale=self.scale, ) self.pool = nn.AvgPool2d(2, 2) if use_avg_pool else nn.Identity() self.reset_parameters() def reset_parameters(self): std = self.q.weight.shape[1] ** -0.5 # fan-in trunc_normal_(self.q.weight, std=std) trunc_normal_(self.kv.weight, std=std) trunc_normal_(self.pos_embed.height_rel, std=self.scale) trunc_normal_(self.pos_embed.width_rel, std=self.scale) def forward(self, x): B, C, H, W = x.shape _assert(H % self.block_size == 0, "") _assert(W % self.block_size == 0, "") num_h_blocks = H // self.block_size num_w_blocks = W // self.block_size num_blocks = num_h_blocks * num_w_blocks q = self.q(x) # unfold q = q.reshape( -1, self.dim_head_qk, num_h_blocks, self.block_size_ds, num_w_blocks, self.block_size_ds, ).permute(0, 1, 3, 5, 2, 4) # B, num_heads * dim_head * block_size ** 2, num_blocks q = q.reshape(B * self.num_heads, self.dim_head_qk, -1, num_blocks).transpose( 1, 3 ) # B * num_heads, num_blocks, block_size ** 2, dim_head kv = self.kv(x) # Generate overlapping windows for kv. This approach is good for GPU and CPU. However, unfold() is not # lowered for PyTorch XLA so it will be very slow. See code at bottom of file for XLA friendly approach. # FIXME figure out how to switch impl between this and conv2d if XLA being used. kv = F.pad(kv, [self.halo_size, self.halo_size, self.halo_size, self.halo_size]) kv = ( kv.unfold(2, self.win_size, self.block_size) .unfold(3, self.win_size, self.block_size) .reshape( B * self.num_heads, self.dim_head_qk + self.dim_head_v, num_blocks, -1 ) .permute(0, 2, 3, 1) ) k, v = torch.split(kv, [self.dim_head_qk, self.dim_head_v], dim=-1) # B * num_heads, num_blocks, win_size ** 2, dim_head_qk or dim_head_v if self.scale_pos_embed: attn = (q @ k.transpose(-1, -2) + self.pos_embed(q)) * self.scale else: attn = (q @ k.transpose(-1, -2)) * self.scale + self.pos_embed(q) # B * num_heads, num_blocks, block_size ** 2, win_size ** 2 attn = attn.softmax(dim=-1) out = (attn @ v).transpose( 1, 3 ) # B * num_heads, dim_head_v, block_size ** 2, num_blocks # fold out = out.reshape( -1, self.block_size_ds, self.block_size_ds, num_h_blocks, num_w_blocks ) out = ( out.permute(0, 3, 1, 4, 2) .contiguous() .view(B, self.dim_out_v, H // self.block_stride, W // self.block_stride) ) # B, dim_out, H // block_stride, W // block_stride out = self.pool(out) return out """ Three alternatives for overlapping windows. `.unfold().unfold()` is same speed as stride tricks with similar clarity as F.unfold() if is_xla: # This code achieves haloing on PyTorch XLA with reasonable runtime trade-off, it is # EXTREMELY slow for backward on a GPU though so I need a way of selecting based on environment. WW = self.win_size ** 2 pw = torch.eye(WW, dtype=x.dtype, device=x.device).reshape(WW, 1, self.win_size, self.win_size) kv = F.conv2d(kv.reshape(-1, 1, H, W), pw, stride=self.block_size, padding=self.halo_size) elif self.stride_tricks: kv = F.pad(kv, [self.halo_size, self.halo_size, self.halo_size, self.halo_size]).contiguous() kv = kv.as_strided(( B, self.dim_out_qk + self.dim_out_v, self.win_size, self.win_size, num_h_blocks, num_w_blocks), stride=(kv.stride(0), kv.stride(1), kv.shape[-1], 1, self.block_size * kv.shape[-1], self.block_size)) else: kv = F.unfold(kv, kernel_size=self.win_size, stride=self.block_size, padding=self.halo_size) kv = kv.reshape( B * self.num_heads, self.dim_head_qk + self.dim_head_v, -1, num_blocks).transpose(1, 3) """ ================================================ FILE: RVT/models/layers/maxvit/layers/helpers.py ================================================ """ Layer/Module Helpers Hacked together by / Copyright 2020 Ross Wightman """ from itertools import repeat import collections.abc # From PyTorch internals def _ntuple(n): def parse(x): if isinstance(x, collections.abc.Iterable) and not isinstance(x, str): return x return tuple(repeat(x, n)) return parse to_1tuple = _ntuple(1) to_2tuple = _ntuple(2) to_3tuple = _ntuple(3) to_4tuple = _ntuple(4) to_ntuple = _ntuple def make_divisible(v, divisor=8, min_value=None, round_limit=0.9): min_value = min_value or divisor new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) # Make sure that round down does not go down by more than 10%. if new_v < round_limit * v: new_v += divisor return new_v def extend_tuple(x, n): # pdas a tuple to specified n by padding with last value if not isinstance(x, (tuple, list)): x = (x,) else: x = tuple(x) pad_n = n - len(x) if pad_n <= 0: return x[:n] return x + (x[-1],) * pad_n ================================================ FILE: RVT/models/layers/maxvit/layers/inplace_abn.py ================================================ import torch from torch import nn as nn try: from inplace_abn.functions import inplace_abn, inplace_abn_sync has_iabn = True except ImportError: has_iabn = False def inplace_abn( x, weight, bias, running_mean, running_var, training=True, momentum=0.1, eps=1e-05, activation="leaky_relu", activation_param=0.01, ): raise ImportError( "Please install InplaceABN:'pip install git+https://github.com/mapillary/inplace_abn.git@v1.0.12'" ) def inplace_abn_sync(**kwargs): inplace_abn(**kwargs) class InplaceAbn(nn.Module): """Activated Batch Normalization This gathers a BatchNorm and an activation function in a single module Parameters ---------- num_features : int Number of feature channels in the input and output. eps : float Small constant to prevent numerical issues. momentum : float Momentum factor applied to compute running statistics. affine : bool If `True` apply learned scale and shift transformation after normalization. act_layer : str or nn.Module type Name or type of the activation functions, one of: `leaky_relu`, `elu` act_param : float Negative slope for the `leaky_relu` activation. """ def __init__( self, num_features, eps=1e-5, momentum=0.1, affine=True, apply_act=True, act_layer="leaky_relu", act_param=0.01, drop_layer=None, ): super(InplaceAbn, self).__init__() self.num_features = num_features self.affine = affine self.eps = eps self.momentum = momentum if apply_act: if isinstance(act_layer, str): assert act_layer in ("leaky_relu", "elu", "identity", "") self.act_name = act_layer if act_layer else "identity" else: # convert act layer passed as type to string if act_layer == nn.ELU: self.act_name = "elu" elif act_layer == nn.LeakyReLU: self.act_name = "leaky_relu" elif act_layer is None or act_layer == nn.Identity: self.act_name = "identity" else: assert False, f"Invalid act layer {act_layer.__name__} for IABN" else: self.act_name = "identity" self.act_param = act_param if self.affine: self.weight = nn.Parameter(torch.ones(num_features)) self.bias = nn.Parameter(torch.zeros(num_features)) else: self.register_parameter("weight", None) self.register_parameter("bias", None) self.register_buffer("running_mean", torch.zeros(num_features)) self.register_buffer("running_var", torch.ones(num_features)) self.reset_parameters() def reset_parameters(self): nn.init.constant_(self.running_mean, 0) nn.init.constant_(self.running_var, 1) if self.affine: nn.init.constant_(self.weight, 1) nn.init.constant_(self.bias, 0) def forward(self, x): output = inplace_abn( x, self.weight, self.bias, self.running_mean, self.running_var, self.training, self.momentum, self.eps, self.act_name, self.act_param, ) if isinstance(output, tuple): output = output[0] return output ================================================ FILE: RVT/models/layers/maxvit/layers/lambda_layer.py ================================================ """ Lambda Layer Paper: `LambdaNetworks: Modeling Long-Range Interactions Without Attention` - https://arxiv.org/abs/2102.08602 @misc{2102.08602, Author = {Irwan Bello}, Title = {LambdaNetworks: Modeling Long-Range Interactions Without Attention}, Year = {2021}, } Status: This impl is a WIP. Code snippets in the paper were used as reference but good chance some details are missing/wrong. I've only implemented local lambda conv based pos embeddings. For a PyTorch impl that includes other embedding options checkout https://github.com/lucidrains/lambda-networks Hacked together by / Copyright 2021 Ross Wightman """ import torch from torch import nn import torch.nn.functional as F from .helpers import to_2tuple, make_divisible from .weight_init import trunc_normal_ def rel_pos_indices(size): size = to_2tuple(size) pos = torch.stack( torch.meshgrid(torch.arange(size[0]), torch.arange(size[1])) ).flatten(1) rel_pos = pos[:, None, :] - pos[:, :, None] rel_pos[0] += size[0] - 1 rel_pos[1] += size[1] - 1 return rel_pos # 2, H * W, H * W class LambdaLayer(nn.Module): """Lambda Layer Paper: `LambdaNetworks: Modeling Long-Range Interactions Without Attention` - https://arxiv.org/abs/2102.08602 NOTE: intra-depth parameter 'u' is fixed at 1. It did not appear worth the complexity to add. The internal dimensions of the lambda module are controlled via the interaction of several arguments. * the output dimension of the module is specified by dim_out, which falls back to input dim if not set * the value (v) dimension is set to dim_out // num_heads, the v projection determines the output dim * the query (q) and key (k) dimension are determined by * dim_head = (dim_out * attn_ratio // num_heads) if dim_head is None * q = num_heads * dim_head, k = dim_head * as seen above, attn_ratio determines the ratio of q and k relative to the output if dim_head not set Args: dim (int): input dimension to the module dim_out (int): output dimension of the module, same as dim if not set feat_size (Tuple[int, int]): size of input feature_map for relative pos variant H, W stride (int): output stride of the module, avg pool used if stride == 2 num_heads (int): parallel attention heads. dim_head (int): dimension of query and key heads, calculated from dim_out * attn_ratio // num_heads if not set r (int): local lambda convolution radius. Use lambda conv if set, else relative pos if not. (default: 9) qk_ratio (float): ratio of q and k dimensions to output dimension when dim_head not set. (default: 1.0) qkv_bias (bool): add bias to q, k, and v projections """ def __init__( self, dim, dim_out=None, feat_size=None, stride=1, num_heads=4, dim_head=16, r=9, qk_ratio=1.0, qkv_bias=False, ): super().__init__() dim_out = dim_out or dim assert dim_out % num_heads == 0, " should be divided by num_heads" self.dim_qk = ( dim_head or make_divisible(dim_out * qk_ratio, divisor=8) // num_heads ) self.num_heads = num_heads self.dim_v = dim_out // num_heads self.qkv = nn.Conv2d( dim, num_heads * self.dim_qk + self.dim_qk + self.dim_v, kernel_size=1, bias=qkv_bias, ) self.norm_q = nn.BatchNorm2d(num_heads * self.dim_qk) self.norm_v = nn.BatchNorm2d(self.dim_v) if r is not None: # local lambda convolution for pos self.conv_lambda = nn.Conv3d( 1, self.dim_qk, (r, r, 1), padding=(r // 2, r // 2, 0) ) self.pos_emb = None self.rel_pos_indices = None else: # relative pos embedding assert feat_size is not None feat_size = to_2tuple(feat_size) rel_size = [2 * s - 1 for s in feat_size] self.conv_lambda = None self.pos_emb = nn.Parameter( torch.zeros(rel_size[0], rel_size[1], self.dim_qk) ) self.register_buffer( "rel_pos_indices", rel_pos_indices(feat_size), persistent=False ) self.pool = nn.AvgPool2d(2, 2) if stride == 2 else nn.Identity() self.reset_parameters() def reset_parameters(self): trunc_normal_(self.qkv.weight, std=self.qkv.weight.shape[1] ** -0.5) # fan-in if self.conv_lambda is not None: trunc_normal_(self.conv_lambda.weight, std=self.dim_qk**-0.5) if self.pos_emb is not None: trunc_normal_(self.pos_emb, std=0.02) def forward(self, x): B, C, H, W = x.shape M = H * W qkv = self.qkv(x) q, k, v = torch.split( qkv, [self.num_heads * self.dim_qk, self.dim_qk, self.dim_v], dim=1 ) q = ( self.norm_q(q).reshape(B, self.num_heads, self.dim_qk, M).transpose(-1, -2) ) # B, num_heads, M, K v = self.norm_v(v).reshape(B, self.dim_v, M).transpose(-1, -2) # B, M, V k = F.softmax(k.reshape(B, self.dim_qk, M), dim=-1) # B, K, M content_lam = k @ v # B, K, V content_out = q @ content_lam.unsqueeze(1) # B, num_heads, M, V if self.pos_emb is None: position_lam = self.conv_lambda( v.reshape(B, 1, H, W, self.dim_v) ) # B, H, W, V, K position_lam = position_lam.reshape( B, 1, self.dim_qk, H * W, self.dim_v ).transpose( 2, 3 ) # B, 1, M, K, V else: # FIXME relative pos embedding path not fully verified pos_emb = self.pos_emb[ self.rel_pos_indices[0], self.rel_pos_indices[1] ].expand(B, -1, -1, -1) position_lam = (pos_emb.transpose(-1, -2) @ v.unsqueeze(1)).unsqueeze( 1 ) # B, 1, M, K, V position_out = (q.unsqueeze(-2) @ position_lam).squeeze( -2 ) # B, num_heads, M, V out = ( (content_out + position_out).transpose(-1, -2).reshape(B, C, H, W) ) # B, C (num_heads * V), H, W out = self.pool(out) return out ================================================ FILE: RVT/models/layers/maxvit/layers/linear.py ================================================ """ Linear layer (alternate definition) """ import torch import torch.nn.functional as F from torch import nn as nn class Linear(nn.Linear): r"""Applies a linear transformation to the incoming data: :math:`y = xA^T + b` Wraps torch.nn.Linear to support AMP + torchscript usage by manually casting weight & bias to input.dtype to work around an issue w/ torch.addmm in this use case. """ def forward(self, input: torch.Tensor) -> torch.Tensor: if torch.jit.is_scripting(): bias = self.bias.to(dtype=input.dtype) if self.bias is not None else None return F.linear(input, self.weight.to(dtype=input.dtype), bias=bias) else: return F.linear(input, self.weight, self.bias) ================================================ FILE: RVT/models/layers/maxvit/layers/median_pool.py ================================================ """ Median Pool Hacked together by / Copyright 2020 Ross Wightman """ import torch.nn as nn import torch.nn.functional as F from .helpers import to_2tuple, to_4tuple class MedianPool2d(nn.Module): """Median pool (usable as median filter when stride=1) module. Args: kernel_size: size of pooling kernel, int or 2-tuple stride: pool stride, int or 2-tuple padding: pool padding, int or 4-tuple (l, r, t, b) as in pytorch F.pad same: override padding and enforce same padding, boolean """ def __init__(self, kernel_size=3, stride=1, padding=0, same=False): super(MedianPool2d, self).__init__() self.k = to_2tuple(kernel_size) self.stride = to_2tuple(stride) self.padding = to_4tuple(padding) # convert to l, r, t, b self.same = same def _padding(self, x): if self.same: ih, iw = x.size()[2:] if ih % self.stride[0] == 0: ph = max(self.k[0] - self.stride[0], 0) else: ph = max(self.k[0] - (ih % self.stride[0]), 0) if iw % self.stride[1] == 0: pw = max(self.k[1] - self.stride[1], 0) else: pw = max(self.k[1] - (iw % self.stride[1]), 0) pl = pw // 2 pr = pw - pl pt = ph // 2 pb = ph - pt padding = (pl, pr, pt, pb) else: padding = self.padding return padding def forward(self, x): x = F.pad(x, self._padding(x), mode="reflect") x = x.unfold(2, self.k[0], self.stride[0]).unfold(3, self.k[1], self.stride[1]) x = x.contiguous().view(x.size()[:4] + (-1,)).median(dim=-1)[0] return x ================================================ FILE: RVT/models/layers/maxvit/layers/mixed_conv2d.py ================================================ """ PyTorch Mixed Convolution Paper: MixConv: Mixed Depthwise Convolutional Kernels (https://arxiv.org/abs/1907.09595) Hacked together by / Copyright 2020 Ross Wightman """ import torch from torch import nn as nn from .conv2d_same import create_conv2d_pad def _split_channels(num_chan, num_groups): split = [num_chan // num_groups for _ in range(num_groups)] split[0] += num_chan - sum(split) return split class MixedConv2d(nn.ModuleDict): """Mixed Grouped Convolution Based on MDConv and GroupedConv in MixNet impl: https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mixnet/custom_layers.py """ def __init__( self, in_channels, out_channels, kernel_size=3, stride=1, padding="", dilation=1, depthwise=False, **kwargs ): super(MixedConv2d, self).__init__() kernel_size = kernel_size if isinstance(kernel_size, list) else [kernel_size] num_groups = len(kernel_size) in_splits = _split_channels(in_channels, num_groups) out_splits = _split_channels(out_channels, num_groups) self.in_channels = sum(in_splits) self.out_channels = sum(out_splits) for idx, (k, in_ch, out_ch) in enumerate( zip(kernel_size, in_splits, out_splits) ): conv_groups = in_ch if depthwise else 1 # use add_module to keep key space clean self.add_module( str(idx), create_conv2d_pad( in_ch, out_ch, k, stride=stride, padding=padding, dilation=dilation, groups=conv_groups, **kwargs ), ) self.splits = in_splits def forward(self, x): x_split = torch.split(x, self.splits, 1) x_out = [c(x_split[i]) for i, c in enumerate(self.values())] x = torch.cat(x_out, 1) return x ================================================ FILE: RVT/models/layers/maxvit/layers/ml_decoder.py ================================================ from typing import Optional import torch from torch import nn from torch import nn, Tensor from torch.nn.modules.transformer import _get_activation_fn def add_ml_decoder_head(model): if hasattr(model, "global_pool") and hasattr( model, "fc" ): # most CNN models, like Resnet50 model.global_pool = nn.Identity() del model.fc num_classes = model.num_classes num_features = model.num_features model.fc = MLDecoder(num_classes=num_classes, initial_num_features=num_features) elif hasattr(model, "global_pool") and hasattr(model, "classifier"): # EfficientNet model.global_pool = nn.Identity() del model.classifier num_classes = model.num_classes num_features = model.num_features model.classifier = MLDecoder( num_classes=num_classes, initial_num_features=num_features ) elif ( "RegNet" in model._get_name() or "TResNet" in model._get_name() ): # hasattr(model, 'head') del model.head num_classes = model.num_classes num_features = model.num_features model.head = MLDecoder( num_classes=num_classes, initial_num_features=num_features ) else: print("Model code-writing is not aligned currently with ml-decoder") exit(-1) if hasattr(model, "drop_rate"): # Ml-Decoder has inner dropout model.drop_rate = 0 return model class TransformerDecoderLayerOptimal(nn.Module): def __init__( self, d_model, nhead=8, dim_feedforward=2048, dropout=0.1, activation="relu", layer_norm_eps=1e-5, ) -> None: super(TransformerDecoderLayerOptimal, self).__init__() self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps) self.dropout = nn.Dropout(dropout) self.dropout1 = nn.Dropout(dropout) self.dropout2 = nn.Dropout(dropout) self.dropout3 = nn.Dropout(dropout) self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) # Implementation of Feedforward model self.linear1 = nn.Linear(d_model, dim_feedforward) self.linear2 = nn.Linear(dim_feedforward, d_model) self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps) self.norm3 = nn.LayerNorm(d_model, eps=layer_norm_eps) self.activation = _get_activation_fn(activation) def __setstate__(self, state): if "activation" not in state: state["activation"] = torch.nn.functional.relu super(TransformerDecoderLayerOptimal, self).__setstate__(state) def forward( self, tgt: Tensor, memory: Tensor, tgt_mask: Optional[Tensor] = None, memory_mask: Optional[Tensor] = None, tgt_key_padding_mask: Optional[Tensor] = None, memory_key_padding_mask: Optional[Tensor] = None, ) -> Tensor: tgt = tgt + self.dropout1(tgt) tgt = self.norm1(tgt) tgt2 = self.multihead_attn(tgt, memory, memory)[0] tgt = tgt + self.dropout2(tgt2) tgt = self.norm2(tgt) tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt)))) tgt = tgt + self.dropout3(tgt2) tgt = self.norm3(tgt) return tgt # @torch.jit.script # class ExtrapClasses(object): # def __init__(self, num_queries: int, group_size: int): # self.num_queries = num_queries # self.group_size = group_size # # def __call__(self, h: torch.Tensor, class_embed_w: torch.Tensor, class_embed_b: torch.Tensor, out_extrap: # torch.Tensor): # # h = h.unsqueeze(-1).expand(-1, -1, -1, self.group_size) # h = h[..., None].repeat(1, 1, 1, self.group_size) # torch.Size([bs, 5, 768, groups]) # w = class_embed_w.view((self.num_queries, h.shape[2], self.group_size)) # out = (h * w).sum(dim=2) + class_embed_b # out = out.view((h.shape[0], self.group_size * self.num_queries)) # return out @torch.jit.script class GroupFC(object): def __init__(self, embed_len_decoder: int): self.embed_len_decoder = embed_len_decoder def __call__( self, h: torch.Tensor, duplicate_pooling: torch.Tensor, out_extrap: torch.Tensor ): for i in range(self.embed_len_decoder): h_i = h[:, i, :] w_i = duplicate_pooling[i, :, :] out_extrap[:, i, :] = torch.matmul(h_i, w_i) class MLDecoder(nn.Module): def __init__( self, num_classes, num_of_groups=-1, decoder_embedding=768, initial_num_features=2048, ): super(MLDecoder, self).__init__() embed_len_decoder = 100 if num_of_groups < 0 else num_of_groups if embed_len_decoder > num_classes: embed_len_decoder = num_classes # switching to 768 initial embeddings decoder_embedding = 768 if decoder_embedding < 0 else decoder_embedding self.embed_standart = nn.Linear(initial_num_features, decoder_embedding) # decoder decoder_dropout = 0.1 num_layers_decoder = 1 dim_feedforward = 2048 layer_decode = TransformerDecoderLayerOptimal( d_model=decoder_embedding, dim_feedforward=dim_feedforward, dropout=decoder_dropout, ) self.decoder = nn.TransformerDecoder( layer_decode, num_layers=num_layers_decoder ) # non-learnable queries self.query_embed = nn.Embedding(embed_len_decoder, decoder_embedding) self.query_embed.requires_grad_(False) # group fully-connected self.num_classes = num_classes self.duplicate_factor = int(num_classes / embed_len_decoder + 0.999) self.duplicate_pooling = torch.nn.Parameter( torch.Tensor(embed_len_decoder, decoder_embedding, self.duplicate_factor) ) self.duplicate_pooling_bias = torch.nn.Parameter(torch.Tensor(num_classes)) torch.nn.init.xavier_normal_(self.duplicate_pooling) torch.nn.init.constant_(self.duplicate_pooling_bias, 0) self.group_fc = GroupFC(embed_len_decoder) def forward(self, x): if len(x.shape) == 4: # [bs,2048, 7,7] embedding_spatial = x.flatten(2).transpose(1, 2) else: # [bs, 197,468] embedding_spatial = x embedding_spatial_786 = self.embed_standart(embedding_spatial) embedding_spatial_786 = torch.nn.functional.relu( embedding_spatial_786, inplace=True ) bs = embedding_spatial_786.shape[0] query_embed = self.query_embed.weight # tgt = query_embed.unsqueeze(1).repeat(1, bs, 1) tgt = query_embed.unsqueeze(1).expand( -1, bs, -1 ) # no allocation of memory with expand h = self.decoder( tgt, embedding_spatial_786.transpose(0, 1) ) # [embed_len_decoder, batch, 768] h = h.transpose(0, 1) out_extrap = torch.zeros( h.shape[0], h.shape[1], self.duplicate_factor, device=h.device, dtype=h.dtype, ) self.group_fc(h, self.duplicate_pooling, out_extrap) h_out = out_extrap.flatten(1)[:, : self.num_classes] h_out += self.duplicate_pooling_bias logits = h_out return logits ================================================ FILE: RVT/models/layers/maxvit/layers/mlp.py ================================================ """ MLP module w/ dropout and configurable activation layer Hacked together by / Copyright 2020 Ross Wightman """ from torch import nn as nn from .helpers import to_2tuple class Mlp(nn.Module): """MLP as used in Vision Transformer, MLP-Mixer and related networks""" def __init__( self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, bias=True, drop=0.0, ): super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features bias = to_2tuple(bias) drop_probs = to_2tuple(drop) self.fc1 = nn.Linear(in_features, hidden_features, bias=bias[0]) self.act = act_layer() self.drop1 = nn.Dropout(drop_probs[0]) self.fc2 = nn.Linear(hidden_features, out_features, bias=bias[1]) self.drop2 = nn.Dropout(drop_probs[1]) def forward(self, x): x = self.fc1(x) x = self.act(x) x = self.drop1(x) x = self.fc2(x) x = self.drop2(x) return x class GluMlp(nn.Module): """MLP w/ GLU style gating See: https://arxiv.org/abs/1612.08083, https://arxiv.org/abs/2002.05202 """ def __init__( self, in_features, hidden_features=None, out_features=None, act_layer=nn.Sigmoid, bias=True, drop=0.0, ): super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features assert hidden_features % 2 == 0 bias = to_2tuple(bias) drop_probs = to_2tuple(drop) self.fc1 = nn.Linear(in_features, hidden_features, bias=bias[0]) self.act = act_layer() self.drop1 = nn.Dropout(drop_probs[0]) self.fc2 = nn.Linear(hidden_features // 2, out_features, bias=bias[1]) self.drop2 = nn.Dropout(drop_probs[1]) def init_weights(self): # override init of fc1 w/ gate portion set to weight near zero, bias=1 fc1_mid = self.fc1.bias.shape[0] // 2 nn.init.ones_(self.fc1.bias[fc1_mid:]) nn.init.normal_(self.fc1.weight[fc1_mid:], std=1e-6) def forward(self, x): x = self.fc1(x) x, gates = x.chunk(2, dim=-1) x = x * self.act(gates) x = self.drop1(x) x = self.fc2(x) x = self.drop2(x) return x class GatedMlp(nn.Module): """MLP as used in gMLP""" def __init__( self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, gate_layer=None, bias=True, drop=0.0, ): super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features bias = to_2tuple(bias) drop_probs = to_2tuple(drop) self.fc1 = nn.Linear(in_features, hidden_features, bias=bias[0]) self.act = act_layer() self.drop1 = nn.Dropout(drop_probs[0]) if gate_layer is not None: assert hidden_features % 2 == 0 self.gate = gate_layer(hidden_features) hidden_features = ( hidden_features // 2 ) # FIXME base reduction on gate property? else: self.gate = nn.Identity() self.fc2 = nn.Linear(hidden_features, out_features, bias=bias[1]) self.drop2 = nn.Dropout(drop_probs[1]) def forward(self, x): x = self.fc1(x) x = self.act(x) x = self.drop1(x) x = self.gate(x) x = self.fc2(x) x = self.drop2(x) return x class ConvMlp(nn.Module): """MLP using 1x1 convs that keeps spatial dims""" def __init__( self, in_features, hidden_features=None, out_features=None, act_layer=nn.ReLU, norm_layer=None, bias=True, drop=0.0, ): super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features bias = to_2tuple(bias) self.fc1 = nn.Conv2d(in_features, hidden_features, kernel_size=1, bias=bias[0]) self.norm = norm_layer(hidden_features) if norm_layer else nn.Identity() self.act = act_layer() self.drop = nn.Dropout(drop) self.fc2 = nn.Conv2d(hidden_features, out_features, kernel_size=1, bias=bias[1]) def forward(self, x): x = self.fc1(x) x = self.norm(x) x = self.act(x) x = self.drop(x) x = self.fc2(x) return x ================================================ FILE: RVT/models/layers/maxvit/layers/non_local_attn.py ================================================ """ Bilinear-Attention-Transform and Non-Local Attention Paper: `Non-Local Neural Networks With Grouped Bilinear Attentional Transforms` - https://openaccess.thecvf.com/content_CVPR_2020/html/Chi_Non-Local_Neural_Networks_With_Grouped_Bilinear_Attentional_Transforms_CVPR_2020_paper.html Adapted from original code: https://github.com/BA-Transform/BAT-Image-Classification """ import torch from torch import nn from torch.nn import functional as F from .conv_bn_act import ConvNormAct from .helpers import make_divisible from .trace_utils import _assert class NonLocalAttn(nn.Module): """Spatial NL block for image classification. This was adapted from https://github.com/BA-Transform/BAT-Image-Classification Their NonLocal impl inspired by https://github.com/facebookresearch/video-nonlocal-net. """ def __init__( self, in_channels, use_scale=True, rd_ratio=1 / 8, rd_channels=None, rd_divisor=8, **kwargs ): super(NonLocalAttn, self).__init__() if rd_channels is None: rd_channels = make_divisible(in_channels * rd_ratio, divisor=rd_divisor) self.scale = in_channels**-0.5 if use_scale else 1.0 self.t = nn.Conv2d(in_channels, rd_channels, kernel_size=1, stride=1, bias=True) self.p = nn.Conv2d(in_channels, rd_channels, kernel_size=1, stride=1, bias=True) self.g = nn.Conv2d(in_channels, rd_channels, kernel_size=1, stride=1, bias=True) self.z = nn.Conv2d(rd_channels, in_channels, kernel_size=1, stride=1, bias=True) self.norm = nn.BatchNorm2d(in_channels) self.reset_parameters() def forward(self, x): shortcut = x t = self.t(x) p = self.p(x) g = self.g(x) B, C, H, W = t.size() t = t.view(B, C, -1).permute(0, 2, 1) p = p.view(B, C, -1) g = g.view(B, C, -1).permute(0, 2, 1) att = torch.bmm(t, p) * self.scale att = F.softmax(att, dim=2) x = torch.bmm(att, g) x = x.permute(0, 2, 1).reshape(B, C, H, W) x = self.z(x) x = self.norm(x) + shortcut return x def reset_parameters(self): for name, m in self.named_modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") if len(list(m.parameters())) > 1: nn.init.constant_(m.bias, 0.0) elif isinstance(m, nn.BatchNorm2d): nn.init.constant_(m.weight, 0) nn.init.constant_(m.bias, 0) elif isinstance(m, nn.GroupNorm): nn.init.constant_(m.weight, 0) nn.init.constant_(m.bias, 0) class BilinearAttnTransform(nn.Module): def __init__( self, in_channels, block_size, groups, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, ): super(BilinearAttnTransform, self).__init__() self.conv1 = ConvNormAct( in_channels, groups, 1, act_layer=act_layer, norm_layer=norm_layer ) self.conv_p = nn.Conv2d( groups, block_size * block_size * groups, kernel_size=(block_size, 1) ) self.conv_q = nn.Conv2d( groups, block_size * block_size * groups, kernel_size=(1, block_size) ) self.conv2 = ConvNormAct( in_channels, in_channels, 1, act_layer=act_layer, norm_layer=norm_layer ) self.block_size = block_size self.groups = groups self.in_channels = in_channels def resize_mat(self, x, t: int): B, C, block_size, block_size1 = x.shape _assert(block_size == block_size1, "") if t <= 1: return x x = x.view(B * C, -1, 1, 1) x = x * torch.eye(t, t, dtype=x.dtype, device=x.device) x = x.view(B * C, block_size, block_size, t, t) x = torch.cat(torch.split(x, 1, dim=1), dim=3) x = torch.cat(torch.split(x, 1, dim=2), dim=4) x = x.view(B, C, block_size * t, block_size * t) return x def forward(self, x): _assert(x.shape[-1] % self.block_size == 0, "") _assert(x.shape[-2] % self.block_size == 0, "") B, C, H, W = x.shape out = self.conv1(x) rp = F.adaptive_max_pool2d(out, (self.block_size, 1)) cp = F.adaptive_max_pool2d(out, (1, self.block_size)) p = ( self.conv_p(rp) .view(B, self.groups, self.block_size, self.block_size) .sigmoid() ) q = ( self.conv_q(cp) .view(B, self.groups, self.block_size, self.block_size) .sigmoid() ) p = p / p.sum(dim=3, keepdim=True) q = q / q.sum(dim=2, keepdim=True) p = ( p.view(B, self.groups, 1, self.block_size, self.block_size) .expand( x.size(0), self.groups, C // self.groups, self.block_size, self.block_size, ) .contiguous() ) p = p.view(B, C, self.block_size, self.block_size) q = ( q.view(B, self.groups, 1, self.block_size, self.block_size) .expand( x.size(0), self.groups, C // self.groups, self.block_size, self.block_size, ) .contiguous() ) q = q.view(B, C, self.block_size, self.block_size) p = self.resize_mat(p, H // self.block_size) q = self.resize_mat(q, W // self.block_size) y = p.matmul(x) y = y.matmul(q) y = self.conv2(y) return y class BatNonLocalAttn(nn.Module): """BAT Adapted from: https://github.com/BA-Transform/BAT-Image-Classification """ def __init__( self, in_channels, block_size=7, groups=2, rd_ratio=0.25, rd_channels=None, rd_divisor=8, drop_rate=0.2, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, **_ ): super().__init__() if rd_channels is None: rd_channels = make_divisible(in_channels * rd_ratio, divisor=rd_divisor) self.conv1 = ConvNormAct( in_channels, rd_channels, 1, act_layer=act_layer, norm_layer=norm_layer ) self.ba = BilinearAttnTransform( rd_channels, block_size, groups, act_layer=act_layer, norm_layer=norm_layer ) self.conv2 = ConvNormAct( rd_channels, in_channels, 1, act_layer=act_layer, norm_layer=norm_layer ) self.dropout = nn.Dropout2d(p=drop_rate) def forward(self, x): xl = self.conv1(x) y = self.ba(xl) y = self.conv2(y) y = self.dropout(y) return y + x ================================================ FILE: RVT/models/layers/maxvit/layers/norm.py ================================================ """ Normalization layers and wrappers Norm layer definitions that support fast norm and consistent channel arg order (always first arg). Hacked together by / Copyright 2022 Ross Wightman """ import torch import torch.nn as nn import torch.nn.functional as F from .fast_norm import is_fast_norm, fast_group_norm, fast_layer_norm class GroupNorm(nn.GroupNorm): def __init__(self, num_channels, num_groups=32, eps=1e-5, affine=True): # NOTE num_channels is swapped to first arg for consistency in swapping norm layers with BN super().__init__(num_groups, num_channels, eps=eps, affine=affine) self.fast_norm = ( is_fast_norm() ) # can't script unless we have these flags here (no globals) def forward(self, x): if self.fast_norm: return fast_group_norm(x, self.num_groups, self.weight, self.bias, self.eps) else: return F.group_norm(x, self.num_groups, self.weight, self.bias, self.eps) class GroupNorm1(nn.GroupNorm): """Group Normalization with 1 group. Input: tensor in shape [B, C, *] """ def __init__(self, num_channels, **kwargs): super().__init__(1, num_channels, **kwargs) self.fast_norm = ( is_fast_norm() ) # can't script unless we have these flags here (no globals) def forward(self, x: torch.Tensor) -> torch.Tensor: if self.fast_norm: return fast_group_norm(x, self.num_groups, self.weight, self.bias, self.eps) else: return F.group_norm(x, self.num_groups, self.weight, self.bias, self.eps) class LayerNorm(nn.LayerNorm): """LayerNorm w/ fast norm option""" def __init__(self, num_channels, eps=1e-6, affine=True): super().__init__(num_channels, eps=eps, elementwise_affine=affine) self._fast_norm = ( is_fast_norm() ) # can't script unless we have these flags here (no globals) def forward(self, x: torch.Tensor) -> torch.Tensor: if self._fast_norm: x = fast_layer_norm( x, self.normalized_shape, self.weight, self.bias, self.eps ) else: x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) return x class LayerNorm2d(nn.LayerNorm): """LayerNorm for channels of '2D' spatial NCHW tensors""" def __init__(self, num_channels, eps=1e-6, affine=True): super().__init__(num_channels, eps=eps, elementwise_affine=affine) self._fast_norm = ( is_fast_norm() ) # can't script unless we have these flags here (no globals) def forward(self, x: torch.Tensor) -> torch.Tensor: x = x.permute(0, 2, 3, 1) if self._fast_norm: x = fast_layer_norm( x, self.normalized_shape, self.weight, self.bias, self.eps ) else: x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) x = x.permute(0, 3, 1, 2) return x def _is_contiguous(tensor: torch.Tensor) -> bool: # jit is oh so lovely :/ if torch.jit.is_scripting(): return tensor.is_contiguous() else: return tensor.is_contiguous(memory_format=torch.contiguous_format) @torch.jit.script def _layer_norm_cf( x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, eps: float ): s, u = torch.var_mean(x, dim=1, unbiased=False, keepdim=True) x = (x - u) * torch.rsqrt(s + eps) x = x * weight[:, None, None] + bias[:, None, None] return x def _layer_norm_cf_sqm( x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, eps: float ): u = x.mean(dim=1, keepdim=True) s = ((x * x).mean(dim=1, keepdim=True) - (u * u)).clamp(0) x = (x - u) * torch.rsqrt(s + eps) x = x * weight.view(1, -1, 1, 1) + bias.view(1, -1, 1, 1) return x class LayerNormExp2d(nn.LayerNorm): """LayerNorm for channels_first tensors with 2d spatial dimensions (ie N, C, H, W). Experimental implementation w/ manual norm for tensors non-contiguous tensors. This improves throughput in some scenarios (tested on Ampere GPU), esp w/ channels_last layout. However, benefits are not always clear and can perform worse on other GPUs. """ def __init__(self, num_channels, eps=1e-6): super().__init__(num_channels, eps=eps) def forward(self, x) -> torch.Tensor: if _is_contiguous(x): x = F.layer_norm( x.permute(0, 2, 3, 1), self.normalized_shape, self.weight, self.bias, self.eps, ).permute(0, 3, 1, 2) else: x = _layer_norm_cf(x, self.weight, self.bias, self.eps) return x ================================================ FILE: RVT/models/layers/maxvit/layers/norm_act.py ================================================ """ Normalization + Activation Layers Provides Norm+Act fns for standard PyTorch norm layers such as * BatchNorm * GroupNorm * LayerNorm This allows swapping with alternative layers that are natively both norm + act such as * EvoNorm (evo_norm.py) * FilterResponseNorm (filter_response_norm.py) * InplaceABN (inplace_abn.py) Hacked together by / Copyright 2022 Ross Wightman """ from typing import Union, List, Optional, Any import torch from torch import nn as nn from torch.nn import functional as F from .create_act import get_act_layer from .fast_norm import is_fast_norm, fast_group_norm, fast_layer_norm from .trace_utils import _assert class BatchNormAct2d(nn.BatchNorm2d): """BatchNorm + Activation This module performs BatchNorm + Activation in a manner that will remain backwards compatible with weights trained with separate bn, act. This is why we inherit from BN instead of composing it as a .bn member. """ def __init__( self, num_features, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True, apply_act=True, act_layer=nn.ReLU, inplace=True, drop_layer=None, device=None, dtype=None, ): try: factory_kwargs = {"device": device, "dtype": dtype} super(BatchNormAct2d, self).__init__( num_features, eps=eps, momentum=momentum, affine=affine, track_running_stats=track_running_stats, **factory_kwargs, ) except TypeError: # NOTE for backwards compat with old PyTorch w/o factory device/dtype support super(BatchNormAct2d, self).__init__( num_features, eps=eps, momentum=momentum, affine=affine, track_running_stats=track_running_stats, ) self.drop = drop_layer() if drop_layer is not None else nn.Identity() act_layer = get_act_layer(act_layer) # string -> nn.Module if act_layer is not None and apply_act: act_args = dict(inplace=True) if inplace else {} self.act = act_layer(**act_args) else: self.act = nn.Identity() def forward(self, x): # cut & paste of torch.nn.BatchNorm2d.forward impl to avoid issues with torchscript and tracing _assert(x.ndim == 4, f"expected 4D input (got {x.ndim}D input)") # exponential_average_factor is set to self.momentum # (when it is available) only so that it gets updated # in ONNX graph when this node is exported to ONNX. if self.momentum is None: exponential_average_factor = 0.0 else: exponential_average_factor = self.momentum if self.training and self.track_running_stats: # TODO: if statement only here to tell the jit to skip emitting this when it is None if self.num_batches_tracked is not None: # type: ignore[has-type] self.num_batches_tracked = self.num_batches_tracked + 1 # type: ignore[has-type] if self.momentum is None: # use cumulative moving average exponential_average_factor = 1.0 / float(self.num_batches_tracked) else: # use exponential moving average exponential_average_factor = self.momentum r""" Decide whether the mini-batch stats should be used for normalization rather than the buffers. Mini-batch stats are used in training mode, and in eval mode when buffers are None. """ if self.training: bn_training = True else: bn_training = (self.running_mean is None) and (self.running_var is None) r""" Buffers are only updated if they are to be tracked and we are in training mode. Thus they only need to be passed when the update should occur (i.e. in training mode when they are tracked), or when buffer stats are used for normalization (i.e. in eval mode when buffers are not None). """ x = F.batch_norm( x, # If buffers are not to be tracked, ensure that they won't be updated ( self.running_mean if not self.training or self.track_running_stats else None ), self.running_var if not self.training or self.track_running_stats else None, self.weight, self.bias, bn_training, exponential_average_factor, self.eps, ) x = self.drop(x) x = self.act(x) return x class SyncBatchNormAct(nn.SyncBatchNorm): # Thanks to Selim Seferbekov (https://github.com/rwightman/pytorch-image-models/issues/1254) # This is a quick workaround to support SyncBatchNorm for timm BatchNormAct2d layers # but ONLY when used in conjunction with the timm conversion function below. # Do not create this module directly or use the PyTorch conversion function. def forward(self, x: torch.Tensor) -> torch.Tensor: x = super().forward( x ) # SyncBN doesn't work with torchscript anyways, so this is fine if hasattr(self, "drop"): x = self.drop(x) if hasattr(self, "act"): x = self.act(x) return x def convert_sync_batchnorm(module, process_group=None): # convert both BatchNorm and BatchNormAct layers to Synchronized variants module_output = module if isinstance(module, torch.nn.modules.batchnorm._BatchNorm): if isinstance(module, BatchNormAct2d): # convert timm norm + act layer module_output = SyncBatchNormAct( module.num_features, module.eps, module.momentum, module.affine, module.track_running_stats, process_group=process_group, ) # set act and drop attr from the original module module_output.act = module.act module_output.drop = module.drop else: # convert standard BatchNorm layers module_output = torch.nn.SyncBatchNorm( module.num_features, module.eps, module.momentum, module.affine, module.track_running_stats, process_group, ) if module.affine: with torch.no_grad(): module_output.weight = module.weight module_output.bias = module.bias module_output.running_mean = module.running_mean module_output.running_var = module.running_var module_output.num_batches_tracked = module.num_batches_tracked if hasattr(module, "qconfig"): module_output.qconfig = module.qconfig for name, child in module.named_children(): module_output.add_module(name, convert_sync_batchnorm(child, process_group)) del module return module_output def _num_groups(num_channels, num_groups, group_size): if group_size: assert num_channels % group_size == 0 return num_channels // group_size return num_groups class GroupNormAct(nn.GroupNorm): # NOTE num_channel and num_groups order flipped for easier layer swaps / binding of fixed args def __init__( self, num_channels, num_groups=32, eps=1e-5, affine=True, group_size=None, apply_act=True, act_layer=nn.ReLU, inplace=True, drop_layer=None, ): super(GroupNormAct, self).__init__( _num_groups(num_channels, num_groups, group_size), num_channels, eps=eps, affine=affine, ) self.drop = drop_layer() if drop_layer is not None else nn.Identity() act_layer = get_act_layer(act_layer) # string -> nn.Module if act_layer is not None and apply_act: act_args = dict(inplace=True) if inplace else {} self.act = act_layer(**act_args) else: self.act = nn.Identity() self._fast_norm = is_fast_norm() def forward(self, x): if self._fast_norm: x = fast_group_norm(x, self.num_groups, self.weight, self.bias, self.eps) else: x = F.group_norm(x, self.num_groups, self.weight, self.bias, self.eps) x = self.drop(x) x = self.act(x) return x class LayerNormAct(nn.LayerNorm): def __init__( self, normalization_shape: Union[int, List[int], torch.Size], eps=1e-5, affine=True, apply_act=True, act_layer=nn.ReLU, inplace=True, drop_layer=None, ): super(LayerNormAct, self).__init__( normalization_shape, eps=eps, elementwise_affine=affine ) self.drop = drop_layer() if drop_layer is not None else nn.Identity() act_layer = get_act_layer(act_layer) # string -> nn.Module if act_layer is not None and apply_act: act_args = dict(inplace=True) if inplace else {} self.act = act_layer(**act_args) else: self.act = nn.Identity() self._fast_norm = is_fast_norm() def forward(self, x): if self._fast_norm: x = fast_layer_norm( x, self.normalized_shape, self.weight, self.bias, self.eps ) else: x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) x = self.drop(x) x = self.act(x) return x class LayerNormAct2d(nn.LayerNorm): def __init__( self, num_channels, eps=1e-5, affine=True, apply_act=True, act_layer=nn.ReLU, inplace=True, drop_layer=None, ): super(LayerNormAct2d, self).__init__( num_channels, eps=eps, elementwise_affine=affine ) self.drop = drop_layer() if drop_layer is not None else nn.Identity() act_layer = get_act_layer(act_layer) # string -> nn.Module if act_layer is not None and apply_act: act_args = dict(inplace=True) if inplace else {} self.act = act_layer(**act_args) else: self.act = nn.Identity() self._fast_norm = is_fast_norm() def forward(self, x): x = x.permute(0, 2, 3, 1) if self._fast_norm: x = fast_layer_norm( x, self.normalized_shape, self.weight, self.bias, self.eps ) else: x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) x = x.permute(0, 3, 1, 2) x = self.drop(x) x = self.act(x) return x ================================================ FILE: RVT/models/layers/maxvit/layers/padding.py ================================================ """ Padding Helpers Hacked together by / Copyright 2020 Ross Wightman """ import math from typing import List, Tuple import torch.nn.functional as F # Calculate symmetric padding for a convolution def get_padding(kernel_size: int, stride: int = 1, dilation: int = 1, **_) -> int: padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2 return padding # Calculate asymmetric TensorFlow-like 'SAME' padding for a convolution def get_same_padding(x: int, k: int, s: int, d: int): return max((math.ceil(x / s) - 1) * s + (k - 1) * d + 1 - x, 0) # Can SAME padding for given args be done statically? def is_static_pad(kernel_size: int, stride: int = 1, dilation: int = 1, **_): return stride == 1 and (dilation * (kernel_size - 1)) % 2 == 0 # Dynamically pad input x with 'SAME' padding for conv with specified args def pad_same(x, k: List[int], s: List[int], d: List[int] = (1, 1), value: float = 0): ih, iw = x.size()[-2:] pad_h, pad_w = get_same_padding(ih, k[0], s[0], d[0]), get_same_padding( iw, k[1], s[1], d[1] ) if pad_h > 0 or pad_w > 0: x = F.pad( x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2], value=value, ) return x def get_padding_value(padding, kernel_size, **kwargs) -> Tuple[Tuple, bool]: dynamic = False if isinstance(padding, str): # for any string padding, the padding will be calculated for you, one of three ways padding = padding.lower() if padding == "same": # TF compatible 'SAME' padding, has a performance and GPU memory allocation impact if is_static_pad(kernel_size, **kwargs): # static case, no extra overhead padding = get_padding(kernel_size, **kwargs) else: # dynamic 'SAME' padding, has runtime/GPU memory overhead padding = 0 dynamic = True elif padding == "valid": # 'VALID' padding, same as padding=0 padding = 0 else: # Default to PyTorch style 'same'-ish symmetric padding padding = get_padding(kernel_size, **kwargs) return padding, dynamic ================================================ FILE: RVT/models/layers/maxvit/layers/patch_embed.py ================================================ """ Image to Patch Embedding using Conv2d A convolution based approach to patchifying a 2D image w/ embedding projection. Based on the impl in https://github.com/google-research/vision_transformer Hacked together by / Copyright 2020 Ross Wightman """ from torch import nn as nn from .helpers import to_2tuple from .trace_utils import _assert class PatchEmbed(nn.Module): """2D Image to Patch Embedding""" def __init__( self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=None, flatten=True, ): super().__init__() img_size = to_2tuple(img_size) patch_size = to_2tuple(patch_size) self.img_size = img_size self.patch_size = patch_size self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) self.num_patches = self.grid_size[0] * self.grid_size[1] self.flatten = flatten self.proj = nn.Conv2d( in_chans, embed_dim, kernel_size=patch_size, stride=patch_size ) self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() def forward(self, x): B, C, H, W = x.shape _assert( H == self.img_size[0], f"Input image height ({H}) doesn't match model ({self.img_size[0]}).", ) _assert( W == self.img_size[1], f"Input image width ({W}) doesn't match model ({self.img_size[1]}).", ) x = self.proj(x) if self.flatten: x = x.flatten(2).transpose(1, 2) # BCHW -> BNC x = self.norm(x) return x ================================================ FILE: RVT/models/layers/maxvit/layers/pool2d_same.py ================================================ """ AvgPool2d w/ Same Padding Hacked together by / Copyright 2020 Ross Wightman """ import torch import torch.nn as nn import torch.nn.functional as F from typing import List, Tuple, Optional from .helpers import to_2tuple from .padding import pad_same, get_padding_value def avg_pool2d_same( x, kernel_size: List[int], stride: List[int], padding: List[int] = (0, 0), ceil_mode: bool = False, count_include_pad: bool = True, ): # FIXME how to deal with count_include_pad vs not for external padding? x = pad_same(x, kernel_size, stride) return F.avg_pool2d(x, kernel_size, stride, (0, 0), ceil_mode, count_include_pad) class AvgPool2dSame(nn.AvgPool2d): """Tensorflow like 'SAME' wrapper for 2D average pooling""" def __init__( self, kernel_size: int, stride=None, padding=0, ceil_mode=False, count_include_pad=True, ): kernel_size = to_2tuple(kernel_size) stride = to_2tuple(stride) super(AvgPool2dSame, self).__init__( kernel_size, stride, (0, 0), ceil_mode, count_include_pad ) def forward(self, x): x = pad_same(x, self.kernel_size, self.stride) return F.avg_pool2d( x, self.kernel_size, self.stride, self.padding, self.ceil_mode, self.count_include_pad, ) def max_pool2d_same( x, kernel_size: List[int], stride: List[int], padding: List[int] = (0, 0), dilation: List[int] = (1, 1), ceil_mode: bool = False, ): x = pad_same(x, kernel_size, stride, value=-float("inf")) return F.max_pool2d(x, kernel_size, stride, (0, 0), dilation, ceil_mode) class MaxPool2dSame(nn.MaxPool2d): """Tensorflow like 'SAME' wrapper for 2D max pooling""" def __init__( self, kernel_size: int, stride=None, padding=0, dilation=1, ceil_mode=False ): kernel_size = to_2tuple(kernel_size) stride = to_2tuple(stride) dilation = to_2tuple(dilation) super(MaxPool2dSame, self).__init__( kernel_size, stride, (0, 0), dilation, ceil_mode ) def forward(self, x): x = pad_same(x, self.kernel_size, self.stride, value=-float("inf")) return F.max_pool2d( x, self.kernel_size, self.stride, (0, 0), self.dilation, self.ceil_mode ) def create_pool2d(pool_type, kernel_size, stride=None, **kwargs): stride = stride or kernel_size padding = kwargs.pop("padding", "") padding, is_dynamic = get_padding_value( padding, kernel_size, stride=stride, **kwargs ) if is_dynamic: if pool_type == "avg": return AvgPool2dSame(kernel_size, stride=stride, **kwargs) elif pool_type == "max": return MaxPool2dSame(kernel_size, stride=stride, **kwargs) else: assert False, f"Unsupported pool type {pool_type}" else: if pool_type == "avg": return nn.AvgPool2d(kernel_size, stride=stride, padding=padding, **kwargs) elif pool_type == "max": return nn.MaxPool2d(kernel_size, stride=stride, padding=padding, **kwargs) else: assert False, f"Unsupported pool type {pool_type}" ================================================ FILE: RVT/models/layers/maxvit/layers/pos_embed.py ================================================ import math from typing import List, Tuple, Optional, Union import torch from torch import nn as nn def pixel_freq_bands( num_bands: int, max_freq: float = 224.0, linear_bands: bool = True, dtype: torch.dtype = torch.float32, device: Optional[torch.device] = None, ): if linear_bands: bands = torch.linspace(1.0, max_freq / 2, num_bands, dtype=dtype, device=device) else: bands = 2 ** torch.linspace( 0, math.log(max_freq, 2) - 1, num_bands, dtype=dtype, device=device ) return bands * torch.pi def inv_freq_bands( num_bands: int, temperature: float = 100000.0, step: int = 2, dtype: torch.dtype = torch.float32, device: Optional[torch.device] = None, ) -> torch.Tensor: inv_freq = 1.0 / ( temperature ** (torch.arange(0, num_bands, step, dtype=dtype, device=device) / num_bands) ) return inv_freq def build_sincos2d_pos_embed( feat_shape: List[int], dim: int = 64, temperature: float = 10000.0, reverse_coord: bool = False, interleave_sin_cos: bool = False, dtype: torch.dtype = torch.float32, device: Optional[torch.device] = None, ) -> torch.Tensor: """ Args: feat_shape: dim: temperature: reverse_coord: stack grid order W, H instead of H, W interleave_sin_cos: sin, cos, sin, cos stack instead of sin, sin, cos, cos dtype: device: Returns: """ assert ( dim % 4 == 0 ), "Embed dimension must be divisible by 4 for sin-cos 2D position embedding" pos_dim = dim // 4 bands = inv_freq_bands( pos_dim, temperature=temperature, step=1, dtype=dtype, device=device ) if reverse_coord: feat_shape = feat_shape[::-1] # stack W, H instead of H, W grid = ( torch.stack( torch.meshgrid( [torch.arange(s, device=device, dtype=dtype) for s in feat_shape] ) ) .flatten(1) .transpose(0, 1) ) pos2 = grid.unsqueeze(-1) * bands.unsqueeze(0) # FIXME add support for unflattened spatial dim? stack_dim = ( 2 if interleave_sin_cos else 1 ) # stack sin, cos, sin, cos instead of sin sin cos cos pos_emb = torch.stack([torch.sin(pos2), torch.cos(pos2)], dim=stack_dim).flatten(1) return pos_emb def build_fourier_pos_embed( feat_shape: List[int], bands: Optional[torch.Tensor] = None, num_bands: int = 64, max_res: int = 224, linear_bands: bool = False, include_grid: bool = False, concat_out: bool = True, in_pixels: bool = True, dtype: torch.dtype = torch.float32, device: Optional[torch.device] = None, ) -> List[torch.Tensor]: if bands is None: if in_pixels: bands = pixel_freq_bands( num_bands, float(max_res), linear_bands=linear_bands, dtype=dtype, device=device, ) else: bands = inv_freq_bands(num_bands, step=1, dtype=dtype, device=device) else: if device is None: device = bands.device if dtype is None: dtype = bands.dtype if in_pixels: grid = torch.stack( torch.meshgrid( [ torch.linspace(-1.0, 1.0, steps=s, device=device, dtype=dtype) for s in feat_shape ] ), dim=-1, ) else: grid = torch.stack( torch.meshgrid( [torch.arange(s, device=device, dtype=dtype) for s in feat_shape] ), dim=-1, ) grid = grid.unsqueeze(-1) pos = grid * bands pos_sin, pos_cos = pos.sin(), pos.cos() out = (grid, pos_sin, pos_cos) if include_grid else (pos_sin, pos_cos) # FIXME torchscript doesn't like multiple return types, probably need to always cat? if concat_out: out = torch.cat(out, dim=-1) return out class FourierEmbed(nn.Module): def __init__( self, max_res: int = 224, num_bands: int = 64, concat_grid=True, keep_spatial=False, ): super().__init__() self.max_res = max_res self.num_bands = num_bands self.concat_grid = concat_grid self.keep_spatial = keep_spatial self.register_buffer( "bands", pixel_freq_bands(max_res, num_bands), persistent=False ) def forward(self, x): B, C = x.shape[:2] feat_shape = x.shape[2:] emb = build_fourier_pos_embed( feat_shape, self.bands, include_grid=self.concat_grid, dtype=x.dtype, device=x.device, ) emb = emb.transpose(-1, -2).flatten(len(feat_shape)) batch_expand = (B,) + (-1,) * (x.ndim - 1) # FIXME support nD if self.keep_spatial: x = torch.cat( [x, emb.unsqueeze(0).expand(batch_expand).permute(0, 3, 1, 2)], dim=1 ) else: x = torch.cat( [x.permute(0, 2, 3, 1), emb.unsqueeze(0).expand(batch_expand)], dim=-1 ) x = x.reshape(B, feat_shape.numel(), -1) return x def rot(x): return torch.stack([-x[..., 1::2], x[..., ::2]], -1).reshape(x.shape) def apply_rot_embed(x: torch.Tensor, sin_emb, cos_emb): return x * cos_emb + rot(x) * sin_emb def apply_rot_embed_list(x: List[torch.Tensor], sin_emb, cos_emb): if isinstance(x, torch.Tensor): x = [x] return [t * cos_emb + rot(t) * sin_emb for t in x] def apply_rot_embed_split(x: torch.Tensor, emb): split = emb.shape[-1] // 2 return x * emb[:, :split] + rot(x) * emb[:, split:] def build_rotary_pos_embed( feat_shape: List[int], bands: Optional[torch.Tensor] = None, dim: int = 64, max_freq: float = 224, linear_bands: bool = False, dtype: torch.dtype = torch.float32, device: Optional[torch.device] = None, ): """ NOTE: shape arg should include spatial dim only """ feat_shape = torch.Size(feat_shape) sin_emb, cos_emb = build_fourier_pos_embed( feat_shape, bands=bands, num_bands=dim // 4, max_res=max_freq, linear_bands=linear_bands, concat_out=False, device=device, dtype=dtype, ) N = feat_shape.numel() sin_emb = sin_emb.reshape(N, -1).repeat_interleave(2, -1) cos_emb = cos_emb.reshape(N, -1).repeat_interleave(2, -1) return sin_emb, cos_emb class RotaryEmbedding(nn.Module): """Rotary position embedding NOTE: This is my initial attempt at impl rotary embedding for spatial use, it has not been well tested, and will likely change. It will be moved to its own file. The following impl/resources were referenced for this impl: * https://github.com/lucidrains/vit-pytorch/blob/6f3a5fcf0bca1c5ec33a35ef48d97213709df4ba/vit_pytorch/rvt.py * https://blog.eleuther.ai/rotary-embeddings/ """ def __init__(self, dim, max_res=224, linear_bands: bool = False): super().__init__() self.dim = dim self.register_buffer( "bands", pixel_freq_bands(dim // 4, max_res, linear_bands=linear_bands), persistent=False, ) def get_embed(self, shape: List[int]): return build_rotary_pos_embed(shape, self.bands) def forward(self, x): # assuming channel-first tensor where spatial dim are >= 2 sin_emb, cos_emb = self.get_embed(x.shape[2:]) return apply_rot_embed(x, sin_emb, cos_emb) ================================================ FILE: RVT/models/layers/maxvit/layers/selective_kernel.py ================================================ """ Selective Kernel Convolution/Attention Paper: Selective Kernel Networks (https://arxiv.org/abs/1903.06586) Hacked together by / Copyright 2020 Ross Wightman """ import torch from torch import nn as nn from .conv_bn_act import ConvNormActAa from .helpers import make_divisible from .trace_utils import _assert def _kernel_valid(k): if isinstance(k, (list, tuple)): for ki in k: return _kernel_valid(ki) assert k >= 3 and k % 2 class SelectiveKernelAttn(nn.Module): def __init__( self, channels, num_paths=2, attn_channels=32, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, ): """Selective Kernel Attention Module Selective Kernel attention mechanism factored out into its own module. """ super(SelectiveKernelAttn, self).__init__() self.num_paths = num_paths self.fc_reduce = nn.Conv2d(channels, attn_channels, kernel_size=1, bias=False) self.bn = norm_layer(attn_channels) self.act = act_layer(inplace=True) self.fc_select = nn.Conv2d( attn_channels, channels * num_paths, kernel_size=1, bias=False ) def forward(self, x): _assert(x.shape[1] == self.num_paths, "") x = x.sum(1).mean((2, 3), keepdim=True) x = self.fc_reduce(x) x = self.bn(x) x = self.act(x) x = self.fc_select(x) B, C, H, W = x.shape x = x.view(B, self.num_paths, C // self.num_paths, H, W) x = torch.softmax(x, dim=1) return x class SelectiveKernel(nn.Module): def __init__( self, in_channels, out_channels=None, kernel_size=None, stride=1, dilation=1, groups=1, rd_ratio=1.0 / 16, rd_channels=None, rd_divisor=8, keep_3x3=True, split_input=True, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, aa_layer=None, drop_layer=None, ): """Selective Kernel Convolution Module As described in Selective Kernel Networks (https://arxiv.org/abs/1903.06586) with some modifications. Largest change is the input split, which divides the input channels across each convolution path, this can be viewed as a grouping of sorts, but the output channel counts expand to the module level value. This keeps the parameter count from ballooning when the convolutions themselves don't have groups, but still provides a noteworthy increase in performance over similar param count models without this attention layer. -Ross W Args: in_channels (int): module input (feature) channel count out_channels (int): module output (feature) channel count kernel_size (int, list): kernel size for each convolution branch stride (int): stride for convolutions dilation (int): dilation for module as a whole, impacts dilation of each branch groups (int): number of groups for each branch rd_ratio (int, float): reduction factor for attention features keep_3x3 (bool): keep all branch convolution kernels as 3x3, changing larger kernels for dilations split_input (bool): split input channels evenly across each convolution branch, keeps param count lower, can be viewed as grouping by path, output expands to module out_channels count act_layer (nn.Module): activation layer to use norm_layer (nn.Module): batchnorm/norm layer to use aa_layer (nn.Module): anti-aliasing module drop_layer (nn.Module): spatial drop module in convs (drop block, etc) """ super(SelectiveKernel, self).__init__() out_channels = out_channels or in_channels kernel_size = kernel_size or [ 3, 5, ] # default to one 3x3 and one 5x5 branch. 5x5 -> 3x3 + dilation _kernel_valid(kernel_size) if not isinstance(kernel_size, list): kernel_size = [kernel_size] * 2 if keep_3x3: dilation = [dilation * (k - 1) // 2 for k in kernel_size] kernel_size = [3] * len(kernel_size) else: dilation = [dilation] * len(kernel_size) self.num_paths = len(kernel_size) self.in_channels = in_channels self.out_channels = out_channels self.split_input = split_input if self.split_input: assert in_channels % self.num_paths == 0 in_channels = in_channels // self.num_paths groups = min(out_channels, groups) conv_kwargs = dict( stride=stride, groups=groups, act_layer=act_layer, norm_layer=norm_layer, aa_layer=aa_layer, drop_layer=drop_layer, ) self.paths = nn.ModuleList( [ ConvNormActAa( in_channels, out_channels, kernel_size=k, dilation=d, **conv_kwargs ) for k, d in zip(kernel_size, dilation) ] ) attn_channels = rd_channels or make_divisible( out_channels * rd_ratio, divisor=rd_divisor ) self.attn = SelectiveKernelAttn(out_channels, self.num_paths, attn_channels) def forward(self, x): if self.split_input: x_split = torch.split(x, self.in_channels // self.num_paths, 1) x_paths = [op(x_split[i]) for i, op in enumerate(self.paths)] else: x_paths = [op(x) for op in self.paths] x = torch.stack(x_paths, dim=1) x_attn = self.attn(x) x = x * x_attn x = torch.sum(x, dim=1) return x ================================================ FILE: RVT/models/layers/maxvit/layers/separable_conv.py ================================================ """ Depthwise Separable Conv Modules Basic DWS convs. Other variations of DWS exist with batch norm or activations between the DW and PW convs such as the Depthwise modules in MobileNetV2 / EfficientNet and Xception. Hacked together by / Copyright 2020 Ross Wightman """ from torch import nn as nn from .create_conv2d import create_conv2d from .create_norm_act import get_norm_act_layer class SeparableConvNormAct(nn.Module): """Separable Conv w/ trailing Norm and Activation""" def __init__( self, in_channels, out_channels, kernel_size=3, stride=1, dilation=1, padding="", bias=False, channel_multiplier=1.0, pw_kernel_size=1, norm_layer=nn.BatchNorm2d, act_layer=nn.ReLU, apply_act=True, drop_layer=None, ): super(SeparableConvNormAct, self).__init__() self.conv_dw = create_conv2d( in_channels, int(in_channels * channel_multiplier), kernel_size, stride=stride, dilation=dilation, padding=padding, depthwise=True, ) self.conv_pw = create_conv2d( int(in_channels * channel_multiplier), out_channels, pw_kernel_size, padding=padding, bias=bias, ) norm_act_layer = get_norm_act_layer(norm_layer, act_layer) norm_kwargs = dict(drop_layer=drop_layer) if drop_layer is not None else {} self.bn = norm_act_layer(out_channels, apply_act=apply_act, **norm_kwargs) @property def in_channels(self): return self.conv_dw.in_channels @property def out_channels(self): return self.conv_pw.out_channels def forward(self, x): x = self.conv_dw(x) x = self.conv_pw(x) x = self.bn(x) return x SeparableConvBnAct = SeparableConvNormAct class SeparableConv2d(nn.Module): """Separable Conv""" def __init__( self, in_channels, out_channels, kernel_size=3, stride=1, dilation=1, padding="", bias=False, channel_multiplier=1.0, pw_kernel_size=1, ): super(SeparableConv2d, self).__init__() self.conv_dw = create_conv2d( in_channels, int(in_channels * channel_multiplier), kernel_size, stride=stride, dilation=dilation, padding=padding, depthwise=True, ) self.conv_pw = create_conv2d( int(in_channels * channel_multiplier), out_channels, pw_kernel_size, padding=padding, bias=bias, ) @property def in_channels(self): return self.conv_dw.in_channels @property def out_channels(self): return self.conv_pw.out_channels def forward(self, x): x = self.conv_dw(x) x = self.conv_pw(x) return x ================================================ FILE: RVT/models/layers/maxvit/layers/space_to_depth.py ================================================ import torch import torch.nn as nn class SpaceToDepth(nn.Module): def __init__(self, block_size=4): super().__init__() assert block_size == 4 self.bs = block_size def forward(self, x): N, C, H, W = x.size() x = x.view( N, C, H // self.bs, self.bs, W // self.bs, self.bs ) # (N, C, H//bs, bs, W//bs, bs) x = x.permute(0, 3, 5, 1, 2, 4).contiguous() # (N, bs, bs, C, H//bs, W//bs) x = x.view( N, C * (self.bs**2), H // self.bs, W // self.bs ) # (N, C*bs^2, H//bs, W//bs) return x @torch.jit.script class SpaceToDepthJit(object): def __call__(self, x: torch.Tensor): # assuming hard-coded that block_size==4 for acceleration N, C, H, W = x.size() x = x.view(N, C, H // 4, 4, W // 4, 4) # (N, C, H//bs, bs, W//bs, bs) x = x.permute(0, 3, 5, 1, 2, 4).contiguous() # (N, bs, bs, C, H//bs, W//bs) x = x.view(N, C * 16, H // 4, W // 4) # (N, C*bs^2, H//bs, W//bs) return x class SpaceToDepthModule(nn.Module): def __init__(self, no_jit=False): super().__init__() if not no_jit: self.op = SpaceToDepthJit() else: self.op = SpaceToDepth() def forward(self, x): return self.op(x) class DepthToSpace(nn.Module): def __init__(self, block_size): super().__init__() self.bs = block_size def forward(self, x): N, C, H, W = x.size() x = x.view( N, self.bs, self.bs, C // (self.bs**2), H, W ) # (N, bs, bs, C//bs^2, H, W) x = x.permute(0, 3, 4, 1, 5, 2).contiguous() # (N, C//bs^2, H, bs, W, bs) x = x.view( N, C // (self.bs**2), H * self.bs, W * self.bs ) # (N, C//bs^2, H * bs, W * bs) return x ================================================ FILE: RVT/models/layers/maxvit/layers/split_attn.py ================================================ """ Split Attention Conv2d (for ResNeSt Models) Paper: `ResNeSt: Split-Attention Networks` - /https://arxiv.org/abs/2004.08955 Adapted from original PyTorch impl at https://github.com/zhanghang1989/ResNeSt Modified for torchscript compat, performance, and consistency with timm by Ross Wightman """ import torch import torch.nn.functional as F from torch import nn from .helpers import make_divisible class RadixSoftmax(nn.Module): def __init__(self, radix, cardinality): super(RadixSoftmax, self).__init__() self.radix = radix self.cardinality = cardinality def forward(self, x): batch = x.size(0) if self.radix > 1: x = x.view(batch, self.cardinality, self.radix, -1).transpose(1, 2) x = F.softmax(x, dim=1) x = x.reshape(batch, -1) else: x = torch.sigmoid(x) return x class SplitAttn(nn.Module): """Split-Attention (aka Splat)""" def __init__( self, in_channels, out_channels=None, kernel_size=3, stride=1, padding=None, dilation=1, groups=1, bias=False, radix=2, rd_ratio=0.25, rd_channels=None, rd_divisor=8, act_layer=nn.ReLU, norm_layer=None, drop_layer=None, **kwargs ): super(SplitAttn, self).__init__() out_channels = out_channels or in_channels self.radix = radix mid_chs = out_channels * radix if rd_channels is None: attn_chs = make_divisible( in_channels * radix * rd_ratio, min_value=32, divisor=rd_divisor ) else: attn_chs = rd_channels * radix padding = kernel_size // 2 if padding is None else padding self.conv = nn.Conv2d( in_channels, mid_chs, kernel_size, stride, padding, dilation, groups=groups * radix, bias=bias, **kwargs ) self.bn0 = norm_layer(mid_chs) if norm_layer else nn.Identity() self.drop = drop_layer() if drop_layer is not None else nn.Identity() self.act0 = act_layer(inplace=True) self.fc1 = nn.Conv2d(out_channels, attn_chs, 1, groups=groups) self.bn1 = norm_layer(attn_chs) if norm_layer else nn.Identity() self.act1 = act_layer(inplace=True) self.fc2 = nn.Conv2d(attn_chs, mid_chs, 1, groups=groups) self.rsoftmax = RadixSoftmax(radix, groups) def forward(self, x): x = self.conv(x) x = self.bn0(x) x = self.drop(x) x = self.act0(x) B, RC, H, W = x.shape if self.radix > 1: x = x.reshape((B, self.radix, RC // self.radix, H, W)) x_gap = x.sum(dim=1) else: x_gap = x x_gap = x_gap.mean((2, 3), keepdim=True) x_gap = self.fc1(x_gap) x_gap = self.bn1(x_gap) x_gap = self.act1(x_gap) x_attn = self.fc2(x_gap) x_attn = self.rsoftmax(x_attn).view(B, -1, 1, 1) if self.radix > 1: out = (x * x_attn.reshape((B, self.radix, RC // self.radix, 1, 1))).sum( dim=1 ) else: out = x * x_attn return out.contiguous() ================================================ FILE: RVT/models/layers/maxvit/layers/split_batchnorm.py ================================================ """ Split BatchNorm A PyTorch BatchNorm layer that splits input batch into N equal parts and passes each through a separate BN layer. The first split is passed through the parent BN layers with weight/bias keys the same as the original BN. All other splits pass through BN sub-layers under the '.aux_bn' namespace. This allows easily removing the auxiliary BN layers after training to efficiently achieve the 'Auxiliary BatchNorm' as described in the AdvProp Paper, section 4.2, 'Disentangled Learning via An Auxiliary BN' Hacked together by / Copyright 2020 Ross Wightman """ import torch import torch.nn as nn class SplitBatchNorm2d(torch.nn.BatchNorm2d): def __init__( self, num_features, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True, num_splits=2, ): super().__init__(num_features, eps, momentum, affine, track_running_stats) assert ( num_splits > 1 ), "Should have at least one aux BN layer (num_splits at least 2)" self.num_splits = num_splits self.aux_bn = nn.ModuleList( [ nn.BatchNorm2d(num_features, eps, momentum, affine, track_running_stats) for _ in range(num_splits - 1) ] ) def forward(self, input: torch.Tensor): if self.training: # aux BN only relevant while training split_size = input.shape[0] // self.num_splits assert ( input.shape[0] == split_size * self.num_splits ), "batch size must be evenly divisible by num_splits" split_input = input.split(split_size) x = [super().forward(split_input[0])] for i, a in enumerate(self.aux_bn): x.append(a(split_input[i + 1])) return torch.cat(x, dim=0) else: return super().forward(input) def convert_splitbn_model(module, num_splits=2): """ Recursively traverse module and its children to replace all instances of ``torch.nn.modules.batchnorm._BatchNorm`` with `SplitBatchnorm2d`. Args: module (torch.nn.Module): input module num_splits: number of separate batchnorm layers to split input across Example:: >>> # model is an instance of torch.nn.Module >>> model = timm.models.convert_splitbn_model(model, num_splits=2) """ mod = module if isinstance(module, torch.nn.modules.instancenorm._InstanceNorm): return module if isinstance(module, torch.nn.modules.batchnorm._BatchNorm): mod = SplitBatchNorm2d( module.num_features, module.eps, module.momentum, module.affine, module.track_running_stats, num_splits=num_splits, ) mod.running_mean = module.running_mean mod.running_var = module.running_var mod.num_batches_tracked = module.num_batches_tracked if module.affine: mod.weight.data = module.weight.data.clone().detach() mod.bias.data = module.bias.data.clone().detach() for aux in mod.aux_bn: aux.running_mean = module.running_mean.clone() aux.running_var = module.running_var.clone() aux.num_batches_tracked = module.num_batches_tracked.clone() if module.affine: aux.weight.data = module.weight.data.clone().detach() aux.bias.data = module.bias.data.clone().detach() for name, child in module.named_children(): mod.add_module(name, convert_splitbn_model(child, num_splits=num_splits)) del module return mod ================================================ FILE: RVT/models/layers/maxvit/layers/squeeze_excite.py ================================================ """ Squeeze-and-Excitation Channel Attention An SE implementation originally based on PyTorch SE-Net impl. Has since evolved with additional functionality / configuration. Paper: `Squeeze-and-Excitation Networks` - https://arxiv.org/abs/1709.01507 Also included is Effective Squeeze-Excitation (ESE). Paper: `CenterMask : Real-Time Anchor-Free Instance Segmentation` - https://arxiv.org/abs/1911.06667 Hacked together by / Copyright 2021 Ross Wightman """ from torch import nn as nn from .create_act import create_act_layer from .helpers import make_divisible class SEModule(nn.Module): """SE Module as defined in original SE-Nets with a few additions Additions include: * divisor can be specified to keep channels % div == 0 (default: 8) * reduction channels can be specified directly by arg (if rd_channels is set) * reduction channels can be specified by float rd_ratio (default: 1/16) * global max pooling can be added to the squeeze aggregation * customizable activation, normalization, and gate layer """ def __init__( self, channels, rd_ratio=1.0 / 16, rd_channels=None, rd_divisor=8, add_maxpool=False, bias=True, act_layer=nn.ReLU, norm_layer=None, gate_layer="sigmoid", ): super(SEModule, self).__init__() self.add_maxpool = add_maxpool if not rd_channels: rd_channels = make_divisible( channels * rd_ratio, rd_divisor, round_limit=0.0 ) self.fc1 = nn.Conv2d(channels, rd_channels, kernel_size=1, bias=bias) self.bn = norm_layer(rd_channels) if norm_layer else nn.Identity() self.act = create_act_layer(act_layer, inplace=True) self.fc2 = nn.Conv2d(rd_channels, channels, kernel_size=1, bias=bias) self.gate = create_act_layer(gate_layer) def forward(self, x): x_se = x.mean((2, 3), keepdim=True) if self.add_maxpool: # experimental codepath, may remove or change x_se = 0.5 * x_se + 0.5 * x.amax((2, 3), keepdim=True) x_se = self.fc1(x_se) x_se = self.act(self.bn(x_se)) x_se = self.fc2(x_se) return x * self.gate(x_se) SqueezeExcite = SEModule # alias class EffectiveSEModule(nn.Module): """'Effective Squeeze-Excitation From `CenterMask : Real-Time Anchor-Free Instance Segmentation` - https://arxiv.org/abs/1911.06667 """ def __init__(self, channels, add_maxpool=False, gate_layer="hard_sigmoid", **_): super(EffectiveSEModule, self).__init__() self.add_maxpool = add_maxpool self.fc = nn.Conv2d(channels, channels, kernel_size=1, padding=0) self.gate = create_act_layer(gate_layer) def forward(self, x): x_se = x.mean((2, 3), keepdim=True) if self.add_maxpool: # experimental codepath, may remove or change x_se = 0.5 * x_se + 0.5 * x.amax((2, 3), keepdim=True) x_se = self.fc(x_se) return x * self.gate(x_se) EffectiveSqueezeExcite = EffectiveSEModule # alias ================================================ FILE: RVT/models/layers/maxvit/layers/std_conv.py ================================================ """ Convolution with Weight Standardization (StdConv and ScaledStdConv) StdConv: @article{weightstandardization, author = {Siyuan Qiao and Huiyu Wang and Chenxi Liu and Wei Shen and Alan Yuille}, title = {Weight Standardization}, journal = {arXiv preprint arXiv:1903.10520}, year = {2019}, } Code: https://github.com/joe-siyuan-qiao/WeightStandardization ScaledStdConv: Paper: `Characterizing signal propagation to close the performance gap in unnormalized ResNets` - https://arxiv.org/abs/2101.08692 Official Deepmind JAX code: https://github.com/deepmind/deepmind-research/tree/master/nfnets Hacked together by / copyright Ross Wightman, 2021. """ import torch import torch.nn as nn import torch.nn.functional as F from .padding import get_padding, get_padding_value, pad_same class StdConv2d(nn.Conv2d): """Conv2d with Weight Standardization. Used for BiT ResNet-V2 models. Paper: `Micro-Batch Training with Batch-Channel Normalization and Weight Standardization` - https://arxiv.org/abs/1903.10520v2 """ def __init__( self, in_channel, out_channels, kernel_size, stride=1, padding=None, dilation=1, groups=1, bias=False, eps=1e-6, ): if padding is None: padding = get_padding(kernel_size, stride, dilation) super().__init__( in_channel, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias, ) self.eps = eps def forward(self, x): weight = F.batch_norm( self.weight.reshape(1, self.out_channels, -1), None, None, training=True, momentum=0.0, eps=self.eps, ).reshape_as(self.weight) x = F.conv2d( x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups ) return x class StdConv2dSame(nn.Conv2d): """Conv2d with Weight Standardization. TF compatible SAME padding. Used for ViT Hybrid model. Paper: `Micro-Batch Training with Batch-Channel Normalization and Weight Standardization` - https://arxiv.org/abs/1903.10520v2 """ def __init__( self, in_channel, out_channels, kernel_size, stride=1, padding="SAME", dilation=1, groups=1, bias=False, eps=1e-6, ): padding, is_dynamic = get_padding_value( padding, kernel_size, stride=stride, dilation=dilation ) super().__init__( in_channel, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias, ) self.same_pad = is_dynamic self.eps = eps def forward(self, x): if self.same_pad: x = pad_same(x, self.kernel_size, self.stride, self.dilation) weight = F.batch_norm( self.weight.reshape(1, self.out_channels, -1), None, None, training=True, momentum=0.0, eps=self.eps, ).reshape_as(self.weight) x = F.conv2d( x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups ) return x class ScaledStdConv2d(nn.Conv2d): """Conv2d layer with Scaled Weight Standardization. Paper: `Characterizing signal propagation to close the performance gap in unnormalized ResNets` - https://arxiv.org/abs/2101.08692 NOTE: the operations used in this impl differ slightly from the DeepMind Haiku impl. The impact is minor. """ def __init__( self, in_channels, out_channels, kernel_size, stride=1, padding=None, dilation=1, groups=1, bias=True, gamma=1.0, eps=1e-6, gain_init=1.0, ): if padding is None: padding = get_padding(kernel_size, stride, dilation) super().__init__( in_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias, ) self.gain = nn.Parameter(torch.full((self.out_channels, 1, 1, 1), gain_init)) self.scale = gamma * self.weight[0].numel() ** -0.5 # gamma * 1 / sqrt(fan-in) self.eps = eps def forward(self, x): weight = F.batch_norm( self.weight.reshape(1, self.out_channels, -1), None, None, weight=(self.gain * self.scale).view(-1), training=True, momentum=0.0, eps=self.eps, ).reshape_as(self.weight) return F.conv2d( x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups ) class ScaledStdConv2dSame(nn.Conv2d): """Conv2d layer with Scaled Weight Standardization and Tensorflow-like SAME padding support Paper: `Characterizing signal propagation to close the performance gap in unnormalized ResNets` - https://arxiv.org/abs/2101.08692 NOTE: the operations used in this impl differ slightly from the DeepMind Haiku impl. The impact is minor. """ def __init__( self, in_channels, out_channels, kernel_size, stride=1, padding="SAME", dilation=1, groups=1, bias=True, gamma=1.0, eps=1e-6, gain_init=1.0, ): padding, is_dynamic = get_padding_value( padding, kernel_size, stride=stride, dilation=dilation ) super().__init__( in_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias, ) self.gain = nn.Parameter(torch.full((self.out_channels, 1, 1, 1), gain_init)) self.scale = gamma * self.weight[0].numel() ** -0.5 self.same_pad = is_dynamic self.eps = eps def forward(self, x): if self.same_pad: x = pad_same(x, self.kernel_size, self.stride, self.dilation) weight = F.batch_norm( self.weight.reshape(1, self.out_channels, -1), None, None, weight=(self.gain * self.scale).view(-1), training=True, momentum=0.0, eps=self.eps, ).reshape_as(self.weight) return F.conv2d( x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups ) ================================================ FILE: RVT/models/layers/maxvit/layers/test_time_pool.py ================================================ """ Test Time Pooling (Average-Max Pool) Hacked together by / Copyright 2020 Ross Wightman """ import logging from torch import nn import torch.nn.functional as F from .adaptive_avgmax_pool import adaptive_avgmax_pool2d _logger = logging.getLogger(__name__) class TestTimePoolHead(nn.Module): def __init__(self, base, original_pool=7): super(TestTimePoolHead, self).__init__() self.base = base self.original_pool = original_pool base_fc = self.base.get_classifier() if isinstance(base_fc, nn.Conv2d): self.fc = base_fc else: self.fc = nn.Conv2d( self.base.num_features, self.base.num_classes, kernel_size=1, bias=True ) self.fc.weight.data.copy_(base_fc.weight.data.view(self.fc.weight.size())) self.fc.bias.data.copy_(base_fc.bias.data.view(self.fc.bias.size())) self.base.reset_classifier(0) # delete original fc layer def forward(self, x): x = self.base.forward_features(x) x = F.avg_pool2d(x, kernel_size=self.original_pool, stride=1) x = self.fc(x) x = adaptive_avgmax_pool2d(x, 1) return x.view(x.size(0), -1) def apply_test_time_pool(model, config, use_test_size=False): test_time_pool = False if not hasattr(model, "default_cfg") or not model.default_cfg: return model, False if use_test_size and "test_input_size" in model.default_cfg: df_input_size = model.default_cfg["test_input_size"] else: df_input_size = model.default_cfg["input_size"] if ( config["input_size"][-1] > df_input_size[-1] and config["input_size"][-2] > df_input_size[-2] ): _logger.info( "Target input size %s > pretrained default %s, using test time pooling" % (str(config["input_size"][-2:]), str(df_input_size[-2:])) ) model = TestTimePoolHead(model, original_pool=model.default_cfg["pool_size"]) test_time_pool = True return model, test_time_pool ================================================ FILE: RVT/models/layers/maxvit/layers/trace_utils.py ================================================ try: from torch import _assert except ImportError: def _assert(condition: bool, message: str): assert condition, message def _float_to_int(x: float) -> int: """ Symbolic tracing helper to substitute for inbuilt `int`. Hint: Inbuilt `int` can't accept an argument of type `Proxy` """ return int(x) ================================================ FILE: RVT/models/layers/maxvit/layers/weight_init.py ================================================ import torch import math import warnings from torch.nn.init import _calculate_fan_in_and_fan_out def _trunc_normal_(tensor, mean, std, a, b): # Cut & paste from PyTorch official master until it's in a few official releases - RW # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf def norm_cdf(x): # Computes standard normal cumulative distribution function return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0 if (mean < a - 2 * std) or (mean > b + 2 * std): warnings.warn( "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " "The distribution of values may be incorrect.", stacklevel=2, ) # Values are generated by using a truncated uniform distribution and # then using the inverse CDF for the normal distribution. # Get upper and lower cdf values l = norm_cdf((a - mean) / std) u = norm_cdf((b - mean) / std) # Uniformly fill tensor with values from [l, u], then translate to # [2l-1, 2u-1]. tensor.uniform_(2 * l - 1, 2 * u - 1) # Use inverse cdf transform for normal distribution to get truncated # standard normal tensor.erfinv_() # Transform to proper mean, std tensor.mul_(std * math.sqrt(2.0)) tensor.add_(mean) # Clamp to ensure it's in the proper range tensor.clamp_(min=a, max=b) return tensor def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0): # type: (Tensor, float, float, float, float) -> Tensor r"""Fills the input Tensor with values drawn from a truncated normal distribution. The values are effectively drawn from the normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` with values outside :math:`[a, b]` redrawn until they are within the bounds. The method used for generating the random values works best when :math:`a \leq \text{mean} \leq b`. NOTE: this impl is similar to the PyTorch trunc_normal_, the bounds [a, b] are applied while sampling the normal with mean/std applied, therefore a, b args should be adjusted to match the range of mean, std args. Args: tensor: an n-dimensional `torch.Tensor` mean: the mean of the normal distribution std: the standard deviation of the normal distribution a: the minimum cutoff value b: the maximum cutoff value Examples: >>> w = torch.empty(3, 5) >>> nn.init.trunc_normal_(w) """ with torch.no_grad(): return _trunc_normal_(tensor, mean, std, a, b) def trunc_normal_tf_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0): # type: (Tensor, float, float, float, float) -> Tensor r"""Fills the input Tensor with values drawn from a truncated normal distribution. The values are effectively drawn from the normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` with values outside :math:`[a, b]` redrawn until they are within the bounds. The method used for generating the random values works best when :math:`a \leq \text{mean} \leq b`. NOTE: this 'tf' variant behaves closer to Tensorflow / JAX impl where the bounds [a, b] are applied when sampling the normal distribution with mean=0, std=1.0 and the result is subsquently scaled and shifted by the mean and std args. Args: tensor: an n-dimensional `torch.Tensor` mean: the mean of the normal distribution std: the standard deviation of the normal distribution a: the minimum cutoff value b: the maximum cutoff value Examples: >>> w = torch.empty(3, 5) >>> nn.init.trunc_normal_(w) """ with torch.no_grad(): _trunc_normal_(tensor, 0, 1.0, a, b) tensor.mul_(std).add_(mean) return tensor def variance_scaling_(tensor, scale=1.0, mode="fan_in", distribution="normal"): fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor) if mode == "fan_in": denom = fan_in elif mode == "fan_out": denom = fan_out elif mode == "fan_avg": denom = (fan_in + fan_out) / 2 variance = scale / denom if distribution == "truncated_normal": # constant is stddev of standard normal truncated to (-2, 2) trunc_normal_tf_(tensor, std=math.sqrt(variance) / 0.87962566103423978) elif distribution == "normal": with torch.no_grad(): tensor.normal_(std=math.sqrt(variance)) elif distribution == "uniform": bound = math.sqrt(3 * variance) with torch.no_grad(): tensor.uniform_(-bound, bound) else: raise ValueError(f"invalid distribution {distribution}") def lecun_normal_(tensor): variance_scaling_(tensor, mode="fan_in", distribution="truncated_normal") ================================================ FILE: RVT/models/layers/maxvit/maxvit.py ================================================ """ Part of this code stems from rwightman's MaxVit implementation: https://github.com/huggingface/pytorch-image-models/blob/1885bdc4318cc3be459981ea1a26cd862220864d/timm/models/maxxvit.py that is: - LayerScale - PartitionAttentionCl - window* - grid* - SelfAttentionCl """ from enum import Enum, auto from functools import partial from typing import Optional, Union, Tuple, List, Type import math import torch from omegaconf import DictConfig from torch import nn from .layers import DropPath, LayerNorm from .layers import get_act_layer, get_norm_layer from .layers import to_2tuple, _assert class PartitionType(Enum): WINDOW = auto() GRID = auto() def nChw_2_nhwC(x: torch.Tensor): """N C H W -> N H W C""" assert x.ndim == 4 return x.permute(0, 2, 3, 1) def nhwC_2_nChw(x: torch.Tensor): """N H W C -> N C H W""" assert x.ndim == 4 return x.permute(0, 3, 1, 2) class LayerScale(nn.Module): def __init__(self, dim: int, init_values: float = 1e-5, inplace: bool = False): super().__init__() self.inplace = inplace self.gamma = nn.Parameter(init_values * torch.ones(dim)) def forward(self, x): gamma = self.gamma return x.mul_(gamma) if self.inplace else x * gamma class GLU(nn.Module): def __init__( self, dim_in: int, dim_out: int, channel_last: bool, act_layer: Type[nn.Module], bias: bool = True, ): super().__init__() # Different activation functions / versions of the gated linear unit: # - ReGLU: Relu # - SwiGLU: Swish/SiLU # - GeGLU: GELU # - GLU: Sigmoid # seem to be the most promising once. # Extensive quantitative eval in table 1: https://arxiv.org/abs/2102.11972 # Section 2 for explanation and implementation details: https://arxiv.org/abs/2002.05202 # NOTE: Pytorch has a native GLU implementation: https://pytorch.org/docs/stable/generated/torch.nn.GLU.html?highlight=glu#torch.nn.GLU proj_out_dim = dim_out * 2 self.proj = ( nn.Linear(dim_in, proj_out_dim, bias=bias) if channel_last else nn.Conv2d(dim_in, proj_out_dim, kernel_size=1, stride=1, bias=bias) ) self.channel_dim = -1 if channel_last else 1 self.act_layer = act_layer() def forward(self, x: torch.Tensor): x, gate = torch.tensor_split(self.proj(x), 2, dim=self.channel_dim) return x * self.act_layer(gate) class MLP(nn.Module): def __init__( self, dim: int, channel_last: bool, expansion_ratio: int, act_layer: Type[nn.Module], gated: bool = True, bias: bool = True, drop_prob: float = 0.0, ): super().__init__() inner_dim = int(dim * expansion_ratio) if gated: # To keep the number of parameters (approx) constant regardless of whether glu == True # Section 2 for explanation: https://arxiv.org/abs/2002.05202 # inner_dim = round(inner_dim * 2 / 3) # inner_dim = math.ceil(inner_dim * 2 / 3 / 32) * 32 # multiple of 32 # inner_dim = round(inner_dim * 2 / 3 / 32) * 32 # multiple of 32 inner_dim = math.floor(inner_dim * 2 / 3 / 32) * 32 # multiple of 32 proj_in = GLU( dim_in=dim, dim_out=inner_dim, channel_last=channel_last, act_layer=act_layer, bias=bias, ) else: proj_in = nn.Sequential( ( nn.Linear(in_features=dim, out_features=inner_dim, bias=bias) if channel_last else nn.Conv2d( in_channels=dim, out_channels=inner_dim, kernel_size=1, stride=1, bias=bias, ) ), act_layer(), ) self.net = nn.Sequential( proj_in, nn.Dropout(p=drop_prob), ( nn.Linear(in_features=inner_dim, out_features=dim, bias=bias) if channel_last else nn.Conv2d( in_channels=inner_dim, out_channels=dim, kernel_size=1, stride=1, bias=bias, ) ), ) def forward(self, x): return self.net(x) class DownsampleBase(nn.Module): def __init__(self): super().__init__() @staticmethod def output_is_normed(): raise NotImplementedError def get_downsample_layer_Cf2Cl( dim_in: int, dim_out: int, downsample_factor: int, downsample_cfg: DictConfig ) -> DownsampleBase: type = downsample_cfg.type if type == "patch": return ConvDownsampling_Cf2Cl( dim_in=dim_in, dim_out=dim_out, downsample_factor=downsample_factor, downsample_cfg=downsample_cfg, ) raise NotImplementedError class ConvDownsampling_Cf2Cl(DownsampleBase): """Downsample with input in NCHW [channel-first] format. Output in NHWC [channel-last] format. """ def __init__( self, dim_in: int, dim_out: int, downsample_factor: int, downsample_cfg: DictConfig, ): super().__init__() assert isinstance(dim_out, int) assert isinstance(dim_in, int) assert downsample_factor in (2, 4, 8) norm_affine = downsample_cfg.get("norm_affine", True) overlap = downsample_cfg.get("overlap", True) if overlap: kernel_size = (downsample_factor - 1) * 2 + 1 padding = kernel_size // 2 else: kernel_size = downsample_factor padding = 0 self.conv = nn.Conv2d( in_channels=dim_in, out_channels=dim_out, kernel_size=kernel_size, padding=padding, stride=downsample_factor, bias=False, ) self.norm = LayerNorm(num_channels=dim_out, eps=1e-5, affine=norm_affine) def forward(self, x: torch.Tensor): x = self.conv(x) x = nChw_2_nhwC(x) x = self.norm(x) return x @staticmethod def output_is_normed(): return True class PartitionAttentionCl(nn.Module): """Grid or Block partition + Attn + FFN. NxC 'channels last' tensor layout. According to RW, NHWC attention is a few percent faster on GPUs (but slower on TPUs) https://github.com/rwightman/pytorch-image-models/blob/4f72bae43be26d9764a08d83b88f8bd4ec3dbe43/timm/models/maxxvit.py#L1258 """ def __init__( self, dim: int, partition_type: PartitionType, attention_cfg: DictConfig, skip_first_norm: bool = False, ): super().__init__() norm_eps = attention_cfg.get("norm_eps", 1e-5) partition_size = attention_cfg.partition_size use_torch_mha = attention_cfg.use_torch_mha dim_head = attention_cfg.get("dim_head", 32) attention_bias = attention_cfg.get("attention_bias", True) mlp_act_string = attention_cfg.mlp_activation mlp_gated = attention_cfg.mlp_gated mlp_bias = attention_cfg.get("mlp_bias", True) mlp_expand_ratio = attention_cfg.get("mlp_ratio", 4) drop_path = attention_cfg.get("drop_path", 0.0) drop_mlp = attention_cfg.get("drop_mlp", 0.0) ls_init_value = attention_cfg.get("ls_init_value", 1e-5) assert isinstance(use_torch_mha, bool) assert isinstance(mlp_gated, bool) assert_activation_string(activation_string=mlp_act_string) mlp_act_layer = get_act_layer(mlp_act_string) self_attn_module = TorchMHSAWrapperCl if use_torch_mha else SelfAttentionCl if isinstance(partition_size, int): partition_size = to_2tuple(partition_size) else: partition_size = tuple(partition_size) assert len(partition_size) == 2 self.partition_size = partition_size norm_layer = partial( get_norm_layer("layernorm"), eps=norm_eps ) # NOTE this block is channels-last assert isinstance(partition_type, PartitionType) self.partition_window = partition_type == PartitionType.WINDOW self.norm1 = nn.Identity() if skip_first_norm else norm_layer(dim) self.self_attn = self_attn_module(dim, dim_head=dim_head, bias=attention_bias) self.ls1 = ( LayerScale(dim=dim, init_values=ls_init_value) if ls_init_value > 0 else nn.Identity() ) self.drop_path1 = ( DropPath(drop_prob=drop_path) if drop_path > 0 else nn.Identity() ) self.norm2 = norm_layer(dim) self.mlp = MLP( dim=dim, channel_last=True, expansion_ratio=mlp_expand_ratio, act_layer=mlp_act_layer, gated=mlp_gated, bias=mlp_bias, drop_prob=drop_mlp, ) self.ls2 = ( LayerScale(dim=dim, init_values=ls_init_value) if ls_init_value > 0 else nn.Identity() ) self.drop_path2 = ( DropPath(drop_prob=drop_path) if drop_path > 0 else nn.Identity() ) def _partition_attn(self, x): img_size = x.shape[1:3] if self.partition_window: partitioned = window_partition(x, self.partition_size) else: partitioned = grid_partition(x, self.partition_size) partitioned = self.self_attn(partitioned) if self.partition_window: x = window_reverse( partitioned, self.partition_size, (img_size[0], img_size[1]) ) else: x = grid_reverse( partitioned, self.partition_size, (img_size[0], img_size[1]) ) return x def forward(self, x): x = x + self.drop_path1(self.ls1(self._partition_attn(self.norm1(x)))) x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x)))) return x def window_partition(x, window_size: Tuple[int, int]): B, H, W, C = x.shape _assert( H % window_size[0] == 0, f"height ({H}) must be divisible by window ({window_size[0]})", ) _assert( W % window_size[1] == 0, f"width ({W}) must be divisible by window ({window_size[1]})", ) x = x.view( B, H // window_size[0], window_size[0], W // window_size[1], window_size[1], C ) windows = ( x.permute(0, 1, 3, 2, 4, 5) .contiguous() .view(-1, window_size[0], window_size[1], C) ) return windows def window_reverse(windows, window_size: Tuple[int, int], img_size: Tuple[int, int]): H, W = img_size C = windows.shape[-1] x = windows.view( -1, H // window_size[0], W // window_size[1], window_size[0], window_size[1], C ) x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, H, W, C) return x def grid_partition(x, grid_size: Tuple[int, int]): B, H, W, C = x.shape _assert( H % grid_size[0] == 0, f"height {H} must be divisible by grid {grid_size[0]}" ) _assert( W % grid_size[1] == 0, f"width {W} must be divisible by grid {grid_size[1]}" ) x = x.view(B, grid_size[0], H // grid_size[0], grid_size[1], W // grid_size[1], C) windows = ( x.permute(0, 2, 4, 1, 3, 5).contiguous().view(-1, grid_size[0], grid_size[1], C) ) return windows def grid_reverse(windows, grid_size: Tuple[int, int], img_size: Tuple[int, int]): H, W = img_size C = windows.shape[-1] x = windows.view( -1, H // grid_size[0], W // grid_size[1], grid_size[0], grid_size[1], C ) x = x.permute(0, 3, 1, 4, 2, 5).contiguous().view(-1, H, W, C) return x class TorchMHSAWrapperCl(nn.Module): """Channels-last multi-head self-attention (B, ..., C)""" def __init__(self, dim: int, dim_head: int = 32, bias: bool = True): super().__init__() assert dim % dim_head == 0 num_heads = dim // dim_head self.mha = nn.MultiheadAttention( embed_dim=dim, num_heads=num_heads, bias=bias, batch_first=True ) def forward(self, x: torch.Tensor): restore_shape = x.shape B, C = restore_shape[0], restore_shape[-1] x = x.view(B, -1, C) attn_output, attn_output_weights = self.mha(query=x, key=x, value=x) attn_output = attn_output.reshape(restore_shape) return attn_output class SelfAttentionCl(nn.Module): """Channels-last multi-head self-attention (B, ..., C)""" def __init__(self, dim: int, dim_head: int = 32, bias: bool = True): super().__init__() self.num_heads = dim // dim_head self.dim_head = dim_head self.scale = dim_head**-0.5 self.qkv = nn.Linear(dim, dim * 3, bias=bias) self.proj = nn.Linear(dim, dim, bias=bias) def forward(self, x: torch.Tensor): B = x.shape[0] restore_shape = x.shape[:-1] q, k, v = ( self.qkv(x) .view(B, -1, self.num_heads, self.dim_head * 3) .transpose(1, 2) .chunk(3, dim=3) ) attn = (q @ k.transpose(-2, -1)) * self.scale attn = attn.softmax(dim=-1) x = (attn @ v).transpose(1, 2).reshape(restore_shape + (-1,)) x = self.proj(x) return x def assert_activation_string( activation_string: Optional[Union[str, Tuple[str, ...], List[str]]] ) -> None: # Serves as a hacky documentation and sanity check. # List of possible activation layer strings that are reasonable: # https://github.com/rwightman/pytorch-image-models/blob/a520da9b495422bc773fb5dfe10819acb8bd7c5c/timm/models/layers/create_act.py#L62 if activation_string is None: return if isinstance(activation_string, str): assert activation_string in ( "silu", "swish", "mish", "relu", "relu6", "leaky_relu", "elu", "prelu", "celu", "selu", "gelu", "sigmoid", "tanh", "hard_sigmoid", "hard_swish", "hard_mish", ) elif isinstance(activation_string, (tuple, list)): for entry in activation_string: assert_activation_string(activation_string=entry) else: raise NotImplementedError def assert_norm2d_layer_string( norm_layer: Optional[Union[str, Tuple[str, ...], List[str]]] ) -> None: # Serves as a hacky documentation and sanity check. # List of possible norm layer strings that are reasonable: # https://github.com/rwightman/pytorch-image-models/blob/4f72bae43be26d9764a08d83b88f8bd4ec3dbe43/timm/models/layers/create_norm.py#L14 if norm_layer is None: return if isinstance(norm_layer, str): assert norm_layer in ("batchnorm", "batchnorm2d", "groupnorm", "layernorm2d") elif isinstance(norm_layer, (tuple, list)): for entry in norm_layer: assert_norm2d_layer_string(norm_layer=entry) else: raise NotImplementedError ================================================ FILE: RVT/models/layers/rnn.py ================================================ from typing import Optional, Tuple import torch as th import torch.nn as nn class DWSConvLSTM2d(nn.Module): """LSTM with (depthwise-separable) Conv option in NCHW [channel-first] format.""" def __init__( self, dim: int, dws_conv: bool = True, dws_conv_only_hidden: bool = True, dws_conv_kernel_size: int = 3, cell_update_dropout: float = 0.0, ): super().__init__() assert isinstance(dws_conv, bool) assert isinstance(dws_conv_only_hidden, bool) self.dim = dim xh_dim = dim * 2 gates_dim = dim * 4 conv3x3_dws_dim = dim if dws_conv_only_hidden else xh_dim self.conv3x3_dws = ( nn.Conv2d( in_channels=conv3x3_dws_dim, out_channels=conv3x3_dws_dim, kernel_size=dws_conv_kernel_size, padding=dws_conv_kernel_size // 2, groups=conv3x3_dws_dim, ) if dws_conv else nn.Identity() ) self.conv1x1 = nn.Conv2d( in_channels=xh_dim, out_channels=gates_dim, kernel_size=1 ) self.conv_only_hidden = dws_conv_only_hidden self.cell_update_dropout = nn.Dropout(p=cell_update_dropout) def forward( self, x: th.Tensor, h_and_c_previous: Optional[Tuple[th.Tensor, th.Tensor]] = None, ) -> Tuple[th.Tensor, th.Tensor]: """ :param x: (N C H W) :param h_and_c_previous: ((N C H W), (N C H W)) :return: ((N C H W), (N C H W)) """ if h_and_c_previous is None: # generate zero states hidden = th.zeros_like(x) cell = th.zeros_like(x) h_and_c_previous = (hidden, cell) h_tm1, c_tm1 = h_and_c_previous if self.conv_only_hidden: h_tm1 = self.conv3x3_dws(h_tm1) xh = th.cat((x, h_tm1), dim=1) if not self.conv_only_hidden: xh = self.conv3x3_dws(xh) mix = self.conv1x1(xh) gates, cell_input = th.tensor_split(mix, [self.dim * 3], dim=1) assert gates.shape[1] == cell_input.shape[1] * 3 gates = th.sigmoid(gates) forget_gate, input_gate, output_gate = th.tensor_split(gates, 3, dim=1) assert forget_gate.shape == input_gate.shape == output_gate.shape cell_input = self.cell_update_dropout(th.tanh(cell_input)) c_t = forget_gate * c_tm1 + input_gate * cell_input h_t = output_gate * th.tanh(c_t) return h_t, c_t ================================================ FILE: RVT/models/layers/s5/__init__.py ================================================ from .s5_model import * ================================================ FILE: RVT/models/layers/s5/jax_func.py ================================================ import torch import numpy as np from torch.utils._pytree import tree_flatten, tree_unflatten from typing import ( overload, Callable, Iterable, List, TypeVar, Any, Literal, Sequence, Optional, ) from functools import partial import math """ Jax-Pytorch ported functions, mostly interfaces are kept the same but unsupported features are removed: * Jax-Keyed RNGs are sampled from global RNG * Canonical/Named shapes/dtypes/etc are now regular shapes,dtypes """ T = TypeVar("T") T1 = TypeVar("T1") T2 = TypeVar("T2") T3 = TypeVar("T3") @overload def safe_map(f: Callable[[T1], T], __arg1: Iterable[T1]) -> List[T]: ... @overload def safe_map( f: Callable[[T1, T2], T], __arg1: Iterable[T1], __arg2: Iterable[T2] ) -> List[T]: ... @overload def safe_map( f: Callable[[T1, T2, T3], T], __arg1: Iterable[T1], __arg2: Iterable[T2], __arg3: Iterable[T3], ) -> List[T]: ... @overload def safe_map( f: Callable[..., T], __arg1: Iterable[Any], __arg2: Iterable[Any], __arg3: Iterable[Any], __arg4: Iterable[Any], *args, ) -> List[T]: ... def safe_map(f, *args): args = list(map(list, args)) n = len(args[0]) for arg in args[1:]: assert len(arg) == n, f"length mismatch: {list(map(len, args))}" return list(map(f, *args)) def combine(tree, operator, a_flat, b_flat): # Lower `fn` to operate on flattened sequences of elems. a = tree_unflatten(a_flat, tree) b = tree_unflatten(b_flat, tree) c = operator(a, b) c_flat, _ = tree_flatten(c) return c_flat def _scan(tree, operator, elems, axis: int): """Perform scan on `elems`.""" num_elems = elems[0].shape[axis] if num_elems < 2: return elems # Combine adjacent pairs of elements. reduced_elems = combine( tree, operator, [torch.ops.aten.slice(elem, axis, 0, -1, 2) for elem in elems], [torch.ops.aten.slice(elem, axis, 1, None, 2) for elem in elems], ) # Recursively compute scan for partially reduced tensors. odd_elems = _scan(tree, operator, reduced_elems, axis) if num_elems % 2 == 0: even_elems = combine( tree, operator, [torch.ops.aten.slice(e, axis, 0, -1) for e in odd_elems], [torch.ops.aten.slice(e, axis, 2, None, 2) for e in elems], ) else: even_elems = combine( tree, operator, odd_elems, [torch.ops.aten.slice(e, axis, 2, None, 2) for e in elems], ) # The first element of a scan is the same as the first element # of the original `elems`. even_elems = [ ( torch.cat([torch.ops.aten.slice(elem, axis, 0, 1), result], dim=axis) if result.shape.numel() > 0 and elem.shape[axis] > 0 else ( result if result.shape.numel() > 0 else torch.ops.aten.slice(elem, axis, 0, 1) ) ) # Jax allows/ignores concat with 0-dim, Pytorch does not for (elem, result) in zip(elems, even_elems) ] return list(safe_map(partial(_interleave, axis=axis), even_elems, odd_elems)) # Pytorch impl. of jax.lax.associative_scan def associative_scan(operator: Callable, elems, axis: int = 0, reverse: bool = False): # if not callable(operator): # raise TypeError("lax.associative_scan: fn argument should be callable.") elems_flat, tree = tree_flatten(elems) if reverse: elems_flat = [torch.flip(elem, [axis]) for elem in elems_flat] assert ( axis >= 0 or axis < elems_flat[0].ndim ), "Axis should be within bounds of input" num_elems = int(elems_flat[0].shape[axis]) if not all(int(elem.shape[axis]) == num_elems for elem in elems_flat[1:]): raise ValueError( "Array inputs to associative_scan must have the same " "first dimension. (saw: {})".format([elem.shape for elem in elems_flat]) ) scans = _scan(tree, operator, elems_flat, axis) if reverse: scans = [torch.flip(scanned, [axis]) for scanned in scans] return tree_unflatten(scans, tree) def test_associative_scan(shape=(1, 24, 24)): import jax.lax import jax x = np.random.randn(*shape) jx = jax.numpy.array(x) tx = torch.tensor(x, dtype=torch.float32) def nested_func(a, b): a_i, b_i = a a_j, b_j = b return a_j * a_i, a_j * b_i + b_j jy1, jy2 = jax.lax.associative_scan(nested_func, (jx, jx)) ty1, ty2 = associative_scan(nested_func, (tx, tx)) assert ( np.isclose(ty1.numpy(), np.array(jy1)).all() and np.isclose(ty2.numpy(), np.array(jy2)).all() ), "Expected jax & pytorch impl to be close" jy1, jy2 = jax.lax.associative_scan(nested_func, (jx, jx), reverse=True) ty1, ty2 = associative_scan(nested_func, (tx, tx), reverse=True) assert ( np.isclose(ty1.numpy(), np.array(jy1)).all() and np.isclose(ty2.numpy(), np.array(jy2)).all() ), "Expected jax & pytorch reverse impl to be close" print("Associative scan working as expected!") def _interleave(a, b, axis: int): # https://stackoverflow.com/questions/60869537/how-can-i-interleave-5-pytorch-tensors b_trunc = a.shape[axis] == b.shape[axis] + 1 if b_trunc: pad = [0, 0] * b.ndim pad[(b.ndim - axis - 1) * 2 + 1] = ( 1 # +1=always end of dim, pad-order is reversed so start is at end ) b = torch.nn.functional.pad(b, pad) stacked = torch.stack([a, b], dim=axis + 1) interleaved = torch.flatten(stacked, start_dim=axis, end_dim=axis + 1) if b_trunc: # TODO: find torch alternative for slice_along axis for torch.jit.script to work interleaved = torch.ops.aten.slice( interleaved, axis, 0, b.shape[axis] + a.shape[axis] - 1 ) return interleaved def test_interleave(): x, y = torch.randn(1, 32, 32), torch.randn(1, 32, 32) v = _interleave(x, y, axis=1) assert v.shape == (1, 64, 32) assert (v[:, 0] == x[:, 0]).all() assert (v[:, 1] == y[:, 0]).all() assert (v[:, 2] == x[:, 1]).all() assert (v[:, 3] == y[:, 1]).all() assert (v[:, 4] == x[:, 2]).all() v = _interleave(x, y, axis=2) assert v.shape == (1, 32, 64) assert (v[..., 0] == x[..., 0]).all() assert (v[..., 1] == y[..., 0]).all() assert (v[..., 2] == x[..., 1]).all() assert (v[..., 3] == y[..., 1]).all() assert (v[..., 4] == x[..., 2]).all() x, y = torch.randn(1, 24, 24), torch.randn(1, 24, 24) assert _interleave(x, y, axis=1).shape == (1, 48, 24) assert _interleave(x, y, axis=2).shape == (1, 24, 48) x, y = torch.randn(3, 96), torch.randn(2, 96) v = _interleave(x, y, axis=0) assert v.shape == (5, 96) assert (v[0] == x[0]).all() assert (v[1] == y[0]).all() assert (v[2] == x[1]).all() assert (v[3] == y[1]).all() assert (v[4] == x[2]).all() print("Interleave working as expected!") def _compute_fans(shape, fan_in_axes=None): """Computes the number of input and output units for a weight shape.""" if len(shape) < 1: fan_in = fan_out = 1 elif len(shape) == 1: fan_in = fan_out = shape[0] elif len(shape) == 2: fan_in, fan_out = shape else: if fan_in_axes is not None: # Compute fan-in using user-specified fan-in axes. fan_in = np.prod([shape[i] for i in fan_in_axes]) fan_out = np.prod([s for i, s in enumerate(shape) if i not in fan_in_axes]) else: # If no axes specified, assume convolution kernels (2D, 3D, or more.) # kernel_shape: (..., input_depth, depth) receptive_field_size = np.prod(shape[:-2]) fan_in = shape[-2] * receptive_field_size fan_out = shape[-1] * receptive_field_size return fan_in, fan_out def uniform(shape, dtype=torch.float, minval=0.0, maxval=1.0, device=None): src = torch.rand(shape, dtype=dtype, device=device) if minval == 0 and maxval == 1.0: return src else: return (src * (maxval - minval)) + minval def _complex_uniform(shape: Sequence[int], dtype, device=None) -> torch.Tensor: """ Sample uniform random values within a disk on the complex plane, with zero mean and unit variance. """ r = torch.sqrt(2 * torch.rand(shape, dtype=dtype, device=device)) theta = 2 * torch.pi * torch.rand(shape, dtype=dtype, device=device) return r * torch.exp(1j * theta) def complex_as_float_dtype(dtype): match dtype: case torch.complex32: return torch.float32 # NOTE: complexe32 is not wel supported yet case torch.complex64: return torch.float32 case torch.complex128: return torch.float64 case _: return dtype def _complex_truncated_normal( upper: float, shape: Sequence[int], dtype, device=None ) -> torch.Tensor: """ Sample random values from a centered normal distribution on the complex plane, whose modulus is truncated to `upper`, and the variance before the truncation is one. """ real_dtype = torch.tensor(0, dtype=dtype).real.dtype t = ( 1 - torch.exp(torch.tensor(-(upper**2), dtype=dtype, device=device)) ) * torch.rand(shape, dtype=real_dtype, device=device).type(dtype) r = torch.sqrt(-torch.log(1 - t)) theta = ( 2 * torch.pi * torch.rand(shape, dtype=real_dtype, device=device).type(dtype) ) return r * torch.exp(1j * theta) def _truncated_normal(lower, upper, shape, dtype=torch.float): if shape is None: shape = torch.broadcast_shapes(np.shape(lower), np.shape(upper)) sqrt2 = math.sqrt(2) a = math.erf(lower / sqrt2) b = math.erf(upper / sqrt2) # a>> import jax, jax.numpy as jnp >>> initializer = jax.nn.initializers.lecun_normal() >>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32) # doctest: +SKIP Array([[ 0.46700746, 0.8414632 , 0.8518669 ], [-0.61677957, -0.67402434, 0.09683388]], dtype=float32) .. _Lecun normal initializer: https://arxiv.org/abs/1706.02515 """ return variance_scaling( 1.0, "fan_in", "truncated_normal", fan_in_axes=fan_in_axes, dtype=dtype ) def test_variance_scaling(): v = variance_scaling(1.0, distribution="normal") n_f32 = v((1, 10000), dtype=torch.float) assert np.isclose( n_f32.std().item(), 1.0, rtol=0.015, atol=0.015 ), f"std for f32 normal[0,1.0] is {n_f32.std()} != 1.0" del n_f32 # NOTE: this is used in the original as `complex_normal` (but with stddev=0.5**0.5) n_c64 = v((1, 10000), dtype=torch.complex64) assert np.isclose( n_c64.std().item(), 1.0, rtol=0.015, atol=0.015 ), f"std for c64 normal[0,1.0] is {n_c64.std()} != 1.0" del n_c64 # Truncated normal v = variance_scaling(1.0, distribution="truncated_normal") tn_f32 = v((1, 10000), dtype=torch.float) assert np.isclose( tn_f32.std().item(), 0.775, rtol=0.015, atol=0.015 ), f"std for f32 truncated normal[0,1.0] is {tn_f32.std()} != 0.775" del tn_f32 # NOTE: this is used in the original (both trunc_standard_normal & lecun_normal it seems), # seems that they are using the fan-in/out feature to 'hide the low variance initialization' # The actual std observed is np.sqrt(2/shape[1]/(2*shape[0])); shape[2] has no impact v = variance_scaling(1.0, distribution="truncated_normal") tn_f32 = v((1, 10000, 2), dtype=torch.float) tn_c32 = torch.complex(tn_f32[..., 0], tn_f32[..., 1]) expected_std = np.sqrt(2 / tn_f32.shape[1] / (2 * tn_f32.shape[0])) print(tn_c32.shape) assert np.isclose( tn_c32.std().item(), expected_std, rtol=0.015, atol=0.015 ), f"std for f32 truncated normal[0,1.0] is {tn_c32.std()} != {expected_std}" del tn_f32 del tn_c32 print("Variance scaling working as expected!") if __name__ == "__main__": test_variance_scaling() test_interleave() test_associative_scan() test_associative_scan(shape=(2, 256, 24)) test_associative_scan(shape=(360, 96)) ================================================ FILE: RVT/models/layers/s5/s5_init.py ================================================ import torch import numpy as np from .jax_func import variance_scaling, lecun_normal, uniform import scipy.linalg # Initialization Functions def make_HiPPO(N): """Create a HiPPO-LegS matrix. From https://github.com/srush/annotated-s4/blob/main/s4/s4.py Args: N (int32): state size Returns: N x N HiPPO LegS matrix """ P = np.sqrt(1 + 2 * np.arange(N)) A = P[:, np.newaxis] * P[np.newaxis, :] A = np.tril(A) - np.diag(np.arange(N)) return -A def make_NPLR_HiPPO(N): """ Makes components needed for NPLR representation of HiPPO-LegS From https://github.com/srush/annotated-s4/blob/main/s4/s4.py Args: N (int32): state size Returns: N x N HiPPO LegS matrix, low-rank factor P, HiPPO input matrix B """ # Make -HiPPO hippo = make_HiPPO(N) # Add in a rank 1 term. Makes it Normal. P = np.sqrt(np.arange(N) + 0.5) # HiPPO also specifies the B matrix B = np.sqrt(2 * np.arange(N) + 1.0) return hippo, P, B def make_DPLR_HiPPO(N): """ Makes components needed for DPLR representation of HiPPO-LegS From https://github.com/srush/annotated-s4/blob/main/s4/s4.py Note, we will only use the diagonal part Args: N: Returns: eigenvalues Lambda, low-rank term P, conjugated HiPPO input matrix B, eigenvectors V, HiPPO B pre-conjugation """ A, P, B = make_NPLR_HiPPO(N) S = A + P[:, np.newaxis] * P[np.newaxis, :] S_diag = np.diagonal(S) Lambda_real = np.mean(S_diag) * np.ones_like(S_diag) # Diagonalize S to V \Lambda V^* Lambda_imag, V = np.linalg.eigh(S * -1j) P = V.conj().T @ P B_orig = B B = V.conj().T @ B return Lambda_real + 1j * Lambda_imag, P, B, V, B_orig def make_Normal_S(N): nhippo = make_HiPPO(N) # Add in a rank 1 term. Makes it Normal. p = 0.5 * np.sqrt(2 * np.arange(1, N + 1) + 1.0) q = 2 * p S = nhippo + p[:, np.newaxis] * q[np.newaxis, :] return S def make_Normal_HiPPO(N, B=1): """Create a normal approximation to HiPPO-LegS matrix. For HiPPO matrix A, A=S+pqT is normal plus low-rank for a certain normal matrix S and low rank terms p and q. We are going to approximate the HiPPO matrix with the normal matrix S. Note we use original numpy instead of jax.numpy first to use the onp.linalg.eig function. This is because Jax's linalg.eig function does not run on GPU for non-symmetric matrices. This creates tracing issues. So we instead use onp.linalg eig and then cast to a jax array (since we only have to do this once in the beginning to initialize). Args: N (int32): state size B (int32): diagonal blocks Returns: Lambda (complex64): eigenvalues of S (N,) V (complex64): eigenvectors of S (N,N) """ assert N % B == 0, "N must divide blocks" S = (make_Normal_S(N // B),) * B S = scipy.linalg.block_diag(*S) # Diagonalize S to V \Lambda V^* Lambda, V = np.linalg.eig(S) # Convert to jax array return torch.tensor(Lambda), torch.tensor(V) def log_step_initializer(dt_min=0.001, dt_max=0.1): """Initialize the learnable timescale Delta by sampling uniformly between dt_min and dt_max. Args: dt_min (float32): minimum value dt_max (float32): maximum value Returns: init function """ def init(shape): """Init function Args: key: jax random key shape tuple: desired shape Returns: sampled log_step (float32) """ return uniform(shape, minval=np.log(dt_min), maxval=np.log(dt_max)) # return torch.rand(shape) * (np.log(dt_max) - np.log(dt_min)) + np.log(dt_min) return init def init_log_steps(H, dt_min, dt_max): """Initialize an array of learnable timescale parameters Args: key: jax random key input: tuple containing the array shape H and dt_min and dt_max Returns: initialized array of timescales (float32): (H,) """ log_steps = [] for i in range(H): log_step = log_step_initializer(dt_min=dt_min, dt_max=dt_max)(shape=(1,)) log_steps.append(log_step) return torch.tensor(log_steps) def init_VinvB(init_fun, Vinv): """Initialize B_tilde=V^{-1}B. First samples B. Then compute V^{-1}B. Note we will parameterize this with two different matrices for complex numbers. Args: init_fun: the initialization function to use, e.g. lecun_normal() shape (tuple): desired shape (P,H) Vinv: (complex64) the inverse eigenvectors used for initialization Returns: B_tilde (complex64) of shape (P,H,2) """ def init(shape, dtype): B = init_fun(shape, dtype) VinvB = Vinv @ B.type(Vinv.dtype) VinvB_real = VinvB.real VinvB_imag = VinvB.imag return torch.cat((VinvB_real[..., None], VinvB_imag[..., None]), axis=-1) return init def trunc_standard_normal(shape): """Sample C with a truncated normal distribution with standard deviation 1. Args: key: jax random key shape (tuple): desired shape, of length 3, (H,P,_) Returns: sampled C matrix (float32) of shape (H,P,2) (for complex parameterization) """ H, P, _ = shape Cs = [] for i in range(H): C = lecun_normal()(shape=(1, P, 2)) Cs.append(C) return torch.tensor(Cs)[:, 0] def init_CV(init_fun, shape, V) -> torch.Tensor: """Initialize C_tilde=CV. First sample C. Then compute CV. Note we will parameterize this with two different matrices for complex numbers. Args: init_fun: the initialization function to use, e.g. lecun_normal() shape (tuple): desired shape (H,P) V: (complex64) the eigenvectors used for initialization Returns: C_tilde (complex64) of shape (H,P,2) """ C_ = init_fun(shape + (2,)) C = C_[..., 0] + 1j * C_[..., 1] CV = C @ V return CV def init_columnwise_B(shape, dtype): """Initialize B matrix in columnwise fashion. We will sample each column of B from a lecun_normal distribution. This gives a different fan-in size then if we sample the entire matrix B at once. We found this approach to be helpful for PathX It appears to be related to the point in https://arxiv.org/abs/2206.12037 regarding the initialization of the C matrix in S4, so potentially more important for the C initialization than for B. Args: key: jax random key shape (tuple): desired shape, either of length 3, (P,H,_), or of length 2 (N,H) depending on if the function is called from the low-rank factorization initialization or a dense initialization Returns: sampled B matrix (float32), either of shape (H,P) or shape (H,P,2) (for complex parameterization) """ shape = shape[:2] + ((2,) if len(shape) == 3 else ()) lecun = variance_scaling(0.5 if len(shape) == 3 else 1.0, fan_in_axes=(0,)) return lecun(shape, dtype) def init_columnwise_VinvB(init_fun, Vinv): """Same function as above, but with transpose applied to prevent shape mismatch when using the columnwise initialization. In general this is unnecessary and will be removed in future versions, but is left for now consistency with certain random seeds until we rerun experiments.""" def init(shape, dtype): B = init_fun(shape[:2], dtype) VinvB = Vinv @ B VinvB_real = VinvB.real VinvB_imag = VinvB.imag return torch.cat((VinvB_real[..., None], VinvB_imag[..., None]), axis=-1) return init def init_rowwise_C(shape, dtype): """Initialize C matrix in rowwise fashion. Analogous to init_columnwise_B function above. We will sample each row of C from a lecun_normal distribution. This gives a different fan-in size then if we sample the entire matrix B at once. We found this approach to be helpful for PathX. It appears to be related to the point in https://arxiv.org/abs/2206.12037 regarding the initialization of the C matrix in S4. Args: shape (tuple): desired shape, of length 3, (H,P,_) Returns: sampled C matrix (float32) of shape (H,P,2) (for complex parameterization) """ shape = shape[:2] + ((2,) if len(shape) == 3 else ()) lecun = variance_scaling(0.5, fan_in_axes=(0,)) return lecun(shape, dtype) ================================================ FILE: RVT/models/layers/s5/s5_model.py ================================================ import torch import torch.nn.functional as F from typing import Literal, Tuple, Optional import os, sys import math ROOT = os.getcwd() if str(ROOT) not in sys.path: sys.path.append(str(ROOT)) sys.path.append(os.path.join(ROOT, "RVT")) from models.layers.s5.jax_func import associative_scan from models.layers.s5.s5_init import * # Runtime functions @torch.jit.script def binary_operator( q_i: Tuple[torch.Tensor, torch.Tensor], q_j: Tuple[torch.Tensor, torch.Tensor] ): """Binary operator for parallel scan of linear recurrence. Assumes a diagonal matrix A. Args: q_i: tuple containing A_i and Bu_i at position i (P,), (P,) q_j: tuple containing A_j and Bu_j at position j (P,), (P,) Returns: new element ( A_out, Bu_out ) """ A_i, b_i = q_i A_j, b_j = q_j # return A_j * A_i, A_j * b_i + b_j return A_j * A_i, torch.addcmul(b_j, A_j, b_i) def apply_ssm( Lambda_bars: torch.Tensor, B_bars, C_tilde, D, input_sequence, prev_state, bidir: bool = False, ): B_bars = as_complex(B_bars) C_tilde = as_complex(C_tilde) Lambda_bars = as_complex(Lambda_bars) cinput_sequence = input_sequence.type( Lambda_bars.dtype ) # Cast to correct complex type if B_bars.ndim == 3: # Dynamic timesteps (significantly more expensive) Bu_elements = torch.vmap(lambda B_bar, u: B_bar @ u)(B_bars, cinput_sequence) else: # Static timesteps Bu_elements = torch.vmap(lambda u: B_bars @ u)(cinput_sequence) if Lambda_bars.ndim == 1: # Repeat for associative_scan Lambda_bars = Lambda_bars.tile(input_sequence.shape[0], 1) Lambda_bars[0] = Lambda_bars[0] * prev_state _, xs = associative_scan(binary_operator, (Lambda_bars, Bu_elements)) if bidir: _, xs2 = associative_scan( binary_operator, (Lambda_bars, Bu_elements), reverse=True ) xs = torch.cat((xs, xs2), axis=-1) Du = torch.vmap(lambda u: D * u)(input_sequence) # TODO: the last element of xs (non-bidir) is the hidden state, allow returning it return torch.vmap(lambda x: (C_tilde @ x).real)(xs) + Du, xs[-1] def apply_ssm_liquid( Lambda_bars, B_bars, C_tilde, D, input_sequence, bidir: bool = False ): """Liquid time constant SSM \u00e1 la dynamical systems given in Eq. 8 of https://arxiv.org/abs/2209.12951""" cinput_sequence = input_sequence.type( Lambda_bars.dtype ) # Cast to correct complex type if B_bars.ndim == 3: # Dynamic timesteps (significantly more expensive) Bu_elements = torch.vmap(lambda B_bar, u: B_bar @ u)(B_bars, cinput_sequence) else: # Static timesteps Bu_elements = torch.vmap(lambda u: B_bars @ u)(cinput_sequence) if Lambda_bars.ndim == 1: # Repeat for associative_scan Lambda_bars = Lambda_bars.tile(input_sequence.shape[0], 1) _, xs = associative_scan(binary_operator, (Lambda_bars + Bu_elements, Bu_elements)) if bidir: _, xs2 = associative_scan( binary_operator, (Lambda_bars, Bu_elements), reverse=True ) xs = torch.cat((xs, xs2), axis=-1) Du = torch.vmap(lambda u: D * u)(input_sequence) return torch.vmap(lambda x: (C_tilde @ x).real)(xs) + Du # Discretization functions def discretize_bilinear(Lambda, B_tilde, Delta): """Discretize a diagonalized, continuous-time linear SSM using bilinear transform method. Args: Lambda (complex64): diagonal state matrix (P,) B_tilde (complex64): input matrix (P, H) Delta (float32): discretization step sizes (P,) Returns: discretized Lambda_bar (complex64), B_bar (complex64) (P,), (P,H) """ Lambda = torch.view_as_complex(Lambda) Identity = torch.ones(Lambda.shape[0], device=Lambda.device) BL = 1 / (Identity - (Delta / 2.0) * Lambda) Lambda_bar = BL * (Identity + (Delta / 2.0) * Lambda) B_bar = (BL * Delta)[..., None] * B_tilde Lambda_bar = torch.view_as_real(Lambda_bar) B_bar = torch.view_as_real(B_bar) return Lambda_bar, B_bar def discretize_zoh(Lambda, B_tilde, Delta): """Discretize a diagonalized, continuous-time linear SSM using zero-order hold method. Args: Lambda (complex64): diagonal state matrix (P,) B_tilde (complex64): input matrix (P, H) Delta (float32): discretization step sizes (P,) Returns: discretized Lambda_bar (complex64), B_bar (complex64) (P,), (P,H) """ # Identity = torch.ones(Lambda.shape[0], device=Lambda.device) # (replaced by -1) Lambda_bar = torch.exp(Lambda * Delta) B_bar = (1 / Lambda * (Lambda_bar - 1))[..., None] * B_tilde return Lambda_bar, B_bar def as_complex(t: torch.Tensor, dtype=torch.complex64): assert t.shape[-1] == 2, "as_complex can only be done on tensors with shape=(...,2)" nt = torch.complex(t[..., 0], t[..., 1]) if nt.dtype != dtype: nt = nt.type(dtype) return nt Initialization = Literal["dense_columns", "dense", "factorized"] class S5SSM(torch.nn.Module): def __init__( self, lambdaInit: torch.Tensor, V: torch.Tensor, Vinv: torch.Tensor, h: int, p: int, dt_min: float, dt_max: float, liquid: bool = False, factor_rank: Optional[int] = None, discretization: Literal["zoh", "bilinear"] = "bilinear", bcInit: Initialization = "factorized", degree: int = 1, bidir: bool = False, step_scale: float = 1.0, bandlimit: Optional[float] = None, ): """The S5 SSM Args: lambdaInit (complex64): Initial diagonal state matrix (P,) V (complex64): Eigenvectors used for init (P,P) Vinv (complex64): Inverse eigenvectors used for init (P,P) h (int32): Number of features of input seq p (int32): state size k (int32): rank of low-rank factorization (if used) bcInit (string): Specifies How B and C are initialized Options: [factorized: low-rank factorization, dense: dense matrix drawn from Lecun_normal] dense_columns: dense matrix where the columns of B and the rows of C are each drawn from Lecun_normal separately (i.e. different fan-in then the dense option). We found this initialization to be helpful for Pathx. discretization: (string) Specifies discretization method options: [zoh: zero-order hold method, bilinear: bilinear transform] liquid: (bool): use liquid_ssm from LiquidS4 dt_min: (float32): minimum value to draw timescale values from when initializing log_step dt_max: (float32): maximum value to draw timescale values from when initializing log_step step_scale: (float32): allows for changing the step size, e.g. after training on a different resolution for the speech commands benchmark """ super().__init__() self.Lambda = torch.nn.Parameter(torch.view_as_real(lambdaInit)) self.degree = degree self.liquid = liquid self.bcInit = bcInit self.bidir = bidir self.bandlimit = bandlimit cp = p if self.bidir: cp *= 2 match bcInit: case "complex_normal": self.C = torch.nn.Parameter( torch.normal(0, 0.5**0.5, (h, cp), dtype=torch.complex64) ) self.B = torch.nn.Parameter( init_VinvB(lecun_normal(), Vinv)((p, h), torch.float) ) case "dense_columns" | "dense": if bcInit == "dense_columns": B_eigen_init = init_columnwise_VinvB B_init = init_columnwise_B C_init = init_rowwise_C elif bcInit == "dense": B_eigen_init = init_VinvB B_init = C_init = lecun_normal() # TODO: make init_*VinvB all a the same interface self.B = torch.nn.Parameter( B_eigen_init(B_init, Vinv)((p, h), torch.float) ) if self.bidir: C = torch.cat( [init_CV(C_init, (h, p), V), init_CV(C_init, (h, p), V)], axis=-1, ) else: C = init_CV(C_init, (h, p), V) self.C = torch.nn.Parameter(torch.view_as_real(C)) case _: raise NotImplementedError(f"BC_init method {bcInit} not implemented") # Initialize feedthrough (D) matrix self.D = torch.nn.Parameter( torch.rand( h, ) ) self.log_step = torch.nn.Parameter(init_log_steps(p, dt_min, dt_max)) match discretization: case "zoh": self.discretize = discretize_zoh case "bilinear": self.discretize = discretize_bilinear case _: raise ValueError(f"Unknown discretization {discretization}") if self.bandlimit is not None: step = step_scale * torch.exp(self.log_step) freqs = step / step_scale * self.Lambda[:, 1].abs() / (2 * math.pi) mask = torch.where(freqs < bandlimit * 0.5, 1, 0) # (64, ) self.C = torch.nn.Parameter( torch.view_as_real(torch.view_as_complex(self.C) * mask) ) def initial_state(self, batch_size: Optional[int]): batch_shape = (batch_size,) if batch_size is not None else () _, C_tilde = self.get_BC_tilde() return torch.zeros((*batch_shape, C_tilde.shape[-2])) def get_BC_tilde(self): match self.bcInit: case "dense_columns" | "dense" | "complex_normal": B_tilde = as_complex(self.B) C_tilde = self.C case "factorized": B_tilde = self.BP @ self.BH.T C_tilde = self.CH.T @ self.CP return B_tilde, C_tilde def forward_rnn(self, signal, prev_state, step_scale: float | torch.Tensor = 1.0): assert not self.bidir, "Can't use bidirectional when manually stepping" B_tilde, C_tilde = self.get_BC_tilde() step = step_scale * torch.exp(self.log_step) Lambda_bar, B_bar = self.discretize(self.Lambda, B_tilde, step) if self.degree != 1: assert ( B_bar.shape[-2] == B_bar.shape[-1] ), "higher-order input operators must be full-rank" B_bar **= self.degree if not torch.is_tensor(step_scale) or step_scale.ndim == 0: step_scale = torch.ones(signal.shape[-2], device=signal.device) * step_scale step = step_scale[:, None] * torch.exp(self.log_step) # https://arxiv.org/abs/2209.12951v1, Eq. 9 Bu = B_bar @ signal if self.liquid: Lambda_bar += Bu # https://arxiv.org/abs/2208.04933v2, Eq. 2 x = Lambda_bar * prev_state + Bu y = (C_tilde @ x + self.D * signal).real return y, x # NOTE: can only be used as RNN OR S5(MIMO) (no mixing) def forward(self, signal, prev_state, step_scale: float | torch.Tensor = 1.0): B_tilde, C_tilde = self.get_BC_tilde() if self.degree != 1: assert ( B_bar.shape[-2] == B_bar.shape[-1] ), "higher-order input operators must be full-rank" B_bar **= self.degree if not torch.is_tensor(step_scale) or step_scale.ndim == 0: # step_scale = torch.ones(signal.shape[-2], device=signal.device) * step_scale step = step_scale * torch.exp(self.log_step) else: # TODO: This is very expensive due to individual steps being multiplied by B_tilde in self.discretize step = step_scale[:, None] * torch.exp(self.log_step) Lambda_bars, B_bars = self.discretize(self.Lambda, B_tilde, step) # Lambda_bars, B_bars = torch.vmap(self.discretize, (None, None, 0))(self.Lambda, B_tilde, step) forward = apply_ssm_liquid if self.liquid else apply_ssm return forward( Lambda_bars, B_bars, C_tilde, self.D, signal, prev_state, bidir=self.bidir ) class S5(torch.nn.Module): def __init__( self, width: int, state_width: Optional[int] = None, factor_rank: Optional[int] = None, block_count: int = 1, dt_min: float = 0.001, dt_max: float = 0.1, liquid: bool = False, degree: int = 1, bidir: bool = False, bcInit: Optional[Initialization] = None, bandlimit: Optional[float] = None, ): super().__init__() state_width = state_width or width assert ( state_width % block_count == 0 ), "block_count should be a factor of state_width" block_size = state_width // block_count Lambda, _, B, V, B_orig = make_DPLR_HiPPO(block_size) Vinv = V.conj().T Lambda, B, V, B_orig, Vinv = map( lambda v: torch.tensor(v, dtype=torch.complex64), (Lambda, B, V, B_orig, Vinv), ) if block_count > 1: Lambda = Lambda[:block_size] V = V[:, :block_size] Lambda = (Lambda * torch.ones((block_count, block_size))).ravel() V = torch.block_diag(*([V] * block_count)) Vinv = torch.block_diag(*([Vinv] * block_count)) assert bool(factor_rank) != bool( bcInit != "factorized" ), "Can't have `bcInit != factorized` and `factor_rank` defined" bc_init = "factorized" if factor_rank is not None else (bcInit or "dense") self.width = width self.seq = S5SSM( Lambda, V, Vinv, width, state_width, dt_min, dt_max, factor_rank=factor_rank, bcInit=bc_init, liquid=liquid, degree=degree, bidir=bidir, bandlimit=bandlimit, ) def initial_state(self, batch_size: Optional[int] = None): return self.seq.initial_state(batch_size) def forward(self, signal, prev_state, step_scale: float | torch.Tensor = 1.0): # NOTE: step_scale can be float | Tensor[batch] | Tensor[batch, seq] if not torch.is_tensor(step_scale): # Duplicate across batchdim step_scale = torch.ones(signal.shape[0], device=signal.device) * step_scale return torch.vmap(lambda s, ps, ss: self.seq(s, prev_state=ps, step_scale=ss))( signal, prev_state, step_scale ) class GEGLU(torch.nn.Module): def forward(self, x): x, gates = x.chunk(2, dim=-1) return x * F.gelu(gates) class S5Block(torch.nn.Module): def __init__( self, dim: int, state_dim: int, bidir: bool, block_count: int = 1, liquid: bool = False, degree: int = 1, factor_rank: int | None = None, bcInit: Optional[Initialization] = None, ff_mult: float = 1.0, glu: bool = True, ff_dropout: float = 0.0, attn_dropout: float = 0.0, bandlimit: Optional[float] = None, ): super().__init__() self.s5 = S5( dim, state_width=state_dim, bidir=bidir, block_count=block_count, liquid=liquid, degree=degree, factor_rank=factor_rank, bcInit=bcInit, bandlimit=bandlimit, ) self.attn_norm = torch.nn.LayerNorm(dim) self.attn_dropout = torch.nn.Dropout(p=attn_dropout) self.geglu = GEGLU() if glu else None self.ff_enc = torch.nn.Linear(dim, int(dim * ff_mult) * (1 + glu), bias=False) self.ff_dec = torch.nn.Linear(int(dim * ff_mult), dim, bias=False) self.ff_norm = torch.nn.LayerNorm(dim) self.ff_dropout = torch.nn.Dropout(p=ff_dropout) def forward(self, x, states): # Standard transfomer-style block with GEGLU/Pre-LayerNorm fx = self.attn_norm(x) res = fx.clone() x, new_state = self.s5(fx, states) x = F.gelu(x) + res x = self.attn_dropout(x) fx = self.ff_norm(x) res = fx.clone() x = self.ff_enc(fx) if self.geglu is not None: x = self.geglu(x) x = self.ff_dec(x) + res x = self.ff_dropout( x ) # TODO: test if should be placed inbetween ff or after ff return x, new_state if __name__ == "__main__": import lovely_tensors as lt lt.monkey_patch() def tensor_stats(t: torch.Tensor): # Clone of lovely_tensors for complex support return f"tensor[{t.shape}] n={t.shape.numel()}, u={t.mean()}, s={round(t.std().item(), 3)} var={round(t.var().item(), 3)}\n" x = torch.rand([2, 256, 32]).cuda() model = S5(32, 32, factor_rank=None).cuda() print("B", tensor_stats(model.seq.B.data)) print("C", tensor_stats(model.seq.C.data)) # print('B', tensor_stats(model.seq.BH.data), tensor_stats(model.seq.BP.data)) # print('C', tensor_stats(model.seq.CH.data), tensor_stats(model.seq.CP.data)) # FIXME: unstable initialization # state = model.initial_state(256) # res = model(x, prev_state=state) # print(res.shape, res.dtype, res) res = model(x) # warm-up print(res.shape, res.dtype, res) # Example 2: (B, L, H) inputs x = torch.rand([2, 256, 32]).cuda() model = S5Block(32, 32, False).cuda() res = model(x) print(res.shape, res.dtype, res) ================================================ FILE: RVT/models/layers/s5/triton_comparison.py ================================================ import torch import numpy as np import time import triton import triton.language as tl from triton.runtime.jit import TensorWrapper, reinterpret from jax_func import associative_scan int_dtypes = ["int8", "int16", "int32", "int64"] uint_dtypes = ["uint8", "uint16", "uint32", "uint64"] float_dtypes = ["float16", "float32", "float64"] dtypes = int_dtypes + uint_dtypes + float_dtypes dtypes_with_bfloat16 = dtypes + ["bfloat16"] torch_dtypes = ["bool"] + int_dtypes + ["uint8"] + float_dtypes + ["bfloat16"] def to_triton(x: np.ndarray, device="cuda", dst_type=None): t = x.dtype.name if t in uint_dtypes: signed_type_name = t.lstrip("u") # e.g. "uint16" -> "int16" x_signed = x.astype(getattr(np, signed_type_name)) return reinterpret( torch.tensor(x_signed, device=device).contiguous(), getattr(tl, t) ) else: if dst_type and "float8" in dst_type: return reinterpret( torch.tensor(x, device=device).contiguous(), getattr(tl, dst_type) ) if t == "float32" and dst_type == "bfloat16": return torch.tensor(x, device=device).contiguous().bfloat16() return torch.tensor(x, device=device).contiguous() def to_numpy(x): if isinstance(x, TensorWrapper): # FIXME: torch_dtype_name doesn't exist return x.base.cpu().numpy().astype(getattr(np, torch_dtype_name(x.dtype))) elif isinstance(x, torch.Tensor): if x.dtype is torch.bfloat16: return x.cpu().float().numpy() return x.cpu().numpy() else: raise ValueError(f"Not a triton-compatible tensor: {x}") if __name__ == "__main__": use_gpu = True if use_gpu: device = torch.device("cuda:0") else: device = None triton_times = [] loop_times = [] loop_comp_times = [] jax_compat_times = [] print("Initializing") op = "cumsum" num_warps = 16 dim = 1 seq_len = 2048 batch = 4 dtype_str = "float32" axis = 0 shape = (batch, seq_len, dim) n_timings = 10000 x = np.random.rand(*shape).astype(dtype=np.float32) inp = torch.tensor(x, device=device, requires_grad=True, dtype=torch.float32) init = torch.zeros(shape[1], 1, device=device, requires_grad=True) inp_scan = inp @triton.jit def sum_op(a, b): return a + b @triton.jit def kernel(X, Z, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, AXIS: tl.constexpr): range_m = tl.arange(0, BLOCK_M) range_n = tl.arange(0, BLOCK_N) x = tl.load(X + range_m[:, None] * BLOCK_N + range_n[None, :]) # tl.device_print("z", x) z = tl.associative_scan(x, 0, sum_op) # tl.device_print("z", z) tl.store(Z + range_m[:, None] * BLOCK_N + range_n[None, :], z) print("Triton") z = np.empty_like(x) x_tri = to_triton(x, device=device) numpy_op = np.cumsum z_dtype_str = dtype_str z_ref = numpy_op(x, axis=axis).astype(getattr(np, z_dtype_str)) # triton result z_tri = to_triton(z, device=device) val = kernel[(1,)]( x_tri, z_tri, BLOCK_M=shape[0], BLOCK_N=shape[1], AXIS=axis, num_warps=num_warps ) out_triton = to_numpy(z_tri) for _ in range(n_timings): # print('.', end='', flush=True) start = time.monotonic_ns() kernel[(1,)]( x_tri, z_tri, BLOCK_M=shape[0], BLOCK_N=shape[1], AXIS=axis, num_warps=num_warps, ) stop = time.monotonic_ns() triton_times.append((stop - start) / (10**9)) print("\nFake scan") def f(carry, x): return carry + x, carry + x def _fake_scan(f, init, x): zs = [] carry = init for xp in x: carry, out = f(carry, xp) zs.append(out) return carry, torch.stack(zs) expected_carry_out, expected_ys = _fake_scan(f, init, inp_scan) for _ in range(n_timings): # print('.', end='', flush=True) start = time.monotonic_ns() expected_carry_out, expected_ys = _fake_scan(f, init, inp_scan) stop = time.monotonic_ns() loop_times.append((stop - start) / (10**9)) # _fake_scan_comp = torch.compile(_fake_scan, mode='reduce-overhead', fullgraph=True, dynamic=False) # # Warm-up cycles # print("\nFake scan-compiled") # for _ in range(5): # expected_carry_out_comp, expected_ys_comp = _fake_scan_comp(f, init, inp_scan) # for _ in range(n_timings): # print('.', end='', flush=True) # start = time.monotonic_ns() # expected_carry_out_comp, expected_ys_comp = _fake_scan_comp(f, init, inp_scan) # stop = time.monotonic_ns() # loop_comp_times.append((stop - start) / (10 ** 9)) def sum_op2(a, b): return a + b, a + b # Warm-up print("\njax_compat") for _ in range(5): expected_ys_comp = associative_scan(sum_op2, inp_scan, axis=-1) for _ in range(n_timings): # print('.', end='', flush=True) start = time.monotonic_ns() expected_ys_comp = associative_scan(sum_op2, inp_scan, axis=-1) stop = time.monotonic_ns() jax_compat_times.append((stop - start) / (10**9)) print() print("Times regular loop " + str(np.array(loop_times).mean())) # print('Times compiled loop ' + str(np.array(loop_comp_times).mean())) print("Times triton " + str(np.array(triton_times).mean())) print("Times jax_compat " + str(np.array(jax_compat_times).mean())) print("Script ended") ================================================ FILE: RVT/modules/__init__.py ================================================ ================================================ FILE: RVT/modules/data/genx.py ================================================ from functools import partial from typing import Any, Dict, Optional, Union import math import lightning.pytorch as pl from omegaconf import DictConfig from torch.utils.data import DataLoader, Dataset from data.genx_utils.collate import custom_collate_rnd, custom_collate_streaming from data.genx_utils.dataset_rnd import ( build_random_access_dataset, get_weighted_random_sampler, CustomConcatDataset, ) from data.genx_utils.dataset_streaming import build_streaming_dataset from data.utils.spatial import get_dataloading_hw from data.utils.types import DatasetMode, DatasetSamplingMode def get_dataloader_kwargs( dataset: Union[Dataset, CustomConcatDataset], sampling_mode: DatasetSamplingMode, dataset_mode: DatasetMode, dataset_config: DictConfig, batch_size: int, num_workers: int, ) -> Dict[str, Any]: if dataset_mode == DatasetMode.TRAIN: if sampling_mode == DatasetSamplingMode.STREAM: return dict( dataset=dataset, batch_size=None, shuffle=False, # Done already in the streaming datapipe num_workers=num_workers, pin_memory=False, drop_last=False, # Cannot be done with streaming datapipes collate_fn=custom_collate_streaming, ) if sampling_mode == DatasetSamplingMode.RANDOM: use_weighted_rnd_sampling = dataset_config.train.random.weighted_sampling sampler = ( get_weighted_random_sampler(dataset) if use_weighted_rnd_sampling else None ) return dict( dataset=dataset, batch_size=batch_size, shuffle=sampler is None, sampler=sampler, num_workers=num_workers, pin_memory=False, drop_last=True, # Maintain the same batch size for logging collate_fn=custom_collate_rnd, ) raise NotImplementedError elif dataset_mode in (DatasetMode.VALIDATION, DatasetMode.TESTING): if sampling_mode == DatasetSamplingMode.STREAM: return dict( dataset=dataset, batch_size=None, shuffle=False, num_workers=num_workers, pin_memory=False, drop_last=False, # Cannot be done with streaming datapipes collate_fn=custom_collate_streaming, ) if sampling_mode == DatasetSamplingMode.RANDOM: return dict( dataset=dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=False, drop_last=True, # Maintain the same batch size for logging collate_fn=custom_collate_rnd, ) raise NotImplementedError raise NotImplementedError class DataModule(pl.LightningDataModule): def __init__( self, dataset_config: DictConfig, num_workers_train: int, num_workers_eval: int, batch_size_train: int, batch_size_eval: int, ): super().__init__() assert num_workers_train >= 0 assert num_workers_eval >= 0 assert batch_size_train >= 1 assert batch_size_eval >= 1 self.dataset_config = dataset_config self.train_sampling_mode = dataset_config.train.sampling self.eval_sampling_mode = dataset_config.eval.sampling assert self.train_sampling_mode in iter(DatasetSamplingMode) assert self.eval_sampling_mode in ( DatasetSamplingMode.STREAM, DatasetSamplingMode.RANDOM, ) # In DDP all configs are per process/GPU (num_workers, batch_size, ...). self.overall_batch_size_train = batch_size_train self.overall_batch_size_eval = batch_size_eval self.overall_num_workers_train = num_workers_train self.overall_num_workers_eval = num_workers_eval if self.eval_sampling_mode == DatasetSamplingMode.STREAM: self.build_eval_dataset = partial( build_streaming_dataset, batch_size=self.overall_batch_size_eval, num_workers=self.overall_num_workers_eval, ) elif self.eval_sampling_mode == DatasetSamplingMode.RANDOM: self.build_eval_dataset = build_random_access_dataset else: raise NotImplementedError self.sampling_mode_2_dataset = dict() self.sampling_mode_2_train_workers = dict() self.sampling_mode_2_train_batch_size = dict() self.validation_dataset = None self.test_dataset = None def get_dataloading_hw(self): return get_dataloading_hw(dataset_config=self.dataset_config) def set_mixed_sampling_mode_variables_for_train(self): assert ( self.overall_batch_size_train >= 2 ), "Cannot use mixed mode with batch size smaller than 2" assert ( self.overall_num_workers_train >= 2 ), "Cannot use mixed mode with num workers smaller than 2" weight_random = self.dataset_config.train.mixed.w_random weight_stream = self.dataset_config.train.mixed.w_stream assert weight_random > 0 assert weight_stream > 0 # Set batch size according to weights. bs_rnd = min( round( self.overall_batch_size_train * weight_random / (weight_stream + weight_random) ), self.overall_batch_size_train - 1, ) bs_str = self.overall_batch_size_train - bs_rnd self.sampling_mode_2_train_batch_size[DatasetSamplingMode.RANDOM] = bs_rnd self.sampling_mode_2_train_batch_size[DatasetSamplingMode.STREAM] = bs_str # Set num workers according to batch size. Random sampling typically takes longer than stream sampling! workers_rnd = min( math.ceil( self.overall_num_workers_train * bs_rnd / self.overall_batch_size_train ), self.overall_num_workers_train - 1, ) workers_str = self.overall_num_workers_train - workers_rnd self.sampling_mode_2_train_workers[DatasetSamplingMode.RANDOM] = workers_rnd self.sampling_mode_2_train_workers[DatasetSamplingMode.STREAM] = workers_str print( f"[Train] Local batch size for:\nstream sampling:\t{bs_str}\nrandom sampling:\t{bs_rnd}\n" f"[Train] Local num workers for:\nstream sampling:\t{workers_str}\nrandom sampling:\t{workers_rnd}" ) def setup(self, stage: Optional[str] = None) -> None: if stage == "fit": if self.train_sampling_mode == DatasetSamplingMode.MIXED: self.set_mixed_sampling_mode_variables_for_train() else: self.sampling_mode_2_train_workers[self.train_sampling_mode] = ( self.overall_num_workers_train ) self.sampling_mode_2_train_batch_size[self.train_sampling_mode] = ( self.overall_batch_size_train ) # This code is a bit hacky because at this point we not use DatasetSamplingMode.MIXED anymore # because we split it up into random and streaming. DatasetSamplingMode.MIXED was just used to determine # whether we use both or not. if self.train_sampling_mode in ( DatasetSamplingMode.RANDOM, DatasetSamplingMode.MIXED, ): self.sampling_mode_2_dataset[DatasetSamplingMode.RANDOM] = ( build_random_access_dataset( dataset_mode=DatasetMode.TRAIN, dataset_config=self.dataset_config, ) ) if self.train_sampling_mode in ( DatasetSamplingMode.STREAM, DatasetSamplingMode.MIXED, ): self.sampling_mode_2_dataset[DatasetSamplingMode.STREAM] = ( build_streaming_dataset( dataset_mode=DatasetMode.TRAIN, dataset_config=self.dataset_config, batch_size=self.sampling_mode_2_train_batch_size[ DatasetSamplingMode.STREAM ], num_workers=self.sampling_mode_2_train_workers[ DatasetSamplingMode.STREAM ], ) ) self.validation_dataset = self.build_eval_dataset( dataset_mode=DatasetMode.VALIDATION, dataset_config=self.dataset_config ) elif stage == "validate": self.validation_dataset = self.build_eval_dataset( dataset_mode=DatasetMode.VALIDATION, dataset_config=self.dataset_config ) elif stage == "test": self.test_dataset = self.build_eval_dataset( dataset_mode=DatasetMode.TESTING, dataset_config=self.dataset_config ) else: raise NotImplementedError def train_dataloader(self): train_loaders = dict() for sampling_mode, dataset in self.sampling_mode_2_dataset.items(): train_loaders[sampling_mode] = DataLoader( **get_dataloader_kwargs( dataset=dataset, sampling_mode=sampling_mode, dataset_mode=DatasetMode.TRAIN, dataset_config=self.dataset_config, batch_size=self.sampling_mode_2_train_batch_size[sampling_mode], num_workers=self.sampling_mode_2_train_workers[sampling_mode], ) ) if len(train_loaders) == 1: train_loaders = next(iter(train_loaders.values())) # Returns a single dataloader. return train_loaders assert len(train_loaders) == 2 # Returns a mapping from dataset sampling modes to dataloader. return train_loaders def val_dataloader(self): return DataLoader( **get_dataloader_kwargs( dataset=self.validation_dataset, sampling_mode=self.eval_sampling_mode, dataset_mode=DatasetMode.VALIDATION, dataset_config=self.dataset_config, batch_size=self.overall_batch_size_eval, num_workers=self.overall_num_workers_eval, ) ) def test_dataloader(self): return DataLoader( **get_dataloader_kwargs( dataset=self.test_dataset, sampling_mode=self.eval_sampling_mode, dataset_mode=DatasetMode.TESTING, dataset_config=self.dataset_config, batch_size=self.overall_batch_size_eval, num_workers=self.overall_num_workers_eval, ) ) ================================================ FILE: RVT/modules/detection.py ================================================ from typing import Any, Optional, Tuple, Union, Dict from warnings import warn import numpy as np import lightning.pytorch as pl import torch import torch as th import torch.distributed as dist from omegaconf import DictConfig from lightning.pytorch.utilities.types import STEP_OUTPUT from einops import rearrange from data.genx_utils.labels import ObjectLabels from data.utils.types import DataType, LstmStates, ObjDetOutput, DatasetSamplingMode from models.detection.yolox.utils.boxes import postprocess from models.detection.yolox_extension.models.detector import YoloXDetector from utils.evaluation.prophesee.evaluator import PropheseeEvaluator from utils.evaluation.prophesee.io.box_loading import to_prophesee from utils.padding import InputPadderFromShape from .utils.detection import ( BackboneFeatureSelector, EventReprSelector, RNNStates, Mode, mode_2_string, merge_mixed_batches, ) class Module(pl.LightningModule): def __init__(self, full_config: DictConfig): super().__init__() self.full_config = full_config self.mdl_config = full_config.model in_res_hw = tuple(self.mdl_config.backbone.in_res_hw) self.input_padder = InputPadderFromShape(desired_hw=in_res_hw) self.mdl = YoloXDetector(self.mdl_config) self.mode_2_rnn_states: Dict[Mode, RNNStates] = { Mode.TRAIN: RNNStates(), Mode.VAL: RNNStates(), Mode.TEST: RNNStates(), } def setup(self, stage: Optional[str] = None) -> None: dataset_name = self.full_config.dataset.name self.mode_2_hw: Dict[Mode, Optional[Tuple[int, int]]] = {} self.mode_2_batch_size: Dict[Mode, Optional[int]] = {} self.mode_2_psee_evaluator: Dict[Mode, Optional[PropheseeEvaluator]] = {} self.mode_2_sampling_mode: Dict[Mode, DatasetSamplingMode] = {} self.started_training = True dataset_train_sampling = self.full_config.dataset.train.sampling dataset_eval_sampling = self.full_config.dataset.eval.sampling assert dataset_train_sampling in iter(DatasetSamplingMode) assert dataset_eval_sampling in ( DatasetSamplingMode.STREAM, DatasetSamplingMode.RANDOM, ) if stage == "fit": # train + val self.train_config = self.full_config.training self.train_metrics_config = self.full_config.logging.train.metrics if self.train_metrics_config.compute: self.mode_2_psee_evaluator[Mode.TRAIN] = PropheseeEvaluator( dataset=dataset_name, downsample_by_2=self.full_config.dataset.downsample_by_factor_2, ) self.mode_2_psee_evaluator[Mode.VAL] = PropheseeEvaluator( dataset=dataset_name, downsample_by_2=self.full_config.dataset.downsample_by_factor_2, ) self.mode_2_sampling_mode[Mode.TRAIN] = dataset_train_sampling self.mode_2_sampling_mode[Mode.VAL] = dataset_eval_sampling for mode in (Mode.TRAIN, Mode.VAL): self.mode_2_hw[mode] = None self.mode_2_batch_size[mode] = None self.started_training = False elif stage == "validate": mode = Mode.VAL self.mode_2_psee_evaluator[mode] = PropheseeEvaluator( dataset=dataset_name, downsample_by_2=self.full_config.dataset.downsample_by_factor_2, ) self.mode_2_sampling_mode[Mode.VAL] = dataset_eval_sampling self.mode_2_hw[mode] = None self.mode_2_batch_size[mode] = None elif stage == "test": mode = Mode.TEST self.mode_2_psee_evaluator[mode] = PropheseeEvaluator( dataset=dataset_name, downsample_by_2=self.full_config.dataset.downsample_by_factor_2, ) self.mode_2_sampling_mode[Mode.TEST] = dataset_eval_sampling self.mode_2_hw[mode] = None self.mode_2_batch_size[mode] = None else: raise NotImplementedError def forward( self, event_tensor: th.Tensor, previous_states: Optional[LstmStates] = None, retrieve_detections: bool = True, targets=None, ) -> Tuple[Union[th.Tensor, None], Union[Dict[str, th.Tensor], None], LstmStates]: return self.mdl( x=event_tensor, previous_states=previous_states, retrieve_detections=retrieve_detections, targets=targets, ) def get_worker_id_from_batch(self, batch: Any) -> int: return batch["worker_id"] def get_data_from_batch(self, batch: Any): return batch["data"] def training_step(self, batch: Any, batch_idx: int) -> STEP_OUTPUT: batch = merge_mixed_batches(batch) data = self.get_data_from_batch(batch) worker_id = self.get_worker_id_from_batch(batch) mode = Mode.TRAIN self.started_training = True step = self.trainer.global_step ev_tensor_sequence = data[DataType.EV_REPR] sparse_obj_labels = data[DataType.OBJLABELS_SEQ] is_first_sample = data[DataType.IS_FIRST_SAMPLE] token_mask_sequence = data.get(DataType.TOKEN_MASK, None) self.mode_2_rnn_states[mode].reset( worker_id=worker_id, indices_or_bool_tensor=is_first_sample ) sequence_len = len(ev_tensor_sequence) assert sequence_len > 0 batch_size = len(sparse_obj_labels[0]) if self.mode_2_batch_size[mode] is None: self.mode_2_batch_size[mode] = batch_size else: assert self.mode_2_batch_size[mode] == batch_size prev_states = self.mode_2_rnn_states[mode].get_states(worker_id=worker_id) backbone_feature_selector = BackboneFeatureSelector() ev_repr_selector = EventReprSelector() obj_labels = list() ev_tensor_sequence = torch.stack( ev_tensor_sequence ) # shape: (sequence_len, batch_size, channels, height, width) = (L, B, C, H, W) ev_tensor_sequence = ev_tensor_sequence.to(dtype=self.dtype) ev_tensor_sequence = self.input_padder.pad_tensor_ev_repr(ev_tensor_sequence) if token_mask_sequence is not None: token_mask_sequence = torch.stack(token_mask_sequence) token_mask_sequence = token_mask_sequence.to(dtype=self.dtype) token_mask_sequence = self.input_padder.pad_token_mask( token_mask=token_mask_sequence ) else: token_mask_sequence = None if self.mode_2_hw[mode] is None: self.mode_2_hw[mode] = tuple(ev_tensor_sequence.shape[-2:]) else: assert self.mode_2_hw[mode] == ev_tensor_sequence.shape[-2:] backbone_features, states = self.mdl.forward_backbone( x=ev_tensor_sequence, previous_states=prev_states, token_mask=token_mask_sequence, train_step=True, ) prev_states = states for tidx, curr_labels in enumerate(sparse_obj_labels): ( current_labels, valid_batch_indices, ) = curr_labels.get_valid_labels_and_batch_indices() # Store backbone features that correspond to the available labels. if len(current_labels) > 0: backbone_feature_selector.add_backbone_features( backbone_features={ k: v[tidx] for k, v in backbone_features.items() }, selected_indices=valid_batch_indices, ) obj_labels.extend(current_labels) ev_repr_selector.add_event_representations( event_representations=ev_tensor_sequence[tidx], selected_indices=valid_batch_indices, ) self.mode_2_rnn_states[mode].save_states_and_detach( worker_id=worker_id, states=prev_states ) assert len(obj_labels) > 0 # Batch the backbone features and labels to parallelize the detection code. selected_backbone_features = ( backbone_feature_selector.get_batched_backbone_features() ) labels_yolox = ObjectLabels.get_labels_as_batched_tensor( obj_label_list=obj_labels, format_="yolox" ) labels_yolox = labels_yolox.to(dtype=self.dtype) predictions, losses = self.mdl.forward_detect( backbone_features=selected_backbone_features, targets=labels_yolox ) if self.mode_2_sampling_mode[mode] in ( DatasetSamplingMode.MIXED, DatasetSamplingMode.RANDOM, ): # We only want to evaluate the last batch_size samples if we use random sampling (or mixed). # This is because otherwise we would mostly evaluate the init phase of the sequence. predictions = predictions[-batch_size:] obj_labels = obj_labels[-batch_size:] pred_processed = postprocess( prediction=predictions, num_classes=self.mdl_config.head.num_classes, conf_thre=self.mdl_config.postprocess.confidence_threshold, nms_thre=self.mdl_config.postprocess.nms_threshold, ) loaded_labels_proph, yolox_preds_proph = to_prophesee( obj_labels, pred_processed ) assert losses is not None assert "loss" in losses # For visualization, we only use the last batch_size items. output = { ObjDetOutput.LABELS_PROPH: loaded_labels_proph[-batch_size:], ObjDetOutput.PRED_PROPH: yolox_preds_proph[-batch_size:], ObjDetOutput.EV_REPR: ev_repr_selector.get_event_representations_as_list( start_idx=-batch_size ), ObjDetOutput.SKIP_VIZ: False, "loss": losses["loss"], } # Logging prefix = f"{mode_2_string[mode]}/" log_dict = {f"{prefix}{k}": v for k, v in losses.items()} self.log_dict( log_dict, on_step=True, on_epoch=True, batch_size=batch_size, sync_dist=True ) if mode in self.mode_2_psee_evaluator: self.mode_2_psee_evaluator[mode].add_labels(loaded_labels_proph) self.mode_2_psee_evaluator[mode].add_predictions(yolox_preds_proph) if ( self.train_metrics_config.detection_metrics_every_n_steps is not None and step > 0 and step % self.train_metrics_config.detection_metrics_every_n_steps == 0 ): self.run_psee_evaluator(mode=mode) return output def _val_test_step_impl(self, batch: Any, mode: Mode) -> Optional[STEP_OUTPUT]: data = self.get_data_from_batch(batch) worker_id = self.get_worker_id_from_batch(batch) assert mode in (Mode.VAL, Mode.TEST) ev_tensor_sequence = data[DataType.EV_REPR] sparse_obj_labels = data[DataType.OBJLABELS_SEQ] is_first_sample = data[DataType.IS_FIRST_SAMPLE] self.mode_2_rnn_states[mode].reset( worker_id=worker_id, indices_or_bool_tensor=is_first_sample ) sequence_len = len(ev_tensor_sequence) assert sequence_len > 0 batch_size = len(sparse_obj_labels[0]) if self.mode_2_batch_size[mode] is None: self.mode_2_batch_size[mode] = batch_size else: assert self.mode_2_batch_size[mode] == batch_size prev_states = self.mode_2_rnn_states[mode].get_states(worker_id=worker_id) backbone_feature_selector = BackboneFeatureSelector() ev_repr_selector = EventReprSelector() obj_labels = list() ev_tensor_sequence = torch.stack( ev_tensor_sequence ) # shape: (sequence_len, batch_size, channels, height, width) = (L, B, C, H, W) ev_tensor_sequence = ev_tensor_sequence.to(dtype=self.dtype) ev_tensor_sequence = self.input_padder.pad_tensor_ev_repr(ev_tensor_sequence) if self.mode_2_hw[mode] is None: self.mode_2_hw[mode] = tuple(ev_tensor_sequence.shape[-2:]) else: assert self.mode_2_hw[mode] == ev_tensor_sequence.shape[-2:] backbone_features, states = self.mdl.forward_backbone( x=ev_tensor_sequence, previous_states=prev_states, train_step=False, ) prev_states = states for tidx in range(sequence_len): collect_predictions = (tidx == sequence_len - 1) or ( self.mode_2_sampling_mode[mode] == DatasetSamplingMode.STREAM ) if collect_predictions: current_labels, valid_batch_indices = sparse_obj_labels[ tidx ].get_valid_labels_and_batch_indices() # Store backbone features that correspond to the available labels. if len(current_labels) > 0: backbone_feature_selector.add_backbone_features( backbone_features={ k: v[tidx] for k, v in backbone_features.items() }, selected_indices=valid_batch_indices, ) obj_labels.extend(current_labels) ev_repr_selector.add_event_representations( event_representations=ev_tensor_sequence[tidx], selected_indices=valid_batch_indices, ) self.mode_2_rnn_states[mode].save_states_and_detach( worker_id=worker_id, states=prev_states ) if len(obj_labels) == 0: return {ObjDetOutput.SKIP_VIZ: True} selected_backbone_features = ( backbone_feature_selector.get_batched_backbone_features() ) predictions, _ = self.mdl.forward_detect( backbone_features=selected_backbone_features ) pred_processed = postprocess( prediction=predictions, num_classes=self.mdl_config.head.num_classes, conf_thre=self.mdl_config.postprocess.confidence_threshold, nms_thre=self.mdl_config.postprocess.nms_threshold, ) loaded_labels_proph, yolox_preds_proph = to_prophesee( obj_labels, pred_processed ) # For visualization, we only use the last item (per batch). output = { ObjDetOutput.LABELS_PROPH: loaded_labels_proph[-1], ObjDetOutput.PRED_PROPH: yolox_preds_proph[-1], ObjDetOutput.EV_REPR: ev_repr_selector.get_event_representations_as_list( start_idx=-1 )[0], ObjDetOutput.SKIP_VIZ: False, } if self.started_training: self.mode_2_psee_evaluator[mode].add_labels(loaded_labels_proph) self.mode_2_psee_evaluator[mode].add_predictions(yolox_preds_proph) return output def validation_step(self, batch: Any, batch_idx: int) -> Optional[STEP_OUTPUT]: return self._val_test_step_impl(batch=batch, mode=Mode.VAL) def test_step(self, batch: Any, batch_idx: int) -> Optional[STEP_OUTPUT]: return self._val_test_step_impl(batch=batch, mode=Mode.TEST) def run_psee_evaluator(self, mode: Mode): psee_evaluator = self.mode_2_psee_evaluator[mode] batch_size = self.mode_2_batch_size[mode] hw_tuple = self.mode_2_hw[mode] if psee_evaluator is None: warn(f"psee_evaluator is None in {mode=}", UserWarning, stacklevel=2) return assert batch_size is not None assert hw_tuple is not None if psee_evaluator.has_data(): metrics = psee_evaluator.evaluate_buffer( img_height=hw_tuple[0], img_width=hw_tuple[1] ) assert metrics is not None prefix = f"{mode_2_string[mode]}/" step = self.trainer.global_step log_dict = {} for k, v in metrics.items(): if isinstance(v, (int, float)): value = torch.tensor(v) elif isinstance(v, np.ndarray): value = torch.from_numpy(v) elif isinstance(v, torch.Tensor): value = v else: raise NotImplementedError assert ( value.ndim == 0 ), f"tensor must be a scalar.\n{v=}\n{type(v)=}\n{value=}\n{type(value)=}" # put them on the current device to avoid this error: https://github.com/Lightning-AI/lightning/discussions/2529 log_dict[f"{prefix}{k}"] = value.to(self.device) # Somehow self.log does not work when we eval during the training epoch. self.log_dict( log_dict, on_step=False, on_epoch=True, batch_size=batch_size, sync_dist=True, ) if dist.is_available() and dist.is_initialized(): # We now have to manually sync (average the metrics) across processes in case of distributed training. # NOTE: This is necessary to ensure that we have the same numbers for the checkpoint metric (metadata) # and wandb metric: # - checkpoint callback is using the self.log function which uses global sync (avg across ranks) # - wandb uses log_metrics that we reduce manually to global rank 0 dist.barrier() for k, v in log_dict.items(): dist.reduce(log_dict[k], dst=0, op=dist.ReduceOp.SUM) if dist.get_rank() == 0: log_dict[k] /= dist.get_world_size() if self.trainer.is_global_zero: # For some reason we need to increase the step by 2 to enable consistent logging in wandb here. # I might not understand wandb login correctly. This works reasonably well for now. add_hack = 2 self.logger.log_metrics(metrics=log_dict, step=step + add_hack) psee_evaluator.reset_buffer() else: warn(f"psee_evaluator has not data in {mode=}", UserWarning, stacklevel=2) def on_train_epoch_end(self) -> None: mode = Mode.TRAIN if ( mode in self.mode_2_psee_evaluator and self.train_metrics_config.detection_metrics_every_n_steps is None and self.mode_2_hw[mode] is not None ): # For some reason PL calls this function when resuming. # We don't know yet the value of train_height_width, so we skip this self.run_psee_evaluator(mode=mode) def on_validation_epoch_end(self) -> None: mode = Mode.VAL if self.started_training: assert self.mode_2_psee_evaluator[mode].has_data() self.run_psee_evaluator(mode=mode) def on_test_epoch_end(self) -> None: mode = Mode.TEST assert self.mode_2_psee_evaluator[mode].has_data() self.run_psee_evaluator(mode=mode) def configure_optimizers(self) -> Any: lr = self.train_config.learning_rate weight_decay = self.train_config.weight_decay optimizer = th.optim.AdamW( self.mdl.parameters(), lr=lr, weight_decay=weight_decay ) scheduler_params = self.train_config.lr_scheduler if not scheduler_params.use: return optimizer total_steps = scheduler_params.total_steps assert total_steps is not None assert total_steps > 0 # Here we interpret the final lr as max_lr/final_div_factor. # Note that Pytorch OneCycleLR interprets it as initial_lr/final_div_factor: final_div_factor_pytorch = ( scheduler_params.final_div_factor / scheduler_params.div_factor ) lr_scheduler = torch.optim.lr_scheduler.OneCycleLR( optimizer=optimizer, max_lr=lr, div_factor=scheduler_params.div_factor, final_div_factor=final_div_factor_pytorch, total_steps=total_steps, pct_start=scheduler_params.pct_start, cycle_momentum=False, anneal_strategy="linear", ) lr_scheduler_config = { "scheduler": lr_scheduler, "interval": "step", "frequency": 1, "strict": True, "name": "learning_rate", } return {"optimizer": optimizer, "lr_scheduler": lr_scheduler_config} ================================================ FILE: RVT/modules/utils/detection.py ================================================ from enum import Enum, auto from typing import List, Optional, Union, Tuple, Dict, Any import torch import torch as th from data.genx_utils.labels import SparselyBatchedObjectLabels from data.utils.types import BackboneFeatures, LstmStates, DatasetSamplingMode class Mode(Enum): TRAIN = auto() VAL = auto() TEST = auto() mode_2_string = { Mode.TRAIN: "train", Mode.VAL: "val", Mode.TEST: "test", } class BackboneFeatureSelector: def __init__(self): self.features = None self.reset() def reset(self): self.features = dict() def add_backbone_features( self, backbone_features: BackboneFeatures, selected_indices: Optional[List[int]] = None, ) -> None: if selected_indices is not None: assert len(selected_indices) > 0 for k, v in backbone_features.items(): if k not in self.features: self.features[k] = ( [v[selected_indices]] if selected_indices is not None else [v] ) else: self.features[k].append( v[selected_indices] if selected_indices is not None else v ) def get_batched_backbone_features(self) -> Optional[BackboneFeatures]: if len(self.features) == 0: return None return {k: th.cat(v, dim=0) for k, v in self.features.items()} class EventReprSelector: def __init__(self): self.repr_list = None self.reset() def reset(self): self.repr_list = list() def __len__(self): return len(self.repr_list) def add_event_representations( self, event_representations: th.Tensor, selected_indices: Optional[List[int]] = None, ) -> None: if selected_indices is not None: assert len(selected_indices) > 0 self.repr_list.extend( x[0] for x in event_representations[selected_indices].split(1) ) def get_event_representations_as_list( self, start_idx: int = 0, end_idx: Optional[int] = None ) -> Optional[List[th.Tensor]]: if len(self) == 0: return None if end_idx is None: end_idx = len(self) assert start_idx < end_idx, f"{start_idx=}, {end_idx=}" return self.repr_list[start_idx:end_idx] class RNNStates: def __init__(self): self.states = {} def _has_states(self): return len(self.states) > 0 @classmethod def recursive_detach(cls, inp: Union[th.Tensor, List, Tuple, Dict]): if isinstance(inp, th.Tensor): return inp.detach() if isinstance(inp, list): return [cls.recursive_detach(x) for x in inp] if isinstance(inp, tuple): return tuple(cls.recursive_detach(x) for x in inp) if isinstance(inp, dict): return {k: cls.recursive_detach(v) for k, v in inp.items()} raise NotImplementedError @classmethod def recursive_reset( cls, inp: Union[th.Tensor, List, Tuple, Dict], indices_or_bool_tensor: Optional[Union[List[int], torch.Tensor]] = None, ): if isinstance(inp, th.Tensor): assert ( inp.requires_grad is False ), "Not assumed here but should be the case." if indices_or_bool_tensor is None: inp[:] = 0 else: assert len(indices_or_bool_tensor) > 0 inp[indices_or_bool_tensor] = 0 return inp if isinstance(inp, list): return [ cls.recursive_reset(x, indices_or_bool_tensor=indices_or_bool_tensor) for x in inp ] if isinstance(inp, tuple): return tuple( cls.recursive_reset(x, indices_or_bool_tensor=indices_or_bool_tensor) for x in inp ) if isinstance(inp, dict): return { k: cls.recursive_reset(v, indices_or_bool_tensor=indices_or_bool_tensor) for k, v in inp.items() } raise NotImplementedError def save_states_and_detach(self, worker_id: int, states: LstmStates) -> None: self.states[worker_id] = self.recursive_detach(states) def get_states(self, worker_id: int) -> Optional[LstmStates]: if not self._has_states(): return None if worker_id not in self.states: return None return self.states[worker_id] def reset( self, worker_id: int, indices_or_bool_tensor: Optional[Union[List[int], torch.Tensor]] = None, ): if not self._has_states(): return if worker_id in self.states: self.states[worker_id] = self.recursive_reset( self.states[worker_id], indices_or_bool_tensor=indices_or_bool_tensor ) def mixed_collate_fn( x1: Union[th.Tensor, List[th.Tensor]], x2: Union[th.Tensor, List[th.Tensor]] ): if isinstance(x1, th.Tensor): assert isinstance(x2, th.Tensor) return th.cat((x1, x2)) if isinstance(x1, SparselyBatchedObjectLabels): assert isinstance(x2, SparselyBatchedObjectLabels) return x1 + x2 if isinstance(x1, list): assert isinstance(x2, list) assert len(x1) == len(x2) return [mixed_collate_fn(x1=el_1, x2=el_2) for el_1, el_2 in zip(x1, x2)] raise NotImplementedError def merge_mixed_batches(batch: Dict[str, Any]): if "data" in batch: return batch rnd_data = batch[DatasetSamplingMode.RANDOM]["data"] stream_batch = batch[DatasetSamplingMode.STREAM] # We only care about the worker id of the streaming dataloader because the states will be anyway reset for the # random dataloader batch. out = {"worker_id": stream_batch["worker_id"]} stream_data = stream_batch["data"] assert ( rnd_data.keys() == stream_data.keys() ), f"{rnd_data.keys()=}, {stream_data.keys()=}" data_out = dict() for key in rnd_data.keys(): data_out[key] = mixed_collate_fn(stream_data[key], rnd_data[key]) out.update({"data": data_out}) return out ================================================ FILE: RVT/modules/utils/fetch.py ================================================ import lightning.pytorch as pl from omegaconf import DictConfig from modules.data.genx import DataModule as genx_data_module from modules.detection import Module as rnn_det_module def fetch_model_module(config: DictConfig) -> pl.LightningModule: model_str = config.model.name if model_str == "rnndet": return rnn_det_module(config) raise NotImplementedError def fetch_data_module(config: DictConfig) -> pl.LightningDataModule: batch_size_train = config.batch_size.train batch_size_eval = config.batch_size.eval num_workers_generic = config.hardware.get("num_workers", None) num_workers_train = config.hardware.num_workers.get("train", num_workers_generic) num_workers_eval = config.hardware.num_workers.get("eval", num_workers_generic) dataset_str = config.dataset.name if dataset_str in {"gen1", "gen4"}: return genx_data_module( config.dataset, num_workers_train=num_workers_train, num_workers_eval=num_workers_eval, batch_size_train=batch_size_train, batch_size_eval=batch_size_eval, ) raise NotImplementedError ================================================ FILE: RVT/scripts/genx/README.md ================================================ # Pre-Processing the Original Dataset ### 1. Download the data
train validation test
1 Mpx download download download
crc32 d677488a 72f13c3e 643e61ef
Gen1 download download download
crc32 3d23bd30 cc802022 cdd4fd69
### 2. Extract the tar files The following directory structure is assumed: ``` data_dir ├── test │ ├── ..._bbox.npy │ ├── ..._td.dat.h5 │ ... │ ├── train │ ├── ....npy │ ├── ..._td.dat.h5 │ ... │ └── val ├── ..._bbox.npy ├── ..._td.dat.h5 ... ``` ### 3. Run the pre-processing script `${DATA_DIR}` should point to the directory structure mentioned above. `${DEST_DIR}` should point to the directory to which the data will be written. For the 1 Mpx dataset: ```Bash NUM_PROCESSES=20 # set to the number of parallel processes to use python preprocess_dataset.py ${DATA_DIR} ${DEST_DIR} conf_preprocess/representation/stacked_hist.yaml \ conf_preprocess/extraction/const_duration.yaml conf_preprocess/filter_gen4.yaml -ds gen4 -np ${NUM_PROCESSES} ``` For the Gen1 dataset: ```Bash NUM_PROCESSES=20 # set to the number of parallel processes to use python preprocess_dataset.py ${DATA_DIR} ${DEST_DIR} conf_preprocess/representation/stacked_hist.yaml \ conf_preprocess/extraction/const_duration.yaml conf_preprocess/filter_gen1.yaml -ds gen1 -np ${NUM_PROCESSES} ``` ================================================ FILE: RVT/scripts/genx/conf_preprocess/extraction/const_count.yaml ================================================ method: COUNT value: 50000 ================================================ FILE: RVT/scripts/genx/conf_preprocess/extraction/const_duration.yaml ================================================ method: DURATION # value is in milliseconds! value: 50 ================================================ FILE: RVT/scripts/genx/conf_preprocess/extraction/frequencies/const_duration_100hz.yaml ================================================ method: DURATION # value is in milliseconds! value: 10 ================================================ FILE: RVT/scripts/genx/conf_preprocess/extraction/frequencies/const_duration_200hz.yaml ================================================ method: DURATION # value is in milliseconds! value: 5 ================================================ FILE: RVT/scripts/genx/conf_preprocess/extraction/frequencies/const_duration_40hz.yaml ================================================ method: DURATION # value is in milliseconds! value: 25 ================================================ FILE: RVT/scripts/genx/conf_preprocess/extraction/frequencies/const_duration_80hz.yaml ================================================ method: DURATION # value is in milliseconds! value: 12 ================================================ FILE: RVT/scripts/genx/conf_preprocess/filter_gen1.yaml ================================================ apply_psee_bbox_filter: True apply_faulty_bbox_filter: True ================================================ FILE: RVT/scripts/genx/conf_preprocess/filter_gen4.yaml ================================================ apply_psee_bbox_filter: False apply_faulty_bbox_filter: True ================================================ FILE: RVT/scripts/genx/conf_preprocess/representation/mixeddensity_stack.yaml ================================================ name: "mixeddensity_stack" nbins: 10 count_cutoff: 32 ================================================ FILE: RVT/scripts/genx/conf_preprocess/representation/stacked_hist.yaml ================================================ name: "stacked_histogram" nbins: 10 count_cutoff: 10 ================================================ FILE: RVT/scripts/genx/preprocess_dataset.py ================================================ import os os.environ["OMP_NUM_THREADS"] = "1" os.environ["OPENBLAS_NUM_THREADS"] = "1" os.environ["MKL_NUM_THREADS"] = "1" os.environ["VECLIB_MAXIMUM_THREADS"] = "1" os.environ["NUMEXPR_NUM_THREADS"] = "1" from abc import ABC, abstractmethod import argparse from dataclasses import dataclass, field from enum import Enum, auto from functools import partial from multiprocessing import get_context from pathlib import Path import shutil import sys sys.path.append("../..") from typing import Any, Dict, List, Optional, Tuple, Union import weakref import h5py import hdf5plugin from numba import jit import numpy as np from omegaconf import OmegaConf, DictConfig, MISSING import torch from tqdm import tqdm from utils.preprocessing import _blosc_opts from data.utils.representations import ( MixedDensityEventStack, StackedHistogram, RepresentationBase, ) class DataKeys(Enum): InNPY = auto() InH5 = auto() OutLabelDir = auto() OutEvReprDir = auto() SplitType = auto() class SplitType(Enum): TRAIN = auto() VAL = auto() TEST = auto() split_name_2_type = { "train": SplitType.TRAIN, "val": SplitType.VAL, "test": SplitType.TEST, } dataset_2_height = {"gen1": 240, "gen4": 720} dataset_2_width = {"gen1": 304, "gen4": 1280} # The following sequences would be discarded because all the labels would be removed after filtering: dirs_to_ignore = { "gen1": ( "17-04-06_09-57-37_6344500000_6404500000", "17-04-13_19-17-27_976500000_1036500000", "17-04-06_15-14-36_1159500000_1219500000", "17-04-11_15-13-23_122500000_182500000", ), "gen4": (), } class NoLabelsException(Exception): # Raised when no labels are present anymore in the sequence after filtering ... class H5Writer: def __init__( self, outfile: Path, key: str, ev_repr_shape: Tuple, numpy_dtype: np.dtype ): assert len(ev_repr_shape) == 3 self.h5f = h5py.File(str(outfile), "w") # Sets a finalizer that ensures the file gets closed when the object is garbage collected self._finalizer = weakref.finalize(self, self.close_callback, self.h5f) self.key = key # The dataset name/key inside the HDF5 file self.numpy_dtype = numpy_dtype # create hdf5 datasets maxshape = (None,) + ev_repr_shape chunkshape = (1,) + ev_repr_shape self.maxshape = maxshape self.h5f.create_dataset( key, dtype=self.numpy_dtype.name, shape=chunkshape, chunks=chunkshape, maxshape=maxshape, **_blosc_opts(complevel=1, shuffle="byte"), ) self.t_idx = 0 # enter and exit alllow to use the class as a context manager def __enter__(self): return self def __exit__(self, exc_type, exc_val, exc_tb): self._finalizer() @staticmethod def close_callback(h5f: h5py.File): h5f.close() def close(self): self.h5f.close() def get_current_length(self): return self.t_idx def add_data(self, data: np.ndarray): # append new data into the already initialized HDF5 dataset assert data.dtype == self.numpy_dtype, f"{data.dtype=}, {self.numpy_dtype=}" assert data.shape == self.maxshape[1:] new_size = self.t_idx + 1 self.h5f[self.key].resize(new_size, axis=0) self.h5f[self.key][ self.t_idx : new_size ] = data # it writes the new data to the last position of the first dimension self.t_idx = new_size # It updates the internal index (self.t_idx) to point to the next empty slot in the dataset class H5Reader: def __init__(self, h5_file: Path, dataset: str = "gen4"): assert h5_file.exists() assert h5_file.suffix == ".h5" assert dataset in {"gen1", "gen4"} self.h5f = h5py.File(str(h5_file), "r") self._finalizer = weakref.finalize(self, self._close_callback, self.h5f) self.is_open = True try: self.height = self.h5f["events"]["height"][()].item() self.width = self.h5f["events"]["width"][()].item() except KeyError: self.height = dataset_2_height[dataset] self.width = dataset_2_width[dataset] self.all_times = None def __enter__(self): return self def __exit__(self, exc_type, exc_val, exc_tb): self._finalizer() @staticmethod def _close_callback(h5f: h5py.File): h5f.close() def close(self): self.h5f.close() self.is_open = False def get_height_and_width(self) -> Tuple[int, int]: return self.height, self.width @property def time(self) -> np.ndarray: # We need to lazy load time because it is typically not sorted everywhere. # - Set timestamps of events such they are not decreasing. assert self.is_open if self.all_times is None: self.all_times = np.asarray(self.h5f["events"]["t"]) self._correct_time(self.all_times) return self.all_times @staticmethod @jit(nopython=True) def _correct_time(time_array: np.ndarray): assert time_array[0] >= 0 time_last = 0 for idx, time in enumerate(time_array): if time < time_last: time_array[idx] = time_last else: time_last = time def get_event_slice( self, idx_start: int, idx_end: int, convert_2_torch: bool = True ): assert self.is_open assert idx_end >= idx_start ev_data = self.h5f["events"] x_array = np.asarray(ev_data["x"][idx_start:idx_end], dtype="int64") y_array = np.asarray(ev_data["y"][idx_start:idx_end], dtype="int64") p_array = np.asarray(ev_data["p"][idx_start:idx_end], dtype="int64") p_array = np.clip(p_array, a_min=0, a_max=None) t_array = np.asarray(self.time[idx_start:idx_end], dtype="int64") assert np.all(t_array[:-1] <= t_array[1:]) ev_data = dict( x=x_array if not convert_2_torch else torch.from_numpy(x_array), y=y_array if not convert_2_torch else torch.from_numpy(y_array), p=p_array if not convert_2_torch else torch.from_numpy(p_array), t=t_array if not convert_2_torch else torch.from_numpy(t_array), height=self.height, width=self.width, ) return ev_data def prophesee_bbox_filter(labels: np.ndarray, dataset_type: str) -> np.ndarray: assert dataset_type in {"gen1", "gen4"} # Default values taken from: https://github.com/prophesee-ai/prophesee-automotive-dataset-toolbox/blob/0393adea2bf22d833893c8cb1d986fcbe4e6f82d/src/psee_evaluator.py#L23-L24 min_box_diag = 60 if dataset_type == "gen4" else 30 # Corrected values from supplementary mat from paper for min_box_side! min_box_side = 20 if dataset_type == "gen4" else 10 w_lbl = labels["w"] h_lbl = labels["h"] diag_ok = w_lbl**2 + h_lbl**2 >= min_box_diag**2 side_ok = (w_lbl >= min_box_side) & (h_lbl >= min_box_side) keep = diag_ok & side_ok labels = labels[keep] return labels def conservative_bbox_filter(labels: np.ndarray) -> np.ndarray: w_lbl = labels["w"] h_lbl = labels["h"] min_box_side = 5 side_ok = (w_lbl >= min_box_side) & (h_lbl >= min_box_side) labels = labels[side_ok] return labels def remove_faulty_huge_bbox_filter(labels: np.ndarray, dataset_type: str) -> np.ndarray: """There are some labels which span the frame horizontally without actually covering an object.""" assert dataset_type in {"gen1", "gen4"} w_lbl = labels["w"] max_width = (9 * dataset_2_width[dataset_type]) // 10 side_ok = w_lbl <= max_width labels = labels[side_ok] return labels def crop_to_fov_filter(labels: np.ndarray, dataset_type: str) -> np.ndarray: assert dataset_type in {"gen1", "gen4"}, f"{dataset_type=}" # In the gen1 and gen4 datasets the bounding box can be partially or completely outside the frame. # We fix this labeling error by cropping to the FOV. frame_height = dataset_2_height[dataset_type] frame_width = dataset_2_width[dataset_type] x_left = labels["x"] y_top = labels["y"] x_right = x_left + labels["w"] y_bottom = y_top + labels["h"] x_left_cropped = np.clip(x_left, a_min=0, a_max=frame_width - 1) y_top_cropped = np.clip(y_top, a_min=0, a_max=frame_height - 1) x_right_cropped = np.clip(x_right, a_min=0, a_max=frame_width - 1) y_bottom_cropped = np.clip(y_bottom, a_min=0, a_max=frame_height - 1) w_cropped = x_right_cropped - x_left_cropped assert np.all(w_cropped >= 0) h_cropped = y_bottom_cropped - y_top_cropped assert np.all(h_cropped >= 0) labels["x"] = x_left_cropped labels["y"] = y_top_cropped labels["w"] = w_cropped labels["h"] = h_cropped # Remove bboxes that have 0 height or width keep = (labels["w"] > 0) & (labels["h"] > 0) labels = labels[keep] return labels def prophesee_remove_labels_filter_gen4(labels: np.ndarray) -> np.ndarray: # Original gen4 labels: pedestrian, two wheeler, car, truck, bus, traffic sign, traffic light # gen4 labels to keep: pedestrian, two wheeler, car # gen4 labels to remove: truck, bus, traffic sign, traffic light # # class_id in {0, 1, 2, 3, 4, 5, 6} in the order mentioned above keep = labels["class_id"] <= 2 labels = labels[keep] return labels def apply_filters( labels: np.ndarray, split_type: SplitType, filter_cfg: DictConfig, dataset_type: str = "gen1", ) -> np.ndarray: assert isinstance(dataset_type, str) if dataset_type == "gen4": labels = prophesee_remove_labels_filter_gen4(labels=labels) labels = crop_to_fov_filter(labels=labels, dataset_type=dataset_type) if filter_cfg.apply_psee_bbox_filter: labels = prophesee_bbox_filter(labels=labels, dataset_type=dataset_type) else: labels = conservative_bbox_filter(labels=labels) if split_type == SplitType.TRAIN and filter_cfg.apply_faulty_bbox_filter: labels = remove_faulty_huge_bbox_filter( labels=labels, dataset_type=dataset_type ) return labels def get_base_delta_ts_for_labels_us( unique_label_ts_us: np.ndarray, dataset_type: str = "gen1" ) -> int: if dataset_type == "gen1": delta_t_us_4hz = 250000 return delta_t_us_4hz assert dataset_type == "gen4" diff_us = np.diff(unique_label_ts_us) median_diff_us = np.median(diff_us) hz = int(np.rint(10**6 / median_diff_us)) assert hz in {30, 60}, f"{hz=} but should be either 30 or 60" delta_t_us_approx_10hz = int(6 * median_diff_us if hz == 60 else 3 * median_diff_us) return delta_t_us_approx_10hz def save_labels( out_labels_dir: Path, labels_per_frame: List[np.ndarray], frame_timestamps_us: np.ndarray, match_if_exists: bool = True, ) -> None: assert len(labels_per_frame) == len(frame_timestamps_us) assert len(labels_per_frame) > 0 labels_v2 = list() objframe_idx_2_label_idx = list() start_idx = 0 for labels, timestamp in zip(labels_per_frame, frame_timestamps_us): objframe_idx_2_label_idx.append(start_idx) labels_v2.append(labels) start_idx += len(labels) assert len(labels_v2) == len(objframe_idx_2_label_idx) labels_v2 = np.concatenate(labels_v2) outfile_labels = out_labels_dir / "labels.npz" if outfile_labels.exists() and match_if_exists: data_existing = np.load(str(outfile_labels)) labels_existing = data_existing["labels"] assert np.array_equal(labels_existing, labels_v2) oi_2_li_existing = data_existing["objframe_idx_2_label_idx"] assert np.array_equal(oi_2_li_existing, objframe_idx_2_label_idx) else: np.savez( str(outfile_labels), labels=labels_v2, objframe_idx_2_label_idx=objframe_idx_2_label_idx, ) out_labels_ts_file = out_labels_dir / "timestamps_us.npy" if out_labels_ts_file.exists() and match_if_exists: frame_timestamps_us_existing = np.load(str(out_labels_ts_file)) assert np.array_equal(frame_timestamps_us_existing, frame_timestamps_us) else: np.save(str(out_labels_ts_file), frame_timestamps_us) def labels_and_ev_repr_timestamps( npy_file: Path, split_type: SplitType, filter_cfg: DictConfig, align_t_ms: int, ts_step_ev_repr_ms: int, dataset_type: str, ): assert npy_file.exists() assert npy_file.suffix == ".npy" ts_step_frame_ms = 100 assert ts_step_frame_ms >= ts_step_ev_repr_ms assert ts_step_frame_ms % ts_step_ev_repr_ms == 0 and ts_step_ev_repr_ms > 0 align_t_us = align_t_ms * 1000 delta_t_us = ts_step_ev_repr_ms * 1000 sequence_labels = np.load(str(npy_file)) assert len(sequence_labels) > 0 sequence_labels = apply_filters( labels=sequence_labels, split_type=split_type, filter_cfg=filter_cfg, dataset_type=dataset_type, ) if sequence_labels.size == 0: raise NoLabelsException unique_ts_us = np.unique(np.asarray(sequence_labels["t"], dtype="int64")) base_delta_ts_labels_us = get_base_delta_ts_for_labels_us( unique_label_ts_us=unique_ts_us, dataset_type=dataset_type ) # We extract the first label at or after align_t_us to keep it as the reference for the label extraction. unique_ts_idx_first = np.searchsorted(unique_ts_us, align_t_us, side="left") # Extract "frame" timestamps from labels and prepare ev repr ts computation num_ev_reprs_between_frame_ts = [] frame_timestamps_us = [unique_ts_us[unique_ts_idx_first]] for unique_ts_idx in range(unique_ts_idx_first + 1, len(unique_ts_us)): reference_time = frame_timestamps_us[-1] ts = unique_ts_us[unique_ts_idx] diff_to_ref = ts - reference_time base_delta_count = round(diff_to_ref / base_delta_ts_labels_us) diff_to_ref_rounded = base_delta_count * base_delta_ts_labels_us if np.abs(diff_to_ref - diff_to_ref_rounded) <= 2000: assert base_delta_count > 0 # We accept up to 2 millisecond of jitter frame_timestamps_us.append(ts) num_ev_reprs_between_frame_ts.append( base_delta_count * (ts_step_frame_ms // ts_step_ev_repr_ms) ) frame_timestamps_us = np.asarray(frame_timestamps_us, dtype="int64") assert len(frame_timestamps_us) > 0, f"{npy_file=}" start_indices_per_label = np.searchsorted( sequence_labels["t"], frame_timestamps_us, side="left" ) end_indices_per_label = np.searchsorted( sequence_labels["t"], frame_timestamps_us, side="right" ) # Create labels per "frame" labels_per_frame = [] for idx_start, idx_end in zip(start_indices_per_label, end_indices_per_label): labels = sequence_labels[idx_start:idx_end] label_time_us = labels["t"][0] assert np.all(labels["t"] == label_time_us) labels_per_frame.append(labels) if len(frame_timestamps_us) > 1: assert ( np.diff(frame_timestamps_us).min() > 98000 ), f"{np.diff(frame_timestamps_us).min()=}" # Event repr timestamps generation ev_repr_timestamps_us_end = list( reversed(range(frame_timestamps_us[0], 0, -delta_t_us)) )[1:-1] assert ( len(num_ev_reprs_between_frame_ts) == len(frame_timestamps_us) - 1 ), f"{len(num_ev_reprs_between_frame_ts)=}, {len(frame_timestamps_us)=}" for idx, (num_ev_repr_between, frame_ts_us_start, frame_ts_us_end) in enumerate( zip( num_ev_reprs_between_frame_ts, frame_timestamps_us[:-1], frame_timestamps_us[1:], ) ): new_edge_timestamps = np.asarray( np.linspace(frame_ts_us_start, frame_ts_us_end, num_ev_repr_between + 1), dtype="int64", ).tolist() is_last_iter = idx == len(num_ev_reprs_between_frame_ts) - 1 if not is_last_iter: new_edge_timestamps = new_edge_timestamps[:-1] ev_repr_timestamps_us_end.extend(new_edge_timestamps) if len(frame_timestamps_us) == 1: # special case not handled in above for loop (no iter in this case) # yes, it's hacky ... ev_repr_timestamps_us_end.append(frame_timestamps_us[0]) ev_repr_timestamps_us_end = np.asarray(ev_repr_timestamps_us_end, dtype="int64") frameidx_2_repridx = np.searchsorted( ev_repr_timestamps_us_end, frame_timestamps_us, side="left" ) assert len(frameidx_2_repridx) == len(frame_timestamps_us) # Some sanity checks: assert len(labels_per_frame) == len(frame_timestamps_us) assert len(frame_timestamps_us) == len(frameidx_2_repridx) for label, frame_ts_us, repr_idx in zip( labels_per_frame, frame_timestamps_us, frameidx_2_repridx ): assert label["t"][0] == frame_ts_us assert frame_ts_us == ev_repr_timestamps_us_end[repr_idx] return ( labels_per_frame, frame_timestamps_us, ev_repr_timestamps_us_end, frameidx_2_repridx, ) def write_event_data( in_h5_file: Path, ev_out_dir: Path, dataset: str, event_representation: RepresentationBase, ev_repr_num_events: Optional[int], ev_repr_delta_ts_ms: Optional[int], ev_repr_timestamps_us: np.ndarray, downsample_by_2: bool, frameidx2repridx: np.ndarray, ) -> None: frameidx2repridx_file = ev_out_dir / "objframe_idx_2_repr_idx.npy" if frameidx2repridx_file.exists(): frameidx2repridx_loaded = np.load(str(frameidx2repridx_file)) assert np.array_equal(frameidx2repridx_loaded, frameidx2repridx) else: np.save(str(frameidx2repridx_file), frameidx2repridx) timestamps_file = ev_out_dir / "timestamps_us.npy" if timestamps_file.exists(): timestamps_loaded = np.load(str(timestamps_file)) assert np.array_equal(timestamps_loaded, ev_repr_timestamps_us) else: np.save(str(timestamps_file), ev_repr_timestamps_us) write_event_representations( in_h5_file=in_h5_file, ev_out_dir=ev_out_dir, dataset=dataset, event_representation=event_representation, ev_repr_num_events=ev_repr_num_events, ev_repr_delta_ts_ms=ev_repr_delta_ts_ms, ev_repr_timestamps_us=ev_repr_timestamps_us, downsample_by_2=downsample_by_2, overwrite_if_exists=False, ) def downsample_ev_repr(x: torch.Tensor, scale_factor: float): assert 0 < scale_factor < 1 orig_dtype = x.dtype if orig_dtype == torch.int8: x = torch.asarray(x, dtype=torch.int16) x = torch.asarray(x + 128, dtype=torch.uint8) x = torch.nn.functional.interpolate( x, scale_factor=scale_factor, mode="nearest-exact" ) if orig_dtype == torch.int8: x = torch.asarray(x, dtype=torch.int16) x = torch.asarray(x - 128, dtype=torch.int8) return x def write_event_representations( in_h5_file: Path, ev_out_dir: Path, dataset: str, event_representation: RepresentationBase, ev_repr_num_events: Optional[int], ev_repr_delta_ts_ms: Optional[int], ev_repr_timestamps_us: np.ndarray, downsample_by_2: bool, overwrite_if_exists: bool = False, ) -> None: ev_outfile = ( ev_out_dir / f"event_representations{'_ds2_nearest' if downsample_by_2 else ''}.h5" ) if ev_outfile.exists() and not overwrite_if_exists: return ev_outfile_in_progress = ev_outfile.parent / ( ev_outfile.stem + "_in_progress" + ev_outfile.suffix ) if ev_outfile_in_progress.exists(): os.remove(ev_outfile_in_progress) ev_repr_shape = tuple(event_representation.get_shape()) if downsample_by_2: ev_repr_shape = ev_repr_shape[0], ev_repr_shape[1] // 2, ev_repr_shape[2] // 2 ev_repr_dtype = event_representation.get_numpy_dtype() with H5Reader(in_h5_file, dataset=dataset) as h5_reader, H5Writer( ev_outfile_in_progress, key="data", ev_repr_shape=ev_repr_shape, numpy_dtype=ev_repr_dtype, ) as h5_writer: height, width = h5_reader.get_height_and_width() if downsample_by_2: assert (height // 2, width // 2) == ev_repr_shape[-2:] else: assert (height, width) == ev_repr_shape[-2:] ev_ts_us = h5_reader.time end_indices = np.searchsorted(ev_ts_us, ev_repr_timestamps_us, side="right") if ev_repr_num_events is not None: start_indices = np.maximum(end_indices - ev_repr_num_events, 0) else: assert ev_repr_delta_ts_ms is not None start_indices = np.searchsorted( ev_ts_us, ev_repr_timestamps_us - ev_repr_delta_ts_ms * 1000, side="left", ) for idx_start, idx_end in zip(start_indices, end_indices): ev_window = h5_reader.get_event_slice(idx_start=idx_start, idx_end=idx_end) ev_repr = event_representation.construct( x=ev_window["x"], y=ev_window["y"], pol=ev_window["p"], time=ev_window["t"], ) if downsample_by_2: ev_repr = ev_repr.unsqueeze(0) ev_repr = downsample_ev_repr(x=ev_repr, scale_factor=0.5) ev_repr_numpy = ev_repr.numpy()[0] else: ev_repr_numpy = ev_repr.numpy() h5_writer.add_data(ev_repr_numpy) num_written_ev_repr = h5_writer.get_current_length() assert num_written_ev_repr == len(ev_repr_timestamps_us) os.rename(ev_outfile_in_progress, ev_outfile) def process_sequence( dataset: str, filter_cfg: DictConfig, event_representation: RepresentationBase, ev_repr_num_events: Optional[int], ev_repr_delta_ts_ms: Optional[int], ts_step_ev_repr_ms: int, downsample_by_2: bool, sequence_data: Dict[DataKeys, Union[Path, SplitType]], ): in_npy_file = sequence_data[DataKeys.InNPY] in_h5_file = sequence_data[DataKeys.InH5] out_labels_dir = sequence_data[DataKeys.OutLabelDir] out_ev_repr_dir = sequence_data[DataKeys.OutEvReprDir] split_type = sequence_data[DataKeys.SplitType] assert out_labels_dir.is_dir() assert ts_step_ev_repr_ms > 0 assert bool(ev_repr_num_events is not None) ^ bool( ev_repr_delta_ts_ms is not None ), f"{ev_repr_num_events=}, {ev_repr_delta_ts_ms=}" # 1) extract: labels_per_frame, frame_timestamps_us, ev_repr_timestamps_us, frameidx2repridx align_t_ms = 100 try: ( labels_per_frame, frame_timestamps_us, ev_repr_timestamps_us, frameidx2repridx, ) = labels_and_ev_repr_timestamps( npy_file=in_npy_file, split_type=split_type, filter_cfg=filter_cfg, align_t_ms=align_t_ms, ts_step_ev_repr_ms=ts_step_ev_repr_ms, dataset_type=dataset, ) except NoLabelsException: parent_dir = out_labels_dir.parent print(f"No labels after filtering. Deleting {str(parent_dir)}") shutil.rmtree(parent_dir) return # 2) save: labels_per_frame, frame_timestamps_us save_labels( out_labels_dir=out_labels_dir, labels_per_frame=labels_per_frame, frame_timestamps_us=frame_timestamps_us, ) # 3) retrieve event data, compute event representations and save them write_event_data( in_h5_file=in_h5_file, ev_out_dir=out_ev_repr_dir, dataset=dataset, event_representation=event_representation, ev_repr_num_events=ev_repr_num_events, ev_repr_delta_ts_ms=ev_repr_delta_ts_ms, ev_repr_timestamps_us=ev_repr_timestamps_us, downsample_by_2=downsample_by_2, frameidx2repridx=frameidx2repridx, ) class AggregationType(Enum): COUNT = auto() DURATION = auto() aggregation_2_string = { AggregationType.DURATION: "dt", AggregationType.COUNT: "ne", } @dataclass class FilterConf: apply_psee_bbox_filter: bool = MISSING apply_faulty_bbox_filter: bool = MISSING @dataclass class EventWindowExtractionConf: method: AggregationType = MISSING value: int = MISSING @dataclass class StackedHistogramConf: name: str = MISSING nbins: int = MISSING count_cutoff: Optional[int] = MISSING event_window_extraction: EventWindowExtractionConf = field( default_factory=EventWindowExtractionConf ) fastmode: bool = True @dataclass class MixedDensityEventStackConf: name: str = MISSING nbins: int = MISSING count_cutoff: Optional[int] = MISSING event_window_extraction: EventWindowExtractionConf = field( default_factory=EventWindowExtractionConf ) name_2_structured_config = { "stacked_histogram": StackedHistogramConf, "mixeddensity_stack": MixedDensityEventStackConf, } class EventRepresentationFactory(ABC): def __init__(self, config: DictConfig): self.config = config @property @abstractmethod def name(self) -> str: ... @abstractmethod def create(self, height: int, width: int) -> Any: ... class StackedHistogramFactory(EventRepresentationFactory): @property def name(self) -> str: extraction = self.config.event_window_extraction return f"{self.config.name}_{aggregation_2_string[extraction.method]}={extraction.value}_nbins={self.config.nbins}" def create(self, height: int, width: int) -> StackedHistogram: return StackedHistogram( bins=self.config.nbins, height=height, width=width, count_cutoff=self.config.count_cutoff, fastmode=self.config.fastmode, ) class MixedDensityStackFactory(EventRepresentationFactory): @property def name(self) -> str: extraction = self.config.event_window_extraction cutoff_str = ( f"_cutoff={self.config.count_cutoff}" if self.config.count_cutoff is not None else "" ) return f"{self.config.name}_{aggregation_2_string[extraction.method]}={extraction.value}_nbins={self.config.nbins}{cutoff_str}" def create(self, height: int, width: int) -> MixedDensityEventStack: return MixedDensityEventStack( bins=self.config.nbins, height=height, width=width, count_cutoff=self.config.count_cutoff, ) name_2_ev_repr_factory = { "stacked_histogram": StackedHistogramFactory, "mixeddensity_stack": MixedDensityStackFactory, } def get_configuration( ev_repr_yaml_config: Path, extraction_yaml_config: Path ) -> DictConfig: config = OmegaConf.load(ev_repr_yaml_config) event_window_extraction_config = OmegaConf.load(extraction_yaml_config) event_window_extraction_config = OmegaConf.merge( OmegaConf.structured(EventWindowExtractionConf), event_window_extraction_config ) config.event_window_extraction = event_window_extraction_config config_schema = OmegaConf.structured(name_2_structured_config[config.name]) config = OmegaConf.merge(config_schema, config) return config if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("input_dir") parser.add_argument("target_dir") parser.add_argument( "ev_repr_yaml_config", help="Path to event representation yaml config file" ) parser.add_argument( "extraction_yaml_config", help="Path to event window extraction yaml config file", ) parser.add_argument( "bbox_filter_yaml_config", help="Path to bbox filter yaml config file" ) parser.add_argument("-ds", "--dataset", default="gen1", help="gen1 or gen4") parser.add_argument( "-np", "--num_processes", type=int, default=1, help="Num proceesses to run in parallel", ) args = parser.parse_args() num_processes = args.num_processes dataset = args.dataset assert dataset in ("gen1", "gen4") downsample_by_2 = True if dataset == "gen4" else False config = get_configuration( ev_repr_yaml_config=Path(args.ev_repr_yaml_config), extraction_yaml_config=Path(args.extraction_yaml_config), ) bbox_filter_yaml_config = Path(args.bbox_filter_yaml_config) assert bbox_filter_yaml_config.exists() filter_cfg = OmegaConf.load(str(bbox_filter_yaml_config)) filter_cfg = OmegaConf.merge(OmegaConf.structured(FilterConf), filter_cfg) print("") print(OmegaConf.to_yaml(config)) ev_repr_factory: EventRepresentationFactory = name_2_ev_repr_factory[config.name]( config ) height = dataset_2_height[args.dataset] width = dataset_2_width[args.dataset] ev_repr = ev_repr_factory.create(height=height, width=width) ev_repr_string = ev_repr_factory.name dataset_input_path = Path(args.input_dir) train_path = dataset_input_path / "train" val_path = dataset_input_path / "val" test_path = dataset_input_path / "test" target_dir = Path(args.target_dir) os.makedirs(target_dir, exist_ok=True) assert train_path.exists(), f"{train_path=}" assert val_path.exists(), f"{val_path=}" assert test_path.exists(), f"{test_path=}" seq_data_list = list() for split in [train_path, val_path, test_path]: split_out_dir = target_dir / split.name os.makedirs(split_out_dir, exist_ok=True) for npy_file in split.iterdir(): if npy_file.suffix != ".npy": continue h5f_path = npy_file.parent / ( npy_file.stem.split("bbox")[0] + f"td{'.dat' if dataset == 'gen1' else ''}.h5" ) assert h5f_path.exists(), f"{h5f_path=}" dir_name = npy_file.stem.split("_bbox")[0] if dir_name in dirs_to_ignore[dataset]: continue out_seq_path = split_out_dir / dir_name out_labels_path = out_seq_path / "labels_v2" os.makedirs(out_labels_path, exist_ok=True) out_ev_repr_parent_path = out_seq_path / "event_representations_v2" out_ev_repr_path = out_ev_repr_parent_path / ev_repr_string os.makedirs(out_ev_repr_path, exist_ok=True) sequence_data = { DataKeys.InNPY: npy_file, DataKeys.InH5: h5f_path, DataKeys.OutLabelDir: out_labels_path, DataKeys.OutEvReprDir: out_ev_repr_path, DataKeys.SplitType: split_name_2_type[split.name], } seq_data_list.append(sequence_data) ev_repr_num_events = None ev_repr_delta_ts_ms = None if config.event_window_extraction.method == AggregationType.COUNT: ev_repr_num_events = config.event_window_extraction.value else: assert config.event_window_extraction.method == AggregationType.DURATION ev_repr_delta_ts_ms = config.event_window_extraction.value ts_step_ev_repr_ms = 50 # Could be an argument of the script. if num_processes > 1: chunksize = 1 func = partial( process_sequence, dataset, filter_cfg, ev_repr, ev_repr_num_events, ev_repr_delta_ts_ms, ts_step_ev_repr_ms, downsample_by_2, ) with get_context("spawn").Pool(num_processes) as pool: with tqdm(total=len(seq_data_list), desc="sequences") as pbar: for _ in pool.imap_unordered( func, iterable=seq_data_list, chunksize=chunksize ): pbar.update() else: for entry in tqdm(seq_data_list, desc="sequences"): process_sequence( dataset=dataset, filter_cfg=filter_cfg, event_representation=ev_repr, ev_repr_num_events=ev_repr_num_events, ev_repr_delta_ts_ms=ev_repr_delta_ts_ms, ts_step_ev_repr_ms=ts_step_ev_repr_ms, downsample_by_2=downsample_by_2, sequence_data=entry, ) ================================================ FILE: RVT/scripts/genx/preprocess_dataset.sh ================================================ NUM_PROCESSES=20 # set to the number of parallel processes to use DATA_DIR=/data/scratch1/nzubic/datasets/gen1_tar/ DEST_DIR=/data/scratch1/nzubic/datasets/RVT/gen1_frequencies/gen1_200hz/ FREQUENCY=conf_preprocess/extraction/frequencies/const_duration_200hz.yaml python preprocess_dataset.py ${DATA_DIR} ${DEST_DIR} conf_preprocess/representation/stacked_hist.yaml ${FREQUENCY} \ conf_preprocess/filter_gen1.yaml -ds gen1 -np ${NUM_PROCESSES} ================================================ FILE: RVT/scripts/viz/viz_gt.py ================================================ import os os.environ["OMP_NUM_THREADS"] = "1" # export OMP_NUM_THREADS=1 os.environ["OPENBLAS_NUM_THREADS"] = "1" # export OPENBLAS_NUM_THREADS=1 os.environ["MKL_NUM_THREADS"] = "1" # export MKL_NUM_THREADS=1 os.environ["VECLIB_MAXIMUM_THREADS"] = "1" # export VECLIB_MAXIMUM_THREADS=1 os.environ["NUMEXPR_NUM_THREADS"] = "1" # export NUMEXPR_NUM_THREADS=1 from pathlib import Path import sys current_filepath = Path(os.path.realpath(__file__)) sys.path.insert(0, str(current_filepath.parent.parent.parent)) from typing import Tuple, Optional import imageio.v3 as iio import torch as th from tqdm import tqdm from data.utils.types import DataType, DatasetType from data.genx_utils.sequence_for_streaming import SequenceForIter from data.genx_utils.labels import ObjectLabels from utils.evaluation.prophesee.io.box_loading import loaded_label_to_prophesee from callbacks.viz_base import VizCallbackBase import cv2 import numpy as np import bbox_visualizer as bbv import hdf5plugin LABELMAP_GEN1 = ("car", "pedestrian") LABELMAP_GEN4_SHORT = ("pedestrian", "two wheeler", "car") def draw_bboxes_bbv( img, boxes, labelmap=LABELMAP_GEN1, hd_resolution: bool = False ) -> np.ndarray: """ draw bboxes in the image img """ colors = cv2.applyColorMap(np.arange(0, 255).astype(np.uint8), cv2.COLORMAP_HSV) colors = [tuple(*item) for item in colors.tolist()] if labelmap == LABELMAP_GEN1: classid2colors = { 0: (255, 255, 0), # car -> yellow (rgb) 1: (0, 0, 255), # ped -> blue (rgb) } scale_multiplier = 4 else: assert labelmap == LABELMAP_GEN4_SHORT classid2colors = { 0: (0, 0, 255), # ped -> blue (rgb) 1: (0, 255, 255), # 2-wheeler cyan (rgb) 2: (255, 255, 0), # car -> yellow (rgb) } scale_multiplier = 1 if hd_resolution else 2 add_score = True ht, wd, ch = img.shape dim_new_wh = (int(wd * scale_multiplier), int(ht * scale_multiplier)) if scale_multiplier != 1: img = cv2.resize(img, dim_new_wh, interpolation=cv2.INTER_AREA) for i in range(boxes.shape[0]): pt1 = (int(boxes["x"][i]), int(boxes["y"][i])) size = (int(boxes["w"][i]), int(boxes["h"][i])) pt2 = (pt1[0] + size[0], pt1[1] + size[1]) bbox = (pt1[0], pt1[1], pt2[0], pt2[1]) bbox = tuple(x * scale_multiplier for x in bbox) score = boxes["class_confidence"][i] class_id = boxes["class_id"][i] class_name = labelmap[class_id % len(labelmap)] bbox_txt = class_name if add_score: bbox_txt += f" {score:.2f}" color_tuple_rgb = classid2colors[class_id] img = bbv.draw_rectangle(img, bbox, bbox_color=color_tuple_rgb) img = bbv.add_label( img, bbox_txt, bbox, text_bg_color=color_tuple_rgb, top=True ) return img def draw_predictions( ev_repr: th.Tensor, predictions_proph, hd_resolution: bool = False, labelmap=LABELMAP_GEN4_SHORT, ): img = VizCallbackBase.ev_repr_to_img(ev_repr.cpu().numpy()) if predictions_proph is not None: img = draw_bboxes_bbv( img, predictions_proph, labelmap=labelmap, hd_resolution=hd_resolution ) return img def gen_gt_generator( seq_path: Path, ev_representation_name: str, downsample_by_factor_2: bool, dataset_type: DatasetType = DatasetType.GEN4, ) -> Tuple[th.Tensor, Optional[ObjectLabels]]: sequence_length = 5 if dataset_type == DatasetType.GEN1: map_dataset = SequenceForIter( path=seq_path, ev_representation_name=ev_representation_name, sequence_length=sequence_length, dataset_type=DatasetType.GEN1, downsample_by_factor_2=downsample_by_factor_2, ) else: map_dataset = SequenceForIter( path=seq_path, ev_representation_name=ev_representation_name, sequence_length=sequence_length, dataset_type=DatasetType.GEN4, downsample_by_factor_2=downsample_by_factor_2, ) iter_dataset = map_dataset.to_iter_datapipe() for data in iter_dataset: seq_ev_reprs = data[DataType.EV_REPR] seq_labels = data[DataType.OBJLABELS_SEQ] for idx, ev_repr in enumerate(seq_ev_reprs): labels = seq_labels[idx] yield ev_repr, labels if __name__ == "__main__": SEQUENCE_PATH = "/data/scratch1/nzubic/datasets/RVT/gen1_frequencies/gen1_40hz/test/17-04-04_11-00-13_cut_15_500000_60500000/" OUT_DIR_PATH = "/data/scratch1/nzubic/out_viz/" DOWNSAMPLE = False EV_REPR_NAME = "stacked_histogram_dt=25_nbins=10" # dt varies depending on different frequencies DATASET_TYPE = DatasetType.GEN1 seq_path = Path(SEQUENCE_PATH) out_dir = Path(OUT_DIR_PATH) os.makedirs(out_dir, exist_ok=False) if DATASET_TYPE == DatasetType.GEN1: labelmap = LABELMAP_GEN1 else: labelmap = LABELMAP_GEN4_SHORT viz_at_hd_resolution = None prev_img_with_labels = None for idx, (ev_repr, labels) in enumerate( tqdm( gen_gt_generator( seq_path=seq_path, ev_representation_name=EV_REPR_NAME, downsample_by_factor_2=DOWNSAMPLE, dataset_type=DATASET_TYPE, ) ) ): if viz_at_hd_resolution is None: height, width = ev_repr.shape[-2:] viz_at_hd_resolution = height * width > 9e5 have_labels = labels is not None labels_proph = loaded_label_to_prophesee(labels) if have_labels else None img = draw_predictions( ev_repr=ev_repr, predictions_proph=labels_proph, hd_resolution=viz_at_hd_resolution, labelmap=labelmap, ) filename = f"{idx}".zfill(6) + ".png" img_filepath = out_dir / filename if have_labels or prev_img_with_labels is None: img_to_write = img else: img_to_write = prev_img_with_labels iio.imwrite(str(img_filepath), img_to_write) if labels_proph is not None: prev_img_with_labels = img ================================================ FILE: RVT/train.py ================================================ import os os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" os.environ["OMP_NUM_THREADS"] = "1" os.environ["OPENBLAS_NUM_THREADS"] = "1" os.environ["MKL_NUM_THREADS"] = "1" os.environ["VECLIB_MAXIMUM_THREADS"] = "1" os.environ["NUMEXPR_NUM_THREADS"] = "1" import torch torch.multiprocessing.set_sharing_strategy("file_system") from torch.backends import cuda, cudnn cuda.matmul.allow_tf32 = True cudnn.allow_tf32 = True import hydra import hdf5plugin from omegaconf import DictConfig, OmegaConf import lightning.pytorch as pl from lightning.pytorch.callbacks import LearningRateMonitor, ModelSummary from lightning.pytorch.strategies import DDPStrategy from callbacks.custom import get_ckpt_callback, get_viz_callback from callbacks.gradflow import GradFlowLogCallback from config.modifier import dynamically_modify_train_config from data.utils.types import DatasetSamplingMode from loggers.utils import get_wandb_logger, get_ckpt_path from modules.utils.fetch import fetch_data_module, fetch_model_module from modules.detection import Module @hydra.main(config_path="config", config_name="train", version_base="1.2") def main(config: DictConfig): dynamically_modify_train_config(config) # Just to check whether config can be resolved OmegaConf.to_container(config, resolve=True, throw_on_missing=True) print("------ Configuration ------") print(OmegaConf.to_yaml(config)) print("---------------------------") # --------------------- # Reproducibility # --------------------- dataset_train_sampling = config.dataset.train.sampling assert dataset_train_sampling in iter(DatasetSamplingMode) disable_seed_everything = dataset_train_sampling in ( DatasetSamplingMode.STREAM, DatasetSamplingMode.MIXED, ) if disable_seed_everything: print( "Disabling PL seed everything because of unresolved issues with shuffling during training on streaming " "datasets" ) seed = config.reproduce.seed_everything if seed is not None and not disable_seed_everything: assert isinstance(seed, int) print(f"USING pl.seed_everything WITH {seed=}") pl.seed_everything(seed=seed, workers=True) # --------------------- # DDP # --------------------- gpu_config = config.hardware.gpus gpus = ( OmegaConf.to_container(gpu_config) if OmegaConf.is_config(gpu_config) else gpu_config ) gpus = gpus if isinstance(gpus, list) else [gpus] distributed_backend = config.hardware.dist_backend assert distributed_backend in ("nccl", "gloo"), f"{distributed_backend=}" strategy = ( DDPStrategy( process_group_backend=distributed_backend, find_unused_parameters=True, gradient_as_bucket_view=True, ) if len(gpus) > 1 else "auto" ) # --------------------- # Data # --------------------- data_module = fetch_data_module(config=config) # --------------------- # Logging and Checkpoints # --------------------- logger = get_wandb_logger(config) ckpt_path = None if config.wandb.artifact_name is not None: ckpt_path = get_ckpt_path(logger, wandb_config=config.wandb) # --------------------- # Model # --------------------- module = fetch_model_module(config=config) if ckpt_path is not None and config.wandb.resume_only_weights: print("Resuming only the weights instead of the full training state") module = Module.load_from_checkpoint( str(ckpt_path), **{"full_config": config}, strict=False ) ckpt_path = None # --------------------- # Callbacks and Misc # --------------------- callbacks = list() callbacks.append(get_ckpt_callback(config)) callbacks.append(GradFlowLogCallback(config.logging.train.log_model_every_n_steps)) if config.training.lr_scheduler.use: callbacks.append(LearningRateMonitor(logging_interval="step")) if ( config.logging.train.high_dim.enable or config.logging.validation.high_dim.enable ): viz_callback = get_viz_callback(config=config) callbacks.append(viz_callback) callbacks.append(ModelSummary(max_depth=2)) logger.watch( model=module, log="all", log_freq=config.logging.train.log_model_every_n_steps, log_graph=True, ) # --------------------- # Training # --------------------- val_check_interval = config.validation.val_check_interval check_val_every_n_epoch = config.validation.check_val_every_n_epoch assert val_check_interval is None or check_val_every_n_epoch is None trainer = pl.Trainer( accelerator="gpu", callbacks=callbacks, enable_checkpointing=True, val_check_interval=val_check_interval, check_val_every_n_epoch=check_val_every_n_epoch, default_root_dir=None, devices=gpus, gradient_clip_val=config.training.gradient_clip_val, gradient_clip_algorithm="value", limit_train_batches=config.training.limit_train_batches, limit_val_batches=config.validation.limit_val_batches, logger=logger, log_every_n_steps=config.logging.train.log_every_n_steps, plugins=None, precision=config.training.precision, max_epochs=config.training.max_epochs, max_steps=config.training.max_steps, strategy=strategy, sync_batchnorm=False if strategy == "auto" else True, # move_metrics_to_cpu=False, benchmark=config.reproduce.benchmark, deterministic=config.reproduce.deterministic_flag, ) trainer.fit(model=module, ckpt_path=ckpt_path, datamodule=data_module) if __name__ == "__main__": main() ================================================ FILE: RVT/utils/evaluation/prophesee/__init__.py ================================================ ================================================ FILE: RVT/utils/evaluation/prophesee/evaluation.py ================================================ from .io.box_filtering import filter_boxes from .metrics.coco_eval import evaluate_detection def evaluate_list( result_boxes_list, gt_boxes_list, height: int, width: int, camera: str = "gen1", apply_bbox_filters: bool = True, downsampled_by_2: bool = False, return_aps: bool = True, ): assert camera in {"gen1", "gen4"} if camera == "gen1": classes = ("car", "pedestrian") elif camera == "gen4": classes = ("pedestrian", "two-wheeler", "car") else: raise NotImplementedError if apply_bbox_filters: # Default values taken from: https://github.com/prophesee-ai/prophesee-automotive-dataset-toolbox/blob/0393adea2bf22d833893c8cb1d986fcbe4e6f82d/src/psee_evaluator.py#L23-L24 min_box_diag = 60 if camera == "gen4" else 30 # In the supplementary mat, they say that min_box_side is 20 for gen4. min_box_side = 20 if camera == "gen4" else 10 if downsampled_by_2: assert min_box_diag % 2 == 0 min_box_diag //= 2 assert min_box_side % 2 == 0 min_box_side //= 2 half_sec_us = int(5e5) filter_boxes_fn = lambda x: filter_boxes( x, half_sec_us, min_box_diag, min_box_side ) gt_boxes_list = map(filter_boxes_fn, gt_boxes_list) # NOTE: We also filter the prediction to follow the prophesee protocol of evaluation. result_boxes_list = map(filter_boxes_fn, result_boxes_list) return evaluate_detection( gt_boxes_list, result_boxes_list, height=height, width=width, classes=classes, return_aps=return_aps, ) ================================================ FILE: RVT/utils/evaluation/prophesee/evaluator.py ================================================ from typing import Any, List, Optional, Dict from warnings import warn import numpy as np from utils.evaluation.prophesee.evaluation import evaluate_list class PropheseeEvaluator: LABELS = "lables" PREDICTIONS = "predictions" def __init__(self, dataset: str, downsample_by_2: bool): super().__init__() assert dataset in {"gen1", "gen4"} self.dataset = dataset self.downsample_by_2 = downsample_by_2 self._buffer = None self._buffer_empty = True self._reset_buffer() def _reset_buffer(self): self._buffer_empty = True self._buffer = { self.LABELS: list(), self.PREDICTIONS: list(), } def _add_to_buffer(self, key: str, value: List[np.ndarray]): assert isinstance(value, list) for entry in value: assert isinstance(entry, np.ndarray) self._buffer_empty = False assert self._buffer is not None self._buffer[key].extend(value) def _get_from_buffer(self, key: str) -> List[np.ndarray]: assert not self._buffer_empty assert self._buffer is not None return self._buffer[key] def add_predictions(self, predictions: List[np.ndarray]): self._add_to_buffer(self.PREDICTIONS, predictions) def add_labels(self, labels: List[np.ndarray]): self._add_to_buffer(self.LABELS, labels) def reset_buffer(self) -> None: # E.g. call in on_validation_epoch_start self._reset_buffer() def has_data(self): return not self._buffer_empty def evaluate_buffer( self, img_height: int, img_width: int ) -> Optional[Dict[str, Any]]: # e.g call in on_validation_epoch_end if self._buffer_empty: warn( "Attempt to use prophesee evaluation buffer, but it is empty", UserWarning, stacklevel=2, ) return labels = self._get_from_buffer(self.LABELS) predictions = self._get_from_buffer(self.PREDICTIONS) assert len(labels) == len(predictions) metrics = evaluate_list( result_boxes_list=predictions, gt_boxes_list=labels, height=img_height, width=img_width, apply_bbox_filters=True, downsampled_by_2=self.downsample_by_2, camera=self.dataset, ) return metrics ================================================ FILE: RVT/utils/evaluation/prophesee/io/__init__.py ================================================ ================================================ FILE: RVT/utils/evaluation/prophesee/io/box_filtering.py ================================================ """ Define same filtering that we apply in: "Learning to detect objects on a 1 Megapixel Event Camera" by Etienne Perot et al. Namely we apply 2 different filters: 1. skip all boxes before 0.5s (before we assume it is unlikely you have sufficient historic) 2. filter all boxes whose diagonal <= min_box_diag**2 and whose side <= min_box_side Copyright: (c) 2019-2020 Prophesee """ from __future__ import print_function import numpy as np def filter_boxes(boxes, skip_ts=int(5e5), min_box_diag=60, min_box_side=20): """Filters boxes according to the paper rule. To note: the default represents our threshold when evaluating GEN4 resolution (1280x720) To note: we assume the initial time of the video is always 0 Args: boxes (np.ndarray): structured box array with fields ['t','x','y','w','h','class_id','track_id','class_confidence'] (example BBOX_DTYPE is provided in src/box_loading.py) Returns: boxes: filtered boxes """ ts = boxes["t"] width = boxes["w"] height = boxes["h"] diag_square = width**2 + height**2 mask = ( (ts > skip_ts) * (diag_square >= min_box_diag**2) * (width >= min_box_side) * (height >= min_box_side) ) return boxes[mask] ================================================ FILE: RVT/utils/evaluation/prophesee/io/box_loading.py ================================================ """ Defines some tools to handle events. In particular : -> defines events' types -> defines functions to read events from binary .dat files using numpy -> defines functions to write events to binary .dat files using numpy Copyright: (c) 2019-2020 Prophesee """ from __future__ import print_function from typing import List, Optional, Tuple import numpy as np import torch as th from data.genx_utils.labels import ObjectLabels BBOX_DTYPE = np.dtype( { "names": ["t", "x", "y", "w", "h", "class_id", "track_id", "class_confidence"], "formats": [" np.ndarray: loaded_labels.numpy_() loaded_label_proph = np.zeros((len(loaded_labels),), dtype=BBOX_DTYPE) for name in BBOX_DTYPE.names: if name == "track_id": # We don't have that and don't need it continue loaded_label_proph[name] = np.asarray( loaded_labels.get(name), dtype=BBOX_DTYPE[name] ) return loaded_label_proph def to_prophesee( loaded_label_list: LOADED_LABELS, yolox_pred_list: YOLOX_PRED_PROCESSED ) -> Tuple[List[np.ndarray], List[np.ndarray]]: assert len(loaded_label_list) == len(yolox_pred_list) loaded_label_list_proph = [] yolox_pred_list_proph = [] for loaded_labels, yolox_preds in zip(loaded_label_list, yolox_pred_list): # TODO: use loaded_label_to_prophesee func here time = None # --- LOADED LABELS --- loaded_labels.numpy_() loaded_label_proph = np.zeros((len(loaded_labels),), dtype=BBOX_DTYPE) for name in BBOX_DTYPE.names: if name == "track_id": # We don't have that and don't need it continue loaded_label_proph[name] = np.asarray( loaded_labels.get(name), dtype=BBOX_DTYPE[name] ) if name == "t": time = np.unique(loaded_labels.get(name)) assert time.size == 1 time = time.item() loaded_label_list_proph.append(loaded_label_proph) # --- YOLOX PREDICTIONS --- # Assumes batch of post-processed predictions from YoloX Head. # See postprocessing: https://github.com/Megvii-BaseDetection/YOLOX/blob/a5bb5ab12a61b8a25a5c3c11ae6f06397eb9b296/yolox/utils/boxes.py#L32 # Detections ordered as (x1, y1, x2, y2, obj_conf, class_conf, class_pred) num_pred = 0 if yolox_preds is None else yolox_preds.shape[0] yolox_pred_proph = np.zeros((num_pred,), dtype=BBOX_DTYPE) if num_pred > 0: yolox_preds = yolox_preds.detach().cpu().numpy() assert yolox_preds.shape == (num_pred, 7) yolox_pred_proph["t"] = np.ones((num_pred,), dtype=BBOX_DTYPE["t"]) * time yolox_pred_proph["x"] = np.asarray(yolox_preds[:, 0], dtype=BBOX_DTYPE["x"]) yolox_pred_proph["y"] = np.asarray(yolox_preds[:, 1], dtype=BBOX_DTYPE["y"]) yolox_pred_proph["w"] = np.asarray( yolox_preds[:, 2] - yolox_preds[:, 0], dtype=BBOX_DTYPE["w"] ) yolox_pred_proph["h"] = np.asarray( yolox_preds[:, 3] - yolox_preds[:, 1], dtype=BBOX_DTYPE["h"] ) yolox_pred_proph["class_id"] = np.asarray( yolox_preds[:, 6], dtype=BBOX_DTYPE["class_id"] ) yolox_pred_proph["class_confidence"] = np.asarray( yolox_preds[:, 5], dtype=BBOX_DTYPE["class_confidence"] ) yolox_pred_list_proph.append(yolox_pred_proph) return loaded_label_list_proph, yolox_pred_list_proph ================================================ FILE: RVT/utils/evaluation/prophesee/io/dat_events_tools.py ================================================ """ Defines some tools to handle events. In particular : -> defines events' types -> defines functions to read events from binary .dat files using numpy -> defines functions to write events to binary .dat files using numpy Copyright: (c) 2019-2020 Prophesee """ from __future__ import print_function import datetime import os import sys import numpy as np EV_TYPE = [("t", "u4"), ("_", "i4")] # Event2D EV_STRING = "Event2D" def load_td_data(filename, ev_count=-1, ev_start=0): """ Loads TD data from files generated by the StreamLogger consumer for Event2D events [ts,x,y,p]. The type ID in the file header must be 0. args : - path to a dat file - number of event (all if set to the default -1) - index of the first event return : - dat, a dictionary like structure containing the fields ts, x, y, p """ with open(filename, "rb") as f: _, ev_type, ev_size, _ = parse_header(f) if ev_start > 0: f.seek(ev_start * ev_size, 1) dtype = EV_TYPE dat = np.fromfile(f, dtype=dtype, count=ev_count) xyp = None if ("_", "i4") in dtype: x = np.bitwise_and(dat["_"], 16383) y = np.right_shift(np.bitwise_and(dat["_"], 268419072), 14) p = np.right_shift(np.bitwise_and(dat["_"], 268435456), 28) xyp = (x, y, p) return _dat_transfer(dat, dtype, xyp=xyp) def _dat_transfer(dat, dtype, xyp=None): """ Transfers the fields present in dtype from an old datastructure to a new datastructure xyp should be passed as a tuple args : - dat vector as directly read from file - dtype _numpy dtype_ as a list of couple of field name/ type eg [('x','i4'), ('y','f2')] - xyp optional tuple containing x,y,p etracted from a field '_'and untangled by bitshift and masking """ variables = [] xyp_index = -1 for i, (name, _) in enumerate(dtype): if name == "_": xyp_index = i continue variables.append((name, dat[name])) if xyp and xyp_index == -1: print("Error dat didn't contain a '_' field !") return if xyp_index >= 0: dtype = ( dtype[:xyp_index] + [("x", "i2"), ("y", "i2"), ("p", "i2")] + dtype[xyp_index + 1 :] ) new_dat = np.empty(dat.shape[0], dtype=dtype) if xyp: new_dat["x"] = xyp[0].astype(np.uint16) new_dat["y"] = xyp[1].astype(np.uint16) new_dat["p"] = xyp[2].astype(np.uint16) for name, arr in variables: new_dat[name] = arr return new_dat def stream_td_data(file_handle, buffer, dtype, ev_count=-1): """ Streams data from opened file_handle args : - file_handle: file object - buffer: pre-allocated buffer to fill with events - dtype: expected fields - ev_count: number of events """ dat = np.fromfile(file_handle, dtype=dtype, count=ev_count) count = len(dat["t"]) for name, _ in dtype: if name == "_": buffer["x"][:count] = np.bitwise_and(dat["_"], 16383) buffer["y"][:count] = np.right_shift( np.bitwise_and(dat["_"], 268419072), 14 ) buffer["p"][:count] = np.right_shift( np.bitwise_and(dat["_"], 268435456), 28 ) else: buffer[name][:count] = dat[name] def count_events(filename): """ Returns the number of events in a dat file args : - path to a dat file """ with open(filename, "rb") as f: bod, _, ev_size, _ = parse_header(f) f.seek(0, os.SEEK_END) eod = f.tell() if (eod - bod) % ev_size != 0: raise Exception("unexpected format !") return (eod - bod) // ev_size def parse_header(f): """ Parses the header of a dat file Args: - f file handle to a dat file return : - int position of the file cursor after the header - int type of event - int size of event in bytes - size (height, width) tuple of int or None """ f.seek(0, os.SEEK_SET) bod = None end_of_header = False header = [] num_comment_line = 0 size = [None, None] # parse header while not end_of_header: bod = f.tell() line = f.readline() if sys.version_info > (3, 0): first_item = line.decode("latin-1")[:2] else: first_item = line[:2] if first_item != "% ": end_of_header = True else: words = line.split() if len(words) > 1: if words[1] == "Date": header += ["Date", words[2] + " " + words[3]] if ( words[1] == "Height" or words[1] == b"Height" ): # compliant with python 3 (and python2) size[0] = int(words[2]) header += ["Height", words[2]] if ( words[1] == "Width" or words[1] == b"Width" ): # compliant with python 3 (and python2) size[1] = int(words[2]) header += ["Width", words[2]] else: header += words[1:3] num_comment_line += 1 # parse data f.seek(bod, os.SEEK_SET) if num_comment_line > 0: # Ensure compatibility with previous files. # Read event type ev_type = np.frombuffer(f.read(1), dtype=np.uint8)[0] # Read event size ev_size = np.frombuffer(f.read(1), dtype=np.uint8)[0] else: ev_type = 0 ev_size = sum([int(n[-1]) for _, n in EV_TYPE]) bod = f.tell() return bod, ev_type, ev_size, size def write_header(filename, height=240, width=320, ev_type=0): """ write header for a dat file """ if max(height, width) > 2**14 - 1: raise ValueError( "Coordinates value exceed maximum range in" " binary .dat file format max({:d},{:d}) vs 2^14 - 1".format(height, width) ) f = open(filename, "w") f.write( "% Data file containing {:s} events.\n" "% Version 2\n".format(EV_STRINGS[ev_type]) ) now = datetime.datetime.utcnow() f.write( "% Date {}-{}-{} {}:{}:{}\n".format( now.year, now.month, now.day, now.hour, now.minute, now.second ) ) f.write("% Height {:d}\n" "% Width {:d}\n".format(height, width)) # write type and bit size ev_size = sum([int(b[-1]) for _, b in EV_TYPE]) np.array([ev_type, ev_size], dtype=np.uint8).tofile(f) f.flush() return f def write_event_buffer(f, buffers): """ writes events of fields x,y,p,t into the file object f """ # pack data as events dtype = EV_TYPE data_to_write = np.empty(len(buffers["t"]), dtype=dtype) for name, typ in buffers.dtype.fields.items(): if name == "x": x = buffers["x"].astype("i4") elif name == "y": y = np.left_shift(buffers["y"].astype("i4"), 14) elif name == "p": buffers["p"] = (buffers["p"] == 1).astype(buffers["p"].dtype) p = np.left_shift(buffers["p"].astype("i4"), 28) else: data_to_write[name] = buffers[name].astype(typ[0]) data_to_write["_"] = x + y + p # write data data_to_write.tofile(f) f.flush() ================================================ FILE: RVT/utils/evaluation/prophesee/io/npy_events_tools.py ================================================ #!/usr/bin/env python """ Defines some tools to handle events, mimicking dat_events_tools.py. In particular : -> defines functions to read events from binary .npy files using numpy -> defines functions to write events to binary .dat files using numpy (TODO later) Copyright: (c) 2015-2019 Prophesee """ from __future__ import print_function import numpy as np def stream_td_data(file_handle, buffer, dtype, ev_count=-1): """ Streams data from opened file_handle args : - file_handle: file object - buffer: pre-allocated buffer to fill with events - dtype: expected fields - ev_count: number of events """ dat = np.fromfile(file_handle, dtype=dtype, count=ev_count) count = len(dat["t"]) for name, _ in dtype: buffer[name][:count] = dat[name] def parse_header(fhandle): """ Parses the header of a .npy file Args: - f file handle to a .npy file return : - int position of the file cursor after the header - int type of event - int size of event in bytes - size (height, width) tuple of int or (None, None) """ version = np.lib.format.read_magic(fhandle) shape, fortran, dtype = np.lib.format._read_array_header(fhandle, version) assert not fortran, "Fortran order arrays not supported" # Get the number of elements in one 'row' by taking # a product over all other dimensions. if len(shape) == 0: count = 1 else: count = np.multiply.reduce(shape, dtype=np.int64) ev_size = dtype.itemsize assert ev_size != 0 start = fhandle.tell() # turn numpy.dtype into an iterable list ev_type = [(x, str(dtype.fields[x][0])) for x in dtype.names] # filter name to have only t and not ts ev_type = [(name if name != "ts" else "t", desc) for name, desc in ev_type] ev_type = [ (name if name != "confidence" else "class_confidence", desc) for name, desc in ev_type ] size = (None, None) size = (None, None) return start, ev_type, ev_size, size ================================================ FILE: RVT/utils/evaluation/prophesee/io/psee_loader.py ================================================ """ This class loads events from dat or npy files Copyright: (c) 2019-2020 Prophesee """ from __future__ import print_function import os import numpy as np from . import dat_events_tools as dat from . import npy_events_tools as npy_format class PSEELoader(object): """ PSEELoader loads a dat or npy file and stream events """ def __init__(self, datfile): """ ctor :param datfile: binary dat or npy file """ self._extension = datfile.split(".")[-1] assert self._extension in ["dat", "npy"], "input file path = {}".format(datfile) if self._extension == "dat": self._binary_format = dat elif self._extension == "npy": self._binary_format = npy_format self._file = open(datfile, "rb") ( self._start, self.ev_type, self._ev_size, self._size, ) = self._binary_format.parse_header(self._file) assert self._ev_size != 0 if self._extension == "dat": self._dtype = self._binary_format.EV_TYPE elif self._extension == "npy": self._dtype = self.ev_type else: assert False, "unsupported extension" self._decode_dtype = [] for dtype in self._dtype: if dtype[0] == "_": self._decode_dtype += [("x", "u2"), ("y", "u2"), ("p", "u1")] else: self._decode_dtype.append(dtype) # size self._file.seek(0, os.SEEK_END) self._end = self._file.tell() self._ev_count = (self._end - self._start) // self._ev_size self.done = False self._file.seek(self._start) # If the current time is t, it means that next event that will be loaded has a # timestamp superior or equal to t (event with timestamp exactly t is not loaded yet) self.current_time = 0 self.duration_s = self.total_time() * 1e-6 def reset(self): """reset at beginning of file""" self._file.seek(self._start) self.done = False self.current_time = 0 def event_count(self): """ getter on event_count :return: """ return self._ev_count def get_size(self): """ "(height, width) of the imager might be (None, None)""" return self._size def __repr__(self): """ prints properties :return: """ wrd = "" wrd += "PSEELoader:" + "\n" wrd += "-----------" + "\n" if self._extension == "dat": wrd += "Event Type: " + str(self._binary_format.EV_STRING) + "\n" elif self._extension == "npy": wrd += "Event Type: numpy array element\n" wrd += "Event Size: " + str(self._ev_size) + " bytes\n" wrd += "Event Count: " + str(self._ev_count) + "\n" wrd += "Duration: " + str(self.duration_s) + " s \n" wrd += "-----------" + "\n" return wrd def load_n_events(self, ev_count): """ load batch of n events :param ev_count: number of events that will be loaded :return: events Note that current time will be incremented to reach the timestamp of the first event not loaded yet """ event_buffer = np.empty((ev_count + 1,), dtype=self._decode_dtype) pos = self._file.tell() count = (self._end - pos) // self._ev_size if ev_count >= count: self.done = True ev_count = count self._binary_format.stream_td_data( self._file, event_buffer, self._dtype, ev_count ) self.current_time = event_buffer["t"][ev_count - 1] + 1 else: self._binary_format.stream_td_data( self._file, event_buffer, self._dtype, ev_count + 1 ) self.current_time = event_buffer["t"][ev_count] self._file.seek(pos + ev_count * self._ev_size) return event_buffer[:ev_count] def load_delta_t(self, delta_t): """ loads a slice of time. :param delta_t: (us) slice thickness :return: events Note that current time will be incremented by delta_t. If an event is timestamped at exactly current_time it will not be loaded. """ if delta_t < 1: raise ValueError( "load_delta_t(): delta_t must be at least 1 micro-second: {}".format( delta_t ) ) if self.done or (self._file.tell() >= self._end): self.done = True return np.empty((0,), dtype=self._decode_dtype) final_time = self.current_time + delta_t tmp_time = self.current_time start = self._file.tell() pos = start nevs = 0 batch = 100000 event_buffer = [] # data is read by buffers until enough events are read or until the end of the file while tmp_time < final_time and pos < self._end: count = (min(self._end, pos + batch * self._ev_size) - pos) // self._ev_size buffer = np.empty((count,), dtype=self._decode_dtype) self._binary_format.stream_td_data(self._file, buffer, self._dtype, count) tmp_time = buffer["t"][-1] event_buffer.append(buffer) nevs += count pos = self._file.tell() if tmp_time >= final_time: self.current_time = final_time else: self.current_time = tmp_time + 1 assert len(event_buffer) > 0 idx = np.searchsorted(event_buffer[-1]["t"], final_time) event_buffer[-1] = event_buffer[-1][:idx] event_buffer = np.concatenate(event_buffer) idx = len(event_buffer) self._file.seek(start + idx * self._ev_size) self.done = self._file.tell() >= self._end return event_buffer def seek_event(self, ev_count): """ seek in the file by ev_count events :param ev_count: seek in the file after ev_count events Note that current time will be set to the timestamp of the next event. """ if ev_count <= 0: self._file.seek(self._start) self.current_time = 0 elif ev_count >= self._ev_count: # we put the cursor one event before and read the last event # which puts the file cursor at the right place # current_time is set to the last event timestamp + 1 self._file.seek(self._start + (self._ev_count - 1) * self._ev_size) self.current_time = ( np.fromfile(self._file, dtype=self._dtype, count=1)["t"][0] + 1 ) else: # we put the cursor at the *ev_count*nth event self._file.seek(self._start + (ev_count) * self._ev_size) # we read the timestamp of the following event (this change the position in the file) self.current_time = np.fromfile(self._file, dtype=self._dtype, count=1)[ "t" ][0] # this is why we go back at the right position here self._file.seek(self._start + (ev_count) * self._ev_size) self.done = self._file.tell() >= self._end def seek_time(self, final_time, term_criterion=100000): """ go to the time final_time inside the file. This is implemented using a binary search algorithm :param final_time: expected time :param term_cirterion: (nb event) binary search termination criterion it will load those events in a buffer and do a numpy searchsorted so the result is always exact """ if final_time > self.total_time(): self._file.seek(self._end) self.done = True self.current_time = self.total_time() + 1 return if final_time <= 0: self.reset() return low = 0 high = self._ev_count # binary search while high - low > term_criterion: middle = (low + high) // 2 self.seek_event(middle) mid = np.fromfile(self._file, dtype=self._dtype, count=1)["t"][0] if mid > final_time: high = middle elif mid < final_time: low = middle + 1 else: self.current_time = final_time self.done = self._file.tell() >= self._end return # we now know that it is between low and high self.seek_event(low) final_buffer = np.fromfile(self._file, dtype=self._dtype, count=high - low)["t"] final_index = np.searchsorted(final_buffer, final_time) self.seek_event(low + final_index) self.current_time = final_time self.done = self._file.tell() >= self._end def total_time(self): """ get total duration of video in mus, providing there is no overflow :return: """ if not self._ev_count: return 0 # save the state of the class pos = self._file.tell() current_time = self.current_time done = self.done # read the last event's timestamp self.seek_event(self._ev_count - 1) time = np.fromfile(self._file, dtype=self._dtype, count=1)["t"][0] # restore the state self._file.seek(pos) self.current_time = current_time self.done = done return time def __del__(self): self._file.close() ================================================ FILE: RVT/utils/evaluation/prophesee/metrics/__init__.py ================================================ ================================================ FILE: RVT/utils/evaluation/prophesee/metrics/coco_eval.py ================================================ """ Compute the COCO metric on bounding box files by matching timestamps Copyright: (c) 2019-2020 Prophesee """ from __future__ import absolute_import from __future__ import division from __future__ import print_function import contextlib import os import numpy as np from pycocotools.coco import COCO try: coco_eval_type = "cpp-based" from detectron2.evaluation.fast_eval_api import COCOeval_opt as COCOeval except ImportError: coco_eval_type = "python-based" from pycocotools.cocoeval import COCOeval print(f"Using {coco_eval_type} detection evaluation") def evaluate_detection( gt_boxes_list, dt_boxes_list, classes=("car", "pedestrian"), height=240, width=304, time_tol=50000, return_aps: bool = True, ): """ Compute detection KPIs on list of boxes in the numpy format, using the COCO python API https://github.com/cocodataset/cocoapi KPIs are only computed on timestamps where there is actual at least one box (fully empty frames are not considered) :param gt_boxes_list: list of numpy array for GT boxes (one per file) :param dt_boxes_list: list of numpy array for detected boxes :param classes: iterable of classes names :param height: int for box size statistics :param width: int for box size statistics :param time_tol: int size of the temporal window in micro seconds to look for a detection around a gt box """ flattened_gt = [] flattened_dt = [] for gt_boxes, dt_boxes in zip(gt_boxes_list, dt_boxes_list): assert np.all(gt_boxes["t"][1:] >= gt_boxes["t"][:-1]) assert np.all(dt_boxes["t"][1:] >= dt_boxes["t"][:-1]) all_ts = np.unique(gt_boxes["t"]) n_steps = len(all_ts) gt_win, dt_win = _match_times(all_ts, gt_boxes, dt_boxes, time_tol) flattened_gt = flattened_gt + gt_win flattened_dt = flattened_dt + dt_win return _coco_eval( flattened_gt, flattened_dt, height, width, labelmap=classes, return_aps=return_aps, ) def _match_times(all_ts, gt_boxes, dt_boxes, time_tol): """ match ground truth boxes and ground truth detections at all timestamps using a specified tolerance return a list of boxes vectors """ gt_size = len(gt_boxes) dt_size = len(dt_boxes) windowed_gt = [] windowed_dt = [] low_gt, high_gt = 0, 0 low_dt, high_dt = 0, 0 for ts in all_ts: while low_gt < gt_size and gt_boxes[low_gt]["t"] < ts: low_gt += 1 # the high index is at least as big as the low one high_gt = max(low_gt, high_gt) while high_gt < gt_size and gt_boxes[high_gt]["t"] <= ts: high_gt += 1 # detection are allowed to be inside a window around the right detection timestamp low = ts - time_tol high = ts + time_tol while low_dt < dt_size and dt_boxes[low_dt]["t"] < low: low_dt += 1 # the high index is at least as big as the low one high_dt = max(low_dt, high_dt) while high_dt < dt_size and dt_boxes[high_dt]["t"] <= high: high_dt += 1 windowed_gt.append(gt_boxes[low_gt:high_gt]) windowed_dt.append(dt_boxes[low_dt:high_dt]) return windowed_gt, windowed_dt def _coco_eval( gts, detections, height, width, labelmap=("car", "pedestrian"), return_aps: bool = True, ): """simple helper function wrapping around COCO's Python API :params: gts iterable of numpy boxes for the ground truth :params: detections iterable of numpy boxes for the detections :params: height int :params: width int :params: labelmap iterable of class labels """ categories = [ {"id": id + 1, "name": class_name, "supercategory": "none"} for id, class_name in enumerate(labelmap) ] num_detections = 0 for detection in detections: num_detections += detection.size # Meaning: https://cocodataset.org/#detection-eval out_keys = ("AP", "AP_50", "AP_75", "AP_S", "AP_M", "AP_L") out_dict = {k: 0.0 for k in out_keys} if num_detections == 0: # Corner case at the very beginning of the training. print("no detections for evaluation found.") return out_dict if return_aps else None dataset, results = _to_coco_format( gts, detections, categories, height=height, width=width ) coco_gt = COCO() coco_gt.dataset = dataset coco_gt.createIndex() coco_pred = coco_gt.loadRes(results) coco_eval = COCOeval(coco_gt, coco_pred, "bbox") coco_eval.params.imgIds = np.arange(1, len(gts) + 1, dtype=int) coco_eval.evaluate() coco_eval.accumulate() if return_aps: with open(os.devnull, "w") as f, contextlib.redirect_stdout(f): # info: https://stackoverflow.com/questions/8391411/how-to-block-calls-to-print coco_eval.summarize() for idx, key in enumerate(out_keys): out_dict[key] = coco_eval.stats[idx] return out_dict # Print the whole summary instead without return coco_eval.summarize() def coco_eval_return_metrics(coco_eval: COCOeval): pass def _to_coco_format(gts, detections, categories, height=240, width=304): """ utilitary function producing our data in a COCO usable format """ annotations = [] results = [] images = [] # to dictionary for image_id, (gt, pred) in enumerate(zip(gts, detections)): im_id = image_id + 1 images.append( { "date_captured": "2019", "file_name": "n.a", "id": im_id, "license": 1, "url": "", "height": height, "width": width, } ) for bbox in gt: x1, y1 = bbox["x"], bbox["y"] w, h = bbox["w"], bbox["h"] area = w * h annotation = { "area": float(area), "iscrowd": False, "image_id": im_id, "bbox": [x1, y1, w, h], "category_id": int(bbox["class_id"]) + 1, "id": len(annotations) + 1, } annotations.append(annotation) for bbox in pred: image_result = { "image_id": im_id, "category_id": int(bbox["class_id"]) + 1, "score": float(bbox["class_confidence"]), "bbox": [bbox["x"], bbox["y"], bbox["w"], bbox["h"]], } results.append(image_result) dataset = { "info": {}, "licenses": [], "type": "instances", "images": images, "annotations": annotations, "categories": categories, } return dataset, results ================================================ FILE: RVT/utils/evaluation/prophesee/visualize/__init__.py ================================================ ================================================ FILE: RVT/utils/evaluation/prophesee/visualize/vis_utils.py ================================================ """ Functions to display events and boxes Copyright: (c) 2019-2020 Prophesee """ from __future__ import print_function import bbox_visualizer as bbv import cv2 import numpy as np LABELMAP_GEN1 = ("car", "pedestrian") LABELMAP_GEN4 = ( "pedestrian", "two wheeler", "car", "truck", "bus", "traffic sign", "traffic light", ) LABELMAP_GEN4_SHORT = ("pedestrian", "two wheeler", "car") def make_binary_histo(events, img=None, width=304, height=240): """ simple display function that shows negative events as blacks dots and positive as white one on a gray background args : - events structured numpy array - img (numpy array, height x width x 3) optional array to paint event on. - width int - height int return: - img numpy array, height x width x 3) """ if img is None: img = 127 * np.ones((height, width, 3), dtype=np.uint8) else: # if an array was already allocated just paint it grey img[...] = 127 if events.size: assert events["x"].max() < width, "out of bound events: x = {}, w = {}".format( events["x"].max(), width ) assert events["y"].max() < height, "out of bound events: y = {}, h = {}".format( events["y"].max(), height ) img[events["y"], events["x"], :] = 255 * events["p"][:, None] return img def draw_bboxes_bbv(img, boxes, labelmap=LABELMAP_GEN1) -> np.ndarray: """ draw bboxes in the image img """ colors = cv2.applyColorMap(np.arange(0, 255).astype(np.uint8), cv2.COLORMAP_HSV) colors = [tuple(*item) for item in colors.tolist()] if labelmap == LABELMAP_GEN1: classid2colors = { 0: (255, 255, 0), # car -> yellow (rgb) 1: (0, 0, 255), # ped -> blue (rgb) } scale_multiplier = 4 else: assert labelmap == LABELMAP_GEN4_SHORT classid2colors = { 0: (0, 0, 255), # ped -> blue (rgb) 1: (0, 255, 255), # 2-wheeler cyan (rgb) 2: (255, 255, 0), # car -> yellow (rgb) } scale_multiplier = 2 add_score = True ht, wd, ch = img.shape dim_new_wh = (int(wd * scale_multiplier), int(ht * scale_multiplier)) if scale_multiplier != 1: img = cv2.resize(img, dim_new_wh, interpolation=cv2.INTER_AREA) for i in range(boxes.shape[0]): pt1 = (int(boxes["x"][i]), int(boxes["y"][i])) size = (int(boxes["w"][i]), int(boxes["h"][i])) pt2 = (pt1[0] + size[0], pt1[1] + size[1]) bbox = (pt1[0], pt1[1], pt2[0], pt2[1]) bbox = tuple(x * scale_multiplier for x in bbox) score = boxes["class_confidence"][i] class_id = boxes["class_id"][i] class_name = labelmap[class_id % len(labelmap)] bbox_txt = class_name if add_score: bbox_txt += f" {score:.2f}" color_tuple_rgb = classid2colors[class_id] img = bbv.draw_rectangle(img, bbox, bbox_color=color_tuple_rgb) img = bbv.add_label( img, bbox_txt, bbox, text_bg_color=color_tuple_rgb, top=True ) return img def draw_bboxes(img, boxes, labelmap=LABELMAP_GEN1) -> None: """ draw bboxes in the image img """ colors = cv2.applyColorMap(np.arange(0, 255).astype(np.uint8), cv2.COLORMAP_HSV) colors = [tuple(*item) for item in colors.tolist()] for i in range(boxes.shape[0]): pt1 = (int(boxes["x"][i]), int(boxes["y"][i])) size = (int(boxes["w"][i]), int(boxes["h"][i])) pt2 = (pt1[0] + size[0], pt1[1] + size[1]) score = boxes["class_confidence"][i] class_id = boxes["class_id"][i] class_name = labelmap[class_id % len(labelmap)] color = colors[class_id * 60 % 255] center = ((pt1[0] + pt2[0]) // 2, (pt1[1] + pt2[1]) // 2) cv2.rectangle(img, pt1, pt2, color, 1) cv2.putText( img, class_name, (center[0], pt2[1] - 1), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, ) cv2.putText( img, str(score), (center[0], pt1[1] - 1), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, ) ================================================ FILE: RVT/utils/helpers.py ================================================ from typing import Union import torch as th def torch_uniform_sample_scalar(min_value: float, max_value: float): assert max_value >= min_value, f"{max_value=} is smaller than {min_value=}" if max_value == min_value: return min_value return min_value + (max_value - min_value) * th.rand(1).item() def clamp( value: Union[int, float], smallest: Union[int, float], largest: Union[int, float] ): return max(smallest, min(value, largest)) ================================================ FILE: RVT/utils/padding.py ================================================ from typing import Any, List, Tuple import torch as th import torch.nn.functional as F class InputPadderFromShape: def __init__( self, desired_hw: Tuple[int, int], mode: str = "constant", value: int = 0, type: str = "corner", ): """ :param desired_hw: Desired height and width :param mode: See torch.nn.functional.pad :param value: See torch.nn.functional.pad :param type: "corner": add zero to bottom and right """ assert isinstance(desired_hw, tuple) assert len(desired_hw) == 2 assert desired_hw[0] % 4 == 0, "Required for token mask padding" assert desired_hw[1] % 4 == 0, "Required for token mask padding" assert type in {"corner"} self.desired_hw = desired_hw self.mode = mode self.value = value self.type = type self._pad_ev_repr = None self._pad_token_mask = None @staticmethod def _pad_tensor_impl( input_tensor: th.Tensor, desired_hw: Tuple[int, int], mode: str, value: Any ) -> Tuple[th.Tensor, List[int]]: assert isinstance(input_tensor, th.Tensor) ht, wd = input_tensor.shape[-2:] ht_des, wd_des = desired_hw assert ht <= ht_des assert wd <= wd_des pad_left = 0 pad_right = wd_des - wd pad_top = 0 pad_bottom = ht_des - ht pad = [pad_left, pad_right, pad_top, pad_bottom] return ( F.pad( input_tensor, pad=pad, mode=mode, value=value if mode == "constant" else None, ), pad, ) def pad_tensor_ev_repr(self, ev_repr: th.Tensor) -> th.Tensor: padded_ev_repr, pad = self._pad_tensor_impl( input_tensor=ev_repr, desired_hw=self.desired_hw, mode=self.mode, value=self.value, ) if self._pad_ev_repr is None: self._pad_ev_repr = pad else: assert self._pad_ev_repr == pad return padded_ev_repr def pad_token_mask(self, token_mask: th.Tensor): assert isinstance(token_mask, th.Tensor) desired_hw = tuple(x // 4 for x in self.desired_hw) padded_token_mask, pad = self._pad_tensor_impl( input_tensor=token_mask, desired_hw=desired_hw, mode="constant", value=0 ) if self._pad_token_mask is None: self._pad_token_mask = pad else: assert self._pad_token_mask == pad return padded_token_mask ================================================ FILE: RVT/utils/preprocessing.py ================================================ def _blosc_opts(complevel=1, complib="blosc:zstd", shuffle="byte"): shuffle = 2 if shuffle == "bit" else 1 if shuffle == "byte" else 0 compressors = ["blosclz", "lz4", "lz4hc", "snappy", "zlib", "zstd"] complib = ["blosc:" + c for c in compressors].index(complib) args = { "compression": 32001, "compression_opts": (0, 0, 0, 0, complevel, shuffle, complib), } if shuffle > 0: # Do not use h5py shuffle if blosc shuffle is enabled. args["shuffle"] = False return args ================================================ FILE: RVT/utils/timers.py ================================================ import atexit import time from functools import wraps import numpy as np import torch cuda_timers = {} timers = {} class CudaTimer: def __init__(self, device: torch.device, timer_name: str): assert isinstance(device, torch.device) assert isinstance(timer_name, str) self.timer_name = timer_name if self.timer_name not in cuda_timers: cuda_timers[self.timer_name] = [] self.device = device self.start = None self.end = None def __enter__(self): torch.cuda.synchronize(device=self.device) self.start = time.time() return self def __exit__(self, *args): assert self.start is not None torch.cuda.synchronize(device=self.device) end = time.time() cuda_timers[self.timer_name].append(end - self.start) def cuda_timer_decorator(device: torch.device, timer_name: str): def decorator(func): @wraps(func) def wrapper(*args, **kwargs): with CudaTimer(device=device, timer_name=timer_name): out = func(*args, **kwargs) return out return wrapper return decorator class TimerDummy: def __init__(self, *args, **kwargs): pass def __enter__(self): pass def __exit__(self, *args): pass class Timer: def __init__(self, timer_name=""): self.timer_name = timer_name if self.timer_name not in timers: timers[self.timer_name] = [] def __enter__(self): self.start = time.time() return self def __exit__(self, *args): end = time.time() time_diff_s = end - self.start # measured in seconds timers[self.timer_name].append(time_diff_s) def print_timing_info(): print("== Timing statistics ==") skip_warmup = 10 for timer_name, timing_values in [*cuda_timers.items(), *timers.items()]: if len(timing_values) <= skip_warmup: continue values = timing_values[skip_warmup:] timing_value_s_mean = np.mean(np.array(values)) timing_value_s_median = np.median(np.array(values)) timing_value_ms_mean = timing_value_s_mean * 1000 timing_value_ms_median = timing_value_s_median * 1000 if timing_value_ms_mean > 1000: print( "{}: mean={:.2f} s, median={:.2f} s".format( timer_name, timing_value_s_mean, timing_value_s_median ) ) else: print( "{}: mean={:.2f} ms, median={:.2f} ms".format( timer_name, timing_value_ms_mean, timing_value_ms_median ) ) # this will print all the timer values upon termination of any program that imported this file atexit.register(print_timing_info) ================================================ FILE: RVT/validation.py ================================================ import os os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" os.environ["OMP_NUM_THREADS"] = "1" os.environ["OPENBLAS_NUM_THREADS"] = "1" os.environ["MKL_NUM_THREADS"] = "1" os.environ["VECLIB_MAXIMUM_THREADS"] = "1" os.environ["NUMEXPR_NUM_THREADS"] = "1" from pathlib import Path import torch from torch.backends import cuda, cudnn cuda.matmul.allow_tf32 = True cudnn.allow_tf32 = True torch.multiprocessing.set_sharing_strategy("file_system") import hydra import hdf5plugin from omegaconf import DictConfig, OmegaConf import lightning.pytorch as pl from lightning.pytorch.loggers import CSVLogger from lightning.pytorch.callbacks import ModelSummary from config.modifier import dynamically_modify_train_config from modules.utils.fetch import fetch_data_module, fetch_model_module from modules.detection import Module @hydra.main(config_path="config", config_name="val", version_base="1.2") def main(config: DictConfig): dynamically_modify_train_config(config) # Just to check whether config can be resolved OmegaConf.to_container(config, resolve=True, throw_on_missing=True) print("------ Configuration ------") print(OmegaConf.to_yaml(config)) print("---------------------------") # --------------------- # GPU options # --------------------- gpus = config.hardware.gpus assert isinstance(gpus, int), "no more than 1 GPU supported" gpus = [gpus] # --------------------- # Data # --------------------- data_module = fetch_data_module(config=config) # --------------------- # Logging and Checkpoints # --------------------- logger = CSVLogger(save_dir="./validation_logs") ckpt_path = Path(config.checkpoint) # --------------------- # Model # --------------------- module = fetch_model_module(config=config) module = Module.load_from_checkpoint(str(ckpt_path), **{"full_config": config}) # --------------------- # Callbacks and Misc # --------------------- callbacks = [ModelSummary(max_depth=2)] # --------------------- # Validation # --------------------- trainer = pl.Trainer( accelerator="gpu", callbacks=callbacks, default_root_dir=None, devices=gpus, logger=logger, log_every_n_steps=100, precision=config.training.precision, # move_metrics_to_cpu=False, ) with torch.inference_mode(): if config.use_test_set: trainer.test(model=module, datamodule=data_module, ckpt_path=str(ckpt_path)) else: trainer.validate( model=module, datamodule=data_module, ckpt_path=str(ckpt_path) ) if __name__ == "__main__": main() ================================================ FILE: installation_details.txt ================================================ conda create -y -n events_signals python=3.11 conda activate events_signals conda install -y pytorch torchvision torchaudio pytorch-cuda=11.8 -c pytorch -c nvidia (Stable (2.2.1)) pip install lightning wandb pandas plotly opencv-python tabulate pycocotools bbox-visualizer StrEnum hydra-core einops torchdata tqdm numba h5py hdf5plugin lovely-tensors tensorboardX pykeops scikit-learn ================================================ FILE: scripts/1mpx/onempx_base.bash ================================================ #!/usr/bin/env bash source activate events_signals python RVT/train.py model=rnndet dataset=gen4 dataset.path=/shares/rpg.ifi.uzh/nzubic/datasets/RVT/gen4_new_no_psee_filter wandb.project_name=ssms_event_cameras \ wandb.group_name=1mpx +experiment/gen4=base.yaml hardware.gpus=[0,1] batch_size.train=6 batch_size.eval=6 \ hardware.num_workers.train=12 hardware.num_workers.eval=4 ================================================ FILE: scripts/1mpx/onempx_base.job ================================================ #!/usr/bin/env bash #SBATCH --ntasks-per-node=2 #SBATCH --cpus-per-task=16 #SBATCH --mem-per-cpu=8G #SBATCH --time=86:00:00 #SBATCH --gres=gpu:2 # The GPU model is optional, you can simply specify 'gpu:1' #SBATCH --constraint=GPUMEM80GB # This constraint is optional if you don't care about VRAM #SBATCH --output=final_outputs/onempx_base.txt module load gpu cuda srun onempx_base.bash ================================================ FILE: scripts/1mpx/onempx_small.bash ================================================ #!/usr/bin/env bash source activate events_signals python RVT/train.py model=rnndet dataset=gen4 dataset.path=/shares/rpg.ifi.uzh/nzubic/datasets/RVT/gen4_new_no_psee_filter wandb.project_name=ssms_event_cameras \ wandb.group_name=1mpx +experiment/gen4=small.yaml hardware.gpus=[0,1] batch_size.train=6 batch_size.eval=6 \ hardware.num_workers.train=12 hardware.num_workers.eval=4 ================================================ FILE: scripts/1mpx/onempx_small.job ================================================ #!/usr/bin/env bash #SBATCH --ntasks-per-node=2 #SBATCH --cpus-per-task=16 #SBATCH --mem-per-cpu=8G #SBATCH --time=78:00:00 #SBATCH --gres=gpu:2 # The GPU model is optional, you can simply specify 'gpu:1' #SBATCH --constraint=GPUMEM80GB # This constraint is optional if you don't care about VRAM #SBATCH --output=final_outputs/onempx_small_2.txt module load gpu cuda srun onempx_small.bash ================================================ FILE: scripts/gen1/base.txt ================================================ python RVT/train.py model=rnndet dataset=gen1 dataset.path=/data/scratch1/nzubic/datasets/RVT/gen1 wandb.project_name=ssms_event_cameras \ wandb.group_name=gen1 +experiment/gen1=base.yaml hardware.gpus=0 batch_size.train=8 batch_size.eval=8 hardware.num_workers.train=24 \ hardware.num_workers.eval=8 ================================================ FILE: scripts/gen1/small.txt ================================================ python RVT/train.py model=rnndet dataset=gen1 dataset.path=/data/scratch1/nzubic/datasets/RVT/gen1 wandb.project_name=ssms_event_cameras \ wandb.group_name=gen1 +experiment/gen1=small.yaml hardware.gpus=0 batch_size.train=8 batch_size.eval=8 hardware.num_workers.train=24 \ hardware.num_workers.eval=8