Full Code of JMoonr/LATR for AI

main 1c9e62646274 cached
43 files
330.6 KB
84.1k tokens
230 symbols
1 requests
Download .txt
Showing preview only (346K chars total). Download the full file or copy to clipboard to get everything.
Repository: JMoonr/LATR
Branch: main
Commit: 1c9e62646274
Files: 43
Total size: 330.6 KB

Directory structure:
gitextract_q_wu7p65/

├── .gitignore
├── LICENSE
├── README.md
├── config/
│   ├── _base_/
│   │   ├── base_res101_bs16xep100.py
│   │   ├── base_res101_bs16xep100_apollo.py
│   │   ├── once_eval_config.json
│   │   └── optimizer.py
│   └── release_iccv/
│       ├── apollo_illu.py
│       ├── apollo_rare.py
│       ├── apollo_standard.py
│       ├── latr_1000_baseline.py
│       ├── latr_1000_baseline_lite.py
│       └── once.py
├── data/
│   ├── Load_Data.py
│   ├── __init__.py
│   ├── apollo_dataset.py
│   └── transform.py
├── docs/
│   ├── data_preparation.md
│   ├── install.md
│   └── train_eval.md
├── experiments/
│   ├── __init__.py
│   ├── ddp.py
│   ├── gpu_utils.py
│   └── runner.py
├── main.py
├── models/
│   ├── __init__.py
│   ├── latr.py
│   ├── latr_head.py
│   ├── ms2one.py
│   ├── scatter_utils.py
│   ├── sparse_ins.py
│   ├── sparse_inst_loss.py
│   ├── transformer_bricks.py
│   └── utils.py
├── pretrained_models/
│   └── .gitkeep
├── requirements.txt
├── utils/
│   ├── MinCostFlow.py
│   ├── __init__.py
│   ├── eval_3D_lane.py
│   ├── eval_3D_lane_apollo.py
│   ├── eval_3D_once.py
│   └── utils.py
└── work_dirs/
    └── .gitkeep

================================================
FILE CONTENTS
================================================

================================================
FILE: .gitignore
================================================
*pyc
*pth

================================================
FILE: LICENSE
================================================
MIT License

Copyright (c) 2023 JMoonr

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.


================================================
FILE: README.md
================================================
<br />
<p align="center">
  
  <h3 align="center"><strong>LATR: 3D Lane Detection from Monocular Images with Transformer</strong></h3>

<p align="center">
  <a href="https://arxiv.org/abs/2308.04583" target='_blank'>
    <!-- <img src="https://img.shields.io/badge/arXiv-%F0%9F%93%83-yellow"> -->
    <img src="https://img.shields.io/badge/arXiv-2308.04583-b31b1b.svg">
  </a>
  <a href="" target='_blank'>
    <img src="https://visitor-badge.laobi.icu/badge?page_id=JMoonr.LATR&left_color=gray&right_color=yellow">
  </a>
    <a href="https://github.com/JMoonr/LATR" target='_blank'>
     <img src="https://img.shields.io/github/stars/JMoonr/LATR?style=social">
  </a>
  
</p>


This is the official PyTorch implementation of [LATR: 3D Lane Detection from Monocular Images with Transformer](https://arxiv.org/abs/2308.04583).

![fig2](/assets/fig2.png)  

## News
  - **2024-01-15** :confetti_ball: Our new work [DV-3DLane: End-to-end Multi-modal 3D Lane Detection with Dual-view Representation](https://github.com/JMoonr/dv-3dlane) is accepted by ICLR2024.

  - **2023-08-12** :tada: LATR is accepted as an Oral presentation at ICCV2023! :sparkles:


## Environments
To set up the required packages, please refer to the [installation guide](./docs/install.md).

## Data
Please follow [data preparation](./docs/data_preparation.md) to download dataset.

## Pretrained Models
Note that the performance of pretrained model is higher than our paper due to code refactoration and optimization. All models are uploaded to [google drive](https://drive.google.com/drive/folders/1AhvLvE84vayzFxa0teRHYRdXz34ulzjB?usp=sharing).

| Dataset | Pretrained | Metrics | md5 |
| - | - | - | - |
| OpenLane-1000 | [Google Drive](https://drive.google.com/file/d/1jThvqnJ2cUaAuKdlTuRKjhLCH0Zq62A1/view?usp=sharing) | F1=0.6297 | d8ecb900c34fd23a9e7af840aff00843 |
| OpenLane-1000 (Lite version) | [Google Drive](https://drive.google.com/file/d/1WD5dxa6SI2oR9popw3kO2-7eGM2z-IHY/view?usp=sharing) | F1=0.6212 | 918de41d0d31dbfbecff3001c49dc296 |
| ONCE | [Google Drive](https://drive.google.com/file/d/12kXkJ9tDxm13CyFbB1ddt82lJZkYEicd/view?usp=sharing) | F1=0.8125 | 65a6958c162e3c7be0960bceb3f54650 |
| Apollo-balance | [Google Drive](https://drive.google.com/file/d/1hGyNrYi3wAQaKbC1mD_18NG35gdmMUiM/view?usp=sharing) | F1=0.9697 | 551967e8654a8a522bdb0756d74dd1a2 |
| Apollo-rare | [Google Drive](https://drive.google.com/file/d/19VVBaWBnWiEqGx1zJaeXF_1CKn88G5v0/view?usp=sharing) | F1=0.9641 | 184cfff1d3097a9009011f79f4594138 |
| Apollo-visual | [Google Drive](https://drive.google.com/file/d/1ZzaUODYK2dyiG_2bDXe5tiutxNvc71M2/view?usp=sharing) | F1=0.9611 | cec4aa567c264c84808f3c32f5aace82 |


## Evaluation
You can download the [pretrained models](#pretrained-models) to `./pretrained_models` directory and refer to the [eval guide](./docs/train_eval.md#evaluation) for evaluation.

## Train
Please follow the steps in [training](./docs/train_eval.md#train) to train the model.

## Benchmark

### OpenLane

| Models | F1 | Accuracy | X error <br> near \| far | Z-error <br> near \| far |
| ----- | -- | -------- | ------- | ------- |
| 3DLaneNet | 44.1 | - | 0.479 \| 0.572 | 0.367 \| 0.443 |
| GenLaneNet | 32.3 | - | 0.593 \| 0.494 | 0.140 \| 0.195 |
| Cond-IPM | 36.3 | - | 0.563 \| 1.080 | 0.421 \| 0.892 |
| PersFormer | 50.5 | 89.5 | 0.319 \| 0.325 | 0.112 \| 0.141 |
| CurveFormer | 50.5 | - | 0.340 \| 0.772 | 0.207 \| 0.651 |
| PersFormer-Res50 | 53.0 | 89.2 | 0.321 \| 0.303 | 0.085 \| 0.118 |
| **LATR-Lite** | 61.5 | 91.9 | 0.225 \| 0.249 | 0.073 \| 0.106 |
| **LATR** | 61.9 | 92.0 | 0.219 \| 0.259 | 0.075 \| 0.104 |


### Apollo

Plaes kindly refer to our paper for the performance on other scenes.

<table>
    <tr>
        <td>Scene</td>
        <td>Models</td>
        <td>F1</td>
        <td>AP</td>
        <td>X error <br> near | far </td>
        <td>Z error <br> near | far </td>
    </tr>
    <tr>
        <td rowspan="8">Balanced Scene</td>
        <td>3DLaneNet</td>
        <td>86.4</td>
        <td>89.3</td>
        <td>0.068 | 0.477</td>
        <td>0.015 | 0.202</td>
    </tr>
    <tr>
        <td>GenLaneNet</td>
        <td>88.1</td>
        <td>90.1</td>
        <td>0.061 | 0.496</td>
        <td>0.012 | 0.214</td>
    </tr>
    <tr>
        <td>CLGo</td>
        <td>91.9</td>
        <td>94.2</td>
        <td>0.061 | 0.361</td>
        <td>0.029 | 0.250</td>
    </tr>
    <tr>
        <td>PersFormer</td>
        <td>92.9</td>
        <td>-</td>
        <td>0.054 | 0.356</td>
        <td>0.010 | 0.234</td>
    </tr>
    <tr>
        <td>GP</td>
        <td>91.9</td>
        <td>93.8</td>
        <td>0.049 | 0.387</td>
        <td>0.008 | 0.213</td>
    </tr>
    <tr>
        <td>CurveFormer</td>
        <td>95.8</td>
        <td>97.3</td>
        <td>0.078 | 0.326</td>
        <td>0.018 | 0.219</td>
    </tr>
    <tr>
        <td><b>LATR-Lite</b></td>
        <td>96.5</td>
        <td>97.8</td>
        <td>0.035 | 0.283</td>
        <td>0.012 | 0.209</td>
    </tr>
    <tr>
        <td><b>LATR</b?</td>
        <td>96.8</td>
        <td>97.9</td>
        <td>0.022 | 0.253</td>
        <td>0.007 | 0.202</td>
    </tr>
</table>


### ONCE

| Method     | F1  | Precision(%) | Recall(%) | CD error(m) |
| :- | :- | :- | :- | :- |   
| 3DLaneNet  | 44.73 | 61.46 | 35.16 | 0.127 |
| GenLaneNet | 45.59 | 63.95 | 35.42 | 0.121 |
| SALAD <ONCE-3DLane> | 64.07 | 75.90 | 55.42 | 0.098 |
| PersFormer | 72.07 | 77.82 | 67.11 | 0.086 |
| **LATR** | 80.59 | 86.12 | 75.73 | 0.052 |

## Acknowledgment

This library is inspired by [OpenLane](https://github.com/OpenDriveLab/PersFormer_3DLane), [GenLaneNet](https://github.com/yuliangguo/Pytorch_Generalized_3D_Lane_Detection), [mmdetection3d](https://github.com/open-mmlab/mmdetection3d), [SparseInst](https://github.com/hustvl/SparseInst), [ONCE](https://github.com/once-3dlanes/once_3dlanes_benchmark) and many other related works, we thank them for sharing the code and datasets.


## Citation
If you find LATR is useful for your research, please consider citing the paper:

```tex
@article{luo2023latr,
  title={LATR: 3D Lane Detection from Monocular Images with Transformer},
  author={Luo, Yueru and Zheng, Chaoda and Yan, Xu and Kun, Tang and Zheng, Chao and Cui, Shuguang and Li, Zhen},
  journal={arXiv preprint arXiv:2308.04583},
  year={2023}
}
```

================================================
FILE: config/_base_/base_res101_bs16xep100.py
================================================
import os
import os.path as osp
import numpy as np


dataset_name = 'openlane'
dataset = '300' # '300' | '1000'

#  The path of dataset json files (annotations)
data_dir = './data/openlane/lane3d_300/'
# The path of dataset image files (images)
dataset_dir = './data/openlane/images/'
output_dir = dataset_name

org_h = 1280
org_w = 1920
crop_y = 0

ipm_h = 208
ipm_w = 128
resize_h = 360
resize_w = 480

mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]

cam_height = 1.55
pitch = 3
fix_cam = False
pred_cam = False

model_name = 'LATR'
weight_init = 'normal'
mod = None

position_embedding = 'learned'
max_lanes = 20
num_category = 21
prob_th = 0.5
num_class = 21 # 1 bgd | 1 lanes

# top view
top_view_region = np.array([[-10, 103], [10, 103], [-10, 3], [10, 3]])
anchor_y_steps = np.linspace(3, 103, 25)
num_y_steps = len(anchor_y_steps)

# placeholder, not used
K = np.array([[1000., 0., 960.],
            [0., 1000., 640.],
            [0., 0., 1.]])

# persformer anchor
use_default_anchor = False

batch_size = 16
nepochs = 100

no_cuda = False
nworkers = 16

start_epoch = 0
channels_in = 3

# args input
test_mode = False # 'store_true' # TODO 
evaluate = False # TODO
resume = '' # resume latest saved run.

# tensorboard
no_tb = False

# print & save
print_freq = 50
save_freq = 50

# ddp setting
dist = True
sync_bn = True
cudnn = True

distributed = True
local_rank = None #TODO
gpu = 0
world_size = 1
nodes = 1

# for reload ckpt
eval_ckpt = ''
resume_from = ''
output_dir = 'openlane'
evaluate_case = ''
eval_freq = 8 # eval freq during training

save_json_path = None
save_root = 'work_dirs'
save_prefix = osp.join(os.getcwd(), save_root)
save_path = osp.join(save_prefix, output_dir)

================================================
FILE: config/_base_/base_res101_bs16xep100_apollo.py
================================================
import os
import os.path as osp
import numpy as np

# ========DATA SETTING======== #
dataset_name = 'apollo'
dataset = 'standard'

data_dir = osp.join('./data/apollosyn_gen-lanenet/data_splits', dataset)
dataset_dir = './data/apollosyn_gen-lanenet/Apollo_Sim_3D_Lane_Release'

output_dir = 'apollo'

rewrite_pred = True
save_best = False

output_dir = dataset_name

org_h = 1080
org_w = 1920
crop_y = 0

cam_height = 1.55
pitch = 3
fix_cam = False
pred_cam = False

model_name = 'LATR'
mod = None

ipm_h = 208
ipm_w = 128
resize_h = 360
resize_w = 480

mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]

K = np.array([[2015., 0., 960.],
            [0., 2015., 540.],
            [0., 0., 1.]])

position_embedding = 'learned'

max_lanes = 6
num_category = 2
prob_th = 0.5
num_class = 2 # 1 bgd | 1 lanes

batch_size = 16
nepochs = 210
nworkers = 16

# ddp setting
dist = True
sync_bn = True
cudnn = True

distributed = True
local_rank = None #TODO
gpu = 0
world_size = 1
nodes = 1

# for reload ckpt
eval_ckpt = ''
resume = '' # ckpt number as input
resume_from = '' # ckpt path as input

no_cuda = False

# tensorboard
no_tb = False

start_epoch = 0
channels_in = 3

# args input
test_mode = False # 'store_true' # TODO 
evaluate = False # TODO
evaluate_case = ''

# print & save
print_freq = 50
save_freq = 50
eval_freq = 20 # eval freq during training

# top view
top_view_region = np.array([[-10, 103], [10, 103], [-10, 3], [10, 3]])
anchor_y_steps = np.linspace(3, 103, 25)
num_y_steps = len(anchor_y_steps)

save_path = None
save_json_path = None



================================================
FILE: config/_base_/once_eval_config.json
================================================
{"side_range_l": -10, "side_range_h": 10, "fwd_range_l": 0, "fwd_range_h": 50, "height_range_l": 0, "height_range_h": 5, "res": 0.05, "lane_width_x": 30, "lane_width_y": 10, "iou_thresh": 0.3, "distance_thresh": 0.3, "process_num": 10, "score_l": 0.10, "score_h": 1, "score_step": 0.05, "exp_name": "evaluation"}

================================================
FILE: config/_base_/optimizer.py
================================================
# opt setting
optimizer = 'adam'
learning_rate = 2e-4

weight_decay = 0.001
lr_decay = False # TODO 'store_true'
niter = 900 # num of iter at starting learning rate
niter_decay = 400 # '# of iter to linearly decay learning rate to zero'
lr_policy = 'cosine'
gamma = 0.1 # multiplicative factor of learning rate decay
lr_decay_iters = 10 # multiply by a gamma every lr_decay_iters iterations
T_max = 8 # maximum number of iterations
T_0 = 8
T_mult = 2
eta_min = 1e-5 # minimum learning rate
clip_grad_norm = 35.0 # grad clipping
loss_threshold = 1e5



================================================
FILE: config/release_iccv/apollo_illu.py
================================================
import numpy as np
from mmcv.utils import Config
import os.path as osp

_base_ = [
    '../_base_/base_res101_bs16xep100_apollo.py',
    '../_base_/optimizer.py',
]

mod = 'release_iccv/apollo_illu'
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]


dataset_name = 'apollo'
dataset = 'illus_chg'
data_dir = osp.join('./data/apollosyn_gen-lanenet/data_splits', dataset)
dataset_dir = './data/apollosyn_gen-lanenet/Apollo_Sim_3D_Lane_Release'
output_dir = 'apollo'
num_category = 2
max_lanes = 6

T_max = 30
eta_min = 1e-6
clip_grad_norm = 20
nepochs = 210
eval_freq = 1

h_org, w_org = 1080, 1920

batch_size = 8
nworkers = 10
pos_threshold = 0.5
top_view_region = np.array([
    [-10, 103], [10, 103], [-10, 3], [10, 3]])
enlarge_length = 20
position_range = [
    top_view_region[0][0] - enlarge_length,
    top_view_region[2][1] - enlarge_length,
    -5,
    top_view_region[1][0] + enlarge_length,
    top_view_region[0][1] + enlarge_length,
    5.]
anchor_y_steps = np.linspace(3, 103, 20)
num_y_steps = len(anchor_y_steps)

photo_aug = dict(
    brightness_delta=32,
    contrast_range=(0.5, 1.5),
    saturation_range=(0.5, 1.5),
    hue_delta=18)

_dim_ = 256
num_query = 12
num_pt_per_line = 20
latr_cfg = dict(
    fpn_dim = _dim_,
    num_query = num_query,
    num_group = 1,
    sparse_num_group = 4,
    encoder = dict(
        type='ResNet',
        depth=50,
        num_stages=4,
        out_indices=(1, 2, 3),
        frozen_stages=-1,
        norm_cfg=dict(type='BN2d', requires_grad=False),
        norm_eval=True,
        style='caffe',
        dcn=dict(type='DCNv2', deform_groups=1, fallback_on_stride=False),
        stage_with_dcn=(False, False, True, True),
        init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50')
    ),
    neck = dict(
        type='FPN',
        in_channels=[512, 1024, 2048],
        out_channels=_dim_,
        start_level=0,
        add_extra_convs='on_output',
        num_outs=4,
        relu_before_extra_convs=True
    ),
    head=dict(
        xs_loss_weight=2.0,
        zs_loss_weight=10.0,
        vis_loss_weight=1.0,
        cls_loss_weight=10,
        project_loss_weight=1.0,
        pt_as_query=True,
        num_pt_per_line=num_pt_per_line,
    ),
    trans_params=dict(init_z=0, bev_h=150, bev_w=70),
)

ms2one=dict(
    type='DilateNaive',
    inc=_dim_, outc=_dim_, num_scales=4,
    dilations=(1, 2, 5, 9))

transformer=dict(
    type='LATRTransformer',
    decoder=dict(
        type='LATRTransformerDecoder',
        embed_dims=_dim_,
        num_layers=6,
        enlarge_length=enlarge_length,
        M_decay_ratio=1,
        num_query=num_query,
        num_anchor_per_query=num_pt_per_line,
        anchor_y_steps=anchor_y_steps,
        transformerlayers=dict(
            type='LATRDecoderLayer',
            attn_cfgs=[
                dict(
                    type='MultiheadAttention',
                    embed_dims=_dim_,
                    num_heads=4,
                    dropout=0.1),
                dict(
                    type='MSDeformableAttention3D',
                    embed_dims=_dim_,
                    num_heads=4,
                    num_levels=1,
                    num_points=8,
                    batch_first=False,
                    num_query=num_query,
                    num_anchor_per_query=num_pt_per_line,
                    anchor_y_steps=anchor_y_steps,
                    dropout=0.1),
                ],
            ffn_cfgs=dict(
                type='FFN',
                embed_dims=_dim_,
                feedforward_channels=_dim_*8,
                num_fcs=2,
                ffn_drop=0.1,
                act_cfg=dict(type='ReLU', inplace=True),
            ),
            feedforward_channels=_dim_ * 8,
            operation_order=('self_attn', 'norm', 'cross_attn', 'norm',
                            'ffn', 'norm')),
))

sparse_ins_decoder=Config(
    dict(
        encoder=dict(
            out_dims=_dim_),
        decoder=dict(
            num_query=latr_cfg['num_query'],
            num_group=latr_cfg['num_group'],
            sparse_num_group=latr_cfg['sparse_num_group'],
            hidden_dim=_dim_,
            kernel_dim=_dim_,
            num_classes=num_category,
            num_convs=4,
            output_iam=True,
            scale_factor=1.,
            ce_weight=2.0,
            mask_weight=5.0,
            dice_weight=2.0,
            objectness_weight=1.0,
        ),
        sparse_decoder_weight=5.0,
))

resize_h = 720
resize_w = 960

optimizer_cfg = dict(
    type='AdamW',
    lr=2e-4,
    weight_decay=0.01)

================================================
FILE: config/release_iccv/apollo_rare.py
================================================
import numpy as np
from mmcv.utils import Config
import os.path as osp

_base_ = [
    '../_base_/base_res101_bs16xep100_apollo.py',
    '../_base_/optimizer.py',
]

mod = 'release_iccv/apollo_rare'
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]


dataset_name = 'apollo'
dataset = 'rare_subset'
data_dir = osp.join('./data/apollosyn_gen-lanenet/data_splits', dataset)
dataset_dir = './data/apollosyn_gen-lanenet/Apollo_Sim_3D_Lane_Release'
output_dir = 'apollo'
num_category = 2
max_lanes = 6

T_max = 30
eta_min = 1e-8
clip_grad_norm = 20
nepochs = 210
eval_freq = 1

h_org, w_org = 1080, 1920

batch_size = 8
nworkers = 10
pos_threshold = 0.5
top_view_region = np.array([
    [-10, 103], [10, 103], [-10, 3], [10, 3]])
enlarge_length = 20
position_range = [
    top_view_region[0][0] - enlarge_length,
    top_view_region[2][1] - enlarge_length,
    -5,
    top_view_region[1][0] + enlarge_length,
    top_view_region[0][1] + enlarge_length,
    5.]
anchor_y_steps = np.linspace(3, 103, 20)
num_y_steps = len(anchor_y_steps)

photo_aug = dict(
    brightness_delta=32,
    contrast_range=(0.5, 1.5),
    saturation_range=(0.5, 1.5),
    hue_delta=18)

_dim_ = 256
num_query = 12
num_pt_per_line = 20
latr_cfg = dict(
    fpn_dim = _dim_,
    num_query = num_query,
    num_group = 1,
    sparse_num_group = 4,
    encoder = dict(
        type='ResNet',
        depth=50,
        num_stages=4,
        out_indices=(1, 2, 3),
        frozen_stages=-1,
        norm_cfg=dict(type='BN2d', requires_grad=False),
        norm_eval=True,
        style='caffe',
        dcn=dict(type='DCNv2', deform_groups=1, fallback_on_stride=False),
        stage_with_dcn=(False, False, True, True),
        init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50')
    ),
    neck = dict(
        type='FPN',
        in_channels=[512, 1024, 2048],
        out_channels=_dim_,
        start_level=0,
        add_extra_convs='on_output',
        num_outs=4,
        relu_before_extra_convs=True
    ),
    head=dict(
        xs_loss_weight=2.0,
        zs_loss_weight=10.0,
        vis_loss_weight=1.0,
        cls_loss_weight=10,
        project_loss_weight=1.0,
        pt_as_query=True,
        num_pt_per_line=num_pt_per_line,
    ),
    trans_params=dict(init_z=0, bev_h=150, bev_w=70),
)

ms2one=dict(
    type='DilateNaive',
    inc=_dim_, outc=_dim_, num_scales=4,
    dilations=(1, 2, 5, 9))

transformer=dict(
    type='LATRTransformer',
    decoder=dict(
        type='LATRTransformerDecoder',
        embed_dims=_dim_,
        num_layers=6,
        enlarge_length=enlarge_length,
        M_decay_ratio=1,
        num_query=num_query,
        num_anchor_per_query=num_pt_per_line,
        anchor_y_steps=anchor_y_steps,
        transformerlayers=dict(
            type='LATRDecoderLayer',
            attn_cfgs=[
                dict(
                    type='MultiheadAttention',
                    embed_dims=_dim_,
                    num_heads=4,
                    dropout=0.1),
                dict(
                    type='MSDeformableAttention3D',
                    embed_dims=_dim_,
                    num_heads=4,
                    num_levels=1,
                    num_points=8,
                    batch_first=False,
                    num_query=num_query,
                    num_anchor_per_query=num_pt_per_line,
                    anchor_y_steps=anchor_y_steps,
                    dropout=0.1),
                ],
            ffn_cfgs=dict(
                type='FFN',
                embed_dims=_dim_,
                feedforward_channels=_dim_*8,
                num_fcs=2,
                ffn_drop=0.1,
                act_cfg=dict(type='ReLU', inplace=True),
            ),
            feedforward_channels=_dim_ * 8,
            operation_order=('self_attn', 'norm', 'cross_attn', 'norm',
                            'ffn', 'norm')),
))

sparse_ins_decoder=Config(
    dict(
        encoder=dict(
            out_dims=_dim_),
        decoder=dict(
            num_query=latr_cfg['num_query'],
            num_group=latr_cfg['num_group'],
            sparse_num_group=latr_cfg['sparse_num_group'],
            hidden_dim=_dim_,
            kernel_dim=_dim_,
            num_classes=num_category,
            num_convs=4,
            output_iam=True,
            scale_factor=1.,
            ce_weight=2.0,
            mask_weight=5.0,
            dice_weight=2.0,
            objectness_weight=1.0,
        ),
        sparse_decoder_weight=5.0,
))

resize_h = 720
resize_w = 960

optimizer_cfg = dict(
    type='AdamW',
    lr=2e-4,
    weight_decay=0.01)

================================================
FILE: config/release_iccv/apollo_standard.py
================================================
import numpy as np
from mmcv.utils import Config
import os.path as osp

_base_ = [
    '../_base_/base_res101_bs16xep100_apollo.py',
    '../_base_/optimizer.py',
]

mod = 'release_iccv/apollo_standard'
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]


dataset_name = 'apollo'
dataset = 'standard'
data_dir = osp.join('./data/apollosyn_gen-lanenet/data_splits', dataset)
dataset_dir = './data/apollosyn_gen-lanenet/Apollo_Sim_3D_Lane_Release'
output_dir = 'apollo'
num_category = 2
max_lanes = 6

T_max = 30
eta_min = 1e-6
clip_grad_norm = 20
nepochs = 210
eval_freq = 1

h_org, w_org = 1080, 1920

batch_size = 8
nworkers = 10
pos_threshold = 0.3
top_view_region = np.array([
    [-10, 103], [10, 103], [-10, 3], [10, 3]])
enlarge_length = 20
position_range = [
    top_view_region[0][0] - enlarge_length,
    top_view_region[2][1] - enlarge_length,
    -5,
    top_view_region[1][0] + enlarge_length,
    top_view_region[0][1] + enlarge_length,
    5.]
anchor_y_steps = np.linspace(3, 103, 20)
num_y_steps = len(anchor_y_steps)

_dim_ = 256
num_query = 12
num_pt_per_line = 20
latr_cfg = dict(
    fpn_dim = _dim_,
    num_query = num_query,
    num_group = 1,
    sparse_num_group = 4,
    encoder = dict(
        type='ResNet',
        depth=50,
        num_stages=4,
        out_indices=(1, 2, 3),
        frozen_stages=-1,
        norm_cfg=dict(type='BN2d', requires_grad=False),
        norm_eval=True,
        style='caffe',
        dcn=dict(type='DCNv2', deform_groups=1, fallback_on_stride=False),
        stage_with_dcn=(False, False, True, True),
        init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50')
    ),
    neck = dict(
        type='FPN',
        in_channels=[512, 1024, 2048],
        out_channels=_dim_,
        start_level=0,
        add_extra_convs='on_output',
        num_outs=4,
        relu_before_extra_convs=True
    ),
    head=dict(
        xs_loss_weight=2.0,
        zs_loss_weight=10.0,
        vis_loss_weight=1.0,
        cls_loss_weight=10,
        project_loss_weight=1.0,
        pt_as_query=True,
        num_pt_per_line=num_pt_per_line,
    ),
    trans_params=dict(init_z=0, bev_h=150, bev_w=70),
)

ms2one=dict(
    type='DilateNaive',
    inc=_dim_, outc=_dim_, num_scales=4,
    dilations=(1, 2, 5, 9))

transformer=dict(
    type='LATRTransformer',
    decoder=dict(
        type='LATRTransformerDecoder',
        embed_dims=_dim_,
        num_layers=6,
        enlarge_length=enlarge_length,
        M_decay_ratio=1,
        num_query=num_query,
        num_anchor_per_query=num_pt_per_line,
        anchor_y_steps=anchor_y_steps,
        transformerlayers=dict(
            type='LATRDecoderLayer',
            attn_cfgs=[
                dict(
                    type='MultiheadAttention',
                    embed_dims=_dim_,
                    num_heads=4,
                    dropout=0.1),
                dict(
                    type='MSDeformableAttention3D',
                    embed_dims=_dim_,
                    num_heads=4,
                    num_levels=1,
                    num_points=8,
                    batch_first=False,
                    num_query=num_query,
                    num_anchor_per_query=num_pt_per_line,
                    anchor_y_steps=anchor_y_steps,
                    dropout=0.1),
                ],
            ffn_cfgs=dict(
                type='FFN',
                embed_dims=_dim_,
                feedforward_channels=_dim_*8,
                num_fcs=2,
                ffn_drop=0.1,
                act_cfg=dict(type='ReLU', inplace=True),
            ),
            feedforward_channels=_dim_ * 8,
            operation_order=('self_attn', 'norm', 'cross_attn', 'norm',
                            'ffn', 'norm')),
))

sparse_ins_decoder=Config(
    dict(
        encoder=dict(
            out_dims=_dim_),
        decoder=dict(
            num_query=latr_cfg['num_query'],
            num_group=latr_cfg['num_group'],
            sparse_num_group=latr_cfg['sparse_num_group'],
            hidden_dim=_dim_,
            kernel_dim=_dim_,
            num_classes=num_category,
            num_convs=4,
            output_iam=True,
            scale_factor=1.,
            ce_weight=2.0,
            mask_weight=5.0,
            dice_weight=2.0,
            objectness_weight=1.0,
        ),
        sparse_decoder_weight=5.0,
))

resize_h = 720
resize_w = 960
optimizer_cfg = dict(
    type='AdamW',
    lr=2e-4,
    weight_decay=0.01)

================================================
FILE: config/release_iccv/latr_1000_baseline.py
================================================
import numpy as np
from mmcv.utils import Config

_base_ = [
    '../_base_/base_res101_bs16xep100.py',
    '../_base_/optimizer.py'
]

mod = 'release_iccv/latr_1000_baseline'
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]

dataset = '1000'
dataset_dir = './data/openlane/images/'
data_dir = './data/openlane/lane3d_1000/'

batch_size = 8
nworkers = 10
num_category = 21
pos_threshold = 0.3
top_view_region = np.array([
    [-10, 103], [10, 103], [-10, 3], [10, 3]])
enlarge_length = 20
position_range = [
    top_view_region[0][0] - enlarge_length,
    top_view_region[2][1] - enlarge_length,
    -5,
    top_view_region[1][0] + enlarge_length,
    top_view_region[0][1] + enlarge_length,
    5.]
anchor_y_steps = np.linspace(3, 103, 20)
num_y_steps = len(anchor_y_steps)

# extra aug
photo_aug = dict(
    brightness_delta=32 // 2,
    contrast_range=(0.5, 1.5),
    saturation_range=(0.5, 1.5),
    hue_delta=9)

clip_grad_norm = 20.0

_dim_ = 256
num_query = 40
num_pt_per_line = 20
latr_cfg = dict(
    fpn_dim = _dim_,
    num_query = num_query,
    num_group = 1,
    sparse_num_group = 4,
    encoder = dict(
        type='ResNet',
        depth=50,
        num_stages=4,
        out_indices=(1, 2, 3),
        frozen_stages=1,
        norm_cfg=dict(type='BN2d', requires_grad=False),
        norm_eval=True,
        style='caffe',
        dcn=dict(type='DCNv2', deform_groups=1, fallback_on_stride=False),
        stage_with_dcn=(False, False, True, True),
        init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50')
    ),
    neck = dict(
        type='FPN',
        in_channels=[512, 1024, 2048],
        out_channels=_dim_,
        start_level=0,
        add_extra_convs='on_output',
        num_outs=4,
        relu_before_extra_convs=True
    ),
    head=dict(
        pt_as_query=True,
        num_pt_per_line=num_pt_per_line,
        xs_loss_weight=2.0,
        zs_loss_weight=10.0,
        vis_loss_weight=1.0,
        cls_loss_weight=10,
        project_loss_weight=1.0,
    ),
    trans_params=dict(init_z=0, bev_h=150, bev_w=70),
)

ms2one=dict(
    type='DilateNaive',
    inc=_dim_, outc=_dim_, num_scales=4,
    dilations=(1, 2, 5, 9))

transformer=dict(
    type='LATRTransformer',
    decoder=dict(
        type='LATRTransformerDecoder',
        embed_dims=_dim_,
        num_layers=6,
        enlarge_length=enlarge_length,
        M_decay_ratio=1,
        num_query=num_query,
        num_anchor_per_query=num_pt_per_line,
        anchor_y_steps=anchor_y_steps,
        transformerlayers=dict(
            type='LATRDecoderLayer',
            attn_cfgs=[
                dict(
                    type='MultiheadAttention',
                    embed_dims=_dim_,
                    num_heads=4,
                    dropout=0.1),
                dict(
                    type='MSDeformableAttention3D',
                    embed_dims=_dim_,
                    num_heads=4,
                    num_levels=1,
                    num_points=8,
                    batch_first=False,
                    num_query=num_query,
                    num_anchor_per_query=num_pt_per_line,
                    anchor_y_steps=anchor_y_steps,
                    dropout=0.1),
                ],
            ffn_cfgs=dict(
                type='FFN',
                embed_dims=_dim_,
                feedforward_channels=_dim_*8,
                num_fcs=2,
                ffn_drop=0.1,
                act_cfg=dict(type='ReLU', inplace=True),
            ),
            feedforward_channels=_dim_ * 8,
            operation_order=('self_attn', 'norm', 'cross_attn', 'norm',
                            'ffn', 'norm')),
))

sparse_ins_decoder=Config(
    dict(
        encoder=dict(
            out_dims=_dim_),
        decoder=dict(
            num_query=latr_cfg['num_query'],
            num_group=latr_cfg['num_group'],
            sparse_num_group=latr_cfg['sparse_num_group'],
            hidden_dim=_dim_,
            kernel_dim=_dim_,
            num_classes=num_category,
            num_convs=4,
            output_iam=True,
            scale_factor=1.,
            ce_weight=2.0,
            mask_weight=5.0,
            dice_weight=2.0,
            objectness_weight=1.0,
        ),
        sparse_decoder_weight=5.0,
))

nepochs = 24
resize_h = 720
resize_w = 960

eval_freq = 8
optimizer_cfg = dict(
    type='AdamW',
    lr=2e-4,
    paramwise_cfg=dict(
        custom_keys={
            'sampling_offsets': dict(lr_mult=0.1),
        }),
    weight_decay=0.01)

================================================
FILE: config/release_iccv/latr_1000_baseline_lite.py
================================================
import numpy as np
from mmcv.utils import Config

_base_ = [
    '../_base_/base_res101_bs16xep100.py',
    '../_base_/optimizer.py'
]

mod = 'release_iccv/latr_1000_baseline_lite'
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]

dataset = '1000'
dataset_dir = './data/openlane/images/'
data_dir = './data/openlane/lane3d_1000/'

batch_size = 8
nworkers = 10
num_category = 21
pos_threshold = 0.3

clip_grad_norm = 20

top_view_region = np.array([
    [-10, 103], [10, 103], [-10, 3], [10, 3]])
enlarge_length = 20
position_range = [
    top_view_region[0][0] - enlarge_length,
    top_view_region[2][1] - enlarge_length,
    -5,
    top_view_region[1][0] + enlarge_length,
    top_view_region[0][1] + enlarge_length,
    5.]
anchor_y_steps = np.linspace(3, 103, 20)
num_y_steps = len(anchor_y_steps)

# extra aug
photo_aug = dict(
    brightness_delta=32 // 2,
    contrast_range=(0.5, 1.5),
    saturation_range=(0.5, 1.5),
    hue_delta=9)

_dim_ = 256
num_query = 40
num_pt_per_line = 20
latr_cfg = dict(
    fpn_dim = _dim_,
    num_query = num_query,
    num_group = 1,
    sparse_num_group = 4,
    encoder = dict(
        type='ResNet',
        depth=50,
        num_stages=4,
        out_indices=(1, 2, 3),
        frozen_stages=1,
        norm_cfg=dict(type='BN2d', requires_grad=False),
        norm_eval=True,
        style='caffe',
        dcn=dict(type='DCNv2', deform_groups=1, fallback_on_stride=False),
        stage_with_dcn=(False, False, True, True),
        init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50')
    ),
    neck = dict(
        type='FPN',
        in_channels=[512, 1024, 2048],
        out_channels=_dim_,
        start_level=0,
        add_extra_convs='on_output',
        num_outs=4,
        relu_before_extra_convs=True
    ),
    head=dict(
        pt_as_query=True,
        num_pt_per_line=num_pt_per_line,
        xs_loss_weight=2.0,
        zs_loss_weight=10.0,
        vis_loss_weight=1.0,
        cls_loss_weight=10,
        project_loss_weight=1.0,
    ),
    trans_params=dict(init_z=0, bev_h=150, bev_w=70),
)

ms2one=dict(
    type='DilateNaive',
    inc=_dim_, outc=_dim_, num_scales=4,
    dilations=(1, 2, 5, 9))

transformer=dict(
    type='LATRTransformer',
    decoder=dict(
        type='LATRTransformerDecoder',
        embed_dims=_dim_,
        num_layers=2,
        enlarge_length=enlarge_length,
        M_decay_ratio=1,
        num_query=num_query,
        num_anchor_per_query=num_pt_per_line,
        anchor_y_steps=anchor_y_steps,
        transformerlayers=dict(
            type='LATRDecoderLayer',
            attn_cfgs=[
                dict(
                    type='MultiheadAttention',
                    embed_dims=_dim_,
                    num_heads=4,
                    dropout=0.1),
                dict(
                    type='MSDeformableAttention3D',
                    embed_dims=_dim_,
                    num_heads=4,
                    num_levels=1,
                    num_points=8,
                    batch_first=False,
                    num_query=num_query,
                    num_anchor_per_query=num_pt_per_line,
                    anchor_y_steps=anchor_y_steps,
                    dropout=0.1),
                ],
            ffn_cfgs=dict(
                type='FFN',
                embed_dims=_dim_,
                feedforward_channels=_dim_*8,
                num_fcs=2,
                ffn_drop=0.1,
                act_cfg=dict(type='ReLU', inplace=True),
            ),
            feedforward_channels=_dim_ * 8,
            operation_order=('self_attn', 'norm', 'cross_attn', 'norm',
                            'ffn', 'norm')),
))

sparse_ins_decoder=Config(
    dict(
        encoder=dict(
            out_dims=_dim_),
        decoder=dict(
            num_query=latr_cfg['num_query'],
            num_group=latr_cfg['num_group'],
            sparse_num_group=latr_cfg['sparse_num_group'],
            hidden_dim=_dim_,
            kernel_dim=_dim_,
            num_classes=num_category,
            num_convs=4,
            output_iam=True,
            scale_factor=1.,
            ce_weight=2.0,
            mask_weight=5.0,
            dice_weight=2.0,
            objectness_weight=1.0,
        ),
        sparse_decoder_weight=5.0,
))


resize_h = 720
resize_w = 960

nepochs = 24
eval_freq = 8
optimizer_cfg = dict(
    type='AdamW',
    lr=2e-4,
    paramwise_cfg=dict(
        custom_keys={
            'sampling_offsets': dict(lr_mult=0.1),
        }),
    weight_decay=0.01)

================================================
FILE: config/release_iccv/once.py
================================================
import numpy as np
from mmcv.utils import Config
import os.path as osp

_base_ = [
    '../_base_/base_res101_bs16xep100.py',
    '../_base_/optimizer.py'
]

mod = 'release_iccv/once'
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]


dataset = 'once'
dataset_name = 'once'
data_dir = 'data/once/'
dataset_dir = 'data/once/data/'
eval_config_dir = 'config/_base_/once_eval_config.json'

save_path = osp.join('./work_dirs', dataset)

max_lanes = 8
num_pt_per_line = 20

eta_min = 1e-6
clip_grad_norm = 20

batch_size = 8
nworkers = 10
num_category = 2
pos_threshold = 0.3

top_view_region = np.array([
    [-10, 65], [10, 65], [-10, 0.5], [10, 0.5]])

enlarge_length = 10
position_range = [
    top_view_region[0][0] - enlarge_length,
    top_view_region[2][1] - enlarge_length,
    -5,
    top_view_region[1][0] + enlarge_length,
    top_view_region[0][1] + enlarge_length,
    5.]

anchor_y_steps = np.linspace(0.5, 65, num_pt_per_line)
num_y_steps = len(anchor_y_steps)

_dim_ = 256
num_query = 12
num_pt_per_line = 20
latr_cfg = dict(
    fpn_dim = _dim_,
    num_query = num_query,
    num_group = 1,
    sparse_num_group = 4,
    encoder = dict(
        type='ResNet',
        depth=50,
        num_stages=4,
        out_indices=(1, 2, 3),
        frozen_stages=1,
        norm_cfg=dict(type='BN2d', requires_grad=False),
        norm_eval=True,
        style='caffe',
        # with_cp=True,
        dcn=dict(type='DCNv2', deform_groups=1, fallback_on_stride=False),
        stage_with_dcn=(False, False, True, True),
        init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50')
    ),
    neck = dict(
        type='FPN',
        in_channels=[512, 1024, 2048],
        out_channels=_dim_,
        start_level=0,
        add_extra_convs='on_output',
        num_outs=4,
        relu_before_extra_convs=True
    ),
    head=dict(
        xs_loss_weight=2.0,
        zs_loss_weight=10.0,
        vis_loss_weight=1.0,
        cls_loss_weight=10,
        project_loss_weight=1.0,
        pt_as_query=True,
        num_pt_per_line=num_pt_per_line,
    ),
    trans_params=dict(init_z=0, bev_h=150, bev_w=70),
)

ms2one=dict(
    type='DilateNaive',
    inc=_dim_, outc=_dim_, num_scales=4,
    dilations=(1, 2, 5, 9))

transformer=dict(
    type='LATRTransformer',
    decoder=dict(
        type='LATRTransformerDecoder',
        embed_dims=_dim_,
        num_layers=6,
        enlarge_length=enlarge_length,
        M_decay_ratio=1,
        num_query=num_query,
        num_anchor_per_query=num_pt_per_line,
        anchor_y_steps=anchor_y_steps,
        transformerlayers=dict(
            type='LATRDecoderLayer',
            attn_cfgs=[
                dict(
                    type='MultiheadAttention',
                    embed_dims=_dim_,
                    num_heads=4,
                    dropout=0.1),
                dict(
                    type='MSDeformableAttention3D',
                    embed_dims=_dim_,
                    num_heads=4,
                    num_levels=1,
                    num_points=8,
                    batch_first=False,
                    num_query=num_query,
                    num_anchor_per_query=num_pt_per_line,
                    anchor_y_steps=anchor_y_steps,
                    dropout=0.1),
                ],
            ffn_cfgs=dict(
                type='FFN',
                embed_dims=_dim_,
                feedforward_channels=_dim_*8,
                num_fcs=2,
                ffn_drop=0.1,
                act_cfg=dict(type='ReLU', inplace=True),
            ),
            feedforward_channels=_dim_ * 8,
            operation_order=('self_attn', 'norm', 'cross_attn', 'norm',
                            'ffn', 'norm')),
))

sparse_ins_decoder=Config(
    dict(
        encoder=dict(
            out_dims=_dim_),
        decoder=dict(
            num_query=latr_cfg['num_query'],
            num_group=latr_cfg['num_group'],
            sparse_num_group=latr_cfg['sparse_num_group'],
            hidden_dim=_dim_,
            kernel_dim=_dim_,
            num_classes=num_category,
            num_convs=4,
            output_iam=True,
            scale_factor=1.,
            ce_weight=2.0,
            mask_weight=5.0,
            dice_weight=2.0,
            objectness_weight=1.0,
        ),
        sparse_decoder_weight=5.0,
))

resize_h = 720
resize_w = 960
nepochs = 24
eval_freq = 8

optimizer_cfg = dict(
    type='AdamW',
    lr=2e-4,
    paramwise_cfg=dict(
        custom_keys={
            'sampling_offsets': dict(lr_mult=0.1),
        }),
    weight_decay=0.01)

================================================
FILE: data/Load_Data.py
================================================
import re
import os
import sys
import copy
import json
import glob
import random
import warnings
import numpy as np
import cv2
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import torchvision.transforms.functional as F
from torchvision.transforms import InterpolationMode
from utils.utils import *
from experiments.gpu_utils import is_main_process

from .transform import PhotoMetricDistortionMultiViewImage

sys.path.append('./')
warnings.simplefilter('ignore', np.RankWarning)
matplotlib.use('Agg')

import yaml

class LaneDataset(Dataset):
    """
    Dataset with labeled lanes
        This implementation considers:
        w/o laneline 3D attributes
        w/o centerline annotations
        default considers 3D laneline, including centerlines

        This new version of data loader prepare ground-truth anchor tensor in flat ground space.
        It is assumed the dataset provides accurate visibility labels. Preparing ground-truth tensor depends on it.
    """
    # dataset_base_dir is image path, json_file_path is json file path,
    def __init__(self, dataset_base_dir, json_file_path, args, data_aug=False):
        """

        :param dataset_info_file: json file list
        """
        self.totensor = transforms.ToTensor()
        mean = [0.485, 0.456, 0.406] if args.mean is None else args.mean
        std = [0.229, 0.224, 0.225] if args.std is None else args.std
        self.normalize = transforms.Normalize(mean, std)
        self.data_aug = data_aug
        if data_aug:
            if hasattr(args, 'photo_aug'):
                self.photo_aug = PhotoMetricDistortionMultiViewImage(**args.photo_aug)
            else:
                self.photo_aug = False

        self.dataset_base_dir = dataset_base_dir
        self.json_file_path = json_file_path

        # dataset parameters
        self.dataset_name = args.dataset_name
        self.num_category = args.num_category

        self.h_org = args.org_h
        self.w_org = args.org_w
        self.h_crop = args.crop_y

        # parameters related to service network
        self.h_net = args.resize_h
        self.w_net = args.resize_w
        self.u_ratio = float(self.w_net) / float(self.w_org)
        self.v_ratio = float(self.h_net) / float(self.h_org - self.h_crop)
        self.top_view_region = args.top_view_region
        self.max_lanes = args.max_lanes

        self.K = args.K
        self.H_crop = homography_crop_resize([args.org_h, args.org_w], args.crop_y, [args.resize_h, args.resize_w])

        if args.fix_cam:
            self.fix_cam = True
            # compute the homography between image and IPM, and crop transformation
            self.cam_height = args.cam_height
            self.cam_pitch = np.pi / 180 * args.pitch
            self.P_g2im = projection_g2im(self.cam_pitch, self.cam_height, args.K)
        else:
            self.fix_cam = False

        # compute anchor steps
        self.use_default_anchor = args.use_default_anchor
        
        self.x_min, self.x_max = self.top_view_region[0, 0], self.top_view_region[1, 0]
        self.y_min, self.y_max = self.top_view_region[2, 1], self.top_view_region[0, 1]
        
        self.anchor_y_steps = args.anchor_y_steps
        self.num_y_steps = len(self.anchor_y_steps)

        self.anchor_y_steps_dense = args.get(
            'anchor_y_steps_dense',
            np.linspace(3, 103, 200))
        args.anchor_y_steps_dense = self.anchor_y_steps_dense
        self.num_y_steps_dense = len(self.anchor_y_steps_dense)
        self.anchor_dim = 3 * self.num_y_steps + args.num_category
        self.save_json_path = args.save_json_path

        # parse ground-truth file
        if 'openlane' in self.dataset_name:
            label_list = glob.glob(json_file_path + '**/*.json', recursive=True)
            self._label_list = label_list
        elif 'once' in self.dataset_name:
            label_list = glob.glob(json_file_path + '*/*/*.json', recursive=True)
            self._label_list = []
            for js_label_file in label_list:
                if not os.path.getsize(js_label_file):
                    continue
                image_path = map_once_json2img(js_label_file)
                if not os.path.exists(image_path):
                    continue
                self._label_list.append(js_label_file)
        else: 
            raise ValueError("to use ApolloDataset for apollo")
        
        if hasattr(self, '_label_list'):
            self.n_samples = len(self._label_list)
        else:
            self.n_samples = self._label_image_path.shape[0]

    def preprocess_data_from_json_once(self, idx_json_file):
        _label_image_path = None
        _label_cam_height = None
        _label_cam_pitch = None
        cam_extrinsics = None
        cam_intrinsics = None
        _label_laneline_org = None
        _gt_laneline_category_org = None

        image_path = map_once_json2img(idx_json_file)

        assert ops.exists(image_path), '{:s} not exist'.format(image_path)
        _label_image_path = image_path

        with open(idx_json_file, 'r') as file:
            file_lines = [line for line in file]
            if len(file_lines) != 0:
                info_dict = json.loads(file_lines[0])
            else:
                print('Empty label_file:', idx_json_file)
                return

            if not self.fix_cam:
                cam_pitch = 0.3/180*np.pi
                cam_height = 1.5
                cam_extrinsics = np.array([[np.cos(cam_pitch), 0, -np.sin(cam_pitch), 0],
                                            [0, 1, 0, 0],
                                            [np.sin(cam_pitch), 0,  np.cos(cam_pitch), cam_height],
                                            [0, 0, 0, 1]], dtype=float)
                R_vg = np.array([[0, 1, 0],
                                    [-1, 0, 0],
                                    [0, 0, 1]], dtype=float)
                R_gc = np.array([[1, 0, 0],
                                    [0, 0, 1],
                                    [0, -1, 0]], dtype=float)
                cam_extrinsics[:3, :3] = np.matmul(np.matmul(
                                            np.matmul(np.linalg.inv(R_vg), cam_extrinsics[:3, :3]),
                                                R_vg), R_gc)
                cam_extrinsics[0:2, 3] = 0.0

                gt_cam_height = cam_extrinsics[2, 3] 
                gt_cam_pitch = 0

                if 'calibration' in info_dict:
                    cam_intrinsics = info_dict['calibration']
                    cam_intrinsics = np.array(cam_intrinsics)
                    cam_intrinsics = cam_intrinsics[:, :3]
                else:
                    cam_intrinsics = self.K

            _label_cam_height = gt_cam_height
            _label_cam_pitch = gt_cam_pitch

            gt_lanes_packed = info_dict['lanes']
            gt_lane_pts, gt_lane_visibility, gt_laneline_category = [], [], []
            for i, gt_lane_packed in enumerate(gt_lanes_packed):
                lane = np.array(gt_lane_packed).T

                # Coordinate convertion for openlane_300 data
                lane = np.vstack((lane, np.ones((1, lane.shape[1]))))
                lane = np.matmul(cam_extrinsics, lane)

                lane = lane[0:3, :].T
                lane = lane[lane[:,1].argsort()] #TODO:make y mono increase
                gt_lane_pts.append(lane)
                gt_lane_visibility.append(1.0)
                gt_laneline_category.append(1)

        _gt_laneline_category_org = copy.deepcopy(np.array(gt_laneline_category))

        if not self.fix_cam:
            cam_K = cam_intrinsics
            if 'openlane' in self.dataset_name or 'once' in self.dataset_name:
                cam_E = cam_extrinsics
                P_g2im = projection_g2im_extrinsic(cam_E, cam_K)
                H_g2im = homograpthy_g2im_extrinsic(cam_E, cam_K)
            else:
                gt_cam_height = _label_cam_height
                gt_cam_pitch = _label_cam_pitch
                P_g2im = projection_g2im(gt_cam_pitch, gt_cam_height, cam_K)
                H_g2im = homograpthy_g2im(gt_cam_pitch, gt_cam_height, cam_K)
            H_im2g = np.linalg.inv(H_g2im)
        else:
            P_g2im = self.P_g2im
            H_im2g = self.H_im2g
        P_g2gflat = np.matmul(H_im2g, P_g2im)

        gt_lanes = gt_lane_pts
        gt_visibility = gt_lane_visibility
        gt_category = gt_laneline_category

        # prune gt lanes by visibility labels
        gt_lanes = [prune_3d_lane_by_visibility(gt_lane, gt_visibility[k]).squeeze(0) for k, gt_lane in enumerate(gt_lanes)]
        _label_laneline_org = copy.deepcopy(gt_lanes)
        return _label_image_path, _label_cam_height, _label_cam_pitch, \
               cam_extrinsics, cam_intrinsics, \
               _label_laneline_org, \
               _gt_laneline_category_org, info_dict
               #    _label_laneline, \
               #    _gt_laneline_visibility, _gt_laneline_category, \

    def preprocess_data_from_json_openlane(self, idx_json_file):
        _label_image_path = None
        _label_cam_height = None
        _label_cam_pitch = None
        cam_extrinsics = None
        cam_intrinsics = None
        # _label_laneline = None
        _label_laneline_org = None
        # _gt_laneline_visibility = None
        # _gt_laneline_category = None
        _gt_laneline_category_org = None
        # _laneline_ass_id = None

        with open(idx_json_file, 'r') as file:
            file_lines = [line for line in file]
            info_dict = json.loads(file_lines[0])

            image_path = ops.join(self.dataset_base_dir, info_dict['file_path'])
            assert ops.exists(image_path), '{:s} not exist'.format(image_path)
            _label_image_path = image_path

            if not self.fix_cam:
                cam_extrinsics = np.array(info_dict['extrinsic'])
                # Re-calculate extrinsic matrix based on ground coordinate
                R_vg = np.array([[0, 1, 0],
                                    [-1, 0, 0],
                                    [0, 0, 1]], dtype=float)
                R_gc = np.array([[1, 0, 0],
                                    [0, 0, 1],
                                    [0, -1, 0]], dtype=float)
                cam_extrinsics[:3, :3] = np.matmul(np.matmul(
                                            np.matmul(np.linalg.inv(R_vg), cam_extrinsics[:3, :3]),
                                                R_vg), R_gc)
                cam_extrinsics[0:2, 3] = 0.0
                
                # gt_cam_height = info_dict['cam_height']
                gt_cam_height = cam_extrinsics[2, 3]
                if 'cam_pitch' in info_dict:
                    gt_cam_pitch = info_dict['cam_pitch']
                else:
                    gt_cam_pitch = 0

                if 'intrinsic' in info_dict:
                    cam_intrinsics = info_dict['intrinsic']
                    cam_intrinsics = np.array(cam_intrinsics)
                else:
                    cam_intrinsics = self.K  

            _label_cam_height = gt_cam_height
            _label_cam_pitch = gt_cam_pitch

            gt_lanes_packed = info_dict['lane_lines']
            gt_lane_pts, gt_lane_visibility, gt_laneline_category = [], [], []
            for i, gt_lane_packed in enumerate(gt_lanes_packed):
                # A GT lane can be either 2D or 3D
                # if a GT lane is 3D, the height is intact from 3D GT, so keep it intact here too
                lane = np.array(gt_lane_packed['xyz'])
                lane_visibility = np.array(gt_lane_packed['visibility'])

                # Coordinate convertion for openlane_300 data
                lane = np.vstack((lane, np.ones((1, lane.shape[1]))))
                cam_representation = np.linalg.inv(
                                        np.array([[0, 0, 1, 0],
                                                    [-1, 0, 0, 0],
                                                    [0, -1, 0, 0],
                                                    [0, 0, 0, 1]], dtype=float))  # transformation from apollo camera to openlane camera
                lane = np.matmul(cam_extrinsics, np.matmul(cam_representation, lane))

                lane = lane[0:3, :].T
                gt_lane_pts.append(lane)
                gt_lane_visibility.append(lane_visibility)

                if 'category' in gt_lane_packed:
                    lane_cate = gt_lane_packed['category']
                    if lane_cate == 21:  # merge left and right road edge into road edge
                        lane_cate = 20
                    gt_laneline_category.append(lane_cate)
                else:
                    gt_laneline_category.append(1)
        
        # _label_laneline_org = copy.deepcopy(gt_lane_pts)
        _gt_laneline_category_org = copy.deepcopy(np.array(gt_laneline_category))

        gt_lanes = gt_lane_pts
        gt_visibility = gt_lane_visibility
        gt_category = gt_laneline_category

        # prune gt lanes by visibility labels
        gt_lanes = [prune_3d_lane_by_visibility(gt_lane, gt_visibility[k]) for k, gt_lane in enumerate(gt_lanes)]
        _label_laneline_org = copy.deepcopy(gt_lanes)

        return _label_image_path, _label_cam_height, _label_cam_pitch, \
               cam_extrinsics, cam_intrinsics, \
               _label_laneline_org, \
               _gt_laneline_category_org, info_dict

    def __len__(self):
        """
        Conventional len method
        """
        return self.n_samples

    # new getitem, WIP
    def WIP__getitem__(self, idx):
        """
        Args: idx (int): Index in list to load image
        """
        extra_dict = {}

        idx_json_file = self._label_list[idx]
        # preprocess data from json file
        if 'openlane' in self.dataset_name:
            _label_image_path, _label_cam_height, _label_cam_pitch, \
            cam_extrinsics, cam_intrinsics, \
            _label_laneline_org, \
            _gt_laneline_category_org, info_dict = self.preprocess_data_from_json_openlane(idx_json_file)
        elif 'once' in self.dataset_name:
            _label_image_path, _label_cam_height, _label_cam_pitch, \
            cam_extrinsics, cam_intrinsics, \
            _label_laneline_org, \
            _gt_laneline_category_org, info_dict = self.preprocess_data_from_json_once(idx_json_file)

        # fetch camera height and pitch
        if not self.fix_cam:
            gt_cam_height = _label_cam_height
            gt_cam_pitch = _label_cam_pitch
            if 'openlane' in self.dataset_name or 'once' in self.dataset_name:
                intrinsics = cam_intrinsics
                extrinsics = cam_extrinsics
            else:
                # should not be used
                intrinsics = self.K
                extrinsics = np.zeros((3,4))
                extrinsics[2,3] = gt_cam_height
        else:
            gt_cam_height = self.cam_height
            gt_cam_pitch = self.cam_pitch
            # should not be used
            intrinsics = self.K
            extrinsics = np.zeros((3,4))
            extrinsics[2,3] = gt_cam_height

        img_name = _label_image_path
        with open(img_name, 'rb') as f:
            image = (Image.open(f).convert('RGB'))

        # image preprocess with crop and resize
        image = F.crop(image, self.h_crop, 0, self.h_org-self.h_crop, self.w_org)
        image = F.resize(image, size=(self.h_net, self.w_net), interpolation=InterpolationMode.BILINEAR)

        gt_category_2d = _gt_laneline_category_org
        if self.data_aug:
            img_rot, aug_mat = data_aug_rotate(image)
            if self.photo_aug:
                img_rot = self.photo_aug(
                    dict(img=img_rot.copy().astype(np.float32))
                )['img']
            image = Image.fromarray(
                np.clip(img_rot, 0, 255).astype(np.uint8))
        image = self.totensor(image).float()
        image = self.normalize(image)
        intrinsics = torch.from_numpy(intrinsics)
        extrinsics = torch.from_numpy(extrinsics)

        # prepare binary segmentation label map
        seg_label = np.zeros((self.h_net, self.w_net), dtype=np.int8)
        # seg idx has the same order as gt_lanes
        seg_idx_label = np.zeros((self.max_lanes, self.h_net, self.w_net), dtype=np.uint8)
        ground_lanes = np.zeros((self.max_lanes, self.anchor_dim), dtype=np.float32)
        ground_lanes_dense = np.zeros(
            (self.max_lanes, self.num_y_steps_dense * 3), dtype=np.float32)
        gt_lanes = _label_laneline_org # ground
        gt_laneline_img = [[0]] * len(gt_lanes)

        H_g2im, P_g2im, H_crop = self.transform_mats_impl(cam_extrinsics, \
                                            cam_intrinsics, _label_cam_pitch, _label_cam_height)
        M = np.matmul(H_crop, P_g2im)
        # update transformation with image augmentation
        if self.data_aug:
            M = np.matmul(aug_mat, M)

        lidar2img = np.eye(4).astype(np.float32)
        lidar2img[:3] = M

        SEG_WIDTH = 80
        thickness = int(SEG_WIDTH / 2650 * self.h_net / 2)

        for i, lane in enumerate(gt_lanes):
            if i >= self.max_lanes:
                break

            if lane.shape[0] <= 2:
                continue

            if _gt_laneline_category_org[i] >= self.num_category:
                continue

            xs, zs = resample_laneline_in_y(lane, self.anchor_y_steps)
            vis = np.logical_and(
                self.anchor_y_steps > lane[:, 1].min() - 5,
                self.anchor_y_steps < lane[:, 1].max() + 5)

            ground_lanes[i][0: self.num_y_steps] = xs
            ground_lanes[i][self.num_y_steps:2*self.num_y_steps] = zs
            ground_lanes[i][2*self.num_y_steps:3*self.num_y_steps] = vis * 1.0
            ground_lanes[i][self.anchor_dim - self.num_category] = 0.0
            ground_lanes[i][self.anchor_dim - self.num_category + _gt_laneline_category_org[i]] = 1.0

            xs_dense, zs_dense = resample_laneline_in_y(
                lane, self.anchor_y_steps_dense)
            vis_dense = np.logical_and(
                self.anchor_y_steps_dense > lane[:, 1].min(),
                self.anchor_y_steps_dense < lane[:, 1].max())
            ground_lanes_dense[i][0: self.num_y_steps_dense] = xs_dense
            ground_lanes_dense[i][1*self.num_y_steps_dense: 2*self.num_y_steps_dense] = zs_dense
            ground_lanes_dense[i][2*self.num_y_steps_dense: 3*self.num_y_steps_dense] = vis_dense * 1.0

            x_2d, y_2d = projective_transformation(M, lane[:, 0],
                                                   lane[:, 1], lane[:, 2])
            gt_laneline_img[i] = np.array([x_2d, y_2d]).T.tolist()
            for j in range(len(x_2d) - 1):
                seg_label = cv2.line(seg_label,
                                     (int(x_2d[j]), int(y_2d[j])), (int(x_2d[j+1]), int(y_2d[j+1])),
                                     color=np.asscalar(np.array([1])),
                                     thickness=thickness)
                seg_idx_label[i] = cv2.line(
                    seg_idx_label[i],
                    (int(x_2d[j]), int(y_2d[j])), (int(x_2d[j+1]), int(y_2d[j+1])),
                    color=gt_category_2d[i].item(),
                    thickness=thickness)

        seg_label = torch.from_numpy(seg_label.astype(np.float32))
        seg_label.unsqueeze_(0)
        extra_dict['seg_label'] = seg_label
        extra_dict['seg_idx_label'] = seg_idx_label
        extra_dict['ground_lanes'] = ground_lanes
        extra_dict['ground_lanes_dense'] = ground_lanes_dense
        extra_dict['lidar2img'] = lidar2img
        extra_dict['pad_shape'] = torch.Tensor(seg_idx_label.shape[-2:]).float()
        extra_dict['idx_json_file'] = idx_json_file
        extra_dict['image'] = image
        if self.data_aug:
            aug_mat = torch.from_numpy(aug_mat.astype(np.float32))
            extra_dict['aug_mat'] = aug_mat
        return extra_dict

    # old getitem, workable
    def __getitem__(self, idx):
        """
        Args: idx (int): Index in list to load image
        """
        return self.WIP__getitem__(idx)

    def transform_mats_impl(self, cam_extrinsics, cam_intrinsics, cam_pitch, cam_height):
        if not self.fix_cam:
            if 'openlane' in self.dataset_name or 'once' in self.dataset_name:
                H_g2im = homograpthy_g2im_extrinsic(cam_extrinsics, cam_intrinsics)
                P_g2im = projection_g2im_extrinsic(cam_extrinsics, cam_intrinsics)
            else:
                H_g2im = homograpthy_g2im(cam_pitch, cam_height, self.K)
                P_g2im = projection_g2im(cam_pitch, cam_height, self.K)
            return H_g2im, P_g2im, self.H_crop
        else:
            return self.H_g2im, self.P_g2im, self.H_crop

def make_lane_y_mono_inc(lane):
    """
        Due to lose of height dim, projected lanes to flat ground plane may not have monotonically increasing y.
        This function trace the y with monotonically increasing y, and output a pruned lane
    :param lane:
    :return:
    """
    idx2del = []
    max_y = lane[0, 1]
    for i in range(1, lane.shape[0]):
        # hard-coded a smallest step, so the far-away near horizontal tail can be pruned
        if lane[i, 1] <= max_y + 3:
            idx2del.append(i)
        else:
            max_y = lane[i, 1]
    lane = np.delete(lane, idx2del, 0)
    return lane

def data_aug_rotate(img):
    # assume img in PIL image format
    rot = random.uniform(-np.pi/18, np.pi/18)
    center_x = img.width / 2
    center_y = img.height / 2
    rot_mat = cv2.getRotationMatrix2D((center_x, center_y), rot, 1.0)
    img_rot = np.array(img)
    img_rot = cv2.warpAffine(img_rot, rot_mat, (img.width, img.height), flags=cv2.INTER_LINEAR)
    rot_mat = np.vstack([rot_mat, [0, 0, 1]])
    return img_rot, rot_mat


def seed_worker(worker_id):
    worker_seed = torch.initial_seed() % 2**32
    np.random.seed(worker_seed)
    random.seed(worker_seed)


def get_loader(transformed_dataset, args):
    """
        create dataset from ground-truth
        return a batch sampler based ont the dataset
    """

    # transformed_dataset = LaneDataset(dataset_base_dir, json_file_path, args)
    sample_idx = range(transformed_dataset.n_samples)

    g = torch.Generator()
    g.manual_seed(0)

    discarded_sample_start = len(sample_idx) // args.batch_size * args.batch_size
    if is_main_process():
        print("Discarding images:")
        if hasattr(transformed_dataset, '_label_image_path'):
            print(transformed_dataset._label_image_path[discarded_sample_start: len(sample_idx)])
        else:
            print(len(sample_idx) - discarded_sample_start)
    sample_idx = sample_idx[0 : discarded_sample_start]
    
    if args.dist:
        if is_main_process():
            print('use distributed sampler')
        if 'standard' in args.dataset_name or 'rare_subset' in args.dataset_name or 'illus_chg' in args.dataset_name:
            data_sampler = torch.utils.data.distributed.DistributedSampler(transformed_dataset, shuffle=True, drop_last=True)
            data_loader = DataLoader(transformed_dataset,
                                        batch_size=args.batch_size, 
                                        sampler=data_sampler,
                                        num_workers=args.nworkers, 
                                        pin_memory=True,
                                        persistent_workers=args.nworkers > 0,
                                        worker_init_fn=seed_worker,
                                        generator=g,
                                        drop_last=True)
        else:
            data_sampler = torch.utils.data.distributed.DistributedSampler(transformed_dataset)
            data_loader = DataLoader(transformed_dataset,
                                        batch_size=args.batch_size, 
                                        sampler=data_sampler,
                                        num_workers=args.nworkers, 
                                        pin_memory=True,
                                        persistent_workers=args.nworkers > 0,
                                        worker_init_fn=seed_worker,
                                        generator=g)
    else:
        if is_main_process():
            print("use default sampler")
        data_sampler = torch.utils.data.sampler.SubsetRandomSampler(sample_idx)
        data_loader = DataLoader(transformed_dataset,
                                batch_size=args.batch_size, sampler=data_sampler,
                                num_workers=args.nworkers, pin_memory=True,
                                persistent_workers=args.nworkers > 0,
                                worker_init_fn=seed_worker,
                                generator=g)

    if args.dist:
        return data_loader, data_sampler
    return data_loader

def map_once_json2img(json_label_file):
    if 'train' in json_label_file:
        split_name = 'train'
    elif 'val' in json_label_file:
        split_name = 'val'
    elif 'test' in json_label_file:
        split_name = 'test'
    else:
        raise ValueError("train/val/test not in the json path")
    image_path = json_label_file.replace(split_name, 'data').replace('.json', '.jpg')
    return image_path


================================================
FILE: data/__init__.py
================================================


================================================
FILE: data/apollo_dataset.py
================================================
# ==============================================================================
# Copyright (c) 2022 The PersFormer Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

import re
import os
import sys
import copy
import json
import glob
import random
import pickle
import warnings
from pathlib import Path
import numpy as np
from numpy import int32#, result_type
import cv2
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import torchvision.transforms.functional as F
from torchvision.transforms import InterpolationMode
from utils.utils import *
sys.path.append('./')
warnings.simplefilter('ignore', np.RankWarning)
matplotlib.use('Agg')

from .transform import PhotoMetricDistortionMultiViewImage 

from tqdm import tqdm


class ApolloLaneDataset(Dataset):
    def __init__(self, dataset_base_dir, json_file_path, args, data_aug=False, **kwargs):
        # define image pre-processor
        self.totensor = transforms.ToTensor()
        # expect same mean/std for all torchvision models
        mean = [0.485, 0.456, 0.406] if args.mean is None else args.mean
        std = [0.229, 0.224, 0.225] if args.std is None else args.std
        self.normalize = transforms.Normalize(mean, std)

        self.data_aug = data_aug
        if data_aug:
            if hasattr(args, 'photo_aug'):
                self.photo_aug = PhotoMetricDistortionMultiViewImage(**args.photo_aug)
            else:
                self.photo_aug = False
        
        self.dataset_base_dir = dataset_base_dir
        self.json_file_path = json_file_path

        # dataset parameters
        self.dataset_name = args.dataset_name
        self.num_category = args.num_category

        self.h_org = args.org_h
        self.w_org = args.org_w
        self.h_crop = args.crop_y

        # parameters related to service network
        self.h_net = args.resize_h
        self.w_net = args.resize_w
        self.ipm_h = args.ipm_h
        self.ipm_w = args.ipm_w
        self.u_ratio = float(self.w_net) / float(self.w_org)
        self.v_ratio = float(self.h_net) / float(self.h_org - self.h_crop)
        self.top_view_region = args.top_view_region
        
        self.max_lanes = args.max_lanes

        self.K = args.K
        self.H_crop = homography_crop_resize([args.org_h, args.org_w], args.crop_y, [args.resize_h, args.resize_w])
        self.fix_cam = False
        
        self.x_min, self.x_max = self.top_view_region[0, 0], self.top_view_region[1, 0]
        self.y_min, self.y_max = self.top_view_region[2, 1], self.top_view_region[0, 1]
        
        self.anchor_y_steps = args.anchor_y_steps
        self.num_y_steps = len(self.anchor_y_steps)

        self.anchor_y_steps_dense = args.get(
            'anchor_y_steps_dense',
            np.linspace(3, 103, 200))
        args.anchor_y_steps_dense = self.anchor_y_steps_dense
        self.num_y_steps_dense = len(self.anchor_y_steps_dense)

        self.anchor_dim = 3 * self.num_y_steps + args.num_category

        self.save_json_path = args.save_json_path

        # parse ground-truth file
        self.processed_info_dict = None

        self.label_list = self.gen_single_file_json()
        self.n_samples = len(self.label_list)
        self.processed_info_dict = self.init_dataset_3D(dataset_base_dir, json_file_path)

    def gen_single_file_json(self):
        gt_labels_json_dict = [json.loads(line) for line in open(self.json_file_path, 'r').readlines()]
        
        # e.g., xxx/standard/train
        json_save_dir = self.json_file_path.split('.json')[0]

        mkdir_if_missing(json_save_dir)
        
        label_list_path = self.json_file_path.rsplit('/', 1)[0] + '/%s_json_list.txt' % self.json_file_path.rsplit('/', 1)[-1].split('.json')[0]

        if os.path.isfile(label_list_path):
            with open(label_list_path, 'r') as f:
                label_list = f.readlines()
                label_list = list(map(lambda x: os.path.join(json_save_dir, x.strip()), label_list))
        else:
            label_list = []
            
            for single_info in tqdm(gt_labels_json_dict):
                img_p = Path(single_info['raw_file'])
                json_dir = os.path.join(json_save_dir, img_p.parent.name)
                mkdir_if_missing(json_dir)
                json_p = os.path.join(json_dir, img_p.stem + '.json')
                single_info['file_path'] = json_p.split(json_save_dir + '/')[-1]
                json.dump(single_info, open(json_p, 'w'), separators=(',', ': '), indent=4)
                label_list.append(json_p)
        
            with open(label_list_path, 'w') as f:
                for label_js in label_list:
                    f.write(label_js)
                    f.write('\n')
        
        return label_list
    
        
    def parse_processed_info_dict_apollo(self, idx):
        keys = self.processed_info_dict.keys()
        keys = list(keys)
        res = []
        for k in keys:
            res.append(self.processed_info_dict[k][idx])
        return res


    def __len__(self):
        """
        Conventional len method
        """
        return len(self.label_list)

    # new getitem, WIP
    def WIP__getitem__(self, idx):
        """
        Args: idx (int): Index in list to load image
        """
        
        # preprocess data from json file

        _label_image_path, _label_cam_height, _label_cam_pitch, \
        cam_extrinsics, cam_intrinsics, \
        _label_laneline, _label_laneline_org, \
        _gt_laneline_visibility, _gt_laneline_category, \
        _gt_laneline_category_org, gt_laneline_img = self.parse_processed_info_dict_apollo(idx)
        
        if not self.fix_cam:
            gt_cam_height = _label_cam_height
            gt_cam_pitch = _label_cam_pitch
            intrinsics = cam_intrinsics
            extrinsics = cam_extrinsics
        else:
            raise ValueError('check release with training, fix_cam=False')
        img_name = _label_image_path

        with open(img_name, 'rb') as f:
            image = (Image.open(f).convert('RGB'))

        # image preprocess with crop and resize
        image = F.crop(image, self.h_crop, 0, self.h_org-self.h_crop, self.w_org)
        image = F.resize(image, size=(self.h_net, self.w_net), interpolation=InterpolationMode.BILINEAR)

        gt_category_2d = _gt_laneline_category_org

        if self.data_aug:
            img_rot, aug_mat = data_aug_rotate(image)
            if self.photo_aug:
                img_rot = self.photo_aug(
                    dict(img=img_rot.copy().astype(np.float32))
                )['img']
            
            image = Image.fromarray(np.clip(img_rot, 0, 255).astype(np.uint8))
        image = self.totensor(image).float()
        image = self.normalize(image)
        gt_cam_height = torch.tensor(gt_cam_height, dtype=torch.float32)
        gt_cam_pitch = torch.tensor(gt_cam_pitch, dtype=torch.float32)
        intrinsics = torch.from_numpy(intrinsics)
        extrinsics = torch.from_numpy(extrinsics)

        # prepare binary segmentation label map
        seg_label = np.zeros((self.h_net, self.w_net), dtype=np.uint8)
        seg_idx_label = np.zeros((self.max_lanes, self.h_net, self.w_net), dtype=np.uint8)
        ground_lanes = np.zeros((self.max_lanes, self.anchor_dim), dtype=np.float32)
        ground_lanes_dense = np.zeros(
            (self.max_lanes, self.num_y_steps_dense * 3), dtype=np.float32)
        
        gt_lanes = _label_laneline_org
        H_g2im, P_g2im, H_crop = self.transform_mats_impl(_label_cam_pitch, 
                                                                    _label_cam_height)
        M = np.matmul(H_crop, P_g2im)
        # update transformation with image augmentation
        if self.data_aug:
            M = np.matmul(aug_mat, M)
        
        lidar2img = np.eye(4).astype(np.float32)
        lidar2img[:3] = M
            
        SEG_WIDTH = 80
        thickness_st = int(SEG_WIDTH / 2550 * self.h_net / 2)

        for i, lane in enumerate(gt_lanes):
            if i >= self.max_lanes:
                break

            # TODO remove this
            if lane.shape[0] < 2:
                continue

            if _gt_laneline_category_org[i] > self.num_category:
                continue

            xs, zs = resample_laneline_in_y(lane, self.anchor_y_steps)
            vis = np.logical_and(
                self.anchor_y_steps > lane[:, 1].min() - 5,
                self.anchor_y_steps < lane[:, 1].max() + 5)

            ground_lanes[i][0: self.num_y_steps] = xs
            ground_lanes[i][self.num_y_steps:2*self.num_y_steps] = zs
            ground_lanes[i][2*self.num_y_steps:3*self.num_y_steps] = vis * 1.0
            ground_lanes[i][self.anchor_dim - self.num_category] = 0.0
            ground_lanes[i][self.anchor_dim - self.num_category + 1] = 1.0

            xs_dense, zs_dense = resample_laneline_in_y(
                lane, self.anchor_y_steps_dense)
            vis_dense = np.logical_and(
                self.anchor_y_steps_dense > lane[:, 1].min(),
                self.anchor_y_steps_dense < lane[:, 1].max())
            ground_lanes_dense[i][0: self.num_y_steps_dense] = xs_dense
            ground_lanes_dense[i][1*self.num_y_steps_dense: 2*self.num_y_steps_dense] = zs_dense
            ground_lanes_dense[i][2*self.num_y_steps_dense: 3*self.num_y_steps_dense] = vis_dense * 1.0

            x_2d, y_2d = projective_transformation(M, 
                                                   lane[:, 0],
                                                   lane[:, 1], 
                                                   lane[:, 2])
            
            for j in range(len(x_2d) - 1):
                # empirical setting.
                k = 2.7e-2 - ((2.5e-2 - 5e-5) / 600) * y_2d[j]
                thickness = max(round(thickness_st - k * (self.h_net - y_2d[j])), 2)
                if thickness >= 6:
                    thickness += 1

                seg_label = cv2.line(seg_label,
                                    (int(x_2d[j]), int(y_2d[j])), 
                                    (int(x_2d[j+1]), int(y_2d[j+1])),
                                    color=1,
                                    thickness=thickness)
                seg_idx_label[i] = cv2.line(seg_idx_label[i],
                                        (int(x_2d[j]), int(y_2d[j])),
                                        (int(x_2d[j+1]), int(y_2d[j+1])),
                                        color=gt_category_2d[i].item(),
                                        thickness=thickness,
                                        lineType=cv2.LINE_AA
                                        )

        seg_label = torch.from_numpy(seg_label.astype(np.float32))
        seg_label.unsqueeze_(0)
        
        extra_dict = {}
        
        extra_dict['seg_label'] = seg_label
        extra_dict['seg_idx_label'] = seg_idx_label
        extra_dict['ground_lanes'] = ground_lanes
        extra_dict['ground_lanes_dense'] = ground_lanes_dense
        extra_dict['lidar2img'] = lidar2img
        extra_dict['pad_shape'] = torch.Tensor(seg_idx_label.shape[-2:]).float()
        extra_dict['idx_json_file'] = self.label_list[idx]

        extra_dict['image'] = image
        if self.data_aug:
            aug_mat = torch.from_numpy(aug_mat.astype(np.float32))
            extra_dict['aug_mat'] = aug_mat
        
        extra_dict['cam_extrinsics'] = cam_extrinsics
        extra_dict['cam_intrinsics'] = cam_intrinsics
        return extra_dict

    # old getitem, workable
    def __getitem__(self, idx):
        """
        Args: idx (int): Index in list to load image
        """
        return self.WIP__getitem__(idx)

    def init_dataset_3D(self, dataset_base_dir, json_file_path):
        """
        :param dataset_info_file:
        :return: image paths, labels in unormalized net input coordinates

        data processing:
        ground truth labels map are scaled wrt network input sizes
        """

        # load image path, and lane pts
        label_image_path = []
        gt_laneline_pts_all = []
        gt_centerline_pts_all = []
        gt_laneline_visibility_all = []
        gt_centerline_visibility_all = []
        gt_laneline_category_all = []
        gt_cam_height_all = []
        gt_cam_pitch_all = []

        assert ops.exists(json_file_path), '{:s} not exist'.format(json_file_path)
        
        with open(json_file_path, 'r') as file:
            for idx, line in enumerate(file):
                
                info_dict = json.loads(line)
                # print('load json : %s | %s' % (idx, info_dict['raw_file']))
                image_path = ops.join(dataset_base_dir, info_dict['raw_file'])
                assert ops.exists(image_path), '{:s} not exist'.format(image_path)

                label_image_path.append(image_path)

                gt_lane_pts = info_dict['laneLines']
                gt_lane_visibility = info_dict['laneLines_visibility']
                for i, lane in enumerate(gt_lane_pts):
                    lane = np.array(lane)
                    gt_lane_pts[i] = lane
                    gt_lane_visibility[i] = np.array(gt_lane_visibility[i])
                
                gt_laneline_pts_all.append(gt_lane_pts)
                gt_laneline_visibility_all.append(gt_lane_visibility)
                
                if 'category' in info_dict:
                    gt_laneline_category = info_dict['category']
                    gt_laneline_category_all.append(np.array(gt_laneline_category, dtype=np.int32))
                else:
                    gt_laneline_category_all.append(np.ones(len(gt_lane_pts), dtype=np.int32))

                if not self.fix_cam:
                    gt_cam_height = info_dict['cam_height']
                    gt_cam_height_all.append(gt_cam_height)
                    gt_cam_pitch = info_dict['cam_pitch']
                    gt_cam_pitch_all.append(gt_cam_pitch)
        
        label_image_path = np.array(label_image_path)
        gt_cam_height_all = np.array(gt_cam_height_all)
        gt_cam_pitch_all = np.array(gt_cam_pitch_all)
        gt_laneline_pts_all_org = copy.deepcopy(gt_laneline_pts_all)
        gt_laneline_category_all_org = copy.deepcopy(gt_laneline_category_all)
        
        visibility_all_flat = []
        gt_laneline_im_all = []
        gt_centerline_im_all = []
        cam_extrinsics_all = []
        cam_intrinsics_all = []
        for idx in range(len(gt_laneline_pts_all)):
            # fetch camera height and pitch
            gt_cam_height = gt_cam_height_all[idx]
            gt_cam_pitch = gt_cam_pitch_all[idx]
            if not self.fix_cam:
                P_g2im = projection_g2im(gt_cam_pitch, gt_cam_height, self.K)
                H_g2im = homograpthy_g2im(gt_cam_pitch, gt_cam_height, self.K)
                H_im2g = np.linalg.inv(H_g2im)
            else:
                P_g2im = self.P_g2im
                H_im2g = self.H_im2g

            gt_lanes = gt_laneline_pts_all[idx]
            gt_visibility = gt_laneline_visibility_all[idx]

            # prune gt lanes by visibility labels
            gt_lanes = [prune_3d_lane_by_visibility(gt_lane, gt_visibility[k]) for k, gt_lane in enumerate(gt_lanes)]
            gt_laneline_pts_all_org[idx] = gt_lanes
            
            # project gt laneline to image plane
            gt_laneline_im = []
            for gt_lane in gt_lanes:
                x_vals, y_vals = projective_transformation(P_g2im, gt_lane[:,0], gt_lane[:,1], gt_lane[:,2])
                gt_laneline_im_oneline = np.array([x_vals, y_vals]).T.tolist()
                gt_laneline_im.append(gt_laneline_im_oneline)
            gt_laneline_im_all.append(gt_laneline_im)

            # generate ex/in from apollo
            cam_intrinsics = self.K
            cam_extrinsics = np.zeros((4,4))
            cam_extrinsics[-1, -1] = 1
            cam_extrinsics[2,3] = gt_cam_height
            cam_extrinsics_all.append(cam_extrinsics)
            cam_intrinsics_all.append(cam_intrinsics)
        
        visibility_all_flat = np.array(visibility_all_flat)
        
        processed_info_dict = {}
        processed_info_dict['label_json_path'] = label_image_path
        
        processed_info_dict['gt_cam_height_all'] = gt_cam_height_all
        processed_info_dict['gt_cam_pitch_all'] = gt_cam_pitch_all
        processed_info_dict['cam_extrinsics'] = cam_extrinsics_all
        processed_info_dict['cam_intrinsics'] = cam_intrinsics_all

        processed_info_dict['gt_laneline_pts_all'] = gt_laneline_pts_all
        processed_info_dict['gt_laneline_pts_all_org'] = gt_laneline_pts_all_org
        processed_info_dict['gt_laneline_visibility_all'] = gt_laneline_visibility_all

        processed_info_dict['gt_laneline_category_all'] = gt_laneline_category_all
        processed_info_dict['gt_laneline_category_all_org'] = gt_laneline_category_all_org
        processed_info_dict['gt_laneline_im_all'] = gt_laneline_im_all
        return processed_info_dict

    def transform_mats_impl(self, cam_pitch, cam_height):
        if not self.fix_cam:
            H_g2im = homograpthy_g2im(cam_pitch, cam_height, self.K)
            P_g2im = projection_g2im(cam_pitch, cam_height, self.K)
            return H_g2im, P_g2im, self.H_crop
        else:
            return self.H_g2im, self.P_g2im, self.H_crop

def data_aug_rotate(img):
    # assume img in PIL image format
    rot = random.uniform(-np.pi/18, np.pi/18)
    center_x = img.width / 2
    center_y = img.height / 2
    rot_mat = cv2.getRotationMatrix2D((center_x, center_y), rot, 1.0)
    img_rot = np.array(img)
    img_rot = cv2.warpAffine(img_rot, rot_mat, (img.width, img.height), flags=cv2.INTER_LINEAR)
    rot_mat = np.vstack([rot_mat, [0, 0, 1]])
    return img_rot, rot_mat


def seed_worker(worker_id):
    worker_seed = torch.initial_seed() % 2**32
    np.random.seed(worker_seed)
    random.seed(worker_seed)


def get_loader(transformed_dataset, args):
    """
        create dataset from ground-truth
        return a batch sampler based ont the dataset
    """

    sample_idx = range(transformed_dataset.n_samples)

    g = torch.Generator()
    g.manual_seed(0)

    discarded_sample_start = len(sample_idx) // args.batch_size * args.batch_size
    if args.proc_id == 0:
        print("Discarding images:")
    if args.proc_id == 0:
        if hasattr(transformed_dataset, '_label_image_path'):
            print(transformed_dataset._label_image_path[discarded_sample_start: len(sample_idx)])
        else:
            print(len(sample_idx) - discarded_sample_start)
    sample_idx = sample_idx[0 : discarded_sample_start]
    
    if args.dist:
        if args.proc_id == 0:
            print('use distributed sampler')
        if 'standard' in args.dataset_name or 'rare_subset' in args.dataset_name or 'illus_chg' in args.dataset_name:
            data_sampler = torch.utils.data.distributed.DistributedSampler(transformed_dataset, shuffle=True, drop_last=True)
            data_loader = DataLoader(transformed_dataset,
                                        batch_size=args.batch_size, 
                                        sampler=data_sampler,
                                        num_workers=args.nworkers, 
                                        pin_memory=True,
                                        persistent_workers=True,
                                        worker_init_fn=seed_worker,
                                        generator=g,
                                        drop_last=True)
        else:
            data_sampler = torch.utils.data.distributed.DistributedSampler(transformed_dataset)
            data_loader = DataLoader(transformed_dataset,
                                        batch_size=args.batch_size, 
                                        sampler=data_sampler,
                                        num_workers=args.nworkers, 
                                        pin_memory=True,
                                        persistent_workers=True,
                                        worker_init_fn=seed_worker,
                                        generator=g)
    else:
        if args.proc_id == 0:
            print("use default sampler")
        data_sampler = torch.utils.data.sampler.SubsetRandomSampler(sample_idx)
        data_loader = DataLoader(transformed_dataset,
                                batch_size=args.batch_size, sampler=data_sampler,
                                num_workers=args.nworkers, pin_memory=True,
                                persistent_workers=True,
                                worker_init_fn=seed_worker,
                                generator=g)

    if args.dist:
        return data_loader, data_sampler
    return data_loader


================================================
FILE: data/transform.py
================================================
import numpy as np
import mmcv
import torch
import torch.nn.functional as F
import PIL
import random


def get_random_state() -> np.random.RandomState:
    return np.random.RandomState(random.randint(0, (1 << 32) - 1))


def normal(
    loc=0.0,
    scale=1.0,
    size=None,
    random_state=None,
):
    if random_state is None:
        random_state = get_random_state()
    return random_state.normal(loc, scale, size)


class PhotoMetricDistortionMultiViewImage:
    """Apply photometric distortion to image sequentially, every transformation
    is applied with a probability of 0.5. The position of random contrast is in
    second or second to last.
    1. random brightness
    2. random contrast (mode 0)
    3. convert color from BGR to HSV
    4. random saturation
    5. random hue
    6. convert color from HSV to BGR
    7. random contrast (mode 1)
    8. randomly swap channels
    Args:
        brightness_delta (int): delta of brightness.
        contrast_range (tuple): range of contrast.
        saturation_range (tuple): range of saturation.
        hue_delta (int): delta of hue.
    """

    def __init__(self,
                 brightness_delta=32,
                 contrast_range=(0.5, 1.5),
                 saturation_range=(0.5, 1.5),
                 hue_delta=18):
        self.brightness_delta = brightness_delta
        self.contrast_lower, self.contrast_upper = contrast_range
        self.saturation_lower, self.saturation_upper = saturation_range
        self.hue_delta = hue_delta

    def __call__(self, results):
        """Call function to perform photometric distortion on images.
        Args:
            results (dict): Result dict from loading pipeline.
        Returns:
            dict: Result dict with images distorted.
        """
        imgs = results['img']
        if not isinstance(imgs, list):
            imgs = [imgs]

        new_imgs = []
        for img in imgs:
            assert img.dtype == np.float32, \
                'PhotoMetricDistortion needs the input image of dtype np.float32,'\
                ' please set "to_float32=True" in "LoadImageFromFile" pipeline'
            # random brightness
            if np.random.randint(2):
                delta = np.random.uniform(-self.brightness_delta,
                                    self.brightness_delta)
                img += delta

            # mode == 0 --> do random contrast first
            # mode == 1 --> do random contrast last
            mode = np.random.randint(2)
            if mode == 1:
                if np.random.randint(2):
                    alpha = np.random.uniform(self.contrast_lower,
                                        self.contrast_upper)
                    img *= alpha

            # convert color from BGR to HSV
            img = mmcv.bgr2hsv(img)

            # random saturation
            if np.random.randint(2):
                img[..., 1] *= np.random.uniform(self.saturation_lower,
                                            self.saturation_upper)

            # random hue
            if np.random.randint(2):
                img[..., 0] += np.random.uniform(-self.hue_delta, self.hue_delta)
                img[..., 0][img[..., 0] > 360] -= 360
                img[..., 0][img[..., 0] < 0] += 360

            # convert color from HSV to BGR
            img = mmcv.hsv2bgr(img)

            # random contrast
            if mode == 0:
                if np.random.randint(2):
                    alpha = np.random.uniform(self.contrast_lower,
                                        self.contrast_upper)
                    img *= alpha

            # randomly swap channels
            if np.random.randint(2):
                img = img[..., np.random.permutation(3)]
            new_imgs.append(img)
        if not isinstance(results['img'], list):
            new_imgs = new_imgs[0]

        results['img'] = new_imgs
        return results


================================================
FILE: docs/data_preparation.md
================================================
# Data Preparation

## OpenLane

Follow [OpenLane](https://github.com/OpenDriveLab/PersFormer_3DLane#dataset) to download dataset and then link it under `data` directory.

```bash
cd data && mkdir openlane && cd openlane
ln -s ${OPENLANE_PATH}/images .
ln -s ${OPENLANE_PATH}/lane3d_1000 .
```

## ONCE

Follow [ONCE](https://github.com/once-3dlanes/once_3dlanes_benchmark#data-preparation) to download dataset, and then link it under `data` directory.

```bash
cd data
ln -s ${once} .
```

## Apollo

Follow [Apollo](https://github.com/yuliangguo/Pytorch_Generalized_3D_Lane_Detection#data-preparation) to download dataset and link it under `data` directory.

```bash
cd data && mkdir apollosyn_gen-lanenet
cd apollosyn_gen-lanenet
ln -s ${Apollo_Sim_3D_Lane_Release} .
ln -s ${data_splits} .
```


Your data directory should be like:

```bash
|-- apollosyn_gen-lanenet
    |-- Apollo_Sim_3D_Lane_Release
    |   |-- depth
    |   |-- images
    |   |-- img_list.txt
    |   |-- labels
    |   |-- laneline_label.json
    |   `-- segmentation
    `-- data_splits
        |-- illus_chg
        |-- rare_subset
        `-- standard
|-- Load_Data.py
|-- __init__.py
|-- apollo_dataset.py
|-- once -> ${once}
    |-- annotation
    |-- data
    |-- data_check.py
    |-- list
    |-- raw_cam01
    |-- raw_cam_multi
    |-- train
    `-- val
|-- openlane
|   |-- images -> ${openlane}/images/
        |-- training
        `-- validation
|   `-- lane3d_1000 -> ${openlane}/lane3d_1000/
        |-- test
        |-- training
        `-- validation
`-- transform.py
```

================================================
FILE: docs/install.md
================================================
# Environment

It is recommanded to build a new virtual environment.

## 1. Install pytorch and requirements.

```bash
# first install pytorch
conda install pytorch==1.8.0 torchvision==0.9.0 torchaudio==0.8.0 cudatoolkit=10.1 -c pytorch

# then clone LATR and change directory to it to install requirements
cd ${LATR_PATH}
python -m pip install -r requirements.txt
```

## 2. Install mm packages

### 2.1 Install `mmcv`

```bash
git clone https://github.com/open-mmlab/mmcv.git
cd mmcv && git checkout v1.5.0
FORCE_CUDA=1 MMCV_WITH_OPS=1 python -m pip install .
```

### 2.2 Install other mm packages

Install [mmdet](https://github.com/open-mmlab/mmdetection) and [mmdet3d](https://github.com/open-mmlab/mmdetection3d). Note that we use `mmdet==2.24.0` and `mmdet3d==1.0.0rc3`.


================================================
FILE: docs/train_eval.md
================================================
# Training and Evaluation

## Train

### Openlane

- Base version:

```bash
CUDA_VISIBLE_DEVICES='0,1,2,3' python -m torch.distributed.launch --nproc_per_node 4 main.py --config config/release_iccv/latr_1000_baseline.py
```

- lite version:

```bash
CUDA_VISIBLE_DEVICES='0,1,2,3' python -m torch.distributed.launch --nproc_per_node 4 main.py --config config/release_iccv/latr_1000_baseline_lite.py
```

### ONCE

```bash
CUDA_VISIBLE_DEVICES='0,1,2,3' python -m torch.distributed.launch --nproc_per_node 4 main.py --config config/release_iccv/once.py
```

### Apollo

- Balanced Scene

```bash
CUDA_VISIBLE_DEVICES='0,1,2,3' python -m torch.distributed.launch --nproc_per_node 4 main.py --config config/release_iccv/apollo_standard.py
```

- Rare Subset

```bash
CUDA_VISIBLE_DEVICES='0,1,2,3' python -m torch.distributed.launch --nproc_per_node 4 main.py --config config/release_iccv/apollo_rare.py
```

- Visual Variations

```bash
CUDA_VISIBLE_DEVICES='0,1,2,3' python -m torch.distributed.launch --master_port=29284 --nproc_per_node 4 main.py --config config/release_iccv/apollo_illu.py
```

## Evaluation

### Openlane

- Base version:

```bash
CUDA_VISIBLE_DEVICES='0,1,2,3' python -m torch.distributed.launch --nproc_per_node 4 main.py --config config/release_iccv/latr_1000_baseline.py --cfg-options evaluate=true eval_ckpt=pretrained_models/openlane.pth
```

- lite version:

```bash
CUDA_VISIBLE_DEVICES='0,1,2,3' python -m torch.distributed.launch --nproc_per_node 4 main.py --config config/release_iccv/latr_1000_baseline_lite.py --cfg-options evaluate=true eval_ckpt=pretrained_models/openlane_lite.pth
```

### ONCE

```bash
CUDA_VISIBLE_DEVICES='0,1,2,3' python -m torch.distributed.launch --nproc_per_node 4 main.py --config config/release_iccv/once.py --cfg-options evaluate=true eval_ckpt=pretrained_models/once.pth
```

### Apollo

- Balanced Scene

```bash
CUDA_VISIBLE_DEVICES='0,1,2,3' python -m torch.distributed.launch --nproc_per_node 4 main.py --config config/release_iccv/apollo_standard.py --cfg-options evaluate=true eval_ckpt=pretrained_models/apollo_standard.pth
```

- Rare Subset

```bash
CUDA_VISIBLE_DEVICES='0,1,2,3' python -m torch.distributed.launch --nproc_per_node 4 main.py --config config/release_iccv/apollo_rare.py --cfg-options evaluate=true eval_ckpt=pretrained_models/apollo_rare.pth
```

- Visual Variations

```bash
CUDA_VISIBLE_DEVICES='0,1,2,3' python -m torch.distributed.launch --master_port=29284 --nproc_per_node 4 main.py --config config/release_iccv/apollo_illu.py --cfg-options evaluate=true eval_ckpt=pretrained_models/apollo_illu.pth
```

================================================
FILE: experiments/__init__.py
================================================


================================================
FILE: experiments/ddp.py
================================================
# ==============================================================================
# Copyright (c) 2022 The PersFormer Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import os
import subprocess
import numpy as np
import random

def setup_dist_launch(args):
    args.proc_id = args.local_rank
    world_size = int(os.getenv('WORLD_SIZE', 1))*args.nodes
    print("proc_id: " + str(args.proc_id))
    print("world size: " + str(world_size))
    print("local_rank: " + str(args.local_rank))

    os.environ['WORLD_SIZE'] = str(world_size)
    os.environ['RANK'] = str(args.proc_id)
    os.environ['LOCAL_RANK'] = str(args.local_rank)

def setup_slurm(args):
    if mp.get_start_method(allow_none=True) is None:
        mp.set_start_method('spawn')

    args.proc_id = int(os.environ['SLURM_PROCID'])
    ntasks = int(os.environ['SLURM_NTASKS'])
    node_list = os.environ['SLURM_NODELIST']
    num_gpus = torch.cuda.device_count()
    local_rank = args.proc_id % num_gpus
    args.local_rank = local_rank

    print("proc_id: " + str(args.proc_id))
    print("world size: " + str(ntasks))
    print("local_rank: " + str(local_rank))

    addr = subprocess.getoutput(
        f'scontrol show hostname {node_list} | head -n1')
    os.environ['MASTER_PORT'] = str(args.port)
    os.environ['MASTER_ADDR'] = addr

    os.environ['WORLD_SIZE'] = str(ntasks)
    os.environ['RANK'] = str(args.proc_id)
    os.environ['LOCAL_RANK'] = str(local_rank)

def setup_distributed(args):
    args.gpu = args.local_rank
    torch.cuda.set_device(args.gpu)
    dist.init_process_group(backend='nccl')
    args.world_size = dist.get_world_size()
    torch.set_printoptions(precision=10)
    print('args.world_size', args.world_size)

def ddp_init(args):
    args.proc_id, args.gpu, args.world_size = 0, 0, 1

    if args.use_slurm == True:
        setup_slurm(args)
    else:
        setup_dist_launch(args)

    if 'WORLD_SIZE' in os.environ:
        args.distributed = int(os.environ['WORLD_SIZE']) >= 1

    if args.distributed:
        setup_distributed(args)

    # deterministic
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True
    torch.manual_seed(args.proc_id)
    np.random.seed(args.proc_id)
    random.seed(args.proc_id)

def to_python_float(t):
    if hasattr(t, 'item'):
        return t.item()
    else:
        return t[0]

def reduce_tensor(tensor, world_size):
    rt = tensor.clone()
    dist.all_reduce(rt, op=dist.ReduceOp.SUM)
    rt /= world_size
    return rt


def reduce_tensors(*tensors, world_size):
    return [reduce_tensor(tensor, world_size) for tensor in tensors]

================================================
FILE: experiments/gpu_utils.py
================================================
import torch
import torch.distributed as dist

def get_rank() -> int:
    if not dist.is_available():
        return 0
    if not dist.is_initialized():
        return 0
    return dist.get_rank()


def is_main_process() -> bool:
    return get_rank() == 0


def gpu_available() -> bool:
    return torch.cuda.is_available()

================================================
FILE: experiments/runner.py
================================================
import torch
import torch.optim
import torch.nn as nn
import numpy as np
import glob
import time
import os
from tqdm import tqdm
from tensorboardX import SummaryWriter
import traceback
import shutil

from data.Load_Data import *
from data.apollo_dataset import ApolloLaneDataset
from models.latr import LATR
from experiments.gpu_utils import is_main_process
from utils import eval_3D_lane, eval_3D_once
from utils import eval_3D_lane_apollo
from utils.utils import *

# ddp related
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from .ddp import *
import os.path as osp
from .gpu_utils import gpu_available
from mmcv.runner.optimizer import build_optimizer


class Runner:
    def __init__(self, args):
        self.args = args
        set_work_dir(self.args)
        self.logger = create_logger(args)

        # Check GPU availability
        if is_main_process():
            if not gpu_available():
                raise Exception("No gpu available for usage")
            if int(os.getenv('WORLD_SIZE', 1)) >= 1:
                self.logger.info("Let's use %s" % os.environ['WORLD_SIZE'] + "GPUs!")
                torch.cuda.empty_cache()
        
        # Get Dataset
        if is_main_process():
            self.logger.info("Loading Dataset ...")

        self.val_gt_file = ops.join(args.save_path, 'test.json')
        if not args.evaluate:
            self.train_dataset, self.train_loader, self.train_sampler = self._get_train_dataset()
        else:
            self.train_dataset, self.train_loader, self.train_sampler = [],[],[]
        self.valid_dataset, self.valid_loader, self.valid_sampler = self._get_valid_dataset()

        if 'openlane' in args.dataset_name:
            self.evaluator = eval_3D_lane.LaneEval(args, logger=self.logger)
        elif 'apollo' in args.dataset_name:
            self.evaluator = eval_3D_lane_apollo.LaneEval(args, logger=self.logger)
        elif 'once' in args.dataset_name:
            self.evaluator = eval_3D_once.LaneEval()
        else:
            assert False
        # Tensorboard writer
        if not args.no_tb and is_main_process():
            tensorboard_path = os.path.join(args.save_path, 'Tensorboard/')
            mkdir_if_missing(tensorboard_path)
            self.writer = SummaryWriter(tensorboard_path)
        
        if is_main_process():
            self.logger.info("Init Done!")
        
        self.is_apollo = False
        if 'apollo' in args.dataset_name:
            self.is_apollo = True

    def train(self):
        args = self.args

        # Get Dataset
        train_loader = self.train_loader
        train_sampler = self.train_sampler

        global lowest_loss, best_f1_epoch, best_val_f1, best_epoch
        # Define model or resume
        
        model, optimizer, scheduler, best_epoch, \
            lowest_loss, best_f1_epoch, best_val_f1 = self._get_model_ddp()
        
        self._log_model_info(model)
        
        def save_cur_ckpt(
                loss,
                with_eval=True,
                eval_stats=None):
            # Save model
            if not with_eval:
                self.save_checkpoint({
                    'state_dict': model.module.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'scheduler': scheduler.state_dict()
                }, False, epoch+1, self.args.save_path)
            else:
                total_score = loss.item() # loss_list[0].avg
                if is_main_process():
                    # File to keep latest epoch
                    with open(os.path.join(args.save_path, 'first_run.txt'), 'w') as f:
                        f.write(str(epoch + 1))
                global best_val_f1, best_f1_epoch, lowest_loss, best_epoch

                to_copy, to_save = False, True # False if args.save_best else True

                if total_score < lowest_loss:
                    best_epoch = epoch + 1
                    lowest_loss = total_score
                if eval_stats[0] > best_val_f1:
                    to_copy = True
                    best_f1_epoch = epoch + 1
                    best_val_f1 = eval_stats[0]
                    to_save = True
                self.log_eval_stats(eval_stats)
                self.logger.info("===> Last best F1 was {:.8f} in epoch {}".format(best_val_f1, best_f1_epoch))
                if not to_save:
                    return
                self.save_checkpoint({
                        'state_dict': model.module.state_dict(),
                        'optimizer': optimizer.state_dict(),
                        'scheduler': scheduler.state_dict()
                    }, to_copy, epoch+1, self.args.save_path)

        # Start training and validation for nepochs
        for epoch in range(args.start_epoch, args.nepochs):
            if is_main_process():
                self.logger.info("\n => Start train set for EPOCH {}".format(epoch + 1))
                self.logger.info('lr is set to {}'.format(optimizer.param_groups[0]['lr']))
            
            if args.distributed:
                train_sampler.set_epoch(epoch)

            # Define container objects to keep track of multiple losses/metrics
            batch_time = AverageMeter()
            data_time = AverageMeter()         # compute FPS
            epoch_time = AverageMeter()
            
            loss = 0

            # Specify operation modules
            model.train()
            # compute timing
            end = time.time()
            epoch_time.start = end
            # Start training loop
            train_pbar = tqdm(total=len(train_loader), ncols=60)
            
            for i, extra_dict in enumerate(train_loader):
                train_pbar.update(1)
                data_time.update(time.time() - end)
                if gpu_available():
                    json_files = extra_dict.pop('idx_json_file')
                    for k, v in extra_dict.items():
                        extra_dict[k] = v.cuda()
                    image = extra_dict['image']
                image = image.contiguous().float()
                # Run model
                optimizer.zero_grad()

                output = model(image=image, extra_dict=extra_dict, is_training=True)
                
                loss, loss_info = self._log_training_loss(
                    output, epoch, step=i, data_loader=train_loader)

                train_pbar.set_postfix(loss=loss.item())
                
                if is_main_process():
                    self.writer.add_scalar('lr', optimizer.param_groups[0]['lr'], epoch)
                
                # Setup backward pass
                loss.backward()

                # Clip gradients (usefull for instabilities or mistakes in ground truth)
                if args.clip_grad_norm != 0:
                    nn.utils.clip_grad_norm_(model.parameters(), args.clip_grad_norm)

                # update params
                optimizer.step()

                if args.lr_policy == 'cosine_warmup':
                    scheduler.step(epoch + i / len(train_loader))
                elif args.lr_policy == 'PolyLR':
                    scheduler.step()

                # Time trainig iteration
                batch_time.update(time.time() - end)
                end = time.time()

                # Print info
                if (i + 1) % args.print_freq == 0 and is_main_process():
                    self.logger.info('Epoch: [{0}][{1}/{2}]\t'
                        'Batch Time / Avg Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                        'Loss {loss:.8f} {loss_info}'.format(
                            epoch+1, i+1, len(train_loader), 
                            batch_time=batch_time, data_time=data_time,
                            loss=loss.item(), loss_info=loss_info))
            train_pbar.close()

            epoch_time.update(time.time() - epoch_time.start)

            if is_main_process():
                self.logger.info('Epoch time : {:.3f} hours.'.format(epoch_time.val / 60 / 60))
            
            # Adjust learning rate
            if args.lr_policy != 'cosine_warmup':
                scheduler.step()
            
            meet_eval_freq = args.eval_freq > 0 and (epoch + 1) % args.eval_freq == 0
            last_ep = (epoch == args.nepochs - 1)

            if meet_eval_freq or last_ep:
                loss_valid_list, eval_stats = self.validate(model)
                if eval_stats[0] >= best_val_f1:
                    self.logger.info(' >>> to save new best model at ep : %s with F1 %s' % ((epoch+1), eval_stats[0]))
                    save_cur_ckpt(loss, with_eval=True, eval_stats=eval_stats)
                elif last_ep:
                    self.logger.info(' >>> to save the last model at ep : %s with F1 %s' % ((epoch+1), eval_stats[0]))
                    save_cur_ckpt(loss, with_eval=True, eval_stats=eval_stats)
                else:
                    self.logger.info(' >>> skip model at ep : %s with lower F1 : %s' % ((epoch+1), eval_stats[0]))
            
                self.log_eval_stats(eval_stats)

            dist.barrier()
            torch.cuda.empty_cache()

        # at the end of training
        if not args.no_tb and is_main_process():
            self.writer.close()

    def _log_model_info(self, model):
        args = self.args
        if not is_main_process():
            return
        
        self.logger.info(40*"="+"\nArgs:{}\n".format(args)+40*"=")
        self.logger.info("Init model: '{}'".format(args.mod))
        self.logger.info("Number of parameters in model {} is {:.3f}M".format(args.mod, sum(tensor.numel() for tensor in model.parameters())/1e6))

    def _log_training_loss(self, output, epoch, step, data_loader):
        loss = 0.0
        loss_info = ''
        for k, v in output.items():
            if 'loss' in k:
                loss = loss + v
                loss_info = loss_info + '| %s:%.4f ' % (k, v.item() if isinstance(v, torch.Tensor) else v)
                if isinstance(v, torch.Tensor):
                    v = v.item()
                if is_main_process():
                    self.writer.add_scalar(k, v, epoch*len(data_loader) + step)
        return loss, loss_info

    def save_checkpoint(self, state, to_copy, epoch, save_path):
        if is_main_process():
            self.logger.info('Saving checkpoint to {}'.format(save_path))

            if to_copy:
                file_pre = f'model_best_epoch_{epoch}.pth.tar'
                self.logger.info('save the best model : %s' % epoch)
            else:
                file_pre = f'checkpoint_model_epoch_{epoch}.path.tar'

            filepath = os.path.join(save_path, file_pre)
            torch.save(state, filepath)

    def validate(self, model, **kwargs):
        args = self.args
        loader = self.valid_loader
        
        pred_lines_sub = []
        gt_lines_sub = []

        model.eval()

        # Start validation loop
        with torch.no_grad():
            val_pbar = tqdm(total=len(loader), ncols=50)
            
            for i, extra_dict in enumerate(loader):
                val_pbar.update(1)

                if not args.no_cuda:
                    json_files = extra_dict.pop('idx_json_file')
                    for k, v in extra_dict.items():
                        extra_dict[k] = v.cuda()
                    image = extra_dict['image']
                image = image.contiguous().float()
                
                output = model(image=image, extra_dict=extra_dict, is_training=False)
                all_line_preds = output['all_line_preds'] # in ground coordinate system
                all_cls_scores = output['all_cls_scores']

                all_line_preds = all_line_preds[-1]
                all_cls_scores = all_cls_scores[-1]
                num_el = all_cls_scores.shape[0]
                if 'cam_extrinsics' in extra_dict:
                    cam_extrinsics_all = extra_dict['cam_extrinsics']
                    cam_intrinsics_all = extra_dict['cam_intrinsics']
                else:
                    cam_extrinsics_all, cam_intrinsics_all = None, None

                # Print info
                if (i + 1) % args.print_freq == 0 and is_main_process():
                    self.logger.info('Test: [{0}/{1}]'.format(i+1, len(loader)))

                # Write results
                for j in range(num_el):
                    json_file = json_files[j]
                    if cam_extrinsics_all is not None:
                        extrinsic = cam_extrinsics_all[j].cpu().numpy()
                        intrinsic = cam_intrinsics_all[j].cpu().numpy()

                    with open(json_file, 'r') as file:
                        if 'apollo' in args.dataset_name:
                            json_line = json.loads(file.read())
                            if 'extrinsic' not in json_line:
                                json_line['extrinsic'] = extrinsic
                            if 'intrinsic' not in json_line:
                                json_line['intrinsic'] = intrinsic
                        else:
                            file_lines = [line for line in file]
                            json_line = json.loads(file_lines[0])

                    json_line['json_file'] = json_file
                    if 'once' in args.dataset_name:
                        if 'train' in json_file:
                            img_path = json_file.replace('train', 'data').replace('.json', '.jpg')
                        elif 'val' in json_file:
                            img_path = json_file.replace('val', 'data').replace('.json', '.jpg')
                        elif 'test' in json_file:
                            img_path = json_file.replace('test', 'data').replace('.json', '.jpg')
                        json_line["file_path"] = img_path

                    gt_lines_sub.append(copy.deepcopy(json_line))

                    # pred in ground
                    lane_pred = all_line_preds[j].cpu().numpy()
                    cls_pred = torch.argmax(all_cls_scores[j], dim=-1).cpu().numpy()
                    pos_lanes = lane_pred[cls_pred > 0]

                    if self.args.num_category > 1:
                        scores_pred = torch.softmax(all_cls_scores[j][cls_pred > 0], dim=-1).cpu().numpy()
                    else:
                        scores_pred = torch.sigmoid(all_cls_scores[j][cls_pred > 0]).cpu().numpy()

                    if pos_lanes.shape[0]:
                        lanelines_pred = []
                        lanelines_prob = []
                        xs = pos_lanes[:, 0:args.num_y_steps]
                        ys = np.tile(args.anchor_y_steps.copy()[None, :], (xs.shape[0], 1))
                        zs = pos_lanes[:, args.num_y_steps:2*args.num_y_steps]
                        vis = pos_lanes[:, 2*args.num_y_steps:]

                        for tmp_idx in range(pos_lanes.shape[0]):
                            cur_vis = vis[tmp_idx] > 0
                            cur_xs = xs[tmp_idx][cur_vis]
                            cur_ys = ys[tmp_idx][cur_vis]
                            cur_zs = zs[tmp_idx][cur_vis]

                            if cur_vis.sum() < 2:
                                continue

                            lanelines_pred.append([])
                            for tmp_inner_idx in range(cur_xs.shape[0]):
                                lanelines_pred[-1].append(
                                    [cur_xs[tmp_inner_idx],
                                     cur_ys[tmp_inner_idx],
                                     cur_zs[tmp_inner_idx]])
                            lanelines_prob.append(scores_pred[tmp_idx].tolist())
                    else:
                        lanelines_pred = []
                        lanelines_prob = []

                    json_line["pred_laneLines"] = lanelines_pred
                    json_line["pred_laneLines_prob"] = lanelines_prob

                    pred_lines_sub.append(copy.deepcopy(json_line))
                    img_path = json_line['file_path']
                    
                    if args.dataset_name == 'once':
                        self.save_eval_result_once(args, img_path, lanelines_pred, lanelines_prob)
            val_pbar.close()

            if 'openlane' in args.dataset_name:
                eval_stats = self.evaluator.bench_one_submit_ddp(
                    pred_lines_sub, gt_lines_sub, args.model_name,
                    args.pos_threshold, vis=False)
            elif 'once' in args.dataset_name:
                eval_stats = self.evaluator.lane_evaluation(
                    args.data_dir + 'val', '%s/once_pred/test' % (args.save_path),
                    args.eval_config_dir, args)
            elif 'apollo' in args.dataset_name:
                self.logger.info(' >>> eval mAP | [0.05, 0.95]')
                eval_stats = self.evaluator.bench_one_submit_ddp(
                    pred_lines_sub, gt_lines_sub,
                    args.model_name, args.pos_threshold, vis=False)
            else:
                assert False
                
            if any(name in args.dataset_name for name in ['openlane', 'apollo']):
                gather_output = [None for _ in range(args.world_size)]
                # all_gather all eval_stats and calculate mean
                dist.all_gather_object(gather_output, eval_stats)
                dist.barrier()
                eval_stats = self._recal_gpus_val(gather_output, eval_stats)

                loss_list = []
                return loss_list, eval_stats
            elif 'once' in args.dataset_name:
                loss_list = []
                return loss_list, eval_stats

    def _recal_gpus_val(self, gather_output, eval_stats):
        args = self.args

        apollo_metrics = {
            'r_lane': 0, 
            'p_lane': 0, 
            'cnt_gt': 0, 
            'cnt_pred': 0
        }
        openlane_metrics = {
            'r_lane': 0, 
            'p_lane': 0, 
            'c_lane': 0, 
            'cnt_gt': 0, 
            'cnt_pred': 0,
            'match_num': 0
        }

        if 'apollo' in self.args.dataset_name:
            # apollo no category accuracy.
            start_idx = 7
            gather_metrics = apollo_metrics
        else:
            start_idx = 8
            gather_metrics = openlane_metrics
        
        for i, k in enumerate(gather_metrics.keys()):
            gather_metrics[k] = np.sum(
                [eval_stats_sub[start_idx + i] for eval_stats_sub in gather_output])

        if gather_metrics['cnt_gt']!=0 :
            Recall = gather_metrics['r_lane'] / gather_metrics['cnt_gt']
        else:
            Recall = gather_metrics['r_lane'] / (gather_metrics['cnt_gt'] + 1e-6)
        if gather_metrics['cnt_pred'] !=0 :
            Precision = gather_metrics['p_lane'] / gather_metrics['cnt_pred']
        else:
            Precision = gather_metrics['p_lane'] / (gather_metrics['cnt_pred'] + 1e-6)
        if (Recall + Precision)!=0:
            f1_score = 2 * Recall * Precision / (Recall + Precision)
        else:
            f1_score = 2 * Recall * Precision / (Recall + Precision + 1e-6)
        
        if 'apollo' not in self.args.dataset_name:
            if gather_metrics['match_num']!=0:
                category_accuracy = gather_metrics['c_lane'] / gather_metrics['match_num']
            else:
                category_accuracy = gather_metrics['c_lane'] / (gather_metrics['match_num'] + 1e-6)
        
        eval_stats[0] = f1_score
        eval_stats[1] = Recall
        eval_stats[2] = Precision
        if self.is_apollo:
            err_start_idx = 3
        else:
            eval_stats[3] = category_accuracy
            err_start_idx = 4
        for i in range(4):
            err_idx = err_start_idx + i
            eval_stats[err_idx] = np.sum([eval_stats_sub[err_idx] for eval_stats_sub in gather_output]) / args.world_size
        return eval_stats

    def _get_model_from_cfg(self):
        args = self.args
        model = LATR(args)
        
        if args.sync_bn:
            if is_main_process():
                self.logger.info("Convert model with Sync BatchNorm")
            model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
            
        if gpu_available():
            device = torch.device("cuda", args.local_rank)
            model = model.to(device)

        return model

    def _load_ckpt_from_workdir(self, model):
        args = self.args
        if args.eval_ckpt:
            best_file_name = args.eval_ckpt
        else:
            best_file_name = glob.glob(os.path.join(args.save_path, 'model_best*'))
            if len(best_file_name) > 0:
                best_file_name = best_file_name[0]
            else:
                best_file_name = ''
        if os.path.isfile(best_file_name):
            checkpoint = torch.load(best_file_name)
            if is_main_process():
                self.logger.info("=> loading checkpoint '{}'".format(best_file_name))
                model.load_state_dict(checkpoint['state_dict'])
        else:
            self.logger.info("=> no checkpoint found at '{}'".format(best_file_name))

    def eval(self):
        self.logger.info('>>>>>  start eval <<<<< \n')
        args = self.args
        
        model = self._get_model_from_cfg()
        self._load_ckpt_from_workdir(model)

        dist.barrier()
        # DDP setting
        if args.distributed:
            model = DDP(model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True)
        _, eval_stats = self.validate(model)

        if is_main_process() and (eval_stats is not None):
            self.log_eval_stats(eval_stats)

    def _get_train_dataset(self):
        args = self.args
        if 'openlane' in args.dataset_name:
            train_dataset = LaneDataset(args.dataset_dir, args.data_dir + 'training/', args, data_aug=True)

        elif 'once' in args.dataset_name:
            train_dataset = LaneDataset(args.dataset_dir, ops.join(args.data_dir, 'train/'), args, data_aug=True)
        else:
            self.logger.info('using Apollo Dataset')
            train_dataset = ApolloLaneDataset(args.dataset_dir, ops.join(args.data_dir, 'train.json'), args, data_aug=True)
        
        train_loader, train_sampler = get_loader(train_dataset, args)

        return train_dataset, train_loader, train_sampler

    def _get_model_ddp(self):
        args = self.args
        # define network
        model = LATR(args)
        
        # if args.sync_bn:
        if is_main_process():
            self.logger.info("Convert model with Sync BatchNorm")
        model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
        
        if gpu_available():
            # Load model on gpu before passing params to optimizer
            device = torch.device("cuda", args.local_rank)
            model = model.to(device)

        """
            first load param to model, then model = DDP(model)
        """

        # resume model
        args.resume = first_run(args.save_path)

        model, best_epoch, lowest_loss, best_f1_epoch, best_val_f1, \
            optim_saved_state, schedule_saved_state = self.resume_model(model)
        dist.barrier()
        # DDP setting
        if args.distributed:
            model = DDP(
                model, device_ids=[args.local_rank],
                output_device=args.local_rank,
                find_unused_parameters=True
            )

        # Define optimizer and scheduler
        optimizer = build_optimizer(
            model,
            args.optimizer_cfg)
        scheduler = define_scheduler(
            optimizer, args, dataset_size=len(self.train_loader))

        return model, optimizer, scheduler, best_epoch, lowest_loss, best_f1_epoch, best_val_f1

    def resume_model(self, model, path=''):
        args = self.args
        
        best_epoch = 0
        lowest_loss = np.inf
        best_f1_epoch = 0
        best_val_f1 = -1e-5
        optim_saved_state = None
        schedule_saved_state = None
            
        if len(path) == 0 and args.resume:
            # try the latest ckpt
            path = os.path.join(args.save_path, 'checkpoint_model_epoch_{}.pth.tar'.format(int(args.resume)))
            # try the best ckpt saved
            if not os.path.isfile(path):
                path = os.path.join(args.save_path, f'model_best_epoch_{args.resume}.pth.tar')
            
        if os.path.isfile(path):
            self.logger.info("=> loading checkpoint from {}".format(path))
            checkpoint = torch.load(path, map_location='cpu')
            if is_main_process():
                model.load_state_dict(checkpoint['state_dict'])
                self.logger.info("=> loaded checkpoint '{}' (epoch {})".format(args.resume, args.start_epoch))
            
            optim_saved_state = checkpoint['optimizer']
            schedule_saved_state = checkpoint['scheduler']
            
            args.start_epoch = int(args.resume)
        else:
            if is_main_process():
                self.logger.info("=> Warning: no checkpoint found at '{}'".format(path))
            
        return model, best_epoch, lowest_loss, best_f1_epoch, best_val_f1, optim_saved_state, schedule_saved_state

    def _get_valid_dataset(self):
        args = self.args
        if 'openlane' in args.dataset_name:
            if not args.evaluate_case:
                valid_dataset = LaneDataset(args.dataset_dir, args.data_dir + 'validation/', args)
            else:
                # TODO eval case
                valid_dataset = LaneDataset(args.dataset_dir, args.data_dir + 'test/up_down_case/', args)

        elif 'once' in args.dataset_name:
            valid_dataset = LaneDataset(args.dataset_dir, ops.join(args.data_dir, 'val/'), args)
        else:
            valid_dataset = ApolloLaneDataset(args.dataset_dir, os.path.join(args.data_dir, 'test.json'), args)

        valid_loader, valid_sampler = get_loader(valid_dataset, args)
        return valid_dataset, valid_loader, valid_sampler

    def save_eval_result_once(self, args, img_path, lanelines_pred, lanelines_prob):
        # 3d eval result
        result = {}
        result_dir = os.path.join(args.save_path, 'once_pred/')
        mkdir_if_missing(result_dir)
        result_dir = os.path.join(result_dir, 'test/')
        mkdir_if_missing(result_dir)
        file_path_splited = img_path.split('/')
        result_dir = os.path.join(result_dir, file_path_splited[-3]) # sequence
        mkdir_if_missing(result_dir)
        result_dir = os.path.join(result_dir, 'cam01/')
        mkdir_if_missing(result_dir)
        result_file_path = ops.join(result_dir, file_path_splited[-1][:-4]+'.json')

        cam_pitch = 0.3/180*np.pi
        cam_height = 1.5
        cam_extrinsics = np.array([[np.cos(cam_pitch), 0, -np.sin(cam_pitch), 0],
                                    [0, 1, 0, 0],
                                    [np.sin(cam_pitch), 0,  np.cos(cam_pitch), cam_height],
                                    [0, 0, 0, 1]], dtype=float)
        R_vg = np.array([[0, 1, 0],
                            [-1, 0, 0],
                            [0, 0, 1]], dtype=float)
        R_gc = np.array([[1, 0, 0],
                            [0, 0, 1],
                            [0, -1, 0]], dtype=float)
        cam_extrinsics[:3, :3] = np.matmul(np.matmul(
                                    np.matmul(np.linalg.inv(R_vg), cam_extrinsics[:3, :3]),
                                        R_vg), R_gc)
        cam_extrinsics[0:2, 3] = 0.0

        # write lane result
        lane_lines = []
        for k in range(len(lanelines_pred)):
            lane = np.array(lanelines_pred[k])
            lane = np.flip(lane, axis=0)
            lane = lane.T
            lane = np.vstack((lane, np.ones((1, lane.shape[1]))))
            lane = np.matmul(np.linalg.inv(cam_extrinsics), lane)
            lane = lane[0:3,:].T
            lane_lines.append({'points': lane.tolist(),
                               'score': np.max(lanelines_prob[k])})
        result['lanes'] = lane_lines

        with open(result_file_path, 'w') as result_file:
            json.dump(result, result_file)

    def log_eval_stats(self, eval_stats):
        if self.is_apollo:
            return self._log_genlane_eval_info(eval_stats)

        if is_main_process():
            self.logger.info("===> Evaluation laneline F-measure: {:.8f}".format(eval_stats[0]))
            self.logger.info("===> Evaluation laneline Recall: {:.8f}".format(eval_stats[1]))
            self.logger.info("===> Evaluation laneline Precision: {:.8f}".format(eval_stats[2]))
            self.logger.info("===> Evaluation laneline Category Accuracy: {:.8f}".format(eval_stats[3]))
            self.logger.info("===> Evaluation laneline x error (close): {:.8f} m".format(eval_stats[4]))
            self.logger.info("===> Evaluation laneline x error (far): {:.8f} m".format(eval_stats[5]))
            self.logger.info("===> Evaluation laneline z error (close): {:.8f} m".format(eval_stats[6]))
            self.logger.info("===> Evaluation laneline z error (far): {:.8f} m".format(eval_stats[7]))

    def _log_genlane_eval_info(self, eval_stats):
        if is_main_process():
            self.logger.info("===> Evaluation on validation set: \n"
                "laneline F-measure {:.8} \n"
                "laneline Recall  {:.8} \n"
                "laneline Precision  {:.8} \n"
                "laneline x error (close)  {:.8} m\n"
                "laneline x error (far)  {:.8} m\n"
                "laneline z error (close)  {:.8} m\n"
                "laneline z error (far)  {:.8} m\n".format(eval_stats[0], eval_stats[1],
                                                            eval_stats[2], eval_stats[3],
                                                            eval_stats[4], eval_stats[5],
                                                            eval_stats[6]))


def set_work_dir(cfg):
    # =========output path========== #
    save_prefix = osp.join(os.getcwd(), 'work_dirs')
    save_root = osp.join(save_prefix, cfg.output_dir)

    # cur work dirname
    cfg_path = Path(cfg.config)

    if cfg.mod is None:
        cfg.mod = os.path.join(cfg_path.parent.name, cfg_path.stem)
    
    save_ppath = Path(save_root, cfg.mod)
    save_ppath.mkdir(parents=True, exist_ok=True)

    cfg.save_path = save_ppath.as_posix()
    cfg.save_json_path = cfg.save_path
    
    seg_output_dir = Path(cfg.save_path, 'seg_vis')
    seg_output_dir.mkdir(parents=True, exist_ok=True)

    # cp config into cur_work_dir
    shutil.copy(cfg_path.as_posix(), cfg.save_path)
    

================================================
FILE: main.py
================================================
import argparse
from mmcv.utils import Config, DictAction

from utils.utils import *
from experiments.ddp import *
from experiments.runner import *


def get_args():
    parser = argparse.ArgumentParser()
    # DDP setting
    parser.add_argument('--distributed', action='store_true')
    parser.add_argument("--local_rank", type=int)
    parser.add_argument('--gpu', type=int, default=0)
    parser.add_argument('--world_size', type=int, default=1)
    parser.add_argument('--nodes', type=int, default=1)
    parser.add_argument('--use_slurm', default=False, action='store_true')

    # exp setting
    parser.add_argument('--config', type=str, help='config file path')
    parser.add_argument(
        '--cfg-options',
        nargs='+',
        action=DictAction,
        help='overwrite config param.')
    return parser.parse_args()


def main():
    args = get_args()
    # define runner to begin training or evaluation
    cfg = Config.fromfile(args.config)
    if args.cfg_options is not None:
        cfg.merge_from_dict(args.cfg_options)

    # initialize distributed data parallel set
    ddp_init(args)
    cfg.merge_from_dict(vars(args))
    
    runner = Runner(cfg)
    if not cfg.evaluate:
        runner.train()
    else:
        runner.eval()


if __name__ == '__main__':
    main()

================================================
FILE: models/__init__.py
================================================


================================================
FILE: models/latr.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
from utils.utils import *
from mmdet3d.models import build_backbone, build_neck
from .latr_head import LATRHead
from mmcv.utils import Config
from .ms2one import build_ms2one
from .utils import deepFeatureExtractor_EfficientNet

from mmdet.models.builder import BACKBONES


# overall network
class LATR(nn.Module):
    def __init__(self, args):
        super().__init__()
        self.no_cuda = args.no_cuda
        self.batch_size = args.batch_size
        self.num_lane_type = 1  # no centerline
        self.num_y_steps = args.num_y_steps
        self.max_lanes = args.max_lanes
        self.num_category = args.num_category
        _dim_ = args.latr_cfg.fpn_dim
        num_query = args.latr_cfg.num_query
        num_group = args.latr_cfg.num_group
        sparse_num_group = args.latr_cfg.sparse_num_group

        self.encoder = build_backbone(args.latr_cfg.encoder)
        if getattr(args.latr_cfg, 'neck', None):
            self.neck = build_neck(args.latr_cfg.neck)
        else:
            self.neck = None
        self.encoder.init_weights()
        self.ms2one = build_ms2one(args.ms2one)

        # build 2d query-based instance seg
        self.head = LATRHead(
            args=args,
            dim=_dim_,
            num_group=num_group,
            num_convs=4,
            in_channels=_dim_,
            kernel_dim=_dim_,
            position_range=args.position_range,
            top_view_region=args.top_view_region,
            positional_encoding=dict(
                type='SinePositionalEncoding',
                num_feats=_dim_// 2, normalize=True),
            num_query=num_query,
            pred_dim=self.num_y_steps,
            num_classes=args.num_category,
            embed_dims=_dim_,
            transformer=args.transformer,
            sparse_ins_decoder=args.sparse_ins_decoder,
            **args.latr_cfg.get('head', {}),
            trans_params=args.latr_cfg.get('trans_params', {})
        )

    def forward(self, image, _M_inv=None, is_training=True, extra_dict=None):
        out_featList = self.encoder(image)
        neck_out = self.neck(out_featList)
        neck_out = self.ms2one(neck_out)

        output = self.head(
            dict(
                x=neck_out,
                lane_idx=extra_dict['seg_idx_label'],
                seg=extra_dict['seg_label'],
                lidar2img=extra_dict['lidar2img'],
                pad_shape=extra_dict['pad_shape'],
                ground_lanes=extra_dict['ground_lanes'] if is_training else None,
                ground_lanes_dense=extra_dict['ground_lanes_dense'] if is_training else None,
                image=image,
            ),
            is_training=is_training,
        )
        return output

================================================
FILE: models/latr_head.py
================================================
import numpy as np
import math
import cv2

import torch
import torch.nn as nn
from torch.nn import init
import torch.nn.functional as F
from torch.nn.init import normal_

from mmcv.cnn import bias_init_with_prob
from mmdet.models.builder import build_loss
from mmdet.models.utils import build_transformer
from mmdet.core import multi_apply

from mmcv.utils import Config
from models.sparse_ins import SparseInsDecoder
from .utils import inverse_sigmoid
from .transformer_bricks import *


class LATRHead(nn.Module):
    def __init__(self, args,
                 dim=128,
                 num_group=1,
                 num_convs=4,
                 in_channels=128,
                 kernel_dim=128,
                 positional_encoding=dict(
                    type='SinePositionalEncoding',
                    num_feats=128 // 2, normalize=True),
                 num_classes=21,
                 num_query=30,
                 embed_dims=128,
                 transformer=None,
                 num_reg_fcs=2,
                 depth_num=50,
                 depth_start=3,
                 top_view_region=None,
                 position_range=[-50, 3, -10, 50, 103, 10.],
                 pred_dim=10,
                 loss_cls=dict(
                     type='FocalLoss',
                     use_sigmoid=True,
                     gamma=2.0,
                     alpha=0.25,
                     loss_weight=2.0),
                 loss_reg=dict(type='L1Loss', loss_weight=2.0),
                 loss_vis=dict(type='BCEWithLogitsLoss', reduction='mean'),
                 sparse_ins_decoder=Config(
                    dict(
                        encoder=dict(
                            out_dims=64),# neck output feature channels
                        decoder=dict(
                            num_group=1,
                            output_iam=True,
                            scale_factor=1.),
                        sparse_decoder_weight=1.0,
                        )),
                 xs_loss_weight=1.0,
                 zs_loss_weight=5.0,
                 vis_loss_weight=1.0,
                 cls_loss_weight=20,
                 project_loss_weight=1.0,
                 trans_params=dict(
                     init_z=0, bev_h=250, bev_w=100),
                 pt_as_query=False,
                 num_pt_per_line=5,
                 num_feature_levels=1,
                 gt_project_h=20,
                 gt_project_w=30,
                 project_crit=dict(
                     type='SmoothL1Loss',
                     reduction='none'),
                 ):
        super().__init__()
        self.trans_params = dict(
            top_view_region=top_view_region,
            z_region=[position_range[2], position_range[5]])
        self.trans_params.update(trans_params)
        self.gt_project_h = gt_project_h
        self.gt_project_w = gt_project_w

        self.num_y_steps = args.num_y_steps
        self.register_buffer('anchor_y_steps',
            torch.from_numpy(args.anchor_y_steps).float())
        self.register_buffer('anchor_y_steps_dense',
            torch.from_numpy(args.anchor_y_steps_dense).float())

        project_crit['reduction'] = 'none'
        self.project_crit = getattr(
            nn, project_crit.pop('type'))(**project_crit)

        self.num_classes = num_classes
        self.embed_dims = embed_dims
        # points num along y-axis.
        self.code_size = pred_dim
        self.num_query = num_query
        self.num_group = num_group
        self.num_pred = transformer['decoder']['num_layers']
        self.pc_range = position_range
        self.xs_loss_weight = xs_loss_weight
        self.zs_loss_weight = zs_loss_weight
        self.vis_loss_weight = vis_loss_weight
        self.cls_loss_weight = cls_loss_weight
        self.project_loss_weight = project_loss_weight

        loss_reg['reduction'] = 'none'
        self.reg_crit = build_loss(loss_reg)
        self.cls_crit = build_loss(loss_cls)
        self.bce_loss = build_nn_loss(loss_vis)
        self.sparse_ins = SparseInsDecoder(cfg=sparse_ins_decoder)

        self.depth_num = depth_num
        self.position_dim = 3 * self.depth_num
        self.position_range = position_range
        self.depth_start = depth_start
        self.adapt_pos3d = nn.Sequential(
            nn.Conv2d(self.embed_dims, self.embed_dims*4, kernel_size=1, stride=1, padding=0),
            nn.ReLU(),
            nn.Conv2d(self.embed_dims*4, self.embed_dims, kernel_size=1, stride=1, padding=0),
        )
        self.positional_encoding = build_positional_encoding(positional_encoding)
        self.position_encoder = nn.Sequential(
            nn.Conv2d(self.position_dim, self.embed_dims*4, kernel_size=1, stride=1, padding=0),
            nn.ReLU(),
            nn.Conv2d(self.embed_dims*4, self.embed_dims, kernel_size=1, stride=1, padding=0),
        )
        self.transformer = build_transformer(transformer)
        self.query_embedding = nn.Sequential(
            nn.Linear(self.embed_dims, self.embed_dims),
            nn.ReLU(),
            nn.Linear(self.embed_dims, self.embed_dims),
        )

        # build pred layer: cls, reg, vis
        self.num_reg_fcs = num_reg_fcs
        cls_branch = []
        for _ in range(self.num_reg_fcs):
            cls_branch.append(nn.Linear(self.embed_dims, self.embed_dims))
            cls_branch.append(nn.LayerNorm(self.embed_dims))
            cls_branch.append(nn.ReLU(inplace=True))
        cls_branch.append(nn.Linear(self.embed_dims, self.num_classes))
        fc_cls = nn.Sequential(*cls_branch)

        reg_branch = []
        for _ in range(self.num_reg_fcs):
            reg_branch.append(nn.Linear(self.embed_dims, self.embed_dims))
            reg_branch.append(nn.ReLU())
        reg_branch.append(
            nn.Linear(
                self.embed_dims,
                3 * self.code_size // num_pt_per_line))
        reg_branch = nn.Sequential(*reg_branch)

        self.cls_branches = nn.ModuleList(
            [fc_cls for _ in range(self.num_pred)])
        self.reg_branches = nn.ModuleList(
            [reg_branch for _ in range(self.num_pred)])

        self.num_pt_per_line = num_pt_per_line
        self.point_embedding = nn.Embedding(
            self.num_pt_per_line, self.embed_dims)

        self.reference_points = nn.Sequential(
            nn.Linear(self.embed_dims, self.embed_dims),
            nn.ReLU(True),
            nn.Linear(self.embed_dims, self.embed_dims),
            nn.ReLU(True),
            nn.Linear(self.embed_dims, 2 * self.code_size // num_pt_per_line))
        self.num_feature_levels = num_feature_levels
        self.level_embeds = nn.Parameter(torch.Tensor(
            self.num_feature_levels, self.embed_dims))

        self._init_weights()

    def _init_weights(self):
        self.transformer.init_weights()
        xavier_init(self.reference_points, distribution='uniform', bias=0)
        if self.cls_crit.use_sigmoid:
            bias_init = bias_init_with_prob(0.01)
            for m in self.cls_branches:
                nn.init.constant_(m[-1].bias, bias_init)
        normal_(self.level_embeds)

    def forward(self, input_dict, is_training=True):
        output_dict = {}
        img_feats = input_dict['x']

        if not isinstance(img_feats, (list, tuple)):
            img_feats = [img_feats]

        sparse_output = self.sparse_ins(
            img_feats[0],
            lane_idx_map=input_dict['lane_idx'],
            input_shape=input_dict['seg'].shape[-2:],
            is_training=is_training)
        # generate 2d pos emb
        B, C, H, W = img_feats[0].shape
        masks = img_feats[0].new_zeros((B, H, W))

        # TODO use actual mask if using padding or other aug
        sin_embed = self.positional_encoding(masks)
        sin_embed = self.adapt_pos3d(sin_embed)

        # init query and reference pt
        query = sparse_output['inst_features'] # BxNxC
        # B, N, C -> B, N, num_anchor_per_line, C
        query = query.unsqueeze(2) + self.point_embedding.weight[None, None, ...]
       
        query_embeds = self.query_embedding(query).flatten(1, 2)
        query = torch.zeros_like(query_embeds)
        reference_points = self.reference_points(query_embeds)
        reference_points = reference_points.sigmoid()
        mlvl_feats = img_feats

        feat_flatten = []
        spatial_shapes = []
        mlvl_masks = []

        assert self.num_feature_levels == len(mlvl_feats)
        for lvl, feat in enumerate(mlvl_feats):
            bs, c, h, w = feat.shape
            spatial_shape = (h, w)
            feat = feat.flatten(2).permute(2, 0, 1) # NxBxC
            feat = feat + self.level_embeds[None, lvl:lvl+1, :].to(feat.device)
            spatial_shapes.append(spatial_shape)
            feat_flatten.append(feat)
            mlvl_masks.append(torch.zeros((bs, *spatial_shape),
                                           dtype=torch.bool,
                                           device=feat.device))

        if self.transformer.with_encoder:
            mlvl_positional_encodings = []
            pos_embed2d = []
            for lvl, feat in enumerate(mlvl_feats):
                mlvl_positional_encodings.append(
                    self.positional_encoding(mlvl_masks[lvl]))
                pos_embed2d.append(
                    mlvl_positional_encodings[-1].flatten(2).permute(2, 0, 1))
            pos_embed2d = torch.cat(pos_embed2d, 0)
        else:
            mlvl_positional_encodings = None
            pos_embed2d = None

        feat_flatten = torch.cat(feat_flatten, 0)

        spatial_shapes = torch.as_tensor(
            spatial_shapes, dtype=torch.long, device=query.device)
        level_start_index = torch.cat(
            (spatial_shapes.new_zeros((1, )),
             spatial_shapes.prod(1).cumsum(0)[:-1])
        )

        # head
        pos_embed = None
        outs_dec, project_results, outputs_classes, outputs_coords = \
            self.transformer(
                feat_flatten, None,
                query, query_embeds, pos_embed,
                reference_points=reference_points,
                reg_branches=self.reg_branches,
                cls_branches=self.cls_branches,
                img_feats=img_feats,
                lidar2img=input_dict['lidar2img'],
                pad_shape=input_dict['pad_shape'],
                sin_embed=sin_embed,
                spatial_shapes=spatial_shapes,
                level_start_index=level_start_index,
                mlvl_masks=mlvl_masks,
                mlvl_positional_encodings=mlvl_positional_encodings,
                pos_embed2d=pos_embed2d,
                image=input_dict['image'],
                **self.trans_params)

        all_cls_scores = torch.stack(outputs_classes)
        all_line_preds = torch.stack(outputs_coords)
        all_line_preds[..., 0] = (all_line_preds[..., 0]
            * (self.pc_range[3] - self.pc_range[0]) + self.pc_range[0])
        all_line_preds[..., 1] = (all_line_preds[..., 1]
            * (self.pc_range[5] - self.pc_range[2]) + self.pc_range[2])

        # reshape to original format
        all_line_preds = all_line_preds.view(
            len(outputs_classes), bs, self.num_query,
            self.transformer.decoder.num_anchor_per_query,
            self.transformer.decoder.num_points_per_anchor, 2 + 1 # xz+vis
        )
        all_line_preds = all_line_preds.permute(0, 1, 2, 5, 3, 4)
        all_line_preds = all_line_preds.flatten(3, 5)

        output_dict.update({
            'all_cls_scores': all_cls_scores,
            'all_line_preds': all_line_preds,
        })
        output_dict.update(sparse_output)

        if is_training:
            losses = self.get_loss(output_dict, input_dict)
            project_loss = self.get_project_loss(
                project_results, input_dict,
                h=self.gt_project_h, w=self.gt_project_w)
            losses['project_loss'] = \
                self.project_loss_weight * project_loss
            output_dict.update(losses)
        return output_dict

    def get_project_loss(self, results, input_dict, h=20, w=30):
        gt_lane = input_dict['ground_lanes_dense']
        gt_ys = self.anchor_y_steps_dense.clone()
        code_size = gt_ys.shape[0]
        gt_xs = gt_lane[..., :code_size]
        gt_zs = gt_lane[..., code_size : 2*code_size]
        gt_vis = gt_lane[..., 2*code_size:3*code_size]
        gt_ys = gt_ys[None, None, :].expand_as(gt_xs)
        gt_points = torch.stack([gt_xs, gt_ys, gt_zs], dim=-1)

        B = results[0].shape[0]
        ref_3d_home = F.pad(gt_points, (0, 1), value=1)
        coords_img = ground2img(
            ref_3d_home,
            h, w,
            input_dict['lidar2img'],
            input_dict['pad_shape'], mask=gt_vis)

        all_loss = 0.
        for projct_result in results:
            projct_result = F.interpolate(
                projct_result,
                size=(h, w),
                mode='nearest')
            gt_proj = coords_img.clone()

            mask = (gt_proj[:, -1, ...] > 0) * (projct_result[:, -1, ...] > 0)
            diff_loss = self.project_crit(
                projct_result[:, :3, ...],
                gt_proj[:, :3, ...],
            )
            diff_y_loss = diff_loss[:, 1, ...]
            diff_z_loss = diff_loss[:, 2, ...]
            diff_loss = diff_y_loss * 0.1 + diff_z_loss
            diff_loss = (diff_loss * mask).sum() / torch.clamp(mask.sum(), 1)
            all_loss = all_loss + diff_loss

        return all_loss / len(results)

    def get_loss(self, output_dict, input_dict):
        all_cls_pred = output_dict['all_cls_scores']
        all_lane_pred = output_dict['all_line_preds']
        gt_lanes = input_dict['ground_lanes']
        all_xs_loss = 0.0
        all_zs_loss = 0.0
        all_vis_loss = 0.0
        all_cls_loss = 0.0
        matched_indices = output_dict['matched_indices']
        num_layers = all_lane_pred.shape[0]

        def single_layer_loss(layer_idx):
            gcls_pred = all_cls_pred[layer_idx]
            glane_pred = all_lane_pred[layer_idx]

            glane_pred = glane_pred.view(
                glane_pred.shape[0],
                self.num_group,
                self.num_query,
                glane_pred.shape[-1])
            gcls_pred = gcls_pred.view(
                gcls_pred.shape[0],
                self.num_group,
                self.num_query,
                gcls_pred.shape[-1])

            per_xs_loss = 0.0
            per_zs_loss = 0.0
            per_vis_loss = 0.0
            per_cls_loss = 0.0
            batch_size = len(matched_indices[0])

            for b_idx in range(len(matched_indices[0])):
                for group_idx in range(self.num_group):
                    pred_idx = matched_indices[group_idx][b_idx][0]
                    gt_idx = matched_indices[group_idx][b_idx][1]

                    cls_pred = gcls_pred[:, group_idx, ...]
                    lane_pred = glane_pred[:, group_idx, ...]

                    if gt_idx.shape[0] < 1:
                        cls_target = cls_pred.new_zeros(cls_pred[b_idx].shape[0]).long()
                        cls_loss = self.cls_crit(cls_pred[b_idx], cls_target)
                        per_cls_loss = per_cls_loss + cls_loss
                        per_xs_loss = per_xs_loss + 0.0 * lane_pred[b_idx].mean()
                        continue

                    pos_lane_pred = lane_pred[b_idx][pred_idx]
                    gt_lane = gt_lanes[b_idx][gt_idx]

                    pred_xs = pos_lane_pred[:, :self.code_size]
                    pred_zs = pos_lane_pred[:, self.code_size : 2*self.code_size]
                    pred_vis = pos_lane_pred[:, 2*self.code_size:]
                    gt_xs = gt_lane[:, :self.code_size]
                    gt_zs = gt_lane[:, self.code_size : 2*self.code_size]
                    gt_vis = gt_lane[:, 2*self.code_size:3*self.code_size]

                    loc_mask = gt_vis > 0
                    xs_loss = self.reg_crit(pred_xs, gt_xs)
                    zs_loss = self.reg_crit(pred_zs, gt_zs)
                    xs_loss = (xs_loss * loc_mask).sum() / torch.clamp(loc_mask.sum(), 1)
                    zs_loss = (zs_loss * loc_mask).sum() / torch.clamp(loc_mask.sum(), 1)
                    vis_loss = self.bce_loss(pred_vis, gt_vis)

                    cls_target = cls_pred.new_zeros(cls_pred[b_idx].shape[0]).long()
                    cls_target[pred_idx] = torch.argmax(
                        gt_lane[:, 3*self.code_size:], dim=1)
                    cls_loss = self.cls_crit(cls_pred[b_idx], cls_target)

                    per_xs_loss += xs_loss
                    per_zs_loss += zs_loss
                    per_vis_loss += vis_loss
                    per_cls_loss += cls_loss

            return tuple(map(lambda x: x / batch_size / self.num_group,
                             [per_xs_loss, per_zs_loss, per_vis_loss, per_cls_loss]))

        all_xs_loss, all_zs_loss, all_vis_loss, all_cls_loss = multi_apply(
            single_layer_loss, range(all_lane_pred.shape[0]))
        all_xs_loss = sum(all_xs_loss) / num_layers
        all_zs_loss = sum(all_zs_loss) / num_layers
        all_vis_loss = sum(all_vis_loss) / num_layers
        all_cls_loss = sum(all_cls_loss) / num_layers

        return dict(
            all_xs_loss=self.xs_loss_weight * all_xs_loss,
            all_zs_loss=self.zs_loss_weight * all_zs_loss,
            all_vis_loss=self.vis_loss_weight * all_vis_loss,
            all_cls_loss=self.cls_loss_weight * all_cls_loss,
        )

    @staticmethod
    def get_reference_points(H, W, bs=1, device='cuda', dtype=torch.float):
        ref_y, ref_x = torch.meshgrid(
            torch.linspace(
                0.5, H - 0.5, H, dtype=dtype, device=device),
            torch.linspace(
                0.5, W - 0.5, W, dtype=dtype, device=device)
        )
        ref_y = ref_y.reshape(-1)[None] / H
        ref_x = ref_x.reshape(-1)[None] / W
        ref_2d = torch.stack((ref_x, ref_y), -1)
        ref_2d = ref_2d.repeat(bs, 1, 1) 
        return ref_2d

def build_nn_loss(loss_cfg):
    crit_t = loss_cfg.pop('type')
    return getattr(nn, crit_t)(**loss_cfg)

================================================
FILE: models/ms2one.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
import copy
from mmcv.cnn import ConvModule
from mmseg.ops import resize


def build_ms2one(config):
    config = copy.deepcopy(config)
    t = config.pop('type')
    if t == 'Naive':
        return Naive(**config)
    elif t == 'DilateNaive':
        return DilateNaive(**config)


class Naive(nn.Module):
    def __init__(self, inc, outc, kernel_size=1):
        super().__init__()
        self.layer = nn.Conv2d(inc, outc, kernel_size=1)

    def forward(self, ms_feats):
        out = self.layer(torch.cat([
            F.interpolate(tmp, ms_feats[0].shape[-2:],
                          mode='bilinear') for tmp in ms_feats], dim=1))
        return out


class DilateNaive(nn.Module):
    def __init__(self, inc, outc, num_scales=4,
                 dilations=(1, 2, 5, 9),
                 merge=True, fpn=False,
                 target_shape=None,
                 one_layer_before=False):
        super().__init__()
        self.dilations = dilations
        self.num_scales = num_scales
        if not isinstance(inc, (tuple, list)):
            inc = [inc for _ in range(num_scales)]
        self.inc = inc
        self.outc = outc
        self.merge = merge
        self.fpn = fpn
        self.target_shape = target_shape
        self.layers = nn.ModuleList()
        for i in range(num_scales):
            layers = []
            if one_layer_before:
                layers.extend([
                    nn.Conv2d(inc[i], outc, kernel_size=1, bias=False),
                    nn.BatchNorm2d(outc),
                    nn.ReLU(True)
                ])
            for j in range(len(dilations[:-i])):
                d = dilations[j]
                layers.append(nn.Sequential(
                    nn.Conv2d(inc[i] if j == 0 and not one_layer_before else outc, outc,
                              kernel_size=1 if d == 1 else 3,
                              stride=1,
                              padding=0 if d == 1 else d,
                              dilation=d,
                              bias=False),
                    nn.BatchNorm2d(outc),
                    nn.ReLU(True)))
            self.layers.append(nn.Sequential(*layers))
        if self.merge:
            self.final_layer = nn.Sequential(
                nn.Conv2d(outc, outc, 3, 1, padding=1, bias=False),
                nn.BatchNorm2d(outc),
                nn.ReLU(True),
                nn.Conv2d(outc, outc, 1))

    def forward(self, x):
        outs = []

        for i in range(self.num_scales - 1, -1, -1):
            if self.fpn and i < self.num_scales - 1:
                tmp = self.layers[i](x[i] + F.interpolate(
                    x[i + 1], x[i].shape[2:],
                    mode='bilinear', align_corners=True))
            else:
                tmp = self.layers[i](x[i])

            if self.target_shape is None:
                if i > 0 and self.merge:
                    tmp = F.interpolate(tmp, x[0].shape[2:],
                        mode='bilinear', align_corners=True)
            else:
                tmp = F.interpolate(tmp, self.target_shape,
                        mode='bilinear', align_corners=True)
            outs.append(tmp)
        if self.merge:
            out = torch.sum(torch.stack(outs, dim=-1), dim=-1)
            out = self.final_layer(out)
            
            return out
        else:
            return outs

================================================
FILE: models/scatter_utils.py
================================================
# Copy from https://github.com/rusty1s/pytorch_scatter

from typing import Optional, Tuple

import torch


def broadcast(src: torch.Tensor, other: torch.Tensor, dim: int):
    if dim < 0:
        dim = other.dim() + dim
    if src.dim() == 1:
        for _ in range(0, dim):
            src = src.unsqueeze(0)
    for _ in range(src.dim(), other.dim()):
        src = src.unsqueeze(-1)
    src = src.expand(other.size())
    return src


def scatter_sum(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
                out: Optional[torch.Tensor] = None,
                dim_size: Optional[int] = None) -> torch.Tensor:
    index = broadcast(index, src, dim)
    if out is None:
        size = list(src.size())
        if dim_size is not None:
            size[dim] = dim_size
        elif index.numel() == 0:
            size[dim] = 0
        else:
            size[dim] = int(index.max()) + 1
        out = torch.zeros(size, dtype=src.dtype, device=src.device)
        return out.scatter_add_(dim, index, src)
    else:
        return out.scatter_add_(dim, index, src)


def scatter_add(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
                out: Optional[torch.Tensor] = None,
                dim_size: Optional[int] = None) -> torch.Tensor:
    return scatter_sum(src, index, dim, out, dim_size)


def scatter_mul(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
                out: Optional[torch.Tensor] = None,
                dim_size: Optional[int] = None) -> torch.Tensor:
    return torch.ops.torch_scatter.scatter_mul(src, index, dim, out, dim_size)


def scatter_mean(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
                 out: Optional[torch.Tensor] = None,
                 dim_size: Optional[int] = None) -> torch.Tensor:
    out = scatter_sum(src, index, dim, out, dim_size)
    dim_size = out.size(dim)

    index_dim = dim
    if index_dim < 0:
        index_dim = index_dim + src.dim()
    if index.dim() <= index_dim:
        index_dim = index.dim() - 1

    ones = torch.ones(index.size(), dtype=src.dtype, device=src.device)
    count = scatter_sum(ones, index, index_dim, None, dim_size)
    count[count < 1] = 1
    count = broadcast(count, out, dim)
    if out.is_floating_point():
        out.true_divide_(count)
    else:
        out.div_(count, rounding_mode='floor')
    return out


def scatter_min(
        src: torch.Tensor, index: torch.Tensor, dim: int = -1,
        out: Optional[torch.Tensor] = None,
        dim_size: Optional[int] = None) -> Tuple[torch.Tensor, torch.Tensor]:
    return torch.ops.torch_scatter.scatter_min(src, index, dim, out, dim_size)


def scatter_max(
        src: torch.Tensor, index: torch.Tensor, dim: int = -1,
        out: Optional[torch.Tensor] = None,
        dim_size: Optional[int] = None) -> Tuple[torch.Tensor, torch.Tensor]:
    return torch.ops.torch_scatter.scatter_max(src, index, dim, out, dim_size)


def scatter(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
            out: Optional[torch.Tensor] = None, dim_size: Optional[int] = None,
            reduce: str = "sum") -> torch.Tensor:
    r"""
    |

    .. image:: https://raw.githubusercontent.com/rusty1s/pytorch_scatter/
            master/docs/source/_figures/add.svg?sanitize=true
        :align: center
        :width: 400px

    |

    Reduces all values from the :attr:`src` tensor into :attr:`out` at the
    indices specified in the :attr:`index` tensor along a given axis
    :attr:`dim`.
    For each value in :attr:`src`, its output index is specified by its index
    in :attr:`src` for dimensions outside of :attr:`dim` and by the
    corresponding value in :attr:`index` for dimension :attr:`dim`.
    The applied reduction is defined via the :attr:`reduce` argument.

    Formally, if :attr:`src` and :attr:`index` are :math:`n`-dimensional
    tensors with size :math:`(x_0, ..., x_{i-1}, x_i, x_{i+1}, ..., x_{n-1})`
    and :attr:`dim` = `i`, then :attr:`out` must be an :math:`n`-dimensional
    tensor with size :math:`(x_0, ..., x_{i-1}, y, x_{i+1}, ..., x_{n-1})`.
    Moreover, the values of :attr:`index` must be between :math:`0` and
    :math:`y - 1`, although no specific ordering of indices is required.
    The :attr:`index` tensor supports broadcasting in case its dimensions do
    not match with :attr:`src`.

    For one-dimensional tensors with :obj:`reduce="sum"`, the operation
    computes

    .. math::
        \mathrm{out}_i = \mathrm{out}_i + \sum_j~\mathrm{src}_j

    where :math:`\sum_j` is over :math:`j` such that
    :math:`\mathrm{index}_j = i`.

    .. note::

        This operation is implemented via atomic operations on the GPU and is
        therefore **non-deterministic** since the order of parallel operations
        to the same value is undetermined.
        For floating-point variables, this results in a source of variance in
        the result.

    :param src: The source tensor.
    :param index: The indices of elements to scatter.
    :param dim: The axis along which to index. (default: :obj:`-1`)
    :param out: The destination tensor.
    :param dim_size: If :attr:`out` is not given, automatically create output
        with size :attr:`dim_size` at dimension :attr:`dim`.
        If :attr:`dim_size` is not given, a minimal sized output tensor
        according to :obj:`index.max() + 1` is returned.
    :param reduce: The reduce operation (:obj:`"sum"`, :obj:`"mul"`,
        :obj:`"mean"`, :obj:`"min"` or :obj:`"max"`). (default: :obj:`"sum"`)

    :rtype: :class:`Tensor`

    .. code-block:: python

        from torch_scatter import scatter

        src = torch.randn(10, 6, 64)
        index = torch.tensor([0, 1, 0, 1, 2, 1])

        # Broadcasting in the first and last dim.
        out = scatter(src, index, dim=1, reduce="sum")

        print(out.size())

    .. code-block::

        torch.Size([10, 3, 64])
    """
    if reduce == 'sum' or reduce == 'add':
        return scatter_sum(src, index, dim, out, dim_size)
    if reduce == 'mul':
        return scatter_mul(src, index, dim, out, dim_size)
    elif reduce == 'mean':
        return scatter_mean(src, index, dim, out, dim_size)
    elif reduce == 'min':
        return scatter_min(src, index, dim, out, dim_size)[0]
    elif reduce == 'max':
        return scatter_max(src, index, dim, out, dim_size)[0]
    else:
        raise ValueError

================================================
FILE: models/sparse_ins.py
================================================
import math

import torch
import torch.nn as nn
from torch.nn import init
import torch.nn.functional as F
from fvcore.nn.weight_init import c2_msra_fill, c2_xavier_fill
from .sparse_inst_loss import SparseInstCriterion, SparseInstMatcher

def _make_stack_3x3_convs(num_convs, in_channels, out_channels):
    convs = []
    for _ in range(num_convs):
        convs.append(
            nn.Conv2d(in_channels, out_channels, 3, padding=1))
        convs.append(nn.ReLU(True))
        in_channels = out_channels
    return nn.Sequential(*convs)


class MaskBranch(nn.Module):
    def __init__(self, cfg, in_channels):
        super().__init__()
        dim = cfg.hidden_dim
        num_convs = cfg.num_convs
        kernel_dim = cfg.kernel_dim
        self.mask_convs = _make_stack_3x3_convs(num_convs, in_channels, dim)
        self.projection = nn.Conv2d(dim, kernel_dim, kernel_size=1)
        self._init_weights()

    def _init_weights(self):
        for m in self.mask_convs.modules():
            if isinstance(m, nn.Conv2d):
                c2_msra_fill(m)
        c2_msra_fill(self.projection)

    def forward(self, features):
        # mask features (x4 convs)
        features = self.mask_convs(features)
        return self.projection(features)


class InstanceBranch(nn.Module):
    def __init__(self, cfg, in_channels, **kwargs):
        super().__init__()
        num_mask = cfg.num_query
        dim = cfg.hidden_dim
        num_classes = cfg.num_classes
        kernel_dim = cfg.kernel_dim
        num_convs = cfg.num_convs
        num_group = cfg.get('num_group', 1)
        sparse_num_group = cfg.get('sparse_num_group', 1)
        self.num_group = num_group
        self.sparse_num_group = sparse_num_group
        self.num_mask = num_mask
        self.inst_convs = _make_stack_3x3_convs(
                            num_convs=num_convs, 
                            in_channels=in_channels, 
                            out_channels=dim)

        self.iam_conv = nn.Conv2d(
            dim * num_group,
            num_group * num_mask * sparse_num_group,
            3, padding=1, groups=num_group * sparse_num_group)
        self.fc = nn.Linear(dim * sparse_num_group, dim)
        # output
        self.mask_kernel = nn.Linear(
            dim, kernel_dim)
        self.cls_score = nn.Linear(
            dim, num_classes)
        self.objectness = nn.Linear(
            dim, 1)
        self.prior_prob = 0.01
        self._init_weights()

    def _init_weights(self):
        for m in self.inst_convs.modules():
            if isinstance(m, nn.Conv2d):
                c2_msra_fill(m)
        bias_value = -math.log((1 - self.prior_prob) / self.prior_prob)
        for module in [self.iam_conv, self.cls_score]:
            init.constant_(module.bias, bias_value)
        init.normal_(self.iam_conv.weight, std=0.01)
        init.normal_(self.cls_score.weight, std=0.01)

        init.normal_(self.mask_kernel.weight, std=0.01)
        init.constant_(self.mask_kernel.bias, 0.0)
        c2_xavier_fill(self.fc)

    def forward(self, seg_features, is_training=True):
        out = {}
        # SparseInst part
        seg_features = self.inst_convs(seg_features)
        # predict instance activation maps
        iam = self.iam_conv(seg_features.tile(
            (1, self.num_group, 1, 1)))
        if not is_training:
            iam = iam.view(
                iam.shape[0],
                self.num_group,
                self.num_mask * self.sparse_num_group,
                *iam.shape[-2:])
            iam = iam[:, 0, ...]
            num_group = 1
        else:
            num_group = self.num_group

        iam_prob = iam.sigmoid()
        B, N = iam_prob.shape[:2]
        C = seg_features.size(1)
        # BxNxHxW -> BxNx(HW)
        iam_prob = iam_prob.view(B, N, -1)
        normalizer = iam_prob.sum(-1).clamp(min=1e-6)
        iam_prob_norm_hw = iam_prob / normalizer[:, :, None]

        # aggregate features: BxCxHxW -> Bx(HW)xC
        # (B x N x HW) @ (B x HW x C) -> B x N x C
        all_inst_features = torch.bmm(
            iam_prob_norm_hw,
            seg_features.view(B, C, -1).permute(0, 2, 1)) #BxNxC

        # concat sparse group features
        inst_features = all_inst_features.reshape(
            B, num_group,
            self.sparse_num_group,
            self.num_mask, -1
        ).permute(0, 1, 3, 2, 4).reshape(
            B, num_group,
            self.num_mask, -1)
        inst_features = F.relu_(
            self.fc(inst_features))

        # avg over sparse group
        iam_prob = iam_prob.view(
            B, num_group,
            self.sparse_num_group,
            self.num_mask,
            iam_prob.shape[-1])
        iam_prob = iam_prob.mean(dim=2).flatten(1, 2)
        inst_features = inst_features.flatten(1, 2)
        out.update(dict(
            iam_prob=iam_prob,
            inst_features=inst_features))

        if self.training:
            pred_logits = self.cls_score(inst_features)
            pred_kernel = self.mask_kernel(inst_features)
            pred_scores = self.objectness(inst_features)
            out.update(dict(
                pred_logits=pred_logits,
                pred_kernel=pred_kernel,
                pred_scores=pred_scores))
        return out

class SparseInsDecoder(nn.Module):
    def __init__(self, cfg, **kargs) -> None:
        super().__init__()
        in_channels = cfg.encoder.out_dims + 2
        self.output_iam = cfg.decoder.output_iam
        self.scale_factor = cfg.decoder.scale_factor
        self.sparse_decoder_weight = cfg.sparse_decoder_weight
        self.inst_branch = InstanceBranch(cfg.decoder, in_channels)
        # dim, num_convs, kernel_dim, in_channels
        self.mask_branch = MaskBranch(cfg.decoder, in_channels)
        self.sparse_inst_crit = SparseInstCriterion(
            num_classes=cfg.decoder.num_classes,
            matcher=SparseInstMatcher(),
            cfg=cfg)
        self._init_weights()

    def _init_weights(self):
        self.inst_branch._init_weights()
        self.mask_branch._init_weights()

    @torch.no_grad()
    def compute_coordinates(self, x):
        h, w = x.size(2), x.size(3)
        y_loc = -1.0 + 2.0 * torch.arange(h, device=x.device) / (h - 1)
        x_loc = -1.0 + 2.0 * torch.arange(w, device=x.device) / (w - 1)
        y_loc, x_loc = torch.meshgrid(y_loc, x_loc)
        y_loc = y_loc.expand([x.shape[0], 1, -1, -1])
        x_loc = x_loc.expand([x.shape[0], 1, -1, -1])
        locations = torch.cat([x_loc, y_loc], 1)
        return locations.to(x)

    def forward(self, features, is_training=True, **kwargs):
        output = {}
        coord_features = self.compute_coordinates(features)
        features = torch.cat([coord_features, features], dim=1)
        inst_output = self.inst_branch(
            features, is_training=is_training)
        output.update(inst_output)

        if is_training:
            mask_features = self.mask_branch(features)
            pred_kernel = inst_output['pred_kernel']
            N = pred_kernel.shape[1]
            B, C, H, W = mask_features.shape

            pred_masks = torch.bmm(pred_kernel, mask_features.view(
            B, C, H * W)).view(B, N, H, W)
            pred_masks = F.interpolate(
                pred_masks, scale_factor=self.scale_factor,
                mode='bilinear', align_corners=False)
            output.update(dict(
                pred_masks=pred_masks))
        
        if self.training:
            sparse_inst_losses, matched_indices = self.loss(
                    output,
                    lane_idx_map=kwargs.get('lane_idx_map'),
                    input_shape=kwargs.get('input_shape')
            )
            for k, v in sparse_inst_losses.items():
                sparse_inst_losses[k] = self.sparse_decoder_weight * v
            output.update(sparse_inst_losses)
            output['matched_indices'] = matched_indices
        return output

    def loss(self, output, lane_idx_map, input_shape):
        """
        output : from self.forward
        lane_idx_map : instance-level segmentation map, [20, H, W] where 20=max_lanes
        """
        pred_masks = output['pred_masks']
        pred_masks = output['pred_masks'].view(
            pred_masks.shape[0],
            self.inst_branch.num_group,
            self.inst_branch.num_mask,
            *pred_masks.shape[2:])
        pred_logits = output['pred_logits']
        pred_logits = output['pred_logits'].view(
            pred_logits.shape[0],
            self.inst_branch.num_group,
            self.inst_branch.num_mask,
            *pred_logits.shape[2:])
        pred_scores = output['pred_scores']
        pred_scores = output['pred_scores'].view(
            pred_scores.shape[0],
            self.inst_branch.num_group,
            self.inst_branch.num_mask,
            *pred_scores.shape[2:])

        out = {}
        all_matched_indices = []
        for group_idx in range(self.inst_branch.num_group):
            sparse_inst_losses, matched_indices = \
                self.sparse_inst_crit(
                    outputs=dict(
                        pred_masks=pred_masks[:, group_idx, ...].contiguous(),
                        pred_logits=pred_logits[:, group_idx, ...].contiguous(),
                        pred_scores=pred_scores[:, group_idx, ...].contiguous(),
                    ),
                    targets=self.prepare_targets(lane_idx_map),
                    input_shape=input_shape, # seg_bev
                )
            for k, v in sparse_inst_losses.items():
                out['%s_%d' % (k, group_idx)] = v
            all_matched_indices.append(matched_indices)
        return out, all_matched_indices

    def prepare_targets(self, targets):
        new_targets = []
        for targets_per_image in targets:
            target = {}
            cls_labels = targets_per_image.flatten(-2).max(-1)[0]
            pos_mask = cls_labels > 0

            target["labels"] = cls_labels[pos_mask].long()
            target["masks"] = targets_per_image[pos_mask] > 0
            new_targets.append(target)
        return new_targets
        return output


================================================
FILE: models/sparse_inst_loss.py
================================================
# Copyright (c) Tianheng Cheng and its affiliates. All Rights Reserved

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.cuda.amp import autocast
from scipy.optimize import linear_sum_assignment
from fvcore.nn import sigmoid_focal_loss_jit

from typing import Optional, List

import torch
from torch import Tensor
import torch.distributed as dist
import torch.nn.functional as F
import torchvision


def _max_by_axis(the_list):
    # type: (List[List[int]]) -> List[int]
    maxes = the_list[0]
    for sublist in the_list[1:]:
        for index, item in enumerate(sublist):
            maxes[index] = max(maxes[index], item)
    return maxes


class NestedTensor(object):
    def __init__(self, tensors, mask: Optional[Tensor]):
        self.tensors = tensors
        self.mask = mask

    def to(self, device):
        cast_tensor = self.tensors.to(device)
        mask = self.mask
        if mask is not None:
            assert mask is not None
            cast_mask = mask.to(device)
        else:
            cast_mask = None
        return NestedTensor(cast_tensor, cast_mask)

    def decompose(self):
        return self.tensors, self.mask

    def __repr__(self):
        return str(self.tensors)

# _onnx_nested_tensor_from_tensor_list() is an implementation of
# nested_tensor_from_tensor_list() that is supported by ONNX tracing.


@torch.jit.unused
def _onnx_nested_tensor_from_tensor_list(tensor_list: List[Tensor]) -> NestedTensor:
    max_size = []
    for i in range(tensor_list[0].dim()):
        max_size_i = torch.max(torch.stack([img.shape[i]
                                            for img in tensor_list]).to(torch.float32)).to(torch.int64)
        max_size.append(max_size_i)
    max_size = tuple(max_size)

    # work around for
    # pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
    # m[: img.shape[1], :img.shape[2]] = False
    # which is not yet supported in onnx
    padded_imgs = []
    padded_masks = []
    for img in tensor_list:
        padding = [(s1 - s2) for s1, s2 in zip(max_size, tuple(img.shape))]
        padded_img = torch.nn.functional.pad(img, (0, padding[2], 0, padding[1], 0, padding[0]))
        padded_imgs.append(padded_img)

        m = torch.zeros_like(img[0], dtype=torch.int, device=img.device)
        padded_mask = torch.nn.functional.pad(m, (0, padding[2], 0, padding[1]), "constant", 1)
        padded_masks.append(padded_mask.to(torch.bool))

    tensor = torch.stack(padded_imgs)
    mask = torch.stack(padded_masks)

    return NestedTensor(tensor, mask=mask)


def nested_tensor_from_tensor_list(tensor_list: List[Tensor]):
    # TODO make this more general
    if tensor_list[0].ndim == 3:
        if torchvision._is_tracing():
            # nested_tensor_from_tensor_list() does not export well to ONNX
            # call _onnx_nested_tensor_from_tensor_list() instead
            return _onnx_nested_tensor_from_tensor_list(tensor_list)

        # TODO make it support different-sized images
        max_size = _max_by_axis([list(img.shape) for img in tensor_list])
        # min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list]))
        batch_shape = [len(tensor_list)] + max_size
        b, c, h, w = batch_shape
        dtype = tensor_list[0].dtype
        device = tensor_list[0].device
        tensor = torch.zeros(batch_shape, dtype=dtype, device=device)
        mask = torch.ones((b, h, w), dtype=torch.bool, device=device)
        for img, pad_img, m in zip(tensor_list, tensor, mask):
            pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
            m[: img.shape[1], :img.shape[2]] = False
    else:
        raise ValueError('not supported')
    return NestedTensor(tensor, mask)


def nested_masks_from_list(tensor_list: List[Tensor], input_shape=None):
    if tensor_list[0].ndim == 3:
        dim_size = sum([img.shape[0] for img in tensor_list])
        if input_shape is None:
            max_size = _max_by_axis([list(img.shape[-2:]) for img in tensor_list])
        else:
            max_size = [input_shape[0], input_shape[1]]
        batch_shape = [dim_size] + max_size
        # b, h, w = batch_shape
        dtype = tensor_list[0].dtype
        device = tensor_list[0].device
        tensor = torch.zeros(batch_shape, dtype=dtype, device=device)
        mask = torch.zeros(batch_shape, dtype=torch.bool, device=device)
        idx = 0
        for img in tensor_list:
            c = img.shape[0]
            c_ = idx + c
            tensor[idx: c_, :img.shape[1], : img.shape[2]].copy_(img)
            mask[idx: c_, :img.shape[1], :img.shape[2]] = True
            idx = c_
    else:
        raise ValueError('not supported')
    return NestedTensor(tensor, mask)


def is_dist_avail_and_initialized():
    if not dist.is_available():
        return False
    if not dist.is_initialized():
        return False
    return True


def get_world_size():
    if not is_dist_avail_and_initialized():
        return 1
    return dist.get_world_size()


def aligned_bilinear(tensor, factor):
    # borrowed from Adelaidet: https://github1s.com/aim-uofa/AdelaiDet/blob/HEAD/adet/utils/comm.py
    assert tensor.dim() == 4
    assert factor >= 1
    assert int(factor) == factor

    if factor == 1:
        return tensor

    h, w = tensor.size()[2:]
    tensor = F.pad(tensor, pad=(0, 1, 0, 1), mode="replicate")
    oh = factor * h + 1
    ow = factor * w + 1
    tensor = F.interpolate(
        tensor, size=(oh, ow),
        mode='bilinear',
        align_corners=True
    )
    tensor = F.pad(
        tensor, pad=(factor // 2, 0, factor // 2, 0),
        mode="replicate"
    )

    return tensor[:, :, :oh - 1, :ow - 1]



def compute_mask_iou(inputs, targets):
    inputs = inputs.sigmoid()
    # thresholding
    binarized_inputs = (inputs >= 0.4).float()
    targets = (targets > 0.5).float()
    intersection = (binarized_inputs * targets).sum(-1)
    union = targets.sum(-1) + binarized_inputs.sum(-1) - intersection
    score = intersection / (union + 1e-6)
    return score


def dice_score(inputs, targets):
    inputs = inputs.sigmoid()
    numerator = 2 * torch.matmul(inputs, targets.t())
    denominator = (
        inputs * inputs).sum(-1)[:, None] + (targets * targets).sum(-1)
    score = numerator / (denominator + 1e-4)
    return score


def dice_loss(inputs, targets, reduction='sum'):
    inputs = inputs.sigmoid()
    assert inputs.shape == targets.shape
    numerator = 2 * (inputs * targets).sum(1)
    denominator = (inputs * inputs).sum(-1) + (targets * targets).sum(-1)
    loss = 1 - (numerator) / (denominator + 1e-4)
    if reduction == 'none':
        return loss
    return loss.sum()


# @SPARSE_INST_CRITERION_REGISTRY.register()
class SparseInstCriterion(nn.Module):
    # This part is partially derivated from: https://github.com/facebookresearch/detr/blob/main/models/detr.py

    def __init__(self, num_classes=4, cfg=None, matcher=None):
        super().__init__()
        self.matcher = matcher
        self.losses = ("labels", "masks") # cfg.MODEL.SPARSE_INST.LOSS.ITEMS
        self.weight_dict = self.get_weight_dict(cfg)
        self.num_classes = num_classes # cfg.MODEL.SPARSE_INST.DECODER.NUM_CLASSES

    def get_weight_dict(self, cfg):
        losses = ("loss_ce", "loss_mask", "loss_dice", "loss_objectness")
        weight_dict = {}

        ce_weight = cfg.get('ce_weight', 2.0)
        mask_weight = cfg.get('mask_weight', 5.0)
        dice_weight = cfg.get('dice_weight', 2.0)
        objectness_weight = cfg.get('objectness_weight', 1.0)
        weight_dict = dict(
            zip(losses, (ce_weight, mask_weight, dice_weight, objectness_weight)))
        return weight_dict

    def _get_src_permutation_idx(self, indices):
        # permute predictions following indices
        batch_idx = torch.cat([torch.full_like(src, i)
                              for i, (src, _) in enumerate(indices)])
        src_idx = torch.cat([src for (src, _) in indices])
        return batch_idx, src_idx

    def _get_tgt_permutation_idx(self, indices):
        # permute targets following indices
        batch_idx = torch.cat([torch.full_like(tgt, i)
                              for i, (_, tgt) in enumerate(indices)])
        tgt_idx = torch.cat([tgt for (_, tgt) in indices])
        return batch_idx, tgt_idx

    def loss_labels(self, outputs, targets, indices, num_instances, input_shape=None):
        assert "pred_logits" in outputs
        src_logits = outputs['pred_logits']
        target_classes = torch.full(src_logits.shape[:2], 0, # self.num_classes,
                                    dtype=torch.int64, device=src_logits.device)
        if sum([tmp[0].shape[0] for tmp in indices]) > 0:
            idx = self._get_src_permutation_idx(indices)
            target_classes_o = torch.cat([t["labels"][J]
                                         for t, (_, J) in zip(targets, indices)])
            target_classes[idx] = target_classes_o

        src_logits = src_logits.flatten(0, 1)
        # prepare one_hot target.
        target_classes = target_classes.flatten(0, 1)
        pos_inds = torch.nonzero(
            target_classes != self.num_classes, as_tuple=True)[0]
        labels = torch.zeros_like(src_logits)
        labels[pos_inds, target_classes[pos_inds]] = 1
        # comp focal loss.
        class_loss = sigmoid_focal_loss_jit(
            src_logits,
            labels,
            alpha=0.25,
            gamma=2.0,
            reduction="sum",
        ) / num_instances
        losses = {'loss_ce': class_loss}
        return losses

    def loss_masks_with_iou_objectness(self, outputs, targets, indices, num_instances, input_shape):
        src_idx = self._get_src_permutation_idx(indices)
        tgt_idx = self._get_tgt_permutation_idx(indices)
        # Bx100xHxW
        assert "pred_masks" in outputs
        assert "pred_scores" in outputs
        src_iou_scores = outputs["pred_scores"]
        src_masks = outputs["pred_masks"]
        with torch.no_grad():
            target_masks, _ = nested_masks_from_list(
                [t["masks"] for t in targets], input_shape).decompose()
        num_masks = [len(t["masks"]) for t in targets]
        target_masks = target_masks.to(src_masks)
        if len(target_masks) == 0:
            losses = {
                "loss_dice": src_masks.sum() * 0.0,
                "loss_mask": src_masks.sum() * 0.0,
                "loss_objectness": src_iou_scores.sum() * 0.0
            }
            return losses

        src_masks = src_masks[src_idx]
        target_masks = F.interpolate(
            target_masks[:, None], size=src_masks.shape[-2:], mode='bilinear', align_corners=False).squeeze(1)

        src_masks = src_masks.flatten(1)
        # FIXME: tgt_idx
        mix_tgt_idx = torch.zeros_like(tgt_idx[1])
        cum_sum = 0
        for num_mask in num_masks:
            mix_tgt_idx[cum_sum: cum_sum + num_mask] = cum_sum
            cum_sum += num_mask
        mix_tgt_idx += tgt_idx[1]

        target_masks = target_masks[mix_tgt_idx].flatten(1)

        with torch.no_grad():
            ious = compute_mask_iou(src_masks, target_masks)

        tgt_iou_scores = ious
        src_iou_scores = src_iou_scores[src_idx]
        tgt_iou_scores = tgt_iou_scores.flatten(0)
        src_iou_scores = src_iou_scores.flatten(0)

        losses = {
            "loss_objectness": F.binary_cross_entropy_with_logits(src_iou_scores, tgt_iou_scores, reduction='mean'),
            "loss_dice": dice_loss(src_masks, target_masks) / num_instances,
            "loss_mask": F.binary_cross_entropy_with_logits(src_masks, target_masks, reduction='mean')
        }
        return losses

    def get_loss(self, loss, outputs, targets, indices, num_instances, **kwargs):
        loss_map = {
            "labels": self.loss_labels,
            "masks": self.loss_masks_with_iou_objectness,
        }
        if loss == "loss_objectness":
            # NOTE: loss_objectness will be calculated in `loss_masks_with_iou_objectness`
            return {}
        assert loss in loss_map
        return loss_map[loss](outputs, targets, indices, num_instances, **kwargs)

    def forward(self, outputs, targets, input_shape):

        outputs_without_aux = {k: v for k,
                               v in outputs.items() if k != 'aux_outputs'}

        # Retrieve the matching between the outputs of the last layer and the targets
        indices = self.matcher(outputs_without_aux, targets, input_shape)
        # Compute the average number of target boxes accross all nodes, for normalization purposes
        num_instances = sum(len(t["labels"]) for t in targets)
        num_instances = torch.as_tensor(
            [num_instances], dtype=torch.float, device=next(iter(outputs.values())).device)
        if is_dist_avail_and_initialized():
            torch.distributed.all_reduce(num_instances)
        num_instances = torch.clamp(
            num_instances / get_world_size(), min=1).item()
        # Compute all the requested losses
        losses = {}
        for loss in self.losses:
            # try:
            losses.update(self.get_loss(loss, outputs, targets, indices,
                                        num_instances, input_shape=input_shape))
            # except Exception as e:
            #     import pdb; pdb.set_trace()

        for k in losses.keys():
            if k in self.weight_dict:
                losses[k] *= self.weight_dict[k]
        return losses, indices


# @SPARSE_INST_MATCHER_REGISTRY.register()
class SparseInstMatcherV1(nn.Module):

    def __init__(self, cfg=None):
        super().__init__()
        self.alpha = 0.8 # cfg.MODEL.SPARSE_INST.MATCHER.ALPHA
        self.beta = 0.2 # cfg.MODEL.SPARSE_INST.MATCHER.BETA
        self.mask_score = dice_score

    @torch.no_grad()
    def forward(self, outputs, targets, input_shape):
        B, N, H, W = outputs["pred_masks"].shape
        pred_masks = outputs['pred_masks']
        pred_logits = outputs['pred_logits'].sigmoid()

        indices = []

        for i in range(B):
            tgt_ids = targets[i]["labels"]
            # no annotations
            if tgt_ids.shape[0] == 0:
                indices.append((torch.as_tensor([]),
                                torch.as_tensor([])))
                continue

            tgt_masks = targets[i]['masks'].tensor.to(pred_masks)
            pred_logit = pred_logits[i]
            out_masks = pred_masks[i]

            # upsampling:
            # (1) padding/
            # (2) upsampling to 1x input size (input_shape)
            # (3) downsampling to 0.25x input size (output mask size)
            ori_h, ori_w = tgt_masks.size(1), tgt_masks.size(2)
            tgt_masks_ = torch.zeros(
                (1, tgt_masks.size(0), input_shape[0], input_shape[1])).to(pred_masks)
            tgt_masks_[0, :, :ori_h, :ori_w] = tgt_masks
            tgt_masks = F.interpolate(
                tgt_masks_, size=out_masks.shape[-2:], mode='bilinear', align_corners=False)[0]

            # compute dice score and classification score
            tgt_masks = tgt_masks.flatten(1)
            out_masks = out_masks.flatten(1)

            mask_score = self.mask_score(out_masks, tgt_masks)
            # Nx(Number of gts)
            matching_prob = pred_logit[:, tgt_ids]
            C = (mask_score ** self.alpha) * (matching_prob ** self.beta)
            # hungarian matching
            inds = linear_sum_assignment(C.cpu(), maximize=True)
            indices.append(inds)
        return [(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices]


# @SPARSE_INST_MATCHER_REGISTRY.register()
class SparseInstMatcher(nn.Module):

    def __init__(self, cfg=None):
        super().__init__()
        self.alpha = 0.8 # cfg.MODEL.SPARSE_INST.MATCHER.ALPHA
        self.beta = 0.2 # cfg.MODEL.SPARSE_INST.MATCHER.BETA
        self.mask_score = dice_score

    def forward(self, outputs, targets, input_shape):
        with torch.no_grad():
            # B x 40 x 90 x 120 
            B, N, H, W = outputs["pred_masks"].shape
            pred_masks = outputs['pred_masks']
            pred_logits = outputs['pred_logits'].sigmoid()
            tgt_ids = torch.cat([v["labels"] for v in targets])

            if tgt_ids.shape[0] == 0:
                return [(torch.as_tensor([]).to(pred_logits), torch.as_tensor([]).to(pred_logits))] * B
            tgt_masks, _ = nested_masks_from_list(
                [t["masks"] for t in targets], input_shape).decompose()
            device = pred_masks.device
            tgt_masks = tgt_masks.to(pred_masks)

            tgt_masks = F.interpolate(
                tgt_masks[:, None], size=pred_masks.shape[-2:], mode="bilinear", align_corners=False).squeeze(1)

            pred_masks = pred_masks.view(B * N, -1)
            tgt_masks = tgt_masks.flatten(1)
            with autocast(enabled=False):
                pred_masks = pred_masks.float()
                tgt_masks = tgt_masks.float()
                pred_logits = pred_logits.float()
                mask_score = self.mask_score(pred_masks, tgt_masks)
                # Nx(Number of gts)
                matching_prob = pred_logits.view(B * N, -1)[:, tgt_ids]
                C = (mask_score ** self.alpha) * (matching_prob ** self.beta)

            C = C.view(B, N, -1).cpu()
            # hungarian matching
            sizes = [len(v["masks"]) for v in targets]
            indices = [linear_sum_assignment(c[i], maximize=True)
                       for i, c in enumerate(C.split(sizes, -1))]
            indices = [(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(
                j, dtype=torch.int64)) for i, j in indices]
            return indices


# def build_sparse_inst_matcher(cfg):
#     name = cfg.MODEL.SPARSE_INST.MATCHER.NAME
#     return SPARSE_INST_MATCHER_REGISTRY.get(name)(cfg)


# def build_sparse_inst_criterion(cfg):
#     matcher = build_sparse_inst_matcher(cfg)
#     name = cfg.MODEL.SPARSE_INST.LOSS.NAME
#     return SPARSE_INST_CRITERION_REGISTRY.get(name)(cfg, matcher)


================================================
FILE: models/transformer_bricks.py
================================================
import numpy as np
import math
import warnings

import torch
import torch.nn as nn
from torch.nn import init
import torch.nn.functional as F

from fvcore.nn.weight_init import c2_msra_fill, c2_xavier_fill

from mmdet.models.utils.builder import TRANSFORMER
from mmcv.cnn.bricks.transformer import FFN, build_positional_encoding
from mmcv.cnn import (build_activation_layer, build_conv_layer,
                      build_norm_layer, xavier_init, constant_init)
from mmcv.runner.base_module import BaseModule
from mmcv.cnn.bricks.transformer import (BaseTransformerLayer,
                                         TransformerLayerSequence,
                                         build_transformer_layer_sequence)
from mmcv.cnn.bricks.registry import (ATTENTION,TRANSFORMER_LAYER,
                                      TRANSFORMER_LAYER_SEQUENCE)
from mmcv.ops.multi_scale_deform_attn import MultiScaleDeformableAttnFunction

from .scatter_utils import scatter_mean
from .utils import inverse_sigmoid


def pos2posemb3d(pos, num_pos_feats=128, temperature=10000):
    scale = 2 * math.pi
    pos = pos * scale
    dim_t = torch.arange(num_pos_feats, dtype=torch.float32, device=pos.device)
    dim_t = temperature ** (2 * (dim_t // 2) / num_pos_feats)
    pos_x = pos[..., 0, None] / dim_t
    pos_y = pos[..., 1, None] / dim_t
    pos_z = pos[..., 2, None] / dim_t
    pos_x = torch.stack((pos_x[..., 0::2].sin(), pos_x[..., 1::2].cos()), dim=-1).flatten(-2)
    pos_y = torch.stack((pos_y[..., 0::2].sin(), pos_y[..., 1::2].cos()), dim=-1).flatten(-2)
    pos_z = torch.stack((pos_z[..., 0::2].sin(), pos_z[..., 1::2].cos()), dim=-1).flatten(-2)
    posemb = torch.cat((pos_y, pos_x, pos_z), dim=-1)
    return posemb


def generate_ref_pt(minx, miny, maxx, maxy, z, nx, ny, device='cuda'):
    if isinstance(z, list):
        nz = z[-1]
        # minx, miny, maxx, maxy : in ground coords
        xs = torch.linspace(minx, maxx, nx, dtype=torch.float, device=device
                ).view(1, -1, 1).expand(ny, nx, nz)
        ys = torch.linspace(miny, maxy, ny, dtype=torch.float, device=device
                ).view(-1, 1, 1).expand(ny, nx, nz)
        zs = torch.linspace(z[0], z[1], nz, dtype=torch.float, device=device
                ).view(1, 1, -1).expand(ny, nx, nz)
        ref_3d = torch.stack([xs, ys, zs], dim=-1)
        ref_3d = ref_3d.flatten(1, 2)
    else:
        # minx, miny, maxx, maxy : in ground coords
        xs = torch.linspace(minx, maxx, nx, dtype=torch.float, device=device
                ).view(1, -1, 1).expand(ny, nx, 1)
        ys = torch.linspace(miny, maxy, ny, dtype=torch.float, device=device
                ).view(-1, 1, 1).expand(ny, nx, 1)
        ref_3d = F.pad(torch.cat([xs, ys], dim=-1), (0, 1), mode='constant', value=z)
    return ref_3d


def ground2img(coords3d, H, W, lidar2img, ori_shape, mask=None, return_img_pts=False):
    coords3d = coords3d.clone()
    img_pt = coords3d.flatten(1, 2) @ lidar2img.permute(0, 2, 1)
    img_pt = torch.cat([
        img_pt[..., :2] / torch.maximum(
            img_pt[..., 2:3], torch.ones_like(img_pt[..., 2:3]) * 1e-5),
        img_pt[..., 2:]
    ], dim=-1)

    # rescale to feature_map size
    x = img_pt[..., 0] / ori_shape[0][1] * (W - 1)
    y = img_pt[..., 1] / ori_shape[0][0] * (H - 1)
    valid = (x >= 0) * (y >= 0) * (x <= (W - 1)) \
          * (y <= (H - 1)) * (img_pt[..., 2] > 0)
    if return_img_pts:
        return x, y, valid

    if mask is not None:
        valid = valid * mask.flatten(1, 2).float()

    # B, C, H, W = img_feats.shape
    B = coords3d.shape[0]
    canvas = torch.zeros((B, H, W, 3 + 1),
                         dtype=torch.float32,
                         device=coords3d.device)
    x = x.long()
    y = y.long()
    ind = (x + y * W) * valid.long()
    # ind = torch.clamp(ind, 0, H * W - 1)
    ind = ind.long().unsqueeze(-1).repeat(1, 1, canvas.shape[-1])
    canvas = canvas.flatten(1, 2)
    target = coords3d.flatten(1, 2).clone()
    scatter_mean(target, ind, out=canvas, dim=1)
    canvas = canvas.view(B, H, W, canvas.shape[-1]
        ).permute(0, 3, 1, 2).contiguous()
    canvas[:, :, 0, 0] = 0
    return canvas


@ATTENTION.register_module()
class MSDeformableAttention3D(BaseModule):
    def __init__(self,
                 embed_dims=256,
                 num_heads=8,
                 num_levels=4,
                 num_points=8,
                 im2col_step=64,
                 dropout=0.1,
                 num_query=None,
                 num_anchor_per_query=None,
                 anchor_y_steps=None,
                 batch_first=False,
                 norm_cfg=None,
                 init_cfg=None):
        super().__init__(init_cfg)
        if embed_dims % num_heads != 0:
            raise ValueError(f'embed_dims must be divisible by num_heads, '
                             f'but got {embed_dims} and {num_heads}')
        dim_per_head = embed_dims // num_heads
        self.norm_cfg = norm_cfg
        self.batch_first = batch_first
        self.output_proj = None
        self.fp16_enabled = False

        self.num_query = num_query
        self.num_anchor_per_query = num_anchor_per_query
        self.register_buffer('anchor_y_steps',
            torch.from_numpy(anchor_y_steps).float())
        self.num_points_per_anchor = len(anchor_y_steps) // num_anchor_per_query

        # you'd better set dim_per_head to a power of 2
        # which is more efficient in the CUDA implementation
        def _is_power_of_2(n):
            if (not isinstance(n, int)) or (n < 0):
                raise ValueError(
                    'invalid input for _is_power_of_2: {} (type: {})'.format(
                        n, type(n)))
            return (n & (n - 1) == 0) and n != 0

        if not _is_power_of_2(dim_per_head):
            warnings.warn(
                "You'd better set embed_dims in "
                'MultiScaleDeformAttention to make '
                'the dimension of each attention head a power of 2 '
                'which is more efficient in our CUDA implementation.')

        self.im2col_step = im2col_step
        self.embed_dims = embed_dims
        self.num_levels = num_levels
        self.num_heads = num_heads
        self.num_points = num_points
        self.sampling_offsets = nn.Linear(
            embed_dims,
            num_heads * num_levels * num_points * 2 * self.num_points_per_anchor)
 
Download .txt
gitextract_q_wu7p65/

├── .gitignore
├── LICENSE
├── README.md
├── config/
│   ├── _base_/
│   │   ├── base_res101_bs16xep100.py
│   │   ├── base_res101_bs16xep100_apollo.py
│   │   ├── once_eval_config.json
│   │   └── optimizer.py
│   └── release_iccv/
│       ├── apollo_illu.py
│       ├── apollo_rare.py
│       ├── apollo_standard.py
│       ├── latr_1000_baseline.py
│       ├── latr_1000_baseline_lite.py
│       └── once.py
├── data/
│   ├── Load_Data.py
│   ├── __init__.py
│   ├── apollo_dataset.py
│   └── transform.py
├── docs/
│   ├── data_preparation.md
│   ├── install.md
│   └── train_eval.md
├── experiments/
│   ├── __init__.py
│   ├── ddp.py
│   ├── gpu_utils.py
│   └── runner.py
├── main.py
├── models/
│   ├── __init__.py
│   ├── latr.py
│   ├── latr_head.py
│   ├── ms2one.py
│   ├── scatter_utils.py
│   ├── sparse_ins.py
│   ├── sparse_inst_loss.py
│   ├── transformer_bricks.py
│   └── utils.py
├── pretrained_models/
│   └── .gitkeep
├── requirements.txt
├── utils/
│   ├── MinCostFlow.py
│   ├── __init__.py
│   ├── eval_3D_lane.py
│   ├── eval_3D_lane_apollo.py
│   ├── eval_3D_once.py
│   └── utils.py
└── work_dirs/
    └── .gitkeep
Download .txt
SYMBOL INDEX (230 symbols across 20 files)

FILE: data/Load_Data.py
  class LaneDataset (line 27) | class LaneDataset(Dataset):
    method __init__ (line 39) | def __init__(self, dataset_base_dir, json_file_path, args, data_aug=Fa...
    method preprocess_data_from_json_once (line 125) | def preprocess_data_from_json_once(self, idx_json_file):
    method preprocess_data_from_json_openlane (line 226) | def preprocess_data_from_json_openlane(self, idx_json_file):
    method __len__ (line 322) | def __len__(self):
    method WIP__getitem__ (line 329) | def WIP__getitem__(self, idx):
    method __getitem__ (line 473) | def __getitem__(self, idx):
    method transform_mats_impl (line 479) | def transform_mats_impl(self, cam_extrinsics, cam_intrinsics, cam_pitc...
  function make_lane_y_mono_inc (line 491) | def make_lane_y_mono_inc(lane):
  function data_aug_rotate (line 509) | def data_aug_rotate(img):
  function seed_worker (line 521) | def seed_worker(worker_id):
  function get_loader (line 527) | def get_loader(transformed_dataset, args):
  function map_once_json2img (line 587) | def map_once_json2img(json_label_file):

FILE: data/apollo_dataset.py
  class ApolloLaneDataset (line 45) | class ApolloLaneDataset(Dataset):
    method __init__ (line 46) | def __init__(self, dataset_base_dir, json_file_path, args, data_aug=Fa...
    method gen_single_file_json (line 110) | def gen_single_file_json(self):
    method parse_processed_info_dict_apollo (line 144) | def parse_processed_info_dict_apollo(self, idx):
    method __len__ (line 153) | def __len__(self):
    method WIP__getitem__ (line 160) | def WIP__getitem__(self, idx):
    method __getitem__ (line 306) | def __getitem__(self, idx):
    method init_dataset_3D (line 312) | def init_dataset_3D(self, dataset_base_dir, json_file_path):
    method transform_mats_impl (line 430) | def transform_mats_impl(self, cam_pitch, cam_height):
  function data_aug_rotate (line 438) | def data_aug_rotate(img):
  function seed_worker (line 450) | def seed_worker(worker_id):
  function get_loader (line 456) | def get_loader(transformed_dataset, args):

FILE: data/transform.py
  function get_random_state (line 9) | def get_random_state() -> np.random.RandomState:
  function normal (line 13) | def normal(
  class PhotoMetricDistortionMultiViewImage (line 24) | class PhotoMetricDistortionMultiViewImage:
    method __init__ (line 43) | def __init__(self,
    method __call__ (line 53) | def __call__(self, results):

FILE: experiments/ddp.py
  function setup_dist_launch (line 25) | def setup_dist_launch(args):
  function setup_slurm (line 36) | def setup_slurm(args):
  function setup_distributed (line 60) | def setup_distributed(args):
  function ddp_init (line 68) | def ddp_init(args):
  function to_python_float (line 89) | def to_python_float(t):
  function reduce_tensor (line 95) | def reduce_tensor(tensor, world_size):
  function reduce_tensors (line 102) | def reduce_tensors(*tensors, world_size):

FILE: experiments/gpu_utils.py
  function get_rank (line 4) | def get_rank() -> int:
  function is_main_process (line 12) | def is_main_process() -> bool:
  function gpu_available (line 16) | def gpu_available() -> bool:

FILE: experiments/runner.py
  class Runner (line 30) | class Runner:
    method __init__ (line 31) | def __init__(self, args):
    method train (line 76) | def train(self):
    method _log_model_info (line 237) | def _log_model_info(self, model):
    method _log_training_loss (line 246) | def _log_training_loss(self, output, epoch, step, data_loader):
    method save_checkpoint (line 259) | def save_checkpoint(self, state, to_copy, epoch, save_path):
    method validate (line 272) | def validate(self, model, **kwargs):
    method _recal_gpus_val (line 419) | def _recal_gpus_val(self, gather_output, eval_stats):
    method _get_model_from_cfg (line 481) | def _get_model_from_cfg(self):
    method _load_ckpt_from_workdir (line 496) | def _load_ckpt_from_workdir(self, model):
    method eval (line 514) | def eval(self):
    method _get_train_dataset (line 530) | def _get_train_dataset(self):
    method _get_model_ddp (line 545) | def _get_model_ddp(self):
    method resume_model (line 587) | def resume_model(self, model, path=''):
    method _get_valid_dataset (line 621) | def _get_valid_dataset(self):
    method save_eval_result_once (line 638) | def save_eval_result_once(self, args, img_path, lanelines_pred, laneli...
    method log_eval_stats (line 685) | def log_eval_stats(self, eval_stats):
    method _log_genlane_eval_info (line 699) | def _log_genlane_eval_info(self, eval_stats):
  function set_work_dir (line 714) | def set_work_dir(cfg):

FILE: main.py
  function get_args (line 9) | def get_args():
  function main (line 29) | def main():

FILE: models/latr.py
  class LATR (line 15) | class LATR(nn.Module):
    method __init__ (line 16) | def __init__(self, args):
    method forward (line 60) | def forward(self, image, _M_inv=None, is_training=True, extra_dict=None):

FILE: models/latr_head.py
  class LATRHead (line 22) | class LATRHead(nn.Module):
    method __init__ (line 23) | def __init__(self, args,
    method _init_weights (line 177) | def _init_weights(self):
    method forward (line 186) | def forward(self, input_dict, is_training=True):
    method get_project_loss (line 308) | def get_project_loss(self, results, input_dict, h=20, w=30):
    method get_loss (line 347) | def get_loss(self, output_dict, input_dict):
    method get_reference_points (line 439) | def get_reference_points(H, W, bs=1, device='cuda', dtype=torch.float):
  function build_nn_loss (line 452) | def build_nn_loss(loss_cfg):

FILE: models/ms2one.py
  function build_ms2one (line 9) | def build_ms2one(config):
  class Naive (line 18) | class Naive(nn.Module):
    method __init__ (line 19) | def __init__(self, inc, outc, kernel_size=1):
    method forward (line 23) | def forward(self, ms_feats):
  class DilateNaive (line 30) | class DilateNaive(nn.Module):
    method __init__ (line 31) | def __init__(self, inc, outc, num_scales=4,
    method forward (line 74) | def forward(self, x):

FILE: models/scatter_utils.py
  function broadcast (line 8) | def broadcast(src: torch.Tensor, other: torch.Tensor, dim: int):
  function scatter_sum (line 20) | def scatter_sum(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
  function scatter_add (line 38) | def scatter_add(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
  function scatter_mul (line 44) | def scatter_mul(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
  function scatter_mean (line 50) | def scatter_mean(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
  function scatter_min (line 73) | def scatter_min(
  function scatter_max (line 80) | def scatter_max(
  function scatter (line 87) | def scatter(src: torch.Tensor, index: torch.Tensor, dim: int = -1,

FILE: models/sparse_ins.py
  function _make_stack_3x3_convs (line 10) | def _make_stack_3x3_convs(num_convs, in_channels, out_channels):
  class MaskBranch (line 20) | class MaskBranch(nn.Module):
    method __init__ (line 21) | def __init__(self, cfg, in_channels):
    method _init_weights (line 30) | def _init_weights(self):
    method forward (line 36) | def forward(self, features):
  class InstanceBranch (line 42) | class InstanceBranch(nn.Module):
    method __init__ (line 43) | def __init__(self, cfg, in_channels, **kwargs):
    method _init_weights (line 75) | def _init_weights(self):
    method forward (line 89) | def forward(self, seg_features, is_training=True):
  class SparseInsDecoder (line 154) | class SparseInsDecoder(nn.Module):
    method __init__ (line 155) | def __init__(self, cfg, **kargs) -> None:
    method _init_weights (line 170) | def _init_weights(self):
    method compute_coordinates (line 175) | def compute_coordinates(self, x):
    method forward (line 185) | def forward(self, features, is_training=True, **kwargs):
    method loss (line 219) | def loss(self, output, lane_idx_map, input_shape):
    method prepare_targets (line 261) | def prepare_targets(self, targets):

FILE: models/sparse_inst_loss.py
  function _max_by_axis (line 19) | def _max_by_axis(the_list):
  class NestedTensor (line 28) | class NestedTensor(object):
    method __init__ (line 29) | def __init__(self, tensors, mask: Optional[Tensor]):
    method to (line 33) | def to(self, device):
    method decompose (line 43) | def decompose(self):
    method __repr__ (line 46) | def __repr__(self):
  function _onnx_nested_tensor_from_tensor_list (line 54) | def _onnx_nested_tensor_from_tensor_list(tensor_list: List[Tensor]) -> N...
  function nested_tensor_from_tensor_list (line 83) | def nested_tensor_from_tensor_list(tensor_list: List[Tensor]):
  function nested_masks_from_list (line 108) | def nested_masks_from_list(tensor_list: List[Tensor], input_shape=None):
  function is_dist_avail_and_initialized (line 133) | def is_dist_avail_and_initialized():
  function get_world_size (line 141) | def get_world_size():
  function aligned_bilinear (line 147) | def aligned_bilinear(tensor, factor):
  function compute_mask_iou (line 174) | def compute_mask_iou(inputs, targets):
  function dice_score (line 185) | def dice_score(inputs, targets):
  function dice_loss (line 194) | def dice_loss(inputs, targets, reduction='sum'):
  class SparseInstCriterion (line 206) | class SparseInstCriterion(nn.Module):
    method __init__ (line 209) | def __init__(self, num_classes=4, cfg=None, matcher=None):
    method get_weight_dict (line 216) | def get_weight_dict(self, cfg):
    method _get_src_permutation_idx (line 228) | def _get_src_permutation_idx(self, indices):
    method _get_tgt_permutation_idx (line 235) | def _get_tgt_permutation_idx(self, indices):
    method loss_labels (line 242) | def loss_labels(self, outputs, targets, indices, num_instances, input_...
    method loss_masks_with_iou_objectness (line 271) | def loss_masks_with_iou_objectness(self, outputs, targets, indices, nu...
    method get_loss (line 322) | def get_loss(self, loss, outputs, targets, indices, num_instances, **k...
    method forward (line 333) | def forward(self, outputs, targets, input_shape):
  class SparseInstMatcherV1 (line 364) | class SparseInstMatcherV1(nn.Module):
    method __init__ (line 366) | def __init__(self, cfg=None):
    method forward (line 373) | def forward(self, outputs, targets, input_shape):
  class SparseInstMatcher (line 418) | class SparseInstMatcher(nn.Module):
    method __init__ (line 420) | def __init__(self, cfg=None):
    method forward (line 426) | def forward(self, outputs, targets, input_shape):

FILE: models/transformer_bricks.py
  function pos2posemb3d (line 28) | def pos2posemb3d(pos, num_pos_feats=128, temperature=10000):
  function generate_ref_pt (line 43) | def generate_ref_pt(minx, miny, maxx, maxy, z, nx, ny, device='cuda'):
  function ground2img (line 65) | def ground2img(coords3d, H, W, lidar2img, ori_shape, mask=None, return_i...
  class MSDeformableAttention3D (line 105) | class MSDeformableAttention3D(BaseModule):
    method __init__ (line 106) | def __init__(self,
    method init_weights (line 165) | def init_weights(self):
    method ref_to_lidar (line 185) | def ref_to_lidar(self, reference_points, pc_range, not_y=True):
    method point_sampling (line 196) | def point_sampling(self, reference_points, lidar2img, ori_shape):
    method forward (line 203) | def forward(self,
  class LATRDecoderLayer (line 311) | class LATRDecoderLayer(BaseTransformerLayer):
    method __init__ (line 312) | def __init__(self,
    method forward (line 331) | def forward(self,
  class LATRTransformerDecoder (line 350) | class LATRTransformerDecoder(TransformerLayerSequence):
    method __init__ (line 351) | def __init__(self,
    method init_weights (line 397) | def init_weights(self):
    method pred2M (line 402) | def pred2M(self, pitch_z):
    method forward (line 417) | def forward(self, query, key, value,
  class LATRTransformer (line 525) | class LATRTransformer(BaseModule):
    method __init__ (line 526) | def __init__(self, encoder=None, decoder=None, init_cfg=None):
    method init_weights (line 536) | def init_weights(self):
    method get_reference_points (line 544) | def get_reference_points(spatial_shapes, valid_ratios, device):
    method get_valid_ratio (line 574) | def get_valid_ratio(self, mask):
    method with_encoder (line 585) | def with_encoder(self):
    method forward (line 588) | def forward(self, x, mask, query,

FILE: models/utils.py
  function inverse_sigmoid (line 6) | def inverse_sigmoid(x, eps=1e-5):
  class deepFeatureExtractor_EfficientNet (line 25) | class deepFeatureExtractor_EfficientNet(nn.Module):
    method __init__ (line 26) | def __init__(self, architecture="EfficientNet-B5", lv6=False, lv5=Fals...
    method forward (line 100) | def forward(self, x):
    method freeze_bn (line 129) | def freeze_bn(self, enable=False):

FILE: utils/MinCostFlow.py
  function SolveMinCostFlow (line 32) | def SolveMinCostFlow(adj_mat, cost_mat):
  function main (line 99) | def main():

FILE: utils/eval_3D_lane.py
  class LaneEval (line 43) | class LaneEval(object):
    method __init__ (line 44) | def __init__(self, args, logger):
    method bench (line 61) | def bench(self, pred_lanes, pred_category, gt_lanes, gt_visibility, gt...
    method bench_one_submit (line 259) | def bench_one_submit(self, pred_dir, gt_dir, test_txt, prob_th=0.5, vi...
    method bench_one_submit_ddp (line 423) | def bench_one_submit_ddp(self, pred_lines_sub, gt_lines_sub, model_nam...

FILE: utils/eval_3D_lane_apollo.py
  class LaneEval (line 48) | class LaneEval(object):
    method __init__ (line 49) | def __init__(self, args, logger=None):
    method log_eval_info (line 68) | def log_eval_info(self):
    method bench (line 75) | def bench(self, pred_lanes, gt_lanes, gt_visibility, raw_file, gt_cam_...
    method bench_one_submit (line 280) | def bench_one_submit(self, pred_file, gt_file, prob_th=0.5, vis=False):
    method bench_one_submit_ddp (line 479) | def bench_one_submit_ddp(self, pred_lines_sub, gt_lines_sub, model_nam...
    method bench_PR (line 595) | def bench_PR(self, pred_lanes, gt_lanes, gt_visibility):
    method bench_one_submit_varying_probs (line 701) | def bench_one_submit_varying_probs(self, pred_file, gt_file, eval_out_...

FILE: utils/eval_3D_once.py
  class Bev_Projector (line 19) | class Bev_Projector:
    method __init__ (line 20) | def __init__(self, side_range, fwd_range, height_range, res, lane_widt...
    method proj_oneline_zx (line 32) | def proj_oneline_zx(self, one_lane):
  class LaneEval (line 56) | class LaneEval:
    method file_parser (line 58) | def file_parser(gt_root_path, pred_root_path):
    method summarize (line 79) | def summarize(res):
    method lane_evaluation (line 105) | def lane_evaluation(self, gt_root_path, pred_root_path, config_path, a...
  function evaluate_list (line 188) | def evaluate_list(gt_path_list, pred_path_list, config):
  class LaneEvalOneFile (line 222) | class LaneEvalOneFile:
    method __init__ (line 223) | def __init__(self, gt_path, pred_path, bev_projector, iou_thresh, dist...
    method preprocess (line 233) | def preprocess(self, store_spec):
    method calc_iou (line 246) | def calc_iou(self, lane1, lane2):
    method cal_mean_dist (line 262) | def cal_mean_dist(self, src_line, dst_line):
    method sort_lanes_z (line 278) | def sort_lanes_z(self, lanes):
    method eval (line 286) | def eval(self):
    method cal_tp (line 303) | def cal_tp(self, gt_num, pred_num, gt_lanes, pred_lanes):
  function parse_config (line 329) | def parse_config():

FILE: utils/utils.py
  function create_logger (line 41) | def create_logger(args):
  function define_args (line 66) | def define_args():
  function prune_3d_lane_by_visibility (line 107) | def prune_3d_lane_by_visibility(lane_3d, visibility):
  function prune_3d_lane_by_range (line 112) | def prune_3d_lane_by_range(lane_3d, x_min, x_max):
  function resample_laneline_in_y (line 125) | def resample_laneline_in_y(input_lane, y_steps, out_vis=False):
  function resample_laneline_in_y_with_vis (line 156) | def resample_laneline_in_y_with_vis(input_lane, y_steps, vis_vec):
  function homograpthy_g2im (line 186) | def homograpthy_g2im(cam_pitch, cam_height, K):
  function projection_g2im (line 195) | def projection_g2im(cam_pitch, cam_height, K):
  function homograpthy_g2im_extrinsic (line 203) | def homograpthy_g2im_extrinsic(E, K):
  function projection_g2im_extrinsic (line 211) | def projection_g2im_extrinsic(E, K):
  function homography_crop_resize (line 217) | def homography_crop_resize(org_img_size, crop_y, resize_img_size):
  function homographic_transformation (line 234) | def homographic_transformation(Matrix, x, y):
  function projective_transformation (line 252) | def projective_transformation(Matrix, x, y, z):
  function first_run (line 271) | def first_run(save_path):
  function mkdir_if_missing (line 284) | def mkdir_if_missing(directory):
  function str2bool (line 294) | def str2bool(argument):
  class Logger (line 303) | class Logger(object):
    method __init__ (line 307) | def __init__(self, fpath=None):
    method __del__ (line 315) | def __del__(self):
    method __enter__ (line 318) | def __enter__(self):
    method __exit__ (line 321) | def __exit__(self, *args):
    method write (line 324) | def write(self, msg):
    method flush (line 329) | def flush(self):
    method close (line 335) | def close(self):
  class AverageMeter (line 341) | class AverageMeter(object):
    method __init__ (line 343) | def __init__(self):
    method reset (line 346) | def reset(self):
    method update (line 352) | def update(self, val, n=1):
  function define_optim (line 359) | def define_optim(optim, params, lr, weight_decay):
  function cosine_schedule_with_warmup (line 373) | def cosine_schedule_with_warmup(k, args, dataset_size=None):
  function define_scheduler (line 392) | def define_scheduler(optimizer, args, dataset_size=None):
  function define_init_weights (line 436) | def define_init_weights(model, init_w='normal', activation='relu'):
  function weights_init_normal (line 450) | def weights_init_normal(m):
  function weights_init_xavier (line 474) | def weights_init_xavier(m):
  function weights_init_kaiming (line 490) | def weights_init_kaiming(m):
  function weights_init_orthogonal (line 506) | def weights_init_orthogonal(m):
Condensed preview — 43 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (351K chars).
[
  {
    "path": ".gitignore",
    "chars": 9,
    "preview": "*pyc\n*pth"
  },
  {
    "path": "LICENSE",
    "chars": 1063,
    "preview": "MIT License\n\nCopyright (c) 2023 JMoonr\n\nPermission is hereby granted, free of charge, to any person obtaining a copy\nof "
  },
  {
    "path": "README.md",
    "chars": 6342,
    "preview": "<br />\n<p align=\"center\">\n  \n  <h3 align=\"center\"><strong>LATR: 3D Lane Detection from Monocular Images with Transformer"
  },
  {
    "path": "config/_base_/base_res101_bs16xep100.py",
    "chars": 1709,
    "preview": "import os\nimport os.path as osp\nimport numpy as np\n\n\ndataset_name = 'openlane'\ndataset = '300' # '300' | '1000'\n\n#  The "
  },
  {
    "path": "config/_base_/base_res101_bs16xep100_apollo.py",
    "chars": 1562,
    "preview": "import os\nimport os.path as osp\nimport numpy as np\n\n# ========DATA SETTING======== #\ndataset_name = 'apollo'\ndataset = '"
  },
  {
    "path": "config/_base_/once_eval_config.json",
    "chars": 312,
    "preview": "{\"side_range_l\": -10, \"side_range_h\": 10, \"fwd_range_l\": 0, \"fwd_range_h\": 50, \"height_range_l\": 0, \"height_range_h\": 5,"
  },
  {
    "path": "config/_base_/optimizer.py",
    "chars": 550,
    "preview": "# opt setting\noptimizer = 'adam'\nlearning_rate = 2e-4\n\nweight_decay = 0.001\nlr_decay = False # TODO 'store_true'\nniter ="
  },
  {
    "path": "config/release_iccv/apollo_illu.py",
    "chars": 4600,
    "preview": "import numpy as np\nfrom mmcv.utils import Config\nimport os.path as osp\n\n_base_ = [\n    '../_base_/base_res101_bs16xep100"
  },
  {
    "path": "config/release_iccv/apollo_rare.py",
    "chars": 4602,
    "preview": "import numpy as np\nfrom mmcv.utils import Config\nimport os.path as osp\n\n_base_ = [\n    '../_base_/base_res101_bs16xep100"
  },
  {
    "path": "config/release_iccv/apollo_standard.py",
    "chars": 4476,
    "preview": "import numpy as np\nfrom mmcv.utils import Config\nimport os.path as osp\n\n_base_ = [\n    '../_base_/base_res101_bs16xep100"
  },
  {
    "path": "config/release_iccv/latr_1000_baseline.py",
    "chars": 4521,
    "preview": "import numpy as np\nfrom mmcv.utils import Config\n\n_base_ = [\n    '../_base_/base_res101_bs16xep100.py',\n    '../_base_/o"
  },
  {
    "path": "config/release_iccv/latr_1000_baseline_lite.py",
    "chars": 4526,
    "preview": "import numpy as np\nfrom mmcv.utils import Config\n\n_base_ = [\n    '../_base_/base_res101_bs16xep100.py',\n    '../_base_/o"
  },
  {
    "path": "config/release_iccv/once.py",
    "chars": 4577,
    "preview": "import numpy as np\nfrom mmcv.utils import Config\nimport os.path as osp\n\n_base_ = [\n    '../_base_/base_res101_bs16xep100"
  },
  {
    "path": "data/Load_Data.py",
    "chars": 25424,
    "preview": "import re\nimport os\nimport sys\nimport copy\nimport json\nimport glob\nimport random\nimport warnings\nimport numpy as np\nimpo"
  },
  {
    "path": "data/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "data/apollo_dataset.py",
    "chars": 21423,
    "preview": "# ==============================================================================\n# Copyright (c) 2022 The PersFormer Aut"
  },
  {
    "path": "data/transform.py",
    "chars": 3916,
    "preview": "import numpy as np\nimport mmcv\nimport torch\nimport torch.nn.functional as F\nimport PIL\nimport random\n\n\ndef get_random_st"
  },
  {
    "path": "docs/data_preparation.md",
    "chars": 1563,
    "preview": "# Data Preparation\n\n## OpenLane\n\nFollow [OpenLane](https://github.com/OpenDriveLab/PersFormer_3DLane#dataset) to downloa"
  },
  {
    "path": "docs/install.md",
    "chars": 779,
    "preview": "# Environment\n\nIt is recommanded to build a new virtual environment.\n\n## 1. Install pytorch and requirements.\n\n```bash\n#"
  },
  {
    "path": "docs/train_eval.md",
    "chars": 2598,
    "preview": "# Training and Evaluation\n\n## Train\n\n### Openlane\n\n- Base version:\n\n```bash\nCUDA_VISIBLE_DEVICES='0,1,2,3' python -m tor"
  },
  {
    "path": "experiments/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "experiments/ddp.py",
    "chars": 3305,
    "preview": "# ==============================================================================\n# Copyright (c) 2022 The PersFormer Aut"
  },
  {
    "path": "experiments/gpu_utils.py",
    "chars": 324,
    "preview": "import torch\nimport torch.distributed as dist\n\ndef get_rank() -> int:\n    if not dist.is_available():\n        return 0\n "
  },
  {
    "path": "experiments/runner.py",
    "chars": 30749,
    "preview": "import torch\nimport torch.optim\nimport torch.nn as nn\nimport numpy as np\nimport glob\nimport time\nimport os\nfrom tqdm imp"
  },
  {
    "path": "main.py",
    "chars": 1300,
    "preview": "import argparse\nfrom mmcv.utils import Config, DictAction\n\nfrom utils.utils import *\nfrom experiments.ddp import *\nfrom "
  },
  {
    "path": "models/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "models/latr.py",
    "chars": 2780,
    "preview": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom utils.utils import *\nfrom mmdet3d.models import "
  },
  {
    "path": "models/latr_head.py",
    "chars": 18214,
    "preview": "import numpy as np\nimport math\nimport cv2\n\nimport torch\nimport torch.nn as nn\nfrom torch.nn import init\nimport torch.nn."
  },
  {
    "path": "models/ms2one.py",
    "chars": 3422,
    "preview": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport copy\nfrom mmcv.cnn import ConvModule\nfrom mmse"
  },
  {
    "path": "models/scatter_utils.py",
    "chars": 6376,
    "preview": "# Copy from https://github.com/rusty1s/pytorch_scatter\n\nfrom typing import Optional, Tuple\n\nimport torch\n\n\ndef broadcast"
  },
  {
    "path": "models/sparse_ins.py",
    "chars": 10147,
    "preview": "import math\n\nimport torch\nimport torch.nn as nn\nfrom torch.nn import init\nimport torch.nn.functional as F\nfrom fvcore.nn"
  },
  {
    "path": "models/sparse_inst_loss.py",
    "chars": 18128,
    "preview": "# Copyright (c) Tianheng Cheng and its affiliates. All Rights Reserved\n\nimport torch\nimport torch.nn as nn\nimport torch."
  },
  {
    "path": "models/transformer_bricks.py",
    "chars": 26517,
    "preview": "import numpy as np\nimport math\nimport warnings\n\nimport torch\nimport torch.nn as nn\nfrom torch.nn import init\nimport torc"
  },
  {
    "path": "models/utils.py",
    "chars": 6446,
    "preview": "import torch\nimport geffnet\nimport torch.nn as nn\n\n\ndef inverse_sigmoid(x, eps=1e-5):\n    \"\"\"Inverse function of sigmoid"
  },
  {
    "path": "pretrained_models/.gitkeep",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "requirements.txt",
    "chars": 281,
    "preview": "opencv_python==4.2.0.32\nShapely==1.7.0\nxmljson==0.2.0\nthop==0.0.31.post2005241907\nmatplotlib==3.5.1\nscipy==1.4.1\np_tqdm="
  },
  {
    "path": "utils/MinCostFlow.py",
    "chars": 6670,
    "preview": "# ==============================================================================\n# Binaries and/or source for the follow"
  },
  {
    "path": "utils/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "utils/eval_3D_lane.py",
    "chars": 31560,
    "preview": "# ==============================================================================\n# Binaries and/or source for the follow"
  },
  {
    "path": "utils/eval_3D_lane_apollo.py",
    "chars": 44647,
    "preview": "\"\"\"\nDescription: This code is to evaluate 3D lane detection. The optimal matching between ground-truth set and predicted"
  },
  {
    "path": "utils/eval_3D_once.py",
    "chars": 13426,
    "preview": "import argparse\nimport numpy as np\nfrom multiprocessing import Process\nimport cv2\n# from jarvis.eload import load_json\nf"
  },
  {
    "path": "utils/utils.py",
    "chars": 19672,
    "preview": "# ==============================================================================\r\n# Copyright (c) 2022 The PersFormer Au"
  },
  {
    "path": "work_dirs/.gitkeep",
    "chars": 0,
    "preview": ""
  }
]

About this extraction

This page contains the full source code of the JMoonr/LATR GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 43 files (330.6 KB), approximately 84.1k tokens, and a symbol index with 230 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!