Showing preview only (672K chars total). Download the full file or copy to clipboard to get everything.
Repository: uzh-rpg/ssms_event_cameras
Branch: master
Commit: 7c871b55a0c5
Files: 162
Total size: 623.9 KB
Directory structure:
gitextract_9qf3l4ub/
├── .gitignore
├── README.md
├── RVT/
│ ├── .gitignore
│ ├── LICENSE
│ ├── README.md
│ ├── callbacks/
│ │ ├── custom.py
│ │ ├── detection.py
│ │ ├── gradflow.py
│ │ ├── utils/
│ │ │ └── visualization.py
│ │ └── viz_base.py
│ ├── config/
│ │ ├── dataset/
│ │ │ ├── base.yaml
│ │ │ ├── gen1.yaml
│ │ │ └── gen4.yaml
│ │ ├── experiment/
│ │ │ ├── gen1/
│ │ │ │ ├── base.yaml
│ │ │ │ ├── default.yaml
│ │ │ │ └── small.yaml
│ │ │ └── gen4/
│ │ │ ├── base.yaml
│ │ │ ├── default.yaml
│ │ │ └── small.yaml
│ │ ├── general.yaml
│ │ ├── model/
│ │ │ ├── base.yaml
│ │ │ ├── maxvit_yolox/
│ │ │ │ └── default.yaml
│ │ │ └── rnndet.yaml
│ │ ├── modifier.py
│ │ ├── train.yaml
│ │ └── val.yaml
│ ├── data/
│ │ ├── genx_utils/
│ │ │ ├── collate.py
│ │ │ ├── collate_from_pytorch.py
│ │ │ ├── dataset_rnd.py
│ │ │ ├── dataset_streaming.py
│ │ │ ├── labels.py
│ │ │ ├── sequence_base.py
│ │ │ ├── sequence_for_streaming.py
│ │ │ └── sequence_rnd.py
│ │ └── utils/
│ │ ├── augmentor.py
│ │ ├── representations.py
│ │ ├── spatial.py
│ │ ├── stream_concat_datapipe.py
│ │ ├── stream_sharded_datapipe.py
│ │ └── types.py
│ ├── loggers/
│ │ ├── utils.py
│ │ └── wandb_logger.py
│ ├── models/
│ │ ├── detection/
│ │ │ ├── __init_.py
│ │ │ ├── recurrent_backbone/
│ │ │ │ ├── __init__.py
│ │ │ │ ├── base.py
│ │ │ │ └── maxvit_rnn.py
│ │ │ ├── yolox/
│ │ │ │ ├── models/
│ │ │ │ │ ├── __init__.py
│ │ │ │ │ ├── losses.py
│ │ │ │ │ ├── network_blocks.py
│ │ │ │ │ └── yolo_head.py
│ │ │ │ └── utils/
│ │ │ │ ├── __init__.py
│ │ │ │ ├── boxes.py
│ │ │ │ └── compat.py
│ │ │ └── yolox_extension/
│ │ │ └── models/
│ │ │ ├── __init__.py
│ │ │ ├── build.py
│ │ │ ├── detector.py
│ │ │ └── yolo_pafpn.py
│ │ └── layers/
│ │ ├── maxvit/
│ │ │ ├── __init__.py
│ │ │ ├── layers/
│ │ │ │ ├── __init__.py
│ │ │ │ ├── activations.py
│ │ │ │ ├── activations_jit.py
│ │ │ │ ├── activations_me.py
│ │ │ │ ├── adaptive_avgmax_pool.py
│ │ │ │ ├── attention_pool2d.py
│ │ │ │ ├── blur_pool.py
│ │ │ │ ├── bottleneck_attn.py
│ │ │ │ ├── cbam.py
│ │ │ │ ├── classifier.py
│ │ │ │ ├── cond_conv2d.py
│ │ │ │ ├── config.py
│ │ │ │ ├── conv2d_same.py
│ │ │ │ ├── conv_bn_act.py
│ │ │ │ ├── create_act.py
│ │ │ │ ├── create_attn.py
│ │ │ │ ├── create_conv2d.py
│ │ │ │ ├── create_norm.py
│ │ │ │ ├── create_norm_act.py
│ │ │ │ ├── drop.py
│ │ │ │ ├── eca.py
│ │ │ │ ├── evo_norm.py
│ │ │ │ ├── fast_norm.py
│ │ │ │ ├── filter_response_norm.py
│ │ │ │ ├── gather_excite.py
│ │ │ │ ├── global_context.py
│ │ │ │ ├── halo_attn.py
│ │ │ │ ├── helpers.py
│ │ │ │ ├── inplace_abn.py
│ │ │ │ ├── lambda_layer.py
│ │ │ │ ├── linear.py
│ │ │ │ ├── median_pool.py
│ │ │ │ ├── mixed_conv2d.py
│ │ │ │ ├── ml_decoder.py
│ │ │ │ ├── mlp.py
│ │ │ │ ├── non_local_attn.py
│ │ │ │ ├── norm.py
│ │ │ │ ├── norm_act.py
│ │ │ │ ├── padding.py
│ │ │ │ ├── patch_embed.py
│ │ │ │ ├── pool2d_same.py
│ │ │ │ ├── pos_embed.py
│ │ │ │ ├── selective_kernel.py
│ │ │ │ ├── separable_conv.py
│ │ │ │ ├── space_to_depth.py
│ │ │ │ ├── split_attn.py
│ │ │ │ ├── split_batchnorm.py
│ │ │ │ ├── squeeze_excite.py
│ │ │ │ ├── std_conv.py
│ │ │ │ ├── test_time_pool.py
│ │ │ │ ├── trace_utils.py
│ │ │ │ └── weight_init.py
│ │ │ └── maxvit.py
│ │ ├── rnn.py
│ │ └── s5/
│ │ ├── __init__.py
│ │ ├── jax_func.py
│ │ ├── s5_init.py
│ │ ├── s5_model.py
│ │ └── triton_comparison.py
│ ├── modules/
│ │ ├── __init__.py
│ │ ├── data/
│ │ │ └── genx.py
│ │ ├── detection.py
│ │ └── utils/
│ │ ├── detection.py
│ │ └── fetch.py
│ ├── scripts/
│ │ ├── genx/
│ │ │ ├── README.md
│ │ │ ├── conf_preprocess/
│ │ │ │ ├── extraction/
│ │ │ │ │ ├── const_count.yaml
│ │ │ │ │ ├── const_duration.yaml
│ │ │ │ │ └── frequencies/
│ │ │ │ │ ├── const_duration_100hz.yaml
│ │ │ │ │ ├── const_duration_200hz.yaml
│ │ │ │ │ ├── const_duration_40hz.yaml
│ │ │ │ │ └── const_duration_80hz.yaml
│ │ │ │ ├── filter_gen1.yaml
│ │ │ │ ├── filter_gen4.yaml
│ │ │ │ └── representation/
│ │ │ │ ├── mixeddensity_stack.yaml
│ │ │ │ └── stacked_hist.yaml
│ │ │ ├── preprocess_dataset.py
│ │ │ └── preprocess_dataset.sh
│ │ └── viz/
│ │ └── viz_gt.py
│ ├── train.py
│ ├── utils/
│ │ ├── evaluation/
│ │ │ └── prophesee/
│ │ │ ├── __init__.py
│ │ │ ├── evaluation.py
│ │ │ ├── evaluator.py
│ │ │ ├── io/
│ │ │ │ ├── __init__.py
│ │ │ │ ├── box_filtering.py
│ │ │ │ ├── box_loading.py
│ │ │ │ ├── dat_events_tools.py
│ │ │ │ ├── npy_events_tools.py
│ │ │ │ └── psee_loader.py
│ │ │ ├── metrics/
│ │ │ │ ├── __init__.py
│ │ │ │ └── coco_eval.py
│ │ │ └── visualize/
│ │ │ ├── __init__.py
│ │ │ └── vis_utils.py
│ │ ├── helpers.py
│ │ ├── padding.py
│ │ ├── preprocessing.py
│ │ └── timers.py
│ └── validation.py
├── installation_details.txt
└── scripts/
├── 1mpx/
│ ├── onempx_base.bash
│ ├── onempx_base.job
│ ├── onempx_small.bash
│ └── onempx_small.job
└── gen1/
├── base.txt
└── small.txt
================================================
FILE CONTENTS
================================================
================================================
FILE: .gitignore
================================================
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock
# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
# in version control.
# https://pdm.fming.dev/#use-with-ide
.pdm.toml
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
# PyCharm
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
.idea/
*.pyc
*.iml
# Specific stuff
wandb/
cache_dir/
raw_datasets/
raw_data/
final_outputs/
outputs/
validation_logs/
apex/
*.ckpt
.vscode/
================================================
FILE: README.md
================================================
# [CVPR'24 Spotlight] State Space Models for Event Cameras
<p align="center">
<a href="https://www.youtube.com/watch?v=WRZZJn6Me9M">
<img src="https://github.com/uzh-rpg/ssms_event_cameras/blob/master/scripts/zubic_cvpr2024_youtube.png" alt="youtube_video"/>
</a>
</p>
This is the official PyTorch implementation of the CVPR 2024 paper [State Space Models for Event Cameras](https://arxiv.org/abs/2402.15584).
### 🖼️ Check Out Our Poster! 🖼️ [here](https://download.ifi.uzh.ch/rpg/CVPR24_Zubic/Zubic_CVPR24_poster.pdf)
## :white_check_mark: Updates
* **` June. 14th, 2024`**: Everything is updated! Poster released! Check it above.
* **` June. 6st, 2024`**: Video released! To watch our video, simply click on the YouTube play button above.
* **` June. 1st, 2024`**: Our CVPR conference paper has also been accepted as a Spotlight presentation at "The 3rd Workshop on Transformers for Vision (T4V)."
* **` April. 19th, 2024`**: The code along with the best checkpoints is released! The poster and video will be released shortly before CVPR 2024.
## Citation
If you find this work and/or code useful, please cite our paper:
```bibtex
@InProceedings{Zubic_2024_CVPR,
author = {Zubic, Nikola and Gehrig, Mathias and Scaramuzza, Davide},
title = {State Space Models for Event Cameras},
booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
month = {June},
year = {2024},
pages = {5819-5828}
}
```
## SSM-ViT
- S5 model used in our SSM-ViT pipeline can be seen [here](https://github.com/uzh-rpg/ssms_event_cameras/tree/master/RVT/models/layers/s5).
- In particular, S5 is used instead of RNN in a 4-stage hierarchical ViT backbone, and its forward function is exposed [here](https://github.com/uzh-rpg/ssms_event_cameras/blob/master/RVT/models/detection/recurrent_backbone/maxvit_rnn.py#L245). What is nice about this approach is that we do not need a 'for' loop over sequence dimension, but instead we employ a parallel scanning algorithm. This model assumes that a hidden state is being carried over.
- For a model that is standalone, and can be used for any sequence modeling problem, one does not use by default this formulation where we carry on the hidden state. The implementation is the same as the original JAX implementation and can be downloaded in zip format from [ssms_event_cameras/RVT/models/s5.zip](https://github.com/uzh-rpg/ssms_event_cameras/raw/master/RVT/models/s5.zip).
## Installation
### Conda
We highly recommend using [Mambaforge](https://github.com/conda-forge/miniforge#mambaforge) to reduce the installation time.
```Bash
conda create -y -n events_signals python=3.11
conda activate events_signals
conda install pytorch==2.2.1 torchvision==0.17.1 torchaudio==2.2.1 pytorch-cuda=11.8 -c pytorch -c nvidia
pip install lightning wandb pandas plotly opencv-python tabulate pycocotools bbox-visualizer StrEnum hydra-core einops torchdata tqdm numba h5py hdf5plugin lovely-tensors tensorboardX pykeops scikit-learn
```
## Required Data
To evaluate or train the S5-ViT model, you will need to download the required preprocessed datasets:
<table><tbody>
<th valign="bottom"></th>
<th valign="bottom">1 Mpx</th>
<th valign="bottom">Gen1</th>
<tr><td align="left">pre-processed dataset</td>
<td align="center"><a href="https://download.ifi.uzh.ch/rpg/RVT/datasets/preprocessed/gen4.tar">download</a></td>
<td align="center"><a href="https://download.ifi.uzh.ch/rpg/RVT/datasets/preprocessed/gen1.tar">download</a></td>
</tr>
<tr><td align="left">crc32</td>
<td align="center"><tt>c5ec7c38</tt></td>
<td align="center"><tt>5acab6f3</tt></td>
</tr>
</tbody></table>
You may also pre-process the dataset yourself by following the [instructions](https://github.com/NikolaZubic/ssms_event_cameras/blob/master/RVT/scripts/genx/README.md).
## Pre-trained Checkpoints
### 1 Mpx
<table><tbody>
<th valign="bottom"></th>
<th valign="bottom">S5-ViT-Base</th>
<th valign="bottom">S5-ViT-Small</th>
<tr><td align="left">pre-trained checkpoint</td>
<td align="center"><a href="https://download.ifi.uzh.ch/rpg/CVPR24_Zubic/gen4_base.ckpt">download</a></td>
<td align="center"><a href="https://download.ifi.uzh.ch/rpg/CVPR24_Zubic/gen4_small.ckpt">download</a></td>
</tr>
</tbody></table>
### Gen1
<table><tbody>
<th valign="bottom"></th>
<th valign="bottom">S5-ViT-Base</th>
<th valign="bottom">S5-ViT-Small</th>
<tr><td align="left">pre-trained checkpoint</td>
<td align="center"><a href="https://download.ifi.uzh.ch/rpg/CVPR24_Zubic/gen1_base.ckpt">download</a></td>
<td align="center"><a href="https://download.ifi.uzh.ch/rpg/CVPR24_Zubic/gen1_small.ckpt">download</a></td>
</tr>
</tbody></table>
## Evaluation
- Evaluation scripts with concrete parameters that we trained our models can be seen [here](https://github.com/uzh-rpg/ssms_event_cameras/tree/master/scripts).
- Set `DATA_DIR` as the path to either the 1 Mpx or Gen1 dataset directory
- Set `CKPT_PATH` to the path of the *correct* checkpoint matching the choice of the model and dataset
- Set
- `MDL_CFG=base` or
- `MDL_CFG=small`
to load either the base or small model configuration.
- Set `GPU_ID` to the PCI BUS ID of the GPU that you want to use. e.g. `GPU_ID=0`.
Only a single GPU is supported for evaluation
### 1 Mpx
```Bash
python RVT/validation.py dataset=gen4 dataset.path=${DATA_DIR} checkpoint=${CKPT_PATH} \
use_test_set=1 hardware.gpus=${GPU_ID} +experiment/gen4="${MDL_CFG}.yaml" \
batch_size.eval=12 model.postprocess.confidence_threshold=0.001
```
### Gen1
```Bash
python RVT/validation.py dataset=gen1 dataset.path=${DATA_DIR} checkpoint=${CKPT_PATH} \
use_test_set=1 hardware.gpus=${GPU_ID} +experiment/gen1="${MDL_CFG}.yaml" \
batch_size.eval=8 model.postprocess.confidence_threshold=0.001
```
We set the same batch size for the evaluation and training: 12 for the 1 Mpx dataset, and 8 for the Gen1 dataset.
## Evaluation results
Evaluation should give the same results as shown below:
- 47.7 and 47.8 mAP on Gen1 and 1 Mpx datasets for the base model, and
- 46.6 and 46.5 mAP on Gen1 and 1 Mpx datasets for the small model.
<p align="center">
<img src="https://github.com/uzh-rpg/ssms_event_cameras/blob/master/scripts/checkpoints.png">
</p>
## Training
- Set `DATA_DIR` as the path to either the 1 Mpx or Gen1 dataset directory
- Set
- `MDL_CFG=base` or
- `MDL_CFG=small`
to load either the base or the small configuration.
- Set `GPU_IDS` to the PCI BUS IDs of the GPUs that you want to use. e.g. `GPU_IDS=[0,1]` for using GPU 0 and 1.
**Using a list of IDS will enable single-node multi-GPU training.**
Pay attention to the batch size which is defined per GPU.
- Set `BATCH_SIZE_PER_GPU` such that the effective batch size is matching the parameters below.
The **effective batch size** is (batch size per GPU)*(number of GPUs).
- If you would like to change the effective batch size, we found the following learning rate scaling to work well for
all models on both datasets:
`lr = 2e-4 * sqrt(effective_batch_size/8)`.
- The training code uses [W&B](https://wandb.ai/) for logging during the training.
Hence, we assume that you have a W&B account.
- The training script below will create a new project called `ssms_event_cameras`. Adapt the project name and group name if necessary.
### 1 Mpx
- The effective batch size for the 1 Mpx training is 12.
- For training the model on 1 Mpx dataset, we need 2x A100 80 GB GPUs and we use 12 workers per GPU for training and 4 workers per GPU for evaluation:
```Bash
GPU_IDS=[0,1]
BATCH_SIZE_PER_GPU=6
TRAIN_WORKERS_PER_GPU=12
EVAL_WORKERS_PER_GPU=4
python RVT/train.py model=rnndet dataset=gen4 dataset.path=${DATA_DIR} wandb.project_name=ssms_event_cameras \
wandb.group_name=1mpx +experiment/gen4="${MDL_CFG}.yaml" hardware.gpus=${GPU_IDS} \
batch_size.train=${BATCH_SIZE_PER_GPU} batch_size.eval=${BATCH_SIZE_PER_GPU} \
hardware.num_workers.train=${TRAIN_WORKERS_PER_GPU} hardware.num_workers.eval=${EVAL_WORKERS_PER_GPU}
```
If you for example want to execute the training on 4 GPUs simply adapt `GPU_IDS` and `BATCH_SIZE_PER_GPU` accordingly:
```Bash
GPU_IDS=[0,1,2,3]
BATCH_SIZE_PER_GPU=3
```
### Gen1
- The effective batch size for the Gen1 training is 8.
- For training the model on the Gen1 dataset, we need 1x A100 80 GPU using 24 workers for training and 8 workers for evaluation:
```Bash
GPU_IDS=0
BATCH_SIZE_PER_GPU=8
TRAIN_WORKERS_PER_GPU=24
EVAL_WORKERS_PER_GPU=8
python RVT/train.py model=rnndet dataset=gen1 dataset.path=${DATA_DIR} wandb.project_name=ssms_event_cameras \
wandb.group_name=gen1 +experiment/gen1="${MDL_CFG}.yaml" hardware.gpus=${GPU_IDS} \
batch_size.train=${BATCH_SIZE_PER_GPU} batch_size.eval=${BATCH_SIZE_PER_GPU} \
hardware.num_workers.train=${TRAIN_WORKERS_PER_GPU} hardware.num_workers.eval=${EVAL_WORKERS_PER_GPU}
```
## Code Acknowledgments
This project has used code from the following projects:
- [RVT](https://github.com/uzh-rpg/RVT) - Recurrent Vision Transformers for Object Detection with Event Cameras in PyTorch
- [S4](https://github.com/state-spaces/s4) - Structured State Spaces for Sequence Modeling, in particular S4 and S4D models in PyTorch
- [S5](https://github.com/lindermanlab/S5) - Simplified State Space Layers for Sequence Modeling in JAX
- [S5 PyTorch](https://github.com/i404788/s5-pytorch) - S5 model in PyTorch
================================================
FILE: RVT/.gitignore
================================================
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
pip-wheel-metadata/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
.python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
================================================
FILE: RVT/LICENSE
================================================
MIT License
Copyright (c) 2023 Mathias Gehrig
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
================================================
FILE: RVT/README.md
================================================
# RVT: Recurrent Vision Transformers for Object Detection with Event Cameras
<p align="center">
<img src="https://rpg.ifi.uzh.ch/img/papers/arxiv22_detection_mgehrig/combo.png" width="750">
</p>
This is the official Pytorch implementation of the CVPR 2023 paper [Recurrent Vision Transformers for Object Detection with Event Cameras](https://arxiv.org/abs/2212.05598).
Watch the [**video**](https://youtu.be/xZ-pNwHxHgY) for a quick overview.
```bibtex
@InProceedings{Gehrig_2023_CVPR,
author = {Mathias Gehrig and Davide Scaramuzza},
title = {Recurrent Vision Transformers for Object Detection with Event Cameras},
booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
year = {2023},
}
```
## Conda Installation
We highly recommend to use [Mambaforge](https://github.com/conda-forge/miniforge#mambaforge) to reduce the installation time.
```Bash
conda create -y -n rvt python=3.9 pip
conda activate rvt
conda config --set channel_priority flexible
CUDA_VERSION=11.8
conda install -y h5py=3.8.0 blosc-hdf5-plugin=1.0.0 \
hydra-core=1.3.2 einops=0.6.0 torchdata=0.6.0 tqdm numba \
pytorch=2.0.0 torchvision=0.15.0 pytorch-cuda=$CUDA_VERSION \
-c pytorch -c nvidia -c conda-forge
python -m pip install pytorch-lightning==1.8.6 wandb==0.14.0 \
pandas==1.5.3 plotly==5.13.1 opencv-python==4.6.0.66 tabulate==0.9.0 \
pycocotools==2.0.6 bbox-visualizer==0.1.0 StrEnum==0.4.10
python -m pip install 'git+https://github.com/facebookresearch/detectron2.git'
```
Detectron2 is not strictly required but speeds up the evaluation.
## Required Data
To evaluate or train RVT you will need to download the required preprocessed datasets:
<table><tbody>
<th valign="bottom"></th>
<th valign="bottom">1 Mpx</th>
<th valign="bottom">Gen1</th>
<tr><td align="left">pre-processed dataset</td>
<td align="center"><a href="https://download.ifi.uzh.ch/rpg/RVT/datasets/preprocessed/gen4.tar">download</a></td>
<td align="center"><a href="https://download.ifi.uzh.ch/rpg/RVT/datasets/preprocessed/gen1.tar">download</a></td>
</tr>
<tr><td align="left">crc32</td>
<td align="center"><tt>c5ec7c38</tt></td>
<td align="center"><tt>5acab6f3</tt></td>
</tr>
</tbody></table>
You may also pre-process the dataset yourself by following the [instructions](scripts/genx/README.md).
## Pre-trained Checkpoints
### 1 Mpx
<table><tbody>
<th valign="bottom"></th>
<th valign="bottom">RVT-Base</th>
<th valign="bottom">RVT-Small</th>
<th valign="bottom">RVT-Tiny</th>
<tr><td align="left">pre-trained checkpoint</td>
<td align="center"><a href="https://download.ifi.uzh.ch/rpg/RVT/checkpoints/1mpx/rvt-b.ckpt">download</a></td>
<td align="center"><a href="https://download.ifi.uzh.ch/rpg/RVT/checkpoints/1mpx/rvt-s.ckpt">download</a></td>
<td align="center"><a href="https://download.ifi.uzh.ch/rpg/RVT/checkpoints/1mpx/rvt-t.ckpt">download</a></td>
</tr>
<tr><td align="left">md5</td>
<td align="center"><tt>72923a</tt></td>
<td align="center"><tt>a94207</tt></td>
<td align="center"><tt>5a3c78</tt></td>
</tr>
</tbody></table>
### Gen1
<table><tbody>
<th valign="bottom"></th>
<th valign="bottom">RVT-Base</th>
<th valign="bottom">RVT-Small</th>
<th valign="bottom">RVT-Tiny</th>
<tr><td align="left">pre-trained checkpoint</td>
<td align="center"><a href="https://download.ifi.uzh.ch/rpg/RVT/checkpoints/gen1/rvt-b.ckpt">download</a></td>
<td align="center"><a href="https://download.ifi.uzh.ch/rpg/RVT/checkpoints/gen1/rvt-s.ckpt">download</a></td>
<td align="center"><a href="https://download.ifi.uzh.ch/rpg/RVT/checkpoints/gen1/rvt-t.ckpt">download</a></td>
</tr>
<tr><td align="left">md5</td>
<td align="center"><tt>839317</tt></td>
<td align="center"><tt>840f2b</tt></td>
<td align="center"><tt>a770b9</tt></td>
</tr>
</tbody></table>
## Evaluation
- Set `DATA_DIR` as the path to either the 1 Mpx or Gen1 dataset directory
- Set `CKPT_PATH` to the path of the *correct* checkpoint matching the choice of the model and dataset.
- Set
- `MDL_CFG=base`, or
- `MDL_CFG=small`, or
- `MDL_CFG=tiny`
to load either the base, small, or tiny model configuration
- Set
- `USE_TEST=1` to evaluate on the test set, or
- `USE_TEST=0` to evaluate on the validation set
- Set `GPU_ID` to the PCI BUS ID of the GPU that you want to use. e.g. `GPU_ID=0`.
Only a single GPU is supported for evaluation
### 1 Mpx
```Bash
python validation.py dataset=gen4 dataset.path=${DATA_DIR} checkpoint=${CKPT_PATH} \
use_test_set=${USE_TEST} hardware.gpus=${GPU_ID} +experiment/gen4="${MDL_CFG}.yaml" \
batch_size.eval=8 model.postprocess.confidence_threshold=0.001
```
### Gen1
```Bash
python validation.py dataset=gen1 dataset.path=${DATA_DIR} checkpoint=${CKPT_PATH} \
use_test_set=${USE_TEST} hardware.gpus=${GPU_ID} +experiment/gen1="${MDL_CFG}.yaml" \
batch_size.eval=8 model.postprocess.confidence_threshold=0.001
```
## Training
- Set `DATA_DIR` as the path to either the 1 Mpx or Gen1 dataset directory
- Set
- `MDL_CFG=base`, or
- `MDL_CFG=small`, or
- `MDL_CFG=tiny`
to load either the base, small, or tiny model configuration
- Set `GPU_IDS` to the PCI BUS IDs of the GPUs that you want to use. e.g. `GPU_IDS=[0,1]` for using GPU 0 and 1.
**Using a list of IDS will enable single-node multi-GPU training.**
Pay attention to the batch size which is defined per GPU:
- Set `BATCH_SIZE_PER_GPU` such that the effective batch size is matching the parameters below.
The **effective batch size** is (batch size per gpu)*(number of GPUs).
- If you would like to change the effective batch size, we found the following learning rate scaling to work well for
all models on both datasets:
`lr = 2e-4 * sqrt(effective_batch_size/8)`.
- The training code uses [W&B](https://wandb.ai/) for logging during the training.
Hence, we assume that you have a W&B account.
- The training script below will create a new project called `RVT`. Adapt the project name and group name if necessary.
### 1 Mpx
- The effective batch size for the 1 Mpx training is 24.
- To train on 2 GPUs using 6 workers per GPU for training and 2 workers per GPU for evaluation:
```Bash
GPU_IDS=[0,1]
BATCH_SIZE_PER_GPU=12
TRAIN_WORKERS_PER_GPU=6
EVAL_WORKERS_PER_GPU=2
python train.py model=rnndet dataset=gen4 dataset.path=${DATA_DIR} wandb.project_name=RVT \
wandb.group_name=1mpx +experiment/gen4="${MDL_CFG}.yaml" hardware.gpus=${GPU_IDS} \
batch_size.train=${BATCH_SIZE_PER_GPU} batch_size.eval=${BATCH_SIZE_PER_GPU} \
hardware.num_workers.train=${TRAIN_WORKERS_PER_GPU} hardware.num_workers.eval=${EVAL_WORKERS_PER_GPU}
```
If you instead want to execute the training on 4 GPUs simply adapt `GPU_IDS` and `BATCH_SIZE_PER_GPU` accordingly:
```Bash
GPU_IDS=[0,1,2,3]
BATCH_SIZE_PER_GPU=6
```
### Gen1
- The effective batch size for the Gen1 training is 8.
- To train on 1 GPU using 6 workers for training and 2 workers for evaluation:
```Bash
GPU_IDS=0
BATCH_SIZE_PER_GPU=8
TRAIN_WORKERS_PER_GPU=6
EVAL_WORKERS_PER_GPU=2
python train.py model=rnndet dataset=gen1 dataset.path=${DATA_DIR} wandb.project_name=RVT \
wandb.group_name=gen1 +experiment/gen1="${MDL_CFG}.yaml" hardware.gpus=${GPU_IDS} \
batch_size.train=${BATCH_SIZE_PER_GPU} batch_size.eval=${BATCH_SIZE_PER_GPU} \
hardware.num_workers.train=${TRAIN_WORKERS_PER_GPU} hardware.num_workers.eval=${EVAL_WORKERS_PER_GPU}
```
## Code Acknowledgments
This project has used code from the following projects:
- [timm](https://github.com/huggingface/pytorch-image-models) for the MaxViT layer implementation in Pytorch
- [YOLOX](https://github.com/Megvii-BaseDetection/YOLOX) for the detection PAFPN/head
================================================
FILE: RVT/callbacks/custom.py
================================================
from omegaconf import DictConfig
from lightning.pytorch.callbacks import Callback
from lightning.pytorch.callbacks import ModelCheckpoint
from callbacks.detection import DetectionVizCallback
def get_ckpt_callback(config: DictConfig) -> ModelCheckpoint:
model_name = config.model.name
prefix = "val"
if model_name == "rnndet":
metric = "AP"
mode = "max"
else:
raise NotImplementedError
ckpt_callback_monitor = prefix + "/" + metric
filename_monitor_str = prefix + "_" + metric
ckpt_filename = (
"epoch={epoch:03d}-step={step}-"
+ filename_monitor_str
+ "={"
+ ckpt_callback_monitor
+ ":.2f}"
)
cktp_callback = ModelCheckpoint(
monitor=ckpt_callback_monitor,
filename=ckpt_filename,
auto_insert_metric_name=False, # because backslash would create a directory
save_top_k=1,
mode=mode,
every_n_epochs=config.logging.ckpt_every_n_epochs,
save_last=True,
verbose=True,
)
cktp_callback.CHECKPOINT_NAME_LAST = "last_epoch={epoch:03d}-step={step}"
return cktp_callback
def get_viz_callback(config: DictConfig) -> Callback:
model_name = config.model.name
if model_name == "rnndet":
return DetectionVizCallback(config=config)
raise NotImplementedError
================================================
FILE: RVT/callbacks/detection.py
================================================
from enum import Enum, auto
from typing import Any
import torch
from einops import rearrange
from omegaconf import DictConfig
from data.utils.types import ObjDetOutput
from loggers.wandb_logger import WandbLogger
from utils.evaluation.prophesee.visualize.vis_utils import (
LABELMAP_GEN1,
LABELMAP_GEN4_SHORT,
draw_bboxes,
)
from .viz_base import VizCallbackBase
class DetectionVizEnum(Enum):
EV_IMG = auto()
LABEL_IMG_PROPH = auto()
PRED_IMG_PROPH = auto()
class DetectionVizCallback(VizCallbackBase):
def __init__(self, config: DictConfig):
super().__init__(config=config, buffer_entries=DetectionVizEnum)
dataset_name = config.dataset.name
if dataset_name == "gen1":
self.label_map = LABELMAP_GEN1
elif dataset_name == "gen4":
self.label_map = LABELMAP_GEN4_SHORT
else:
raise NotImplementedError
def on_train_batch_end_custom(
self,
logger: WandbLogger,
outputs: Any,
batch: Any,
log_n_samples: int,
global_step: int,
) -> None:
if outputs is None:
# If we tried to skip the training step (not supported in DDP in PL, atm)
return
ev_tensors = outputs[ObjDetOutput.EV_REPR]
num_samples = len(ev_tensors)
assert num_samples > 0
log_n_samples = min(num_samples, log_n_samples)
merged_img = []
captions = []
start_idx = num_samples - 1
end_idx = start_idx - log_n_samples
# for sample_idx in range(log_n_samples):
for sample_idx in range(start_idx, end_idx, -1):
ev_img = self.ev_repr_to_img(ev_tensors[sample_idx].cpu().numpy())
predictions_proph = outputs[ObjDetOutput.PRED_PROPH][sample_idx]
prediction_img = ev_img.copy()
draw_bboxes(prediction_img, predictions_proph, labelmap=self.label_map)
labels_proph = outputs[ObjDetOutput.LABELS_PROPH][sample_idx]
label_img = ev_img.copy()
draw_bboxes(label_img, labels_proph, labelmap=self.label_map)
merged_img.append(
rearrange(
[prediction_img, label_img], "pl H W C -> (pl H) W C", pl=2, C=3
)
)
captions.append(f"sample_{sample_idx}")
logger.log_images(
key="train/predictions",
images=merged_img,
caption=captions,
step=global_step,
)
def on_validation_batch_end_custom(self, batch: Any, outputs: Any):
if outputs[ObjDetOutput.SKIP_VIZ]:
return
ev_tensor = outputs[ObjDetOutput.EV_REPR]
assert isinstance(ev_tensor, torch.Tensor)
ev_img = self.ev_repr_to_img(ev_tensor.cpu().numpy())
predictions_proph = outputs[ObjDetOutput.PRED_PROPH]
prediction_img = ev_img.copy()
draw_bboxes(prediction_img, predictions_proph, labelmap=self.label_map)
self.add_to_buffer(DetectionVizEnum.PRED_IMG_PROPH, prediction_img)
labels_proph = outputs[ObjDetOutput.LABELS_PROPH]
label_img = ev_img.copy()
draw_bboxes(label_img, labels_proph, labelmap=self.label_map)
self.add_to_buffer(DetectionVizEnum.LABEL_IMG_PROPH, label_img)
def on_validation_epoch_end_custom(self, logger: WandbLogger):
pred_imgs = self.get_from_buffer(DetectionVizEnum.PRED_IMG_PROPH)
label_imgs = self.get_from_buffer(DetectionVizEnum.LABEL_IMG_PROPH)
assert len(pred_imgs) == len(label_imgs)
merged_img = []
captions = []
for idx, (pred_img, label_img) in enumerate(zip(pred_imgs, label_imgs)):
merged_img.append(
rearrange([pred_img, label_img], "pl H W C -> (pl H) W C", pl=2, C=3)
)
captions.append(f"sample_{idx}")
logger.log_images(key="val/predictions", images=merged_img, caption=captions)
================================================
FILE: RVT/callbacks/gradflow.py
================================================
from typing import Any
import lightning.pytorch as pl
from lightning.pytorch.callbacks import Callback
from lightning.pytorch.utilities.rank_zero import rank_zero_only
from callbacks.utils.visualization import get_grad_flow_figure
class GradFlowLogCallback(Callback):
def __init__(self, log_every_n_train_steps: int):
super().__init__()
assert log_every_n_train_steps > 0
self.log_every_n_train_steps = log_every_n_train_steps
@rank_zero_only
def on_before_zero_grad(
self, trainer: pl.Trainer, pl_module: pl.LightningModule, optimizer: Any
) -> None:
# NOTE: before we had this in the on_after_backward callback.
# This was fine for fp32 but showed unscaled gradients for fp16.
# That is why we move it to on_before_zero_grad where gradients are scaled.
global_step = trainer.global_step
if global_step % self.log_every_n_train_steps != 0:
return
named_parameters = pl_module.named_parameters()
figure = get_grad_flow_figure(named_parameters)
trainer.logger.log_metrics({"train/gradients": figure}, step=global_step)
================================================
FILE: RVT/callbacks/utils/visualization.py
================================================
import pandas as pd
import plotly.express as px
def get_grad_flow_figure(named_params):
"""Creates figure to visualize gradients flowing through different layers in the net during training.
Can be used for checking for possible gradient vanishing / exploding problems.
Usage: Use this function after loss.backwards()
"""
data_dict = {
"name": list(),
"grad_abs": list(),
}
for name, param in named_params:
if param.requires_grad and param.grad is not None:
grad_abs = param.grad.abs()
data_dict["name"].append(name)
data_dict["grad_abs"].append(grad_abs.mean().cpu().item())
data_frame = pd.DataFrame.from_dict(data_dict)
fig = px.bar(data_frame, x="name", y="grad_abs")
return fig
================================================
FILE: RVT/callbacks/viz_base.py
================================================
import random
from enum import Enum
from typing import Any, List, Optional, Type, Union
import numpy as np
import pytorch_lightning as pl
import torch as th
from einops import rearrange, reduce
from omegaconf import DictConfig
from lightning.pytorch.callbacks import Callback
from lightning.pytorch.utilities.rank_zero import rank_zero_only
from loggers.wandb_logger import WandbLogger
class VizCallbackBase(Callback):
def __init__(self, config: DictConfig, buffer_entries: Type[Enum]):
super().__init__()
self.log_config = config.logging
self._training_has_started = False
self._selected_val_batches = False
self.buffer_entries = buffer_entries
self._val_batch_indices = list()
self._buffer = None
self._reset_buffer()
def _reset_buffer(self):
self._buffer = {entry: [] for entry in self.buffer_entries}
# Functions to be USED in the base class ---------------------------------------------------------------------------
def add_to_buffer(self, key: Enum, value: Union[np.ndarray, th.Tensor]):
if isinstance(value, th.Tensor):
assert not value.requires_grad
value = value.cpu()
else:
assert isinstance(value, np.ndarray)
assert type(key) == self.buffer_entries
assert key in self._buffer
self._buffer[key].append(value)
def get_from_buffer(self, key: Enum) -> List[th.Tensor]:
assert type(key) == self.buffer_entries
return self._buffer[key]
# Functions to be IMPLEMENTED in the base class --------------------------------------------------------------------
def on_train_batch_end_custom(
self,
logger: WandbLogger,
outputs: Any,
batch: Any,
log_n_samples: int,
global_step: int,
) -> None:
raise NotImplementedError
def on_validation_batch_end_custom(self, batch: Any, outputs: Any) -> None:
raise NotImplementedError
def on_validation_epoch_end_custom(self, logger: WandbLogger) -> None:
raise NotImplementedError
# ------------------------------------------------------------------------------------------------------------------
def on_train_batch_end(
self,
trainer: pl.Trainer,
pl_module: pl.LightningModule,
outputs: Any,
batch: Any,
batch_idx: int,
unused: int = 0,
) -> None:
log_train_hd = self.log_config.train.high_dim
if not log_train_hd.enable:
return
step = trainer.global_step
assert log_train_hd.every_n_steps > 0
if step % log_train_hd.every_n_steps != 0:
return
n_samples = log_train_hd.n_samples
logger: Optional[WandbLogger] = trainer.logger
assert isinstance(logger, WandbLogger)
global_step = trainer.global_step
self.on_train_batch_end_custom(
logger=logger,
outputs=outputs,
batch=batch,
log_n_samples=n_samples,
global_step=global_step,
)
@rank_zero_only
def on_validation_batch_end(
self,
trainer: pl.Trainer,
pl_module: pl.LightningModule,
outputs: Optional[Any],
batch: Any,
batch_idx: int,
dataloader_idx: int = 0,
) -> None:
log_val_hd = self.log_config.validation.high_dim
log_freq_val_epochs = log_val_hd.every_n_epochs
if not log_val_hd.enable:
return
if dataloader_idx > 0:
raise NotImplementedError
if not self._training_has_started:
# PL has a short sanity check for validation. Hence, we have to make sure that one training run is done.
return
if not self._selected_val_batches:
# We only want to add validation batch indices during the first true validation run.
self._val_batch_indices.append(batch_idx)
return
assert len(self._val_batch_indices) > 0
if batch_idx not in self._val_batch_indices:
return
if trainer.current_epoch % log_freq_val_epochs != 0:
return
self.on_validation_batch_end_custom(batch, outputs)
def on_validation_epoch_start(
self, trainer: pl.Trainer, pl_module: pl.LightningModule
) -> None:
self._reset_buffer()
@rank_zero_only
def on_validation_epoch_end(
self, trainer: pl.Trainer, pl_module: pl.LightningModule
) -> None:
log_val_hd = self.log_config.validation.high_dim
log_n_samples = log_val_hd.n_samples
log_freq_val_epochs = log_val_hd.every_n_epochs
if len(self._val_batch_indices) == 0:
return
if not self._selected_val_batches:
random.seed(0)
num_samples = min(len(self._val_batch_indices), log_n_samples)
# draw without replacement
sampled_indices = random.sample(self._val_batch_indices, num_samples)
self._val_batch_indices = sampled_indices
self._selected_val_batches = True
return
if trainer.current_epoch % log_freq_val_epochs != 0:
return
logger: Optional[WandbLogger] = trainer.logger
assert isinstance(logger, WandbLogger)
self.on_validation_epoch_end_custom(logger)
def on_train_batch_start(
self,
trainer: "pl.Trainer",
pl_module: "pl.LightningModule",
batch: Any,
batch_idx: int,
) -> None:
self._training_has_started = True
@staticmethod
def ev_repr_to_img(x: np.ndarray):
ch, ht, wd = x.shape[-3:]
assert ch > 1 and ch % 2 == 0
ev_repr_reshaped = rearrange(x, "(posneg C) H W -> posneg C H W", posneg=2)
img_neg = np.asarray(
reduce(ev_repr_reshaped[0], "C H W -> H W", "sum"), dtype="int32"
)
img_pos = np.asarray(
reduce(ev_repr_reshaped[1], "C H W -> H W", "sum"), dtype="int32"
)
img_diff = img_pos - img_neg
img = 127 * np.ones((ht, wd, 3), dtype=np.uint8)
img[img_diff > 0] = 255
img[img_diff < 0] = 0
return img
================================================
FILE: RVT/config/dataset/base.yaml
================================================
name: ???
path: ???
train:
sampling: 'mixed' # ('random', 'stream', 'mixed')
random:
weighted_sampling: False
mixed:
w_stream: 1
w_random: 1
eval:
sampling: 'stream'
data_augmentation:
random:
prob_hflip: 0.5
rotate:
prob: 0
min_angle_deg: 2
max_angle_deg: 6
zoom:
prob: 0.8
zoom_in:
weight: 8
factor:
min: 1
max: 1.5
zoom_out:
weight: 2
factor:
min: 1
max: 1.2
stream:
prob_hflip: 0.5
rotate:
prob: 0
min_angle_deg: 2
max_angle_deg: 6
zoom:
prob: 0.5
zoom_out:
factor:
min: 1
max: 1.2
================================================
FILE: RVT/config/dataset/gen1.yaml
================================================
defaults:
- base
name: gen1
ev_repr_name: 'stacked_histogram_dt=50_nbins=10'
sequence_length: 21
resolution_hw: [240, 304]
downsample_by_factor_2: False
only_load_end_labels: False
================================================
FILE: RVT/config/dataset/gen4.yaml
================================================
defaults:
- base
name: gen4
ev_repr_name: 'stacked_histogram_dt=50_nbins=10'
sequence_length: 10
resolution_hw: [720, 1280]
downsample_by_factor_2: True
only_load_end_labels: False
================================================
FILE: RVT/config/experiment/gen1/base.yaml
================================================
# @package _global_
defaults:
- default
model:
backbone:
embed_dim: 64
fpn:
depth: 0.67
================================================
FILE: RVT/config/experiment/gen1/default.yaml
================================================
# @package _global_
defaults:
- /model/maxvit_yolox: default
training:
precision: 32
max_epochs: 10000
max_steps: 400000
learning_rate: 0.0002
lr_scheduler:
use: True
total_steps: ${..max_steps}
pct_start: 0.005
div_factor: 20
final_div_factor: 10000
validation:
val_check_interval: 10000
check_val_every_n_epoch: null
batch_size:
train: 8
eval: 8
hardware:
num_workers:
train: 6
eval: 2
dataset:
train:
sampling: 'mixed'
random:
weighted_sampling: False
mixed:
w_stream: 1
w_random: 1
eval:
sampling: 'stream'
ev_repr_name: 'stacked_histogram_dt=50_nbins=10'
sequence_length: 21
downsample_by_factor_2: False
only_load_end_labels: False
model:
backbone:
partition_split_32: 1
================================================
FILE: RVT/config/experiment/gen1/small.yaml
================================================
# @package _global_
defaults:
- default
model:
backbone:
embed_dim: 48
stage:
attention:
dim_head: 24
fpn:
depth: 0.33
================================================
FILE: RVT/config/experiment/gen4/base.yaml
================================================
# @package _global_
defaults:
- default
model:
backbone:
embed_dim: 64
fpn:
depth: 0.67
================================================
FILE: RVT/config/experiment/gen4/default.yaml
================================================
# @package _global_
defaults:
- /model/maxvit_yolox: default
training:
precision: 32
max_epochs: 10000
max_steps: 400000
learning_rate: 0.0002449489742783178 # 2e-4 * sqrt(effective_batch_size/8) = 2e-4 * sqrt(12/8)
lr_scheduler:
use: True
total_steps: ${..max_steps}
pct_start: 0.005
div_factor: 20
final_div_factor: 10000
validation:
val_check_interval: 10000
check_val_every_n_epoch: null
batch_size:
train: 12
eval: 12
hardware:
num_workers:
train: 6
eval: 2
dataset:
train:
sampling: 'mixed'
random:
weighted_sampling: False
mixed:
w_stream: 1
w_random: 1
eval:
sampling: 'stream'
ev_repr_name: 'stacked_histogram_dt=50_nbins=10'
sequence_length: 10
downsample_by_factor_2: True
only_load_end_labels: False
================================================
FILE: RVT/config/experiment/gen4/small.yaml
================================================
# @package _global_
defaults:
- default
model:
backbone:
embed_dim: 48
stage:
attention:
dim_head: 24
fpn:
depth: 0.33
================================================
FILE: RVT/config/general.yaml
================================================
reproduce:
seed_everything: null # Union[int, null]
deterministic_flag: False # Must be true for fully deterministic behaviour (slows down training)
benchmark: False # Should be set to false for fully deterministic behaviour. Could potentially speed up training.
training:
precision: 16
max_epochs: 10000
max_steps: 400000
learning_rate: 0.0002
weight_decay: 0
gradient_clip_val: 1.0
limit_train_batches: 1.0
lr_scheduler:
use: True
total_steps: ${..max_steps}
pct_start: 0.005
div_factor: 25 # init_lr = max_lr / div_factor
final_div_factor: 10000 # final_lr = max_lr / final_div_factor (this is different from Pytorch' OneCycleLR param)
validation:
limit_val_batches: 1.0
val_check_interval: null # Optional[int]
check_val_every_n_epoch: 1 # Optional[int]
batch_size:
train: 8
eval: 8
hardware:
num_workers:
train: 6
eval: 2
gpus: 0 # Either a single integer (e.g. 3) or a list of integers (e.g. [3,5,6])
dist_backend: "nccl"
logging:
ckpt_every_n_epochs: 1
train:
metrics:
compute: false
detection_metrics_every_n_steps: null # Optional[int] -> null: every train epoch, int: every N steps
log_model_every_n_steps: 5000
log_every_n_steps: 500
high_dim:
enable: True
every_n_steps: 5000
n_samples: 4
validation:
high_dim:
enable: True
every_n_epochs: 1
n_samples: 8
wandb:
# How to use:
# 1) resume existing wandb run: set artifact_name & wandb_runpath
# 2) resume full training state in new wandb run: set artifact_name
# 3) resume only model weights of checkpoint in new wandb run: set artifact_name & resume_only_weights=True
#
# In addition: you can specify artifact_local_file to load the checkpoint from disk.
# This is for example required for resuming training with DDP.
wandb_runpath: null # WandB run path. E.g. USERNAME/PROJECTNAME/1grv5kg6
artifact_name: null # Name of checkpoint/artifact. Required for resuming. E.g. USERNAME/PROJECTNAME/checkpoint-1grv5kg6-last:v15
artifact_local_file: null # If specified, will use the provided local filepath instead of downloading it. Required if resuming with DDP.
resume_only_weights: False
group_name: ??? # Specify group name of the run
project_name: RVT
================================================
FILE: RVT/config/model/base.yaml
================================================
name: ???
================================================
FILE: RVT/config/model/maxvit_yolox/default.yaml
================================================
# @package _global_
defaults:
- override /model: rnndet
model:
backbone:
name: MaxViTRNN
compile:
enable: False
args:
mode: reduce-overhead
input_channels: 20
enable_masking: False
partition_split_32: 2
embed_dim: 64
dim_multiplier: [1, 2, 4, 8]
num_blocks: [1, 1, 1, 1]
T_max_chrono_init: [4, 8, 16, 32]
stem:
patch_size: 4
stage:
downsample:
type: patch
overlap: True
norm_affine: True
attention:
use_torch_mha: False
partition_size: ???
dim_head: 32
attention_bias: True
mlp_activation: gelu
mlp_gated: False
mlp_bias: True
mlp_ratio: 4
drop_mlp: 0
drop_path: 0
ls_init_value: 1e-5
lstm:
dws_conv: False
dws_conv_only_hidden: True
dws_conv_kernel_size: 3
drop_cell_update: 0
s5:
dim: 80
state_dim: 80
s4:
dim: 80
state_dim: 80
fpn:
name: PAFPN
compile:
enable: False
args:
mode: reduce-overhead
depth: 0.67 # round(depth * 3) == num bottleneck blocks
# stage 1 is the first and len(num_layers) is the last
in_stages: [2, 3, 4]
depthwise: False
act: "silu"
head:
name: YoloX
compile:
enable: False
args:
mode: reduce-overhead
depthwise: False
act: "silu"
postprocess:
confidence_threshold: 0.1
nms_threshold: 0.45
================================================
FILE: RVT/config/model/rnndet.yaml
================================================
defaults:
- base
name: rnndet
backbone:
name: ???
fpn:
name: ???
head:
name: ???
postprocess:
confidence_threshold: 0.1
nms_threshold: 0.45
================================================
FILE: RVT/config/modifier.py
================================================
import os
from typing import Tuple
import math
from omegaconf import DictConfig, open_dict
from data.utils.spatial import get_dataloading_hw
def dynamically_modify_train_config(config: DictConfig):
with open_dict(config):
slurm_job_id = os.environ.get("SLURM_JOB_ID")
if slurm_job_id and slurm_job_id != "":
config.slurm_job_id = int(slurm_job_id)
dataset_cfg = config.dataset
dataset_name = dataset_cfg.name
assert dataset_name in {"gen1", "gen4"}
dataset_hw = get_dataloading_hw(dataset_config=dataset_cfg)
mdl_cfg = config.model
mdl_name = mdl_cfg.name
if mdl_name == "rnndet":
backbone_cfg = mdl_cfg.backbone
backbone_name = backbone_cfg.name
if backbone_name == "MaxViTRNN":
partition_split_32 = backbone_cfg.partition_split_32
assert partition_split_32 in (1, 2, 4)
multiple_of = 32 * partition_split_32
mdl_hw = _get_modified_hw_multiple_of(
hw=dataset_hw, multiple_of=multiple_of
)
print(f"Set {backbone_name} backbone (height, width) to {mdl_hw}")
backbone_cfg.in_res_hw = mdl_hw
attention_cfg = backbone_cfg.stage.attention
partition_size = tuple(x // (32 * partition_split_32) for x in mdl_hw)
assert (mdl_hw[0] // 32) % partition_size[
0
] == 0, f"{mdl_hw[0]=}, {partition_size[0]=}"
assert (mdl_hw[1] // 32) % partition_size[
1
] == 0, f"{mdl_hw[1]=}, {partition_size[1]=}"
print(f"Set partition sizes: {partition_size}")
attention_cfg.partition_size = partition_size
else:
print(f"{backbone_name=} not available")
raise NotImplementedError
num_classes = 2 if dataset_name == "gen1" else 3
mdl_cfg.head.num_classes = num_classes
print(f"Set {num_classes=} for detection head")
else:
print(f"{mdl_name=} not available")
raise NotImplementedError
def _get_modified_hw_multiple_of(
hw: Tuple[int, int], multiple_of: int
) -> Tuple[int, ...]:
assert isinstance(hw, tuple), f"{type(hw)=}, {hw=}"
assert len(hw) == 2
assert isinstance(multiple_of, int)
assert multiple_of >= 1
if multiple_of == 1:
return hw
new_hw = tuple(math.ceil(x / multiple_of) * multiple_of for x in hw)
return new_hw
================================================
FILE: RVT/config/train.yaml
================================================
defaults:
- general
- dataset: ???
- model: rnndet
- optional model/dataset: ${model}_${dataset}
================================================
FILE: RVT/config/val.yaml
================================================
defaults:
- dataset: ???
- model: rnndet
- _self_
checkpoint: ???
use_test_set: False
hardware:
num_workers:
eval: 4
gpus: 0 # GPU idx (multi-gpu not supported for validation)
batch_size:
eval: 8
training:
precision: 16
================================================
FILE: RVT/data/genx_utils/collate.py
================================================
from copy import deepcopy
from typing import Any, Callable, Dict, Optional, Type, Tuple, Union
import torch
from data.genx_utils.collate_from_pytorch import collate, default_collate_fn_map
from data.genx_utils.labels import ObjectLabels, SparselyBatchedObjectLabels
def collate_object_labels(
batch,
*,
collate_fn_map: Optional[Dict[Union[Type, Tuple[Type, ...]], Callable]] = None
):
return batch
def collate_sparsely_batched_object_labels(
batch,
*,
collate_fn_map: Optional[Dict[Union[Type, Tuple[Type, ...]], Callable]] = None
):
return SparselyBatchedObjectLabels.transpose_list(batch)
custom_collate_fn_map = deepcopy(default_collate_fn_map)
custom_collate_fn_map[ObjectLabels] = collate_object_labels
custom_collate_fn_map[SparselyBatchedObjectLabels] = (
collate_sparsely_batched_object_labels
)
def custom_collate(batch: Any):
return collate(batch, collate_fn_map=custom_collate_fn_map)
def custom_collate_rnd(batch: Any):
samples = batch
# NOTE: We do not really need the worker id for map style datasets (rnd) but we still provide the id for consistency
worker_info = torch.utils.data.get_worker_info()
local_worker_id = 0 if worker_info is None else worker_info.id
return {
"data": custom_collate(samples),
"worker_id": local_worker_id,
}
def custom_collate_streaming(batch: Any):
"""We assume that we receive a batch collected by a worker of our streaming datapipe"""
samples = batch[0]
worker_id = batch[1]
assert isinstance(worker_id, int)
return {
"data": custom_collate(samples),
"worker_id": worker_id,
}
================================================
FILE: RVT/data/genx_utils/collate_from_pytorch.py
================================================
import collections
import contextlib
import re
import torch
torch_is_version_1 = int(torch.__version__.split(".")[0]) == 1
from typing import Callable, Dict, Optional, Tuple, Type, Union
np_str_obj_array_pattern = re.compile(r"[SaUO]")
default_collate_err_msg_format = (
"default_collate: batch must contain tensors, numpy arrays, numbers, "
"dicts or lists; found {}"
)
def collate(
batch,
*,
collate_fn_map: Optional[Dict[Union[Type, Tuple[Type, ...]], Callable]] = None
):
r"""
General collate function that handles collection type of element within each batch
and opens function registry to deal with specific element types. `default_collate_fn_map`
provides default collate functions for tensors, numpy arrays, numbers and strings.
Args:
batch: a single batch to be collated
collate_fn_map: Optional dictionary mapping from element type to the corresponding collate function.
If the element type isn't present in this dictionary,
this function will go through each key of the dictionary in the insertion order to
invoke the corresponding collate function if the element type is a subclass of the key.
Examples:
>>> # Extend this function to handle batch of tensors
>>> def collate_tensor_fn(batch, *, collate_fn_map):
... return torch.stack(batch, 0)
>>> def custom_collate(batch):
... collate_map = {torch.Tensor: collate_tensor_fn}
... return collate(batch, collate_fn_map=collate_map)
>>> # Extend `default_collate` by in-place modifying `default_collate_fn_map`
>>> default_collate_fn_map.update({torch.Tensor: collate_tensor_fn})
Note:
Each collate function requires a positional argument for batch and a keyword argument
for the dictionary of collate functions as `collate_fn_map`.
"""
elem = batch[0]
elem_type = type(elem)
if collate_fn_map is not None:
if elem_type in collate_fn_map:
return collate_fn_map[elem_type](batch, collate_fn_map=collate_fn_map)
for collate_type in collate_fn_map:
if isinstance(elem, collate_type):
return collate_fn_map[collate_type](
batch, collate_fn_map=collate_fn_map
)
if isinstance(elem, collections.abc.Mapping):
try:
return elem_type(
{
key: collate([d[key] for d in batch], collate_fn_map=collate_fn_map)
for key in elem
}
)
except TypeError:
# The mapping type may not support `__init__(iterable)`.
return {
key: collate([d[key] for d in batch], collate_fn_map=collate_fn_map)
for key in elem
}
elif isinstance(elem, tuple) and hasattr(elem, "_fields"): # namedtuple
return elem_type(
*(
collate(samples, collate_fn_map=collate_fn_map)
for samples in zip(*batch)
)
)
elif isinstance(elem, collections.abc.Sequence):
# check to make sure that the elements in batch have consistent size
it = iter(batch)
elem_size = len(next(it))
if not all(len(elem) == elem_size for elem in it):
raise RuntimeError("each element in list of batch should be of equal size")
transposed = list(zip(*batch)) # It may be accessed twice, so we use a list.
if isinstance(elem, tuple):
return [
collate(samples, collate_fn_map=collate_fn_map)
for samples in transposed
] # Backwards compatibility.
else:
try:
return elem_type(
[
collate(samples, collate_fn_map=collate_fn_map)
for samples in transposed
]
)
except TypeError:
# The sequence type may not support `__init__(iterable)` (e.g., `range`).
return [
collate(samples, collate_fn_map=collate_fn_map)
for samples in transposed
]
raise TypeError(default_collate_err_msg_format.format(elem_type))
if torch_is_version_1:
def collate_tensor_fn(
batch,
*,
collate_fn_map: Optional[Dict[Union[Type, Tuple[Type, ...]], Callable]] = None
):
elem = batch[0]
out = None
if torch.utils.data.get_worker_info() is not None:
# If we're in a background process, concatenate directly into a
# shared memory tensor to avoid an extra copy
numel = sum(x.numel() for x in batch)
storage = elem.storage()._new_shared(numel, device=elem.device)
out = elem.new(storage).resize_(len(batch), *list(elem.size()))
return torch.stack(batch, 0, out=out)
else:
def collate_tensor_fn(
batch,
*,
collate_fn_map: Optional[Dict[Union[Type, Tuple[Type, ...]], Callable]] = None
):
elem = batch[0]
out = None
if torch.utils.data.get_worker_info() is not None:
# If we're in a background process, concatenate directly into a
# shared memory tensor to avoid an extra copy
numel = sum(x.numel() for x in batch)
storage = elem._typed_storage()._new_shared(numel, device=elem.device)
out = elem.new(storage).resize_(len(batch), *list(elem.size()))
return torch.stack(batch, 0, out=out)
def collate_numpy_array_fn(
batch,
*,
collate_fn_map: Optional[Dict[Union[Type, Tuple[Type, ...]], Callable]] = None
):
elem = batch[0]
# array of string classes and object
if np_str_obj_array_pattern.search(elem.dtype.str) is not None:
raise TypeError(default_collate_err_msg_format.format(elem.dtype))
return collate([torch.as_tensor(b) for b in batch], collate_fn_map=collate_fn_map)
def collate_numpy_scalar_fn(
batch,
*,
collate_fn_map: Optional[Dict[Union[Type, Tuple[Type, ...]], Callable]] = None
):
return torch.as_tensor(batch)
def collate_float_fn(
batch,
*,
collate_fn_map: Optional[Dict[Union[Type, Tuple[Type, ...]], Callable]] = None
):
return torch.tensor(batch, dtype=torch.float64)
def collate_int_fn(
batch,
*,
collate_fn_map: Optional[Dict[Union[Type, Tuple[Type, ...]], Callable]] = None
):
return torch.tensor(batch)
def collate_str_fn(
batch,
*,
collate_fn_map: Optional[Dict[Union[Type, Tuple[Type, ...]], Callable]] = None
):
return batch
default_collate_fn_map: Dict[Union[Type, Tuple[Type, ...]], Callable] = {
torch.Tensor: collate_tensor_fn
}
with contextlib.suppress(ImportError):
import numpy as np
# For both ndarray and memmap (subclass of ndarray)
default_collate_fn_map[np.ndarray] = collate_numpy_array_fn
# See scalars hierarchy: https://numpy.org/doc/stable/reference/arrays.scalars.html
# Skip string scalars
default_collate_fn_map[(np.bool_, np.number, np.object_)] = collate_numpy_scalar_fn
default_collate_fn_map[float] = collate_float_fn
default_collate_fn_map[int] = collate_int_fn
default_collate_fn_map[str] = collate_str_fn
================================================
FILE: RVT/data/genx_utils/dataset_rnd.py
================================================
from collections import namedtuple
from collections.abc import Iterable
from pathlib import Path
from typing import List
import numpy as np
from omegaconf import DictConfig
from torch.utils.data import ConcatDataset, Dataset
from torch.utils.data.sampler import WeightedRandomSampler
from tqdm import tqdm
from data.genx_utils.labels import SparselyBatchedObjectLabels
from data.genx_utils.sequence_rnd import SequenceForRandomAccess
from data.utils.augmentor import RandomSpatialAugmentorGenX
from data.utils.types import DatasetMode, LoaderDataDictGenX, DatasetType, DataType
class SequenceDataset(Dataset):
def __init__(
self, path: Path, dataset_mode: DatasetMode, dataset_config: DictConfig
):
assert path.is_dir()
### extract settings from config ###
sequence_length = dataset_config.sequence_length
assert isinstance(sequence_length, int)
assert sequence_length > 0
self.output_seq_len = sequence_length
ev_representation_name = dataset_config.ev_repr_name
downsample_by_factor_2 = dataset_config.downsample_by_factor_2
only_load_end_labels = dataset_config.only_load_end_labels
augm_config = dataset_config.data_augmentation
####################################
if dataset_config.name == "gen1":
dataset_type = DatasetType.GEN1
elif dataset_config.name == "gen4":
dataset_type = DatasetType.GEN4
else:
raise NotImplementedError
self.sequence = SequenceForRandomAccess(
path=path,
ev_representation_name=ev_representation_name,
sequence_length=sequence_length,
dataset_type=dataset_type,
downsample_by_factor_2=downsample_by_factor_2,
only_load_end_labels=only_load_end_labels,
)
self.spatial_augmentor = None
if dataset_mode == DatasetMode.TRAIN:
resolution_hw = tuple(dataset_config.resolution_hw)
assert len(resolution_hw) == 2
ds_by_factor_2 = dataset_config.downsample_by_factor_2
if ds_by_factor_2:
resolution_hw = tuple(x // 2 for x in resolution_hw)
self.spatial_augmentor = RandomSpatialAugmentorGenX(
dataset_hw=resolution_hw,
automatic_randomization=True,
augm_config=augm_config.random,
)
def only_load_labels(self):
self.sequence.only_load_labels()
def load_everything(self):
self.sequence.load_everything()
def __len__(self):
return len(self.sequence)
def __getitem__(self, index: int) -> LoaderDataDictGenX:
item = self.sequence[index]
if (
self.spatial_augmentor is not None
and not self.sequence.is_only_loading_labels()
):
item = self.spatial_augmentor(item)
return item
class CustomConcatDataset(ConcatDataset):
datasets: List[SequenceDataset]
def __init__(self, datasets: Iterable[SequenceDataset]):
super().__init__(datasets=datasets)
def only_load_labels(self):
for idx, dataset in enumerate(self.datasets):
self.datasets[idx].only_load_labels()
def load_everything(self):
for idx, dataset in enumerate(self.datasets):
self.datasets[idx].load_everything()
def build_random_access_dataset(
dataset_mode: DatasetMode, dataset_config: DictConfig
) -> CustomConcatDataset:
dataset_path = Path(dataset_config.path)
assert dataset_path.is_dir(), f"{str(dataset_path)}"
mode2str = {
DatasetMode.TRAIN: "train",
DatasetMode.VALIDATION: "val",
DatasetMode.TESTING: "test",
}
split_path = dataset_path / mode2str[dataset_mode]
assert split_path.is_dir()
seq_datasets = list()
for entry in tqdm(
split_path.iterdir(),
desc=f"creating rnd access {mode2str[dataset_mode]} datasets",
):
seq_datasets.append(
SequenceDataset(
path=entry, dataset_mode=dataset_mode, dataset_config=dataset_config
)
)
return CustomConcatDataset(seq_datasets)
def get_weighted_random_sampler(dataset: CustomConcatDataset) -> WeightedRandomSampler:
class2count = dict()
ClassAndCount = namedtuple("ClassAndCount", ["class_ids", "counts"])
classandcount_list = list()
print("--- START generating weighted random sampler ---")
dataset.only_load_labels()
for idx, data in enumerate(tqdm(dataset, desc="iterate through dataset")):
labels: SparselyBatchedObjectLabels = data[DataType.OBJLABELS_SEQ]
label_list, valid_batch_indices = labels.get_valid_labels_and_batch_indices()
class_ids_seq = list()
for label in label_list:
class_ids_numpy = np.asarray(label.class_id.numpy(), dtype="int32")
class_ids_seq.append(class_ids_numpy)
class_ids_seq, counts_seq = np.unique(
np.concatenate(class_ids_seq), return_counts=True
)
for class_id, count in zip(class_ids_seq, counts_seq):
class2count[class_id] = class2count.get(class_id, 0) + count
classandcount_list.append(
ClassAndCount(class_ids=class_ids_seq, counts=counts_seq)
)
dataset.load_everything()
class2weight = {}
for class_id, count in class2count.items():
count = max(count, 1)
class2weight[class_id] = 1 / count
weights = []
for classandcount in classandcount_list:
weight = 0
for class_id, count in zip(classandcount.class_ids, classandcount.counts):
# Not only weight depending on class but also depending on number of occurrences.
# This will bias towards sampling "frames" with more bounding boxes.
weight += class2weight[class_id] * count
weights.append(weight)
print("--- DONE generating weighted random sampler ---")
return WeightedRandomSampler(
weights=weights, num_samples=len(weights), replacement=True
)
================================================
FILE: RVT/data/genx_utils/dataset_streaming.py
================================================
from functools import partialmethod
from pathlib import Path
from typing import List, Union
from omegaconf import DictConfig
from torchdata.datapipes.map import MapDataPipe
from tqdm import tqdm
from data.genx_utils.sequence_for_streaming import (
SequenceForIter,
RandAugmentIterDataPipe,
)
from data.utils.stream_concat_datapipe import ConcatStreamingDataPipe
from data.utils.stream_sharded_datapipe import ShardedStreamingDataPipe
from data.utils.types import DatasetMode, DatasetType
def build_streaming_dataset(
dataset_mode: DatasetMode,
dataset_config: DictConfig,
batch_size: int,
num_workers: int,
) -> Union[ConcatStreamingDataPipe, ShardedStreamingDataPipe]:
dataset_path = Path(dataset_config.path)
assert dataset_path.is_dir(), f"{str(dataset_path)}"
mode2str = {
DatasetMode.TRAIN: "train",
DatasetMode.VALIDATION: "val",
DatasetMode.TESTING: "test",
}
split_path = dataset_path / mode2str[dataset_mode]
assert split_path.is_dir()
datapipes = list()
num_full_sequences = 0
num_splits = 0
num_split_sequences = 0
guarantee_labels = dataset_mode == DatasetMode.TRAIN
for entry in tqdm(
split_path.iterdir(),
desc=f"creating streaming {mode2str[dataset_mode]} datasets",
):
new_datapipes = get_sequences(
path=entry, dataset_config=dataset_config, guarantee_labels=guarantee_labels
)
if len(new_datapipes) == 1:
num_full_sequences += 1
else:
num_splits += 1
num_split_sequences += len(new_datapipes)
datapipes.extend(new_datapipes)
print(f"{num_full_sequences=}\n{num_splits=}\n{num_split_sequences=}")
if dataset_mode == DatasetMode.TRAIN:
return build_streaming_train_dataset(
datapipes=datapipes,
dataset_config=dataset_config,
batch_size=batch_size,
num_workers=num_workers,
)
elif dataset_mode in (DatasetMode.VALIDATION, DatasetMode.TESTING):
return build_streaming_evaluation_dataset(
datapipes=datapipes, batch_size=batch_size
)
else:
raise NotImplementedError
def get_sequences(
path: Path, dataset_config: DictConfig, guarantee_labels: bool
) -> List[SequenceForIter]:
assert path.is_dir()
### extract settings from config ###
sequence_length = dataset_config.sequence_length
ev_representation_name = dataset_config.ev_repr_name
downsample_by_factor_2 = dataset_config.downsample_by_factor_2
if dataset_config.name == "gen1":
dataset_type = DatasetType.GEN1
elif dataset_config.name == "gen4":
dataset_type = DatasetType.GEN4
else:
raise NotImplementedError
####################################
if guarantee_labels:
return SequenceForIter.get_sequences_with_guaranteed_labels(
path=path,
ev_representation_name=ev_representation_name,
sequence_length=sequence_length,
dataset_type=dataset_type,
downsample_by_factor_2=downsample_by_factor_2,
)
return [
SequenceForIter(
path=path,
ev_representation_name=ev_representation_name,
sequence_length=sequence_length,
dataset_type=dataset_type,
downsample_by_factor_2=downsample_by_factor_2,
)
]
def partialclass(cls, *args, **kwargs):
class NewCls(cls):
__init__ = partialmethod(cls.__init__, *args, **kwargs)
return NewCls
def build_streaming_train_dataset(
datapipes: List[MapDataPipe],
dataset_config: DictConfig,
batch_size: int,
num_workers: int,
) -> ConcatStreamingDataPipe:
assert len(datapipes) > 0
augmentation_datapipe_type = partialclass(
RandAugmentIterDataPipe, dataset_config=dataset_config
)
streaming_dataset = ConcatStreamingDataPipe(
datapipe_list=datapipes,
batch_size=batch_size,
num_workers=num_workers,
augmentation_pipeline=augmentation_datapipe_type,
print_seed_debug=False,
)
return streaming_dataset
def build_streaming_evaluation_dataset(
datapipes: List[MapDataPipe], batch_size: int
) -> ShardedStreamingDataPipe:
assert len(datapipes) > 0
fill_value = datapipes[0].get_fully_padded_sample()
streaming_dataset = ShardedStreamingDataPipe(
datapipe_list=datapipes, batch_size=batch_size, fill_value=fill_value
)
return streaming_dataset
================================================
FILE: RVT/data/genx_utils/labels.py
================================================
from __future__ import annotations
from typing import List, Tuple, Union, Optional
import math
import numpy as np
import torch as th
from einops import rearrange
from torch.nn.functional import pad
class ObjectLabelBase:
_str2idx = {
"t": 0,
"x": 1,
"y": 2,
"w": 3,
"h": 4,
"class_id": 5,
"class_confidence": 6,
}
def __init__(self, object_labels: th.Tensor, input_size_hw: Tuple[int, int]):
assert isinstance(object_labels, th.Tensor)
assert object_labels.dtype in {th.float32, th.float64}
assert object_labels.ndim == 2
assert object_labels.shape[-1] == len(self._str2idx)
assert isinstance(input_size_hw, tuple)
assert len(input_size_hw) == 2
self.object_labels = object_labels
self._input_size_hw = input_size_hw
self._is_numpy = False
def clamp_to_frame_(self):
ht, wd = self.input_size_hw
x0 = th.clamp(self.x, min=0, max=wd - 1)
y0 = th.clamp(self.y, min=0, max=ht - 1)
x1 = th.clamp(self.x + self.w, min=0, max=wd - 1)
y1 = th.clamp(self.y + self.h, min=0, max=ht - 1)
w = x1 - x0
h = y1 - y0
assert th.all(w > 0)
assert th.all(h > 0)
self.x = x0
self.y = y0
self.w = w
self.h = h
def remove_flat_labels_(self):
keep = (self.w > 0) & (self.h > 0)
self.object_labels = self.object_labels[keep]
@classmethod
def create_empty(cls):
# This is useful to represent cases where no labels are available.
return ObjectLabelBase(
object_labels=th.empty((0, len(cls._str2idx))), input_size_hw=(0, 0)
)
def _assert_not_numpy(self):
assert (
not self._is_numpy
), "Labels have been converted numpy. \
Numpy is not supported for the intended operations."
def to(self, *args, **kwargs):
# This function executes torch.to on self tensors and returns self.
self._assert_not_numpy()
# This will be used by Pytorch Lightning to transfer to the relevant device
self.object_labels = self.object_labels.to(*args, **kwargs)
return self
def numpy_(self) -> None:
"""
In place conversion to numpy (detach + to cpu + to numpy).
Cannot be undone.
"""
self._is_numpy = True
self.object_labels = self.object_labels.detach().cpu().numpy()
@property
def input_size_hw(self) -> Tuple[int, int]:
return self._input_size_hw
@input_size_hw.setter
def input_size_hw(self, height_width: Tuple[int, int]):
assert isinstance(height_width, tuple)
assert len(height_width) == 2
assert height_width[0] > 0
assert height_width[1] > 0
self._input_size_hw = height_width
def get(self, request: str):
assert request in self._str2idx
return self.object_labels[:, self._str2idx[request]]
@property
def t(self):
return self.object_labels[:, self._str2idx["t"]]
@property
def x(self):
return self.object_labels[:, self._str2idx["x"]]
@x.setter
def x(self, value: Union[th.Tensor, np.ndarray]):
self.object_labels[:, self._str2idx["x"]] = value
@property
def y(self):
return self.object_labels[:, self._str2idx["y"]]
@y.setter
def y(self, value: Union[th.Tensor, np.ndarray]):
self.object_labels[:, self._str2idx["y"]] = value
@property
def w(self):
return self.object_labels[:, self._str2idx["w"]]
@w.setter
def w(self, value: Union[th.Tensor, np.ndarray]):
self.object_labels[:, self._str2idx["w"]] = value
@property
def h(self):
return self.object_labels[:, self._str2idx["h"]]
@h.setter
def h(self, value: Union[th.Tensor, np.ndarray]):
self.object_labels[:, self._str2idx["h"]] = value
@property
def class_id(self):
return self.object_labels[:, self._str2idx["class_id"]]
@property
def class_confidence(self):
return self.object_labels[:, self._str2idx["class_confidence"]]
@property
def dtype(self):
return self.object_labels.dtype
@property
def device(self):
return self.object_labels.device
class ObjectLabelFactory(ObjectLabelBase):
def __init__(
self,
object_labels: th.Tensor,
objframe_idx_2_label_idx: th.Tensor,
input_size_hw: Tuple[int, int],
downsample_factor: Optional[float] = None,
):
super().__init__(object_labels=object_labels, input_size_hw=input_size_hw)
assert objframe_idx_2_label_idx.dtype == th.int64
assert objframe_idx_2_label_idx.dim() == 1
self.objframe_idx_2_label_idx = objframe_idx_2_label_idx
self.downsample_factor = downsample_factor
if self.downsample_factor is not None:
assert self.downsample_factor > 1
self.clamp_to_frame_()
@staticmethod
def from_structured_array(
object_labels: np.ndarray,
objframe_idx_2_label_idx: np.ndarray,
input_size_hw: Tuple[int, int],
downsample_factor: Optional[float] = None,
) -> ObjectLabelFactory:
np_labels = [
object_labels[key].astype("float32") for key in ObjectLabels._str2idx.keys()
]
np_labels = rearrange(np_labels, "fields L -> L fields")
torch_labels = th.from_numpy(np_labels)
objframe_idx_2_label_idx = th.from_numpy(
objframe_idx_2_label_idx.astype("int64")
)
assert objframe_idx_2_label_idx.numel() == np.unique(object_labels["t"]).size
return ObjectLabelFactory(
object_labels=torch_labels,
objframe_idx_2_label_idx=objframe_idx_2_label_idx,
input_size_hw=input_size_hw,
downsample_factor=downsample_factor,
)
def __len__(self):
return len(self.objframe_idx_2_label_idx)
def __getitem__(self, item: int) -> ObjectLabels:
assert item >= 0
length = len(self)
assert length > 0
assert item < length
is_last_item = item == length - 1
from_idx = self.objframe_idx_2_label_idx[item]
to_idx = (
self.object_labels.shape[0]
if is_last_item
else self.objframe_idx_2_label_idx[item + 1]
)
assert to_idx > from_idx
object_labels = ObjectLabels(
object_labels=self.object_labels[from_idx:to_idx].clone(),
input_size_hw=self.input_size_hw,
)
if self.downsample_factor is not None:
object_labels.scale_(scaling_multiplier=1 / self.downsample_factor)
return object_labels
class ObjectLabels(ObjectLabelBase):
def __init__(self, object_labels: th.Tensor, input_size_hw: Tuple[int, int]):
super().__init__(object_labels=object_labels, input_size_hw=input_size_hw)
def __len__(self) -> int:
return self.object_labels.shape[0]
def rotate_(self, angle_deg: float):
if len(self) == 0:
return
# (x0,y0)---(x1,y0) p00---p10
# | | | |
# | | | |
# (x0,y1)---(x1,y1) p01---p11
p00 = th.stack((self.x, self.y), dim=1)
p10 = th.stack((self.x + self.w, self.y), dim=1)
p01 = th.stack((self.x, self.y + self.h), dim=1)
p11 = th.stack((self.x + self.w, self.y + self.h), dim=1)
# points: 4 x N x 2
points = th.stack((p00, p10, p01, p11), dim=0)
cx = self._input_size_hw[1] // 2
cy = self._input_size_hw[0] // 2
center = th.tensor([cx, cy], device=self.device)
angle_rad = angle_deg / 180 * math.pi
# counter-clockwise rotation
rot_matrix = th.tensor(
[
[math.cos(angle_rad), math.sin(angle_rad)],
[-math.sin(angle_rad), math.cos(angle_rad)],
],
device=self.device,
)
points = points - center
points = th.einsum("ij,pnj->pni", rot_matrix, points)
points = points + center
height, width = self.input_size_hw
x0 = th.clamp(th.min(points[..., 0], dim=0)[0], min=0, max=width - 1)
y0 = th.clamp(th.min(points[..., 1], dim=0)[0], min=0, max=height - 1)
x1 = th.clamp(th.max(points[..., 0], dim=0)[0], min=0, max=width - 1)
y1 = th.clamp(th.max(points[..., 1], dim=0)[0], min=0, max=height - 1)
self.x = x0
self.y = y0
self.w = x1 - x0
self.h = y1 - y0
self.remove_flat_labels_()
assert th.all(self.x >= 0)
assert th.all(self.y >= 0)
assert th.all(self.x + self.w <= self.input_size_hw[1] - 1)
assert th.all(self.y + self.h <= self.input_size_hw[0] - 1)
def zoom_in_and_rescale_(
self, zoom_coordinates_x0y0: Tuple[int, int], zoom_in_factor: float
):
"""
1) Computes a new smaller canvas size: original canvas scaled by a factor of 1/zoom_in_factor (downscaling)
2) Places the smaller canvas inside the original canvas at the top-left coordinates zoom_coordinates_x0y0
3) Extract the smaller canvas and rescale it back to the original resolution
"""
if len(self) == 0:
return
assert len(zoom_coordinates_x0y0) == 2
assert zoom_in_factor >= 1
if zoom_in_factor == 1:
return
z_x0, z_y0 = zoom_coordinates_x0y0
h_orig, w_orig = self.input_size_hw
assert 0 <= z_x0 <= w_orig - 1
assert 0 <= z_y0 <= h_orig - 1
zoom_window_h, zoom_window_w = tuple(
x / zoom_in_factor for x in self.input_size_hw
)
z_x1 = min(z_x0 + zoom_window_w, w_orig - 1)
assert z_x1 <= w_orig - 1, f"{z_x1=} is larger than {w_orig-1=}"
z_y1 = min(z_y0 + zoom_window_h, h_orig - 1)
assert z_y1 <= h_orig - 1, f"{z_y1=} is larger than {h_orig-1=}"
x0 = th.clamp(self.x, min=z_x0, max=z_x1 - 1)
y0 = th.clamp(self.y, min=z_y0, max=z_y1 - 1)
x1 = th.clamp(self.x + self.w, min=z_x0, max=z_x1 - 1)
y1 = th.clamp(self.y + self.h, min=z_y0, max=z_y1 - 1)
self.x = x0 - z_x0
self.y = y0 - z_y0
self.w = x1 - x0
self.h = y1 - y0
self.input_size_hw = (zoom_window_h, zoom_window_w)
self.remove_flat_labels_()
self.scale_(scaling_multiplier=zoom_in_factor)
def zoom_out_and_rescale_(
self, zoom_coordinates_x0y0: Tuple[int, int], zoom_out_factor: float
):
"""
1) Scales the input by a factor of 1/zoom_out_factor (i.e. reduces the canvas size)
2) Places the downscaled canvas into the original canvas at the top-left coordinates zoom_coordinates_x0y0
"""
if len(self) == 0:
return
assert len(zoom_coordinates_x0y0) == 2
assert zoom_out_factor >= 1
if zoom_out_factor == 1:
return
h_orig, w_orig = self.input_size_hw
self.scale_(scaling_multiplier=1 / zoom_out_factor)
self.input_size_hw = (h_orig, w_orig)
z_x0, z_y0 = zoom_coordinates_x0y0
assert 0 <= z_x0 <= w_orig - 1
assert 0 <= z_y0 <= h_orig - 1
self.x = self.x + z_x0
self.y = self.y + z_y0
def scale_(self, scaling_multiplier: float):
if len(self) == 0:
return
assert scaling_multiplier > 0
if scaling_multiplier == 1:
return
img_ht, img_wd = self.input_size_hw
new_img_ht = scaling_multiplier * img_ht
new_img_wd = scaling_multiplier * img_wd
self.input_size_hw = (new_img_ht, new_img_wd)
x1 = th.clamp((self.x + self.w) * scaling_multiplier, max=new_img_wd - 1)
y1 = th.clamp((self.y + self.h) * scaling_multiplier, max=new_img_ht - 1)
self.x = self.x * scaling_multiplier
self.y = self.y * scaling_multiplier
self.w = x1 - self.x
self.h = y1 - self.y
self.remove_flat_labels_()
def flip_lr_(self) -> None:
if len(self) == 0:
return
self.x = self.input_size_hw[1] - 1 - self.x - self.w
def get_labels_as_tensors(self, format_: str = "yolox") -> th.Tensor:
self._assert_not_numpy()
if format_ == "yolox":
out = th.zeros((len(self), 5), dtype=th.float32, device=self.device)
if len(self) == 0:
return out
out[:, 0] = self.class_id
out[:, 1] = self.x + 0.5 * self.w
out[:, 2] = self.y + 0.5 * self.h
out[:, 3] = self.w
out[:, 4] = self.h
return out
else:
raise NotImplementedError
@staticmethod
def get_labels_as_batched_tensor(
obj_label_list: List[ObjectLabels], format_: str = "yolox"
) -> th.Tensor:
num_object_frames = len(obj_label_list)
assert num_object_frames > 0
max_num_labels_per_object_frame = max([len(x) for x in obj_label_list])
assert max_num_labels_per_object_frame > 0
if format_ == "yolox":
tensor_labels = []
for labels in obj_label_list:
obj_labels_tensor = labels.get_labels_as_tensors(format_=format_)
num_to_pad = max_num_labels_per_object_frame - len(labels)
padded_labels = pad(
obj_labels_tensor, (0, 0, 0, num_to_pad), mode="constant", value=0
)
tensor_labels.append(padded_labels)
tensor_labels = th.stack(tensors=tensor_labels, dim=0)
return tensor_labels
else:
raise NotImplementedError
class SparselyBatchedObjectLabels:
def __init__(self, sparse_object_labels_batch: List[Optional[ObjectLabels]]):
# Can contain None elements that indicate missing labels.
for entry in sparse_object_labels_batch:
assert isinstance(entry, ObjectLabels) or entry is None
self.sparse_object_labels_batch = sparse_object_labels_batch
self.set_empty_labels_to_none_()
def __len__(self) -> int:
return len(self.sparse_object_labels_batch)
def __iter__(self):
return iter(self.sparse_object_labels_batch)
def __getitem__(self, item: int) -> Optional[ObjectLabels]:
if item < 0 or item >= len(self):
raise IndexError(f"Index ({item}) out of range (0, {len(self) - 1})")
return self.sparse_object_labels_batch[item]
def __add__(self, other: SparselyBatchedObjectLabels):
sparse_object_labels_batch = (
self.sparse_object_labels_batch + other.sparse_object_labels_batch
)
return SparselyBatchedObjectLabels(
sparse_object_labels_batch=sparse_object_labels_batch
)
def set_empty_labels_to_none_(self):
for idx, obj_label in enumerate(self.sparse_object_labels_batch):
if obj_label is not None and len(obj_label) == 0:
self.sparse_object_labels_batch[idx] = None
@property
def input_size_hw(self) -> Optional[Union[Tuple[int, int], Tuple[float, float]]]:
for obj_labels in self.sparse_object_labels_batch:
if obj_labels is not None:
return obj_labels.input_size_hw
return None
def zoom_in_and_rescale_(self, *args, **kwargs):
for idx, entry in enumerate(self.sparse_object_labels_batch):
if entry is not None:
self.sparse_object_labels_batch[idx].zoom_in_and_rescale_(
*args, **kwargs
)
# We may have deleted labels. If no labels are left, set the object to None
self.set_empty_labels_to_none_()
def zoom_out_and_rescale_(self, *args, **kwargs):
for idx, entry in enumerate(self.sparse_object_labels_batch):
if entry is not None:
self.sparse_object_labels_batch[idx].zoom_out_and_rescale_(
*args, **kwargs
)
def rotate_(self, *args, **kwargs):
for idx, entry in enumerate(self.sparse_object_labels_batch):
if entry is not None:
self.sparse_object_labels_batch[idx].rotate_(*args, **kwargs)
def scale_(self, *args, **kwargs):
for idx, entry in enumerate(self.sparse_object_labels_batch):
if entry is not None:
self.sparse_object_labels_batch[idx].scale_(*args, **kwargs)
# We may have deleted labels. If no labels are left, set the object to None
self.set_empty_labels_to_none_()
def flip_lr_(self):
for idx, entry in enumerate(self.sparse_object_labels_batch):
if entry is not None:
self.sparse_object_labels_batch[idx].flip_lr_()
def to(self, *args, **kwargs):
for idx, entry in enumerate(self.sparse_object_labels_batch):
if entry is not None:
self.sparse_object_labels_batch[idx].to(*args, **kwargs)
return self
def get_valid_labels_and_batch_indices(
self,
) -> Tuple[List[ObjectLabels], List[int]]:
out = list()
valid_indices = list()
for idx, label in enumerate(self.sparse_object_labels_batch):
if label is not None:
out.append(label)
valid_indices.append(idx)
return out, valid_indices
@staticmethod
def transpose_list(
list_of_sparsely_batched_object_labels: List[SparselyBatchedObjectLabels],
) -> List[SparselyBatchedObjectLabels]:
return [
SparselyBatchedObjectLabels(list(labels_as_tuple))
for labels_as_tuple in zip(*list_of_sparsely_batched_object_labels)
]
================================================
FILE: RVT/data/genx_utils/sequence_base.py
================================================
from pathlib import Path
from typing import Any, List, Optional
import h5py
import numpy as np
import torch
from torchdata.datapipes.map import MapDataPipe
from data.genx_utils.labels import ObjectLabelFactory, ObjectLabels
from data.utils.spatial import get_original_hw
from data.utils.types import DatasetType
from utils.timers import TimerDummy as Timer
def get_event_representation_dir(path: Path, ev_representation_name: str) -> Path:
ev_repr_dir = path / "event_representations_v2" / ev_representation_name
assert ev_repr_dir.is_dir(), f"{ev_repr_dir}"
return ev_repr_dir
def get_objframe_idx_2_repr_idx(path: Path, ev_representation_name: str) -> np.ndarray:
ev_repr_dir = get_event_representation_dir(
path=path, ev_representation_name=ev_representation_name
)
objframe_idx_2_repr_idx = np.load(str(ev_repr_dir / "objframe_idx_2_repr_idx.npy"))
return objframe_idx_2_repr_idx
class SequenceBase(MapDataPipe):
"""
Structure example of a sequence:
.
├── event_representations_v2
│ └── ev_representation_name
│ ├── event_representations.h5
│ ├── objframe_idx_2_repr_idx.npy
│ └── timestamps_us.npy
└── labels_v2
├── labels.npz
└── timestamps_us.npy
"""
def __init__(
self,
path: Path,
ev_representation_name: str,
sequence_length: int,
dataset_type: DatasetType,
downsample_by_factor_2: bool,
only_load_end_labels: bool,
):
assert sequence_length >= 1
assert path.is_dir()
assert dataset_type in {
DatasetType.GEN1,
DatasetType.GEN4,
}, f"{dataset_type} not implemented"
self.only_load_end_labels = only_load_end_labels
ev_repr_dir = get_event_representation_dir(
path=path, ev_representation_name=ev_representation_name
)
labels_dir = path / "labels_v2"
assert labels_dir.is_dir()
height, width = get_original_hw(dataset_type)
self.seq_len = sequence_length
ds_factor_str = "_ds2_nearest" if downsample_by_factor_2 else ""
self.ev_repr_file = ev_repr_dir / f"event_representations{ds_factor_str}.h5"
assert self.ev_repr_file.exists(), f"{str(self.ev_repr_file)=}"
with Timer(timer_name="prepare labels"):
label_data = np.load(str(labels_dir / "labels.npz"))
objframe_idx_2_label_idx = label_data["objframe_idx_2_label_idx"]
labels = label_data["labels"]
label_factory = ObjectLabelFactory.from_structured_array(
object_labels=labels,
objframe_idx_2_label_idx=objframe_idx_2_label_idx,
input_size_hw=(height, width),
downsample_factor=2 if downsample_by_factor_2 else None,
)
self.label_factory = label_factory
with Timer(timer_name="load objframe_idx_2_repr_idx"):
self.objframe_idx_2_repr_idx = get_objframe_idx_2_repr_idx(
path=path, ev_representation_name=ev_representation_name
)
with Timer(timer_name="construct repr_idx_2_objframe_idx"):
self.repr_idx_2_objframe_idx = dict(
zip(
self.objframe_idx_2_repr_idx,
range(len(self.objframe_idx_2_repr_idx)),
)
)
def _get_labels_from_repr_idx(self, repr_idx: int) -> Optional[ObjectLabels]:
objframe_idx = self.repr_idx_2_objframe_idx.get(repr_idx, None)
return None if objframe_idx is None else self.label_factory[objframe_idx]
def _get_event_repr_torch(self, start_idx: int, end_idx: int) -> List[torch.Tensor]:
assert end_idx > start_idx
with h5py.File(str(self.ev_repr_file), "r") as h5f:
ev_repr = h5f["data"][start_idx:end_idx]
ev_repr = torch.from_numpy(ev_repr)
if ev_repr.dtype != torch.uint8:
ev_repr = torch.asarray(ev_repr, dtype=torch.float32)
ev_repr = torch.split(ev_repr, 1, dim=0)
# remove first dim that is always 1 due to how torch.split works
ev_repr = [x[0] for x in ev_repr]
return ev_repr
def __len__(self) -> int:
raise NotImplementedError
def __getitem__(self, index: int) -> Any:
raise NotImplementedError
================================================
FILE: RVT/data/genx_utils/sequence_for_streaming.py
================================================
from pathlib import Path
from typing import List, Optional, Union, Tuple
import h5py
import numpy as np
import torch
from omegaconf import DictConfig
from torchdata.datapipes.iter import IterDataPipe
from data.genx_utils.labels import SparselyBatchedObjectLabels
from data.genx_utils.sequence_base import SequenceBase, get_objframe_idx_2_repr_idx
from data.utils.augmentor import RandomSpatialAugmentorGenX
from data.utils.types import DataType, DatasetType, LoaderDataDictGenX
from utils.timers import TimerDummy as Timer
def _scalar_as_1d_array(scalar: Union[int, float]):
return np.atleast_1d(scalar)
def _get_ev_repr_range_indices(
indices: np.ndarray, max_len: int
) -> List[Tuple[int, int]]:
"""
Computes a list of index ranges based on the input array of indices and a maximum length.
The index ranges are computed such that the difference between consecutive indices
should not exceed the maximum length (max_len).
Parameters:
-----------
indices : np.ndarray
A NumPy array of indices, where the indices are sorted in ascending order.
max_len : int
The maximum allowed length between consecutive indices.
Returns:
--------
out : List[Tuple[int, int]]
A list of tuples, where each tuple contains two integers representing the start and
stop indices of the range.
"""
meta_indices_stop = np.flatnonzero(np.diff(indices) > max_len)
meta_indices_start = np.concatenate((np.atleast_1d(0), meta_indices_stop + 1))
meta_indices_stop = np.concatenate(
(meta_indices_stop, np.atleast_1d(len(indices) - 1))
)
out = list()
for meta_idx_start, meta_idx_stop in zip(meta_indices_start, meta_indices_stop):
idx_start = max(indices[meta_idx_start] - max_len + 1, 0)
idx_stop = indices[meta_idx_stop] + 1
out.append((idx_start, idx_stop))
return out
class SequenceForIter(SequenceBase):
def __init__(
self,
path: Path,
ev_representation_name: str,
sequence_length: int,
dataset_type: DatasetType,
downsample_by_factor_2: bool,
range_indices: Optional[Tuple[int, int]] = None,
):
super().__init__(
path=path,
ev_representation_name=ev_representation_name,
sequence_length=sequence_length,
dataset_type=dataset_type,
downsample_by_factor_2=downsample_by_factor_2,
only_load_end_labels=False,
)
with h5py.File(str(self.ev_repr_file), "r") as h5f:
num_ev_repr = h5f["data"].shape[0]
if range_indices is None:
repr_idx_start = max(
self.objframe_idx_2_repr_idx[0] - sequence_length + 1, 0
)
repr_idx_stop = num_ev_repr
else:
repr_idx_start, repr_idx_stop = range_indices
# Set start idx such that the first label is no further than the last timestamp of the first sample sub-sequence
min_start_repr_idx = max(
self.objframe_idx_2_repr_idx[0] - sequence_length + 1, 0
)
assert (
0 <= min_start_repr_idx <= repr_idx_start < repr_idx_stop <= num_ev_repr
), f"{min_start_repr_idx=}, {repr_idx_start=}, {repr_idx_stop=}, {num_ev_repr=}, {path=}"
self.start_indices = list(range(repr_idx_start, repr_idx_stop, sequence_length))
self.stop_indices = self.start_indices[1:] + [repr_idx_stop]
self.length = len(self.start_indices)
self._padding_representation = None
@staticmethod
def get_sequences_with_guaranteed_labels(
path: Path,
ev_representation_name: str,
sequence_length: int,
dataset_type: DatasetType,
downsample_by_factor_2: bool,
) -> List["SequenceForIter"]:
"""Generate sequences such that we do always have labels within each sample of the sequence
This is required for training such that we are guaranteed to always have labels in the training step.
However, for validation we don't require this if we catch the special case.
"""
objframe_idx_2_repr_idx = get_objframe_idx_2_repr_idx(
path=path, ev_representation_name=ev_representation_name
)
# max diff for repr idx is sequence length
range_indices_list = _get_ev_repr_range_indices(
indices=objframe_idx_2_repr_idx, max_len=sequence_length
)
sequence_list = list()
for range_indices in range_indices_list:
sequence_list.append(
SequenceForIter(
path=path,
ev_representation_name=ev_representation_name,
sequence_length=sequence_length,
dataset_type=dataset_type,
downsample_by_factor_2=downsample_by_factor_2,
range_indices=range_indices,
)
)
return sequence_list
@property
def padding_representation(self) -> torch.Tensor:
if self._padding_representation is None:
ev_repr = self._get_event_repr_torch(start_idx=0, end_idx=1)[0]
self._padding_representation = torch.zeros_like(ev_repr)
return self._padding_representation
def get_fully_padded_sample(self) -> LoaderDataDictGenX:
is_first_sample = False
is_padded_mask = [True] * self.seq_len
ev_repr = [self.padding_representation] * self.seq_len
labels = [None] * self.seq_len
sparse_labels = SparselyBatchedObjectLabels(sparse_object_labels_batch=labels)
out = {
DataType.EV_REPR: ev_repr,
DataType.OBJLABELS_SEQ: sparse_labels,
DataType.IS_FIRST_SAMPLE: is_first_sample,
DataType.IS_PADDED_MASK: is_padded_mask,
}
return out
def __len__(self):
return self.length
def __getitem__(self, index: int) -> LoaderDataDictGenX:
start_idx = self.start_indices[index]
end_idx = self.stop_indices[index]
# sequence info ###
sample_len = end_idx - start_idx
assert self.seq_len >= sample_len > 0, (
f"{self.seq_len=}, {sample_len=}, {start_idx=}, {end_idx=}, "
f"\n{self.start_indices=}\n{self.stop_indices=}"
)
is_first_sample = True if index == 0 else False
is_padded_mask = [False] * sample_len
###################
# event representations ###
with Timer(timer_name="read ev reprs"):
ev_repr = self._get_event_repr_torch(start_idx=start_idx, end_idx=end_idx)
assert len(ev_repr) == sample_len
###########################
# labels ###
labels = list()
for repr_idx in range(start_idx, end_idx):
labels.append(self._get_labels_from_repr_idx(repr_idx))
assert len(labels) == len(ev_repr)
############
# apply padding (if necessary) ###
if sample_len < self.seq_len:
padding_len = self.seq_len - sample_len
is_padded_mask.extend([True] * padding_len)
ev_repr.extend([self.padding_representation] * padding_len)
labels.extend([None] * padding_len)
##################################
# convert labels to sparse labels for datapipes and dataloader
sparse_labels = SparselyBatchedObjectLabels(sparse_object_labels_batch=labels)
out = {
DataType.EV_REPR: ev_repr,
DataType.OBJLABELS_SEQ: sparse_labels,
DataType.IS_FIRST_SAMPLE: is_first_sample,
DataType.IS_PADDED_MASK: is_padded_mask,
}
return out
class RandAugmentIterDataPipe(IterDataPipe):
def __init__(self, source_dp: IterDataPipe, dataset_config: DictConfig):
super().__init__()
self.source_dp = source_dp
resolution_hw = tuple(dataset_config.resolution_hw)
assert len(resolution_hw) == 2
ds_by_factor_2 = dataset_config.downsample_by_factor_2
if ds_by_factor_2:
resolution_hw = tuple(x // 2 for x in resolution_hw)
augm_config = dataset_config.data_augmentation
self.spatial_augmentor = RandomSpatialAugmentorGenX(
dataset_hw=resolution_hw,
automatic_randomization=False,
augm_config=augm_config.stream,
)
def __iter__(self):
self.spatial_augmentor.randomize_augmentation()
for x in self.source_dp:
yield self.spatial_augmentor(x)
================================================
FILE: RVT/data/genx_utils/sequence_rnd.py
================================================
from pathlib import Path
from data.genx_utils.labels import SparselyBatchedObjectLabels
from data.genx_utils.sequence_base import SequenceBase
from data.utils.types import DataType, DatasetType, LoaderDataDictGenX
from utils.timers import TimerDummy as Timer
class SequenceForRandomAccess(SequenceBase):
def __init__(
self,
path: Path,
ev_representation_name: str,
sequence_length: int,
dataset_type: DatasetType,
downsample_by_factor_2: bool,
only_load_end_labels: bool,
):
super().__init__(
path=path,
ev_representation_name=ev_representation_name,
sequence_length=sequence_length,
dataset_type=dataset_type,
downsample_by_factor_2=downsample_by_factor_2,
only_load_end_labels=only_load_end_labels,
)
self.start_idx_offset = None
for objframe_idx, repr_idx in enumerate(self.objframe_idx_2_repr_idx):
if repr_idx - self.seq_len + 1 >= 0:
# We can fit the sequence length to the label
self.start_idx_offset = objframe_idx
break
if self.start_idx_offset is None:
# This leads to actual length of 0:
self.start_idx_offset = len(self.label_factory)
self.length = len(self.label_factory) - self.start_idx_offset
assert len(self.label_factory) == len(self.objframe_idx_2_repr_idx)
# Useful for weighted sampler that is based on label statistics:
self._only_load_labels = False
def __len__(self):
return self.length
def __getitem__(self, index: int) -> LoaderDataDictGenX:
corrected_idx = index + self.start_idx_offset
labels_repr_idx = self.objframe_idx_2_repr_idx[corrected_idx]
end_idx = labels_repr_idx + 1
start_idx = end_idx - self.seq_len
assert_msg = (
f"{self.ev_repr_file=}, {self.start_idx_offset=}, {start_idx=}, {end_idx=}"
)
assert start_idx >= 0, assert_msg
labels = list()
for repr_idx in range(start_idx, end_idx):
if self.only_load_end_labels and repr_idx < end_idx - 1:
labels.append(None)
else:
labels.append(self._get_labels_from_repr_idx(repr_idx))
sparse_labels = SparselyBatchedObjectLabels(sparse_object_labels_batch=labels)
if self._only_load_labels:
return {DataType.OBJLABELS_SEQ: sparse_labels}
with Timer(timer_name="read ev reprs"):
ev_repr = self._get_event_repr_torch(start_idx=start_idx, end_idx=end_idx)
assert len(sparse_labels) == len(ev_repr)
is_first_sample = True # Due to random loading
is_padded_mask = [False] * len(ev_repr)
out = {
DataType.EV_REPR: ev_repr,
DataType.OBJLABELS_SEQ: sparse_labels,
DataType.IS_FIRST_SAMPLE: is_first_sample,
DataType.IS_PADDED_MASK: is_padded_mask,
}
return out
def is_only_loading_labels(self) -> bool:
return self._only_load_labels
def only_load_labels(self):
self._only_load_labels = True
def load_everything(self):
self._only_load_labels = False
================================================
FILE: RVT/data/utils/augmentor.py
================================================
import collections.abc as abc
from dataclasses import dataclass
from typing import Any, Optional, Tuple, Union
from warnings import filterwarnings, warn
import torch as th
import torch.distributions.categorical
from omegaconf import DictConfig
from torch.nn.functional import interpolate
from torchvision.transforms import InterpolationMode
from torchvision.transforms.functional import rotate
from data.genx_utils.labels import ObjectLabels, SparselyBatchedObjectLabels
from data.utils.types import DataType, LoaderDataDictGenX
from utils.helpers import torch_uniform_sample_scalar
NO_LABEL_WARN_MSG = (
"No Labels found. This can lead to a crash and should not happen often."
)
filterwarnings("always", message=NO_LABEL_WARN_MSG)
@dataclass
class ZoomOutState:
active: bool
x0: int
y0: int
zoom_out_factor: float
@dataclass
class RotationState:
active: bool
angle_deg: float
@dataclass
class AugmentationState:
apply_h_flip: bool
rotation: RotationState
apply_zoom_in: bool
zoom_out: ZoomOutState
class RandomSpatialAugmentorGenX:
def __init__(
self,
dataset_hw: Tuple[int, int],
automatic_randomization: bool,
augm_config: DictConfig,
):
assert isinstance(dataset_hw, tuple)
assert len(dataset_hw) == 2
assert all(x > 0 for x in dataset_hw)
assert isinstance(automatic_randomization, bool)
self.hw_tuple = dataset_hw
self.automatic_randomization = automatic_randomization
self.h_flip_prob = augm_config.prob_hflip
self.rot_prob = augm_config.rotate.prob
self.rot_min_angle_deg = augm_config.rotate.get("min_angle_deg", 0)
self.rot_max_angle_deg = augm_config.rotate.max_angle_deg
self.zoom_prob = augm_config.zoom.prob
zoom_out_weight = augm_config.zoom.zoom_out.get("weight", 1)
self.min_zoom_out_factor = augm_config.zoom.zoom_out.factor.min
self.max_zoom_out_factor = augm_config.zoom.zoom_out.factor.max
has_zoom_in = "zoom_in" in augm_config.zoom
zoom_in_weight = augm_config.zoom.zoom_in.weight if has_zoom_in else 0
self.min_zoom_in_factor = (
augm_config.zoom.zoom_in.factor.min if has_zoom_in else 1
)
self.max_zoom_in_factor = (
augm_config.zoom.zoom_in.factor.max if has_zoom_in else 1
)
assert 0 <= self.h_flip_prob <= 1
assert 0 <= self.rot_prob <= 1
assert 0 <= self.rot_min_angle_deg <= self.rot_max_angle_deg
assert 0 <= self.zoom_prob <= 1
assert 0 <= zoom_in_weight
assert self.max_zoom_in_factor >= self.min_zoom_in_factor >= 1
assert 0 <= zoom_out_weight
assert self.max_zoom_out_factor >= self.min_zoom_out_factor >= 1
if not automatic_randomization:
# We are probably applying augmentation to a streaming dataset for which zoom in augm is not supported.
assert zoom_in_weight == 0, f"{zoom_in_weight=}"
self.zoom_in_or_out_distribution = torch.distributions.categorical.Categorical(
probs=th.tensor([zoom_in_weight, zoom_out_weight])
)
self.augm_state = AugmentationState(
apply_h_flip=False,
rotation=RotationState(active=False, angle_deg=0.0),
apply_zoom_in=False,
zoom_out=ZoomOutState(active=False, x0=0, y0=0, zoom_out_factor=1.0),
)
def randomize_augmentation(self):
"""Sample new augmentation parameters that will be consistently applied among the items.
This function only works with augmentations that are input-independent.
E.g. The zoom-in augmentation parameters depend on the labels and cannot be sampled in this function.
For the same reason, it is not a very reasonable augmentation for the streaming scenario.
"""
self.augm_state.apply_h_flip = self.h_flip_prob > th.rand(1).item()
self.augm_state.rotation.active = self.rot_prob > th.rand(1).item()
if self.augm_state.rotation.active:
sign = 1 if th.randn(1).item() >= 0 else -1
self.augm_state.rotation.angle_deg = sign * torch_uniform_sample_scalar(
min_value=self.rot_min_angle_deg, max_value=self.rot_max_angle_deg
)
# Zoom in and zoom out is mutually exclusive.
do_zoom = self.zoom_prob > th.rand(1).item()
do_zoom_in = self.zoom_in_or_out_distribution.sample().item() == 0
do_zoom_out = not do_zoom_in
do_zoom_in &= do_zoom
do_zoom_out &= do_zoom
self.augm_state.apply_zoom_in = do_zoom_in
self.augm_state.zoom_out.active = do_zoom_out
if do_zoom_out:
rand_zoom_out_factor = torch_uniform_sample_scalar(
min_value=self.min_zoom_out_factor, max_value=self.max_zoom_out_factor
)
height, width = self.hw_tuple
zoom_window_h, zoom_window_w = int(height / rand_zoom_out_factor), int(
width / rand_zoom_out_factor
)
x0_sampled = int(
torch_uniform_sample_scalar(
min_value=0, max_value=width - zoom_window_w
)
)
y0_sampled = int(
torch_uniform_sample_scalar(
min_value=0, max_value=height - zoom_window_h
)
)
self.augm_state.zoom_out.x0 = x0_sampled
self.augm_state.zoom_out.y0 = y0_sampled
self.augm_state.zoom_out.zoom_out_factor = rand_zoom_out_factor
def _zoom_out_and_rescale(
self, data_dict: LoaderDataDictGenX
) -> LoaderDataDictGenX:
zoom_out_state = self.augm_state.zoom_out
zoom_out_factor = zoom_out_state.zoom_out_factor
if zoom_out_factor == 1:
return data_dict
return {
k: RandomSpatialAugmentorGenX._zoom_out_and_rescale_recursive(
v,
zoom_coordinates_x0y0=(zoom_out_state.x0, zoom_out_state.y0),
zoom_out_factor=zoom_out_factor,
datatype=k,
)
for k, v in data_dict.items()
}
@staticmethod
def _zoom_out_and_rescale_tensor(
input_: th.Tensor,
zoom_coordinates_x0y0: Tuple[int, int],
zoom_out_factor: float,
datatype: DataType,
) -> th.Tensor:
assert len(zoom_coordinates_x0y0) == 2
assert isinstance(input_, th.Tensor)
if datatype == DataType.IMAGE or datatype == DataType.EV_REPR:
assert input_.ndim == 3, f"{input_.shape=}"
height, width = input_.shape[-2:]
zoom_window_h, zoom_window_w = int(height / zoom_out_factor), int(
width / zoom_out_factor
)
zoom_window = interpolate(
input_.unsqueeze(0),
size=(zoom_window_h, zoom_window_w),
mode="nearest-exact",
)[0]
output = th.zeros_like(input_)
x0, y0 = zoom_coordinates_x0y0
assert x0 >= 0
assert y0 >= 0
output[:, y0 : y0 + zoom_window_h, x0 : x0 + zoom_window_w] = zoom_window
return output
raise NotImplementedError
@classmethod
def _zoom_out_and_rescale_recursive(
cls,
input_: Any,
zoom_coordinates_x0y0: Tuple[int, int],
zoom_out_factor: float,
datatype: DataType,
):
if datatype in (DataType.IS_PADDED_MASK, DataType.IS_FIRST_SAMPLE):
return input_
if isinstance(input_, th.Tensor):
return cls._zoom_out_and_rescale_tensor(
input_=input_,
zoom_coordinates_x0y0=zoom_coordinates_x0y0,
zoom_out_factor=zoom_out_factor,
datatype=datatype,
)
if isinstance(input_, ObjectLabels) or isinstance(
input_, SparselyBatchedObjectLabels
):
assert datatype == DataType.OBJLABELS or datatype == DataType.OBJLABELS_SEQ
input_.zoom_out_and_rescale_(
zoom_coordinates_x0y0=zoom_coordinates_x0y0,
zoom_out_factor=zoom_out_factor,
)
return input_
if isinstance(input_, abc.Sequence):
return [
RandomSpatialAugmentorGenX._zoom_out_and_rescale_recursive(
x,
zoom_coordinates_x0y0=zoom_coordinates_x0y0,
zoom_out_factor=zoom_out_factor,
datatype=datatype,
)
for x in input_
]
if isinstance(input_, abc.Mapping):
return {
key: RandomSpatialAugmentorGenX._zoom_out_and_rescale_recursive(
value,
zoom_coordinates_x0y0=zoom_coordinates_x0y0,
zoom_out_factor=zoom_out_factor,
datatype=datatype,
)
for key, value in input_.items()
}
raise NotImplementedError
def _zoom_in_and_rescale(self, data_dict: LoaderDataDictGenX) -> LoaderDataDictGenX:
rand_zoom_in_factor = torch_uniform_sample_scalar(
min_value=self.min_zoom_in_factor, max_value=self.max_zoom_in_factor
)
if rand_zoom_in_factor == 1:
return data_dict
height, width = RandomSpatialAugmentorGenX._hw_from_data(data_dict=data_dict)
assert (height, width) == self.hw_tuple
zoom_window_h, zoom_window_w = int(height / rand_zoom_in_factor), int(
width / rand_zoom_in_factor
)
latest_objframe = get_most_recent_objframe(
data_dict=data_dict, check_if_nonempty=True
)
if latest_objframe is None:
warn(message=NO_LABEL_WARN_MSG, category=UserWarning, stacklevel=2)
return data_dict
x0_sampled, y0_sampled = randomly_sample_zoom_window_from_objframe(
objframe=latest_objframe,
zoom_window_height=zoom_window_h,
zoom_window_width=zoom_window_w,
)
return {
k: RandomSpatialAugmentorGenX._zoom_in_and_rescale_recursive(
v,
zoom_coordinates_x0y0=(x0_sampled, y0_sampled),
zoom_in_factor=rand_zoom_in_factor,
datatype=k,
)
for k, v in data_dict.items()
}
@staticmethod
def _zoom_in_and_rescale_tensor(
input_: th.Tensor,
zoom_coordinates_x0y0: Tuple[int, int],
zoom_in_factor: float,
datatype: DataType,
) -> th.Tensor:
assert len(zoom_coordinates_x0y0) == 2
assert isinstance(input_, th.Tensor)
if datatype == DataType.IMAGE or datatype == DataType.EV_REPR:
assert input_.ndim == 3, f"{input_.shape=}"
height, width = input_.shape[-2:]
zoom_window_h, zoom_window_w = int(height / zoom_in_factor), int(
width / zoom_in_factor
)
x0, y0 = zoom_coordinates_x0y0
assert x0 >= 0
assert y0 >= 0
zoom_canvas = input_[
..., y0 : y0 + zoom_window_h, x0 : x0 + zoom_window_w
].unsqueeze(0)
output = interpolate(
zoom_canvas, size=(height, width), mode="nearest-exact"
)
output = output[0]
return output
raise NotImplementedError
@classmethod
def _zoom_in_and_rescale_recursive(
cls,
input_: Any,
zoom_coordinates_x0y0: Tuple[int, int],
zoom_in_factor: float,
datatype: DataType,
):
if datatype in (DataType.IS_PADDED_MASK, DataType.IS_FIRST_SAMPLE):
return input_
if isinstance(input_, th.Tensor):
return cls._zoom_in_and_rescale_tensor(
input_=input_,
zoom_coordinates_x0y0=zoom_coordinates_x0y0,
zoom_in_factor=zoom_in_factor,
datatype=datatype,
)
if isinstance(input_, ObjectLabels) or isinstance(
input_, SparselyBatchedObjectLabels
):
assert datatype == DataType.OBJLABELS or datatype == DataType.OBJLABELS_SEQ
input_.zoom_in_and_rescale_(
zoom_coordinates_x0y0=zoom_coordinates_x0y0,
zoom_in_factor=zoom_in_factor,
)
return input_
if isinstance(input_, abc.Sequence):
return [
RandomSpatialAugmentorGenX._zoom_in_and_rescale_recursive(
x,
zoom_coordinates_x0y0=zoom_coordinates_x0y0,
zoom_in_factor=zoom_in_factor,
datatype=datatype,
)
for x in input_
]
if isinstance(input_, abc.Mapping):
return {
key: RandomSpatialAugmentorGenX._zoom_in_and_rescale_recursive(
value,
zoom_coordinates_x0y0=zoom_coordinates_x0y0,
zoom_in_factor=zoom_in_factor,
datatype=datatype,
)
for key, value in input_.items()
}
raise NotImplementedError
def _rotate(self, data_dict: LoaderDataDictGenX) -> LoaderDataDictGenX:
angle_deg = self.augm_state.rotation.angle_deg
return {
k: RandomSpatialAugmentorGenX._rotate_recursive(
v, angle_deg=angle_deg, datatype=k
)
for k, v in data_dict.items()
}
@staticmethod
def _rotate_tensor(input_: Any, angle_deg: float, datatype: DataType):
assert isinstance(input_, th.Tensor)
if datatype == DataType.IMAGE or datatype == DataType.EV_REPR:
return rotate(
input_, angle=angle_deg, interpolation=InterpolationMode.NEAREST
)
raise NotImplementedError
@classmethod
def _rotate_recursive(cls, input_: Any, angle_deg: float, datatype: DataType):
if datatype in (DataType.IS_PADDED_MASK, DataType.IS_FIRST_SAMPLE):
return input_
if isinstance(input_, th.Tensor):
return cls._rotate_tensor(
input_=input_, angle_deg=angle_deg, datatype=datatype
)
if isinstance(input_, ObjectLabels) or isinstance(
input_, SparselyBatchedObjectLabels
):
assert datatype == DataType.OBJLABELS or datatype == DataType.OBJLABELS_SEQ
input_.rotate_(angle_deg=angle_deg)
return input_
if isinstance(input_, abc.Sequence):
return [
RandomSpatialAugmentorGenX._rotate_recursive(
x, angle_deg=angle_deg, datatype=datatype
)
for x in input_
]
if isinstance(input_, abc.Mapping):
return {
key: RandomSpatialAugmentorGenX._rotate_recursive(
value, angle_deg=angle_deg, datatype=datatype
)
for key, value in input_.items()
}
raise NotImplementedError
@staticmethod
def _flip(data_dict: LoaderDataDictGenX, type_: str) -> LoaderDataDictGenX:
assert type_ in {"h", "v"}
return {
k: RandomSpatialAugmentorGenX._flip_recursive(
v, flip_type=type_, datatype=k
)
for k, v in data_dict.items()
}
@staticmethod
def _flip_tensor(input_: Any, flip_type: str, datatype: DataType):
assert isinstance(input_, th.Tensor)
flip_axis = -1 if flip_type == "h" else -2
if datatype == DataType.IMAGE or datatype == DataType.EV_REPR:
return th.flip(input_, dims=[flip_axis])
if datatype == DataType.FLOW:
assert input_.shape[-3] == 2
flow_idx = 0 if flip_type == "h" else 1
input_ = th.flip(input_, dims=[flip_axis])
# Also flip the sign of the x (horizontal) or y (vertical) component of the flow.
input_[..., flow_idx, :, :] = -1 * input_[..., flow_idx, :, :]
return input_
raise NotImplementedError
@classmethod
def _flip_recursive(cls, input_: Any, flip_type: str, datatype: DataType):
if datatype in (DataType.IS_PADDED_MASK, DataType.IS_FIRST_SAMPLE):
return input_
if isinstance(input_, th.Tensor):
return cls._flip_tensor(
input_=input_, flip_type=flip_type, datatype=datatype
)
if isinstance(input_, ObjectLabels) or isinstance(
input_, SparselyBatchedObjectLabels
):
assert datatype == DataType.OBJLABELS or datatype == DataType.OBJLABELS_SEQ
if flip_type == "h":
# in-place modification
input_.flip_lr_()
return input_
else:
raise NotImplementedError
if isinstance(input_, abc.Sequence):
return [
RandomSpatialAugmentorGenX._flip_recursive(
x, flip_type=flip_type, datatype=datatype
)
for x in input_
]
if isinstance(input_, abc.Mapping):
return {
key: RandomSpatialAugmentorGenX._flip_recursive(
value, flip_type=flip_type, datatype=datatype
)
for key, value in input_.items()
}
raise NotImplementedError
@staticmethod
def _hw_from_data(data_dict: LoaderDataDictGenX) -> Tuple[int, int]:
height = None
width = None
for k, v in data_dict.items():
_hw = None
if k == DataType.OBJLABELS or k == DataType.OBJLABELS_SEQ:
hw = v.input_size_hw
if hw is not None:
_hw = v.input_size_hw
elif k in (DataType.IMAGE, DataType.FLOW, DataType.EV_REPR):
_hw = v[0].shape[-2:]
if _hw is not None:
_height, _width = _hw
if height is None:
assert width is None
height, width = _height, _width
else:
assert height == _height and width == _width
assert height is not None
assert width is not None
return height, width
def __call__(self, data_dict: LoaderDataDictGenX):
"""
:param data_dict: LoaderDataDictGenX type, image-based tensors must have (*, h, w) shape.
:return: map with same keys but spatially augmented values.
"""
if self.automatic_randomization:
self.randomize_augmentation()
if self.augm_state.apply_h_flip:
data_dict = self._flip(data_dict, type_="h")
if self.augm_state.rotation.active:
data_dict = self._rotate(data_dict)
if self.augm_state.apply_zoom_in:
data_dict = self._zoom_in_and_rescale(data_dict=data_dict)
if self.augm_state.zoom_out.active:
assert not self.augm_state.apply_zoom_in
data_dict = self._zoom_out_and_rescale(data_dict=data_dict)
return data_dict
def get_most_recent_objframe(
data_dict: LoaderDataDictGenX, check_if_nonempty: bool = True
) -> Optional[ObjectLabels]:
assert (
DataType.OBJLABELS_SEQ in data_dict
), f"Requires datatype {DataType.OBJLABELS_SEQ} to be present"
sparse_obj_labels = data_dict[DataType.OBJLABELS_SEQ]
sparse_obj_labels: SparselyBatchedObjectLabels
for obj_label in reversed(sparse_obj_labels):
if obj_label is not None:
return_label = True if not check_if_nonempty else len(obj_label) > 0
if return_label:
return obj_label
# no labels found
return None
def randomly_sample_zoom_window_from_objframe(
objframe: ObjectLabels,
zoom_window_height: Union[int, float],
zoom_window_width: Union[int, float],
) -> Tuple[int, int]:
input_height, input_width = objframe.input_size_hw
possible_samples = []
for idx in range(len(objframe)):
label_xywh = (
objframe.x[idx],
objframe.y[idx],
objframe.w[idx],
objframe.h[idx],
)
possible_samples.append(
randomly_sample_zoom_window_from_label_rectangle(
label_xywh=label_xywh,
input_height=input_height,
input_width=input_width,
zoom_window_height=zoom_window_height,
zoom_window_width=zoom_window_width,
)
)
assert len(possible_samples) > 0
# Using torch to sample, to avoid potential problems with multiprocessing.
sample_idx = (
0
if len(possible_samples) == 1
else th.randint(low=0, high=len(possible_samples) - 1, size=(1,)).item()
)
x0_sample, y0_sample = possible_samples[sample_idx]
assert input_width > x0_sample >= 0, f"{x0_sample=}"
assert input_height > y0_sample >= 0, f"{y0_sample=}"
return x0_sample, y0_sample
def randomly_sample_zoom_window_from_label_rectangle(
label_xywh: Tuple[Union[int, float, th.Tensor], ...],
input_height: Union[int, float],
input_width: Union[int, float],
zoom_window_height: Union[int, float],
zoom_window_width: Union[int, float],
) -> Tuple[int, int]:
"""Computes a set of top-left coordinates from which the top-left corner of the zoom window
can be sampled such that the zoom window is guaranteed to contain the whole (rectangular) label.
Return a random sample from this set.
Notation:
(x0,y0)---(x1,y0)
| |
| |
(x0,y1)---(x1,y1)
"""
assert input_height >= zoom_window_height
assert input_width >= zoom_window_width
label_xywh = tuple(x.item() if isinstance(x, th.Tensor) else x for x in label_xywh)
x0_l, y0_l, w_l, h_l = label_xywh
x1_l = x0_l + w_l
y1_l = y0_l + h_l
assert x0_l >= 0
assert y0_l >= 0
assert w_l > 0
assert h_l > 0
assert x1_l <= input_width + 1e-2 - 1
assert y1_l <= input_height + 1e-2 - 1
x0_valid_region = max(x1_l - max(zoom_window_width, w_l), 0)
y0_valid_region = max(y1_l - max(zoom_window_height, h_l), 0)
x1_valid_region = min(x0_l + max(zoom_window_width, w_l), input_width - 1)
y1_valid_region = min(y0_l + max(zoom_window_height, h_l), input_height - 1)
x1_valid_region = max(x1_valid_region - zoom_window_width, x0_valid_region)
y1_valid_region = max(y1_valid_region - zoom_window_height, y0_valid_region)
x_topleft_sample = int(
torch_uniform_sample_scalar(
min_value=x0_valid_region, max_value=x1_valid_region
)
)
assert 0 <= x_topleft_sample < input_width
y_topleft_sample = int(
torch_uniform_sample_scalar(
min_value=y0_valid_region, max_value=y1_valid_region
)
)
assert 0 <= y_topleft_sample < input_height
return x_topleft_sample, y_topleft_sample
================================================
FILE: RVT/data/utils/representations.py
================================================
from abc import ABC, abstractmethod
from typing import Optional, Tuple
import math
import numpy as np
import torch as th
class RepresentationBase(ABC):
@abstractmethod
def construct(
self, x: th.Tensor, y: th.Tensor, pol: th.Tensor, time: th.Tensor
) -> th.Tensor: ...
@abstractmethod
def get_shape(self) -> Tuple[int, int, int]: ...
@staticmethod
@abstractmethod
def get_numpy_dtype() -> np.dtype: ...
@staticmethod
@abstractmethod
def get_torch_dtype() -> th.dtype: ...
@property
def dtype(self) -> th.dtype:
return self.get_torch_dtype()
@staticmethod
def _is_int_tensor(tensor: th.Tensor) -> bool:
return not th.is_floating_point(tensor) and not th.is_complex(tensor)
class StackedHistogram(RepresentationBase):
def __init__(
self,
bins: int,
height: int,
width: int,
count_cutoff: Optional[int] = None,
fastmode: bool = True,
):
"""
In case of fastmode == True: use uint8 to construct the representation, but could lead to overflow.
In case of fastmode == False: use int16 to construct the representation, and convert to uint8 after clipping.
Note: Overflow should not be a big problem because it happens only for hot pixels. In case of overflow,
the value will just start accumulating from 0 again.
"""
assert bins >= 1
self.bins = bins
assert height >= 1
self.height = height
assert width >= 1
self.width = width
self.count_cutoff = count_cutoff
if self.count_cutoff is None:
self.count_cutoff = 255
else:
assert count_cutoff >= 1
self.count_cutoff = min(count_cutoff, 255)
self.fastmode = fastmode
self.channels = 2
@staticmethod
def get_numpy_dtype() -> np.dtype:
return np.dtype("uint8")
@staticmethod
def get_torch_dtype() -> th.dtype:
return th.uint8
def merge_channel_and_bins(self, representation: th.Tensor):
assert representation.dim() == 4
return th.reshape(representation, (-1, self.height, self.width))
def get_shape(self) -> Tuple[int, int, int]:
return 2 * self.bins, self.height, self.width
def construct(
self, x: th.Tensor, y: th.Tensor, pol: th.Tensor, time: th.Tensor
) -> th.Tensor:
device = x.device
assert y.device == pol.device == time.device == device
assert self._is_int_tensor(x)
assert self._is_int_tensor(y)
assert self._is_int_tensor(pol)
assert self._is_int_tensor(time)
dtype = th.uint8 if self.fastmode else th.int16
representation = th.zeros(
(self.channels, self.bins, self.height, self.width),
dtype=dtype,
device=device,
requires_grad=False,
)
if x.numel() == 0:
assert y.numel() == 0
assert pol.numel() == 0
assert time.numel() == 0
return self.merge_channel_and_bins(representation.to(th.uint8))
assert x.numel() == y.numel() == pol.numel() == time.numel()
assert pol.min() >= 0
assert pol.max() <= 1
bn, ch, ht, wd = self.bins, self.channels, self.height, self.width
# NOTE: assume sorted time
t0_int = time[0]
t1_int = time[-1]
assert t1_int >= t0_int
t_norm = time - t0_int
t_norm = t_norm / max((t1_int - t0_int), 1)
t_norm = t_norm * bn
t_idx = t_norm.floor()
t_idx = th.clamp(t_idx, max=bn - 1)
indices = (
x.long()
+ wd * y.long()
+ ht * wd * t_idx.long()
+ bn * ht * wd * pol.long()
)
values = th.ones_like(indices, dtype=dtype, device=device)
representation.put_(indices, values, accumulate=True)
representation = th.clamp(representation, min=0, max=self.count_cutoff)
if not self.fastmode:
representation = representation.to(th.uint8)
return self.merge_channel_and_bins(representation)
def cumsum_channel(x: th.Tensor, num_channels: int):
for i in reversed(range(num_channels)):
x[i] = th.sum(input=x[: i + 1], dim=0)
return x
class MixedDensityEventStack(RepresentationBase):
def __init__(
self,
bins: int,
height: int,
width: int,
count_cutoff: Optional[int] = None,
allow_compilation: bool = False,
):
assert bins >= 1
self.bins = bins
assert height >= 1
self.height = height
assert width >= 1
self.width = width
self.count_cutoff = count_cutoff
if self.count_cutoff is not None:
assert isinstance(count_cutoff, int)
assert 0 <= self.count_cutoff <= 2**7 - 1
self.cumsum_ch_opt = cumsum_channel
if allow_compilation:
# Will most likely not work with multiprocessing.
try:
self.cumsum_ch_opt = th.compile(cumsum_channel)
except AttributeError:
...
@staticmethod
def get_numpy_dtype() -> np.dtype:
return np.dtype("int8")
@staticmethod
def get_torch_dtype() -> th.dtype:
return th.int8
def get_shape(self) -> Tuple[int, int, int]:
return self.bins, self.height, self.width
def construct(
self, x: th.Tensor, y: th.Tensor, pol: th.Tensor, time: th.Tensor
) -> th.Tensor:
device = x.device
assert y.device == pol.device == time.device == device
assert self._is_int_tensor(x)
assert self._is_int_tensor(y)
assert self._is_int_tensor(pol)
assert self._is_int_tensor(time)
dtype = th.int8
representation = th.zeros(
(self.bins, self.height, self.width),
dtype=dtype,
device=device,
requires_grad=False,
)
if x.numel() == 0:
assert y.numel() == 0
assert pol.numel() == 0
assert time.numel() == 0
return representation
assert x.numel() == y.numel() == pol.numel() == time.numel()
assert pol.min() >= 0 # maybe remove because too costly
assert pol.max() <= 1 # maybe remove because too costly
pol = pol * 2 - 1
bn, ht, wd = self.bins, self.height, self.width
# NOTE: assume sorted time
t0_int = time[0]
t1_int = time[-1]
assert t1_int >= t0_int
t_norm = (time - t0_int) / max((t1_int - t0_int), 1)
t_norm = th.clamp(t_norm, min=1e-6, max=1 - 1e-6)
# Let N be the number of bins. I.e. bin \in [0, N):
# Let f(bin) = t_norm, model the relationship between bin and normalized time \in [0, 1]
# f(bin=N) = 1
# f(bin=N-1) = 1/2
# f(bin=N-2) = 1/2*1/2
# -> f(bin=N-i) = (1/2)^i
# Also: f(bin) = t_norm
#
# Hence, (1/2)^(N-bin) = t_norm
# And, bin = N - log(t_norm, base=1/2) = N - log(t_norm)/log(1/2)
bin_float = self.bins - th.log(t_norm) / math.log(1 / 2)
# Can go below 0 for t_norm close to 0 -> clamp to 0
bin_float = th.clamp(bin_float, min=0)
t_idx = bin_float.floor()
indices = x.long() + wd * y.long() + ht * wd * t_idx.long()
values = th.asarray(pol, dtype=dtype, device=device)
representation.put_(indices, values, accumulate=True)
representation = self.cumsum_ch_opt(representation, num_channels=self.bins)
if self.count_cutoff is not None:
representation = th.clamp(
representation, min=-self.count_cutoff, max=self.count_cutoff
)
return representation
================================================
FILE: RVT/data/utils/spatial.py
================================================
from omegaconf import DictConfig
from data.utils.types import DatasetType
_type_2_hw = {
DatasetType.GEN1: (240, 304),
DatasetType.GEN4: (720, 1280),
}
_str_2_type = {
"gen1": DatasetType.GEN1,
"gen4": DatasetType.GEN4,
}
def get_original_hw(dataset_type: DatasetType):
return _type_2_hw[dataset_type]
def get_dataloading_hw(dataset_config: DictConfig):
dataset_name = dataset_config.name
hw = get_original_hw(dataset_type=_str_2_type[dataset_name])
downsample_by_factor_2 = dataset_config.downsample_by_factor_2
if downsample_by_factor_2:
hw = tuple(x // 2 for x in hw)
return hw
================================================
FILE: RVT/data/utils/stream_concat_datapipe.py
================================================
from typing import Any, Iterator, List, Optional, Type
import torch as th
import torch.distributed as dist
from torch.utils.data import DataLoader
from torchdata.datapipes.iter import (
Concater,
IterableWrapper,
IterDataPipe,
Zipper,
)
from torchdata.datapipes.map import MapDataPipe
class DummyIterDataPipe(IterDataPipe):
def __init__(self, source_dp: IterDataPipe):
super().__init__()
assert isinstance(source_dp, IterDataPipe)
self.source_dp = source_dp
def __iter__(self):
yield from self.source_dp
class ConcatStreamingDataPipe(IterDataPipe):
"""This Dataset avoids the sharding problem by instantiating randomized stream concatenation at the batch and
worker level.
Pros:
- Every single batch has valid samples. Consequently, the batch size is always constant.
Cons:
- There might be repeated samples in a batch. Although they should be different because of data augmentation.
- Cannot be used for validation or testing because we repeat the dataset multiple times in an epoch.
TLDR: preferred approach for training but not useful for validation or testing.
"""
def __init__(
self,
datapipe_list: List[MapDataPipe],
batch_size: int,
num_workers: int,
augmentation_pipeline: Optional[Type[IterDataPipe]] = None,
print_seed_debug: bool = False,
):
super().__init__()
assert batch_size > 0
if augmentation_pipeline is not None:
self.augmentation_dp = augmentation_pipeline
else:
self.augmentation_dp = DummyIterDataPipe
# We require MapDataPipes instead of IterDataPipes because IterDataPipes must be deepcopied in each worker.
# Instead, MapDataPipes can be converted to IterDataPipes in each worker without requiring a deepcopy.
self.datapipe_list = datapipe_list
self.batch_size = batch_size
self.print_seed_debug = print_seed_debug
@staticmethod
def random_torch_shuffle_list(data: List[Any]) -> Iterator[Any]:
assert isinstance(data, List)
return (data[idx] for idx in th.randperm(len(data)).tolist())
def _get_zipped_streams(self, datapipe_list: List[MapDataPipe], batch_size: int):
"""Use it only in the iter function of this class!!!
Reason: randomized shuffling must happen within each worker. Otherwise, the same random order will be used
for all workers.
"""
assert isinstance(datapipe_list, List)
assert batch_size > 0
streams = Zipper(
*(
Concater(
*(
self.augmentation_dp(x.to_iter_datapipe())
for x in self.random_torch_shuffle_list(datapipe_list)
)
)
for _ in range(batch_size)
)
)
return streams
def _print_seed_debug_info(self):
worker_info = th.utils.data.get_worker_info()
local_worker_id = 0 if worker_info is None else worker_info.id
worker_torch_seed = worker_info.seed
local_num_workers = 1 if worker_info is None else worker_info.num_workers
if dist.is_available() and dist.is_initialized():
global_rank = dist.get_rank()
else:
global_rank = 0
global_worker_id = global_rank * local_num_workers + local_worker_id
rnd_number = th.randn(1)
print(
f"{worker_torch_seed=},\t{global_worker_id=},\t{global_rank=},\t{local_worker_id=},\t{rnd_number=}",
flush=True,
)
def _get_zipped_streams_with_worker_id(self):
"""Use it only in the iter function of this class!!!"""
worker_info = th.utils.data.get_worker_info()
local_worker_id = 0 if worker_info is None else worker_info.id
worker_id_stream = IterableWrapper([local_worker_id]).cycle(count=None)
zipped_stream = self._get_zipped_streams(
datapipe_list=self.datapipe_list, batch_size=self.batch_size
)
return zipped_stream.zip(worker_id_stream)
def __iter__(self):
if self.print_seed_debug:
self._print_seed_debug_info()
return iter(self._get_zipped_streams_with_worker_id())
================================================
FILE: RVT/data/utils/stream_sharded_datapipe.py
================================================
from typing import Any, List, Optional
import torch
import torch.distributed as dist
from torch.utils.data import DataLoader
from torchdata.datapipes.iter import (
Concater,
IterableWrapper,
IterDataPipe,
ZipperLongest,
)
from torchdata.datapipes.map import MapDataPipe
class ShardedStreamingDataPipe(IterDataPipe):
def __init__(
self,
datapipe_list: List[MapDataPipe],
batch_size: int,
fill_value: Optional[Any] = None,
):
super().__init__()
assert batch_size > 0
# We require MapDataPipes instead of IterDataPipes because IterDataPipes must be deepcopied in each worker.
# Instead, MapDataPipes can be converted to IterDataPipes in each worker without requiring a deepcopy.
# Note: Sorting is a heuristic to get potentially better distribution of workloads than taking the data as is.
# Sort iterators from long to short.
self.datapipe_list = sorted(datapipe_list, key=lambda x: len(x), reverse=True)
self.batch_size = batch_size
self.fill_value = fill_value
@staticmethod
def yield_pyramid_indices(start_idx: int, end_idx: int):
while True:
for idx in range(start_idx, end_idx):
yield idx
for idx in range(end_idx - 1, start_idx - 1, -1):
yield idx
@classmethod
def assign_datapipes_to_worker(
cls,
sorted_datapipe_list: List[MapDataPipe],
total_num_workers: int,
global_worker_id: int,
) -> List[MapDataPipe]:
num_datapipes = len(sorted_datapipe_list)
assert (
num_datapipes >= total_num_workers > global_worker_id
), f"{num_datapipes=}, {total_num_workers=}, {global_worker_id=}"
datapipes = []
# Assumes sorted datapipes from long to short.
global_worker_id_generator = cls.yield_pyramid_indices(
start_idx=0, end_idx=total_num_workers
)
for idx, dp in enumerate(sorted_datapipe_list):
generated_global_worker_id = next(global_worker_id_generator)
if generated_global_worker_id == global_worker_id:
datapipes.append(dp)
assert len(sorted_datapipe_list) > 0
return datapipes
def get_zipped_stream_from_worker_datapipes(
self, datapipe_list: List[MapDataPipe], batch_size: int
) -> ZipperLongest:
num_datapipes = len(datapipe_list)
assert num_datapipes > 0
assert batch_size > 0
assert num_datapipes >= batch_size, (
"Each worker must at least get 'batch_size' number of datapipes. "
"Otherwise, we would have to support dynamic batch sizes. "
"As a workaround, decrease the number of workers."
)
# Sort datapipe_list from long to short.
datapipe_list = sorted(datapipe_list, key=lambda x: len(x), reverse=True)
zipped_streams = [[] for _ in range(batch_size)]
batch_id_generator = self.yield_pyramid_indices(start_idx=0, end_idx=batch_size)
for datapipe in datapipe_list:
batch_idx = next(batch_id_generator)
zipped_streams[batch_idx].append(datapipe)
for idx, streams in enumerate(zipped_streams):
zipped_streams[idx] = Concater(
*(stream.to_iter_datapipe() for stream in streams)
)
zipped_streams = ZipperLongest(*zipped_streams, fill_value=self.fill_value)
return zipped_streams
def __iter__(self):
worker_info = torch.utils.data.get_worker_info()
local_worker_id = 0 if worker_info is None else worker_info.id
local_num_workers = 1 if worker_info is None else worker_info.num_workers
if dist.is_available() and dist.is_initialized():
world_size = dist.get_world_size()
global_rank = dist.get_rank()
else:
world_size = 1
global_rank = 0
total_num_workers = local_num_workers * world_size
global_worker_id = global_rank * local_num_workers + local_worker_id
local_datapipes = self.assign_datapipes_to_worker(
sorted_datapipe_list=self.datapipe_list,
total_num_workers=total_num_workers,
global_worker_id=global_worker_id,
)
zipped_stream = self.get_zipped_stream_from_worker_datapipes(
datapipe_list=local_datapipes, batch_size=self.batch_size
)
# We also stream the local worker id for the use-case where we have a recurrent neural network that saves
# its state based on the local worker id. We don't need the global worker id for that because the states
# are saved in each DDP process (per GPU) separately and do not to communicate with each other.
worker_id_stream = IterableWrapper([local_worker_id]).cycle(count=None)
zipped_stream = zipped_stream.zip(worker_id_stream)
return iter(zipped_stream)
================================================
FILE: RVT/data/utils/types.py
================================================
from enum import auto, Enum
try:
from enum import StrEnum
except ImportError:
from strenum import StrEnum
from typing import Dict, List, Optional, Tuple, Union
import torch as th
from data.genx_utils.labels import ObjectLabels, SparselyBatchedObjectLabels
class DataType(Enum):
EV_REPR = auto()
FLOW = auto()
IMAGE = auto()
OBJLABELS = auto()
OBJLABELS_SEQ = auto()
IS_PADDED_MASK = auto()
IS_FIRST_SAMPLE = auto()
TOKEN_MASK = auto()
class DatasetType(Enum):
GEN1 = auto()
GEN4 = auto()
class DatasetMode(Enum):
TRAIN = auto()
VALIDATION = auto()
TESTING = auto()
class DatasetSamplingMode(StrEnum):
RANDOM = "random"
STREAM = "stream"
MIXED = "mixed"
class ObjDetOutput(Enum):
LABELS_PROPH = auto()
PRED_PROPH = auto()
EV_REPR = auto()
SKIP_VIZ = auto()
LoaderDataDictGenX = Dict[
DataType,
Union[List[th.Tensor], ObjectLabels, SparselyBatchedObjectLabels, List[bool]],
]
LstmState = Optional[Tuple[th.Tensor]]
LstmStates = List[LstmState]
FeatureMap = th.Tensor
BackboneFeatures = Dict[int, th.Tensor]
================================================
FILE: RVT/loggers/utils.py
================================================
from pathlib import Path
from typing import Union
import wandb
from omegaconf import DictConfig, OmegaConf
from loggers.wandb_logger import WandbLogger
def get_wandb_logger(full_config: DictConfig) -> WandbLogger:
wandb_config = full_config.wandb
wandb_runpath = wandb_config.wandb_runpath
if wandb_runpath is None:
wandb_id = wandb.util.generate_id()
print(f"new run: generating id {wandb_id}")
else:
wandb_id = Path(wandb_runpath).name
print(f"using provided id {wandb_id}")
full_config_dict = OmegaConf.to_container(
full_config, resolve=True, throw_on_missing=True
)
logger = WandbLogger(
project=wandb_config.project_name,
group=wandb_config.group_name,
wandb_id=wandb_id,
log_model=True,
save_last_only_final=False,
save_code=True,
config_args=full_config_dict,
)
return logger
def get_ckpt_path(logger: WandbLogger, wandb_config: DictConfig) -> Union[Path, None]:
cfg = wandb_config
artifact_name = cfg.artifact_name
assert (
artifact_name is not None
), "Artifact name is required to resume from checkpoint."
print(f"resuming checkpoint from artifact {artifact_name}")
artifact_local_file = cfg.artifact_local_file
if artifact_local_file is not None:
artifact_local_file = Path(artifact_local_file)
if isinstance(logger, WandbLogger):
resume_path = logger.get_checkpoint(
artifact_name=artifact_name, artifact_filepath=artifact_local_file
)
else:
resume_path = artifact_local_file
assert resume_path.exists()
assert resume_path.suffix == ".ckpt", resume_path.suffix
return resume_path
================================================
FILE: RVT/loggers/wandb_logger.py
================================================
"""
This is a modified version of the Pytorch Lightning logger
"""
import time
from argparse import Namespace
from pathlib import Path
from typing import Any, Dict, List, Optional, Union
from weakref import ReferenceType
import numpy as np
import lightning.pytorch as pl
import torch
import torch.nn as nn
pl_is_ge_1_6 = float(pl.__version__[:3]) >= 1.6
assert pl_is_ge_1_6
from lightning.pytorch.callbacks.model_checkpoint import ModelCheckpoint
from lightning.pytorch.loggers.logger import rank_zero_experiment, Logger
from lightning.pytorch.utilities.rank_zero import rank_zero_only, rank_zero_warn
from lightning.fabric.utilities.logger import (
_add_prefix,
_convert_params,
_flatten_dict,
_sanitize_callable_params,
)
import wandb
from wandb.sdk.lib import RunDisabled
from wandb.wandb_run import Run
class WandbLogger(Logger):
LOGGER_JOIN_CHAR = "-"
STEP_METRIC = "trainer/global_step"
def __init__(
self,
name: Optional[str] = None,
project: Optional[str] = None,
group: Optional[str] = None,
wandb_id: Optional[str] = None,
prefix: Optional[str] = "",
log_model: Optional[bool] = True,
save_last_only_final: Optional[bool] = False,
config_args: Optional[Dict[str, Any]] = None,
**kwargs,
):
super().__init__()
self._experiment = None
self._log_model = log_model
self._prefix = prefix
self._logged_model_time = {}
self._checkpoint_callback = None
# Save last is determined by the checkpoint callback argument
self._save_last = None
# Whether to save the last checkpoint continuously (more storage) or only when the run is aborted
self._save_last_only_final = save_last_only_final
# Save the configuration args (e.g. parsed arguments) and log it in wandb
self._config_args = config_args
# set wandb init arguments
self._wandb_init = dict(
name=name,
project=project,
group=group,
id=wandb_id,
resume="allow",
save_code=True,
)
self._wandb_init.update(**kwargs)
# extract parameters
self._name = self._wandb_init.get("name")
self._id = self._wandb_init.get("id")
# for save_top_k
self._public_run = None
# start wandb run (to create an attach_id for distributed modes)
wandb.require("service")
_ = self.experiment
def get_checkpoint(
self, artifact_name: str, artifact_filepath: Optional[Path] = None
) -> Path:
artifact = self.experiment.use_artifact(artifact_name)
if artifact_filepath is None:
assert artifact is not None, (
"You are probably using DDP, "
"in which case you should provide an artifact filepath."
)
# TODO: specify download directory
artifact_dir = artifact.download()
artifact_filepath = next(Path(artifact_dir).iterdir())
assert artifact_filepath.exists()
assert artifact_filepath.suffix == ".ckpt"
return artifact_filepath
def __getstate__(self) -> Dict[str, Any]:
state = self.__dict__.copy()
# args needed to reload correct experiment
if self._experiment is not None:
state["_id"] = getattr(self._experiment, "id", None)
state["_attach_id"] = getattr(self._experiment, "_attach_id", None)
state["_name"] = self._experiment.name
# cannot be pickled
state["_experiment"] = None
return state
@property
@rank_zero_experiment
def experiment(self) -> Run:
if self._experiment is None:
attach_id = getattr(self, "_attach_id", None)
if wandb.run is not None:
# wandb process already created in this instance
rank_zero_warn(
"There is a wandb run already in progress and newly created instances of `WandbLogger` will reuse"
" this run. If this is not desired, call `wandb.finish()` before instantiating `WandbLogger`."
)
self._experiment = wandb.run
elif attach_id is not None and hasattr(wandb, "_attach"):
# attach to wandb process referenced
self._experiment = wandb._attach(attach_id)
else:
# create new wandb process
self._experiment = wandb.init(**self._wandb_init)
if self._config_args is not None:
self._experiment.config.update(
self._config_args, allow_val_change=True
)
# define default x-axis
if isinstance(self._experiment, (Run, RunDisabled)) and getattr(
self._experiment, "define_metric", None
):
self._experiment.define_metric(self.STEP_METRIC)
self._experiment.define_metric(
"*", step_metric=self.STEP_METRIC, step_sync=True
)
assert isinstance(self._experiment, (Run, RunDisabled))
return self._experiment
def watch(
self,
model: nn.Module,
log: str = "all",
log_freq: int = 100,
log_graph: bool = True,
):
self.experiment.watch(model, log=log, log_freq=log_freq, log_graph=log_graph)
def add_step_metric(self, input_dict: dict, step: int) -> None:
input_dict.update({self.STEP_METRIC: step})
@rank_zero_only
def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None:
params = _convert_params(params)
params = _flatten_dict(params)
params = _sanitize_callable_params(params)
self.experiment.config.update(params, allow_val_change=True)
@rank_zero_only
def log_metrics(self, metrics: Dict[str, Any], step: Optional[int] = None) -> None:
assert rank_zero_only.rank == 0, "experiment tried to log from global_rank != 0"
metrics = _add_prefix(metrics, self._prefix, self.LOGGER_JOIN_CHAR)
if step is not None:
self.add_step_metric(metrics, step)
self.experiment.log({**metrics}, step=step)
else:
self.experiment.log(metrics)
@rank_zero_only
def log_images(
self, key: str, images: List[Any], step: Optional[int] = None, **kwargs: str
) -> None:
"""Log images (tensors, numpy arrays, PIL Images or file paths).
Optional kwargs are lists passed to each image (ex: caption, masks, boxes).
How to use: https://pytorch-lightning.readthedocs.io/en/latest/api/pytorch_lightning.loggers.wandb.html#weights-and-biases-logger
Taken from: https://github.com/PyTorchLightning/pytorch-lightning/blob/11e289ad9f95f5fe23af147fa4edcc9794f9b9a7/pytorch_lightning/loggers/wandb.py#L420
"""
if not isinstance(images, list):
raise TypeError(f'Expected a list as "images", found {type(images)}')
n = len(images)
for k, v in kwargs.items():
if len(v) != n:
raise ValueError(f"Expected {n} items but only found {len(v)} for {k}")
kwarg_list = [{k: kwargs[k][i] for k in kwargs.keys()} for i in range(n)]
metrics = {
key: [wandb.Image(img, **kwarg) for img, kwarg in zip(images, kwarg_list)]
}
self.log_metrics(metrics, step)
@rank_zero_only
def log_videos(
self,
key: str,
videos: List[Union[np.ndarray, str]],
step: Optional[int] = None,
captions: Optional[List[str]] = None,
fps: int = 4,
format_: Optional[str] = None,
):
"""
:param video: List[(T,C,H,W)] or List[(N,T,C,H,W)]
:param captions: List[str] or None
More info: https://docs.wandb.ai/ref/python/data-types/video and
https://docs.wandb.ai/guides/track/log/media#other-media
"""
assert isinstance(videos, list)
if captions is not None:
assert isinstance(captions, list)
assert len(captions) == len(videos)
wandb_videos = list()
for idx, video in enumerate(videos):
caption = captions[idx] if captions is not None else None
wandb_videos.append(
wandb.Video(
data_or_path=video, caption=caption, fps=fps, format=format_
)
)
self.log_metrics(metrics={key: wandb_videos}, step=step)
@property
def name(self) -> Optional[str]:
# This function seems to be only relevant if LoggerCollection is used.
# don't create an experiment if we don't have one
return self._experiment.project_name() if self._experiment else self._name
@property
def version(self) -> Optional[str]:
# This function seems to be only relevant if LoggerCollection is used.
# don't create an experiment if we don't have one
return self._experiment.id if self._experiment else self._id
@rank_zero_only
def after_save_checkpoint(
self, checkpoint_callback: "ReferenceType[ModelCheckpoint]"
) -> None:
# log checkpoints as artifacts
if self._checkpoint_callback is None:
self._checkpoint_callback = checkpoint_callback
self._save_last = checkpoint_callback.save_last
if self._log_model:
self._scan_and_log_checkpoints(
checkpoint_callback, self._save_last and not self._save_last_only_final
)
@rank_zero_only
def finalize(self, status: str) -> None:
# log checkpoints as artifacts
if self._checkpoint_callback and self._log_model:
self._scan_and_log_checkpoints(self._checkpoint_callback, self._save_last)
def _get_public_run(self):
if self._public_run is None:
experiment = self.experiment
runpath = (
experiment._entity
+ "/"
+ experiment._project
+ "/"
+ experiment._run_id
)
api = wandb.Api()
self._public_run = api.run(path=runpath)
return self._public_run
def _num_logged_artifact(self):
public_run = self._get_public_run()
return len(public_run.logged_artifacts())
def _scan_and_log_checkpoints(
self, checkpoint_callback: "ReferenceType[ModelCheckpoint]", save_last: bool
) -> None:
assert self._log_model
if self._checkpoint_callback is None:
self._checkpoint_callback = checkpoint_callback
self._save_last = checkpoint_callback.save_last
checkpoints = {
checkpoint_callback.best_model_path: checkpoint_callback.best_model_score,
**checkpoint_callback.best_k_models,
}
assert len(checkpoints) <= max(checkpoint_callback.save_top_k, 0)
if save_last:
last_model_path = Path(checkpoint_callback.last_model_path)
if last_model_path.exists():
checkpoints.update(
{
checkpoint_callback.last_model_path: checkpoint_callback.current_score
}
)
else:
print(
f"last model checkpoint not found at {checkpoint_callback.last_model_path}"
)
checkpoints = sorted(
(
(Path(path).stat().st_mtime, path, score)
for path, score in checkpoints.items()
if Path(path).is_file()
),
key=lambda x: x[0],
)
# Retain only checkpoints that we have not logged before with one exception:
# If the name is the same (e.g. last checkpoint which should be overwritten),
# make sure that they are newer than the previously saved checkpoint by checking their modification time
checkpoints = [
ckpt
for ckpt in checkpoints
if ckpt[1] not in self._logged_model_time.keys()
or self._logged_model_time[ckpt[1]] < ckpt[0]
]
# remove checkpoints with undefined (None) score
checkpoints = [x for x in checkpoints if x[2] is not None]
num_ckpt_logged_before = self._num_logged_artifact()
num_new_cktps = len(checkpoints)
if num_new_cktps == 0:
return
# log iteratively all new checkpoints
for time_, path, score in checkpoints:
score = score.item() if isinstance(score, torch.Tensor) else score
is_best = path == checkpoint_callback.best_model_path
is_last = path == checkpoint_callback.last_model_path
metadata = {
"score": score,
"original_filename": Path(path).name,
"ModelCheckpoint": {
k: getattr(checkpoint_callback, k)
for k in [
"monitor",
"mode",
"save_last",
"save_top_k",
"save_weights_only",
]
# ensure it does not break if `ModelCheckpoint` args change
if hasattr(checkpoint_callback, k)
},
}
aliases = []
if is_best:
aliases.append("best")
if is_last:
aliases.append("last")
artifact_name = f"checkpoint-{self.experiment.id}-" + (
"last" if is_last else "topK"
)
artifact = wandb.Artifact(
name=artifact_name, type="model", metadata=metadata
)
assert Path(path).exists()
artifact.add_file(path, name=f"{self.experiment.id}.ckpt")
self.experiment.log_artifact(artifact, aliases=aliases)
# remember logged models - timestamp needed in case filename didn't change (last.ckpt or custom name)
self._logged_model_time[path] = time_
timeout = 20
time_spent = 0
while self._num_logged_artifact() < num_ckpt_logged_before + num_new_cktps:
time.sleep(1)
time_spent += 1
if time_spent >= timeout:
rank_zero_warn(
"Timeout: Num logged artifacts never reached expected value."
)
print(f"self._num_logged_artifact() = {self._num_logged_artifact()}")
print(f"num_ckpt_logged_before = {num_ckpt_logged_before}")
print(f"num_new_cktps = {num_new_cktps}")
break
try:
self._rm_but_top_k(checkpoint_callback.save_top_k)
except KeyError:
pass
def _rm_but_top_k(self, top_k: int):
# top_k == -1: save all models
# top_k == 0: no models saved at all. The checkpoint callback does not return checkpoints.
# top_k > 0: keep only top k models (last and best will not be deleted)
def is_last(artifact):
return "last" in artifact.aliases
def is_best(artifact):
return "best" in artifact.aliases
def try_delete(artifact):
try:
artifact.delete(delete_aliases=True)
except wandb.errors.CommError:
print(
f"Failed to delete artifact {artifact.name} due to wandb.errors.CommError"
)
public_run = self._get_public_run()
score2art = list()
for artifact in public_run.logged_artifacts():
score = artifact.metadata["score"]
original_filename = artifact.metadata["original_filename"]
if score == "Infinity":
print(
f"removing INF artifact (name, score, original_filename): ({artifact.name}, {score}, {original_filename})"
)
try_delete(artifact)
continue
if score is None:
print(
f"removing None artifact (name, score, original_filename): ({artifact.name}, {score}, {original_filename})"
)
try_delete(artifact)
continue
score2art.append((score, artifact))
# From high score to low score
score2art.sort(key=lambda x: x[0], reverse=True)
count = 0
for score, artifact in score2art:
original_filename = artifact.metadata["original_filename"]
if "last" in original_filename and not is_last(artifact):
try_delete(artifact)
continue
if is_last(artifact):
continue
count += 1
if is_best(artifact):
continue
# if top_k == -1, we do not delete anything
if 0 <= top_k < count:
try_delete(artifact)
================================================
FILE: RVT/models/detection/__init_.py
================================================
================================================
FILE: RVT/models/detection/recurrent_backbone/__init__.py
================================================
from omegaconf import DictConfig
from .maxvit_rnn import RNNDetector as MaxViTRNNDetector
def build_recurrent_backbone(backbone_cfg: DictConfig):
name = backbone_cfg.name
if name == "MaxViTRNN":
return MaxViTRNNDetector(backbone_cfg)
else:
raise NotImplementedError
================================================
FILE: RVT/models/detection/recurrent_backbone/base.py
================================================
from typing import Tuple
import torch.nn as nn
class BaseDetector(nn.Module):
def get_stage_dims(self, stages: Tuple[int, ...]) -> Tuple[int, ...]:
raise NotImplementedError
def get_strides(self, stages: Tuple[int, ...]) -> Tuple[int, ...]:
raise NotImplementedError
================================================
FILE: RVT/models/detection/recurrent_backbone/maxvit_rnn.py
================================================
from typing import Dict, Optional, Tuple
import torch as th
import torch.nn as nn
from omegaconf import DictConfig, OmegaConf
from einops import rearrange
try:
from torch import compile as th_compile
except ImportError:
th_compile = None
from data.utils.types import FeatureMap, BackboneFeatures, LstmState, LstmStates
# from models.layers.rnn import DWSConvLSTM2d
from models.layers.s5.s5_model import S5Block
from models.layers.maxvit.maxvit import (
PartitionAttentionCl,
nhwC_2_nChw,
get_downsample_layer_Cf2Cl,
PartitionType,
)
from .base import BaseDetector
class RNNDetector(BaseDetector):
def __init__(self, mdl_config: DictConfig):
super().__init__()
###### Config ######
in_channels = mdl_config.input_channels
embed_dim = mdl_config.embed_dim
dim_multiplier_per_stage = tuple(mdl_config.dim_multiplier)
num_blocks_per_stage = tuple(mdl_config.num_blocks)
T_max_chrono_init_per_stage = tuple(mdl_config.T_max_chrono_init)
enable_masking = mdl_config.enable_masking
num_stages = len(num_blocks_per_stage)
assert num_stages == 4
assert isinstance(embed_dim, int)
assert num_stages == len(dim_multiplier_per_stage)
assert num_stages == len(num_blocks_per_stage)
assert num_stages == len(T_max_chrono_init_per_stage)
###### Compile if requested ######
compile_cfg = mdl_config.get("compile", None)
if compile_cfg is not None:
compile_mdl = compile_cfg.enable
if compile_mdl and th_compile is not None:
compile_args = OmegaConf.to_container(
compile_cfg.args, resolve=True, throw_on_missing=True
)
self.forward = th_compile(self.forward, **compile_args)
elif compile_mdl:
print(
"Could not compile backbone because torch.compile is not available"
)
##################################
input_dim = in_channels
patch_size = mdl_config.stem.patch_size
stride = 1
self.stage_dims = [embed_dim * x for x in dim_multiplier_per_stage]
self.stages = nn.ModuleList()
self.strides = []
for stage_idx, (num_blocks, T_max_chrono_init_stage) in enumerate(
zip(num_blocks_per_stage, T_max_chrono_init_per_stage)
):
spatial_downsample_factor = patch_size if stage_idx == 0 else 2
stage_dim = self.stage_dims[stage_idx]
enable_masking_in_stage = enable_masking and stage_idx == 0
stage = RNNDetectorStage(
dim_in=input_dim,
stage_dim=stage_dim,
spatial_downsample_factor=spatial_downsample_factor,
num_blocks=num_blocks,
enable_token_masking=enable_masking_in_stage,
T_max_chrono_init=T_max_chrono_init_stage,
stage_cfg=mdl_config.stage,
)
stride = stride * spatial_downsample_factor
self.strides.append(stride)
input_dim = stage_dim
self.stages.append(stage)
self.num_stages = num_stages
def get_stage_dims(self, stages: Tuple[int, ...]) -> Tuple[int, ...]:
stage_indices = [x - 1 for x in stages]
assert min(stage_indices) >= 0, stage_indices
assert max(stage_indices) < len(self.stages), stage_indices
return tuple(self.stage_dims[stage_idx] for stage_idx in stage_indices)
def get_strides(self, stages: Tuple[int, ...]) -> Tuple[int, ...]:
stage_indices = [x - 1 for x in stages]
assert min(stage_indices) >= 0, stage_indices
assert max(stage_indices) < len(self.stages), stage_indices
return tuple(self.strides[stage_idx] for stage_idx in stage_indices)
def forward(
self,
x: th.Tensor,
prev_states: Optional[LstmStates] = None,
token_mask: Optional[th.Tensor] = None,
train_step: bool = True,
) -> Tuple[BackboneFeatures, LstmStates]:
if prev_states is None:
prev_states = [None] * self.num_stages
assert len(prev_states) == self.num_stages
states: LstmStates = list()
output: Dict[int, FeatureMap] = {}
for stage_idx, stage in enumerate(self.stages):
x, state = stage(
x,
prev_states[stage_idx],
token_mask if stage_idx == 0 else None,
train_step,
)
states.append(state)
stage_number = stage_idx + 1
output[stage_number] = x
return output, states
class MaxVitAttentionPairCl(nn.Module):
def __init__(self, dim: int, skip_first_norm: bool, attention_cfg: DictConfig):
super().__init__()
self.att_window = PartitionAttentionCl(
dim=dim,
partition_type=PartitionType.WINDOW,
attention_cfg=attention_cfg,
skip_first_norm=skip_first_norm,
)
self.att_grid = PartitionAttentionCl(
dim=dim,
partition_type=PartitionType.GRID,
attention_cfg=attention_cfg,
skip_first_norm=False,
)
def forward(self, x):
x = self.att_window(x)
x = self.att_grid(x)
return x
class RNNDetectorStage(nn.Module):
"""Operates with NCHW [channel-first] format as input and output."""
def __init__(
self,
dim_in: int,
stage_dim: int,
spatial_downsample_factor: int,
num_blocks: int,
enable_token_masking: bool,
T_max_chrono_init: Optional[int],
stage_cfg: DictConfig,
):
super().__init__()
assert isinstance(num_blocks, int) and num_blocks > 0
downsample_cfg = stage_cfg.downsample
lstm_cfg = stage_cfg.lstm
attention_cfg = stage_cfg.attention
self.downsample_cf2cl = get_downsample_layer_Cf2Cl(
dim_in=dim_in,
dim_out=stage_dim,
downsample_factor=spatial_downsample_factor,
downsample_cfg=downsample_cfg,
)
blocks = [
MaxVitAttentionPairCl(
dim=stage_dim,
skip_first_norm=i == 0 and self.downsample_cf2cl.output_is_normed(),
attention_cfg=attention_cfg,
)
for i in range(num_blocks)
]
self.att_blocks = nn.ModuleList(blocks)
self.s5_block = S5Block(
dim=stage_dim, state_dim=stage_dim, bidir=False, bandlimit=0.5
)
"""
self.lstm = DWSConvLSTM2d(
dim=stage_dim,
dws_conv=lstm_cfg.dws_conv,
dws_conv_only_hidden=lstm_cfg.dws_conv_only_hidden,
dws_conv_kernel_size=lstm_cfg.dws_conv_kernel_size,
cell_update_dropout=lstm_cfg.get("drop_cell_update", 0),
)
"""
###### Mask Token ################
self.mask_token = (
nn.Parameter(th.zeros(1, 1, 1, stage_dim), requires_grad=True)
if enable_token_masking
else None
)
if self.mask_token is not None:
th.nn.init.normal_(self.mask_token, std=0.02)
##################################
def forward(
self,
x: th.Tensor,
states: Optional[LstmState] = None,
token_mask: Optional[th.Tensor] = None,
train_step: bool = True,
) -> Tuple[FeatureMap, LstmState]:
sequence_length = x.shape[0]
batch_size = x.shape[1]
x = rearrange(
x, "L B C H W -> (L B) C H W"
) # where B' = (L B) is the new batch size
x = self.downsample_cf2cl(x) # B' C H W -> B' H W C
if token_mask is not None:
assert self.mask_token is not None, "No mask token present in this stage"
x[token_mask] = self.mask_token
for blk in self.att_blocks:
x = blk(x)
x = nhwC_2_nChw(x) # B' H W C -> B' C H W
new_h, new_w = x.shape[2], x.shape[3]
x = rearrange(x, "(L B) C H W -> (B H W) L C", L=sequence_length)
if states is None:
states = self.s5_block.s5.initial_state(
batch_size=batch_size * new_h * new_w
).to(x.device)
else:
states = rearrange(states, "B C H W -> (B H W) C")
x, states = self.s5_block(x, states)
x = rearrange(
x, "(B H W) L C -> L B C H W", B=batch_size, H=int(new_h), W=int(new_w)
)
states = rearrange(states, "(B H W) C -> B C H W", H=new_h, W=new_w)
return x, states
================================================
FILE: RVT/models/detection/yolox/models/__init__.py
================================================
================================================
FILE: RVT/models/detection/yolox/models/losses.py
================================================
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
# Copyright (c) Megvii Inc. All rights reserved.
import torch
import torch.nn as nn
class IOUloss(nn.Module):
def __init__(self, reduction="none", loss_type="iou"):
super(IOUloss, self).__init__()
self.reduction = reduction
self.loss_type = loss_type
def forward(self, pred, target):
assert pred.shape[0] == target.shape[0]
pred = pred.view(-1, 4)
target = target.view(-1, 4)
tl = torch.max(
(pred[:, :2] - pred[:, 2:] / 2), (target[:, :2] - target[:, 2:] / 2)
)
br = torch.min(
(pred[:, :2] + pred[:, 2:] / 2), (target[:, :2] + target[:, 2:] / 2)
)
area_p = torch.prod(pred[:, 2:], 1)
area_g = torch.prod(target[:, 2:], 1)
en = (tl < br).type(tl.type()).prod(dim=1)
area_i = torch.prod(br - tl, 1) * en
area_u = area_p + area_g - area_i
iou = (area_i) / (area_u + 1e-16)
if self.loss_type == "iou":
loss = 1 - iou**2
elif self.loss_type == "giou":
c_tl = torch.min(
(pred[:, :2] - pred[:, 2:] / 2), (target[:, :2] - target[:, 2:] / 2)
)
c_br = torch.max(
(pred[:, :2] + pred[:, 2:] / 2), (target[:, :2] + target[:, 2:] / 2)
)
area_c = torch.prod(c_br - c_tl, 1)
giou = iou - (area_c - area_u) / area_c.clamp(1e-16)
loss = 1 - giou.clamp(min=-1.0, max=1.0)
else:
raise NotImplementedError
if self.reduction == "mean":
loss = loss.mean()
elif self.reduction == "sum":
loss = loss.sum()
return loss
================================================
FILE: RVT/models/detection/yolox/models/network_blocks.py
================================================
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
# Copyright (c) Megvii Inc. All rights reserved.
import torch
import torch.nn as nn
class SiLU(nn.Module):
"""export-friendly version of nn.SiLU()"""
@staticmethod
def forward(x):
return x * torch.sigmoid(x)
def get_activation(name="silu", inplace=True):
if name == "silu":
module = nn.SiLU(inplace=inplace)
elif name == "relu":
module = nn.ReLU(inplace=inplace)
elif name == "lrelu":
module = nn.LeakyReLU(0.1, inplace=inplace)
else:
raise AttributeError("Unsupported act type: {}".format(name))
return module
class BaseConv(nn.Module):
"""A Conv2d -> Batchnorm -> silu/leaky relu block"""
def __init__(
self, in_channels, out_channels, ksize, stride, groups=1, bias=False, act="silu"
):
super().__init__()
# same padding
pad = (ksize - 1) // 2
self.conv = nn.Conv2d(
in_channels,
out_channels,
kernel_size=ksize,
stride=stride,
padding=pad,
groups=groups,
bias=bias,
)
self.bn = nn.BatchNorm2d(out_channels)
self.act = get_activation(act, inplace=True)
def forward(self, x):
return self.act(self.bn(self.conv(x)))
def fuseforward(self, x):
return self.act(self.conv(x))
class DWConv(nn.Module):
"""Depthwise Conv + Conv"""
def __init__(self, in_channels, out_channels, ksize, stride=1, act="silu"):
super().__init__()
self.dconv = BaseConv(
in_channels,
in_channels,
ksize=ksize,
stride=stride,
groups=in_channels,
act=act,
)
self.pconv = BaseConv(
in_channels, out_channels, ksize=1, stride=1, groups=1, act=act
)
def forward(self, x):
x = self.dconv(x)
return self.pconv(x)
class Bottleneck(nn.Module):
# Standard bottleneck
def __init__(
self,
in_channels,
out_channels,
shortcut=True,
expansion=0.5,
depthwise=False,
act="silu",
):
super().__init__()
hidden_channels = int(out_channels * expansion)
Conv = DWConv if depthwise else BaseConv
self.conv1 = BaseConv(in_channels, hidden_channels, 1, stride=1, act=act)
self.conv2 = Conv(hidden_channels, out_channels, 3, stride=1, act=act)
self.use_add = shortcut and in_channels == out_channels
def forward(self, x):
y = self.conv2(self.conv1(x))
if self.use_add:
y = y + x
return y
class CSPLayer(nn.Module):
"""C3 in yolov5, CSP Bottleneck with 3 convolutions"""
def __init__(
self,
in_channels,
out_channels,
n=1,
shortcut=True,
expansion=0.5,
depthwise=False,
act="silu",
):
"""
Args:
in_channels (int): input channels.
out_channels (int): output channels.
n (int): number of Bottlenecks. Default value: 1.
"""
# ch_in, ch_out, number, shortcut, groups, expansion
super().__init__()
hidden_channels = int(out_channels * expansion) # hidden channels
self.conv1 = BaseConv(in_channels, hidden_channels, 1, stride=1, act=act)
self.conv2 = BaseConv(in_channels, hidden_channels, 1, stride=1, act=act)
self.conv3 = BaseConv(2 * hidden_channels, out_channels, 1, stride=1, act=act)
module_list = [
Bottleneck(
hidden_channels, hidden_channels, shortcut, 1.0, depthwise, act=act
)
for _ in range(n)
]
self.m = nn.Sequential(*module_list)
def forward(self, x):
x_1 = self.conv1(x)
x_2 = self.conv2(x)
x_1 = self.m(x_1)
x = torch.cat((x_1, x_2), dim=1)
return self.conv3(x)
================================================
FILE: RVT/models/detection/yolox/models/yolo_head.py
================================================
"""
Original Yolox Head code with slight modifications
"""
import math
from typing import Dict, Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
try:
from torch import compile as th_compile
except ImportError:
th_compile = None
from models.detection.yolox.utils import bboxes_iou
from .losses import IOUloss
from .network_blocks import BaseConv, DWConv
class YOLOXHead(nn.Module):
def __init__(
self,
num_classes=80,
strides=(8, 16, 32),
in_channels=(256, 512, 1024),
act="silu",
depthwise=False,
compile_cfg: Optional[Dict] = None,
):
super().__init__()
self.num_classes = num_classes
self.decode_in_inference = True # for deploy, set to False
self.cls_convs = nn.ModuleList()
self.reg_convs = nn.ModuleList()
self.cls_preds = nn.ModuleList()
self.reg_preds = nn.ModuleList()
self.obj_preds = nn.ModuleList()
self.stems = nn.ModuleList()
Conv = DWConv if depthwise else BaseConv
self.output_strides = None
self.output_grids = None
# Automatic width scaling according to original YoloX channel dims.
# in[-1]/out = 4/1
# out = in[-1]/4 = 256 * width
# -> width = in[-1]/1024
largest_base_dim_yolox = 1024
largest_base_dim_from_input = in_channels[-1]
width = largest_base_dim_from_input / largest_base_dim_yolox
hidden_dim = int(256 * width)
for i in range(len(in_channels)):
self.stems.append(
BaseConv(
in_channels=in_channels[i],
out_channels=hidden_dim,
ksize=1,
stride=1,
act=act,
)
)
self.cls_convs.append(
nn.Sequential(
*[
Conv(
in_channels=hidden_dim,
out_channels=hidden_dim,
ksize=3,
stride=1,
act=act,
),
Conv(
in_channels=hidden_dim,
out_channels=hidden_dim,
ksize=3,
stride=1,
act=act,
),
]
)
)
self.reg_convs.append(
nn.Sequential(
*[
Conv(
in_channels=hidden_dim,
out_channels=hidden_dim,
ksize=3,
stride=1,
act=act,
),
Conv(
in_channels=hidden_dim,
out_channels=hidden_dim,
ksize=3,
stride=1,
act=act,
),
]
)
)
self.cls_preds.append(
nn.Conv2d(
in_channels=hidden_dim,
out_channels=self.num_classes,
kernel_size=1,
stride=1,
padding=0,
)
)
self.reg_preds.append(
nn.Conv2d(
in_channels=hidden_dim,
out_channels=4,
kernel_size=1,
stride=1,
padding=0,
)
)
self.obj_preds.append(
nn.Conv2d(
in_channels=hidden_dim,
out_channels=1,
kernel_size=1,
stride=1,
padding=0,
)
)
self.use_l1 = False
self.l1_loss = nn.L1Loss(reduction="none")
self.bcewithlog_loss = nn.BCEWithLogitsLoss(reduction="none")
self.iou_loss = IOUloss(reduction="none")
self.strides = strides
self.grids = [torch.zeros(1)] * len(in_channels)
# According to Focal Loss paper:
self.initialize_biases(prior_prob=0.01)
###### Compile if requested ######
if compile_cfg is not None:
compile_mdl = compile_cfg["enable"]
if compile_mdl and th_compile is not None:
self.forward = th_compile(self.forward, **compile_cfg["args"])
elif compile_mdl:
print(
"Could not compile YOLOXHead because torch.compile is not available"
)
##################################
def initialize_biases(self, prior_prob):
for conv in self.cls_preds:
b = conv.bias.view(1, -1)
b.data.fill_(-math.log((1 - prior_prob) / prior_prob))
conv.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
for conv in self.obj_preds:
b = conv.bias.view(1, -1)
b.data.fill_(-math.log((1 - prior_prob) / prior_prob))
conv.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
def forward(self, xin, labels=None):
train_outputs = []
inference_outputs = []
origin_preds = []
x_shifts = []
y_shifts = []
expanded_strides = []
for k, (cls_conv, reg_conv, stride_this_level, x) in enumerate(
zip(self.cls_convs, self.reg_convs, self.strides, xin)
):
x = self.stems[k](x)
cls_x = x
reg_x = x
cls_feat = cls_conv(cls_x)
cls_output = self.cls_preds[k](cls_feat)
reg_feat = reg_conv(reg_x)
reg_output = self.reg_preds[k](reg_feat)
obj_output = self.obj_preds[k](reg_feat)
if self.training:
output = torch.cat([reg_output, obj_output, cls_output], 1)
output, grid = self.get_output_and_grid(
output, k, stride_this_level, xin[0].type()
)
x_shifts.append(grid[:, :, 0])
y_shifts.append(grid[:, :, 1])
expanded_strides.append(
torch.zeros(1, grid.shape[1])
.fill_(stride_this_level)
.type_as(xin[0])
)
if self.use_l1:
batch_size = reg_output.shape[0]
hsize, wsize = reg_output.shape[-2:]
reg_output = reg_output.view(batch_size, 1, 4, hsize, wsize)
reg_output = reg_output.permute(0, 1, 3, 4, 2).reshape(
batch_size, -1, 4
)
origin_preds.append(reg_output.clone())
train_outputs.append(output)
inference_output = torch.cat(
[reg_output, obj_output.sigmoid(), cls_output.sigmoid()], 1
)
inference_outputs.append(inference_output)
# --------------------------------------------------------
# Modification: return decoded output also during training
# --------------------------------------------------------
losses = None
if self.training:
losses = self.get_losses(
x_shifts,
y_shifts,
expanded_strides,
labels,
torch.cat(train_outputs, 1),
origin_preds,
dtype=xin[0].dtype,
)
assert len(losses) == 6
losses = {
"loss": losses[0],
"iou_loss": losses[1],
"conf_loss": losses[2], # object-ness
"cls_loss": losses[3], # predicted class
"l1_loss": losses[4],
"num_fg": losses[5],
}
self.hw = [x.shape[-2:] for x in inference_outputs]
# [batch, n_anchors_all, 85]
outputs = torch.cat(
[x.flatten(start_dim=2) for x in inference_outputs], dim=2
).permute(0, 2, 1)
if self.decode_in_inference:
return self.decode_outputs(outputs), losses
else:
return outputs, losses
def get_output_and_grid(self, output, k, stride, dtype):
grid = self.grids[k]
batch_size = output.shape[0]
n_ch = 5 + self.num_classes
hsize, wsize = output.shape[-2:]
if grid.shape[2:4] != output.shape[2:4]:
yv, xv = torch.meshgrid([torch.arange(hsize), torch.arange(wsize)])
grid = torch.stack((xv, yv), 2).view(1, 1, hsize, wsize, 2).type(dtype)
self.grids[k] = grid
output = output.view(batch_size, 1, n_ch, hsize, wsize)
output = output.permute(0, 1, 3, 4, 2).reshape(batch_size, hsize * wsize, -1)
grid = grid.view(1, -1, 2)
output[..., :2] = (output[..., :2] + grid) * stride
output[..., 2:4] = torch.exp(output[..., 2:4]) * stride
return output, grid
def decode_outputs(self, outputs):
if self.output_grids is None:
assert self.output_strides is None
dtype = outputs.dtype
device = outputs.device
grids = []
strides = []
for (hsize, wsize), stride in zip(self.hw, self.strides):
yv, xv = torch.meshgrid(
[
torch.arange(hsize, device=device, dtype=dtype),
torch.arange(wsize, device=device, dtype=dtype),
]
)
grid = torch.stack((xv, yv), 2).view(1, -1, 2)
grids.append(grid)
shape = grid.shape[:2]
strides.append(
torch.full((*shape, 1), stride, device=device, dtype=dtype)
)
self.output_grids = torch.cat(grids, dim=1)
self.output_strides = torch.cat(strides, dim=1)
outputs = torch.cat(
[
(outputs[..., 0:2] + self.output_grids) * self.output_strides,
torch.exp(outputs[..., 2:4]) * self.output_strides,
outputs[..., 4:],
],
dim=-1,
)
return outputs
def get_losses(
self,
x_shifts,
y_shifts,
expanded_strides,
labels,
outputs,
origin_preds,
dtype,
):
bbox_preds = outputs[:, :, :4] # [batch, n_anchors_all, 4]
obj_preds = outputs[:, :, 4:5] # [batch, n_anchors_all, 1]
cls_preds = outputs[:, :, 5:] # [batch, n_anchors_all, n_cls]
# calculate targets
nlabel = (labels.sum(dim=2) > 0).sum(dim=1) # number of objects
total_num_anchors = outputs.shape[1]
x_shifts = torch.cat(x_shifts, 1) # [1, n_anchors_all]
y_shifts = torch.cat(y_shifts, 1) # [1, n_anchors_all]
expanded_strides = torch.cat(expanded_strides, 1)
if self.use_l1:
origin_preds = torch.cat(origin_preds, 1)
cls_targets = []
reg_targets = []
l1_targets = []
obj_targets = []
fg_masks = []
num_fg = 0.0
num_gts = 0.0
for batch_idx in range(outputs.shape[0]):
num_gt = int(nlabel[batch_idx])
num_gts += num_gt
if num_gt == 0:
cls_target = outputs.new_zeros((0, self.num_classes))
reg_target = outputs.new_zeros((0, 4))
l1_target = outputs.new_zeros((0, 4))
obj_target = outputs.new_zeros((total_num_anchors, 1))
fg_mask = outputs.new_zeros(total_num_anchors).bool()
else:
gt_bboxes_per_image = labels[batch_idx, :num_gt, 1:5]
gt_classes = labels[batch_idx, :num_gt, 0]
bboxes_preds_per_image = bbox_preds[batch_idx]
try:
(
gt_matched_classes,
fg_mask,
pred_ious_this_matching,
matched_gt_inds,
num_fg_img,
) = self.get_assignments( # noqa
batch_idx,
num_gt,
gt_bboxes_per_image,
gt_classes,
bboxes_preds_per_image,
expanded_strides,
x_shifts,
y_shifts,
cls_preds,
obj_preds,
)
except RuntimeError as e:
# TODO: the string might change, consider a better way
if "CUDA out of memory. " not in str(e):
raise
torch.cuda.empty_cache()
(
gt_matched_classes,
fg_mask,
pred_ious_this_matching,
matched_gt_inds,
num_fg_img,
) = self.get_assignments( # noqa
batch_idx,
num_gt,
gt_bboxes_per_image,
gt_classes,
bboxes_preds_per_image,
expanded_strides,
x_shifts,
y_shifts,
cls_preds,
obj_preds,
"cpu",
)
torch.cuda.empty_cache()
num_fg += num_fg_img
cls_target = F.one_hot(
gt_matched_classes.to(torch.int64), self.num_classes
) * pred_ious_this_matching.unsqueeze(-1)
obj_target = fg_mask.unsqueeze(-1)
reg_target = gt_bboxes_per_image[matched_gt_inds]
if self.use_l1:
l1_target = self.get_l1_target(
outputs.new_zeros((num_fg_img, 4)),
gt_bboxes_per_image[matched_gt_inds],
expanded_strides[0][fg_mask],
x_shifts=x_shifts[0][fg_mask],
y_shifts=y_shifts[0][fg_mask],
)
cls_targets.append(cls_target)
reg_targets.append(reg_target)
obj_targets.append(obj_target.to(dtype))
fg_masks.append(fg_mask)
if self.use_l1:
l1_targets.append(l1_target)
cls_targets = torch.cat(cls_targets, 0)
reg_targets = torch.cat(reg_targets, 0)
obj_targets = torch.cat(obj_targets, 0)
fg_masks = torch.cat(fg_masks, 0)
if self.use_l1:
l1_targets = torch.cat(l1_targets, 0)
num_fg = max(num_fg, 1)
loss_iou = (
self.iou_loss(bbox_preds.view(-1, 4)[fg_masks], reg_targets)
).sum() / num_fg
loss_obj = (
self.bcewithlog_loss(obj_preds.view(-1, 1), obj_targets)
).sum() / num_fg
loss_cls = (
self.bcewithlog_loss(
cls_preds.view(-1, self.num_classes)[fg_masks], cls_targets
)
).sum() / num_fg
if self.use_l1:
loss_l1 = (
self.l1_loss(origin_preds.view(-1, 4)[fg_masks], l1_targets)
).sum() / num_fg
else:
loss_l1 = 0.0
reg_weight = 5.0
loss = reg_weight * loss_iou + loss_obj + loss_cls + loss_l1
return (
loss,
reg_weight * loss_iou,
loss_obj,
loss_cls,
loss_l1,
num_fg / max(num_gts, 1),
)
def get_l1_target(self, l1_target, gt, stride, x_shifts, y_shifts, eps=1e-8):
l1_target[:, 0] = gt[:, 0] / stride - x_shifts
l1_target[:, 1] = gt[:, 1] / stride - y_shifts
l1_target[:, 2] = torch.log(gt[:, 2] / stride + eps)
l1_target[:, 3] = torch.log(gt[:, 3] / stride + eps)
return l1_target
@torch.no_grad()
def get_assignments(
self,
batch_idx,
num_gt,
gt_bboxes_per_image,
gt_classes,
bboxes_preds_per_image,
expanded_strides,
x_shifts,
y_shifts,
cls_preds,
obj_preds,
mode="gpu",
):
if mode == "cpu":
print("-----------Using CPU for the Current Batch-------------")
gt_bboxes_per_image = gt_bboxes_per_image.cpu().float()
bboxes_preds_per_image = bboxes_preds_per_image.cpu().float()
gt_classes = gt_classes.cpu().float()
expanded_strides = expanded_strides.cpu().float()
x_shifts = x_shifts.cpu()
y_shifts = y_shifts.cpu()
fg_mask, geometry_relation = self.get_geometry_constraint(
gt_bboxes_per_image,
gitextract_9qf3l4ub/
├── .gitignore
├── README.md
├── RVT/
│ ├── .gitignore
│ ├── LICENSE
│ ├── README.md
│ ├── callbacks/
│ │ ├── custom.py
│ │ ├── detection.py
│ │ ├── gradflow.py
│ │ ├── utils/
│ │ │ └── visualization.py
│ │ └── viz_base.py
│ ├── config/
│ │ ├── dataset/
│ │ │ ├── base.yaml
│ │ │ ├── gen1.yaml
│ │ │ └── gen4.yaml
│ │ ├── experiment/
│ │ │ ├── gen1/
│ │ │ │ ├── base.yaml
│ │ │ │ ├── default.yaml
│ │ │ │ └── small.yaml
│ │ │ └── gen4/
│ │ │ ├── base.yaml
│ │ │ ├── default.yaml
│ │ │ └── small.yaml
│ │ ├── general.yaml
│ │ ├── model/
│ │ │ ├── base.yaml
│ │ │ ├── maxvit_yolox/
│ │ │ │ └── default.yaml
│ │ │ └── rnndet.yaml
│ │ ├── modifier.py
│ │ ├── train.yaml
│ │ └── val.yaml
│ ├── data/
│ │ ├── genx_utils/
│ │ │ ├── collate.py
│ │ │ ├── collate_from_pytorch.py
│ │ │ ├── dataset_rnd.py
│ │ │ ├── dataset_streaming.py
│ │ │ ├── labels.py
│ │ │ ├── sequence_base.py
│ │ │ ├── sequence_for_streaming.py
│ │ │ └── sequence_rnd.py
│ │ └── utils/
│ │ ├── augmentor.py
│ │ ├── representations.py
│ │ ├── spatial.py
│ │ ├── stream_concat_datapipe.py
│ │ ├── stream_sharded_datapipe.py
│ │ └── types.py
│ ├── loggers/
│ │ ├── utils.py
│ │ └── wandb_logger.py
│ ├── models/
│ │ ├── detection/
│ │ │ ├── __init_.py
│ │ │ ├── recurrent_backbone/
│ │ │ │ ├── __init__.py
│ │ │ │ ├── base.py
│ │ │ │ └── maxvit_rnn.py
│ │ │ ├── yolox/
│ │ │ │ ├── models/
│ │ │ │ │ ├── __init__.py
│ │ │ │ │ ├── losses.py
│ │ │ │ │ ├── network_blocks.py
│ │ │ │ │ └── yolo_head.py
│ │ │ │ └── utils/
│ │ │ │ ├── __init__.py
│ │ │ │ ├── boxes.py
│ │ │ │ └── compat.py
│ │ │ └── yolox_extension/
│ │ │ └── models/
│ │ │ ├── __init__.py
│ │ │ ├── build.py
│ │ │ ├── detector.py
│ │ │ └── yolo_pafpn.py
│ │ └── layers/
│ │ ├── maxvit/
│ │ │ ├── __init__.py
│ │ │ ├── layers/
│ │ │ │ ├── __init__.py
│ │ │ │ ├── activations.py
│ │ │ │ ├── activations_jit.py
│ │ │ │ ├── activations_me.py
│ │ │ │ ├── adaptive_avgmax_pool.py
│ │ │ │ ├── attention_pool2d.py
│ │ │ │ ├── blur_pool.py
│ │ │ │ ├── bottleneck_attn.py
│ │ │ │ ├── cbam.py
│ │ │ │ ├── classifier.py
│ │ │ │ ├── cond_conv2d.py
│ │ │ │ ├── config.py
│ │ │ │ ├── conv2d_same.py
│ │ │ │ ├── conv_bn_act.py
│ │ │ │ ├── create_act.py
│ │ │ │ ├── create_attn.py
│ │ │ │ ├── create_conv2d.py
│ │ │ │ ├── create_norm.py
│ │ │ │ ├── create_norm_act.py
│ │ │ │ ├── drop.py
│ │ │ │ ├── eca.py
│ │ │ │ ├── evo_norm.py
│ │ │ │ ├── fast_norm.py
│ │ │ │ ├── filter_response_norm.py
│ │ │ │ ├── gather_excite.py
│ │ │ │ ├── global_context.py
│ │ │ │ ├── halo_attn.py
│ │ │ │ ├── helpers.py
│ │ │ │ ├── inplace_abn.py
│ │ │ │ ├── lambda_layer.py
│ │ │ │ ├── linear.py
│ │ │ │ ├── median_pool.py
│ │ │ │ ├── mixed_conv2d.py
│ │ │ │ ├── ml_decoder.py
│ │ │ │ ├── mlp.py
│ │ │ │ ├── non_local_attn.py
│ │ │ │ ├── norm.py
│ │ │ │ ├── norm_act.py
│ │ │ │ ├── padding.py
│ │ │ │ ├── patch_embed.py
│ │ │ │ ├── pool2d_same.py
│ │ │ │ ├── pos_embed.py
│ │ │ │ ├── selective_kernel.py
│ │ │ │ ├── separable_conv.py
│ │ │ │ ├── space_to_depth.py
│ │ │ │ ├── split_attn.py
│ │ │ │ ├── split_batchnorm.py
│ │ │ │ ├── squeeze_excite.py
│ │ │ │ ├── std_conv.py
│ │ │ │ ├── test_time_pool.py
│ │ │ │ ├── trace_utils.py
│ │ │ │ └── weight_init.py
│ │ │ └── maxvit.py
│ │ ├── rnn.py
│ │ └── s5/
│ │ ├── __init__.py
│ │ ├── jax_func.py
│ │ ├── s5_init.py
│ │ ├── s5_model.py
│ │ └── triton_comparison.py
│ ├── modules/
│ │ ├── __init__.py
│ │ ├── data/
│ │ │ └── genx.py
│ │ ├── detection.py
│ │ └── utils/
│ │ ├── detection.py
│ │ └── fetch.py
│ ├── scripts/
│ │ ├── genx/
│ │ │ ├── README.md
│ │ │ ├── conf_preprocess/
│ │ │ │ ├── extraction/
│ │ │ │ │ ├── const_count.yaml
│ │ │ │ │ ├── const_duration.yaml
│ │ │ │ │ └── frequencies/
│ │ │ │ │ ├── const_duration_100hz.yaml
│ │ │ │ │ ├── const_duration_200hz.yaml
│ │ │ │ │ ├── const_duration_40hz.yaml
│ │ │ │ │ └── const_duration_80hz.yaml
│ │ │ │ ├── filter_gen1.yaml
│ │ │ │ ├── filter_gen4.yaml
│ │ │ │ └── representation/
│ │ │ │ ├── mixeddensity_stack.yaml
│ │ │ │ └── stacked_hist.yaml
│ │ │ ├── preprocess_dataset.py
│ │ │ └── preprocess_dataset.sh
│ │ └── viz/
│ │ └── viz_gt.py
│ ├── train.py
│ ├── utils/
│ │ ├── evaluation/
│ │ │ └── prophesee/
│ │ │ ├── __init__.py
│ │ │ ├── evaluation.py
│ │ │ ├── evaluator.py
│ │ │ ├── io/
│ │ │ │ ├── __init__.py
│ │ │ │ ├── box_filtering.py
│ │ │ │ ├── box_loading.py
│ │ │ │ ├── dat_events_tools.py
│ │ │ │ ├── npy_events_tools.py
│ │ │ │ └── psee_loader.py
│ │ │ ├── metrics/
│ │ │ │ ├── __init__.py
│ │ │ │ └── coco_eval.py
│ │ │ └── visualize/
│ │ │ ├── __init__.py
│ │ │ └── vis_utils.py
│ │ ├── helpers.py
│ │ ├── padding.py
│ │ ├── preprocessing.py
│ │ └── timers.py
│ └── validation.py
├── installation_details.txt
└── scripts/
├── 1mpx/
│ ├── onempx_base.bash
│ ├── onempx_base.job
│ ├── onempx_small.bash
│ └── onempx_small.job
└── gen1/
├── base.txt
└── small.txt
SYMBOL INDEX (1050 symbols across 111 files)
FILE: RVT/callbacks/custom.py
function get_ckpt_callback (line 8) | def get_ckpt_callback(config: DictConfig) -> ModelCheckpoint:
function get_viz_callback (line 41) | def get_viz_callback(config: DictConfig) -> Callback:
FILE: RVT/callbacks/detection.py
class DetectionVizEnum (line 18) | class DetectionVizEnum(Enum):
class DetectionVizCallback (line 24) | class DetectionVizCallback(VizCallbackBase):
method __init__ (line 25) | def __init__(self, config: DictConfig):
method on_train_batch_end_custom (line 36) | def on_train_batch_end_custom(
method on_validation_batch_end_custom (line 82) | def on_validation_batch_end_custom(self, batch: Any, outputs: Any):
method on_validation_epoch_end_custom (line 100) | def on_validation_epoch_end_custom(self, logger: WandbLogger):
FILE: RVT/callbacks/gradflow.py
class GradFlowLogCallback (line 10) | class GradFlowLogCallback(Callback):
method __init__ (line 11) | def __init__(self, log_every_n_train_steps: int):
method on_before_zero_grad (line 17) | def on_before_zero_grad(
FILE: RVT/callbacks/utils/visualization.py
function get_grad_flow_figure (line 5) | def get_grad_flow_figure(named_params):
FILE: RVT/callbacks/viz_base.py
class VizCallbackBase (line 16) | class VizCallbackBase(Callback):
method __init__ (line 17) | def __init__(self, config: DictConfig, buffer_entries: Type[Enum]):
method _reset_buffer (line 30) | def _reset_buffer(self):
method add_to_buffer (line 35) | def add_to_buffer(self, key: Enum, value: Union[np.ndarray, th.Tensor]):
method get_from_buffer (line 45) | def get_from_buffer(self, key: Enum) -> List[th.Tensor]:
method on_train_batch_end_custom (line 51) | def on_train_batch_end_custom(
method on_validation_batch_end_custom (line 61) | def on_validation_batch_end_custom(self, batch: Any, outputs: Any) -> ...
method on_validation_epoch_end_custom (line 64) | def on_validation_epoch_end_custom(self, logger: WandbLogger) -> None:
method on_train_batch_end (line 69) | def on_train_batch_end(
method on_validation_batch_end (line 103) | def on_validation_batch_end(
method on_validation_epoch_start (line 133) | def on_validation_epoch_start(
method on_validation_epoch_end (line 139) | def on_validation_epoch_end(
method on_train_batch_start (line 162) | def on_train_batch_start(
method ev_repr_to_img (line 172) | def ev_repr_to_img(x: np.ndarray):
FILE: RVT/config/modifier.py
function dynamically_modify_train_config (line 10) | def dynamically_modify_train_config(config: DictConfig):
function _get_modified_hw_multiple_of (line 59) | def _get_modified_hw_multiple_of(
FILE: RVT/data/genx_utils/collate.py
function collate_object_labels (line 10) | def collate_object_labels(
function collate_sparsely_batched_object_labels (line 18) | def collate_sparsely_batched_object_labels(
function custom_collate (line 33) | def custom_collate(batch: Any):
function custom_collate_rnd (line 37) | def custom_collate_rnd(batch: Any):
function custom_collate_streaming (line 48) | def custom_collate_streaming(batch: Any):
FILE: RVT/data/genx_utils/collate_from_pytorch.py
function collate (line 19) | def collate(
function collate_tensor_fn (line 117) | def collate_tensor_fn(
function collate_tensor_fn (line 134) | def collate_tensor_fn(
function collate_numpy_array_fn (line 150) | def collate_numpy_array_fn(
function collate_numpy_scalar_fn (line 163) | def collate_numpy_scalar_fn(
function collate_float_fn (line 171) | def collate_float_fn(
function collate_int_fn (line 179) | def collate_int_fn(
function collate_str_fn (line 187) | def collate_str_fn(
FILE: RVT/data/genx_utils/dataset_rnd.py
class SequenceDataset (line 18) | class SequenceDataset(Dataset):
method __init__ (line 19) | def __init__(
method only_load_labels (line 65) | def only_load_labels(self):
method load_everything (line 68) | def load_everything(self):
method __len__ (line 71) | def __len__(self):
method __getitem__ (line 74) | def __getitem__(self, index: int) -> LoaderDataDictGenX:
class CustomConcatDataset (line 86) | class CustomConcatDataset(ConcatDataset):
method __init__ (line 89) | def __init__(self, datasets: Iterable[SequenceDataset]):
method only_load_labels (line 92) | def only_load_labels(self):
method load_everything (line 96) | def load_everything(self):
function build_random_access_dataset (line 101) | def build_random_access_dataset(
function get_weighted_random_sampler (line 130) | def get_weighted_random_sampler(dataset: CustomConcatDataset) -> Weighte...
FILE: RVT/data/genx_utils/dataset_streaming.py
function build_streaming_dataset (line 18) | def build_streaming_dataset(
function get_sequences (line 70) | def get_sequences(
function partialclass (line 105) | def partialclass(cls, *args, **kwargs):
function build_streaming_train_dataset (line 112) | def build_streaming_train_dataset(
function build_streaming_evaluation_dataset (line 132) | def build_streaming_evaluation_dataset(
FILE: RVT/data/genx_utils/labels.py
class ObjectLabelBase (line 12) | class ObjectLabelBase:
method __init__ (line 23) | def __init__(self, object_labels: th.Tensor, input_size_hw: Tuple[int,...
method clamp_to_frame_ (line 35) | def clamp_to_frame_(self):
method remove_flat_labels_ (line 50) | def remove_flat_labels_(self):
method create_empty (line 55) | def create_empty(cls):
method _assert_not_numpy (line 61) | def _assert_not_numpy(self):
method to (line 67) | def to(self, *args, **kwargs):
method numpy_ (line 74) | def numpy_(self) -> None:
method input_size_hw (line 83) | def input_size_hw(self) -> Tuple[int, int]:
method input_size_hw (line 87) | def input_size_hw(self, height_width: Tuple[int, int]):
method get (line 94) | def get(self, request: str):
method t (line 99) | def t(self):
method x (line 103) | def x(self):
method x (line 107) | def x(self, value: Union[th.Tensor, np.ndarray]):
method y (line 111) | def y(self):
method y (line 115) | def y(self, value: Union[th.Tensor, np.ndarray]):
method w (line 119) | def w(self):
method w (line 123) | def w(self, value: Union[th.Tensor, np.ndarray]):
method h (line 127) | def h(self):
method h (line 131) | def h(self, value: Union[th.Tensor, np.ndarray]):
method class_id (line 135) | def class_id(self):
method class_confidence (line 139) | def class_confidence(self):
method dtype (line 143) | def dtype(self):
method device (line 147) | def device(self):
class ObjectLabelFactory (line 151) | class ObjectLabelFactory(ObjectLabelBase):
method __init__ (line 152) | def __init__(
method from_structured_array (line 170) | def from_structured_array(
method __len__ (line 192) | def __len__(self):
method __getitem__ (line 195) | def __getitem__(self, item: int) -> ObjectLabels:
class ObjectLabels (line 218) | class ObjectLabels(ObjectLabelBase):
method __init__ (line 219) | def __init__(self, object_labels: th.Tensor, input_size_hw: Tuple[int,...
method __len__ (line 222) | def __len__(self) -> int:
method rotate_ (line 225) | def rotate_(self, angle_deg: float):
method zoom_in_and_rescale_ (line 275) | def zoom_in_and_rescale_(
method zoom_out_and_rescale_ (line 317) | def zoom_out_and_rescale_(
method scale_ (line 342) | def scale_(self, scaling_multiplier: float):
method flip_lr_ (line 362) | def flip_lr_(self) -> None:
method get_labels_as_tensors (line 367) | def get_labels_as_tensors(self, format_: str = "yolox") -> th.Tensor:
method get_labels_as_batched_tensor (line 384) | def get_labels_as_batched_tensor(
class SparselyBatchedObjectLabels (line 407) | class SparselyBatchedObjectLabels:
method __init__ (line 408) | def __init__(self, sparse_object_labels_batch: List[Optional[ObjectLab...
method __len__ (line 415) | def __len__(self) -> int:
method __iter__ (line 418) | def __iter__(self):
method __getitem__ (line 421) | def __getitem__(self, item: int) -> Optional[ObjectLabels]:
method __add__ (line 426) | def __add__(self, other: SparselyBatchedObjectLabels):
method set_empty_labels_to_none_ (line 434) | def set_empty_labels_to_none_(self):
method input_size_hw (line 440) | def input_size_hw(self) -> Optional[Union[Tuple[int, int], Tuple[float...
method zoom_in_and_rescale_ (line 446) | def zoom_in_and_rescale_(self, *args, **kwargs):
method zoom_out_and_rescale_ (line 455) | def zoom_out_and_rescale_(self, *args, **kwargs):
method rotate_ (line 462) | def rotate_(self, *args, **kwargs):
method scale_ (line 467) | def scale_(self, *args, **kwargs):
method flip_lr_ (line 474) | def flip_lr_(self):
method to (line 479) | def to(self, *args, **kwargs):
method get_valid_labels_and_batch_indices (line 485) | def get_valid_labels_and_batch_indices(
method transpose_list (line 497) | def transpose_list(
FILE: RVT/data/genx_utils/sequence_base.py
function get_event_representation_dir (line 15) | def get_event_representation_dir(path: Path, ev_representation_name: str...
function get_objframe_idx_2_repr_idx (line 21) | def get_objframe_idx_2_repr_idx(path: Path, ev_representation_name: str)...
class SequenceBase (line 29) | class SequenceBase(MapDataPipe):
method __init__ (line 43) | def __init__(
method _get_labels_from_repr_idx (line 99) | def _get_labels_from_repr_idx(self, repr_idx: int) -> Optional[ObjectL...
method _get_event_repr_torch (line 103) | def _get_event_repr_torch(self, start_idx: int, end_idx: int) -> List[...
method __len__ (line 115) | def __len__(self) -> int:
method __getitem__ (line 118) | def __getitem__(self, index: int) -> Any:
FILE: RVT/data/genx_utils/sequence_for_streaming.py
function _scalar_as_1d_array (line 17) | def _scalar_as_1d_array(scalar: Union[int, float]):
function _get_ev_repr_range_indices (line 21) | def _get_ev_repr_range_indices(
class SequenceForIter (line 57) | class SequenceForIter(SequenceBase):
method __init__ (line 58) | def __init__(
method get_sequences_with_guaranteed_labels (line 100) | def get_sequences_with_guaranteed_labels(
method padding_representation (line 133) | def padding_representation(self) -> torch.Tensor:
method get_fully_padded_sample (line 139) | def get_fully_padded_sample(self) -> LoaderDataDictGenX:
method __len__ (line 153) | def __len__(self):
method __getitem__ (line 156) | def __getitem__(self, index: int) -> LoaderDataDictGenX:
class RandAugmentIterDataPipe (line 205) | class RandAugmentIterDataPipe(IterDataPipe):
method __init__ (line 206) | def __init__(self, source_dp: IterDataPipe, dataset_config: DictConfig):
method __iter__ (line 223) | def __iter__(self):
FILE: RVT/data/genx_utils/sequence_rnd.py
class SequenceForRandomAccess (line 9) | class SequenceForRandomAccess(SequenceBase):
method __init__ (line 10) | def __init__(
method __len__ (line 44) | def __len__(self):
method __getitem__ (line 47) | def __getitem__(self, index: int) -> LoaderDataDictGenX:
method is_only_loading_labels (line 83) | def is_only_loading_labels(self) -> bool:
method only_load_labels (line 86) | def only_load_labels(self):
method load_everything (line 89) | def load_everything(self):
FILE: RVT/data/utils/augmentor.py
class ZoomOutState (line 24) | class ZoomOutState:
class RotationState (line 32) | class RotationState:
class AugmentationState (line 38) | class AugmentationState:
class RandomSpatialAugmentorGenX (line 45) | class RandomSpatialAugmentorGenX:
method __init__ (line 46) | def __init__(
method randomize_augmentation (line 99) | def randomize_augmentation(self):
method _zoom_out_and_rescale (line 145) | def _zoom_out_and_rescale(
method _zoom_out_and_rescale_tensor (line 164) | def _zoom_out_and_rescale_tensor(
method _zoom_out_and_rescale_recursive (line 194) | def _zoom_out_and_rescale_recursive(
method _zoom_in_and_rescale (line 241) | def _zoom_in_and_rescale(self, data_dict: LoaderDataDictGenX) -> Loade...
method _zoom_in_and_rescale_tensor (line 276) | def _zoom_in_and_rescale_tensor(
method _zoom_in_and_rescale_recursive (line 306) | def _zoom_in_and_rescale_recursive(
method _rotate (line 353) | def _rotate(self, data_dict: LoaderDataDictGenX) -> LoaderDataDictGenX:
method _rotate_tensor (line 363) | def _rotate_tensor(input_: Any, angle_deg: float, datatype: DataType):
method _rotate_recursive (line 372) | def _rotate_recursive(cls, input_: Any, angle_deg: float, datatype: Da...
method _flip (line 402) | def _flip(data_dict: LoaderDataDictGenX, type_: str) -> LoaderDataDict...
method _flip_tensor (line 412) | def _flip_tensor(input_: Any, flip_type: str, datatype: DataType):
method _flip_recursive (line 427) | def _flip_recursive(cls, input_: Any, flip_type: str, datatype: DataTy...
method _hw_from_data (line 461) | def _hw_from_data(data_dict: LoaderDataDictGenX) -> Tuple[int, int]:
method __call__ (line 483) | def __call__(self, data_dict: LoaderDataDictGenX):
function get_most_recent_objframe (line 503) | def get_most_recent_objframe(
function randomly_sample_zoom_window_from_objframe (line 521) | def randomly_sample_zoom_window_from_objframe(
function randomly_sample_zoom_window_from_label_rectangle (line 557) | def randomly_sample_zoom_window_from_label_rectangle(
FILE: RVT/data/utils/representations.py
class RepresentationBase (line 9) | class RepresentationBase(ABC):
method construct (line 11) | def construct(
method get_shape (line 16) | def get_shape(self) -> Tuple[int, int, int]: ...
method get_numpy_dtype (line 20) | def get_numpy_dtype() -> np.dtype: ...
method get_torch_dtype (line 24) | def get_torch_dtype() -> th.dtype: ...
method dtype (line 27) | def dtype(self) -> th.dtype:
method _is_int_tensor (line 31) | def _is_int_tensor(tensor: th.Tensor) -> bool:
class StackedHistogram (line 35) | class StackedHistogram(RepresentationBase):
method __init__ (line 36) | def __init__(
method get_numpy_dtype (line 67) | def get_numpy_dtype() -> np.dtype:
method get_torch_dtype (line 71) | def get_torch_dtype() -> th.dtype:
method merge_channel_and_bins (line 74) | def merge_channel_and_bins(self, representation: th.Tensor):
method get_shape (line 78) | def get_shape(self) -> Tuple[int, int, int]:
method construct (line 81) | def construct(
function cumsum_channel (line 137) | def cumsum_channel(x: th.Tensor, num_channels: int):
class MixedDensityEventStack (line 143) | class MixedDensityEventStack(RepresentationBase):
method __init__ (line 144) | def __init__(
method get_numpy_dtype (line 173) | def get_numpy_dtype() -> np.dtype:
method get_torch_dtype (line 177) | def get_torch_dtype() -> th.dtype:
method get_shape (line 180) | def get_shape(self) -> Tuple[int, int, int]:
method construct (line 183) | def construct(
FILE: RVT/data/utils/spatial.py
function get_original_hw (line 16) | def get_original_hw(dataset_type: DatasetType):
function get_dataloading_hw (line 20) | def get_dataloading_hw(dataset_config: DictConfig):
FILE: RVT/data/utils/stream_concat_datapipe.py
class DummyIterDataPipe (line 15) | class DummyIterDataPipe(IterDataPipe):
method __init__ (line 16) | def __init__(self, source_dp: IterDataPipe):
method __iter__ (line 21) | def __iter__(self):
class ConcatStreamingDataPipe (line 25) | class ConcatStreamingDataPipe(IterDataPipe):
method __init__ (line 37) | def __init__(
method random_torch_shuffle_list (line 61) | def random_torch_shuffle_list(data: List[Any]) -> Iterator[Any]:
method _get_zipped_streams (line 65) | def _get_zipped_streams(self, datapipe_list: List[MapDataPipe], batch_...
method _print_seed_debug_info (line 85) | def _print_seed_debug_info(self):
method _get_zipped_streams_with_worker_id (line 103) | def _get_zipped_streams_with_worker_id(self):
method __iter__ (line 113) | def __iter__(self):
FILE: RVT/data/utils/stream_sharded_datapipe.py
class ShardedStreamingDataPipe (line 15) | class ShardedStreamingDataPipe(IterDataPipe):
method __init__ (line 16) | def __init__(
method yield_pyramid_indices (line 34) | def yield_pyramid_indices(start_idx: int, end_idx: int):
method assign_datapipes_to_worker (line 42) | def assign_datapipes_to_worker(
method get_zipped_stream_from_worker_datapipes (line 64) | def get_zipped_stream_from_worker_datapipes(
method __iter__ (line 89) | def __iter__(self):
FILE: RVT/data/utils/types.py
class DataType (line 14) | class DataType(Enum):
class DatasetType (line 25) | class DatasetType(Enum):
class DatasetMode (line 30) | class DatasetMode(Enum):
class DatasetSamplingMode (line 36) | class DatasetSamplingMode(StrEnum):
class ObjDetOutput (line 42) | class ObjDetOutput(Enum):
FILE: RVT/loggers/utils.py
function get_wandb_logger (line 10) | def get_wandb_logger(full_config: DictConfig) -> WandbLogger:
function get_ckpt_path (line 37) | def get_ckpt_path(logger: WandbLogger, wandb_config: DictConfig) -> Unio...
FILE: RVT/loggers/wandb_logger.py
class WandbLogger (line 34) | class WandbLogger(Logger):
method __init__ (line 38) | def __init__(
method get_checkpoint (line 82) | def get_checkpoint(
method __getstate__ (line 98) | def __getstate__(self) -> Dict[str, Any]:
method experiment (line 112) | def experiment(self) -> Run:
method watch (line 145) | def watch(
method add_step_metric (line 154) | def add_step_metric(self, input_dict: dict, step: int) -> None:
method log_hyperparams (line 158) | def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) ->...
method log_metrics (line 165) | def log_metrics(self, metrics: Dict[str, Any], step: Optional[int] = N...
method log_images (line 176) | def log_images(
method log_videos (line 198) | def log_videos(
method name (line 229) | def name(self) -> Optional[str]:
method version (line 235) | def version(self) -> Optional[str]:
method after_save_checkpoint (line 241) | def after_save_checkpoint(
method finalize (line 254) | def finalize(self, status: str) -> None:
method _get_public_run (line 259) | def _get_public_run(self):
method _num_logged_artifact (line 273) | def _num_logged_artifact(self):
method _scan_and_log_checkpoints (line 277) | def _scan_and_log_checkpoints(
method _rm_but_top_k (line 387) | def _rm_but_top_k(self, top_k: int):
FILE: RVT/models/detection/recurrent_backbone/__init__.py
function build_recurrent_backbone (line 6) | def build_recurrent_backbone(backbone_cfg: DictConfig):
FILE: RVT/models/detection/recurrent_backbone/base.py
class BaseDetector (line 6) | class BaseDetector(nn.Module):
method get_stage_dims (line 7) | def get_stage_dims(self, stages: Tuple[int, ...]) -> Tuple[int, ...]:
method get_strides (line 10) | def get_strides(self, stages: Tuple[int, ...]) -> Tuple[int, ...]:
FILE: RVT/models/detection/recurrent_backbone/maxvit_rnn.py
class RNNDetector (line 27) | class RNNDetector(BaseDetector):
method __init__ (line 28) | def __init__(self, mdl_config: DictConfig):
method get_stage_dims (line 92) | def get_stage_dims(self, stages: Tuple[int, ...]) -> Tuple[int, ...]:
method get_strides (line 98) | def get_strides(self, stages: Tuple[int, ...]) -> Tuple[int, ...]:
method forward (line 104) | def forward(
class MaxVitAttentionPairCl (line 129) | class MaxVitAttentionPairCl(nn.Module):
method __init__ (line 130) | def __init__(self, dim: int, skip_first_norm: bool, attention_cfg: Dic...
method forward (line 146) | def forward(self, x):
class RNNDetectorStage (line 152) | class RNNDetectorStage(nn.Module):
method __init__ (line 155) | def __init__(
method forward (line 212) | def forward(
FILE: RVT/models/detection/yolox/models/losses.py
class IOUloss (line 9) | class IOUloss(nn.Module):
method __init__ (line 10) | def __init__(self, reduction="none", loss_type="iou"):
method forward (line 15) | def forward(self, pred, target):
FILE: RVT/models/detection/yolox/models/network_blocks.py
class SiLU (line 9) | class SiLU(nn.Module):
method forward (line 13) | def forward(x):
function get_activation (line 17) | def get_activation(name="silu", inplace=True):
class BaseConv (line 29) | class BaseConv(nn.Module):
method __init__ (line 32) | def __init__(
method forward (line 50) | def forward(self, x):
method fuseforward (line 53) | def fuseforward(self, x):
class DWConv (line 57) | class DWConv(nn.Module):
method __init__ (line 60) | def __init__(self, in_channels, out_channels, ksize, stride=1, act="si...
method forward (line 74) | def forward(self, x):
class Bottleneck (line 79) | class Bottleneck(nn.Module):
method __init__ (line 81) | def __init__(
method forward (line 97) | def forward(self, x):
class CSPLayer (line 104) | class CSPLayer(nn.Module):
method __init__ (line 107) | def __init__(
method forward (line 137) | def forward(self, x):
FILE: RVT/models/detection/yolox/models/yolo_head.py
class YOLOXHead (line 23) | class YOLOXHead(nn.Module):
method __init__ (line 24) | def __init__(
method initialize_biases (line 158) | def initialize_biases(self, prior_prob):
method forward (line 169) | def forward(self, xin, labels=None):
method get_output_and_grid (line 250) | def get_output_and_grid(self, output, k, stride, dtype):
method decode_outputs (line 268) | def decode_outputs(self, outputs):
method get_losses (line 300) | def get_losses(
method get_l1_target (line 454) | def get_l1_target(self, l1_target, gt, stride, x_shifts, y_shifts, eps...
method get_assignments (line 462) | def get_assignments(
method get_geometry_constraint (line 550) | def get_geometry_constraint(
method simota_matching (line 589) | def simota_matching(self, cost, pair_wise_ious, gt_classes, num_gt, fg...
FILE: RVT/models/detection/yolox/utils/boxes.py
function filter_box (line 21) | def filter_box(output, scale_range):
function postprocess (line 32) | def postprocess(
function bboxes_iou (line 82) | def bboxes_iou(bboxes_a, bboxes_b, xyxy=True):
function matrix_iou (line 108) | def matrix_iou(a, b):
function adjust_box_anns (line 121) | def adjust_box_anns(bbox, scale_ratio, padw, padh, w_max, h_max):
function xyxy2xywh (line 127) | def xyxy2xywh(bboxes):
function xyxy2cxcywh (line 133) | def xyxy2cxcywh(bboxes):
FILE: RVT/models/detection/yolox/utils/compat.py
function meshgrid (line 11) | def meshgrid(*tensors):
FILE: RVT/models/detection/yolox_extension/models/build.py
function build_yolox_head (line 9) | def build_yolox_head(
function build_yolox_fpn (line 24) | def build_yolox_fpn(fpn_cfg: DictConfig, in_channels: Tuple[int, ...]):
FILE: RVT/models/detection/yolox_extension/models/detector.py
class YoloXDetector (line 18) | class YoloXDetector(th.nn.Module):
method __init__ (line 19) | def __init__(self, model_cfg: DictConfig):
method forward_backbone (line 35) | def forward_backbone(
method forward_detect (line 48) | def forward_detect(
method forward (line 64) | def forward(
FILE: RVT/models/detection/yolox_extension/models/yolo_pafpn.py
class YOLOPAFPN (line 19) | class YOLOPAFPN(nn.Module):
method __init__ (line 24) | def __init__(
method forward (line 104) | def forward(self, input: BackboneFeatures):
FILE: RVT/models/layers/maxvit/layers/activations.py
function swish (line 14) | def swish(x, inplace: bool = False):
class Swish (line 19) | class Swish(nn.Module):
method __init__ (line 20) | def __init__(self, inplace: bool = False):
method forward (line 24) | def forward(self, x):
function mish (line 28) | def mish(x, inplace: bool = False):
class Mish (line 35) | class Mish(nn.Module):
method __init__ (line 38) | def __init__(self, inplace: bool = False):
method forward (line 41) | def forward(self, x):
function sigmoid (line 45) | def sigmoid(x, inplace: bool = False):
class Sigmoid (line 50) | class Sigmoid(nn.Module):
method __init__ (line 51) | def __init__(self, inplace: bool = False):
method forward (line 55) | def forward(self, x):
function tanh (line 59) | def tanh(x, inplace: bool = False):
class Tanh (line 64) | class Tanh(nn.Module):
method __init__ (line 65) | def __init__(self, inplace: bool = False):
method forward (line 69) | def forward(self, x):
function hard_swish (line 73) | def hard_swish(x, inplace: bool = False):
class HardSwish (line 78) | class HardSwish(nn.Module):
method __init__ (line 79) | def __init__(self, inplace: bool = False):
method forward (line 83) | def forward(self, x):
function hard_sigmoid (line 87) | def hard_sigmoid(x, inplace: bool = False):
class HardSigmoid (line 94) | class HardSigmoid(nn.Module):
method __init__ (line 95) | def __init__(self, inplace: bool = False):
method forward (line 99) | def forward(self, x):
function hard_mish (line 103) | def hard_mish(x, inplace: bool = False):
class HardMish (line 114) | class HardMish(nn.Module):
method __init__ (line 115) | def __init__(self, inplace: bool = False):
method forward (line 119) | def forward(self, x):
class PReLU (line 123) | class PReLU(nn.PReLU):
method __init__ (line 126) | def __init__(
method forward (line 131) | def forward(self, input: torch.Tensor) -> torch.Tensor:
function gelu (line 135) | def gelu(x: torch.Tensor, inplace: bool = False) -> torch.Tensor:
class GELU (line 139) | class GELU(nn.Module):
method __init__ (line 142) | def __init__(self, inplace: bool = False):
method forward (line 145) | def forward(self, input: torch.Tensor) -> torch.Tensor:
FILE: RVT/models/layers/maxvit/layers/activations_jit.py
function swish_jit (line 19) | def swish_jit(x, inplace: bool = False):
function mish_jit (line 25) | def mish_jit(x, _inplace: bool = False):
class SwishJit (line 30) | class SwishJit(nn.Module):
method __init__ (line 31) | def __init__(self, inplace: bool = False):
method forward (line 34) | def forward(self, x):
class MishJit (line 38) | class MishJit(nn.Module):
method __init__ (line 39) | def __init__(self, inplace: bool = False):
method forward (line 42) | def forward(self, x):
function hard_sigmoid_jit (line 47) | def hard_sigmoid_jit(x, inplace: bool = False):
class HardSigmoidJit (line 52) | class HardSigmoidJit(nn.Module):
method __init__ (line 53) | def __init__(self, inplace: bool = False):
method forward (line 56) | def forward(self, x):
function hard_swish_jit (line 61) | def hard_swish_jit(x, inplace: bool = False):
class HardSwishJit (line 68) | class HardSwishJit(nn.Module):
method __init__ (line 69) | def __init__(self, inplace: bool = False):
method forward (line 72) | def forward(self, x):
function hard_mish_jit (line 77) | def hard_mish_jit(x, inplace: bool = False):
class HardMishJit (line 85) | class HardMishJit(nn.Module):
method __init__ (line 86) | def __init__(self, inplace: bool = False):
method forward (line 89) | def forward(self, x):
FILE: RVT/models/layers/maxvit/layers/activations_me.py
function swish_jit_fwd (line 18) | def swish_jit_fwd(x):
function swish_jit_bwd (line 23) | def swish_jit_bwd(x, grad_output):
class SwishJitAutoFn (line 28) | class SwishJitAutoFn(torch.autograd.Function):
method symbolic (line 35) | def symbolic(g, x):
method forward (line 39) | def forward(ctx, x):
method backward (line 44) | def backward(ctx, grad_output):
function swish_me (line 49) | def swish_me(x, inplace=False):
class SwishMe (line 53) | class SwishMe(nn.Module):
method __init__ (line 54) | def __init__(self, inplace: bool = False):
method forward (line 57) | def forward(self, x):
function mish_jit_fwd (line 62) | def mish_jit_fwd(x):
function mish_jit_bwd (line 67) | def mish_jit_bwd(x, grad_output):
class MishJitAutoFn (line 73) | class MishJitAutoFn(torch.autograd.Function):
method forward (line 79) | def forward(ctx, x):
method backward (line 84) | def backward(ctx, grad_output):
function mish_me (line 89) | def mish_me(x, inplace=False):
class MishMe (line 93) | class MishMe(nn.Module):
method __init__ (line 94) | def __init__(self, inplace: bool = False):
method forward (line 97) | def forward(self, x):
function hard_sigmoid_jit_fwd (line 102) | def hard_sigmoid_jit_fwd(x, inplace: bool = False):
function hard_sigmoid_jit_bwd (line 107) | def hard_sigmoid_jit_bwd(x, grad_output):
class HardSigmoidJitAutoFn (line 112) | class HardSigmoidJitAutoFn(torch.autograd.Function):
method forward (line 114) | def forward(ctx, x):
method backward (line 119) | def backward(ctx, grad_output):
function hard_sigmoid_me (line 124) | def hard_sigmoid_me(x, inplace: bool = False):
class HardSigmoidMe (line 128) | class HardSigmoidMe(nn.Module):
method __init__ (line 129) | def __init__(self, inplace: bool = False):
method forward (line 132) | def forward(self, x):
function hard_swish_jit_fwd (line 137) | def hard_swish_jit_fwd(x):
function hard_swish_jit_bwd (line 142) | def hard_swish_jit_bwd(x, grad_output):
class HardSwishJitAutoFn (line 148) | class HardSwishJitAutoFn(torch.autograd.Function):
method forward (line 152) | def forward(ctx, x):
method backward (line 157) | def backward(ctx, grad_output):
method symbolic (line 162) | def symbolic(g, self):
function hard_swish_me (line 180) | def hard_swish_me(x, inplace=False):
class HardSwishMe (line 184) | class HardSwishMe(nn.Module):
method __init__ (line 185) | def __init__(self, inplace: bool = False):
method forward (line 188) | def forward(self, x):
function hard_mish_jit_fwd (line 193) | def hard_mish_jit_fwd(x):
function hard_mish_jit_bwd (line 198) | def hard_mish_jit_bwd(x, grad_output):
class HardMishJitAutoFn (line 204) | class HardMishJitAutoFn(torch.autograd.Function):
method forward (line 211) | def forward(ctx, x):
method backward (line 216) | def backward(ctx, grad_output):
function hard_mish_me (line 221) | def hard_mish_me(x, inplace: bool = False):
class HardMishMe (line 225) | class HardMishMe(nn.Module):
method __init__ (line 226) | def __init__(self, inplace: bool = False):
method forward (line 229) | def forward(self, x):
FILE: RVT/models/layers/maxvit/layers/adaptive_avgmax_pool.py
function adaptive_pool_feat_mult (line 18) | def adaptive_pool_feat_mult(pool_type="avg"):
function adaptive_avgmax_pool2d (line 25) | def adaptive_avgmax_pool2d(x, output_size=1):
function adaptive_catavgmax_pool2d (line 31) | def adaptive_catavgmax_pool2d(x, output_size=1):
function select_adaptive_pool2d (line 37) | def select_adaptive_pool2d(x, pool_type="avg", output_size=1):
class FastAdaptiveAvgPool2d (line 52) | class FastAdaptiveAvgPool2d(nn.Module):
method __init__ (line 53) | def __init__(self, flatten=False):
method forward (line 57) | def forward(self, x):
class AdaptiveAvgMaxPool2d (line 61) | class AdaptiveAvgMaxPool2d(nn.Module):
method __init__ (line 62) | def __init__(self, output_size=1):
method forward (line 66) | def forward(self, x):
class AdaptiveCatAvgMaxPool2d (line 70) | class AdaptiveCatAvgMaxPool2d(nn.Module):
method __init__ (line 71) | def __init__(self, output_size=1):
method forward (line 75) | def forward(self, x):
class SelectAdaptivePool2d (line 79) | class SelectAdaptivePool2d(nn.Module):
method __init__ (line 82) | def __init__(self, output_size=1, pool_type="fast", flatten=False):
method is_identity (line 105) | def is_identity(self):
method forward (line 108) | def forward(self, x):
method feat_mult (line 113) | def feat_mult(self):
method __repr__ (line 116) | def __repr__(self):
FILE: RVT/models/layers/maxvit/layers/attention_pool2d.py
class RotAttentionPool2d (line 21) | class RotAttentionPool2d(nn.Module):
method __init__ (line 32) | def __init__(
method forward (line 54) | def forward(self, x):
class AttentionPool2d (line 85) | class AttentionPool2d(nn.Module):
method __init__ (line 95) | def __init__(
method forward (line 122) | def forward(self, x):
FILE: RVT/models/layers/maxvit/layers/blur_pool.py
class BlurPool2d (line 16) | class BlurPool2d(nn.Module):
method __init__ (line 30) | def __init__(self, channels, filt_size=3, stride=2) -> None:
method forward (line 45) | def forward(self, x: torch.Tensor) -> torch.Tensor:
FILE: RVT/models/layers/maxvit/layers/bottleneck_attn.py
function rel_logits_1d (line 29) | def rel_logits_1d(q, rel_k, permute_mask: List[int]):
class PosEmbedRel (line 57) | class PosEmbedRel(nn.Module):
method __init__ (line 63) | def __init__(self, feat_size, dim_head, scale):
method forward (line 72) | def forward(self, q):
class BottleneckAttn (line 88) | class BottleneckAttn(nn.Module):
method __init__ (line 111) | def __init__(
method reset_parameters (line 152) | def reset_parameters(self):
method forward (line 157) | def forward(self, x):
FILE: RVT/models/layers/maxvit/layers/cbam.py
class ChannelAttn (line 20) | class ChannelAttn(nn.Module):
method __init__ (line 23) | def __init__(
method forward (line 43) | def forward(self, x):
class LightChannelAttn (line 49) | class LightChannelAttn(ChannelAttn):
method __init__ (line 52) | def __init__(
method forward (line 66) | def forward(self, x):
class SpatialAttn (line 72) | class SpatialAttn(nn.Module):
method __init__ (line 75) | def __init__(self, kernel_size=7, gate_layer="sigmoid"):
method forward (line 80) | def forward(self, x):
class LightSpatialAttn (line 88) | class LightSpatialAttn(nn.Module):
method __init__ (line 91) | def __init__(self, kernel_size=7, gate_layer="sigmoid"):
method forward (line 96) | def forward(self, x):
class CbamModule (line 102) | class CbamModule(nn.Module):
method __init__ (line 103) | def __init__(
method forward (line 126) | def forward(self, x):
class LightCbamModule (line 132) | class LightCbamModule(nn.Module):
method __init__ (line 133) | def __init__(
method forward (line 156) | def forward(self, x):
FILE: RVT/models/layers/maxvit/layers/classifier.py
function _create_pool (line 12) | def _create_pool(num_features, num_classes, pool_type="avg", use_conv=Fa...
function _create_fc (line 26) | def _create_fc(num_features, num_classes, use_conv=False):
function create_classifier (line 36) | def create_classifier(num_features, num_classes, pool_type="avg", use_co...
class ClassifierHead (line 44) | class ClassifierHead(nn.Module):
method __init__ (line 47) | def __init__(
method forward (line 58) | def forward(self, x, pre_logits: bool = False):
FILE: RVT/models/layers/maxvit/layers/cond_conv2d.py
function get_condconv_initializer (line 21) | def get_condconv_initializer(initializer, num_experts, expert_shape):
class CondConv2d (line 41) | class CondConv2d(nn.Module):
method __init__ (line 51) | def __init__(
method reset_parameters (line 101) | def reset_parameters(self):
method forward (line 118) | def forward(self, x, routing_weights):
FILE: RVT/models/layers/maxvit/layers/config.py
function is_no_jit (line 31) | def is_no_jit():
class set_no_jit (line 35) | class set_no_jit:
method __init__ (line 36) | def __init__(self, mode: bool) -> None:
method __enter__ (line 41) | def __enter__(self) -> None:
method __exit__ (line 44) | def __exit__(self, *args: Any) -> bool:
function is_exportable (line 50) | def is_exportable():
class set_exportable (line 54) | class set_exportable:
method __init__ (line 55) | def __init__(self, mode: bool) -> None:
method __enter__ (line 60) | def __enter__(self) -> None:
method __exit__ (line 63) | def __exit__(self, *args: Any) -> bool:
function is_scriptable (line 69) | def is_scriptable():
class set_scriptable (line 73) | class set_scriptable:
method __init__ (line 74) | def __init__(self, mode: bool) -> None:
method __enter__ (line 79) | def __enter__(self) -> None:
method __exit__ (line 82) | def __exit__(self, *args: Any) -> bool:
class set_layer_config (line 88) | class set_layer_config:
method __init__ (line 93) | def __init__(
method __enter__ (line 114) | def __enter__(self) -> None:
method __exit__ (line 117) | def __exit__(self, *args: Any) -> bool:
FILE: RVT/models/layers/maxvit/layers/conv2d_same.py
function conv2d_same (line 14) | def conv2d_same(
class Conv2dSame (line 27) | class Conv2dSame(nn.Conv2d):
method __init__ (line 30) | def __init__(
method forward (line 45) | def forward(self, x):
function create_conv2d_pad (line 57) | def create_conv2d_pad(in_chs, out_chs, kernel_size, **kwargs):
FILE: RVT/models/layers/maxvit/layers/conv_bn_act.py
class ConvNormAct (line 13) | class ConvNormAct(nn.Module):
method __init__ (line 14) | def __init__(
method in_channels (line 48) | def in_channels(self):
method out_channels (line 52) | def out_channels(self):
method forward (line 55) | def forward(self, x):
function create_aa (line 64) | def create_aa(aa_layer, channels, stride=2, enable=True):
class ConvNormActAa (line 78) | class ConvNormActAa(nn.Module):
method __init__ (line 79) | def __init__(
method in_channels (line 117) | def in_channels(self):
method out_channels (line 121) | def out_channels(self):
method forward (line 124) | def forward(self, x):
FILE: RVT/models/layers/maxvit/layers/create_act.py
function get_act_fn (line 106) | def get_act_fn(name: Union[Callable, str] = "relu"):
function get_act_layer (line 126) | def get_act_layer(name: Union[Type[nn.Module], str] = "relu"):
function create_act_layer (line 145) | def create_act_layer(name: Union[nn.Module, str], inplace=None, **kwargs):
FILE: RVT/models/layers/maxvit/layers/create_attn.py
function get_attn (line 22) | def get_attn(attn_type):
function create_attn (line 85) | def create_attn(attn_type, channels, **kwargs):
FILE: RVT/models/layers/maxvit/layers/create_conv2d.py
function create_conv2d (line 11) | def create_conv2d(in_channels, out_channels, kernel_size, **kwargs):
FILE: RVT/models/layers/maxvit/layers/create_norm.py
function create_norm_layer (line 27) | def create_norm_layer(
function get_norm_layer (line 35) | def get_norm_layer(norm_layer):
FILE: RVT/models/layers/maxvit/layers/create_norm_act.py
function create_norm_act_layer (line 51) | def create_norm_act_layer(
function get_norm_act_layer (line 61) | def get_norm_act_layer(norm_layer, act_layer=None):
FILE: RVT/models/layers/maxvit/layers/drop.py
function drop_block_2d (line 23) | def drop_block_2d(
function drop_block_fast_2d (line 92) | def drop_block_fast_2d(
class DropBlock2d (line 142) | class DropBlock2d(nn.Module):
method __init__ (line 145) | def __init__(
method forward (line 164) | def forward(self, x):
function drop_path (line 188) | def drop_path(
class DropPath (line 212) | class DropPath(nn.Module):
method __init__ (line 215) | def __init__(self, drop_prob: float = 0.0, scale_by_keep: bool = True):
method forward (line 220) | def forward(self, x):
method extra_repr (line 223) | def extra_repr(self):
FILE: RVT/models/layers/maxvit/layers/eca.py
class EcaModule (line 46) | class EcaModule(nn.Module):
method __init__ (line 62) | def __init__(
method forward (line 100) | def forward(self, x):
class CecaModule (line 113) | class CecaModule(nn.Module):
method __init__ (line 137) | def __init__(
method forward (line 160) | def forward(self, x):
FILE: RVT/models/layers/maxvit/layers/evo_norm.py
function instance_std (line 37) | def instance_std(x, eps: float = 1e-5):
function instance_std_tpu (line 48) | def instance_std_tpu(x, eps: float = 1e-5):
function instance_rms (line 56) | def instance_rms(x, eps: float = 1e-5):
function manual_var (line 61) | def manual_var(x, dim: Union[int, Sequence[int]], diff_sqm: bool = False):
function group_std (line 71) | def group_std(x, groups: int = 32, eps: float = 1e-5, flatten: bool = Fa...
function group_std_tpu (line 96) | def group_std_tpu(
function group_rms (line 119) | def group_rms(x, groups: int = 32, eps: float = 1e-5):
class EvoNorm2dB0 (line 135) | class EvoNorm2dB0(nn.Module):
method __init__ (line 136) | def __init__(self, num_features, apply_act=True, momentum=0.1, eps=1e-...
method reset_parameters (line 147) | def reset_parameters(self):
method forward (line 153) | def forward(self, x):
class EvoNorm2dB1 (line 177) | class EvoNorm2dB1(nn.Module):
method __init__ (line 178) | def __init__(self, num_features, apply_act=True, momentum=0.1, eps=1e-...
method reset_parameters (line 188) | def reset_parameters(self):
method forward (line 192) | def forward(self, x):
class EvoNorm2dB2 (line 217) | class EvoNorm2dB2(nn.Module):
method __init__ (line 218) | def __init__(self, num_features, apply_act=True, momentum=0.1, eps=1e-...
method reset_parameters (line 228) | def reset_parameters(self):
method forward (line 232) | def forward(self, x):
class EvoNorm2dS0 (line 257) | class EvoNorm2dS0(nn.Module):
method __init__ (line 258) | def __init__(
method reset_parameters (line 274) | def reset_parameters(self):
method forward (line 280) | def forward(self, x):
class EvoNorm2dS0a (line 292) | class EvoNorm2dS0a(EvoNorm2dS0):
method __init__ (line 293) | def __init__(
method forward (line 304) | def forward(self, x):
class EvoNorm2dS1 (line 318) | class EvoNorm2dS1(nn.Module):
method __init__ (line 319) | def __init__(
method reset_parameters (line 347) | def reset_parameters(self):
method forward (line 351) | def forward(self, x):
class EvoNorm2dS1a (line 362) | class EvoNorm2dS1a(EvoNorm2dS1):
method __init__ (line 363) | def __init__(
method forward (line 382) | def forward(self, x):
class EvoNorm2dS2 (line 392) | class EvoNorm2dS2(nn.Module):
method __init__ (line 393) | def __init__(
method reset_parameters (line 420) | def reset_parameters(self):
method forward (line 424) | def forward(self, x):
class EvoNorm2dS2a (line 435) | class EvoNorm2dS2a(EvoNorm2dS2):
method __init__ (line 436) | def __init__(
method forward (line 455) | def forward(self, x):
FILE: RVT/models/layers/maxvit/layers/fast_norm.py
function is_fast_norm (line 27) | def is_fast_norm():
function set_fast_norm (line 31) | def set_fast_norm(enable=True):
function fast_group_norm (line 36) | def fast_group_norm(
function fast_layer_norm (line 58) | def fast_layer_norm(
FILE: RVT/models/layers/maxvit/layers/filter_response_norm.py
function inv_instance_rms (line 15) | def inv_instance_rms(x, eps: float = 1e-5):
class FilterResponseNormTlu2d (line 20) | class FilterResponseNormTlu2d(nn.Module):
method __init__ (line 21) | def __init__(self, num_features, apply_act=True, eps=1e-5, rms=True, *...
method reset_parameters (line 31) | def reset_parameters(self):
method forward (line 37) | def forward(self, x):
class FilterResponseNormAct2d (line 52) | class FilterResponseNormAct2d(nn.Module):
method __init__ (line 53) | def __init__(
method reset_parameters (line 74) | def reset_parameters(self):
method forward (line 78) | def forward(self, x):
FILE: RVT/models/layers/maxvit/layers/gather_excite.py
class GatherExcite (line 26) | class GatherExcite(nn.Module):
method __init__ (line 29) | def __init__(
method forward (line 101) | def forward(self, x):
FILE: RVT/models/layers/maxvit/layers/global_context.py
class GlobalContext (line 20) | class GlobalContext(nn.Module):
method __init__ (line 21) | def __init__(
method reset_parameters (line 62) | def reset_parameters(self):
method forward (line 70) | def forward(self, x):
FILE: RVT/models/layers/maxvit/layers/halo_attn.py
function rel_logits_1d (line 31) | def rel_logits_1d(q, rel_k, permute_mask: List[int]):
class PosEmbedRel (line 62) | class PosEmbedRel(nn.Module):
method __init__ (line 69) | def __init__(self, block_size, win_size, dim_head, scale):
method forward (line 83) | def forward(self, q):
class HaloAttn (line 99) | class HaloAttn(nn.Module):
method __init__ (line 128) | def __init__(
method reset_parameters (line 185) | def reset_parameters(self):
method forward (line 192) | def forward(self, x):
FILE: RVT/models/layers/maxvit/layers/helpers.py
function _ntuple (line 11) | def _ntuple(n):
function make_divisible (line 27) | def make_divisible(v, divisor=8, min_value=None, round_limit=0.9):
function extend_tuple (line 36) | def extend_tuple(x, n):
FILE: RVT/models/layers/maxvit/layers/inplace_abn.py
function inplace_abn (line 11) | def inplace_abn(
function inplace_abn_sync (line 27) | def inplace_abn_sync(**kwargs):
class InplaceAbn (line 31) | class InplaceAbn(nn.Module):
method __init__ (line 52) | def __init__(
method reset_parameters (line 95) | def reset_parameters(self):
method forward (line 102) | def forward(self, x):
FILE: RVT/models/layers/maxvit/layers/lambda_layer.py
function rel_pos_indices (line 32) | def rel_pos_indices(size):
class LambdaLayer (line 43) | class LambdaLayer(nn.Module):
method __init__ (line 71) | def __init__(
method reset_parameters (line 125) | def reset_parameters(self):
method forward (line 132) | def forward(self, x):
FILE: RVT/models/layers/maxvit/layers/linear.py
class Linear (line 9) | class Linear(nn.Linear):
method forward (line 16) | def forward(self, input: torch.Tensor) -> torch.Tensor:
FILE: RVT/models/layers/maxvit/layers/median_pool.py
class MedianPool2d (line 10) | class MedianPool2d(nn.Module):
method __init__ (line 20) | def __init__(self, kernel_size=3, stride=1, padding=0, same=False):
method _padding (line 27) | def _padding(self, x):
method forward (line 47) | def forward(self, x):
FILE: RVT/models/layers/maxvit/layers/mixed_conv2d.py
function _split_channels (line 14) | def _split_channels(num_chan, num_groups):
class MixedConv2d (line 20) | class MixedConv2d(nn.ModuleDict):
method __init__ (line 27) | def __init__(
method forward (line 66) | def forward(self, x):
FILE: RVT/models/layers/maxvit/layers/ml_decoder.py
function add_ml_decoder_head (line 9) | def add_ml_decoder_head(model):
class TransformerDecoderLayerOptimal (line 43) | class TransformerDecoderLayerOptimal(nn.Module):
method __init__ (line 44) | def __init__(
method __setstate__ (line 71) | def __setstate__(self, state):
method forward (line 76) | def forward(
class GroupFC (line 113) | class GroupFC(object):
method __init__ (line 114) | def __init__(self, embed_len_decoder: int):
method __call__ (line 117) | def __call__(
class MLDecoder (line 126) | class MLDecoder(nn.Module):
method __init__ (line 127) | def __init__(
method forward (line 171) | def forward(self, x):
FILE: RVT/models/layers/maxvit/layers/mlp.py
class Mlp (line 11) | class Mlp(nn.Module):
method __init__ (line 14) | def __init__(
method forward (line 35) | def forward(self, x):
class GluMlp (line 44) | class GluMlp(nn.Module):
method __init__ (line 49) | def __init__(
method init_weights (line 71) | def init_weights(self):
method forward (line 77) | def forward(self, x):
class GatedMlp (line 87) | class GatedMlp(nn.Module):
method __init__ (line 90) | def __init__(
method forward (line 120) | def forward(self, x):
class ConvMlp (line 130) | class ConvMlp(nn.Module):
method __init__ (line 133) | def __init__(
method forward (line 154) | def forward(self, x):
FILE: RVT/models/layers/maxvit/layers/non_local_attn.py
class NonLocalAttn (line 17) | class NonLocalAttn(nn.Module):
method __init__ (line 24) | def __init__(
method forward (line 44) | def forward(self, x):
method reset_parameters (line 66) | def reset_parameters(self):
class BilinearAttnTransform (line 80) | class BilinearAttnTransform(nn.Module):
method __init__ (line 81) | def __init__(
method resize_mat (line 107) | def resize_mat(self, x, t: int):
method forward (line 120) | def forward(self, x):
class BatNonLocalAttn (line 172) | class BatNonLocalAttn(nn.Module):
method __init__ (line 177) | def __init__(
method forward (line 204) | def forward(self, x):
FILE: RVT/models/layers/maxvit/layers/norm.py
class GroupNorm (line 15) | class GroupNorm(nn.GroupNorm):
method __init__ (line 16) | def __init__(self, num_channels, num_groups=32, eps=1e-5, affine=True):
method forward (line 23) | def forward(self, x):
class GroupNorm1 (line 30) | class GroupNorm1(nn.GroupNorm):
method __init__ (line 35) | def __init__(self, num_channels, **kwargs):
method forward (line 41) | def forward(self, x: torch.Tensor) -> torch.Tensor:
class LayerNorm (line 48) | class LayerNorm(nn.LayerNorm):
method __init__ (line 51) | def __init__(self, num_channels, eps=1e-6, affine=True):
method forward (line 57) | def forward(self, x: torch.Tensor) -> torch.Tensor:
class LayerNorm2d (line 67) | class LayerNorm2d(nn.LayerNorm):
method __init__ (line 70) | def __init__(self, num_channels, eps=1e-6, affine=True):
method forward (line 76) | def forward(self, x: torch.Tensor) -> torch.Tensor:
function _is_contiguous (line 88) | def _is_contiguous(tensor: torch.Tensor) -> bool:
function _layer_norm_cf (line 97) | def _layer_norm_cf(
function _layer_norm_cf_sqm (line 106) | def _layer_norm_cf_sqm(
class LayerNormExp2d (line 116) | class LayerNormExp2d(nn.LayerNorm):
method __init__ (line 125) | def __init__(self, num_channels, eps=1e-6):
method forward (line 128) | def forward(self, x) -> torch.Tensor:
FILE: RVT/models/layers/maxvit/layers/norm_act.py
class BatchNormAct2d (line 27) | class BatchNormAct2d(nn.BatchNorm2d):
method __init__ (line 35) | def __init__(
method forward (line 76) | def forward(self, x):
class SyncBatchNormAct (line 131) | class SyncBatchNormAct(nn.SyncBatchNorm):
method forward (line 136) | def forward(self, x: torch.Tensor) -> torch.Tensor:
function convert_sync_batchnorm (line 147) | def convert_sync_batchnorm(module, process_group=None):
function _num_groups (line 189) | def _num_groups(num_channels, num_groups, group_size):
class GroupNormAct (line 196) | class GroupNormAct(nn.GroupNorm):
method __init__ (line 198) | def __init__(
method forward (line 225) | def forward(self, x):
class LayerNormAct (line 235) | class LayerNormAct(nn.LayerNorm):
method __init__ (line 236) | def __init__(
method forward (line 258) | def forward(self, x):
class LayerNormAct2d (line 270) | class LayerNormAct2d(nn.LayerNorm):
method __init__ (line 271) | def __init__(
method forward (line 293) | def forward(self, x):
FILE: RVT/models/layers/maxvit/layers/padding.py
function get_padding (line 13) | def get_padding(kernel_size: int, stride: int = 1, dilation: int = 1, **...
function get_same_padding (line 19) | def get_same_padding(x: int, k: int, s: int, d: int):
function is_static_pad (line 24) | def is_static_pad(kernel_size: int, stride: int = 1, dilation: int = 1, ...
function pad_same (line 29) | def pad_same(x, k: List[int], s: List[int], d: List[int] = (1, 1), value...
function get_padding_value (line 43) | def get_padding_value(padding, kernel_size, **kwargs) -> Tuple[Tuple, bo...
FILE: RVT/models/layers/maxvit/layers/patch_embed.py
class PatchEmbed (line 16) | class PatchEmbed(nn.Module):
method __init__ (line 19) | def __init__(
method forward (line 42) | def forward(self, x):
FILE: RVT/models/layers/maxvit/layers/pool2d_same.py
function avg_pool2d_same (line 15) | def avg_pool2d_same(
class AvgPool2dSame (line 28) | class AvgPool2dSame(nn.AvgPool2d):
method __init__ (line 31) | def __init__(
method forward (line 45) | def forward(self, x):
function max_pool2d_same (line 57) | def max_pool2d_same(
class MaxPool2dSame (line 69) | class MaxPool2dSame(nn.MaxPool2d):
method __init__ (line 72) | def __init__(
method forward (line 82) | def forward(self, x):
function create_pool2d (line 89) | def create_pool2d(pool_type, kernel_size, stride=None, **kwargs):
FILE: RVT/models/layers/maxvit/layers/pos_embed.py
function pixel_freq_bands (line 8) | def pixel_freq_bands(
function inv_freq_bands (line 24) | def inv_freq_bands(
function build_sincos2d_pos_embed (line 38) | def build_sincos2d_pos_embed(
function build_fourier_pos_embed (line 90) | def build_fourier_pos_embed(
class FourierEmbed (line 147) | class FourierEmbed(nn.Module):
method __init__ (line 148) | def __init__(
method forward (line 164) | def forward(self, x):
function rot (line 191) | def rot(x):
function apply_rot_embed (line 195) | def apply_rot_embed(x: torch.Tensor, sin_emb, cos_emb):
function apply_rot_embed_list (line 199) | def apply_rot_embed_list(x: List[torch.Tensor], sin_emb, cos_emb):
function apply_rot_embed_split (line 205) | def apply_rot_embed_split(x: torch.Tensor, emb):
function build_rotary_pos_embed (line 210) | def build_rotary_pos_embed(
class RotaryEmbedding (line 240) | class RotaryEmbedding(nn.Module):
method __init__ (line 251) | def __init__(self, dim, max_res=224, linear_bands: bool = False):
method get_embed (line 260) | def get_embed(self, shape: List[int]):
method forward (line 263) | def forward(self, x):
FILE: RVT/models/layers/maxvit/layers/selective_kernel.py
function _kernel_valid (line 16) | def _kernel_valid(k):
class SelectiveKernelAttn (line 23) | class SelectiveKernelAttn(nn.Module):
method __init__ (line 24) | def __init__(
method forward (line 46) | def forward(self, x):
class SelectiveKernel (line 59) | class SelectiveKernel(nn.Module):
method __init__ (line 60) | def __init__(
method forward (line 148) | def forward(self, x):
FILE: RVT/models/layers/maxvit/layers/separable_conv.py
class SeparableConvNormAct (line 15) | class SeparableConvNormAct(nn.Module):
method __init__ (line 18) | def __init__(
method in_channels (line 59) | def in_channels(self):
method out_channels (line 63) | def out_channels(self):
method forward (line 66) | def forward(self, x):
class SeparableConv2d (line 76) | class SeparableConv2d(nn.Module):
method __init__ (line 79) | def __init__(
method in_channels (line 112) | def in_channels(self):
method out_channels (line 116) | def out_channels(self):
method forward (line 119) | def forward(self, x):
FILE: RVT/models/layers/maxvit/layers/space_to_depth.py
class SpaceToDepth (line 5) | class SpaceToDepth(nn.Module):
method __init__ (line 6) | def __init__(self, block_size=4):
method forward (line 11) | def forward(self, x):
class SpaceToDepthJit (line 24) | class SpaceToDepthJit(object):
method __call__ (line 25) | def __call__(self, x: torch.Tensor):
class SpaceToDepthModule (line 34) | class SpaceToDepthModule(nn.Module):
method __init__ (line 35) | def __init__(self, no_jit=False):
method forward (line 42) | def forward(self, x):
class DepthToSpace (line 46) | class DepthToSpace(nn.Module):
method __init__ (line 47) | def __init__(self, block_size):
method forward (line 51) | def forward(self, x):
FILE: RVT/models/layers/maxvit/layers/split_attn.py
class RadixSoftmax (line 17) | class RadixSoftmax(nn.Module):
method __init__ (line 18) | def __init__(self, radix, cardinality):
method forward (line 23) | def forward(self, x):
class SplitAttn (line 34) | class SplitAttn(nn.Module):
method __init__ (line 37) | def __init__(
method forward (line 88) | def forward(self, x):
FILE: RVT/models/layers/maxvit/layers/split_batchnorm.py
class SplitBatchNorm2d (line 19) | class SplitBatchNorm2d(torch.nn.BatchNorm2d):
method __init__ (line 20) | def __init__(
method forward (line 41) | def forward(self, input: torch.Tensor):
function convert_splitbn_model (line 56) | def convert_splitbn_model(module, num_splits=2):
FILE: RVT/models/layers/maxvit/layers/squeeze_excite.py
class SEModule (line 20) | class SEModule(nn.Module):
method __init__ (line 30) | def __init__(
method forward (line 54) | def forward(self, x):
class EffectiveSEModule (line 68) | class EffectiveSEModule(nn.Module):
method __init__ (line 73) | def __init__(self, channels, add_maxpool=False, gate_layer="hard_sigmo...
method forward (line 79) | def forward(self, x):
FILE: RVT/models/layers/maxvit/layers/std_conv.py
class StdConv2d (line 27) | class StdConv2d(nn.Conv2d):
method __init__ (line 34) | def __init__(
method forward (line 60) | def forward(self, x):
class StdConv2dSame (line 75) | class StdConv2dSame(nn.Conv2d):
method __init__ (line 82) | def __init__(
method forward (line 110) | def forward(self, x):
class ScaledStdConv2d (line 127) | class ScaledStdConv2d(nn.Conv2d):
method __init__ (line 136) | def __init__(
method forward (line 166) | def forward(self, x):
class ScaledStdConv2dSame (line 181) | class ScaledStdConv2dSame(nn.Conv2d):
method __init__ (line 190) | def __init__(
method forward (line 222) | def forward(self, x):
FILE: RVT/models/layers/maxvit/layers/test_time_pool.py
class TestTimePoolHead (line 16) | class TestTimePoolHead(nn.Module):
method __init__ (line 17) | def __init__(self, base, original_pool=7):
method forward (line 32) | def forward(self, x):
function apply_test_time_pool (line 40) | def apply_test_time_pool(model, config, use_test_size=False):
FILE: RVT/models/layers/maxvit/layers/trace_utils.py
function _assert (line 5) | def _assert(condition: bool, message: str):
function _float_to_int (line 9) | def _float_to_int(x: float) -> int:
FILE: RVT/models/layers/maxvit/layers/weight_init.py
function _trunc_normal_ (line 8) | def _trunc_normal_(tensor, mean, std, a, b):
function trunc_normal_ (line 45) | def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0):
function trunc_normal_tf_ (line 72) | def trunc_normal_tf_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0):
function variance_scaling_ (line 101) | def variance_scaling_(tensor, scale=1.0, mode="fan_in", distribution="no...
function lecun_normal_ (line 126) | def lecun_normal_(tensor):
FILE: RVT/models/layers/maxvit/maxvit.py
class PartitionType (line 26) | class PartitionType(Enum):
function nChw_2_nhwC (line 31) | def nChw_2_nhwC(x: torch.Tensor):
function nhwC_2_nChw (line 37) | def nhwC_2_nChw(x: torch.Tensor):
class LayerScale (line 43) | class LayerScale(nn.Module):
method __init__ (line 44) | def __init__(self, dim: int, init_values: float = 1e-5, inplace: bool ...
method forward (line 49) | def forward(self, x):
class GLU (line 54) | class GLU(nn.Module):
method __init__ (line 55) | def __init__(
method forward (line 83) | def forward(self, x: torch.Tensor):
class MLP (line 88) | class MLP(nn.Module):
method __init__ (line 89) | def __init__(
method forward (line 146) | def forward(self, x):
class DownsampleBase (line 150) | class DownsampleBase(nn.Module):
method __init__ (line 151) | def __init__(self):
method output_is_normed (line 155) | def output_is_normed():
function get_downsample_layer_Cf2Cl (line 159) | def get_downsample_layer_Cf2Cl(
class ConvDownsampling_Cf2Cl (line 173) | class ConvDownsampling_Cf2Cl(DownsampleBase):
method __init__ (line 178) | def __init__(
method forward (line 209) | def forward(self, x: torch.Tensor):
method output_is_normed (line 216) | def output_is_normed():
class PartitionAttentionCl (line 220) | class PartitionAttentionCl(nn.Module):
method __init__ (line 228) | def __init__(
method _partition_attn (line 301) | def _partition_attn(self, x):
method forward (line 320) | def forward(self, x):
function window_partition (line 326) | def window_partition(x, window_size: Tuple[int, int]):
function window_reverse (line 347) | def window_reverse(windows, window_size: Tuple[int, int], img_size: Tupl...
function grid_partition (line 357) | def grid_partition(x, grid_size: Tuple[int, int]):
function grid_reverse (line 372) | def grid_reverse(windows, grid_size: Tuple[int, int], img_size: Tuple[in...
class TorchMHSAWrapperCl (line 382) | class TorchMHSAWrapperCl(nn.Module):
method __init__ (line 385) | def __init__(self, dim: int, dim_head: int = 32, bias: bool = True):
method forward (line 393) | def forward(self, x: torch.Tensor):
class SelfAttentionCl (line 402) | class SelfAttentionCl(nn.Module):
method __init__ (line 405) | def __init__(self, dim: int, dim_head: int = 32, bias: bool = True):
method forward (line 414) | def forward(self, x: torch.Tensor):
function assert_activation_string (line 433) | def assert_activation_string(
function assert_norm2d_layer_string (line 467) | def assert_norm2d_layer_string(
FILE: RVT/models/layers/rnn.py
class DWSConvLSTM2d (line 7) | class DWSConvLSTM2d(nn.Module):
method __init__ (line 10) | def __init__(
method forward (line 43) | def forward(
FILE: RVT/models/layers/s5/jax_func.py
function safe_map (line 31) | def safe_map(f: Callable[[T1], T], __arg1: Iterable[T1]) -> List[T]: ...
function safe_map (line 35) | def safe_map(
function safe_map (line 41) | def safe_map(
function safe_map (line 50) | def safe_map(
function safe_map (line 60) | def safe_map(f, *args):
function combine (line 68) | def combine(tree, operator, a_flat, b_flat):
function _scan (line 77) | def _scan(tree, operator, elems, axis: int):
function associative_scan (line 129) | def associative_scan(operator: Callable, elems, axis: int = 0, reverse: ...
function test_associative_scan (line 155) | def test_associative_scan(shape=(1, 24, 24)):
function _interleave (line 185) | def _interleave(a, b, axis: int):
function test_interleave (line 205) | def test_interleave():
function _compute_fans (line 238) | def _compute_fans(shape, fan_in_axes=None):
function uniform (line 260) | def uniform(shape, dtype=torch.float, minval=0.0, maxval=1.0, device=None):
function _complex_uniform (line 268) | def _complex_uniform(shape: Sequence[int], dtype, device=None) -> torch....
function complex_as_float_dtype (line 278) | def complex_as_float_dtype(dtype):
function _complex_truncated_normal (line 290) | def _complex_truncated_normal(
function _truncated_normal (line 309) | def _truncated_normal(lower, upper, shape, dtype=torch.float):
function variance_scaling (line 330) | def variance_scaling(
function lecun_normal (line 384) | def lecun_normal(fan_in_axes=None, dtype=torch.float):
function test_variance_scaling (line 418) | def test_variance_scaling():
FILE: RVT/models/layers/s5/s5_init.py
function make_HiPPO (line 9) | def make_HiPPO(N):
function make_NPLR_HiPPO (line 23) | def make_NPLR_HiPPO(N):
function make_DPLR_HiPPO (line 43) | def make_DPLR_HiPPO(N):
function make_Normal_S (line 70) | def make_Normal_S(N):
function make_Normal_HiPPO (line 79) | def make_Normal_HiPPO(N, B=1):
function log_step_initializer (line 108) | def log_step_initializer(dt_min=0.001, dt_max=0.1):
function init_log_steps (line 132) | def init_log_steps(H, dt_min, dt_max):
function init_VinvB (line 149) | def init_VinvB(init_fun, Vinv):
function trunc_standard_normal (line 171) | def trunc_standard_normal(shape):
function init_CV (line 187) | def init_CV(init_fun, shape, V) -> torch.Tensor:
function init_columnwise_B (line 204) | def init_columnwise_B(shape, dtype):
function init_columnwise_VinvB (line 228) | def init_columnwise_VinvB(init_fun, Vinv):
function init_rowwise_C (line 244) | def init_rowwise_C(shape, dtype):
FILE: RVT/models/layers/s5/s5_model.py
function binary_operator (line 19) | def binary_operator(
function apply_ssm (line 35) | def apply_ssm(
function apply_ssm_liquid (line 77) | def apply_ssm_liquid(
function discretize_bilinear (line 109) | def discretize_bilinear(Lambda, B_tilde, Delta):
function discretize_zoh (line 132) | def discretize_zoh(Lambda, B_tilde, Delta):
function as_complex (line 148) | def as_complex(t: torch.Tensor, dtype=torch.complex64):
class S5SSM (line 159) | class S5SSM(torch.nn.Module):
method __init__ (line 160) | def __init__(
method initial_state (line 271) | def initial_state(self, batch_size: Optional[int]):
method get_BC_tilde (line 277) | def get_BC_tilde(self):
method forward_rnn (line 287) | def forward_rnn(self, signal, prev_state, step_scale: float | torch.Te...
method forward (line 311) | def forward(self, signal, prev_state, step_scale: float | torch.Tensor...
class S5 (line 334) | class S5(torch.nn.Module):
method __init__ (line 335) | def __init__(
method initial_state (line 390) | def initial_state(self, batch_size: Optional[int] = None):
method forward (line 393) | def forward(self, signal, prev_state, step_scale: float | torch.Tensor...
class GEGLU (line 404) | class GEGLU(torch.nn.Module):
method forward (line 405) | def forward(self, x):
class S5Block (line 410) | class S5Block(torch.nn.Module):
method __init__ (line 411) | def __init__(
method forward (line 447) | def forward(self, x, states):
function tensor_stats (line 472) | def tensor_stats(t: torch.Tensor): # Clone of lovely_tensors for comple...
FILE: RVT/models/layers/s5/triton_comparison.py
function to_triton (line 17) | def to_triton(x: np.ndarray, device="cuda", dst_type=None):
function to_numpy (line 35) | def to_numpy(x):
function sum_op (line 79) | def sum_op(a, b):
function kernel (line 83) | def kernel(X, Z, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, AXIS: tl....
function f (line 121) | def f(carry, x):
function _fake_scan (line 124) | def _fake_scan(f, init, x):
function sum_op2 (line 155) | def sum_op2(a, b):
FILE: RVT/modules/data/genx.py
function get_dataloader_kwargs (line 20) | def get_dataloader_kwargs(
class DataModule (line 82) | class DataModule(pl.LightningDataModule):
method __init__ (line 83) | def __init__(
method get_dataloading_hw (line 130) | def get_dataloading_hw(self):
method set_mixed_sampling_mode_variables_for_train (line 133) | def set_mixed_sampling_mode_variables_for_train(self):
method setup (line 174) | def setup(self, stage: Optional[str] = None) -> None:
method train_dataloader (line 229) | def train_dataloader(self):
method val_dataloader (line 250) | def val_dataloader(self):
method test_dataloader (line 262) | def test_dataloader(self):
FILE: RVT/modules/detection.py
class Module (line 30) | class Module(pl.LightningModule):
method __init__ (line 31) | def __init__(self, full_config: DictConfig):
method setup (line 48) | def setup(self, stage: Optional[str] = None) -> None:
method forward (line 105) | def forward(
method get_worker_id_from_batch (line 119) | def get_worker_id_from_batch(self, batch: Any) -> int:
method get_data_from_batch (line 122) | def get_data_from_batch(self, batch: Any):
method training_step (line 125) | def training_step(self, batch: Any, batch_idx: int) -> STEP_OUTPUT:
method _val_test_step_impl (line 272) | def _val_test_step_impl(self, batch: Any, mode: Mode) -> Optional[STEP...
method validation_step (line 379) | def validation_step(self, batch: Any, batch_idx: int) -> Optional[STEP...
method test_step (line 382) | def test_step(self, batch: Any, batch_idx: int) -> Optional[STEP_OUTPUT]:
method run_psee_evaluator (line 385) | def run_psee_evaluator(self, mode: Mode):
method on_train_epoch_end (line 446) | def on_train_epoch_end(self) -> None:
method on_validation_epoch_end (line 457) | def on_validation_epoch_end(self) -> None:
method on_test_epoch_end (line 463) | def on_test_epoch_end(self) -> None:
method configure_optimizers (line 468) | def configure_optimizers(self) -> Any:
FILE: RVT/modules/utils/detection.py
class Mode (line 11) | class Mode(Enum):
class BackboneFeatureSelector (line 24) | class BackboneFeatureSelector:
method __init__ (line 25) | def __init__(self):
method reset (line 29) | def reset(self):
method add_backbone_features (line 32) | def add_backbone_features(
method get_batched_backbone_features (line 49) | def get_batched_backbone_features(self) -> Optional[BackboneFeatures]:
class EventReprSelector (line 55) | class EventReprSelector:
method __init__ (line 56) | def __init__(self):
method reset (line 60) | def reset(self):
method __len__ (line 63) | def __len__(self):
method add_event_representations (line 66) | def add_event_representations(
method get_event_representations_as_list (line 77) | def get_event_representations_as_list(
class RNNStates (line 88) | class RNNStates:
method __init__ (line 89) | def __init__(self):
method _has_states (line 92) | def _has_states(self):
method recursive_detach (line 96) | def recursive_detach(cls, inp: Union[th.Tensor, List, Tuple, Dict]):
method recursive_reset (line 108) | def recursive_reset(
method save_states_and_detach (line 140) | def save_states_and_detach(self, worker_id: int, states: LstmStates) -...
method get_states (line 143) | def get_states(self, worker_id: int) -> Optional[LstmStates]:
method reset (line 150) | def reset(
function mixed_collate_fn (line 163) | def mixed_collate_fn(
function merge_mixed_batches (line 179) | def merge_mixed_batches(batch: Dict[str, Any]):
FILE: RVT/modules/utils/fetch.py
function fetch_model_module (line 8) | def fetch_model_module(config: DictConfig) -> pl.LightningModule:
function fetch_data_module (line 15) | def fetch_data_module(config: DictConfig) -> pl.LightningDataModule:
FILE: RVT/scripts/genx/preprocess_dataset.py
class DataKeys (line 39) | class DataKeys(Enum):
class SplitType (line 47) | class SplitType(Enum):
class NoLabelsException (line 74) | class NoLabelsException(Exception):
class H5Writer (line 79) | class H5Writer:
method __init__ (line 80) | def __init__(
method __enter__ (line 105) | def __enter__(self):
method __exit__ (line 108) | def __exit__(self, exc_type, exc_val, exc_tb):
method close_callback (line 112) | def close_callback(h5f: h5py.File):
method close (line 115) | def close(self):
method get_current_length (line 118) | def get_current_length(self):
method add_data (line 121) | def add_data(self, data: np.ndarray):
class H5Reader (line 133) | class H5Reader:
method __init__ (line 134) | def __init__(self, h5_file: Path, dataset: str = "gen4"):
method __enter__ (line 152) | def __enter__(self):
method __exit__ (line 155) | def __exit__(self, exc_type, exc_val, exc_tb):
method _close_callback (line 159) | def _close_callback(h5f: h5py.File):
method close (line 162) | def close(self):
method get_height_and_width (line 166) | def get_height_and_width(self) -> Tuple[int, int]:
method time (line 170) | def time(self) -> np.ndarray:
method _correct_time (line 181) | def _correct_time(time_array: np.ndarray):
method get_event_slice (line 190) | def get_event_slice(
function prophesee_bbox_filter (line 213) | def prophesee_bbox_filter(labels: np.ndarray, dataset_type: str) -> np.n...
function conservative_bbox_filter (line 231) | def conservative_bbox_filter(labels: np.ndarray) -> np.ndarray:
function remove_faulty_huge_bbox_filter (line 240) | def remove_faulty_huge_bbox_filter(labels: np.ndarray, dataset_type: str...
function crop_to_fov_filter (line 250) | def crop_to_fov_filter(labels: np.ndarray, dataset_type: str) -> np.ndar...
function prophesee_remove_labels_filter_gen4 (line 281) | def prophesee_remove_labels_filter_gen4(labels: np.ndarray) -> np.ndarray:
function apply_filters (line 292) | def apply_filters(
function get_base_delta_ts_for_labels_us (line 313) | def get_base_delta_ts_for_labels_us(
function save_labels (line 330) | def save_labels(
function labels_and_ev_repr_timestamps (line 370) | def labels_and_ev_repr_timestamps(
function write_event_data (line 497) | def write_event_data(
function downsample_ev_repr (line 533) | def downsample_ev_repr(x: torch.Tensor, scale_factor: float):
function write_event_representations (line 548) | def write_event_representations(
function process_sequence (line 619) | def process_sequence(
class AggregationType (line 683) | class AggregationType(Enum):
class FilterConf (line 695) | class FilterConf:
class EventWindowExtractionConf (line 701) | class EventWindowExtractionConf:
class StackedHistogramConf (line 707) | class StackedHistogramConf:
class MixedDensityEventStackConf (line 718) | class MixedDensityEventStackConf:
class EventRepresentationFactory (line 733) | class EventRepresentationFactory(ABC):
method __init__ (line 734) | def __init__(self, config: DictConfig):
method name (line 739) | def name(self) -> str: ...
method create (line 742) | def create(self, height: int, width: int) -> Any: ...
class StackedHistogramFactory (line 745) | class StackedHistogramFactory(EventRepresentationFactory):
method name (line 747) | def name(self) -> str:
method create (line 751) | def create(self, height: int, width: int) -> StackedHistogram:
class MixedDensityStackFactory (line 761) | class MixedDensityStackFactory(EventRepresentationFactory):
method name (line 763) | def name(self) -> str:
method create (line 772) | def create(self, height: int, width: int) -> MixedDensityEventStack:
function get_configuration (line 787) | def get_configuration(
FILE: RVT/scripts/viz/viz_gt.py
function draw_bboxes_bbv (line 34) | def draw_bboxes_bbv(
function draw_predictions (line 85) | def draw_predictions(
function gen_gt_generator (line 99) | def gen_gt_generator(
FILE: RVT/train.py
function main (line 35) | def main(config: DictConfig):
FILE: RVT/utils/evaluation/prophesee/evaluation.py
function evaluate_list (line 5) | def evaluate_list(
FILE: RVT/utils/evaluation/prophesee/evaluator.py
class PropheseeEvaluator (line 9) | class PropheseeEvaluator:
method __init__ (line 13) | def __init__(self, dataset: str, downsample_by_2: bool):
method _reset_buffer (line 23) | def _reset_buffer(self):
method _add_to_buffer (line 30) | def _add_to_buffer(self, key: str, value: List[np.ndarray]):
method _get_from_buffer (line 38) | def _get_from_buffer(self, key: str) -> List[np.ndarray]:
method add_predictions (line 43) | def add_predictions(self, predictions: List[np.ndarray]):
method add_labels (line 46) | def add_labels(self, labels: List[np.ndarray]):
method reset_buffer (line 49) | def reset_buffer(self) -> None:
method has_data (line 53) | def has_data(self):
method evaluate_buffer (line 56) | def evaluate_buffer(
FILE: RVT/utils/evaluation/prophesee/io/box_filtering.py
function filter_boxes (line 19) | def filter_boxes(boxes, skip_ts=int(5e5), min_box_diag=60, min_box_side=...
FILE: RVT/utils/evaluation/prophesee/io/box_loading.py
function reformat_boxes (line 33) | def reformat_boxes(boxes):
function loaded_label_to_prophesee (line 53) | def loaded_label_to_prophesee(loaded_labels: ObjectLabels) -> np.ndarray:
function to_prophesee (line 66) | def to_prophesee(
FILE: RVT/utils/evaluation/prophesee/io/dat_events_tools.py
function load_td_data (line 24) | def load_td_data(filename, ev_count=-1, ev_start=0):
function _dat_transfer (line 53) | def _dat_transfer(dat, dtype, xyp=None):
function stream_td_data (line 88) | def stream_td_data(file_handle, buffer, dtype, ev_count=-1):
function count_events (line 113) | def count_events(filename):
function parse_header (line 128) | def parse_header(f):
function write_header (line 190) | def write_header(filename, height=240, width=320, ev_type=0):
function write_event_buffer (line 220) | def write_event_buffer(f, buffers):
FILE: RVT/utils/evaluation/prophesee/io/npy_events_tools.py
function stream_td_data (line 16) | def stream_td_data(file_handle, buffer, dtype, ev_count=-1):
function parse_header (line 31) | def parse_header(fhandle):
FILE: RVT/utils/evaluation/prophesee/io/psee_loader.py
class PSEELoader (line 17) | class PSEELoader(object):
method __init__ (line 22) | def __init__(self, datfile):
method reset (line 66) | def reset(self):
method event_count (line 72) | def event_count(self):
method get_size (line 79) | def get_size(self):
method __repr__ (line 83) | def __repr__(self):
method load_n_events (line 101) | def load_n_events(self, ev_count):
method load_delta_t (line 128) | def load_delta_t(self, delta_t):
method seek_event (line 176) | def seek_event(self, ev_count):
method seek_time (line 204) | def seek_time(self, final_time, term_criterion=100000):
method total_time (line 248) | def total_time(self):
method __del__ (line 269) | def __del__(self):
FILE: RVT/utils/evaluation/prophesee/metrics/coco_eval.py
function evaluate_detection (line 26) | def evaluate_detection(
function _match_times (line 70) | def _match_times(all_ts, gt_boxes, dt_boxes, time_tol):
function _coco_eval (line 107) | def _coco_eval(
function coco_eval_return_metrics (line 164) | def coco_eval_return_metrics(coco_eval: COCOeval):
function _to_coco_format (line 168) | def _to_coco_format(gts, detections, categories, height=240, width=304):
FILE: RVT/utils/evaluation/prophesee/visualize/vis_utils.py
function make_binary_histo (line 25) | def make_binary_histo(events, img=None, width=304, height=240):
function draw_bboxes_bbv (line 54) | def draw_bboxes_bbv(img, boxes, labelmap=LABELMAP_GEN1) -> np.ndarray:
function draw_bboxes (line 103) | def draw_bboxes(img, boxes, labelmap=LABELMAP_GEN1) -> None:
FILE: RVT/utils/helpers.py
function torch_uniform_sample_scalar (line 6) | def torch_uniform_sample_scalar(min_value: float, max_value: float):
function clamp (line 13) | def clamp(
FILE: RVT/utils/padding.py
class InputPadderFromShape (line 7) | class InputPadderFromShape:
method __init__ (line 8) | def __init__(
method _pad_tensor_impl (line 35) | def _pad_tensor_impl(
method pad_tensor_ev_repr (line 61) | def pad_tensor_ev_repr(self, ev_repr: th.Tensor) -> th.Tensor:
method pad_token_mask (line 74) | def pad_token_mask(self, token_mask: th.Tensor):
FILE: RVT/utils/preprocessing.py
function _blosc_opts (line 1) | def _blosc_opts(complevel=1, complib="blosc:zstd", shuffle="byte"):
FILE: RVT/utils/timers.py
class CudaTimer (line 12) | class CudaTimer:
method __init__ (line 13) | def __init__(self, device: torch.device, timer_name: str):
method __enter__ (line 24) | def __enter__(self):
method __exit__ (line 29) | def __exit__(self, *args):
function cuda_timer_decorator (line 36) | def cuda_timer_decorator(device: torch.device, timer_name: str):
class TimerDummy (line 49) | class TimerDummy:
method __init__ (line 50) | def __init__(self, *args, **kwargs):
method __enter__ (line 53) | def __enter__(self):
method __exit__ (line 56) | def __exit__(self, *args):
class Timer (line 60) | class Timer:
method __init__ (line 61) | def __init__(self, timer_name=""):
method __enter__ (line 66) | def __enter__(self):
method __exit__ (line 70) | def __exit__(self, *args):
function print_timing_info (line 76) | def print_timing_info():
FILE: RVT/validation.py
function main (line 31) | def main(config: DictConfig):
Condensed preview — 162 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (673K chars).
[
{
"path": ".gitignore",
"chars": 3213,
"preview": "# Byte-compiled / optimized / DLL files\n__pycache__/\n*.py[cod]\n*$py.class\n\n# C extensions\n*.so\n\n# Distribution / packagi"
},
{
"path": "README.md",
"chars": 9379,
"preview": "# [CVPR'24 Spotlight] State Space Models for Event Cameras\n<p align=\"center\">\n <a href=\"https://www.youtube.com/watch?v="
},
{
"path": "RVT/.gitignore",
"chars": 1799,
"preview": "# Byte-compiled / optimized / DLL files\n__pycache__/\n*.py[cod]\n*$py.class\n\n# C extensions\n*.so\n\n# Distribution / packagi"
},
{
"path": "RVT/LICENSE",
"chars": 1071,
"preview": "MIT License\n\nCopyright (c) 2023 Mathias Gehrig\n\nPermission is hereby granted, free of charge, to any person obtaining a "
},
{
"path": "RVT/README.md",
"chars": 7601,
"preview": "# RVT: Recurrent Vision Transformers for Object Detection with Event Cameras\n<p align=\"center\">\n <img src=\"https://rpg."
},
{
"path": "RVT/callbacks/custom.py",
"chars": 1346,
"preview": "from omegaconf import DictConfig\nfrom lightning.pytorch.callbacks import Callback\nfrom lightning.pytorch.callbacks impor"
},
{
"path": "RVT/callbacks/detection.py",
"chars": 3956,
"preview": "from enum import Enum, auto\nfrom typing import Any\n\nimport torch\nfrom einops import rearrange\nfrom omegaconf import Dict"
},
{
"path": "RVT/callbacks/gradflow.py",
"chars": 1147,
"preview": "from typing import Any\n\nimport lightning.pytorch as pl\nfrom lightning.pytorch.callbacks import Callback\nfrom lightning.p"
},
{
"path": "RVT/callbacks/utils/visualization.py",
"chars": 786,
"preview": "import pandas as pd\nimport plotly.express as px\n\n\ndef get_grad_flow_figure(named_params):\n \"\"\"Creates figure to visua"
},
{
"path": "RVT/callbacks/viz_base.py",
"chars": 6218,
"preview": "import random\nfrom enum import Enum\nfrom typing import Any, List, Optional, Type, Union\n\nimport numpy as np\nimport pytor"
},
{
"path": "RVT/config/dataset/base.yaml",
"chars": 699,
"preview": "name: ???\npath: ???\ntrain:\n sampling: 'mixed' # ('random', 'stream', 'mixed')\n random:\n weighted_sampling: False\n "
},
{
"path": "RVT/config/dataset/gen1.yaml",
"chars": 183,
"preview": "defaults:\n - base\n\nname: gen1\nev_repr_name: 'stacked_histogram_dt=50_nbins=10'\nsequence_length: 21\nresolution_hw: [240,"
},
{
"path": "RVT/config/dataset/gen4.yaml",
"chars": 183,
"preview": "defaults:\n - base\n\nname: gen4\nev_repr_name: 'stacked_histogram_dt=50_nbins=10'\nsequence_length: 10\nresolution_hw: [720,"
},
{
"path": "RVT/config/experiment/gen1/base.yaml",
"chars": 103,
"preview": "# @package _global_\ndefaults:\n - default\n\nmodel:\n backbone:\n embed_dim: 64\n fpn:\n depth: 0.67\n"
},
{
"path": "RVT/config/experiment/gen1/default.yaml",
"chars": 779,
"preview": "# @package _global_\ndefaults:\n - /model/maxvit_yolox: default\n\ntraining:\n precision: 32\n max_epochs: 10000\n max_step"
},
{
"path": "RVT/config/experiment/gen1/small.yaml",
"chars": 151,
"preview": "# @package _global_\ndefaults:\n - default\n\nmodel:\n backbone:\n embed_dim: 48\n stage:\n attention:\n dim_"
},
{
"path": "RVT/config/experiment/gen4/base.yaml",
"chars": 103,
"preview": "# @package _global_\ndefaults:\n - default\n\nmodel:\n backbone:\n embed_dim: 64\n fpn:\n depth: 0.67\n"
},
{
"path": "RVT/config/experiment/gen4/default.yaml",
"chars": 808,
"preview": "# @package _global_\ndefaults:\n - /model/maxvit_yolox: default\n\ntraining:\n precision: 32\n max_epochs: 10000\n max_step"
},
{
"path": "RVT/config/experiment/gen4/small.yaml",
"chars": 151,
"preview": "# @package _global_\ndefaults:\n - default\n\nmodel:\n backbone:\n embed_dim: 48\n stage:\n attention:\n dim_"
},
{
"path": "RVT/config/general.yaml",
"chars": 2336,
"preview": "reproduce:\n seed_everything: null # Union[int, null]\n deterministic_flag: False # Must be true for fully deterministic"
},
{
"path": "RVT/config/model/base.yaml",
"chars": 9,
"preview": "name: ???"
},
{
"path": "RVT/config/model/maxvit_yolox/default.yaml",
"chars": 1496,
"preview": "# @package _global_\ndefaults:\n - override /model: rnndet\n\nmodel:\n backbone:\n name: MaxViTRNN\n compile:\n ena"
},
{
"path": "RVT/config/model/rnndet.yaml",
"chars": 152,
"preview": "defaults:\n - base\n\nname: rnndet\nbackbone:\n name: ???\nfpn:\n name: ???\nhead:\n name: ???\npostprocess:\n confidence_thre"
},
{
"path": "RVT/config/modifier.py",
"chars": 2585,
"preview": "import os\nfrom typing import Tuple\n\nimport math\nfrom omegaconf import DictConfig, open_dict\n\nfrom data.utils.spatial imp"
},
{
"path": "RVT/config/train.yaml",
"chars": 104,
"preview": "defaults:\n - general\n - dataset: ???\n - model: rnndet\n - optional model/dataset: ${model}_${dataset}"
},
{
"path": "RVT/config/val.yaml",
"chars": 239,
"preview": "defaults:\n - dataset: ???\n - model: rnndet\n - _self_\n\ncheckpoint: ???\nuse_test_set: False\nhardware:\n num_workers:\n "
},
{
"path": "RVT/data/genx_utils/collate.py",
"chars": 1657,
"preview": "from copy import deepcopy\nfrom typing import Any, Callable, Dict, Optional, Type, Tuple, Union\n\nimport torch\n\nfrom data."
},
{
"path": "RVT/data/genx_utils/collate_from_pytorch.py",
"chars": 7329,
"preview": "import collections\nimport contextlib\nimport re\n\nimport torch\n\ntorch_is_version_1 = int(torch.__version__.split(\".\")[0]) "
},
{
"path": "RVT/data/genx_utils/dataset_rnd.py",
"chars": 6081,
"preview": "from collections import namedtuple\nfrom collections.abc import Iterable\nfrom pathlib import Path\nfrom typing import List"
},
{
"path": "RVT/data/genx_utils/dataset_streaming.py",
"chars": 4531,
"preview": "from functools import partialmethod\nfrom pathlib import Path\nfrom typing import List, Union\n\nfrom omegaconf import DictC"
},
{
"path": "RVT/data/genx_utils/labels.py",
"chars": 17826,
"preview": "from __future__ import annotations\n\nfrom typing import List, Tuple, Union, Optional\n\nimport math\nimport numpy as np\nimpo"
},
{
"path": "RVT/data/genx_utils/sequence_base.py",
"chars": 4355,
"preview": "from pathlib import Path\nfrom typing import Any, List, Optional\n\nimport h5py\nimport numpy as np\nimport torch\nfrom torchd"
},
{
"path": "RVT/data/genx_utils/sequence_for_streaming.py",
"chars": 8535,
"preview": "from pathlib import Path\nfrom typing import List, Optional, Union, Tuple\n\nimport h5py\nimport numpy as np\nimport torch\nfr"
},
{
"path": "RVT/data/genx_utils/sequence_rnd.py",
"chars": 3267,
"preview": "from pathlib import Path\n\nfrom data.genx_utils.labels import SparselyBatchedObjectLabels\nfrom data.genx_utils.sequence_b"
},
{
"path": "RVT/data/utils/augmentor.py",
"chars": 23173,
"preview": "import collections.abc as abc\nfrom dataclasses import dataclass\nfrom typing import Any, Optional, Tuple, Union\nfrom warn"
},
{
"path": "RVT/data/utils/representations.py",
"chars": 7808,
"preview": "from abc import ABC, abstractmethod\nfrom typing import Optional, Tuple\n\nimport math\nimport numpy as np\nimport torch as t"
},
{
"path": "RVT/data/utils/spatial.py",
"chars": 636,
"preview": "from omegaconf import DictConfig\n\nfrom data.utils.types import DatasetType\n\n_type_2_hw = {\n DatasetType.GEN1: (240, 3"
},
{
"path": "RVT/data/utils/stream_concat_datapipe.py",
"chars": 4309,
"preview": "from typing import Any, Iterator, List, Optional, Type\n\nimport torch as th\nimport torch.distributed as dist\nfrom torch.u"
},
{
"path": "RVT/data/utils/stream_sharded_datapipe.py",
"chars": 4966,
"preview": "from typing import Any, List, Optional\n\nimport torch\nimport torch.distributed as dist\nfrom torch.utils.data import DataL"
},
{
"path": "RVT/data/utils/types.py",
"chars": 1118,
"preview": "from enum import auto, Enum\n\ntry:\n from enum import StrEnum\nexcept ImportError:\n from strenum import StrEnum\nfrom "
},
{
"path": "RVT/loggers/utils.py",
"chars": 1734,
"preview": "from pathlib import Path\nfrom typing import Union\n\nimport wandb\nfrom omegaconf import DictConfig, OmegaConf\n\nfrom logger"
},
{
"path": "RVT/loggers/wandb_logger.py",
"chars": 17065,
"preview": "\"\"\"\nThis is a modified version of the Pytorch Lightning logger\n\"\"\"\n\nimport time\nfrom argparse import Namespace\nfrom path"
},
{
"path": "RVT/models/detection/__init_.py",
"chars": 0,
"preview": ""
},
{
"path": "RVT/models/detection/recurrent_backbone/__init__.py",
"chars": 297,
"preview": "from omegaconf import DictConfig\n\nfrom .maxvit_rnn import RNNDetector as MaxViTRNNDetector\n\n\ndef build_recurrent_backbon"
},
{
"path": "RVT/models/detection/recurrent_backbone/base.py",
"chars": 295,
"preview": "from typing import Tuple\n\nimport torch.nn as nn\n\n\nclass BaseDetector(nn.Module):\n def get_stage_dims(self, stages: Tu"
},
{
"path": "RVT/models/detection/recurrent_backbone/maxvit_rnn.py",
"chars": 8686,
"preview": "from typing import Dict, Optional, Tuple\nimport torch as th\nimport torch.nn as nn\nfrom omegaconf import DictConfig, Omeg"
},
{
"path": "RVT/models/detection/yolox/models/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "RVT/models/detection/yolox/models/losses.py",
"chars": 1727,
"preview": "#!/usr/bin/env python\n# -*- encoding: utf-8 -*-\n# Copyright (c) Megvii Inc. All rights reserved.\n\nimport torch\nimport to"
},
{
"path": "RVT/models/detection/yolox/models/network_blocks.py",
"chars": 3960,
"preview": "#!/usr/bin/env python\n# -*- encoding: utf-8 -*-\n# Copyright (c) Megvii Inc. All rights reserved.\n\nimport torch\nimport to"
},
{
"path": "RVT/models/detection/yolox/models/yolo_head.py",
"chars": 22378,
"preview": "\"\"\"\nOriginal Yolox Head code with slight modifications\n\"\"\"\n\nimport math\nfrom typing import Dict, Optional\n\nimport torch\n"
},
{
"path": "RVT/models/detection/yolox/utils/__init__.py",
"chars": 146,
"preview": "#!/usr/bin/env python3\n# -*- coding:utf-8 -*-\n# Copyright (c) Megvii Inc. All rights reserved.\n\nfrom .boxes import *\nfro"
},
{
"path": "RVT/models/detection/yolox/utils/boxes.py",
"chars": 4499,
"preview": "#!/usr/bin/env python3\n# -*- coding:utf-8 -*-\n# Copyright (c) Megvii Inc. All rights reserved.\n\nimport numpy as np\n\nimpo"
},
{
"path": "RVT/models/detection/yolox/utils/compat.py",
"chars": 310,
"preview": "#!/usr/bin/env python3\n# -*- coding:utf-8 -*-\n\nimport torch\n\n_TORCH_VER = [int(x) for x in torch.__version__.split(\".\")["
},
{
"path": "RVT/models/detection/yolox_extension/models/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "RVT/models/detection/yolox_extension/models/build.py",
"chars": 1162,
"preview": "from typing import Tuple\n\nfrom omegaconf import OmegaConf, DictConfig\n\nfrom .yolo_pafpn import YOLOPAFPN\nfrom ...yolox.m"
},
{
"path": "RVT/models/detection/yolox_extension/models/detector.py",
"chars": 2937,
"preview": "from typing import Dict, Optional, Tuple, Union\n\nimport torch as th\nfrom omegaconf import DictConfig\n\ntry:\n from torc"
},
{
"path": "RVT/models/detection/yolox_extension/models/yolo_pafpn.py",
"chars": 4319,
"preview": "\"\"\"\nOriginal Yolox PAFPN code with slight modifications\n\"\"\"\n\nfrom typing import Dict, Optional, Tuple\n\nimport torch as t"
},
{
"path": "RVT/models/layers/maxvit/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "RVT/models/layers/maxvit/layers/__init__.py",
"chars": 2837,
"preview": "from .activations import *\nfrom .adaptive_avgmax_pool import (\n adaptive_avgmax_pool2d,\n select_adaptive_pool2d,\n "
},
{
"path": "RVT/models/layers/maxvit/layers/activations.py",
"chars": 4044,
"preview": "\"\"\" Activations\n\nA collection of activations fn and modules with a common interface so that they can\neasily be swapped. "
},
{
"path": "RVT/models/layers/maxvit/layers/activations_jit.py",
"chars": 2534,
"preview": "\"\"\" Activations\n\nA collection of jit-scripted activations fn and modules with a common interface so that they can\neasily"
},
{
"path": "RVT/models/layers/maxvit/layers/activations_me.py",
"chars": 6025,
"preview": "\"\"\" Activations (memory-efficient w/ custom autograd)\n\nA collection of activations fn and modules with a common interfac"
},
{
"path": "RVT/models/layers/maxvit/layers/adaptive_avgmax_pool.py",
"chars": 3967,
"preview": "\"\"\" PyTorch selectable adaptive pooling\nAdaptive pooling with the ability to select the type of pooling from:\n * 'avg"
},
{
"path": "RVT/models/layers/maxvit/layers/attention_pool2d.py",
"chars": 4966,
"preview": "\"\"\" Attention Pool 2D\n\nImplementations of 2D spatial feature pooling using multi-head attention instead of average pool."
},
{
"path": "RVT/models/layers/maxvit/layers/blur_pool.py",
"chars": 1639,
"preview": "\"\"\"\nBlurPool layer inspired by\n - Kornia's Max_BlurPool2d\n - Making Convolutional Networks Shift-Invariant Again :cite:`"
},
{
"path": "RVT/models/layers/maxvit/layers/bottleneck_attn.py",
"chars": 7167,
"preview": "\"\"\" Bottleneck Self Attention (Bottleneck Transformers)\n\nPaper: `Bottleneck Transformers for Visual Recognition` - https"
},
{
"path": "RVT/models/layers/maxvit/layers/cbam.py",
"chars": 4822,
"preview": "\"\"\" CBAM (sort-of) Attention\n\nExperimental impl of CBAM: Convolutional Block Attention Module: https://arxiv.org/abs/180"
},
{
"path": "RVT/models/layers/maxvit/layers/classifier.py",
"chars": 2407,
"preview": "\"\"\" Classifier head and layer factory\n\nHacked together by / Copyright 2020 Ross Wightman\n\"\"\"\n\nfrom torch import nn as nn"
},
{
"path": "RVT/models/layers/maxvit/layers/cond_conv2d.py",
"chars": 5796,
"preview": "\"\"\" PyTorch Conditionally Parameterized Convolution (CondConv)\n\nPaper: CondConv: Conditionally Parameterized Convolution"
},
{
"path": "RVT/models/layers/maxvit/layers/config.py",
"chars": 3077,
"preview": "\"\"\" Model / Layer Config singleton state\n\"\"\"\n\nfrom typing import Any, Optional\n\n__all__ = [\n \"is_exportable\",\n \"is"
},
{
"path": "RVT/models/layers/maxvit/layers/conv2d_same.py",
"chars": 1663,
"preview": "\"\"\" Conv2d w/ Same Padding\n\nHacked together by / Copyright 2020 Ross Wightman\n\"\"\"\n\nimport torch\nimport torch.nn as nn\nim"
},
{
"path": "RVT/models/layers/maxvit/layers/conv_bn_act.py",
"chars": 3533,
"preview": "\"\"\" Conv2d + BN + Act\n\nHacked together by / Copyright 2020 Ross Wightman\n\"\"\"\n\nimport functools\nfrom torch import nn as n"
},
{
"path": "RVT/models/layers/maxvit/layers/create_act.py",
"chars": 5246,
"preview": "\"\"\" Activation Factory\nHacked together by / Copyright 2020 Ross Wightman\n\"\"\"\n\nfrom typing import Union, Callable, Type\n\n"
},
{
"path": "RVT/models/layers/maxvit/layers/create_attn.py",
"chars": 3515,
"preview": "\"\"\" Attention Factory\n\nHacked together by / Copyright 2021 Ross Wightman\n\"\"\"\n\nimport torch\nfrom functools import partial"
},
{
"path": "RVT/models/layers/maxvit/layers/create_conv2d.py",
"chars": 1705,
"preview": "\"\"\" Create Conv2d Factory Method\n\nHacked together by / Copyright 2020 Ross Wightman\n\"\"\"\n\nfrom .mixed_conv2d import Mixed"
},
{
"path": "RVT/models/layers/maxvit/layers/create_norm.py",
"chars": 1820,
"preview": "\"\"\" Norm Layer Factory\n\nCreate norm modules by string (to mirror create_act and creat_norm-act fns)\n\nCopyright 2022 Ross"
},
{
"path": "RVT/models/layers/maxvit/layers/create_norm_act.py",
"chars": 3811,
"preview": "\"\"\" NormAct (Normalizaiton + Activation Layer) Factory\n\nCreate norm + act combo modules that attempt to be backwards com"
},
{
"path": "RVT/models/layers/maxvit/layers/drop.py",
"chars": 7290,
"preview": "\"\"\" DropBlock, DropPath\n\nPyTorch implementations of DropBlock and DropPath (Stochastic Depth) regularization layers.\n\nPa"
},
{
"path": "RVT/models/layers/maxvit/layers/eca.py",
"chars": 6583,
"preview": "\"\"\"\nECA module from ECAnet\n\npaper: ECA-Net: Efficient Channel Attention for Deep Convolutional Neural Networks\nhttps://a"
},
{
"path": "RVT/models/layers/maxvit/layers/evo_norm.py",
"chars": 14891,
"preview": "\"\"\" EvoNorm in PyTorch\n\nBased on `Evolving Normalization-Activation Layers` - https://arxiv.org/abs/2004.02967\n@inprocee"
},
{
"path": "RVT/models/layers/maxvit/layers/fast_norm.py",
"chars": 2430,
"preview": "\"\"\" 'Fast' Normalization Functions\n\nFor GroupNorm and LayerNorm these functions bypass typical AMP upcast to float32.\n\nA"
},
{
"path": "RVT/models/layers/maxvit/layers/filter_response_norm.py",
"chars": 2703,
"preview": "\"\"\" Filter Response Norm in PyTorch\n\nBased on `Filter Response Normalization Layer` - https://arxiv.org/abs/1911.09737\n\n"
},
{
"path": "RVT/models/layers/maxvit/layers/gather_excite.py",
"chars": 4416,
"preview": "\"\"\" Gather-Excite Attention Block\n\nPaper: `Gather-Excite: Exploiting Feature Context in CNNs` - https://arxiv.org/abs/18"
},
{
"path": "RVT/models/layers/maxvit/layers/global_context.py",
"chars": 2671,
"preview": "\"\"\" Global Context Attention Block\n\nPaper: `GCNet: Non-local Networks Meet Squeeze-Excitation Networks and Beyond`\n -"
},
{
"path": "RVT/models/layers/maxvit/layers/halo_attn.py",
"chars": 11078,
"preview": "\"\"\" Halo Self Attention\n\nPaper: `Scaling Local Self-Attention for Parameter Efficient Visual Backbones`\n - https://ar"
},
{
"path": "RVT/models/layers/maxvit/layers/helpers.py",
"chars": 1049,
"preview": "\"\"\" Layer/Module Helpers\n\nHacked together by / Copyright 2020 Ross Wightman\n\"\"\"\n\nfrom itertools import repeat\nimport col"
},
{
"path": "RVT/models/layers/maxvit/layers/inplace_abn.py",
"chars": 3619,
"preview": "import torch\nfrom torch import nn as nn\n\ntry:\n from inplace_abn.functions import inplace_abn, inplace_abn_sync\n\n h"
},
{
"path": "RVT/models/layers/maxvit/layers/lambda_layer.py",
"chars": 6383,
"preview": "\"\"\" Lambda Layer\n\nPaper: `LambdaNetworks: Modeling Long-Range Interactions Without Attention`\n - https://arxiv.org/ab"
},
{
"path": "RVT/models/layers/maxvit/layers/linear.py",
"chars": 745,
"preview": "\"\"\" Linear layer (alternate definition)\n\"\"\"\n\nimport torch\nimport torch.nn.functional as F\nfrom torch import nn as nn\n\n\nc"
},
{
"path": "RVT/models/layers/maxvit/layers/median_pool.py",
"chars": 1738,
"preview": "\"\"\" Median Pool\nHacked together by / Copyright 2020 Ross Wightman\n\"\"\"\n\nimport torch.nn as nn\nimport torch.nn.functional "
},
{
"path": "RVT/models/layers/maxvit/layers/mixed_conv2d.py",
"chars": 2064,
"preview": "\"\"\" PyTorch Mixed Convolution\n\nPaper: MixConv: Mixed Depthwise Convolutional Kernels (https://arxiv.org/abs/1907.09595)\n"
},
{
"path": "RVT/models/layers/maxvit/layers/ml_decoder.py",
"chars": 7367,
"preview": "from typing import Optional\n\nimport torch\nfrom torch import nn\nfrom torch import nn, Tensor\nfrom torch.nn.modules.transf"
},
{
"path": "RVT/models/layers/maxvit/layers/mlp.py",
"chars": 4616,
"preview": "\"\"\" MLP module w/ dropout and configurable activation layer\n\nHacked together by / Copyright 2020 Ross Wightman\n\"\"\"\n\nfrom"
},
{
"path": "RVT/models/layers/maxvit/layers/non_local_attn.py",
"chars": 6899,
"preview": "\"\"\" Bilinear-Attention-Transform and Non-Local Attention\n\nPaper: `Non-Local Neural Networks With Grouped Bilinear Attent"
},
{
"path": "RVT/models/layers/maxvit/layers/norm.py",
"chars": 4757,
"preview": "\"\"\" Normalization layers and wrappers\n\nNorm layer definitions that support fast norm and consistent channel arg order (a"
},
{
"path": "RVT/models/layers/maxvit/layers/norm_act.py",
"chars": 10885,
"preview": "\"\"\" Normalization + Activation Layers\n\nProvides Norm+Act fns for standard PyTorch norm layers such as\n* BatchNorm\n* Grou"
},
{
"path": "RVT/models/layers/maxvit/layers/padding.py",
"chars": 2229,
"preview": "\"\"\" Padding Helpers\n\nHacked together by / Copyright 2020 Ross Wightman\n\"\"\"\n\nimport math\nfrom typing import List, Tuple\n\n"
},
{
"path": "RVT/models/layers/maxvit/layers/patch_embed.py",
"chars": 1641,
"preview": "\"\"\" Image to Patch Embedding using Conv2d\n\nA convolution based approach to patchifying a 2D image w/ embedding projectio"
},
{
"path": "RVT/models/layers/maxvit/layers/pool2d_same.py",
"chars": 3272,
"preview": "\"\"\" AvgPool2d w/ Same Padding\n\nHacked together by / Copyright 2020 Ross Wightman\n\"\"\"\n\nimport torch\nimport torch.nn as nn"
},
{
"path": "RVT/models/layers/maxvit/layers/pos_embed.py",
"chars": 7670,
"preview": "import math\nfrom typing import List, Tuple, Optional, Union\n\nimport torch\nfrom torch import nn as nn\n\n\ndef pixel_freq_ba"
},
{
"path": "RVT/models/layers/maxvit/layers/selective_kernel.py",
"chars": 5762,
"preview": "\"\"\" Selective Kernel Convolution/Attention\n\nPaper: Selective Kernel Networks (https://arxiv.org/abs/1903.06586)\n\nHacked "
},
{
"path": "RVT/models/layers/maxvit/layers/separable_conv.py",
"chars": 3022,
"preview": "\"\"\" Depthwise Separable Conv Modules\n\nBasic DWS convs. Other variations of DWS exist with batch norm or activations betw"
},
{
"path": "RVT/models/layers/maxvit/layers/space_to_depth.py",
"chars": 1831,
"preview": "import torch\nimport torch.nn as nn\n\n\nclass SpaceToDepth(nn.Module):\n def __init__(self, block_size=4):\n super("
},
{
"path": "RVT/models/layers/maxvit/layers/split_attn.py",
"chars": 3334,
"preview": "\"\"\" Split Attention Conv2d (for ResNeSt Models)\n\nPaper: `ResNeSt: Split-Attention Networks` - /https://arxiv.org/abs/200"
},
{
"path": "RVT/models/layers/maxvit/layers/split_batchnorm.py",
"chars": 3656,
"preview": "\"\"\" Split BatchNorm\n\nA PyTorch BatchNorm layer that splits input batch into N equal parts and passes each through\na sepa"
},
{
"path": "RVT/models/layers/maxvit/layers/squeeze_excite.py",
"chars": 3124,
"preview": "\"\"\" Squeeze-and-Excitation Channel Attention\n\nAn SE implementation originally based on PyTorch SE-Net impl.\nHas since ev"
},
{
"path": "RVT/models/layers/maxvit/layers/std_conv.py",
"chars": 6866,
"preview": "\"\"\" Convolution with Weight Standardization (StdConv and ScaledStdConv)\n\nStdConv:\n@article{weightstandardization,\n auth"
},
{
"path": "RVT/models/layers/maxvit/layers/test_time_pool.py",
"chars": 2046,
"preview": "\"\"\" Test Time Pooling (Average-Max Pool)\n\nHacked together by / Copyright 2020 Ross Wightman\n\"\"\"\n\nimport logging\nfrom tor"
},
{
"path": "RVT/models/layers/maxvit/layers/trace_utils.py",
"chars": 336,
"preview": "try:\n from torch import _assert\nexcept ImportError:\n\n def _assert(condition: bool, message: str):\n assert c"
},
{
"path": "RVT/models/layers/maxvit/layers/weight_init.py",
"chars": 4781,
"preview": "import torch\nimport math\nimport warnings\n\nfrom torch.nn.init import _calculate_fan_in_and_fan_out\n\n\ndef _trunc_normal_(t"
},
{
"path": "RVT/models/layers/maxvit/maxvit.py",
"chars": 15426,
"preview": "\"\"\"\nPart of this code stems from rwightman's MaxVit implementation:\nhttps://github.com/huggingface/pytorch-image-models/"
},
{
"path": "RVT/models/layers/rnn.py",
"chars": 2558,
"preview": "from typing import Optional, Tuple\n\nimport torch as th\nimport torch.nn as nn\n\n\nclass DWSConvLSTM2d(nn.Module):\n \"\"\"LS"
},
{
"path": "RVT/models/layers/s5/__init__.py",
"chars": 24,
"preview": "from .s5_model import *\n"
},
{
"path": "RVT/models/layers/s5/jax_func.py",
"chars": 15446,
"preview": "import torch\nimport numpy as np\nfrom torch.utils._pytree import tree_flatten, tree_unflatten\nfrom typing import (\n ov"
},
{
"path": "RVT/models/layers/s5/s5_init.py",
"chars": 8582,
"preview": "import torch\nimport numpy as np\nfrom .jax_func import variance_scaling, lecun_normal, uniform\nimport scipy.linalg\n\n# Ini"
},
{
"path": "RVT/models/layers/s5/s5_model.py",
"chars": 18273,
"preview": "import torch\nimport torch.nn.functional as F\nfrom typing import Literal, Tuple, Optional\nimport os, sys\nimport math\n\nROO"
},
{
"path": "RVT/models/layers/s5/triton_comparison.py",
"chars": 5592,
"preview": "import torch\nimport numpy as np\nimport time\nimport triton\nimport triton.language as tl\nfrom triton.runtime.jit import Te"
},
{
"path": "RVT/modules/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "RVT/modules/data/genx.py",
"chars": 11114,
"preview": "from functools import partial\nfrom typing import Any, Dict, Optional, Union\n\nimport math\nimport lightning.pytorch as pl\n"
},
{
"path": "RVT/modules/detection.py",
"chars": 20788,
"preview": "from typing import Any, Optional, Tuple, Union, Dict\nfrom warnings import warn\n\nimport numpy as np\nimport lightning.pyto"
},
{
"path": "RVT/modules/utils/detection.py",
"chars": 6227,
"preview": "from enum import Enum, auto\nfrom typing import List, Optional, Union, Tuple, Dict, Any\n\nimport torch\nimport torch as th\n"
},
{
"path": "RVT/modules/utils/fetch.py",
"chars": 1148,
"preview": "import lightning.pytorch as pl\nfrom omegaconf import DictConfig\n\nfrom modules.data.genx import DataModule as genx_data_m"
},
{
"path": "RVT/scripts/genx/README.md",
"chars": 2377,
"preview": "# Pre-Processing the Original Dataset\n\n### 1. Download the data\n<table><tbody>\n<th valign=\"bottom\"></th>\n<th valign=\"bot"
},
{
"path": "RVT/scripts/genx/conf_preprocess/extraction/const_count.yaml",
"chars": 26,
"preview": "method: COUNT\nvalue: 50000"
},
{
"path": "RVT/scripts/genx/conf_preprocess/extraction/const_duration.yaml",
"chars": 55,
"preview": "method: DURATION\n# value is in milliseconds!\nvalue: 50\n"
},
{
"path": "RVT/scripts/genx/conf_preprocess/extraction/frequencies/const_duration_100hz.yaml",
"chars": 55,
"preview": "method: DURATION\n# value is in milliseconds!\nvalue: 10\n"
},
{
"path": "RVT/scripts/genx/conf_preprocess/extraction/frequencies/const_duration_200hz.yaml",
"chars": 54,
"preview": "method: DURATION\n# value is in milliseconds!\nvalue: 5\n"
},
{
"path": "RVT/scripts/genx/conf_preprocess/extraction/frequencies/const_duration_40hz.yaml",
"chars": 55,
"preview": "method: DURATION\n# value is in milliseconds!\nvalue: 25\n"
},
{
"path": "RVT/scripts/genx/conf_preprocess/extraction/frequencies/const_duration_80hz.yaml",
"chars": 55,
"preview": "method: DURATION\n# value is in milliseconds!\nvalue: 12\n"
},
{
"path": "RVT/scripts/genx/conf_preprocess/filter_gen1.yaml",
"chars": 59,
"preview": "apply_psee_bbox_filter: True\napply_faulty_bbox_filter: True"
},
{
"path": "RVT/scripts/genx/conf_preprocess/filter_gen4.yaml",
"chars": 60,
"preview": "apply_psee_bbox_filter: False\napply_faulty_bbox_filter: True"
},
{
"path": "RVT/scripts/genx/conf_preprocess/representation/mixeddensity_stack.yaml",
"chars": 54,
"preview": "name: \"mixeddensity_stack\"\nnbins: 10\ncount_cutoff: 32\n"
},
{
"path": "RVT/scripts/genx/conf_preprocess/representation/stacked_hist.yaml",
"chars": 53,
"preview": "name: \"stacked_histogram\"\nnbins: 10\ncount_cutoff: 10\n"
},
{
"path": "RVT/scripts/genx/preprocess_dataset.py",
"chars": 32439,
"preview": "import os\n\nos.environ[\"OMP_NUM_THREADS\"] = \"1\"\nos.environ[\"OPENBLAS_NUM_THREADS\"] = \"1\"\nos.environ[\"MKL_NUM_THREADS\"] = "
},
{
"path": "RVT/scripts/genx/preprocess_dataset.sh",
"chars": 446,
"preview": "NUM_PROCESSES=20 # set to the number of parallel processes to use\nDATA_DIR=/data/scratch1/nzubic/datasets/gen1_tar/\nDES"
},
{
"path": "RVT/scripts/viz/viz_gt.py",
"chars": 6244,
"preview": "import os\n\nos.environ[\"OMP_NUM_THREADS\"] = \"1\" # export OMP_NUM_THREADS=1\nos.environ[\"OPENBLAS_NUM_THREADS\"] = \"1\" # e"
},
{
"path": "RVT/train.py",
"chars": 5816,
"preview": "import os\n\nos.environ[\"CUDA_DEVICE_ORDER\"] = \"PCI_BUS_ID\"\nos.environ[\"OMP_NUM_THREADS\"] = \"1\"\nos.environ[\"OPENBLAS_NUM_T"
},
{
"path": "RVT/utils/evaluation/prophesee/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "RVT/utils/evaluation/prophesee/evaluation.py",
"chars": 1678,
"preview": "from .io.box_filtering import filter_boxes\nfrom .metrics.coco_eval import evaluate_detection\n\n\ndef evaluate_list(\n re"
},
{
"path": "RVT/utils/evaluation/prophesee/evaluator.py",
"chars": 2442,
"preview": "from typing import Any, List, Optional, Dict\nfrom warnings import warn\n\nimport numpy as np\n\nfrom utils.evaluation.prophe"
},
{
"path": "RVT/utils/evaluation/prophesee/io/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "RVT/utils/evaluation/prophesee/io/box_filtering.py",
"chars": 1265,
"preview": "\"\"\"\nDefine same filtering that we apply in:\n\"Learning to detect objects on a 1 Megapixel Event Camera\" by Etienne Perot "
},
{
"path": "RVT/utils/evaluation/prophesee/io/box_loading.py",
"chars": 4543,
"preview": "\"\"\"\nDefines some tools to handle events.\nIn particular :\n -> defines events' types\n -> defines functions to read e"
},
{
"path": "RVT/utils/evaluation/prophesee/io/dat_events_tools.py",
"chars": 7451,
"preview": "\"\"\"\nDefines some tools to handle events.\nIn particular :\n -> defines events' types\n -> defines functions to read e"
},
{
"path": "RVT/utils/evaluation/prophesee/io/npy_events_tools.py",
"chars": 2079,
"preview": "#!/usr/bin/env python\n\n\"\"\"\nDefines some tools to handle events, mimicking dat_events_tools.py.\nIn particular :\n -> de"
},
{
"path": "RVT/utils/evaluation/prophesee/io/psee_loader.py",
"chars": 9553,
"preview": "\"\"\"\nThis class loads events from dat or npy files\n\nCopyright: (c) 2019-2020 Prophesee\n\"\"\"\n\nfrom __future__ import print_"
},
{
"path": "RVT/utils/evaluation/prophesee/metrics/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "RVT/utils/evaluation/prophesee/metrics/coco_eval.py",
"chars": 6837,
"preview": "\"\"\"\nCompute the COCO metric on bounding box files by matching timestamps\n\nCopyright: (c) 2019-2020 Prophesee\n\"\"\"\n\nfrom _"
},
{
"path": "RVT/utils/evaluation/prophesee/visualize/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "RVT/utils/evaluation/prophesee/visualize/vis_utils.py",
"chars": 4280,
"preview": "\"\"\"\nFunctions to display events and boxes\nCopyright: (c) 2019-2020 Prophesee\n\"\"\"\n\nfrom __future__ import print_function\n"
},
{
"path": "RVT/utils/helpers.py",
"chars": 467,
"preview": "from typing import Union\n\nimport torch as th\n\n\ndef torch_uniform_sample_scalar(min_value: float, max_value: float):\n "
},
{
"path": "RVT/utils/padding.py",
"chars": 2613,
"preview": "from typing import Any, List, Tuple\n\nimport torch as th\nimport torch.nn.functional as F\n\n\nclass InputPadderFromShape:\n "
},
{
"path": "RVT/utils/preprocessing.py",
"chars": 527,
"preview": "def _blosc_opts(complevel=1, complib=\"blosc:zstd\", shuffle=\"byte\"):\n shuffle = 2 if shuffle == \"bit\" else 1 if shuffl"
},
{
"path": "RVT/utils/timers.py",
"chars": 2827,
"preview": "import atexit\nimport time\nfrom functools import wraps\n\nimport numpy as np\nimport torch\n\ncuda_timers = {}\ntimers = {}\n\n\nc"
},
{
"path": "RVT/validation.py",
"chars": 2712,
"preview": "import os\n\nos.environ[\"CUDA_DEVICE_ORDER\"] = \"PCI_BUS_ID\"\nos.environ[\"OMP_NUM_THREADS\"] = \"1\"\nos.environ[\"OPENBLAS_NUM_T"
},
{
"path": "installation_details.txt",
"chars": 389,
"preview": "conda create -y -n events_signals python=3.11\nconda activate events_signals\nconda install -y pytorch torchvision torchau"
},
{
"path": "scripts/1mpx/onempx_base.bash",
"chars": 382,
"preview": "#!/usr/bin/env bash\n\nsource activate events_signals\n\npython RVT/train.py model=rnndet dataset=gen4 dataset.path=/shares/"
},
{
"path": "scripts/1mpx/onempx_base.job",
"chars": 389,
"preview": "#!/usr/bin/env bash\n#SBATCH --ntasks-per-node=2\n#SBATCH --cpus-per-task=16\n#SBATCH --mem-per-cpu=8G\n#SBATCH --time=86:00"
},
{
"path": "scripts/1mpx/onempx_small.bash",
"chars": 383,
"preview": "#!/usr/bin/env bash\n\nsource activate events_signals\n\npython RVT/train.py model=rnndet dataset=gen4 dataset.path=/shares/"
},
{
"path": "scripts/1mpx/onempx_small.job",
"chars": 393,
"preview": "#!/usr/bin/env bash\n#SBATCH --ntasks-per-node=2\n#SBATCH --cpus-per-task=16\n#SBATCH --mem-per-cpu=8G\n#SBATCH --time=78:00"
},
{
"path": "scripts/gen1/base.txt",
"chars": 301,
"preview": "python RVT/train.py model=rnndet dataset=gen1 dataset.path=/data/scratch1/nzubic/datasets/RVT/gen1 wandb.project_name=ss"
},
{
"path": "scripts/gen1/small.txt",
"chars": 302,
"preview": "python RVT/train.py model=rnndet dataset=gen1 dataset.path=/data/scratch1/nzubic/datasets/RVT/gen1 wandb.project_name=ss"
}
]
About this extraction
This page contains the full source code of the uzh-rpg/ssms_event_cameras GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 162 files (623.9 KB), approximately 166.1k tokens, and a symbol index with 1050 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.
Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.