Full Code of kakaobrain/sparse-detr for AI

main f40632c3f467 cached
64 files
365.2 KB
88.7k tokens
341 symbols
1 requests
Download .txt
Showing preview only (386K chars total). Download the full file or copy to clipboard to get everything.
Repository: kakaobrain/sparse-detr
Branch: main
Commit: f40632c3f467
Files: 64
Total size: 365.2 KB

Directory structure:
gitextract_r25kqp43/

├── LICENSE
├── NOTICE
├── README.md
├── configs/
│   ├── r50_deformable_detr.sh
│   ├── r50_efficient_detr.sh
│   ├── r50_sparse_detr_rho_0.1.sh
│   ├── r50_sparse_detr_rho_0.2.sh
│   ├── r50_sparse_detr_rho_0.3.sh
│   ├── swint_deformable_detr.sh
│   ├── swint_efficient_detr.sh
│   ├── swint_sparse_detr_rho_0.1.sh
│   ├── swint_sparse_detr_rho_0.2.sh
│   └── swint_sparse_detr_rho_0.3.sh
├── datasets/
│   ├── __init__.py
│   ├── coco.py
│   ├── coco_eval.py
│   ├── coco_panoptic.py
│   ├── data_prefetcher.py
│   ├── panoptic_eval.py
│   ├── samplers.py
│   ├── torchvision_datasets/
│   │   ├── __init__.py
│   │   └── coco.py
│   └── transforms.py
├── engine.py
├── main.py
├── models/
│   ├── __init__.py
│   ├── backbone.py
│   ├── deformable_detr.py
│   ├── deformable_transformer.py
│   ├── matcher.py
│   ├── ops/
│   │   ├── functions/
│   │   │   ├── __init__.py
│   │   │   └── ms_deform_attn_func.py
│   │   ├── make.sh
│   │   ├── modules/
│   │   │   ├── __init__.py
│   │   │   └── ms_deform_attn.py
│   │   ├── setup.py
│   │   ├── src/
│   │   │   ├── cpu/
│   │   │   │   ├── ms_deform_attn_cpu.cpp
│   │   │   │   └── ms_deform_attn_cpu.h
│   │   │   ├── cuda/
│   │   │   │   ├── ms_deform_attn_cuda.cu
│   │   │   │   ├── ms_deform_attn_cuda.h
│   │   │   │   └── ms_deform_im2col_cuda.cuh
│   │   │   ├── ms_deform_attn.h
│   │   │   └── vision.cpp
│   │   └── test.py
│   ├── position_encoding.py
│   ├── segmentation.py
│   └── swin_transformer/
│       ├── __init__.py
│       ├── build.py
│       ├── config.py
│       ├── configs/
│       │   ├── default.yaml
│       │   ├── swin_base_patch4_window7_224.yaml
│       │   ├── swin_large_patch4_window7_224.yaml
│       │   ├── swin_small_patch4_window7_224.yaml
│       │   └── swin_tiny_patch4_window7_224.yaml
│       └── swin_transformer.py
├── requirements.txt
├── tools/
│   ├── launch.py
│   └── run_dist_launch.sh
└── util/
    ├── __init__.py
    ├── benchmark.py
    ├── box_ops.py
    ├── dam.py
    ├── misc.py
    └── plot_utils.py

================================================
FILE CONTENTS
================================================

================================================
FILE: LICENSE
================================================
                                Apache License
                           Version 2.0, January 2004
                        http://www.apache.org/licenses/

   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION

   1. Definitions.

      "License" shall mean the terms and conditions for use, reproduction,
      and distribution as defined by Sections 1 through 9 of this document.

      "Licensor" shall mean the copyright owner or entity authorized by
      the copyright owner that is granting the License.

      "Legal Entity" shall mean the union of the acting entity and all
      other entities that control, are controlled by, or are under common
      control with that entity. For the purposes of this definition,
      "control" means (i) the power, direct or indirect, to cause the
      direction or management of such entity, whether by contract or
      otherwise, or (ii) ownership of fifty percent (50%) or more of the
      outstanding shares, or (iii) beneficial ownership of such entity.

      "You" (or "Your") shall mean an individual or Legal Entity
      exercising permissions granted by this License.

      "Source" form shall mean the preferred form for making modifications,
      including but not limited to software source code, documentation
      source, and configuration files.

      "Object" form shall mean any form resulting from mechanical
      transformation or translation of a Source form, including but
      not limited to compiled object code, generated documentation,
      and conversions to other media types.

      "Work" shall mean the work of authorship, whether in Source or
      Object form, made available under the License, as indicated by a
      copyright notice that is included in or attached to the work
      (an example is provided in the Appendix below).

      "Derivative Works" shall mean any work, whether in Source or Object
      form, that is based on (or derived from) the Work and for which the
      editorial revisions, annotations, elaborations, or other modifications
      represent, as a whole, an original work of authorship. For the purposes
      of this License, Derivative Works shall not include works that remain
      separable from, or merely link (or bind by name) to the interfaces of,
      the Work and Derivative Works thereof.

      "Contribution" shall mean any work of authorship, including
      the original version of the Work and any modifications or additions
      to that Work or Derivative Works thereof, that is intentionally
      submitted to Licensor for inclusion in the Work by the copyright owner
      or by an individual or Legal Entity authorized to submit on behalf of
      the copyright owner. For the purposes of this definition, "submitted"
      means any form of electronic, verbal, or written communication sent
      to the Licensor or its representatives, including but not limited to
      communication on electronic mailing lists, source code control systems,
      and issue tracking systems that are managed by, or on behalf of, the
      Licensor for the purpose of discussing and improving the Work, but
      excluding communication that is conspicuously marked or otherwise
      designated in writing by the copyright owner as "Not a Contribution."

      "Contributor" shall mean Licensor and any individual or Legal Entity
      on behalf of whom a Contribution has been received by Licensor and
      subsequently incorporated within the Work.

   2. Grant of Copyright License. Subject to the terms and conditions of
      this License, each Contributor hereby grants to You a perpetual,
      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
      copyright license to reproduce, prepare Derivative Works of,
      publicly display, publicly perform, sublicense, and distribute the
      Work and such Derivative Works in Source or Object form.

   3. Grant of Patent License. Subject to the terms and conditions of
      this License, each Contributor hereby grants to You a perpetual,
      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
      (except as stated in this section) patent license to make, have made,
      use, offer to sell, sell, import, and otherwise transfer the Work,
      where such license applies only to those patent claims licensable
      by such Contributor that are necessarily infringed by their
      Contribution(s) alone or by combination of their Contribution(s)
      with the Work to which such Contribution(s) was submitted. If You
      institute patent litigation against any entity (including a
      cross-claim or counterclaim in a lawsuit) alleging that the Work
      or a Contribution incorporated within the Work constitutes direct
      or contributory patent infringement, then any patent licenses
      granted to You under this License for that Work shall terminate
      as of the date such litigation is filed.

   4. Redistribution. You may reproduce and distribute copies of the
      Work or Derivative Works thereof in any medium, with or without
      modifications, and in Source or Object form, provided that You
      meet the following conditions:

      (a) You must give any other recipients of the Work or
          Derivative Works a copy of this License; and

      (b) You must cause any modified files to carry prominent notices
          stating that You changed the files; and

      (c) You must retain, in the Source form of any Derivative Works
          that You distribute, all copyright, patent, trademark, and
          attribution notices from the Source form of the Work,
          excluding those notices that do not pertain to any part of
          the Derivative Works; and

      (d) If the Work includes a "NOTICE" text file as part of its
          distribution, then any Derivative Works that You distribute must
          include a readable copy of the attribution notices contained
          within such NOTICE file, excluding those notices that do not
          pertain to any part of the Derivative Works, in at least one
          of the following places: within a NOTICE text file distributed
          as part of the Derivative Works; within the Source form or
          documentation, if provided along with the Derivative Works; or,
          within a display generated by the Derivative Works, if and
          wherever such third-party notices normally appear. The contents
          of the NOTICE file are for informational purposes only and
          do not modify the License. You may add Your own attribution
          notices within Derivative Works that You distribute, alongside
          or as an addendum to the NOTICE text from the Work, provided
          that such additional attribution notices cannot be construed
          as modifying the License.

      You may add Your own copyright statement to Your modifications and
      may provide additional or different license terms and conditions
      for use, reproduction, or distribution of Your modifications, or
      for any such Derivative Works as a whole, provided Your use,
      reproduction, and distribution of the Work otherwise complies with
      the conditions stated in this License.

   5. Submission of Contributions. Unless You explicitly state otherwise,
      any Contribution intentionally submitted for inclusion in the Work
      by You to the Licensor shall be under the terms and conditions of
      this License, without any additional terms or conditions.
      Notwithstanding the above, nothing herein shall supersede or modify
      the terms of any separate license agreement you may have executed
      with Licensor regarding such Contributions.

   6. Trademarks. This License does not grant permission to use the trade
      names, trademarks, service marks, or product names of the Licensor,
      except as required for reasonable and customary use in describing the
      origin of the Work and reproducing the content of the NOTICE file.

   7. Disclaimer of Warranty. Unless required by applicable law or
      agreed to in writing, Licensor provides the Work (and each
      Contributor provides its Contributions) on an "AS IS" BASIS,
      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
      implied, including, without limitation, any warranties or conditions
      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
      PARTICULAR PURPOSE. You are solely responsible for determining the
      appropriateness of using or redistributing the Work and assume any
      risks associated with Your exercise of permissions under this License.

   8. Limitation of Liability. In no event and under no legal theory,
      whether in tort (including negligence), contract, or otherwise,
      unless required by applicable law (such as deliberate and grossly
      negligent acts) or agreed to in writing, shall any Contributor be
      liable to You for damages, including any direct, indirect, special,
      incidental, or consequential damages of any character arising as a
      result of this License or out of the use or inability to use the
      Work (including but not limited to damages for loss of goodwill,
      work stoppage, computer failure or malfunction, or any and all
      other commercial damages or losses), even if such Contributor
      has been advised of the possibility of such damages.

   9. Accepting Warranty or Additional Liability. While redistributing
      the Work or Derivative Works thereof, You may choose to offer,
      and charge a fee for, acceptance of support, warranty, indemnity,
      or other liability obligations and/or rights consistent with this
      License. However, in accepting such obligations, You may act only
      on Your own behalf and on Your sole responsibility, not on behalf
      of any other Contributor, and only if You agree to indemnify,
      defend, and hold each Contributor harmless for any liability
      incurred by, or claims asserted against, such Contributor by reason
      of your accepting any such warranty or additional liability.

   END OF TERMS AND CONDITIONS

   APPENDIX: How to apply the Apache License to your work.

      To apply the Apache License to your work, attach the following
      boilerplate notice, with the fields enclosed by brackets "{}"
      replaced with your own identifying information. (Don't include
      the brackets!)  The text should be enclosed in the appropriate
      comment syntax for the file format. We also recommend that a
      file or class name and description of purpose be included on the
      same "printed page" as the copyright notice for easier
      identification within third-party archives.

   Copyright 2021 KAKAO BRAIN Corp. All Rights Reserved.
   
   Licensed under the Apache License, Version 2.0 (the "License");
   you may not use this file except in compliance with the License.
   You may obtain a copy of the License at

       http://www.apache.org/licenses/LICENSE-2.0

   Unless required by applicable law or agreed to in writing, software
   distributed under the License is distributed on an "AS IS" BASIS,
   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
   See the License for the specific language governing permissions and
   limitations under the License.



Deformable DETR

Copyright 2020 SenseTime

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

   http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.



DETR

Copyright 2020 - present, Facebook, Inc

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

   http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.


================================================
FILE: NOTICE
================================================
===============================================================================
Deformable DETR's Apache License 2.0
===============================================================================
The overall structure of the code is based on the implementation in 
Deformable-DETR(https://github.com/fundamentalvision/Deformable-DETR).
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
Copyright (c) 2020 SenseTime

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

   http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

===============================================================================
DETR's Apache License 2.0
===============================================================================
Deformable DETR code is orginally built on the implementation in DETR
(https://github.com/facebookresearch/detr).
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
Copyright (c) 2020 Facebook, Inc

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

   http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.


===============================================================================
Swin Transformer' MIT License
===============================================================================
The transformer backbone is based on the implementation in Swin Transformer
(https://github.com/microsoft/Swin-Transformer).
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
Copyright (c) 2021 Microsoft

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.


================================================
FILE: README.md
================================================
[![KakaoBrain](https://img.shields.io/badge/kakao-brain-ffcd00.svg)](http://kakaobrain.com/)
[![pytorch](https://img.shields.io/badge/pytorch-1.6.0-%2523ee4c2c.svg)](https://pytorch.org/)
[![pytorch](https://img.shields.io/badge/pytorch-1.7.1-%2523ee4c2c.svg)](https://pytorch.org/)

Sparse DETR (ICLR'22)
========

By [Byungseok Roh](https://scholar.google.com/citations?user=H4VWYHwAAAAJ)\*,  [Jaewoong Shin](https://scholar.google.com/citations?user=i_o_95kAAAAJ)\*,  [Wuhyun Shin](https://scholar.google.com/citations?user=bGwfkakAAAAJ)\*, and [Saehoon Kim](https://scholar.google.com/citations?user=_ZfueMIAAAAJ) at [Kakao Brain](https://www.kakaobrain.com).
(*: Equal contribution)

* This repository is an official implementation of the paper [Sparse DETR: Efficient End-to-End Object Detection with Learnable Sparsity](https://arxiv.org/abs/2111.14330). 
* The code and some instructions are built upon the official [Deformable DETR repository](https://github.com/fundamentalvision/Deformable-DETR).



# Introduction

**TL; DR.** Sparse DETR is an efficient end-to-end object detector that **sparsifies encoder tokens** by using the learnable DAM(Decoder Attention Map) predictor. It achieves better performance than Deformable DETR even with only 10% encoder queries on the COCO dataset.

<p align="center">
<img src="./figs/dam_creation.png" height=350>
</p>

**Abstract.** DETR is the first end-to-end object detector using a transformer encoder-decoder architecture and demonstrates competitive performance but low computational efficiency on high resolution feature maps.
The subsequent work, Deformable DETR, enhances the efficiency of DETR by replacing dense attention with deformable attention, which achieves 10x faster convergence and improved performance. 
Deformable DETR uses the multiscale feature to ameliorate performance, however, the number of encoder tokens increases by 20x compared to DETR, and the computation cost of the encoder attention remains a bottleneck.
In our preliminary experiment, we observe that the detection performance hardly deteriorates even if only a part of the encoder token is updated.
Inspired by this observation, we propose *Sparse DETR* that selectively updates only the tokens expected to be referenced by the decoder, thus help the model effectively detect objects.
In addition, we show that applying an auxiliary detection loss on the selected tokens in the encoder improves the performance while minimizing computational overhead.
We validate that *Sparse DETR* achieves better performance than Deformable DETR even with only 10\% encoder tokens on the COCO dataset.
Albeit only the encoder tokens are sparsified, the total computation cost decreases by 38\% and the frames per second (FPS) increases by 42\% compared to Deformable DETR.


# Installation

## Requirements

We have tested the code on the following environments: 
* Python 3.7.7 / Pytorch 1.6.0 / torchvisoin 0.7.0 / CUDA 10.1 / Ubuntu 18.04
* Python 3.8.3 / Pytorch 1.7.1 / torchvisoin 0.8.2 / CUDA 11.1 / Ubuntu 18.04

Run the following command to install dependencies:
```bash
pip install -r requirements.txt
```

## Compiling CUDA operators
```bash
cd ./models/ops
sh ./make.sh
# unit test (should see all checking is True)
python test.py
```

# Usage

## Dataset preparation

Please download [COCO 2017 dataset](https://cocodataset.org/) and organize them as follows:

```
code_root/
└── data/
    └── coco/
        ├── train2017/
        ├── val2017/
        └── annotations/
        	├── instances_train2017.json
        	└── instances_val2017.json
```

## Training

### Training on a single node

For example, the command for training Sparse DETR with the keeping ratio of 10% on 8 GPUs is as follows:

```bash
$ GPUS_PER_NODE=8 ./tools/run_dist_launch.sh 8 ./configs/swint_sparse_detr_rho_0.1.sh
```

### Training on multiple nodes

For example, the command Sparse DETR with the keeping ratio of 10% on 2 nodes of each with 8 GPUs is as follows:

On node 1:

```bash
$ MASTER_ADDR=<IP address of node 1> NODE_RANK=0 GPUS_PER_NODE=8 ./tools/run_dist_launch.sh 16 ./configs/swint_sparse_detr_rho_0.1.sh
```

On node 2:

```bash
$ MASTER_ADDR=<IP address of node 2> NODE_RANK=1 GPUS_PER_NODE=8 ./tools/run_dist_launch.sh 16 ./configs/swint_sparse_detr_rho_0.1.sh
```

### Direct argument control

```bash
# Deformable DETR (with bounding-box-refinement and two-stage argument, if wanted)
$ GPUS_PER_NODE=8 ./tools/run_dist_launch.sh 8 python main.py --with_box_refine --two_stage
# Efficient DETR (with the class-specific head as describe in their paper)
$ GPUS_PER_NODE=8 ./tools/run_dist_launch.sh 8 python main.py --with_box_refine --two_stage --eff_query_init --eff_specific_head
# Sparse DETR (with the keeping ratio of 10% and encoder auxiliary loss)
$ GPUS_PER_NODE=8 ./tools/run_dist_launch.sh 8 python main.py --with_box_refine --two_stage --eff_query_init --eff_specific_head --rho 0.1 --use_enc_aux_loss
```

### Some tips to speed-up training
* If your file system is slow to read images, you may consider enabling '--cache_mode' option to load the whole dataset into memory at the beginning of training.
* You may increase the batch size to maximize the GPU utilization, according to GPU memory of yours, e.g., set '--batch_size 3' or '--batch_size 4'.

## Evaluation

You can get the pre-trained model of Sparse DETR (the link is in "Main Results" session), then run the following command to evaluate it on COCO 2017 validation set:

```bash
# Note that you should run the command with the corresponding configuration.
$ ./configs/swint_sparse_detr_rho_0.1.sh --resume <path to pre-trained model> --eval
```

You can also run distributed evaluation by using ```./tools/run_dist_launch.sh```.

# Main Results
The tables below demonstrate the detection performance of Sparse DETR on the COCO 2017 validation set when using different backbones. 
* **Top-k** : sampling the top-k object queries instead of using the learned object queries(as in Efficient DETR).
* **BBR** : performing bounding box refinement in the decoder block(as in Deformable DETR).
* The **encoder auxiliary loss** proposed in our paper is only applied to Sparse DETR.
* **FLOPs** and **FPS** are measured in the same way as used in Deformable DETR. 
* Refer to **Table 1** in the paper for more details.



## ResNet-50 backbone
| Method             | Epochs | ρ   | Top-k & BBR | AP   | #Params(M) | GFLOPs | B4FPS | Download |
|:------------------:|:------:|:---:|:-----------:|:----:|:----------:|:------:|:-----:|:--------:|
| Faster R-CNN + FPN | 109    | N/A |             | 42.0 | 42M        | 180G   | 26    |          |
| DETR               | 50     | N/A |             | 35.0 | 41M        | 86G    | 28    |          |
| DETR               | 500    | N/A |             | 42.0 | 41M        | 86G    | 28    |          |
| DETR-DC5           | 500    | N/A |             | 43.3 | 41M        | 187G   | 12    |          |
| PnP-DETR           | 500    | 33% |             | 41.1 |            |        |       |          |
|                    | 500    | 50% |             | 41.8 |            |        |       |          |
| PnP-DETR-DC5       | 500    | 33% |             | 42.7 |            |        |       |          |
|                    | 500    | 50% |             | 43.1 |            |        |       |          |
| Deformable-DETR    | 50     | N/A |             | 43.9 | 39.8M      | 172.9G | 19.1  |          |
|                    | 50     | N/A | o           | 46.0 | 40.8M      | 177.3G | 18.2  |          |
| Sparse-DETR        | 50     | 10% | o           | 45.3 | 40.9M      | 105.4G | 26.5  | [link](https://twg.kakaocdn.net/brainrepo/sparse_detr/sparse_detr_r50_10.pth)     |
|                    | 50     | 20% | o           | 45.6 | 40.9M      | 112.9G | 24.8  | [link](https://twg.kakaocdn.net/brainrepo/sparse_detr/sparse_detr_r50_20.pth)     |
|                    | 50     | 30% | o           | 46.0 | 40.9M      | 120.5G | 23.2  | [link](https://twg.kakaocdn.net/brainrepo/sparse_detr/sparse_detr_r50_30.pth)     |
|                    | 50     | 40% | o           | 46.2 | 40.9M      | 128.0G | 21.8  | [link](https://twg.kakaocdn.net/brainrepo/sparse_detr/sparse_detr_r50_40.pth)     |
|                    | 50     | 50% | o           | 46.3 | 40.9M      | 135.6G | 20.5  | [link](https://twg.kakaocdn.net/brainrepo/sparse_detr/sparse_detr_r50_50.pth)     |



## Swin-T backbone
| Method          | Epochs | ρ   | Top-k & BBR | AP   | #Params(M) | GFLOPs | B4FPS | Download |
|:---------------:|:------:|:---:|:-----------:|:----:|:----------:|:------:|:-----:|:--------:|
| DETR            | 50     | N/A |             | 35.9 | 45.0M      | 91.6G  | 26.8  |          |
| DETR            | 500    | N/A |             | 45.4 | 45.0M      | 91.6G  | 26.8  |          |
| Deformable-DETR | 50     | N/A |             | 45.7 | 40.3M      | 180.4G | 15.9  |          |
|                 | 50     | N/A | o           | 48.0 | 41.3M      | 184.8G | 15.4  |          |
| Sparse-DETR     | 50     | 10% | o           | 48.2 | 41.4M      | 113.4G | 21.2  | [link](https://twg.kakaocdn.net/brainrepo/sparse_detr/sparse_detr_swint_10.pth)     |
|                 | 50     | 20% | o           | 48.8 | 41.4M      | 121.0G | 20    | [link](https://twg.kakaocdn.net/brainrepo/sparse_detr/sparse_detr_swint_20.pth)     |
|                 | 50     | 30% | o           | 49.1 | 41.4M      | 128.5G | 18.9  | [link](https://twg.kakaocdn.net/brainrepo/sparse_detr/sparse_detr_swint_30.pth)     |
|                 | 50     | 40% | o           | 49.2 | 41.4M      | 136.1G | 18    | [link](https://twg.kakaocdn.net/brainrepo/sparse_detr/sparse_detr_swint_40.pth)     |
|                 | 50     | 50% | o           | 49.3 | 41.4M      | 143.7G | 17.2  | [link](https://twg.kakaocdn.net/brainrepo/sparse_detr/sparse_detr_swint_50.pth)     |


## Initializing ResNet-50 backbone with SCRL
The performance of Sparse DETR can be further improved when the backbone network is initialized with the `SCRL`([Spatially Consistent Representation Learning](https://arxiv.org/abs/2103.06122)) that aims to learn dense representations in a self-supervised way, compared to the default initialization with the ImageNet pre-trained one, denoted as `IN-sup` in the table below. 
* We obtained pre-trained weights from [Torchvision](https://pytorch.org/tutorials/beginner/finetuning_torchvision_models_tutorial.html#sphx-glr-beginner-finetuning-torchvision-models-tutorial-py) for `IN-sup`, and the [SCRL GitHub repository](https://github.com/kakaobrain/scrl) for `SCRL`.
* To reproduce the `SCRL` results, add `--scrl_pretrained_path <downloaded_filepath>` to the training command.
 
| Method      | ρ   | AP(IN-sup) | AP(SCRL) | AP(gain) | Download |
|:-----------:|:---:|:-----------:|:--------:|:--------:|:--------:|
| Sparse DETR | 10% | 45.3        | 46.9     | +1.6     | [link](https://twg.kakaocdn.net/brainrepo/sparse_detr/sparse_detr_r50_scrl_10.pth)     |
|             | 20% | 45.6        | 47.2     | +1.7     | [link](https://twg.kakaocdn.net/brainrepo/sparse_detr/sparse_detr_r50_scrl_20.pth)     |
|             | 30% | 46.0        | 47.4     | +1.4     | [link](https://twg.kakaocdn.net/brainrepo/sparse_detr/sparse_detr_r50_scrl_30.pth)     |
|             | 40% | 46.2        | 47.7     | +1.5     | [link](https://twg.kakaocdn.net/brainrepo/sparse_detr/sparse_detr_r50_scrl_40.pth)     |
|             | 50% | 46.3        | 47.9     | +1.6     | [link](https://twg.kakaocdn.net/brainrepo/sparse_detr/sparse_detr_r50_scrl_50.pth)     |


# Citation
If you find Sparse DETR useful in your research, please consider citing:
```bibtex
@inproceedings{roh2022sparse,
  title={Sparse DETR: Efficient End-to-End Object Detection with Learnable Sparsity},
  author={Roh, Byungseok and Shin, JaeWoong and Shin, Wuhyun and Kim, Saehoon},
  booktitle={ICLR},
  year={2022}
}
```

# License

This project is released under the [Apache 2.0 license](./LICENSE).
Copyright 2021 [Kakao Brain Corp](https://www.kakaobrain.com). All Rights Reserved.


================================================
FILE: configs/r50_deformable_detr.sh
================================================
#!/usr/bin/env bash

set -x

EXP_DIR=exps/r50_deformable_detr
PY_ARGS=${@:1}

python -u main.py \
    --output_dir ${EXP_DIR} \
    ${PY_ARGS}


================================================
FILE: configs/r50_efficient_detr.sh
================================================
#!/usr/bin/env bash

set -x

EXP_DIR=exps/r50_efficient_detr
PY_ARGS=${@:1}

python -u main.py \
    --output_dir ${EXP_DIR} \
    --with_box_refine \
    --two_stage \
    --eff_query_init \
    --eff_specific_head \
    ${PY_ARGS}


================================================
FILE: configs/r50_sparse_detr_rho_0.1.sh
================================================
#!/usr/bin/env bash

set -x

EXP_DIR=exps/r50_sparse_detr_0.1
PY_ARGS=${@:1}

python -u main.py \
    --output_dir ${EXP_DIR} \
    --with_box_refine \
    --two_stage \
    --eff_query_init \
    --eff_specific_head \
    --rho 0.1 \
    --use_enc_aux_loss \
    ${PY_ARGS}


================================================
FILE: configs/r50_sparse_detr_rho_0.2.sh
================================================
#!/usr/bin/env bash

set -x

EXP_DIR=exps/r50_sparse_detr_0.2
PY_ARGS=${@:1}

python -u main.py \
    --output_dir ${EXP_DIR} \
    --with_box_refine \
    --two_stage \
    --eff_query_init \
    --eff_specific_head \
    --rho 0.2 \
    --use_enc_aux_loss \
    ${PY_ARGS}


================================================
FILE: configs/r50_sparse_detr_rho_0.3.sh
================================================
#!/usr/bin/env bash

set -x

EXP_DIR=exps/r50_sparse_detr_0.3
PY_ARGS=${@:1}

python -u main.py \
    --output_dir ${EXP_DIR} \
    --with_box_refine \
    --two_stage \
    --eff_query_init \
    --eff_specific_head \
    --rho 0.3 \
    --use_enc_aux_loss \
    ${PY_ARGS}


================================================
FILE: configs/swint_deformable_detr.sh
================================================
#!/usr/bin/env bash

set -x

EXP_DIR=exps/swint_deformable_detr
PY_ARGS=${@:1}

python -u main.py \
    --output_dir ${EXP_DIR} \
    --backbone swin-t \
    ${PY_ARGS}


================================================
FILE: configs/swint_efficient_detr.sh
================================================
#!/usr/bin/env bash

set -x

EXP_DIR=exps/swint_efficient_detr
PY_ARGS=${@:1}

python -u main.py \
    --output_dir ${EXP_DIR} \
    --backbone swin-t \
    --with_box_refine \
    --two_stage \
    --eff_query_init \
    --eff_specific_head \
    ${PY_ARGS}


================================================
FILE: configs/swint_sparse_detr_rho_0.1.sh
================================================
#!/usr/bin/env bash

set -x

EXP_DIR=exps/swint_sparse_detr_0.1
PY_ARGS=${@:1}

python -u main.py \
    --output_dir ${EXP_DIR} \
    --backbone swin-t \
    --with_box_refine \
    --two_stage \
    --eff_query_init \
    --eff_specific_head \
    --rho 0.1 \
    --use_enc_aux_loss \
    ${PY_ARGS}


================================================
FILE: configs/swint_sparse_detr_rho_0.2.sh
================================================
#!/usr/bin/env bash

set -x

EXP_DIR=exps/swint_sparse_detr_0.2
PY_ARGS=${@:1}

python -u main.py \
    --output_dir ${EXP_DIR} \
    --backbone swin-t \
    --with_box_refine \
    --two_stage \
    --eff_query_init \
    --eff_specific_head \
    --rho 0.2 \
    --use_enc_aux_loss \
    ${PY_ARGS}


================================================
FILE: configs/swint_sparse_detr_rho_0.3.sh
================================================
#!/usr/bin/env bash

set -x

EXP_DIR=exps/swint_sparse_detr_0.3
PY_ARGS=${@:1}

python -u main.py \
    --output_dir ${EXP_DIR} \
    --backbone swin-t \
    --with_box_refine \
    --two_stage \
    --eff_query_init \
    --eff_specific_head \
    --rho 0.3 \
    --use_enc_aux_loss \
    ${PY_ARGS}


================================================
FILE: datasets/__init__.py
================================================
# ------------------------------------------------------------------------
# Deformable DETR
# Copyright (c) 2020 SenseTime. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------
# Modified from DETR (https://github.com/facebookresearch/detr)
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
# ------------------------------------------------------------------------

import torch.utils.data
from .torchvision_datasets import CocoDetection

from .coco import build as build_coco


def get_coco_api_from_dataset(dataset):
    for _ in range(10):
        # if isinstance(dataset, torchvision.datasets.CocoDetection):
        #     break
        if isinstance(dataset, torch.utils.data.Subset):
            dataset = dataset.dataset
    if isinstance(dataset, CocoDetection):
        return dataset.coco


def build_dataset(image_set, args):
    if args.dataset_file == 'coco':
        return build_coco(image_set, args)
    if args.dataset_file == 'coco_panoptic':
        # to avoid making panopticapi required for coco
        from .coco_panoptic import build as build_coco_panoptic
        return build_coco_panoptic(image_set, args)
    raise ValueError(f'dataset {args.dataset_file} not supported')


================================================
FILE: datasets/coco.py
================================================
# ------------------------------------------------------------------------
# Deformable DETR
# Copyright (c) 2020 SenseTime. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------
# Modified from DETR (https://github.com/facebookresearch/detr)
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
# ------------------------------------------------------------------------

"""
COCO dataset which returns image_id for evaluation.

Mostly copy-paste from https://github.com/pytorch/vision/blob/13b35ff/references/detection/coco_utils.py
"""
from pathlib import Path

import torch
import torch.utils.data
from pycocotools import mask as coco_mask

from .torchvision_datasets import CocoDetection as TvCocoDetection
from util.misc import get_local_rank, get_local_size
import datasets.transforms as T


class CocoDetection(TvCocoDetection):
    def __init__(self, img_folder, ann_file, transforms, return_masks, cache_mode=False, local_rank=0, local_size=1):
        super(CocoDetection, self).__init__(img_folder, ann_file,
                                            cache_mode=cache_mode, local_rank=local_rank, local_size=local_size)
        self._transforms = transforms
        self.prepare = ConvertCocoPolysToMask(return_masks)

    def __getitem__(self, idx):
        img, target = super(CocoDetection, self).__getitem__(idx)
        image_id = self.ids[idx]
        target = {'image_id': image_id, 'annotations': target}
        img, target = self.prepare(img, target)
        if self._transforms is not None:
            img, target = self._transforms(img, target)
        return img, target


def convert_coco_poly_to_mask(segmentations, height, width):
    masks = []
    for polygons in segmentations:
        rles = coco_mask.frPyObjects(polygons, height, width)
        mask = coco_mask.decode(rles)
        if len(mask.shape) < 3:
            mask = mask[..., None]
        mask = torch.as_tensor(mask, dtype=torch.uint8)
        mask = mask.any(dim=2)
        masks.append(mask)
    if masks:
        masks = torch.stack(masks, dim=0)
    else:
        masks = torch.zeros((0, height, width), dtype=torch.uint8)
    return masks


class ConvertCocoPolysToMask(object):
    def __init__(self, return_masks=False):
        self.return_masks = return_masks

    def __call__(self, image, target):
        w, h = image.size

        image_id = target["image_id"]
        image_id = torch.tensor([image_id])

        anno = target["annotations"]

        anno = [obj for obj in anno if 'iscrowd' not in obj or obj['iscrowd'] == 0]

        boxes = [obj["bbox"] for obj in anno]
        # guard against no boxes via resizing
        boxes = torch.as_tensor(boxes, dtype=torch.float32).reshape(-1, 4)
        boxes[:, 2:] += boxes[:, :2]
        boxes[:, 0::2].clamp_(min=0, max=w)
        boxes[:, 1::2].clamp_(min=0, max=h)

        classes = [obj["category_id"] for obj in anno]
        classes = torch.tensor(classes, dtype=torch.int64)

        if self.return_masks:
            segmentations = [obj["segmentation"] for obj in anno]
            masks = convert_coco_poly_to_mask(segmentations, h, w)

        keypoints = None
        if anno and "keypoints" in anno[0]:
            keypoints = [obj["keypoints"] for obj in anno]
            keypoints = torch.as_tensor(keypoints, dtype=torch.float32)
            num_keypoints = keypoints.shape[0]
            if num_keypoints:
                keypoints = keypoints.view(num_keypoints, -1, 3)

        keep = (boxes[:, 3] > boxes[:, 1]) & (boxes[:, 2] > boxes[:, 0])
        boxes = boxes[keep]
        classes = classes[keep]
        if self.return_masks:
            masks = masks[keep]
        if keypoints is not None:
            keypoints = keypoints[keep]

        target = {}
        target["boxes"] = boxes
        target["labels"] = classes
        if self.return_masks:
            target["masks"] = masks
        target["image_id"] = image_id
        if keypoints is not None:
            target["keypoints"] = keypoints

        # for conversion to coco api
        area = torch.tensor([obj["area"] for obj in anno])
        iscrowd = torch.tensor([obj["iscrowd"] if "iscrowd" in obj else 0 for obj in anno])
        target["area"] = area[keep]
        target["iscrowd"] = iscrowd[keep]

        target["orig_size"] = torch.as_tensor([int(h), int(w)])
        target["size"] = torch.as_tensor([int(h), int(w)])

        return image, target


def make_coco_transforms(image_set):

    normalize = T.Compose([
        T.ToTensor(),
        T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])

    scales = [480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800]

    if image_set == 'train':
        return T.Compose([
            T.RandomHorizontalFlip(),
            T.RandomSelect(
                T.RandomResize(scales, max_size=1333),
                T.Compose([
                    T.RandomResize([400, 500, 600]),
                    T.RandomSizeCrop(384, 600),
                    T.RandomResize(scales, max_size=1333),
                ])
            ),
            normalize,
        ])

    if image_set == 'val':
        return T.Compose([
            T.RandomResize([800], max_size=1333),
            normalize,
        ])

    raise ValueError(f'unknown {image_set}')


def build(image_set, args):
    root = Path(args.coco_path)
    assert root.exists(), f'provided COCO path {root} does not exist'
    mode = 'instances'
    PATHS = {
        "train": (root / "train2017", root / "annotations" / f'{mode}_train2017.json'),
        "val": (root / "val2017", root / "annotations" / f'{mode}_val2017.json'),
    }

    img_folder, ann_file = PATHS[image_set]
    dataset = CocoDetection(img_folder, ann_file, transforms=make_coco_transforms(image_set), return_masks=args.masks,
                            cache_mode=args.cache_mode, local_rank=get_local_rank(), local_size=get_local_size())
    return dataset


================================================
FILE: datasets/coco_eval.py
================================================
# ------------------------------------------------------------------------
# Deformable DETR
# Copyright (c) 2020 SenseTime. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------
# Modified from DETR (https://github.com/facebookresearch/detr)
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
# ------------------------------------------------------------------------

"""
COCO evaluator that works in distributed mode.

Mostly copy-paste from https://github.com/pytorch/vision/blob/edfd5a7/references/detection/coco_eval.py
The difference is that there is less copy-pasting from pycocotools
in the end of the file, as python3 can suppress prints with contextlib
"""
import os
import contextlib
import copy
import numpy as np
import torch

from pycocotools.cocoeval import COCOeval
from pycocotools.coco import COCO
import pycocotools.mask as mask_util

from util.misc import all_gather


class CocoEvaluator(object):
    def __init__(self, coco_gt, iou_types):
        assert isinstance(iou_types, (list, tuple))
        coco_gt = copy.deepcopy(coco_gt)
        self.coco_gt = coco_gt

        self.iou_types = iou_types
        self.coco_eval = {}
        for iou_type in iou_types:
            self.coco_eval[iou_type] = COCOeval(coco_gt, iouType=iou_type)

        self.img_ids = []
        self.eval_imgs = {k: [] for k in iou_types}

    def update(self, predictions):
        img_ids = list(np.unique(list(predictions.keys())))
        self.img_ids.extend(img_ids)

        for iou_type in self.iou_types:
            results = self.prepare(predictions, iou_type)

            # suppress pycocotools prints
            with open(os.devnull, 'w') as devnull:
                with contextlib.redirect_stdout(devnull):
                    coco_dt = COCO.loadRes(self.coco_gt, results) if results else COCO()
            coco_eval = self.coco_eval[iou_type]

            coco_eval.cocoDt = coco_dt
            coco_eval.params.imgIds = list(img_ids)
            img_ids, eval_imgs = evaluate(coco_eval)

            self.eval_imgs[iou_type].append(eval_imgs)

    def synchronize_between_processes(self):
        for iou_type in self.iou_types:
            self.eval_imgs[iou_type] = np.concatenate(self.eval_imgs[iou_type], 2)
            create_common_coco_eval(self.coco_eval[iou_type], self.img_ids, self.eval_imgs[iou_type])

    def accumulate(self):
        for coco_eval in self.coco_eval.values():
            coco_eval.accumulate()

    def summarize(self):
        for iou_type, coco_eval in self.coco_eval.items():
            print("IoU metric: {}".format(iou_type))
            coco_eval.summarize()

    def prepare(self, predictions, iou_type):
        if iou_type == "bbox":
            return self.prepare_for_coco_detection(predictions)
        elif iou_type == "segm":
            return self.prepare_for_coco_segmentation(predictions)
        elif iou_type == "keypoints":
            return self.prepare_for_coco_keypoint(predictions)
        else:
            raise ValueError("Unknown iou type {}".format(iou_type))

    def prepare_for_coco_detection(self, predictions):
        coco_results = []
        for original_id, prediction in predictions.items():
            if len(prediction) == 0:
                continue

            boxes = prediction["boxes"]
            boxes = convert_to_xywh(boxes).tolist()
            scores = prediction["scores"].tolist()
            labels = prediction["labels"].tolist()

            coco_results.extend(
                [
                    {
                        "image_id": original_id,
                        "category_id": labels[k],
                        "bbox": box,
                        "score": scores[k],
                    }
                    for k, box in enumerate(boxes)
                ]
            )
        return coco_results

    def prepare_for_coco_segmentation(self, predictions):
        coco_results = []
        for original_id, prediction in predictions.items():
            if len(prediction) == 0:
                continue

            scores = prediction["scores"]
            labels = prediction["labels"]
            masks = prediction["masks"]

            masks = masks > 0.5

            scores = prediction["scores"].tolist()
            labels = prediction["labels"].tolist()

            rles = [
                mask_util.encode(np.array(mask[0, :, :, np.newaxis], dtype=np.uint8, order="F"))[0]
                for mask in masks
            ]
            for rle in rles:
                rle["counts"] = rle["counts"].decode("utf-8")

            coco_results.extend(
                [
                    {
                        "image_id": original_id,
                        "category_id": labels[k],
                        "segmentation": rle,
                        "score": scores[k],
                    }
                    for k, rle in enumerate(rles)
                ]
            )
        return coco_results

    def prepare_for_coco_keypoint(self, predictions):
        coco_results = []
        for original_id, prediction in predictions.items():
            if len(prediction) == 0:
                continue

            boxes = prediction["boxes"]
            boxes = convert_to_xywh(boxes).tolist()
            scores = prediction["scores"].tolist()
            labels = prediction["labels"].tolist()
            keypoints = prediction["keypoints"]
            keypoints = keypoints.flatten(start_dim=1).tolist()

            coco_results.extend(
                [
                    {
                        "image_id": original_id,
                        "category_id": labels[k],
                        'keypoints': keypoint,
                        "score": scores[k],
                    }
                    for k, keypoint in enumerate(keypoints)
                ]
            )
        return coco_results


def convert_to_xywh(boxes):
    xmin, ymin, xmax, ymax = boxes.unbind(1)
    return torch.stack((xmin, ymin, xmax - xmin, ymax - ymin), dim=1)


def merge(img_ids, eval_imgs):
    all_img_ids = all_gather(img_ids)
    all_eval_imgs = all_gather(eval_imgs)

    merged_img_ids = []
    for p in all_img_ids:
        merged_img_ids.extend(p)

    merged_eval_imgs = []
    for p in all_eval_imgs:
        merged_eval_imgs.append(p)

    merged_img_ids = np.array(merged_img_ids)
    merged_eval_imgs = np.concatenate(merged_eval_imgs, 2)

    # keep only unique (and in sorted order) images
    merged_img_ids, idx = np.unique(merged_img_ids, return_index=True)
    merged_eval_imgs = merged_eval_imgs[..., idx]

    return merged_img_ids, merged_eval_imgs


def create_common_coco_eval(coco_eval, img_ids, eval_imgs):
    img_ids, eval_imgs = merge(img_ids, eval_imgs)
    img_ids = list(img_ids)
    eval_imgs = list(eval_imgs.flatten())

    coco_eval.evalImgs = eval_imgs
    coco_eval.params.imgIds = img_ids
    coco_eval._paramsEval = copy.deepcopy(coco_eval.params)


#################################################################
# From pycocotools, just removed the prints and fixed
# a Python3 bug about unicode not defined
#################################################################


def evaluate(self):
    '''
    Run per image evaluation on given images and store results (a list of dict) in self.evalImgs
    :return: None
    '''
    # tic = time.time()
    # print('Running per image evaluation...')
    p = self.params
    # add backward compatibility if useSegm is specified in params
    if p.useSegm is not None:
        p.iouType = 'segm' if p.useSegm == 1 else 'bbox'
        print('useSegm (deprecated) is not None. Running {} evaluation'.format(p.iouType))
    # print('Evaluate annotation type *{}*'.format(p.iouType))
    p.imgIds = list(np.unique(p.imgIds))
    if p.useCats:
        p.catIds = list(np.unique(p.catIds))
    p.maxDets = sorted(p.maxDets)
    self.params = p

    self._prepare()
    # loop through images, area range, max detection number
    catIds = p.catIds if p.useCats else [-1]

    if p.iouType == 'segm' or p.iouType == 'bbox':
        computeIoU = self.computeIoU
    elif p.iouType == 'keypoints':
        computeIoU = self.computeOks
    self.ious = {
        (imgId, catId): computeIoU(imgId, catId)
        for imgId in p.imgIds
        for catId in catIds}

    evaluateImg = self.evaluateImg
    maxDet = p.maxDets[-1]
    evalImgs = [
        evaluateImg(imgId, catId, areaRng, maxDet)
        for catId in catIds
        for areaRng in p.areaRng
        for imgId in p.imgIds
    ]
    # this is NOT in the pycocotools code, but could be done outside
    evalImgs = np.asarray(evalImgs).reshape(len(catIds), len(p.areaRng), len(p.imgIds))
    self._paramsEval = copy.deepcopy(self.params)
    # toc = time.time()
    # print('DONE (t={:0.2f}s).'.format(toc-tic))
    return p.imgIds, evalImgs

#################################################################
# end of straight copy from pycocotools, just removing the prints
#################################################################


================================================
FILE: datasets/coco_panoptic.py
================================================
# ------------------------------------------------------------------------
# Deformable DETR
# Copyright (c) 2020 SenseTime. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------
# Modified from DETR (https://github.com/facebookresearch/detr)
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
# ------------------------------------------------------------------------

import json
from pathlib import Path

import numpy as np
import torch
from PIL import Image

from panopticapi.utils import rgb2id
from util.box_ops import masks_to_boxes

from .coco import make_coco_transforms


class CocoPanoptic:
    def __init__(self, img_folder, ann_folder, ann_file, transforms=None, return_masks=True):
        with open(ann_file, 'r') as f:
            self.coco = json.load(f)

        # sort 'images' field so that they are aligned with 'annotations'
        # i.e., in alphabetical order
        self.coco['images'] = sorted(self.coco['images'], key=lambda x: x['id'])
        # sanity check
        if "annotations" in self.coco:
            for img, ann in zip(self.coco['images'], self.coco['annotations']):
                assert img['file_name'][:-4] == ann['file_name'][:-4]

        self.img_folder = img_folder
        self.ann_folder = ann_folder
        self.ann_file = ann_file
        self.transforms = transforms
        self.return_masks = return_masks

    def __getitem__(self, idx):
        ann_info = self.coco['annotations'][idx] if "annotations" in self.coco else self.coco['images'][idx]
        img_path = Path(self.img_folder) / ann_info['file_name'].replace('.png', '.jpg')
        ann_path = Path(self.ann_folder) / ann_info['file_name']

        img = Image.open(img_path).convert('RGB')
        w, h = img.size
        if "segments_info" in ann_info:
            masks = np.asarray(Image.open(ann_path), dtype=np.uint32)
            masks = rgb2id(masks)

            ids = np.array([ann['id'] for ann in ann_info['segments_info']])
            masks = masks == ids[:, None, None]

            masks = torch.as_tensor(masks, dtype=torch.uint8)
            labels = torch.tensor([ann['category_id'] for ann in ann_info['segments_info']], dtype=torch.int64)

        target = {}
        target['image_id'] = torch.tensor([ann_info['image_id'] if "image_id" in ann_info else ann_info["id"]])
        if self.return_masks:
            target['masks'] = masks
        target['labels'] = labels

        target["boxes"] = masks_to_boxes(masks)

        target['size'] = torch.as_tensor([int(h), int(w)])
        target['orig_size'] = torch.as_tensor([int(h), int(w)])
        if "segments_info" in ann_info:
            for name in ['iscrowd', 'area']:
                target[name] = torch.tensor([ann[name] for ann in ann_info['segments_info']])

        if self.transforms is not None:
            img, target = self.transforms(img, target)

        return img, target

    def __len__(self):
        return len(self.coco['images'])

    def get_height_and_width(self, idx):
        img_info = self.coco['images'][idx]
        height = img_info['height']
        width = img_info['width']
        return height, width


def build(image_set, args):
    img_folder_root = Path(args.coco_path)
    ann_folder_root = Path(args.coco_panoptic_path)
    assert img_folder_root.exists(), f'provided COCO path {img_folder_root} does not exist'
    assert ann_folder_root.exists(), f'provided COCO path {ann_folder_root} does not exist'
    mode = 'panoptic'
    PATHS = {
        "train": ("train2017", Path("annotations") / f'{mode}_train2017.json'),
        "val": ("val2017", Path("annotations") / f'{mode}_val2017.json'),
    }

    img_folder, ann_file = PATHS[image_set]
    img_folder_path = img_folder_root / img_folder
    ann_folder = ann_folder_root / f'{mode}_{img_folder}'
    ann_file = ann_folder_root / ann_file

    dataset = CocoPanoptic(img_folder_path, ann_folder, ann_file,
                           transforms=make_coco_transforms(image_set), return_masks=args.masks)

    return dataset


================================================
FILE: datasets/data_prefetcher.py
================================================
# ------------------------------------------------------------------------
# Deformable DETR
# Copyright (c) 2020 SenseTime. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------

import torch

def to_cuda(samples, targets, device):
    samples = samples.to(device, non_blocking=True)
    targets = [{k: v.to(device, non_blocking=True) for k, v in t.items()} for t in targets]
    return samples, targets

class data_prefetcher():
    def __init__(self, loader, device, prefetch=True):
        self.loader = iter(loader)
        self.prefetch = prefetch
        self.device = device
        if prefetch:
            self.stream = torch.cuda.Stream()
            self.preload()

    def preload(self):
        try:
            self.next_samples, self.next_targets = next(self.loader)
        except StopIteration:
            self.next_samples = None
            self.next_targets = None
            return
        # if record_stream() doesn't work, another option is to make sure device inputs are created
        # on the main stream.
        # self.next_input_gpu = torch.empty_like(self.next_input, device='cuda')
        # self.next_target_gpu = torch.empty_like(self.next_target, device='cuda')
        # Need to make sure the memory allocated for next_* is not still in use by the main stream
        # at the time we start copying to next_*:
        # self.stream.wait_stream(torch.cuda.current_stream())
        with torch.cuda.stream(self.stream):
            self.next_samples, self.next_targets = to_cuda(self.next_samples, self.next_targets, self.device)
            # more code for the alternative if record_stream() doesn't work:
            # copy_ will record the use of the pinned source tensor in this side stream.
            # self.next_input_gpu.copy_(self.next_input, non_blocking=True)
            # self.next_target_gpu.copy_(self.next_target, non_blocking=True)
            # self.next_input = self.next_input_gpu
            # self.next_target = self.next_target_gpu

            # With Amp, it isn't necessary to manually convert data to half.
            # if args.fp16:
            #     self.next_input = self.next_input.half()
            # else:

    def next(self):
        if self.prefetch:
            torch.cuda.current_stream().wait_stream(self.stream)
            samples = self.next_samples
            targets = self.next_targets
            if samples is not None:
                samples.record_stream(torch.cuda.current_stream())
            if targets is not None:
                for t in targets:
                    for k, v in t.items():
                        v.record_stream(torch.cuda.current_stream())
            self.preload()
        else:
            try:
                samples, targets = next(self.loader)
                samples, targets = to_cuda(samples, targets, self.device)
            except StopIteration:
                samples = None
                targets = None
        return samples, targets


================================================
FILE: datasets/panoptic_eval.py
================================================
# ------------------------------------------------------------------------
# Deformable DETR
# Copyright (c) 2020 SenseTime. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------
# Modified from DETR (https://github.com/facebookresearch/detr)
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
# ------------------------------------------------------------------------

import json
import os

import util.misc as utils

try:
    from panopticapi.evaluation import pq_compute
except ImportError:
    pass


class PanopticEvaluator(object):
    def __init__(self, ann_file, ann_folder, output_dir="panoptic_eval"):
        self.gt_json = ann_file
        self.gt_folder = ann_folder
        if utils.is_main_process():
            if not os.path.exists(output_dir):
                os.mkdir(output_dir)
        self.output_dir = output_dir
        self.predictions = []

    def update(self, predictions):
        for p in predictions:
            with open(os.path.join(self.output_dir, p["file_name"]), "wb") as f:
                f.write(p.pop("png_string"))

        self.predictions += predictions

    def synchronize_between_processes(self):
        all_predictions = utils.all_gather(self.predictions)
        merged_predictions = []
        for p in all_predictions:
            merged_predictions += p
        self.predictions = merged_predictions

    def summarize(self):
        if utils.is_main_process():
            json_data = {"annotations": self.predictions}
            predictions_json = os.path.join(self.output_dir, "predictions.json")
            with open(predictions_json, "w") as f:
                f.write(json.dumps(json_data))
            return pq_compute(self.gt_json, predictions_json, gt_folder=self.gt_folder, pred_folder=self.output_dir)
        return None


================================================
FILE: datasets/samplers.py
================================================
# ------------------------------------------------------------------------
# Deformable DETR
# Copyright (c) 2020 SenseTime. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------
# Modified from codes in torch.utils.data.distributed
# ------------------------------------------------------------------------

import os
import math
import torch
import torch.distributed as dist
from torch.utils.data.sampler import Sampler


class DistributedSampler(Sampler):
    """Sampler that restricts data loading to a subset of the dataset.
    It is especially useful in conjunction with
    :class:`torch.nn.parallel.DistributedDataParallel`. In such case, each
    process can pass a DistributedSampler instance as a DataLoader sampler,
    and load a subset of the original dataset that is exclusive to it.
    .. note::
        Dataset is assumed to be of constant size.
    Arguments:
        dataset: Dataset used for sampling.
        num_replicas (optional): Number of processes participating in
            distributed training.
        rank (optional): Rank of the current process within num_replicas.
    """

    def __init__(self, dataset, num_replicas=None, rank=None, local_rank=None, local_size=None, shuffle=True):
        if num_replicas is None:
            if not dist.is_available():
                raise RuntimeError("Requires distributed package to be available")
            num_replicas = dist.get_world_size()
        if rank is None:
            if not dist.is_available():
                raise RuntimeError("Requires distributed package to be available")
            rank = dist.get_rank()
        self.dataset = dataset
        self.num_replicas = num_replicas
        self.rank = rank
        self.epoch = 0
        self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas))
        self.total_size = self.num_samples * self.num_replicas
        self.shuffle = shuffle

    def __iter__(self):
        if self.shuffle:
            # deterministically shuffle based on epoch
            g = torch.Generator()
            g.manual_seed(self.epoch)
            indices = torch.randperm(len(self.dataset), generator=g).tolist()
        else:
            indices = torch.arange(len(self.dataset)).tolist()

        # add extra samples to make it evenly divisible
        indices += indices[: (self.total_size - len(indices))]
        assert len(indices) == self.total_size

        # subsample
        offset = self.num_samples * self.rank
        indices = indices[offset : offset + self.num_samples]
        assert len(indices) == self.num_samples

        return iter(indices)

    def __len__(self):
        return self.num_samples

    def set_epoch(self, epoch):
        self.epoch = epoch


class NodeDistributedSampler(Sampler):
    """Sampler that restricts data loading to a subset of the dataset.
    It is especially useful in conjunction with
    :class:`torch.nn.parallel.DistributedDataParallel`. In such case, each
    process can pass a DistributedSampler instance as a DataLoader sampler,
    and load a subset of the original dataset that is exclusive to it.
    .. note::
        Dataset is assumed to be of constant size.
    Arguments:
        dataset: Dataset used for sampling.
        num_replicas (optional): Number of processes participating in
            distributed training.
        rank (optional): Rank of the current process within num_replicas.
    """

    def __init__(self, dataset, num_replicas=None, rank=None, local_rank=None, local_size=None, shuffle=True):
        if num_replicas is None:
            if not dist.is_available():
                raise RuntimeError("Requires distributed package to be available")
            num_replicas = dist.get_world_size()
        if rank is None:
            if not dist.is_available():
                raise RuntimeError("Requires distributed package to be available")
            rank = dist.get_rank()
        if local_rank is None:
            local_rank = int(os.environ.get('LOCAL_RANK', 0))
        if local_size is None:
            local_size = int(os.environ.get('LOCAL_SIZE', 1))
        self.dataset = dataset
        self.shuffle = shuffle
        self.num_replicas = num_replicas
        self.num_parts = local_size
        self.rank = rank
        self.local_rank = local_rank
        self.epoch = 0
        self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas))
        self.total_size = self.num_samples * self.num_replicas

        self.total_size_parts = self.num_samples * self.num_replicas // self.num_parts

    def __iter__(self):
        if self.shuffle:
            # deterministically shuffle based on epoch
            g = torch.Generator()
            g.manual_seed(self.epoch)
            indices = torch.randperm(len(self.dataset), generator=g).tolist()
        else:
            indices = torch.arange(len(self.dataset)).tolist()
        indices = [i for i in indices if i % self.num_parts == self.local_rank]

        # add extra samples to make it evenly divisible
        indices += indices[:(self.total_size_parts - len(indices))]
        assert len(indices) == self.total_size_parts

        # subsample
        indices = indices[self.rank // self.num_parts:self.total_size_parts:self.num_replicas // self.num_parts]
        assert len(indices) == self.num_samples

        return iter(indices)

    def __len__(self):
        return self.num_samples

    def set_epoch(self, epoch):
        self.epoch = epoch


================================================
FILE: datasets/torchvision_datasets/__init__.py
================================================
# ------------------------------------------------------------------------
# Deformable DETR
# Copyright (c) 2020 SenseTime. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------

from .coco import CocoDetection


================================================
FILE: datasets/torchvision_datasets/coco.py
================================================
# ------------------------------------------------------------------------
# Deformable DETR
# Copyright (c) 2020 SenseTime. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------
# Modified from torchvision
# ------------------------------------------------------------------------

"""
Copy-Paste from torchvision, but add utility of caching images on memory
"""
from torchvision.datasets.vision import VisionDataset
from PIL import Image
import os
import os.path
import tqdm
from io import BytesIO


class CocoDetection(VisionDataset):
    """`MS Coco Detection <http://mscoco.org/dataset/#detections-challenge2016>`_ Dataset.
    Args:
        root (string): Root directory where images are downloaded to.
        annFile (string): Path to json annotation file.
        transform (callable, optional): A function/transform that  takes in an PIL image
            and returns a transformed version. E.g, ``transforms.ToTensor``
        target_transform (callable, optional): A function/transform that takes in the
            target and transforms it.
        transforms (callable, optional): A function/transform that takes input sample and its target as entry
            and returns a transformed version.
    """

    def __init__(self, root, annFile, transform=None, target_transform=None, transforms=None,
                 cache_mode=False, local_rank=0, local_size=1):
        super(CocoDetection, self).__init__(root, transforms, transform, target_transform)
        from pycocotools.coco import COCO
        self.coco = COCO(annFile)
        self.ids = list(sorted(self.coco.imgs.keys()))
        self.cache_mode = cache_mode
        self.local_rank = local_rank
        self.local_size = local_size
        if cache_mode:
            self.cache = {}
            self.cache_images()

    def cache_images(self):
        self.cache = {}
        for index, img_id in zip(tqdm.trange(len(self.ids)), self.ids):
            if index % self.local_size != self.local_rank:
                continue
            path = self.coco.loadImgs(img_id)[0]['file_name']
            with open(os.path.join(self.root, path), 'rb') as f:
                self.cache[path] = f.read()

    def get_image(self, path):
        if self.cache_mode:
            if path not in self.cache.keys():
                with open(os.path.join(self.root, path), 'rb') as f:
                    self.cache[path] = f.read()
            return Image.open(BytesIO(self.cache[path])).convert('RGB')
        return Image.open(os.path.join(self.root, path)).convert('RGB')

    def __getitem__(self, index):
        """
        Args:
            index (int): Index
        Returns:
            tuple: Tuple (image, target). target is the object returned by ``coco.loadAnns``.
        """
        coco = self.coco
        img_id = self.ids[index]
        ann_ids = coco.getAnnIds(imgIds=img_id)
        target = coco.loadAnns(ann_ids)

        path = coco.loadImgs(img_id)[0]['file_name']

        img = self.get_image(path)
        if self.transforms is not None:
            img, target = self.transforms(img, target)

        return img, target

    def __len__(self):
        return len(self.ids)


================================================
FILE: datasets/transforms.py
================================================
# ------------------------------------------------------------------------
# Deformable DETR
# Copyright (c) 2020 SenseTime. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------
# Modified from DETR (https://github.com/facebookresearch/detr)
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
# ------------------------------------------------------------------------

"""
Transforms and data augmentation for both image + bbox.
"""
import random

import PIL
import torch
import torchvision.transforms as T
import torchvision.transforms.functional as F

from util.box_ops import box_xyxy_to_cxcywh
from util.misc import interpolate


def crop(image, target, region):
    cropped_image = F.crop(image, *region)

    target = target.copy()
    i, j, h, w = region

    # should we do something wrt the original size?
    target["size"] = torch.tensor([h, w])

    fields = ["labels", "area", "iscrowd"]

    if "boxes" in target:
        boxes = target["boxes"]
        max_size = torch.as_tensor([w, h], dtype=torch.float32)
        cropped_boxes = boxes - torch.as_tensor([j, i, j, i])
        cropped_boxes = torch.min(cropped_boxes.reshape(-1, 2, 2), max_size)
        cropped_boxes = cropped_boxes.clamp(min=0)
        area = (cropped_boxes[:, 1, :] - cropped_boxes[:, 0, :]).prod(dim=1)
        target["boxes"] = cropped_boxes.reshape(-1, 4)
        target["area"] = area
        fields.append("boxes")

    if "masks" in target:
        # FIXME should we update the area here if there are no boxes?
        target['masks'] = target['masks'][:, i:i + h, j:j + w]
        fields.append("masks")

    # remove elements for which the boxes or masks that have zero area
    if "boxes" in target or "masks" in target:
        # favor boxes selection when defining which elements to keep
        # this is compatible with previous implementation
        if "boxes" in target:
            cropped_boxes = target['boxes'].reshape(-1, 2, 2)
            keep = torch.all(cropped_boxes[:, 1, :] > cropped_boxes[:, 0, :], dim=1)
        else:
            keep = target['masks'].flatten(1).any(1)

        for field in fields:
            target[field] = target[field][keep]

    return cropped_image, target


def hflip(image, target):
    flipped_image = F.hflip(image)

    w, h = image.size

    target = target.copy()
    if "boxes" in target:
        boxes = target["boxes"]
        boxes = boxes[:, [2, 1, 0, 3]] * torch.as_tensor([-1, 1, -1, 1]) + torch.as_tensor([w, 0, w, 0])
        target["boxes"] = boxes

    if "masks" in target:
        target['masks'] = target['masks'].flip(-1)

    return flipped_image, target


def resize(image, target, size, max_size=None):
    # size can be min_size (scalar) or (w, h) tuple

    def get_size_with_aspect_ratio(image_size, size, max_size=None):
        w, h = image_size
        if max_size is not None:
            min_original_size = float(min((w, h)))
            max_original_size = float(max((w, h)))
            if max_original_size / min_original_size * size > max_size:
                size = int(round(max_size * min_original_size / max_original_size))

        if (w <= h and w == size) or (h <= w and h == size):
            return (h, w)

        if w < h:
            ow = size
            oh = int(size * h / w)
        else:
            oh = size
            ow = int(size * w / h)

        return (oh, ow)

    def get_size(image_size, size, max_size=None):
        if isinstance(size, (list, tuple)):
            return size[::-1]
        else:
            return get_size_with_aspect_ratio(image_size, size, max_size)

    size = get_size(image.size, size, max_size)
    rescaled_image = F.resize(image, size)

    if target is None:
        return rescaled_image, None

    ratios = tuple(float(s) / float(s_orig) for s, s_orig in zip(rescaled_image.size, image.size))
    ratio_width, ratio_height = ratios

    target = target.copy()
    if "boxes" in target:
        boxes = target["boxes"]
        scaled_boxes = boxes * torch.as_tensor([ratio_width, ratio_height, ratio_width, ratio_height])
        target["boxes"] = scaled_boxes

    if "area" in target:
        area = target["area"]
        scaled_area = area * (ratio_width * ratio_height)
        target["area"] = scaled_area

    h, w = size
    target["size"] = torch.tensor([h, w])

    if "masks" in target:
        target['masks'] = interpolate(
            target['masks'][:, None].float(), size, mode="nearest")[:, 0] > 0.5

    return rescaled_image, target


def pad(image, target, padding):
    # assumes that we only pad on the bottom right corners
    padded_image = F.pad(image, (0, 0, padding[0], padding[1]))
    if target is None:
        return padded_image, None
    target = target.copy()
    # should we do something wrt the original size?
    target["size"] = torch.tensor(padded_image[::-1])
    if "masks" in target:
        target['masks'] = torch.nn.functional.pad(target['masks'], (0, padding[0], 0, padding[1]))
    return padded_image, target


class RandomCrop(object):
    def __init__(self, size):
        self.size = size

    def __call__(self, img, target):
        region = T.RandomCrop.get_params(img, self.size)
        return crop(img, target, region)


class RandomSizeCrop(object):
    def __init__(self, min_size: int, max_size: int):
        self.min_size = min_size
        self.max_size = max_size

    def __call__(self, img: PIL.Image.Image, target: dict):
        w = random.randint(self.min_size, min(img.width, self.max_size))
        h = random.randint(self.min_size, min(img.height, self.max_size))
        region = T.RandomCrop.get_params(img, [h, w])
        return crop(img, target, region)


class CenterCrop(object):
    def __init__(self, size):
        self.size = size

    def __call__(self, img, target):
        image_width, image_height = img.size
        crop_height, crop_width = self.size
        crop_top = int(round((image_height - crop_height) / 2.))
        crop_left = int(round((image_width - crop_width) / 2.))
        return crop(img, target, (crop_top, crop_left, crop_height, crop_width))


class RandomHorizontalFlip(object):
    def __init__(self, p=0.5):
        self.p = p

    def __call__(self, img, target):
        if random.random() < self.p:
            return hflip(img, target)
        return img, target


class RandomResize(object):
    def __init__(self, sizes, max_size=None):
        assert isinstance(sizes, (list, tuple))
        self.sizes = sizes
        self.max_size = max_size

    def __call__(self, img, target=None):
        size = random.choice(self.sizes)
        return resize(img, target, size, self.max_size)


class RandomPad(object):
    def __init__(self, max_pad):
        self.max_pad = max_pad

    def __call__(self, img, target):
        pad_x = random.randint(0, self.max_pad)
        pad_y = random.randint(0, self.max_pad)
        return pad(img, target, (pad_x, pad_y))


class RandomSelect(object):
    """
    Randomly selects between transforms1 and transforms2,
    with probability p for transforms1 and (1 - p) for transforms2
    """
    def __init__(self, transforms1, transforms2, p=0.5):
        self.transforms1 = transforms1
        self.transforms2 = transforms2
        self.p = p

    def __call__(self, img, target):
        if random.random() < self.p:
            return self.transforms1(img, target)
        return self.transforms2(img, target)


class ToTensor(object):
    def __call__(self, img, target):
        return F.to_tensor(img), target


class RandomErasing(object):

    def __init__(self, *args, **kwargs):
        self.eraser = T.RandomErasing(*args, **kwargs)

    def __call__(self, img, target):
        return self.eraser(img), target


class Normalize(object):
    def __init__(self, mean, std):
        self.mean = mean
        self.std = std

    def __call__(self, image, target=None):
        image = F.normalize(image, mean=self.mean, std=self.std)
        if target is None:
            return image, None
        target = target.copy()
        h, w = image.shape[-2:]
        if "boxes" in target:
            boxes = target["boxes"]
            boxes = box_xyxy_to_cxcywh(boxes)
            boxes = boxes / torch.tensor([w, h, w, h], dtype=torch.float32)
            target["boxes"] = boxes
        return image, target


class Compose(object):
    def __init__(self, transforms):
        self.transforms = transforms

    def __call__(self, image, target):
        for t in self.transforms:
            image, target = t(image, target)
        return image, target

    def __repr__(self):
        format_string = self.__class__.__name__ + "("
        for t in self.transforms:
            format_string += "\n"
            format_string += "    {0}".format(t)
        format_string += "\n)"
        return format_string


================================================
FILE: engine.py
================================================
# ------------------------------------------------------------------------------------
# Sparse DETR
# Copyright (c) 2021 KakaoBrain. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------------------
# Modified from Deformable DETR (https://github.com/fundamentalvision/Deformable-DETR)
# Copyright (c) 2020 SenseTime. All Rights Reserved.
# ------------------------------------------------------------------------------------
# Modified from DETR (https://github.com/facebookresearch/detr)
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
# ------------------------------------------------------------------------------------


"""
Train and eval functions used in main.py
"""
import math
import os
import sys
from typing import Iterable

import torch
import util.misc as utils
from datasets.coco_eval import CocoEvaluator
from datasets.panoptic_eval import PanopticEvaluator
from datasets.data_prefetcher import data_prefetcher

from util.misc import check_unused_parameters


def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module,
                    data_loader: Iterable, optimizer: torch.optim.Optimizer,
                    device: torch.device, epoch: int, max_norm: float = 0, 
                    writer=None, total_iter=0):
    model.train()
    criterion.train()
    metric_logger = utils.MetricLogger(delimiter="  ")
    metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
    metric_logger.add_meter('class_error', utils.SmoothedValue(window_size=1, fmt='{value:.2f}'))
    metric_logger.add_meter('grad_norm', utils.SmoothedValue(window_size=1, fmt='{value:.2f}'))
    header = 'Epoch: [{}]'.format(epoch)
    print_freq = 10

    prefetcher = data_prefetcher(data_loader, device, prefetch=True)
    samples, targets = prefetcher.next()

    for i in metric_logger.log_every(range(len(data_loader)), print_freq, header):            
        outputs = model(samples)
        loss_dict = criterion(outputs, targets)
        weight_dict = criterion.weight_dict
        losses = sum(loss_dict[k] * weight_dict[k] for k in loss_dict.keys() if k in weight_dict)

        # reduce losses over all GPUs for logging purposes
        loss_dict_reduced = utils.reduce_dict(loss_dict)
        loss_dict_reduced_unscaled = {f'{k}_unscaled': v
                                      for k, v in loss_dict_reduced.items()}
        loss_dict_reduced_scaled = {k: v * weight_dict[k]
                                    for k, v in loss_dict_reduced.items() if k in weight_dict}
        losses_reduced_scaled = sum(loss_dict_reduced_scaled.values())

        loss_value = losses_reduced_scaled.item()

        if not math.isfinite(loss_value):
            print("Loss is {}, stopping training".format(loss_value))
            print(loss_dict_reduced)
            sys.exit(1)
            
        optimizer.zero_grad()
        losses.backward()
        
        if i == 0:
            check_unused_parameters(model, loss_dict, weight_dict)
                
        if max_norm > 0:
            grad_total_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)
        else:
            grad_total_norm = utils.get_total_grad_norm(model.parameters(), max_norm)
            
        metric_logger.update(loss=loss_value, **loss_dict_reduced_scaled, **loss_dict_reduced_unscaled)
        metric_logger.update(class_error=loss_dict_reduced['class_error'])
        metric_logger.update(lr=optimizer.param_groups[0]["lr"])
        metric_logger.update(grad_norm=grad_total_norm)
                    
        optimizer.step()

        if total_iter % (print_freq*10) == 0 and utils.is_main_process():
            writer.add_scalar('train/loss', loss_value, total_iter)
            writer.add_scalar('train/class_error', loss_dict_reduced['class_error'], total_iter)
            writer.add_scalar('lr', optimizer.param_groups[0]["lr"], total_iter)
            writer.add_scalar('train/grad_norm', grad_total_norm, total_iter)
            for key, value in loss_dict_reduced_scaled.items():
                writer.add_scalar('train/'+key, value, total_iter)
            for key, value in loss_dict_reduced_unscaled.items():
                if "corr" in key:
                    writer.add_scalar('train/'+key, value, total_iter)

        total_iter += 1
        samples, targets = prefetcher.next()

    # gather the stats from all processes
    metric_logger.synchronize_between_processes()
    print("Averaged stats:", metric_logger)
    return {k: meter.global_avg for k, meter in metric_logger.meters.items()}, total_iter


@torch.no_grad()
def evaluate(model, criterion, postprocessors, data_loader, base_ds, device, args):
    model.eval()
    criterion.eval()

    metric_logger = utils.MetricLogger(delimiter="  ")
    metric_logger.add_meter('class_error', utils.SmoothedValue(window_size=1, fmt='{value:.2f}'))
    header = 'Test:'

    iou_types = tuple(k for k in ('segm', 'bbox') if k in postprocessors.keys())
    coco_evaluator = CocoEvaluator(base_ds, iou_types)

    panoptic_evaluator = None
    if 'panoptic' in postprocessors.keys():
        panoptic_evaluator = PanopticEvaluator(
            data_loader.dataset.ann_file,
            data_loader.dataset.ann_folder,
            output_dir=os.path.join(args.output_dir, "panoptic_eval"),
        )

    for step, (samples, targets) in enumerate(metric_logger.log_every(data_loader, 10, header)):
        samples = samples.to(device)
        targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

        outputs = model(samples)
        loss_dict = criterion(outputs, targets)
        weight_dict = criterion.weight_dict

        # reduce losses over all GPUs for logging purposes
        loss_dict_reduced = utils.reduce_dict(loss_dict)
        loss_dict_reduced_scaled = {k: v * weight_dict[k]
                                    for k, v in loss_dict_reduced.items() if k in weight_dict}
        loss_dict_reduced_unscaled = {f'{k}_unscaled': v
                                      for k, v in loss_dict_reduced.items()}
        metric_logger.update(loss=sum(loss_dict_reduced_scaled.values()),
                             **loss_dict_reduced_scaled,
                             **loss_dict_reduced_unscaled)
        metric_logger.update(class_error=loss_dict_reduced['class_error'])

        orig_target_sizes = torch.stack([t["orig_size"] for t in targets], dim=0)
        results = postprocessors['bbox'](outputs, orig_target_sizes)
        if 'segm' in postprocessors.keys():
            target_sizes = torch.stack([t["size"] for t in targets], dim=0)
            results = postprocessors['segm'](results, outputs, orig_target_sizes, target_sizes)
        res = {target['image_id'].item(): output for target, output in zip(targets, results)}
        if coco_evaluator is not None:
            coco_evaluator.update(res)

        if panoptic_evaluator is not None:
            res_pano = postprocessors["panoptic"](outputs, target_sizes, orig_target_sizes)
            for i, target in enumerate(targets):
                image_id = target["image_id"].item()
                file_name = f"{image_id:012d}.png"
                res_pano[i]["image_id"] = image_id
                res_pano[i]["file_name"] = file_name

            panoptic_evaluator.update(res_pano)



    # gather the stats from all processes
    metric_logger.synchronize_between_processes()
    print("Averaged stats:", metric_logger)
    if coco_evaluator is not None:
        coco_evaluator.synchronize_between_processes()
    if panoptic_evaluator is not None:
        panoptic_evaluator.synchronize_between_processes()

    # accumulate predictions from all images
    if coco_evaluator is not None:
        coco_evaluator.accumulate()
        coco_evaluator.summarize()
    panoptic_res = None
    if panoptic_evaluator is not None:
        panoptic_res = panoptic_evaluator.summarize()
    stats = {k: meter.global_avg for k, meter in metric_logger.meters.items()}
    if coco_evaluator is not None:
        if 'bbox' in postprocessors.keys():
            stats['coco_eval_bbox'] = coco_evaluator.coco_eval['bbox'].stats.tolist()
        if 'segm' in postprocessors.keys():
            stats['coco_eval_masks'] = coco_evaluator.coco_eval['segm'].stats.tolist()
    if panoptic_res is not None:
        stats['PQ_all'] = panoptic_res["All"]
        stats['PQ_th'] = panoptic_res["Things"]
        stats['PQ_st'] = panoptic_res["Stuff"]
    return stats, coco_evaluator


================================================
FILE: main.py
================================================
# ------------------------------------------------------------------------------------
# Sparse DETR
# Copyright (c) 2021 KakaoBrain. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------------------
# Modified from Deformable DETR (https://github.com/fundamentalvision/Deformable-DETR)
# Copyright (c) 2020 SenseTime. All Rights Reserved.
# ------------------------------------------------------------------------------------
# Modified from DETR (https://github.com/facebookresearch/detr)
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
# ------------------------------------------------------------------------------------


import argparse
import datetime
import json
import random
import time
from tabulate import tabulate
from pathlib import Path

import numpy as np
import torch
from torch.utils.data import DataLoader, Subset

import datasets
import util.misc as utils
import datasets.samplers as samplers
from datasets import build_dataset, get_coco_api_from_dataset
from engine import evaluate, train_one_epoch
from models import build_model
from util.benchmark import compute_fps, compute_gflops

from torch.utils.tensorboard import SummaryWriter


def get_args_parser():
    parser = argparse.ArgumentParser('Deformable DETR Detector', add_help=False)
    parser.add_argument('--lr', default=2e-4, type=float)
    parser.add_argument('--lr_backbone_names', default=["backbone.0"], type=str, nargs='+')
    parser.add_argument('--lr_backbone', default=2e-5, type=float)
    parser.add_argument('--lr_linear_proj_names', default=['reference_points', 'sampling_offsets'], type=str, nargs='+')
    parser.add_argument('--lr_linear_proj_mult', default=0.1, type=float)
    parser.add_argument('--batch_size', default=2, type=int)
    parser.add_argument('--weight_decay', default=1e-4, type=float)
    parser.add_argument('--epochs', default=50, type=int)
    parser.add_argument('--lr_drop', default=40, type=int)
    parser.add_argument('--lr_drop_epochs', default=None, type=int, nargs='+')
    parser.add_argument('--clip_max_norm', default=0.1, type=float,
                        help='gradient clipping max norm')


    parser.add_argument('--sgd', action='store_true')

    # Variants of Deformable DETR
    parser.add_argument('--with_box_refine', default=False, action='store_true')
    parser.add_argument('--two_stage', default=False, action='store_true')

    # Model parameters
    parser.add_argument('--frozen_weights', type=str, default=None,
                        help="Path to the pretrained model. If set, only the mask head will be trained")

    # * Backbone
    parser.add_argument('--backbone', default='resnet50', type=str,
                        help="Name of the convolutional backbone to use")
    parser.add_argument('--dilation', action='store_true',
                        help="If true, we replace stride with dilation in the last convolutional block (DC5)")
    parser.add_argument('--position_embedding', default='sine', type=str, choices=('sine', 'learned'),
                        help="Type of positional embedding to use on top of the image features")
    parser.add_argument('--position_embedding_scale', default=2 * np.pi, type=float,
                        help="position / size * scale")
    parser.add_argument('--num_feature_levels', default=4, type=int, help='number of feature levels')

    # * Modified architecture
    parser.add_argument('--backbone_from_scratch', default=False, action='store_true')
    parser.add_argument('--finetune_early_layers', default=False, action='store_true')
    parser.add_argument('--scrl_pretrained_path', default='', type=str)

    # * Transformer
    parser.add_argument('--enc_layers', default=6, type=int,
                        help="Number of encoding layers in the transformer")
    parser.add_argument('--dec_layers', default=6, type=int,
                        help="Number of decoding layers in the transformer")
    parser.add_argument('--dim_feedforward', default=1024, type=int,
                        help="Intermediate size of the feedforward layers in the transformer blocks")
    parser.add_argument('--hidden_dim', default=256, type=int,
                        help="Size of the embeddings (dimension of the transformer)")
    parser.add_argument('--dropout', default=0.1, type=float,
                        help="Dropout applied in the transformer")
    parser.add_argument('--nheads', default=8, type=int,
                        help="Number of attention heads inside the transformer's attentions")
    parser.add_argument('--num_queries', default=300, type=int,
                        help="Number of query slots")
    parser.add_argument('--dec_n_points', default=4, type=int)
    parser.add_argument('--enc_n_points', default=4, type=int)
    
    # * Efficient DETR
    parser.add_argument('--eff_query_init', default=False, action='store_true')
    parser.add_argument('--eff_specific_head', default=False, action='store_true')

    # * Sparse DETR
    parser.add_argument('--use_enc_aux_loss', default=False, action='store_true')
    parser.add_argument('--rho', default=0., type=float)

    # * Segmentation
    parser.add_argument('--masks', action='store_true',
                        help="Train segmentation head if the flag is provided")

    # Loss
    parser.add_argument('--no_aux_loss', dest='aux_loss', action='store_false',
                        help="Disables auxiliary decoding losses (loss at each layer)")

    # * Matcher
    parser.add_argument('--set_cost_class', default=2, type=float,
                        help="Class coefficient in the matching cost")
    parser.add_argument('--set_cost_bbox', default=5, type=float,
                        help="L1 box coefficient in the matching cost")
    parser.add_argument('--set_cost_giou', default=2, type=float,
                        help="giou box coefficient in the matching cost")

    # * Loss coefficients
    parser.add_argument('--mask_loss_coef', default=1, type=float)
    parser.add_argument('--dice_loss_coef', default=1, type=float)
    parser.add_argument('--cls_loss_coef', default=2, type=float)
    parser.add_argument('--bbox_loss_coef', default=5, type=float)
    parser.add_argument('--giou_loss_coef', default=2, type=float)
    parser.add_argument('--mask_prediction_coef', default=1, type=float)
    parser.add_argument('--focal_alpha', default=0.25, type=float)

    # * dataset parameters
    parser.add_argument('--dataset_file', default='coco')
    parser.add_argument('--coco_path', default='./data/coco', type=str)
    parser.add_argument('--coco_panoptic_path', type=str)
    parser.add_argument('--remove_difficult', action='store_true')

    parser.add_argument('--output_dir', default='',
                        help='path where to save, empty for no saving')
    parser.add_argument('--device', default='cuda',
                        help='device to use for training / testing')
    parser.add_argument('--seed', default=42, type=int)
    parser.add_argument('--resume', default='', help='resume from checkpoint')
    parser.add_argument('--start_epoch', default=0, type=int, metavar='N',
                        help='start epoch')
    parser.add_argument('--eval', action='store_true')
    parser.add_argument('--num_workers', default=2, type=int)
    parser.add_argument('--cache_mode', default=False, action='store_true', help='whether to cache images on memory')
    
    # * benchmark
    parser.add_argument('--approx_benchmark_only', default=False, action='store_true')
    parser.add_argument('--benchmark_only', default=False, action='store_true')
    parser.add_argument('--no_benchmark', dest='benchmark', action='store_false')

    return parser


def main(args):
    utils.init_distributed_mode(args)
    print("git:\n  {}\n".format(utils.get_sha()))

    if args.frozen_weights is not None:
        assert args.masks, "Frozen training is meant for segmentation only"
    print(args)

    device = torch.device(args.device)

    # fix the seed for reproducibility
    seed = args.seed + utils.get_rank()
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)

    model, criterion, postprocessors = build_model(args)
    model.to(device)
    model_without_ddp = model
    
    dataset_val_org = build_dataset(image_set='val', args=args)
    
    if args.approx_benchmark_only or args.benchmark_only:
        assert not args.distributed and args.benchmark
    
    if utils.is_main_process() and args.benchmark:
        n_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
        if args.benchmark_only:
            gflops = compute_gflops(model, dataset_val_org, approximated=False)
        else:
            gflops = compute_gflops(model, dataset_val_org, approximated=True)
        fps = compute_fps(model, dataset_val_org, num_iters=20, batch_size=1)
        bfps = compute_fps(model, dataset_val_org, num_iters=20, batch_size=4)
        tab_keys = ["#Params(M)", "GFLOPs", "FPS", "B4FPS"]
        tab_vals = [n_params / 10 ** 6, gflops, fps, bfps]
        table = tabulate([tab_vals], headers=tab_keys, tablefmt="pipe",
                        floatfmt=".3f", stralign="center", numalign="center")
        print("===== Benchmark (Crude Approx.) =====\n" + table)
        
    if args.approx_benchmark_only or args.benchmark_only:
        import sys; sys.exit()
            
    if args.distributed:
        # wait for benchmark in the main process
        torch.distributed.barrier()
        
    dataset_train = build_dataset(image_set='train', args=args)
    dataset_val = dataset_val_org

    if args.distributed:
        if args.cache_mode:
            sampler_train = samplers.NodeDistributedSampler(dataset_train)
            sampler_val = samplers.NodeDistributedSampler(dataset_val, shuffle=False)
        else:
            sampler_train = samplers.DistributedSampler(dataset_train)
            sampler_val = samplers.DistributedSampler(dataset_val, shuffle=False)
    else:
        sampler_train = torch.utils.data.RandomSampler(dataset_train)
        sampler_val = torch.utils.data.SequentialSampler(dataset_val)

    batch_sampler_train = torch.utils.data.BatchSampler(
        sampler_train, args.batch_size, drop_last=True)

    data_loader_train = DataLoader(dataset_train, batch_sampler=batch_sampler_train,
                                   collate_fn=utils.collate_fn, num_workers=args.num_workers,
                                   pin_memory=True)
    data_loader_val = DataLoader(dataset_val, args.batch_size, sampler=sampler_val,
                                 drop_last=False, collate_fn=utils.collate_fn, num_workers=args.num_workers,
                                 pin_memory=True)

    args = utils.scale_learning_rate(args)
    def match_name_keywords(n, name_keywords):
        out = False
        for b in name_keywords:
            if b in n:
                out = True
                break
        return out

    param_dicts = [
        {
            "params":
                [p for n, p in model_without_ddp.named_parameters()
                 if (not match_name_keywords(n, args.lr_backbone_names) 
                     and not match_name_keywords(n, args.lr_linear_proj_names) 
                     and p.requires_grad)],
            "lr": args.lr,
        },
        {
            "params": [p for n, p in model_without_ddp.named_parameters() 
                       if (match_name_keywords(n, args.lr_backbone_names) 
                           and not match_name_keywords(n, args.lr_linear_proj_names) 
                           and p.requires_grad)],
            "lr": args.lr_backbone,
        },
        {
            "params": [p for n, p in model_without_ddp.named_parameters() 
                       if match_name_keywords(n, args.lr_linear_proj_names) and p.requires_grad],
            "lr": args.lr * args.lr_linear_proj_mult,
        }
    ]
    if args.sgd:
        optimizer = torch.optim.SGD(param_dicts, lr=args.lr, momentum=0.9,
                                    weight_decay=args.weight_decay)
    else:
        optimizer = torch.optim.AdamW(param_dicts, lr=args.lr,
                                      weight_decay=args.weight_decay)
    lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, args.lr_drop)

    if args.distributed:
        model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], 
                                                          find_unused_parameters=True)
        model_without_ddp = model.module

    if args.dataset_file == "coco_panoptic":
        # We also evaluate AP during panoptic training, on original coco DS
        coco_val = datasets.coco.build("val", args)
        base_ds = get_coco_api_from_dataset(coco_val)
    else:
        base_ds = get_coco_api_from_dataset(dataset_val)

    if args.frozen_weights is not None:
        checkpoint = torch.load(args.frozen_weights, map_location='cpu')
        model_without_ddp.detr.load_state_dict(checkpoint['model'])

    output_dir = Path(args.output_dir)
    if args.resume:
        if args.resume.startswith('https'):
            checkpoint = torch.hub.load_state_dict_from_url(
                args.resume, map_location='cpu', check_hash=True)
        else:
            checkpoint = torch.load(args.resume, map_location='cpu')
        missing_keys, unexpected_keys = model_without_ddp.load_state_dict(checkpoint['model'], strict=False)
        unexpected_keys = [k for k in unexpected_keys if not (k.endswith('total_params') or k.endswith('total_ops'))]
        if len(missing_keys) > 0:
            print('Missing Keys: {}'.format(missing_keys))
        if len(unexpected_keys) > 0:
            print('Unexpected Keys: {}'.format(unexpected_keys))
        if not args.eval and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint:
            import copy
            p_groups = copy.deepcopy(optimizer.param_groups)
            optimizer.load_state_dict(checkpoint['optimizer'])
            for pg, pg_old in zip(optimizer.param_groups, p_groups):
                pg['lr'] = pg_old['lr']
                pg['initial_lr'] = pg_old['initial_lr']
            print(optimizer.param_groups)
            lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
            # todo: this is a hack for doing experiment that resume from checkpoint 
            # and also modify lr scheduler (e.g., decrease lr in advance).
            args.override_resumed_lr_drop = True
            if args.override_resumed_lr_drop:
                print('Warning: (hack) args.override_resumed_lr_drop is set to True, '
                      'so args.lr_drop would override lr_drop in resumed lr_scheduler.')
                lr_scheduler.step_size = args.lr_drop
                lr_scheduler.base_lrs = list(map(lambda group: group['initial_lr'], optimizer.param_groups))
            lr_scheduler.step(lr_scheduler.last_epoch)
            args.start_epoch = checkpoint['epoch'] + 1
        # check the resumed model
        if not args.eval:
            test_stats, coco_evaluator = evaluate(
                model, criterion, postprocessors, data_loader_val, base_ds, device, args
            )
    
    if args.eval:
        print("Start evaluation")
        start_time = time.time()
        test_stats, coco_evaluator = evaluate(model, criterion, postprocessors,
                                              data_loader_val, base_ds, device, args)
        if args.output_dir:
            utils.save_on_master(coco_evaluator.coco_eval["bbox"].eval, output_dir / "eval.pth")
        print_final_result_on_master(model, dataset_val_org, args, test_stats, start_time)
        return

    if utils.is_main_process():
        writer = SummaryWriter(output_dir)
    else:
        writer = None
    total_iter = 0
    
    print("Start training")
    start_time = time.time()
    for epoch in range(args.start_epoch, args.epochs):
        if args.distributed:
            sampler_train.set_epoch(epoch)
        train_stats, total_iter = train_one_epoch(
            model, criterion, data_loader_train, optimizer, device, epoch, args.clip_max_norm, writer, total_iter)
        lr_scheduler.step()
        if args.output_dir:
            checkpoint_paths = [output_dir / 'checkpoint.pth']
            # extra checkpoint before LR drop and every 5 epochs
            if (epoch + 1) % args.lr_drop == 0 or (epoch + 1) % 5 == 0:
                checkpoint_paths.append(output_dir / f'checkpoint{epoch:04}.pth')
            for checkpoint_path in checkpoint_paths:
                utils.save_on_master({
                    'model': model_without_ddp.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'lr_scheduler': lr_scheduler.state_dict(),
                    'epoch': epoch,
                    'args': args,
                }, checkpoint_path)

        test_stats, coco_evaluator = evaluate(
            model, criterion, postprocessors, data_loader_val, base_ds, device, args
        )

        # write test status
        if utils.is_main_process():
            writer.add_scalar('test/AP', test_stats['coco_eval_bbox'][0], epoch)
            writer.add_scalar('test/AP50', test_stats['coco_eval_bbox'][1], epoch)
            writer.add_scalar('test/AP75', test_stats['coco_eval_bbox'][2], epoch)
            writer.add_scalar('test/APs', test_stats['coco_eval_bbox'][3], epoch)
            writer.add_scalar('test/APm', test_stats['coco_eval_bbox'][4], epoch)
            writer.add_scalar('test/APl', test_stats['coco_eval_bbox'][5], epoch)
            writer.add_scalar('test/class_error', test_stats['class_error'], epoch)
            writer.add_scalar('test/loss', test_stats['loss'], epoch)
            writer.add_scalar('test/loss_ce', test_stats['loss_ce'], epoch)
            writer.add_scalar('test/loss_bbox', test_stats['loss_bbox'], epoch)
            writer.add_scalar('test/loss_giou', test_stats['loss_giou'], epoch)
            for key, value in test_stats.items():
                if "corr" in key:
                    writer.add_scalar('test/'+key, value, epoch)

        log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
                     **{f'test_{k}': v for k, v in test_stats.items()},
                     'epoch': epoch}

        if args.output_dir and utils.is_main_process():
            if args.benchmark:
                log_stats.update({'params': n_params, 'gflops': gflops, 'fps': fps, 'bfps': bfps})

            with (output_dir / "log.txt").open("a") as f:
                f.write(json.dumps(log_stats) + "\n")

            # for evaluation logs
            if coco_evaluator is not None:
                (output_dir / 'eval').mkdir(exist_ok=True)
                if "bbox" in coco_evaluator.coco_eval:
                    filenames = ['latest.pth']
                    if epoch % 50 == 0:
                        filenames.append(f'{epoch:03}.pth')
                    for name in filenames:
                        torch.save(coco_evaluator.coco_eval["bbox"].eval,
                                   output_dir / "eval" / name)
        
    print_final_result_on_master(model, dataset_val_org, args, test_stats, start_time)
    

def print_final_result_on_master(model, dataset_val, args, test_stats, start_time=None):   
    if not utils.is_main_process():
        return False
    
    # training wallclock-time / gpus-hours
    num_gpus = args.world_size if args.distributed else 1
    if start_time is not None:
        total_time = time.time() - start_time
        gpu_hours = total_time / 3600 * num_gpus
        gpu_hours_per_epoch = gpu_hours / args.epochs
        total_time_str = str(datetime.timedelta(seconds=int(total_time)))
    else:
        total_time_str, gpu_hours, gpu_hours_per_epoch = ["N/A"] * 3

    # make result table
    now = datetime.datetime.now().strftime("%h%d %H:%M")
    tab_keys =  ["Time", "output_dir", "epochs", "bsz", "#GPUs"]
    tab_vals =  [now, Path(args.output_dir), args.epochs, int(args.batch_size * num_gpus), num_gpus]
    tab_keys += ["AP", "AP50", "AP75", "APs", "APm", "APl"]
    tab_vals += [v * 100 for v in test_stats['coco_eval_bbox'][:6]]

    tab_keys += ["E/T", "GPU*hrs", "GPU*hrs/ep"]
    tab_vals += [total_time_str, gpu_hours, gpu_hours_per_epoch]
    
    # add benchmark
    if args.benchmark:
        gflops = compute_gflops(model, dataset_val, approximated=False)
        fps = compute_fps(model, dataset_val, num_iters=300, batch_size=1)
        bfps = compute_fps(model, dataset_val, num_iters=300, batch_size=4)
        tab_keys += ['GFLOPs', 'FPS', 'B4FPS']
        tab_vals += [gflops, fps, bfps]
        
    table = tabulate([tab_vals], headers=tab_keys, tablefmt="pipe",
                     floatfmt=".3f", stralign="center", numalign="center")
    
    # dump to the file
    with open("log_result.txt", "a") as f:
        f.write("\n" + table + "\n")
            
    print(f"Save the final result to ./log_result.txt\n{table}")


if __name__ == '__main__':
    parser = argparse.ArgumentParser('Sparse DETR training and evaluation script', parents=[get_args_parser()])
    args = parser.parse_args()
    if args.output_dir:
        Path(args.output_dir).mkdir(parents=True, exist_ok=True)
    main(args)


================================================
FILE: models/__init__.py
================================================
# ------------------------------------------------------------------------------------
# Sparse DETR
# Copyright (c) 2021 KakaoBrain. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------------------
# Modified from Deformable DETR (https://github.com/fundamentalvision/Deformable-DETR)
# Copyright (c) 2020 SenseTime. All Rights Reserved.
# ------------------------------------------------------------------------------------
# Modified from DETR (https://github.com/facebookresearch/detr)
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
# ------------------------------------------------------------------------------------


from .deformable_detr import build


def build_model(args):
    return build(args)



================================================
FILE: models/backbone.py
================================================
# ------------------------------------------------------------------------------------
# Sparse DETR
# Copyright (c) 2021 KakaoBrain. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------------------
# Modified from Deformable DETR (https://github.com/fundamentalvision/Deformable-DETR)
# Copyright (c) 2020 SenseTime. All Rights Reserved.
# ------------------------------------------------------------------------------------
# Modified from DETR (https://github.com/facebookresearch/detr)
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
# ------------------------------------------------------------------------------------


"""
Backbone modules.
"""
from collections import OrderedDict

import torch
import torch.nn.functional as F
import torchvision
from torch import nn
from torchvision.models._utils import IntermediateLayerGetter
from typing import Dict, List

from models import swin_transformer
from util.misc import NestedTensor, is_main_process

from .position_encoding import build_position_encoding


class FrozenBatchNorm2d(torch.nn.Module):
    """
    BatchNorm2d where the batch statistics and the affine parameters are fixed.

    Copy-paste from torchvision.misc.ops with added eps before rsqrt,
    without which any other models than torchvision.models.resnet[18,34,50,101]
    produce nans.
    """

    def __init__(self, n, eps=1e-5):
        super(FrozenBatchNorm2d, self).__init__()
        self.register_buffer("weight", torch.ones(n))
        self.register_buffer("bias", torch.zeros(n))
        self.register_buffer("running_mean", torch.zeros(n))
        self.register_buffer("running_var", torch.ones(n))
        self.eps = eps

    def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
                              missing_keys, unexpected_keys, error_msgs):
        num_batches_tracked_key = prefix + 'num_batches_tracked'
        if num_batches_tracked_key in state_dict:
            del state_dict[num_batches_tracked_key]

        super(FrozenBatchNorm2d, self)._load_from_state_dict(
            state_dict, prefix, local_metadata, strict,
            missing_keys, unexpected_keys, error_msgs)

    def forward(self, x):
        # move reshapes to the beginning
        # to make it fuser-friendly
        w = self.weight.reshape(1, -1, 1, 1)
        b = self.bias.reshape(1, -1, 1, 1)
        rv = self.running_var.reshape(1, -1, 1, 1)
        rm = self.running_mean.reshape(1, -1, 1, 1)
        eps = self.eps
        scale = w * (rv + eps).rsqrt()
        bias = b - rm * scale
        return x * scale + bias


class BackboneBase(nn.Module):

    def __init__(self, backbone: nn.Module, train_backbone: bool, return_interm_layers: bool, args):
        # TODO: args -> duplicated args
        super().__init__()
        if 'none' in args.backbone:
            self.strides = [1]  # not used, actually (length only matters)  
            self.num_channels = [3]
            return_layers = self.get_return_layers('identity', (0,))
            self.body = IntermediateLayerGetter(backbone, return_layers=return_layers)

        elif 'resnet' in args.backbone:
            
            if not args.backbone_from_scratch and not args.finetune_early_layers:
                print("Freeze early layers.")
                for name, parameter in backbone.named_parameters():
                    if not train_backbone or all([k not in name for k in ['layer2', 'layer3', 'layer4']]):
                        parameter.requires_grad_(False)
            else:
                print('Finetune early layers as well.')
                    
            layer_name = "layer"
            if return_interm_layers:
                return_layers = self.get_return_layers(layer_name, (2, 3, 4))
                self.strides = [8, 16, 32]
                self.num_channels = [512, 1024, 2048]
            else:
                return_layers = self.get_return_layers(layer_name, (4,))
                self.strides = [32]
                self.num_channels = [2048]
            self.body = IntermediateLayerGetter(backbone, return_layers=return_layers)
                
        elif 'swin' in args.backbone:
            if return_interm_layers:
                num_channels = [int(backbone.embed_dim * 2 ** i) for i in range(backbone.num_layers)]
                return_layers = [2, 3, 4]
                self.strides = [8, 16, 32]
                self.num_channels = num_channels[1:]
            else:
                return_layers = [4]
                self.strides = [32]
                self.num_channels = num_channels[-1]
            self.body = backbone
                
        else:
            raise ValueError(f"Unknown backbone name: {args.backbone}")
        
    @staticmethod
    def get_return_layers(name: str, layer_ids):
        return {name + str(n): str(i) for i, n in enumerate(layer_ids)}

    def forward(self, tensor_list: NestedTensor):
        xs = self.body(tensor_list.tensors)
        out: Dict[str, NestedTensor] = {}
        for name, x in xs.items():
            m = tensor_list.mask
            assert m is not None
            mask = F.interpolate(m[None].float(), size=x.shape[-2:]).to(torch.bool)[0]
            out[name] = NestedTensor(x, mask)
        return out
    
    
class DummyBackbone(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.identity0 = torch.nn.Identity()


class Backbone(BackboneBase):
    """ResNet backbone with frozen BatchNorm."""
    def __init__(self, name: str,
                 train_backbone: bool,
                 return_interm_layers: bool,
                 dilation: bool,
                 args):
        print(f"Backbone: {name}")
        pretrained = is_main_process() and not args.backbone_from_scratch and not args.scrl_pretrained_path
        if not pretrained:
            print("Train backbone from scratch.")
        else:
            print("Load pretrained weights")
        
        if "none" in name:
            backbone = DummyBackbone()
        elif "resnet" in name:
            assert name not in ("resnet18", "resnet34"), "number of channels are hard coded"
            backbone = getattr(torchvision.models, name)(
                replace_stride_with_dilation=[False, False, dilation],
                pretrained=pretrained, norm_layer=FrozenBatchNorm2d)
        elif "swin" in name:
            assert not dilation, "not supported"
            if not args.backbone_from_scratch and not args.finetune_early_layers:
                print("Freeze early layers.")
                frozen_stages = 2
            else:
                print('Finetune early layers as well.')
                frozen_stages = -1
            if return_interm_layers:
                out_indices = [1, 2, 3]
            else:
                out_indices = [3]
                
            backbone = swin_transformer.build_model(
                name, out_indices=out_indices, frozen_stages=frozen_stages, pretrained=pretrained)
        else:
            raise ValueError(f"Unknown backbone name: {args.backbone}")
            
        if args.scrl_pretrained_path:
            assert "resnet" in name, "Currently only resnet50 is available."
            ckpt = torch.load(args.scrl_pretrained_path, map_location="cpu")
            translate_map = {
                "encoder.0" : "conv1",
                "encoder.1" : "bn1",
                "encoder.4" : "layer1",
                "encoder.5" : "layer2",
                "encoder.6" : "layer3",
                "encoder.7" : "layer4",
            }
            state_dict = {
                translate_map[k[:9]] + k[9:] : v
                for k, v in ckpt["online_network_state_dict"].items()
                if "encoder" in k
            }
            backbone.load_state_dict(state_dict, strict=False)
        
        super().__init__(backbone, train_backbone, return_interm_layers, args)
        if dilation and "resnet" in name:
            self.strides[-1] = self.strides[-1] // 2


class Joiner(nn.Sequential):
    def __init__(self, backbone, position_embedding):
        super().__init__(backbone, position_embedding)
        self.strides = backbone.strides
        self.num_channels = backbone.num_channels

    def forward(self, tensor_list: NestedTensor):
        xs = self[0](tensor_list)
        out: List[NestedTensor] = []
        pos = []
        for name, x in sorted(xs.items()):
            out.append(x)

        # position encoding
        for x in out:
            pos.append(self[1](x).to(x.tensors.dtype))

        return out, pos
    
    
def test_backbone(backbone):
    imgs = [
        torch.randn(2, 3, 633, 122),
        torch.randn(2, 3, 322, 532),
        torch.randn(2, 3, 236, 42),
    ]
    return [backbone(img).shape for img in imgs]


def build_backbone(args):
    # test_backbone(torchvision.models.resnet50())
    position_embedding = build_position_encoding(args)
    train_backbone = args.lr_backbone > 0
    return_interm_layers = args.masks or (args.num_feature_levels > 1)
    backbone = Backbone(args.backbone, train_backbone, return_interm_layers, args.dilation, args)
    model = Joiner(backbone, position_embedding)
    return model


================================================
FILE: models/deformable_detr.py
================================================
# ------------------------------------------------------------------------------------
# Sparse DETR
# Copyright (c) 2021 KakaoBrain. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------------------
# Modified from Deformable DETR (https://github.com/fundamentalvision/Deformable-DETR)
# Copyright (c) 2020 SenseTime. All Rights Reserved.
# ------------------------------------------------------------------------------------
# Modified from DETR (https://github.com/facebookresearch/detr)
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
# ------------------------------------------------------------------------------------


"""
Deformable DETR model and criterion classes.
"""
import torch
import torch.nn.functional as F
from torch import nn
import math

from util import box_ops
from util.misc import (NestedTensor, nested_tensor_from_tensor_list,
                       accuracy, get_world_size, interpolate,
                       is_dist_avail_and_initialized, inverse_sigmoid)
from util.dam import idx_to_flat_grid, attn_map_to_flat_grid, compute_corr

from .backbone import build_backbone
from .matcher import build_matcher
from .segmentation import (DETRsegm, PostProcessPanoptic, PostProcessSegm,
                           dice_loss, sigmoid_focal_loss)
from .deformable_transformer import build_deforamble_transformer
import copy


def _get_clones(module, N):
    return nn.ModuleList([copy.deepcopy(module) for i in range(N)])


class DeformableDETR(nn.Module):
    """ This is the Deformable DETR module that performs object detection """
    def __init__(self, backbone, transformer, num_classes, num_queries, num_feature_levels,
                 aux_loss=True, with_box_refine=False, two_stage=False, args=None):
        """ Initializes the model.
        Parameters:
            backbone: torch module of the backbone to be used. See backbone.py
            transformer: torch module of the transformer architecture. See transformer.py
            num_classes: number of object classes
            num_queries: number of object queries, ie detection slot. This is the maximal number of objects
                         DETR can detect in a single image. For COCO, we recommend 100 queries.
            aux_loss: True if auxiliary decoding losses (loss at each decoder layer) are to be used.
            with_box_refine: iterative bounding box refinement
            two_stage: two-stage Deformable DETR
        """
        super().__init__()
        self.num_queries = num_queries
        self.transformer = transformer
        hidden_dim = transformer.d_model
        self.class_embed = nn.Linear(hidden_dim, num_classes)
        self.bbox_embed = MLP(hidden_dim, hidden_dim, output_dim=4, num_layers=3)
        self.num_feature_levels = num_feature_levels
        if not two_stage:
            self.query_embed = nn.Embedding(num_queries, hidden_dim * 2)
            # will be splited into query_embed(query_pos) & tgt later
        if num_feature_levels > 1:
            num_backbone_outs = len(backbone.strides)
            input_proj_list = []
            for _ in range(num_backbone_outs):
                in_channels = backbone.num_channels[_]
                input_proj_list.append(nn.Sequential(
                    nn.Conv2d(in_channels, hidden_dim, kernel_size=1),
                    nn.GroupNorm(32, hidden_dim),
                ))
            for _ in range(num_feature_levels - num_backbone_outs):
                input_proj_list.append(nn.Sequential(
                    nn.Conv2d(in_channels, hidden_dim, kernel_size=3, stride=2, padding=1),
                    nn.GroupNorm(32, hidden_dim),
                ))
                in_channels = hidden_dim
            self.input_proj = nn.ModuleList(input_proj_list)
        else:
            self.input_proj = nn.ModuleList([
                nn.Sequential(
                    nn.Conv2d(backbone.num_channels[0], hidden_dim, kernel_size=1),
                    nn.GroupNorm(32, hidden_dim),
                )])
        self.backbone = backbone
        self.aux_loss = aux_loss
        self.with_box_refine = with_box_refine
        self.two_stage = two_stage

        self.use_enc_aux_loss = args.use_enc_aux_loss
        self.rho = args.rho

        prior_prob = 0.01
        bias_value = -math.log((1 - prior_prob) / prior_prob)
        self.class_embed.bias.data = torch.ones(num_classes) * bias_value
        nn.init.constant_(self.bbox_embed.layers[-1].weight.data, 0)
        nn.init.constant_(self.bbox_embed.layers[-1].bias.data, 0)
        for proj in self.input_proj:
            nn.init.xavier_uniform_(proj[0].weight, gain=1)
            nn.init.constant_(proj[0].bias, 0)
 
        # hack implementation: a list of embedding heads (see the order)
        # n: dec_layers / m: enc_layers
        # [dec_0, dec_1, ..., dec_n-1, encoder, backbone, enc_0, enc_1, ..., enc_m-2]
        
        # at each layer of decoder (by default)
        num_pred = transformer.decoder.num_layers
        if self.two_stage:
            # at the end of encoder
            num_pred += 1  
        if self.use_enc_aux_loss:
            # at each layer of encoder (excl. the last)
            num_pred += transformer.encoder.num_layers - 1  
        
        if with_box_refine or self.use_enc_aux_loss:
            # individual heads with the same initialization
            self.class_embed = _get_clones(self.class_embed, num_pred)
            self.bbox_embed = _get_clones(self.bbox_embed, num_pred)
            nn.init.constant_(self.bbox_embed[0].layers[-1].bias.data[2:], -2.0)
        else:
            # shared heads
            nn.init.constant_(self.bbox_embed.layers[-1].bias.data[2:], -2.0)
            self.class_embed = nn.ModuleList([self.class_embed for _ in range(num_pred)])
            self.bbox_embed = nn.ModuleList([self.bbox_embed for _ in range(num_pred)])
            
        if two_stage:
            # hack implementation
            self.transformer.decoder.class_embed = self.class_embed
            self.transformer.decoder.bbox_embed = self.bbox_embed            
            for box_embed in self.transformer.decoder.bbox_embed:
                nn.init.constant_(box_embed.layers[-1].bias.data[2:], 0.0)
                
        if self.use_enc_aux_loss:
            # the output from the last layer should be specially treated as an input of decoder
            num_layers_excluding_the_last = transformer.encoder.num_layers - 1
            self.transformer.encoder.aux_heads = True
            self.transformer.encoder.class_embed = self.class_embed[-num_layers_excluding_the_last:]
            self.transformer.encoder.bbox_embed = self.bbox_embed[-num_layers_excluding_the_last:] 
            for box_embed in self.transformer.encoder.bbox_embed:
                nn.init.constant_(box_embed.layers[-1].bias.data[2:], 0.0)

    def forward(self, samples: NestedTensor):
        """ The forward expects a NestedTensor, which consists of:
               - samples.tensor: batched images, of shape [batch_size x 3 x H x W]
               - samples.mask: a binary mask of shape [batch_size x H x W], containing 1 on padded pixels

            It returns a dict with the following elements:
               - "pred_logits": the classification logits (including no-object) for all queries.
                                Shape= [batch_size x num_queries x (num_classes + 1)]
               - "pred_boxes": The normalized boxes coordinates for all queries, represented as
                               (center_x, center_y, height, width). These values are normalized in [0, 1],
                               relative to the size of each individual image (disregarding possible padding).
                               See PostProcess for information on how to retrieve the unnormalized bounding box.
               - "aux_outputs": Optional, only returned when auxilary losses are activated. It is a list of
                                dictionnaries containing the two above keys for each decoder layer.
        """
        ###########
        # Backbone
        if not isinstance(samples, NestedTensor):
            samples = nested_tensor_from_tensor_list(samples)
        features, pos = self.backbone(samples)

        srcs = []
        masks = []
        
        # multi-scale features projected from ~C5 with 1x1 conv
        for l, feat in enumerate(features):
            src, mask = feat.decompose()
            srcs.append(self.input_proj[l](src))
            masks.append(mask)
            assert mask is not None
            
        # multi-scale features smaller than C5 projected with 2 strided 3x3 conv
        if self.num_feature_levels > len(srcs):
            _len_srcs = len(srcs)
            for l in range(_len_srcs, self.num_feature_levels):
                if l == _len_srcs:
                    # feature scale 1/32 
                    src = self.input_proj[l](features[-1].tensors)
                else:
                    # feature scale <1/64: recursively downsample the last projection
                    src = self.input_proj[l](srcs[-1])
                m = samples.mask
                mask = F.interpolate(m[None].float(), size=src.shape[-2:]).to(torch.bool)[0]
                pos_l = self.backbone[1](NestedTensor(src, mask)).to(src.dtype)
                srcs.append(src)
                masks.append(mask)
                pos.append(pos_l)

        ###########
        # Transformer encoder & decoder
        query_embeds = None
        if not self.two_stage:
            query_embeds = self.query_embed.weight
        (hs, init_reference, inter_references, 
         enc_outputs_class, enc_outputs_coord_unact, 
         backbone_mask_prediction,
         enc_inter_outputs_class, enc_inter_outputs_coord, 
         sampling_locations_enc, attn_weights_enc, 
         sampling_locations_dec, attn_weights_dec,
         backbone_topk_proposals, spatial_shapes, level_start_index) = \
            self.transformer(srcs, masks, pos, query_embeds)

        ###########
        # Detection heads
        outputs_classes = []
        outputs_coords = []
        for lvl in range(len(hs)):
            # lvl: level of decoding layer
            outputs_class = self.class_embed[lvl](hs[lvl])
            outputs_coord = self.bbox_embed[lvl](hs[lvl])
            
            assert init_reference is not None and inter_references is not None
            if lvl == 0:
                reference = init_reference
            else:
                reference = inter_references[lvl - 1]
            reference = inverse_sigmoid(reference)
            if reference.shape[-1] == 4:
                outputs_coord += reference
            else:
                assert reference.shape[-1] == 2
                outputs_coord[..., :2] += reference
            
            outputs_coord = outputs_coord.sigmoid()
            outputs_classes.append(outputs_class)
            outputs_coords.append(outputs_coord)
            
        outputs_class = torch.stack(outputs_classes)
        outputs_coord = torch.stack(outputs_coords)

        # the topmost layer output
        out = {
            "pred_logits": outputs_class[-1],
            "pred_boxes": outputs_coord[-1],
            "sampling_locations_enc": sampling_locations_enc,
            "attn_weights_enc": attn_weights_enc,
            "sampling_locations_dec": sampling_locations_dec,
            "attn_weights_dec": attn_weights_dec,
            "spatial_shapes": spatial_shapes,
            "level_start_index": level_start_index,
        }
        if backbone_topk_proposals is not None:
            out["backbone_topk_proposals"] = backbone_topk_proposals
        
        if self.aux_loss:
            # make loss from every intermediate layers (excluding the last one)
            out['aux_outputs'] = self._set_aux_loss(outputs_class[:-1], outputs_coord[:-1])

        if self.two_stage:
            enc_outputs_coord = enc_outputs_coord_unact.sigmoid()
            out['enc_outputs'] = {'pred_logits': enc_outputs_class, 'pred_boxes': enc_outputs_coord}

        if self.rho:
            out["backbone_mask_prediction"] = backbone_mask_prediction
            
        if self.use_enc_aux_loss:
            out['aux_outputs_enc'] = self._set_aux_loss(enc_inter_outputs_class, enc_inter_outputs_coord)
        
        if self.rho:
            out["sparse_token_nums"] = self.transformer.sparse_token_nums

        out['mask_flatten'] = torch.cat([m.flatten(1) for m in masks], 1)

        return out

    @torch.jit.unused
    def _set_aux_loss(self, outputs_class, outputs_coord):
        # this is a workaround to make torchscript happy, as torchscript
        # doesn't support dictionary with non-homogeneous values, such
        # as a dict having both a Tensor and a list.
        return [{'pred_logits': a, 'pred_boxes': b}
                for a, b in zip(outputs_class, outputs_coord)]


class SetCriterion(nn.Module):
    """ This class computes the loss for DETR.
    The process happens in two steps:
        1) we compute hungarian assignment between ground truth boxes and the outputs of the model
        2) we supervise each pair of matched ground-truth / prediction (supervise class and box)
    """
    def __init__(self, num_classes, matcher, weight_dict, losses, args):
        """ Create the criterion.
        Parameters:
            num_classes: number of object categories, omitting the special no-object category
            matcher: module able to compute a matching between targets and proposals
            weight_dict: dict containing as key the names of the losses and as values their relative weight.
            losses: list of all the losses to be applied. See get_loss for list of available losses.
            focal_alpha: alpha in Focal Loss
        """
        super().__init__()
        self.num_classes = num_classes
        self.matcher = matcher
        self.weight_dict = weight_dict
        self.losses = losses

        self.focal_alpha = args.focal_alpha
        self.eff_specific_head = args.eff_specific_head

    def loss_labels(self, outputs, targets, indices, num_boxes, log=True):
        """Classification loss (NLL)
        targets dicts must contain the key "labels" containing a tensor of dim [nb_target_boxes]
        """
        assert 'pred_logits' in outputs
        src_logits = outputs['pred_logits']

        idx = self._get_src_permutation_idx(indices)
        target_classes_o = torch.cat([t["labels"][J] for t, (_, J) in zip(targets, indices)])
        target_classes = torch.full(src_logits.shape[:2], self.num_classes,
                                    dtype=torch.int64, device=src_logits.device)
        target_classes[idx] = target_classes_o

        target_classes_onehot = torch.zeros([src_logits.shape[0], src_logits.shape[1], src_logits.shape[2] + 1],
                                            dtype=src_logits.dtype, layout=src_logits.layout, device=src_logits.device)
        target_classes_onehot.scatter_(2, target_classes.unsqueeze(-1), 1)

        target_classes_onehot = target_classes_onehot[:,:,:-1]
        loss_ce = sigmoid_focal_loss(src_logits, target_classes_onehot, num_boxes, alpha=self.focal_alpha, gamma=2)
        loss_ce = loss_ce * src_logits.shape[1]
        losses = {'loss_ce': loss_ce}

        if log:
            # TODO this should probably be a separate loss, not hacked in this one here
            losses['class_error'] = 100 - accuracy(src_logits[idx], target_classes_o)[0]
        return losses

    @torch.no_grad()
    def loss_cardinality(self, outputs, targets, indices, num_boxes):
        """ Compute the cardinality error, ie the absolute error in the number of predicted non-empty boxes
        This is not really a loss, it is intended for logging purposes only. It doesn't propagate gradients
        """
        pred_logits = outputs['pred_logits']
        device = pred_logits.device
        tgt_lengths = torch.as_tensor([len(v["labels"]) for v in targets], device=device)
        # Count the number of predictions that are NOT "no-object" (which is the last class)
        card_pred = (pred_logits.argmax(-1) != pred_logits.shape[-1] - 1).sum(1)
        card_err = F.l1_loss(card_pred.float(), tgt_lengths.float())
        losses = {'cardinality_error': card_err}
        return losses

    def loss_boxes(self, outputs, targets, indices, num_boxes):
        """Compute the losses related to the bounding boxes, the L1 regression loss and the GIoU loss
           targets dicts must contain the key "boxes" containing a tensor of dim [nb_target_boxes, 4]
           The target boxes are expected in format (center_x, center_y, h, w), normalized by the image size.
        """
        assert 'pred_boxes' in outputs
        idx = self._get_src_permutation_idx(indices)
        src_boxes = outputs['pred_boxes'][idx]
        target_boxes = torch.cat([t['boxes'][i] for t, (_, i) in zip(targets, indices)], dim=0)

        loss_bbox = F.l1_loss(src_boxes, target_boxes, reduction='none')

        losses = {}
        losses['loss_bbox'] = loss_bbox.sum() / num_boxes

        loss_giou = 1 - torch.diag(box_ops.generalized_box_iou(
            box_ops.box_cxcywh_to_xyxy(src_boxes),
            box_ops.box_cxcywh_to_xyxy(target_boxes)))
        losses['loss_giou'] = loss_giou.sum() / num_boxes
        return losses

    def loss_masks(self, outputs, targets, indices, num_boxes):
        """Compute the losses related to the masks: the focal loss and the dice loss.
           targets dicts must contain the key "masks" containing a tensor of dim [nb_target_boxes, h, w]
        """
        assert "pred_masks" in outputs

        src_idx = self._get_src_permutation_idx(indices)
        tgt_idx = self._get_tgt_permutation_idx(indices)

        src_masks = outputs["pred_masks"]

        # TODO use valid to mask invalid areas due to padding in loss
        target_masks, valid = nested_tensor_from_tensor_list([t["masks"] for t in targets]).decompose()
        target_masks = target_masks.to(src_masks)

        src_masks = src_masks[src_idx]
        # upsample predictions to the target size
        src_masks = interpolate(src_masks[:, None], size=target_masks.shape[-2:],
                                mode="bilinear", align_corners=False)
        src_masks = src_masks[:, 0].flatten(1)

        target_masks = target_masks[tgt_idx].flatten(1)

        losses = {
            "loss_mask": sigmoid_focal_loss(src_masks, target_masks, num_boxes),
            "loss_dice": dice_loss(src_masks, target_masks, num_boxes),
        }
        return losses
    
    def loss_mask_prediction(self, outputs, targets, indices, num_boxes, layer=None):
        assert "backbone_mask_prediction" in outputs
        assert "sampling_locations_dec" in outputs
        assert "attn_weights_dec" in outputs
        assert "spatial_shapes" in outputs
        assert "level_start_index" in outputs

        mask_prediction = outputs["backbone_mask_prediction"] 
        loss_key = "loss_mask_prediction"

        sampling_locations_dec = outputs["sampling_locations_dec"]
        attn_weights_dec = outputs["attn_weights_dec"]
        spatial_shapes = outputs["spatial_shapes"]
        level_start_index = outputs["level_start_index"]

        flat_grid_attn_map_dec = attn_map_to_flat_grid(
            spatial_shapes, level_start_index, sampling_locations_dec, attn_weights_dec).sum(dim=(1,2))

        losses = {}

        if 'mask_flatten' in outputs:
            flat_grid_attn_map_dec = flat_grid_attn_map_dec.masked_fill(
                outputs['mask_flatten'], flat_grid_attn_map_dec.min()-1)
                
        sparse_token_nums = outputs["sparse_token_nums"]
        num_topk = sparse_token_nums.max()

        topk_idx_tgt = torch.topk(flat_grid_attn_map_dec, num_topk)[1]
        target = torch.zeros_like(mask_prediction)
        for i in range(target.shape[0]):
            target[i].scatter_(0, topk_idx_tgt[i][:sparse_token_nums[i]], 1)

        losses.update({loss_key: F.multilabel_soft_margin_loss(mask_prediction, target)})

        return losses

    @torch.no_grad()
    def corr(self, outputs, targets, indices, num_boxes):
        if "backbone_topk_proposals" not in outputs.keys():
            return {}

        assert "backbone_topk_proposals" in outputs
        assert "sampling_locations_dec" in outputs
        assert "attn_weights_dec" in outputs
        assert "spatial_shapes" in outputs
        assert "level_start_index" in outputs

        backbone_topk_proposals = outputs["backbone_topk_proposals"]
        sampling_locations_dec = outputs["sampling_locations_dec"]
        attn_weights_dec = outputs["attn_weights_dec"]
        spatial_shapes = outputs["spatial_shapes"]
        level_start_index = outputs["level_start_index"]

        flat_grid_topk = idx_to_flat_grid(spatial_shapes, backbone_topk_proposals)
        flat_grid_attn_map_dec = attn_map_to_flat_grid(
            spatial_shapes, level_start_index, sampling_locations_dec, attn_weights_dec).sum(dim=(1,2))
        corr = compute_corr(flat_grid_topk, flat_grid_attn_map_dec, spatial_shapes)

        losses = {}
        losses["corr_mask_attn_map_dec_all"] = corr[0].mean()
        for i, _corr in enumerate(corr[1:]):
            losses[f"corr_mask_attn_map_dec_{i}"] = _corr.mean()
        return losses

    def _get_src_permutation_idx(self, indices):
        # permute predictions following indices
        batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)])
        src_idx = torch.cat([src for (src, _) in indices])
        return batch_idx, src_idx

    def _get_tgt_permutation_idx(self, indices):
        # permute targets following indices
        batch_idx = torch.cat([torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)])
        tgt_idx = torch.cat([tgt for (_, tgt) in indices])
        return batch_idx, tgt_idx

    def get_loss(self, loss, outputs, targets, indices, num_boxes, **kwargs):
        loss_map = {
            'labels': self.loss_labels,
            'cardinality': self.loss_cardinality,
            'boxes': self.loss_boxes,
            'masks': self.loss_masks,
            "mask_prediction": self.loss_mask_prediction,
            "corr": self.corr,
        }
        assert loss in loss_map, f'do you really want to compute {loss} loss?'
        return loss_map[loss](outputs, targets, indices, num_boxes, **kwargs)

    def forward(self, outputs, targets):
        """ This performs the loss computation.
        Parameters:
             outputs: dict of tensors, see the output specification of the model for the format
             targets: list of dicts, such that len(targets) == batch_size.
                      The expected keys in each dict depends on the losses applied, see each loss' doc
        """
        outputs_without_aux = {k: v for k, v in outputs.items() 
                               if k not in ['aux_outputs', 'enc_outputs', 'backbone_outputs', 'mask_flatten']}

        # Retrieve the matching between the outputs of the last layer and the targets
        indices = self.matcher(outputs_without_aux, targets)

        # Compute the average number of target boxes accross all nodes, for normalization purposes
        num_boxes = sum(len(t["labels"]) for t in targets)
        num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, device=next(iter(outputs.values())).device)
        if is_dist_avail_and_initialized():
            torch.distributed.all_reduce(num_boxes)
        num_boxes = torch.clamp(num_boxes / get_world_size(), min=1).item()

        # Compute all the requested losses
        losses = {}
        for loss in self.losses:
            kwargs = {}
            losses.update(self.get_loss(loss, outputs, targets, indices, num_boxes, **kwargs))

        # In case of auxiliary losses, we repeat this process with the output of each intermediate layer.
        if 'aux_outputs' in outputs:
            for i, aux_outputs in enumerate(outputs['aux_outputs']):
                indices = self.matcher(aux_outputs, targets)
                for loss in self.losses:
                    if loss in ['masks', "mask_prediction", "corr"]:
                        # Intermediate masks losses are too costly to compute, we ignore them.
                        continue
                    kwargs = {}
                    if loss == 'labels':
                        # Logging is enabled only for the last layer
                        kwargs['log'] = False
                    l_dict = self.get_loss(loss, aux_outputs, targets, indices, num_boxes, **kwargs)
                    l_dict = {k + f'_{i}': v for k, v in l_dict.items()}
                    losses.update(l_dict)

        if 'enc_outputs' in outputs:
            enc_outputs = outputs['enc_outputs']
            bin_targets = copy.deepcopy(targets)
            if not self.eff_specific_head:
                for bt in bin_targets:
                    bt['labels'] = torch.zeros_like(bt['labels'])  # all labels are zero (meaning foreground)
            indices = self.matcher(enc_outputs, bin_targets)
            for loss in self.losses:
                if loss in ['masks', "mask_prediction", "corr"]:
                    # Intermediate masks losses are too costly to compute, we ignore them.
                    continue
                kwargs = {}
                if loss == 'labels':
                    # Logging is enabled only for the last layer
                    kwargs['log'] = False
                l_dict = self.get_loss(loss, enc_outputs, bin_targets, indices, num_boxes, **kwargs)
                l_dict = {k + f'_enc': v for k, v in l_dict.items()}
                losses.update(l_dict)

        if 'backbone_outputs' in outputs:
            backbone_outputs = outputs['backbone_outputs']
            bin_targets = copy.deepcopy(targets)
            if not self.eff_specific_head:
                for bt in bin_targets:
                    bt['labels'] = torch.zeros_like(bt['labels'])  # all labels are zero (meaning foreground)
            indices = self.matcher(backbone_outputs, bin_targets)
            for loss in self.losses:
                if loss in ['masks', "mask_prediction", "corr"]:
                    # Intermediate masks losses are too costly to compute, we ignore them.
                    continue
                kwargs = {}
                if loss == 'labels':
                    # Logging is enabled only for the last layer
                    kwargs['log'] = False
                l_dict = self.get_loss(loss, backbone_outputs, bin_targets, indices, num_boxes, **kwargs)
                l_dict = {k + f'_backbone': v for k, v in l_dict.items()}
                losses.update(l_dict)
                
        if 'aux_outputs_enc' in outputs:
            for i, aux_outputs in enumerate(outputs['aux_outputs_enc']):
                indices = self.matcher(aux_outputs, targets)
                for loss in self.losses:
                    if loss in ['masks', "mask_prediction", "corr"]:
                        # Intermediate masks losses are too costly to compute, we ignore them.
                        continue
                    kwargs = {}
                    if loss == 'labels':
                        # Logging is enabled only for the last layer
                        kwargs['log'] = False
                    l_dict = self.get_loss(loss, aux_outputs, targets, indices, num_boxes, **kwargs)
                    l_dict = {k + f'_enc_{i}': v for k, v in l_dict.items()}
                    losses.update(l_dict)

        return losses


class PostProcess(nn.Module):
    """ This module converts the model's output into the format expected by the coco api"""

    @torch.no_grad()
    def forward(self, outputs, target_sizes):
        """ Perform the computation
        Parameters:
            outputs: raw outputs of the model
            target_sizes: tensor of dimension [batch_size x 2] containing the size of each images of the batch
                          For evaluation, this must be the original image size (before any data augmentation)
                          For visualization, this should be the image size after data augment, but before padding
        """
        out_logits, out_bbox = outputs['pred_logits'], outputs['pred_boxes']

        assert len(out_logits) == len(target_sizes)
        assert target_sizes.shape[1] == 2

        prob = out_logits.sigmoid()
        topk_values, topk_indexes = torch.topk(prob.view(out_logits.shape[0], -1), 100, dim=1)
        scores = topk_values
        topk_boxes = topk_indexes // out_logits.shape[2]
        labels = topk_indexes % out_logits.shape[2]
        boxes = box_ops.box_cxcywh_to_xyxy(out_bbox)
        boxes = torch.gather(boxes, 1, topk_boxes.unsqueeze(-1).repeat(1,1,4))

        # and from relative [0, 1] to absolute [0, height] coordinates
        img_h, img_w = target_sizes.unbind(1)
        scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1)
        boxes = boxes * scale_fct[:, None, :]

        results = [{'scores': s, 'labels': l, 'boxes': b} for s, l, b in zip(scores, labels, boxes)]

        return results


class MLP(nn.Module):
    """ Very simple multi-layer perceptron (also called FFN)"""

    def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
        super().__init__()
        self.num_layers = num_layers
        h = [hidden_dim] * (num_layers - 1)
        self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))

    def forward(self, x):
        for i, layer in enumerate(self.layers):
            x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
        return x


def build(args):
    num_classes = 20 if args.dataset_file != 'coco' else 91
    if args.dataset_file == "coco_panoptic":
        num_classes = 250
    device = torch.device(args.device)

    backbone = build_backbone(args)

    transformer = build_deforamble_transformer(args)
    model = DeformableDETR(
        backbone,
        transformer,
        num_classes=num_classes,
        num_queries=args.num_queries,
        num_feature_levels=args.num_feature_levels,
        aux_loss=args.aux_loss,
        with_box_refine=args.with_box_refine,
        two_stage=args.two_stage,
        args=args,
    )
    if args.masks:
        model = DETRsegm(model, freeze_detr=(args.frozen_weights is not None))
    matcher = build_matcher(args)
    weight_dict = {'loss_ce': args.cls_loss_coef, 'loss_bbox': args.bbox_loss_coef}
    weight_dict['loss_giou'] = args.giou_loss_coef

    if args.masks:
        weight_dict["loss_mask"] = args.mask_loss_coef
        weight_dict["loss_dice"] = args.dice_loss_coef
        
    # TODO this is a hack
    aux_weight_dict = {}
    
    if args.aux_loss:
        for i in range(args.dec_layers - 1):
            aux_weight_dict.update({k + f'_{i}': v for k, v in weight_dict.items()})
            
    if args.two_stage:
        aux_weight_dict.update({k + f'_enc': v for k, v in weight_dict.items()})
        
    if args.use_enc_aux_loss:
        for i in range(args.enc_layers - 1):
            aux_weight_dict.update({k + f'_enc_{i}': v for k, v in weight_dict.items()})
            
    if args.rho:
        aux_weight_dict.update({k + f'_backbone': v for k, v in weight_dict.items()})
        
    if aux_weight_dict:
        weight_dict.update(aux_weight_dict)

    weight_dict['loss_mask_prediction'] = args.mask_prediction_coef

    losses = ['labels', 'boxes', 'cardinality', "corr"]
    if args.masks:
        losses += ["masks"]
    if args.rho:
        losses += ["mask_prediction"]
    
    # num_classes, matcher, weight_dict, losses, focal_alpha=0.25
    criterion = SetCriterion(num_classes, matcher, weight_dict, losses, args)
    criterion.to(device)
    postprocessors = {'bbox': PostProcess()}
    if args.masks:
        postprocessors['segm'] = PostProcessSegm()
        if args.dataset_file == "coco_panoptic":
            is_thing_map = {i: i <= 90 for i in range(201)}
            postprocessors["panoptic"] = PostProcessPanoptic(is_thing_map, threshold=0.85)

    return model, criterion, postprocessors


================================================
FILE: models/deformable_transformer.py
================================================
# ------------------------------------------------------------------------------------
# Sparse DETR
# Copyright (c) 2021 KakaoBrain. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------------------
# Modified from Deformable DETR (https://github.com/fundamentalvision/Deformable-DETR)
# Copyright (c) 2020 SenseTime. All Rights Reserved.
# ------------------------------------------------------------------------------------
# Modified from DETR (https://github.com/facebookresearch/detr)
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
# ------------------------------------------------------------------------------------


import copy
from typing import Optional, List
import math

import torch
import torch.nn.functional as F
from torch import nn, Tensor
from torch.nn.init import xavier_uniform_, constant_, uniform_, normal_

from util.misc import inverse_sigmoid
from models.ops.modules import MSDeformAttn


class DeformableTransformer(nn.Module):
    def __init__(self, d_model=256, nhead=8,
                 num_encoder_layers=6, num_decoder_layers=6, dim_feedforward=1024, dropout=0.1,
                 activation="relu", return_intermediate_dec=False,
                 num_feature_levels=4, dec_n_points=4,  enc_n_points=4,
                 two_stage=False, two_stage_num_proposals=300,
                 args=None):
        super().__init__()

        self.d_model = d_model
        self.nhead = nhead
        self.two_stage = two_stage
        self.two_stage_num_proposals = two_stage_num_proposals
        self.eff_query_init = args.eff_query_init
        self.eff_specific_head = args.eff_specific_head
        # there's no need to compute reference points if above 2 conditions meet simultaneously
        self._log_args('eff_query_init', 'eff_specific_head')

        self.rho = args.rho
        self.use_enc_aux_loss = args.use_enc_aux_loss
        self.sparse_enc_head = 1 if self.two_stage and self.rho else 0

        if self.rho:
            self.enc_mask_predictor = MaskPredictor(self.d_model, self.d_model)
        else:
            self.enc_mask_predictor = None

        encoder_layer = DeformableTransformerEncoderLayer(d_model, dim_feedforward, dropout, activation, 
                                                            num_feature_levels, nhead, enc_n_points)
        self.encoder = DeformableTransformerEncoder(encoder_layer, num_encoder_layers, self.d_model)

        decoder_layer = DeformableTransformerDecoderLayer(d_model, dim_feedforward,
                                                          dropout, activation,
                                                          num_feature_levels, nhead, dec_n_points)
        self.decoder = DeformableTransformerDecoder(decoder_layer, num_decoder_layers, return_intermediate_dec)

        self.level_embed = nn.Parameter(torch.Tensor(num_feature_levels, d_model))

        if self.two_stage:
            self.enc_output = nn.Linear(d_model, d_model)
            self.enc_output_norm = nn.LayerNorm(d_model)
            
        if self.two_stage:
            self.pos_trans = nn.Linear(d_model * 2, d_model * (1 if self.eff_query_init else 2))
            self.pos_trans_norm = nn.LayerNorm(d_model * (1 if self.eff_query_init else 2))
    
        if not self.two_stage:
            self.reference_points = nn.Linear(d_model, 2)

        self._reset_parameters()
        
    def _log_args(self, *names):
        print('==============')
        print("\n".join([f"{name}: {getattr(self, name)}" for name in names]))
        print('==============')

    def _reset_parameters(self):
        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)
        for m in self.modules():
            if isinstance(m, MSDeformAttn):
                m._reset_parameters()
        if hasattr(self, 'reference_points'):
            xavier_uniform_(self.reference_points.weight.data, gain=1.0)
            constant_(self.reference_points.bias.data, 0.)
        normal_(self.level_embed)

    def get_proposal_pos_embed(self, proposals):
        # proposals: N, L(top_k), 4(bbox coords.)
        num_pos_feats = 128
        temperature = 10000
        scale = 2 * math.pi

        dim_t = torch.arange(num_pos_feats, dtype=torch.float32, device=proposals.device)  # 128
        dim_t = temperature ** (2 * (dim_t // 2) / num_pos_feats)
        proposals = proposals.sigmoid() * scale  # N, L, 4
        pos = proposals[:, :, :, None] / dim_t  # N, L, 4, 128
        # apply sin/cos alternatively
        pos = torch.stack((pos[:, :, :, 0::2].sin(), pos[:, :, :, 1::2].cos()), dim=4)  # N, L, 4, 64, 2
        pos = pos.flatten(2)  # N, L, 512 (4 x 128)
        return pos

    def gen_encoder_output_proposals(self, memory, memory_padding_mask, spatial_shapes, process_output=True):
        """Make region proposals for each multi-scale features considering their shapes and padding masks, 
        and project & normalize the encoder outputs corresponding to these proposals.
            - center points: relative grid coordinates in the range of [0.01, 0.99] (additional mask)
            - width/height:  2^(layer_id) * s (s=0.05) / see the appendix A.4
        
        Tensor shape example:
            Args:
                memory: torch.Size([2, 15060, 256])
                memory_padding_mask: torch.Size([2, 15060])
                spatial_shape: torch.Size([4, 2])
            Returns:
                output_memory: torch.Size([2, 15060, 256])
                    - same shape with memory ( + additional mask + linear layer + layer norm )
                output_proposals: torch.Size([2, 15060, 4]) 
                    - x, y, w, h
        """
        N_, S_, C_ = memory.shape
        proposals = []
        _cur = 0
        for lvl, (H_, W_) in enumerate(spatial_shapes):
            # level of encoded feature scale
            mask_flatten_ = memory_padding_mask[:, _cur:(_cur + H_ * W_)].view(N_, H_, W_, 1)
            valid_H = torch.sum(~mask_flatten_[:, :, 0, 0], 1)
            valid_W = torch.sum(~mask_flatten_[:, 0, :, 0], 1)

            grid_y, grid_x = torch.meshgrid(torch.linspace(0, H_ - 1, H_, dtype=torch.float32, device=memory.device),
                                            torch.linspace(0, W_ - 1, W_, dtype=torch.float32, device=memory.device))
            grid = torch.cat([grid_x.unsqueeze(-1), grid_y.unsqueeze(-1)], -1)

            scale = torch.cat([valid_W.unsqueeze(-1), valid_H.unsqueeze(-1)], 1).view(N_, 1, 1, 2)
            grid = (grid.unsqueeze(0).expand(N_, -1, -1, -1) + 0.5) / scale
            wh = torch.ones_like(grid) * 0.05 * (2.0 ** lvl)
            proposal = torch.cat((grid, wh), -1).view(N_, -1, 4)
            proposals.append(proposal)
            _cur += (H_ * W_)
            
        output_proposals = torch.cat(proposals, 1)
        output_proposals_valid = ((output_proposals > 0.01) & (output_proposals < 0.99)).all(-1, keepdim=True)  
        output_proposals = torch.log(output_proposals / (1 - output_proposals))  # inverse of sigmoid
        output_proposals = output_proposals.masked_fill(memory_padding_mask.unsqueeze(-1), float('inf')) 
        output_proposals = output_proposals.masked_fill(~output_proposals_valid, float('inf'))  # sigmoid(inf) = 1

        output_memory = memory
        if process_output:
            output_memory = output_memory.masked_fill(memory_padding_mask.unsqueeze(-1), float(0))
            output_memory = output_memory.masked_fill(~output_proposals_valid, float(0))
            output_memory = self.enc_output_norm(self.enc_output(output_memory))
        return output_memory, output_proposals, (~memory_padding_mask).sum(axis=-1)

    def get_valid_ratio(self, mask):
        _, H, W = mask.shape
        valid_H = torch.sum(~mask[:, :, 0], 1)
        valid_W = torch.sum(~mask[:, 0, :], 1)
        valid_ratio_h = valid_H.float() / H
        valid_ratio_w = valid_W.float() / W
        valid_ratio = torch.stack([valid_ratio_w, valid_ratio_h], -1)
        return valid_ratio

    def forward(self, srcs, masks, pos_embeds, query_embed=None):
        assert self.two_stage or query_embed is not None

        ###########
        # prepare input for encoder
        src_flatten = []
        mask_flatten = []
        lvl_pos_embed_flatten = []
        spatial_shapes = []
        for lvl, (src, mask, pos_embed) in enumerate(zip(srcs, masks, pos_embeds)):
            bs, c, h, w = src.shape
            spatial_shape = (h, w)
            spatial_shapes.append(spatial_shape)
            src = src.flatten(2).transpose(1, 2)
            mask = mask.flatten(1)
            pos_embed = pos_embed.flatten(2).transpose(1, 2)
            lvl_pos_embed = pos_embed + self.level_embed[lvl].view(1, 1, -1)
            lvl_pos_embed_flatten.append(lvl_pos_embed)
            src_flatten.append(src)
            mask_flatten.append(mask)
        src_flatten = torch.cat(src_flatten, 1)
        mask_flatten = torch.cat(mask_flatten, 1)
        lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1)
        spatial_shapes = torch.as_tensor(spatial_shapes, dtype=torch.long, device=src_flatten.device)
        level_start_index = torch.cat((spatial_shapes.new_zeros((1, )), spatial_shapes.prod(1).cumsum(0)[:-1]))
        # valid ratios across multi-scale features of the same image can be varied,
        # while they are interpolated and binarized on different resolutions.
        valid_ratios = torch.stack([self.get_valid_ratio(m) for m in masks], 1)

        ###########
        # prepare for sparse encoder
        if self.rho or self.use_enc_aux_loss:
            backbone_output_memory, backbone_output_proposals, valid_token_nums = self.gen_encoder_output_proposals(
                src_flatten+lvl_pos_embed_flatten, mask_flatten, spatial_shapes, 
                process_output=bool(self.rho))
            self.valid_token_nums = valid_token_nums

        if self.rho:
            sparse_token_nums = (valid_token_nums * self.rho).int() + 1
            backbone_topk = int(max(sparse_token_nums))
            self.sparse_token_nums = sparse_token_nums

            backbone_topk = min(backbone_topk, backbone_output_memory.shape[1])

            backbone_mask_prediction = self.enc_mask_predictor(backbone_output_memory).squeeze(-1)
            # excluding pad area
            backbone_mask_prediction = backbone_mask_prediction.masked_fill(mask_flatten, backbone_mask_prediction.min())
            backbone_topk_proposals = torch.topk(backbone_mask_prediction, backbone_topk, dim=1)[1]
        else:
            backbone_topk_proposals = None
            backbone_outputs_class = None
            backbone_outputs_coord_unact = None
            sparse_token_nums= None

        ###########
        # encoder
        if self.encoder:       
            output_proposals = backbone_output_proposals if self.use_enc_aux_loss else None    
            encoder_output = self.encoder(src_flatten, spatial_shapes, level_start_index, valid_ratios, 
                                  pos=lvl_pos_embed_flatten, padding_mask=mask_flatten, 
                                  topk_inds=backbone_topk_proposals, output_proposals=output_proposals,
                                  sparse_token_nums=sparse_token_nums)
            
            memory, sampling_locations_enc, attn_weights_enc = encoder_output[:3]

            if self.use_enc_aux_loss:
                enc_inter_outputs_class, enc_inter_outputs_coord_unact = encoder_output[3:5]            
        else:
            memory = src_flatten + lvl_pos_embed_flatten

        ###########
        # prepare input for decoder
        bs, _, c = memory.shape  # torch.Size([N, L, 256])
        topk_proposals = None
        if self.two_stage:
            # finalize the first stage output
            # project & normalize the memory and make proposal bounding boxes on them
            output_memory, output_proposals, _ = self.gen_encoder_output_proposals(memory, mask_flatten, spatial_shapes)

            # hack implementation for two-stage Deformable DETR (using the last layer registered in class/bbox_embed)
            # 1) a linear projection for bounding box binary classification (fore/background)
            enc_outputs_class = self.decoder.class_embed[self.decoder.num_layers](output_memory)
            # 2) 3-layer FFN for bounding box regression
            enc_outputs_coord_offset = self.decoder.bbox_embed[self.decoder.num_layers](output_memory)
            enc_outputs_coord_unact = output_proposals + enc_outputs_coord_offset  # appendix A.4

            # top scoring bounding boxes are picked as the final region proposals. 
            # these proposals are fed into the decoder as initial boxes for the iterative bounding box refinement.
            topk = self.two_stage_num_proposals
            # enc_outputs_class: torch.Size([N, L, 91])
            
            if self.eff_specific_head:
                # take the best score for judging objectness with class specific head
                enc_outputs_fg_class = enc_outputs_class.topk(1, dim=2).values[... , 0]
            else:
                # take the score from the binary(fore/background) classfier 
                # though outputs have 91 output dim, the 1st dim. alone will be used for the loss computation.
                enc_outputs_fg_class = enc_outputs_class[..., 0]
                
            topk_proposals = torch.topk(enc_outputs_fg_class, topk, dim=1)[1]
            topk_coords_unact = torch.gather(enc_outputs_coord_unact, 1, topk_proposals.unsqueeze(-1).repeat(1, 1, 4))
            topk_coords_unact = topk_coords_unact.detach()
            reference_points = topk_coords_unact.sigmoid()

            init_reference_out = reference_points
            # pos_embed -> linear layer -> layer norm
            pos_trans_out = self.pos_trans_norm(self.pos_trans(self.get_proposal_pos_embed(topk_coords_unact)))
            
            if self.eff_query_init:
                # Efficient-DETR uses top-k memory as the initialization of `tgt` (query vectors)
                tgt = torch.gather(memory, 1, topk_proposals.unsqueeze(-1).repeat(1, 1, memory.size(-1)))
                query_embed = pos_trans_out
            else:
                query_embed, tgt = torch.split(pos_trans_out, c, dim=2)

        else:
            query_embed, tgt = torch.split(query_embed, c, dim=1)
            query_embed = query_embed.unsqueeze(0).expand(bs, -1, -1)
            tgt = tgt.unsqueeze(0).expand(bs, -1, -1)
            reference_points = self.reference_points(query_embed).sigmoid()
            init_reference_out = reference_points

        ###########
        # decoder
        hs, inter_references, sampling_locations_dec, attn_weights_dec = self.decoder(tgt, reference_points, src=memory, src_spatial_shapes=spatial_shapes, 
                                            src_level_start_index=level_start_index, src_valid_ratios=valid_ratios, 
                                            query_pos=query_embed, src_padding_mask=mask_flatten,
                                            topk_inds=topk_proposals)

        inter_references_out = inter_references
        
        ret = []
        ret += [hs, init_reference_out, inter_references_out]
        ret += [enc_outputs_class, enc_outputs_coord_unact] if self.two_stage else [None] * 2        
        if self.rho:
            ret += [backbone_mask_prediction]
        else:
            ret += [None]
        ret += [enc_inter_outputs_class, enc_inter_outputs_coord_unact] if self.use_enc_aux_loss else [None] * 2
        ret += [sampling_locations_enc, attn_weights_enc, sampling_locations_dec, attn_weights_dec]
        ret += [backbone_topk_proposals, spatial_shapes, level_start_index]
        return ret


class DeformableTransformerEncoderLayer(nn.Module):
    def __init__(self,
                 d_model=256, d_ffn=1024,
                 dropout=0.1, activation="relu",
                 n_levels=4, n_heads=8, n_points=4):
        super().__init__()

        # self attention
        self.self_attn = MSDeformAttn(d_model, n_levels, n_heads, n_points)
        self.dropout1 = nn.Dropout(dropout)
        self.norm1 = nn.LayerNorm(d_model)

        # ffn
        self.linear1 = nn.Linear(d_model, d_ffn)
        self.activation = _get_activation_fn(activation)
        self.dropout2 = nn.Dropout(dropout)
        self.linear2 = nn.Linear(d_ffn, d_model)
        self.dropout3 = nn.Dropout(dropout)
        self.norm2 = nn.LayerNorm(d_model)

    @staticmethod
    def with_pos_embed(tensor, pos):
        return tensor if pos is None else tensor + pos

    def forward_ffn(self, src):
        src2 = self.linear2(self.dropout2(self.activation(self.linear1(src))))
        src = src + self.dropout3(src2)
        src = self.norm2(src)
        return src

    def forward(self, src, pos, reference_points, spatial_shapes, level_start_index, padding_mask=None, tgt=None):
        if tgt is None:
            # self attention
            src2, sampling_locations, attn_weights = self.self_attn(self.with_pos_embed(src, pos),
                                reference_points, src, spatial_shapes,
                                level_start_index, padding_mask)
            src = src + self.dropout1(src2)
            src = self.norm1(src)
            # torch.Size([2, 13101, 256])

            # ffn
            src = self.forward_ffn(src)

            return src, sampling_locations, attn_weights
        else:
            # self attention
            tgt2, sampling_locations, attn_weights = self.self_attn(self.with_pos_embed(tgt, pos),
                                reference_points, src, spatial_shapes,
                                level_start_index, padding_mask)
            tgt = tgt + self.dropout1(tgt2)
            tgt = self.norm1(tgt)

            # ffn
            tgt = self.forward_ffn(tgt)

            return tgt, sampling_locations, attn_weights



class DeformableTransformerEncoder(nn.Module):
    def __init__(self, encoder_layer, num_layers, mask_predictor_dim=256):
        super().__init__()
        self.layers = _get_clones(encoder_layer, num_layers)
        self.num_layers = num_layers
        # hack implementation
        self.aux_heads = False
        self.class_embed = None
        self.bbox_embed = None

    @staticmethod
    def get_reference_points(spatial_shapes, valid_ratios, device):
        """Make reference points for every single point on the multi-scale feature maps.
        Each point has K reference points on every the multi-scale features.
        """
        reference_points_list = []
        for lvl, (H_, W_) in enumerate(spatial_shapes):

            ref_y, ref_x = torch.meshgrid(torch.linspace(0.5, H_ - 0.5, H_, dtype=torch.float32, device=device),
                                          torch.linspace(0.5, W_ - 0.5, W_, dtype=torch.float32, device=device))
            ref_y = ref_y.reshape(-1)[None] / (valid_ratios[:, None, lvl, 1] * H_)
            ref_x = ref_x.reshape(-1)[None] / (valid_ratios[:, None, lvl, 0] * W_)
            # out-of-reference points have relative coords. larger than 1
            ref = torch.stack((ref_x, ref_y), -1)
            reference_points_list.append(ref)
        reference_points = torch.cat(reference_points_list, 1)
        reference_points = reference_points[:, :, None] * valid_ratios[:, None]
        # >>> reference_points[:, :, None].shape
        # torch.Size([2, 13101, 1, 2])
        # >>> valid_ratios[:, None].shape
        # torch.Size([2, 1, 4, 2])
        return reference_points

    def forward(self, src, spatial_shapes, level_start_index, valid_ratios, 
                pos=None, padding_mask=None, topk_inds=None, output_proposals=None, sparse_token_nums=None):
        if self.aux_heads:
            assert output_proposals is not None
        else:
            assert output_proposals is None
            
        output = src
        sparsified_keys = False if topk_inds is None else True
        reference_points = self.get_reference_points(spatial_shapes, valid_ratios, device=src.device)
        reference_points_orig = reference_points
        pos_orig = pos
        output_proposals_orig = output_proposals
        sampling_locations_all = []
        attn_weights_all = []
        if self.aux_heads:
            enc_inter_outputs_class = []
            enc_inter_outputs_coords = []
                    
        if sparsified_keys:
            assert topk_inds is not None
            B_, N_, S_, P_ = reference_points.shape
            reference_points = torch.gather(reference_points.view(B_, N_, -1), 1, topk_inds.unsqueeze(-1).repeat(1, 1, S_*P_)).view(B_, -1, S_, P_)
            tgt = torch.gather(output, 1, topk_inds.unsqueeze(-1).repeat(1, 1, output.size(-1)))
            pos = torch.gather(pos, 1, topk_inds.unsqueeze(-1).repeat(1, 1, pos.size(-1)))
            if output_proposals is not None:
                output_proposals = output_proposals.gather(1, topk_inds.unsqueeze(-1).repeat(1, 1, output_proposals.size(-1)))
        else:
            tgt = None

        for lid, layer in enumerate(self.layers):
            # if tgt is None: self-attention / if tgt is not None: cross-attention w.r.t. the target queries
            tgt, sampling_locations, attn_weights = layer(output, pos, reference_points, spatial_shapes, level_start_index, padding_mask, 
                        tgt=tgt if sparsified_keys else None)
            sampling_locations_all.append(sampling_locations)
            attn_weights_all.append(attn_weights)
            if sparsified_keys:                
                if sparse_token_nums is None:
                    output = output.scatter(1, topk_inds.unsqueeze(-1).repeat(1, 1, tgt.size(-1)), tgt)
                else:
                    outputs = []
                    for i in range(topk_inds.shape[0]):
                        outputs.append(output[i].scatter(0, topk_inds[i][:sparse_token_nums[i]].unsqueeze(-1).repeat(1, tgt.size(-1)), tgt[i][:sparse_token_nums[i]]))
                    output = torch.stack(outputs)
            else:
                output = tgt
            
            if self.aux_heads and lid < self.num_layers - 1:
                # feed outputs to aux. heads
                output_class = self.class_embed[lid](tgt)
                output_offset = self.bbox_embed[lid](tgt)
                output_coords_unact = output_proposals + output_offset
                # values to be used for loss compuation
                enc_inter_outputs_class.append(output_class)
                enc_inter_outputs_coords.append(output_coords_unact.sigmoid())

        # Change dimension from [num_layer, batch_size, ...] to [batch_size, num_layer, ...]
        sampling_locations_all = torch.stack(sampling_locations_all, dim=1)
        attn_weights_all = torch.stack(attn_weights_all, dim=1)
        
        ret = [output, sampling_locations_all, attn_weights_all]

        if self.aux_heads:
            ret += [enc_inter_outputs_class, enc_inter_outputs_coords]
        
        return ret


class DeformableTransformerDecoderLayer(nn.Module):
    def __init__(self, d_model=256, d_ffn=1024, dropout=0.1, activation="relu",
                 n_levels=4, n_heads=8, n_points=4):
        super().__init__()

        # cross attention
        self.cross_attn = MSDeformAttn(d_model, n_levels, n_heads, n_points)
        self.dropout1 = nn.Dropout(dropout)
        self.norm1 = nn.LayerNorm(d_model)

        # self attention
        self.self_attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout)
        self.dropout2 = nn.Dropout(dropout)
        self.norm2 = nn.LayerNorm(d_model)

        # ffn
        self.linear1 = nn.Linear(d_model, d_ffn)
        self.activation = _get_activation_fn(activation)
        self.dropout3 = nn.Dropout(dropout)
        self.linear2 = nn.Linear(d_ffn, d_model)
        self.dropout4 = nn.Dropout(dropout)
        self.norm3 = nn.LayerNorm(d_model)

    @staticmethod
    def with_pos_embed(tensor, pos):
        return tensor if pos is None else tensor + pos

    def forward_ffn(self, tgt):
        tgt2 = self.linear2(self.dropout3(self.activation(self.linear1(tgt))))
        tgt = tgt + self.dropout4(tgt2)
        tgt = self.norm3(tgt)
        return tgt

    def forward(self, tgt, query_pos, reference_points, src, src_spatial_shapes, 
                level_start_index, src_padding_mask=None):
        # self attention
        q = k = self.with_pos_embed(tgt, query_pos) 
        tgt2 = self.self_attn(q.transpose(0, 1), k.transpose(0, 1), tgt.transpose(0, 1))[0].transpose(0, 1)
        tgt = tgt + self.dropout2(tgt2)
        tgt = self.norm2(tgt)

        # cross attention
        assert reference_points is not None, "deformable attention needs reference points!"
        tgt2, sampling_locations, attn_weights = self.cross_attn(self.with_pos_embed(tgt, query_pos),
                                reference_points,
                                src, src_spatial_shapes, level_start_index, src_padding_mask)
            
        tgt = tgt + self.dropout1(tgt2)
        tgt = self.norm1(tgt)

        # ffn
        tgt = self.forward_ffn(tgt)
        # torch.Size([2, 300, 256])

        return tgt, sampling_locations, attn_weights


class DeformableTransformerDecoder(nn.Module):
    def __init__(self, decoder_layer, num_layers, return_intermediate=False):
        super().__init__()
        self.layers = _get_clones(decoder_layer, num_layers)
        self.num_layers = num_layers
        self.return_intermediate = return_intermediate
        # hack implementation for iterative bounding box refinement and two-stage Deformable DETR
        self.bbox_embed = None
        self.class_embed = None



    def forward(self, tgt, reference_points, src, src_spatial_shapes, src_level_start_index, 
                src_valid_ratios, query_pos=None, src_padding_mask=None, topk_inds=None):
        """
        Args:
            tgt: torch.Size([2, 300, 256]) (query vectors)
            reference_points: torch.Size([2, 300, 2])
            src: torch.Size([2, 13101, 256]) (last MS feature map from the encoder)
            query_pos: torch.Size([2, 300, 256]) (learned positional embedding of query vectors)
            - `tgt` and `query_pos` are originated from the same query embedding. 
            - `tgt` changes through the forward pass as object query vector 
               while `query_pos` does not and is added as positional embedding.
            
        Returns: (when return_intermediate=True)
            output: torch.Size([6, 2, 300, 256])
            reference_points: torch.Size([6, 2, 300, 2])
        """
        output = tgt

        intermediate = []
        intermediate_reference_points = []
        sampling_locations_all = []
        attn_weights_all = []
        for lid, layer in enumerate(self.layers):
            
            if reference_points is None:
                reference_points_input = None
            elif reference_points.shape[-1] == 4:
                # output from iterative bounding box refinement
                # reference_points: N, top_k, 4(x/y/w/h)
                # src_valid_ratios: N, num_feature_levels, 2(w/h)
                # reference_points_input: N, top_k, num_feature_levels, 4(x/y/w/h)
                reference_points_input = reference_points[:, :, None] \
                                        * torch.cat([src_valid_ratios, src_valid_ratios], -1)[:, None]
            else:
                assert reference_points.shape[-1] == 2
                reference_points_input = reference_points[:, :, None] * src_valid_ratios[:, None]
                
            output, sampling_locations, attn_weights = layer(output, query_pos, reference_points_input, src, src_spatial_shapes, 
                           src_level_start_index, src_padding_mask)
            sampling_locations_all.append(sampling_locations)
            attn_weights_all.append(attn_weights)

            # hack implementation for iterative bounding box refinement
            if self.bbox_embed is not None:
                assert reference_points is not None, "box refinement needs reference points!"
                tmp = self.bbox_embed[lid](output)
                if reference_points.shape[-1] == 4:
                    new_reference_points = tmp + inverse_sigmoid(reference_points)
                    new_reference_points = new_reference_points.sigmoid()
                else:
                    assert reference_points.shape[-1] == 2
                    new_reference_points = tmp
                    new_reference_points[..., :2] = tmp[..., :2] + inverse_sigmoid(reference_points)
                    new_reference_points = new_reference_points.sigmoid()
                reference_points = new_reference_points.detach()

            if self.return_intermediate:
                intermediate.append(output)
                intermediate_reference_points.append(reference_points)

        # Change dimension from [num_layer, batch_size, ...] to [batch_size, num_layer, ...]
        sampling_locations_all = torch.stack(sampling_locations_all, dim=1)
        attn_weights_all = torch.stack(attn_weights_all, dim=1)

        if self.return_intermediate:
            intermediate_outputs = torch.stack(intermediate)
            if intermediate_reference_points[0] is None:
                intermediate_reference_points = None
            else:
                intermediate_reference_points = torch.stack(intermediate_reference_points)

            return intermediate_outputs, intermediate_reference_points, sampling_locations_all, attn_weights_all

        return output, reference_points, sampling_locations_all, attn_weights_all


class MaskPredictor(nn.Module):
    def __init__(self, in_dim, h_dim):
        super().__init__()
        self.h_dim = h_dim
        self.layer1 = nn.Sequential(
            nn.LayerNorm(in_dim),
            nn.Linear(in_dim, h_dim),
            nn.GELU()
        )
        self.layer2 = nn.Sequential(
            nn.Linear(h_dim, h_dim // 2),
            nn.GELU(),
            nn.Linear(h_dim // 2, h_dim // 4),
            nn.GELU(),
            nn.Linear(h_dim // 4, 1)
        )
    
    def forward(self, x):
        z = self.layer1(x)
        z_local, z_global = torch.split(z, self.h_dim // 2, dim=-1)
        z_global = z_global.mean(dim=1, keepdim=True).expand(-1, z_local.shape[1], -1)
        z = torch.cat([z_local, z_global], dim=-1)
        out = self.layer2(z)
        return out
    

def _get_clones(module, N):
    return nn.ModuleList([copy.deepcopy(module) for i in range(N)])


def _get_activation_fn(activation):
    """Return an activation function given a string"""
    if activation == "relu":
        return F.relu
    if activation == "gelu":
        return F.gelu
    if activation == "glu":
        return F.glu
    raise RuntimeError(F"activation should be relu/gelu, not {activation}.")


def build_deforamble_transformer(args):
    return DeformableTransformer(
        d_model=args.hidden_dim,
        nhead=args.nheads,
        num_encoder_layers=args.enc_layers,
        num_decoder_layers=args.dec_layers,
        dim_feedforward=args.dim_feedforward,
        dropout=args.dropout,
        activation="relu",
        return_intermediate_dec=True,
        num_feature_levels=args.num_feature_levels,
        dec_n_points=args.dec_n_points,
        enc_n_points=args.enc_n_points,
        two_stage=args.two_stage,
        two_stage_num_proposals=args.num_queries,
        args=args)


================================================
FILE: models/matcher.py
================================================
# ------------------------------------------------------------------------------------
# Sparse DETR
# Copyright (c) 2021 KakaoBrain. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------------------
# Modified from Deformable DETR (https://github.com/fundamentalvision/Deformable-DETR)
# Copyright (c) 2020 SenseTime. All Rights Reserved.
# ------------------------------------------------------------------------------------
# Modified from DETR (https://github.com/facebookresearch/detr)
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
# ------------------------------------------------------------------------------------


"""
Modules to compute the matching cost and solve the corresponding LSAP.
"""
import torch
from scipy.optimize import linear_sum_assignment
from torch import nn

from util.box_ops import box_cxcywh_to_xyxy, generalized_box_iou


class HungarianMatcher(nn.Module):
    """This class computes an assignment between the targets and the predictions of the network

    For efficiency reasons, the targets don't include the no_object. Because of this, in general,
    there are more predictions than targets. In this case, we do a 1-to-1 matching of the best predictions,
    while the others are un-matched (and thus treated as non-objects).
    """

    def __init__(self,
                 cost_class: float = 1,
                 cost_bbox: float = 1,
                 cost_giou: float = 1):
        """Creates the matcher

        Params:
            cost_class: This is the relative weight of the classification error in the matching cost
            cost_bbox: This is the relative weight of the L1 error of the bounding box coordinates in the matching cost
            cost_giou: This is the relative weight of the giou loss of the bounding box in the matching cost
        """
        super().__init__()
        self.cost_class = cost_class
        self.cost_bbox = cost_bbox
        self.cost_giou = cost_giou
        assert cost_class != 0 or cost_bbox != 0 or cost_giou != 0, "all costs cant be 0"

    def forward(self, outputs, targets):
        """ Performs the matching

        Params:
            outputs: This is a dict that contains at least these entries:
                 "pred_logits": Tensor of dim [batch_size, num_queries, num_classes] with the classification logits
                 "pred_boxes": Tensor of dim [batch_size, num_queries, 4] with the predicted box coordinates

            targets: This is a list of targets (len(targets) = batch_size), where each target is a dict containing:
                 "labels": Tensor of dim [num_target_boxes] (where num_target_boxes is the number of ground-truth
                           objects in the target) containing the class labels
                 "boxes": Tensor of dim [num_target_boxes, 4] containing the target box coordinates

        Returns:
            A list of size batch_size, containing tuples of (index_i, index_j) where:
                - index_i is the indices of the selected predictions (in order)
                - index_j is the indices of the corresponding selected targets (in order)
            For each batch element, it holds:
                len(index_i) = len(index_j) = min(num_queries, num_target_boxes)
        """
        with torch.no_grad():
            bs, num_queries = outputs["pred_logits"].shape[:2]

            # We flatten to compute the cost matrices in a batch
            out_prob = outputs["pred_logits"].flatten(0, 1).sigmoid()
            out_bbox = outputs["pred_boxes"].flatten(0, 1)  # [batch_size * num_queries, 4]

            # Also concat the target labels and boxes
            tgt_ids = torch.cat([v["labels"] for v in targets])
            tgt_bbox = torch.cat([v["boxes"] for v in targets])

            # Compute the classification cost.
            alpha = 0.25
            gamma = 2.0
            neg_cost_class = (1 - alpha) * (out_prob ** gamma) * (-(1 - out_prob + 1e-8).log())
            pos_cost_class = alpha * ((1 - out_prob) ** gamma) * (-(out_prob + 1e-8).log())
            cost_class = pos_cost_class[:, tgt_ids] - neg_cost_class[:, tgt_ids]

            # Compute the L1 cost between boxes
            cost_bbox = torch.cdist(out_bbox, tgt_bbox, p=1)

            # Compute the giou cost betwen boxes
            cost_giou = -generalized_box_iou(box_cxcywh_to_xyxy(out_bbox),
                                             box_cxcywh_to_xyxy(tgt_bbox))

            # Final cost matrix
            C = self.cost_bbox * cost_bbox + self.cost_class * cost_class + self.cost_giou * cost_giou
            C = C.view(bs, num_queries, -1).cpu()

            sizes = [len(v["boxes"]) for v in targets]
            indices = [linear_sum_assignment(c[i]) for i, c in enumerate(C.split(sizes, -1))]
            return [(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor([_j % size for _j in j], dtype=torch.int64)) 
                    for (i, j), size in zip(indices, sizes)]


def build_matcher(args):
    return HungarianMatcher(cost_class=args.set_cost_class,
                            cost_bbox=args.set_cost_bbox,
                            cost_giou=args.set_cost_giou)


================================================
FILE: models/ops/functions/__init__.py
================================================
# ------------------------------------------------------------------------------------------------
# Deformable DETR
# Copyright (c) 2020 SenseTime. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------------------------------
# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
# ------------------------------------------------------------------------------------------------

from .ms_deform_attn_func import MSDeformAttnFunction



================================================
FILE: models/ops/functions/ms_deform_attn_func.py
================================================
# ------------------------------------------------------------------------------------------------
# Deformable DETR
# Copyright (c) 2020 SenseTime. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------------------------------
# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
# ------------------------------------------------------------------------------------------------

from __future__ import absolute_import
from __future__ import print_function
from __future__ import division

import torch
import torch.nn.functional as F
from torch.autograd import Function
from torch.autograd.function import once_differentiable

import MultiScaleDeformableAttention as MSDA


class MSDeformAttnFunction(Function):
    @staticmethod
    def forward(ctx, value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, im2col_step):
        ctx.im2col_step = im2col_step
        output = MSDA.ms_deform_attn_forward(
            value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, ctx.im2col_step)
        ctx.save_for_backward(value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights)
        return output

    @staticmethod
    @once_differentiable
    def backward(ctx, grad_output):
        value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights = ctx.saved_tensors
        grad_value, grad_sampling_loc, grad_attn_weight = \
            MSDA.ms_deform_attn_backward(
                value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, grad_output, ctx.im2col_step)

        return grad_value, None, None, grad_sampling_loc, grad_attn_weight, None


def ms_deform_attn_core_pytorch(value, value_spatial_shapes, sampling_locations, attention_weights):
    # for debug and test only,
    # need to use cuda version instead
    N_, S_, M_, D_ = value.shape
    _, Lq_, M_, L_, P_, _ = sampling_locations.shape
    value_list = value.split([H_ * W_ for H_, W_ in value_spatial_shapes], dim=1)
    sampling_grids = 2 * sampling_locations - 1
    sampling_value_list = []
    for lid_, (H_, W_) in enumerate(value_spatial_shapes):
        # N_, H_*W_, M_, D_ -> N_, H_*W_, M_*D_ -> N_, M_*D_, H_*W_ -> N_*M_, D_, H_, W_
        value_l_ = value_list[lid_].flatten(2).transpose(1, 2).reshape(N_*M_, D_, H_, W_)
        # N_, Lq_, M_, P_, 2 -> N_, M_, Lq_, P_, 2 -> N_*M_, Lq_, P_, 2
        sampling_grid_l_ = sampling_grids[:, :, :, lid_].transpose(1, 2).flatten(0, 1)
        # N_*M_, D_, Lq_, P_
        sampling_value_l_ = F.grid_sample(value_l_, sampling_grid_l_,
                                          mode='bilinear', padding_mode='zeros', align_corners=False)
        sampling_value_list.append(sampling_value_l_)
    # (N_, Lq_, M_, L_, P_) -> (N_, M_, Lq_, L_, P_) -> (N_, M_, 1, Lq_, L_*P_)
    attention_weights = attention_weights.transpose(1, 2).reshape(N_*M_, 1, Lq_, L_*P_)
    output = (torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights).sum(-1).view(N_, M_*D_, Lq_)
    return output.transpose(1, 2).contiguous()


================================================
FILE: models/ops/make.sh
================================================
#!/usr/bin/env bash
# ------------------------------------------------------------------------------------------------
# Deformable DETR
# Copyright (c) 2020 SenseTime. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------------------------------
# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
# ------------------------------------------------------------------------------------------------

python setup.py build install


================================================
FILE: models/ops/modules/__init__.py
================================================
# ------------------------------------------------------------------------------------
# Sparse DETR
# Copyright (c) 2021 KakaoBrain. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------------------------------
# Deformable DETR
# Copyright (c) 2020 SenseTime. All Rights Reserved.
# ------------------------------------------------------------------------------------------------
# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
# ------------------------------------------------------------------------------------------------

from .ms_deform_attn import MSDeformAttn


================================================
FILE: models/ops/modules/ms_deform_attn.py
================================================
# ------------------------------------------------------------------------------------
# Sparse DETR
# Copyright (c) 2021 KakaoBrain. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------------------------------
# Deformable DETR
# Copyright (c) 2020 SenseTime. All Rights Reserved.
# ------------------------------------------------------------------------------------------------
# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
# ------------------------------------------------------------------------------------------------

from __future__ import absolute_import
from __future__ import print_function
from __future__ import division

import warnings
import math

import torch
from torch import nn
import torch.nn.functional as F
from torch.nn.init import xavier_uniform_, constant_

from ..functions import MSDeformAttnFunction


def _is_power_of_2(n):
    if (not isinstance(n, int)) or (n < 0):
        raise ValueError("invalid input for _is_power_of_2: {} (type: {})".format(n, type(n)))
    return (n & (n-1) == 0) and n != 0


class MSDeformAttn(nn.Module):
    def __init__(self, d_model=256, n_levels=4, n_heads=8, n_points=4):
        """
        Multi-Scale Deformable Attention Module
        :param d_model      hidden dimension
        :param n_levels     number of feature levels
        :param n_heads      number of attention heads
        :param n_points     number of sampling points per attention head per feature level
        """
        super().__init__()
        if d_model % n_heads != 0:
            raise ValueError('d_model must be divisible by n_heads, but got {} and {}'.format(d_model, n_heads))
        _d_per_head = d_model // n_heads
        # you'd better set _d_per_head to a power of 2 which is more efficient in our CUDA implementation
        if not _is_power_of_2(_d_per_head):
            warnings.warn("You'd better set d_model in MSDeformAttn to make the dimension of each attention head a power of 2 "
                          "which is more efficient in our CUDA implementation.")

        self.im2col_step = 64

        self.d_model = d_model
        self.n_levels = n_levels
        self.n_heads = n_heads
        self.n_points = n_points

        self.sampling_offsets = nn.Linear(d_model, n_heads * n_levels * n_points * 2)
        self.attention_weights = nn.Linear(d_model, n_heads * n_levels * n_points)
        self.value_proj = nn.Linear(d_model, d_model)
        self.output_proj = nn.Linear(d_model, d_model)
        self.python_ops_for_test = False

        self._reset_parameters()

    def _reset_parameters(self):
        constant_(self.sampling_offsets.weight.data, 0.)
        thetas = torch.arange(self.n_heads, dtype=torch.float32) * (2.0 * math.pi / self.n_heads)
        grid_init = torch.stack([thetas.cos(), thetas.sin()], -1)
        grid_init = (grid_init / grid_init.abs().max(-1, keepdim=True)[0]).view(self.n_heads, 1, 1, 2).repeat(1, self.n_levels, self.n_points, 1)
        for i in range(self.n_points):
            grid_init[:, :, i, :] *= i + 1
        with torch.no_grad():
            self.sampling_offsets.bias = nn.Parameter(grid_init.view(-1))
        constant_(self.attention_weights.weight.data, 0.)
        constant_(self.attention_weights.bias.data, 0.)
        xavier_uniform_(self.value_proj.weight.data)
        constant_(self.value_proj.bias.data, 0.)
        xavier_uniform_(self.output_proj.weight.data)
        constant_(self.output_proj.bias.data, 0.)

    def forward(self, query, reference_points, input_flatten, input_spatial_shapes, input_level_start_index, input_padding_mask=None):
        """
        :param query                       (N, Length_{query}, C)
        :param reference_points            (N, Length_{query}, n_levels, 2), range in [0, 1], top-left (0,0), bottom-right (1, 1), including padding area
                                        or (N, Length_{query}, n_levels, 4), add additional (w, h) to form reference boxes
        :param input_flatten               (N, \sum_{l=0}^{L-1} H_l \cdot W_l, C)
        :param input_spatial_shapes        (n_levels, 2), [(H_0, W_0), (H_1, W_1), ..., (H_{L-1}, W_{L-1})]
        :param input_level_start_index     (n_levels, ), [0, H_0*W_0, H_0*W_0+H_1*W_1, H_0*W_0+H_1*W_1+H_2*W_2, ..., H_0*W_0+H_1*W_1+...+H_{L-1}*W_{L-1}]
        :param input_padding_mask          (N, \sum_{l=0}^{L-1} H_l \cdot W_l), True for padding elements, False for non-padding elements

        :return output                     (N, Length_{query}, C)
        """
        N, Len_q, _ = query.shape
        N, Len_in, _ = input_flatten.shape
        assert (input_spatial_shapes[:, 0] * input_spatial_shapes[:, 1]).sum() == Len_in

        value = self.value_proj(input_flatten)
        if input_padding_mask is not None:
            value = value.masked_fill(input_padding_mask[..., None], float(0))
        value = value.view(N, Len_in, self.n_heads, self.d_model // self.n_heads)
        sampling_offsets = self.sampling_offsets(query).view(N, Len_q, self.n_heads, self.n_levels, self.n_points, 2)
        attention_weights = self.attention_weights(query).view(N, Len_q, self.n_heads, self.n_levels * self.n_points)
        attention_weights = F.softmax(attention_weights, -1).view(N, Len_q, self.n_heads, self.n_levels, self.n_points)
        # N, Len_q, n_heads, n_levels, n_points, 2
        if reference_points.shape[-1] == 2:
            offset_normalizer = torch.stack([input_spatial_shapes[..., 1], input_spatial_shapes[..., 0]], -1)
            sampling_locations = reference_points[:, :, None, :, None, :] \
                                 + sampling_offsets / offset_normalizer[None, None, None, :, None, :]
        elif reference_points.shape[-1] == 4:
            sampling_locations = reference_points[:, :, None, :, None, :2] \
                                 + sampling_offsets / self.n_points * reference_points[:, :, None, :, None, 2:] * 0.5
        else:
            raise ValueError(
                'Last dim of reference_points must be 2 or 4, but get {} instead.'.format(reference_points.shape[-1]))
        if not self.python_ops_for_test:
            output = MSDeformAttnFunction.app
Download .txt
gitextract_r25kqp43/

├── LICENSE
├── NOTICE
├── README.md
├── configs/
│   ├── r50_deformable_detr.sh
│   ├── r50_efficient_detr.sh
│   ├── r50_sparse_detr_rho_0.1.sh
│   ├── r50_sparse_detr_rho_0.2.sh
│   ├── r50_sparse_detr_rho_0.3.sh
│   ├── swint_deformable_detr.sh
│   ├── swint_efficient_detr.sh
│   ├── swint_sparse_detr_rho_0.1.sh
│   ├── swint_sparse_detr_rho_0.2.sh
│   └── swint_sparse_detr_rho_0.3.sh
├── datasets/
│   ├── __init__.py
│   ├── coco.py
│   ├── coco_eval.py
│   ├── coco_panoptic.py
│   ├── data_prefetcher.py
│   ├── panoptic_eval.py
│   ├── samplers.py
│   ├── torchvision_datasets/
│   │   ├── __init__.py
│   │   └── coco.py
│   └── transforms.py
├── engine.py
├── main.py
├── models/
│   ├── __init__.py
│   ├── backbone.py
│   ├── deformable_detr.py
│   ├── deformable_transformer.py
│   ├── matcher.py
│   ├── ops/
│   │   ├── functions/
│   │   │   ├── __init__.py
│   │   │   └── ms_deform_attn_func.py
│   │   ├── make.sh
│   │   ├── modules/
│   │   │   ├── __init__.py
│   │   │   └── ms_deform_attn.py
│   │   ├── setup.py
│   │   ├── src/
│   │   │   ├── cpu/
│   │   │   │   ├── ms_deform_attn_cpu.cpp
│   │   │   │   └── ms_deform_attn_cpu.h
│   │   │   ├── cuda/
│   │   │   │   ├── ms_deform_attn_cuda.cu
│   │   │   │   ├── ms_deform_attn_cuda.h
│   │   │   │   └── ms_deform_im2col_cuda.cuh
│   │   │   ├── ms_deform_attn.h
│   │   │   └── vision.cpp
│   │   └── test.py
│   ├── position_encoding.py
│   ├── segmentation.py
│   └── swin_transformer/
│       ├── __init__.py
│       ├── build.py
│       ├── config.py
│       ├── configs/
│       │   ├── default.yaml
│       │   ├── swin_base_patch4_window7_224.yaml
│       │   ├── swin_large_patch4_window7_224.yaml
│       │   ├── swin_small_patch4_window7_224.yaml
│       │   └── swin_tiny_patch4_window7_224.yaml
│       └── swin_transformer.py
├── requirements.txt
├── tools/
│   ├── launch.py
│   └── run_dist_launch.sh
└── util/
    ├── __init__.py
    ├── benchmark.py
    ├── box_ops.py
    ├── dam.py
    ├── misc.py
    └── plot_utils.py
Download .txt
SYMBOL INDEX (341 symbols across 34 files)

FILE: datasets/__init__.py
  function get_coco_api_from_dataset (line 16) | def get_coco_api_from_dataset(dataset):
  function build_dataset (line 26) | def build_dataset(image_set, args):

FILE: datasets/coco.py
  class CocoDetection (line 26) | class CocoDetection(TvCocoDetection):
    method __init__ (line 27) | def __init__(self, img_folder, ann_file, transforms, return_masks, cac...
    method __getitem__ (line 33) | def __getitem__(self, idx):
  function convert_coco_poly_to_mask (line 43) | def convert_coco_poly_to_mask(segmentations, height, width):
  class ConvertCocoPolysToMask (line 60) | class ConvertCocoPolysToMask(object):
    method __init__ (line 61) | def __init__(self, return_masks=False):
    method __call__ (line 64) | def __call__(self, image, target):
  function make_coco_transforms (line 125) | def make_coco_transforms(image_set):
  function build (line 157) | def build(image_set, args):

FILE: datasets/coco_eval.py
  class CocoEvaluator (line 30) | class CocoEvaluator(object):
    method __init__ (line 31) | def __init__(self, coco_gt, iou_types):
    method update (line 44) | def update(self, predictions):
    method synchronize_between_processes (line 63) | def synchronize_between_processes(self):
    method accumulate (line 68) | def accumulate(self):
    method summarize (line 72) | def summarize(self):
    method prepare (line 77) | def prepare(self, predictions, iou_type):
    method prepare_for_coco_detection (line 87) | def prepare_for_coco_detection(self, predictions):
    method prepare_for_coco_segmentation (line 111) | def prepare_for_coco_segmentation(self, predictions):
    method prepare_for_coco_keypoint (line 146) | def prepare_for_coco_keypoint(self, predictions):
  function convert_to_xywh (line 173) | def convert_to_xywh(boxes):
  function merge (line 178) | def merge(img_ids, eval_imgs):
  function create_common_coco_eval (line 200) | def create_common_coco_eval(coco_eval, img_ids, eval_imgs):
  function evaluate (line 216) | def evaluate(self):

FILE: datasets/coco_panoptic.py
  class CocoPanoptic (line 23) | class CocoPanoptic:
    method __init__ (line 24) | def __init__(self, img_folder, ann_folder, ann_file, transforms=None, ...
    method __getitem__ (line 42) | def __getitem__(self, idx):
    method __len__ (line 78) | def __len__(self):
    method get_height_and_width (line 81) | def get_height_and_width(self, idx):
  function build (line 88) | def build(image_set, args):

FILE: datasets/data_prefetcher.py
  function to_cuda (line 9) | def to_cuda(samples, targets, device):
  class data_prefetcher (line 14) | class data_prefetcher():
    method __init__ (line 15) | def __init__(self, loader, device, prefetch=True):
    method preload (line 23) | def preload(self):
    method next (line 51) | def next(self):

FILE: datasets/panoptic_eval.py
  class PanopticEvaluator (line 21) | class PanopticEvaluator(object):
    method __init__ (line 22) | def __init__(self, ann_file, ann_folder, output_dir="panoptic_eval"):
    method update (line 31) | def update(self, predictions):
    method synchronize_between_processes (line 38) | def synchronize_between_processes(self):
    method summarize (line 45) | def summarize(self):

FILE: datasets/samplers.py
  class DistributedSampler (line 16) | class DistributedSampler(Sampler):
    method __init__ (line 31) | def __init__(self, dataset, num_replicas=None, rank=None, local_rank=N...
    method __iter__ (line 48) | def __iter__(self):
    method __len__ (line 68) | def __len__(self):
    method set_epoch (line 71) | def set_epoch(self, epoch):
  class NodeDistributedSampler (line 75) | class NodeDistributedSampler(Sampler):
    method __init__ (line 90) | def __init__(self, dataset, num_replicas=None, rank=None, local_rank=N...
    method __iter__ (line 115) | def __iter__(self):
    method __len__ (line 135) | def __len__(self):
    method set_epoch (line 138) | def set_epoch(self, epoch):

FILE: datasets/torchvision_datasets/coco.py
  class CocoDetection (line 20) | class CocoDetection(VisionDataset):
    method __init__ (line 33) | def __init__(self, root, annFile, transform=None, target_transform=Non...
    method cache_images (line 46) | def cache_images(self):
    method get_image (line 55) | def get_image(self, path):
    method __getitem__ (line 63) | def __getitem__(self, index):
    method __len__ (line 83) | def __len__(self):

FILE: datasets/transforms.py
  function crop (line 24) | def crop(image, target, region):
  function hflip (line 67) | def hflip(image, target):
  function resize (line 84) | def resize(image, target, size, max_size=None):
  function pad (line 143) | def pad(image, target, padding):
  class RandomCrop (line 156) | class RandomCrop(object):
    method __init__ (line 157) | def __init__(self, size):
    method __call__ (line 160) | def __call__(self, img, target):
  class RandomSizeCrop (line 165) | class RandomSizeCrop(object):
    method __init__ (line 166) | def __init__(self, min_size: int, max_size: int):
    method __call__ (line 170) | def __call__(self, img: PIL.Image.Image, target: dict):
  class CenterCrop (line 177) | class CenterCrop(object):
    method __init__ (line 178) | def __init__(self, size):
    method __call__ (line 181) | def __call__(self, img, target):
  class RandomHorizontalFlip (line 189) | class RandomHorizontalFlip(object):
    method __init__ (line 190) | def __init__(self, p=0.5):
    method __call__ (line 193) | def __call__(self, img, target):
  class RandomResize (line 199) | class RandomResize(object):
    method __init__ (line 200) | def __init__(self, sizes, max_size=None):
    method __call__ (line 205) | def __call__(self, img, target=None):
  class RandomPad (line 210) | class RandomPad(object):
    method __init__ (line 211) | def __init__(self, max_pad):
    method __call__ (line 214) | def __call__(self, img, target):
  class RandomSelect (line 220) | class RandomSelect(object):
    method __init__ (line 225) | def __init__(self, transforms1, transforms2, p=0.5):
    method __call__ (line 230) | def __call__(self, img, target):
  class ToTensor (line 236) | class ToTensor(object):
    method __call__ (line 237) | def __call__(self, img, target):
  class RandomErasing (line 241) | class RandomErasing(object):
    method __init__ (line 243) | def __init__(self, *args, **kwargs):
    method __call__ (line 246) | def __call__(self, img, target):
  class Normalize (line 250) | class Normalize(object):
    method __init__ (line 251) | def __init__(self, mean, std):
    method __call__ (line 255) | def __call__(self, image, target=None):
  class Compose (line 269) | class Compose(object):
    method __init__ (line 270) | def __init__(self, transforms):
    method __call__ (line 273) | def __call__(self, image, target):
    method __repr__ (line 278) | def __repr__(self):

FILE: engine.py
  function train_one_epoch (line 31) | def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module,
  function evaluate (line 107) | def evaluate(model, criterion, postprocessors, data_loader, base_ds, dev...

FILE: main.py
  function get_args_parser (line 37) | def get_args_parser():
  function main (line 156) | def main(args):
  function print_final_result_on_master (line 403) | def print_final_result_on_master(model, dataset_val, args, test_stats, s...

FILE: models/__init__.py
  function build_model (line 17) | def build_model(args):

FILE: models/backbone.py
  class FrozenBatchNorm2d (line 32) | class FrozenBatchNorm2d(torch.nn.Module):
    method __init__ (line 41) | def __init__(self, n, eps=1e-5):
    method _load_from_state_dict (line 49) | def _load_from_state_dict(self, state_dict, prefix, local_metadata, st...
    method forward (line 59) | def forward(self, x):
  class BackboneBase (line 72) | class BackboneBase(nn.Module):
    method __init__ (line 74) | def __init__(self, backbone: nn.Module, train_backbone: bool, return_i...
    method get_return_layers (line 120) | def get_return_layers(name: str, layer_ids):
    method forward (line 123) | def forward(self, tensor_list: NestedTensor):
  class DummyBackbone (line 134) | class DummyBackbone(torch.nn.Module):
    method __init__ (line 135) | def __init__(self):
  class Backbone (line 140) | class Backbone(BackboneBase):
    method __init__ (line 142) | def __init__(self, name: str,
  class Joiner (line 202) | class Joiner(nn.Sequential):
    method __init__ (line 203) | def __init__(self, backbone, position_embedding):
    method forward (line 208) | def forward(self, tensor_list: NestedTensor):
  function test_backbone (line 222) | def test_backbone(backbone):
  function build_backbone (line 231) | def build_backbone(args):

FILE: models/deformable_detr.py
  function _get_clones (line 36) | def _get_clones(module, N):
  class DeformableDETR (line 40) | class DeformableDETR(nn.Module):
    method __init__ (line 42) | def __init__(self, backbone, transformer, num_classes, num_queries, nu...
    method forward (line 144) | def forward(self, samples: NestedTensor):
    method _set_aux_loss (line 270) | def _set_aux_loss(self, outputs_class, outputs_coord):
  class SetCriterion (line 278) | class SetCriterion(nn.Module):
    method __init__ (line 284) | def __init__(self, num_classes, matcher, weight_dict, losses, args):
    method loss_labels (line 302) | def loss_labels(self, outputs, targets, indices, num_boxes, log=True):
    method loss_cardinality (line 330) | def loss_cardinality(self, outputs, targets, indices, num_boxes):
    method loss_boxes (line 343) | def loss_boxes(self, outputs, targets, indices, num_boxes):
    method loss_masks (line 364) | def loss_masks(self, outputs, targets, indices, num_boxes):
    method loss_mask_prediction (line 393) | def loss_mask_prediction(self, outputs, targets, indices, num_boxes, l...
    method corr (line 430) | def corr(self, outputs, targets, indices, num_boxes):
    method _get_src_permutation_idx (line 457) | def _get_src_permutation_idx(self, indices):
    method _get_tgt_permutation_idx (line 463) | def _get_tgt_permutation_idx(self, indices):
    method get_loss (line 469) | def get_loss(self, loss, outputs, targets, indices, num_boxes, **kwargs):
    method forward (line 481) | def forward(self, outputs, targets):
  class PostProcess (line 579) | class PostProcess(nn.Module):
    method forward (line 583) | def forward(self, outputs, target_sizes):
  class MLP (line 614) | class MLP(nn.Module):
    method __init__ (line 617) | def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
    method forward (line 623) | def forward(self, x):
  function build (line 629) | def build(args):

FILE: models/deformable_transformer.py
  class DeformableTransformer (line 27) | class DeformableTransformer(nn.Module):
    method __init__ (line 28) | def __init__(self, d_model=256, nhead=8,
    method _log_args (line 78) | def _log_args(self, *names):
    method _reset_parameters (line 83) | def _reset_parameters(self):
    method get_proposal_pos_embed (line 95) | def get_proposal_pos_embed(self, proposals):
    method gen_encoder_output_proposals (line 110) | def gen_encoder_output_proposals(self, memory, memory_padding_mask, sp...
    method get_valid_ratio (line 160) | def get_valid_ratio(self, mask):
    method forward (line 169) | def forward(self, srcs, masks, pos_embeds, query_embed=None):
  class DeformableTransformerEncoderLayer (line 313) | class DeformableTransformerEncoderLayer(nn.Module):
    method __init__ (line 314) | def __init__(self,
    method with_pos_embed (line 334) | def with_pos_embed(tensor, pos):
    method forward_ffn (line 337) | def forward_ffn(self, src):
    method forward (line 343) | def forward(self, src, pos, reference_points, spatial_shapes, level_st...
  class DeformableTransformerEncoder (line 372) | class DeformableTransformerEncoder(nn.Module):
    method __init__ (line 373) | def __init__(self, encoder_layer, num_layers, mask_predictor_dim=256):
    method get_reference_points (line 383) | def get_reference_points(spatial_shapes, valid_ratios, device):
    method forward (line 405) | def forward(self, src, spatial_shapes, level_start_index, valid_ratios,
  class DeformableTransformerDecoderLayer (line 473) | class DeformableTransformerDecoderLayer(nn.Module):
    method __init__ (line 474) | def __init__(self, d_model=256, d_ffn=1024, dropout=0.1, activation="r...
    method with_pos_embed (line 497) | def with_pos_embed(tensor, pos):
    method forward_ffn (line 500) | def forward_ffn(self, tgt):
    method forward (line 506) | def forward(self, tgt, query_pos, reference_points, src, src_spatial_s...
  class DeformableTransformerDecoder (line 530) | class DeformableTransformerDecoder(nn.Module):
    method __init__ (line 531) | def __init__(self, decoder_layer, num_layers, return_intermediate=False):
    method forward (line 542) | def forward(self, tgt, reference_points, src, src_spatial_shapes, src_...
  class MaskPredictor (line 618) | class MaskPredictor(nn.Module):
    method __init__ (line 619) | def __init__(self, in_dim, h_dim):
    method forward (line 635) | def forward(self, x):
  function _get_clones (line 644) | def _get_clones(module, N):
  function _get_activation_fn (line 648) | def _get_activation_fn(activation):
  function build_deforamble_transformer (line 659) | def build_deforamble_transformer(args):

FILE: models/matcher.py
  class HungarianMatcher (line 24) | class HungarianMatcher(nn.Module):
    method __init__ (line 32) | def __init__(self,
    method forward (line 49) | def forward(self, outputs, targets):
  function build_matcher (line 104) | def build_matcher(args):

FILE: models/ops/functions/ms_deform_attn_func.py
  class MSDeformAttnFunction (line 21) | class MSDeformAttnFunction(Function):
    method forward (line 23) | def forward(ctx, value, value_spatial_shapes, value_level_start_index,...
    method backward (line 32) | def backward(ctx, grad_output):
  function ms_deform_attn_core_pytorch (line 41) | def ms_deform_attn_core_pytorch(value, value_spatial_shapes, sampling_lo...

FILE: models/ops/modules/ms_deform_attn.py
  function _is_power_of_2 (line 27) | def _is_power_of_2(n):
  class MSDeformAttn (line 33) | class MSDeformAttn(nn.Module):
    method __init__ (line 34) | def __init__(self, d_model=256, n_levels=4, n_heads=8, n_points=4):
    method _reset_parameters (line 66) | def _reset_parameters(self):
    method forward (line 82) | def forward(self, query, reference_points, input_flatten, input_spatia...
  function ms_deform_attn_core_pytorch (line 125) | def ms_deform_attn_core_pytorch(value, value_spatial_shapes, sampling_lo...

FILE: models/ops/setup.py
  function get_extensions (line 23) | def get_extensions():

FILE: models/ops/src/cpu/ms_deform_attn_cpu.cpp
  function ms_deform_attn_cpu_forward (line 17) | at::Tensor
  function ms_deform_attn_cpu_backward (line 29) | std::vector<at::Tensor>

FILE: models/ops/src/ms_deform_attn.h
  function im2col_step (line 27) | int im2col_step)

FILE: models/ops/src/vision.cpp
  function PYBIND11_MODULE (line 13) | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {

FILE: models/ops/test.py
  function check_forward_equal_with_pytorch_double (line 32) | def check_forward_equal_with_pytorch_double():
  function check_forward_equal_with_pytorch_float (line 48) | def check_forward_equal_with_pytorch_float():
  function check_gradient_numerical (line 63) | def check_gradient_numerical(channels=4, grad_value=True, grad_sampling_...

FILE: models/position_encoding.py
  class PositionEmbeddingSine (line 24) | class PositionEmbeddingSine(nn.Module):
    method __init__ (line 29) | def __init__(self, num_pos_feats=64, temperature=10000, normalize=Fals...
    method forward (line 40) | def forward(self, tensor_list: NestedTensor):
  class PositionEmbeddingLearned (line 63) | class PositionEmbeddingLearned(nn.Module):
    method __init__ (line 67) | def __init__(self, num_pos_feats=256):
    method reset_parameters (line 73) | def reset_parameters(self):
    method forward (line 77) | def forward(self, tensor_list: NestedTensor):
  function build_position_encoding (line 91) | def build_position_encoding(args):

FILE: models/segmentation.py
  class DETRsegm (line 34) | class DETRsegm(nn.Module):
    method __init__ (line 35) | def __init__(self, detr, freeze_detr=False):
    method forward (line 47) | def forward(self, samples: NestedTensor):
  class MaskHeadSmallConv (line 76) | class MaskHeadSmallConv(nn.Module):
    method __init__ (line 82) | def __init__(self, dim, fpn_dims, context_dim):
    method forward (line 109) | def forward(self, x, bbox_mask, fpns):
  class MHAttentionMap (line 150) | class MHAttentionMap(nn.Module):
    method __init__ (line 153) | def __init__(self, query_dim, hidden_dim, num_heads, dropout=0, bias=T...
    method forward (line 168) | def forward(self, q, k, mask=None):
  function dice_loss (line 182) | def dice_loss(inputs, targets, num_boxes):
  function sigmoid_focal_loss (line 200) | def sigmoid_focal_loss(inputs, targets, num_boxes, alpha: float = 0.25, ...
  class PostProcessSegm (line 229) | class PostProcessSegm(nn.Module):
    method __init__ (line 230) | def __init__(self, threshold=0.5):
    method forward (line 235) | def forward(self, results, outputs, orig_target_sizes, max_target_sizes):
  class PostProcessPanoptic (line 252) | class PostProcessPanoptic(nn.Module):
    method __init__ (line 256) | def __init__(self, is_thing_map, threshold=0.85):
    method forward (line 267) | def forward(self, outputs, processed_sizes, target_sizes=None):

FILE: models/swin_transformer/build.py
  function build_model (line 31) | def build_model(name, out_indices, frozen_stages, pretrained):
  function _update_dict (line 67) | def _update_dict(tar, src):
  function load_config_yaml (line 77) | def load_config_yaml(cfg_file, config=None):

FILE: models/swin_transformer/config.py
  class Config (line 18) | class Config(SimpleNamespace):
    method __init__ (line 48) | def __init__(self, _dict=None, **kwargs):
    method _set_with_nested_dict (line 55) | def _set_with_nested_dict(self, _dict):
    method freezed (line 64) | def freezed(self):
    method from_yaml (line 68) | def from_yaml(cls, yaml_file):
    method __repr__ (line 73) | def __repr__(self):
    method __getitem__ (line 76) | def __getitem__(self, item):
    method __getattr__ (line 79) | def __getattr__(self, item):
    method __setattr__ (line 91) | def __setattr__(self, item, value):
    method __bool__ (line 96) | def __bool__(self):
    method __len__ (line 100) | def __len__(self):
    method __getstate__ (line 103) | def __getstate__(self):
    method __setstate__ (line 106) | def __setstate__(self, state):
    method __contains__ (line 109) | def __contains__(self, item):
    method __deepcopy__ (line 112) | def __deepcopy__(self, memodict={}):
    method __iter__ (line 115) | def __iter__(self):
    method pformat (line 119) | def pformat(self):
    method pprint (line 123) | def pprint(self):
    method freeze (line 126) | def freeze(self):
    method defrost (line 134) | def defrost(self):
    method get (line 141) | def get(self, *args, **kwargs):
    method keys (line 144) | def keys(self):
    method values (line 147) | def values(self):
    method items (line 150) | def items(self):
    method clone (line 153) | def clone(self):
    method update (line 156) | def update(self, dict_, delimiter='/'):
    method _update (line 160) | def _update(self, key, value, delimiter='/'):
    method to_dict (line 167) | def to_dict(self):

FILE: models/swin_transformer/swin_transformer.py
  class Mlp (line 21) | class Mlp(nn.Module):
    method __init__ (line 22) | def __init__(self, in_features, hidden_features=None, out_features=Non...
    method forward (line 31) | def forward(self, x):
  function window_partition (line 40) | def window_partition(x, window_size):
  function window_reverse (line 55) | def window_reverse(windows, window_size, H, W):
  class WindowAttention (line 72) | class WindowAttention(nn.Module):
    method __init__ (line 85) | def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scal...
    method forward (line 119) | def forward(self, x, mask=None):
  class SwinTransformerBlock (line 153) | class SwinTransformerBlock(nn.Module):
    method __init__ (line 170) | def __init__(self, dim, num_heads, window_size=7, shift_size=0,
    method forward (line 194) | def forward(self, x, mask_matrix):
  class PatchMerging (line 253) | class PatchMerging(nn.Module):
    method __init__ (line 259) | def __init__(self, dim, norm_layer=nn.LayerNorm):
    method forward (line 265) | def forward(self, x, H, W):
  class BasicLayer (line 294) | class BasicLayer(nn.Module):
    method __init__ (line 312) | def __init__(self,
    method forward (line 354) | def forward(self, x, H, W):
  class PatchEmbed (line 396) | class PatchEmbed(nn.Module):
    method __init__ (line 405) | def __init__(self, patch_size=4, in_chans=3, embed_dim=96, norm_layer=...
    method forward (line 419) | def forward(self, x):
  class SwinTransformer (line 438) | class SwinTransformer(nn.Module):
    method __init__ (line 466) | def __init__(self,
    method _freeze_stages (line 546) | def _freeze_stages(self):
    method _init_weights (line 563) | def _init_weights(self, m):
    method forward (line 572) | def forward(self, x):

FILE: tools/launch.py
  function parse_args (line 116) | def parse_args():
  function main (line 159) | def main():

FILE: util/benchmark.py
  function measure_average_inference_time (line 17) | def measure_average_inference_time(model, inputs, num_iters=100, warm_it...
  function python_ops_mode_for_deform_attn (line 31) | def python_ops_mode_for_deform_attn(model, ops_mode):
  function compute_fps (line 39) | def compute_fps(model, dataset, num_iters=300, warm_iters=5, batch_size=4):
  function compute_gflops (line 54) | def compute_gflops(model, dataset, approximated=True):
  function flop_count_without_warnings (line 93) | def flop_count_without_warnings(

FILE: util/box_ops.py
  function box_cxcywh_to_xyxy (line 17) | def box_cxcywh_to_xyxy(x):
  function box_xyxy_to_cxcywh (line 24) | def box_xyxy_to_cxcywh(x):
  function box_iou (line 32) | def box_iou(boxes1, boxes2):
  function generalized_box_iou (line 48) | def generalized_box_iou(boxes1, boxes2):
  function masks_to_boxes (line 72) | def masks_to_boxes(masks):

FILE: util/dam.py
  function idx_to_flat_grid (line 21) | def idx_to_flat_grid(spatial_shapes, idx):
  function attn_map_to_flat_grid (line 29) | def attn_map_to_flat_grid(spatial_shapes, level_start_index, sampling_lo...
  function compute_corr (line 70) | def compute_corr(flat_grid_topk, flat_grid_attn_map, spatial_shapes):

FILE: util/misc.py
  function _check_size_scale_factor (line 39) | def _check_size_scale_factor(dim, size, scale_factor):
  function _output_size (line 50) | def _output_size(dim, input, size, scale_factor):
  class SmoothedValue (line 68) | class SmoothedValue(object):
    method __init__ (line 73) | def __init__(self, window_size=20, fmt=None):
    method update (line 81) | def update(self, value, n=1):
    method synchronize_between_processes (line 86) | def synchronize_between_processes(self):
    method median (line 100) | def median(self):
    method avg (line 105) | def avg(self):
    method global_avg (line 110) | def global_avg(self):
    method max (line 114) | def max(self):
    method value (line 118) | def value(self):
    method __str__ (line 121) | def __str__(self):
  function unwrap (line 130) | def unwrap(wrapped_module):
  function check_unused_parameters (line 138) | def check_unused_parameters(model, loss_dict, weight_dict):
  function all_gather (line 152) | def all_gather(data):
  function reduce_dict (line 195) | def reduce_dict(input_dict, average=True):
  class MetricLogger (line 222) | class MetricLogger(object):
    method __init__ (line 223) | def __init__(self, delimiter="\t"):
    method update (line 227) | def update(self, **kwargs):
    method __getattr__ (line 234) | def __getattr__(self, attr):
    method __str__ (line 242) | def __str__(self):
    method synchronize_between_processes (line 250) | def synchronize_between_processes(self):
    method add_meter (line 254) | def add_meter(self, name, meter):
    method log_every (line 257) | def log_every(self, iterable, print_freq, header=None):
  function get_sha (line 312) | def get_sha():
  function collate_fn (line 332) | def collate_fn(batch):
  function _max_by_axis (line 338) | def _max_by_axis(the_list):
  function nested_tensor_from_tensor_list (line 347) | def nested_tensor_from_tensor_list(tensor_list: List[Tensor]):
  class NestedTensor (line 367) | class NestedTensor(object):
    method __init__ (line 368) | def __init__(self, tensors, mask: Optional[Tensor]):
    method to (line 372) | def to(self, device, non_blocking=False):
    method record_stream (line 383) | def record_stream(self, *args, **kwargs):
    method decompose (line 388) | def decompose(self):
    method __repr__ (line 391) | def __repr__(self):
  function setup_for_distributed (line 395) | def setup_for_distributed(is_master):
  function is_dist_avail_and_initialized (line 410) | def is_dist_avail_and_initialized():
  function get_world_size (line 418) | def get_world_size():
  function get_rank (line 424) | def get_rank():
  function get_local_size (line 430) | def get_local_size():
  function get_local_rank (line 436) | def get_local_rank():
  function is_main_process (line 442) | def is_main_process():
  function save_on_master (line 446) | def save_on_master(*args, **kwargs):
  function _check_if_valid_ip (line 450) | def _check_if_valid_ip(ip):
  function _maybe_gethostbyname (line 460) | def _maybe_gethostbyname(addr):
  function init_distributed_mode (line 483) | def init_distributed_mode(args):
  function accuracy (line 527) | def accuracy(output, target, topk=(1,)):
  function interpolate (line 545) | def interpolate(input, size=None, scale_factor=None, mode="nearest", ali...
  function get_total_grad_norm (line 567) | def get_total_grad_norm(parameters, norm_type=2):
  function inverse_sigmoid (line 576) | def inverse_sigmoid(x, eps=1e-5):
  function scale_learning_rate (line 583) | def scale_learning_rate(args):

FILE: util/plot_utils.py
  function plot_logs (line 21) | def plot_logs(logs, fields=('class_error', 'loss_bbox_unscaled', 'mAP'),...
  function plot_precision_recall (line 77) | def plot_precision_recall(files, naming_scheme='iter'):
Condensed preview — 64 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (388K chars).
[
  {
    "path": "LICENSE",
    "chars": 12512,
    "preview": "                                Apache License\n                           Version 2.0, January 2004\n                    "
  },
  {
    "path": "NOTICE",
    "chars": 3358,
    "preview": "===============================================================================\nDeformable DETR's Apache License 2.0\n==="
  },
  {
    "path": "README.md",
    "chars": 12069,
    "preview": "[![KakaoBrain](https://img.shields.io/badge/kakao-brain-ffcd00.svg)](http://kakaobrain.com/)\n[![pytorch](https://img.shi"
  },
  {
    "path": "configs/r50_deformable_detr.sh",
    "chars": 143,
    "preview": "#!/usr/bin/env bash\n\nset -x\n\nEXP_DIR=exps/r50_deformable_detr\nPY_ARGS=${@:1}\n\npython -u main.py \\\n    --output_dir ${EXP"
  },
  {
    "path": "configs/r50_efficient_detr.sh",
    "chars": 233,
    "preview": "#!/usr/bin/env bash\n\nset -x\n\nEXP_DIR=exps/r50_efficient_detr\nPY_ARGS=${@:1}\n\npython -u main.py \\\n    --output_dir ${EXP_"
  },
  {
    "path": "configs/r50_sparse_detr_rho_0.1.sh",
    "chars": 275,
    "preview": "#!/usr/bin/env bash\n\nset -x\n\nEXP_DIR=exps/r50_sparse_detr_0.1\nPY_ARGS=${@:1}\n\npython -u main.py \\\n    --output_dir ${EXP"
  },
  {
    "path": "configs/r50_sparse_detr_rho_0.2.sh",
    "chars": 275,
    "preview": "#!/usr/bin/env bash\n\nset -x\n\nEXP_DIR=exps/r50_sparse_detr_0.2\nPY_ARGS=${@:1}\n\npython -u main.py \\\n    --output_dir ${EXP"
  },
  {
    "path": "configs/r50_sparse_detr_rho_0.3.sh",
    "chars": 275,
    "preview": "#!/usr/bin/env bash\n\nset -x\n\nEXP_DIR=exps/r50_sparse_detr_0.3\nPY_ARGS=${@:1}\n\npython -u main.py \\\n    --output_dir ${EXP"
  },
  {
    "path": "configs/swint_deformable_detr.sh",
    "chars": 169,
    "preview": "#!/usr/bin/env bash\n\nset -x\n\nEXP_DIR=exps/swint_deformable_detr\nPY_ARGS=${@:1}\n\npython -u main.py \\\n    --output_dir ${E"
  },
  {
    "path": "configs/swint_efficient_detr.sh",
    "chars": 259,
    "preview": "#!/usr/bin/env bash\n\nset -x\n\nEXP_DIR=exps/swint_efficient_detr\nPY_ARGS=${@:1}\n\npython -u main.py \\\n    --output_dir ${EX"
  },
  {
    "path": "configs/swint_sparse_detr_rho_0.1.sh",
    "chars": 301,
    "preview": "#!/usr/bin/env bash\n\nset -x\n\nEXP_DIR=exps/swint_sparse_detr_0.1\nPY_ARGS=${@:1}\n\npython -u main.py \\\n    --output_dir ${E"
  },
  {
    "path": "configs/swint_sparse_detr_rho_0.2.sh",
    "chars": 301,
    "preview": "#!/usr/bin/env bash\n\nset -x\n\nEXP_DIR=exps/swint_sparse_detr_0.2\nPY_ARGS=${@:1}\n\npython -u main.py \\\n    --output_dir ${E"
  },
  {
    "path": "configs/swint_sparse_detr_rho_0.3.sh",
    "chars": 301,
    "preview": "#!/usr/bin/env bash\n\nset -x\n\nEXP_DIR=exps/swint_sparse_detr_0.3\nPY_ARGS=${@:1}\n\npython -u main.py \\\n    --output_dir ${E"
  },
  {
    "path": "datasets/__init__.py",
    "chars": 1341,
    "preview": "# ------------------------------------------------------------------------\n# Deformable DETR\n# Copyright (c) 2020 SenseT"
  },
  {
    "path": "datasets/coco.py",
    "chars": 6044,
    "preview": "# ------------------------------------------------------------------------\n# Deformable DETR\n# Copyright (c) 2020 SenseT"
  },
  {
    "path": "datasets/coco_eval.py",
    "chars": 9171,
    "preview": "# ------------------------------------------------------------------------\n# Deformable DETR\n# Copyright (c) 2020 SenseT"
  },
  {
    "path": "datasets/coco_panoptic.py",
    "chars": 4159,
    "preview": "# ------------------------------------------------------------------------\n# Deformable DETR\n# Copyright (c) 2020 SenseT"
  },
  {
    "path": "datasets/data_prefetcher.py",
    "chars": 3085,
    "preview": "# ------------------------------------------------------------------------\n# Deformable DETR\n# Copyright (c) 2020 SenseT"
  },
  {
    "path": "datasets/panoptic_eval.py",
    "chars": 1929,
    "preview": "# ------------------------------------------------------------------------\n# Deformable DETR\n# Copyright (c) 2020 SenseT"
  },
  {
    "path": "datasets/samplers.py",
    "chars": 5608,
    "preview": "# ------------------------------------------------------------------------\n# Deformable DETR\n# Copyright (c) 2020 SenseT"
  },
  {
    "path": "datasets/torchvision_datasets/__init__.py",
    "chars": 329,
    "preview": "# ------------------------------------------------------------------------\n# Deformable DETR\n# Copyright (c) 2020 SenseT"
  },
  {
    "path": "datasets/torchvision_datasets/coco.py",
    "chars": 3285,
    "preview": "# ------------------------------------------------------------------------\n# Deformable DETR\n# Copyright (c) 2020 SenseT"
  },
  {
    "path": "datasets/transforms.py",
    "chars": 8955,
    "preview": "# ------------------------------------------------------------------------\n# Deformable DETR\n# Copyright (c) 2020 SenseT"
  },
  {
    "path": "engine.py",
    "chars": 8624,
    "preview": "# ------------------------------------------------------------------------------------\n# Sparse DETR\n# Copyright (c) 202"
  },
  {
    "path": "main.py",
    "chars": 21407,
    "preview": "# ------------------------------------------------------------------------------------\n# Sparse DETR\n# Copyright (c) 202"
  },
  {
    "path": "models/__init__.py",
    "chars": 852,
    "preview": "# ------------------------------------------------------------------------------------\n# Sparse DETR\n# Copyright (c) 202"
  },
  {
    "path": "models/backbone.py",
    "chars": 9312,
    "preview": "# ------------------------------------------------------------------------------------\n# Sparse DETR\n# Copyright (c) 202"
  },
  {
    "path": "models/deformable_detr.py",
    "chars": 32307,
    "preview": "# ------------------------------------------------------------------------------------\n# Sparse DETR\n# Copyright (c) 202"
  },
  {
    "path": "models/deformable_transformer.py",
    "chars": 31649,
    "preview": "# ------------------------------------------------------------------------------------\n# Sparse DETR\n# Copyright (c) 202"
  },
  {
    "path": "models/matcher.py",
    "chars": 5288,
    "preview": "# ------------------------------------------------------------------------------------\n# Sparse DETR\n# Copyright (c) 202"
  },
  {
    "path": "models/ops/functions/__init__.py",
    "chars": 598,
    "preview": "# ------------------------------------------------------------------------------------------------\n# Deformable DETR\n# C"
  },
  {
    "path": "models/ops/functions/ms_deform_attn_func.py",
    "chars": 3298,
    "preview": "# ------------------------------------------------------------------------------------------------\n# Deformable DETR\n# C"
  },
  {
    "path": "models/ops/make.sh",
    "chars": 593,
    "preview": "#!/usr/bin/env bash\n# ------------------------------------------------------------------------------------------------\n#"
  },
  {
    "path": "models/ops/modules/__init__.py",
    "chars": 739,
    "preview": "# ------------------------------------------------------------------------------------\n# Sparse DETR\n# Copyright (c) 202"
  },
  {
    "path": "models/ops/modules/ms_deform_attn.py",
    "chars": 8087,
    "preview": "# ------------------------------------------------------------------------------------\n# Sparse DETR\n# Copyright (c) 202"
  },
  {
    "path": "models/ops/setup.py",
    "chars": 2559,
    "preview": "# ------------------------------------------------------------------------------------------------\n# Deformable DETR\n# C"
  },
  {
    "path": "models/ops/src/cpu/ms_deform_attn_cpu.cpp",
    "chars": 1256,
    "preview": "/*!\n**************************************************************************************************\n* Deformable DETR"
  },
  {
    "path": "models/ops/src/cpu/ms_deform_attn_cpu.h",
    "chars": 1139,
    "preview": "/*!\n**************************************************************************************************\n* Deformable DETR"
  },
  {
    "path": "models/ops/src/cuda/ms_deform_attn_cuda.cu",
    "chars": 7316,
    "preview": "/*!\n**************************************************************************************************\n* Deformable DETR"
  },
  {
    "path": "models/ops/src/cuda/ms_deform_attn_cuda.h",
    "chars": 1140,
    "preview": "/*!\n**************************************************************************************************\n* Deformable DETR"
  },
  {
    "path": "models/ops/src/cuda/ms_deform_im2col_cuda.cuh",
    "chars": 54694,
    "preview": "/*!\n**************************************************************************\n* Deformable DETR\n* Copyright (c) 2020 Se"
  },
  {
    "path": "models/ops/src/ms_deform_attn.h",
    "chars": 1838,
    "preview": "/*!\n**************************************************************************************************\n* Deformable DETR"
  },
  {
    "path": "models/ops/src/vision.cpp",
    "chars": 799,
    "preview": "/*!\n**************************************************************************************************\n* Deformable DETR"
  },
  {
    "path": "models/ops/test.py",
    "chars": 4087,
    "preview": "# ------------------------------------------------------------------------------------------------\n# Deformable DETR\n# C"
  },
  {
    "path": "models/position_encoding.py",
    "chars": 4049,
    "preview": "# ------------------------------------------------------------------------------------\n# Sparse DETR\n# Copyright (c) 202"
  },
  {
    "path": "models/segmentation.py",
    "chars": 16270,
    "preview": "# ------------------------------------------------------------------------------------\n# Sparse DETR\n# Copyright (c) 202"
  },
  {
    "path": "models/swin_transformer/__init__.py",
    "chars": 350,
    "preview": "# ------------------------------------------------------------------------------------\n# Sparse DETR\n# Copyright (c) 202"
  },
  {
    "path": "models/swin_transformer/build.py",
    "chars": 3370,
    "preview": "# ------------------------------------------------------------------------------\n# Sparse DETR\n# Copyright (c) 2021 Kaka"
  },
  {
    "path": "models/swin_transformer/config.py",
    "chars": 5585,
    "preview": "# ------------------------------------------------------------------------------\n# Sparse DETR\n# Copyright (c) 2021 Kaka"
  },
  {
    "path": "models/swin_transformer/configs/default.yaml",
    "chars": 346,
    "preview": "DATA:\n  IMG_SIZE: 224\nTRAIN:\n  USE_CHECKPOINT: false\nMODEL:\n  SWIN:\n    APE: false\n    DEPTHS: [2, 2, 6, 2]\n    EMBED_DI"
  },
  {
    "path": "models/swin_transformer/configs/swin_base_patch4_window7_224.yaml",
    "chars": 208,
    "preview": "BASE: ['default.yaml']\nMODEL:\n  TYPE: swin\n  NAME: swin_base_patch4_window7_224\n  DROP_PATH_RATE: 0.5\n  SWIN:\n    EMBED_"
  },
  {
    "path": "models/swin_transformer/configs/swin_large_patch4_window7_224.yaml",
    "chars": 188,
    "preview": "BASE: ['default.yaml']\nMODEL:\n  TYPE: swin\n  NAME: swin_large_patch4_window7_224\n  SWIN:\n    EMBED_DIM: 192\n    DEPTHS: "
  },
  {
    "path": "models/swin_transformer/configs/swin_small_patch4_window7_224.yaml",
    "chars": 208,
    "preview": "BASE: ['default.yaml']\nMODEL:\n  TYPE: swin\n  NAME: swin_small_patch4_window7_224\n  DROP_PATH_RATE: 0.3\n  SWIN:\n    EMBED"
  },
  {
    "path": "models/swin_transformer/configs/swin_tiny_patch4_window7_224.yaml",
    "chars": 206,
    "preview": "BASE: ['default.yaml']\nMODEL:\n  TYPE: swin\n  NAME: swin_tiny_patch4_window7_224\n  DROP_PATH_RATE: 0.2\n  SWIN:\n    EMBED_"
  },
  {
    "path": "models/swin_transformer/swin_transformer.py",
    "chars": 23925,
    "preview": "# ------------------------------------------------------------------------------\n# Sparse DETR\n# Copyright (c) 2021 Kaka"
  },
  {
    "path": "requirements.txt",
    "chars": 47,
    "preview": "pycocotools\ntqdm\nscipy\ntimm\nfvcore\ntensorboard\n"
  },
  {
    "path": "tools/launch.py",
    "chars": 9306,
    "preview": "# ----------------------------------------------------------------------------------------------------------------------"
  },
  {
    "path": "tools/run_dist_launch.sh",
    "chars": 812,
    "preview": "#!/usr/bin/env bash\n# ------------------------------------------------------------------------\n# Deformable DETR\n# Copyr"
  },
  {
    "path": "util/__init__.py",
    "chars": 506,
    "preview": "# ------------------------------------------------------------------------\n# Deformable DETR\n# Copyright (c) 2020 SenseT"
  },
  {
    "path": "util/benchmark.py",
    "chars": 4644,
    "preview": "from collections import defaultdict\nimport time\nfrom typing import Any, Counter, DefaultDict, Tuple, Dict, Optional\nimpo"
  },
  {
    "path": "util/box_ops.py",
    "chars": 2997,
    "preview": "# ------------------------------------------------------------------------\n# Deformable DETR\n# Copyright (c) 2020 SenseT"
  },
  {
    "path": "util/dam.py",
    "chars": 3968,
    "preview": "# ------------------------------------------------------------------------------------\n# Sparse DETR\n# Copyright (c) 202"
  },
  {
    "path": "util/misc.py",
    "chars": 20396,
    "preview": "# ------------------------------------------------------------------------------------\n# Sparse DETR\n# Copyright (c) 202"
  },
  {
    "path": "util/plot_utils.py",
    "chars": 4664,
    "preview": "# ------------------------------------------------------------------------\n# Deformable DETR\n# Copyright (c) 2020 SenseT"
  }
]

About this extraction

This page contains the full source code of the kakaobrain/sparse-detr GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 64 files (365.2 KB), approximately 88.7k tokens, and a symbol index with 341 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!