main d78c7b58b10b cached
59 files
278.6 KB
73.1k tokens
217 symbols
1 requests
Download .txt
Showing preview only (299K chars total). Download the full file or copy to clipboard to get everything.
Repository: Megvii-BaseDetection/BEVDepth
Branch: main
Commit: d78c7b58b10b
Files: 59
Total size: 278.6 KB

Directory structure:
gitextract_ucu3202x/

├── .github/
│   └── workflows/
│       └── lint.yml
├── .gitignore
├── .pre-commit-config.yaml
├── LICENSE.md
├── README.md
├── bevdepth/
│   ├── callbacks/
│   │   └── ema.py
│   ├── datasets/
│   │   └── nusc_det_dataset.py
│   ├── evaluators/
│   │   └── det_evaluators.py
│   ├── exps/
│   │   ├── base_cli.py
│   │   └── nuscenes/
│   │       ├── MatrixVT/
│   │       │   └── matrixvt_bev_depth_lss_r50_256x704_128x128_24e_ema.py
│   │       ├── base_exp.py
│   │       ├── fusion/
│   │       │   ├── bev_depth_fusion_lss_r50_256x704_128x128_24e.py
│   │       │   ├── bev_depth_fusion_lss_r50_256x704_128x128_24e_2key.py
│   │       │   ├── bev_depth_fusion_lss_r50_256x704_128x128_24e_2key_trainval.py
│   │       │   └── bev_depth_fusion_lss_r50_256x704_128x128_24e_key4.py
│   │       └── mv/
│   │           ├── bev_depth_lss_r50_256x704_128x128_20e_cbgs_2key_da.py
│   │           ├── bev_depth_lss_r50_256x704_128x128_20e_cbgs_2key_da_ema.py
│   │           ├── bev_depth_lss_r50_256x704_128x128_24e_2key.py
│   │           ├── bev_depth_lss_r50_256x704_128x128_24e_2key_ema.py
│   │           ├── bev_depth_lss_r50_256x704_128x128_24e_ema.py
│   │           ├── bev_depth_lss_r50_512x1408_128x128_24e_2key.py
│   │           ├── bev_depth_lss_r50_640x1600_128x128_24e_2key.py
│   │           ├── bev_stereo_lss_r50_256x704_128x128_20e_cbgs_2key_da.py
│   │           ├── bev_stereo_lss_r50_256x704_128x128_20e_cbgs_2key_da_ema.py
│   │           ├── bev_stereo_lss_r50_256x704_128x128_24e_2key.py
│   │           ├── bev_stereo_lss_r50_256x704_128x128_24e_2key_ema.py
│   │           ├── bev_stereo_lss_r50_256x704_128x128_24e_key4.py
│   │           └── bev_stereo_lss_r50_256x704_128x128_24e_key4_ema.py
│   ├── layers/
│   │   ├── __init__.py
│   │   ├── backbones/
│   │   │   ├── __init__.py
│   │   │   ├── base_lss_fpn.py
│   │   │   ├── bevstereo_lss_fpn.py
│   │   │   ├── fusion_lss_fpn.py
│   │   │   └── matrixvt.py
│   │   └── heads/
│   │       ├── __init__.py
│   │       └── bev_depth_head.py
│   ├── models/
│   │   ├── base_bev_depth.py
│   │   ├── bev_stereo.py
│   │   ├── fusion_bev_depth.py
│   │   └── matrixvt_det.py
│   ├── ops/
│   │   ├── voxel_pooling_inference/
│   │   │   ├── __init__.py
│   │   │   ├── src/
│   │   │   │   ├── voxel_pooling_inference_forward.cpp
│   │   │   │   └── voxel_pooling_inference_forward_cuda.cu
│   │   │   └── voxel_pooling_inference.py
│   │   └── voxel_pooling_train/
│   │       ├── __init__.py
│   │       ├── src/
│   │       │   ├── voxel_pooling_train_forward.cpp
│   │       │   └── voxel_pooling_train_forward_cuda.cu
│   │       └── voxel_pooling_train.py
│   └── utils/
│       └── torch_dist.py
├── requirements-dev.txt
├── requirements.txt
├── scripts/
│   ├── gen_info.py
│   └── visualize_nusc.py
├── setup.py
└── test/
    ├── test_dataset/
    │   └── test_nusc_mv_det_dataset.py
    ├── test_layers/
    │   ├── test_backbone.py
    │   ├── test_head.py
    │   └── test_matrixvt.py
    └── test_ops/
        └── test_voxel_pooling.py

================================================
FILE CONTENTS
================================================

================================================
FILE: .github/workflows/lint.yml
================================================
name: lint

on: [push, pull_request]

concurrency:
  group: ${{ github.workflow }}-${{ github.ref }}
  cancel-in-progress: true

jobs:
  lint:
    runs-on: ubuntu-latest
    steps:
      - uses: actions/checkout@v2
      - name: Set up Python 3.7
        uses: actions/setup-python@v2
        with:
          python-version: 3.7
      - name: Install pre-commit hook
        run: |
          pip install pre-commit
          pre-commit install
      - name: Linting
        run: pre-commit run --all-files
      - name: Format c/cuda codes with clang-format
        uses: DoozyX/clang-format-lint-action@v0.11
        with:
          source: bevdepth/ops
          extensions: h,c,cpp,hpp,cu,cuh
          style: google
      - name: Check docstring coverage
        run: |
          pip install interrogate
          interrogate -v --ignore-init-method --ignore-module --ignore-nested-functions --ignore-regex "__repr__" -e 'bevdepth/exps' -e 'test/' -e 'scripts' -e 'setup.py' -e 'bevdepth/ops' -e 'bevdepth/utils/' --fail-under 50


================================================
FILE: .gitignore
================================================
### Linux ###
*~

# temporary files which can be created if a process still has a handle open of a deleted file
.fuse_hidden*

# KDE directory preferences
.directory

# Linux trash folder which might appear on any partition or disk
.Trash-*

# .nfs files are created when an open file is removed but is still being accessed
.nfs*

### PyCharm ###
# User-specific stuff
.idea

# CMake
cmake-build-*/

# Mongo Explorer plugin
.idea/**/mongoSettings.xml

# File-based project format
*.iws

# IntelliJ
out/

# mpeltonen/sbt-idea plugin
.idea_modules/

# JIRA plugin
atlassian-ide-plugin.xml

# Cursive Clojure plugin
.idea/replstate.xml

# Crashlytics plugin (for Android Studio and IntelliJ)
com_crashlytics_export_strings.xml
crashlytics.properties
crashlytics-build.properties
fabric.properties

# Editor-based Rest Client
.idea/httpRequests

# Android studio 3.1+ serialized cache file
.idea/caches/build_file_checksums.ser

# JetBrains templates
**___jb_tmp___

### Python ###
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class

# C extensions
*.so

# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
pip-wheel-metadata/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST

# PyInstaller
#  Usually these files are written by a python script from a template
#  before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec

# Installer logs
pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
.hypothesis/
.pytest_cache/

# Translations
*.mo
*.pot

# Django stuff:
*.log
local_settings.py
db.sqlite3

# Flask stuff:
instance/
.webassets-cache

# Scrapy stuff:
.scrapy

# Sphinx documentation
docs/_build/
docs/build/

# PyBuilder
target/

# Jupyter Notebook
.ipynb_checkpoints

# IPython
profile_default/
ipython_config.py

# pyenv
.python-version

# pipenv
#   According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
#   However, in case of collaboration, if having platform-specific dependencies or dependencies
#   having no cross-platform support, pipenv may install dependencies that don’t work, or not
#   install all needed dependencies.
#Pipfile.lock

# celery beat schedule file
celerybeat-schedule

# SageMath parsed files
*.sage.py

# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/

# Spyder project settings
.spyderproject
.spyproject

# Rope project settings
.ropeproject

# mkdocs documentation
/site

# mypy
.mypy_cache/
.dmypy.json
dmypy.json

# Pyre type checker
.pyre/

### Vim ###
# Swap
[._]*.s[a-v][a-z]
[._]*.sw[a-p]
[._]s[a-rt-v][a-z]
[._]ss[a-gi-z]
[._]sw[a-p]

# Session
Session.vim

# Temporary
.netrwhist
# Auto-generated tag files
tags
# Persistent undo
[._]*.un~

### Researcher ###
# output
train_log
docs/api
.code-workspace.code-workspace
output
outputs
instant_test_output
inference_test_output
*.pkl
*.npy
*.pth
events.out.tfevents*

# vscode
*.code-workspace
.vscode

# vim
.vim


================================================
FILE: .pre-commit-config.yaml
================================================
repos:
  - repo: https://github.com/PyCQA/flake8
    rev: 5.0.4
    hooks:
      - id: flake8
  - repo: https://github.com/PyCQA/isort
    rev: 5.10.1
    hooks:
      - id: isort
  - repo: https://github.com/pre-commit/mirrors-yapf
    rev: v0.32.0
    hooks:
      - id: yapf
  - repo: https://github.com/pre-commit/pre-commit-hooks
    rev: v4.3.0
    hooks:
      - id: trailing-whitespace
      - id: check-yaml
      - id: end-of-file-fixer
      - id: requirements-txt-fixer
      - id: double-quote-string-fixer
      - id: check-merge-conflict
      - id: fix-encoding-pragma
        args: ["--remove"]
      - id: mixed-line-ending
        args: ["--fix=lf"]
  - repo: https://github.com/codespell-project/codespell
    rev: v2.2.1
    hooks:
      - id: codespell


================================================
FILE: LICENSE.md
================================================
MIT License

Copyright (c) 2022 Megvii-BaseDetection

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.


================================================
FILE: README.md
================================================
## BEVDepth
BEVDepth is a new 3D object detector with a trustworthy depth
estimation. For more details, please refer to our [paper on Arxiv](https://arxiv.org/abs/2206.10092).

<img src="assets/bevdepth.png" width="1000" >

## BEVStereo
BEVStereo is a new multi-view 3D object detector using temporal stereo to enhance depth estimation.
<img src="assets/bevstereo.png" width="1000" >

## MatrixVT
[MatrixVT](bevdepth/exps/nuscenes/MatrixVT/matrixvt_bev_depth_lss_r50_256x704_128x128_24e_ema.py) is a novel View Transformer for BEV paradigm with high efficiency and without customized operators. For more details, please refer to our [paper on Arxiv](https://arxiv.org/abs/2211.10593). Try MatrixVT on **CPU** by run [this file](bevdepth/layers/backbones/matrixvt.py) !
<img src="assets/matrixvt.jpg" width="1000" >

## Updates!!
* 【2022/12/06】 We released our new View Transformer (MatrixVT), the paper is on [Arxiv](https://arxiv.org/abs/2211.10593).
* 【2022/11/30】 We updated our paper(BEVDepth) on [Arxiv](https://arxiv.org/abs/2206.10092).
* 【2022/11/18】 Both BEVDepth and BEVStereo were accepted by AAAI'2023.
* 【2022/09/22】 We released our paper(BEVStereo) on [Arxiv](https://arxiv.org/abs/2209.10248).
* 【2022/08/24】 We submitted our result(BEVStereo) on [nuScenes Detection Task](https://nuscenes.org/object-detection?externalData=all&mapData=all&modalities=Camera) and achieved the SOTA.
* 【2022/06/23】 We submitted our result(BEVDepth) without extra data on [nuScenes Detection Task](https://nuscenes.org/object-detection?externalData=all&mapData=all&modalities=Camera) and achieved the SOTA.
* 【2022/06/21】 We released our paper(BEVDepth) on [Arxiv](https://arxiv.org/abs/2206.10092).
* 【2022/04/11】 We submitted our result(BEVDepth) on [nuScenes Detection Task](https://nuscenes.org/object-detection?externalData=all&mapData=all&modalities=Camera) and achieved the SOTA.


## Quick Start
### Installation
**Step 0.** Install [pytorch](https://pytorch.org/)(v1.9.0).

**Step 1.** Install [MMDetection3D](https://github.com/open-mmlab/mmdetection3d)(v1.0.0rc4).

**Step 2.** Install requirements.
```shell
pip install -r requirements.txt
```
**Step 3.** Install BEVDepth(gpu required).
```shell
python setup.py develop
```

### Data preparation
**Step 0.** Download nuScenes official dataset.

**Step 1.** Symlink the dataset root to `./data/`.
```
ln -s [nuscenes root] ./data/
```
The directory will be as follows.
```
BEVDepth
├── data
│   ├── nuScenes
│   │   ├── maps
│   │   ├── samples
│   │   ├── sweeps
│   │   ├── v1.0-test
|   |   ├── v1.0-trainval
```
**Step 2.** Prepare infos.
```
python scripts/gen_info.py
```

### Tutorials
**Train.**
```
python [EXP_PATH] --amp_backend native -b 8 --gpus 8
```
**Eval.**
```
python [EXP_PATH] --ckpt_path [CKPT_PATH] -e -b 8 --gpus 8
```

### Benchmark
|Exp |EMA| CBGS |mAP |mATE| mASE | mAOE |mAVE| mAAE | NDS | weights |
| ------ | :---: | :---: | :---:       |:---:     |:---:  | :---: | :----: | :----: | :----: | :----: |
|[BEVDepth](bevdepth/exps/nuscenes/mv/bev_depth_lss_r50_256x704_128x128_24e_2key.py)| | |0.3304| 0.7021| 0.2795| 0.5346| 0.5530| 0.2274| 0.4355 | [github](https://github.com/Megvii-BaseDetection/BEVDepth/releases/download/v0.0.2/bev_depth_lss_r50_256x704_128x128_24e_2key.pth)
|[BEVDepth](bevdepth/exps/nuscenes/mv/bev_depth_lss_r50_256x704_128x128_24e_2key_ema.py)|√ | |0.3329 |  0.6832     |0.2761 | 0.5446 | 0.5258 | 0.2259 | 0.4409 | [github](https://github.com/Megvii-BaseDetection/BEVDepth/releases/download/v0.0.2/bev_depth_lss_r50_256x704_128x128_24e_2key_ema.pth)
|[BEVDepth](bevdepth/exps/nuscenes/mv/bev_depth_lss_r50_256x704_128x128_20e_cbgs_2key_da.py)| |√ |0.3484| 0.6159| 0.2716| 0.4144| 0.4402| 0.1954| 0.4805 | [github](https://github.com/Megvii-BaseDetection/BEVDepth/releases/download/v0.0.2/bev_depth_lss_r50_256x704_128x128_20e_cbgs_2key_da.pth)
|[BEVDepth](bevdepth/exps/nuscenes/mv/bev_depth_lss_r50_256x704_128x128_20e_cbgs_2key_da_ema.py)|√  |√ |0.3589 |  0.6119     |0.2692 | 0.5074 | 0.4086 | 0.2009 | 0.4797 | [github](https://github.com/Megvii-BaseDetection/BEVDepth/releases/download/v0.0.2/bev_depth_lss_r50_256x704_128x128_20e_cbgs_2key_da_ema.pth) |
|[BEVStereo](bevdepth/exps/nuscenes/mv/bev_stereo_lss_r50_256x704_128x128_24e_2key.py)|  | |0.3456 | 0.6589 | 0.2774 | 0.5500 | 0.4980 | 0.2278 | 0.4516 | [github](https://github.com/Megvii-BaseDetection/BEVStereo/releases/download/v0.0.2/bev_stereo_lss_r50_256x704_128x128_24e_2key.pth) |
|[BEVStereo](bevdepth/exps/nuscenes/mv/bev_stereo_lss_r50_256x704_128x128_24e_2key_ema.py)|√  | |0.3494|	0.6671|	0.2785|	0.5606|	0.4686|	0.2295|	0.4543 | [github](https://github.com/Megvii-BaseDetection/BEVStereo/releases/download/v0.0.2/bev_stereo_lss_r50_256x704_128x128_24e_2key_ema.pth) |
|[BEVStereo](bevdepth/exps/nuscenes/mv/bev_stereo_lss_r50_256x704_128x128_24e_key4.py)|  | |0.3427|	0.6560|	0.2784|	0.5982|	0.5347|	0.2228|	0.4423 | [github](https://github.com/Megvii-BaseDetection/BEVStereo/releases/download/v0.0.2/bev_stereo_lss_r50_256x704_128x128_24e_key4.pth) |
|[BEVStereo](bevdepth/exps/nuscenes/mv/bev_stereo_lss_r50_256x704_128x128_24e_key4_ema.py)|√  | |0.3435|	0.6585|	0.2757|	0.5792|	0.5034|	0.2163|	0.4485 | [github](https://github.com/Megvii-BaseDetection/BEVStereo/releases/download/v0.0.2/bev_stereo_lss_r50_256x704_128x128_24e_key4_ema.pth) |
|[BEVStereo](bevdepth/exps/nuscenes/mv/bev_stereo_lss_r50_256x704_128x128_20e_cbgs_2key_da.py)|  |√ |0.3576|	0.6071|	0.2684|	0.4157|	0.3928|	0.2021|	0.4902 | [github](https://github.com/Megvii-BaseDetection/BEVStereo/releases/download/v0.0.2/bev_stereo_lss_r50_256x704_128x128_20e_cbgs_2key_da.pth) |
|[BEVStereo](bevdepth/exps/nuscenes/mv/bev_stereo_lss_r50_256x704_128x128_20e_cbgs_2key_da_ema.py)|√  |√ |0.3721|	0.5980|	0.2701|	0.4381|	0.3672|	0.1898|	0.4997 | [github](https://github.com/Megvii-BaseDetection/BEVStereo/releases/download/v0.0.2/bev_stereo_lss_r50_256x704_128x128_20e_cbgs_2key_da_ema.pth) |

## FAQ

### EMA
- The results are different between evaluation during training and evaluation from ckpt.

Due to the working mechanism of EMA, the model parameters saved by ckpt are different from the model parameters used in the training stage.

- EMA exps are unable to resume training from ckpt.

We used the customized EMA callback and this function is not supported for now.

## Cite BEVDepth & BEVStereo & MatrixVT
If you use BEVDepth and BEVStereo in your research, please cite our work by using the following BibTeX entry:

```latex
 @article{li2022bevdepth,
  title={BEVDepth: Acquisition of Reliable Depth for Multi-view 3D Object Detection},
  author={Li, Yinhao and Ge, Zheng and Yu, Guanyi and Yang, Jinrong and Wang, Zengran and Shi, Yukang and Sun, Jianjian and Li, Zeming},
  journal={arXiv preprint arXiv:2206.10092},
  year={2022}
}
@article{li2022bevstereo,
  title={Bevstereo: Enhancing depth estimation in multi-view 3d object detection with dynamic temporal stereo},
  author={Li, Yinhao and Bao, Han and Ge, Zheng and Yang, Jinrong and Sun, Jianjian and Li, Zeming},
  journal={arXiv preprint arXiv:2209.10248},
  year={2022}
}
@article{zhou2022matrixvt,
  title={MatrixVT: Efficient Multi-Camera to BEV Transformation for 3D Perception},
  author={Zhou, Hongyu and Ge, Zheng and Li, Zeming and Zhang, Xiangyu},
  journal={arXiv preprint arXiv:2211.10593},
  year={2022}
}
```


================================================
FILE: bevdepth/callbacks/ema.py
================================================
#!/usr/bin/env python3
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
import math
import os
from copy import deepcopy

import torch
import torch.nn as nn
from pytorch_lightning.callbacks import Callback

__all__ = ['ModelEMA', 'is_parallel']


def is_parallel(model):
    """check if model is in parallel mode."""
    parallel_type = (
        nn.parallel.DataParallel,
        nn.parallel.DistributedDataParallel,
    )
    return isinstance(model, parallel_type)


class ModelEMA:
    """
    Model Exponential Moving Average from https://github.com/rwightman/
    pytorch-image-models Keep a moving average of everything in
    the model state_dict (parameters and buffers).
    This is intended to allow functionality like
    https://www.tensorflow.org/api_docs/python/tf/train/
    ExponentialMovingAverage
    A smoothed version of the weights is necessary for some training
    schemes to perform well.
    This class is sensitive where it is initialized in the sequence
    of model init, GPU assignment and distributed training wrappers.
    """

    def __init__(self, model, decay=0.9999, updates=0):
        """
        Args:
            model (nn.Module): model to apply EMA.
            decay (float): ema decay.
            updates (int): counter of EMA updates.
        """
        # Create EMA(FP32)
        self.ema = deepcopy(
            model.module if is_parallel(model) else model).eval()
        self.updates = updates
        # decay exponential ramp (to help early epochs)
        self.decay = lambda x: decay * (1 - math.exp(-x / 2000))
        for p in self.ema.parameters():
            p.requires_grad_(False)

    def update(self, trainer, model):
        # Update EMA parameters
        with torch.no_grad():
            self.updates += 1
            d = self.decay(self.updates)

            msd = model.module.state_dict() if is_parallel(
                model) else model.state_dict()  # model state_dict
            for k, v in self.ema.state_dict().items():
                if v.dtype.is_floating_point:
                    v *= d
                    v += (1.0 - d) * msd[k].detach()


class EMACallback(Callback):

    def __init__(self, len_updates) -> None:
        super().__init__()
        self.len_updates = len_updates

    def on_fit_start(self, trainer, pl_module):
        # Todo (@lizeming@megvii.com): delete manually specified device
        from torch.nn.modules.batchnorm import SyncBatchNorm

        bn_model_list = list()
        bn_model_dist_group_list = list()
        for model_ref in trainer.model.modules():
            if isinstance(model_ref, SyncBatchNorm):
                bn_model_list.append(model_ref)
                bn_model_dist_group_list.append(model_ref.process_group)
                model_ref.process_group = None
        trainer.ema_model = ModelEMA(trainer.model.module.module.model.cuda(),
                                     0.9990)

        for bn_model, dist_group in zip(bn_model_list,
                                        bn_model_dist_group_list):
            bn_model.process_group = dist_group
        trainer.ema_model.updates = self.len_updates

    def on_train_batch_end(self,
                           trainer,
                           pl_module,
                           outputs,
                           batch,
                           batch_idx,
                           unused=0):
        trainer.ema_model.update(trainer, trainer.model.module.module.model)

    def on_train_epoch_end(self, trainer, pl_module) -> None:
        state_dict = trainer.ema_model.ema.state_dict()
        state_dict_keys = list(state_dict.keys())
        # TODO: Change to more elegant way.
        for state_dict_key in state_dict_keys:
            new_key = 'model.' + state_dict_key
            state_dict[new_key] = state_dict.pop(state_dict_key)
        checkpoint = {
            # the epoch and global step are saved for
            # compatibility but they are not relevant for restoration
            'epoch': trainer.current_epoch,
            'global_step': trainer.global_step,
            'state_dict': state_dict
        }
        torch.save(
            checkpoint,
            os.path.join(trainer.log_dir, f'{trainer.current_epoch}.pth'))


================================================
FILE: bevdepth/datasets/nusc_det_dataset.py
================================================
import os

import mmcv
import numpy as np
import torch
from mmdet3d.core.bbox.structures.lidar_box3d import LiDARInstance3DBoxes
from nuscenes.utils.data_classes import Box, LidarPointCloud
from nuscenes.utils.geometry_utils import view_points
from PIL import Image
from pyquaternion import Quaternion
from torch.utils.data import Dataset

__all__ = ['NuscDetDataset']

map_name_from_general_to_detection = {
    'human.pedestrian.adult': 'pedestrian',
    'human.pedestrian.child': 'pedestrian',
    'human.pedestrian.wheelchair': 'ignore',
    'human.pedestrian.stroller': 'ignore',
    'human.pedestrian.personal_mobility': 'ignore',
    'human.pedestrian.police_officer': 'pedestrian',
    'human.pedestrian.construction_worker': 'pedestrian',
    'animal': 'ignore',
    'vehicle.car': 'car',
    'vehicle.motorcycle': 'motorcycle',
    'vehicle.bicycle': 'bicycle',
    'vehicle.bus.bendy': 'bus',
    'vehicle.bus.rigid': 'bus',
    'vehicle.truck': 'truck',
    'vehicle.construction': 'construction_vehicle',
    'vehicle.emergency.ambulance': 'ignore',
    'vehicle.emergency.police': 'ignore',
    'vehicle.trailer': 'trailer',
    'movable_object.barrier': 'barrier',
    'movable_object.trafficcone': 'traffic_cone',
    'movable_object.pushable_pullable': 'ignore',
    'movable_object.debris': 'ignore',
    'static_object.bicycle_rack': 'ignore',
}


def get_rot(h):
    return torch.Tensor([
        [np.cos(h), np.sin(h)],
        [-np.sin(h), np.cos(h)],
    ])


def img_transform(img, resize, resize_dims, crop, flip, rotate):
    ida_rot = torch.eye(2)
    ida_tran = torch.zeros(2)
    # adjust image
    img = img.resize(resize_dims)
    img = img.crop(crop)
    if flip:
        img = img.transpose(method=Image.FLIP_LEFT_RIGHT)
    img = img.rotate(rotate)

    # post-homography transformation
    ida_rot *= resize
    ida_tran -= torch.Tensor(crop[:2])
    if flip:
        A = torch.Tensor([[-1, 0], [0, 1]])
        b = torch.Tensor([crop[2] - crop[0], 0])
        ida_rot = A.matmul(ida_rot)
        ida_tran = A.matmul(ida_tran) + b
    A = get_rot(rotate / 180 * np.pi)
    b = torch.Tensor([crop[2] - crop[0], crop[3] - crop[1]]) / 2
    b = A.matmul(-b) + b
    ida_rot = A.matmul(ida_rot)
    ida_tran = A.matmul(ida_tran) + b
    ida_mat = ida_rot.new_zeros(4, 4)
    ida_mat[3, 3] = 1
    ida_mat[2, 2] = 1
    ida_mat[:2, :2] = ida_rot
    ida_mat[:2, 3] = ida_tran
    return img, ida_mat


def bev_transform(gt_boxes, rotate_angle, scale_ratio, flip_dx, flip_dy):
    rotate_angle = torch.tensor(rotate_angle / 180 * np.pi)
    rot_sin = torch.sin(rotate_angle)
    rot_cos = torch.cos(rotate_angle)
    rot_mat = torch.Tensor([[rot_cos, -rot_sin, 0], [rot_sin, rot_cos, 0],
                            [0, 0, 1]])
    scale_mat = torch.Tensor([[scale_ratio, 0, 0], [0, scale_ratio, 0],
                              [0, 0, scale_ratio]])
    flip_mat = torch.Tensor([[1, 0, 0], [0, 1, 0], [0, 0, 1]])
    if flip_dx:
        flip_mat = flip_mat @ torch.Tensor([[-1, 0, 0], [0, 1, 0], [0, 0, 1]])
    if flip_dy:
        flip_mat = flip_mat @ torch.Tensor([[1, 0, 0], [0, -1, 0], [0, 0, 1]])
    rot_mat = flip_mat @ (scale_mat @ rot_mat)
    if gt_boxes.shape[0] > 0:
        gt_boxes[:, :3] = (rot_mat @ gt_boxes[:, :3].unsqueeze(-1)).squeeze(-1)
        gt_boxes[:, 3:6] *= scale_ratio
        gt_boxes[:, 6] += rotate_angle
        if flip_dx:
            gt_boxes[:, 6] = 2 * torch.asin(torch.tensor(1.0)) - gt_boxes[:, 6]
        if flip_dy:
            gt_boxes[:, 6] = -gt_boxes[:, 6]
        gt_boxes[:, 7:] = (
            rot_mat[:2, :2] @ gt_boxes[:, 7:].unsqueeze(-1)).squeeze(-1)
    return gt_boxes, rot_mat


def depth_transform(cam_depth, resize, resize_dims, crop, flip, rotate):
    """Transform depth based on ida augmentation configuration.

    Args:
        cam_depth (np array): Nx3, 3: x,y,d.
        resize (float): Resize factor.
        resize_dims (list): Final dimension.
        crop (list): x1, y1, x2, y2
        flip (bool): Whether to flip.
        rotate (float): Rotation value.

    Returns:
        np array: [h/down_ratio, w/down_ratio, d]
    """

    H, W = resize_dims
    cam_depth[:, :2] = cam_depth[:, :2] * resize
    cam_depth[:, 0] -= crop[0]
    cam_depth[:, 1] -= crop[1]
    if flip:
        cam_depth[:, 0] = resize_dims[1] - cam_depth[:, 0]

    cam_depth[:, 0] -= W / 2.0
    cam_depth[:, 1] -= H / 2.0

    h = rotate / 180 * np.pi
    rot_matrix = [
        [np.cos(h), np.sin(h)],
        [-np.sin(h), np.cos(h)],
    ]
    cam_depth[:, :2] = np.matmul(rot_matrix, cam_depth[:, :2].T).T

    cam_depth[:, 0] += W / 2.0
    cam_depth[:, 1] += H / 2.0

    depth_coords = cam_depth[:, :2].astype(np.int16)

    depth_map = np.zeros(resize_dims)
    valid_mask = ((depth_coords[:, 1] < resize_dims[0])
                  & (depth_coords[:, 0] < resize_dims[1])
                  & (depth_coords[:, 1] >= 0)
                  & (depth_coords[:, 0] >= 0))
    depth_map[depth_coords[valid_mask, 1],
              depth_coords[valid_mask, 0]] = cam_depth[valid_mask, 2]

    return torch.Tensor(depth_map)


def map_pointcloud_to_image(
    lidar_points,
    img,
    lidar_calibrated_sensor,
    lidar_ego_pose,
    cam_calibrated_sensor,
    cam_ego_pose,
    min_dist: float = 0.0,
):

    # Points live in the point sensor frame. So they need to be
    # transformed via global to the image plane.
    # First step: transform the pointcloud to the ego vehicle
    # frame for the timestamp of the sweep.

    lidar_points = LidarPointCloud(lidar_points.T)
    lidar_points.rotate(
        Quaternion(lidar_calibrated_sensor['rotation']).rotation_matrix)
    lidar_points.translate(np.array(lidar_calibrated_sensor['translation']))

    # Second step: transform from ego to the global frame.
    lidar_points.rotate(Quaternion(lidar_ego_pose['rotation']).rotation_matrix)
    lidar_points.translate(np.array(lidar_ego_pose['translation']))

    # Third step: transform from global into the ego vehicle
    # frame for the timestamp of the image.
    lidar_points.translate(-np.array(cam_ego_pose['translation']))
    lidar_points.rotate(Quaternion(cam_ego_pose['rotation']).rotation_matrix.T)

    # Fourth step: transform from ego into the camera.
    lidar_points.translate(-np.array(cam_calibrated_sensor['translation']))
    lidar_points.rotate(
        Quaternion(cam_calibrated_sensor['rotation']).rotation_matrix.T)

    # Fifth step: actually take a "picture" of the point cloud.
    # Grab the depths (camera frame z axis points away from the camera).
    depths = lidar_points.points[2, :]
    coloring = depths

    # Take the actual picture (matrix multiplication with camera-matrix
    # + renormalization).
    points = view_points(lidar_points.points[:3, :],
                         np.array(cam_calibrated_sensor['camera_intrinsic']),
                         normalize=True)

    # Remove points that are either outside or behind the camera.
    # Leave a margin of 1 pixel for aesthetic reasons. Also make
    # sure points are at least 1m in front of the camera to avoid
    # seeing the lidar points on the camera casing for non-keyframes
    # which are slightly out of sync.
    mask = np.ones(depths.shape[0], dtype=bool)
    mask = np.logical_and(mask, depths > min_dist)
    mask = np.logical_and(mask, points[0, :] > 1)
    mask = np.logical_and(mask, points[0, :] < img.size[0] - 1)
    mask = np.logical_and(mask, points[1, :] > 1)
    mask = np.logical_and(mask, points[1, :] < img.size[1] - 1)
    points = points[:, mask]
    coloring = coloring[mask]

    return points, coloring


class NuscDetDataset(Dataset):

    def __init__(self,
                 ida_aug_conf,
                 bda_aug_conf,
                 classes,
                 data_root,
                 info_paths,
                 is_train,
                 use_cbgs=False,
                 num_sweeps=1,
                 img_conf=dict(img_mean=[123.675, 116.28, 103.53],
                               img_std=[58.395, 57.12, 57.375],
                               to_rgb=True),
                 return_depth=False,
                 sweep_idxes=list(),
                 key_idxes=list(),
                 use_fusion=False):
        """Dataset used for bevdetection task.
        Args:
            ida_aug_conf (dict): Config for ida augmentation.
            bda_aug_conf (dict): Config for bda augmentation.
            classes (list): Class names.
            use_cbgs (bool): Whether to use cbgs strategy,
                Default: False.
            num_sweeps (int): Number of sweeps to be used for each sample.
                default: 1.
            img_conf (dict): Config for image.
            return_depth (bool): Whether to use depth gt.
                default: False.
            sweep_idxes (list): List of sweep idxes to be used.
                default: list().
            key_idxes (list): List of key idxes to be used.
                default: list().
            use_fusion (bool): Whether to use lidar data.
                default: False.
        """
        super().__init__()
        if isinstance(info_paths, list):
            self.infos = list()
            for info_path in info_paths:
                self.infos.extend(mmcv.load(info_path))
        else:
            self.infos = mmcv.load(info_paths)
        self.is_train = is_train
        self.ida_aug_conf = ida_aug_conf
        self.bda_aug_conf = bda_aug_conf
        self.data_root = data_root
        self.classes = classes
        self.use_cbgs = use_cbgs
        if self.use_cbgs:
            self.cat2id = {name: i for i, name in enumerate(self.classes)}
            self.sample_indices = self._get_sample_indices()
        self.num_sweeps = num_sweeps
        self.img_mean = np.array(img_conf['img_mean'], np.float32)
        self.img_std = np.array(img_conf['img_std'], np.float32)
        self.to_rgb = img_conf['to_rgb']
        self.return_depth = return_depth
        assert sum([sweep_idx >= 0 for sweep_idx in sweep_idxes]) \
            == len(sweep_idxes), 'All `sweep_idxes` must greater \
                than or equal to 0.'

        self.sweeps_idx = sweep_idxes
        assert sum([key_idx < 0 for key_idx in key_idxes]) == len(key_idxes),\
            'All `key_idxes` must less than 0.'
        self.key_idxes = [0] + key_idxes
        self.use_fusion = use_fusion

    def _get_sample_indices(self):
        """Load annotations from ann_file.

        Args:
            ann_file (str): Path of the annotation file.

        Returns:
            list[dict]: List of annotations after class sampling.
        """
        class_sample_idxs = {cat_id: [] for cat_id in self.cat2id.values()}
        for idx, info in enumerate(self.infos):
            gt_names = set(
                [ann_info['category_name'] for ann_info in info['ann_infos']])
            for gt_name in gt_names:
                gt_name = map_name_from_general_to_detection[gt_name]
                if gt_name not in self.classes:
                    continue
                class_sample_idxs[self.cat2id[gt_name]].append(idx)
        duplicated_samples = sum(
            [len(v) for _, v in class_sample_idxs.items()])
        class_distribution = {
            k: len(v) / duplicated_samples
            for k, v in class_sample_idxs.items()
        }

        sample_indices = []

        frac = 1.0 / len(self.classes)
        ratios = [frac / v for v in class_distribution.values()]
        for cls_inds, ratio in zip(list(class_sample_idxs.values()), ratios):
            sample_indices += np.random.choice(cls_inds,
                                               int(len(cls_inds) *
                                                   ratio)).tolist()
        return sample_indices

    def sample_ida_augmentation(self):
        """Generate ida augmentation values based on ida_config."""
        H, W = self.ida_aug_conf['H'], self.ida_aug_conf['W']
        fH, fW = self.ida_aug_conf['final_dim']
        if self.is_train:
            resize = np.random.uniform(*self.ida_aug_conf['resize_lim'])
            resize_dims = (int(W * resize), int(H * resize))
            newW, newH = resize_dims
            crop_h = int(
                (1 - np.random.uniform(*self.ida_aug_conf['bot_pct_lim'])) *
                newH) - fH
            crop_w = int(np.random.uniform(0, max(0, newW - fW)))
            crop = (crop_w, crop_h, crop_w + fW, crop_h + fH)
            flip = False
            if self.ida_aug_conf['rand_flip'] and np.random.choice([0, 1]):
                flip = True
            rotate_ida = np.random.uniform(*self.ida_aug_conf['rot_lim'])
        else:
            resize = max(fH / H, fW / W)
            resize_dims = (int(W * resize), int(H * resize))
            newW, newH = resize_dims
            crop_h = int(
                (1 - np.mean(self.ida_aug_conf['bot_pct_lim'])) * newH) - fH
            crop_w = int(max(0, newW - fW) / 2)
            crop = (crop_w, crop_h, crop_w + fW, crop_h + fH)
            flip = False
            rotate_ida = 0
        return resize, resize_dims, crop, flip, rotate_ida

    def sample_bda_augmentation(self):
        """Generate bda augmentation values based on bda_config."""
        if self.is_train:
            rotate_bda = np.random.uniform(*self.bda_aug_conf['rot_lim'])
            scale_bda = np.random.uniform(*self.bda_aug_conf['scale_lim'])
            flip_dx = np.random.uniform() < self.bda_aug_conf['flip_dx_ratio']
            flip_dy = np.random.uniform() < self.bda_aug_conf['flip_dy_ratio']
        else:
            rotate_bda = 0
            scale_bda = 1.0
            flip_dx = False
            flip_dy = False
        return rotate_bda, scale_bda, flip_dx, flip_dy

    def get_lidar_depth(self, lidar_points, img, lidar_info, cam_info):
        lidar_calibrated_sensor = lidar_info['LIDAR_TOP']['calibrated_sensor']
        lidar_ego_pose = lidar_info['LIDAR_TOP']['ego_pose']
        cam_calibrated_sensor = cam_info['calibrated_sensor']
        cam_ego_pose = cam_info['ego_pose']
        pts_img, depth = map_pointcloud_to_image(
            lidar_points.copy(), img, lidar_calibrated_sensor.copy(),
            lidar_ego_pose.copy(), cam_calibrated_sensor, cam_ego_pose)
        return np.concatenate([pts_img[:2, :].T, depth[:, None]],
                              axis=1).astype(np.float32)

    def get_image(self, cam_infos, cams, lidar_infos=None):
        """Given data and cam_names, return image data needed.

        Args:
            sweeps_data (list): Raw data used to generate the data we needed.
            cams (list): Camera names.

        Returns:
            Tensor: Image data after processing.
            Tensor: Transformation matrix from camera to ego.
            Tensor: Intrinsic matrix.
            Tensor: Transformation matrix for ida.
            Tensor: Transformation matrix from key
                frame camera to sweep frame camera.
            Tensor: timestamps.
            dict: meta infos needed for evaluation.
        """
        assert len(cam_infos) > 0
        sweep_imgs = list()
        sweep_sensor2ego_mats = list()
        sweep_intrin_mats = list()
        sweep_ida_mats = list()
        sweep_sensor2sensor_mats = list()
        sweep_timestamps = list()
        sweep_lidar_depth = list()
        if self.return_depth or self.use_fusion:
            sweep_lidar_points = list()
            for lidar_info in lidar_infos:
                lidar_path = lidar_info['LIDAR_TOP']['filename']
                lidar_points = np.fromfile(os.path.join(
                    self.data_root, lidar_path),
                                           dtype=np.float32,
                                           count=-1).reshape(-1, 5)[..., :4]
                sweep_lidar_points.append(lidar_points)
        for cam in cams:
            imgs = list()
            sensor2ego_mats = list()
            intrin_mats = list()
            ida_mats = list()
            sensor2sensor_mats = list()
            timestamps = list()
            lidar_depth = list()
            key_info = cam_infos[0]
            resize, resize_dims, crop, flip, \
                rotate_ida = self.sample_ida_augmentation(
                    )
            for sweep_idx, cam_info in enumerate(cam_infos):

                img = Image.open(
                    os.path.join(self.data_root, cam_info[cam]['filename']))
                # img = Image.fromarray(img)
                w, x, y, z = cam_info[cam]['calibrated_sensor']['rotation']
                # sweep sensor to sweep ego
                sweepsensor2sweepego_rot = torch.Tensor(
                    Quaternion(w, x, y, z).rotation_matrix)
                sweepsensor2sweepego_tran = torch.Tensor(
                    cam_info[cam]['calibrated_sensor']['translation'])
                sweepsensor2sweepego = sweepsensor2sweepego_rot.new_zeros(
                    (4, 4))
                sweepsensor2sweepego[3, 3] = 1
                sweepsensor2sweepego[:3, :3] = sweepsensor2sweepego_rot
                sweepsensor2sweepego[:3, -1] = sweepsensor2sweepego_tran
                # sweep ego to global
                w, x, y, z = cam_info[cam]['ego_pose']['rotation']
                sweepego2global_rot = torch.Tensor(
                    Quaternion(w, x, y, z).rotation_matrix)
                sweepego2global_tran = torch.Tensor(
                    cam_info[cam]['ego_pose']['translation'])
                sweepego2global = sweepego2global_rot.new_zeros((4, 4))
                sweepego2global[3, 3] = 1
                sweepego2global[:3, :3] = sweepego2global_rot
                sweepego2global[:3, -1] = sweepego2global_tran

                # global sensor to cur ego
                w, x, y, z = key_info[cam]['ego_pose']['rotation']
                keyego2global_rot = torch.Tensor(
                    Quaternion(w, x, y, z).rotation_matrix)
                keyego2global_tran = torch.Tensor(
                    key_info[cam]['ego_pose']['translation'])
                keyego2global = keyego2global_rot.new_zeros((4, 4))
                keyego2global[3, 3] = 1
                keyego2global[:3, :3] = keyego2global_rot
                keyego2global[:3, -1] = keyego2global_tran
                global2keyego = keyego2global.inverse()

                # cur ego to sensor
                w, x, y, z = key_info[cam]['calibrated_sensor']['rotation']
                keysensor2keyego_rot = torch.Tensor(
                    Quaternion(w, x, y, z).rotation_matrix)
                keysensor2keyego_tran = torch.Tensor(
                    key_info[cam]['calibrated_sensor']['translation'])
                keysensor2keyego = keysensor2keyego_rot.new_zeros((4, 4))
                keysensor2keyego[3, 3] = 1
                keysensor2keyego[:3, :3] = keysensor2keyego_rot
                keysensor2keyego[:3, -1] = keysensor2keyego_tran
                keyego2keysensor = keysensor2keyego.inverse()
                keysensor2sweepsensor = (
                    keyego2keysensor @ global2keyego @ sweepego2global
                    @ sweepsensor2sweepego).inverse()
                sweepsensor2keyego = global2keyego @ sweepego2global @\
                    sweepsensor2sweepego
                sensor2ego_mats.append(sweepsensor2keyego)
                sensor2sensor_mats.append(keysensor2sweepsensor)
                intrin_mat = torch.zeros((4, 4))
                intrin_mat[3, 3] = 1
                intrin_mat[:3, :3] = torch.Tensor(
                    cam_info[cam]['calibrated_sensor']['camera_intrinsic'])
                if self.return_depth and (self.use_fusion or sweep_idx == 0):
                    point_depth = self.get_lidar_depth(
                        sweep_lidar_points[sweep_idx], img,
                        lidar_infos[sweep_idx], cam_info[cam])
                    point_depth_augmented = depth_transform(
                        point_depth, resize, self.ida_aug_conf['final_dim'],
                        crop, flip, rotate_ida)
                    lidar_depth.append(point_depth_augmented)
                img, ida_mat = img_transform(
                    img,
                    resize=resize,
                    resize_dims=resize_dims,
                    crop=crop,
                    flip=flip,
                    rotate=rotate_ida,
                )
                ida_mats.append(ida_mat)
                img = mmcv.imnormalize(np.array(img), self.img_mean,
                                       self.img_std, self.to_rgb)
                img = torch.from_numpy(img).permute(2, 0, 1)
                imgs.append(img)
                intrin_mats.append(intrin_mat)
                timestamps.append(cam_info[cam]['timestamp'])
            sweep_imgs.append(torch.stack(imgs))
            sweep_sensor2ego_mats.append(torch.stack(sensor2ego_mats))
            sweep_intrin_mats.append(torch.stack(intrin_mats))
            sweep_ida_mats.append(torch.stack(ida_mats))
            sweep_sensor2sensor_mats.append(torch.stack(sensor2sensor_mats))
            sweep_timestamps.append(torch.tensor(timestamps))
            if self.return_depth:
                sweep_lidar_depth.append(torch.stack(lidar_depth))
        # Get mean pose of all cams.
        ego2global_rotation = np.mean(
            [key_info[cam]['ego_pose']['rotation'] for cam in cams], 0)
        ego2global_translation = np.mean(
            [key_info[cam]['ego_pose']['translation'] for cam in cams], 0)
        img_metas = dict(
            box_type_3d=LiDARInstance3DBoxes,
            ego2global_translation=ego2global_translation,
            ego2global_rotation=ego2global_rotation,
        )

        ret_list = [
            torch.stack(sweep_imgs).permute(1, 0, 2, 3, 4),
            torch.stack(sweep_sensor2ego_mats).permute(1, 0, 2, 3),
            torch.stack(sweep_intrin_mats).permute(1, 0, 2, 3),
            torch.stack(sweep_ida_mats).permute(1, 0, 2, 3),
            torch.stack(sweep_sensor2sensor_mats).permute(1, 0, 2, 3),
            torch.stack(sweep_timestamps).permute(1, 0),
            img_metas,
        ]
        if self.return_depth:
            ret_list.append(torch.stack(sweep_lidar_depth).permute(1, 0, 2, 3))
        return ret_list

    def get_gt(self, info, cams):
        """Generate gt labels from info.

        Args:
            info(dict): Infos needed to generate gt labels.
            cams(list): Camera names.

        Returns:
            Tensor: GT bboxes.
            Tensor: GT labels.
        """
        ego2global_rotation = np.mean(
            [info['cam_infos'][cam]['ego_pose']['rotation'] for cam in cams],
            0)
        ego2global_translation = np.mean([
            info['cam_infos'][cam]['ego_pose']['translation'] for cam in cams
        ], 0)
        trans = -np.array(ego2global_translation)
        rot = Quaternion(ego2global_rotation).inverse
        gt_boxes = list()
        gt_labels = list()
        for ann_info in info['ann_infos']:
            # Use ego coordinate.
            if (map_name_from_general_to_detection[ann_info['category_name']]
                    not in self.classes
                    or ann_info['num_lidar_pts'] + ann_info['num_radar_pts'] <=
                    0):
                continue
            box = Box(
                ann_info['translation'],
                ann_info['size'],
                Quaternion(ann_info['rotation']),
                velocity=ann_info['velocity'],
            )
            box.translate(trans)
            box.rotate(rot)
            box_xyz = np.array(box.center)
            box_dxdydz = np.array(box.wlh)[[1, 0, 2]]
            box_yaw = np.array([box.orientation.yaw_pitch_roll[0]])
            box_velo = np.array(box.velocity[:2])
            gt_box = np.concatenate([box_xyz, box_dxdydz, box_yaw, box_velo])
            gt_boxes.append(gt_box)
            gt_labels.append(
                self.classes.index(map_name_from_general_to_detection[
                    ann_info['category_name']]))
        return torch.Tensor(gt_boxes), torch.tensor(gt_labels)

    def choose_cams(self):
        """Choose cameras randomly.

        Returns:
            list: Cameras to be used.
        """
        if self.is_train and self.ida_aug_conf['Ncams'] < len(
                self.ida_aug_conf['cams']):
            cams = np.random.choice(self.ida_aug_conf['cams'],
                                    self.ida_aug_conf['Ncams'],
                                    replace=False)
        else:
            cams = self.ida_aug_conf['cams']
        return cams

    def __getitem__(self, idx):
        if self.use_cbgs:
            idx = self.sample_indices[idx]
        cam_infos = list()
        lidar_infos = list()
        # TODO: Check if it still works when number of cameras is reduced.
        cams = self.choose_cams()
        for key_idx in self.key_idxes:
            cur_idx = key_idx + idx
            # Handle scenarios when current idx doesn't have previous key
            # frame or previous key frame is from another scene.
            if cur_idx < 0:
                cur_idx = idx
            elif self.infos[cur_idx]['scene_token'] != self.infos[idx][
                    'scene_token']:
                cur_idx = idx
            info = self.infos[cur_idx]
            cam_infos.append(info['cam_infos'])
            lidar_infos.append(info['lidar_infos'])
            lidar_sweep_timestamps = [
                lidar_sweep['LIDAR_TOP']['timestamp']
                for lidar_sweep in info['lidar_sweeps']
            ]
            for sweep_idx in self.sweeps_idx:
                if len(info['cam_sweeps']) == 0:
                    cam_infos.append(info['cam_infos'])
                    lidar_infos.append(info['lidar_infos'])
                else:
                    # Handle scenarios when current sweep doesn't have all
                    # cam keys.
                    for i in range(min(len(info['cam_sweeps']) - 1, sweep_idx),
                                   -1, -1):
                        if sum([cam in info['cam_sweeps'][i]
                                for cam in cams]) == len(cams):
                            cam_infos.append(info['cam_sweeps'][i])
                            cam_timestamp = np.mean([
                                val['timestamp']
                                for val in info['cam_sweeps'][i].values()
                            ])
                            # Find the closest lidar frame to the cam frame.
                            lidar_idx = np.abs(lidar_sweep_timestamps -
                                               cam_timestamp).argmin()
                            lidar_infos.append(info['lidar_sweeps'][lidar_idx])
                            break
        if self.return_depth or self.use_fusion:
            image_data_list = self.get_image(cam_infos, cams, lidar_infos)

        else:
            image_data_list = self.get_image(cam_infos, cams)
        ret_list = list()
        (
            sweep_imgs,
            sweep_sensor2ego_mats,
            sweep_intrins,
            sweep_ida_mats,
            sweep_sensor2sensor_mats,
            sweep_timestamps,
            img_metas,
        ) = image_data_list[:7]
        img_metas['token'] = self.infos[idx]['sample_token']
        if self.is_train:
            gt_boxes, gt_labels = self.get_gt(self.infos[idx], cams)
        # Temporary solution for test.
        else:
            gt_boxes = sweep_imgs.new_zeros(0, 7)
            gt_labels = sweep_imgs.new_zeros(0, )

        rotate_bda, scale_bda, flip_dx, flip_dy = self.sample_bda_augmentation(
        )
        bda_mat = sweep_imgs.new_zeros(4, 4)
        bda_mat[3, 3] = 1
        gt_boxes, bda_rot = bev_transform(gt_boxes, rotate_bda, scale_bda,
                                          flip_dx, flip_dy)
        bda_mat[:3, :3] = bda_rot
        ret_list = [
            sweep_imgs,
            sweep_sensor2ego_mats,
            sweep_intrins,
            sweep_ida_mats,
            sweep_sensor2sensor_mats,
            bda_mat,
            sweep_timestamps,
            img_metas,
            gt_boxes,
            gt_labels,
        ]
        if self.return_depth:
            ret_list.append(image_data_list[7])
        return ret_list

    def __str__(self):
        return f"""NuscData: {len(self)} samples. Split: \
            {"train" if self.is_train else "val"}.
                    Augmentation Conf: {self.ida_aug_conf}"""

    def __len__(self):
        if self.use_cbgs:
            return len(self.sample_indices)
        else:
            return len(self.infos)


def collate_fn(data, is_return_depth=False):
    imgs_batch = list()
    sensor2ego_mats_batch = list()
    intrin_mats_batch = list()
    ida_mats_batch = list()
    sensor2sensor_mats_batch = list()
    bda_mat_batch = list()
    timestamps_batch = list()
    gt_boxes_batch = list()
    gt_labels_batch = list()
    img_metas_batch = list()
    depth_labels_batch = list()
    for iter_data in data:
        (
            sweep_imgs,
            sweep_sensor2ego_mats,
            sweep_intrins,
            sweep_ida_mats,
            sweep_sensor2sensor_mats,
            bda_mat,
            sweep_timestamps,
            img_metas,
            gt_boxes,
            gt_labels,
        ) = iter_data[:10]
        if is_return_depth:
            gt_depth = iter_data[10]
            depth_labels_batch.append(gt_depth)
        imgs_batch.append(sweep_imgs)
        sensor2ego_mats_batch.append(sweep_sensor2ego_mats)
        intrin_mats_batch.append(sweep_intrins)
        ida_mats_batch.append(sweep_ida_mats)
        sensor2sensor_mats_batch.append(sweep_sensor2sensor_mats)
        bda_mat_batch.append(bda_mat)
        timestamps_batch.append(sweep_timestamps)
        img_metas_batch.append(img_metas)
        gt_boxes_batch.append(gt_boxes)
        gt_labels_batch.append(gt_labels)
    mats_dict = dict()
    mats_dict['sensor2ego_mats'] = torch.stack(sensor2ego_mats_batch)
    mats_dict['intrin_mats'] = torch.stack(intrin_mats_batch)
    mats_dict['ida_mats'] = torch.stack(ida_mats_batch)
    mats_dict['sensor2sensor_mats'] = torch.stack(sensor2sensor_mats_batch)
    mats_dict['bda_mat'] = torch.stack(bda_mat_batch)
    ret_list = [
        torch.stack(imgs_batch),
        mats_dict,
        torch.stack(timestamps_batch),
        img_metas_batch,
        gt_boxes_batch,
        gt_labels_batch,
    ]
    if is_return_depth:
        ret_list.append(torch.stack(depth_labels_batch))
    return ret_list


================================================
FILE: bevdepth/evaluators/det_evaluators.py
================================================
'''Modified from # https://github.com/nutonomy/nuscenes-devkit/blob/57889ff20678577025326cfc24e57424a829be0a/python-sdk/nuscenes/eval/detection/evaluate.py#L222 # noqa
'''
import os.path as osp
import tempfile

import mmcv
import numpy as np
import pyquaternion
from nuscenes.utils.data_classes import Box
from pyquaternion import Quaternion

__all__ = ['DetNuscEvaluator']


class DetNuscEvaluator():
    ErrNameMapping = {
        'trans_err': 'mATE',
        'scale_err': 'mASE',
        'orient_err': 'mAOE',
        'vel_err': 'mAVE',
        'attr_err': 'mAAE',
    }

    DefaultAttribute = {
        'car': 'vehicle.parked',
        'pedestrian': 'pedestrian.moving',
        'trailer': 'vehicle.parked',
        'truck': 'vehicle.parked',
        'bus': 'vehicle.moving',
        'motorcycle': 'cycle.without_rider',
        'construction_vehicle': 'vehicle.parked',
        'bicycle': 'cycle.without_rider',
        'barrier': '',
        'traffic_cone': '',
    }

    def __init__(
        self,
        class_names,
        eval_version='detection_cvpr_2019',
        data_root='./data/nuScenes',
        version='v1.0-trainval',
        modality=dict(use_lidar=False,
                      use_camera=True,
                      use_radar=False,
                      use_map=False,
                      use_external=False),
        output_dir=None,
    ) -> None:
        self.eval_version = eval_version
        self.data_root = data_root
        if self.eval_version is not None:
            from nuscenes.eval.detection.config import config_factory

            self.eval_detection_configs = config_factory(self.eval_version)
        self.version = version
        self.class_names = class_names
        self.modality = modality
        self.output_dir = output_dir

    def _evaluate_single(self,
                         result_path,
                         logger=None,
                         metric='bbox',
                         result_name='pts_bbox'):
        """Evaluation for a single model in nuScenes protocol.

        Args:
            result_path (str): Path of the result file.
            logger (logging.Logger | str | None): Logger used for printing
                related information during evaluation. Default: None.
            metric (str): Metric name used for evaluation. Default: 'bbox'.
            result_name (str): Result name in the metric prefix.
                Default: 'pts_bbox'.

        Returns:
            dict: Dictionary of evaluation details.
        """
        from nuscenes import NuScenes
        from nuscenes.eval.detection.evaluate import NuScenesEval

        output_dir = osp.join(*osp.split(result_path)[:-1])
        nusc = NuScenes(version=self.version,
                        dataroot=self.data_root,
                        verbose=False)
        eval_set_map = {
            'v1.0-mini': 'mini_val',
            'v1.0-trainval': 'val',
        }
        nusc_eval = NuScenesEval(nusc,
                                 config=self.eval_detection_configs,
                                 result_path=result_path,
                                 eval_set=eval_set_map[self.version],
                                 output_dir=output_dir,
                                 verbose=False)
        nusc_eval.main(render_curves=False)

        # record metrics
        metrics = mmcv.load(osp.join(output_dir, 'metrics_summary.json'))
        detail = dict()
        metric_prefix = f'{result_name}_NuScenes'
        for class_name in self.class_names:
            for k, v in metrics['label_aps'][class_name].items():
                val = float('{:.4f}'.format(v))
                detail['{}/{}_AP_dist_{}'.format(metric_prefix, class_name,
                                                 k)] = val
            for k, v in metrics['label_tp_errors'][class_name].items():
                val = float('{:.4f}'.format(v))
                detail['{}/{}_{}'.format(metric_prefix, class_name, k)] = val
            for k, v in metrics['tp_errors'].items():
                val = float('{:.4f}'.format(v))
                detail['{}/{}'.format(metric_prefix,
                                      self.ErrNameMapping[k])] = val

        detail['{}/NDS'.format(metric_prefix)] = metrics['nd_score']
        detail['{}/mAP'.format(metric_prefix)] = metrics['mean_ap']
        return detail

    def format_results(self,
                       results,
                       img_metas,
                       result_names=['img_bbox'],
                       jsonfile_prefix=None,
                       **kwargs):
        """Format the results to json (standard format for COCO evaluation).

        Args:
            results (list[tuple | numpy.ndarray]): Testing results of the
                dataset.
            jsonfile_prefix (str | None): The prefix of json files. It includes
                the file path and the prefix of filename, e.g., "a/b/prefix".
                If not specified, a temp file will be created. Default: None.

        Returns:
            tuple: (result_files, tmp_dir), result_files is a dict containing \
                the json filepaths, tmp_dir is the temporal directory created \
                for saving json files when jsonfile_prefix is not specified.
        """
        assert isinstance(results, list), 'results must be a list'

        if jsonfile_prefix is None:
            tmp_dir = tempfile.TemporaryDirectory()
            jsonfile_prefix = osp.join(tmp_dir.name, 'results')
        else:
            tmp_dir = None

        # currently the output prediction results could be in two formats
        # 1. list of dict('boxes_3d': ..., 'scores_3d': ..., 'labels_3d': ...)
        # 2. list of dict('pts_bbox' or 'img_bbox':
        #     dict('boxes_3d': ..., 'scores_3d': ..., 'labels_3d': ...))
        # this is a workaround to enable evaluation of both formats on nuScenes
        # refer to https://github.com/open-mmlab/mmdetection3d/issues/449
        # should take the inner dict out of 'pts_bbox' or 'img_bbox' dict
        result_files = dict()
        # refactor this.
        for rasult_name in result_names:
            # not evaluate 2D predictions on nuScenes
            if '2d' in rasult_name:
                continue
            print(f'\nFormating bboxes of {rasult_name}')
            tmp_file_ = osp.join(jsonfile_prefix, rasult_name)
            if self.output_dir:
                result_files.update({
                    rasult_name:
                    self._format_bbox(results, img_metas, self.output_dir)
                })
            else:
                result_files.update({
                    rasult_name:
                    self._format_bbox(results, img_metas, tmp_file_)
                })
        return result_files, tmp_dir

    def evaluate(
        self,
        results,
        img_metas,
        metric='bbox',
        logger=None,
        jsonfile_prefix=None,
        result_names=['img_bbox'],
        show=False,
        out_dir=None,
        pipeline=None,
    ):
        """Evaluation in nuScenes protocol.

        Args:
            results (list[dict]): Testing results of the dataset.
            metric (str | list[str]): Metrics to be evaluated.
            logger (logging.Logger | str | None): Logger used for printing
                related information during evaluation. Default: None.
            jsonfile_prefix (str | None): The prefix of json files. It includes
                the file path and the prefix of filename, e.g., "a/b/prefix".
                If not specified, a temp file will be created. Default: None.
            show (bool): Whether to visualize.
                Default: False.
            out_dir (str): Path to save the visualization results.
                Default: None.
            pipeline (list[dict], optional): raw data loading for showing.
                Default: None.

        Returns:
            dict[str, float]: Results of each evaluation metric.
        """
        result_files, tmp_dir = self.format_results(results, img_metas,
                                                    result_names,
                                                    jsonfile_prefix)
        if isinstance(result_files, dict):
            for name in result_names:
                print('Evaluating bboxes of {}'.format(name))
                self._evaluate_single(result_files[name])
        elif isinstance(result_files, str):
            self._evaluate_single(result_files)

        if tmp_dir is not None:
            tmp_dir.cleanup()

    def _format_bbox(self, results, img_metas, jsonfile_prefix=None):
        """Convert the results to the standard format.

        Args:
            results (list[dict]): Testing results of the dataset.
            jsonfile_prefix (str): The prefix of the output jsonfile.
                You can specify the output directory/filename by
                modifying the jsonfile_prefix. Default: None.

        Returns:
            str: Path of the output json file.
        """
        nusc_annos = {}
        mapped_class_names = self.class_names

        print('Start to convert detection format...')

        for sample_id, det in enumerate(mmcv.track_iter_progress(results)):
            boxes, scores, labels = det
            boxes = boxes
            sample_token = img_metas[sample_id]['token']
            trans = np.array(img_metas[sample_id]['ego2global_translation'])
            rot = Quaternion(img_metas[sample_id]['ego2global_rotation'])
            annos = list()
            for i, box in enumerate(boxes):
                name = mapped_class_names[labels[i]]
                center = box[:3]
                wlh = box[[4, 3, 5]]
                box_yaw = box[6]
                box_vel = box[7:].tolist()
                box_vel.append(0)
                quat = pyquaternion.Quaternion(axis=[0, 0, 1], radians=box_yaw)
                nusc_box = Box(center, wlh, quat, velocity=box_vel)
                nusc_box.rotate(rot)
                nusc_box.translate(trans)
                if np.sqrt(nusc_box.velocity[0]**2 +
                           nusc_box.velocity[1]**2) > 0.2:
                    if name in [
                            'car',
                            'construction_vehicle',
                            'bus',
                            'truck',
                            'trailer',
                    ]:
                        attr = 'vehicle.moving'
                    elif name in ['bicycle', 'motorcycle']:
                        attr = 'cycle.with_rider'
                    else:
                        attr = self.DefaultAttribute[name]
                else:
                    if name in ['pedestrian']:
                        attr = 'pedestrian.standing'
                    elif name in ['bus']:
                        attr = 'vehicle.stopped'
                    else:
                        attr = self.DefaultAttribute[name]
                nusc_anno = dict(
                    sample_token=sample_token,
                    translation=nusc_box.center.tolist(),
                    size=nusc_box.wlh.tolist(),
                    rotation=nusc_box.orientation.elements.tolist(),
                    velocity=nusc_box.velocity[:2],
                    detection_name=name,
                    detection_score=float(scores[i]),
                    attribute_name=attr,
                )
                annos.append(nusc_anno)
            # other views results of the same frame should be concatenated
            if sample_token in nusc_annos:
                nusc_annos[sample_token].extend(annos)
            else:
                nusc_annos[sample_token] = annos
        nusc_submissions = {
            'meta': self.modality,
            'results': nusc_annos,
        }
        mmcv.mkdir_or_exist(jsonfile_prefix)
        res_path = osp.join(jsonfile_prefix, 'results_nusc.json')
        print('Results writes to', res_path)
        mmcv.dump(nusc_submissions, res_path)
        return res_path


================================================
FILE: bevdepth/exps/base_cli.py
================================================
# Copyright (c) Megvii Inc. All rights reserved.
import os
from argparse import ArgumentParser

import pytorch_lightning as pl

from bevdepth.callbacks.ema import EMACallback
from bevdepth.utils.torch_dist import all_gather_object, synchronize

from .nuscenes.base_exp import BEVDepthLightningModel


def run_cli(model_class=BEVDepthLightningModel,
            exp_name='base_exp',
            use_ema=False,
            extra_trainer_config_args={}):
    parent_parser = ArgumentParser(add_help=False)
    parent_parser = pl.Trainer.add_argparse_args(parent_parser)
    parent_parser.add_argument('-e',
                               '--evaluate',
                               dest='evaluate',
                               action='store_true',
                               help='evaluate model on validation set')
    parent_parser.add_argument('-p',
                               '--predict',
                               dest='predict',
                               action='store_true',
                               help='predict model on testing set')
    parent_parser.add_argument('-b', '--batch_size_per_device', type=int)
    parent_parser.add_argument('--seed',
                               type=int,
                               default=0,
                               help='seed for initializing training.')
    parent_parser.add_argument('--ckpt_path', type=str)
    parser = BEVDepthLightningModel.add_model_specific_args(parent_parser)
    parser.set_defaults(profiler='simple',
                        deterministic=False,
                        max_epochs=extra_trainer_config_args.get('epochs', 24),
                        accelerator='ddp',
                        num_sanity_val_steps=0,
                        gradient_clip_val=5,
                        limit_val_batches=0,
                        enable_checkpointing=True,
                        precision=16,
                        default_root_dir=os.path.join('./outputs/', exp_name))
    args = parser.parse_args()
    if args.seed is not None:
        pl.seed_everything(args.seed)

    model = model_class(**vars(args))
    if use_ema:
        train_dataloader = model.train_dataloader()
        ema_callback = EMACallback(
            len(train_dataloader.dataset) * args.max_epochs)
        trainer = pl.Trainer.from_argparse_args(args, callbacks=[ema_callback])
    else:
        trainer = pl.Trainer.from_argparse_args(args)
    if args.evaluate:
        trainer.test(model, ckpt_path=args.ckpt_path)
    elif args.predict:
        predict_step_outputs = trainer.predict(model, ckpt_path=args.ckpt_path)
        all_pred_results = list()
        all_img_metas = list()
        for predict_step_output in predict_step_outputs:
            for i in range(len(predict_step_output)):
                all_pred_results.append(predict_step_output[i][:3])
                all_img_metas.append(predict_step_output[i][3])
        synchronize()
        len_dataset = len(model.test_dataloader().dataset)
        all_pred_results = sum(
            map(list, zip(*all_gather_object(all_pred_results))),
            [])[:len_dataset]
        all_img_metas = sum(map(list, zip(*all_gather_object(all_img_metas))),
                            [])[:len_dataset]
        model.evaluator._format_bbox(all_pred_results, all_img_metas,
                                     os.path.dirname(args.ckpt_path))
    else:
        trainer.fit(model)


================================================
FILE: bevdepth/exps/nuscenes/MatrixVT/matrixvt_bev_depth_lss_r50_256x704_128x128_24e_ema.py
================================================
# Copyright (c) Megvii Inc. All rights reserved.
# isort: skip_file
from bevdepth.exps.base_cli import run_cli
# Basic Experiment
from bevdepth.exps.nuscenes.mv.bev_depth_lss_r50_256x704_128x128_24e_ema import \
    BEVDepthLightningModel as BaseExp # noqa
# new model
from bevdepth.models.matrixvt_det import MatrixVT_Det


class MatrixVT_Exp(BaseExp):

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.model = MatrixVT_Det(self.backbone_conf,
                                  self.head_conf,
                                  is_train_depth=True)
        self.data_use_cbgs = True


if __name__ == '__main__':
    run_cli(
        MatrixVT_Exp,
        'matrixvt_bev_depth_lss_r50_256x704_128x128_24e_ema_cbgs',
        use_ema=True,
    )


================================================
FILE: bevdepth/exps/nuscenes/base_exp.py
================================================
# Copyright (c) Megvii Inc. All rights reserved.
import os
from functools import partial

import mmcv
import torch
import torch.nn.functional as F
import torch.nn.parallel
import torch.utils.data
import torch.utils.data.distributed
import torchvision.models as models
from pytorch_lightning.core import LightningModule
from torch.cuda.amp.autocast_mode import autocast
from torch.optim.lr_scheduler import MultiStepLR

from bevdepth.datasets.nusc_det_dataset import NuscDetDataset, collate_fn
from bevdepth.evaluators.det_evaluators import DetNuscEvaluator
from bevdepth.models.base_bev_depth import BaseBEVDepth
from bevdepth.utils.torch_dist import all_gather_object, get_rank, synchronize

H = 900
W = 1600
final_dim = (256, 704)
img_conf = dict(img_mean=[123.675, 116.28, 103.53],
                img_std=[58.395, 57.12, 57.375],
                to_rgb=True)

backbone_conf = {
    'x_bound': [-51.2, 51.2, 0.8],
    'y_bound': [-51.2, 51.2, 0.8],
    'z_bound': [-5, 3, 8],
    'd_bound': [2.0, 58.0, 0.5],
    'final_dim':
    final_dim,
    'output_channels':
    80,
    'downsample_factor':
    16,
    'img_backbone_conf':
    dict(
        type='ResNet',
        depth=50,
        frozen_stages=0,
        out_indices=[0, 1, 2, 3],
        norm_eval=False,
        init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50'),
    ),
    'img_neck_conf':
    dict(
        type='SECONDFPN',
        in_channels=[256, 512, 1024, 2048],
        upsample_strides=[0.25, 0.5, 1, 2],
        out_channels=[128, 128, 128, 128],
    ),
    'depth_net_conf':
    dict(in_channels=512, mid_channels=512)
}
ida_aug_conf = {
    'resize_lim': (0.386, 0.55),
    'final_dim':
    final_dim,
    'rot_lim': (-5.4, 5.4),
    'H':
    H,
    'W':
    W,
    'rand_flip':
    True,
    'bot_pct_lim': (0.0, 0.0),
    'cams': [
        'CAM_FRONT_LEFT', 'CAM_FRONT', 'CAM_FRONT_RIGHT', 'CAM_BACK_LEFT',
        'CAM_BACK', 'CAM_BACK_RIGHT'
    ],
    'Ncams':
    6,
}

bda_aug_conf = {
    'rot_lim': (-22.5, 22.5),
    'scale_lim': (0.95, 1.05),
    'flip_dx_ratio': 0.5,
    'flip_dy_ratio': 0.5
}

bev_backbone = dict(
    type='ResNet',
    in_channels=80,
    depth=18,
    num_stages=3,
    strides=(1, 2, 2),
    dilations=(1, 1, 1),
    out_indices=[0, 1, 2],
    norm_eval=False,
    base_channels=160,
)

bev_neck = dict(type='SECONDFPN',
                in_channels=[80, 160, 320, 640],
                upsample_strides=[1, 2, 4, 8],
                out_channels=[64, 64, 64, 64])

CLASSES = [
    'car',
    'truck',
    'construction_vehicle',
    'bus',
    'trailer',
    'barrier',
    'motorcycle',
    'bicycle',
    'pedestrian',
    'traffic_cone',
]

TASKS = [
    dict(num_class=1, class_names=['car']),
    dict(num_class=2, class_names=['truck', 'construction_vehicle']),
    dict(num_class=2, class_names=['bus', 'trailer']),
    dict(num_class=1, class_names=['barrier']),
    dict(num_class=2, class_names=['motorcycle', 'bicycle']),
    dict(num_class=2, class_names=['pedestrian', 'traffic_cone']),
]

common_heads = dict(reg=(2, 2),
                    height=(1, 2),
                    dim=(3, 2),
                    rot=(2, 2),
                    vel=(2, 2))

bbox_coder = dict(
    type='CenterPointBBoxCoder',
    post_center_range=[-61.2, -61.2, -10.0, 61.2, 61.2, 10.0],
    max_num=500,
    score_threshold=0.1,
    out_size_factor=4,
    voxel_size=[0.2, 0.2, 8],
    pc_range=[-51.2, -51.2, -5, 51.2, 51.2, 3],
    code_size=9,
)

train_cfg = dict(
    point_cloud_range=[-51.2, -51.2, -5, 51.2, 51.2, 3],
    grid_size=[512, 512, 1],
    voxel_size=[0.2, 0.2, 8],
    out_size_factor=4,
    dense_reg=1,
    gaussian_overlap=0.1,
    max_objs=500,
    min_radius=2,
    code_weights=[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.5, 0.5],
)

test_cfg = dict(
    post_center_limit_range=[-61.2, -61.2, -10.0, 61.2, 61.2, 10.0],
    max_per_img=500,
    max_pool_nms=False,
    min_radius=[4, 12, 10, 1, 0.85, 0.175],
    score_threshold=0.1,
    out_size_factor=4,
    voxel_size=[0.2, 0.2, 8],
    nms_type='circle',
    pre_max_size=1000,
    post_max_size=83,
    nms_thr=0.2,
)

head_conf = {
    'bev_backbone_conf': bev_backbone,
    'bev_neck_conf': bev_neck,
    'tasks': TASKS,
    'common_heads': common_heads,
    'bbox_coder': bbox_coder,
    'train_cfg': train_cfg,
    'test_cfg': test_cfg,
    'in_channels': 256,  # Equal to bev_neck output_channels.
    'loss_cls': dict(type='GaussianFocalLoss', reduction='mean'),
    'loss_bbox': dict(type='L1Loss', reduction='mean', loss_weight=0.25),
    'gaussian_overlap': 0.1,
    'min_radius': 2,
}


class BEVDepthLightningModel(LightningModule):
    MODEL_NAMES = sorted(name for name in models.__dict__
                         if name.islower() and not name.startswith('__')
                         and callable(models.__dict__[name]))

    def __init__(self,
                 gpus: int = 1,
                 data_root='data/nuScenes',
                 eval_interval=1,
                 batch_size_per_device=8,
                 class_names=CLASSES,
                 backbone_conf=backbone_conf,
                 head_conf=head_conf,
                 ida_aug_conf=ida_aug_conf,
                 bda_aug_conf=bda_aug_conf,
                 default_root_dir='./outputs/',
                 **kwargs):
        super().__init__()
        self.save_hyperparameters()
        self.gpus = gpus
        self.eval_interval = eval_interval
        self.batch_size_per_device = batch_size_per_device
        self.data_root = data_root
        self.basic_lr_per_img = 2e-4 / 64
        self.class_names = class_names
        self.backbone_conf = backbone_conf
        self.head_conf = head_conf
        self.ida_aug_conf = ida_aug_conf
        self.bda_aug_conf = bda_aug_conf
        mmcv.mkdir_or_exist(default_root_dir)
        self.default_root_dir = default_root_dir
        self.evaluator = DetNuscEvaluator(class_names=self.class_names,
                                          output_dir=self.default_root_dir)
        self.model = BaseBEVDepth(self.backbone_conf,
                                  self.head_conf,
                                  is_train_depth=True)
        self.mode = 'valid'
        self.img_conf = img_conf
        self.data_use_cbgs = False
        self.num_sweeps = 1
        self.sweep_idxes = list()
        self.key_idxes = list()
        self.data_return_depth = True
        self.downsample_factor = self.backbone_conf['downsample_factor']
        self.dbound = self.backbone_conf['d_bound']
        self.depth_channels = int(
            (self.dbound[1] - self.dbound[0]) / self.dbound[2])
        self.use_fusion = False
        self.train_info_paths = os.path.join(self.data_root,
                                             'nuscenes_infos_train.pkl')
        self.val_info_paths = os.path.join(self.data_root,
                                           'nuscenes_infos_val.pkl')
        self.predict_info_paths = os.path.join(self.data_root,
                                               'nuscenes_infos_test.pkl')

    def forward(self, sweep_imgs, mats):
        return self.model(sweep_imgs, mats)

    def training_step(self, batch):
        (sweep_imgs, mats, _, _, gt_boxes, gt_labels, depth_labels) = batch
        if torch.cuda.is_available():
            for key, value in mats.items():
                mats[key] = value.cuda()
            sweep_imgs = sweep_imgs.cuda()
            gt_boxes = [gt_box.cuda() for gt_box in gt_boxes]
            gt_labels = [gt_label.cuda() for gt_label in gt_labels]
        preds, depth_preds = self(sweep_imgs, mats)
        if isinstance(self.model, torch.nn.parallel.DistributedDataParallel):
            targets = self.model.module.get_targets(gt_boxes, gt_labels)
            detection_loss = self.model.module.loss(targets, preds)
        else:
            targets = self.model.get_targets(gt_boxes, gt_labels)
            detection_loss = self.model.loss(targets, preds)

        if len(depth_labels.shape) == 5:
            # only key-frame will calculate depth loss
            depth_labels = depth_labels[:, 0, ...]
        depth_loss = self.get_depth_loss(depth_labels.cuda(), depth_preds)
        self.log('detection_loss', detection_loss)
        self.log('depth_loss', depth_loss)
        return detection_loss + depth_loss

    def get_depth_loss(self, depth_labels, depth_preds):
        depth_labels = self.get_downsampled_gt_depth(depth_labels)
        depth_preds = depth_preds.permute(0, 2, 3, 1).contiguous().view(
            -1, self.depth_channels)
        fg_mask = torch.max(depth_labels, dim=1).values > 0.0

        with autocast(enabled=False):
            depth_loss = (F.binary_cross_entropy(
                depth_preds[fg_mask],
                depth_labels[fg_mask],
                reduction='none',
            ).sum() / max(1.0, fg_mask.sum()))

        return 3.0 * depth_loss

    def get_downsampled_gt_depth(self, gt_depths):
        """
        Input:
            gt_depths: [B, N, H, W]
        Output:
            gt_depths: [B*N*h*w, d]
        """
        B, N, H, W = gt_depths.shape
        gt_depths = gt_depths.view(
            B * N,
            H // self.downsample_factor,
            self.downsample_factor,
            W // self.downsample_factor,
            self.downsample_factor,
            1,
        )
        gt_depths = gt_depths.permute(0, 1, 3, 5, 2, 4).contiguous()
        gt_depths = gt_depths.view(
            -1, self.downsample_factor * self.downsample_factor)
        gt_depths_tmp = torch.where(gt_depths == 0.0,
                                    1e5 * torch.ones_like(gt_depths),
                                    gt_depths)
        gt_depths = torch.min(gt_depths_tmp, dim=-1).values
        gt_depths = gt_depths.view(B * N, H // self.downsample_factor,
                                   W // self.downsample_factor)

        gt_depths = (gt_depths -
                     (self.dbound[0] - self.dbound[2])) / self.dbound[2]
        gt_depths = torch.where(
            (gt_depths < self.depth_channels + 1) & (gt_depths >= 0.0),
            gt_depths, torch.zeros_like(gt_depths))
        gt_depths = F.one_hot(gt_depths.long(),
                              num_classes=self.depth_channels + 1).view(
                                  -1, self.depth_channels + 1)[:, 1:]

        return gt_depths.float()

    def eval_step(self, batch, batch_idx, prefix: str):
        (sweep_imgs, mats, _, img_metas, _, _) = batch
        if torch.cuda.is_available():
            for key, value in mats.items():
                mats[key] = value.cuda()
            sweep_imgs = sweep_imgs.cuda()
        preds = self.model(sweep_imgs, mats)
        if isinstance(self.model, torch.nn.parallel.DistributedDataParallel):
            results = self.model.module.get_bboxes(preds, img_metas)
        else:
            results = self.model.get_bboxes(preds, img_metas)
        for i in range(len(results)):
            results[i][0] = results[i][0].detach().cpu().numpy()
            results[i][1] = results[i][1].detach().cpu().numpy()
            results[i][2] = results[i][2].detach().cpu().numpy()
            results[i].append(img_metas[i])
        return results

    def validation_step(self, batch, batch_idx):
        return self.eval_step(batch, batch_idx, 'val')

    def validation_epoch_end(self, validation_step_outputs):
        all_pred_results = list()
        all_img_metas = list()
        for validation_step_output in validation_step_outputs:
            for i in range(len(validation_step_output)):
                all_pred_results.append(validation_step_output[i][:3])
                all_img_metas.append(validation_step_output[i][3])
        synchronize()
        len_dataset = len(self.val_dataloader().dataset)
        all_pred_results = sum(
            map(list, zip(*all_gather_object(all_pred_results))),
            [])[:len_dataset]
        all_img_metas = sum(map(list, zip(*all_gather_object(all_img_metas))),
                            [])[:len_dataset]
        if get_rank() == 0:
            self.evaluator.evaluate(all_pred_results, all_img_metas)

    def test_epoch_end(self, test_step_outputs):
        all_pred_results = list()
        all_img_metas = list()
        for test_step_output in test_step_outputs:
            for i in range(len(test_step_output)):
                all_pred_results.append(test_step_output[i][:3])
                all_img_metas.append(test_step_output[i][3])
        synchronize()
        # TODO: Change another way.
        dataset_length = len(self.val_dataloader().dataset)
        all_pred_results = sum(
            map(list, zip(*all_gather_object(all_pred_results))),
            [])[:dataset_length]
        all_img_metas = sum(map(list, zip(*all_gather_object(all_img_metas))),
                            [])[:dataset_length]
        if get_rank() == 0:
            self.evaluator.evaluate(all_pred_results, all_img_metas)

    def configure_optimizers(self):
        lr = self.basic_lr_per_img * \
            self.batch_size_per_device * self.gpus
        optimizer = torch.optim.AdamW(self.model.parameters(),
                                      lr=lr,
                                      weight_decay=1e-7)
        scheduler = MultiStepLR(optimizer, [19, 23])
        return [[optimizer], [scheduler]]

    def train_dataloader(self):
        train_dataset = NuscDetDataset(ida_aug_conf=self.ida_aug_conf,
                                       bda_aug_conf=self.bda_aug_conf,
                                       classes=self.class_names,
                                       data_root=self.data_root,
                                       info_paths=self.train_info_paths,
                                       is_train=True,
                                       use_cbgs=self.data_use_cbgs,
                                       img_conf=self.img_conf,
                                       num_sweeps=self.num_sweeps,
                                       sweep_idxes=self.sweep_idxes,
                                       key_idxes=self.key_idxes,
                                       return_depth=self.data_return_depth,
                                       use_fusion=self.use_fusion)

        train_loader = torch.utils.data.DataLoader(
            train_dataset,
            batch_size=self.batch_size_per_device,
            num_workers=4,
            drop_last=True,
            shuffle=False,
            collate_fn=partial(collate_fn,
                               is_return_depth=self.data_return_depth
                               or self.use_fusion),
            sampler=None,
        )
        return train_loader

    def val_dataloader(self):
        val_dataset = NuscDetDataset(ida_aug_conf=self.ida_aug_conf,
                                     bda_aug_conf=self.bda_aug_conf,
                                     classes=self.class_names,
                                     data_root=self.data_root,
                                     info_paths=self.val_info_paths,
                                     is_train=False,
                                     img_conf=self.img_conf,
                                     num_sweeps=self.num_sweeps,
                                     sweep_idxes=self.sweep_idxes,
                                     key_idxes=self.key_idxes,
                                     return_depth=self.use_fusion,
                                     use_fusion=self.use_fusion)
        val_loader = torch.utils.data.DataLoader(
            val_dataset,
            batch_size=self.batch_size_per_device,
            shuffle=False,
            collate_fn=partial(collate_fn, is_return_depth=self.use_fusion),
            num_workers=4,
            sampler=None,
        )
        return val_loader

    def test_dataloader(self):
        return self.val_dataloader()

    def predict_dataloader(self):
        predict_dataset = NuscDetDataset(ida_aug_conf=self.ida_aug_conf,
                                         bda_aug_conf=self.bda_aug_conf,
                                         classes=self.class_names,
                                         data_root=self.data_root,
                                         info_paths=self.predict_info_paths,
                                         is_train=False,
                                         img_conf=self.img_conf,
                                         num_sweeps=self.num_sweeps,
                                         sweep_idxes=self.sweep_idxes,
                                         key_idxes=self.key_idxes,
                                         return_depth=self.use_fusion,
                                         use_fusion=self.use_fusion)
        predict_loader = torch.utils.data.DataLoader(
            predict_dataset,
            batch_size=self.batch_size_per_device,
            shuffle=False,
            collate_fn=partial(collate_fn, is_return_depth=self.use_fusion),
            num_workers=4,
            sampler=None,
        )
        return predict_loader

    def test_step(self, batch, batch_idx):
        return self.eval_step(batch, batch_idx, 'test')

    def predict_step(self, batch, batch_idx):
        return self.eval_step(batch, batch_idx, 'predict')

    @staticmethod
    def add_model_specific_args(parent_parser):  # pragma: no-cover
        return parent_parser


================================================
FILE: bevdepth/exps/nuscenes/fusion/bev_depth_fusion_lss_r50_256x704_128x128_24e.py
================================================
# Copyright (c) Megvii Inc. All rights reserved.
import torch
import torch.nn.parallel
import torch.utils.data
import torch.utils.data.distributed

from bevdepth.exps.base_cli import run_cli
from bevdepth.exps.nuscenes.base_exp import \
    BEVDepthLightningModel as BaseBEVDepthLightningModel
from bevdepth.models.fusion_bev_depth import FusionBEVDepth


class BEVDepthLightningModel(BaseBEVDepthLightningModel):

    def __init__(self, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        self.model = FusionBEVDepth(self.backbone_conf,
                                    self.head_conf,
                                    is_train_depth=False)
        self.use_fusion = True

    def forward(self, sweep_imgs, mats, lidar_depth):
        return self.model(sweep_imgs, mats, lidar_depth)

    def training_step(self, batch):
        (sweep_imgs, mats, _, _, gt_boxes, gt_labels, lidar_depth) = batch
        if torch.cuda.is_available():
            for key, value in mats.items():
                mats[key] = value.cuda()
            sweep_imgs = sweep_imgs.cuda()
            gt_boxes = [gt_box.cuda() for gt_box in gt_boxes]
            gt_labels = [gt_label.cuda() for gt_label in gt_labels]
        preds = self(sweep_imgs, mats, lidar_depth)
        if isinstance(self.model, torch.nn.parallel.DistributedDataParallel):
            targets = self.model.module.get_targets(gt_boxes, gt_labels)
            detection_loss = self.model.module.loss(targets, preds)
        else:
            targets = self.model.get_targets(gt_boxes, gt_labels)
            detection_loss = self.model.loss(targets, preds)

        if len(lidar_depth.shape) == 5:
            # only key-frame will calculate depth loss
            lidar_depth = lidar_depth[:, 0, ...]
        self.log('detection_loss', detection_loss)
        return detection_loss

    def eval_step(self, batch, batch_idx, prefix: str):
        (sweep_imgs, mats, _, img_metas, _, _, lidar_depth) = batch
        if torch.cuda.is_available():
            for key, value in mats.items():
                mats[key] = value.cuda()
            sweep_imgs = sweep_imgs.cuda()
        preds = self.model(sweep_imgs, mats, lidar_depth)
        if isinstance(self.model, torch.nn.parallel.DistributedDataParallel):
            results = self.model.module.get_bboxes(preds, img_metas)
        else:
            results = self.model.get_bboxes(preds, img_metas)
        for i in range(len(results)):
            results[i][0] = results[i][0].detach().cpu().numpy()
            results[i][1] = results[i][1].detach().cpu().numpy()
            results[i][2] = results[i][2].detach().cpu().numpy()
            results[i].append(img_metas[i])
        return results


if __name__ == '__main__':
    run_cli(BEVDepthLightningModel,
            'bev_depth_fusion_lss_r50_256x704_128x128_24e')


================================================
FILE: bevdepth/exps/nuscenes/fusion/bev_depth_fusion_lss_r50_256x704_128x128_24e_2key.py
================================================
# Copyright (c) Megvii Inc. All rights reserved.
from bevdepth.exps.base_cli import run_cli
from bevdepth.exps.nuscenes.fusion.bev_depth_fusion_lss_r50_256x704_128x128_24e import \
    BEVDepthLightningModel as BaseBEVDepthLightningModel  # noqa
from bevdepth.models.fusion_bev_depth import FusionBEVDepth


class BEVDepthLightningModel(BaseBEVDepthLightningModel):

    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.key_idxes = [-1]
        self.head_conf['bev_backbone_conf']['in_channels'] = 80 * (
            len(self.key_idxes) + 1)
        self.head_conf['bev_neck_conf']['in_channels'] = [
            80 * (len(self.key_idxes) + 1), 160, 320, 640
        ]
        self.head_conf['train_cfg']['code_weight'] = [
            1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0
        ]
        self.model = FusionBEVDepth(self.backbone_conf, self.head_conf)


if __name__ == '__main__':
    run_cli(BEVDepthLightningModel,
            'bev_depth_fusion_lss_r50_256x704_128x128_24e_2key')


================================================
FILE: bevdepth/exps/nuscenes/fusion/bev_depth_fusion_lss_r50_256x704_128x128_24e_2key_trainval.py
================================================
# Copyright (c) Megvii Inc. All rights reserved.
from bevdepth.exps.base_cli import run_cli

from .bev_depth_fusion_lss_r50_256x704_128x128_24e_2key import \
    BEVDepthLightningModel as BaseBEVDepthLightningModel


class BEVDepthLightningModel(BaseBEVDepthLightningModel):

    def __init__(self, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        self.train_info_paths = [
            'data/nuScenes/nuscenes_infos_train.pkl',
            'data/nuScenes/nuscenes_infos_val.pkl'
        ]


if __name__ == '__main__':
    run_cli(BEVDepthLightningModel,
            'bev_depth_fusion_lss_r50_256x704_128x128_24e_2key_trainval')


================================================
FILE: bevdepth/exps/nuscenes/fusion/bev_depth_fusion_lss_r50_256x704_128x128_24e_key4.py
================================================
# Copyright (c) Megvii Inc. All rights reserved.
from bevdepth.exps.base_cli import run_cli
from bevdepth.exps.nuscenes.fusion.bev_depth_fusion_lss_r50_256x704_128x128_24e import \
    BEVDepthLightningModel as BaseBEVDepthLightningModel  # noqa
from bevdepth.models.fusion_bev_depth import FusionBEVDepth


class BEVDepthLightningModel(BaseBEVDepthLightningModel):

    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.sweep_idxes = [4]
        self.head_conf['bev_backbone_conf']['in_channels'] = 80 * (
            len(self.sweep_idxes) + 1)
        self.head_conf['bev_neck_conf']['in_channels'] = [
            80 * (len(self.sweep_idxes) + 1), 160, 320, 640
        ]
        self.head_conf['train_cfg']['code_weight'] = [
            1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0
        ]
        self.model = FusionBEVDepth(self.backbone_conf,
                                    self.head_conf,
                                    is_train_depth=False)


if __name__ == '__main__':
    run_cli(BEVDepthLightningModel,
            'bev_depth_fusion_lss_r50_256x704_128x128_24e_key4')


================================================
FILE: bevdepth/exps/nuscenes/mv/bev_depth_lss_r50_256x704_128x128_20e_cbgs_2key_da.py
================================================
# Copyright (c) Megvii Inc. All rights reserved.
"""
mAP: 0.3484
mATE: 0.6159
mASE: 0.2716
mAOE: 0.4144
mAVE: 0.4402
mAAE: 0.1954
NDS: 0.4805
Eval time: 110.7s
Per-class results:
Object Class    AP      ATE     ASE     AOE     AVE     AAE
car     0.553   0.480   0.157   0.117   0.386   0.205
truck   0.252   0.645   0.202   0.097   0.381   0.185
bus     0.378   0.674   0.197   0.090   0.871   0.298
trailer 0.163   0.932   0.230   0.409   0.543   0.098
construction_vehicle    0.076   0.878   0.495   1.015   0.103   0.344
pedestrian      0.361   0.694   0.300   0.816   0.491   0.247
motorcycle      0.319   0.569   0.252   0.431   0.552   0.181
bicycle 0.286   0.457   0.255   0.630   0.194   0.006
traffic_cone    0.536   0.438   0.339   nan     nan     nan
barrier 0.559   0.392   0.289   0.124   nan     nan
"""
import torch
from torch.optim.lr_scheduler import MultiStepLR

from bevdepth.exps.base_cli import run_cli
from bevdepth.exps.nuscenes.mv.bev_depth_lss_r50_256x704_128x128_24e_2key import \
    BEVDepthLightningModel as BaseBEVDepthLightningModel  # noqa
from bevdepth.models.base_bev_depth import BaseBEVDepth as BaseBEVDepth


class BEVDepthLightningModel(BaseBEVDepthLightningModel):

    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.backbone_conf['use_da'] = True
        self.data_use_cbgs = True
        self.model = BaseBEVDepth(self.backbone_conf,
                                  self.head_conf,
                                  is_train_depth=True)

    def configure_optimizers(self):
        lr = self.basic_lr_per_img * \
            self.batch_size_per_device * self.gpus
        optimizer = torch.optim.AdamW(self.model.parameters(),
                                      lr=lr,
                                      weight_decay=1e-7)
        scheduler = MultiStepLR(optimizer, [16, 19])
        return [[optimizer], [scheduler]]


if __name__ == '__main__':
    run_cli(BEVDepthLightningModel,
            'bev_depth_lss_r50_256x704_128x128_20e_cbgs_2key_da',
            extra_trainer_config_args={'epochs': 20})


================================================
FILE: bevdepth/exps/nuscenes/mv/bev_depth_lss_r50_256x704_128x128_20e_cbgs_2key_da_ema.py
================================================
# Copyright (c) Megvii Inc. All rights reserved.
"""
mAP: 0.3589
mATE: 0.6119
mASE: 0.2692
mAOE: 0.5074
mAVE: 0.4086
mAAE: 0.2009
NDS: 0.4797
Eval time: 183.3s
Per-class results:
Object Class    AP      ATE     ASE     AOE     AVE     AAE
car     0.559   0.475   0.157   0.112   0.370   0.205
truck   0.270   0.659   0.196   0.103   0.356   0.181
bus     0.374   0.651   0.184   0.072   0.846   0.326
trailer 0.179   0.963   0.227   0.512   0.294   0.127
construction_vehicle    0.081   0.825   0.481   1.352   0.094   0.345
pedestrian      0.363   0.690   0.297   0.831   0.491   0.244
motorcycle      0.354   0.580   0.255   0.545   0.615   0.164
bicycle 0.301   0.447   0.280   0.920   0.203   0.015
traffic_cone    0.539   0.435   0.324   nan     nan     nan
barrier 0.569   0.394   0.293   0.120   nan     nan
"""
from bevdepth.exps.base_cli import run_cli
from bevdepth.exps.nuscenes.mv.bev_depth_lss_r50_256x704_128x128_20e_cbgs_2key_da import \
    BEVDepthLightningModel  # noqa

if __name__ == '__main__':
    run_cli(BEVDepthLightningModel,
            'bev_depth_lss_r50_256x704_128x128_20e_cbgs_2key_da_ema',
            use_ema=True,
            extra_trainer_config_args={'epochs': 20})


================================================
FILE: bevdepth/exps/nuscenes/mv/bev_depth_lss_r50_256x704_128x128_24e_2key.py
================================================
# Copyright (c) Megvii Inc. All rights reserved.
"""
mAP: 0.3304
mATE: 0.7021
mASE: 0.2795
mAOE: 0.5346
mAVE: 0.5530
mAAE: 0.2274
NDS: 0.4355
Eval time: 171.8s

Per-class results:
Object Class    AP      ATE     ASE     AOE     AVE     AAE
car     0.499   0.540   0.165   0.211   0.650   0.233
truck   0.278   0.719   0.218   0.265   0.547   0.215
bus     0.386   0.661   0.211   0.171   1.132   0.274
trailer 0.168   1.034   0.235   0.548   0.408   0.168
construction_vehicle    0.075   1.124   0.510   1.177   0.111   0.385
pedestrian      0.284   0.757   0.298   0.966   0.578   0.301
motorcycle      0.335   0.624   0.263   0.621   0.734   0.237
bicycle 0.305   0.554   0.264   0.653   0.263   0.006
traffic_cone    0.462   0.516   0.355   nan     nan     nan
barrier 0.512   0.491   0.275   0.200   nan     nan
"""
from bevdepth.exps.base_cli import run_cli
from bevdepth.exps.nuscenes.base_exp import \
    BEVDepthLightningModel as BaseBEVDepthLightningModel
from bevdepth.models.base_bev_depth import BaseBEVDepth


class BEVDepthLightningModel(BaseBEVDepthLightningModel):

    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.key_idxes = [-1]
        self.head_conf['bev_backbone_conf']['in_channels'] = 80 * (
            len(self.key_idxes) + 1)
        self.head_conf['bev_neck_conf']['in_channels'] = [
            80 * (len(self.key_idxes) + 1), 160, 320, 640
        ]
        self.head_conf['train_cfg']['code_weights'] = [
            1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0
        ]
        self.model = BaseBEVDepth(self.backbone_conf,
                                  self.head_conf,
                                  is_train_depth=True)


if __name__ == '__main__':
    run_cli(BEVDepthLightningModel,
            'bev_depth_lss_r50_256x704_128x128_24e_2key')


================================================
FILE: bevdepth/exps/nuscenes/mv/bev_depth_lss_r50_256x704_128x128_24e_2key_ema.py
================================================
# Copyright (c) Megvii Inc. All rights reserved.
"""
mAP: 0.3329
mATE: 0.6832
mASE: 0.2761
mAOE: 0.5446
mAVE: 0.5258
mAAE: 0.2259
NDS: 0.4409

Per-class results:
Object Class    AP      ATE     ASE     AOE     AVE     AAE
car     0.505   0.531   0.165   0.189   0.618   0.234
truck   0.274   0.731   0.206   0.211   0.546   0.223
bus     0.394   0.673   0.219   0.148   1.061   0.274
trailer 0.174   0.934   0.228   0.544   0.369   0.183
construction_vehicle    0.079   1.043   0.528   1.162   0.112   0.376
pedestrian      0.284   0.748   0.294   0.973   0.575   0.297
motorcycle      0.345   0.633   0.256   0.719   0.667   0.214
bicycle 0.314   0.544   0.252   0.778   0.259   0.007
traffic_cone    0.453   0.519   0.335   nan     nan     nan
barrier 0.506   0.475   0.279   0.178   nan     nan
"""
from bevdepth.exps.base_cli import run_cli
from bevdepth.exps.nuscenes.mv.bev_depth_lss_r50_256x704_128x128_24e_2key import \
    BEVDepthLightningModel  # noqa

if __name__ == '__main__':
    run_cli(BEVDepthLightningModel,
            'bev_depth_lss_r50_256x704_128x128_24e_2key_ema',
            use_ema=True)


================================================
FILE: bevdepth/exps/nuscenes/mv/bev_depth_lss_r50_256x704_128x128_24e_ema.py
================================================
# Copyright (c) Megvii Inc. All rights reserved.
import torch
import torch.nn.parallel
import torch.utils.data
import torch.utils.data.distributed

from bevdepth.exps.base_cli import run_cli
from bevdepth.exps.nuscenes.base_exp import \
    BEVDepthLightningModel as BaseBEVDepthLightningModel


class BEVDepthLightningModel(BaseBEVDepthLightningModel):

    def configure_optimizers(self):
        lr = self.basic_lr_per_img * \
            self.batch_size_per_device * self.gpus
        optimizer = torch.optim.AdamW(self.model.parameters(),
                                      lr=lr,
                                      weight_decay=1e-7)
        return [optimizer]


if __name__ == '__main__':
    run_cli(BEVDepthLightningModel,
            'bev_depth_lss_r50_256x704_128x128_24e_ema',
            use_ema=True)


================================================
FILE: bevdepth/exps/nuscenes/mv/bev_depth_lss_r50_512x1408_128x128_24e_2key.py
================================================
# Copyright (c) Megvii Inc. All rights reserved.
import torch
from torch.optim.lr_scheduler import MultiStepLR

from bevdepth.exps.base_cli import run_cli
from bevdepth.exps.nuscenes.mv.bev_depth_lss_r50_256x704_128x128_24e_2key import \
    BEVDepthLightningModel as BaseBEVDepthLightningModel  # noqa
from bevdepth.models.base_bev_depth import BaseBEVDepth


class BEVDepthLightningModel(BaseBEVDepthLightningModel):

    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        final_dim = (512, 1408)
        self.backbone_conf['final_dim'] = final_dim
        self.ida_aug_conf['resize_lim'] = (0.386 * 2, 0.55 * 2)
        self.ida_aug_conf['final_dim'] = final_dim
        self.model = BaseBEVDepth(self.backbone_conf,
                                  self.head_conf,
                                  is_train_depth=True)

    def configure_optimizers(self):
        lr = self.basic_lr_per_img * \
            self.batch_size_per_device * self.gpus
        optimizer = torch.optim.AdamW(self.model.parameters(),
                                      lr=lr,
                                      weight_decay=1e-3)
        scheduler = MultiStepLR(optimizer, [19, 23])
        return [[optimizer], [scheduler]]


if __name__ == '__main__':
    run_cli(BEVDepthLightningModel,
            'bev_depth_lss_r50_512x1408_128x128_24e_2key')


================================================
FILE: bevdepth/exps/nuscenes/mv/bev_depth_lss_r50_640x1600_128x128_24e_2key.py
================================================
# Copyright (c) Megvii Inc. All rights reserved.
import torch
from torch.optim.lr_scheduler import MultiStepLR

from bevdepth.exps.base_cli import run_cli
from bevdepth.exps.nuscenes.mv.bev_depth_lss_r50_256x704_128x128_24e_2key import \
    BEVDepthLightningModel as BaseBEVDepthLightningModel  # noqa
from bevdepth.models.base_bev_depth import BaseBEVDepth


class BEVDepthLightningModel(BaseBEVDepthLightningModel):

    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        final_dim = (640, 1600)
        self.backbone_conf['final_dim'] = final_dim
        self.ida_aug_conf['resize_lim'] = (0.386 * 2, 0.55 * 2)
        self.ida_aug_conf['final_dim'] = final_dim
        self.model = BaseBEVDepth(self.backbone_conf,
                                  self.head_conf,
                                  is_train_depth=True)

    def configure_optimizers(self):
        lr = self.basic_lr_per_img * \
            self.batch_size_per_device * self.gpus
        optimizer = torch.optim.AdamW(self.model.parameters(),
                                      lr=lr,
                                      weight_decay=1e-3)
        scheduler = MultiStepLR(optimizer, [19, 23])
        return [[optimizer], [scheduler]]


if __name__ == '__main__':
    run_cli(BEVDepthLightningModel,
            'bev_depth_lss_r50_512x1408_128x128_24e_2key')


================================================
FILE: bevdepth/exps/nuscenes/mv/bev_stereo_lss_r50_256x704_128x128_20e_cbgs_2key_da.py
================================================
# Copyright (c) Megvii Inc. All rights reserved.
"""
mAP: 0.3576
mATE: 0.6071
mASE: 0.2684
mAOE: 0.4157
mAVE: 0.3928
mAAE: 0.2021
NDS: 0.4902
Eval time: 129.7s
Per-class results:
Object Class    AP      ATE     ASE     AOE     AVE     AAE
car     0.559   0.465   0.157   0.110   0.350   0.205
truck   0.285   0.633   0.205   0.101   0.304   0.209
bus     0.373   0.667   0.204   0.076   0.896   0.345
trailer 0.167   0.956   0.228   0.482   0.289   0.100
construction_vehicle    0.077   0.869   0.454   1.024   0.108   0.335
pedestrian      0.402   0.652   0.299   0.821   0.493   0.253
motorcycle      0.321   0.544   0.255   0.484   0.529   0.159
bicycle 0.276   0.466   0.272   0.522   0.173   0.011
traffic_cone    0.551   0.432   0.321   nan     nan     nan
barrier 0.565   0.386   0.287   0.121   nan     nan
"""
import torch
from torch.optim.lr_scheduler import MultiStepLR

from bevdepth.exps.base_cli import run_cli
from bevdepth.exps.nuscenes.mv.bev_stereo_lss_r50_256x704_128x128_24e_2key import \
    BEVDepthLightningModel as BaseBEVDepthLightningModel  # noqa
from bevdepth.models.bev_stereo import BEVStereo


class BEVDepthLightningModel(BaseBEVDepthLightningModel):

    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.backbone_conf['use_da'] = True
        self.data_use_cbgs = True
        self.basic_lr_per_img = 2e-4 / 32
        self.model = BEVStereo(self.backbone_conf,
                               self.head_conf,
                               is_train_depth=True)

    def configure_optimizers(self):
        lr = self.basic_lr_per_img * \
            self.batch_size_per_device * self.gpus
        optimizer = torch.optim.AdamW(self.model.parameters(),
                                      lr=lr,
                                      weight_decay=1e-2)
        scheduler = MultiStepLR(optimizer, [16, 19])
        return [[optimizer], [scheduler]]


if __name__ == '__main__':
    run_cli(BEVDepthLightningModel,
            'bev_stereo_lss_r50_256x704_128x128_20e_cbgs_2key_da',
            extra_trainer_config_args={'epochs': 20})


================================================
FILE: bevdepth/exps/nuscenes/mv/bev_stereo_lss_r50_256x704_128x128_20e_cbgs_2key_da_ema.py
================================================
# Copyright (c) Megvii Inc. All rights reserved.
"""
mAP: 0.3721
mATE: 0.5980
mASE: 0.2701
mAOE: 0.4381
mAVE: 0.3672
mAAE: 0.1898
NDS: 0.4997
Eval time: 138.0s
Per-class results:
Object Class    AP      ATE     ASE     AOE     AVE     AAE
car     0.567   0.457   0.156   0.104   0.343   0.204
truck   0.299   0.650   0.205   0.103   0.321   0.197
bus     0.394   0.613   0.203   0.106   0.643   0.252
trailer 0.178   0.991   0.239   0.433   0.345   0.070
construction_vehicle    0.102   0.826   0.458   1.055   0.114   0.372
pedestrian      0.402   0.653   0.297   0.803   0.479   0.249
motorcycle      0.356   0.553   0.251   0.450   0.512   0.168
bicycle 0.311   0.440   0.265   0.779   0.180   0.006
traffic_cone    0.552   0.420   0.336   nan     nan     nan
barrier 0.561   0.377   0.291   0.111   nan     nan
"""
from bevdepth.exps.base_cli import run_cli
from bevdepth.exps.nuscenes.mv.bev_stereo_lss_r50_256x704_128x128_20e_cbgs_2key_da import \
    BEVDepthLightningModel  # noqa

if __name__ == '__main__':
    run_cli(BEVDepthLightningModel,
            'bev_stereo_lss_r50_256x704_128x128_20e_cbgs_2key_da_ema',
            use_ema=True,
            extra_trainer_config_args={'epochs': 20})


================================================
FILE: bevdepth/exps/nuscenes/mv/bev_stereo_lss_r50_256x704_128x128_24e_2key.py
================================================
# Copyright (c) Megvii Inc. All rights reserved.
"""
mAP: 0.3456
mATE: 0.6589
mASE: 0.2774
mAOE: 0.5500
mAVE: 0.4980
mAAE: 0.2278
NDS: 0.4516
Eval time: 158.2s
Per-class results:
Object Class    AP      ATE     ASE     AOE     AVE     AAE
car     0.510   0.525   0.165   0.188   0.510   0.226
truck   0.288   0.698   0.220   0.205   0.443   0.227
bus     0.378   0.622   0.210   0.135   0.896   0.289
trailer 0.156   1.003   0.219   0.482   0.609   0.179
construction_vehicle    0.094   0.929   0.502   1.209   0.108   0.365
pedestrian      0.356   0.728   0.297   1.005   0.579   0.319
motorcycle      0.361   0.571   0.258   0.734   0.631   0.211
bicycle 0.318   0.533   0.269   0.793   0.208   0.007
traffic_cone    0.488   0.501   0.355   nan     nan     nan
barrier 0.506   0.478   0.277   0.200   nan     nan
"""
from bevdepth.exps.base_cli import run_cli
from bevdepth.exps.nuscenes.base_exp import \
    BEVDepthLightningModel as BaseBEVDepthLightningModel
from bevdepth.models.bev_stereo import BEVStereo


class BEVDepthLightningModel(BaseBEVDepthLightningModel):

    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.key_idxes = [-1]
        self.head_conf['bev_backbone_conf']['in_channels'] = 80 * (
            len(self.key_idxes) + 1)
        self.head_conf['bev_neck_conf']['in_channels'] = [
            80 * (len(self.key_idxes) + 1), 160, 320, 640
        ]
        self.head_conf['train_cfg']['code_weights'] = [
            1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0
        ]
        self.head_conf['test_cfg']['thresh_scale'] = [
            0.6, 0.4, 0.4, 0.7, 0.8, 0.9
        ]
        self.head_conf['test_cfg']['nms_type'] = 'size_aware_circle'
        self.model = BEVStereo(self.backbone_conf,
                               self.head_conf,
                               is_train_depth=True)


if __name__ == '__main__':
    run_cli(BEVDepthLightningModel,
            'bev_stereo_lss_r50_256x704_128x128_24e_2key')


================================================
FILE: bevdepth/exps/nuscenes/mv/bev_stereo_lss_r50_256x704_128x128_24e_2key_ema.py
================================================
# Copyright (c) Megvii Inc. All rights reserved.
"""
mAP: 0.3494
mATE: 0.6672
mASE: 0.2785
mAOE: 0.5607
mAVE: 0.4687
mAAE: 0.2295
NDS: 0.4542
Eval time: 166.7s
Per-class results:
Object Class    AP      ATE     ASE     AOE     AVE     AAE
car     0.509   0.522   0.163   0.187   0.507   0.228
truck   0.287   0.694   0.213   0.202   0.449   0.229
bus     0.390   0.681   0.207   0.152   0.902   0.261
trailer 0.167   0.945   0.248   0.491   0.340   0.185
construction_vehicle    0.087   1.057   0.515   1.199   0.104   0.377
pedestrian      0.351   0.729   0.299   0.987   0.575   0.321
motorcycle      0.368   0.581   0.262   0.721   0.663   0.226
bicycle 0.338   0.494   0.258   0.921   0.209   0.008
traffic_cone    0.494   0.502   0.341   nan     nan     nan
barrier 0.502   0.467   0.278   0.185   nan     nan
"""
from bevdepth.exps.base_cli import run_cli
from bevdepth.exps.nuscenes.mv.bev_stereo_lss_r50_256x704_128x128_24e_2key import \
    BEVDepthLightningModel  # noqa

if __name__ == '__main__':
    run_cli(BEVDepthLightningModel,
            'bev_stereo_lss_r50_256x704_128x128_24e_2key_ema',
            use_ema=True)


================================================
FILE: bevdepth/exps/nuscenes/mv/bev_stereo_lss_r50_256x704_128x128_24e_key4.py
================================================
# Copyright (c) Megvii Inc. All rights reserved.
"""
mAP: 0.3427
mATE: 0.6560
mASE: 0.2784
mAOE: 0.5982
mAVE: 0.5347
mAAE: 0.2228
NDS: 0.4423
Eval time: 116.3s
Per-class results:
Object Class    AP      ATE     ASE     AOE     AVE     AAE
car     0.508   0.518   0.163   0.188   0.534   0.230
truck   0.268   0.709   0.214   0.215   0.510   0.226
bus     0.379   0.640   0.207   0.142   1.049   0.315
trailer 0.151   0.953   0.240   0.541   0.618   0.113
construction_vehicle    0.092   0.955   0.514   1.360   0.113   0.394
pedestrian      0.350   0.727   0.300   1.013   0.598   0.328
motorcycle      0.371   0.576   0.259   0.777   0.634   0.175
bicycle 0.325   0.512   0.261   0.942   0.221   0.002
traffic_cone    0.489   0.503   0.345   nan     nan     nan
barrier 0.495   0.468   0.280   0.206   nan     nan
"""
from bevdepth.exps.base_cli import run_cli
from bevdepth.exps.nuscenes.mv.bev_stereo_lss_r50_256x704_128x128_24e_2key import \
    BEVDepthLightningModel as BaseBEVDepthLightningModel  # noqa


class BEVDepthLightningModel(BaseBEVDepthLightningModel):

    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.num_sweeps = 2
        self.sweep_idxes = [4]
        self.key_idxes = list()


if __name__ == '__main__':
    run_cli(BEVDepthLightningModel,
            'bev_stereo_lss_r50_256x704_128x128_24e_key4')


================================================
FILE: bevdepth/exps/nuscenes/mv/bev_stereo_lss_r50_256x704_128x128_24e_key4_ema.py
================================================
# Copyright (c) Megvii Inc. All rights reserved.
"""
mAP: 0.3427
mATE: 0.6560
mASE: 0.2784
mAOE: 0.5982
mAVE: 0.5347
mAAE: 0.2228
NDS: 0.4423
Eval time: 116.3s
Per-class results:
Object Class    AP      ATE     ASE     AOE     AVE     AAE
car     0.508   0.518   0.163   0.188   0.534   0.230
truck   0.268   0.709   0.214   0.215   0.510   0.226
bus     0.379   0.640   0.207   0.142   1.049   0.315
trailer 0.151   0.953   0.240   0.541   0.618   0.113
construction_vehicle    0.092   0.955   0.514   1.360   0.113   0.394
pedestrian      0.350   0.727   0.300   1.013   0.598   0.328
motorcycle      0.371   0.576   0.259   0.777   0.634   0.175
bicycle 0.325   0.512   0.261   0.942   0.221   0.002
traffic_cone    0.489   0.503   0.345   nan     nan     nan
barrier 0.495   0.468   0.280   0.206   nan     nan
"""
from bevdepth.exps.base_cli import run_cli
from bevdepth.exps.nuscenes.mv.bev_stereo_lss_r50_256x704_128x128_24e_key4 import \
    BEVDepthLightningModel as BaseBEVDepthLightningModel  # noqa


class BEVDepthLightningModel(BaseBEVDepthLightningModel):

    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.num_sweeps = 2
        self.sweep_idxes = [4]
        self.key_idxes = list()


if __name__ == '__main__':
    run_cli(BEVDepthLightningModel,
            'bev_stereo_lss_r50_256x704_128x128_24e_key4_ema')


================================================
FILE: bevdepth/layers/__init__.py
================================================
from .heads.bev_depth_head import BEVDepthHead

__all__ = ['BEVDepthHead']


================================================
FILE: bevdepth/layers/backbones/__init__.py
================================================
from .base_lss_fpn import BaseLSSFPN
from .fusion_lss_fpn import FusionLSSFPN

__all__ = ['BaseLSSFPN', 'FusionLSSFPN']


================================================
FILE: bevdepth/layers/backbones/base_lss_fpn.py
================================================
# Copyright (c) Megvii Inc. All rights reserved.
import torch
import torch.nn.functional as F
from mmcv.cnn import build_conv_layer
from mmdet3d.models import build_neck
from mmdet.models import build_backbone
from mmdet.models.backbones.resnet import BasicBlock
from torch import nn
from torch.cuda.amp.autocast_mode import autocast

try:
    from bevdepth.ops.voxel_pooling_inference import voxel_pooling_inference
    from bevdepth.ops.voxel_pooling_train import voxel_pooling_train
except ImportError:
    print('Import VoxelPooling fail.')

__all__ = ['BaseLSSFPN']


class _ASPPModule(nn.Module):

    def __init__(self, inplanes, planes, kernel_size, padding, dilation,
                 BatchNorm):
        super(_ASPPModule, self).__init__()
        self.atrous_conv = nn.Conv2d(inplanes,
                                     planes,
                                     kernel_size=kernel_size,
                                     stride=1,
                                     padding=padding,
                                     dilation=dilation,
                                     bias=False)
        self.bn = BatchNorm(planes)
        self.relu = nn.ReLU()

        self._init_weight()

    def forward(self, x):
        x = self.atrous_conv(x)
        x = self.bn(x)

        return self.relu(x)

    def _init_weight(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                torch.nn.init.kaiming_normal_(m.weight)
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()


class ASPP(nn.Module):

    def __init__(self, inplanes, mid_channels=256, BatchNorm=nn.BatchNorm2d):
        super(ASPP, self).__init__()

        dilations = [1, 6, 12, 18]

        self.aspp1 = _ASPPModule(inplanes,
                                 mid_channels,
                                 1,
                                 padding=0,
                                 dilation=dilations[0],
                                 BatchNorm=BatchNorm)
        self.aspp2 = _ASPPModule(inplanes,
                                 mid_channels,
                                 3,
                                 padding=dilations[1],
                                 dilation=dilations[1],
                                 BatchNorm=BatchNorm)
        self.aspp3 = _ASPPModule(inplanes,
                                 mid_channels,
                                 3,
                                 padding=dilations[2],
                                 dilation=dilations[2],
                                 BatchNorm=BatchNorm)
        self.aspp4 = _ASPPModule(inplanes,
                                 mid_channels,
                                 3,
                                 padding=dilations[3],
                                 dilation=dilations[3],
                                 BatchNorm=BatchNorm)

        self.global_avg_pool = nn.Sequential(
            nn.AdaptiveAvgPool2d((1, 1)),
            nn.Conv2d(inplanes, mid_channels, 1, stride=1, bias=False),
            BatchNorm(mid_channels),
            nn.ReLU(),
        )
        self.conv1 = nn.Conv2d(int(mid_channels * 5),
                               mid_channels,
                               1,
                               bias=False)
        self.bn1 = BatchNorm(mid_channels)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(0.5)
        self._init_weight()

    def forward(self, x):
        x1 = self.aspp1(x)
        x2 = self.aspp2(x)
        x3 = self.aspp3(x)
        x4 = self.aspp4(x)
        x5 = self.global_avg_pool(x)
        x5 = F.interpolate(x5,
                           size=x4.size()[2:],
                           mode='bilinear',
                           align_corners=True)
        x = torch.cat((x1, x2, x3, x4, x5), dim=1)

        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)

        return self.dropout(x)

    def _init_weight(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                torch.nn.init.kaiming_normal_(m.weight)
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()


class Mlp(nn.Module):

    def __init__(self,
                 in_features,
                 hidden_features=None,
                 out_features=None,
                 act_layer=nn.ReLU,
                 drop=0.0):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = act_layer()
        self.drop1 = nn.Dropout(drop)
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop2 = nn.Dropout(drop)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop1(x)
        x = self.fc2(x)
        x = self.drop2(x)
        return x


class SELayer(nn.Module):

    def __init__(self, channels, act_layer=nn.ReLU, gate_layer=nn.Sigmoid):
        super().__init__()
        self.conv_reduce = nn.Conv2d(channels, channels, 1, bias=True)
        self.act1 = act_layer()
        self.conv_expand = nn.Conv2d(channels, channels, 1, bias=True)
        self.gate = gate_layer()

    def forward(self, x, x_se):
        x_se = self.conv_reduce(x_se)
        x_se = self.act1(x_se)
        x_se = self.conv_expand(x_se)
        return x * self.gate(x_se)


class DepthNet(nn.Module):

    def __init__(self, in_channels, mid_channels, context_channels,
                 depth_channels):
        super(DepthNet, self).__init__()
        self.reduce_conv = nn.Sequential(
            nn.Conv2d(in_channels,
                      mid_channels,
                      kernel_size=3,
                      stride=1,
                      padding=1),
            nn.BatchNorm2d(mid_channels),
            nn.ReLU(inplace=True),
        )
        self.context_conv = nn.Conv2d(mid_channels,
                                      context_channels,
                                      kernel_size=1,
                                      stride=1,
                                      padding=0)
        self.bn = nn.BatchNorm1d(27)
        self.depth_mlp = Mlp(27, mid_channels, mid_channels)
        self.depth_se = SELayer(mid_channels)  # NOTE: add camera-aware
        self.context_mlp = Mlp(27, mid_channels, mid_channels)
        self.context_se = SELayer(mid_channels)  # NOTE: add camera-aware
        self.depth_conv = nn.Sequential(
            BasicBlock(mid_channels, mid_channels),
            BasicBlock(mid_channels, mid_channels),
            BasicBlock(mid_channels, mid_channels),
            ASPP(mid_channels, mid_channels),
            build_conv_layer(cfg=dict(
                type='DCN',
                in_channels=mid_channels,
                out_channels=mid_channels,
                kernel_size=3,
                padding=1,
                groups=4,
                im2col_step=128,
            )),
            nn.Conv2d(mid_channels,
                      depth_channels,
                      kernel_size=1,
                      stride=1,
                      padding=0),
        )

    def forward(self, x, mats_dict):
        intrins = mats_dict['intrin_mats'][:, 0:1, ..., :3, :3]
        batch_size = intrins.shape[0]
        num_cams = intrins.shape[2]
        ida = mats_dict['ida_mats'][:, 0:1, ...]
        sensor2ego = mats_dict['sensor2ego_mats'][:, 0:1, ..., :3, :]
        bda = mats_dict['bda_mat'].view(batch_size, 1, 1, 4,
                                        4).repeat(1, 1, num_cams, 1, 1)
        mlp_input = torch.cat(
            [
                torch.stack(
                    [
                        intrins[:, 0:1, ..., 0, 0],
                        intrins[:, 0:1, ..., 1, 1],
                        intrins[:, 0:1, ..., 0, 2],
                        intrins[:, 0:1, ..., 1, 2],
                        ida[:, 0:1, ..., 0, 0],
                        ida[:, 0:1, ..., 0, 1],
                        ida[:, 0:1, ..., 0, 3],
                        ida[:, 0:1, ..., 1, 0],
                        ida[:, 0:1, ..., 1, 1],
                        ida[:, 0:1, ..., 1, 3],
                        bda[:, 0:1, ..., 0, 0],
                        bda[:, 0:1, ..., 0, 1],
                        bda[:, 0:1, ..., 1, 0],
                        bda[:, 0:1, ..., 1, 1],
                        bda[:, 0:1, ..., 2, 2],
                    ],
                    dim=-1,
                ),
                sensor2ego.view(batch_size, 1, num_cams, -1),
            ],
            -1,
        )
        mlp_input = self.bn(mlp_input.reshape(-1, mlp_input.shape[-1]))
        x = self.reduce_conv(x)
        context_se = self.context_mlp(mlp_input)[..., None, None]
        context = self.context_se(x, context_se)
        context = self.context_conv(context)
        depth_se = self.depth_mlp(mlp_input)[..., None, None]
        depth = self.depth_se(x, depth_se)
        depth = self.depth_conv(depth)
        return torch.cat([depth, context], dim=1)


class DepthAggregation(nn.Module):
    """
    pixel cloud feature extraction
    """

    def __init__(self, in_channels, mid_channels, out_channels):
        super(DepthAggregation, self).__init__()

        self.reduce_conv = nn.Sequential(
            nn.Conv2d(in_channels,
                      mid_channels,
                      kernel_size=3,
                      stride=1,
                      padding=1,
                      bias=False),
            nn.BatchNorm2d(mid_channels),
            nn.ReLU(inplace=True),
        )

        self.conv = nn.Sequential(
            nn.Conv2d(mid_channels,
                      mid_channels,
                      kernel_size=3,
                      stride=1,
                      padding=1,
                      bias=False),
            nn.BatchNorm2d(mid_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(mid_channels,
                      mid_channels,
                      kernel_size=3,
                      stride=1,
                      padding=1,
                      bias=False),
            nn.BatchNorm2d(mid_channels),
            nn.ReLU(inplace=True),
        )

        self.out_conv = nn.Sequential(
            nn.Conv2d(mid_channels,
                      out_channels,
                      kernel_size=3,
                      stride=1,
                      padding=1,
                      bias=True),
            # nn.BatchNorm3d(out_channels),
            # nn.ReLU(inplace=True),
        )

    @autocast(False)
    def forward(self, x):
        x = self.reduce_conv(x)
        x = self.conv(x) + x
        x = self.out_conv(x)
        return x


class BaseLSSFPN(nn.Module):

    def __init__(self,
                 x_bound,
                 y_bound,
                 z_bound,
                 d_bound,
                 final_dim,
                 downsample_factor,
                 output_channels,
                 img_backbone_conf,
                 img_neck_conf,
                 depth_net_conf,
                 use_da=False):
        """Modified from `https://github.com/nv-tlabs/lift-splat-shoot`.

        Args:
            x_bound (list): Boundaries for x.
            y_bound (list): Boundaries for y.
            z_bound (list): Boundaries for z.
            d_bound (list): Boundaries for d.
            final_dim (list): Dimension for input images.
            downsample_factor (int): Downsample factor between feature map
                and input image.
            output_channels (int): Number of channels for the output
                feature map.
            img_backbone_conf (dict): Config for image backbone.
            img_neck_conf (dict): Config for image neck.
            depth_net_conf (dict): Config for depth net.
        """

        super(BaseLSSFPN, self).__init__()
        self.downsample_factor = downsample_factor
        self.d_bound = d_bound
        self.final_dim = final_dim
        self.output_channels = output_channels

        self.register_buffer(
            'voxel_size',
            torch.Tensor([row[2] for row in [x_bound, y_bound, z_bound]]))
        self.register_buffer(
            'voxel_coord',
            torch.Tensor([
                row[0] + row[2] / 2.0 for row in [x_bound, y_bound, z_bound]
            ]))
        self.register_buffer(
            'voxel_num',
            torch.LongTensor([(row[1] - row[0]) / row[2]
                              for row in [x_bound, y_bound, z_bound]]))
        self.register_buffer('frustum', self.create_frustum())
        self.depth_channels, _, _, _ = self.frustum.shape

        self.img_backbone = build_backbone(img_backbone_conf)
        self.img_neck = build_neck(img_neck_conf)
        self.depth_net = self._configure_depth_net(depth_net_conf)

        self.img_neck.init_weights()
        self.img_backbone.init_weights()
        self.use_da = use_da
        if self.use_da:
            self.depth_aggregation_net = self._configure_depth_aggregation_net(
            )

    def _configure_depth_net(self, depth_net_conf):
        return DepthNet(
            depth_net_conf['in_channels'],
            depth_net_conf['mid_channels'],
            self.output_channels,
            self.depth_channels,
        )

    def _configure_depth_aggregation_net(self):
        """build pixel cloud feature extractor"""
        return DepthAggregation(self.output_channels, self.output_channels,
                                self.output_channels)

    def _forward_voxel_net(self, img_feat_with_depth):
        if self.use_da:
            # BEVConv2D [n, c, d, h, w] -> [n, h, c, w, d]
            img_feat_with_depth = img_feat_with_depth.permute(
                0, 3, 1, 4,
                2).contiguous()  # [n, c, d, h, w] -> [n, h, c, w, d]
            n, h, c, w, d = img_feat_with_depth.shape
            img_feat_with_depth = img_feat_with_depth.view(-1, c, w, d)
            img_feat_with_depth = (
                self.depth_aggregation_net(img_feat_with_depth).view(
                    n, h, c, w, d).permute(0, 2, 4, 1, 3).contiguous())
        return img_feat_with_depth

    def create_frustum(self):
        """Generate frustum"""
        # make grid in image plane
        ogfH, ogfW = self.final_dim
        fH, fW = ogfH // self.downsample_factor, ogfW // self.downsample_factor
        d_coords = torch.arange(*self.d_bound,
                                dtype=torch.float).view(-1, 1,
                                                        1).expand(-1, fH, fW)
        D, _, _ = d_coords.shape
        x_coords = torch.linspace(0, ogfW - 1, fW, dtype=torch.float).view(
            1, 1, fW).expand(D, fH, fW)
        y_coords = torch.linspace(0, ogfH - 1, fH,
                                  dtype=torch.float).view(1, fH,
                                                          1).expand(D, fH, fW)
        paddings = torch.ones_like(d_coords)

        # D x H x W x 3
        frustum = torch.stack((x_coords, y_coords, d_coords, paddings), -1)
        return frustum

    def get_geometry(self, sensor2ego_mat, intrin_mat, ida_mat, bda_mat):
        """Transfer points from camera coord to ego coord.

        Args:
            rots(Tensor): Rotation matrix from camera to ego.
            trans(Tensor): Translation matrix from camera to ego.
            intrins(Tensor): Intrinsic matrix.
            post_rots_ida(Tensor): Rotation matrix for ida.
            post_trans_ida(Tensor): Translation matrix for ida
            post_rot_bda(Tensor): Rotation matrix for bda.

        Returns:
            Tensors: points ego coord.
        """
        batch_size, num_cams, _, _ = sensor2ego_mat.shape

        # undo post-transformation
        # B x N x D x H x W x 3
        points = self.frustum
        ida_mat = ida_mat.view(batch_size, num_cams, 1, 1, 1, 4, 4)
        points = ida_mat.inverse().matmul(points.unsqueeze(-1))
        # cam_to_ego
        points = torch.cat(
            (points[:, :, :, :, :, :2] * points[:, :, :, :, :, 2:3],
             points[:, :, :, :, :, 2:]), 5)

        combine = sensor2ego_mat.matmul(torch.inverse(intrin_mat))
        points = combine.view(batch_size, num_cams, 1, 1, 1, 4,
                              4).matmul(points)
        if bda_mat is not None:
            bda_mat = bda_mat.unsqueeze(1).repeat(1, num_cams, 1, 1).view(
                batch_size, num_cams, 1, 1, 1, 4, 4)
            points = (bda_mat @ points).squeeze(-1)
        else:
            points = points.squeeze(-1)
        return points[..., :3]

    def get_cam_feats(self, imgs):
        """Get feature maps from images."""
        batch_size, num_sweeps, num_cams, num_channels, imH, imW = imgs.shape

        imgs = imgs.flatten().view(batch_size * num_sweeps * num_cams,
                                   num_channels, imH, imW)
        img_feats = self.img_neck(self.img_backbone(imgs))[0]
        img_feats = img_feats.reshape(batch_size, num_sweeps, num_cams,
                                      img_feats.shape[1], img_feats.shape[2],
                                      img_feats.shape[3])
        return img_feats

    def _forward_depth_net(self, feat, mats_dict):
        return self.depth_net(feat, mats_dict)

    def _forward_single_sweep(self,
                              sweep_index,
                              sweep_imgs,
                              mats_dict,
                              is_return_depth=False):
        """Forward function for single sweep.

        Args:
            sweep_index (int): Index of sweeps.
            sweep_imgs (Tensor): Input images.
            mats_dict (dict):
                sensor2ego_mats(Tensor): Transformation matrix from
                    camera to ego with shape of (B, num_sweeps,
                    num_cameras, 4, 4).
                intrin_mats(Tensor): Intrinsic matrix with shape
                    of (B, num_sweeps, num_cameras, 4, 4).
                ida_mats(Tensor): Transformation matrix for ida with
                    shape of (B, num_sweeps, num_cameras, 4, 4).
                sensor2sensor_mats(Tensor): Transformation matrix
                    from key frame camera to sweep frame camera with
                    shape of (B, num_sweeps, num_cameras, 4, 4).
                bda_mat(Tensor): Rotation matrix for bda with shape
                    of (B, 4, 4).
            is_return_depth (bool, optional): Whether to return depth.
                Default: False.

        Returns:
            Tensor: BEV feature map.
        """
        batch_size, num_sweeps, num_cams, num_channels, img_height, \
            img_width = sweep_imgs.shape
        img_feats = self.get_cam_feats(sweep_imgs)
        source_features = img_feats[:, 0, ...]
        depth_feature = self._forward_depth_net(
            source_features.reshape(batch_size * num_cams,
                                    source_features.shape[2],
                                    source_features.shape[3],
                                    source_features.shape[4]),
            mats_dict,
        )
        depth = depth_feature[:, :self.depth_channels].softmax(
            dim=1, dtype=depth_feature.dtype)
        geom_xyz = self.get_geometry(
            mats_dict['sensor2ego_mats'][:, sweep_index, ...],
            mats_dict['intrin_mats'][:, sweep_index, ...],
            mats_dict['ida_mats'][:, sweep_index, ...],
            mats_dict.get('bda_mat', None),
        )
        geom_xyz = ((geom_xyz - (self.voxel_coord - self.voxel_size / 2.0)) /
                    self.voxel_size).int()
        if self.training or self.use_da:
            img_feat_with_depth = depth.unsqueeze(
                1) * depth_feature[:, self.depth_channels:(
                    self.depth_channels + self.output_channels)].unsqueeze(2)

            img_feat_with_depth = self._forward_voxel_net(img_feat_with_depth)

            img_feat_with_depth = img_feat_with_depth.reshape(
                batch_size,
                num_cams,
                img_feat_with_depth.shape[1],
                img_feat_with_depth.shape[2],
                img_feat_with_depth.shape[3],
                img_feat_with_depth.shape[4],
            )

            img_feat_with_depth = img_feat_with_depth.permute(0, 1, 3, 4, 5, 2)

            feature_map = voxel_pooling_train(geom_xyz,
                                              img_feat_with_depth.contiguous(),
                                              self.voxel_num.cuda())
        else:
            feature_map = voxel_pooling_inference(
                geom_xyz, depth, depth_feature[:, self.depth_channels:(
                    self.depth_channels + self.output_channels)].contiguous(),
                self.voxel_num.cuda())
        if is_return_depth:
            # final_depth has to be fp32, otherwise the depth
            # loss will colapse during the traing process.
            return feature_map.contiguous(
            ), depth_feature[:, :self.depth_channels].softmax(dim=1)
        return feature_map.contiguous()

    def forward(self,
                sweep_imgs,
                mats_dict,
                timestamps=None,
                is_return_depth=False):
        """Forward function.

        Args:
            sweep_imgs(Tensor): Input images with shape of (B, num_sweeps,
                num_cameras, 3, H, W).
            mats_dict(dict):
                sensor2ego_mats(Tensor): Transformation matrix from
                    camera to ego with shape of (B, num_sweeps,
                    num_cameras, 4, 4).
                intrin_mats(Tensor): Intrinsic matrix with shape
                    of (B, num_sweeps, num_cameras, 4, 4).
                ida_mats(Tensor): Transformation matrix for ida with
                    shape of (B, num_sweeps, num_cameras, 4, 4).
                sensor2sensor_mats(Tensor): Transformation matrix
                    from key frame camera to sweep frame camera with
                    shape of (B, num_sweeps, num_cameras, 4, 4).
                bda_mat(Tensor): Rotation matrix for bda with shape
                    of (B, 4, 4).
            timestamps(Tensor): Timestamp for all images with the shape of(B,
                num_sweeps, num_cameras).

        Return:
            Tensor: bev feature map.
        """
        batch_size, num_sweeps, num_cams, num_channels, img_height, \
            img_width = sweep_imgs.shape

        key_frame_res = self._forward_single_sweep(
            0,
            sweep_imgs[:, 0:1, ...],
            mats_dict,
            is_return_depth=is_return_depth)
        if num_sweeps == 1:
            return key_frame_res

        key_frame_feature = key_frame_res[
            0] if is_return_depth else key_frame_res

        ret_feature_list = [key_frame_feature]
        for sweep_index in range(1, num_sweeps):
            with torch.no_grad():
                feature_map = self._forward_single_sweep(
                    sweep_index,
                    sweep_imgs[:, sweep_index:sweep_index + 1, ...],
                    mats_dict,
                    is_return_depth=False)
                ret_feature_list.append(feature_map)

        if is_return_depth:
            return torch.cat(ret_feature_list, 1), key_frame_res[1]
        else:
            return torch.cat(ret_feature_list, 1)


================================================
FILE: bevdepth/layers/backbones/bevstereo_lss_fpn.py
================================================
# Copyright (c) Megvii Inc. All rights reserved.
import math

import numpy as np
import torch
import torch.nn.functional as F
from mmcv.cnn import build_conv_layer
from mmdet.models.backbones.resnet import BasicBlock
from scipy.special import erf
from scipy.stats import norm
from torch import nn

from bevdepth.layers.backbones.base_lss_fpn import (ASPP, BaseLSSFPN, Mlp,
                                                    SELayer)

try:
    from bevdepth.ops.voxel_pooling_inference import voxel_pooling_inference
    from bevdepth.ops.voxel_pooling_train import voxel_pooling_train
except ImportError:
    print('Import VoxelPooling fail.')

__all__ = ['BEVStereoLSSFPN']


class ConvBnReLU3D(nn.Module):
    """Implements of 3d convolution + batch normalization + ReLU."""

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        kernel_size: int = 3,
        stride: int = 1,
        pad: int = 1,
        dilation: int = 1,
    ) -> None:
        """initialization method for convolution3D +
            batch normalization + relu module
        Args:
            in_channels: input channel number of convolution layer
            out_channels: output channel number of convolution layer
            kernel_size: kernel size of convolution layer
            stride: stride of convolution layer
            pad: pad of convolution layer
            dilation: dilation of convolution layer
        """
        super(ConvBnReLU3D, self).__init__()
        self.conv = nn.Conv3d(in_channels,
                              out_channels,
                              kernel_size,
                              stride=stride,
                              padding=pad,
                              dilation=dilation,
                              bias=False)
        self.bn = nn.BatchNorm3d(out_channels)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """forward method"""
        return F.relu(self.bn(self.conv(x)), inplace=True)


class DepthNet(nn.Module):

    def __init__(self,
                 in_channels,
                 mid_channels,
                 context_channels,
                 depth_channels,
                 d_bound,
                 num_ranges=4):
        super(DepthNet, self).__init__()
        self.reduce_conv = nn.Sequential(
            nn.Conv2d(in_channels,
                      mid_channels,
                      kernel_size=3,
                      stride=1,
                      padding=1),
            nn.BatchNorm2d(mid_channels),
            nn.ReLU(inplace=True),
        )
        self.context_conv = nn.Conv2d(mid_channels,
                                      context_channels,
                                      kernel_size=1,
                                      stride=1,
                                      padding=0)
        self.bn = nn.BatchNorm1d(27)
        self.depth_mlp = Mlp(27, mid_channels, mid_channels)
        self.depth_se = SELayer(mid_channels)  # NOTE: add camera-aware
        self.context_mlp = Mlp(27, mid_channels, mid_channels)
        self.context_se = SELayer(mid_channels)  # NOTE: add camera-aware
        self.depth_feat_conv = nn.Sequential(
            BasicBlock(mid_channels, mid_channels),
            BasicBlock(mid_channels, mid_channels),
            ASPP(mid_channels, mid_channels),
            build_conv_layer(cfg=dict(
                type='DCN',
                in_channels=mid_channels,
                out_channels=mid_channels,
                kernel_size=3,
                padding=1,
                groups=4,
                im2col_step=128,
            )),
        )
        self.mu_sigma_range_net = nn.Sequential(
            BasicBlock(mid_channels, mid_channels),
            nn.ConvTranspose2d(mid_channels,
                               mid_channels,
                               3,
                               stride=2,
                               padding=1,
                               output_padding=1),
            nn.BatchNorm2d(mid_channels),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(mid_channels,
                               mid_channels,
                               3,
                               stride=2,
                               padding=1,
                               output_padding=1),
            nn.BatchNorm2d(mid_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(mid_channels,
                      num_ranges * 3,
                      kernel_size=1,
                      stride=1,
                      padding=0),
        )
        self.mono_depth_net = nn.Sequential(
            BasicBlock(mid_channels, mid_channels),
            nn.Conv2d(mid_channels,
                      depth_channels,
                      kernel_size=1,
                      stride=1,
                      padding=0),
        )
        self.d_bound = d_bound
        self.num_ranges = num_ranges

    # @autocast(False)
    def forward(self, x, mats_dict, scale_depth_factor=1000.0):
        B, _, H, W = x.shape
        intrins = mats_dict['intrin_mats'][:, 0:1, ..., :3, :3]
        batch_size = intrins.shape[0]
        num_cams = intrins.shape[2]
        ida = mats_dict['ida_mats'][:, 0:1, ...]
        sensor2ego = mats_dict['sensor2ego_mats'][:, 0:1, ..., :3, :]
        bda = mats_dict['bda_mat'].view(batch_size, 1, 1, 4,
                                        4).repeat(1, 1, num_cams, 1, 1)
        mlp_input = torch.cat(
            [
                torch.stack(
                    [
                        intrins[:, 0:1, ..., 0, 0],
                        intrins[:, 0:1, ..., 1, 1],
                        intrins[:, 0:1, ..., 0, 2],
                        intrins[:, 0:1, ..., 1, 2],
                        ida[:, 0:1, ..., 0, 0],
                        ida[:, 0:1, ..., 0, 1],
                        ida[:, 0:1, ..., 0, 3],
                        ida[:, 0:1, ..., 1, 0],
                        ida[:, 0:1, ..., 1, 1],
                        ida[:, 0:1, ..., 1, 3],
                        bda[:, 0:1, ..., 0, 0],
                        bda[:, 0:1, ..., 0, 1],
                        bda[:, 0:1, ..., 1, 0],
                        bda[:, 0:1, ..., 1, 1],
                        bda[:, 0:1, ..., 2, 2],
                    ],
                    dim=-1,
                ),
                sensor2ego.view(batch_size, 1, num_cams, -1),
            ],
            -1,
        )
        mlp_input = self.bn(mlp_input.reshape(-1, mlp_input.shape[-1]))
        x = self.reduce_conv(x)
        context_se = self.context_mlp(mlp_input)[..., None, None]
        context = self.context_se(x, context_se)
        context = self.context_conv(context)
        depth_se = self.depth_mlp(mlp_input)[..., None, None]
        depth_feat = self.depth_se(x, depth_se)
        depth_feat = self.depth_feat_conv(depth_feat)
        mono_depth = self.mono_depth_net(depth_feat)
        mu_sigma_score = self.mu_sigma_range_net(depth_feat)
        d_coords = torch.arange(*self.d_bound,
                                dtype=torch.float).reshape(1, -1, 1, 1).cuda()
        d_coords = d_coords.repeat(B, 1, H, W)
        mu = mu_sigma_score[:, 0:self.num_ranges, ...]
        sigma = mu_sigma_score[:, self.num_ranges:2 * self.num_ranges, ...]
        range_score = mu_sigma_score[:,
                                     2 * self.num_ranges:3 * self.num_ranges,
                                     ...]
        sigma = F.elu(sigma) + 1.0 + 1e-10
        return x, context, mu, sigma, range_score, mono_depth


class BEVStereoLSSFPN(BaseLSSFPN):

    def __init__(self,
                 x_bound,
                 y_bound,
                 z_bound,
                 d_bound,
                 final_dim,
                 downsample_factor,
                 output_channels,
                 img_backbone_conf,
                 img_neck_conf,
                 depth_net_conf,
                 use_da=False,
                 sampling_range=3,
                 num_samples=3,
                 stereo_downsample_factor=4,
                 em_iteration=3,
                 min_sigma=1,
                 num_groups=8,
                 num_ranges=4,
                 range_list=[[2, 8], [8, 16], [16, 28], [28, 58]],
                 k_list=None,
                 use_mask=True):
        """Modified from `https://github.com/nv-tlabs/lift-splat-shoot`.
        Args:
            x_bound (list): Boundaries for x.
            y_bound (list): Boundaries for y.
            z_bound (list): Boundaries for z.
            d_bound (list): Boundaries for d.
            final_dim (list): Dimension for input images.
            downsample_factor (int): Downsample factor between feature map
                and input image.
            output_channels (int): Number of channels for the output
                feature map.
            img_backbone_conf (dict): Config for image backbone.
            img_neck_conf (dict): Config for image neck.
            depth_net_conf (dict): Config for depth net.
            sampling_range (int): The base range of sampling candidates.
                Defaults to 3.
            num_samples (int): Number of samples. Defaults to 3.
            stereo_downsample_factor (int): Downsample factor from input image
                and stereo depth. Defaults to 4.
            em_iteration (int): Number of iterations for em. Defaults to 3.
            min_sigma (float): Minimal value for sigma. Defaults to 1.
            num_groups (int): Number of groups to keep after inner product.
                Defaults to 8.
            num_ranges (int): Number of split ranges. Defaults to 1.
            range_list (list): Start and end of every range, Defaults to None.
            k_list (list): Depth of all candidates inside the range.
                Defaults to None.
            use_mask (bool): Whether to use mask_net. Defaults to True.
        """
        self.num_ranges = num_ranges
        self.sampling_range = sampling_range
        self.num_samples = num_samples
        super(BEVStereoLSSFPN,
              self).__init__(x_bound, y_bound, z_bound, d_bound, final_dim,
                             downsample_factor, output_channels,
                             img_backbone_conf, img_neck_conf, depth_net_conf,
                             use_da)

        self.depth_channels, _, _, _ = self.frustum.shape
        self.use_mask = use_mask
        if k_list is None:
            self.register_buffer('k_list', torch.Tensor(self.depth_sampling()))
        else:
            self.register_buffer('k_list', torch.Tensor(k_list))
        self.stereo_downsample_factor = stereo_downsample_factor
        self.em_iteration = em_iteration
        self.register_buffer(
            'depth_values',
            torch.arange((self.d_bound[1] - self.d_bound[0]) / self.d_bound[2],
                         dtype=torch.float))
        self.num_groups = num_groups
        self.similarity_net = nn.Sequential(
            ConvBnReLU3D(in_channels=num_groups,
                         out_channels=16,
                         kernel_size=1,
                         stride=1,
                         pad=0),
            ConvBnReLU3D(in_channels=16,
                         out_channels=8,
                         kernel_size=1,
                         stride=1,
                         pad=0),
            nn.Conv3d(in_channels=8,
                      out_channels=1,
                      kernel_size=1,
                      stride=1,
                      padding=0),
        )
        if range_list is None:
            range_length = (d_bound[1] - d_bound[0]) / num_ranges
            self.range_list = [[
                d_bound[0] + range_length * i,
                d_bound[0] + range_length * (i + 1)
            ] for i in range(num_ranges)]
        else:
            assert len(range_list) == num_ranges
            self.range_list = range_list

        self.min_sigma = min_sigma
        self.depth_downsample_net = nn.Sequential(
            nn.Conv2d(self.depth_channels, 256, 3, 2, 1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.Conv2d(256, 256, 3, 2, 1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.Conv2d(256, self.depth_channels, 1, 1, 0),
        )
        self.context_downsample_net = nn.Identity()
        if self.use_mask:
            self.mask_net = nn.Sequential(
                nn.Conv2d(224, 64, 3, 1, 1),
                nn.BatchNorm2d(64),
                nn.ReLU(inplace=True),
                BasicBlock(64, 64),
                BasicBlock(64, 64),
                nn.Conv2d(64, 1, 1, 1, 0),
                nn.Sigmoid(),
            )

    def depth_sampling(self):
        """Generate sampling range of candidates.
        Returns:
            list[float]: List of all candidates.
        """
        P_total = erf(self.sampling_range /
                      np.sqrt(2))  # Probability covered by the sampling range
        idx_list = np.arange(0, self.num_samples + 1)
        p_list = (1 - P_total) / 2 + ((idx_list / self.num_samples) * P_total)
        k_list = norm.ppf(p_list)
        k_list = (k_list[1:] + k_list[:-1]) / 2
        return list(k_list)

    def _generate_cost_volume(
        self,
        sweep_index,
        stereo_feats_all_sweeps,
        mats_dict,
        depth_sample,
        depth_sample_frustum,
        sensor2sensor_mats,
    ):
        """Generate cost volume based on depth sample.
        Args:
            sweep_index (int): Index of sweep.
            stereo_feats_all_sweeps (list[Tensor]): Stereo feature
                of all sweeps.
            mats_dict (dict):
                sensor2ego_mats (Tensor): Transformation matrix from
                    camera to ego with shape of (B, num_sweeps,
                    num_cameras, 4, 4).
                intrin_mats (Tensor): Intrinsic matrix with shape
                    of (B, num_sweeps, num_cameras, 4, 4).
                ida_mats (Tensor): Transformation matrix for ida with
                    shape of (B, num_sweeps, num_cameras, 4, 4).
                sensor2sensor_mats (Tensor): Transformation matrix
                    from key frame camera to sweep frame camera with
                    shape of (B, num_sweeps, num_cameras, 4, 4).
                bda_mat (Tensor): Rotation matrix for bda with shape
                    of (B, 4, 4).
            depth_sample (Tensor): Depth map of all candidates.
            depth_sample_frustum (Tensor): Pre-generated frustum.
            sensor2sensor_mats (Tensor): Transformation matrix from reference
                sensor to source sensor.
        Returns:
            Tensor: Depth score for all sweeps.
        """
        batch_size, num_channels, height, width = stereo_feats_all_sweeps[
            0].shape
        num_sweeps = len(stereo_feats_all_sweeps)
        depth_score_all_sweeps = list()
        for idx in range(num_sweeps):
            if idx == sweep_index:
                continue
            warped_stereo_fea = self.homo_warping(
                stereo_feats_all_sweeps[idx],
                mats_dict['intrin_mats'][:, sweep_index, ...],
                mats_dict['intrin_mats'][:, idx, ...],
                sensor2sensor_mats[idx],
                mats_dict['ida_mats'][:, sweep_index, ...],
                mats_dict['ida_mats'][:, idx, ...],
                depth_sample,
                depth_sample_frustum.type_as(stereo_feats_all_sweeps[idx]),
            )
            warped_stereo_fea = warped_stereo_fea.reshape(
                batch_size, self.num_groups, num_channels // self.num_groups,
                self.num_samples, height, width)
            ref_stereo_feat = stereo_feats_all_sweeps[sweep_index].reshape(
                batch_size, self.num_groups, num_channels // self.num_groups,
                height, width)
            feat_cost = torch.mean(
                (ref_stereo_feat.unsqueeze(3) * warped_stereo_fea), axis=2)
            depth_score = self.similarity_net(feat_cost).squeeze(1)
            depth_score_all_sweeps.append(depth_score)
        return torch.stack(depth_score_all_sweeps).mean(0)

    def homo_warping(
        self,
        stereo_feat,
        key_intrin_mats,
        sweep_intrin_mats,
        sensor2sensor_mats,
        key_ida_mats,
        sweep_ida_mats,
        depth_sample,
        frustum,
    ):
        """Used for mvs method to transfer sweep image feature to
            key image feature.
        Args:
            src_fea(Tensor): image features.
            key_intrin_mats(Tensor): Intrin matrix for key sensor.
            sweep_intrin_mats(Tensor): Intrin matrix for sweep sensor.
            sensor2sensor_mats(Tensor): Transformation matrix from key
                sensor to sweep sensor.
            key_ida_mats(Tensor): Ida matrix for key frame.
            sweep_ida_mats(Tensor): Ida matrix for sweep frame.
            depth_sample (Tensor): Depth map of all candidates.
            depth_sample_frustum (Tensor): Pre-generated frustum.
        """
        batch_size_with_num_cams, channels = stereo_feat.shape[
            0], stereo_feat.shape[1]
        height, width = stereo_feat.shape[2], stereo_feat.shape[3]
        with torch.no_grad():
            points = frustum
            points = points.reshape(points.shape[0], -1, points.shape[-1])
            points[..., 2] = 1
            # Undo ida for key frame.
            points = key_ida_mats.reshape(batch_size_with_num_cams, *
                                          key_ida_mats.shape[2:]).inverse(
                                          ).unsqueeze(1) @ points.unsqueeze(-1)
            # Convert points from pixel coord to key camera coord.
            points[..., :3, :] *= depth_sample.reshape(
                batch_size_with_num_cams, -1, 1, 1)
            num_depth = frustum.shape[1]
            points = (key_intrin_mats.reshape(
                batch_size_with_num_cams, *
                key_intrin_mats.shape[2:]).inverse().unsqueeze(1) @ points)
            points = (sensor2sensor_mats.reshape(
                batch_size_with_num_cams, *
                sensor2sensor_mats.shape[2:]).unsqueeze(1) @ points)
            # points in sweep sensor coord.
            points = (sweep_intrin_mats.reshape(
                batch_size_with_num_cams, *
                sweep_intrin_mats.shape[2:]).unsqueeze(1) @ points)
            # points in sweep pixel coord.
            points[..., :2, :] = points[..., :2, :] / points[
                ..., 2:3, :]  # [B, 2, Ndepth, H*W]

            points = (sweep_ida_mats.reshape(
                batch_size_with_num_cams, *
                sweep_ida_mats.shape[2:]).unsqueeze(1) @ points).squeeze(-1)
            neg_mask = points[..., 2] < 1e-3
            points[..., 0][neg_mask] = width * self.stereo_downsample_factor
            points[..., 1][neg_mask] = height * self.stereo_downsample_factor
            points[..., 2][neg_mask] = 1
            proj_x_normalized = points[..., 0] / (
                (width * self.stereo_downsample_factor - 1) / 2) - 1
            proj_y_normalized = points[..., 1] / (
                (height * self.stereo_downsample_factor - 1) / 2) - 1
            grid = torch.stack([proj_x_normalized, proj_y_normalized],
                               dim=2)  # [B, Ndepth, H*W, 2]

        warped_stereo_fea = F.grid_sample(
            stereo_feat,
            grid.view(batch_size_with_num_cams, num_depth * height, width, 2),
            mode='bilinear',
            padding_mode='zeros',
        )
        warped_stereo_fea = warped_stereo_fea.view(batch_size_with_num_cams,
                                                   channels, num_depth, height,
                                                   width)

        return warped_stereo_fea

    def _forward_stereo(
        self,
        sweep_index,
        stereo_feats_all_sweeps,
        mono_depth_all_sweeps,
        mats_dict,
        sensor2sensor_mats,
        mu_all_sweeps,
        sigma_all_sweeps,
        range_score_all_sweeps,
        depth_feat_all_sweeps,
    ):
        """Forward function to generate stereo depth.
        Args:
            sweep_index (int): Index of sweep.
            stereo_feats_all_sweeps (list[Tensor]): Stereo feature
                of all sweeps.
            mono_depth_all_sweeps (list[Tensor]):
            mats_dict (dict):
                sensor2ego_mats (Tensor): Transformation matrix from
                    camera to ego with shape of (B, num_sweeps,
                    num_cameras, 4, 4).
                intrin_mats (Tensor): Intrinsic matrix with shape
                    of (B, num_sweeps, num_cameras, 4, 4).
                ida_mats (Tensor): Transformation matrix for ida with
                    shape of (B, num_sweeps, num_cameras, 4, 4).
                sensor2sensor_mats (Tensor): Transformation matrix
                    from key frame camera to sweep frame camera with
                    shape of (B, num_sweeps, num_cameras, 4, 4).
                bda_mat (Tensor): Rotation matrix for bda with shape
                    of (B, 4, 4).
            sensor2sensor_mats(Tensor): Transformation matrix from key
                sensor to sweep sensor.
            mu_all_sweeps (list[Tensor]): List of mu for all sweeps.
            sigma_all_sweeps (list[Tensor]): List of sigma for all sweeps.
            range_score_all_sweeps (list[Tensor]): List of all range score
                for all sweeps.
            depth_feat_all_sweeps (list[Tensor]): List of all depth feat for
                all sweeps.
        Returns:
            Tensor: stereo_depth
        """
        batch_size_with_cams, _, feat_height, feat_width = \
            stereo_feats_all_sweeps[0].shape
        device = stereo_feats_all_sweeps[0].device
        d_coords = torch.arange(*self.d_bound,
                                dtype=torch.float,
                                device=device).reshape(1, -1, 1, 1)
        d_coords = d_coords.repeat(batch_size_with_cams, 1, feat_height,
                                   feat_width)
        stereo_depth = stereo_feats_all_sweeps[0].new_zeros(
            batch_size_with_cams, self.depth_channels, feat_height, feat_width)
        mask_score = stereo_feats_all_sweeps[0].new_zeros(
            batch_size_with_cams,
            self.depth_channels,
            feat_height * self.stereo_downsample_factor //
            self.downsample_factor,
            feat_width * self.stereo_downsample_factor //
            self.downsample_factor,
        )
        score_all_ranges = list()
        range_score = range_score_all_sweeps[sweep_index].softmax(1)
        for range_idx in range(self.num_ranges):
            # Map mu to the corresponding interval.
            range_start = self.range_list[range_idx][0]
            mu_all_sweeps_single_range = [
                mu[:, range_idx:range_idx + 1, ...].sigmoid() *
                (self.range_list[range_idx][1] - self.range_list[range_idx][0])
                + range_start for mu in mu_all_sweeps
            ]
            sigma_all_sweeps_single_range = [
                sigma[:, range_idx:range_idx + 1, ...]
                for sigma in sigma_all_sweeps
            ]
            batch_size_with_cams, _, feat_height, feat_width =\
                stereo_feats_all_sweeps[0].shape
            mu = mu_all_sweeps_single_range[sweep_index]
            sigma = sigma_all_sweeps_single_range[sweep_index]
            for _ in range(self.em_iteration):
                depth_sample = torch.cat([mu + sigma * k for k in self.k_list],
                                         1)
                depth_sample_frustum = self.create_depth_sample_frustum(
                    depth_sample, self.stereo_downsample_factor)
                mu_score = self._generate_cost_volume(
                    sweep_index,
                    stereo_feats_all_sweeps,
                    mats_dict,
                    depth_sample,
                    depth_sample_frustum,
                    sensor2sensor_mats,
                )
                mu_score = mu_score.softmax(1)
                scale_factor = torch.clamp(
                    0.5 / (1e-4 + mu_score[:, self.num_samples //
                                           2:self.num_samples // 2 + 1, ...]),
                    min=0.1,
                    max=10)

                sigma = torch.clamp(sigma * scale_factor, min=0.1, max=10)
                mu = (depth_sample * mu_score).sum(1, keepdim=True)
                del depth_sample
                del depth_sample_frustum
            range_length = int(
                (self.range_list[range_idx][1] - self.range_list[range_idx][0])
                // self.d_bound[2])
            if self.use_mask:
                depth_sample = F.avg_pool2d(
                    mu,
                    self.downsample_factor // self.stereo_downsample_factor,
                    self.downsample_factor // self.stereo_downsample_factor,
                )
                depth_sample_frustum = self.create_depth_sample_frustum(
                    depth_sample, self.downsample_factor)
                mask = self._forward_mask(
                    sweep_index,
                    mono_depth_all_sweeps,
                    mats_dict,
                    depth_sample,
                    depth_sample_frustum,
                    sensor2sensor_mats,
                )
                mask_score[:,
                           int((range_start - self.d_bound[0]) //
                               self.d_bound[2]):range_length +
                           int((range_start - self.d_bound[0]) //
                               self.d_bound[2]), ..., ] += mask
                del depth_sample
                del depth_sample_frustum
            sigma = torch.clamp(sigma, self.min_sigma)
            mu_repeated = mu.repeat(1, range_length, 1, 1)
            eps = 1e-6
            depth_score_single_range = (-1 / 2 * (
                (d_coords[:,
                          int((range_start - self.d_bound[0]) //
                              self.d_bound[2]):range_length + int(
                                  (range_start - self.d_bound[0]) //
                                  self.d_bound[2]), ..., ] - mu_repeated) /
                torch.sqrt(sigma))**2)
            depth_score_single_range = depth_score_single_range.exp()
            score_all_ranges.append(mu_score.sum(1).unsqueeze(1))
            depth_score_single_range = depth_score_single_range / (
                sigma * math.sqrt(2 * math.pi) + eps)
            stereo_depth[:,
                         int((range_start - self.d_bound[0]) //
                             self.d_bound[2]):range_length +
                         int((range_start - self.d_bound[0]) //
                             self.d_bound[2]), ..., ] = (
                                 depth_score_single_range *
                                 range_score[:, range_idx:range_idx + 1, ...])
            del depth_score_single_range
            del mu_repeated
        if self.use_mask:
            return stereo_depth, mask_score
        else:
            return stereo_depth

    def create_depth_sample_frustum(self, depth_sample, downsample_factor=16):
        """Generate frustum"""
        # make grid in image plane
        ogfH, ogfW = self.final_dim
        fH, fW = ogfH // downsample_factor, ogfW // downsample_factor
        batch_size, num_depth, _, _ = depth_sample.shape
        x_coords = (torch.linspace(0,
                                   ogfW - 1,
                                   fW,
                                   dtype=torch.float,
                                   device=depth_sample.device).view(
                                       1, 1, 1,
                                       fW).expand(batch_size, num_depth, fH,
                                                  fW))
        y_coords = (torch.linspace(0,
                                   ogfH - 1,
                                   fH,
                                   dtype=torch.float,
                                   device=depth_sample.device).view(
                                       1, 1, fH,
                                       1).expand(batch_size, num_depth, fH,
                                                 fW))
        paddings = torch.ones_like(depth_sample)

        # D x H x W x 3
        frustum = torch.stack((x_coords, y_coords, depth_sample, paddings), -1)
        return frustum

    def _configure_depth_net(self, depth_net_conf):
        return DepthNet(
            depth_net_conf['in_channels'],
            depth_net_conf['mid_channels'],
            self.output_channels,
            self.depth_channels,
            self.d_bound,
            self.num_ranges,
        )

    def get_cam_feats(self, imgs):
        """Get feature maps from images."""
        batch_size, num_sweeps, num_cams, num_channels, imH, imW = imgs.shape

        imgs = imgs.flatten().view(batch_size * num_sweeps * num_cams,
                                   num_channels, imH, imW)
        backbone_feats = self.img_backbone(imgs)
        img_feats = self.img_neck(backbone_feats)[0]
        img_feats_reshape = img_feats.reshape(batch_size, num_sweeps, num_cams,
                                              img_feats.shape[1],
                                              img_feats.shape[2],
                                              img_feats.shape[3])
        return img_feats_reshape, backbone_feats[0].detach()

    def _forward_mask(
        self,
        sweep_index,
        mono_depth_all_sweeps,
        mats_dict,
        depth_sample,
        depth_sample_frustum,
        sensor2sensor_mats,
    ):
        """Forward function to generate mask.
        Args:
            sweep_index (int): Index of sweep.
            mono_depth_all_sweeps (list[Tensor]): List of mono_depth for
                all sweeps.
            mats_dict (dict):
                sensor2ego_mats (Tensor): Transformation matrix from
                    camera to ego with shape of (B, num_sweeps,
                    num_cameras, 4, 4).
                intrin_mats (Tensor): Intrinsic matrix with shape
                    of (B, num_sweeps, num_cameras, 4, 4).
                ida_mats (Tensor): Transformation matrix for ida with
                    shape of (B, num_sweeps, num_cameras, 4, 4).
                sensor2sensor_mats (Tensor): Transformation matrix
                    from key frame camera to sweep frame camera with
                    shape of (B, num_sweeps, num_cameras, 4, 4).
                bda_mat (Tensor): Rotation matrix for bda with shape
                    of (B, 4, 4).
            depth_sample (Tensor): Depth map of all candidates.
            depth_sample_frustum (Tensor): Pre-generated frustum.
            sensor2sensor_mats (Tensor): Transformation matrix from reference
                sensor to source sensor.
        Returns:
            Tensor: Generated mask.
        """
        num_sweeps = len(mono_depth_all_sweeps)
        mask_all_sweeps = list()
        for idx in range(num_sweeps):
            if idx == sweep_index:
                continue
            warped_mono_depth = self.homo_warping(
                mono_depth_all_sweeps[idx],
                mats_dict['intrin_mats'][:, sweep_index, ...],
                mats_dict['intrin_mats'][:, idx, ...],
                sensor2sensor_mats[idx],
                mats_dict['ida_mats'][:, sweep_index, ...],
                mats_dict['ida_mats'][:, idx, ...],
                depth_sample,
                depth_sample_frustum.type_as(mono_depth_all_sweeps[idx]),
            )
            mask = self.mask_net(
                torch.cat([
                    mono_depth_all_sweeps[sweep_index].detach(),
                    warped_mono_depth.mean(2).detach()
                ], 1))
            mask_all_sweeps.append(mask)
        return torch.stack(mask_all_sweeps).mean(0)

    def _forward_single_sweep(self,
                              sweep_index,
                              context,
                              mats_dict,
                              depth_score,
                              is_return_depth=False):
        """Forward function for single sweep.
        Args:
            sweep_index (int): Index of sweeps.
            sweep_imgs (Tensor): Input images.
            mats_dict (dict):
                sensor2ego_mats(Tensor): Transformation matrix from
                    camera to ego with shape of (B, num_sweeps,
                    num_cameras, 4, 4).
                intrin_mats(Tensor): Intrinsic matrix with shape
                    of (B, num_sweeps, num_cameras, 4, 4).
                ida_mats(Tensor): Transformation matrix for ida with
                    shape of (B, num_sweeps, num_cameras, 4, 4).
                sensor2sensor_mats(Tensor): Transformation matrix
                    from key frame camera to sweep frame camera with
                    shape of (B, num_sweeps, num_cameras, 4, 4).
                bda_mat(Tensor): Rotation matrix for bda with shape
                    of (B, 4, 4).
            is_return_depth (bool, optional): Whether to return depth.
                Default: False.
        Returns:
            Tensor: BEV feature map.
        """
        batch_size, num_cams = context.shape[0], context.shape[1]
        context = context.reshape(batch_size * num_cams, *context.shape[2:])
        depth = depth_score
        geom_xyz = self.get_geometry(
            mats_dict['sensor2ego_mats'][:, sweep_index, ...],
            mats_dict['intrin_mats'][:, sweep_index, ...],
            mats_dict['ida_mats'][:, sweep_index, ...],
            mats_dict.get('bda_mat', None),
        )
        geom_xyz = ((geom_xyz - (self.voxel_coord - self.voxel_size / 2.0)) /
                    self.voxel_size).int()
        if self.training or self.use_da:
            img_feat_with_depth = depth.unsqueeze(1) * context.unsqueeze(2)

            img_feat_with_depth = self._forward_voxel_net(img_feat_with_depth)

            img_feat_with_depth = img_feat_with_depth.reshape(
                batch_size,
                num_cams,
                img_feat_with_depth.shape[1],
                img_feat_with_depth.shape[2],
                img_feat_with_depth.shape[3],
                img_feat_with_depth.shape[4],
            )
            img_feat_with_depth = img_feat_with_depth.permute(0, 1, 3, 4, 5, 2)

            feature_map = voxel_pooling_train(geom_xyz,
                                              img_feat_with_depth.contiguous(),
                                              self.voxel_num.cuda())
        else:
            feature_map = voxel_pooling_inference(geom_xyz, depth.contiguous(),
                                                  context.contiguous(),
                                                  self.voxel_num.cuda())
        if is_return_depth:
            return feature_map.contiguous(), depth
        return feature_map.contiguous()

    def forward(self,
                sweep_imgs,
                mats_dict,
                timestamps=None,
                is_return_depth=False):
        """Forward function.
        Args:
            sweep_imgs(Tensor): Input images with shape of (B, num_sweeps,
                num_cameras, 3, H, W).
            mats_dict(dict):
                sensor2ego_mats(Tensor): Transformation matrix from
                    camera to ego with shape of (B, num_sweeps,
                    num_cameras, 4, 4).
                intrin_mats(Tensor): Intrinsic matrix with shape
                    of (B, num_sweeps, num_cameras, 4, 4).
                ida_mats(Tensor): Transformation matrix for ida with
                    shape of (B, num_sweeps, num_cameras, 4, 4).
                sensor2sensor_mats(Tensor): Transformation matrix
                    from key frame camera to sweep frame camera with
                    shape of (B, num_sweeps, num_cameras, 4, 4).
                bda_mat(Tensor): Rotation matrix for bda with shape
                    of (B, 4, 4).
            timestamps(Tensor): Timestamp for all images with the shape of(B,
                num_sweeps, num_cameras).
        Return:
            Tensor: bev feature map.
        """
        batch_size, num_sweeps, num_cams, num_channels, img_height, \
            img_width = sweep_imgs.shape
        context_all_sweeps = list()
        depth_feat_all_sweeps = list()
        img_feats_all_sweeps = list()
        stereo_feats_all_sweeps = list()
        mu_all_sweeps = list()
        sigma_all_sweeps = list()
        mono_depth_all_sweeps = list()
        range_score_all_sweeps = list()
        for sweep_index in range(0, num_sweeps):
            if sweep_index > 0:
                with torch.no_grad():
                    img_feats, stereo_feats = self.get_cam_feats(
                        sweep_imgs[:, sweep_index:sweep_index + 1, ...])
                    img_feats_all_sweeps.append(
                        img_feats.view(batch_size * num_cams,
                                       *img_feats.shape[3:]))
                    stereo_feats_all_sweeps.append(stereo_feats)
                    depth_feat, context, mu, sigma, range_score, mono_depth =\
                        self.depth_net(img_feats.view(batch_size * num_cams,
                                       *img_feats.shape[3:]), mats_dict)
                    context_all_sweeps.append(
                        self.context_downsample_net(
                            context.reshape(batch_size * num_cams,
                                            *context.shape[1:])))
                    depth_feat_all_sweeps.append(depth_feat)
            else:
                img_feats, stereo_feats = self.get_cam_feats(
                    sweep_imgs[:, sweep_index:sweep_index + 1, ...])
                img_feats_all_sweeps.append(
                    img_feats.view(batch_size * num_cams,
                                   *img_feats.shape[3:]))
                stereo_feats_all_sweeps.append(stereo_feats)
                depth_feat, context, mu, sigma, range_score, mono_depth =\
                    self.depth_net(img_feats.view(batch_size * num_cams,
                                   *img_feats.shape[3:]), mats_dict)
                depth_feat_all_sweeps.append(depth_feat)
                context_all_sweeps.append(
                    self.context_downsample_net(
                        context.reshape(batch_size * num_cams,
                                        *context.shape[1:])))
            mu_all_sweeps.append(mu)
            sigma_all_sweeps.append(sigma)
            mono_depth_all_sweeps.append(mono_depth)
            range_score_all_sweeps.append(range_score)
        depth_score_all_sweeps = list()
        final_depth = None
        for ref_idx in range(num_sweeps):
            sensor2sensor_mats = list()
            for src_idx in range(num_sweeps):
                ref2keysensor_mats = mats_dict[
                    'sensor2sensor_mats'][:, ref_idx, ...].inverse()
                key2srcsensor_mats = mats_dict['sensor2sensor_mats'][:,
                                                                     src_idx,
                                                                     ...]
                ref2srcsensor_mats = key2srcsensor_mats @ ref2keysensor_mats
                sensor2sensor_mats.append(ref2srcsensor_mats)
            if ref_idx == 0:
                # last iteration on stage 1 does not have propagation
                # (photometric consistency filtering)
                if self.use_mask:
                    stereo_depth, mask = self._forward_stereo(
                        ref_idx,
                        stereo_feats_all_sweeps,
                        mono_depth_all_sweeps,
                        mats_dict,
                        sensor2sensor_mats,
                        mu_all_sweeps,
                        sigma_all_sweeps,
                        range_score_all_sweeps,
                        depth_feat_all_sweeps,
                    )
                else:
                    stereo_depth = self._forward_stereo(
                        ref_idx,
                        stereo_feats_all_sweeps,
                        mono_depth_all_sweeps,
                        mats_dict,
                        sensor2sensor_mats,
                        mu_all_sweeps,
                        sigma_all_sweeps,
                        range_score_all_sweeps,
                        depth_feat_all_sweeps,
                    )
            else:
                with torch.no_grad():
                    # last iteration on stage 1 does not have
                    # propagation (photometric consistency filtering)
                    if self.use_mask:
                        stereo_depth, mask = self._forward_stereo(
                            ref_idx,
                            stereo_feats_all_sweeps,
                            mono_depth_all_sweeps,
                            mats_dict,
                            sensor2sensor_mats,
                            mu_all_sweeps,
                            sigma_all_sweeps,
                            range_score_all_sweeps,
                            depth_feat_all_sweeps,
                        )
                    else:
                        stereo_depth = self._forward_stereo(
                            ref_idx,
                            stereo_feats_all_sweeps,
                            mono_depth_all_sweeps,
                            mats_dict,
                            sensor2sensor_mats,
                            mu_all_sweeps,
                            sigma_all_sweeps,
                            range_score_all_sweeps,
                            depth_feat_all_sweeps,
                        )
            if self.use_mask:
                depth_score = (
                    mono_depth_all_sweeps[ref_idx] +
                    self.depth_downsample_net(stereo_depth) * mask).softmax(
                        1, dtype=stereo_depth.dtype)
            else:
                depth_score = (
                    mono_depth_all_sweeps[ref_idx] +
                    self.depth_downsample_net(stereo_depth)).softmax(
                        1, dtype=stereo_depth.dtype)
            depth_score_all_sweeps.append(depth_score)
            if ref_idx == 0:
                # final_depth has to be fp32, otherwise the
                # depth loss will colapse during the traing process.
                final_depth = (
                    mono_depth_all_sweeps[ref_idx] +
                    self.depth_downsample_net(stereo_depth)).softmax(1)
        key_frame_res = self._forward_single_sweep(
            0,
            context_all_sweeps[0].reshape(batch_size, num_cams,
                                          *context_all_sweeps[0].shape[1:]),
            mats_dict,
            depth_score_all_sweeps[0],
            is_return_depth=is_return_depth,
        )
        if num_sweeps == 1:
            return key_frame_res

        key_frame_feature = key_frame_res[
            0] if is_return_depth else key_frame_res

        ret_feature_list = [key_frame_feature]
        for sweep_index in range(1, num_sweeps):
            with torch.no_grad():
                feature_map = self._forward_single_sweep(
                    sweep_index,
                    context_all_sweeps[sweep_index].reshape(
                        batch_size, num_cams,
                        *context_all_sweeps[sweep_index].shape[1:]),
                    mats_dict,
                    depth_score_all_sweeps[sweep_index],
                    is_return_depth=False,
                )
                ret_feature_list.append(feature_map)

        if is_return_depth:
            return torch.cat(ret_feature_list, 1), final_depth
        else:
            return torch.cat(ret_feature_list, 1)


================================================
FILE: bevdepth/layers/backbones/fusion_lss_fpn.py
================================================
# Copyright (c) Megvii Inc. All rights reserved.
import torch
import torch.nn as nn
from mmdet.models.backbones.resnet import BasicBlock

try:
    from bevdepth.ops.voxel_pooling_inference import voxel_pooling_inference
    from bevdepth.ops.voxel_pooling_train import voxel_pooling_train
except ImportError:
    print('Import VoxelPooling fail.')

from .base_lss_fpn import ASPP, BaseLSSFPN, Mlp, SELayer

__all__ = ['FusionLSSFPN']


class DepthNet(nn.Module):

    def __init__(self, in_channels, mid_channels, context_channels,
                 depth_channels):
        super(DepthNet, self).__init__()
        self.reduce_conv = nn.Sequential(
            nn.Conv2d(in_channels,
                      mid_channels,
                      kernel_size=3,
                      stride=1,
                      padding=1),
            nn.BatchNorm2d(mid_channels),
            nn.ReLU(inplace=True),
        )
        self.context_conv = nn.Conv2d(mid_channels,
                                      context_channels,
                                      kernel_size=1,
                                      stride=1,
                                      padding=0)
        self.mlp = Mlp(1, mid_channels, mid_channels)
        self.se = SELayer(mid_channels)  # NOTE: add camera-aware
        self.depth_gt_conv = nn.Sequential(
            nn.Conv2d(1, mid_channels, kernel_size=1, stride=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(mid_channels, mid_channels, kernel_size=1, stride=1),
        )
        self.depth_conv = nn.Sequential(
            BasicBlock(mid_channels, mid_channels),
            BasicBlock(mid_channels, mid_channels),
            BasicBlock(mid_channels, mid_channels),
        )
        self.aspp = ASPP(mid_channels, mid_channels)
        self.depth_pred = nn.Conv2d(mid_channels,
                                    depth_channels,
                                    kernel_size=1,
                                    stride=1,
                                    padding=0)

    def forward(self, x, mats_dict, lidar_depth, scale_depth_factor=1000.0):
        x = self.reduce_conv(x)
        context = self.context_conv(x)
        inv_intrinsics = torch.inverse(mats_dict['intrin_mats'][:, 0:1, ...])
        pixel_size = torch.norm(torch.stack(
            [inv_intrinsics[..., 0, 0], inv_intrinsics[..., 1, 1]], dim=-1),
                                dim=-1).reshape(-1, 1)
        aug_scale = torch.sqrt(mats_dict['ida_mats'][:, 0, :, 0, 0]**2 +
                               mats_dict['ida_mats'][:, 0, :, 0,
                                                     0]**2).reshape(-1, 1)
        scaled_pixel_size = pixel_size * scale_depth_factor / aug_scale
        x_se = self.mlp(scaled_pixel_size)[..., None, None]
        x = self.se(x, x_se)
        depth = self.depth_gt_conv(lidar_depth)
        depth = self.depth_conv(x + depth)
        depth = self.aspp(depth)
        depth = self.depth_pred(depth)
        return torch.cat([depth, context], dim=1)


class FusionLSSFPN(BaseLSSFPN):

    def _configure_depth_net(self, depth_net_conf):
        return DepthNet(
            depth_net_conf['in_channels'],
            depth_net_conf['mid_channels'],
            self.output_channels,
            self.depth_channels,
        )

    def _forward_depth_net(self, feat, mats_dict, lidar_depth):
        return self.depth_net(feat, mats_dict, lidar_depth)

    def _forward_single_sweep(self,
                              sweep_index,
                              sweep_imgs,
                              mats_dict,
                              sweep_lidar_depth,
                              is_return_depth=False):
        """Forward function for single sweep.

        Args:
            sweep_index (int): Index of sweeps.
            sweep_imgs (Tensor): Input images.
            mats_dict (dict):
                sensor2ego_mats(Tensor): Transformation matrix from
                    camera to ego with shape of (B, num_sweeps,
                    num_cameras, 4, 4).
                intrin_mats(Tensor): Intrinsic matrix with shape
                    of (B, num_sweeps, num_cameras, 4, 4).
                ida_mats(Tensor): Transformation matrix for ida with
                    shape of (B, num_sweeps, num_cameras, 4, 4).
                sensor2sensor_mats(Tensor): Transformation matrix
                    from key frame camera to sweep frame camera with
                    shape of (B, num_sweeps, num_cameras, 4, 4).
                bda_mat(Tensor): Rotation matrix for bda with shape
                    of (B, 4, 4).
            sweep_lidar_depth (Tensor): Depth generated by lidar.
            is_return_depth (bool, optional): Whether to return depth.
                Default: False.

        Returns:
            Tensor: BEV feature map.
        """
        batch_size, num_sweeps, num_cams, num_channels, img_height, \
            img_width = sweep_imgs.shape
        img_feats = self.get_cam_feats(sweep_imgs)
        sweep_lidar_depth = sweep_lidar_depth.reshape(
            batch_size * num_cams, *sweep_lidar_depth.shape[2:])
        source_features = img_feats[:, 0, ...]
        depth_feature = self._forward_depth_net(
            source_features.reshape(batch_size * num_cams,
                                    source_features.shape[2],
                                    source_features.shape[3],
                                    source_features.shape[4]), mats_dict,
            sweep_lidar_depth)
        depth = depth_feature[:, :self.depth_channels].softmax(
            dim=1, dtype=depth_feature.dtype)
        geom_xyz = self.get_geometry(
            mats_dict['sensor2ego_mats'][:, sweep_index, ...],
            mats_dict['intrin_mats'][:, sweep_index, ...],
            mats_dict['ida_mats'][:, sweep_index, ...],
            mats_dict.get('bda_mat', None),
        )
        geom_xyz = ((geom_xyz - (self.voxel_coord - self.voxel_size / 2.0)) /
                    self.voxel_size).int()
        if self.training or self.use_da:
            img_feat_with_depth = depth.unsqueeze(
                1) * depth_feature[:, self.depth_channels:(
                    self.depth_channels + self.output_channels)].unsqueeze(2)

            img_feat_with_depth = self._forward_voxel_net(img_feat_with_depth)

            img_feat_with_depth = img_feat_with_depth.reshape(
                batch_size,
                num_cams,
                img_feat_with_depth.shape[1],
                img_feat_with_depth.shape[2],
                img_feat_with_depth.shape[3],
                img_feat_with_depth.shape[4],
            )

            img_feat_with_depth = img_feat_with_depth.permute(0, 1, 3, 4, 5, 2)

            feature_map = voxel_pooling_train(geom_xyz,
                                              img_feat_with_depth.contiguous(),
                                              self.voxel_num.cuda())
        else:
            feature_map = voxel_pooling_inference(
                geom_xyz, depth, depth_feature[:, self.depth_channels:(
                    self.depth_channels + self.output_channels)].contiguous(),
                self.voxel_num.cuda())
        if is_return_depth:
            return feature_map.contiguous(), depth.float()
        return feature_map.contiguous()

    def forward(self,
                sweep_imgs,
                mats_dict,
                lidar_depth,
                timestamps=None,
                is_return_depth=False):
        """Forward function.

        Args:
            sweep_imgs(Tensor): Input images with shape of (B, num_sweeps,
                num_cameras, 3, H, W).
            mats_dict(dict):
                sensor2ego_mats(Tensor): Transformation matrix from
                    camera to ego with shape of (B, num_sweeps,
                    num_cameras, 4, 4).
                intrin_mats(Tensor): Intrinsic matrix with shape
                    of (B, num_sweeps, num_cameras, 4, 4).
                ida_mats(Tensor): Transformation matrix for ida with
                    shape of (B, num_sweeps, num_cameras, 4, 4).
                sensor2sensor_mats(Tensor): Transformation matrix
                    from key frame camera to sweep frame camera with
                    shape of (B, num_sweeps, num_cameras, 4, 4).
                bda_mat(Tensor): Rotation matrix for bda with shape
                    of (B, 4, 4).
            lidar_depth (Tensor): Depth generated by lidar.
            timestamps(Tensor): Timestamp for all images with the shape of(B,
                num_sweeps, num_cameras).

        Return:
            Tensor: bev feature map.
        """
        batch_size, num_sweeps, num_cams, num_channels, img_height, \
            img_width = sweep_imgs.shape
        lidar_depth = self.get_downsampled_lidar_depth(lidar_depth)
        key_frame_res = self._forward_single_sweep(
            0,
            sweep_imgs[:, 0:1, ...],
            mats_dict,
            lidar_depth[:, 0, ...],
            is_return_depth=is_return_depth)
        if num_sweeps == 1:
            return key_frame_res

        key_frame_feature = key_frame_res[
            0] if is_return_depth else key_frame_res

        ret_feature_list = [key_frame_feature]
        for sweep_index in range(1, num_sweeps):
            with torch.no_grad():
                feature_map = self._forward_single_sweep(
                    sweep_index,
                    sweep_imgs[:, sweep_index:sweep_index + 1, ...],
                    mats_dict,
                    lidar_depth[:, sweep_index, ...],
                    is_return_depth=False)
                ret_feature_list.append(feature_map)

        if is_return_depth:
            return torch.cat(ret_feature_list, 1), key_frame_res[1]
        else:
            return torch.cat(ret_feature_list, 1)

    def get_downsampled_lidar_depth(self, lidar_depth):
        batch_size, num_sweeps, num_cams, height, width = lidar_depth.shape
        lidar_depth = lidar_depth.view(
            batch_size * num_sweeps * num_cams,
            height // self.downsample_factor,
            self.downsample_factor,
            width // self.downsample_factor,
            self.downsample_factor,
            1,
        )
        lidar_depth = lidar_depth.permute(0, 1, 3, 5, 2, 4).contiguous()
        lidar_depth = lidar_depth.view(
            -1, self.downsample_factor * self.downsample_factor)
        gt_depths_tmp = torch.where(lidar_depth == 0.0, lidar_depth.max(),
                                    lidar_depth)
        lidar_depth = torch.min(gt_depths_tmp, dim=-1).values
        lidar_depth = lidar_depth.view(batch_size, num_sweeps, num_cams, 1,
                                       height // self.downsample_factor,
                                       width // self.downsample_factor)
        lidar_depth = lidar_depth / self.d_bound[1]
        return lidar_depth


================================================
FILE: bevdepth/layers/backbones/matrixvt.py
================================================
# Copyright (c) Megvii Inc. All rights reserved.
import torch
from torch import nn
from torch.cuda.amp import autocast

from bevdepth.layers.backbones.base_lss_fpn import BaseLSSFPN


class HoriConv(nn.Module):

    def __init__(self, in_channels, mid_channels, out_channels, cat_dim=0):
        """HoriConv that reduce the image feature
            in height dimension and refine it.

        Args:
            in_channels (int): in_channels
            mid_channels (int): mid_channels
            out_channels (int): output channels
            cat_dim (int, optional): channels of position
                embedding. Defaults to 0.
        """
        super().__init__()

        self.merger = nn.Sequential(
            nn.Conv2d(in_channels + cat_dim,
                      in_channels,
                      kernel_size=1,
                      bias=True),
            nn.Sigmoid(),
            nn.Conv2d(in_channels, in_channels, kernel_size=1, bias=True),
        )

        self.reduce_conv = nn.Sequential(
            nn.Conv1d(
                in_channels,
                mid_channels,
                kernel_size=3,
                stride=1,
                padding=1,
                bias=False,
            ),
            nn.BatchNorm1d(mid_channels),
            nn.ReLU(inplace=True),
        )

        self.conv1 = nn.Sequential(
            nn.Conv1d(
                mid_channels,
                mid_channels,
                kernel_size=3,
                stride=1,
                padding=1,
                bias=False,
            ),
            nn.BatchNorm1d(mid_channels),
            nn.ReLU(inplace=True),
            nn.Conv1d(
                mid_channels,
                mid_channels,
                kernel_size=3,
                stride=1,
                padding=1,
                bias=False,
            ),
            nn.BatchNorm1d(mid_channels),
            nn.ReLU(inplace=True),
        )

        self.conv2 = nn.Sequential(
            nn.Conv1d(
                mid_channels,
                mid_channels,
                kernel_size=3,
                stride=1,
                padding=1,
                bias=False,
            ),
            nn.BatchNorm1d(mid_channels),
            nn.ReLU(inplace=True),
            nn.Conv1d(
                mid_channels,
                mid_channels,
                kernel_size=3,
                stride=1,
                padding=1,
                bias=False,
            ),
            nn.BatchNorm1d(mid_channels),
            nn.ReLU(inplace=True),
        )

        self.out_conv = nn.Sequential(
            nn.Conv1d(
                mid_channels,
                out_channels,
                kernel_size=3,
                stride=1,
                padding=1,
                bias=True,
            ),
            nn.BatchNorm1d(out_channels),
            nn.ReLU(inplace=True),
        )

    @autocast(False)
    def forward(self, x, pe=None):
        # [N,C,H,W]
        if pe is not None:
            x = self.merger(torch.cat([x, pe], 1))
        else:
            x = self.merger(x)
        x = x.max(2)[0]
        x = self.reduce_conv(x)
        x = self.conv1(x) + x
        x = self.conv2(x) + x
        x = self.out_conv(x)
        return x


class DepthReducer(nn.Module):

    def __init__(self, img_channels, mid_channels):
        """Module that compresses the predicted
            categorical depth in height dimension

        Args:
            img_channels (int): in_channels
            mid_channels (int): mid_channels
        """
        super().__init__()
        self.vertical_weighter = nn.Sequential(
            nn.Conv2d(img_channels,
                      mid_channels,
                      kernel_size=3,
                      stride=1,
                      padding=1),
            nn.BatchNorm2d(mid_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(mid_channels, 1, kernel_size=3, stride=1, padding=1),
        )

    @autocast(False)
    def forward(self, feat, depth):
        vert_weight = self.vertical_weighter(feat).softmax(2)  # [N,1,H,W]
        depth = (depth * vert_weight).sum(2)
        return depth


# NOTE Modified Lift-Splat
class MatrixVT(BaseLSSFPN):

    def __init__(
        self,
        x_bound,
        y_bound,
        z_bound,
        d_bound,
        final_dim,
        downsample_factor,
        output_channels,
        img_backbone_conf,
        img_neck_conf,
        depth_net_conf,
    ):
        """Modified from LSSFPN.

        Args:
            x_bound (list): Boundaries for x.
            y_bound (list): Boundaries for y.
            z_bound (list): Boundaries for z.
            d_bound (list): Boundaries for d.
            final_dim (list): Dimension for input images.
            downsample_factor (int): Downsample factor between feature map
                and input image.
            output_channels (int): Number of channels for the output
                feature map.
            img_backbone_conf (dict): Config for image backbone.
            img_neck_conf (dict): Config for image neck.
            depth_net_conf (dict): Config for depth net.
        """
        super().__init__(
            x_bound,
            y_bound,
            z_bound,
            d_bound,
            final_dim,
            downsample_factor,
            output_channels,
            img_backbone_conf,
            img_neck_conf,
            depth_net_conf,
            use_da=False,
        )

        self.register_buffer('bev_anchors',
                             self.create_bev_anchors(x_bound, y_bound))
        self.horiconv = HoriConv(self.output_channels, 512,
                                 self.output_channels)
        self.depth_reducer = DepthReducer(self.output_channels,
                                          self.output_channels)
        self.static_mat = None

    def create_bev_anchors(self, x_bound, y_bound, ds_rate=1):
        """Create anchors in BEV space

        Args:
            x_bound (list): xbound in meters [start, end, step]
            y_bound (list): ybound in meters [start, end, step]
            ds_rate (iint, optional): downsample rate. Defaults to 1.

        Returns:
            anchors: anchors in [W, H, 2]
        """
        x_coords = ((torch.linspace(
            x_bound[0],
            x_bound[1] - x_bound[2] * ds_rate,
            self.voxel_num[0] // ds_rate,
            dtype=torch.float,
        ) + x_bound[2] * ds_rate / 2).view(self.voxel_num[0] // ds_rate,
                                           1).expand(
                                               self.voxel_num[0] // ds_rate,
                                               self.voxel_num[1] // ds_rate))
        y_coords = ((torch.linspace(
            y_bound[0],
            y_bound[1] - y_bound[2] * ds_rate,
            self.voxel_num[1] // ds_rate,
            dtype=torch.float,
        ) + y_bound[2] * ds_rate / 2).view(
            1,
            self.voxel_num[1] // ds_rate).expand(self.voxel_num[0] // ds_rate,
                                                 self.voxel_num[1] // ds_rate))

        anchors = torch.stack([x_coords, y_coords]).permute(1, 2, 0)
        return anchors

    def get_proj_mat(self, mats_dict=None):
        """Create the Ring Matrix and Ray Matrix

        Args:
            mats_dict (dict, optional): dictionary that
                contains intrin- and extrin- parameters.
            Defaults to None.

        Returns:
            tuple: Ring Matrix in [B, D, L, L] and Ray Matrix in [B, W, L, L]
        """
        if self.static_mat is not None:
            return self.static_mat

        bev_size = int(self.voxel_num[0])  # only consider square BEV
        geom_sep = self.get_geometry(
            mats_dict['sensor2ego_mats'][:, 0, ...],
            mats_dict['intrin_mats'][:, 0, ...],
            mats_dict['ida_mats'][:, 0, ...],
            mats_dict.get('bda_mat', None),
        )
        geom_sep = (
            geom_sep -
            (self.voxel_coord - self.voxel_size / 2.0)) / self.voxel_size
        geom_sep = geom_sep.mean(3).permute(0, 1, 3, 2,
                                            4).contiguous()  # B,Ncam,W,D,2
        B, Nc, W, D, _ = geom_sep.shape
        geom_sep = geom_sep.long().view(B, Nc * W, D, -1)[..., :2]

        invalid1 = torch.logical_or((geom_sep < 0)[..., 0], (geom_sep < 0)[...,
                                                                           1])
        invalid2 = torch.logical_or((geom_sep > (bev_size - 1))[..., 0],
                                    (geom_sep > (bev_size - 1))[..., 1])
        geom_sep[(invalid1 | invalid2)] = int(bev_size / 2)
        geom_idx = geom_sep[..., 1] * bev_size + geom_sep[..., 0]

        geom_uni = self.bev_anchors[None].repeat([B, 1, 1, 1])  # B,128,128,2
        B, L, L, _ = geom_uni.shape

        circle_map = geom_uni.new_zeros((B, D, L * L))

        ray_map = geom_uni.new_zeros((B, Nc * W, L * L))
        for b in range(B):
            for dir in range(Nc * W):
                ray_map[b, dir, geom_idx[b, dir]] += 1
            for d in range(D):
                circle_map[b, d, geom_idx[b, :, d]] += 1
        null_point = int((bev_size / 2) * (bev_size + 1))
        circle_map[..., null_point] = 0
        ray_map[..., null_point] = 0
        circle_map = circle_map.view(B, D, L * L)
        ray_map = ray_map.view(B, -1, L * L)
        circle_map /= circle_map.max(1)[0].clip(min=1)[:, None]
        ray_map /= ray_map.max(1)[0].clip(min=1)[:, None]

        return circle_map, ray_map

    @autocast(False)
    def reduce_and_project(sel
Download .txt
gitextract_ucu3202x/

├── .github/
│   └── workflows/
│       └── lint.yml
├── .gitignore
├── .pre-commit-config.yaml
├── LICENSE.md
├── README.md
├── bevdepth/
│   ├── callbacks/
│   │   └── ema.py
│   ├── datasets/
│   │   └── nusc_det_dataset.py
│   ├── evaluators/
│   │   └── det_evaluators.py
│   ├── exps/
│   │   ├── base_cli.py
│   │   └── nuscenes/
│   │       ├── MatrixVT/
│   │       │   └── matrixvt_bev_depth_lss_r50_256x704_128x128_24e_ema.py
│   │       ├── base_exp.py
│   │       ├── fusion/
│   │       │   ├── bev_depth_fusion_lss_r50_256x704_128x128_24e.py
│   │       │   ├── bev_depth_fusion_lss_r50_256x704_128x128_24e_2key.py
│   │       │   ├── bev_depth_fusion_lss_r50_256x704_128x128_24e_2key_trainval.py
│   │       │   └── bev_depth_fusion_lss_r50_256x704_128x128_24e_key4.py
│   │       └── mv/
│   │           ├── bev_depth_lss_r50_256x704_128x128_20e_cbgs_2key_da.py
│   │           ├── bev_depth_lss_r50_256x704_128x128_20e_cbgs_2key_da_ema.py
│   │           ├── bev_depth_lss_r50_256x704_128x128_24e_2key.py
│   │           ├── bev_depth_lss_r50_256x704_128x128_24e_2key_ema.py
│   │           ├── bev_depth_lss_r50_256x704_128x128_24e_ema.py
│   │           ├── bev_depth_lss_r50_512x1408_128x128_24e_2key.py
│   │           ├── bev_depth_lss_r50_640x1600_128x128_24e_2key.py
│   │           ├── bev_stereo_lss_r50_256x704_128x128_20e_cbgs_2key_da.py
│   │           ├── bev_stereo_lss_r50_256x704_128x128_20e_cbgs_2key_da_ema.py
│   │           ├── bev_stereo_lss_r50_256x704_128x128_24e_2key.py
│   │           ├── bev_stereo_lss_r50_256x704_128x128_24e_2key_ema.py
│   │           ├── bev_stereo_lss_r50_256x704_128x128_24e_key4.py
│   │           └── bev_stereo_lss_r50_256x704_128x128_24e_key4_ema.py
│   ├── layers/
│   │   ├── __init__.py
│   │   ├── backbones/
│   │   │   ├── __init__.py
│   │   │   ├── base_lss_fpn.py
│   │   │   ├── bevstereo_lss_fpn.py
│   │   │   ├── fusion_lss_fpn.py
│   │   │   └── matrixvt.py
│   │   └── heads/
│   │       ├── __init__.py
│   │       └── bev_depth_head.py
│   ├── models/
│   │   ├── base_bev_depth.py
│   │   ├── bev_stereo.py
│   │   ├── fusion_bev_depth.py
│   │   └── matrixvt_det.py
│   ├── ops/
│   │   ├── voxel_pooling_inference/
│   │   │   ├── __init__.py
│   │   │   ├── src/
│   │   │   │   ├── voxel_pooling_inference_forward.cpp
│   │   │   │   └── voxel_pooling_inference_forward_cuda.cu
│   │   │   └── voxel_pooling_inference.py
│   │   └── voxel_pooling_train/
│   │       ├── __init__.py
│   │       ├── src/
│   │       │   ├── voxel_pooling_train_forward.cpp
│   │       │   └── voxel_pooling_train_forward_cuda.cu
│   │       └── voxel_pooling_train.py
│   └── utils/
│       └── torch_dist.py
├── requirements-dev.txt
├── requirements.txt
├── scripts/
│   ├── gen_info.py
│   └── visualize_nusc.py
├── setup.py
└── test/
    ├── test_dataset/
    │   └── test_nusc_mv_det_dataset.py
    ├── test_layers/
    │   ├── test_backbone.py
    │   ├── test_head.py
    │   └── test_matrixvt.py
    └── test_ops/
        └── test_voxel_pooling.py
Download .txt
SYMBOL INDEX (217 symbols across 41 files)

FILE: bevdepth/callbacks/ema.py
  function is_parallel (line 14) | def is_parallel(model):
  class ModelEMA (line 23) | class ModelEMA:
    method __init__ (line 37) | def __init__(self, model, decay=0.9999, updates=0):
    method update (line 53) | def update(self, trainer, model):
  class EMACallback (line 67) | class EMACallback(Callback):
    method __init__ (line 69) | def __init__(self, len_updates) -> None:
    method on_fit_start (line 73) | def on_fit_start(self, trainer, pl_module):
    method on_train_batch_end (line 92) | def on_train_batch_end(self,
    method on_train_epoch_end (line 101) | def on_train_epoch_end(self, trainer, pl_module) -> None:

FILE: bevdepth/datasets/nusc_det_dataset.py
  function get_rot (line 42) | def get_rot(h):
  function img_transform (line 49) | def img_transform(img, resize, resize_dims, crop, flip, rotate):
  function bev_transform (line 80) | def bev_transform(gt_boxes, rotate_angle, scale_ratio, flip_dx, flip_dy):
  function depth_transform (line 107) | def depth_transform(cam_depth, resize, resize_dims, crop, flip, rotate):
  function map_pointcloud_to_image (line 155) | def map_pointcloud_to_image(
  class NuscDetDataset (line 217) | class NuscDetDataset(Dataset):
    method __init__ (line 219) | def __init__(self,
    method _get_sample_indices (line 285) | def _get_sample_indices(self):
    method sample_ida_augmentation (line 320) | def sample_ida_augmentation(self):
    method sample_bda_augmentation (line 349) | def sample_bda_augmentation(self):
    method get_lidar_depth (line 363) | def get_lidar_depth(self, lidar_points, img, lidar_info, cam_info):
    method get_image (line 374) | def get_image(self, cam_infos, cams, lidar_infos=None):
    method get_gt (line 536) | def get_gt(self, info, cams):
    method choose_cams (line 583) | def choose_cams(self):
    method __getitem__ (line 598) | def __getitem__(self, idx):
    method __str__ (line 688) | def __str__(self):
    method __len__ (line 693) | def __len__(self):
  function collate_fn (line 700) | def collate_fn(data, is_return_depth=False):

FILE: bevdepth/evaluators/det_evaluators.py
  class DetNuscEvaluator (line 15) | class DetNuscEvaluator():
    method __init__ (line 37) | def __init__(
    method _evaluate_single (line 61) | def _evaluate_single(self,
    method format_results (line 119) | def format_results(self,
    method evaluate (line 174) | def evaluate(
    method _format_bbox (line 219) | def _format_bbox(self, results, img_metas, jsonfile_prefix=None):

FILE: bevdepth/exps/base_cli.py
  function run_cli (line 13) | def run_cli(model_class=BEVDepthLightningModel,

FILE: bevdepth/exps/nuscenes/MatrixVT/matrixvt_bev_depth_lss_r50_256x704_128x128_24e_ema.py
  class MatrixVT_Exp (line 11) | class MatrixVT_Exp(BaseExp):
    method __init__ (line 13) | def __init__(self, *args, **kwargs):

FILE: bevdepth/exps/nuscenes/base_exp.py
  class BEVDepthLightningModel (line 183) | class BEVDepthLightningModel(LightningModule):
    method __init__ (line 188) | def __init__(self,
    method forward (line 238) | def forward(self, sweep_imgs, mats):
    method training_step (line 241) | def training_step(self, batch):
    method get_depth_loss (line 265) | def get_depth_loss(self, depth_labels, depth_preds):
    method get_downsampled_gt_depth (line 280) | def get_downsampled_gt_depth(self, gt_depths):
    method eval_step (line 317) | def eval_step(self, batch, batch_idx, prefix: str):
    method validation_step (line 335) | def validation_step(self, batch, batch_idx):
    method validation_epoch_end (line 338) | def validation_epoch_end(self, validation_step_outputs):
    method test_epoch_end (line 355) | def test_epoch_end(self, test_step_outputs):
    method configure_optimizers (line 373) | def configure_optimizers(self):
    method train_dataloader (line 382) | def train_dataloader(self):
    method val_dataloader (line 410) | def val_dataloader(self):
    method test_dataloader (line 433) | def test_dataloader(self):
    method predict_dataloader (line 436) | def predict_dataloader(self):
    method test_step (line 459) | def test_step(self, batch, batch_idx):
    method predict_step (line 462) | def predict_step(self, batch, batch_idx):
    method add_model_specific_args (line 466) | def add_model_specific_args(parent_parser):  # pragma: no-cover

FILE: bevdepth/exps/nuscenes/fusion/bev_depth_fusion_lss_r50_256x704_128x128_24e.py
  class BEVDepthLightningModel (line 13) | class BEVDepthLightningModel(BaseBEVDepthLightningModel):
    method __init__ (line 15) | def __init__(self, *args, **kwargs) -> None:
    method forward (line 22) | def forward(self, sweep_imgs, mats, lidar_depth):
    method training_step (line 25) | def training_step(self, batch):
    method eval_step (line 47) | def eval_step(self, batch, batch_idx, prefix: str):

FILE: bevdepth/exps/nuscenes/fusion/bev_depth_fusion_lss_r50_256x704_128x128_24e_2key.py
  class BEVDepthLightningModel (line 8) | class BEVDepthLightningModel(BaseBEVDepthLightningModel):
    method __init__ (line 10) | def __init__(self, **kwargs):

FILE: bevdepth/exps/nuscenes/fusion/bev_depth_fusion_lss_r50_256x704_128x128_24e_2key_trainval.py
  class BEVDepthLightningModel (line 8) | class BEVDepthLightningModel(BaseBEVDepthLightningModel):
    method __init__ (line 10) | def __init__(self, *args, **kwargs) -> None:

FILE: bevdepth/exps/nuscenes/fusion/bev_depth_fusion_lss_r50_256x704_128x128_24e_key4.py
  class BEVDepthLightningModel (line 8) | class BEVDepthLightningModel(BaseBEVDepthLightningModel):
    method __init__ (line 10) | def __init__(self, **kwargs):

FILE: bevdepth/exps/nuscenes/mv/bev_depth_lss_r50_256x704_128x128_20e_cbgs_2key_da.py
  class BEVDepthLightningModel (line 33) | class BEVDepthLightningModel(BaseBEVDepthLightningModel):
    method __init__ (line 35) | def __init__(self, **kwargs):
    method configure_optimizers (line 43) | def configure_optimizers(self):

FILE: bevdepth/exps/nuscenes/mv/bev_depth_lss_r50_256x704_128x128_24e_2key.py
  class BEVDepthLightningModel (line 31) | class BEVDepthLightningModel(BaseBEVDepthLightningModel):
    method __init__ (line 33) | def __init__(self, **kwargs):

FILE: bevdepth/exps/nuscenes/mv/bev_depth_lss_r50_256x704_128x128_24e_ema.py
  class BEVDepthLightningModel (line 12) | class BEVDepthLightningModel(BaseBEVDepthLightningModel):
    method configure_optimizers (line 14) | def configure_optimizers(self):

FILE: bevdepth/exps/nuscenes/mv/bev_depth_lss_r50_512x1408_128x128_24e_2key.py
  class BEVDepthLightningModel (line 11) | class BEVDepthLightningModel(BaseBEVDepthLightningModel):
    method __init__ (line 13) | def __init__(self, **kwargs):
    method configure_optimizers (line 23) | def configure_optimizers(self):

FILE: bevdepth/exps/nuscenes/mv/bev_depth_lss_r50_640x1600_128x128_24e_2key.py
  class BEVDepthLightningModel (line 11) | class BEVDepthLightningModel(BaseBEVDepthLightningModel):
    method __init__ (line 13) | def __init__(self, **kwargs):
    method configure_optimizers (line 23) | def configure_optimizers(self):

FILE: bevdepth/exps/nuscenes/mv/bev_stereo_lss_r50_256x704_128x128_20e_cbgs_2key_da.py
  class BEVDepthLightningModel (line 33) | class BEVDepthLightningModel(BaseBEVDepthLightningModel):
    method __init__ (line 35) | def __init__(self, **kwargs):
    method configure_optimizers (line 44) | def configure_optimizers(self):

FILE: bevdepth/exps/nuscenes/mv/bev_stereo_lss_r50_256x704_128x128_24e_2key.py
  class BEVDepthLightningModel (line 30) | class BEVDepthLightningModel(BaseBEVDepthLightningModel):
    method __init__ (line 32) | def __init__(self, **kwargs):

FILE: bevdepth/exps/nuscenes/mv/bev_stereo_lss_r50_256x704_128x128_24e_key4.py
  class BEVDepthLightningModel (line 29) | class BEVDepthLightningModel(BaseBEVDepthLightningModel):
    method __init__ (line 31) | def __init__(self, **kwargs):

FILE: bevdepth/exps/nuscenes/mv/bev_stereo_lss_r50_256x704_128x128_24e_key4_ema.py
  class BEVDepthLightningModel (line 29) | class BEVDepthLightningModel(BaseBEVDepthLightningModel):
    method __init__ (line 31) | def __init__(self, **kwargs):

FILE: bevdepth/layers/backbones/base_lss_fpn.py
  class _ASPPModule (line 20) | class _ASPPModule(nn.Module):
    method __init__ (line 22) | def __init__(self, inplanes, planes, kernel_size, padding, dilation,
    method forward (line 37) | def forward(self, x):
    method _init_weight (line 43) | def _init_weight(self):
  class ASPP (line 52) | class ASPP(nn.Module):
    method __init__ (line 54) | def __init__(self, inplanes, mid_channels=256, BatchNorm=nn.BatchNorm2d):
    method forward (line 99) | def forward(self, x):
    method _init_weight (line 117) | def _init_weight(self):
  class Mlp (line 126) | class Mlp(nn.Module):
    method __init__ (line 128) | def __init__(self,
    method forward (line 143) | def forward(self, x):
  class SELayer (line 152) | class SELayer(nn.Module):
    method __init__ (line 154) | def __init__(self, channels, act_layer=nn.ReLU, gate_layer=nn.Sigmoid):
    method forward (line 161) | def forward(self, x, x_se):
  class DepthNet (line 168) | class DepthNet(nn.Module):
    method __init__ (line 170) | def __init__(self, in_channels, mid_channels, context_channels,
    method forward (line 213) | def forward(self, x, mats_dict):
  class DepthAggregation (line 258) | class DepthAggregation(nn.Module):
    method __init__ (line 263) | def __init__(self, in_channels, mid_channels, out_channels):
    method forward (line 308) | def forward(self, x):
  class BaseLSSFPN (line 315) | class BaseLSSFPN(nn.Module):
    method __init__ (line 317) | def __init__(self,
    method _configure_depth_net (line 378) | def _configure_depth_net(self, depth_net_conf):
    method _configure_depth_aggregation_net (line 386) | def _configure_depth_aggregation_net(self):
    method _forward_voxel_net (line 391) | def _forward_voxel_net(self, img_feat_with_depth):
    method create_frustum (line 404) | def create_frustum(self):
    method get_geometry (line 424) | def get_geometry(self, sensor2ego_mat, intrin_mat, ida_mat, bda_mat):
    method get_cam_feats (line 461) | def get_cam_feats(self, imgs):
    method _forward_depth_net (line 473) | def _forward_depth_net(self, feat, mats_dict):
    method _forward_single_sweep (line 476) | def _forward_single_sweep(self,
    method forward (line 559) | def forward(self,

FILE: bevdepth/layers/backbones/bevstereo_lss_fpn.py
  class ConvBnReLU3D (line 25) | class ConvBnReLU3D(nn.Module):
    method __init__ (line 28) | def __init__(
    method forward (line 57) | def forward(self, x: torch.Tensor) -> torch.Tensor:
  class DepthNet (line 62) | class DepthNet(nn.Module):
    method __init__ (line 64) | def __init__(self,
    method forward (line 141) | def forward(self, x, mats_dict, scale_depth_factor=1000.0):
  class BEVStereoLSSFPN (line 198) | class BEVStereoLSSFPN(BaseLSSFPN):
    method __init__ (line 200) | def __init__(self,
    method depth_sampling (line 322) | def depth_sampling(self):
    method _generate_cost_volume (line 335) | def _generate_cost_volume(
    method homo_warping (line 398) | def homo_warping(
    method _forward_stereo (line 477) | def _forward_stereo(
    method create_depth_sample_frustum (line 633) | def create_depth_sample_frustum(self, depth_sample, downsample_factor=...
    method _configure_depth_net (line 661) | def _configure_depth_net(self, depth_net_conf):
    method get_cam_feats (line 671) | def get_cam_feats(self, imgs):
    method _forward_mask (line 685) | def _forward_mask(
    method _forward_single_sweep (line 742) | def _forward_single_sweep(self,
    method forward (line 807) | def forward(self,

FILE: bevdepth/layers/backbones/fusion_lss_fpn.py
  class DepthNet (line 17) | class DepthNet(nn.Module):
    method __init__ (line 19) | def __init__(self, in_channels, mid_channels, context_channels,
    method forward (line 55) | def forward(self, x, mats_dict, lidar_depth, scale_depth_factor=1000.0):
  class FusionLSSFPN (line 75) | class FusionLSSFPN(BaseLSSFPN):
    method _configure_depth_net (line 77) | def _configure_depth_net(self, depth_net_conf):
    method _forward_depth_net (line 85) | def _forward_depth_net(self, feat, mats_dict, lidar_depth):
    method _forward_single_sweep (line 88) | def _forward_single_sweep(self,
    method forward (line 171) | def forward(self,
    method get_downsampled_lidar_depth (line 233) | def get_downsampled_lidar_depth(self, lidar_depth):

FILE: bevdepth/layers/backbones/matrixvt.py
  class HoriConv (line 9) | class HoriConv(nn.Module):
    method __init__ (line 11) | def __init__(self, in_channels, mid_channels, out_channels, cat_dim=0):
    method forward (line 106) | def forward(self, x, pe=None):
  class DepthReducer (line 120) | class DepthReducer(nn.Module):
    method __init__ (line 122) | def __init__(self, img_channels, mid_channels):
    method forward (line 143) | def forward(self, feat, depth):
  class MatrixVT (line 150) | class MatrixVT(BaseLSSFPN):
    method __init__ (line 152) | def __init__(
    method create_bev_anchors (line 203) | def create_bev_anchors(self, x_bound, y_bound, ds_rate=1):
    method get_proj_mat (line 236) | def get_proj_mat(self, mats_dict=None):
    method reduce_and_project (line 294) | def reduce_and_project(self, feature, depth, mats_dict):
    method _forward_single_sweep (line 329) | def _forward_single_sweep(self,

FILE: bevdepth/layers/heads/bev_depth_head.py
  function size_aware_circle_nms (line 34) | def size_aware_circle_nms(dets, thresh_scale, post_max_size=83):
  class BEVDepthHead (line 85) | class BEVDepthHead(CenterHead):
    method __init__ (line 103) | def __init__(
    method forward (line 141) | def forward(self, x):
    method get_targets_single (line 169) | def get_targets_single(self, gt_bboxes_3d, gt_labels_3d):
    method loss (line 322) | def loss(self, targets, preds_dicts, **kwargs):
    method get_bboxes (line 382) | def get_bboxes(self, preds_dicts, img_metas, img=None, rescale=False):

FILE: bevdepth/models/base_bev_depth.py
  class BaseBEVDepth (line 9) | class BaseBEVDepth(nn.Module):
    method __init__ (line 20) | def __init__(self, backbone_conf, head_conf, is_train_depth=False):
    method forward (line 26) | def forward(
    method get_targets (line 67) | def get_targets(self, gt_boxes, gt_labels):
    method loss (line 87) | def loss(self, targets, preds_dicts):
    method get_bboxes (line 101) | def get_bboxes(self, preds_dicts, img_metas=None, img=None, rescale=Fa...

FILE: bevdepth/models/bev_stereo.py
  class BEVStereo (line 7) | class BEVStereo(BaseBEVDepth):
    method __init__ (line 18) | def __init__(self, backbone_conf, head_conf, is_train_depth=False):

FILE: bevdepth/models/fusion_bev_depth.py
  class FusionBEVDepth (line 9) | class FusionBEVDepth(BaseBEVDepth):
    method __init__ (line 20) | def __init__(self, backbone_conf, head_conf, is_train_depth=False):
    method forward (line 26) | def forward(

FILE: bevdepth/models/matrixvt_det.py
  class MatrixVT_Det (line 5) | class MatrixVT_Det(BaseBEVDepth):
    method __init__ (line 15) | def __init__(self, backbone_conf, head_conf, is_train_depth=False):

FILE: bevdepth/ops/voxel_pooling_inference/src/voxel_pooling_inference_forward.cpp
  function voxel_pooling_inference_forward_wrapper (line 37) | int voxel_pooling_inference_forward_wrapper(
  function PYBIND11_MODULE (line 71) | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {

FILE: bevdepth/ops/voxel_pooling_inference/voxel_pooling_inference.py
  class VoxelPoolingInference (line 8) | class VoxelPoolingInference(Function):
    method forward (line 11) | def forward(ctx, geom_xyz: torch.Tensor, depth_features: torch.Tensor,

FILE: bevdepth/ops/voxel_pooling_train/src/voxel_pooling_train_forward.cpp
  function voxel_pooling_train_forward_wrapper (line 38) | int voxel_pooling_train_forward_wrapper(int batch_size, int num_points,
  function PYBIND11_MODULE (line 75) | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {

FILE: bevdepth/ops/voxel_pooling_train/voxel_pooling_train.py
  class VoxelPoolingTrain (line 8) | class VoxelPoolingTrain(Function):
    method forward (line 11) | def forward(ctx, geom_xyz: torch.Tensor, input_features: torch.Tensor,
    method backward (line 59) | def backward(ctx, grad_output_features):

FILE: bevdepth/utils/torch_dist.py
  function get_rank (line 8) | def get_rank() -> int:
  function get_world_size (line 16) | def get_world_size() -> int:
  function synchronize (line 24) | def synchronize():
  function all_gather_object (line 37) | def all_gather_object(obj):
  function is_available (line 46) | def is_available() -> bool:

FILE: scripts/gen_info.py
  function generate_info (line 8) | def generate_info(nusc, scenes, max_cam_sweeps=6, max_lidar_sweeps=10):
  function main (line 149) | def main():

FILE: scripts/visualize_nusc.py
  function parse_args (line 16) | def parse_args():
  function get_ego_box (line 29) | def get_ego_box(box_dict, ego2global_rotation, ego2global_translation):
  function rotate_points_along_z (line 46) | def rotate_points_along_z(points, angle):
  function get_corners (line 65) | def get_corners(boxes3d):
  function get_bev_lines (line 99) | def get_bev_lines(corners):
  function get_3d_lines (line 104) | def get_3d_lines(corners):
  function get_cam_corners (line 114) | def get_cam_corners(corners, translation, rotation, cam_intrinsics):
  function demo (line 125) | def demo(

FILE: setup.py
  function make_cuda_ext (line 12) | def make_cuda_ext(name,

FILE: test/test_dataset/test_nusc_mv_det_dataset.py
  class TestNuscMVDetData (line 54) | class TestNuscMVDetData(unittest.TestCase):
    method test_voxel_pooling (line 56) | def test_voxel_pooling(self):

FILE: test/test_layers/test_backbone.py
  class TestLSSFPN (line 9) | class TestLSSFPN(unittest.TestCase):
    method setUp (line 11) | def setUp(self) -> None:
    method test_forward (line 43) | def test_forward(self):

FILE: test/test_layers/test_head.py
  class TestLSSFPN (line 10) | class TestLSSFPN(unittest.TestCase):
    method setUp (line 12) | def setUp(self) -> None:
    method test_forward (line 102) | def test_forward(self):
    method test_get_targets (line 115) | def test_get_targets(self):
    method test_get_bboxes (line 137) | def test_get_bboxes(self):

FILE: test/test_layers/test_matrixvt.py
  class TestMatrixVT (line 8) | class TestMatrixVT(unittest.TestCase):
    method setUp (line 10) | def setUp(self) -> None:
    method test_forward (line 45) | def test_forward(self):

FILE: test/test_ops/test_voxel_pooling.py
  class TestLSSFPN (line 9) | class TestLSSFPN(unittest.TestCase):
    method test_voxel_pooling (line 13) | def test_voxel_pooling(self):
Condensed preview — 59 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (298K chars).
[
  {
    "path": ".github/workflows/lint.yml",
    "chars": 1034,
    "preview": "name: lint\n\non: [push, pull_request]\n\nconcurrency:\n  group: ${{ github.workflow }}-${{ github.ref }}\n  cancel-in-progres"
  },
  {
    "path": ".gitignore",
    "chars": 3111,
    "preview": "### Linux ###\n*~\n\n# temporary files which can be created if a process still has a handle open of a deleted file\n.fuse_hi"
  },
  {
    "path": ".pre-commit-config.yaml",
    "chars": 775,
    "preview": "repos:\n  - repo: https://github.com/PyCQA/flake8\n    rev: 5.0.4\n    hooks:\n      - id: flake8\n  - repo: https://github.c"
  },
  {
    "path": "LICENSE.md",
    "chars": 1077,
    "preview": "MIT License\n\nCopyright (c) 2022 Megvii-BaseDetection\n\nPermission is hereby granted, free of charge, to any person obtain"
  },
  {
    "path": "README.md",
    "chars": 7344,
    "preview": "## BEVDepth\nBEVDepth is a new 3D object detector with a trustworthy depth\nestimation. For more details, please refer to "
  },
  {
    "path": "bevdepth/callbacks/ema.py",
    "chars": 4261,
    "preview": "#!/usr/bin/env python3\n# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.\nimport math\nimport os\nfrom copy import"
  },
  {
    "path": "bevdepth/datasets/nusc_det_dataset.py",
    "chars": 30544,
    "preview": "import os\n\nimport mmcv\nimport numpy as np\nimport torch\nfrom mmdet3d.core.bbox.structures.lidar_box3d import LiDARInstanc"
  },
  {
    "path": "bevdepth/evaluators/det_evaluators.py",
    "chars": 12024,
    "preview": "'''Modified from # https://github.com/nutonomy/nuscenes-devkit/blob/57889ff20678577025326cfc24e57424a829be0a/python-sdk/"
  },
  {
    "path": "bevdepth/exps/base_cli.py",
    "chars": 3431,
    "preview": "# Copyright (c) Megvii Inc. All rights reserved.\nimport os\nfrom argparse import ArgumentParser\n\nimport pytorch_lightning"
  },
  {
    "path": "bevdepth/exps/nuscenes/MatrixVT/matrixvt_bev_depth_lss_r50_256x704_128x128_24e_ema.py",
    "chars": 790,
    "preview": "# Copyright (c) Megvii Inc. All rights reserved.\n# isort: skip_file\nfrom bevdepth.exps.base_cli import run_cli\n# Basic E"
  },
  {
    "path": "bevdepth/exps/nuscenes/base_exp.py",
    "chars": 17401,
    "preview": "# Copyright (c) Megvii Inc. All rights reserved.\nimport os\nfrom functools import partial\n\nimport mmcv\nimport torch\nimpor"
  },
  {
    "path": "bevdepth/exps/nuscenes/fusion/bev_depth_fusion_lss_r50_256x704_128x128_24e.py",
    "chars": 2858,
    "preview": "# Copyright (c) Megvii Inc. All rights reserved.\nimport torch\nimport torch.nn.parallel\nimport torch.utils.data\nimport to"
  },
  {
    "path": "bevdepth/exps/nuscenes/fusion/bev_depth_fusion_lss_r50_256x704_128x128_24e_2key.py",
    "chars": 1026,
    "preview": "# Copyright (c) Megvii Inc. All rights reserved.\nfrom bevdepth.exps.base_cli import run_cli\nfrom bevdepth.exps.nuscenes."
  },
  {
    "path": "bevdepth/exps/nuscenes/fusion/bev_depth_fusion_lss_r50_256x704_128x128_24e_2key_trainval.py",
    "chars": 655,
    "preview": "# Copyright (c) Megvii Inc. All rights reserved.\nfrom bevdepth.exps.base_cli import run_cli\n\nfrom .bev_depth_fusion_lss_"
  },
  {
    "path": "bevdepth/exps/nuscenes/fusion/bev_depth_fusion_lss_r50_256x704_128x128_24e_key4.py",
    "chars": 1125,
    "preview": "# Copyright (c) Megvii Inc. All rights reserved.\nfrom bevdepth.exps.base_cli import run_cli\nfrom bevdepth.exps.nuscenes."
  },
  {
    "path": "bevdepth/exps/nuscenes/mv/bev_depth_lss_r50_256x704_128x128_20e_cbgs_2key_da.py",
    "chars": 2084,
    "preview": "# Copyright (c) Megvii Inc. All rights reserved.\n\"\"\"\nmAP: 0.3484\nmATE: 0.6159\nmASE: 0.2716\nmAOE: 0.4144\nmAVE: 0.4402\nmAA"
  },
  {
    "path": "bevdepth/exps/nuscenes/mv/bev_depth_lss_r50_256x704_128x128_20e_cbgs_2key_da_ema.py",
    "chars": 1202,
    "preview": "# Copyright (c) Megvii Inc. All rights reserved.\n\"\"\"\nmAP: 0.3589\nmATE: 0.6119\nmASE: 0.2692\nmAOE: 0.5074\nmAVE: 0.4086\nmAA"
  },
  {
    "path": "bevdepth/exps/nuscenes/mv/bev_depth_lss_r50_256x704_128x128_24e_2key.py",
    "chars": 1823,
    "preview": "# Copyright (c) Megvii Inc. All rights reserved.\n\"\"\"\nmAP: 0.3304\nmATE: 0.7021\nmASE: 0.2795\nmAOE: 0.5346\nmAVE: 0.5530\nmAA"
  },
  {
    "path": "bevdepth/exps/nuscenes/mv/bev_depth_lss_r50_256x704_128x128_24e_2key_ema.py",
    "chars": 1115,
    "preview": "# Copyright (c) Megvii Inc. All rights reserved.\n\"\"\"\nmAP: 0.3329\nmATE: 0.6832\nmASE: 0.2761\nmAOE: 0.5446\nmAVE: 0.5258\nmAA"
  },
  {
    "path": "bevdepth/exps/nuscenes/mv/bev_depth_lss_r50_256x704_128x128_24e_ema.py",
    "chars": 821,
    "preview": "# Copyright (c) Megvii Inc. All rights reserved.\nimport torch\nimport torch.nn.parallel\nimport torch.utils.data\nimport to"
  },
  {
    "path": "bevdepth/exps/nuscenes/mv/bev_depth_lss_r50_512x1408_128x128_24e_2key.py",
    "chars": 1358,
    "preview": "# Copyright (c) Megvii Inc. All rights reserved.\nimport torch\nfrom torch.optim.lr_scheduler import MultiStepLR\n\nfrom bev"
  },
  {
    "path": "bevdepth/exps/nuscenes/mv/bev_depth_lss_r50_640x1600_128x128_24e_2key.py",
    "chars": 1358,
    "preview": "# Copyright (c) Megvii Inc. All rights reserved.\nimport torch\nfrom torch.optim.lr_scheduler import MultiStepLR\n\nfrom bev"
  },
  {
    "path": "bevdepth/exps/nuscenes/mv/bev_stereo_lss_r50_256x704_128x128_20e_cbgs_2key_da.py",
    "chars": 2096,
    "preview": "# Copyright (c) Megvii Inc. All rights reserved.\n\"\"\"\nmAP: 0.3576\nmATE: 0.6071\nmASE: 0.2684\nmAOE: 0.4157\nmAVE: 0.3928\nmAA"
  },
  {
    "path": "bevdepth/exps/nuscenes/mv/bev_stereo_lss_r50_256x704_128x128_20e_cbgs_2key_da_ema.py",
    "chars": 1204,
    "preview": "# Copyright (c) Megvii Inc. All rights reserved.\n\"\"\"\nmAP: 0.3721\nmATE: 0.5980\nmASE: 0.2701\nmAOE: 0.4381\nmAVE: 0.3672\nmAA"
  },
  {
    "path": "bevdepth/exps/nuscenes/mv/bev_stereo_lss_r50_256x704_128x128_24e_2key.py",
    "chars": 1982,
    "preview": "# Copyright (c) Megvii Inc. All rights reserved.\n\"\"\"\nmAP: 0.3456\nmATE: 0.6589\nmASE: 0.2774\nmAOE: 0.5500\nmAVE: 0.4980\nmAA"
  },
  {
    "path": "bevdepth/exps/nuscenes/mv/bev_stereo_lss_r50_256x704_128x128_24e_2key_ema.py",
    "chars": 1134,
    "preview": "# Copyright (c) Megvii Inc. All rights reserved.\n\"\"\"\nmAP: 0.3494\nmATE: 0.6672\nmASE: 0.2785\nmAOE: 0.5607\nmAVE: 0.4687\nmAA"
  },
  {
    "path": "bevdepth/exps/nuscenes/mv/bev_stereo_lss_r50_256x704_128x128_24e_key4.py",
    "chars": 1356,
    "preview": "# Copyright (c) Megvii Inc. All rights reserved.\n\"\"\"\nmAP: 0.3427\nmATE: 0.6560\nmASE: 0.2784\nmAOE: 0.5982\nmAVE: 0.5347\nmAA"
  },
  {
    "path": "bevdepth/exps/nuscenes/mv/bev_stereo_lss_r50_256x704_128x128_24e_key4_ema.py",
    "chars": 1360,
    "preview": "# Copyright (c) Megvii Inc. All rights reserved.\n\"\"\"\nmAP: 0.3427\nmATE: 0.6560\nmASE: 0.2784\nmAOE: 0.5982\nmAVE: 0.5347\nmAA"
  },
  {
    "path": "bevdepth/layers/__init__.py",
    "chars": 75,
    "preview": "from .heads.bev_depth_head import BEVDepthHead\n\n__all__ = ['BEVDepthHead']\n"
  },
  {
    "path": "bevdepth/layers/backbones/__init__.py",
    "chars": 120,
    "preview": "from .base_lss_fpn import BaseLSSFPN\nfrom .fusion_lss_fpn import FusionLSSFPN\n\n__all__ = ['BaseLSSFPN', 'FusionLSSFPN']\n"
  },
  {
    "path": "bevdepth/layers/backbones/base_lss_fpn.py",
    "chars": 23461,
    "preview": "# Copyright (c) Megvii Inc. All rights reserved.\nimport torch\nimport torch.nn.functional as F\nfrom mmcv.cnn import build"
  },
  {
    "path": "bevdepth/layers/backbones/bevstereo_lss_fpn.py",
    "chars": 43975,
    "preview": "# Copyright (c) Megvii Inc. All rights reserved.\nimport math\n\nimport numpy as np\nimport torch\nimport torch.nn.functional"
  },
  {
    "path": "bevdepth/layers/backbones/fusion_lss_fpn.py",
    "chars": 10894,
    "preview": "# Copyright (c) Megvii Inc. All rights reserved.\nimport torch\nimport torch.nn as nn\nfrom mmdet.models.backbones.resnet i"
  },
  {
    "path": "bevdepth/layers/backbones/matrixvt.py",
    "chars": 14044,
    "preview": "# Copyright (c) Megvii Inc. All rights reserved.\nimport torch\nfrom torch import nn\nfrom torch.cuda.amp import autocast\n\n"
  },
  {
    "path": "bevdepth/layers/heads/__init__.py",
    "chars": 69,
    "preview": "from .bev_depth_head import BEVDepthHead\n\n__all__ = ['BEVDepthHead']\n"
  },
  {
    "path": "bevdepth/layers/heads/bev_depth_head.py",
    "chars": 20265,
    "preview": "\"\"\"Inherited from `https://github.com/open-mmlab/mmdetection3d/blob/master/mmdet3d/models/dense_heads/centerpoint_head.p"
  },
  {
    "path": "bevdepth/models/base_bev_depth.py",
    "chars": 4083,
    "preview": "from torch import nn\n\nfrom bevdepth.layers.backbones.base_lss_fpn import BaseLSSFPN\nfrom bevdepth.layers.heads.bev_depth"
  },
  {
    "path": "bevdepth/models/bev_stereo.py",
    "chars": 751,
    "preview": "from bevdepth.layers.backbones.bevstereo_lss_fpn import BEVStereoLSSFPN\nfrom bevdepth.models.base_bev_depth import BaseB"
  },
  {
    "path": "bevdepth/models/fusion_bev_depth.py",
    "chars": 2326,
    "preview": "from bevdepth.layers.backbones.fusion_lss_fpn import FusionLSSFPN\nfrom bevdepth.layers.heads.bev_depth_head import BEVDe"
  },
  {
    "path": "bevdepth/models/matrixvt_det.py",
    "chars": 593,
    "preview": "from bevdepth.layers.backbones.matrixvt import MatrixVT\nfrom bevdepth.models.base_bev_depth import BaseBEVDepth\n\n\nclass "
  },
  {
    "path": "bevdepth/ops/voxel_pooling_inference/__init__.py",
    "chars": 100,
    "preview": "from .voxel_pooling_inference import voxel_pooling_inference\n\n__all__ = ['voxel_pooling_inference']\n"
  },
  {
    "path": "bevdepth/ops/voxel_pooling_inference/src/voxel_pooling_inference_forward.cpp",
    "chars": 3465,
    "preview": "// Copyright (c) Megvii Inc. All rights reserved.\n#include <ATen/cuda/CUDAContext.h>\n#include <cuda.h>\n#include <cuda_fp"
  },
  {
    "path": "bevdepth/ops/voxel_pooling_inference/src/voxel_pooling_inference_forward_cuda.cu",
    "chars": 8038,
    "preview": "// Copyright (c) Megvii Inc. All rights reserved.\n#include <cuda_fp16.h>\n#include <math.h>\n#include <stdio.h>\n#include <"
  },
  {
    "path": "bevdepth/ops/voxel_pooling_inference/voxel_pooling_inference.py",
    "chars": 1951,
    "preview": "# Copyright (c) Megvii Inc. All rights reserved.\nimport torch\nfrom torch.autograd import Function\n\nfrom . import voxel_p"
  },
  {
    "path": "bevdepth/ops/voxel_pooling_train/__init__.py",
    "chars": 88,
    "preview": "from .voxel_pooling_train import voxel_pooling_train\n\n__all__ = ['voxel_pooling_train']\n"
  },
  {
    "path": "bevdepth/ops/voxel_pooling_train/src/voxel_pooling_train_forward.cpp",
    "chars": 3422,
    "preview": "// Copyright (c) Megvii Inc. All rights reserved.\n#include <ATen/cuda/CUDAContext.h>\n#include <cuda.h>\n#include <cuda_fp"
  },
  {
    "path": "bevdepth/ops/voxel_pooling_train/src/voxel_pooling_train_forward_cuda.cu",
    "chars": 4213,
    "preview": "// Copyright (c) Megvii Inc. All rights reserved.\n#include <cuda_fp16.h>\n#include <math.h>\n#include <stdio.h>\n#include <"
  },
  {
    "path": "bevdepth/ops/voxel_pooling_train/voxel_pooling_train.py",
    "chars": 2947,
    "preview": "# Copyright (c) Megvii Inc. All rights reserved.\nimport torch\nfrom torch.autograd import Function\n\nfrom . import voxel_p"
  },
  {
    "path": "bevdepth/utils/torch_dist.py",
    "chars": 1030,
    "preview": "\"\"\"\n@author: zeming li\n@contact: zengarden2009@gmail.com\n\"\"\"\nfrom torch import distributed as dist\n\n\ndef get_rank() -> i"
  },
  {
    "path": "requirements-dev.txt",
    "chars": 351,
    "preview": "# code formatter\n# force to use same version of the formatter, can be changed only by maintainer.\n\nanybadge\nautoflake==1"
  },
  {
    "path": "requirements.txt",
    "chars": 167,
    "preview": "numba\nnumpy\nnuscenes-devkit\nopencv-python-headless\npandas\npytorch-lightning==1.6.0\nscikit-image\nscipy\nsetuptools==59.5.0"
  },
  {
    "path": "scripts/gen_info.py",
    "chars": 8120,
    "preview": "import mmcv\nimport numpy as np\nfrom nuscenes.nuscenes import NuScenes\nfrom nuscenes.utils import splits\nfrom tqdm import"
  },
  {
    "path": "scripts/visualize_nusc.py",
    "chars": 8872,
    "preview": "import os\nfrom argparse import ArgumentParser\n\nimport cv2\nimport matplotlib.cm as cm\nimport matplotlib.pyplot as plt\nimp"
  },
  {
    "path": "setup.py",
    "chars": 2393,
    "preview": "import os\n\nimport torch\nfrom setuptools import find_packages, setup\nfrom torch.utils.cpp_extension import (BuildExtensio"
  },
  {
    "path": "test/test_dataset/test_nusc_mv_det_dataset.py",
    "chars": 2144,
    "preview": "import unittest\n\nimport numpy as np\nimport torch\n\nfrom bevdepth.datasets.nusc_det_dataset import NuscDetDataset\n\nCLASSES"
  },
  {
    "path": "test/test_layers/test_backbone.py",
    "chars": 1973,
    "preview": "import unittest\n\nimport pytest\nimport torch\n\nfrom bevdepth.layers.backbones.base_lss_fpn import BaseLSSFPN\n\n\nclass TestL"
  },
  {
    "path": "test/test_layers/test_head.py",
    "chars": 5751,
    "preview": "import unittest\n\nimport pytest\nimport torch\nfrom mmdet3d.core.bbox.structures.lidar_box3d import LiDARInstance3DBoxes\n\nf"
  },
  {
    "path": "test/test_layers/test_matrixvt.py",
    "chars": 1819,
    "preview": "import unittest\n\nimport torch\n\nfrom bevdepth.layers.backbones.matrixvt import MatrixVT\n\n\nclass TestMatrixVT(unittest.Tes"
  },
  {
    "path": "test/test_ops/test_voxel_pooling.py",
    "chars": 1475,
    "preview": "import unittest\n\nimport pytest\nimport torch\n\nfrom bevdepth.ops.voxel_pooling_train import voxel_pooling_train\n\n\nclass Te"
  }
]

About this extraction

This page contains the full source code of the Megvii-BaseDetection/BEVDepth GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 59 files (278.6 KB), approximately 73.1k tokens, and a symbol index with 217 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!