[
  {
    "path": ".gitignore",
    "content": "*pyc\n*pth"
  },
  {
    "path": "LICENSE",
    "content": "MIT License\n\nCopyright (c) 2023 JMoonr\n\nPermission is hereby granted, free of charge, to any person obtaining a copy\nof this software and associated documentation files (the \"Software\"), to deal\nin the Software without restriction, including without limitation the rights\nto use, copy, modify, merge, publish, distribute, sublicense, and/or sell\ncopies of the Software, and to permit persons to whom the Software is\nfurnished to do so, subject to the following conditions:\n\nThe above copyright notice and this permission notice shall be included in all\ncopies or substantial portions of the Software.\n\nTHE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\nIMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\nFITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\nAUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\nLIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\nOUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\nSOFTWARE.\n"
  },
  {
    "path": "README.md",
    "content": "<br />\n<p align=\"center\">\n  \n  <h3 align=\"center\"><strong>LATR: 3D Lane Detection from Monocular Images with Transformer</strong></h3>\n\n<p align=\"center\">\n  <a href=\"https://arxiv.org/abs/2308.04583\" target='_blank'>\n    <!-- <img src=\"https://img.shields.io/badge/arXiv-%F0%9F%93%83-yellow\"> -->\n    <img src=\"https://img.shields.io/badge/arXiv-2308.04583-b31b1b.svg\">\n  </a>\n  <a href=\"\" target='_blank'>\n    <img src=\"https://visitor-badge.laobi.icu/badge?page_id=JMoonr.LATR&left_color=gray&right_color=yellow\">\n  </a>\n    <a href=\"https://github.com/JMoonr/LATR\" target='_blank'>\n     <img src=\"https://img.shields.io/github/stars/JMoonr/LATR?style=social\">\n  </a>\n  \n</p>\n\n\nThis is the official PyTorch implementation of [LATR: 3D Lane Detection from Monocular Images with Transformer](https://arxiv.org/abs/2308.04583).\n\n![fig2](/assets/fig2.png)  \n\n## News\n  - **2024-01-15** :confetti_ball: Our new work [DV-3DLane: End-to-end Multi-modal 3D Lane Detection with Dual-view Representation](https://github.com/JMoonr/dv-3dlane) is accepted by ICLR2024.\n\n  - **2023-08-12** :tada: LATR is accepted as an Oral presentation at ICCV2023! :sparkles:\n\n\n## Environments\nTo set up the required packages, please refer to the [installation guide](./docs/install.md).\n\n## Data\nPlease follow [data preparation](./docs/data_preparation.md) to download dataset.\n\n## Pretrained Models\nNote that the performance of pretrained model is higher than our paper due to code refactoration and optimization. All models are uploaded to [google drive](https://drive.google.com/drive/folders/1AhvLvE84vayzFxa0teRHYRdXz34ulzjB?usp=sharing).\n\n| Dataset | Pretrained | Metrics | md5 |\n| - | - | - | - |\n| OpenLane-1000 | [Google Drive](https://drive.google.com/file/d/1jThvqnJ2cUaAuKdlTuRKjhLCH0Zq62A1/view?usp=sharing) | F1=0.6297 | d8ecb900c34fd23a9e7af840aff00843 |\n| OpenLane-1000 (Lite version) | [Google Drive](https://drive.google.com/file/d/1WD5dxa6SI2oR9popw3kO2-7eGM2z-IHY/view?usp=sharing) | F1=0.6212 | 918de41d0d31dbfbecff3001c49dc296 |\n| ONCE | [Google Drive](https://drive.google.com/file/d/12kXkJ9tDxm13CyFbB1ddt82lJZkYEicd/view?usp=sharing) | F1=0.8125 | 65a6958c162e3c7be0960bceb3f54650 |\n| Apollo-balance | [Google Drive](https://drive.google.com/file/d/1hGyNrYi3wAQaKbC1mD_18NG35gdmMUiM/view?usp=sharing) | F1=0.9697 | 551967e8654a8a522bdb0756d74dd1a2 |\n| Apollo-rare | [Google Drive](https://drive.google.com/file/d/19VVBaWBnWiEqGx1zJaeXF_1CKn88G5v0/view?usp=sharing) | F1=0.9641 | 184cfff1d3097a9009011f79f4594138 |\n| Apollo-visual | [Google Drive](https://drive.google.com/file/d/1ZzaUODYK2dyiG_2bDXe5tiutxNvc71M2/view?usp=sharing) | F1=0.9611 | cec4aa567c264c84808f3c32f5aace82 |\n\n\n## Evaluation\nYou can download the [pretrained models](#pretrained-models) to `./pretrained_models` directory and refer to the [eval guide](./docs/train_eval.md#evaluation) for evaluation.\n\n## Train\nPlease follow the steps in [training](./docs/train_eval.md#train) to train the model.\n\n## Benchmark\n\n### OpenLane\n\n| Models | F1 | Accuracy | X error <br> near \\| far | Z-error <br> near \\| far |\n| ----- | -- | -------- | ------- | ------- |\n| 3DLaneNet | 44.1 | - | 0.479 \\| 0.572 | 0.367 \\| 0.443 |\n| GenLaneNet | 32.3 | - | 0.593 \\| 0.494 | 0.140 \\| 0.195 |\n| Cond-IPM | 36.3 | - | 0.563 \\| 1.080 | 0.421 \\| 0.892 |\n| PersFormer | 50.5 | 89.5 | 0.319 \\| 0.325 | 0.112 \\| 0.141 |\n| CurveFormer | 50.5 | - | 0.340 \\| 0.772 | 0.207 \\| 0.651 |\n| PersFormer-Res50 | 53.0 | 89.2 | 0.321 \\| 0.303 | 0.085 \\| 0.118 |\n| **LATR-Lite** | 61.5 | 91.9 | 0.225 \\| 0.249 | 0.073 \\| 0.106 |\n| **LATR** | 61.9 | 92.0 | 0.219 \\| 0.259 | 0.075 \\| 0.104 |\n\n\n### Apollo\n\nPlaes kindly refer to our paper for the performance on other scenes.\n\n<table>\n    <tr>\n        <td>Scene</td>\n        <td>Models</td>\n        <td>F1</td>\n        <td>AP</td>\n        <td>X error <br> near | far </td>\n        <td>Z error <br> near | far </td>\n    </tr>\n    <tr>\n        <td rowspan=\"8\">Balanced Scene</td>\n        <td>3DLaneNet</td>\n        <td>86.4</td>\n        <td>89.3</td>\n        <td>0.068 | 0.477</td>\n        <td>0.015 | 0.202</td>\n    </tr>\n    <tr>\n        <td>GenLaneNet</td>\n        <td>88.1</td>\n        <td>90.1</td>\n        <td>0.061 | 0.496</td>\n        <td>0.012 | 0.214</td>\n    </tr>\n    <tr>\n        <td>CLGo</td>\n        <td>91.9</td>\n        <td>94.2</td>\n        <td>0.061 | 0.361</td>\n        <td>0.029 | 0.250</td>\n    </tr>\n    <tr>\n        <td>PersFormer</td>\n        <td>92.9</td>\n        <td>-</td>\n        <td>0.054 | 0.356</td>\n        <td>0.010 | 0.234</td>\n    </tr>\n    <tr>\n        <td>GP</td>\n        <td>91.9</td>\n        <td>93.8</td>\n        <td>0.049 | 0.387</td>\n        <td>0.008 | 0.213</td>\n    </tr>\n    <tr>\n        <td>CurveFormer</td>\n        <td>95.8</td>\n        <td>97.3</td>\n        <td>0.078 | 0.326</td>\n        <td>0.018 | 0.219</td>\n    </tr>\n    <tr>\n        <td><b>LATR-Lite</b></td>\n        <td>96.5</td>\n        <td>97.8</td>\n        <td>0.035 | 0.283</td>\n        <td>0.012 | 0.209</td>\n    </tr>\n    <tr>\n        <td><b>LATR</b?</td>\n        <td>96.8</td>\n        <td>97.9</td>\n        <td>0.022 | 0.253</td>\n        <td>0.007 | 0.202</td>\n    </tr>\n</table>\n\n\n### ONCE\n\n| Method     | F1  | Precision(%) | Recall(%) | CD error(m) |\n| :- | :- | :- | :- | :- |   \n| 3DLaneNet  | 44.73 | 61.46 | 35.16 | 0.127 |\n| GenLaneNet | 45.59 | 63.95 | 35.42 | 0.121 |\n| SALAD <ONCE-3DLane> | 64.07 | 75.90 | 55.42 | 0.098 |\n| PersFormer | 72.07 | 77.82 | 67.11 | 0.086 |\n| **LATR** | 80.59 | 86.12 | 75.73 | 0.052 |\n\n## Acknowledgment\n\nThis library is inspired by [OpenLane](https://github.com/OpenDriveLab/PersFormer_3DLane), [GenLaneNet](https://github.com/yuliangguo/Pytorch_Generalized_3D_Lane_Detection), [mmdetection3d](https://github.com/open-mmlab/mmdetection3d), [SparseInst](https://github.com/hustvl/SparseInst), [ONCE](https://github.com/once-3dlanes/once_3dlanes_benchmark) and many other related works, we thank them for sharing the code and datasets.\n\n\n## Citation\nIf you find LATR is useful for your research, please consider citing the paper:\n\n```tex\n@article{luo2023latr,\n  title={LATR: 3D Lane Detection from Monocular Images with Transformer},\n  author={Luo, Yueru and Zheng, Chaoda and Yan, Xu and Kun, Tang and Zheng, Chao and Cui, Shuguang and Li, Zhen},\n  journal={arXiv preprint arXiv:2308.04583},\n  year={2023}\n}\n```"
  },
  {
    "path": "config/_base_/base_res101_bs16xep100.py",
    "content": "import os\nimport os.path as osp\nimport numpy as np\n\n\ndataset_name = 'openlane'\ndataset = '300' # '300' | '1000'\n\n#  The path of dataset json files (annotations)\ndata_dir = './data/openlane/lane3d_300/'\n# The path of dataset image files (images)\ndataset_dir = './data/openlane/images/'\noutput_dir = dataset_name\n\norg_h = 1280\norg_w = 1920\ncrop_y = 0\n\nipm_h = 208\nipm_w = 128\nresize_h = 360\nresize_w = 480\n\nmean = [0.485, 0.456, 0.406]\nstd = [0.229, 0.224, 0.225]\n\ncam_height = 1.55\npitch = 3\nfix_cam = False\npred_cam = False\n\nmodel_name = 'LATR'\nweight_init = 'normal'\nmod = None\n\nposition_embedding = 'learned'\nmax_lanes = 20\nnum_category = 21\nprob_th = 0.5\nnum_class = 21 # 1 bgd | 1 lanes\n\n# top view\ntop_view_region = np.array([[-10, 103], [10, 103], [-10, 3], [10, 3]])\nanchor_y_steps = np.linspace(3, 103, 25)\nnum_y_steps = len(anchor_y_steps)\n\n# placeholder, not used\nK = np.array([[1000., 0., 960.],\n            [0., 1000., 640.],\n            [0., 0., 1.]])\n\n# persformer anchor\nuse_default_anchor = False\n\nbatch_size = 16\nnepochs = 100\n\nno_cuda = False\nnworkers = 16\n\nstart_epoch = 0\nchannels_in = 3\n\n# args input\ntest_mode = False # 'store_true' # TODO \nevaluate = False # TODO\nresume = '' # resume latest saved run.\n\n# tensorboard\nno_tb = False\n\n# print & save\nprint_freq = 50\nsave_freq = 50\n\n# ddp setting\ndist = True\nsync_bn = True\ncudnn = True\n\ndistributed = True\nlocal_rank = None #TODO\ngpu = 0\nworld_size = 1\nnodes = 1\n\n# for reload ckpt\neval_ckpt = ''\nresume_from = ''\noutput_dir = 'openlane'\nevaluate_case = ''\neval_freq = 8 # eval freq during training\n\nsave_json_path = None\nsave_root = 'work_dirs'\nsave_prefix = osp.join(os.getcwd(), save_root)\nsave_path = osp.join(save_prefix, output_dir)"
  },
  {
    "path": "config/_base_/base_res101_bs16xep100_apollo.py",
    "content": "import os\nimport os.path as osp\nimport numpy as np\n\n# ========DATA SETTING======== #\ndataset_name = 'apollo'\ndataset = 'standard'\n\ndata_dir = osp.join('./data/apollosyn_gen-lanenet/data_splits', dataset)\ndataset_dir = './data/apollosyn_gen-lanenet/Apollo_Sim_3D_Lane_Release'\n\noutput_dir = 'apollo'\n\nrewrite_pred = True\nsave_best = False\n\noutput_dir = dataset_name\n\norg_h = 1080\norg_w = 1920\ncrop_y = 0\n\ncam_height = 1.55\npitch = 3\nfix_cam = False\npred_cam = False\n\nmodel_name = 'LATR'\nmod = None\n\nipm_h = 208\nipm_w = 128\nresize_h = 360\nresize_w = 480\n\nmean = [0.485, 0.456, 0.406]\nstd = [0.229, 0.224, 0.225]\n\nK = np.array([[2015., 0., 960.],\n            [0., 2015., 540.],\n            [0., 0., 1.]])\n\nposition_embedding = 'learned'\n\nmax_lanes = 6\nnum_category = 2\nprob_th = 0.5\nnum_class = 2 # 1 bgd | 1 lanes\n\nbatch_size = 16\nnepochs = 210\nnworkers = 16\n\n# ddp setting\ndist = True\nsync_bn = True\ncudnn = True\n\ndistributed = True\nlocal_rank = None #TODO\ngpu = 0\nworld_size = 1\nnodes = 1\n\n# for reload ckpt\neval_ckpt = ''\nresume = '' # ckpt number as input\nresume_from = '' # ckpt path as input\n\nno_cuda = False\n\n# tensorboard\nno_tb = False\n\nstart_epoch = 0\nchannels_in = 3\n\n# args input\ntest_mode = False # 'store_true' # TODO \nevaluate = False # TODO\nevaluate_case = ''\n\n# print & save\nprint_freq = 50\nsave_freq = 50\neval_freq = 20 # eval freq during training\n\n# top view\ntop_view_region = np.array([[-10, 103], [10, 103], [-10, 3], [10, 3]])\nanchor_y_steps = np.linspace(3, 103, 25)\nnum_y_steps = len(anchor_y_steps)\n\nsave_path = None\nsave_json_path = None\n\n"
  },
  {
    "path": "config/_base_/once_eval_config.json",
    "content": "{\"side_range_l\": -10, \"side_range_h\": 10, \"fwd_range_l\": 0, \"fwd_range_h\": 50, \"height_range_l\": 0, \"height_range_h\": 5, \"res\": 0.05, \"lane_width_x\": 30, \"lane_width_y\": 10, \"iou_thresh\": 0.3, \"distance_thresh\": 0.3, \"process_num\": 10, \"score_l\": 0.10, \"score_h\": 1, \"score_step\": 0.05, \"exp_name\": \"evaluation\"}"
  },
  {
    "path": "config/_base_/optimizer.py",
    "content": "# opt setting\noptimizer = 'adam'\nlearning_rate = 2e-4\n\nweight_decay = 0.001\nlr_decay = False # TODO 'store_true'\nniter = 900 # num of iter at starting learning rate\nniter_decay = 400 # '# of iter to linearly decay learning rate to zero'\nlr_policy = 'cosine'\ngamma = 0.1 # multiplicative factor of learning rate decay\nlr_decay_iters = 10 # multiply by a gamma every lr_decay_iters iterations\nT_max = 8 # maximum number of iterations\nT_0 = 8\nT_mult = 2\neta_min = 1e-5 # minimum learning rate\nclip_grad_norm = 35.0 # grad clipping\nloss_threshold = 1e5\n\n"
  },
  {
    "path": "config/release_iccv/apollo_illu.py",
    "content": "import numpy as np\nfrom mmcv.utils import Config\nimport os.path as osp\n\n_base_ = [\n    '../_base_/base_res101_bs16xep100_apollo.py',\n    '../_base_/optimizer.py',\n]\n\nmod = 'release_iccv/apollo_illu'\nmean = [0.485, 0.456, 0.406]\nstd = [0.229, 0.224, 0.225]\n\n\ndataset_name = 'apollo'\ndataset = 'illus_chg'\ndata_dir = osp.join('./data/apollosyn_gen-lanenet/data_splits', dataset)\ndataset_dir = './data/apollosyn_gen-lanenet/Apollo_Sim_3D_Lane_Release'\noutput_dir = 'apollo'\nnum_category = 2\nmax_lanes = 6\n\nT_max = 30\neta_min = 1e-6\nclip_grad_norm = 20\nnepochs = 210\neval_freq = 1\n\nh_org, w_org = 1080, 1920\n\nbatch_size = 8\nnworkers = 10\npos_threshold = 0.5\ntop_view_region = np.array([\n    [-10, 103], [10, 103], [-10, 3], [10, 3]])\nenlarge_length = 20\nposition_range = [\n    top_view_region[0][0] - enlarge_length,\n    top_view_region[2][1] - enlarge_length,\n    -5,\n    top_view_region[1][0] + enlarge_length,\n    top_view_region[0][1] + enlarge_length,\n    5.]\nanchor_y_steps = np.linspace(3, 103, 20)\nnum_y_steps = len(anchor_y_steps)\n\nphoto_aug = dict(\n    brightness_delta=32,\n    contrast_range=(0.5, 1.5),\n    saturation_range=(0.5, 1.5),\n    hue_delta=18)\n\n_dim_ = 256\nnum_query = 12\nnum_pt_per_line = 20\nlatr_cfg = dict(\n    fpn_dim = _dim_,\n    num_query = num_query,\n    num_group = 1,\n    sparse_num_group = 4,\n    encoder = dict(\n        type='ResNet',\n        depth=50,\n        num_stages=4,\n        out_indices=(1, 2, 3),\n        frozen_stages=-1,\n        norm_cfg=dict(type='BN2d', requires_grad=False),\n        norm_eval=True,\n        style='caffe',\n        dcn=dict(type='DCNv2', deform_groups=1, fallback_on_stride=False),\n        stage_with_dcn=(False, False, True, True),\n        init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50')\n    ),\n    neck = dict(\n        type='FPN',\n        in_channels=[512, 1024, 2048],\n        out_channels=_dim_,\n        start_level=0,\n        add_extra_convs='on_output',\n        num_outs=4,\n        relu_before_extra_convs=True\n    ),\n    head=dict(\n        xs_loss_weight=2.0,\n        zs_loss_weight=10.0,\n        vis_loss_weight=1.0,\n        cls_loss_weight=10,\n        project_loss_weight=1.0,\n        pt_as_query=True,\n        num_pt_per_line=num_pt_per_line,\n    ),\n    trans_params=dict(init_z=0, bev_h=150, bev_w=70),\n)\n\nms2one=dict(\n    type='DilateNaive',\n    inc=_dim_, outc=_dim_, num_scales=4,\n    dilations=(1, 2, 5, 9))\n\ntransformer=dict(\n    type='LATRTransformer',\n    decoder=dict(\n        type='LATRTransformerDecoder',\n        embed_dims=_dim_,\n        num_layers=6,\n        enlarge_length=enlarge_length,\n        M_decay_ratio=1,\n        num_query=num_query,\n        num_anchor_per_query=num_pt_per_line,\n        anchor_y_steps=anchor_y_steps,\n        transformerlayers=dict(\n            type='LATRDecoderLayer',\n            attn_cfgs=[\n                dict(\n                    type='MultiheadAttention',\n                    embed_dims=_dim_,\n                    num_heads=4,\n                    dropout=0.1),\n                dict(\n                    type='MSDeformableAttention3D',\n                    embed_dims=_dim_,\n                    num_heads=4,\n                    num_levels=1,\n                    num_points=8,\n                    batch_first=False,\n                    num_query=num_query,\n                    num_anchor_per_query=num_pt_per_line,\n                    anchor_y_steps=anchor_y_steps,\n                    dropout=0.1),\n                ],\n            ffn_cfgs=dict(\n                type='FFN',\n                embed_dims=_dim_,\n                feedforward_channels=_dim_*8,\n                num_fcs=2,\n                ffn_drop=0.1,\n                act_cfg=dict(type='ReLU', inplace=True),\n            ),\n            feedforward_channels=_dim_ * 8,\n            operation_order=('self_attn', 'norm', 'cross_attn', 'norm',\n                            'ffn', 'norm')),\n))\n\nsparse_ins_decoder=Config(\n    dict(\n        encoder=dict(\n            out_dims=_dim_),\n        decoder=dict(\n            num_query=latr_cfg['num_query'],\n            num_group=latr_cfg['num_group'],\n            sparse_num_group=latr_cfg['sparse_num_group'],\n            hidden_dim=_dim_,\n            kernel_dim=_dim_,\n            num_classes=num_category,\n            num_convs=4,\n            output_iam=True,\n            scale_factor=1.,\n            ce_weight=2.0,\n            mask_weight=5.0,\n            dice_weight=2.0,\n            objectness_weight=1.0,\n        ),\n        sparse_decoder_weight=5.0,\n))\n\nresize_h = 720\nresize_w = 960\n\noptimizer_cfg = dict(\n    type='AdamW',\n    lr=2e-4,\n    weight_decay=0.01)"
  },
  {
    "path": "config/release_iccv/apollo_rare.py",
    "content": "import numpy as np\nfrom mmcv.utils import Config\nimport os.path as osp\n\n_base_ = [\n    '../_base_/base_res101_bs16xep100_apollo.py',\n    '../_base_/optimizer.py',\n]\n\nmod = 'release_iccv/apollo_rare'\nmean = [0.485, 0.456, 0.406]\nstd = [0.229, 0.224, 0.225]\n\n\ndataset_name = 'apollo'\ndataset = 'rare_subset'\ndata_dir = osp.join('./data/apollosyn_gen-lanenet/data_splits', dataset)\ndataset_dir = './data/apollosyn_gen-lanenet/Apollo_Sim_3D_Lane_Release'\noutput_dir = 'apollo'\nnum_category = 2\nmax_lanes = 6\n\nT_max = 30\neta_min = 1e-8\nclip_grad_norm = 20\nnepochs = 210\neval_freq = 1\n\nh_org, w_org = 1080, 1920\n\nbatch_size = 8\nnworkers = 10\npos_threshold = 0.5\ntop_view_region = np.array([\n    [-10, 103], [10, 103], [-10, 3], [10, 3]])\nenlarge_length = 20\nposition_range = [\n    top_view_region[0][0] - enlarge_length,\n    top_view_region[2][1] - enlarge_length,\n    -5,\n    top_view_region[1][0] + enlarge_length,\n    top_view_region[0][1] + enlarge_length,\n    5.]\nanchor_y_steps = np.linspace(3, 103, 20)\nnum_y_steps = len(anchor_y_steps)\n\nphoto_aug = dict(\n    brightness_delta=32,\n    contrast_range=(0.5, 1.5),\n    saturation_range=(0.5, 1.5),\n    hue_delta=18)\n\n_dim_ = 256\nnum_query = 12\nnum_pt_per_line = 20\nlatr_cfg = dict(\n    fpn_dim = _dim_,\n    num_query = num_query,\n    num_group = 1,\n    sparse_num_group = 4,\n    encoder = dict(\n        type='ResNet',\n        depth=50,\n        num_stages=4,\n        out_indices=(1, 2, 3),\n        frozen_stages=-1,\n        norm_cfg=dict(type='BN2d', requires_grad=False),\n        norm_eval=True,\n        style='caffe',\n        dcn=dict(type='DCNv2', deform_groups=1, fallback_on_stride=False),\n        stage_with_dcn=(False, False, True, True),\n        init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50')\n    ),\n    neck = dict(\n        type='FPN',\n        in_channels=[512, 1024, 2048],\n        out_channels=_dim_,\n        start_level=0,\n        add_extra_convs='on_output',\n        num_outs=4,\n        relu_before_extra_convs=True\n    ),\n    head=dict(\n        xs_loss_weight=2.0,\n        zs_loss_weight=10.0,\n        vis_loss_weight=1.0,\n        cls_loss_weight=10,\n        project_loss_weight=1.0,\n        pt_as_query=True,\n        num_pt_per_line=num_pt_per_line,\n    ),\n    trans_params=dict(init_z=0, bev_h=150, bev_w=70),\n)\n\nms2one=dict(\n    type='DilateNaive',\n    inc=_dim_, outc=_dim_, num_scales=4,\n    dilations=(1, 2, 5, 9))\n\ntransformer=dict(\n    type='LATRTransformer',\n    decoder=dict(\n        type='LATRTransformerDecoder',\n        embed_dims=_dim_,\n        num_layers=6,\n        enlarge_length=enlarge_length,\n        M_decay_ratio=1,\n        num_query=num_query,\n        num_anchor_per_query=num_pt_per_line,\n        anchor_y_steps=anchor_y_steps,\n        transformerlayers=dict(\n            type='LATRDecoderLayer',\n            attn_cfgs=[\n                dict(\n                    type='MultiheadAttention',\n                    embed_dims=_dim_,\n                    num_heads=4,\n                    dropout=0.1),\n                dict(\n                    type='MSDeformableAttention3D',\n                    embed_dims=_dim_,\n                    num_heads=4,\n                    num_levels=1,\n                    num_points=8,\n                    batch_first=False,\n                    num_query=num_query,\n                    num_anchor_per_query=num_pt_per_line,\n                    anchor_y_steps=anchor_y_steps,\n                    dropout=0.1),\n                ],\n            ffn_cfgs=dict(\n                type='FFN',\n                embed_dims=_dim_,\n                feedforward_channels=_dim_*8,\n                num_fcs=2,\n                ffn_drop=0.1,\n                act_cfg=dict(type='ReLU', inplace=True),\n            ),\n            feedforward_channels=_dim_ * 8,\n            operation_order=('self_attn', 'norm', 'cross_attn', 'norm',\n                            'ffn', 'norm')),\n))\n\nsparse_ins_decoder=Config(\n    dict(\n        encoder=dict(\n            out_dims=_dim_),\n        decoder=dict(\n            num_query=latr_cfg['num_query'],\n            num_group=latr_cfg['num_group'],\n            sparse_num_group=latr_cfg['sparse_num_group'],\n            hidden_dim=_dim_,\n            kernel_dim=_dim_,\n            num_classes=num_category,\n            num_convs=4,\n            output_iam=True,\n            scale_factor=1.,\n            ce_weight=2.0,\n            mask_weight=5.0,\n            dice_weight=2.0,\n            objectness_weight=1.0,\n        ),\n        sparse_decoder_weight=5.0,\n))\n\nresize_h = 720\nresize_w = 960\n\noptimizer_cfg = dict(\n    type='AdamW',\n    lr=2e-4,\n    weight_decay=0.01)"
  },
  {
    "path": "config/release_iccv/apollo_standard.py",
    "content": "import numpy as np\nfrom mmcv.utils import Config\nimport os.path as osp\n\n_base_ = [\n    '../_base_/base_res101_bs16xep100_apollo.py',\n    '../_base_/optimizer.py',\n]\n\nmod = 'release_iccv/apollo_standard'\nmean = [0.485, 0.456, 0.406]\nstd = [0.229, 0.224, 0.225]\n\n\ndataset_name = 'apollo'\ndataset = 'standard'\ndata_dir = osp.join('./data/apollosyn_gen-lanenet/data_splits', dataset)\ndataset_dir = './data/apollosyn_gen-lanenet/Apollo_Sim_3D_Lane_Release'\noutput_dir = 'apollo'\nnum_category = 2\nmax_lanes = 6\n\nT_max = 30\neta_min = 1e-6\nclip_grad_norm = 20\nnepochs = 210\neval_freq = 1\n\nh_org, w_org = 1080, 1920\n\nbatch_size = 8\nnworkers = 10\npos_threshold = 0.3\ntop_view_region = np.array([\n    [-10, 103], [10, 103], [-10, 3], [10, 3]])\nenlarge_length = 20\nposition_range = [\n    top_view_region[0][0] - enlarge_length,\n    top_view_region[2][1] - enlarge_length,\n    -5,\n    top_view_region[1][0] + enlarge_length,\n    top_view_region[0][1] + enlarge_length,\n    5.]\nanchor_y_steps = np.linspace(3, 103, 20)\nnum_y_steps = len(anchor_y_steps)\n\n_dim_ = 256\nnum_query = 12\nnum_pt_per_line = 20\nlatr_cfg = dict(\n    fpn_dim = _dim_,\n    num_query = num_query,\n    num_group = 1,\n    sparse_num_group = 4,\n    encoder = dict(\n        type='ResNet',\n        depth=50,\n        num_stages=4,\n        out_indices=(1, 2, 3),\n        frozen_stages=-1,\n        norm_cfg=dict(type='BN2d', requires_grad=False),\n        norm_eval=True,\n        style='caffe',\n        dcn=dict(type='DCNv2', deform_groups=1, fallback_on_stride=False),\n        stage_with_dcn=(False, False, True, True),\n        init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50')\n    ),\n    neck = dict(\n        type='FPN',\n        in_channels=[512, 1024, 2048],\n        out_channels=_dim_,\n        start_level=0,\n        add_extra_convs='on_output',\n        num_outs=4,\n        relu_before_extra_convs=True\n    ),\n    head=dict(\n        xs_loss_weight=2.0,\n        zs_loss_weight=10.0,\n        vis_loss_weight=1.0,\n        cls_loss_weight=10,\n        project_loss_weight=1.0,\n        pt_as_query=True,\n        num_pt_per_line=num_pt_per_line,\n    ),\n    trans_params=dict(init_z=0, bev_h=150, bev_w=70),\n)\n\nms2one=dict(\n    type='DilateNaive',\n    inc=_dim_, outc=_dim_, num_scales=4,\n    dilations=(1, 2, 5, 9))\n\ntransformer=dict(\n    type='LATRTransformer',\n    decoder=dict(\n        type='LATRTransformerDecoder',\n        embed_dims=_dim_,\n        num_layers=6,\n        enlarge_length=enlarge_length,\n        M_decay_ratio=1,\n        num_query=num_query,\n        num_anchor_per_query=num_pt_per_line,\n        anchor_y_steps=anchor_y_steps,\n        transformerlayers=dict(\n            type='LATRDecoderLayer',\n            attn_cfgs=[\n                dict(\n                    type='MultiheadAttention',\n                    embed_dims=_dim_,\n                    num_heads=4,\n                    dropout=0.1),\n                dict(\n                    type='MSDeformableAttention3D',\n                    embed_dims=_dim_,\n                    num_heads=4,\n                    num_levels=1,\n                    num_points=8,\n                    batch_first=False,\n                    num_query=num_query,\n                    num_anchor_per_query=num_pt_per_line,\n                    anchor_y_steps=anchor_y_steps,\n                    dropout=0.1),\n                ],\n            ffn_cfgs=dict(\n                type='FFN',\n                embed_dims=_dim_,\n                feedforward_channels=_dim_*8,\n                num_fcs=2,\n                ffn_drop=0.1,\n                act_cfg=dict(type='ReLU', inplace=True),\n            ),\n            feedforward_channels=_dim_ * 8,\n            operation_order=('self_attn', 'norm', 'cross_attn', 'norm',\n                            'ffn', 'norm')),\n))\n\nsparse_ins_decoder=Config(\n    dict(\n        encoder=dict(\n            out_dims=_dim_),\n        decoder=dict(\n            num_query=latr_cfg['num_query'],\n            num_group=latr_cfg['num_group'],\n            sparse_num_group=latr_cfg['sparse_num_group'],\n            hidden_dim=_dim_,\n            kernel_dim=_dim_,\n            num_classes=num_category,\n            num_convs=4,\n            output_iam=True,\n            scale_factor=1.,\n            ce_weight=2.0,\n            mask_weight=5.0,\n            dice_weight=2.0,\n            objectness_weight=1.0,\n        ),\n        sparse_decoder_weight=5.0,\n))\n\nresize_h = 720\nresize_w = 960\noptimizer_cfg = dict(\n    type='AdamW',\n    lr=2e-4,\n    weight_decay=0.01)"
  },
  {
    "path": "config/release_iccv/latr_1000_baseline.py",
    "content": "import numpy as np\nfrom mmcv.utils import Config\n\n_base_ = [\n    '../_base_/base_res101_bs16xep100.py',\n    '../_base_/optimizer.py'\n]\n\nmod = 'release_iccv/latr_1000_baseline'\nmean = [0.485, 0.456, 0.406]\nstd = [0.229, 0.224, 0.225]\n\ndataset = '1000'\ndataset_dir = './data/openlane/images/'\ndata_dir = './data/openlane/lane3d_1000/'\n\nbatch_size = 8\nnworkers = 10\nnum_category = 21\npos_threshold = 0.3\ntop_view_region = np.array([\n    [-10, 103], [10, 103], [-10, 3], [10, 3]])\nenlarge_length = 20\nposition_range = [\n    top_view_region[0][0] - enlarge_length,\n    top_view_region[2][1] - enlarge_length,\n    -5,\n    top_view_region[1][0] + enlarge_length,\n    top_view_region[0][1] + enlarge_length,\n    5.]\nanchor_y_steps = np.linspace(3, 103, 20)\nnum_y_steps = len(anchor_y_steps)\n\n# extra aug\nphoto_aug = dict(\n    brightness_delta=32 // 2,\n    contrast_range=(0.5, 1.5),\n    saturation_range=(0.5, 1.5),\n    hue_delta=9)\n\nclip_grad_norm = 20.0\n\n_dim_ = 256\nnum_query = 40\nnum_pt_per_line = 20\nlatr_cfg = dict(\n    fpn_dim = _dim_,\n    num_query = num_query,\n    num_group = 1,\n    sparse_num_group = 4,\n    encoder = dict(\n        type='ResNet',\n        depth=50,\n        num_stages=4,\n        out_indices=(1, 2, 3),\n        frozen_stages=1,\n        norm_cfg=dict(type='BN2d', requires_grad=False),\n        norm_eval=True,\n        style='caffe',\n        dcn=dict(type='DCNv2', deform_groups=1, fallback_on_stride=False),\n        stage_with_dcn=(False, False, True, True),\n        init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50')\n    ),\n    neck = dict(\n        type='FPN',\n        in_channels=[512, 1024, 2048],\n        out_channels=_dim_,\n        start_level=0,\n        add_extra_convs='on_output',\n        num_outs=4,\n        relu_before_extra_convs=True\n    ),\n    head=dict(\n        pt_as_query=True,\n        num_pt_per_line=num_pt_per_line,\n        xs_loss_weight=2.0,\n        zs_loss_weight=10.0,\n        vis_loss_weight=1.0,\n        cls_loss_weight=10,\n        project_loss_weight=1.0,\n    ),\n    trans_params=dict(init_z=0, bev_h=150, bev_w=70),\n)\n\nms2one=dict(\n    type='DilateNaive',\n    inc=_dim_, outc=_dim_, num_scales=4,\n    dilations=(1, 2, 5, 9))\n\ntransformer=dict(\n    type='LATRTransformer',\n    decoder=dict(\n        type='LATRTransformerDecoder',\n        embed_dims=_dim_,\n        num_layers=6,\n        enlarge_length=enlarge_length,\n        M_decay_ratio=1,\n        num_query=num_query,\n        num_anchor_per_query=num_pt_per_line,\n        anchor_y_steps=anchor_y_steps,\n        transformerlayers=dict(\n            type='LATRDecoderLayer',\n            attn_cfgs=[\n                dict(\n                    type='MultiheadAttention',\n                    embed_dims=_dim_,\n                    num_heads=4,\n                    dropout=0.1),\n                dict(\n                    type='MSDeformableAttention3D',\n                    embed_dims=_dim_,\n                    num_heads=4,\n                    num_levels=1,\n                    num_points=8,\n                    batch_first=False,\n                    num_query=num_query,\n                    num_anchor_per_query=num_pt_per_line,\n                    anchor_y_steps=anchor_y_steps,\n                    dropout=0.1),\n                ],\n            ffn_cfgs=dict(\n                type='FFN',\n                embed_dims=_dim_,\n                feedforward_channels=_dim_*8,\n                num_fcs=2,\n                ffn_drop=0.1,\n                act_cfg=dict(type='ReLU', inplace=True),\n            ),\n            feedforward_channels=_dim_ * 8,\n            operation_order=('self_attn', 'norm', 'cross_attn', 'norm',\n                            'ffn', 'norm')),\n))\n\nsparse_ins_decoder=Config(\n    dict(\n        encoder=dict(\n            out_dims=_dim_),\n        decoder=dict(\n            num_query=latr_cfg['num_query'],\n            num_group=latr_cfg['num_group'],\n            sparse_num_group=latr_cfg['sparse_num_group'],\n            hidden_dim=_dim_,\n            kernel_dim=_dim_,\n            num_classes=num_category,\n            num_convs=4,\n            output_iam=True,\n            scale_factor=1.,\n            ce_weight=2.0,\n            mask_weight=5.0,\n            dice_weight=2.0,\n            objectness_weight=1.0,\n        ),\n        sparse_decoder_weight=5.0,\n))\n\nnepochs = 24\nresize_h = 720\nresize_w = 960\n\neval_freq = 8\noptimizer_cfg = dict(\n    type='AdamW',\n    lr=2e-4,\n    paramwise_cfg=dict(\n        custom_keys={\n            'sampling_offsets': dict(lr_mult=0.1),\n        }),\n    weight_decay=0.01)"
  },
  {
    "path": "config/release_iccv/latr_1000_baseline_lite.py",
    "content": "import numpy as np\nfrom mmcv.utils import Config\n\n_base_ = [\n    '../_base_/base_res101_bs16xep100.py',\n    '../_base_/optimizer.py'\n]\n\nmod = 'release_iccv/latr_1000_baseline_lite'\nmean = [0.485, 0.456, 0.406]\nstd = [0.229, 0.224, 0.225]\n\ndataset = '1000'\ndataset_dir = './data/openlane/images/'\ndata_dir = './data/openlane/lane3d_1000/'\n\nbatch_size = 8\nnworkers = 10\nnum_category = 21\npos_threshold = 0.3\n\nclip_grad_norm = 20\n\ntop_view_region = np.array([\n    [-10, 103], [10, 103], [-10, 3], [10, 3]])\nenlarge_length = 20\nposition_range = [\n    top_view_region[0][0] - enlarge_length,\n    top_view_region[2][1] - enlarge_length,\n    -5,\n    top_view_region[1][0] + enlarge_length,\n    top_view_region[0][1] + enlarge_length,\n    5.]\nanchor_y_steps = np.linspace(3, 103, 20)\nnum_y_steps = len(anchor_y_steps)\n\n# extra aug\nphoto_aug = dict(\n    brightness_delta=32 // 2,\n    contrast_range=(0.5, 1.5),\n    saturation_range=(0.5, 1.5),\n    hue_delta=9)\n\n_dim_ = 256\nnum_query = 40\nnum_pt_per_line = 20\nlatr_cfg = dict(\n    fpn_dim = _dim_,\n    num_query = num_query,\n    num_group = 1,\n    sparse_num_group = 4,\n    encoder = dict(\n        type='ResNet',\n        depth=50,\n        num_stages=4,\n        out_indices=(1, 2, 3),\n        frozen_stages=1,\n        norm_cfg=dict(type='BN2d', requires_grad=False),\n        norm_eval=True,\n        style='caffe',\n        dcn=dict(type='DCNv2', deform_groups=1, fallback_on_stride=False),\n        stage_with_dcn=(False, False, True, True),\n        init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50')\n    ),\n    neck = dict(\n        type='FPN',\n        in_channels=[512, 1024, 2048],\n        out_channels=_dim_,\n        start_level=0,\n        add_extra_convs='on_output',\n        num_outs=4,\n        relu_before_extra_convs=True\n    ),\n    head=dict(\n        pt_as_query=True,\n        num_pt_per_line=num_pt_per_line,\n        xs_loss_weight=2.0,\n        zs_loss_weight=10.0,\n        vis_loss_weight=1.0,\n        cls_loss_weight=10,\n        project_loss_weight=1.0,\n    ),\n    trans_params=dict(init_z=0, bev_h=150, bev_w=70),\n)\n\nms2one=dict(\n    type='DilateNaive',\n    inc=_dim_, outc=_dim_, num_scales=4,\n    dilations=(1, 2, 5, 9))\n\ntransformer=dict(\n    type='LATRTransformer',\n    decoder=dict(\n        type='LATRTransformerDecoder',\n        embed_dims=_dim_,\n        num_layers=2,\n        enlarge_length=enlarge_length,\n        M_decay_ratio=1,\n        num_query=num_query,\n        num_anchor_per_query=num_pt_per_line,\n        anchor_y_steps=anchor_y_steps,\n        transformerlayers=dict(\n            type='LATRDecoderLayer',\n            attn_cfgs=[\n                dict(\n                    type='MultiheadAttention',\n                    embed_dims=_dim_,\n                    num_heads=4,\n                    dropout=0.1),\n                dict(\n                    type='MSDeformableAttention3D',\n                    embed_dims=_dim_,\n                    num_heads=4,\n                    num_levels=1,\n                    num_points=8,\n                    batch_first=False,\n                    num_query=num_query,\n                    num_anchor_per_query=num_pt_per_line,\n                    anchor_y_steps=anchor_y_steps,\n                    dropout=0.1),\n                ],\n            ffn_cfgs=dict(\n                type='FFN',\n                embed_dims=_dim_,\n                feedforward_channels=_dim_*8,\n                num_fcs=2,\n                ffn_drop=0.1,\n                act_cfg=dict(type='ReLU', inplace=True),\n            ),\n            feedforward_channels=_dim_ * 8,\n            operation_order=('self_attn', 'norm', 'cross_attn', 'norm',\n                            'ffn', 'norm')),\n))\n\nsparse_ins_decoder=Config(\n    dict(\n        encoder=dict(\n            out_dims=_dim_),\n        decoder=dict(\n            num_query=latr_cfg['num_query'],\n            num_group=latr_cfg['num_group'],\n            sparse_num_group=latr_cfg['sparse_num_group'],\n            hidden_dim=_dim_,\n            kernel_dim=_dim_,\n            num_classes=num_category,\n            num_convs=4,\n            output_iam=True,\n            scale_factor=1.,\n            ce_weight=2.0,\n            mask_weight=5.0,\n            dice_weight=2.0,\n            objectness_weight=1.0,\n        ),\n        sparse_decoder_weight=5.0,\n))\n\n\nresize_h = 720\nresize_w = 960\n\nnepochs = 24\neval_freq = 8\noptimizer_cfg = dict(\n    type='AdamW',\n    lr=2e-4,\n    paramwise_cfg=dict(\n        custom_keys={\n            'sampling_offsets': dict(lr_mult=0.1),\n        }),\n    weight_decay=0.01)"
  },
  {
    "path": "config/release_iccv/once.py",
    "content": "import numpy as np\nfrom mmcv.utils import Config\nimport os.path as osp\n\n_base_ = [\n    '../_base_/base_res101_bs16xep100.py',\n    '../_base_/optimizer.py'\n]\n\nmod = 'release_iccv/once'\nmean = [0.485, 0.456, 0.406]\nstd = [0.229, 0.224, 0.225]\n\n\ndataset = 'once'\ndataset_name = 'once'\ndata_dir = 'data/once/'\ndataset_dir = 'data/once/data/'\neval_config_dir = 'config/_base_/once_eval_config.json'\n\nsave_path = osp.join('./work_dirs', dataset)\n\nmax_lanes = 8\nnum_pt_per_line = 20\n\neta_min = 1e-6\nclip_grad_norm = 20\n\nbatch_size = 8\nnworkers = 10\nnum_category = 2\npos_threshold = 0.3\n\ntop_view_region = np.array([\n    [-10, 65], [10, 65], [-10, 0.5], [10, 0.5]])\n\nenlarge_length = 10\nposition_range = [\n    top_view_region[0][0] - enlarge_length,\n    top_view_region[2][1] - enlarge_length,\n    -5,\n    top_view_region[1][0] + enlarge_length,\n    top_view_region[0][1] + enlarge_length,\n    5.]\n\nanchor_y_steps = np.linspace(0.5, 65, num_pt_per_line)\nnum_y_steps = len(anchor_y_steps)\n\n_dim_ = 256\nnum_query = 12\nnum_pt_per_line = 20\nlatr_cfg = dict(\n    fpn_dim = _dim_,\n    num_query = num_query,\n    num_group = 1,\n    sparse_num_group = 4,\n    encoder = dict(\n        type='ResNet',\n        depth=50,\n        num_stages=4,\n        out_indices=(1, 2, 3),\n        frozen_stages=1,\n        norm_cfg=dict(type='BN2d', requires_grad=False),\n        norm_eval=True,\n        style='caffe',\n        # with_cp=True,\n        dcn=dict(type='DCNv2', deform_groups=1, fallback_on_stride=False),\n        stage_with_dcn=(False, False, True, True),\n        init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50')\n    ),\n    neck = dict(\n        type='FPN',\n        in_channels=[512, 1024, 2048],\n        out_channels=_dim_,\n        start_level=0,\n        add_extra_convs='on_output',\n        num_outs=4,\n        relu_before_extra_convs=True\n    ),\n    head=dict(\n        xs_loss_weight=2.0,\n        zs_loss_weight=10.0,\n        vis_loss_weight=1.0,\n        cls_loss_weight=10,\n        project_loss_weight=1.0,\n        pt_as_query=True,\n        num_pt_per_line=num_pt_per_line,\n    ),\n    trans_params=dict(init_z=0, bev_h=150, bev_w=70),\n)\n\nms2one=dict(\n    type='DilateNaive',\n    inc=_dim_, outc=_dim_, num_scales=4,\n    dilations=(1, 2, 5, 9))\n\ntransformer=dict(\n    type='LATRTransformer',\n    decoder=dict(\n        type='LATRTransformerDecoder',\n        embed_dims=_dim_,\n        num_layers=6,\n        enlarge_length=enlarge_length,\n        M_decay_ratio=1,\n        num_query=num_query,\n        num_anchor_per_query=num_pt_per_line,\n        anchor_y_steps=anchor_y_steps,\n        transformerlayers=dict(\n            type='LATRDecoderLayer',\n            attn_cfgs=[\n                dict(\n                    type='MultiheadAttention',\n                    embed_dims=_dim_,\n                    num_heads=4,\n                    dropout=0.1),\n                dict(\n                    type='MSDeformableAttention3D',\n                    embed_dims=_dim_,\n                    num_heads=4,\n                    num_levels=1,\n                    num_points=8,\n                    batch_first=False,\n                    num_query=num_query,\n                    num_anchor_per_query=num_pt_per_line,\n                    anchor_y_steps=anchor_y_steps,\n                    dropout=0.1),\n                ],\n            ffn_cfgs=dict(\n                type='FFN',\n                embed_dims=_dim_,\n                feedforward_channels=_dim_*8,\n                num_fcs=2,\n                ffn_drop=0.1,\n                act_cfg=dict(type='ReLU', inplace=True),\n            ),\n            feedforward_channels=_dim_ * 8,\n            operation_order=('self_attn', 'norm', 'cross_attn', 'norm',\n                            'ffn', 'norm')),\n))\n\nsparse_ins_decoder=Config(\n    dict(\n        encoder=dict(\n            out_dims=_dim_),\n        decoder=dict(\n            num_query=latr_cfg['num_query'],\n            num_group=latr_cfg['num_group'],\n            sparse_num_group=latr_cfg['sparse_num_group'],\n            hidden_dim=_dim_,\n            kernel_dim=_dim_,\n            num_classes=num_category,\n            num_convs=4,\n            output_iam=True,\n            scale_factor=1.,\n            ce_weight=2.0,\n            mask_weight=5.0,\n            dice_weight=2.0,\n            objectness_weight=1.0,\n        ),\n        sparse_decoder_weight=5.0,\n))\n\nresize_h = 720\nresize_w = 960\nnepochs = 24\neval_freq = 8\n\noptimizer_cfg = dict(\n    type='AdamW',\n    lr=2e-4,\n    paramwise_cfg=dict(\n        custom_keys={\n            'sampling_offsets': dict(lr_mult=0.1),\n        }),\n    weight_decay=0.01)"
  },
  {
    "path": "data/Load_Data.py",
    "content": "import re\nimport os\nimport sys\nimport copy\nimport json\nimport glob\nimport random\nimport warnings\nimport numpy as np\nimport cv2\nfrom PIL import Image\nfrom torch.utils.data import Dataset, DataLoader\nfrom torchvision import transforms\nimport torchvision.transforms.functional as F\nfrom torchvision.transforms import InterpolationMode\nfrom utils.utils import *\nfrom experiments.gpu_utils import is_main_process\n\nfrom .transform import PhotoMetricDistortionMultiViewImage\n\nsys.path.append('./')\nwarnings.simplefilter('ignore', np.RankWarning)\nmatplotlib.use('Agg')\n\nimport yaml\n\nclass LaneDataset(Dataset):\n    \"\"\"\n    Dataset with labeled lanes\n        This implementation considers:\n        w/o laneline 3D attributes\n        w/o centerline annotations\n        default considers 3D laneline, including centerlines\n\n        This new version of data loader prepare ground-truth anchor tensor in flat ground space.\n        It is assumed the dataset provides accurate visibility labels. Preparing ground-truth tensor depends on it.\n    \"\"\"\n    # dataset_base_dir is image path, json_file_path is json file path,\n    def __init__(self, dataset_base_dir, json_file_path, args, data_aug=False):\n        \"\"\"\n\n        :param dataset_info_file: json file list\n        \"\"\"\n        self.totensor = transforms.ToTensor()\n        mean = [0.485, 0.456, 0.406] if args.mean is None else args.mean\n        std = [0.229, 0.224, 0.225] if args.std is None else args.std\n        self.normalize = transforms.Normalize(mean, std)\n        self.data_aug = data_aug\n        if data_aug:\n            if hasattr(args, 'photo_aug'):\n                self.photo_aug = PhotoMetricDistortionMultiViewImage(**args.photo_aug)\n            else:\n                self.photo_aug = False\n\n        self.dataset_base_dir = dataset_base_dir\n        self.json_file_path = json_file_path\n\n        # dataset parameters\n        self.dataset_name = args.dataset_name\n        self.num_category = args.num_category\n\n        self.h_org = args.org_h\n        self.w_org = args.org_w\n        self.h_crop = args.crop_y\n\n        # parameters related to service network\n        self.h_net = args.resize_h\n        self.w_net = args.resize_w\n        self.u_ratio = float(self.w_net) / float(self.w_org)\n        self.v_ratio = float(self.h_net) / float(self.h_org - self.h_crop)\n        self.top_view_region = args.top_view_region\n        self.max_lanes = args.max_lanes\n\n        self.K = args.K\n        self.H_crop = homography_crop_resize([args.org_h, args.org_w], args.crop_y, [args.resize_h, args.resize_w])\n\n        if args.fix_cam:\n            self.fix_cam = True\n            # compute the homography between image and IPM, and crop transformation\n            self.cam_height = args.cam_height\n            self.cam_pitch = np.pi / 180 * args.pitch\n            self.P_g2im = projection_g2im(self.cam_pitch, self.cam_height, args.K)\n        else:\n            self.fix_cam = False\n\n        # compute anchor steps\n        self.use_default_anchor = args.use_default_anchor\n        \n        self.x_min, self.x_max = self.top_view_region[0, 0], self.top_view_region[1, 0]\n        self.y_min, self.y_max = self.top_view_region[2, 1], self.top_view_region[0, 1]\n        \n        self.anchor_y_steps = args.anchor_y_steps\n        self.num_y_steps = len(self.anchor_y_steps)\n\n        self.anchor_y_steps_dense = args.get(\n            'anchor_y_steps_dense',\n            np.linspace(3, 103, 200))\n        args.anchor_y_steps_dense = self.anchor_y_steps_dense\n        self.num_y_steps_dense = len(self.anchor_y_steps_dense)\n        self.anchor_dim = 3 * self.num_y_steps + args.num_category\n        self.save_json_path = args.save_json_path\n\n        # parse ground-truth file\n        if 'openlane' in self.dataset_name:\n            label_list = glob.glob(json_file_path + '**/*.json', recursive=True)\n            self._label_list = label_list\n        elif 'once' in self.dataset_name:\n            label_list = glob.glob(json_file_path + '*/*/*.json', recursive=True)\n            self._label_list = []\n            for js_label_file in label_list:\n                if not os.path.getsize(js_label_file):\n                    continue\n                image_path = map_once_json2img(js_label_file)\n                if not os.path.exists(image_path):\n                    continue\n                self._label_list.append(js_label_file)\n        else: \n            raise ValueError(\"to use ApolloDataset for apollo\")\n        \n        if hasattr(self, '_label_list'):\n            self.n_samples = len(self._label_list)\n        else:\n            self.n_samples = self._label_image_path.shape[0]\n\n    def preprocess_data_from_json_once(self, idx_json_file):\n        _label_image_path = None\n        _label_cam_height = None\n        _label_cam_pitch = None\n        cam_extrinsics = None\n        cam_intrinsics = None\n        _label_laneline_org = None\n        _gt_laneline_category_org = None\n\n        image_path = map_once_json2img(idx_json_file)\n\n        assert ops.exists(image_path), '{:s} not exist'.format(image_path)\n        _label_image_path = image_path\n\n        with open(idx_json_file, 'r') as file:\n            file_lines = [line for line in file]\n            if len(file_lines) != 0:\n                info_dict = json.loads(file_lines[0])\n            else:\n                print('Empty label_file:', idx_json_file)\n                return\n\n            if not self.fix_cam:\n                cam_pitch = 0.3/180*np.pi\n                cam_height = 1.5\n                cam_extrinsics = np.array([[np.cos(cam_pitch), 0, -np.sin(cam_pitch), 0],\n                                            [0, 1, 0, 0],\n                                            [np.sin(cam_pitch), 0,  np.cos(cam_pitch), cam_height],\n                                            [0, 0, 0, 1]], dtype=float)\n                R_vg = np.array([[0, 1, 0],\n                                    [-1, 0, 0],\n                                    [0, 0, 1]], dtype=float)\n                R_gc = np.array([[1, 0, 0],\n                                    [0, 0, 1],\n                                    [0, -1, 0]], dtype=float)\n                cam_extrinsics[:3, :3] = np.matmul(np.matmul(\n                                            np.matmul(np.linalg.inv(R_vg), cam_extrinsics[:3, :3]),\n                                                R_vg), R_gc)\n                cam_extrinsics[0:2, 3] = 0.0\n\n                gt_cam_height = cam_extrinsics[2, 3] \n                gt_cam_pitch = 0\n\n                if 'calibration' in info_dict:\n                    cam_intrinsics = info_dict['calibration']\n                    cam_intrinsics = np.array(cam_intrinsics)\n                    cam_intrinsics = cam_intrinsics[:, :3]\n                else:\n                    cam_intrinsics = self.K\n\n            _label_cam_height = gt_cam_height\n            _label_cam_pitch = gt_cam_pitch\n\n            gt_lanes_packed = info_dict['lanes']\n            gt_lane_pts, gt_lane_visibility, gt_laneline_category = [], [], []\n            for i, gt_lane_packed in enumerate(gt_lanes_packed):\n                lane = np.array(gt_lane_packed).T\n\n                # Coordinate convertion for openlane_300 data\n                lane = np.vstack((lane, np.ones((1, lane.shape[1]))))\n                lane = np.matmul(cam_extrinsics, lane)\n\n                lane = lane[0:3, :].T\n                lane = lane[lane[:,1].argsort()] #TODO:make y mono increase\n                gt_lane_pts.append(lane)\n                gt_lane_visibility.append(1.0)\n                gt_laneline_category.append(1)\n\n        _gt_laneline_category_org = copy.deepcopy(np.array(gt_laneline_category))\n\n        if not self.fix_cam:\n            cam_K = cam_intrinsics\n            if 'openlane' in self.dataset_name or 'once' in self.dataset_name:\n                cam_E = cam_extrinsics\n                P_g2im = projection_g2im_extrinsic(cam_E, cam_K)\n                H_g2im = homograpthy_g2im_extrinsic(cam_E, cam_K)\n            else:\n                gt_cam_height = _label_cam_height\n                gt_cam_pitch = _label_cam_pitch\n                P_g2im = projection_g2im(gt_cam_pitch, gt_cam_height, cam_K)\n                H_g2im = homograpthy_g2im(gt_cam_pitch, gt_cam_height, cam_K)\n            H_im2g = np.linalg.inv(H_g2im)\n        else:\n            P_g2im = self.P_g2im\n            H_im2g = self.H_im2g\n        P_g2gflat = np.matmul(H_im2g, P_g2im)\n\n        gt_lanes = gt_lane_pts\n        gt_visibility = gt_lane_visibility\n        gt_category = gt_laneline_category\n\n        # prune gt lanes by visibility labels\n        gt_lanes = [prune_3d_lane_by_visibility(gt_lane, gt_visibility[k]).squeeze(0) for k, gt_lane in enumerate(gt_lanes)]\n        _label_laneline_org = copy.deepcopy(gt_lanes)\n        return _label_image_path, _label_cam_height, _label_cam_pitch, \\\n               cam_extrinsics, cam_intrinsics, \\\n               _label_laneline_org, \\\n               _gt_laneline_category_org, info_dict\n               #    _label_laneline, \\\n               #    _gt_laneline_visibility, _gt_laneline_category, \\\n\n    def preprocess_data_from_json_openlane(self, idx_json_file):\n        _label_image_path = None\n        _label_cam_height = None\n        _label_cam_pitch = None\n        cam_extrinsics = None\n        cam_intrinsics = None\n        # _label_laneline = None\n        _label_laneline_org = None\n        # _gt_laneline_visibility = None\n        # _gt_laneline_category = None\n        _gt_laneline_category_org = None\n        # _laneline_ass_id = None\n\n        with open(idx_json_file, 'r') as file:\n            file_lines = [line for line in file]\n            info_dict = json.loads(file_lines[0])\n\n            image_path = ops.join(self.dataset_base_dir, info_dict['file_path'])\n            assert ops.exists(image_path), '{:s} not exist'.format(image_path)\n            _label_image_path = image_path\n\n            if not self.fix_cam:\n                cam_extrinsics = np.array(info_dict['extrinsic'])\n                # Re-calculate extrinsic matrix based on ground coordinate\n                R_vg = np.array([[0, 1, 0],\n                                    [-1, 0, 0],\n                                    [0, 0, 1]], dtype=float)\n                R_gc = np.array([[1, 0, 0],\n                                    [0, 0, 1],\n                                    [0, -1, 0]], dtype=float)\n                cam_extrinsics[:3, :3] = np.matmul(np.matmul(\n                                            np.matmul(np.linalg.inv(R_vg), cam_extrinsics[:3, :3]),\n                                                R_vg), R_gc)\n                cam_extrinsics[0:2, 3] = 0.0\n                \n                # gt_cam_height = info_dict['cam_height']\n                gt_cam_height = cam_extrinsics[2, 3]\n                if 'cam_pitch' in info_dict:\n                    gt_cam_pitch = info_dict['cam_pitch']\n                else:\n                    gt_cam_pitch = 0\n\n                if 'intrinsic' in info_dict:\n                    cam_intrinsics = info_dict['intrinsic']\n                    cam_intrinsics = np.array(cam_intrinsics)\n                else:\n                    cam_intrinsics = self.K  \n\n            _label_cam_height = gt_cam_height\n            _label_cam_pitch = gt_cam_pitch\n\n            gt_lanes_packed = info_dict['lane_lines']\n            gt_lane_pts, gt_lane_visibility, gt_laneline_category = [], [], []\n            for i, gt_lane_packed in enumerate(gt_lanes_packed):\n                # A GT lane can be either 2D or 3D\n                # if a GT lane is 3D, the height is intact from 3D GT, so keep it intact here too\n                lane = np.array(gt_lane_packed['xyz'])\n                lane_visibility = np.array(gt_lane_packed['visibility'])\n\n                # Coordinate convertion for openlane_300 data\n                lane = np.vstack((lane, np.ones((1, lane.shape[1]))))\n                cam_representation = np.linalg.inv(\n                                        np.array([[0, 0, 1, 0],\n                                                    [-1, 0, 0, 0],\n                                                    [0, -1, 0, 0],\n                                                    [0, 0, 0, 1]], dtype=float))  # transformation from apollo camera to openlane camera\n                lane = np.matmul(cam_extrinsics, np.matmul(cam_representation, lane))\n\n                lane = lane[0:3, :].T\n                gt_lane_pts.append(lane)\n                gt_lane_visibility.append(lane_visibility)\n\n                if 'category' in gt_lane_packed:\n                    lane_cate = gt_lane_packed['category']\n                    if lane_cate == 21:  # merge left and right road edge into road edge\n                        lane_cate = 20\n                    gt_laneline_category.append(lane_cate)\n                else:\n                    gt_laneline_category.append(1)\n        \n        # _label_laneline_org = copy.deepcopy(gt_lane_pts)\n        _gt_laneline_category_org = copy.deepcopy(np.array(gt_laneline_category))\n\n        gt_lanes = gt_lane_pts\n        gt_visibility = gt_lane_visibility\n        gt_category = gt_laneline_category\n\n        # prune gt lanes by visibility labels\n        gt_lanes = [prune_3d_lane_by_visibility(gt_lane, gt_visibility[k]) for k, gt_lane in enumerate(gt_lanes)]\n        _label_laneline_org = copy.deepcopy(gt_lanes)\n\n        return _label_image_path, _label_cam_height, _label_cam_pitch, \\\n               cam_extrinsics, cam_intrinsics, \\\n               _label_laneline_org, \\\n               _gt_laneline_category_org, info_dict\n\n    def __len__(self):\n        \"\"\"\n        Conventional len method\n        \"\"\"\n        return self.n_samples\n\n    # new getitem, WIP\n    def WIP__getitem__(self, idx):\n        \"\"\"\n        Args: idx (int): Index in list to load image\n        \"\"\"\n        extra_dict = {}\n\n        idx_json_file = self._label_list[idx]\n        # preprocess data from json file\n        if 'openlane' in self.dataset_name:\n            _label_image_path, _label_cam_height, _label_cam_pitch, \\\n            cam_extrinsics, cam_intrinsics, \\\n            _label_laneline_org, \\\n            _gt_laneline_category_org, info_dict = self.preprocess_data_from_json_openlane(idx_json_file)\n        elif 'once' in self.dataset_name:\n            _label_image_path, _label_cam_height, _label_cam_pitch, \\\n            cam_extrinsics, cam_intrinsics, \\\n            _label_laneline_org, \\\n            _gt_laneline_category_org, info_dict = self.preprocess_data_from_json_once(idx_json_file)\n\n        # fetch camera height and pitch\n        if not self.fix_cam:\n            gt_cam_height = _label_cam_height\n            gt_cam_pitch = _label_cam_pitch\n            if 'openlane' in self.dataset_name or 'once' in self.dataset_name:\n                intrinsics = cam_intrinsics\n                extrinsics = cam_extrinsics\n            else:\n                # should not be used\n                intrinsics = self.K\n                extrinsics = np.zeros((3,4))\n                extrinsics[2,3] = gt_cam_height\n        else:\n            gt_cam_height = self.cam_height\n            gt_cam_pitch = self.cam_pitch\n            # should not be used\n            intrinsics = self.K\n            extrinsics = np.zeros((3,4))\n            extrinsics[2,3] = gt_cam_height\n\n        img_name = _label_image_path\n        with open(img_name, 'rb') as f:\n            image = (Image.open(f).convert('RGB'))\n\n        # image preprocess with crop and resize\n        image = F.crop(image, self.h_crop, 0, self.h_org-self.h_crop, self.w_org)\n        image = F.resize(image, size=(self.h_net, self.w_net), interpolation=InterpolationMode.BILINEAR)\n\n        gt_category_2d = _gt_laneline_category_org\n        if self.data_aug:\n            img_rot, aug_mat = data_aug_rotate(image)\n            if self.photo_aug:\n                img_rot = self.photo_aug(\n                    dict(img=img_rot.copy().astype(np.float32))\n                )['img']\n            image = Image.fromarray(\n                np.clip(img_rot, 0, 255).astype(np.uint8))\n        image = self.totensor(image).float()\n        image = self.normalize(image)\n        intrinsics = torch.from_numpy(intrinsics)\n        extrinsics = torch.from_numpy(extrinsics)\n\n        # prepare binary segmentation label map\n        seg_label = np.zeros((self.h_net, self.w_net), dtype=np.int8)\n        # seg idx has the same order as gt_lanes\n        seg_idx_label = np.zeros((self.max_lanes, self.h_net, self.w_net), dtype=np.uint8)\n        ground_lanes = np.zeros((self.max_lanes, self.anchor_dim), dtype=np.float32)\n        ground_lanes_dense = np.zeros(\n            (self.max_lanes, self.num_y_steps_dense * 3), dtype=np.float32)\n        gt_lanes = _label_laneline_org # ground\n        gt_laneline_img = [[0]] * len(gt_lanes)\n\n        H_g2im, P_g2im, H_crop = self.transform_mats_impl(cam_extrinsics, \\\n                                            cam_intrinsics, _label_cam_pitch, _label_cam_height)\n        M = np.matmul(H_crop, P_g2im)\n        # update transformation with image augmentation\n        if self.data_aug:\n            M = np.matmul(aug_mat, M)\n\n        lidar2img = np.eye(4).astype(np.float32)\n        lidar2img[:3] = M\n\n        SEG_WIDTH = 80\n        thickness = int(SEG_WIDTH / 2650 * self.h_net / 2)\n\n        for i, lane in enumerate(gt_lanes):\n            if i >= self.max_lanes:\n                break\n\n            if lane.shape[0] <= 2:\n                continue\n\n            if _gt_laneline_category_org[i] >= self.num_category:\n                continue\n\n            xs, zs = resample_laneline_in_y(lane, self.anchor_y_steps)\n            vis = np.logical_and(\n                self.anchor_y_steps > lane[:, 1].min() - 5,\n                self.anchor_y_steps < lane[:, 1].max() + 5)\n\n            ground_lanes[i][0: self.num_y_steps] = xs\n            ground_lanes[i][self.num_y_steps:2*self.num_y_steps] = zs\n            ground_lanes[i][2*self.num_y_steps:3*self.num_y_steps] = vis * 1.0\n            ground_lanes[i][self.anchor_dim - self.num_category] = 0.0\n            ground_lanes[i][self.anchor_dim - self.num_category + _gt_laneline_category_org[i]] = 1.0\n\n            xs_dense, zs_dense = resample_laneline_in_y(\n                lane, self.anchor_y_steps_dense)\n            vis_dense = np.logical_and(\n                self.anchor_y_steps_dense > lane[:, 1].min(),\n                self.anchor_y_steps_dense < lane[:, 1].max())\n            ground_lanes_dense[i][0: self.num_y_steps_dense] = xs_dense\n            ground_lanes_dense[i][1*self.num_y_steps_dense: 2*self.num_y_steps_dense] = zs_dense\n            ground_lanes_dense[i][2*self.num_y_steps_dense: 3*self.num_y_steps_dense] = vis_dense * 1.0\n\n            x_2d, y_2d = projective_transformation(M, lane[:, 0],\n                                                   lane[:, 1], lane[:, 2])\n            gt_laneline_img[i] = np.array([x_2d, y_2d]).T.tolist()\n            for j in range(len(x_2d) - 1):\n                seg_label = cv2.line(seg_label,\n                                     (int(x_2d[j]), int(y_2d[j])), (int(x_2d[j+1]), int(y_2d[j+1])),\n                                     color=np.asscalar(np.array([1])),\n                                     thickness=thickness)\n                seg_idx_label[i] = cv2.line(\n                    seg_idx_label[i],\n                    (int(x_2d[j]), int(y_2d[j])), (int(x_2d[j+1]), int(y_2d[j+1])),\n                    color=gt_category_2d[i].item(),\n                    thickness=thickness)\n\n        seg_label = torch.from_numpy(seg_label.astype(np.float32))\n        seg_label.unsqueeze_(0)\n        extra_dict['seg_label'] = seg_label\n        extra_dict['seg_idx_label'] = seg_idx_label\n        extra_dict['ground_lanes'] = ground_lanes\n        extra_dict['ground_lanes_dense'] = ground_lanes_dense\n        extra_dict['lidar2img'] = lidar2img\n        extra_dict['pad_shape'] = torch.Tensor(seg_idx_label.shape[-2:]).float()\n        extra_dict['idx_json_file'] = idx_json_file\n        extra_dict['image'] = image\n        if self.data_aug:\n            aug_mat = torch.from_numpy(aug_mat.astype(np.float32))\n            extra_dict['aug_mat'] = aug_mat\n        return extra_dict\n\n    # old getitem, workable\n    def __getitem__(self, idx):\n        \"\"\"\n        Args: idx (int): Index in list to load image\n        \"\"\"\n        return self.WIP__getitem__(idx)\n\n    def transform_mats_impl(self, cam_extrinsics, cam_intrinsics, cam_pitch, cam_height):\n        if not self.fix_cam:\n            if 'openlane' in self.dataset_name or 'once' in self.dataset_name:\n                H_g2im = homograpthy_g2im_extrinsic(cam_extrinsics, cam_intrinsics)\n                P_g2im = projection_g2im_extrinsic(cam_extrinsics, cam_intrinsics)\n            else:\n                H_g2im = homograpthy_g2im(cam_pitch, cam_height, self.K)\n                P_g2im = projection_g2im(cam_pitch, cam_height, self.K)\n            return H_g2im, P_g2im, self.H_crop\n        else:\n            return self.H_g2im, self.P_g2im, self.H_crop\n\ndef make_lane_y_mono_inc(lane):\n    \"\"\"\n        Due to lose of height dim, projected lanes to flat ground plane may not have monotonically increasing y.\n        This function trace the y with monotonically increasing y, and output a pruned lane\n    :param lane:\n    :return:\n    \"\"\"\n    idx2del = []\n    max_y = lane[0, 1]\n    for i in range(1, lane.shape[0]):\n        # hard-coded a smallest step, so the far-away near horizontal tail can be pruned\n        if lane[i, 1] <= max_y + 3:\n            idx2del.append(i)\n        else:\n            max_y = lane[i, 1]\n    lane = np.delete(lane, idx2del, 0)\n    return lane\n\ndef data_aug_rotate(img):\n    # assume img in PIL image format\n    rot = random.uniform(-np.pi/18, np.pi/18)\n    center_x = img.width / 2\n    center_y = img.height / 2\n    rot_mat = cv2.getRotationMatrix2D((center_x, center_y), rot, 1.0)\n    img_rot = np.array(img)\n    img_rot = cv2.warpAffine(img_rot, rot_mat, (img.width, img.height), flags=cv2.INTER_LINEAR)\n    rot_mat = np.vstack([rot_mat, [0, 0, 1]])\n    return img_rot, rot_mat\n\n\ndef seed_worker(worker_id):\n    worker_seed = torch.initial_seed() % 2**32\n    np.random.seed(worker_seed)\n    random.seed(worker_seed)\n\n\ndef get_loader(transformed_dataset, args):\n    \"\"\"\n        create dataset from ground-truth\n        return a batch sampler based ont the dataset\n    \"\"\"\n\n    # transformed_dataset = LaneDataset(dataset_base_dir, json_file_path, args)\n    sample_idx = range(transformed_dataset.n_samples)\n\n    g = torch.Generator()\n    g.manual_seed(0)\n\n    discarded_sample_start = len(sample_idx) // args.batch_size * args.batch_size\n    if is_main_process():\n        print(\"Discarding images:\")\n        if hasattr(transformed_dataset, '_label_image_path'):\n            print(transformed_dataset._label_image_path[discarded_sample_start: len(sample_idx)])\n        else:\n            print(len(sample_idx) - discarded_sample_start)\n    sample_idx = sample_idx[0 : discarded_sample_start]\n    \n    if args.dist:\n        if is_main_process():\n            print('use distributed sampler')\n        if 'standard' in args.dataset_name or 'rare_subset' in args.dataset_name or 'illus_chg' in args.dataset_name:\n            data_sampler = torch.utils.data.distributed.DistributedSampler(transformed_dataset, shuffle=True, drop_last=True)\n            data_loader = DataLoader(transformed_dataset,\n                                        batch_size=args.batch_size, \n                                        sampler=data_sampler,\n                                        num_workers=args.nworkers, \n                                        pin_memory=True,\n                                        persistent_workers=args.nworkers > 0,\n                                        worker_init_fn=seed_worker,\n                                        generator=g,\n                                        drop_last=True)\n        else:\n            data_sampler = torch.utils.data.distributed.DistributedSampler(transformed_dataset)\n            data_loader = DataLoader(transformed_dataset,\n                                        batch_size=args.batch_size, \n                                        sampler=data_sampler,\n                                        num_workers=args.nworkers, \n                                        pin_memory=True,\n                                        persistent_workers=args.nworkers > 0,\n                                        worker_init_fn=seed_worker,\n                                        generator=g)\n    else:\n        if is_main_process():\n            print(\"use default sampler\")\n        data_sampler = torch.utils.data.sampler.SubsetRandomSampler(sample_idx)\n        data_loader = DataLoader(transformed_dataset,\n                                batch_size=args.batch_size, sampler=data_sampler,\n                                num_workers=args.nworkers, pin_memory=True,\n                                persistent_workers=args.nworkers > 0,\n                                worker_init_fn=seed_worker,\n                                generator=g)\n\n    if args.dist:\n        return data_loader, data_sampler\n    return data_loader\n\ndef map_once_json2img(json_label_file):\n    if 'train' in json_label_file:\n        split_name = 'train'\n    elif 'val' in json_label_file:\n        split_name = 'val'\n    elif 'test' in json_label_file:\n        split_name = 'test'\n    else:\n        raise ValueError(\"train/val/test not in the json path\")\n    image_path = json_label_file.replace(split_name, 'data').replace('.json', '.jpg')\n    return image_path\n"
  },
  {
    "path": "data/__init__.py",
    "content": ""
  },
  {
    "path": "data/apollo_dataset.py",
    "content": "# ==============================================================================\n# Copyright (c) 2022 The PersFormer Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\nimport re\nimport os\nimport sys\nimport copy\nimport json\nimport glob\nimport random\nimport pickle\nimport warnings\nfrom pathlib import Path\nimport numpy as np\nfrom numpy import int32#, result_type\nimport cv2\nfrom PIL import Image\nfrom torch.utils.data import Dataset, DataLoader\nfrom torchvision import transforms\nimport torchvision.transforms.functional as F\nfrom torchvision.transforms import InterpolationMode\nfrom utils.utils import *\nsys.path.append('./')\nwarnings.simplefilter('ignore', np.RankWarning)\nmatplotlib.use('Agg')\n\nfrom .transform import PhotoMetricDistortionMultiViewImage \n\nfrom tqdm import tqdm\n\n\nclass ApolloLaneDataset(Dataset):\n    def __init__(self, dataset_base_dir, json_file_path, args, data_aug=False, **kwargs):\n        # define image pre-processor\n        self.totensor = transforms.ToTensor()\n        # expect same mean/std for all torchvision models\n        mean = [0.485, 0.456, 0.406] if args.mean is None else args.mean\n        std = [0.229, 0.224, 0.225] if args.std is None else args.std\n        self.normalize = transforms.Normalize(mean, std)\n\n        self.data_aug = data_aug\n        if data_aug:\n            if hasattr(args, 'photo_aug'):\n                self.photo_aug = PhotoMetricDistortionMultiViewImage(**args.photo_aug)\n            else:\n                self.photo_aug = False\n        \n        self.dataset_base_dir = dataset_base_dir\n        self.json_file_path = json_file_path\n\n        # dataset parameters\n        self.dataset_name = args.dataset_name\n        self.num_category = args.num_category\n\n        self.h_org = args.org_h\n        self.w_org = args.org_w\n        self.h_crop = args.crop_y\n\n        # parameters related to service network\n        self.h_net = args.resize_h\n        self.w_net = args.resize_w\n        self.ipm_h = args.ipm_h\n        self.ipm_w = args.ipm_w\n        self.u_ratio = float(self.w_net) / float(self.w_org)\n        self.v_ratio = float(self.h_net) / float(self.h_org - self.h_crop)\n        self.top_view_region = args.top_view_region\n        \n        self.max_lanes = args.max_lanes\n\n        self.K = args.K\n        self.H_crop = homography_crop_resize([args.org_h, args.org_w], args.crop_y, [args.resize_h, args.resize_w])\n        self.fix_cam = False\n        \n        self.x_min, self.x_max = self.top_view_region[0, 0], self.top_view_region[1, 0]\n        self.y_min, self.y_max = self.top_view_region[2, 1], self.top_view_region[0, 1]\n        \n        self.anchor_y_steps = args.anchor_y_steps\n        self.num_y_steps = len(self.anchor_y_steps)\n\n        self.anchor_y_steps_dense = args.get(\n            'anchor_y_steps_dense',\n            np.linspace(3, 103, 200))\n        args.anchor_y_steps_dense = self.anchor_y_steps_dense\n        self.num_y_steps_dense = len(self.anchor_y_steps_dense)\n\n        self.anchor_dim = 3 * self.num_y_steps + args.num_category\n\n        self.save_json_path = args.save_json_path\n\n        # parse ground-truth file\n        self.processed_info_dict = None\n\n        self.label_list = self.gen_single_file_json()\n        self.n_samples = len(self.label_list)\n        self.processed_info_dict = self.init_dataset_3D(dataset_base_dir, json_file_path)\n\n    def gen_single_file_json(self):\n        gt_labels_json_dict = [json.loads(line) for line in open(self.json_file_path, 'r').readlines()]\n        \n        # e.g., xxx/standard/train\n        json_save_dir = self.json_file_path.split('.json')[0]\n\n        mkdir_if_missing(json_save_dir)\n        \n        label_list_path = self.json_file_path.rsplit('/', 1)[0] + '/%s_json_list.txt' % self.json_file_path.rsplit('/', 1)[-1].split('.json')[0]\n\n        if os.path.isfile(label_list_path):\n            with open(label_list_path, 'r') as f:\n                label_list = f.readlines()\n                label_list = list(map(lambda x: os.path.join(json_save_dir, x.strip()), label_list))\n        else:\n            label_list = []\n            \n            for single_info in tqdm(gt_labels_json_dict):\n                img_p = Path(single_info['raw_file'])\n                json_dir = os.path.join(json_save_dir, img_p.parent.name)\n                mkdir_if_missing(json_dir)\n                json_p = os.path.join(json_dir, img_p.stem + '.json')\n                single_info['file_path'] = json_p.split(json_save_dir + '/')[-1]\n                json.dump(single_info, open(json_p, 'w'), separators=(',', ': '), indent=4)\n                label_list.append(json_p)\n        \n            with open(label_list_path, 'w') as f:\n                for label_js in label_list:\n                    f.write(label_js)\n                    f.write('\\n')\n        \n        return label_list\n    \n        \n    def parse_processed_info_dict_apollo(self, idx):\n        keys = self.processed_info_dict.keys()\n        keys = list(keys)\n        res = []\n        for k in keys:\n            res.append(self.processed_info_dict[k][idx])\n        return res\n\n\n    def __len__(self):\n        \"\"\"\n        Conventional len method\n        \"\"\"\n        return len(self.label_list)\n\n    # new getitem, WIP\n    def WIP__getitem__(self, idx):\n        \"\"\"\n        Args: idx (int): Index in list to load image\n        \"\"\"\n        \n        # preprocess data from json file\n\n        _label_image_path, _label_cam_height, _label_cam_pitch, \\\n        cam_extrinsics, cam_intrinsics, \\\n        _label_laneline, _label_laneline_org, \\\n        _gt_laneline_visibility, _gt_laneline_category, \\\n        _gt_laneline_category_org, gt_laneline_img = self.parse_processed_info_dict_apollo(idx)\n        \n        if not self.fix_cam:\n            gt_cam_height = _label_cam_height\n            gt_cam_pitch = _label_cam_pitch\n            intrinsics = cam_intrinsics\n            extrinsics = cam_extrinsics\n        else:\n            raise ValueError('check release with training, fix_cam=False')\n        img_name = _label_image_path\n\n        with open(img_name, 'rb') as f:\n            image = (Image.open(f).convert('RGB'))\n\n        # image preprocess with crop and resize\n        image = F.crop(image, self.h_crop, 0, self.h_org-self.h_crop, self.w_org)\n        image = F.resize(image, size=(self.h_net, self.w_net), interpolation=InterpolationMode.BILINEAR)\n\n        gt_category_2d = _gt_laneline_category_org\n\n        if self.data_aug:\n            img_rot, aug_mat = data_aug_rotate(image)\n            if self.photo_aug:\n                img_rot = self.photo_aug(\n                    dict(img=img_rot.copy().astype(np.float32))\n                )['img']\n            \n            image = Image.fromarray(np.clip(img_rot, 0, 255).astype(np.uint8))\n        image = self.totensor(image).float()\n        image = self.normalize(image)\n        gt_cam_height = torch.tensor(gt_cam_height, dtype=torch.float32)\n        gt_cam_pitch = torch.tensor(gt_cam_pitch, dtype=torch.float32)\n        intrinsics = torch.from_numpy(intrinsics)\n        extrinsics = torch.from_numpy(extrinsics)\n\n        # prepare binary segmentation label map\n        seg_label = np.zeros((self.h_net, self.w_net), dtype=np.uint8)\n        seg_idx_label = np.zeros((self.max_lanes, self.h_net, self.w_net), dtype=np.uint8)\n        ground_lanes = np.zeros((self.max_lanes, self.anchor_dim), dtype=np.float32)\n        ground_lanes_dense = np.zeros(\n            (self.max_lanes, self.num_y_steps_dense * 3), dtype=np.float32)\n        \n        gt_lanes = _label_laneline_org\n        H_g2im, P_g2im, H_crop = self.transform_mats_impl(_label_cam_pitch, \n                                                                    _label_cam_height)\n        M = np.matmul(H_crop, P_g2im)\n        # update transformation with image augmentation\n        if self.data_aug:\n            M = np.matmul(aug_mat, M)\n        \n        lidar2img = np.eye(4).astype(np.float32)\n        lidar2img[:3] = M\n            \n        SEG_WIDTH = 80\n        thickness_st = int(SEG_WIDTH / 2550 * self.h_net / 2)\n\n        for i, lane in enumerate(gt_lanes):\n            if i >= self.max_lanes:\n                break\n\n            # TODO remove this\n            if lane.shape[0] < 2:\n                continue\n\n            if _gt_laneline_category_org[i] > self.num_category:\n                continue\n\n            xs, zs = resample_laneline_in_y(lane, self.anchor_y_steps)\n            vis = np.logical_and(\n                self.anchor_y_steps > lane[:, 1].min() - 5,\n                self.anchor_y_steps < lane[:, 1].max() + 5)\n\n            ground_lanes[i][0: self.num_y_steps] = xs\n            ground_lanes[i][self.num_y_steps:2*self.num_y_steps] = zs\n            ground_lanes[i][2*self.num_y_steps:3*self.num_y_steps] = vis * 1.0\n            ground_lanes[i][self.anchor_dim - self.num_category] = 0.0\n            ground_lanes[i][self.anchor_dim - self.num_category + 1] = 1.0\n\n            xs_dense, zs_dense = resample_laneline_in_y(\n                lane, self.anchor_y_steps_dense)\n            vis_dense = np.logical_and(\n                self.anchor_y_steps_dense > lane[:, 1].min(),\n                self.anchor_y_steps_dense < lane[:, 1].max())\n            ground_lanes_dense[i][0: self.num_y_steps_dense] = xs_dense\n            ground_lanes_dense[i][1*self.num_y_steps_dense: 2*self.num_y_steps_dense] = zs_dense\n            ground_lanes_dense[i][2*self.num_y_steps_dense: 3*self.num_y_steps_dense] = vis_dense * 1.0\n\n            x_2d, y_2d = projective_transformation(M, \n                                                   lane[:, 0],\n                                                   lane[:, 1], \n                                                   lane[:, 2])\n            \n            for j in range(len(x_2d) - 1):\n                # empirical setting.\n                k = 2.7e-2 - ((2.5e-2 - 5e-5) / 600) * y_2d[j]\n                thickness = max(round(thickness_st - k * (self.h_net - y_2d[j])), 2)\n                if thickness >= 6:\n                    thickness += 1\n\n                seg_label = cv2.line(seg_label,\n                                    (int(x_2d[j]), int(y_2d[j])), \n                                    (int(x_2d[j+1]), int(y_2d[j+1])),\n                                    color=1,\n                                    thickness=thickness)\n                seg_idx_label[i] = cv2.line(seg_idx_label[i],\n                                        (int(x_2d[j]), int(y_2d[j])),\n                                        (int(x_2d[j+1]), int(y_2d[j+1])),\n                                        color=gt_category_2d[i].item(),\n                                        thickness=thickness,\n                                        lineType=cv2.LINE_AA\n                                        )\n\n        seg_label = torch.from_numpy(seg_label.astype(np.float32))\n        seg_label.unsqueeze_(0)\n        \n        extra_dict = {}\n        \n        extra_dict['seg_label'] = seg_label\n        extra_dict['seg_idx_label'] = seg_idx_label\n        extra_dict['ground_lanes'] = ground_lanes\n        extra_dict['ground_lanes_dense'] = ground_lanes_dense\n        extra_dict['lidar2img'] = lidar2img\n        extra_dict['pad_shape'] = torch.Tensor(seg_idx_label.shape[-2:]).float()\n        extra_dict['idx_json_file'] = self.label_list[idx]\n\n        extra_dict['image'] = image\n        if self.data_aug:\n            aug_mat = torch.from_numpy(aug_mat.astype(np.float32))\n            extra_dict['aug_mat'] = aug_mat\n        \n        extra_dict['cam_extrinsics'] = cam_extrinsics\n        extra_dict['cam_intrinsics'] = cam_intrinsics\n        return extra_dict\n\n    # old getitem, workable\n    def __getitem__(self, idx):\n        \"\"\"\n        Args: idx (int): Index in list to load image\n        \"\"\"\n        return self.WIP__getitem__(idx)\n\n    def init_dataset_3D(self, dataset_base_dir, json_file_path):\n        \"\"\"\n        :param dataset_info_file:\n        :return: image paths, labels in unormalized net input coordinates\n\n        data processing:\n        ground truth labels map are scaled wrt network input sizes\n        \"\"\"\n\n        # load image path, and lane pts\n        label_image_path = []\n        gt_laneline_pts_all = []\n        gt_centerline_pts_all = []\n        gt_laneline_visibility_all = []\n        gt_centerline_visibility_all = []\n        gt_laneline_category_all = []\n        gt_cam_height_all = []\n        gt_cam_pitch_all = []\n\n        assert ops.exists(json_file_path), '{:s} not exist'.format(json_file_path)\n        \n        with open(json_file_path, 'r') as file:\n            for idx, line in enumerate(file):\n                \n                info_dict = json.loads(line)\n                # print('load json : %s | %s' % (idx, info_dict['raw_file']))\n                image_path = ops.join(dataset_base_dir, info_dict['raw_file'])\n                assert ops.exists(image_path), '{:s} not exist'.format(image_path)\n\n                label_image_path.append(image_path)\n\n                gt_lane_pts = info_dict['laneLines']\n                gt_lane_visibility = info_dict['laneLines_visibility']\n                for i, lane in enumerate(gt_lane_pts):\n                    lane = np.array(lane)\n                    gt_lane_pts[i] = lane\n                    gt_lane_visibility[i] = np.array(gt_lane_visibility[i])\n                \n                gt_laneline_pts_all.append(gt_lane_pts)\n                gt_laneline_visibility_all.append(gt_lane_visibility)\n                \n                if 'category' in info_dict:\n                    gt_laneline_category = info_dict['category']\n                    gt_laneline_category_all.append(np.array(gt_laneline_category, dtype=np.int32))\n                else:\n                    gt_laneline_category_all.append(np.ones(len(gt_lane_pts), dtype=np.int32))\n\n                if not self.fix_cam:\n                    gt_cam_height = info_dict['cam_height']\n                    gt_cam_height_all.append(gt_cam_height)\n                    gt_cam_pitch = info_dict['cam_pitch']\n                    gt_cam_pitch_all.append(gt_cam_pitch)\n        \n        label_image_path = np.array(label_image_path)\n        gt_cam_height_all = np.array(gt_cam_height_all)\n        gt_cam_pitch_all = np.array(gt_cam_pitch_all)\n        gt_laneline_pts_all_org = copy.deepcopy(gt_laneline_pts_all)\n        gt_laneline_category_all_org = copy.deepcopy(gt_laneline_category_all)\n        \n        visibility_all_flat = []\n        gt_laneline_im_all = []\n        gt_centerline_im_all = []\n        cam_extrinsics_all = []\n        cam_intrinsics_all = []\n        for idx in range(len(gt_laneline_pts_all)):\n            # fetch camera height and pitch\n            gt_cam_height = gt_cam_height_all[idx]\n            gt_cam_pitch = gt_cam_pitch_all[idx]\n            if not self.fix_cam:\n                P_g2im = projection_g2im(gt_cam_pitch, gt_cam_height, self.K)\n                H_g2im = homograpthy_g2im(gt_cam_pitch, gt_cam_height, self.K)\n                H_im2g = np.linalg.inv(H_g2im)\n            else:\n                P_g2im = self.P_g2im\n                H_im2g = self.H_im2g\n\n            gt_lanes = gt_laneline_pts_all[idx]\n            gt_visibility = gt_laneline_visibility_all[idx]\n\n            # prune gt lanes by visibility labels\n            gt_lanes = [prune_3d_lane_by_visibility(gt_lane, gt_visibility[k]) for k, gt_lane in enumerate(gt_lanes)]\n            gt_laneline_pts_all_org[idx] = gt_lanes\n            \n            # project gt laneline to image plane\n            gt_laneline_im = []\n            for gt_lane in gt_lanes:\n                x_vals, y_vals = projective_transformation(P_g2im, gt_lane[:,0], gt_lane[:,1], gt_lane[:,2])\n                gt_laneline_im_oneline = np.array([x_vals, y_vals]).T.tolist()\n                gt_laneline_im.append(gt_laneline_im_oneline)\n            gt_laneline_im_all.append(gt_laneline_im)\n\n            # generate ex/in from apollo\n            cam_intrinsics = self.K\n            cam_extrinsics = np.zeros((4,4))\n            cam_extrinsics[-1, -1] = 1\n            cam_extrinsics[2,3] = gt_cam_height\n            cam_extrinsics_all.append(cam_extrinsics)\n            cam_intrinsics_all.append(cam_intrinsics)\n        \n        visibility_all_flat = np.array(visibility_all_flat)\n        \n        processed_info_dict = {}\n        processed_info_dict['label_json_path'] = label_image_path\n        \n        processed_info_dict['gt_cam_height_all'] = gt_cam_height_all\n        processed_info_dict['gt_cam_pitch_all'] = gt_cam_pitch_all\n        processed_info_dict['cam_extrinsics'] = cam_extrinsics_all\n        processed_info_dict['cam_intrinsics'] = cam_intrinsics_all\n\n        processed_info_dict['gt_laneline_pts_all'] = gt_laneline_pts_all\n        processed_info_dict['gt_laneline_pts_all_org'] = gt_laneline_pts_all_org\n        processed_info_dict['gt_laneline_visibility_all'] = gt_laneline_visibility_all\n\n        processed_info_dict['gt_laneline_category_all'] = gt_laneline_category_all\n        processed_info_dict['gt_laneline_category_all_org'] = gt_laneline_category_all_org\n        processed_info_dict['gt_laneline_im_all'] = gt_laneline_im_all\n        return processed_info_dict\n\n    def transform_mats_impl(self, cam_pitch, cam_height):\n        if not self.fix_cam:\n            H_g2im = homograpthy_g2im(cam_pitch, cam_height, self.K)\n            P_g2im = projection_g2im(cam_pitch, cam_height, self.K)\n            return H_g2im, P_g2im, self.H_crop\n        else:\n            return self.H_g2im, self.P_g2im, self.H_crop\n\ndef data_aug_rotate(img):\n    # assume img in PIL image format\n    rot = random.uniform(-np.pi/18, np.pi/18)\n    center_x = img.width / 2\n    center_y = img.height / 2\n    rot_mat = cv2.getRotationMatrix2D((center_x, center_y), rot, 1.0)\n    img_rot = np.array(img)\n    img_rot = cv2.warpAffine(img_rot, rot_mat, (img.width, img.height), flags=cv2.INTER_LINEAR)\n    rot_mat = np.vstack([rot_mat, [0, 0, 1]])\n    return img_rot, rot_mat\n\n\ndef seed_worker(worker_id):\n    worker_seed = torch.initial_seed() % 2**32\n    np.random.seed(worker_seed)\n    random.seed(worker_seed)\n\n\ndef get_loader(transformed_dataset, args):\n    \"\"\"\n        create dataset from ground-truth\n        return a batch sampler based ont the dataset\n    \"\"\"\n\n    sample_idx = range(transformed_dataset.n_samples)\n\n    g = torch.Generator()\n    g.manual_seed(0)\n\n    discarded_sample_start = len(sample_idx) // args.batch_size * args.batch_size\n    if args.proc_id == 0:\n        print(\"Discarding images:\")\n    if args.proc_id == 0:\n        if hasattr(transformed_dataset, '_label_image_path'):\n            print(transformed_dataset._label_image_path[discarded_sample_start: len(sample_idx)])\n        else:\n            print(len(sample_idx) - discarded_sample_start)\n    sample_idx = sample_idx[0 : discarded_sample_start]\n    \n    if args.dist:\n        if args.proc_id == 0:\n            print('use distributed sampler')\n        if 'standard' in args.dataset_name or 'rare_subset' in args.dataset_name or 'illus_chg' in args.dataset_name:\n            data_sampler = torch.utils.data.distributed.DistributedSampler(transformed_dataset, shuffle=True, drop_last=True)\n            data_loader = DataLoader(transformed_dataset,\n                                        batch_size=args.batch_size, \n                                        sampler=data_sampler,\n                                        num_workers=args.nworkers, \n                                        pin_memory=True,\n                                        persistent_workers=True,\n                                        worker_init_fn=seed_worker,\n                                        generator=g,\n                                        drop_last=True)\n        else:\n            data_sampler = torch.utils.data.distributed.DistributedSampler(transformed_dataset)\n            data_loader = DataLoader(transformed_dataset,\n                                        batch_size=args.batch_size, \n                                        sampler=data_sampler,\n                                        num_workers=args.nworkers, \n                                        pin_memory=True,\n                                        persistent_workers=True,\n                                        worker_init_fn=seed_worker,\n                                        generator=g)\n    else:\n        if args.proc_id == 0:\n            print(\"use default sampler\")\n        data_sampler = torch.utils.data.sampler.SubsetRandomSampler(sample_idx)\n        data_loader = DataLoader(transformed_dataset,\n                                batch_size=args.batch_size, sampler=data_sampler,\n                                num_workers=args.nworkers, pin_memory=True,\n                                persistent_workers=True,\n                                worker_init_fn=seed_worker,\n                                generator=g)\n\n    if args.dist:\n        return data_loader, data_sampler\n    return data_loader\n"
  },
  {
    "path": "data/transform.py",
    "content": "import numpy as np\nimport mmcv\nimport torch\nimport torch.nn.functional as F\nimport PIL\nimport random\n\n\ndef get_random_state() -> np.random.RandomState:\n    return np.random.RandomState(random.randint(0, (1 << 32) - 1))\n\n\ndef normal(\n    loc=0.0,\n    scale=1.0,\n    size=None,\n    random_state=None,\n):\n    if random_state is None:\n        random_state = get_random_state()\n    return random_state.normal(loc, scale, size)\n\n\nclass PhotoMetricDistortionMultiViewImage:\n    \"\"\"Apply photometric distortion to image sequentially, every transformation\n    is applied with a probability of 0.5. The position of random contrast is in\n    second or second to last.\n    1. random brightness\n    2. random contrast (mode 0)\n    3. convert color from BGR to HSV\n    4. random saturation\n    5. random hue\n    6. convert color from HSV to BGR\n    7. random contrast (mode 1)\n    8. randomly swap channels\n    Args:\n        brightness_delta (int): delta of brightness.\n        contrast_range (tuple): range of contrast.\n        saturation_range (tuple): range of saturation.\n        hue_delta (int): delta of hue.\n    \"\"\"\n\n    def __init__(self,\n                 brightness_delta=32,\n                 contrast_range=(0.5, 1.5),\n                 saturation_range=(0.5, 1.5),\n                 hue_delta=18):\n        self.brightness_delta = brightness_delta\n        self.contrast_lower, self.contrast_upper = contrast_range\n        self.saturation_lower, self.saturation_upper = saturation_range\n        self.hue_delta = hue_delta\n\n    def __call__(self, results):\n        \"\"\"Call function to perform photometric distortion on images.\n        Args:\n            results (dict): Result dict from loading pipeline.\n        Returns:\n            dict: Result dict with images distorted.\n        \"\"\"\n        imgs = results['img']\n        if not isinstance(imgs, list):\n            imgs = [imgs]\n\n        new_imgs = []\n        for img in imgs:\n            assert img.dtype == np.float32, \\\n                'PhotoMetricDistortion needs the input image of dtype np.float32,'\\\n                ' please set \"to_float32=True\" in \"LoadImageFromFile\" pipeline'\n            # random brightness\n            if np.random.randint(2):\n                delta = np.random.uniform(-self.brightness_delta,\n                                    self.brightness_delta)\n                img += delta\n\n            # mode == 0 --> do random contrast first\n            # mode == 1 --> do random contrast last\n            mode = np.random.randint(2)\n            if mode == 1:\n                if np.random.randint(2):\n                    alpha = np.random.uniform(self.contrast_lower,\n                                        self.contrast_upper)\n                    img *= alpha\n\n            # convert color from BGR to HSV\n            img = mmcv.bgr2hsv(img)\n\n            # random saturation\n            if np.random.randint(2):\n                img[..., 1] *= np.random.uniform(self.saturation_lower,\n                                            self.saturation_upper)\n\n            # random hue\n            if np.random.randint(2):\n                img[..., 0] += np.random.uniform(-self.hue_delta, self.hue_delta)\n                img[..., 0][img[..., 0] > 360] -= 360\n                img[..., 0][img[..., 0] < 0] += 360\n\n            # convert color from HSV to BGR\n            img = mmcv.hsv2bgr(img)\n\n            # random contrast\n            if mode == 0:\n                if np.random.randint(2):\n                    alpha = np.random.uniform(self.contrast_lower,\n                                        self.contrast_upper)\n                    img *= alpha\n\n            # randomly swap channels\n            if np.random.randint(2):\n                img = img[..., np.random.permutation(3)]\n            new_imgs.append(img)\n        if not isinstance(results['img'], list):\n            new_imgs = new_imgs[0]\n\n        results['img'] = new_imgs\n        return results\n"
  },
  {
    "path": "docs/data_preparation.md",
    "content": "# Data Preparation\n\n## OpenLane\n\nFollow [OpenLane](https://github.com/OpenDriveLab/PersFormer_3DLane#dataset) to download dataset and then link it under `data` directory.\n\n```bash\ncd data && mkdir openlane && cd openlane\nln -s ${OPENLANE_PATH}/images .\nln -s ${OPENLANE_PATH}/lane3d_1000 .\n```\n\n## ONCE\n\nFollow [ONCE](https://github.com/once-3dlanes/once_3dlanes_benchmark#data-preparation) to download dataset, and then link it under `data` directory.\n\n```bash\ncd data\nln -s ${once} .\n```\n\n## Apollo\n\nFollow [Apollo](https://github.com/yuliangguo/Pytorch_Generalized_3D_Lane_Detection#data-preparation) to download dataset and link it under `data` directory.\n\n```bash\ncd data && mkdir apollosyn_gen-lanenet\ncd apollosyn_gen-lanenet\nln -s ${Apollo_Sim_3D_Lane_Release} .\nln -s ${data_splits} .\n```\n\n\nYour data directory should be like:\n\n```bash\n|-- apollosyn_gen-lanenet\n    |-- Apollo_Sim_3D_Lane_Release\n    |   |-- depth\n    |   |-- images\n    |   |-- img_list.txt\n    |   |-- labels\n    |   |-- laneline_label.json\n    |   `-- segmentation\n    `-- data_splits\n        |-- illus_chg\n        |-- rare_subset\n        `-- standard\n|-- Load_Data.py\n|-- __init__.py\n|-- apollo_dataset.py\n|-- once -> ${once}\n    |-- annotation\n    |-- data\n    |-- data_check.py\n    |-- list\n    |-- raw_cam01\n    |-- raw_cam_multi\n    |-- train\n    `-- val\n|-- openlane\n|   |-- images -> ${openlane}/images/\n        |-- training\n        `-- validation\n|   `-- lane3d_1000 -> ${openlane}/lane3d_1000/\n        |-- test\n        |-- training\n        `-- validation\n`-- transform.py\n```"
  },
  {
    "path": "docs/install.md",
    "content": "# Environment\n\nIt is recommanded to build a new virtual environment.\n\n## 1. Install pytorch and requirements.\n\n```bash\n# first install pytorch\nconda install pytorch==1.8.0 torchvision==0.9.0 torchaudio==0.8.0 cudatoolkit=10.1 -c pytorch\n\n# then clone LATR and change directory to it to install requirements\ncd ${LATR_PATH}\npython -m pip install -r requirements.txt\n```\n\n## 2. Install mm packages\n\n### 2.1 Install `mmcv`\n\n```bash\ngit clone https://github.com/open-mmlab/mmcv.git\ncd mmcv && git checkout v1.5.0\nFORCE_CUDA=1 MMCV_WITH_OPS=1 python -m pip install .\n```\n\n### 2.2 Install other mm packages\n\nInstall [mmdet](https://github.com/open-mmlab/mmdetection) and [mmdet3d](https://github.com/open-mmlab/mmdetection3d). Note that we use `mmdet==2.24.0` and `mmdet3d==1.0.0rc3`.\n"
  },
  {
    "path": "docs/train_eval.md",
    "content": "# Training and Evaluation\n\n## Train\n\n### Openlane\n\n- Base version:\n\n```bash\nCUDA_VISIBLE_DEVICES='0,1,2,3' python -m torch.distributed.launch --nproc_per_node 4 main.py --config config/release_iccv/latr_1000_baseline.py\n```\n\n- lite version:\n\n```bash\nCUDA_VISIBLE_DEVICES='0,1,2,3' python -m torch.distributed.launch --nproc_per_node 4 main.py --config config/release_iccv/latr_1000_baseline_lite.py\n```\n\n### ONCE\n\n```bash\nCUDA_VISIBLE_DEVICES='0,1,2,3' python -m torch.distributed.launch --nproc_per_node 4 main.py --config config/release_iccv/once.py\n```\n\n### Apollo\n\n- Balanced Scene\n\n```bash\nCUDA_VISIBLE_DEVICES='0,1,2,3' python -m torch.distributed.launch --nproc_per_node 4 main.py --config config/release_iccv/apollo_standard.py\n```\n\n- Rare Subset\n\n```bash\nCUDA_VISIBLE_DEVICES='0,1,2,3' python -m torch.distributed.launch --nproc_per_node 4 main.py --config config/release_iccv/apollo_rare.py\n```\n\n- Visual Variations\n\n```bash\nCUDA_VISIBLE_DEVICES='0,1,2,3' python -m torch.distributed.launch --master_port=29284 --nproc_per_node 4 main.py --config config/release_iccv/apollo_illu.py\n```\n\n## Evaluation\n\n### Openlane\n\n- Base version:\n\n```bash\nCUDA_VISIBLE_DEVICES='0,1,2,3' python -m torch.distributed.launch --nproc_per_node 4 main.py --config config/release_iccv/latr_1000_baseline.py --cfg-options evaluate=true eval_ckpt=pretrained_models/openlane.pth\n```\n\n- lite version:\n\n```bash\nCUDA_VISIBLE_DEVICES='0,1,2,3' python -m torch.distributed.launch --nproc_per_node 4 main.py --config config/release_iccv/latr_1000_baseline_lite.py --cfg-options evaluate=true eval_ckpt=pretrained_models/openlane_lite.pth\n```\n\n### ONCE\n\n```bash\nCUDA_VISIBLE_DEVICES='0,1,2,3' python -m torch.distributed.launch --nproc_per_node 4 main.py --config config/release_iccv/once.py --cfg-options evaluate=true eval_ckpt=pretrained_models/once.pth\n```\n\n### Apollo\n\n- Balanced Scene\n\n```bash\nCUDA_VISIBLE_DEVICES='0,1,2,3' python -m torch.distributed.launch --nproc_per_node 4 main.py --config config/release_iccv/apollo_standard.py --cfg-options evaluate=true eval_ckpt=pretrained_models/apollo_standard.pth\n```\n\n- Rare Subset\n\n```bash\nCUDA_VISIBLE_DEVICES='0,1,2,3' python -m torch.distributed.launch --nproc_per_node 4 main.py --config config/release_iccv/apollo_rare.py --cfg-options evaluate=true eval_ckpt=pretrained_models/apollo_rare.pth\n```\n\n- Visual Variations\n\n```bash\nCUDA_VISIBLE_DEVICES='0,1,2,3' python -m torch.distributed.launch --master_port=29284 --nproc_per_node 4 main.py --config config/release_iccv/apollo_illu.py --cfg-options evaluate=true eval_ckpt=pretrained_models/apollo_illu.pth\n```"
  },
  {
    "path": "experiments/__init__.py",
    "content": ""
  },
  {
    "path": "experiments/ddp.py",
    "content": "# ==============================================================================\n# Copyright (c) 2022 The PersFormer Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\nimport torch\nimport torch.distributed as dist\nimport torch.multiprocessing as mp\nimport os\nimport subprocess\nimport numpy as np\nimport random\n\ndef setup_dist_launch(args):\n    args.proc_id = args.local_rank\n    world_size = int(os.getenv('WORLD_SIZE', 1))*args.nodes\n    print(\"proc_id: \" + str(args.proc_id))\n    print(\"world size: \" + str(world_size))\n    print(\"local_rank: \" + str(args.local_rank))\n\n    os.environ['WORLD_SIZE'] = str(world_size)\n    os.environ['RANK'] = str(args.proc_id)\n    os.environ['LOCAL_RANK'] = str(args.local_rank)\n\ndef setup_slurm(args):\n    if mp.get_start_method(allow_none=True) is None:\n        mp.set_start_method('spawn')\n\n    args.proc_id = int(os.environ['SLURM_PROCID'])\n    ntasks = int(os.environ['SLURM_NTASKS'])\n    node_list = os.environ['SLURM_NODELIST']\n    num_gpus = torch.cuda.device_count()\n    local_rank = args.proc_id % num_gpus\n    args.local_rank = local_rank\n\n    print(\"proc_id: \" + str(args.proc_id))\n    print(\"world size: \" + str(ntasks))\n    print(\"local_rank: \" + str(local_rank))\n\n    addr = subprocess.getoutput(\n        f'scontrol show hostname {node_list} | head -n1')\n    os.environ['MASTER_PORT'] = str(args.port)\n    os.environ['MASTER_ADDR'] = addr\n\n    os.environ['WORLD_SIZE'] = str(ntasks)\n    os.environ['RANK'] = str(args.proc_id)\n    os.environ['LOCAL_RANK'] = str(local_rank)\n\ndef setup_distributed(args):\n    args.gpu = args.local_rank\n    torch.cuda.set_device(args.gpu)\n    dist.init_process_group(backend='nccl')\n    args.world_size = dist.get_world_size()\n    torch.set_printoptions(precision=10)\n    print('args.world_size', args.world_size)\n\ndef ddp_init(args):\n    args.proc_id, args.gpu, args.world_size = 0, 0, 1\n\n    if args.use_slurm == True:\n        setup_slurm(args)\n    else:\n        setup_dist_launch(args)\n\n    if 'WORLD_SIZE' in os.environ:\n        args.distributed = int(os.environ['WORLD_SIZE']) >= 1\n\n    if args.distributed:\n        setup_distributed(args)\n\n    # deterministic\n    torch.backends.cudnn.benchmark = False\n    torch.backends.cudnn.deterministic = True\n    torch.manual_seed(args.proc_id)\n    np.random.seed(args.proc_id)\n    random.seed(args.proc_id)\n\ndef to_python_float(t):\n    if hasattr(t, 'item'):\n        return t.item()\n    else:\n        return t[0]\n\ndef reduce_tensor(tensor, world_size):\n    rt = tensor.clone()\n    dist.all_reduce(rt, op=dist.ReduceOp.SUM)\n    rt /= world_size\n    return rt\n\n\ndef reduce_tensors(*tensors, world_size):\n    return [reduce_tensor(tensor, world_size) for tensor in tensors]"
  },
  {
    "path": "experiments/gpu_utils.py",
    "content": "import torch\nimport torch.distributed as dist\n\ndef get_rank() -> int:\n    if not dist.is_available():\n        return 0\n    if not dist.is_initialized():\n        return 0\n    return dist.get_rank()\n\n\ndef is_main_process() -> bool:\n    return get_rank() == 0\n\n\ndef gpu_available() -> bool:\n    return torch.cuda.is_available()"
  },
  {
    "path": "experiments/runner.py",
    "content": "import torch\nimport torch.optim\nimport torch.nn as nn\nimport numpy as np\nimport glob\nimport time\nimport os\nfrom tqdm import tqdm\nfrom tensorboardX import SummaryWriter\nimport traceback\nimport shutil\n\nfrom data.Load_Data import *\nfrom data.apollo_dataset import ApolloLaneDataset\nfrom models.latr import LATR\nfrom experiments.gpu_utils import is_main_process\nfrom utils import eval_3D_lane, eval_3D_once\nfrom utils import eval_3D_lane_apollo\nfrom utils.utils import *\n\n# ddp related\nimport torch.distributed as dist\nfrom torch.nn.parallel import DistributedDataParallel as DDP\nfrom .ddp import *\nimport os.path as osp\nfrom .gpu_utils import gpu_available\nfrom mmcv.runner.optimizer import build_optimizer\n\n\nclass Runner:\n    def __init__(self, args):\n        self.args = args\n        set_work_dir(self.args)\n        self.logger = create_logger(args)\n\n        # Check GPU availability\n        if is_main_process():\n            if not gpu_available():\n                raise Exception(\"No gpu available for usage\")\n            if int(os.getenv('WORLD_SIZE', 1)) >= 1:\n                self.logger.info(\"Let's use %s\" % os.environ['WORLD_SIZE'] + \"GPUs!\")\n                torch.cuda.empty_cache()\n        \n        # Get Dataset\n        if is_main_process():\n            self.logger.info(\"Loading Dataset ...\")\n\n        self.val_gt_file = ops.join(args.save_path, 'test.json')\n        if not args.evaluate:\n            self.train_dataset, self.train_loader, self.train_sampler = self._get_train_dataset()\n        else:\n            self.train_dataset, self.train_loader, self.train_sampler = [],[],[]\n        self.valid_dataset, self.valid_loader, self.valid_sampler = self._get_valid_dataset()\n\n        if 'openlane' in args.dataset_name:\n            self.evaluator = eval_3D_lane.LaneEval(args, logger=self.logger)\n        elif 'apollo' in args.dataset_name:\n            self.evaluator = eval_3D_lane_apollo.LaneEval(args, logger=self.logger)\n        elif 'once' in args.dataset_name:\n            self.evaluator = eval_3D_once.LaneEval()\n        else:\n            assert False\n        # Tensorboard writer\n        if not args.no_tb and is_main_process():\n            tensorboard_path = os.path.join(args.save_path, 'Tensorboard/')\n            mkdir_if_missing(tensorboard_path)\n            self.writer = SummaryWriter(tensorboard_path)\n        \n        if is_main_process():\n            self.logger.info(\"Init Done!\")\n        \n        self.is_apollo = False\n        if 'apollo' in args.dataset_name:\n            self.is_apollo = True\n\n    def train(self):\n        args = self.args\n\n        # Get Dataset\n        train_loader = self.train_loader\n        train_sampler = self.train_sampler\n\n        global lowest_loss, best_f1_epoch, best_val_f1, best_epoch\n        # Define model or resume\n        \n        model, optimizer, scheduler, best_epoch, \\\n            lowest_loss, best_f1_epoch, best_val_f1 = self._get_model_ddp()\n        \n        self._log_model_info(model)\n        \n        def save_cur_ckpt(\n                loss,\n                with_eval=True,\n                eval_stats=None):\n            # Save model\n            if not with_eval:\n                self.save_checkpoint({\n                    'state_dict': model.module.state_dict(),\n                    'optimizer': optimizer.state_dict(),\n                    'scheduler': scheduler.state_dict()\n                }, False, epoch+1, self.args.save_path)\n            else:\n                total_score = loss.item() # loss_list[0].avg\n                if is_main_process():\n                    # File to keep latest epoch\n                    with open(os.path.join(args.save_path, 'first_run.txt'), 'w') as f:\n                        f.write(str(epoch + 1))\n                global best_val_f1, best_f1_epoch, lowest_loss, best_epoch\n\n                to_copy, to_save = False, True # False if args.save_best else True\n\n                if total_score < lowest_loss:\n                    best_epoch = epoch + 1\n                    lowest_loss = total_score\n                if eval_stats[0] > best_val_f1:\n                    to_copy = True\n                    best_f1_epoch = epoch + 1\n                    best_val_f1 = eval_stats[0]\n                    to_save = True\n                self.log_eval_stats(eval_stats)\n                self.logger.info(\"===> Last best F1 was {:.8f} in epoch {}\".format(best_val_f1, best_f1_epoch))\n                if not to_save:\n                    return\n                self.save_checkpoint({\n                        'state_dict': model.module.state_dict(),\n                        'optimizer': optimizer.state_dict(),\n                        'scheduler': scheduler.state_dict()\n                    }, to_copy, epoch+1, self.args.save_path)\n\n        # Start training and validation for nepochs\n        for epoch in range(args.start_epoch, args.nepochs):\n            if is_main_process():\n                self.logger.info(\"\\n => Start train set for EPOCH {}\".format(epoch + 1))\n                self.logger.info('lr is set to {}'.format(optimizer.param_groups[0]['lr']))\n            \n            if args.distributed:\n                train_sampler.set_epoch(epoch)\n\n            # Define container objects to keep track of multiple losses/metrics\n            batch_time = AverageMeter()\n            data_time = AverageMeter()         # compute FPS\n            epoch_time = AverageMeter()\n            \n            loss = 0\n\n            # Specify operation modules\n            model.train()\n            # compute timing\n            end = time.time()\n            epoch_time.start = end\n            # Start training loop\n            train_pbar = tqdm(total=len(train_loader), ncols=60)\n            \n            for i, extra_dict in enumerate(train_loader):\n                train_pbar.update(1)\n                data_time.update(time.time() - end)\n                if gpu_available():\n                    json_files = extra_dict.pop('idx_json_file')\n                    for k, v in extra_dict.items():\n                        extra_dict[k] = v.cuda()\n                    image = extra_dict['image']\n                image = image.contiguous().float()\n                # Run model\n                optimizer.zero_grad()\n\n                output = model(image=image, extra_dict=extra_dict, is_training=True)\n                \n                loss, loss_info = self._log_training_loss(\n                    output, epoch, step=i, data_loader=train_loader)\n\n                train_pbar.set_postfix(loss=loss.item())\n                \n                if is_main_process():\n                    self.writer.add_scalar('lr', optimizer.param_groups[0]['lr'], epoch)\n                \n                # Setup backward pass\n                loss.backward()\n\n                # Clip gradients (usefull for instabilities or mistakes in ground truth)\n                if args.clip_grad_norm != 0:\n                    nn.utils.clip_grad_norm_(model.parameters(), args.clip_grad_norm)\n\n                # update params\n                optimizer.step()\n\n                if args.lr_policy == 'cosine_warmup':\n                    scheduler.step(epoch + i / len(train_loader))\n                elif args.lr_policy == 'PolyLR':\n                    scheduler.step()\n\n                # Time trainig iteration\n                batch_time.update(time.time() - end)\n                end = time.time()\n\n                # Print info\n                if (i + 1) % args.print_freq == 0 and is_main_process():\n                    self.logger.info('Epoch: [{0}][{1}/{2}]\\t'\n                        'Batch Time / Avg Time {batch_time.val:.3f} ({batch_time.avg:.3f})\\t'\n                        'Loss {loss:.8f} {loss_info}'.format(\n                            epoch+1, i+1, len(train_loader), \n                            batch_time=batch_time, data_time=data_time,\n                            loss=loss.item(), loss_info=loss_info))\n            train_pbar.close()\n\n            epoch_time.update(time.time() - epoch_time.start)\n\n            if is_main_process():\n                self.logger.info('Epoch time : {:.3f} hours.'.format(epoch_time.val / 60 / 60))\n            \n            # Adjust learning rate\n            if args.lr_policy != 'cosine_warmup':\n                scheduler.step()\n            \n            meet_eval_freq = args.eval_freq > 0 and (epoch + 1) % args.eval_freq == 0\n            last_ep = (epoch == args.nepochs - 1)\n\n            if meet_eval_freq or last_ep:\n                loss_valid_list, eval_stats = self.validate(model)\n                if eval_stats[0] >= best_val_f1:\n                    self.logger.info(' >>> to save new best model at ep : %s with F1 %s' % ((epoch+1), eval_stats[0]))\n                    save_cur_ckpt(loss, with_eval=True, eval_stats=eval_stats)\n                elif last_ep:\n                    self.logger.info(' >>> to save the last model at ep : %s with F1 %s' % ((epoch+1), eval_stats[0]))\n                    save_cur_ckpt(loss, with_eval=True, eval_stats=eval_stats)\n                else:\n                    self.logger.info(' >>> skip model at ep : %s with lower F1 : %s' % ((epoch+1), eval_stats[0]))\n            \n                self.log_eval_stats(eval_stats)\n\n            dist.barrier()\n            torch.cuda.empty_cache()\n\n        # at the end of training\n        if not args.no_tb and is_main_process():\n            self.writer.close()\n\n    def _log_model_info(self, model):\n        args = self.args\n        if not is_main_process():\n            return\n        \n        self.logger.info(40*\"=\"+\"\\nArgs:{}\\n\".format(args)+40*\"=\")\n        self.logger.info(\"Init model: '{}'\".format(args.mod))\n        self.logger.info(\"Number of parameters in model {} is {:.3f}M\".format(args.mod, sum(tensor.numel() for tensor in model.parameters())/1e6))\n\n    def _log_training_loss(self, output, epoch, step, data_loader):\n        loss = 0.0\n        loss_info = ''\n        for k, v in output.items():\n            if 'loss' in k:\n                loss = loss + v\n                loss_info = loss_info + '| %s:%.4f ' % (k, v.item() if isinstance(v, torch.Tensor) else v)\n                if isinstance(v, torch.Tensor):\n                    v = v.item()\n                if is_main_process():\n                    self.writer.add_scalar(k, v, epoch*len(data_loader) + step)\n        return loss, loss_info\n\n    def save_checkpoint(self, state, to_copy, epoch, save_path):\n        if is_main_process():\n            self.logger.info('Saving checkpoint to {}'.format(save_path))\n\n            if to_copy:\n                file_pre = f'model_best_epoch_{epoch}.pth.tar'\n                self.logger.info('save the best model : %s' % epoch)\n            else:\n                file_pre = f'checkpoint_model_epoch_{epoch}.path.tar'\n\n            filepath = os.path.join(save_path, file_pre)\n            torch.save(state, filepath)\n\n    def validate(self, model, **kwargs):\n        args = self.args\n        loader = self.valid_loader\n        \n        pred_lines_sub = []\n        gt_lines_sub = []\n\n        model.eval()\n\n        # Start validation loop\n        with torch.no_grad():\n            val_pbar = tqdm(total=len(loader), ncols=50)\n            \n            for i, extra_dict in enumerate(loader):\n                val_pbar.update(1)\n\n                if not args.no_cuda:\n                    json_files = extra_dict.pop('idx_json_file')\n                    for k, v in extra_dict.items():\n                        extra_dict[k] = v.cuda()\n                    image = extra_dict['image']\n                image = image.contiguous().float()\n                \n                output = model(image=image, extra_dict=extra_dict, is_training=False)\n                all_line_preds = output['all_line_preds'] # in ground coordinate system\n                all_cls_scores = output['all_cls_scores']\n\n                all_line_preds = all_line_preds[-1]\n                all_cls_scores = all_cls_scores[-1]\n                num_el = all_cls_scores.shape[0]\n                if 'cam_extrinsics' in extra_dict:\n                    cam_extrinsics_all = extra_dict['cam_extrinsics']\n                    cam_intrinsics_all = extra_dict['cam_intrinsics']\n                else:\n                    cam_extrinsics_all, cam_intrinsics_all = None, None\n\n                # Print info\n                if (i + 1) % args.print_freq == 0 and is_main_process():\n                    self.logger.info('Test: [{0}/{1}]'.format(i+1, len(loader)))\n\n                # Write results\n                for j in range(num_el):\n                    json_file = json_files[j]\n                    if cam_extrinsics_all is not None:\n                        extrinsic = cam_extrinsics_all[j].cpu().numpy()\n                        intrinsic = cam_intrinsics_all[j].cpu().numpy()\n\n                    with open(json_file, 'r') as file:\n                        if 'apollo' in args.dataset_name:\n                            json_line = json.loads(file.read())\n                            if 'extrinsic' not in json_line:\n                                json_line['extrinsic'] = extrinsic\n                            if 'intrinsic' not in json_line:\n                                json_line['intrinsic'] = intrinsic\n                        else:\n                            file_lines = [line for line in file]\n                            json_line = json.loads(file_lines[0])\n\n                    json_line['json_file'] = json_file\n                    if 'once' in args.dataset_name:\n                        if 'train' in json_file:\n                            img_path = json_file.replace('train', 'data').replace('.json', '.jpg')\n                        elif 'val' in json_file:\n                            img_path = json_file.replace('val', 'data').replace('.json', '.jpg')\n                        elif 'test' in json_file:\n                            img_path = json_file.replace('test', 'data').replace('.json', '.jpg')\n                        json_line[\"file_path\"] = img_path\n\n                    gt_lines_sub.append(copy.deepcopy(json_line))\n\n                    # pred in ground\n                    lane_pred = all_line_preds[j].cpu().numpy()\n                    cls_pred = torch.argmax(all_cls_scores[j], dim=-1).cpu().numpy()\n                    pos_lanes = lane_pred[cls_pred > 0]\n\n                    if self.args.num_category > 1:\n                        scores_pred = torch.softmax(all_cls_scores[j][cls_pred > 0], dim=-1).cpu().numpy()\n                    else:\n                        scores_pred = torch.sigmoid(all_cls_scores[j][cls_pred > 0]).cpu().numpy()\n\n                    if pos_lanes.shape[0]:\n                        lanelines_pred = []\n                        lanelines_prob = []\n                        xs = pos_lanes[:, 0:args.num_y_steps]\n                        ys = np.tile(args.anchor_y_steps.copy()[None, :], (xs.shape[0], 1))\n                        zs = pos_lanes[:, args.num_y_steps:2*args.num_y_steps]\n                        vis = pos_lanes[:, 2*args.num_y_steps:]\n\n                        for tmp_idx in range(pos_lanes.shape[0]):\n                            cur_vis = vis[tmp_idx] > 0\n                            cur_xs = xs[tmp_idx][cur_vis]\n                            cur_ys = ys[tmp_idx][cur_vis]\n                            cur_zs = zs[tmp_idx][cur_vis]\n\n                            if cur_vis.sum() < 2:\n                                continue\n\n                            lanelines_pred.append([])\n                            for tmp_inner_idx in range(cur_xs.shape[0]):\n                                lanelines_pred[-1].append(\n                                    [cur_xs[tmp_inner_idx],\n                                     cur_ys[tmp_inner_idx],\n                                     cur_zs[tmp_inner_idx]])\n                            lanelines_prob.append(scores_pred[tmp_idx].tolist())\n                    else:\n                        lanelines_pred = []\n                        lanelines_prob = []\n\n                    json_line[\"pred_laneLines\"] = lanelines_pred\n                    json_line[\"pred_laneLines_prob\"] = lanelines_prob\n\n                    pred_lines_sub.append(copy.deepcopy(json_line))\n                    img_path = json_line['file_path']\n                    \n                    if args.dataset_name == 'once':\n                        self.save_eval_result_once(args, img_path, lanelines_pred, lanelines_prob)\n            val_pbar.close()\n\n            if 'openlane' in args.dataset_name:\n                eval_stats = self.evaluator.bench_one_submit_ddp(\n                    pred_lines_sub, gt_lines_sub, args.model_name,\n                    args.pos_threshold, vis=False)\n            elif 'once' in args.dataset_name:\n                eval_stats = self.evaluator.lane_evaluation(\n                    args.data_dir + 'val', '%s/once_pred/test' % (args.save_path),\n                    args.eval_config_dir, args)\n            elif 'apollo' in args.dataset_name:\n                self.logger.info(' >>> eval mAP | [0.05, 0.95]')\n                eval_stats = self.evaluator.bench_one_submit_ddp(\n                    pred_lines_sub, gt_lines_sub,\n                    args.model_name, args.pos_threshold, vis=False)\n            else:\n                assert False\n                \n            if any(name in args.dataset_name for name in ['openlane', 'apollo']):\n                gather_output = [None for _ in range(args.world_size)]\n                # all_gather all eval_stats and calculate mean\n                dist.all_gather_object(gather_output, eval_stats)\n                dist.barrier()\n                eval_stats = self._recal_gpus_val(gather_output, eval_stats)\n\n                loss_list = []\n                return loss_list, eval_stats\n            elif 'once' in args.dataset_name:\n                loss_list = []\n                return loss_list, eval_stats\n\n    def _recal_gpus_val(self, gather_output, eval_stats):\n        args = self.args\n\n        apollo_metrics = {\n            'r_lane': 0, \n            'p_lane': 0, \n            'cnt_gt': 0, \n            'cnt_pred': 0\n        }\n        openlane_metrics = {\n            'r_lane': 0, \n            'p_lane': 0, \n            'c_lane': 0, \n            'cnt_gt': 0, \n            'cnt_pred': 0,\n            'match_num': 0\n        }\n\n        if 'apollo' in self.args.dataset_name:\n            # apollo no category accuracy.\n            start_idx = 7\n            gather_metrics = apollo_metrics\n        else:\n            start_idx = 8\n            gather_metrics = openlane_metrics\n        \n        for i, k in enumerate(gather_metrics.keys()):\n            gather_metrics[k] = np.sum(\n                [eval_stats_sub[start_idx + i] for eval_stats_sub in gather_output])\n\n        if gather_metrics['cnt_gt']!=0 :\n            Recall = gather_metrics['r_lane'] / gather_metrics['cnt_gt']\n        else:\n            Recall = gather_metrics['r_lane'] / (gather_metrics['cnt_gt'] + 1e-6)\n        if gather_metrics['cnt_pred'] !=0 :\n            Precision = gather_metrics['p_lane'] / gather_metrics['cnt_pred']\n        else:\n            Precision = gather_metrics['p_lane'] / (gather_metrics['cnt_pred'] + 1e-6)\n        if (Recall + Precision)!=0:\n            f1_score = 2 * Recall * Precision / (Recall + Precision)\n        else:\n            f1_score = 2 * Recall * Precision / (Recall + Precision + 1e-6)\n        \n        if 'apollo' not in self.args.dataset_name:\n            if gather_metrics['match_num']!=0:\n                category_accuracy = gather_metrics['c_lane'] / gather_metrics['match_num']\n            else:\n                category_accuracy = gather_metrics['c_lane'] / (gather_metrics['match_num'] + 1e-6)\n        \n        eval_stats[0] = f1_score\n        eval_stats[1] = Recall\n        eval_stats[2] = Precision\n        if self.is_apollo:\n            err_start_idx = 3\n        else:\n            eval_stats[3] = category_accuracy\n            err_start_idx = 4\n        for i in range(4):\n            err_idx = err_start_idx + i\n            eval_stats[err_idx] = np.sum([eval_stats_sub[err_idx] for eval_stats_sub in gather_output]) / args.world_size\n        return eval_stats\n\n    def _get_model_from_cfg(self):\n        args = self.args\n        model = LATR(args)\n        \n        if args.sync_bn:\n            if is_main_process():\n                self.logger.info(\"Convert model with Sync BatchNorm\")\n            model = nn.SyncBatchNorm.convert_sync_batchnorm(model)\n            \n        if gpu_available():\n            device = torch.device(\"cuda\", args.local_rank)\n            model = model.to(device)\n\n        return model\n\n    def _load_ckpt_from_workdir(self, model):\n        args = self.args\n        if args.eval_ckpt:\n            best_file_name = args.eval_ckpt\n        else:\n            best_file_name = glob.glob(os.path.join(args.save_path, 'model_best*'))\n            if len(best_file_name) > 0:\n                best_file_name = best_file_name[0]\n            else:\n                best_file_name = ''\n        if os.path.isfile(best_file_name):\n            checkpoint = torch.load(best_file_name)\n            if is_main_process():\n                self.logger.info(\"=> loading checkpoint '{}'\".format(best_file_name))\n                model.load_state_dict(checkpoint['state_dict'])\n        else:\n            self.logger.info(\"=> no checkpoint found at '{}'\".format(best_file_name))\n\n    def eval(self):\n        self.logger.info('>>>>>  start eval <<<<< \\n')\n        args = self.args\n        \n        model = self._get_model_from_cfg()\n        self._load_ckpt_from_workdir(model)\n\n        dist.barrier()\n        # DDP setting\n        if args.distributed:\n            model = DDP(model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True)\n        _, eval_stats = self.validate(model)\n\n        if is_main_process() and (eval_stats is not None):\n            self.log_eval_stats(eval_stats)\n\n    def _get_train_dataset(self):\n        args = self.args\n        if 'openlane' in args.dataset_name:\n            train_dataset = LaneDataset(args.dataset_dir, args.data_dir + 'training/', args, data_aug=True)\n\n        elif 'once' in args.dataset_name:\n            train_dataset = LaneDataset(args.dataset_dir, ops.join(args.data_dir, 'train/'), args, data_aug=True)\n        else:\n            self.logger.info('using Apollo Dataset')\n            train_dataset = ApolloLaneDataset(args.dataset_dir, ops.join(args.data_dir, 'train.json'), args, data_aug=True)\n        \n        train_loader, train_sampler = get_loader(train_dataset, args)\n\n        return train_dataset, train_loader, train_sampler\n\n    def _get_model_ddp(self):\n        args = self.args\n        # define network\n        model = LATR(args)\n        \n        # if args.sync_bn:\n        if is_main_process():\n            self.logger.info(\"Convert model with Sync BatchNorm\")\n        model = nn.SyncBatchNorm.convert_sync_batchnorm(model)\n        \n        if gpu_available():\n            # Load model on gpu before passing params to optimizer\n            device = torch.device(\"cuda\", args.local_rank)\n            model = model.to(device)\n\n        \"\"\"\n            first load param to model, then model = DDP(model)\n        \"\"\"\n\n        # resume model\n        args.resume = first_run(args.save_path)\n\n        model, best_epoch, lowest_loss, best_f1_epoch, best_val_f1, \\\n            optim_saved_state, schedule_saved_state = self.resume_model(model)\n        dist.barrier()\n        # DDP setting\n        if args.distributed:\n            model = DDP(\n                model, device_ids=[args.local_rank],\n                output_device=args.local_rank,\n                find_unused_parameters=True\n            )\n\n        # Define optimizer and scheduler\n        optimizer = build_optimizer(\n            model,\n            args.optimizer_cfg)\n        scheduler = define_scheduler(\n            optimizer, args, dataset_size=len(self.train_loader))\n\n        return model, optimizer, scheduler, best_epoch, lowest_loss, best_f1_epoch, best_val_f1\n\n    def resume_model(self, model, path=''):\n        args = self.args\n        \n        best_epoch = 0\n        lowest_loss = np.inf\n        best_f1_epoch = 0\n        best_val_f1 = -1e-5\n        optim_saved_state = None\n        schedule_saved_state = None\n            \n        if len(path) == 0 and args.resume:\n            # try the latest ckpt\n            path = os.path.join(args.save_path, 'checkpoint_model_epoch_{}.pth.tar'.format(int(args.resume)))\n            # try the best ckpt saved\n            if not os.path.isfile(path):\n                path = os.path.join(args.save_path, f'model_best_epoch_{args.resume}.pth.tar')\n            \n        if os.path.isfile(path):\n            self.logger.info(\"=> loading checkpoint from {}\".format(path))\n            checkpoint = torch.load(path, map_location='cpu')\n            if is_main_process():\n                model.load_state_dict(checkpoint['state_dict'])\n                self.logger.info(\"=> loaded checkpoint '{}' (epoch {})\".format(args.resume, args.start_epoch))\n            \n            optim_saved_state = checkpoint['optimizer']\n            schedule_saved_state = checkpoint['scheduler']\n            \n            args.start_epoch = int(args.resume)\n        else:\n            if is_main_process():\n                self.logger.info(\"=> Warning: no checkpoint found at '{}'\".format(path))\n            \n        return model, best_epoch, lowest_loss, best_f1_epoch, best_val_f1, optim_saved_state, schedule_saved_state\n\n    def _get_valid_dataset(self):\n        args = self.args\n        if 'openlane' in args.dataset_name:\n            if not args.evaluate_case:\n                valid_dataset = LaneDataset(args.dataset_dir, args.data_dir + 'validation/', args)\n            else:\n                # TODO eval case\n                valid_dataset = LaneDataset(args.dataset_dir, args.data_dir + 'test/up_down_case/', args)\n\n        elif 'once' in args.dataset_name:\n            valid_dataset = LaneDataset(args.dataset_dir, ops.join(args.data_dir, 'val/'), args)\n        else:\n            valid_dataset = ApolloLaneDataset(args.dataset_dir, os.path.join(args.data_dir, 'test.json'), args)\n\n        valid_loader, valid_sampler = get_loader(valid_dataset, args)\n        return valid_dataset, valid_loader, valid_sampler\n\n    def save_eval_result_once(self, args, img_path, lanelines_pred, lanelines_prob):\n        # 3d eval result\n        result = {}\n        result_dir = os.path.join(args.save_path, 'once_pred/')\n        mkdir_if_missing(result_dir)\n        result_dir = os.path.join(result_dir, 'test/')\n        mkdir_if_missing(result_dir)\n        file_path_splited = img_path.split('/')\n        result_dir = os.path.join(result_dir, file_path_splited[-3]) # sequence\n        mkdir_if_missing(result_dir)\n        result_dir = os.path.join(result_dir, 'cam01/')\n        mkdir_if_missing(result_dir)\n        result_file_path = ops.join(result_dir, file_path_splited[-1][:-4]+'.json')\n\n        cam_pitch = 0.3/180*np.pi\n        cam_height = 1.5\n        cam_extrinsics = np.array([[np.cos(cam_pitch), 0, -np.sin(cam_pitch), 0],\n                                    [0, 1, 0, 0],\n                                    [np.sin(cam_pitch), 0,  np.cos(cam_pitch), cam_height],\n                                    [0, 0, 0, 1]], dtype=float)\n        R_vg = np.array([[0, 1, 0],\n                            [-1, 0, 0],\n                            [0, 0, 1]], dtype=float)\n        R_gc = np.array([[1, 0, 0],\n                            [0, 0, 1],\n                            [0, -1, 0]], dtype=float)\n        cam_extrinsics[:3, :3] = np.matmul(np.matmul(\n                                    np.matmul(np.linalg.inv(R_vg), cam_extrinsics[:3, :3]),\n                                        R_vg), R_gc)\n        cam_extrinsics[0:2, 3] = 0.0\n\n        # write lane result\n        lane_lines = []\n        for k in range(len(lanelines_pred)):\n            lane = np.array(lanelines_pred[k])\n            lane = np.flip(lane, axis=0)\n            lane = lane.T\n            lane = np.vstack((lane, np.ones((1, lane.shape[1]))))\n            lane = np.matmul(np.linalg.inv(cam_extrinsics), lane)\n            lane = lane[0:3,:].T\n            lane_lines.append({'points': lane.tolist(),\n                               'score': np.max(lanelines_prob[k])})\n        result['lanes'] = lane_lines\n\n        with open(result_file_path, 'w') as result_file:\n            json.dump(result, result_file)\n\n    def log_eval_stats(self, eval_stats):\n        if self.is_apollo:\n            return self._log_genlane_eval_info(eval_stats)\n\n        if is_main_process():\n            self.logger.info(\"===> Evaluation laneline F-measure: {:.8f}\".format(eval_stats[0]))\n            self.logger.info(\"===> Evaluation laneline Recall: {:.8f}\".format(eval_stats[1]))\n            self.logger.info(\"===> Evaluation laneline Precision: {:.8f}\".format(eval_stats[2]))\n            self.logger.info(\"===> Evaluation laneline Category Accuracy: {:.8f}\".format(eval_stats[3]))\n            self.logger.info(\"===> Evaluation laneline x error (close): {:.8f} m\".format(eval_stats[4]))\n            self.logger.info(\"===> Evaluation laneline x error (far): {:.8f} m\".format(eval_stats[5]))\n            self.logger.info(\"===> Evaluation laneline z error (close): {:.8f} m\".format(eval_stats[6]))\n            self.logger.info(\"===> Evaluation laneline z error (far): {:.8f} m\".format(eval_stats[7]))\n\n    def _log_genlane_eval_info(self, eval_stats):\n        if is_main_process():\n            self.logger.info(\"===> Evaluation on validation set: \\n\"\n                \"laneline F-measure {:.8} \\n\"\n                \"laneline Recall  {:.8} \\n\"\n                \"laneline Precision  {:.8} \\n\"\n                \"laneline x error (close)  {:.8} m\\n\"\n                \"laneline x error (far)  {:.8} m\\n\"\n                \"laneline z error (close)  {:.8} m\\n\"\n                \"laneline z error (far)  {:.8} m\\n\".format(eval_stats[0], eval_stats[1],\n                                                            eval_stats[2], eval_stats[3],\n                                                            eval_stats[4], eval_stats[5],\n                                                            eval_stats[6]))\n\n\ndef set_work_dir(cfg):\n    # =========output path========== #\n    save_prefix = osp.join(os.getcwd(), 'work_dirs')\n    save_root = osp.join(save_prefix, cfg.output_dir)\n\n    # cur work dirname\n    cfg_path = Path(cfg.config)\n\n    if cfg.mod is None:\n        cfg.mod = os.path.join(cfg_path.parent.name, cfg_path.stem)\n    \n    save_ppath = Path(save_root, cfg.mod)\n    save_ppath.mkdir(parents=True, exist_ok=True)\n\n    cfg.save_path = save_ppath.as_posix()\n    cfg.save_json_path = cfg.save_path\n    \n    seg_output_dir = Path(cfg.save_path, 'seg_vis')\n    seg_output_dir.mkdir(parents=True, exist_ok=True)\n\n    # cp config into cur_work_dir\n    shutil.copy(cfg_path.as_posix(), cfg.save_path)\n    "
  },
  {
    "path": "main.py",
    "content": "import argparse\nfrom mmcv.utils import Config, DictAction\n\nfrom utils.utils import *\nfrom experiments.ddp import *\nfrom experiments.runner import *\n\n\ndef get_args():\n    parser = argparse.ArgumentParser()\n    # DDP setting\n    parser.add_argument('--distributed', action='store_true')\n    parser.add_argument(\"--local_rank\", type=int)\n    parser.add_argument('--gpu', type=int, default=0)\n    parser.add_argument('--world_size', type=int, default=1)\n    parser.add_argument('--nodes', type=int, default=1)\n    parser.add_argument('--use_slurm', default=False, action='store_true')\n\n    # exp setting\n    parser.add_argument('--config', type=str, help='config file path')\n    parser.add_argument(\n        '--cfg-options',\n        nargs='+',\n        action=DictAction,\n        help='overwrite config param.')\n    return parser.parse_args()\n\n\ndef main():\n    args = get_args()\n    # define runner to begin training or evaluation\n    cfg = Config.fromfile(args.config)\n    if args.cfg_options is not None:\n        cfg.merge_from_dict(args.cfg_options)\n\n    # initialize distributed data parallel set\n    ddp_init(args)\n    cfg.merge_from_dict(vars(args))\n    \n    runner = Runner(cfg)\n    if not cfg.evaluate:\n        runner.train()\n    else:\n        runner.eval()\n\n\nif __name__ == '__main__':\n    main()"
  },
  {
    "path": "models/__init__.py",
    "content": ""
  },
  {
    "path": "models/latr.py",
    "content": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom utils.utils import *\nfrom mmdet3d.models import build_backbone, build_neck\nfrom .latr_head import LATRHead\nfrom mmcv.utils import Config\nfrom .ms2one import build_ms2one\nfrom .utils import deepFeatureExtractor_EfficientNet\n\nfrom mmdet.models.builder import BACKBONES\n\n\n# overall network\nclass LATR(nn.Module):\n    def __init__(self, args):\n        super().__init__()\n        self.no_cuda = args.no_cuda\n        self.batch_size = args.batch_size\n        self.num_lane_type = 1  # no centerline\n        self.num_y_steps = args.num_y_steps\n        self.max_lanes = args.max_lanes\n        self.num_category = args.num_category\n        _dim_ = args.latr_cfg.fpn_dim\n        num_query = args.latr_cfg.num_query\n        num_group = args.latr_cfg.num_group\n        sparse_num_group = args.latr_cfg.sparse_num_group\n\n        self.encoder = build_backbone(args.latr_cfg.encoder)\n        if getattr(args.latr_cfg, 'neck', None):\n            self.neck = build_neck(args.latr_cfg.neck)\n        else:\n            self.neck = None\n        self.encoder.init_weights()\n        self.ms2one = build_ms2one(args.ms2one)\n\n        # build 2d query-based instance seg\n        self.head = LATRHead(\n            args=args,\n            dim=_dim_,\n            num_group=num_group,\n            num_convs=4,\n            in_channels=_dim_,\n            kernel_dim=_dim_,\n            position_range=args.position_range,\n            top_view_region=args.top_view_region,\n            positional_encoding=dict(\n                type='SinePositionalEncoding',\n                num_feats=_dim_// 2, normalize=True),\n            num_query=num_query,\n            pred_dim=self.num_y_steps,\n            num_classes=args.num_category,\n            embed_dims=_dim_,\n            transformer=args.transformer,\n            sparse_ins_decoder=args.sparse_ins_decoder,\n            **args.latr_cfg.get('head', {}),\n            trans_params=args.latr_cfg.get('trans_params', {})\n        )\n\n    def forward(self, image, _M_inv=None, is_training=True, extra_dict=None):\n        out_featList = self.encoder(image)\n        neck_out = self.neck(out_featList)\n        neck_out = self.ms2one(neck_out)\n\n        output = self.head(\n            dict(\n                x=neck_out,\n                lane_idx=extra_dict['seg_idx_label'],\n                seg=extra_dict['seg_label'],\n                lidar2img=extra_dict['lidar2img'],\n                pad_shape=extra_dict['pad_shape'],\n                ground_lanes=extra_dict['ground_lanes'] if is_training else None,\n                ground_lanes_dense=extra_dict['ground_lanes_dense'] if is_training else None,\n                image=image,\n            ),\n            is_training=is_training,\n        )\n        return output"
  },
  {
    "path": "models/latr_head.py",
    "content": "import numpy as np\nimport math\nimport cv2\n\nimport torch\nimport torch.nn as nn\nfrom torch.nn import init\nimport torch.nn.functional as F\nfrom torch.nn.init import normal_\n\nfrom mmcv.cnn import bias_init_with_prob\nfrom mmdet.models.builder import build_loss\nfrom mmdet.models.utils import build_transformer\nfrom mmdet.core import multi_apply\n\nfrom mmcv.utils import Config\nfrom models.sparse_ins import SparseInsDecoder\nfrom .utils import inverse_sigmoid\nfrom .transformer_bricks import *\n\n\nclass LATRHead(nn.Module):\n    def __init__(self, args,\n                 dim=128,\n                 num_group=1,\n                 num_convs=4,\n                 in_channels=128,\n                 kernel_dim=128,\n                 positional_encoding=dict(\n                    type='SinePositionalEncoding',\n                    num_feats=128 // 2, normalize=True),\n                 num_classes=21,\n                 num_query=30,\n                 embed_dims=128,\n                 transformer=None,\n                 num_reg_fcs=2,\n                 depth_num=50,\n                 depth_start=3,\n                 top_view_region=None,\n                 position_range=[-50, 3, -10, 50, 103, 10.],\n                 pred_dim=10,\n                 loss_cls=dict(\n                     type='FocalLoss',\n                     use_sigmoid=True,\n                     gamma=2.0,\n                     alpha=0.25,\n                     loss_weight=2.0),\n                 loss_reg=dict(type='L1Loss', loss_weight=2.0),\n                 loss_vis=dict(type='BCEWithLogitsLoss', reduction='mean'),\n                 sparse_ins_decoder=Config(\n                    dict(\n                        encoder=dict(\n                            out_dims=64),# neck output feature channels\n                        decoder=dict(\n                            num_group=1,\n                            output_iam=True,\n                            scale_factor=1.),\n                        sparse_decoder_weight=1.0,\n                        )),\n                 xs_loss_weight=1.0,\n                 zs_loss_weight=5.0,\n                 vis_loss_weight=1.0,\n                 cls_loss_weight=20,\n                 project_loss_weight=1.0,\n                 trans_params=dict(\n                     init_z=0, bev_h=250, bev_w=100),\n                 pt_as_query=False,\n                 num_pt_per_line=5,\n                 num_feature_levels=1,\n                 gt_project_h=20,\n                 gt_project_w=30,\n                 project_crit=dict(\n                     type='SmoothL1Loss',\n                     reduction='none'),\n                 ):\n        super().__init__()\n        self.trans_params = dict(\n            top_view_region=top_view_region,\n            z_region=[position_range[2], position_range[5]])\n        self.trans_params.update(trans_params)\n        self.gt_project_h = gt_project_h\n        self.gt_project_w = gt_project_w\n\n        self.num_y_steps = args.num_y_steps\n        self.register_buffer('anchor_y_steps',\n            torch.from_numpy(args.anchor_y_steps).float())\n        self.register_buffer('anchor_y_steps_dense',\n            torch.from_numpy(args.anchor_y_steps_dense).float())\n\n        project_crit['reduction'] = 'none'\n        self.project_crit = getattr(\n            nn, project_crit.pop('type'))(**project_crit)\n\n        self.num_classes = num_classes\n        self.embed_dims = embed_dims\n        # points num along y-axis.\n        self.code_size = pred_dim\n        self.num_query = num_query\n        self.num_group = num_group\n        self.num_pred = transformer['decoder']['num_layers']\n        self.pc_range = position_range\n        self.xs_loss_weight = xs_loss_weight\n        self.zs_loss_weight = zs_loss_weight\n        self.vis_loss_weight = vis_loss_weight\n        self.cls_loss_weight = cls_loss_weight\n        self.project_loss_weight = project_loss_weight\n\n        loss_reg['reduction'] = 'none'\n        self.reg_crit = build_loss(loss_reg)\n        self.cls_crit = build_loss(loss_cls)\n        self.bce_loss = build_nn_loss(loss_vis)\n        self.sparse_ins = SparseInsDecoder(cfg=sparse_ins_decoder)\n\n        self.depth_num = depth_num\n        self.position_dim = 3 * self.depth_num\n        self.position_range = position_range\n        self.depth_start = depth_start\n        self.adapt_pos3d = nn.Sequential(\n            nn.Conv2d(self.embed_dims, self.embed_dims*4, kernel_size=1, stride=1, padding=0),\n            nn.ReLU(),\n            nn.Conv2d(self.embed_dims*4, self.embed_dims, kernel_size=1, stride=1, padding=0),\n        )\n        self.positional_encoding = build_positional_encoding(positional_encoding)\n        self.position_encoder = nn.Sequential(\n            nn.Conv2d(self.position_dim, self.embed_dims*4, kernel_size=1, stride=1, padding=0),\n            nn.ReLU(),\n            nn.Conv2d(self.embed_dims*4, self.embed_dims, kernel_size=1, stride=1, padding=0),\n        )\n        self.transformer = build_transformer(transformer)\n        self.query_embedding = nn.Sequential(\n            nn.Linear(self.embed_dims, self.embed_dims),\n            nn.ReLU(),\n            nn.Linear(self.embed_dims, self.embed_dims),\n        )\n\n        # build pred layer: cls, reg, vis\n        self.num_reg_fcs = num_reg_fcs\n        cls_branch = []\n        for _ in range(self.num_reg_fcs):\n            cls_branch.append(nn.Linear(self.embed_dims, self.embed_dims))\n            cls_branch.append(nn.LayerNorm(self.embed_dims))\n            cls_branch.append(nn.ReLU(inplace=True))\n        cls_branch.append(nn.Linear(self.embed_dims, self.num_classes))\n        fc_cls = nn.Sequential(*cls_branch)\n\n        reg_branch = []\n        for _ in range(self.num_reg_fcs):\n            reg_branch.append(nn.Linear(self.embed_dims, self.embed_dims))\n            reg_branch.append(nn.ReLU())\n        reg_branch.append(\n            nn.Linear(\n                self.embed_dims,\n                3 * self.code_size // num_pt_per_line))\n        reg_branch = nn.Sequential(*reg_branch)\n\n        self.cls_branches = nn.ModuleList(\n            [fc_cls for _ in range(self.num_pred)])\n        self.reg_branches = nn.ModuleList(\n            [reg_branch for _ in range(self.num_pred)])\n\n        self.num_pt_per_line = num_pt_per_line\n        self.point_embedding = nn.Embedding(\n            self.num_pt_per_line, self.embed_dims)\n\n        self.reference_points = nn.Sequential(\n            nn.Linear(self.embed_dims, self.embed_dims),\n            nn.ReLU(True),\n            nn.Linear(self.embed_dims, self.embed_dims),\n            nn.ReLU(True),\n            nn.Linear(self.embed_dims, 2 * self.code_size // num_pt_per_line))\n        self.num_feature_levels = num_feature_levels\n        self.level_embeds = nn.Parameter(torch.Tensor(\n            self.num_feature_levels, self.embed_dims))\n\n        self._init_weights()\n\n    def _init_weights(self):\n        self.transformer.init_weights()\n        xavier_init(self.reference_points, distribution='uniform', bias=0)\n        if self.cls_crit.use_sigmoid:\n            bias_init = bias_init_with_prob(0.01)\n            for m in self.cls_branches:\n                nn.init.constant_(m[-1].bias, bias_init)\n        normal_(self.level_embeds)\n\n    def forward(self, input_dict, is_training=True):\n        output_dict = {}\n        img_feats = input_dict['x']\n\n        if not isinstance(img_feats, (list, tuple)):\n            img_feats = [img_feats]\n\n        sparse_output = self.sparse_ins(\n            img_feats[0],\n            lane_idx_map=input_dict['lane_idx'],\n            input_shape=input_dict['seg'].shape[-2:],\n            is_training=is_training)\n        # generate 2d pos emb\n        B, C, H, W = img_feats[0].shape\n        masks = img_feats[0].new_zeros((B, H, W))\n\n        # TODO use actual mask if using padding or other aug\n        sin_embed = self.positional_encoding(masks)\n        sin_embed = self.adapt_pos3d(sin_embed)\n\n        # init query and reference pt\n        query = sparse_output['inst_features'] # BxNxC\n        # B, N, C -> B, N, num_anchor_per_line, C\n        query = query.unsqueeze(2) + self.point_embedding.weight[None, None, ...]\n       \n        query_embeds = self.query_embedding(query).flatten(1, 2)\n        query = torch.zeros_like(query_embeds)\n        reference_points = self.reference_points(query_embeds)\n        reference_points = reference_points.sigmoid()\n        mlvl_feats = img_feats\n\n        feat_flatten = []\n        spatial_shapes = []\n        mlvl_masks = []\n\n        assert self.num_feature_levels == len(mlvl_feats)\n        for lvl, feat in enumerate(mlvl_feats):\n            bs, c, h, w = feat.shape\n            spatial_shape = (h, w)\n            feat = feat.flatten(2).permute(2, 0, 1) # NxBxC\n            feat = feat + self.level_embeds[None, lvl:lvl+1, :].to(feat.device)\n            spatial_shapes.append(spatial_shape)\n            feat_flatten.append(feat)\n            mlvl_masks.append(torch.zeros((bs, *spatial_shape),\n                                           dtype=torch.bool,\n                                           device=feat.device))\n\n        if self.transformer.with_encoder:\n            mlvl_positional_encodings = []\n            pos_embed2d = []\n            for lvl, feat in enumerate(mlvl_feats):\n                mlvl_positional_encodings.append(\n                    self.positional_encoding(mlvl_masks[lvl]))\n                pos_embed2d.append(\n                    mlvl_positional_encodings[-1].flatten(2).permute(2, 0, 1))\n            pos_embed2d = torch.cat(pos_embed2d, 0)\n        else:\n            mlvl_positional_encodings = None\n            pos_embed2d = None\n\n        feat_flatten = torch.cat(feat_flatten, 0)\n\n        spatial_shapes = torch.as_tensor(\n            spatial_shapes, dtype=torch.long, device=query.device)\n        level_start_index = torch.cat(\n            (spatial_shapes.new_zeros((1, )),\n             spatial_shapes.prod(1).cumsum(0)[:-1])\n        )\n\n        # head\n        pos_embed = None\n        outs_dec, project_results, outputs_classes, outputs_coords = \\\n            self.transformer(\n                feat_flatten, None,\n                query, query_embeds, pos_embed,\n                reference_points=reference_points,\n                reg_branches=self.reg_branches,\n                cls_branches=self.cls_branches,\n                img_feats=img_feats,\n                lidar2img=input_dict['lidar2img'],\n                pad_shape=input_dict['pad_shape'],\n                sin_embed=sin_embed,\n                spatial_shapes=spatial_shapes,\n                level_start_index=level_start_index,\n                mlvl_masks=mlvl_masks,\n                mlvl_positional_encodings=mlvl_positional_encodings,\n                pos_embed2d=pos_embed2d,\n                image=input_dict['image'],\n                **self.trans_params)\n\n        all_cls_scores = torch.stack(outputs_classes)\n        all_line_preds = torch.stack(outputs_coords)\n        all_line_preds[..., 0] = (all_line_preds[..., 0]\n            * (self.pc_range[3] - self.pc_range[0]) + self.pc_range[0])\n        all_line_preds[..., 1] = (all_line_preds[..., 1]\n            * (self.pc_range[5] - self.pc_range[2]) + self.pc_range[2])\n\n        # reshape to original format\n        all_line_preds = all_line_preds.view(\n            len(outputs_classes), bs, self.num_query,\n            self.transformer.decoder.num_anchor_per_query,\n            self.transformer.decoder.num_points_per_anchor, 2 + 1 # xz+vis\n        )\n        all_line_preds = all_line_preds.permute(0, 1, 2, 5, 3, 4)\n        all_line_preds = all_line_preds.flatten(3, 5)\n\n        output_dict.update({\n            'all_cls_scores': all_cls_scores,\n            'all_line_preds': all_line_preds,\n        })\n        output_dict.update(sparse_output)\n\n        if is_training:\n            losses = self.get_loss(output_dict, input_dict)\n            project_loss = self.get_project_loss(\n                project_results, input_dict,\n                h=self.gt_project_h, w=self.gt_project_w)\n            losses['project_loss'] = \\\n                self.project_loss_weight * project_loss\n            output_dict.update(losses)\n        return output_dict\n\n    def get_project_loss(self, results, input_dict, h=20, w=30):\n        gt_lane = input_dict['ground_lanes_dense']\n        gt_ys = self.anchor_y_steps_dense.clone()\n        code_size = gt_ys.shape[0]\n        gt_xs = gt_lane[..., :code_size]\n        gt_zs = gt_lane[..., code_size : 2*code_size]\n        gt_vis = gt_lane[..., 2*code_size:3*code_size]\n        gt_ys = gt_ys[None, None, :].expand_as(gt_xs)\n        gt_points = torch.stack([gt_xs, gt_ys, gt_zs], dim=-1)\n\n        B = results[0].shape[0]\n        ref_3d_home = F.pad(gt_points, (0, 1), value=1)\n        coords_img = ground2img(\n            ref_3d_home,\n            h, w,\n            input_dict['lidar2img'],\n            input_dict['pad_shape'], mask=gt_vis)\n\n        all_loss = 0.\n        for projct_result in results:\n            projct_result = F.interpolate(\n                projct_result,\n                size=(h, w),\n                mode='nearest')\n            gt_proj = coords_img.clone()\n\n            mask = (gt_proj[:, -1, ...] > 0) * (projct_result[:, -1, ...] > 0)\n            diff_loss = self.project_crit(\n                projct_result[:, :3, ...],\n                gt_proj[:, :3, ...],\n            )\n            diff_y_loss = diff_loss[:, 1, ...]\n            diff_z_loss = diff_loss[:, 2, ...]\n            diff_loss = diff_y_loss * 0.1 + diff_z_loss\n            diff_loss = (diff_loss * mask).sum() / torch.clamp(mask.sum(), 1)\n            all_loss = all_loss + diff_loss\n\n        return all_loss / len(results)\n\n    def get_loss(self, output_dict, input_dict):\n        all_cls_pred = output_dict['all_cls_scores']\n        all_lane_pred = output_dict['all_line_preds']\n        gt_lanes = input_dict['ground_lanes']\n        all_xs_loss = 0.0\n        all_zs_loss = 0.0\n        all_vis_loss = 0.0\n        all_cls_loss = 0.0\n        matched_indices = output_dict['matched_indices']\n        num_layers = all_lane_pred.shape[0]\n\n        def single_layer_loss(layer_idx):\n            gcls_pred = all_cls_pred[layer_idx]\n            glane_pred = all_lane_pred[layer_idx]\n\n            glane_pred = glane_pred.view(\n                glane_pred.shape[0],\n                self.num_group,\n                self.num_query,\n                glane_pred.shape[-1])\n            gcls_pred = gcls_pred.view(\n                gcls_pred.shape[0],\n                self.num_group,\n                self.num_query,\n                gcls_pred.shape[-1])\n\n            per_xs_loss = 0.0\n            per_zs_loss = 0.0\n            per_vis_loss = 0.0\n            per_cls_loss = 0.0\n            batch_size = len(matched_indices[0])\n\n            for b_idx in range(len(matched_indices[0])):\n                for group_idx in range(self.num_group):\n                    pred_idx = matched_indices[group_idx][b_idx][0]\n                    gt_idx = matched_indices[group_idx][b_idx][1]\n\n                    cls_pred = gcls_pred[:, group_idx, ...]\n                    lane_pred = glane_pred[:, group_idx, ...]\n\n                    if gt_idx.shape[0] < 1:\n                        cls_target = cls_pred.new_zeros(cls_pred[b_idx].shape[0]).long()\n                        cls_loss = self.cls_crit(cls_pred[b_idx], cls_target)\n                        per_cls_loss = per_cls_loss + cls_loss\n                        per_xs_loss = per_xs_loss + 0.0 * lane_pred[b_idx].mean()\n                        continue\n\n                    pos_lane_pred = lane_pred[b_idx][pred_idx]\n                    gt_lane = gt_lanes[b_idx][gt_idx]\n\n                    pred_xs = pos_lane_pred[:, :self.code_size]\n                    pred_zs = pos_lane_pred[:, self.code_size : 2*self.code_size]\n                    pred_vis = pos_lane_pred[:, 2*self.code_size:]\n                    gt_xs = gt_lane[:, :self.code_size]\n                    gt_zs = gt_lane[:, self.code_size : 2*self.code_size]\n                    gt_vis = gt_lane[:, 2*self.code_size:3*self.code_size]\n\n                    loc_mask = gt_vis > 0\n                    xs_loss = self.reg_crit(pred_xs, gt_xs)\n                    zs_loss = self.reg_crit(pred_zs, gt_zs)\n                    xs_loss = (xs_loss * loc_mask).sum() / torch.clamp(loc_mask.sum(), 1)\n                    zs_loss = (zs_loss * loc_mask).sum() / torch.clamp(loc_mask.sum(), 1)\n                    vis_loss = self.bce_loss(pred_vis, gt_vis)\n\n                    cls_target = cls_pred.new_zeros(cls_pred[b_idx].shape[0]).long()\n                    cls_target[pred_idx] = torch.argmax(\n                        gt_lane[:, 3*self.code_size:], dim=1)\n                    cls_loss = self.cls_crit(cls_pred[b_idx], cls_target)\n\n                    per_xs_loss += xs_loss\n                    per_zs_loss += zs_loss\n                    per_vis_loss += vis_loss\n                    per_cls_loss += cls_loss\n\n            return tuple(map(lambda x: x / batch_size / self.num_group,\n                             [per_xs_loss, per_zs_loss, per_vis_loss, per_cls_loss]))\n\n        all_xs_loss, all_zs_loss, all_vis_loss, all_cls_loss = multi_apply(\n            single_layer_loss, range(all_lane_pred.shape[0]))\n        all_xs_loss = sum(all_xs_loss) / num_layers\n        all_zs_loss = sum(all_zs_loss) / num_layers\n        all_vis_loss = sum(all_vis_loss) / num_layers\n        all_cls_loss = sum(all_cls_loss) / num_layers\n\n        return dict(\n            all_xs_loss=self.xs_loss_weight * all_xs_loss,\n            all_zs_loss=self.zs_loss_weight * all_zs_loss,\n            all_vis_loss=self.vis_loss_weight * all_vis_loss,\n            all_cls_loss=self.cls_loss_weight * all_cls_loss,\n        )\n\n    @staticmethod\n    def get_reference_points(H, W, bs=1, device='cuda', dtype=torch.float):\n        ref_y, ref_x = torch.meshgrid(\n            torch.linspace(\n                0.5, H - 0.5, H, dtype=dtype, device=device),\n            torch.linspace(\n                0.5, W - 0.5, W, dtype=dtype, device=device)\n        )\n        ref_y = ref_y.reshape(-1)[None] / H\n        ref_x = ref_x.reshape(-1)[None] / W\n        ref_2d = torch.stack((ref_x, ref_y), -1)\n        ref_2d = ref_2d.repeat(bs, 1, 1) \n        return ref_2d\n\ndef build_nn_loss(loss_cfg):\n    crit_t = loss_cfg.pop('type')\n    return getattr(nn, crit_t)(**loss_cfg)"
  },
  {
    "path": "models/ms2one.py",
    "content": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport copy\nfrom mmcv.cnn import ConvModule\nfrom mmseg.ops import resize\n\n\ndef build_ms2one(config):\n    config = copy.deepcopy(config)\n    t = config.pop('type')\n    if t == 'Naive':\n        return Naive(**config)\n    elif t == 'DilateNaive':\n        return DilateNaive(**config)\n\n\nclass Naive(nn.Module):\n    def __init__(self, inc, outc, kernel_size=1):\n        super().__init__()\n        self.layer = nn.Conv2d(inc, outc, kernel_size=1)\n\n    def forward(self, ms_feats):\n        out = self.layer(torch.cat([\n            F.interpolate(tmp, ms_feats[0].shape[-2:],\n                          mode='bilinear') for tmp in ms_feats], dim=1))\n        return out\n\n\nclass DilateNaive(nn.Module):\n    def __init__(self, inc, outc, num_scales=4,\n                 dilations=(1, 2, 5, 9),\n                 merge=True, fpn=False,\n                 target_shape=None,\n                 one_layer_before=False):\n        super().__init__()\n        self.dilations = dilations\n        self.num_scales = num_scales\n        if not isinstance(inc, (tuple, list)):\n            inc = [inc for _ in range(num_scales)]\n        self.inc = inc\n        self.outc = outc\n        self.merge = merge\n        self.fpn = fpn\n        self.target_shape = target_shape\n        self.layers = nn.ModuleList()\n        for i in range(num_scales):\n            layers = []\n            if one_layer_before:\n                layers.extend([\n                    nn.Conv2d(inc[i], outc, kernel_size=1, bias=False),\n                    nn.BatchNorm2d(outc),\n                    nn.ReLU(True)\n                ])\n            for j in range(len(dilations[:-i])):\n                d = dilations[j]\n                layers.append(nn.Sequential(\n                    nn.Conv2d(inc[i] if j == 0 and not one_layer_before else outc, outc,\n                              kernel_size=1 if d == 1 else 3,\n                              stride=1,\n                              padding=0 if d == 1 else d,\n                              dilation=d,\n                              bias=False),\n                    nn.BatchNorm2d(outc),\n                    nn.ReLU(True)))\n            self.layers.append(nn.Sequential(*layers))\n        if self.merge:\n            self.final_layer = nn.Sequential(\n                nn.Conv2d(outc, outc, 3, 1, padding=1, bias=False),\n                nn.BatchNorm2d(outc),\n                nn.ReLU(True),\n                nn.Conv2d(outc, outc, 1))\n\n    def forward(self, x):\n        outs = []\n\n        for i in range(self.num_scales - 1, -1, -1):\n            if self.fpn and i < self.num_scales - 1:\n                tmp = self.layers[i](x[i] + F.interpolate(\n                    x[i + 1], x[i].shape[2:],\n                    mode='bilinear', align_corners=True))\n            else:\n                tmp = self.layers[i](x[i])\n\n            if self.target_shape is None:\n                if i > 0 and self.merge:\n                    tmp = F.interpolate(tmp, x[0].shape[2:],\n                        mode='bilinear', align_corners=True)\n            else:\n                tmp = F.interpolate(tmp, self.target_shape,\n                        mode='bilinear', align_corners=True)\n            outs.append(tmp)\n        if self.merge:\n            out = torch.sum(torch.stack(outs, dim=-1), dim=-1)\n            out = self.final_layer(out)\n            \n            return out\n        else:\n            return outs"
  },
  {
    "path": "models/scatter_utils.py",
    "content": "# Copy from https://github.com/rusty1s/pytorch_scatter\n\nfrom typing import Optional, Tuple\n\nimport torch\n\n\ndef broadcast(src: torch.Tensor, other: torch.Tensor, dim: int):\n    if dim < 0:\n        dim = other.dim() + dim\n    if src.dim() == 1:\n        for _ in range(0, dim):\n            src = src.unsqueeze(0)\n    for _ in range(src.dim(), other.dim()):\n        src = src.unsqueeze(-1)\n    src = src.expand(other.size())\n    return src\n\n\ndef scatter_sum(src: torch.Tensor, index: torch.Tensor, dim: int = -1,\n                out: Optional[torch.Tensor] = None,\n                dim_size: Optional[int] = None) -> torch.Tensor:\n    index = broadcast(index, src, dim)\n    if out is None:\n        size = list(src.size())\n        if dim_size is not None:\n            size[dim] = dim_size\n        elif index.numel() == 0:\n            size[dim] = 0\n        else:\n            size[dim] = int(index.max()) + 1\n        out = torch.zeros(size, dtype=src.dtype, device=src.device)\n        return out.scatter_add_(dim, index, src)\n    else:\n        return out.scatter_add_(dim, index, src)\n\n\ndef scatter_add(src: torch.Tensor, index: torch.Tensor, dim: int = -1,\n                out: Optional[torch.Tensor] = None,\n                dim_size: Optional[int] = None) -> torch.Tensor:\n    return scatter_sum(src, index, dim, out, dim_size)\n\n\ndef scatter_mul(src: torch.Tensor, index: torch.Tensor, dim: int = -1,\n                out: Optional[torch.Tensor] = None,\n                dim_size: Optional[int] = None) -> torch.Tensor:\n    return torch.ops.torch_scatter.scatter_mul(src, index, dim, out, dim_size)\n\n\ndef scatter_mean(src: torch.Tensor, index: torch.Tensor, dim: int = -1,\n                 out: Optional[torch.Tensor] = None,\n                 dim_size: Optional[int] = None) -> torch.Tensor:\n    out = scatter_sum(src, index, dim, out, dim_size)\n    dim_size = out.size(dim)\n\n    index_dim = dim\n    if index_dim < 0:\n        index_dim = index_dim + src.dim()\n    if index.dim() <= index_dim:\n        index_dim = index.dim() - 1\n\n    ones = torch.ones(index.size(), dtype=src.dtype, device=src.device)\n    count = scatter_sum(ones, index, index_dim, None, dim_size)\n    count[count < 1] = 1\n    count = broadcast(count, out, dim)\n    if out.is_floating_point():\n        out.true_divide_(count)\n    else:\n        out.div_(count, rounding_mode='floor')\n    return out\n\n\ndef scatter_min(\n        src: torch.Tensor, index: torch.Tensor, dim: int = -1,\n        out: Optional[torch.Tensor] = None,\n        dim_size: Optional[int] = None) -> Tuple[torch.Tensor, torch.Tensor]:\n    return torch.ops.torch_scatter.scatter_min(src, index, dim, out, dim_size)\n\n\ndef scatter_max(\n        src: torch.Tensor, index: torch.Tensor, dim: int = -1,\n        out: Optional[torch.Tensor] = None,\n        dim_size: Optional[int] = None) -> Tuple[torch.Tensor, torch.Tensor]:\n    return torch.ops.torch_scatter.scatter_max(src, index, dim, out, dim_size)\n\n\ndef scatter(src: torch.Tensor, index: torch.Tensor, dim: int = -1,\n            out: Optional[torch.Tensor] = None, dim_size: Optional[int] = None,\n            reduce: str = \"sum\") -> torch.Tensor:\n    r\"\"\"\n    |\n\n    .. image:: https://raw.githubusercontent.com/rusty1s/pytorch_scatter/\n            master/docs/source/_figures/add.svg?sanitize=true\n        :align: center\n        :width: 400px\n\n    |\n\n    Reduces all values from the :attr:`src` tensor into :attr:`out` at the\n    indices specified in the :attr:`index` tensor along a given axis\n    :attr:`dim`.\n    For each value in :attr:`src`, its output index is specified by its index\n    in :attr:`src` for dimensions outside of :attr:`dim` and by the\n    corresponding value in :attr:`index` for dimension :attr:`dim`.\n    The applied reduction is defined via the :attr:`reduce` argument.\n\n    Formally, if :attr:`src` and :attr:`index` are :math:`n`-dimensional\n    tensors with size :math:`(x_0, ..., x_{i-1}, x_i, x_{i+1}, ..., x_{n-1})`\n    and :attr:`dim` = `i`, then :attr:`out` must be an :math:`n`-dimensional\n    tensor with size :math:`(x_0, ..., x_{i-1}, y, x_{i+1}, ..., x_{n-1})`.\n    Moreover, the values of :attr:`index` must be between :math:`0` and\n    :math:`y - 1`, although no specific ordering of indices is required.\n    The :attr:`index` tensor supports broadcasting in case its dimensions do\n    not match with :attr:`src`.\n\n    For one-dimensional tensors with :obj:`reduce=\"sum\"`, the operation\n    computes\n\n    .. math::\n        \\mathrm{out}_i = \\mathrm{out}_i + \\sum_j~\\mathrm{src}_j\n\n    where :math:`\\sum_j` is over :math:`j` such that\n    :math:`\\mathrm{index}_j = i`.\n\n    .. note::\n\n        This operation is implemented via atomic operations on the GPU and is\n        therefore **non-deterministic** since the order of parallel operations\n        to the same value is undetermined.\n        For floating-point variables, this results in a source of variance in\n        the result.\n\n    :param src: The source tensor.\n    :param index: The indices of elements to scatter.\n    :param dim: The axis along which to index. (default: :obj:`-1`)\n    :param out: The destination tensor.\n    :param dim_size: If :attr:`out` is not given, automatically create output\n        with size :attr:`dim_size` at dimension :attr:`dim`.\n        If :attr:`dim_size` is not given, a minimal sized output tensor\n        according to :obj:`index.max() + 1` is returned.\n    :param reduce: The reduce operation (:obj:`\"sum\"`, :obj:`\"mul\"`,\n        :obj:`\"mean\"`, :obj:`\"min\"` or :obj:`\"max\"`). (default: :obj:`\"sum\"`)\n\n    :rtype: :class:`Tensor`\n\n    .. code-block:: python\n\n        from torch_scatter import scatter\n\n        src = torch.randn(10, 6, 64)\n        index = torch.tensor([0, 1, 0, 1, 2, 1])\n\n        # Broadcasting in the first and last dim.\n        out = scatter(src, index, dim=1, reduce=\"sum\")\n\n        print(out.size())\n\n    .. code-block::\n\n        torch.Size([10, 3, 64])\n    \"\"\"\n    if reduce == 'sum' or reduce == 'add':\n        return scatter_sum(src, index, dim, out, dim_size)\n    if reduce == 'mul':\n        return scatter_mul(src, index, dim, out, dim_size)\n    elif reduce == 'mean':\n        return scatter_mean(src, index, dim, out, dim_size)\n    elif reduce == 'min':\n        return scatter_min(src, index, dim, out, dim_size)[0]\n    elif reduce == 'max':\n        return scatter_max(src, index, dim, out, dim_size)[0]\n    else:\n        raise ValueError"
  },
  {
    "path": "models/sparse_ins.py",
    "content": "import math\n\nimport torch\nimport torch.nn as nn\nfrom torch.nn import init\nimport torch.nn.functional as F\nfrom fvcore.nn.weight_init import c2_msra_fill, c2_xavier_fill\nfrom .sparse_inst_loss import SparseInstCriterion, SparseInstMatcher\n\ndef _make_stack_3x3_convs(num_convs, in_channels, out_channels):\n    convs = []\n    for _ in range(num_convs):\n        convs.append(\n            nn.Conv2d(in_channels, out_channels, 3, padding=1))\n        convs.append(nn.ReLU(True))\n        in_channels = out_channels\n    return nn.Sequential(*convs)\n\n\nclass MaskBranch(nn.Module):\n    def __init__(self, cfg, in_channels):\n        super().__init__()\n        dim = cfg.hidden_dim\n        num_convs = cfg.num_convs\n        kernel_dim = cfg.kernel_dim\n        self.mask_convs = _make_stack_3x3_convs(num_convs, in_channels, dim)\n        self.projection = nn.Conv2d(dim, kernel_dim, kernel_size=1)\n        self._init_weights()\n\n    def _init_weights(self):\n        for m in self.mask_convs.modules():\n            if isinstance(m, nn.Conv2d):\n                c2_msra_fill(m)\n        c2_msra_fill(self.projection)\n\n    def forward(self, features):\n        # mask features (x4 convs)\n        features = self.mask_convs(features)\n        return self.projection(features)\n\n\nclass InstanceBranch(nn.Module):\n    def __init__(self, cfg, in_channels, **kwargs):\n        super().__init__()\n        num_mask = cfg.num_query\n        dim = cfg.hidden_dim\n        num_classes = cfg.num_classes\n        kernel_dim = cfg.kernel_dim\n        num_convs = cfg.num_convs\n        num_group = cfg.get('num_group', 1)\n        sparse_num_group = cfg.get('sparse_num_group', 1)\n        self.num_group = num_group\n        self.sparse_num_group = sparse_num_group\n        self.num_mask = num_mask\n        self.inst_convs = _make_stack_3x3_convs(\n                            num_convs=num_convs, \n                            in_channels=in_channels, \n                            out_channels=dim)\n\n        self.iam_conv = nn.Conv2d(\n            dim * num_group,\n            num_group * num_mask * sparse_num_group,\n            3, padding=1, groups=num_group * sparse_num_group)\n        self.fc = nn.Linear(dim * sparse_num_group, dim)\n        # output\n        self.mask_kernel = nn.Linear(\n            dim, kernel_dim)\n        self.cls_score = nn.Linear(\n            dim, num_classes)\n        self.objectness = nn.Linear(\n            dim, 1)\n        self.prior_prob = 0.01\n        self._init_weights()\n\n    def _init_weights(self):\n        for m in self.inst_convs.modules():\n            if isinstance(m, nn.Conv2d):\n                c2_msra_fill(m)\n        bias_value = -math.log((1 - self.prior_prob) / self.prior_prob)\n        for module in [self.iam_conv, self.cls_score]:\n            init.constant_(module.bias, bias_value)\n        init.normal_(self.iam_conv.weight, std=0.01)\n        init.normal_(self.cls_score.weight, std=0.01)\n\n        init.normal_(self.mask_kernel.weight, std=0.01)\n        init.constant_(self.mask_kernel.bias, 0.0)\n        c2_xavier_fill(self.fc)\n\n    def forward(self, seg_features, is_training=True):\n        out = {}\n        # SparseInst part\n        seg_features = self.inst_convs(seg_features)\n        # predict instance activation maps\n        iam = self.iam_conv(seg_features.tile(\n            (1, self.num_group, 1, 1)))\n        if not is_training:\n            iam = iam.view(\n                iam.shape[0],\n                self.num_group,\n                self.num_mask * self.sparse_num_group,\n                *iam.shape[-2:])\n            iam = iam[:, 0, ...]\n            num_group = 1\n        else:\n            num_group = self.num_group\n\n        iam_prob = iam.sigmoid()\n        B, N = iam_prob.shape[:2]\n        C = seg_features.size(1)\n        # BxNxHxW -> BxNx(HW)\n        iam_prob = iam_prob.view(B, N, -1)\n        normalizer = iam_prob.sum(-1).clamp(min=1e-6)\n        iam_prob_norm_hw = iam_prob / normalizer[:, :, None]\n\n        # aggregate features: BxCxHxW -> Bx(HW)xC\n        # (B x N x HW) @ (B x HW x C) -> B x N x C\n        all_inst_features = torch.bmm(\n            iam_prob_norm_hw,\n            seg_features.view(B, C, -1).permute(0, 2, 1)) #BxNxC\n\n        # concat sparse group features\n        inst_features = all_inst_features.reshape(\n            B, num_group,\n            self.sparse_num_group,\n            self.num_mask, -1\n        ).permute(0, 1, 3, 2, 4).reshape(\n            B, num_group,\n            self.num_mask, -1)\n        inst_features = F.relu_(\n            self.fc(inst_features))\n\n        # avg over sparse group\n        iam_prob = iam_prob.view(\n            B, num_group,\n            self.sparse_num_group,\n            self.num_mask,\n            iam_prob.shape[-1])\n        iam_prob = iam_prob.mean(dim=2).flatten(1, 2)\n        inst_features = inst_features.flatten(1, 2)\n        out.update(dict(\n            iam_prob=iam_prob,\n            inst_features=inst_features))\n\n        if self.training:\n            pred_logits = self.cls_score(inst_features)\n            pred_kernel = self.mask_kernel(inst_features)\n            pred_scores = self.objectness(inst_features)\n            out.update(dict(\n                pred_logits=pred_logits,\n                pred_kernel=pred_kernel,\n                pred_scores=pred_scores))\n        return out\n\nclass SparseInsDecoder(nn.Module):\n    def __init__(self, cfg, **kargs) -> None:\n        super().__init__()\n        in_channels = cfg.encoder.out_dims + 2\n        self.output_iam = cfg.decoder.output_iam\n        self.scale_factor = cfg.decoder.scale_factor\n        self.sparse_decoder_weight = cfg.sparse_decoder_weight\n        self.inst_branch = InstanceBranch(cfg.decoder, in_channels)\n        # dim, num_convs, kernel_dim, in_channels\n        self.mask_branch = MaskBranch(cfg.decoder, in_channels)\n        self.sparse_inst_crit = SparseInstCriterion(\n            num_classes=cfg.decoder.num_classes,\n            matcher=SparseInstMatcher(),\n            cfg=cfg)\n        self._init_weights()\n\n    def _init_weights(self):\n        self.inst_branch._init_weights()\n        self.mask_branch._init_weights()\n\n    @torch.no_grad()\n    def compute_coordinates(self, x):\n        h, w = x.size(2), x.size(3)\n        y_loc = -1.0 + 2.0 * torch.arange(h, device=x.device) / (h - 1)\n        x_loc = -1.0 + 2.0 * torch.arange(w, device=x.device) / (w - 1)\n        y_loc, x_loc = torch.meshgrid(y_loc, x_loc)\n        y_loc = y_loc.expand([x.shape[0], 1, -1, -1])\n        x_loc = x_loc.expand([x.shape[0], 1, -1, -1])\n        locations = torch.cat([x_loc, y_loc], 1)\n        return locations.to(x)\n\n    def forward(self, features, is_training=True, **kwargs):\n        output = {}\n        coord_features = self.compute_coordinates(features)\n        features = torch.cat([coord_features, features], dim=1)\n        inst_output = self.inst_branch(\n            features, is_training=is_training)\n        output.update(inst_output)\n\n        if is_training:\n            mask_features = self.mask_branch(features)\n            pred_kernel = inst_output['pred_kernel']\n            N = pred_kernel.shape[1]\n            B, C, H, W = mask_features.shape\n\n            pred_masks = torch.bmm(pred_kernel, mask_features.view(\n            B, C, H * W)).view(B, N, H, W)\n            pred_masks = F.interpolate(\n                pred_masks, scale_factor=self.scale_factor,\n                mode='bilinear', align_corners=False)\n            output.update(dict(\n                pred_masks=pred_masks))\n        \n        if self.training:\n            sparse_inst_losses, matched_indices = self.loss(\n                    output,\n                    lane_idx_map=kwargs.get('lane_idx_map'),\n                    input_shape=kwargs.get('input_shape')\n            )\n            for k, v in sparse_inst_losses.items():\n                sparse_inst_losses[k] = self.sparse_decoder_weight * v\n            output.update(sparse_inst_losses)\n            output['matched_indices'] = matched_indices\n        return output\n\n    def loss(self, output, lane_idx_map, input_shape):\n        \"\"\"\n        output : from self.forward\n        lane_idx_map : instance-level segmentation map, [20, H, W] where 20=max_lanes\n        \"\"\"\n        pred_masks = output['pred_masks']\n        pred_masks = output['pred_masks'].view(\n            pred_masks.shape[0],\n            self.inst_branch.num_group,\n            self.inst_branch.num_mask,\n            *pred_masks.shape[2:])\n        pred_logits = output['pred_logits']\n        pred_logits = output['pred_logits'].view(\n            pred_logits.shape[0],\n            self.inst_branch.num_group,\n            self.inst_branch.num_mask,\n            *pred_logits.shape[2:])\n        pred_scores = output['pred_scores']\n        pred_scores = output['pred_scores'].view(\n            pred_scores.shape[0],\n            self.inst_branch.num_group,\n            self.inst_branch.num_mask,\n            *pred_scores.shape[2:])\n\n        out = {}\n        all_matched_indices = []\n        for group_idx in range(self.inst_branch.num_group):\n            sparse_inst_losses, matched_indices = \\\n                self.sparse_inst_crit(\n                    outputs=dict(\n                        pred_masks=pred_masks[:, group_idx, ...].contiguous(),\n                        pred_logits=pred_logits[:, group_idx, ...].contiguous(),\n                        pred_scores=pred_scores[:, group_idx, ...].contiguous(),\n                    ),\n                    targets=self.prepare_targets(lane_idx_map),\n                    input_shape=input_shape, # seg_bev\n                )\n            for k, v in sparse_inst_losses.items():\n                out['%s_%d' % (k, group_idx)] = v\n            all_matched_indices.append(matched_indices)\n        return out, all_matched_indices\n\n    def prepare_targets(self, targets):\n        new_targets = []\n        for targets_per_image in targets:\n            target = {}\n            cls_labels = targets_per_image.flatten(-2).max(-1)[0]\n            pos_mask = cls_labels > 0\n\n            target[\"labels\"] = cls_labels[pos_mask].long()\n            target[\"masks\"] = targets_per_image[pos_mask] > 0\n            new_targets.append(target)\n        return new_targets\n        return output\n"
  },
  {
    "path": "models/sparse_inst_loss.py",
    "content": "# Copyright (c) Tianheng Cheng and its affiliates. All Rights Reserved\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom torch.cuda.amp import autocast\nfrom scipy.optimize import linear_sum_assignment\nfrom fvcore.nn import sigmoid_focal_loss_jit\n\nfrom typing import Optional, List\n\nimport torch\nfrom torch import Tensor\nimport torch.distributed as dist\nimport torch.nn.functional as F\nimport torchvision\n\n\ndef _max_by_axis(the_list):\n    # type: (List[List[int]]) -> List[int]\n    maxes = the_list[0]\n    for sublist in the_list[1:]:\n        for index, item in enumerate(sublist):\n            maxes[index] = max(maxes[index], item)\n    return maxes\n\n\nclass NestedTensor(object):\n    def __init__(self, tensors, mask: Optional[Tensor]):\n        self.tensors = tensors\n        self.mask = mask\n\n    def to(self, device):\n        cast_tensor = self.tensors.to(device)\n        mask = self.mask\n        if mask is not None:\n            assert mask is not None\n            cast_mask = mask.to(device)\n        else:\n            cast_mask = None\n        return NestedTensor(cast_tensor, cast_mask)\n\n    def decompose(self):\n        return self.tensors, self.mask\n\n    def __repr__(self):\n        return str(self.tensors)\n\n# _onnx_nested_tensor_from_tensor_list() is an implementation of\n# nested_tensor_from_tensor_list() that is supported by ONNX tracing.\n\n\n@torch.jit.unused\ndef _onnx_nested_tensor_from_tensor_list(tensor_list: List[Tensor]) -> NestedTensor:\n    max_size = []\n    for i in range(tensor_list[0].dim()):\n        max_size_i = torch.max(torch.stack([img.shape[i]\n                                            for img in tensor_list]).to(torch.float32)).to(torch.int64)\n        max_size.append(max_size_i)\n    max_size = tuple(max_size)\n\n    # work around for\n    # pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)\n    # m[: img.shape[1], :img.shape[2]] = False\n    # which is not yet supported in onnx\n    padded_imgs = []\n    padded_masks = []\n    for img in tensor_list:\n        padding = [(s1 - s2) for s1, s2 in zip(max_size, tuple(img.shape))]\n        padded_img = torch.nn.functional.pad(img, (0, padding[2], 0, padding[1], 0, padding[0]))\n        padded_imgs.append(padded_img)\n\n        m = torch.zeros_like(img[0], dtype=torch.int, device=img.device)\n        padded_mask = torch.nn.functional.pad(m, (0, padding[2], 0, padding[1]), \"constant\", 1)\n        padded_masks.append(padded_mask.to(torch.bool))\n\n    tensor = torch.stack(padded_imgs)\n    mask = torch.stack(padded_masks)\n\n    return NestedTensor(tensor, mask=mask)\n\n\ndef nested_tensor_from_tensor_list(tensor_list: List[Tensor]):\n    # TODO make this more general\n    if tensor_list[0].ndim == 3:\n        if torchvision._is_tracing():\n            # nested_tensor_from_tensor_list() does not export well to ONNX\n            # call _onnx_nested_tensor_from_tensor_list() instead\n            return _onnx_nested_tensor_from_tensor_list(tensor_list)\n\n        # TODO make it support different-sized images\n        max_size = _max_by_axis([list(img.shape) for img in tensor_list])\n        # min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list]))\n        batch_shape = [len(tensor_list)] + max_size\n        b, c, h, w = batch_shape\n        dtype = tensor_list[0].dtype\n        device = tensor_list[0].device\n        tensor = torch.zeros(batch_shape, dtype=dtype, device=device)\n        mask = torch.ones((b, h, w), dtype=torch.bool, device=device)\n        for img, pad_img, m in zip(tensor_list, tensor, mask):\n            pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)\n            m[: img.shape[1], :img.shape[2]] = False\n    else:\n        raise ValueError('not supported')\n    return NestedTensor(tensor, mask)\n\n\ndef nested_masks_from_list(tensor_list: List[Tensor], input_shape=None):\n    if tensor_list[0].ndim == 3:\n        dim_size = sum([img.shape[0] for img in tensor_list])\n        if input_shape is None:\n            max_size = _max_by_axis([list(img.shape[-2:]) for img in tensor_list])\n        else:\n            max_size = [input_shape[0], input_shape[1]]\n        batch_shape = [dim_size] + max_size\n        # b, h, w = batch_shape\n        dtype = tensor_list[0].dtype\n        device = tensor_list[0].device\n        tensor = torch.zeros(batch_shape, dtype=dtype, device=device)\n        mask = torch.zeros(batch_shape, dtype=torch.bool, device=device)\n        idx = 0\n        for img in tensor_list:\n            c = img.shape[0]\n            c_ = idx + c\n            tensor[idx: c_, :img.shape[1], : img.shape[2]].copy_(img)\n            mask[idx: c_, :img.shape[1], :img.shape[2]] = True\n            idx = c_\n    else:\n        raise ValueError('not supported')\n    return NestedTensor(tensor, mask)\n\n\ndef is_dist_avail_and_initialized():\n    if not dist.is_available():\n        return False\n    if not dist.is_initialized():\n        return False\n    return True\n\n\ndef get_world_size():\n    if not is_dist_avail_and_initialized():\n        return 1\n    return dist.get_world_size()\n\n\ndef aligned_bilinear(tensor, factor):\n    # borrowed from Adelaidet: https://github1s.com/aim-uofa/AdelaiDet/blob/HEAD/adet/utils/comm.py\n    assert tensor.dim() == 4\n    assert factor >= 1\n    assert int(factor) == factor\n\n    if factor == 1:\n        return tensor\n\n    h, w = tensor.size()[2:]\n    tensor = F.pad(tensor, pad=(0, 1, 0, 1), mode=\"replicate\")\n    oh = factor * h + 1\n    ow = factor * w + 1\n    tensor = F.interpolate(\n        tensor, size=(oh, ow),\n        mode='bilinear',\n        align_corners=True\n    )\n    tensor = F.pad(\n        tensor, pad=(factor // 2, 0, factor // 2, 0),\n        mode=\"replicate\"\n    )\n\n    return tensor[:, :, :oh - 1, :ow - 1]\n\n\n\ndef compute_mask_iou(inputs, targets):\n    inputs = inputs.sigmoid()\n    # thresholding\n    binarized_inputs = (inputs >= 0.4).float()\n    targets = (targets > 0.5).float()\n    intersection = (binarized_inputs * targets).sum(-1)\n    union = targets.sum(-1) + binarized_inputs.sum(-1) - intersection\n    score = intersection / (union + 1e-6)\n    return score\n\n\ndef dice_score(inputs, targets):\n    inputs = inputs.sigmoid()\n    numerator = 2 * torch.matmul(inputs, targets.t())\n    denominator = (\n        inputs * inputs).sum(-1)[:, None] + (targets * targets).sum(-1)\n    score = numerator / (denominator + 1e-4)\n    return score\n\n\ndef dice_loss(inputs, targets, reduction='sum'):\n    inputs = inputs.sigmoid()\n    assert inputs.shape == targets.shape\n    numerator = 2 * (inputs * targets).sum(1)\n    denominator = (inputs * inputs).sum(-1) + (targets * targets).sum(-1)\n    loss = 1 - (numerator) / (denominator + 1e-4)\n    if reduction == 'none':\n        return loss\n    return loss.sum()\n\n\n# @SPARSE_INST_CRITERION_REGISTRY.register()\nclass SparseInstCriterion(nn.Module):\n    # This part is partially derivated from: https://github.com/facebookresearch/detr/blob/main/models/detr.py\n\n    def __init__(self, num_classes=4, cfg=None, matcher=None):\n        super().__init__()\n        self.matcher = matcher\n        self.losses = (\"labels\", \"masks\") # cfg.MODEL.SPARSE_INST.LOSS.ITEMS\n        self.weight_dict = self.get_weight_dict(cfg)\n        self.num_classes = num_classes # cfg.MODEL.SPARSE_INST.DECODER.NUM_CLASSES\n\n    def get_weight_dict(self, cfg):\n        losses = (\"loss_ce\", \"loss_mask\", \"loss_dice\", \"loss_objectness\")\n        weight_dict = {}\n\n        ce_weight = cfg.get('ce_weight', 2.0)\n        mask_weight = cfg.get('mask_weight', 5.0)\n        dice_weight = cfg.get('dice_weight', 2.0)\n        objectness_weight = cfg.get('objectness_weight', 1.0)\n        weight_dict = dict(\n            zip(losses, (ce_weight, mask_weight, dice_weight, objectness_weight)))\n        return weight_dict\n\n    def _get_src_permutation_idx(self, indices):\n        # permute predictions following indices\n        batch_idx = torch.cat([torch.full_like(src, i)\n                              for i, (src, _) in enumerate(indices)])\n        src_idx = torch.cat([src for (src, _) in indices])\n        return batch_idx, src_idx\n\n    def _get_tgt_permutation_idx(self, indices):\n        # permute targets following indices\n        batch_idx = torch.cat([torch.full_like(tgt, i)\n                              for i, (_, tgt) in enumerate(indices)])\n        tgt_idx = torch.cat([tgt for (_, tgt) in indices])\n        return batch_idx, tgt_idx\n\n    def loss_labels(self, outputs, targets, indices, num_instances, input_shape=None):\n        assert \"pred_logits\" in outputs\n        src_logits = outputs['pred_logits']\n        target_classes = torch.full(src_logits.shape[:2], 0, # self.num_classes,\n                                    dtype=torch.int64, device=src_logits.device)\n        if sum([tmp[0].shape[0] for tmp in indices]) > 0:\n            idx = self._get_src_permutation_idx(indices)\n            target_classes_o = torch.cat([t[\"labels\"][J]\n                                         for t, (_, J) in zip(targets, indices)])\n            target_classes[idx] = target_classes_o\n\n        src_logits = src_logits.flatten(0, 1)\n        # prepare one_hot target.\n        target_classes = target_classes.flatten(0, 1)\n        pos_inds = torch.nonzero(\n            target_classes != self.num_classes, as_tuple=True)[0]\n        labels = torch.zeros_like(src_logits)\n        labels[pos_inds, target_classes[pos_inds]] = 1\n        # comp focal loss.\n        class_loss = sigmoid_focal_loss_jit(\n            src_logits,\n            labels,\n            alpha=0.25,\n            gamma=2.0,\n            reduction=\"sum\",\n        ) / num_instances\n        losses = {'loss_ce': class_loss}\n        return losses\n\n    def loss_masks_with_iou_objectness(self, outputs, targets, indices, num_instances, input_shape):\n        src_idx = self._get_src_permutation_idx(indices)\n        tgt_idx = self._get_tgt_permutation_idx(indices)\n        # Bx100xHxW\n        assert \"pred_masks\" in outputs\n        assert \"pred_scores\" in outputs\n        src_iou_scores = outputs[\"pred_scores\"]\n        src_masks = outputs[\"pred_masks\"]\n        with torch.no_grad():\n            target_masks, _ = nested_masks_from_list(\n                [t[\"masks\"] for t in targets], input_shape).decompose()\n        num_masks = [len(t[\"masks\"]) for t in targets]\n        target_masks = target_masks.to(src_masks)\n        if len(target_masks) == 0:\n            losses = {\n                \"loss_dice\": src_masks.sum() * 0.0,\n                \"loss_mask\": src_masks.sum() * 0.0,\n                \"loss_objectness\": src_iou_scores.sum() * 0.0\n            }\n            return losses\n\n        src_masks = src_masks[src_idx]\n        target_masks = F.interpolate(\n            target_masks[:, None], size=src_masks.shape[-2:], mode='bilinear', align_corners=False).squeeze(1)\n\n        src_masks = src_masks.flatten(1)\n        # FIXME: tgt_idx\n        mix_tgt_idx = torch.zeros_like(tgt_idx[1])\n        cum_sum = 0\n        for num_mask in num_masks:\n            mix_tgt_idx[cum_sum: cum_sum + num_mask] = cum_sum\n            cum_sum += num_mask\n        mix_tgt_idx += tgt_idx[1]\n\n        target_masks = target_masks[mix_tgt_idx].flatten(1)\n\n        with torch.no_grad():\n            ious = compute_mask_iou(src_masks, target_masks)\n\n        tgt_iou_scores = ious\n        src_iou_scores = src_iou_scores[src_idx]\n        tgt_iou_scores = tgt_iou_scores.flatten(0)\n        src_iou_scores = src_iou_scores.flatten(0)\n\n        losses = {\n            \"loss_objectness\": F.binary_cross_entropy_with_logits(src_iou_scores, tgt_iou_scores, reduction='mean'),\n            \"loss_dice\": dice_loss(src_masks, target_masks) / num_instances,\n            \"loss_mask\": F.binary_cross_entropy_with_logits(src_masks, target_masks, reduction='mean')\n        }\n        return losses\n\n    def get_loss(self, loss, outputs, targets, indices, num_instances, **kwargs):\n        loss_map = {\n            \"labels\": self.loss_labels,\n            \"masks\": self.loss_masks_with_iou_objectness,\n        }\n        if loss == \"loss_objectness\":\n            # NOTE: loss_objectness will be calculated in `loss_masks_with_iou_objectness`\n            return {}\n        assert loss in loss_map\n        return loss_map[loss](outputs, targets, indices, num_instances, **kwargs)\n\n    def forward(self, outputs, targets, input_shape):\n\n        outputs_without_aux = {k: v for k,\n                               v in outputs.items() if k != 'aux_outputs'}\n\n        # Retrieve the matching between the outputs of the last layer and the targets\n        indices = self.matcher(outputs_without_aux, targets, input_shape)\n        # Compute the average number of target boxes accross all nodes, for normalization purposes\n        num_instances = sum(len(t[\"labels\"]) for t in targets)\n        num_instances = torch.as_tensor(\n            [num_instances], dtype=torch.float, device=next(iter(outputs.values())).device)\n        if is_dist_avail_and_initialized():\n            torch.distributed.all_reduce(num_instances)\n        num_instances = torch.clamp(\n            num_instances / get_world_size(), min=1).item()\n        # Compute all the requested losses\n        losses = {}\n        for loss in self.losses:\n            # try:\n            losses.update(self.get_loss(loss, outputs, targets, indices,\n                                        num_instances, input_shape=input_shape))\n            # except Exception as e:\n            #     import pdb; pdb.set_trace()\n\n        for k in losses.keys():\n            if k in self.weight_dict:\n                losses[k] *= self.weight_dict[k]\n        return losses, indices\n\n\n# @SPARSE_INST_MATCHER_REGISTRY.register()\nclass SparseInstMatcherV1(nn.Module):\n\n    def __init__(self, cfg=None):\n        super().__init__()\n        self.alpha = 0.8 # cfg.MODEL.SPARSE_INST.MATCHER.ALPHA\n        self.beta = 0.2 # cfg.MODEL.SPARSE_INST.MATCHER.BETA\n        self.mask_score = dice_score\n\n    @torch.no_grad()\n    def forward(self, outputs, targets, input_shape):\n        B, N, H, W = outputs[\"pred_masks\"].shape\n        pred_masks = outputs['pred_masks']\n        pred_logits = outputs['pred_logits'].sigmoid()\n\n        indices = []\n\n        for i in range(B):\n            tgt_ids = targets[i][\"labels\"]\n            # no annotations\n            if tgt_ids.shape[0] == 0:\n                indices.append((torch.as_tensor([]),\n                                torch.as_tensor([])))\n                continue\n\n            tgt_masks = targets[i]['masks'].tensor.to(pred_masks)\n            pred_logit = pred_logits[i]\n            out_masks = pred_masks[i]\n\n            # upsampling:\n            # (1) padding/\n            # (2) upsampling to 1x input size (input_shape)\n            # (3) downsampling to 0.25x input size (output mask size)\n            ori_h, ori_w = tgt_masks.size(1), tgt_masks.size(2)\n            tgt_masks_ = torch.zeros(\n                (1, tgt_masks.size(0), input_shape[0], input_shape[1])).to(pred_masks)\n            tgt_masks_[0, :, :ori_h, :ori_w] = tgt_masks\n            tgt_masks = F.interpolate(\n                tgt_masks_, size=out_masks.shape[-2:], mode='bilinear', align_corners=False)[0]\n\n            # compute dice score and classification score\n            tgt_masks = tgt_masks.flatten(1)\n            out_masks = out_masks.flatten(1)\n\n            mask_score = self.mask_score(out_masks, tgt_masks)\n            # Nx(Number of gts)\n            matching_prob = pred_logit[:, tgt_ids]\n            C = (mask_score ** self.alpha) * (matching_prob ** self.beta)\n            # hungarian matching\n            inds = linear_sum_assignment(C.cpu(), maximize=True)\n            indices.append(inds)\n        return [(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices]\n\n\n# @SPARSE_INST_MATCHER_REGISTRY.register()\nclass SparseInstMatcher(nn.Module):\n\n    def __init__(self, cfg=None):\n        super().__init__()\n        self.alpha = 0.8 # cfg.MODEL.SPARSE_INST.MATCHER.ALPHA\n        self.beta = 0.2 # cfg.MODEL.SPARSE_INST.MATCHER.BETA\n        self.mask_score = dice_score\n\n    def forward(self, outputs, targets, input_shape):\n        with torch.no_grad():\n            # B x 40 x 90 x 120 \n            B, N, H, W = outputs[\"pred_masks\"].shape\n            pred_masks = outputs['pred_masks']\n            pred_logits = outputs['pred_logits'].sigmoid()\n            tgt_ids = torch.cat([v[\"labels\"] for v in targets])\n\n            if tgt_ids.shape[0] == 0:\n                return [(torch.as_tensor([]).to(pred_logits), torch.as_tensor([]).to(pred_logits))] * B\n            tgt_masks, _ = nested_masks_from_list(\n                [t[\"masks\"] for t in targets], input_shape).decompose()\n            device = pred_masks.device\n            tgt_masks = tgt_masks.to(pred_masks)\n\n            tgt_masks = F.interpolate(\n                tgt_masks[:, None], size=pred_masks.shape[-2:], mode=\"bilinear\", align_corners=False).squeeze(1)\n\n            pred_masks = pred_masks.view(B * N, -1)\n            tgt_masks = tgt_masks.flatten(1)\n            with autocast(enabled=False):\n                pred_masks = pred_masks.float()\n                tgt_masks = tgt_masks.float()\n                pred_logits = pred_logits.float()\n                mask_score = self.mask_score(pred_masks, tgt_masks)\n                # Nx(Number of gts)\n                matching_prob = pred_logits.view(B * N, -1)[:, tgt_ids]\n                C = (mask_score ** self.alpha) * (matching_prob ** self.beta)\n\n            C = C.view(B, N, -1).cpu()\n            # hungarian matching\n            sizes = [len(v[\"masks\"]) for v in targets]\n            indices = [linear_sum_assignment(c[i], maximize=True)\n                       for i, c in enumerate(C.split(sizes, -1))]\n            indices = [(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(\n                j, dtype=torch.int64)) for i, j in indices]\n            return indices\n\n\n# def build_sparse_inst_matcher(cfg):\n#     name = cfg.MODEL.SPARSE_INST.MATCHER.NAME\n#     return SPARSE_INST_MATCHER_REGISTRY.get(name)(cfg)\n\n\n# def build_sparse_inst_criterion(cfg):\n#     matcher = build_sparse_inst_matcher(cfg)\n#     name = cfg.MODEL.SPARSE_INST.LOSS.NAME\n#     return SPARSE_INST_CRITERION_REGISTRY.get(name)(cfg, matcher)\n"
  },
  {
    "path": "models/transformer_bricks.py",
    "content": "import numpy as np\nimport math\nimport warnings\n\nimport torch\nimport torch.nn as nn\nfrom torch.nn import init\nimport torch.nn.functional as F\n\nfrom fvcore.nn.weight_init import c2_msra_fill, c2_xavier_fill\n\nfrom mmdet.models.utils.builder import TRANSFORMER\nfrom mmcv.cnn.bricks.transformer import FFN, build_positional_encoding\nfrom mmcv.cnn import (build_activation_layer, build_conv_layer,\n                      build_norm_layer, xavier_init, constant_init)\nfrom mmcv.runner.base_module import BaseModule\nfrom mmcv.cnn.bricks.transformer import (BaseTransformerLayer,\n                                         TransformerLayerSequence,\n                                         build_transformer_layer_sequence)\nfrom mmcv.cnn.bricks.registry import (ATTENTION,TRANSFORMER_LAYER,\n                                      TRANSFORMER_LAYER_SEQUENCE)\nfrom mmcv.ops.multi_scale_deform_attn import MultiScaleDeformableAttnFunction\n\nfrom .scatter_utils import scatter_mean\nfrom .utils import inverse_sigmoid\n\n\ndef pos2posemb3d(pos, num_pos_feats=128, temperature=10000):\n    scale = 2 * math.pi\n    pos = pos * scale\n    dim_t = torch.arange(num_pos_feats, dtype=torch.float32, device=pos.device)\n    dim_t = temperature ** (2 * (dim_t // 2) / num_pos_feats)\n    pos_x = pos[..., 0, None] / dim_t\n    pos_y = pos[..., 1, None] / dim_t\n    pos_z = pos[..., 2, None] / dim_t\n    pos_x = torch.stack((pos_x[..., 0::2].sin(), pos_x[..., 1::2].cos()), dim=-1).flatten(-2)\n    pos_y = torch.stack((pos_y[..., 0::2].sin(), pos_y[..., 1::2].cos()), dim=-1).flatten(-2)\n    pos_z = torch.stack((pos_z[..., 0::2].sin(), pos_z[..., 1::2].cos()), dim=-1).flatten(-2)\n    posemb = torch.cat((pos_y, pos_x, pos_z), dim=-1)\n    return posemb\n\n\ndef generate_ref_pt(minx, miny, maxx, maxy, z, nx, ny, device='cuda'):\n    if isinstance(z, list):\n        nz = z[-1]\n        # minx, miny, maxx, maxy : in ground coords\n        xs = torch.linspace(minx, maxx, nx, dtype=torch.float, device=device\n                ).view(1, -1, 1).expand(ny, nx, nz)\n        ys = torch.linspace(miny, maxy, ny, dtype=torch.float, device=device\n                ).view(-1, 1, 1).expand(ny, nx, nz)\n        zs = torch.linspace(z[0], z[1], nz, dtype=torch.float, device=device\n                ).view(1, 1, -1).expand(ny, nx, nz)\n        ref_3d = torch.stack([xs, ys, zs], dim=-1)\n        ref_3d = ref_3d.flatten(1, 2)\n    else:\n        # minx, miny, maxx, maxy : in ground coords\n        xs = torch.linspace(minx, maxx, nx, dtype=torch.float, device=device\n                ).view(1, -1, 1).expand(ny, nx, 1)\n        ys = torch.linspace(miny, maxy, ny, dtype=torch.float, device=device\n                ).view(-1, 1, 1).expand(ny, nx, 1)\n        ref_3d = F.pad(torch.cat([xs, ys], dim=-1), (0, 1), mode='constant', value=z)\n    return ref_3d\n\n\ndef ground2img(coords3d, H, W, lidar2img, ori_shape, mask=None, return_img_pts=False):\n    coords3d = coords3d.clone()\n    img_pt = coords3d.flatten(1, 2) @ lidar2img.permute(0, 2, 1)\n    img_pt = torch.cat([\n        img_pt[..., :2] / torch.maximum(\n            img_pt[..., 2:3], torch.ones_like(img_pt[..., 2:3]) * 1e-5),\n        img_pt[..., 2:]\n    ], dim=-1)\n\n    # rescale to feature_map size\n    x = img_pt[..., 0] / ori_shape[0][1] * (W - 1)\n    y = img_pt[..., 1] / ori_shape[0][0] * (H - 1)\n    valid = (x >= 0) * (y >= 0) * (x <= (W - 1)) \\\n          * (y <= (H - 1)) * (img_pt[..., 2] > 0)\n    if return_img_pts:\n        return x, y, valid\n\n    if mask is not None:\n        valid = valid * mask.flatten(1, 2).float()\n\n    # B, C, H, W = img_feats.shape\n    B = coords3d.shape[0]\n    canvas = torch.zeros((B, H, W, 3 + 1),\n                         dtype=torch.float32,\n                         device=coords3d.device)\n    x = x.long()\n    y = y.long()\n    ind = (x + y * W) * valid.long()\n    # ind = torch.clamp(ind, 0, H * W - 1)\n    ind = ind.long().unsqueeze(-1).repeat(1, 1, canvas.shape[-1])\n    canvas = canvas.flatten(1, 2)\n    target = coords3d.flatten(1, 2).clone()\n    scatter_mean(target, ind, out=canvas, dim=1)\n    canvas = canvas.view(B, H, W, canvas.shape[-1]\n        ).permute(0, 3, 1, 2).contiguous()\n    canvas[:, :, 0, 0] = 0\n    return canvas\n\n\n@ATTENTION.register_module()\nclass MSDeformableAttention3D(BaseModule):\n    def __init__(self,\n                 embed_dims=256,\n                 num_heads=8,\n                 num_levels=4,\n                 num_points=8,\n                 im2col_step=64,\n                 dropout=0.1,\n                 num_query=None,\n                 num_anchor_per_query=None,\n                 anchor_y_steps=None,\n                 batch_first=False,\n                 norm_cfg=None,\n                 init_cfg=None):\n        super().__init__(init_cfg)\n        if embed_dims % num_heads != 0:\n            raise ValueError(f'embed_dims must be divisible by num_heads, '\n                             f'but got {embed_dims} and {num_heads}')\n        dim_per_head = embed_dims // num_heads\n        self.norm_cfg = norm_cfg\n        self.batch_first = batch_first\n        self.output_proj = None\n        self.fp16_enabled = False\n\n        self.num_query = num_query\n        self.num_anchor_per_query = num_anchor_per_query\n        self.register_buffer('anchor_y_steps',\n            torch.from_numpy(anchor_y_steps).float())\n        self.num_points_per_anchor = len(anchor_y_steps) // num_anchor_per_query\n\n        # you'd better set dim_per_head to a power of 2\n        # which is more efficient in the CUDA implementation\n        def _is_power_of_2(n):\n            if (not isinstance(n, int)) or (n < 0):\n                raise ValueError(\n                    'invalid input for _is_power_of_2: {} (type: {})'.format(\n                        n, type(n)))\n            return (n & (n - 1) == 0) and n != 0\n\n        if not _is_power_of_2(dim_per_head):\n            warnings.warn(\n                \"You'd better set embed_dims in \"\n                'MultiScaleDeformAttention to make '\n                'the dimension of each attention head a power of 2 '\n                'which is more efficient in our CUDA implementation.')\n\n        self.im2col_step = im2col_step\n        self.embed_dims = embed_dims\n        self.num_levels = num_levels\n        self.num_heads = num_heads\n        self.num_points = num_points\n        self.sampling_offsets = nn.Linear(\n            embed_dims,\n            num_heads * num_levels * num_points * 2 * self.num_points_per_anchor)\n        self.attention_weights = nn.Linear(embed_dims,\n            num_heads * num_levels * num_points * self.num_points_per_anchor)\n        self.value_proj = nn.Linear(embed_dims, embed_dims)\n\n        self.init_weights()\n\n    def init_weights(self):\n        \"\"\"Default initialization for Parameters of Module.\"\"\"\n        constant_init(self.sampling_offsets, 0.)\n        thetas = torch.arange(\n            self.num_heads,\n            dtype=torch.float32) * (2.0 * math.pi / self.num_heads)\n        grid_init = torch.stack([thetas.cos(), thetas.sin()], -1)\n        grid_init = (grid_init /\n                     grid_init.abs().max(-1, keepdim=True)[0]).view(\n            self.num_heads, 1, 1, 1,\n            2).repeat(1, self.num_points_per_anchor, self.num_levels, self.num_points, 1)\n        for i in range(self.num_points):\n            grid_init[..., i, :] *= i + 1\n\n        self.sampling_offsets.bias.data = grid_init.view(-1)\n        constant_init(self.attention_weights, val=0., bias=0.)\n        xavier_init(self.value_proj, distribution='uniform', bias=0.)\n        xavier_init(self.output_proj, distribution='uniform', bias=0.)\n        self._is_init = True\n\n    def ref_to_lidar(self, reference_points, pc_range, not_y=True):\n        reference_points = reference_points.clone()\n        reference_points[..., 0:1] = reference_points[..., 0:1] * \\\n            (pc_range[3] - pc_range[0]) + pc_range[0]\n        if not not_y:\n            reference_points[..., 1:2] = reference_points[..., 1:2] * \\\n                (pc_range[4] - pc_range[1]) + pc_range[1]\n        reference_points[..., 2:3] = reference_points[..., 2:3] * \\\n            (pc_range[5] - pc_range[2]) + pc_range[2]\n        return reference_points\n\n    def point_sampling(self, reference_points, lidar2img, ori_shape):\n        x, y, mask = ground2img(\n            reference_points, H=2, W=2,\n            lidar2img=lidar2img, ori_shape=ori_shape,\n            mask=None, return_img_pts=True)\n        return torch.stack([x, y], -1), mask\n\n    def forward(self,\n                query,\n                key=None,\n                value=None,\n                identity=None,\n                query_pos=None,\n                key_padding_mask=None,\n                reference_points=None,\n                spatial_shapes=None,\n                level_start_index=None,\n                pc_range=None,\n                lidar2img=None,\n                pad_shape=None,\n                key_pos=None,\n                **kwargs):\n        if value is None:\n            assert False\n            value = key\n        if identity is None:\n            identity = query\n        if query_pos is not None:\n            query = query + query_pos\n\n        if key_pos is not None:\n            value = value + key_pos\n\n        if not self.batch_first:\n            # change to (bs, num_query ,embed_dims)\n            query = query.permute(1, 0, 2)\n            value = value.permute(1, 0, 2)\n\n        bs, num_query, _ = query.shape\n        bs, num_value, _ = value.shape\n        assert (spatial_shapes[:, 0] * spatial_shapes[:, 1]).sum() == num_value\n\n        value = self.value_proj(value)\n        if key_padding_mask is not None:\n            value = value.masked_fill(key_padding_mask[..., None], 0.0)\n        value = value.view(bs, num_value, self.num_heads, -1)\n\n        sampling_offsets = self.sampling_offsets(query).view(\n            bs, num_query, self.num_points_per_anchor,\n            self.num_heads, self.num_levels, self.num_points, 2)\n        attention_weights = self.attention_weights(query).view(\n            bs, num_query, self.num_points_per_anchor,\n            self.num_heads, self.num_levels * self.num_points)\n\n        attention_weights = attention_weights.softmax(-1)\n        attention_weights = attention_weights.view(bs, num_query,\n                                                   self.num_points_per_anchor,\n                                                   self.num_heads,\n                                                   self.num_levels,\n                                                   self.num_points)\n\n        reference_points = reference_points.view(\n            bs, self.num_query, self.num_anchor_per_query, -1, 2)\n        ref_pt3d = torch.cat([\n            reference_points[..., 0:1], # x\n            self.anchor_y_steps.view(1, 1, self.num_anchor_per_query, -1, 1\n                ).expand_as(reference_points[..., 0:1]), # y\n            reference_points[..., 1:2] # z\n        ], dim=-1)\n\n        sampling_locations = self.ref_to_lidar(ref_pt3d, pc_range, not_y=True)\n        sampling_locations2d, mask = self.point_sampling(\n            F.pad(sampling_locations.flatten(1, 2), (0, 1), value=1),\n            lidar2img=lidar2img, ori_shape=pad_shape,\n        )\n        sampling_locations2d = sampling_locations2d.view(\n            *sampling_locations.shape[:-1], 2)\n\n        offset_normalizer = torch.stack(\n            [spatial_shapes[..., 1], spatial_shapes[..., 0]], -1)\n        sampling_offsets = sampling_offsets / \\\n            offset_normalizer[None, None, None, :, None, :]\n\n        sampling_locations2d = sampling_locations2d.view(\n                bs, self.num_query, self.num_anchor_per_query, -1, 1, 1, 1, 2) \\\n            + sampling_offsets.view(\n                bs, self.num_query, self.num_anchor_per_query, self.num_points_per_anchor,\n                *sampling_offsets.shape[3:]\n            )\n\n        # reshape, move self.num_anchor_per_query to last axis\n        sampling_locations2d = sampling_locations2d.permute(0, 1, 2, 4, 5, 6, 3, 7)\n        attention_weights = attention_weights.permute(0, 1, 3, 4, 5, 2)\n        sampling_locations2d = sampling_locations2d.flatten(-3, -2)\n        attention_weights = attention_weights.flatten(-2) / self.num_points_per_anchor\n\n        xy = 2\n        num_all_points = sampling_locations2d.shape[-2]\n\n        sampling_locations2d = sampling_locations2d.view(\n            bs, num_query, self.num_heads, self.num_levels, num_all_points, xy)\n\n        if torch.cuda.is_available() and value.is_cuda:\n            output = MultiScaleDeformableAttnFunction.apply(\n                value, spatial_shapes, level_start_index, sampling_locations2d,\n                attention_weights, self.im2col_step)\n        else:\n            output = multi_scale_deformable_attn_pytorch(\n                value, spatial_shapes, sampling_locations2d, attention_weights)\n        if not self.batch_first:\n            output = output.permute(1, 0, 2)\n        return output\n\n\n@TRANSFORMER_LAYER.register_module()\nclass LATRDecoderLayer(BaseTransformerLayer):\n    def __init__(self,\n                 attn_cfgs,\n                 feedforward_channels,\n                 ffn_dropout=0.0,\n                 operation_order=None,\n                 act_cfg=dict(type='ReLU', inplace=True),\n                 norm_cfg=dict(type='LN'),\n                 ffn_num_fcs=2,\n                 **kwargs):\n        super(LATRDecoderLayer, self).__init__(\n            attn_cfgs=attn_cfgs,\n            feedforward_channels=feedforward_channels,\n            ffn_dropout=ffn_dropout,\n            operation_order=operation_order,\n            act_cfg=act_cfg,\n            norm_cfg=norm_cfg,\n            ffn_num_fcs=ffn_num_fcs,\n            **kwargs)\n\n    def forward(self,\n                query,\n                key=None,\n                value=None,\n                query_pos=None,\n                key_pos=None,\n                attn_masks=None,\n                query_key_padding_mask=None,\n                key_padding_mask=None,\n                **kwargs):\n        query = super().forward(\n            query=query, key=key, value=value,\n            query_pos=query_pos, key_pos=key_pos,\n            attn_masks=attn_masks, query_key_padding_mask=query_key_padding_mask,\n            key_padding_mask=key_padding_mask, **kwargs)\n        return query\n\n\n@TRANSFORMER_LAYER_SEQUENCE.register_module()\nclass LATRTransformerDecoder(TransformerLayerSequence):\n    def __init__(self,\n                 *args, embed_dims=None,\n                 post_norm_cfg=dict(type='LN'),\n                 enlarge_length=10,\n                 M_decay_ratio=10,\n                 num_query=None,\n                 num_anchor_per_query=None,\n                 anchor_y_steps=None,\n                 **kwargs):\n        super(LATRTransformerDecoder, self).__init__(*args, **kwargs)\n        if post_norm_cfg is not None:\n            self.post_norm = build_norm_layer(post_norm_cfg,\n                                              self.embed_dims)[1]\n        else:\n            self.post_norm = None\n\n        self.num_query = num_query\n        self.num_anchor_per_query = num_anchor_per_query\n        self.anchor_y_steps = anchor_y_steps\n        self.num_points_per_anchor = len(anchor_y_steps) // num_anchor_per_query\n\n        self.embed_dims = embed_dims\n        self.gflat_pred_layer = nn.Sequential(\n            nn.Conv2d(embed_dims + 4, embed_dims, 3, stride=2, padding=1, bias=False),\n            nn.BatchNorm2d(embed_dims),\n            nn.ReLU(True),\n            nn.Conv2d(embed_dims, embed_dims, 3, stride=2, padding=1, bias=False),\n            nn.BatchNorm2d(embed_dims),\n            nn.ReLU(True),\n            nn.AdaptiveAvgPool2d(1),\n            nn.Conv2d(embed_dims, embed_dims, 1),\n            nn.BatchNorm2d(embed_dims),\n            nn.ReLU(True),\n            nn.Conv2d(embed_dims, embed_dims // 4, 1),\n            nn.BatchNorm2d(embed_dims // 4),\n            nn.ReLU(True),\n            nn.Conv2d(embed_dims // 4, 2, 1))\n\n        self.position_encoder = nn.Sequential(\n            nn.Conv2d(3, self.embed_dims*4, kernel_size=1, stride=1, padding=0),\n            nn.ReLU(),\n            nn.Conv2d(self.embed_dims*4, self.embed_dims, kernel_size=1, stride=1, padding=0),\n        )\n        self.M_decay_ratio = M_decay_ratio\n        self.enlarge_length = enlarge_length\n\n    def init_weights(self):\n        super().init_weights()\n        for l in self.gflat_pred_layer:\n            xavier_init(l, gain=0.01)\n\n    def pred2M(self, pitch_z):\n        pitch_z = pitch_z / self.M_decay_ratio\n        t = pitch_z[:, 0] / 100\n        z = pitch_z[:, 1]\n        one = torch.ones_like(t)\n        zero = torch.zeros_like(t)\n\n        # rot first, then translate\n        M = torch.stack([\n            one, zero, zero, zero,\n            zero, t.cos(), -t.sin(), zero,\n            zero, t.sin(), t.cos(), z,\n            zero, zero, zero, one], dim=-1).view(t.shape[0], 4, 4)\n        return M\n\n    def forward(self, query, key, value,\n                top_view_region=None, z_region=None,\n                bev_h=None, bev_w=None,\n                init_z=0, img_feats=None,\n                lidar2img=None, pad_shape=None,\n                key_pos=None, key_padding_mask=None,\n                sin_embed=None, reference_points=None,\n                reg_branches=None, cls_branches=None,\n                query_pos=None,\n                **kwargs):\n        assert key_pos is None\n        assert key_padding_mask is None\n\n        # init pts and M to generate pos embed for key/value\n        batch_size = query.shape[1]\n        xmin = top_view_region[0][0] - self.enlarge_length\n        xmax = top_view_region[1][0] + self.enlarge_length\n        ymin = top_view_region[2][1] - self.enlarge_length\n        ymax = top_view_region[0][1] + self.enlarge_length\n        zmin = z_region[0]\n        zmax = z_region[1]\n        init_ref_3d = generate_ref_pt(\n            xmin, ymin, xmax, ymax, init_z,\n            bev_w, bev_h, query.device)\n        init_ref_3d = init_ref_3d[None, ...].repeat(batch_size, 1, 1, 1)\n        ref_3d_homo = F.pad(init_ref_3d, (0, 1), value=1)\n        init_ref_3d_homo = ref_3d_homo.clone()\n        init_M = torch.eye(4, device=query.device).float()\n        M = init_M[None, ...].repeat(batch_size, 1, 1)\n\n        intermediate = []\n        project_results = []\n        outputs_classes = []\n        outputs_coords = []\n        for layer_idx, layer in enumerate(self.layers):\n            coords_img = ground2img(\n                ref_3d_homo, *img_feats[0].shape[-2:],\n                lidar2img, pad_shape)\n            if layer_idx > 0:\n                project_results.append(coords_img.clone())\n            coords_img_key_pos = coords_img.clone()\n            ground_coords = coords_img_key_pos[:, :3, ...]\n            img_mask = coords_img_key_pos[:, -1, ...]\n\n            ground_coords[:, 0, ...] = (ground_coords[:, 0, ...] - xmin) / (xmax - xmin)\n            ground_coords[:, 1, ...] = (ground_coords[:, 1, ...] - ymin) / (ymax - ymin)\n            ground_coords[:, 2, ...] = (ground_coords[:, 2, ...] - zmin) / (zmax - zmin)\n            ground_coords = inverse_sigmoid(ground_coords)\n            key_pos = self.position_encoder(ground_coords)\n\n            query = layer(query, key=key, value=value,\n                          key_pos=(key_pos + sin_embed\n                                  ).flatten(2, 3).permute(2, 0, 1).contiguous(),\n                          reference_points=reference_points,\n                          pc_range=[xmin, ymin, zmin, xmax, ymax, zmax],\n                          pad_shape=pad_shape,\n                          lidar2img=lidar2img,\n                          query_pos=query_pos,\n                          **kwargs)\n\n            # update M\n            if layer_idx < len(self.layers) - 1:\n                input_feat = torch.cat([img_feats[0], coords_img], dim=1)\n                M = M.detach() @ self.pred2M(self.gflat_pred_layer(input_feat).squeeze(-1).squeeze(-1))\n                ref_3d_homo = (init_ref_3d_homo.flatten(1, 2) @ M.permute(0, 2, 1)\n                              ).view(*ref_3d_homo.shape)\n\n            if self.post_norm is not None:\n                intermediate.append(self.post_norm(query))\n            else:\n                intermediate.append(query)\n\n            query = query.permute(1, 0, 2)\n            tmp = reg_branches[layer_idx](query)\n\n            bs = tmp.shape[0]\n            # iterative update\n            tmp = tmp.view(bs, self.num_query,\n                self.num_anchor_per_query, -1, 3)\n            reference_points = reference_points.view(\n                bs, self.num_query, self.num_anchor_per_query,\n                self.num_points_per_anchor, 2\n            )\n            reference_points = inverse_sigmoid(reference_points)\n            new_reference_points = torch.stack([\n                reference_points[..., 0] + tmp[..., 0],\n                reference_points[..., 1] + tmp[..., 1],\n            ], dim=-1)\n            reference_points = new_reference_points.sigmoid()\n\n            cls_feat = query.view(bs, self.num_query, self.num_anchor_per_query, -1)\n            cls_feat = torch.max(cls_feat, dim=2)[0]\n            outputs_class = cls_branches[layer_idx](cls_feat)\n\n            outputs_classes.append(outputs_class)\n            outputs_coords.append(torch.cat([\n                reference_points, tmp[..., -1:]\n            ], dim=-1))\n\n            reference_points = reference_points.view(\n                bs, self.num_query * self.num_anchor_per_query,\n                self.num_points_per_anchor * 2\n            ).detach()\n            query = query.permute(1, 0, 2)\n        return torch.stack(intermediate), project_results, outputs_classes, outputs_coords\n\n\n@TRANSFORMER.register_module()\nclass LATRTransformer(BaseModule):\n    def __init__(self, encoder=None, decoder=None, init_cfg=None):\n        super(LATRTransformer, self).__init__(init_cfg=init_cfg)\n        if encoder is not None:\n            self.encoder = build_transformer_layer_sequence(encoder)\n        else:\n            self.encoder = None\n        self.decoder = build_transformer_layer_sequence(decoder)\n        self.embed_dims = self.decoder.embed_dims\n        self.init_weights()\n\n    def init_weights(self):\n        # follow the official DETR to init parameters\n        for m in self.modules():\n            if hasattr(m, 'weight') and m.weight.dim() > 1:\n                xavier_init(m, distribution='uniform')\n        self._is_init = True\n\n    @staticmethod\n    def get_reference_points(spatial_shapes, valid_ratios, device):\n        \"\"\"Get the reference points used in decoder.\n        Args:\n            spatial_shapes (Tensor): The shape of all\n                feature maps, has shape (num_level, 2).\n            valid_ratios (Tensor): The radios of valid\n                points on the feature map, has shape\n                (bs, num_levels, 2)\n            device (obj:`device`): The device where\n                reference_points should be.\n        Returns:\n            Tensor: reference points used in decoder, has \\\n                shape (bs, num_keys, num_levels, 2).\n        \"\"\"\n        reference_points_list = []\n        for lvl, (H, W) in enumerate(spatial_shapes):\n            #  TODO  check this 0.5\n            ref_y, ref_x = torch.meshgrid(\n                torch.linspace(\n                    0.5, H - 0.5, H, dtype=torch.float32, device=device),\n                torch.linspace(\n                    0.5, W - 0.5, W, dtype=torch.float32, device=device))\n            ref_y = ref_y.reshape(-1)[None] / (valid_ratios[:, None, lvl, 1] * H)\n            ref_x = ref_x.reshape(-1)[None] / (valid_ratios[:, None, lvl, 0] * W)\n            ref = torch.stack((ref_x, ref_y), -1)\n            reference_points_list.append(ref)\n        reference_points = torch.cat(reference_points_list, 1)\n        reference_points = reference_points[:, :, None] * valid_ratios[:, None]\n        return reference_points\n\n    def get_valid_ratio(self, mask):\n        \"\"\"Get the valid radios of feature maps of all  level.\"\"\"\n        _, H, W = mask.shape\n        valid_H = torch.sum(~mask[:, :, 0], 1)\n        valid_W = torch.sum(~mask[:, 0, :], 1)\n        valid_ratio_h = valid_H.float() / H\n        valid_ratio_w = valid_W.float() / W\n        valid_ratio = torch.stack([valid_ratio_w, valid_ratio_h], -1)\n        return valid_ratio\n\n    @property\n    def with_encoder(self):\n        return hasattr(self, 'encoder') and self.encoder\n\n    def forward(self, x, mask, query,\n                query_embed, pos_embed,\n                reference_points=None,\n                reg_branches=None, cls_branches=None,\n                spatial_shapes=None,\n                level_start_index=None,\n                mlvl_masks=None,\n                mlvl_positional_encodings=None,\n                pos_embed2d=None,\n                **kwargs):\n        assert pos_embed is None\n        memory = x\n        # encoder\n        if hasattr(self, 'encoder') and self.encoder:\n            B = x.shape[1]\n            valid_ratios = torch.stack(\n                [self.get_valid_ratio(m) for m in mlvl_masks], 1)\n            reference_points_2d = \\\n                self.get_reference_points(spatial_shapes,\n                                          valid_ratios,\n                                          device=x.device)\n            memory = self.encoder(\n                query=memory,\n                key=memory,\n                value=memory,\n                key_pos=pos_embed2d,\n                query_pos=pos_embed2d,\n                spatial_shapes=spatial_shapes,\n                level_start_index=level_start_index,\n                reference_points=reference_points_2d,\n                valid_ratios=valid_ratios,\n            )\n\n        query_embed = query_embed.permute(1, 0, 2)\n        if mask is not None:\n            mask = mask.view(bs, -1)  # [bs, n, h, w] -> [bs, n*h*w] (n=1)\n        target = query.permute(1, 0, 2)\n\n        # out_dec: [num_layers, num_query, bs, dim]\n        out_dec, project_results, outputs_classes, outputs_coords = \\\n            self.decoder(\n                query=target,\n                key=memory,\n                value=memory,\n                key_pos=pos_embed,\n                query_pos=query_embed,\n                key_padding_mask=mask.astype(torch.bool) if mask is not None else None,\n                reg_branches=reg_branches,\n                cls_branches=cls_branches,\n                reference_points=reference_points,\n                spatial_shapes=spatial_shapes,\n                level_start_index=level_start_index,\n                **kwargs\n            )\n        return out_dec.permute(0, 2, 1, 3), project_results, \\\n               outputs_classes, outputs_coords"
  },
  {
    "path": "models/utils.py",
    "content": "import torch\nimport geffnet\nimport torch.nn as nn\n\n\ndef inverse_sigmoid(x, eps=1e-5):\n    \"\"\"Inverse function of sigmoid.\n\n    Args:\n        x (Tensor): The tensor to do the\n            inverse.\n        eps (float): EPS avoid numerical\n            overflow. Defaults 1e-5.\n    Returns:\n        Tensor: The x has passed the inverse\n            function of sigmoid, has same\n            shape with input.\n    \"\"\"\n    x = x.clamp(min=0, max=1)\n    x1 = x.clamp(min=eps)\n    x2 = (1 - x).clamp(min=eps)\n    return torch.log(x1 / x2)\n\n\nclass deepFeatureExtractor_EfficientNet(nn.Module):\n    def __init__(self, architecture=\"EfficientNet-B5\", lv6=False, lv5=False, lv4=False, lv3=False):\n        super(deepFeatureExtractor_EfficientNet, self).__init__()\n        assert architecture in [\"EfficientNet-B0\", \"EfficientNet-B1\", \"EfficientNet-B2\", \"EfficientNet-B3\", \n                                    \"EfficientNet-B4\", \"EfficientNet-B5\", \"EfficientNet-B6\", \"EfficientNet-B7\"]\n        \n        if architecture == \"EfficientNet-B0\":\n            self.encoder = geffnet.tf_efficientnet_b0_ns(pretrained=True)\n            self.dimList = [16, 24, 40, 112, 1280] #5th feature is extracted after conv_head or bn2\n            #self.dimList = [16, 24, 40, 112, 320] #5th feature is extracted after blocks[6]\n        elif architecture == \"EfficientNet-B1\":\n            self.encoder = geffnet.tf_efficientnet_b1_ns(pretrained=True)\n            self.dimList = [16, 24, 40, 112, 1280] #5th feature is extracted after conv_head or bn2\n            #self.dimList = [16, 24, 40, 112, 320] #5th feature is extracted after blocks[6]\n        elif architecture == \"EfficientNet-B2\":\n            self.encoder = geffnet.tf_efficientnet_b2_ns(pretrained=True)\n            self.dimList = [16, 24, 48, 120, 1408] #5th feature is extracted after conv_head or bn2\n            #self.dimList = [16, 24, 48, 120, 352] #5th feature is extracted after blocks[6]\n        elif architecture == \"EfficientNet-B3\":\n            self.encoder = geffnet.tf_efficientnet_b3_ns(pretrained=True)\n            self.dimList = [24, 32, 48, 136, 1536] #5th feature is extracted after conv_head or bn2\n            #self.dimList = [24, 32, 48, 136, 384] #5th feature is extracted after blocks[6]\n        elif architecture == \"EfficientNet-B4\":\n            self.encoder = geffnet.tf_efficientnet_b4_ns(pretrained=True)\n            self.dimList = [24, 32, 56, 160, 1792] #5th feature is extracted after conv_head or bn2\n            #self.dimList = [24, 32, 56, 160, 448] #5th feature is extracted after blocks[6]\n        elif architecture == \"EfficientNet-B5\":\n            self.encoder = geffnet.tf_efficientnet_b5_ns(pretrained=True)\n            self.dimList = [24, 40, 64, 176, 2048] #5th feature is extracted after conv_head or bn2\n            #self.dimList = [24, 40, 64, 176, 512] #5th feature is extracted after blocks[6]\n        elif architecture == \"EfficientNet-B6\":\n            self.encoder = geffnet.tf_efficientnet_b6_ns(pretrained=True)\n            self.dimList = [32, 40, 72, 200, 2304] #5th feature is extracted after conv_head or bn2\n            #self.dimList = [32, 40, 72, 200, 576] #5th feature is extracted after blocks[6]\n        elif architecture == \"EfficientNet-B7\":\n            self.encoder = geffnet.tf_efficientnet_b7_ns(pretrained=True)\n            self.dimList = [32, 48, 80, 224, 2560] #5th feature is extracted after conv_head or bn2\n            #self.dimList = [32, 48, 80, 224, 640] #5th feature is extracted after blocks[6]\n        del self.encoder.global_pool\n        del self.encoder.classifier\n        #self.block_idx = [3, 4, 5, 7, 9] #5th feature is extracted after blocks[6]\n        #self.block_idx = [3, 4, 5, 7, 10] #5th feature is extracted after conv_head\n        self.block_idx = [3, 4, 5, 7, 11] #5th feature is extracted after bn2\n        if lv6 is False:\n            del self.encoder.blocks[6]\n            del self.encoder.conv_head\n            del self.encoder.bn2\n            del self.encoder.act2\n            self.block_idx = self.block_idx[:4]\n            self.dimList = self.dimList[:4]\n        if lv5 is False:\n            del self.encoder.blocks[5]\n            self.block_idx = self.block_idx[:3]\n            self.dimList = self.dimList[:3]\n        if lv4 is False:\n            del self.encoder.blocks[4]\n            self.block_idx = self.block_idx[:2]\n            self.dimList = self.dimList[:2]\n        if lv3 is False:\n            del self.encoder.blocks[3]\n            self.block_idx = self.block_idx[:1]\n            self.dimList = self.dimList[:1]\n        # after passing blocks[3]    : H/2  x W/2\n        # after passing blocks[4]    : H/4  x W/4\n        # after passing blocks[5]    : H/8  x W/8\n        # after passing blocks[7]    : H/16 x W/16\n        # after passing conv_stem    : H/32 x W/32\n        self.fixList = ['blocks.0.0','bn']\n\n        for name, parameters in self.encoder.named_parameters():\n            if name == 'conv_stem.weight':\n                parameters.requires_grad = False\n            if any(x in name for x in self.fixList):\n                parameters.requires_grad = False\n        \n    def forward(self, x):\n        out_featList = []\n        feature = x\n        cnt = 0\n        block_cnt = 0\n        for k, v in self.encoder._modules.items():\n            if k == 'act2':\n                break\n            if k == 'blocks':\n                for m, n in v._modules.items():\n                    feature = n(feature)\n                    try:\n                        if self.block_idx[block_cnt] == cnt:\n                            out_featList.append(feature)\n                            block_cnt += 1\n                            break\n                        cnt += 1\n                    except:\n                        continue\n            else:\n                feature = v(feature)\n                if self.block_idx[block_cnt] == cnt:\n                    out_featList.append(feature)\n                    block_cnt += 1\n                    break\n                cnt += 1            \n            \n        return out_featList\n\n    def freeze_bn(self, enable=False):\n        \"\"\" Adapted from https://discuss.pytorch.org/t/how-to-train-with-frozen-batchnorm/12106/8 \"\"\"\n        for module in self.modules():\n            if isinstance(module, nn.BatchNorm2d):\n                module.train() if enable else module.eval()\n                module.weight.requires_grad = enable\n                module.bias.requires_grad = enable\n"
  },
  {
    "path": "pretrained_models/.gitkeep",
    "content": ""
  },
  {
    "path": "requirements.txt",
    "content": "opencv_python==4.2.0.32\nShapely==1.7.0\nxmljson==0.2.0\nthop==0.0.31.post2005241907\nmatplotlib==3.5.1\nscipy==1.4.1\np_tqdm==1.3.3\nlxml==4.5.0\ntqdm==4.43.0\nujson==1.35\nPyYAML==5.3.1\nscikit_learn==0.23.2\ntensorboard==2.3.0\ngdown==4.4.0\nortools==9.2.9972\ngeffnet==1.0.2\ntensorboardX==2.5"
  },
  {
    "path": "utils/MinCostFlow.py",
    "content": "# ==============================================================================\n# Binaries and/or source for the following packages or projects are presented under one or more of the following open\n# source licenses:\n# MinCostFlow.py       The PersFormer Authors        Apache License, Version 2.0\n#\n# Contact simachonghao@pjlab.org.cn if you have any issue\n# \n# See:\n# https://github.com/yuliangguo/Pytorch_Generalized_3D_Lane_Detection/blob/master/tools/MinCostFlow.py\n#\n# Copyright (c) 2022 The PersFormer Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\nfrom __future__ import print_function\nimport numpy as np\nfrom ortools.graph import pywrapgraph\nimport time\n\n\ndef SolveMinCostFlow(adj_mat, cost_mat):\n    \"\"\"\n        Solving an Assignment Problem with MinCostFlow\"\n    :param adj_mat: adjacency matrix with binary values indicating possible matchings between two sets\n    :param cost_mat: cost matrix recording the matching cost of every possible pair of items from two sets\n    :return:\n    \"\"\"\n\n    # Instantiate a SimpleMinCostFlow solver.\n    min_cost_flow = pywrapgraph.SimpleMinCostFlow()\n    # Define the directed graph for the flow.\n\n    cnt_1, cnt_2 = adj_mat.shape\n    cnt_nonzero_row = int(np.sum(np.sum(adj_mat, axis=1) > 0))\n    cnt_nonzero_col = int(np.sum(np.sum(adj_mat, axis=0) > 0))\n\n    # prepare directed graph for the flow\n    start_nodes = np.zeros(cnt_1, dtype=np.int).tolist() +\\\n                  np.repeat(np.array(range(1, cnt_1+1)), cnt_2).tolist() + \\\n                  [i for i in range(cnt_1+1, cnt_1 + cnt_2 + 1)]\n    end_nodes = [i for i in range(1, cnt_1+1)] + \\\n                np.repeat(np.array([i for i in range(cnt_1+1, cnt_1 + cnt_2 + 1)]).reshape([1, -1]), cnt_1, axis=0).flatten().tolist() + \\\n                [cnt_1 + cnt_2 + 1 for i in range(cnt_2)]\n    capacities = np.ones(cnt_1, dtype=np.int).tolist() + adj_mat.flatten().astype(np.int).tolist() + np.ones(cnt_2, dtype=np.int).tolist()\n    costs = (np.zeros(cnt_1, dtype=np.int).tolist() + cost_mat.flatten().astype(np.int).tolist() + np.zeros(cnt_2, dtype=np.int).tolist())\n    # Define an array of supplies at each node.\n    supplies = [min(cnt_nonzero_row, cnt_nonzero_col)] + np.zeros(cnt_1 + cnt_2, dtype=np.int).tolist() + [-min(cnt_nonzero_row, cnt_nonzero_col)]\n    # supplies = [min(cnt_1, cnt_2)] + np.zeros(cnt_1 + cnt_2, dtype=np.int).tolist() + [-min(cnt_1, cnt_2)]\n    source = 0\n    sink = cnt_1 + cnt_2 + 1\n\n    # Add each arc.\n    for i in range(len(start_nodes)):\n        min_cost_flow.AddArcWithCapacityAndUnitCost(start_nodes[i], end_nodes[i],\n                                                    capacities[i], costs[i])\n\n    # Add node supplies.\n    for i in range(len(supplies)):\n        min_cost_flow.SetNodeSupply(i, supplies[i])\n\n    match_results = []\n    # Find the minimum cost flow between node 0 and node 10.\n    if min_cost_flow.Solve() == min_cost_flow.OPTIMAL:\n        # print('Total cost = ', min_cost_flow.OptimalCost())\n        # print()\n        for arc in range(min_cost_flow.NumArcs()):\n\n            # Can ignore arcs leading out of source or into sink.\n            if min_cost_flow.Tail(arc)!=source and min_cost_flow.Head(arc)!=sink:\n\n                # Arcs in the solution have a flow value of 1. Their start and end nodes\n                # give an assignment of worker to task.\n\n                if min_cost_flow.Flow(arc) > 0:\n                    # print('set A item %d assigned to set B item %d.  Cost = %d' % (\n                    #     min_cost_flow.Tail(arc)-1,\n                    #     min_cost_flow.Head(arc)-cnt_1-1,\n                    #     min_cost_flow.UnitCost(arc)))\n                    match_results.append([min_cost_flow.Tail(arc)-1,\n                                          min_cost_flow.Head(arc)-cnt_1-1,\n                                          min_cost_flow.UnitCost(arc)])\n    else:\n        print('There was an issue with the min cost flow input.')\n\n    return match_results\n\n\ndef main():\n    \"\"\"Solving an Assignment Problem with MinCostFlow\"\"\"\n\n    # Instantiate a SimpleMinCostFlow solver.\n    min_cost_flow = pywrapgraph.SimpleMinCostFlow()\n    # Define the directed graph for the flow.\n\n    start_nodes = [0, 0, 0, 0] + [1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4] + [5, 6, 7, 8]\n    end_nodes = [1, 2, 3, 4] + [5, 6, 7, 8, 5, 6, 7, 8, 5, 6, 7, 8, 5, 6, 7, 8] + [9, 9, 9, 9]\n    capacities = [1, 1, 1, 1] + [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1] + [1, 1, 1, 1]\n    costs = ([0, 0, 0, 0] + [90, 76, 75, 70, 35, 85, 55, 65, 125, 95, 90, 105, 45, 110, 95, 115] + [0, 0, 0, 0])\n    # Define an array of supplies at each node.\n    supplies = [4, 0, 0, 0, 0, 0, 0, 0, 0, -4]\n    source = 0\n    sink = 9\n    tasks = 4\n\n    # Add each arc.\n    for i in range(len(start_nodes)):\n        min_cost_flow.AddArcWithCapacityAndUnitCost(start_nodes[i], end_nodes[i],\n                                                    capacities[i], costs[i])\n\n    # Add node supplies.\n\n    for i in range(len(supplies)):\n        min_cost_flow.SetNodeSupply(i, supplies[i])\n    # Find the minimum cost flow between node 0 and node 10.\n    if min_cost_flow.Solve() == min_cost_flow.OPTIMAL:\n        print('Total cost = ', min_cost_flow.OptimalCost())\n        print()\n        for arc in range(min_cost_flow.NumArcs()):\n\n            # Can ignore arcs leading out of source or into sink.\n            if min_cost_flow.Tail(arc)!=source and min_cost_flow.Head(arc)!=sink:\n\n                # Arcs in the solution have a flow value of 1. Their start and end nodes\n                # give an assignment of worker to task.\n\n                if min_cost_flow.Flow(arc) > 0:\n                    print('Worker %d assigned to task %d.  Cost = %d' % (\n                        min_cost_flow.Tail(arc),\n                        min_cost_flow.Head(arc),\n                        min_cost_flow.UnitCost(arc)))\n    else:\n        print('There was an issue with the min cost flow input.')\n\n\nif __name__ == '__main__':\n    start_time = time.clock()\n    main()\n    print()\n    print(\"Time =\", time.clock() - start_time, \"seconds\")"
  },
  {
    "path": "utils/__init__.py",
    "content": ""
  },
  {
    "path": "utils/eval_3D_lane.py",
    "content": "# ==============================================================================\n# Binaries and/or source for the following packages or projects are presented under one or more of the following open\n# source licenses:\n# eval_3D_lane.py       The OpenLane Dataset Authors        Apache License, Version 2.0\n#\n# Contact simachonghao@pjlab.org.cn if you have any issue\n# \n# See:\n# https://github.com/yuliangguo/Pytorch_Generalized_3D_Lane_Detection/blob/master/tools/eval_3D_lane.py\n#\n# Copyright (c) 2022 The OpenLane Dataset Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\n\"\"\"\nDescription: This code is to evaluate 3D lane detection. The optimal matching between ground-truth set and predicted \n    set of lanes are sought via solving a min cost flow.\n\nEvaluation metrics includes:\n    F-scores\n    x error close (0 - 40 m)\n    x error far (0 - 100 m)\n    z error close (0 - 40 m)\n    z error far (0 - 100 m)\n\"\"\"\n\nfrom copy import deepcopy\nimport numpy as np\nfrom utils.utils import *\nfrom utils.MinCostFlow import SolveMinCostFlow\n\nclass LaneEval(object):\n    def __init__(self, args, logger):        \n        self.dataset_name = args.dataset_name\n        self.dataset_dir = args.dataset_dir\n\n        self.x_min = args.top_view_region[0, 0]\n        self.x_max = args.top_view_region[1, 0]\n        self.y_min = args.top_view_region[2, 1]\n        self.y_max = args.top_view_region[0, 1]\n        self.y_samples = np.linspace(self.y_min, self.y_max, num=100, endpoint=False)\n        self.dist_th = 1.5\n        self.ratio_th = 0.75\n        self.close_range = 40\n\n        self.is_apollo = 'apollo' in self.dataset_name\n        self.args = args\n        self.logger = logger\n\n    def bench(self, pred_lanes, pred_category, gt_lanes, gt_visibility, gt_category, raw_file, gt_cam_height, gt_cam_pitch, vis, P_g2im=None):\n        \"\"\"\n            Matching predicted lanes and ground-truth lanes in their IPM projection, ignoring z attributes.\n            x error, y_error, and z error are all considered, although the matching does not rely on z\n            The input of prediction and ground-truth lanes are in ground coordinate, x-right, y-forward, z-up\n            The fundamental assumption is: 1. there are no two points from different lanes with identical x, y\n                                              but different z's\n                                           2. there are no two points from a single lane having identical x, y\n                                              but different z's\n            If the interest area is within the current drivable road, the above assumptions are almost always valid.\n\n        :param pred_lanes: N X 2 or N X 3 lists depending on 2D or 3D\n        :param gt_lanes: N X 2 or N X 3 lists depending on 2D or 3D\n        :param raw_file: file path rooted in dataset folder\n        :param gt_cam_height: camera height given in ground-truth data\n        :param gt_cam_pitch: camera pitch given in ground-truth data\n        :return:\n        \"\"\"\n        \n        # change this properly\n        close_range_idx = np.where(self.y_samples > self.close_range)[0][0]\n\n        r_lane, p_lane, c_lane = 0., 0., 0.\n        x_error_close = []\n        x_error_far = []\n        z_error_close = []\n        z_error_far = []\n\n        # only keep the visible portion\n        gt_lanes = [prune_3d_lane_by_visibility(np.array(gt_lane), np.array(gt_visibility[k])) for k, gt_lane in\n                    enumerate(gt_lanes)]\n        if 'openlane' in self.dataset_name:\n            gt_category = [gt_category[k] for k, lane in enumerate(gt_lanes) if lane.shape[0] > 1]\n        gt_lanes = [lane for lane in gt_lanes if lane.shape[0] > 1]\n\n        # # only consider those pred lanes overlapping with sampling range\n        pred_category = [pred_category[k] for k, lane in enumerate(pred_lanes)\n                        if np.array(lane)[0, 1] < self.y_samples[-1] and np.array(lane)[-1, 1] > self.y_samples[0]]\n        pred_lanes = [lane for lane in pred_lanes if np.array(lane)[0, 1] < self.y_samples[-1] and np.array(lane)[-1, 1] > self.y_samples[0]]\n\n        pred_lanes = [prune_3d_lane_by_range(np.array(lane), self.x_min, self.x_max) for lane in pred_lanes]\n\n        pred_category = [pred_category[k] for k, lane in enumerate(pred_lanes) if np.array(lane).shape[0] > 1]\n        pred_lanes = [lane for lane in pred_lanes if np.array(lane).shape[0] > 1]\n\n        # only consider those gt lanes overlapping with sampling range\n        gt_category = [gt_category[k] for k, lane in enumerate(gt_lanes)\n                        if lane[0, 1] < self.y_samples[-1] and lane[-1, 1] > self.y_samples[0]]\n        gt_lanes = [lane for lane in gt_lanes if lane[0, 1] < self.y_samples[-1] and lane[-1, 1] > self.y_samples[0]]\n\n        gt_lanes = [prune_3d_lane_by_range(np.array(lane), self.x_min, self.x_max) for lane in gt_lanes]\n\n        gt_category = [gt_category[k] for k, lane in enumerate(gt_lanes) if lane.shape[0] > 1]\n        gt_lanes = [lane for lane in gt_lanes if lane.shape[0] > 1]\n\n        cnt_gt = len(gt_lanes)\n        cnt_pred = len(pred_lanes)\n\n        gt_visibility_mat = np.zeros((cnt_gt, 100))\n        pred_visibility_mat = np.zeros((cnt_pred, 100))\n\n        # resample gt and pred at y_samples\n        for i in range(cnt_gt):\n            min_y = np.min(np.array(gt_lanes[i])[:, 1])\n            max_y = np.max(np.array(gt_lanes[i])[:, 1])\n            x_values, z_values, visibility_vec = resample_laneline_in_y(np.array(gt_lanes[i]), self.y_samples, out_vis=True)\n            gt_lanes[i] = np.vstack([x_values, z_values]).T\n            gt_visibility_mat[i, :] = np.logical_and(x_values >= self.x_min, np.logical_and(x_values <= self.x_max,\n                                                     np.logical_and(self.y_samples >= min_y, self.y_samples <= max_y)))\n            gt_visibility_mat[i, :] = np.logical_and(gt_visibility_mat[i, :], visibility_vec)\n\n        for i in range(cnt_pred):\n            # # ATTENTION: ensure y mono increase before interpolation: but it can reduce size\n            # pred_lanes[i] = make_lane_y_mono_inc(np.array(pred_lanes[i]))\n            # pred_lane = prune_3d_lane_by_range(np.array(pred_lanes[i]), self.x_min, self.x_max)\n            min_y = np.min(np.array(pred_lanes[i])[:, 1])\n            max_y = np.max(np.array(pred_lanes[i])[:, 1])\n            x_values, z_values, visibility_vec = resample_laneline_in_y(np.array(pred_lanes[i]), self.y_samples, out_vis=True)\n            pred_lanes[i] = np.vstack([x_values, z_values]).T\n            pred_visibility_mat[i, :] = np.logical_and(x_values >= self.x_min, np.logical_and(x_values <= self.x_max,\n                                                       np.logical_and(self.y_samples >= min_y, self.y_samples <= max_y)))\n            pred_visibility_mat[i, :] = np.logical_and(pred_visibility_mat[i, :], visibility_vec)\n            # pred_visibility_mat[i, :] = np.logical_and(x_values >= self.x_min, x_values <= self.x_max)\n\n        # at least two-points for both gt and pred\n        gt_lanes = [gt_lanes[k] for k in range(cnt_gt) if np.sum(gt_visibility_mat[k, :]) > 1]\n        gt_category = [gt_category[k] for k in range(cnt_gt) if np.sum(gt_visibility_mat[k, :]) > 1]\n        gt_visibility_mat = gt_visibility_mat[np.sum(gt_visibility_mat, axis=-1) > 1, :]\n        cnt_gt = len(gt_lanes)\n\n        pred_lanes = [pred_lanes[k] for k in range(cnt_pred) if np.sum(pred_visibility_mat[k, :]) > 1]\n        pred_category = [pred_category[k] for k in range(cnt_pred) if np.sum(pred_visibility_mat[k, :]) > 1]\n        pred_visibility_mat = pred_visibility_mat[np.sum(pred_visibility_mat, axis=-1) > 1, :]\n        cnt_pred = len(pred_lanes)\n\n        adj_mat = np.zeros((cnt_gt, cnt_pred), dtype=int)\n        cost_mat = np.zeros((cnt_gt, cnt_pred), dtype=int)\n        cost_mat.fill(1000)\n        num_match_mat = np.zeros((cnt_gt, cnt_pred), dtype=float)\n        x_dist_mat_close = np.zeros((cnt_gt, cnt_pred), dtype=float)\n        x_dist_mat_close.fill(1000.)\n        x_dist_mat_far = np.zeros((cnt_gt, cnt_pred), dtype=float)\n        x_dist_mat_far.fill(1000.)\n        z_dist_mat_close = np.zeros((cnt_gt, cnt_pred), dtype=float)\n        z_dist_mat_close.fill(1000.)\n        z_dist_mat_far = np.zeros((cnt_gt, cnt_pred), dtype=float)\n        z_dist_mat_far.fill(1000.)\n\n        # compute curve to curve distance\n        for i in range(cnt_gt):\n            for j in range(cnt_pred):\n                x_dist = np.abs(gt_lanes[i][:, 0] - pred_lanes[j][:, 0])\n                z_dist = np.abs(gt_lanes[i][:, 1] - pred_lanes[j][:, 1])\n\n                # apply visibility to penalize different partial matching accordingly\n                both_visible_indices = np.logical_and(gt_visibility_mat[i, :] >= 0.5, pred_visibility_mat[j, :] >= 0.5)\n                both_invisible_indices = np.logical_and(gt_visibility_mat[i, :] < 0.5, pred_visibility_mat[j, :] < 0.5)\n                other_indices = np.logical_not(np.logical_or(both_visible_indices, both_invisible_indices))\n                \n                euclidean_dist = np.sqrt(x_dist ** 2 + z_dist ** 2)\n                euclidean_dist[both_invisible_indices] = 0\n                euclidean_dist[other_indices] = self.dist_th\n\n                # if np.average(euclidean_dist) < 2*self.dist_th: # don't prune here to encourage finding perfect match\n                num_match_mat[i, j] = np.sum(euclidean_dist < self.dist_th) - np.sum(both_invisible_indices)\n                adj_mat[i, j] = 1\n                # ATTENTION: use the sum as int type to meet the requirements of min cost flow optimization (int type)\n                # using num_match_mat as cost does not work?\n                cost_ = np.sum(euclidean_dist)\n                if cost_<1 and cost_>0:\n                    cost_ = 1\n                else:\n                    cost_ = (cost_).astype(int)\n                cost_mat[i, j] = cost_\n\n                # use the both visible portion to calculate distance error\n                if np.sum(both_visible_indices[:close_range_idx]) > 0:\n                    x_dist_mat_close[i, j] = np.sum(\n                        x_dist[:close_range_idx] * both_visible_indices[:close_range_idx]) / np.sum(\n                        both_visible_indices[:close_range_idx])\n                    z_dist_mat_close[i, j] = np.sum(\n                        z_dist[:close_range_idx] * both_visible_indices[:close_range_idx]) / np.sum(\n                        both_visible_indices[:close_range_idx])\n                else:\n                    x_dist_mat_close[i, j] = -1\n                    z_dist_mat_close[i, j] = -1\n                    \n\n                if np.sum(both_visible_indices[close_range_idx:]) > 0:\n                    x_dist_mat_far[i, j] = np.sum(\n                        x_dist[close_range_idx:] * both_visible_indices[close_range_idx:]) / np.sum(\n                        both_visible_indices[close_range_idx:])\n                    z_dist_mat_far[i, j] = np.sum(\n                        z_dist[close_range_idx:] * both_visible_indices[close_range_idx:]) / np.sum(\n                        both_visible_indices[close_range_idx:])\n                else:\n                    x_dist_mat_far[i, j] = -1\n                    z_dist_mat_far[i, j] = -1\n\n        # solve bipartite matching vis min cost flow solver\n        match_results = SolveMinCostFlow(adj_mat, cost_mat)\n        match_results = np.array(match_results)\n\n        # only a match with avg cost < self.dist_th is consider valid one\n        match_gt_ids = []\n        match_pred_ids = []\n        match_num = 0\n        if match_results.shape[0] > 0:\n            for i in range(len(match_results)):\n                if match_results[i, 2] < self.dist_th * self.y_samples.shape[0]:\n                    match_num += 1\n                    gt_i = match_results[i, 0]\n                    pred_i = match_results[i, 1]\n                    # consider match when the matched points is above a ratio\n                    if num_match_mat[gt_i, pred_i] / np.sum(gt_visibility_mat[gt_i, :]) >= self.ratio_th:\n                        r_lane += 1\n                        match_gt_ids.append(gt_i)\n                    if num_match_mat[gt_i, pred_i] / np.sum(pred_visibility_mat[pred_i, :]) >= self.ratio_th:\n                        p_lane += 1\n                        match_pred_ids.append(pred_i)\n\n                    if pred_category != []:\n                        if self.args.num_category == 2:\n                            true_pred = pred_category[pred_i] == 1 \n                        else:\n                            true_pred = (pred_category[pred_i] == gt_category[gt_i]) or (pred_category[pred_i]==20 and gt_category[gt_i]==21)\n                        if true_pred:\n                            c_lane += 1    # category matched num\n                    x_error_close.append(x_dist_mat_close[gt_i, pred_i])\n                    x_error_far.append(x_dist_mat_far[gt_i, pred_i])\n                    z_error_close.append(z_dist_mat_close[gt_i, pred_i])\n                    z_error_far.append(z_dist_mat_far[gt_i, pred_i])\n        # # Visulization to be added\n        # if vis:\n        #     pass \n        return r_lane, p_lane, c_lane, cnt_gt, cnt_pred, match_num, x_error_close, x_error_far, z_error_close, z_error_far\n\n\n    def bench_one_submit(self, pred_dir, gt_dir, test_txt, prob_th=0.5, vis=False):\n        pred_lines = open(test_txt).readlines()\n        gt_lines = pred_lines\n\n        json_pred = []\n        json_gt = []\n\n        print(\"Loading pred json ...\")\n        for pred_file_path in pred_lines:\n            pred_lines = pred_dir + pred_file_path.strip('\\n').replace('jpg','json')\n\n            with open(pred_lines,'r') as fp:\n                json_pred.append(json.load(fp))\n\n        print(\"Loading gt json ...\")\n        for gt_file_path in gt_lines:\n            gt_lines = gt_dir + gt_file_path.strip('\\n').replace('jpg','json')\n\n            with open(gt_lines,'r') as fp:\n                json_gt.append(json.load(fp))\n                \n        if len(json_gt) != len(json_pred):\n            raise Exception('We do not get the predictions of all the test tasks')\n\n        gts = {l['file_path']: l for l in json_gt}\n\n        laneline_stats = []\n        laneline_x_error_close = []\n        laneline_x_error_far = []\n        laneline_z_error_close = []\n        laneline_z_error_far = []\n        for i, pred in enumerate(json_pred):\n            if i % 1000 == 0 or i == len(json_pred)-1:\n                self.logger.info('eval:{}/{}'.format(i+1,len(json_pred)))\n            if 'file_path' not in pred or 'lane_lines' not in pred:\n                raise Exception('file_path or lane_lines not in some predictions.')\n            raw_file = pred['file_path']\n\n            pred_lanelines = pred['lane_lines']\n            pred_lanes = [np.array(lane['xyz']) for i, lane in enumerate(pred_lanelines)]\n            pred_category = [int(lane['category']) for i, lane in enumerate(pred_lanelines)]\n            \n            if raw_file not in gts:\n                raise Exception('Some raw_file from your predictions do not exist in the test tasks.')\n            gt = gts[raw_file]\n\n            # evaluate lanelines\n            cam_extrinsics = np.array(gt['extrinsic'])\n            # Re-calculate extrinsic matrix based on ground coordinate\n            R_vg = np.array([[0, 1, 0],\n                                [-1, 0, 0],\n                                [0, 0, 1]], dtype=float)\n            R_gc = np.array([[1, 0, 0],\n                                [0, 0, 1],\n                                [0, -1, 0]], dtype=float)\n            cam_extrinsics[:3, :3] = np.matmul(np.matmul(\n                                        np.matmul(np.linalg.inv(R_vg), cam_extrinsics[:3, :3]),\n                                            R_vg), R_gc)\n            gt_cam_height = cam_extrinsics[2, 3]\n            gt_cam_pitch = 0\n\n            cam_extrinsics[0:2, 3] = 0.0\n            # cam_extrinsics[2, 3] = gt_cam_height\n\n            cam_intrinsics = gt['intrinsic']\n            cam_intrinsics = np.array(cam_intrinsics)\n\n            try:\n                gt_lanes_packed = gt['lane_lines']\n            except:\n                print(\"error 'lane_lines' in gt: \", gt['file_path'])\n\n            gt_lanes, gt_visibility, gt_category = [], [], []\n            for j, gt_lane_packed in enumerate(gt_lanes_packed):\n                # A GT lane can be either 2D or 3D\n                # if a GT lane is 3D, the height is intact from 3D GT, so keep it intact here too\n                lane = np.array(gt_lane_packed['xyz'])\n                lane_visibility = np.array(gt_lane_packed['visibility'])\n\n                lane = np.vstack((lane, np.ones((1, lane.shape[1]))))\n                cam_representation = np.linalg.inv(\n                                        np.array([[0, 0, 1, 0],\n                                                  [-1, 0, 0, 0],\n                                                  [0, -1, 0, 0],\n                                                  [0, 0, 0, 1]], dtype=float))\n                lane = np.matmul(cam_extrinsics, np.matmul(cam_representation, lane))\n                lane = lane[0:3, :].T\n\n                gt_lanes.append(lane)\n                gt_visibility.append(lane_visibility)\n                gt_category.append(gt_lane_packed['category'])\n            if is_apollo:\n                P_g2im = projection_g2im()\n            else:\n                P_g2im = projection_g2im_extrinsic(cam_extrinsics, cam_intrinsics)\n\n            # N to N matching of lanelines\n            r_lane, p_lane, c_lane, cnt_gt, cnt_pred, match_num, \\\n            x_error_close, x_error_far, \\\n            z_error_close, z_error_far = self.bench(pred_lanes,\n                                                    pred_category, \n                                                    gt_lanes,\n                                                    gt_visibility,\n                                                    gt_category,\n                                                    raw_file,\n                                                    gt_cam_height,\n                                                    gt_cam_pitch,\n                                                    vis,\n                                                    P_g2im)\n            laneline_stats.append(np.array([r_lane, p_lane, c_lane, cnt_gt, cnt_pred, match_num]))\n\n            laneline_x_error_close.extend(x_error_close)\n            laneline_x_error_far.extend(x_error_far)\n            laneline_z_error_close.extend(z_error_close)\n            laneline_z_error_far.extend(z_error_far)\n\n\n        output_stats = []\n        laneline_stats = np.array(laneline_stats)\n        laneline_x_error_close = np.array(laneline_x_error_close)\n        laneline_x_error_far = np.array(laneline_x_error_far)\n        laneline_z_error_close = np.array(laneline_z_error_close)\n        laneline_z_error_far = np.array(laneline_z_error_far)\n\n        if np.sum(laneline_stats[:, 3])!= 0:\n            R_lane = np.sum(laneline_stats[:, 0]) / (np.sum(laneline_stats[:, 3]))\n        else:\n            R_lane = np.sum(laneline_stats[:, 0]) / (np.sum(laneline_stats[:, 3]) + 1e-6)   # recall = TP / (TP+FN)\n        if np.sum(laneline_stats[:, 4]) != 0:\n            P_lane = np.sum(laneline_stats[:, 1]) / (np.sum(laneline_stats[:, 4]))\n        else:\n            P_lane = np.sum(laneline_stats[:, 1]) / (np.sum(laneline_stats[:, 4]) + 1e-6)   # precision = TP / (TP+FP)\n        if np.sum(laneline_stats[:, 5]) != 0:\n            C_lane = np.sum(laneline_stats[:, 2]) / (np.sum(laneline_stats[:, 5]))\n        else:\n            C_lane = np.sum(laneline_stats[:, 2]) / (np.sum(laneline_stats[:, 5]) + 1e-6)   # category_accuracy\n        if R_lane + P_lane != 0:\n            F_lane = 2 * R_lane * P_lane / (R_lane + P_lane)\n        else:\n            F_lane = 2 * R_lane * P_lane / (R_lane + P_lane + 1e-6)\n        x_error_close_avg = np.average(laneline_x_error_close[laneline_x_error_close > -1 + 1e-6])\n        x_error_far_avg = np.average(laneline_x_error_far[laneline_x_error_far > -1 + 1e-6])\n        z_error_close_avg = np.average(laneline_z_error_close[laneline_z_error_close > -1 + 1e-6])\n        z_error_far_avg = np.average(laneline_z_error_far[laneline_z_error_far > -1 + 1e-6])\n\n        output_stats.append(F_lane)\n        output_stats.append(R_lane)\n        output_stats.append(P_lane)\n        output_stats.append(C_lane)\n        output_stats.append(x_error_close_avg)\n        output_stats.append(x_error_far_avg)\n        output_stats.append(z_error_close_avg)\n        output_stats.append(z_error_far_avg)\n        output_stats.append(np.sum(laneline_stats[:, 0]))   # 8\n        output_stats.append(np.sum(laneline_stats[:, 1]))   # 9\n        output_stats.append(np.sum(laneline_stats[:, 2]))   # 10\n        output_stats.append(np.sum(laneline_stats[:, 3]))   # 11\n        output_stats.append(np.sum(laneline_stats[:, 4]))   # 12\n        output_stats.append(np.sum(laneline_stats[:, 5]))   # 13\n\n        return output_stats\n\n    \n    # compare predicted set and ground-truth set using a fixed lane probability threshold\n    def bench_one_submit_ddp(self, pred_lines_sub, gt_lines_sub, model_name, prob_th=0.5, vis=False):\n        json_gt = gt_lines_sub\n        json_pred = pred_lines_sub\n\n        gts = {l['file_path']: l for l in json_gt}\n\n        laneline_stats = []\n        laneline_x_error_close = []\n        laneline_x_error_far = []\n        laneline_z_error_close = []\n        laneline_z_error_far = []\n\n        gt_num_all, pred_num_all = 0, 0\n        for i, pred in enumerate(json_pred):\n            if 'file_path' not in pred or 'pred_laneLines' not in pred:\n                raise Exception('file_path or lane_lines not in some predictions.')\n            raw_file = pred['file_path']\n\n            pred_lanes = pred['pred_laneLines']\n            pred_lanes_prob = pred['pred_laneLines_prob']\n            if model_name == \"GenLaneNet\":\n                pred_lanes = [pred_lanes[ii] for ii in range(len(pred_lanes_prob)) if\n                              pred_lanes_prob[ii] > prob_th]\n                pred_category = np.zeros(len(pred_lanes_prob))\n            # Note: non-lane class is already filtered out in compute_3d_lanes_all_category()\n            else:\n                pred_lanes = [pred_lanes[ii] for ii in range(len(pred_lanes_prob)) if max(pred_lanes_prob[ii]) > prob_th]\n                pred_lanes_prob = [prob for k, prob in enumerate(pred_lanes_prob) if max(prob) > prob_th]\n\n                if pred_lanes_prob:\n                    pred_category = np.argmax(pred_lanes_prob, 1)\n                else:\n                    pred_category = []\n\n            if raw_file not in gts:\n                raise Exception('Some raw_file from your predictions do not exist in the test tasks.')\n            \n            gt = gts[raw_file]\n\n            # evaluate lanelines\n            assert 'extrinsic' in gt and 'intrinsic' in gt\n            cam_extrinsics = np.array(gt['extrinsic'])\n            cam_intrinsics = gt['intrinsic']\n            cam_intrinsics = np.array(cam_intrinsics)\n            # else:\n            #     assert all(cam_param in gt for cam_param in ['cam_pitch', 'cam_height']), \"without 'extrinsic' in gt json_file dict, AND not 'cam_height' & 'cam_pitch' as well.\"\n            #     cam_extrinsics = \n            #     cam_intrinsics = np.array(self.args.K)\n            \n            if self.is_apollo:\n                gt_cam_height = gt['cam_height']\n                gt_cam_pitch = gt['cam_pitch']\n                gt_lanes_packed = gt['laneLines']\n                gt_lane_visibility = gt['laneLines_visibility']\n\n                gt_lanes, gt_visibility, gt_category = [], [], []\n                \n                for j, gt_lane_packed in enumerate(gt_lanes_packed):\n                    lane = np.array(gt_lane_packed)\n                    lane_visibility = np.array(gt_lane_visibility[j])\n                    gt_lanes.append(lane)\n                    gt_visibility.append(lane_visibility)\n                    gt_category.append(1)\n            else:\n                # Re-calculate extrinsic matrix based on ground coordinate\n                R_vg = np.array([[0, 1, 0],\n                                    [-1, 0, 0],\n                                    [0, 0, 1]], dtype=float)\n                R_gc = np.array([[1, 0, 0],\n                                    [0, 0, 1],\n                                    [0, -1, 0]], dtype=float)\n                cam_extrinsics[:3, :3] = np.matmul(np.matmul(\n                                            np.matmul(np.linalg.inv(R_vg), cam_extrinsics[:3, :3]),\n                                                R_vg), R_gc)\n                gt_cam_height = cam_extrinsics[2, 3]\n                gt_cam_pitch = 0\n\n                cam_extrinsics[0:2, 3] = 0.0\n                # cam_extrinsics[2, 3] = gt_cam_height\n\n                cam_intrinsics = gt['intrinsic']\n                cam_intrinsics = np.array(cam_intrinsics)\n                try:\n                    gt_lanes_packed = gt['lane_lines']\n                except:\n                    print(\"error 'lane_lines' in gt: \", gt['file_path'])\n                    gt_lanes_packed = []\n\n                gt_lanes, gt_visibility, gt_category = [], [], []\n                for j, gt_lane_packed in enumerate(gt_lanes_packed):\n                    # A GT lane can be either 2D or 3D\n                    # if a GT lane is 3D, the height is intact from 3D GT, so keep it intact here too\n                    lane = np.array(gt_lane_packed['xyz'])\n                    lane_visibility = np.array(gt_lane_packed['visibility'])\n\n                    lane = np.vstack((lane, np.ones((1, lane.shape[1]))))\n                    cam_representation = np.linalg.inv(\n                                            np.array([[0, 0, 1, 0],\n                                                    [-1, 0, 0, 0],\n                                                    [0, -1, 0, 0],\n                                                    [0, 0, 0, 1]], dtype=float))\n                    lane = np.matmul(cam_extrinsics, np.matmul(cam_representation, lane))\n                    lane = lane[0:3, :].T\n\n                    gt_lanes.append(lane)\n                    gt_visibility.append(lane_visibility)\n                    gt_category.append(gt_lane_packed['category'])\n            \n            \n            if self.is_apollo:\n                P_g2im = projection_g2im(gt_cam_pitch, gt_cam_height, cam_intrinsics)\n            else:\n                P_g2im = projection_g2im_extrinsic(cam_extrinsics, cam_intrinsics)\n\n            # N to N matching of lanelines\n            gt_num_all += len(gt_lanes)\n            pred_num_all += len(pred_lanes)\n            r_lane, p_lane, c_lane, cnt_gt, cnt_pred, match_num, \\\n            x_error_close, x_error_far, \\\n            z_error_close, z_error_far = self.bench(pred_lanes,\n                                                    pred_category, \n                                                    gt_lanes,\n                                                    gt_visibility,\n                                                    gt_category,\n                                                    raw_file,\n                                                    gt_cam_height,\n                                                    gt_cam_pitch,\n                                                    vis,\n                                                    P_g2im)\n            laneline_stats.append(np.array([r_lane, p_lane, c_lane, cnt_gt, cnt_pred, match_num]))\n            # consider x_error z_error only for the matched lanes\n            # if r_lane > 0 and p_lane > 0:\n            laneline_x_error_close.extend(x_error_close)\n            laneline_x_error_far.extend(x_error_far)\n            laneline_z_error_close.extend(z_error_close)\n            laneline_z_error_far.extend(z_error_far)\n\n        output_stats = []\n        laneline_stats = np.array(laneline_stats)\n\n        laneline_x_error_close = np.array(laneline_x_error_close)\n        laneline_x_error_far = np.array(laneline_x_error_far)\n        laneline_z_error_close = np.array(laneline_z_error_close)\n        laneline_z_error_far = np.array(laneline_z_error_far)\n\n        self.logger.info(\"match num:\"+(str(np.sum(laneline_stats[:,5]))))\n        self.logger.info(\"cnt_gt_all:\"+str(gt_num_all))\n        self.logger.info(\"cnt_pred_all:\"+str(pred_num_all))\n        self.logger.info(\"cnt_gt_matched:\"+str(np.sum(laneline_stats[:,3])))\n        self.logger.info(\"cnt_pred_matched:\"+str(np.sum(laneline_stats[:,4])))\n\n        \n        if np.sum(laneline_stats[:, 3])!= 0:\n            R_lane = np.sum(laneline_stats[:, 0]) / (np.sum(laneline_stats[:, 3]))\n        else:\n            R_lane = np.sum(laneline_stats[:, 0]) / (np.sum(laneline_stats[:, 3]) + 1e-6)   # recall = TP / (TP+FN)\n        if np.sum(laneline_stats[:, 4]) != 0:\n            P_lane = np.sum(laneline_stats[:, 1]) / (np.sum(laneline_stats[:, 4]))\n        else:\n            P_lane = np.sum(laneline_stats[:, 1]) / (np.sum(laneline_stats[:, 4]) + 1e-6)   # precision = TP / (TP+FP)\n        if np.sum(laneline_stats[:, 5]) != 0:\n            C_lane = np.sum(laneline_stats[:, 2]) / (np.sum(laneline_stats[:, 5]))\n        else:\n            C_lane = np.sum(laneline_stats[:, 2]) / (np.sum(laneline_stats[:, 5]) + 1e-6)   # category_accuracy\n        if (R_lane + P_lane) != 0:\n            F_lane = 2 * R_lane * P_lane / (R_lane + P_lane)\n        else:\n            F_lane = 2 * R_lane * P_lane / (R_lane + P_lane + 1e-6)\n        \n        if laneline_x_error_close.shape[0] > 0:\n            x_error_close_avg = np.average(laneline_x_error_close[laneline_x_error_close > -1 + 1e-6])\n        else:\n            x_error_close_avg = -1\n        if laneline_x_error_far.shape[0] > 0:\n            x_error_far_avg = np.average(laneline_x_error_far[laneline_x_error_far > -1 + 1e-6])\n        else:\n            x_error_far_avg = -1\n        if laneline_z_error_close.shape[0] > 0:\n            z_error_close_avg = np.average(laneline_z_error_close[laneline_z_error_close > -1 + 1e-6])\n        else:\n            z_error_close_avg = -1\n        if laneline_z_error_far.shape[0] > 0:\n            z_error_far_avg = np.average(laneline_z_error_far[laneline_z_error_far > -1 + 1e-6])\n        else:\n            z_error_far_avg = -1\n        \n        output_stats.append(F_lane)\n        output_stats.append(R_lane)\n        output_stats.append(P_lane)\n        output_stats.append(C_lane)\n        output_stats.append(x_error_close_avg)\n        output_stats.append(x_error_far_avg)\n        output_stats.append(z_error_close_avg)\n        output_stats.append(z_error_far_avg)\n        output_stats.append(np.sum(laneline_stats[:, 0]))   # 8\n        output_stats.append(np.sum(laneline_stats[:, 1]))   # 9\n        output_stats.append(np.sum(laneline_stats[:, 2]))   # 10\n        output_stats.append(np.sum(laneline_stats[:, 3]))   # 11\n        output_stats.append(np.sum(laneline_stats[:, 4]))   # 12\n        output_stats.append(np.sum(laneline_stats[:, 5]))   # 13\n\n        return output_stats\n\n"
  },
  {
    "path": "utils/eval_3D_lane_apollo.py",
    "content": "\"\"\"\nDescription: This code is to evaluate 3D lane detection. The optimal matching between ground-truth set and predicted\nset of lanes are sought via solving a min cost flow.\n\nEvaluation metrics includes:\n    Average Precision (AP)\n    Max F-scores\n    x error close (0 - 40 m)\n    x error far (0 - 100 m)\n    z error close (0 - 40 m)\n    z error far (0 - 100 m)\n\nReference: \"Gen-LaneNet: Generalized and Scalable Approach for 3D Lane Detection\". Y. Guo. etal. 2020\n\nAuthor: Yuliang Guo (33yuliangguo@gmail.com)\nDate: March, 2020\n\"\"\"\n\nimport numpy as np\nimport cv2\nimport os\nimport os.path as ops\nimport copy\nimport math\nimport ujson as json\nfrom scipy.interpolate import interp1d\nimport matplotlib\nfrom utils.utils import *\nfrom utils.MinCostFlow import SolveMinCostFlow\nfrom mpl_toolkits.mplot3d import Axes3D\n\nmatplotlib.use('Agg')\nimport matplotlib.pyplot as plt\n\nplt.rcParams['figure.figsize'] = (35, 30)\nplt.rcParams.update({'font.size': 25})\nplt.rcParams.update({'font.weight': 'semibold'})\n\ncolor = [[0, 0, 255],  # red\n         [0, 255, 0],  # green\n         [255, 0, 255],  # purple\n         [255, 255, 0]]  # cyan\n\nvis_min_y = 5\nvis_max_y = 80\n\n\nclass LaneEval(object):\n    def __init__(self, args, logger=None):\n        self.dataset_dir = args.dataset_dir\n        self.K = args.K\n        self.no_centerline = True\n        self.resize_h = args.resize_h\n        self.resize_w = args.resize_w\n        self.H_crop = homography_crop_resize([args.org_h, args.org_w], args.crop_y, [args.resize_h, args.resize_w])\n\n        self.x_min = args.top_view_region[0, 0]\n        self.x_max = args.top_view_region[1, 0]\n        self.y_min = args.top_view_region[2, 1]\n        self.y_max = args.top_view_region[0, 1]\n        self.y_samples = np.linspace(self.y_min, self.y_max, num=100, endpoint=False)\n        self.dist_th = 1.5\n        self.ratio_th = 0.75\n        self.close_range = 40\n        self.logger = logger\n        self.log_eval_info()\n\n    def log_eval_info(self):\n        self.logger.info('eval x range : [%s, %s]' % (self.x_min, self.x_max))\n        self.logger.info('eval y range : [%s, %s)' % (self.y_min, self.y_max))\n        self.logger.info('eval distance thresh: %s meter' % self.dist_th)\n        self.logger.info('eval points ratio : %s' % self.ratio_th)\n        self.logger.info('eval close range: %s' % self.close_range)\n\n    def bench(self, pred_lanes, gt_lanes, gt_visibility, raw_file, gt_cam_height, gt_cam_pitch, vis, ax1, ax2):\n        \"\"\"\n            Matching predicted lanes and ground-truth lanes in their IPM projection, ignoring z attributes.\n            x error, y_error, and z error are all considered, although the matching does not rely on z\n            The input of prediction and ground-truth lanes are in ground coordinate, x-right, y-forward, z-up\n            The fundamental assumption is: 1. there are no two points from different lanes with identical x, y\n                                              but different z's\n                                           2. there are no two points from a single lane having identical x, y\n                                              but different z's\n            If the interest area is within the current drivable road, the above assumptions are almost always valid.\n\n        :param pred_lanes: N X 2 or N X 3 lists depending on 2D or 3D\n        :param gt_lanes: N X 2 or N X 3 lists depending on 2D or 3D\n        :param raw_file: file path rooted in dataset folder\n        :param gt_cam_height: camera height given in ground-truth data\n        :param gt_cam_pitch: camera pitch given in ground-truth data\n        :return:\n        \"\"\"\n\n        # change this properly\n        close_range_idx = np.where(self.y_samples > self.close_range)[0][0]\n\n        r_lane, p_lane = 0., 0.\n        x_error_close = []\n        x_error_far = []\n        z_error_close = []\n        z_error_far = []\n\n        # only keep the visible portion\n        gt_lanes = [prune_3d_lane_by_visibility(np.array(gt_lane), np.array(gt_visibility[k])) for k, gt_lane in\n                    enumerate(gt_lanes)]\n        gt_lanes = [lane for lane in gt_lanes if lane.shape[0] > 1]\n        # only consider those gt lanes overlapping with sampling range\n        gt_lanes = [lane for lane in gt_lanes if lane[0, 1] < self.y_samples[-1] and lane[-1, 1] > self.y_samples[0]]\n        gt_lanes = [prune_3d_lane_by_range(np.array(gt_lane), 3 * self.x_min, 3 * self.x_max) for gt_lane in gt_lanes]\n        gt_lanes = [lane for lane in gt_lanes if lane.shape[0] > 1]\n        cnt_gt = len(gt_lanes)\n        cnt_pred = len(pred_lanes)\n\n        gt_visibility_mat = np.zeros((cnt_gt, 100))\n        pred_visibility_mat = np.zeros((cnt_pred, 100))\n        # resample gt and pred at y_samples\n        for i in range(cnt_gt):\n            min_y = np.min(np.array(gt_lanes[i])[:, 1])\n            max_y = np.max(np.array(gt_lanes[i])[:, 1])\n            x_values, z_values, visibility_vec = resample_laneline_in_y(np.array(gt_lanes[i]), self.y_samples,\n                                                                        out_vis=True)\n            gt_lanes[i] = np.vstack([x_values, z_values]).T\n            gt_visibility_mat[i, :] = np.logical_and(x_values >= self.x_min,\n                                                     np.logical_and(x_values <= self.x_max,\n                                                                    np.logical_and(self.y_samples >= min_y,\n                                                                                   self.y_samples <= max_y)))\n            gt_visibility_mat[i, :] = np.logical_and(gt_visibility_mat[i, :], visibility_vec)\n\n        for i in range(cnt_pred):\n            # # ATTENTION: ensure y mono increase before interpolation: but it can reduce size\n            # pred_lanes[i] = make_lane_y_mono_inc(np.array(pred_lanes[i]))\n            # pred_lane = prune_3d_lane_by_range(np.array(pred_lanes[i]), self.x_min, self.x_max)\n            min_y = np.min(np.array(pred_lanes[i])[:, 1])\n            max_y = np.max(np.array(pred_lanes[i])[:, 1])\n            x_values, z_values, visibility_vec = resample_laneline_in_y(np.array(pred_lanes[i]), self.y_samples,\n                                                                        out_vis=True)\n            pred_lanes[i] = np.vstack([x_values, z_values]).T\n            pred_visibility_mat[i, :] = np.logical_and(x_values >= self.x_min,\n                                                       np.logical_and(x_values <= self.x_max,\n                                                                      np.logical_and(self.y_samples >= min_y,\n                                                                                     self.y_samples <= max_y)))\n            pred_visibility_mat[i, :] = np.logical_and(pred_visibility_mat[i, :], visibility_vec)\n            # pred_visibility_mat[i, :] = np.logical_and(x_values >= self.x_min, x_values <= self.x_max)\n\n        adj_mat = np.zeros((cnt_gt, cnt_pred), dtype=np.int)\n        cost_mat = np.zeros((cnt_gt, cnt_pred), dtype=np.int)\n        cost_mat.fill(1000)\n        num_match_mat = np.zeros((cnt_gt, cnt_pred), dtype=np.float)\n        x_dist_mat_close = np.zeros((cnt_gt, cnt_pred), dtype=np.float)\n        x_dist_mat_close.fill(1000.)\n        x_dist_mat_far = np.zeros((cnt_gt, cnt_pred), dtype=np.float)\n        x_dist_mat_far.fill(1000.)\n        z_dist_mat_close = np.zeros((cnt_gt, cnt_pred), dtype=np.float)\n        z_dist_mat_close.fill(1000.)\n        z_dist_mat_far = np.zeros((cnt_gt, cnt_pred), dtype=np.float)\n        z_dist_mat_far.fill(1000.)\n        # compute curve to curve distance\n        for i in range(cnt_gt):\n            for j in range(cnt_pred):\n                x_dist = np.abs(gt_lanes[i][:, 0] - pred_lanes[j][:, 0])\n                z_dist = np.abs(gt_lanes[i][:, 1] - pred_lanes[j][:, 1])\n                euclidean_dist = np.sqrt(x_dist ** 2 + z_dist ** 2)\n\n                # apply visibility to penalize different partial matching accordingly\n                euclidean_dist[\n                    np.logical_or(gt_visibility_mat[i, :] < 0.5, pred_visibility_mat[j, :] < 0.5)] = self.dist_th\n\n                # if np.average(euclidean_dist) < 2*self.dist_th: # don't prune here to encourage finding perfect match\n                num_match_mat[i, j] = np.sum(euclidean_dist < self.dist_th)\n                adj_mat[i, j] = 1\n                # ATTENTION: use the sum as int type to meet the requirements of min cost flow optimization (int type)\n                # using num_match_mat as cost does not work?\n                cost_mat[i, j] = np.sum(euclidean_dist).astype(np.int)\n                # cost_mat[i, j] = num_match_mat[i, j]\n\n                # use the both visible portion to calculate distance error\n                both_visible_indices = np.logical_and(gt_visibility_mat[i, :] > 0.5, pred_visibility_mat[j, :] > 0.5)\n                if np.sum(both_visible_indices[:close_range_idx]) > 0:\n                    x_dist_mat_close[i, j] = np.sum(\n                        x_dist[:close_range_idx] * both_visible_indices[:close_range_idx]) / np.sum(\n                        both_visible_indices[:close_range_idx])\n                    z_dist_mat_close[i, j] = np.sum(\n                        z_dist[:close_range_idx] * both_visible_indices[:close_range_idx]) / np.sum(\n                        both_visible_indices[:close_range_idx])\n                else:\n                    x_dist_mat_close[i, j] = self.dist_th\n                    z_dist_mat_close[i, j] = self.dist_th\n\n                if np.sum(both_visible_indices[close_range_idx:]) > 0:\n                    x_dist_mat_far[i, j] = np.sum(\n                        x_dist[close_range_idx:] * both_visible_indices[close_range_idx:]) / np.sum(\n                        both_visible_indices[close_range_idx:])\n                    z_dist_mat_far[i, j] = np.sum(\n                        z_dist[close_range_idx:] * both_visible_indices[close_range_idx:]) / np.sum(\n                        both_visible_indices[close_range_idx:])\n                else:\n                    x_dist_mat_far[i, j] = self.dist_th\n                    z_dist_mat_far[i, j] = self.dist_th\n\n        # solve bipartite matching vis min cost flow solver\n        match_results = SolveMinCostFlow(adj_mat, cost_mat)\n        match_results = np.array(match_results)\n\n        # only a match with avg cost < self.dist_th is consider valid one\n        match_gt_ids = []\n        match_pred_ids = []\n        \n        if match_results.shape[0] > 0:\n            for i in range(len(match_results)):\n                if match_results[i, 2] < self.dist_th * self.y_samples.shape[0]:\n                    gt_i = match_results[i, 0]\n                    pred_i = match_results[i, 1]\n                    # consider match when the matched points is above a ratio\n                    if num_match_mat[gt_i, pred_i] / np.sum(gt_visibility_mat[gt_i, :]) >= self.ratio_th:\n                        r_lane += 1\n                        match_gt_ids.append(gt_i)\n                    if num_match_mat[gt_i, pred_i] / np.sum(pred_visibility_mat[pred_i, :]) >= self.ratio_th:\n                        p_lane += 1\n                        match_pred_ids.append(pred_i)\n                    x_error_close.append(x_dist_mat_close[gt_i, pred_i])\n                    x_error_far.append(x_dist_mat_far[gt_i, pred_i])\n                    z_error_close.append(z_dist_mat_close[gt_i, pred_i])\n                    z_error_far.append(z_dist_mat_far[gt_i, pred_i])\n\n        # visualize lanelines and matching results both in image and 3D\n        if vis:\n            P_g2im = projection_g2im(gt_cam_pitch, gt_cam_height, self.K)\n            P_gt = np.matmul(self.H_crop, P_g2im)\n            img = cv2.imread(ops.join(self.dataset_dir, raw_file))\n            img = cv2.warpPerspective(img, self.H_crop, (self.resize_w, self.resize_h))\n            img = img.astype(np.float) / 255\n\n            for i in range(cnt_gt):\n                x_values = gt_lanes[i][:, 0]\n                z_values = gt_lanes[i][:, 1]\n                x_2d, y_2d = projective_transformation(P_gt, x_values, self.y_samples, z_values)\n                x_2d = x_2d.astype(np.int)\n                y_2d = y_2d.astype(np.int)\n\n                if i in match_gt_ids:\n                    color = [0, 0, 1]\n                else:\n                    color = [0, 1, 1]\n                for k in range(1, x_2d.shape[0]):\n                    # only draw the visible portion\n                    if gt_visibility_mat[i, k - 1] and gt_visibility_mat[i, k]:\n                        img = cv2.line(img, (x_2d[k - 1], y_2d[k - 1]), (x_2d[k], y_2d[k]), color[-1::-1], 3)\n                ax2.plot(x_values[np.where(gt_visibility_mat[i, :])],\n                         self.y_samples[np.where(gt_visibility_mat[i, :])],\n                         z_values[np.where(gt_visibility_mat[i, :])], color=color, linewidth=5)\n\n            for i in range(cnt_pred):\n                x_values = pred_lanes[i][:, 0]\n                z_values = pred_lanes[i][:, 1]\n                x_2d, y_2d = projective_transformation(P_gt, x_values, self.y_samples, z_values)\n                x_2d = x_2d.astype(np.int)\n                y_2d = y_2d.astype(np.int)\n\n                if i in match_pred_ids:\n                    color = [1, 0, 0]\n                else:\n                    color = [1, 0, 1]\n                for k in range(1, x_2d.shape[0]):\n                    # only draw the visible portion\n                    if pred_visibility_mat[i, k - 1] and pred_visibility_mat[i, k]:\n                        img = cv2.line(img, (x_2d[k - 1], y_2d[k - 1]), (x_2d[k], y_2d[k]), color[-1::-1], 2)\n                ax2.plot(x_values[np.where(pred_visibility_mat[i, :])],\n                         self.y_samples[np.where(pred_visibility_mat[i, :])],\n                         z_values[np.where(pred_visibility_mat[i, :])], color=color, linewidth=5)\n\n            cv2.putText(img, 'Recall: {:.3f}'.format(r_lane / (cnt_gt + 1e-6)),\n                        (5, 30), fontFace=cv2.FONT_HERSHEY_SIMPLEX, fontScale=0.7, color=(0, 0, 1), thickness=2)\n            cv2.putText(img, 'Precision: {:.3f}'.format(p_lane / (cnt_pred + 1e-6)),\n                        (5, 60), fontFace=cv2.FONT_HERSHEY_SIMPLEX, fontScale=0.7, color=(0, 0, 1), thickness=2)\n            ax1.imshow(img[:, :, [2, 1, 0]])\n\n        return r_lane, p_lane, cnt_gt, cnt_pred, x_error_close, x_error_far, z_error_close, z_error_far\n\n    # compare predicted set and ground-truth set using a fixed lane probability threshold\n    def bench_one_submit(self, pred_file, gt_file, prob_th=0.5, vis=False):\n        if vis:\n            save_path = pred_file[:pred_file.rfind('/')]\n            save_path += '/vis'\n            if vis and not os.path.exists(save_path):\n                try:\n                    os.makedirs(save_path)\n                except OSError as e:\n                    print(e.message)\n        pred_lines = open(pred_file).readlines()\n        json_pred = [json.loads(line) for line in pred_lines]\n        json_gt = [json.loads(line) for line in open(gt_file).readlines()]\n        if len(json_gt) != len(json_pred):\n            raise Exception('We do not get the predictions of all the test tasks')\n        gts = {l['raw_file']: l for l in json_gt}\n\n        laneline_stats = []\n        laneline_x_error_close = []\n        laneline_x_error_far = []\n        laneline_z_error_close = []\n        laneline_z_error_far = []\n        centerline_stats = []\n        centerline_x_error_close = []\n        centerline_x_error_far = []\n        centerline_z_error_close = []\n        centerline_z_error_far = []\n\n        print('prob th: %s' % prob_th)\n        for i, pred in enumerate(json_pred):\n            if 'raw_file' not in pred or 'laneLines' not in pred:\n                raise Exception('raw_file or lanelines not in some predictions.')\n            raw_file = pred['raw_file']\n            pred_lanelines = pred['laneLines']\n            pred_laneLines_prob = pred['laneLines_prob']\n            pred_laneLines_prob_one_cls = []\n            for prob_2cls in pred_laneLines_prob:\n                max_cls = np.argmax(prob_2cls)\n                if max_cls > 0:\n                    pred_laneLines_prob_one_cls.append(prob_2cls[max_cls])\n                else:\n                    pred_laneLines_prob_one_cls.append(prob_2cls[0])\n            \n            pred_lanelines = [pred_lanelines[ii] for ii in range(len(pred_laneLines_prob_one_cls)) if\n                              pred_laneLines_prob_one_cls[ii] > prob_th]\n\n            if raw_file not in gts:\n                raise Exception('Some raw_file from your predictions do not exist in the test tasks.')\n            gt = gts[raw_file]\n            gt_cam_height = gt['cam_height']\n            gt_cam_pitch = gt['cam_pitch']\n\n            if vis:\n                fig = plt.figure()\n                ax1 = fig.add_subplot(221)\n                ax2 = fig.add_subplot(222, projection='3d')\n                ax3 = fig.add_subplot(223)\n                ax4 = fig.add_subplot(224, projection='3d')\n            else:\n                ax1 = 0\n                ax2 = 0\n                ax3 = 0\n                ax4 = 0\n\n            # evaluate lanelines\n            gt_lanelines = gt['laneLines']\n            gt_visibility = gt['laneLines_visibility']\n            # N to N matching of lanelines\n            r_lane, p_lane, cnt_gt, cnt_pred, \\\n            x_error_close, x_error_far, \\\n            z_error_close, z_error_far = self.bench(pred_lanelines,\n                                                    gt_lanelines,\n                                                    gt_visibility,\n                                                    raw_file,\n                                                    gt_cam_height,\n                                                    gt_cam_pitch,\n                                                    vis, ax1, ax2)\n            c_lane = 0.0\n            laneline_stats.append(np.array([r_lane, p_lane, cnt_gt, cnt_pred]))\n            laneline_x_error_close.extend(x_error_close)\n            laneline_x_error_far.extend(x_error_far)\n            laneline_z_error_close.extend(z_error_close)\n            laneline_z_error_far.extend(z_error_far)\n\n            # evaluate centerlines\n            if not self.no_centerline:\n                pred_centerlines = pred['centerLines']\n                pred_centerlines_prob = pred['centerLines_prob']\n                pred_centerlines = [pred_centerlines[ii] for ii in range(len(pred_centerlines_prob)) if\n                                    pred_centerlines_prob[ii] > prob_th]\n\n                gt_centerlines = gt['centerLines']\n                gt_visibility = gt['centerLines_visibility']\n\n                # N to N matching of lanelines\n                r_lane, p_lane, cnt_gt, cnt_pred, \\\n                x_error_close, x_error_far, \\\n                z_error_close, z_error_far = self.bench(pred_centerlines,\n                                                        gt_centerlines,\n                                                        gt_visibility,\n                                                        raw_file,\n                                                        gt_cam_height,\n                                                        gt_cam_pitch,\n                                                        vis, ax3, ax4)\n                centerline_stats.append(np.array([r_lane, p_lane, cnt_gt, cnt_pred]))\n                centerline_x_error_close.extend(x_error_close)\n                centerline_x_error_far.extend(x_error_far)\n                centerline_z_error_close.extend(z_error_close)\n                centerline_z_error_far.extend(z_error_far)\n\n            if vis:\n                ax1.set_xticks([])\n                ax1.set_yticks([])\n                # ax2.set_xlabel('x axis')\n                # ax2.set_ylabel('y axis')\n                # ax2.set_zlabel('z axis')\n                bottom, top = ax2.get_zlim()\n                left, right = ax2.get_xlim()\n                ax2.set_zlim(min(bottom, -0.1), max(top, 0.1))\n                ax2.set_xlim(left, right)\n                ax2.set_ylim(vis_min_y, vis_max_y)\n                ax2.locator_params(nbins=5, axis='x')\n                ax2.locator_params(nbins=5, axis='z')\n                ax2.tick_params(pad=18)\n\n                ax3.set_xticks([])\n                ax3.set_yticks([])\n                # ax4.set_xlabel('x axis')\n                # ax4.set_ylabel('y axis')\n                # ax4.set_zlabel('z axis')\n                bottom, top = ax4.get_zlim()\n                left, right = ax4.get_xlim()\n                ax4.set_zlim(min(bottom, -0.1), max(top, 0.1))\n                ax4.set_xlim(left, right)\n                ax4.set_ylim(vis_min_y, vis_max_y)\n                ax4.locator_params(nbins=5, axis='x')\n                ax4.locator_params(nbins=5, axis='z')\n                ax4.tick_params(pad=18)\n\n                fig.subplots_adjust(wspace=0, hspace=0.01)\n                fig.savefig(ops.join(save_path, raw_file.replace(\"/\", \"_\")))\n                plt.close(fig)\n                print('processed sample: {}  {}'.format(i, raw_file))\n\n        output_stats = []\n        laneline_stats = np.array(laneline_stats)\n        laneline_x_error_close = np.array(laneline_x_error_close)\n        laneline_x_error_far = np.array(laneline_x_error_far)\n        laneline_z_error_close = np.array(laneline_z_error_close)\n        laneline_z_error_far = np.array(laneline_z_error_far)\n\n        R_lane = np.sum(laneline_stats[:, 0]) / (np.sum(laneline_stats[:, 2]) + 1e-6)\n        P_lane = np.sum(laneline_stats[:, 1]) / (np.sum(laneline_stats[:, 3]) + 1e-6)\n        F_lane = 2 * R_lane * P_lane / (R_lane + P_lane + 1e-6)\n        x_error_close_avg = np.average(laneline_x_error_close)\n        x_error_far_avg = np.average(laneline_x_error_far)\n        z_error_close_avg = np.average(laneline_z_error_close)\n        z_error_far_avg = np.average(laneline_z_error_far)\n\n        output_stats.append(F_lane)\n        output_stats.append(R_lane)\n        output_stats.append(P_lane)\n        # output_stats.append(0.0)\n        output_stats.append(x_error_close_avg)\n        output_stats.append(x_error_far_avg)\n        output_stats.append(z_error_close_avg)\n        output_stats.append(z_error_far_avg)\n\n        # for gpus cal res.\n        output_stats.append(np.sum(laneline_stats[:, 0]))   # 7\n        output_stats.append(np.sum(laneline_stats[:, 1]))   # 8\n        output_stats.append(np.sum(laneline_stats[:, 2]))   # 9\n        output_stats.append(np.sum(laneline_stats[:, 3]))   # 10\n        \n\n        if not self.no_centerline:\n            centerline_stats = np.array(centerline_stats)\n            centerline_x_error_close = np.array(centerline_x_error_close)\n            centerline_x_error_far = np.array(centerline_x_error_far)\n            centerline_z_error_close = np.array(centerline_z_error_close)\n            centerline_z_error_far = np.array(centerline_z_error_far)\n\n            R_lane = np.sum(centerline_stats[:, 0]) / (np.sum(centerline_stats[:, 2]) + 1e-6)\n            P_lane = np.sum(centerline_stats[:, 1]) / (np.sum(centerline_stats[:, 3]) + 1e-6)\n            F_lane = 2 * R_lane * P_lane / (R_lane + P_lane + 1e-6)\n            x_error_close_avg = np.average(centerline_x_error_close)\n            x_error_far_avg = np.average(centerline_x_error_far)\n            z_error_close_avg = np.average(centerline_z_error_close)\n            z_error_far_avg = np.average(centerline_z_error_far)\n\n            output_stats.append(F_lane)\n            output_stats.append(R_lane)\n            output_stats.append(P_lane)\n            output_stats.append(x_error_close_avg)\n            output_stats.append(x_error_far_avg)\n            output_stats.append(z_error_close_avg)\n            output_stats.append(z_error_far_avg)\n\n        return output_stats\n    \n    def bench_one_submit_ddp(self, pred_lines_sub, gt_lines_sub, model_name, prob_th=0.5, vis=False):\n        gts = {l['file_path']: l for l in gt_lines_sub}\n\n        laneline_stats = []\n        laneline_x_error_close = []\n        laneline_x_error_far = []\n        laneline_z_error_close = []\n        laneline_z_error_far = []\n\n        gt_num_all, pred_num_all = 0, 0\n        for i, pred in enumerate(pred_lines_sub):\n            if 'file_path' not in pred or 'pred_laneLines' not in pred:\n                raise Exception('file_path or lane_lines not in some predictions.')\n            raw_file = pred['file_path']\n\n            pred_lanes = pred['pred_laneLines']\n            pred_lanes_prob = pred['pred_laneLines_prob']\n            \n            pred_lanes = [pred_lanes[ii] for ii in range(len(pred_lanes_prob)) if max(pred_lanes_prob[ii]) > prob_th]\n            pred_lanes_prob = [prob for k, prob in enumerate(pred_lanes_prob) if max(prob) > prob_th]\n\n            if len(pred_lanes_prob) > 0:\n                pred_category = np.argmax(pred_lanes_prob, 1)\n            else:\n                pred_category = []\n\n            if raw_file not in gts:\n                raise Exception('Some raw_file from your predictions do not exist in the test tasks.')\n            \n            gt = gts[raw_file]\n\n            gt_cam_height = gt['cam_height']\n            gt_cam_pitch = gt['cam_pitch']\n            cam_intrinsics = gt['intrinsic']\n            gt_lanes_packed = gt['laneLines']\n            gt_lane_visibility = gt['laneLines_visibility']\n\n            gt_lanes, gt_visibility, gt_category = [], [], []\n            for j, gt_lane_packed in enumerate(gt_lanes_packed):\n                lane = np.array(gt_lane_packed)\n                lane_visibility = np.array(gt_lane_visibility[j])\n                gt_lanes.append(lane)\n                gt_visibility.append(lane_visibility)\n                gt_category.append(1)\n            \n            P_g2im = projection_g2im(gt_cam_pitch, gt_cam_height, cam_intrinsics)\n\n            if vis:\n                fig = plt.figure()\n                ax1 = fig.add_subplot(221)\n                ax2 = fig.add_subplot(222, projection='3d')\n                ax3 = fig.add_subplot(223)\n                ax4 = fig.add_subplot(224, projection='3d')\n            else:\n                ax1 = 0\n                ax2 = 0\n                ax3 = 0\n                ax4 = 0\n                \n            # N to N matching of lanelines\n            gt_num_all += len(gt_lanes)\n            pred_num_all += len(pred_lanes)\n\n            r_lane, p_lane, cnt_gt, cnt_pred, \\\n            x_error_close, x_error_far, \\\n            z_error_close, z_error_far = self.bench(pred_lanes,\n                                                    gt_lanes,\n                                                    gt_visibility,\n                                                    raw_file,\n                                                    gt_cam_height,\n                                                    gt_cam_pitch,\n                                                    vis, ax1, ax2)\n            laneline_stats.append(np.array([r_lane, p_lane, cnt_gt, cnt_pred]))\n            laneline_x_error_close.extend(x_error_close)\n            laneline_x_error_far.extend(x_error_far)\n            laneline_z_error_close.extend(z_error_close)\n            laneline_z_error_far.extend(z_error_far)\n\n        output_stats = []\n        laneline_stats = np.array(laneline_stats)\n\n        laneline_x_error_close = np.array(laneline_x_error_close)\n        laneline_x_error_far = np.array(laneline_x_error_far)\n        laneline_z_error_close = np.array(laneline_z_error_close)\n        laneline_z_error_far = np.array(laneline_z_error_far)\n\n        \n        R_lane = np.sum(laneline_stats[:, 0]) / (np.sum(laneline_stats[:, 2]) + 1e-6)   # recall = TP / (TP+FN)\n    \n        P_lane = np.sum(laneline_stats[:, 1]) / (np.sum(laneline_stats[:, 3]) + 1e-6)   # precision = TP / (TP+FP)\n        \n        F_lane = 2 * R_lane * P_lane / (R_lane + P_lane + 1e-6)\n\n        x_error_close_avg = np.average(laneline_x_error_close)\n        x_error_far_avg = np.average(laneline_x_error_far)\n        z_error_close_avg = np.average(laneline_z_error_close)\n        z_error_far_avg = np.average(laneline_z_error_far)\n\n        output_stats.append(F_lane)\n        output_stats.append(R_lane)\n        output_stats.append(P_lane)\n        output_stats.append(x_error_close_avg)\n        output_stats.append(x_error_far_avg)\n        output_stats.append(z_error_close_avg)\n        output_stats.append(z_error_far_avg)\n        \n        # for gpus\n\n        output_stats.append(np.sum(laneline_stats[:, 0]))\n        output_stats.append(np.sum(laneline_stats[:, 1]))\n        output_stats.append(np.sum(laneline_stats[:, 2]))\n        output_stats.append(np.sum(laneline_stats[:, 3]))\n\n        return output_stats\n\n\n    def bench_PR(self, pred_lanes, gt_lanes, gt_visibility):\n        \"\"\"\n            Matching predicted lanes and ground-truth lanes in their IPM projection, ignoring z attributes.\n            x error, y_error, and z error are all considered, although the matching does not rely on z\n            The input of prediction and ground-truth lanes are in ground coordinate, x-right, y-forward, z-up\n            The fundamental assumption is: 1. there are no two points from different lanes with identical x, y\n                                              but different z's\n                                           2. there are no two points from a single lane having identical x, y\n                                              but different z's\n            If the interest area is within the current drivable road, the above assumptions are almost always valid.\n\n        :param pred_lanes: N X 2 or N X 3 lists depending on 2D or 3D\n        :param gt_lanes: N X 2 or N X 3 lists depending on 2D or 3D\n        :return:\n        \"\"\"\n\n        r_lane, p_lane = 0., 0.\n\n        # only keep the visible portion\n        gt_lanes = [prune_3d_lane_by_visibility(np.array(gt_lane), np.array(gt_visibility[k])) for k, gt_lane in\n                    enumerate(gt_lanes)]\n        gt_lanes = [lane for lane in gt_lanes if lane.shape[0] > 1]\n        # only consider those gt lanes overlapping with sampling range\n        gt_lanes = [lane for lane in gt_lanes if lane[0, 1] < self.y_samples[-1] and lane[-1, 1] > self.y_samples[0]]\n        gt_lanes = [prune_3d_lane_by_range(np.array(gt_lane), 3 * self.x_min, 3 * self.x_max) for gt_lane in gt_lanes]\n        gt_lanes = [lane for lane in gt_lanes if lane.shape[0] > 1]\n        cnt_gt = len(gt_lanes)\n        cnt_pred = len(pred_lanes)\n\n        gt_visibility_mat = np.zeros((cnt_gt, 100))\n        pred_visibility_mat = np.zeros((cnt_pred, 100))\n        # resample gt and pred at y_samples\n        for i in range(cnt_gt):\n            min_y = np.min(np.array(gt_lanes[i])[:, 1])\n            max_y = np.max(np.array(gt_lanes[i])[:, 1])\n            x_values, z_values, visibility_vec = resample_laneline_in_y(np.array(gt_lanes[i]), self.y_samples,\n                                                                        out_vis=True)\n            gt_lanes[i] = np.vstack([x_values, z_values]).T\n            gt_visibility_mat[i, :] = np.logical_and(x_values >= self.x_min,\n                                                     np.logical_and(x_values <= self.x_max,\n                                                                    np.logical_and(self.y_samples >= min_y,\n                                                                                   self.y_samples <= max_y)))\n            gt_visibility_mat[i, :] = np.logical_and(gt_visibility_mat[i, :], visibility_vec)\n\n        for i in range(cnt_pred):\n            # # ATTENTION: ensure y mono increase before interpolation: but it can reduce size\n            # pred_lanes[i] = make_lane_y_mono_inc(np.array(pred_lanes[i]))\n            # pred_lane = prune_3d_lane_by_range(np.array(pred_lanes[i]), self.x_min, self.x_max)\n            min_y = np.min(np.array(pred_lanes[i])[:, 1])\n            max_y = np.max(np.array(pred_lanes[i])[:, 1])\n            x_values, z_values, visibility_vec = resample_laneline_in_y(np.array(pred_lanes[i]), self.y_samples,\n                                                                        out_vis=True)\n            pred_lanes[i] = np.vstack([x_values, z_values]).T\n            pred_visibility_mat[i, :] = np.logical_and(x_values >= self.x_min,\n                                                       np.logical_and(x_values <= self.x_max,\n                                                                      np.logical_and(self.y_samples >= min_y,\n                                                                                     self.y_samples <= max_y)))\n            pred_visibility_mat[i, :] = np.logical_and(pred_visibility_mat[i, :], visibility_vec)\n            # pred_visibility_mat[i, :] = np.logical_and(x_values >= self.x_min, x_values <= self.x_max)\n\n        adj_mat = np.zeros((cnt_gt, cnt_pred), dtype=np.int)\n        cost_mat = np.zeros((cnt_gt, cnt_pred), dtype=np.int)\n        cost_mat.fill(1000)\n        num_match_mat = np.zeros((cnt_gt, cnt_pred), dtype=np.float)\n        # compute curve to curve distance\n        for i in range(cnt_gt):\n            for j in range(cnt_pred):\n                x_dist = np.abs(gt_lanes[i][:, 0] - pred_lanes[j][:, 0])\n                z_dist = np.abs(gt_lanes[i][:, 1] - pred_lanes[j][:, 1])\n                euclidean_dist = np.sqrt(x_dist ** 2 + z_dist ** 2)\n\n                # apply visibility to penalize different partial matching accordingly\n                euclidean_dist[\n                    np.logical_or(gt_visibility_mat[i, :] < 0.5, pred_visibility_mat[j, :] < 0.5)] = self.dist_th\n\n                # if np.average(euclidean_dist) < 2*self.dist_th: # don't prune here to encourage finding perfect match\n                num_match_mat[i, j] = np.sum(euclidean_dist < self.dist_th)\n                adj_mat[i, j] = 1\n                # ATTENTION: use the sum as int type to meet the requirements of min cost flow optimization (int type)\n                # why using num_match_mat as cost does not work?\n                cost_mat[i, j] = np.sum(euclidean_dist).astype(np.int)\n                # cost_mat[i, j] = num_match_mat[i, j]\n\n        # solve bipartite matching vis min cost flow solver\n        match_results = SolveMinCostFlow(adj_mat, cost_mat)\n        match_results = np.array(match_results)\n\n        # only a match with avg cost < self.dist_th is consider valid one\n        match_gt_ids = []\n        match_pred_ids = []\n        if match_results.shape[0] > 0:\n            for i in range(len(match_results)):\n                if match_results[i, 2] < self.dist_th * self.y_samples.shape[0]:\n                    gt_i = match_results[i, 0]\n                    pred_i = match_results[i, 1]\n                    # consider match when the matched points is above a ratio\n                    if num_match_mat[gt_i, pred_i] / np.sum(gt_visibility_mat[gt_i, :]) >= self.ratio_th:\n                        r_lane += 1\n                        match_gt_ids.append(gt_i)\n                    if num_match_mat[gt_i, pred_i] / np.sum(pred_visibility_mat[pred_i, :]) >= self.ratio_th:\n                        p_lane += 1\n                        match_pred_ids.append(pred_i)\n\n        return r_lane, p_lane, cnt_gt, cnt_pred\n\n    # evaluate two dataset at varying lane probability threshold to calculate AP\n    def bench_one_submit_varying_probs(self, pred_file, gt_file, eval_out_file=None, eval_fig_file=None):\n        varying_th = np.linspace(0.05, 0.95, 19)\n        # try:\n        pred_lines = open(pred_file).readlines()\n        json_pred = [json.loads(line) for line in pred_lines]\n        # except BaseException as e:\n        #     raise Exception('Fail to load json file of the prediction.')\n        json_gt = [json.loads(line) for line in open(gt_file).readlines()]\n        if len(json_gt) != len(json_pred):\n            raise Exception('We do not get the predictions of all the test tasks')\n        gts = {l['raw_file']: l for l in json_gt}\n\n        laneline_r_all = []\n        laneline_p_all = []\n        laneline_gt_cnt_all = []\n        laneline_pred_cnt_all = []\n        centerline_r_all = []\n        centerline_p_all = []\n        centerline_gt_cnt_all = []\n        centerline_pred_cnt_all = []\n        for i, pred in enumerate(json_pred):\n            self.logger.info('Evaluating sample {} / {}'.format(i, len(json_pred)))\n            if 'raw_file' not in pred or 'laneLines' not in pred:\n                raise Exception('raw_file or lanelines not in some predictions.')\n            raw_file = pred['raw_file']\n\n            pred_lanelines = pred['laneLines']\n            pred_laneLines_prob = pred['laneLines_prob']\n            if raw_file not in gts:\n                raise Exception('Some raw_file from your predictions do not exist in the test tasks.')\n            gt = gts[raw_file]\n            gt_cam_height = gt['cam_height']\n            gt_cam_pitch = gt['cam_pitch']\n\n            # evaluate lanelines\n            gt_lanelines = gt['laneLines']\n            gt_visibility = gt['laneLines_visibility']\n            r_lane_vec = []\n            p_lane_vec = []\n            cnt_gt_vec = []\n            cnt_pred_vec = []\n\n            for prob_th in varying_th:\n            \n                pred_lanelines_tmp = [pred_lanelines[ii] for ii in range(len(pred_laneLines_prob)) if\n                                pred_laneLines_prob[ii][1] > prob_th]\n                pred_laneLines_prob_tmp = [prob[1] for prob in pred_laneLines_prob if prob[1] > prob_th]\n                \n                pred_lanelines_copy = copy.deepcopy(pred_lanelines_tmp)\n                # N to N matching of lanelines\n                r_lane, p_lane, cnt_gt, cnt_pred = self.bench_PR(pred_lanelines_copy,\n                                                                 gt_lanelines,\n                                                                 gt_visibility)\n                r_lane_vec.append(r_lane)\n                p_lane_vec.append(p_lane)\n                cnt_gt_vec.append(cnt_gt)\n                cnt_pred_vec.append(cnt_pred)\n\n            laneline_r_all.append(r_lane_vec)\n            laneline_p_all.append(p_lane_vec)\n            laneline_gt_cnt_all.append(cnt_gt_vec)\n            laneline_pred_cnt_all.append(cnt_pred_vec)\n\n            # evaluate centerlines\n            if not self.no_centerline:\n                pred_centerlines = pred['centerLines']\n                pred_centerLines_prob = pred['centerLines_prob']\n                gt_centerlines = gt['centerLines']\n                gt_visibility = gt['centerLines_visibility']\n                r_lane_vec = []\n                p_lane_vec = []\n                cnt_gt_vec = []\n                cnt_pred_vec = []\n\n                for prob_th in varying_th:\n                    pred_centerlines = [pred_centerlines[ii] for ii in range(len(pred_centerLines_prob)) if\n                                        pred_centerLines_prob[ii] > prob_th]\n                    pred_centerLines_prob = [prob for prob in pred_centerLines_prob if prob > prob_th]\n                    pred_centerlines_copy = copy.deepcopy(pred_centerlines)\n                    # N to N matching of lanelines\n                    r_lane, p_lane, cnt_gt, cnt_pred = self.bench_PR(pred_centerlines_copy,\n                                                                     gt_centerlines,\n                                                                     gt_visibility)\n                    r_lane_vec.append(r_lane)\n                    p_lane_vec.append(p_lane)\n                    cnt_gt_vec.append(cnt_gt)\n                    cnt_pred_vec.append(cnt_pred)\n                centerline_r_all.append(r_lane_vec)\n                centerline_p_all.append(p_lane_vec)\n                centerline_gt_cnt_all.append(cnt_gt_vec)\n                centerline_pred_cnt_all.append(cnt_pred_vec)\n\n        output_stats = []\n        # compute precision, recall\n        laneline_r_all = np.array(laneline_r_all)\n        laneline_p_all = np.array(laneline_p_all)\n        laneline_gt_cnt_all = np.array(laneline_gt_cnt_all)\n        laneline_pred_cnt_all = np.array(laneline_pred_cnt_all)\n\n        R_lane = np.sum(laneline_r_all, axis=0) / (np.sum(laneline_gt_cnt_all, axis=0) + 1e-6)\n        P_lane = np.sum(laneline_p_all, axis=0) / (np.sum(laneline_pred_cnt_all, axis=0) + 1e-6)\n        F_lane = 2 * R_lane * P_lane / (R_lane + P_lane + 1e-6)\n\n        output_stats.append(F_lane)\n        output_stats.append(R_lane)\n        output_stats.append(P_lane)\n\n        # if self.args.dist:\n        #     output_stats.append(laneline_r_all)\n        #     output_stats.append(laneline_p_all)\n        #     output_stats.append(laneline_gt_cnt_all)\n        #     output_stats.append(laneline_pred_cnt_all)\n\n        if not self.no_centerline:\n            centerline_r_all = np.array(centerline_r_all)\n            centerline_p_all = np.array(centerline_p_all)\n            centerline_gt_cnt_all = np.array(centerline_gt_cnt_all)\n            centerline_pred_cnt_all = np.array(centerline_pred_cnt_all)\n\n            R_lane = np.sum(centerline_r_all, axis=0) / (np.sum(centerline_gt_cnt_all, axis=0) + 1e-6)\n            P_lane = np.sum(centerline_p_all, axis=0) / (np.sum(centerline_pred_cnt_all, axis=0) + 1e-6)\n            F_lane = 2 * R_lane * P_lane / (R_lane + P_lane + 1e-6)\n\n            output_stats.append(F_lane)\n            output_stats.append(R_lane)\n            output_stats.append(P_lane)\n\n        # calculate metrics\n        laneline_F = output_stats[0]\n        laneline_F_max = np.max(laneline_F)\n        laneline_max_i = np.argmax(laneline_F)\n        laneline_R = output_stats[1]\n        laneline_P = output_stats[2]\n\n        if not self.no_centerline:\n            centerline_F = output_stats[3]\n            centerline_F_max = centerline_F[laneline_max_i]\n            centerline_max_i = laneline_max_i\n            centerline_R = output_stats[4]\n            centerline_P = output_stats[5]\n\n        laneline_R = np.array([1.] + laneline_R.tolist() + [0.])\n        laneline_P = np.array([0.] + laneline_P.tolist() + [1.])\n        f_laneline = interp1d(laneline_R, laneline_P)\n        r_range = np.linspace(0.05, 0.95, 19)\n        laneline_AP = np.mean(f_laneline(r_range))\n        if not self.no_centerline:\n            centerline_R = np.array([1.] + centerline_R.tolist() + [0.])\n            centerline_P = np.array([0.] + centerline_P.tolist() + [1.])\n            f_centerline = interp1d(centerline_R, centerline_P)\n            centerline_AP = np.mean(f_centerline(r_range))\n        \n        if eval_fig_file is not None:\n            # plot PR curve\n            fig = plt.figure()\n            if not self.no_centerline:\n                ax1 = fig.add_subplot(121)\n                ax2 = fig.add_subplot(122)\n                ax1.plot(laneline_R, laneline_P, '-s')\n                ax2.plot(centerline_R, centerline_P, '-s')\n            else:\n                ax1 = fig.add_subplot(111)\n                ax1.plot(laneline_R, laneline_P, '-s')\n\n            ax1.set_xlim(0, 1)\n            ax1.set_ylim(0, 1)\n            ax1.set_title('Lane Line')\n            ax1.set_xlabel('Recall')\n            ax1.set_ylabel('Precision')\n            ax1.set_aspect('equal')\n            ax1.legend('Max F-measure {:.3}'.format(laneline_F_max))\n\n            if not self.no_centerline:\n                ax2.set_xlim(0, 1)\n                ax2.set_ylim(0, 1)\n                ax2.set_title('Center Line')\n                ax2.set_xlabel('Recall')\n                ax2.set_ylabel('Precision')\n                ax2.set_aspect('equal')\n                ax2.legend('Max F-measure {:.3}'.format(centerline_F_max))\n\n            # fig.subplots_adjust(wspace=0.1, hspace=0.01)\n            fig.savefig(eval_fig_file)\n            plt.close(fig)\n\n        json_out = {}\n        json_out['laneline_R'] = laneline_R[1:-1].astype(np.float32).tolist()\n        json_out['laneline_P'] = laneline_P[1:-1].astype(np.float32).tolist()\n        json_out['laneline_F_max'] = laneline_F_max\n        json_out['laneline_max_i'] = laneline_max_i.tolist()\n        json_out['laneline_AP'] = laneline_AP\n        \n        if not self.no_centerline:\n            json_out['centerline_R'] = centerline_R[1:-1].astype(np.float32).tolist()\n            json_out['centerline_P'] = centerline_P[1:-1].astype(np.float32).tolist()\n            json_out['centerline_F_max'] = centerline_F_max\n            json_out['centerline_max_i'] = centerline_max_i.tolist()\n            json_out['centerline_AP'] = centerline_AP\n\n        json_out['max_F_prob_th'] = varying_th[laneline_max_i]\n\n        if eval_out_file is not None:\n            with open(eval_out_file, 'w') as jsonFile:\n                jsonFile.write(json.dumps(json_out))\n                jsonFile.write('\\n')\n                jsonFile.close()\n        for k, v in json_out.items():\n            self.logger.info('%s: %s' % (k, v))\n        return json_out\n"
  },
  {
    "path": "utils/eval_3D_once.py",
    "content": "import argparse\nimport numpy as np\nfrom multiprocessing import Process\nimport cv2\n# from jarvis.eload import load_json\nfrom munkres import Munkres\nimport os\nfrom shapely.geometry import LineString\n# from jarvis.edump import ptable_to_csv\n# from jarvis.epath import inherit\nimport time\n# from jarvis.eload import load_json\nimport tempfile\nimport json\nfrom prettytable import PrettyTable\nimport torch\n\n\nclass Bev_Projector:\n    def __init__(self, side_range, fwd_range, height_range, res, lane_width_x, lane_width_y):\n        self.side_range = side_range\n        self.fwd_range = fwd_range\n        self.height_range = height_range\n        self.res = res\n        self.lane_width_x = lane_width_x\n        self.lane_width_y = lane_width_y\n        self.zx_xmax = int((self.side_range[1] - self.side_range[0]) / self.res)\n        self.zx_ymax = int((self.fwd_range[1] - self.fwd_range[0]) / self.res)\n        self.zy_xmax = self.zx_ymax\n        self.zy_ymax = int((self.height_range[1] - self.height_range[0]) / self.res)\n\n    def proj_oneline_zx(self, one_lane):\n        \"\"\"\n        :param one_lane: N*3,[[x,y,z],...]\n        :return:\n        \"\"\"\n        img = np.zeros([self.zx_ymax, self.zx_xmax], dtype=np.uint8)\n\n        one_lane = np.array(one_lane)\n        one_lane = one_lane[one_lane[:, 2] < 10]\n        lane_x = one_lane[:, 0]\n        lane_z = one_lane[:, 2]\n\n        x_img = (lane_x / self.res).astype(np.int32)\n        y_img = (-lane_z / self.res).astype(np.int32)\n\n        x_img += int(self.side_range[1] / self.res)\n        y_img += int(self.fwd_range[1] / self.res)\n\n        # img[y_img, x_img] = 255\n        for i in range(y_img.shape[0]-1):\n            cv2.line(img, (x_img[i], y_img[i]), (x_img[i+1], y_img[i+1]), 255, self.lane_width_x)\n        return img\n\n\nclass LaneEval:\n    @staticmethod\n    def file_parser(gt_root_path, pred_root_path):\n        gt_files_list = list()\n        pred_files_list = list()\n        for segment in os.listdir(gt_root_path):\n            gt_segment_path = os.path.join(gt_root_path, segment)\n            gt_segment_path = os.path.join(gt_segment_path, 'cam01')\n            pred_segment_path = os.path.join(pred_root_path, segment)\n\n            if not os.path.exists(pred_segment_path):\n                assert False\n                print('%s Missed from pred' % segment)\n                continue\n            pred_segment_path = os.path.join(pred_segment_path, 'cam01')\n            gt_files_list.extend([os.path.join(gt_segment_path, filename) for filename in\n                                  os.listdir(gt_segment_path) if filename.endswith(\".json\")])\n            pred_files_list.extend([os.path.join(pred_segment_path, filename) for filename in\n                                    os.listdir(gt_segment_path) if filename.endswith(\".json\")])\n        \n        return gt_files_list, pred_files_list\n\n    @staticmethod\n    def summarize(res):\n        gt_all = 0.\n        pred_all = 0.\n        tp_all = 0.\n        distance_mean = 0.\n        for res_spec in res:\n            gt_all += res_spec[0]\n            pred_all += res_spec[1]\n            tp_all += res_spec[2]\n            distance_mean += res_spec[3]\n\n        precision = tp_all / pred_all\n        recall = tp_all / gt_all\n        if precision + recall == 0:\n            F_value = 0.\n        else:\n            F_value = 2 * precision * recall / (precision + recall)\n\n        distance_mean /= tp_all + 1e-5\n        return dict(\n            F_value=F_value,\n            precision=precision,\n            recall=recall,\n            distance_error=distance_mean,\n        )\n\n    def lane_evaluation(self, gt_root_path, pred_root_path, config_path, args=None):\n        gt_files_list, pred_files_list = self.file_parser(gt_root_path, pred_root_path)\n        with open(config_path, 'r') as file:\n            file_lines = [line for line in file]\n            if len(file_lines) != 0:\n                config = json.loads(file_lines[0])\n        # config = json.loads(config_path)\n        process_num = config['process_num']\n        score_l = int(config[\"score_l\"] * 100)\n        score_h = int(config[\"score_h\"] * 100)\n        score_step = int(config[\"score_step\"] * 100)\n        score_num = int((score_h - score_l) / score_step)\n        config['score_num'] = score_num\n        \n        tempfile.tempdir = './tmp'\n        if not os.path.exists('./tmp'):\n            os.mkdir('./tmp')\n        tmp_dir = tempfile.mkdtemp()\n        if not os.path.exists(tmp_dir):\n            os.mkdir(tmp_dir)\n\n        gt_in_process = [[] for n in range(process_num)]\n        pr_in_process = [[] for n in range(process_num)]\n        n_file = 0\n        for gt_file, pred_file in zip(gt_files_list, pred_files_list):\n            gt_in_process[n_file % process_num].append(gt_file)\n            pr_in_process[n_file % process_num].append(pred_file)\n            n_file += 1\n        process_list = list()\n        for n in range(process_num):\n            tmp_file = tmp_dir + str(n) + \".json\"\n            config[\"tmp_file\"] = tmp_file\n            p = Process(target=evaluate_list, args=(gt_in_process[n], pr_in_process[n], config))\n            process_list.append(p)\n            p.start()\n        for p in process_list:\n            p.join()\n\n        torch.distributed.barrier()\n\n        gt_all = np.zeros((score_num,), dtype=np.float32)\n        pr_all = np.zeros((score_num,), dtype=np.float32)\n        tp_all = np.zeros((score_num,), dtype=np.float32)\n        distance_error = np.zeros((score_num,), dtype=np.float32)\n        for n in range(process_num):\n            tmp_file = tmp_dir + str(n) + \".json\"\n            json_data = json.load(open(tmp_file))\n            gt_all += json_data['gt_all']\n            pr_all += json_data['pr_all']\n            tp_all += json_data['tp_all']\n            distance_error += json_data['distance_error']\n\n        precision = tp_all / pr_all\n        recall = tp_all / gt_all\n\n        F_value = 2 * precision * recall / (precision + recall)\n\n        distance_error /= tp_all + 1e-5\n\n        pt = PrettyTable()\n        title_file = f'evaluate by {__file__}'\n        pt.title = f'{title_file}'\n        pt.field_names = ['prob_thresh', 'F1', 'precision', 'recall', 'D error']\n        for i in range(score_l, score_h, score_step):\n            index = int((i - score_l) / score_step)\n            pt.add_row([str(i / 100),\n                        F_value[index],\n                        precision[index],\n                        recall[index],\n                        distance_error[index]\n                        ])\n        if args.proc_id == 0:\n            print(pt)\n        result_dir = os.path.join(os.path.dirname(__file__), 'eval_results')\n        os.makedirs(result_dir, exist_ok=True)\n        result_file_name = config['exp_name']\n        # result_path = inherit(dirname=__file__, filename=result_file_name, middlename='eval_results', suffix='.csv')\n        # ptable_to_csv(table=pt, filename=result_path)\n        print(f'''legacy evaluate  end at {time.strftime('%Y-%m-%d @ %H:%M:%S')}''')\n        \n        return F_value\n\n\ndef evaluate_list(gt_path_list, pred_path_list, config):\n    bev_projector = Bev_Projector(side_range=(config['side_range_l'], config['side_range_h']),\n                                  fwd_range=(config['fwd_range_l'], config['fwd_range_h']),\n                                  height_range=(config['height_range_l'], config['height_range_h']),\n                                  res=config['res'], lane_width_x=config['lane_width_x'],\n                                  lane_width_y=config['lane_width_y'])\n    score_num = config[\"score_num\"]\n    tmp_file = config[\"tmp_file\"]\n    iou_thresh = config['iou_thresh']\n    distance_thresh = config['distance_thresh']\n\n    score_l = int(config[\"score_l\"] * 100)\n    score_h = int(config[\"score_h\"] * 100)\n    score_step = int(config[\"score_step\"] * 100)\n\n    gt_all = np.zeros((score_num,), dtype=np.float32)\n    pr_all = np.zeros((score_num,), dtype=np.float32)\n    tp_all = np.zeros((score_num,), dtype=np.float32)\n    distance_error = np.zeros((score_num,), dtype=np.float32)\n\n    for gt_path, pred_path in zip(gt_path_list, pred_path_list):\n        leof = LaneEvalOneFile(gt_path, pred_path, bev_projector, iou_thresh, distance_thresh, score_l, score_h, score_step)\n        gt_num, pr_num, tp_num, distance_tmp = leof.eval()\n        gt_all += gt_num\n        pr_all += pr_num\n        tp_all += tp_num\n        distance_error += distance_tmp\n    json_out_data = {\"gt_all\": gt_all.tolist(), \"pr_all\": pr_all.tolist(),\n                     \"tp_all\": tp_all.tolist(), \"distance_error\": distance_error.tolist()}\n    fid_tmp_out = open(tmp_file, 'w')\n    json.dump(json_out_data, fid_tmp_out, indent=4)\n    fid_tmp_out.close()\n\n\nclass LaneEvalOneFile:\n    def __init__(self, gt_path, pred_path, bev_projector, iou_thresh, distance_thresh, score_l, score_h, score_step):\n        self.gt_path = gt_path\n        self.pred_path = pred_path\n        self.bev_projector = bev_projector\n        self.iou_thresh = iou_thresh\n        self.distance_thresh = distance_thresh\n        self.score_l = score_l\n        self.score_h = score_h\n        self.score_step = score_step\n\n    def preprocess(self, store_spec):\n        gt_json = json.load(open(self.gt_path))\n        pred_json = json.load(open(self.pred_path))\n        gt_lanes3d = gt_json['lanes']\n        gt_lanes3d = [gt_lanespec3d for gt_lanespec3d in gt_lanes3d if len(gt_lanespec3d) >= 2]\n        gt_num = len(gt_lanes3d)\n        pred_lanes3d = pred_json['lanes']\n        pred_lanes3d = [pred_lanespec3d[\"points\"] for pred_lanespec3d in pred_lanes3d if len(pred_lanespec3d) >= 2\n                        and np.float(pred_lanespec3d[\"score\"]) > store_spec]\n        # pred_lanes3d = [pred_lanespec3d for pred_lanespec3d in pred_lanes3d if len(pred_lanespec3d) >= 2]\n        pred_num = len(pred_lanes3d)\n        return gt_lanes3d, gt_num, pred_lanes3d, pred_num\n\n    def calc_iou(self, lane1, lane2):\n        \"\"\"\n        :param lane1:\n        :param lane2:\n        :return:\n        \"\"\"\n        img1, img2 = self.bev_projector.proj_oneline_zx(lane1), self.bev_projector.proj_oneline_zx(lane2)\n\n        union_im = cv2.bitwise_or(img1, img2)\n        union_sum = union_im.sum()\n        inter_sum = img1.sum() + img2.sum() - union_sum\n        if union_sum == 0:\n            return 0\n        else:\n            return inter_sum / float(union_sum)\n\n    def cal_mean_dist(self, src_line, dst_line):\n        \"\"\"\n        :param src_line: gt\n        :param dst_line: pred\n        :return:\n        \"\"\"\n        src_line = LineString(np.array(src_line))\n        dst_line = LineString(np.array(dst_line))\n\n        total_distance = 0\n        samples = np.arange(0.05, 1, 0.1)\n        for sample in samples:\n            total_distance += src_line.interpolate(sample, normalized=True).distance(dst_line)\n        mean_distance = total_distance / samples.shape[0]\n        return mean_distance\n\n    def sort_lanes_z(self, lanes):\n        sorted_lanes = list()\n        for lane_spec in lanes:\n            if lane_spec[0][-1] > lane_spec[1][-1]:\n                lane_spec = lane_spec[::-1]\n            sorted_lanes.append(lane_spec)\n        return sorted_lanes\n\n    def eval(self):\n        gt_num = list()\n        pred_num = list()\n        tp = list()\n        distance_error = list()\n        for store in range(self.score_l, self.score_h, self.score_step):\n            store_spec = store * 0.01\n            gt_lanes, gt_num_spec, pred_lanes, pred_num_spec = self.preprocess(store_spec)\n            gt_lanes = self.sort_lanes_z(gt_lanes)\n            pred_lanes = self.sort_lanes_z(pred_lanes)\n            tp_spec, distance_error_spec = self.cal_tp(gt_num_spec, pred_num_spec, gt_lanes, pred_lanes)\n            gt_num.append(gt_num_spec)\n            pred_num.append(pred_num_spec)\n            tp.append(tp_spec)\n            distance_error.append(distance_error_spec)\n        return gt_num, pred_num, tp, distance_error\n\n    def cal_tp(self, gt_num, pred_num, gt_lanes, pred_lanes):\n        tp = 0\n        distance_error = 0\n        if gt_num > 0 and pred_num > 0:\n            iou_mat = [[0 for col in range(pred_num)] for row in range(gt_num)]\n            for i in range(gt_num):\n                for j in range(pred_num):\n                    iou_mat[i][j] = self.calc_iou(gt_lanes[i], pred_lanes[j])\n            cost_mat = []\n            for row in iou_mat:\n                cost_row = list()\n                for col in row:\n                    cost_row.append(1.0 - col)\n                cost_mat.append(cost_row)\n            m = Munkres()\n            match_idx = m.compute(cost_mat)  #\n            for row, col in match_idx:\n                gt_lane = gt_lanes[row]\n                pred_lane = pred_lanes[col]\n                cur_distance = self.cal_mean_dist(gt_lane, pred_lane)\n                if cur_distance < self.distance_thresh:\n                    distance_error += cur_distance\n                    tp += 1\n        return tp, distance_error\n\n\ndef parse_config():\n    parser = argparse.ArgumentParser(description='arg parser')\n    parser.add_argument('--cfg_file', type=str, default='/home/dingzihan/PersFormer_3DLane/config/once_eval_config.json', help='specify the config for evaluation')\n    # parser.add_argument('--gt_path', type=str, default=None, required=True, help='')\n    # parser.add_argument('--pred_path', type=str, default=None, required=True, help='')\n    args, unknown_args = parser.parse_known_args()\n    return args, unknown_args\n"
  },
  {
    "path": "utils/utils.py",
    "content": "# ==============================================================================\r\n# Copyright (c) 2022 The PersFormer Authors. All Rights Reserved.\r\n#\r\n# Licensed under the Apache License, Version 2.0 (the \"License\");\r\n# you may not use this file except in compliance with the License.\r\n# You may obtain a copy of the License at\r\n#\r\n#     http://www.apache.org/licenses/LICENSE-2.0\r\n#\r\n# Unless required by applicable law or agreed to in writing, software\r\n# distributed under the License is distributed on an \"AS IS\" BASIS,\r\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\r\n# See the License for the specific language governing permissions and\r\n# limitations under the License.\r\n# ==============================================================================\r\n\r\nimport argparse\r\nimport errno\r\nimport os\r\nimport sys\r\nfrom pathlib import Path\r\n\r\nimport cv2\r\nimport matplotlib\r\nimport numpy as np\r\nimport torch\r\nimport torch.nn.init as init\r\nimport torch.optim\r\nfrom torch.optim import lr_scheduler\r\nimport os.path as ops\r\n\r\n\r\nfrom scipy.interpolate import interp1d\r\nfrom scipy.special import softmax\r\nimport logging, datetime\r\n\r\nfrom experiments.gpu_utils import is_main_process\r\nfrom mmdet.utils import get_root_logger as get_mmdet_root_logger\r\n\r\n\r\ndef create_logger(args):\r\n    datenow = datetime.datetime.now()\r\n    ymd = '-'.join(list(map(str, [datenow.year, datenow.month, datenow.day])))\r\n    hms = ':'.join(list(map(str, [datenow.hour, datenow.minute, datenow.second])))\r\n    logname = '%s_%s' % (ymd, hms)\r\n    logdir = os.path.join(args.save_path, 'logs')\r\n    os.makedirs(logdir, exist_ok=True)\r\n\r\n    ckpt_name = Path(args.eval_ckpt).stem.split('checkpoint_model_epoch_')[-1]\r\n    logtype = 'eval_{}'.format(ckpt_name)  if args.evaluate else 'train'\r\n    filename = os.path.join(logdir, '%s_%s.log' % (logtype, logname))\r\n\r\n    logging.basicConfig(level=logging.INFO, \r\n                        format ='%(asctime)s %(filename)s[line:%(lineno)d] %(levelname)s %(message)s',\r\n                        datefmt='%a, %d-%b-%Y %H:%W:%S',\r\n                        filename=filename,\r\n                        filemode= 'w'\r\n                        )\r\n\r\n    # logger = logging.getLogger(filename)\r\n    logger = get_mmdet_root_logger(log_file=filename, log_level=logging.INFO)\r\n\r\n    return logger\r\n\r\n\r\ndef define_args():\r\n    parser = argparse.ArgumentParser(description='PersFormer_3DLane_Detection')\r\n    \r\n    # CUDNN usage\r\n    parser.add_argument(\"--cudnn\", type=str2bool, nargs='?', const=True, default=True, help=\"cudnn optimization active\")\r\n    \r\n    # DDP setting\r\n    parser.add_argument('--distributed', action='store_true')\r\n    parser.add_argument(\"--local_rank\", type=int)\r\n    parser.add_argument('--gpu', type=int, default = 0)\r\n    parser.add_argument('--world_size', type=int, default = 1)\r\n    parser.add_argument('--nodes', type=int, default = 1)\r\n    parser.add_argument('--eval_ckpt', type=str, default='')\r\n    parser.add_argument('--resume_from', type=str, default='')\r\n    parser.add_argument('--no_eval', action='store_true')\r\n\r\n    # General model settings\r\n    parser.add_argument('--nworkers', type=int, default=0, help='num of threads')\r\n    parser.add_argument('--test_mode', action='store_true', help='prevents loading latest saved model')\r\n    parser.add_argument('--start_epoch', type=int, default=0, help='prevents loading latest saved model')\r\n    parser.add_argument('--evaluate', action='store_true', default=False, help='only perform evaluation')\r\n    parser.add_argument('--vis', action='store_true')\r\n    parser.add_argument('--resume', type=str, default='', help='resume latest saved run')\r\n    parser.add_argument('--output_dir', default='openlane', type=str, \r\n                        help='output_dir name under `work_dirs`')\r\n    parser.add_argument('--evaluate_case', default='', type=str, \r\n                        help='scene name, some are in shor.')\r\n    parser.add_argument('--eval_freq', type=int, default=2,\r\n                        help='evaluation frequency during training, 0 means no eval', )\r\n    \r\n    # eval using gen-laneNet\r\n    parser.add_argument('--rewrite_pred', default=False, action='store_true', help='whether rewrite existing pred .json file.')\r\n    parser.add_argument('--save_best', default=False, action='store_true', help='only save best ckpt.')\r\n    \r\n    # workdir\r\n    parser.add_argument('--save_root', default='work_dirs', type=str)\r\n    # dataset\r\n    parser.add_argument('--dataset', default='300', type=str, help='1000 | 300 openlane dataset')\r\n    return parser\r\n\r\n\r\ndef prune_3d_lane_by_visibility(lane_3d, visibility):\r\n    lane_3d = lane_3d[visibility > 0, ...]\r\n    return lane_3d\r\n\r\n\r\ndef prune_3d_lane_by_range(lane_3d, x_min, x_max):\r\n    # TODO: solve hard coded range later\r\n    # remove points with y out of range\r\n    # 3D label may miss super long straight-line with only two points: Not have to be 200, gt need a min-step\r\n    # 2D dataset requires this to rule out those points projected to ground, but out of meaningful range\r\n    lane_3d = lane_3d[np.logical_and(lane_3d[:, 1] > 0, lane_3d[:, 1] < 200), ...]\r\n\r\n    # remove lane points out of x range\r\n    lane_3d = lane_3d[np.logical_and(lane_3d[:, 0] > x_min,\r\n                                     lane_3d[:, 0] < x_max), ...]\r\n    return lane_3d\r\n\r\n\r\ndef resample_laneline_in_y(input_lane, y_steps, out_vis=False):\r\n    \"\"\"\r\n        Interpolate x, z values at each anchor grid, including those beyond the range of input lnae y range\r\n    :param input_lane: N x 2 or N x 3 ndarray, one row for a point (x, y, z-optional).\r\n                       It requires y values of input lane in ascending order\r\n    :param y_steps: a vector of steps in y\r\n    :param out_vis: whether to output visibility indicator which only depends on input y range\r\n    :return:\r\n    \"\"\"\r\n\r\n    # at least two points are included\r\n    assert(input_lane.shape[0] >= 2)\r\n\r\n    y_min = np.min(input_lane[:, 1])-5\r\n    y_max = np.max(input_lane[:, 1])+5\r\n\r\n    if input_lane.shape[1] < 3:\r\n        input_lane = np.concatenate([input_lane, np.zeros([input_lane.shape[0], 1], dtype=np.float32)], axis=1)\r\n\r\n    f_x = interp1d(input_lane[:, 1], input_lane[:, 0], fill_value=\"extrapolate\")\r\n    f_z = interp1d(input_lane[:, 1], input_lane[:, 2], fill_value=\"extrapolate\")\r\n\r\n    x_values = f_x(y_steps)\r\n    z_values = f_z(y_steps)\r\n\r\n    if out_vis:\r\n        output_visibility = np.logical_and(y_steps >= y_min, y_steps <= y_max)\r\n        return x_values, z_values, output_visibility.astype(np.float32) + 1e-9\r\n    return x_values, z_values\r\n\r\n\r\ndef resample_laneline_in_y_with_vis(input_lane, y_steps, vis_vec):\r\n    \"\"\"\r\n        Interpolate x, z values at each anchor grid, including those beyond the range of input lnae y range\r\n    :param input_lane: N x 2 or N x 3 ndarray, one row for a point (x, y, z-optional).\r\n                       It requires y values of input lane in ascending order\r\n    :param y_steps: a vector of steps in y\r\n    :param out_vis: whether to output visibility indicator which only depends on input y range\r\n    :return:\r\n    \"\"\"\r\n\r\n    # at least two points are included\r\n    assert(input_lane.shape[0] >= 2)\r\n\r\n    if input_lane.shape[1] < 3:\r\n        input_lane = np.concatenate([input_lane, np.zeros([input_lane.shape[0], 1], dtype=np.float32)], axis=1)\r\n\r\n    f_x = interp1d(input_lane[:, 1], input_lane[:, 0], fill_value=\"extrapolate\")\r\n    f_z = interp1d(input_lane[:, 1], input_lane[:, 2], fill_value=\"extrapolate\")\r\n    f_vis = interp1d(input_lane[:, 1], vis_vec, fill_value=\"extrapolate\")\r\n\r\n    x_values = f_x(y_steps)\r\n    z_values = f_z(y_steps)\r\n    vis_values = f_vis(y_steps)\r\n\r\n    x_values = x_values[vis_values > 0.5]\r\n    y_values = y_steps[vis_values > 0.5]\r\n    z_values = z_values[vis_values > 0.5]\r\n    return np.array([x_values, y_values, z_values]).T\r\n\r\n\r\ndef homograpthy_g2im(cam_pitch, cam_height, K):\r\n    # transform top-view region to original image region\r\n    R_g2c = np.array([[1, 0, 0],\r\n                      [0, np.cos(np.pi / 2 + cam_pitch), -np.sin(np.pi / 2 + cam_pitch)],\r\n                      [0, np.sin(np.pi / 2 + cam_pitch), np.cos(np.pi / 2 + cam_pitch)]])\r\n    H_g2im = np.matmul(K, np.concatenate([R_g2c[:, 0:2], [[0], [cam_height], [0]]], 1))\r\n    return H_g2im\r\n\r\n\r\ndef projection_g2im(cam_pitch, cam_height, K):\r\n    P_g2c = np.array([[1,                             0,                              0,          0],\r\n                      [0, np.cos(np.pi / 2 + cam_pitch), -np.sin(np.pi / 2 + cam_pitch), cam_height],\r\n                      [0, np.sin(np.pi / 2 + cam_pitch),  np.cos(np.pi / 2 + cam_pitch),          0]])\r\n    P_g2im = np.matmul(K, P_g2c)\r\n    return P_g2im\r\n\r\n\r\ndef homograpthy_g2im_extrinsic(E, K):\r\n    \"\"\"E: extrinsic matrix, 4*4\"\"\"\r\n    E_inv = np.linalg.inv(E)[0:3, :]\r\n    H_g2c = E_inv[:, [0,1,3]]\r\n    H_g2im = np.matmul(K, H_g2c)\r\n    return H_g2im\r\n\r\n\r\ndef projection_g2im_extrinsic(E, K):\r\n    E_inv = np.linalg.inv(E)[0:3, :]\r\n    P_g2im = np.matmul(K, E_inv)\r\n    return P_g2im\r\n\r\n\r\ndef homography_crop_resize(org_img_size, crop_y, resize_img_size):\r\n    \"\"\"\r\n        compute the homography matrix transform original image to cropped and resized image\r\n    :param org_img_size: [org_h, org_w]\r\n    :param crop_y:\r\n    :param resize_img_size: [resize_h, resize_w]\r\n    :return:\r\n    \"\"\"\r\n    # transform original image region to network input region\r\n    ratio_x = resize_img_size[1] / org_img_size[1]\r\n    ratio_y = resize_img_size[0] / (org_img_size[0] - crop_y)\r\n    H_c = np.array([[ratio_x, 0, 0],\r\n                    [0, ratio_y, -ratio_y*crop_y],\r\n                    [0, 0, 1]])\r\n    return H_c\r\n\r\n\r\ndef homographic_transformation(Matrix, x, y):\r\n    \"\"\"\r\n    Helper function to transform coordinates defined by transformation matrix\r\n\r\n    Args:\r\n            Matrix (multi dim - array): 3x3 homography matrix\r\n            x (array): original x coordinates\r\n            y (array): original y coordinates\r\n    \"\"\"\r\n    ones = np.ones((1, len(y)))\r\n    coordinates = np.vstack((x, y, ones))\r\n    trans = np.matmul(Matrix, coordinates)\r\n\r\n    x_vals = trans[0, :]/trans[2, :]\r\n    y_vals = trans[1, :]/trans[2, :]\r\n    return x_vals, y_vals\r\n\r\n\r\ndef projective_transformation(Matrix, x, y, z):\r\n    \"\"\"\r\n    Helper function to transform coordinates defined by transformation matrix\r\n\r\n    Args:\r\n            Matrix (multi dim - array): 3x4 projection matrix\r\n            x (array): original x coordinates\r\n            y (array): original y coordinates\r\n            z (array): original z coordinates\r\n    \"\"\"\r\n    ones = np.ones((1, len(z)))\r\n    coordinates = np.vstack((x, y, z, ones))\r\n    trans = np.matmul(Matrix, coordinates)\r\n\r\n    x_vals = trans[0, :]/trans[2, :]\r\n    y_vals = trans[1, :]/trans[2, :]\r\n    return x_vals, y_vals\r\n\r\n\r\ndef first_run(save_path):\r\n    txt_file = os.path.join(save_path,'first_run.txt')\r\n    if not os.path.exists(txt_file):\r\n        open(txt_file, 'w').close()\r\n    else:\r\n        saved_epoch = open(txt_file).read()\r\n        if saved_epoch is None:\r\n            print('You forgot to delete [first run file]')\r\n            return '' \r\n        return saved_epoch\r\n    return ''\r\n\r\n\r\ndef mkdir_if_missing(directory):\r\n    if not os.path.exists(directory):\r\n        try:\r\n            os.makedirs(directory)\r\n        except OSError as e:\r\n            if e.errno != errno.EEXIST:\r\n                raise\r\n\r\n\r\n# trick from stackoverflow\r\ndef str2bool(argument):\r\n    if argument.lower() in ('yes', 'true', 't', 'y', '1'):\r\n        return True\r\n    elif argument.lower() in ('no', 'false', 'f', 'n', '0'):\r\n        return False\r\n    else:\r\n        raise argparse.ArgumentTypeError('Wrong argument in argparse, should be a boolean')\r\n\r\n\r\nclass Logger(object):\r\n    \"\"\"\r\n    Source https://github.com/Cysu/open-reid/blob/master/reid/utils/logging.py.\r\n    \"\"\"\r\n    def __init__(self, fpath=None):\r\n        self.console = sys.stdout\r\n        self.file = None\r\n        self.fpath = fpath\r\n        if fpath is not None:\r\n            mkdir_if_missing(os.path.dirname(fpath))\r\n            self.file = open(fpath, 'w')\r\n\r\n    def __del__(self):\r\n        self.close()\r\n\r\n    def __enter__(self):\r\n        pass\r\n\r\n    def __exit__(self, *args):\r\n        self.close()\r\n\r\n    def write(self, msg):\r\n        self.console.write(msg)\r\n        if self.file is not None:\r\n            self.file.write(msg)\r\n\r\n    def flush(self):\r\n        self.console.flush()\r\n        if self.file is not None:\r\n            self.file.flush()\r\n            os.fsync(self.file.fileno())\r\n\r\n    def close(self):\r\n        self.console.close()\r\n        if self.file is not None:\r\n            self.file.close()\r\n\r\n\r\nclass AverageMeter(object):\r\n    \"\"\"Computes and stores the average and current value\"\"\"\r\n    def __init__(self):\r\n        self.reset()\r\n\r\n    def reset(self):\r\n        self.val = 0\r\n        self.avg = 0\r\n        self.sum = 0\r\n        self.count = 0\r\n\r\n    def update(self, val, n=1):\r\n        self.val = val\r\n        self.sum += val * n\r\n        self.count += n\r\n        self.avg = self.sum / self.count\r\n\r\n\r\ndef define_optim(optim, params, lr, weight_decay):\r\n    if optim == 'adam':\r\n        optimizer = torch.optim.Adam(params, lr=lr, weight_decay=weight_decay)\r\n    elif optim == 'adamw':\r\n        optimizer = torch.optim.AdamW(params, lr=lr, weight_decay=weight_decay)\r\n    elif optim == 'sgd':\r\n        optimizer = torch.optim.SGD(params, lr=lr, momentum=0.9, weight_decay=weight_decay)\r\n    elif optim == 'rmsprop':\r\n        optimizer = torch.optim.RMSprop(params, lr=lr, momentum=0.9, weight_decay=weight_decay)\r\n    else:\r\n        raise KeyError(\"The requested optimizer: {} is not implemented\".format(optim))\r\n    return optimizer\r\n\r\n\r\ndef cosine_schedule_with_warmup(k, args, dataset_size=None):\r\n    # k : iter num\r\n    num_gpu = args.world_size\r\n    dataset_size = dataset_size\r\n    batch_size = args.batch_size\r\n    num_epochs = args.nepochs\r\n\r\n    if num_gpu == 1:\r\n        warmup_iters = 0\r\n    else:\r\n        warmup_iters = 1000 // num_gpu\r\n\r\n    if k < warmup_iters:\r\n        return (k + 1) / warmup_iters\r\n    else:\r\n        iter_per_epoch = (dataset_size + batch_size - 1) // batch_size\r\n        return 0.5 * (1 + np.cos(np.pi * (k - warmup_iters) / (num_epochs * iter_per_epoch)))\r\n\r\n\r\ndef define_scheduler(optimizer, args, dataset_size=None):\r\n    if args.lr_policy == 'lambda':\r\n        def lambda_rule(epoch):\r\n            lr_l = 1.0 - max(0, epoch + 1 - args.niter) / float(args.niter_decay + 1)\r\n            return lr_l\r\n        scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule)\r\n    elif args.lr_policy == 'step':\r\n        scheduler = lr_scheduler.StepLR(optimizer,\r\n                                        step_size=args.lr_decay_iters, gamma=args.gamma)\r\n    elif args.lr_policy == 'multi_step':\r\n        scheduler = lr_scheduler.MultiStepLR(\r\n            optimizer, step_size=args.lr_multi_steps, gamma=args.gamma)\r\n    elif args.lr_policy == 'cosine':\r\n        scheduler = lr_scheduler.CosineAnnealingLR(optimizer,\r\n                                                   T_max=args.T_max, eta_min=args.eta_min)\r\n    elif args.lr_policy == 'cosine_warm':\r\n        '''\r\n        lr_config = dict(\r\n            policy='CosineAnnealing',\r\n            warmup='linear',\r\n            warmup_iters=500,\r\n            warmup_ratio=1.0 / 3,\r\n            min_lr_ratio=1e-3)\r\n        '''\r\n        scheduler = lr_scheduler.CosineAnnealingWarmRestarts(optimizer,\r\n                                                             T_0=args.T_0, T_mult=args.T_mult, eta_min=args.eta_min)\r\n\r\n    # elif args.lr_policy == 'plateau':\r\n    #     scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min',\r\n    #                                                factor=args.gamma,\r\n    #                                                threshold=0.0001,\r\n    #                                                patience=args.lr_decay_iters)\r\n    elif args.lr_policy == 'cosine_warmup':\r\n        from functools import partial\r\n        cosine_warmup = partial(cosine_schedule_with_warmup, args=args, dataset_size=dataset_size)\r\n        scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=cosine_warmup)\r\n        \r\n    elif args.lr_policy == 'None':\r\n        scheduler = None\r\n    else:\r\n        return NotImplementedError('learning rate policy [%s] is not implemented', args.lr_policy)\r\n    return scheduler\r\n\r\n\r\ndef define_init_weights(model, init_w='normal', activation='relu'):\r\n    # print('Init weights in network with [{}]'.format(init_w))\r\n    if init_w == 'normal':\r\n        model.apply(weights_init_normal)\r\n    elif init_w == 'xavier':\r\n        model.apply(weights_init_xavier)\r\n    elif init_w == 'kaiming':\r\n        model.apply(weights_init_kaiming)\r\n    elif init_w == 'orthogonal':\r\n        model.apply(weights_init_orthogonal)\r\n    else:\r\n        raise NotImplementedError('initialization method [{}] is not implemented'.format(init_w))\r\n\r\n\r\ndef weights_init_normal(m):\r\n    classname = m.__class__.__name__\r\n#    print(classname)\r\n    if classname.find('Conv') != -1 or classname.find('ConvTranspose') != -1:\r\n        try:\r\n            init.normal_(m.weight.data, 0.0, 0.02)\r\n            if m.bias is not None:\r\n                m.bias.data.zero_()\r\n        except:\r\n            print(\"{} not support init\".format(str(classname)))\r\n    elif classname.find('Linear') != -1:\r\n        try:\r\n            init.normal_(m.weight.data, 0.0, 0.02)\r\n            if m.bias is not None:\r\n                m.bias.data.zero_()\r\n        except:\r\n            print(\"{} not support init\".format(str(classname)))\r\n    elif classname.find('BatchNorm2d') != -1:\r\n        try:\r\n            init.normal_(m.weight.data, 1.0, 0.02)\r\n            init.constant_(m.bias.data, 0.0)\r\n        except:\r\n            print(\"{} not support init\".format(str(classname)))\r\n\r\ndef weights_init_xavier(m):\r\n    classname = m.__class__.__name__\r\n    # print(classname)\r\n    if classname.find('Conv') != -1 or classname.find('ConvTranspose') != -1:\r\n        init.xavier_normal_(m.weight.data, gain=0.02)\r\n        if m.bias is not None:\r\n            m.bias.data.zero_()\r\n    elif classname.find('Linear') != -1:\r\n        init.xavier_normal_(m.weight.data, gain=0.02)\r\n        if m.bias is not None:\r\n            m.bias.data.zero_()\r\n    elif classname.find('BatchNorm2d') != -1:\r\n        init.normal_(m.weight.data, 1.0, 0.02)\r\n        init.constant_(m.bias.data, 0.0)\r\n\r\n\r\ndef weights_init_kaiming(m):\r\n    classname = m.__class__.__name__\r\n    # print(classname)\r\n    if classname.find('Conv') != -1 or classname.find('ConvTranspose') != -1:\r\n        init.kaiming_normal_(m.weight.data, a=0, mode='fan_in', nonlinearity='relu')\r\n        if m.bias is not None:\r\n            m.bias.data.zero_()\r\n    elif classname.find('Linear') != -1:\r\n        init.kaiming_normal_(m.weight.data, a=0, mode='fan_in', nonlinearity='relu')\r\n        if m.bias is not None:\r\n            m.bias.data.zero_()\r\n    elif classname.find('BatchNorm2d') != -1:\r\n        init.normal_(m.weight.data, 1.0, 0.02)\r\n        init.constant_(m.bias.data, 0.0)\r\n\r\n\r\ndef weights_init_orthogonal(m):\r\n    classname = m.__class__.__name__\r\n#    print(classname)\r\n    if classname.find('Conv') != -1 or classname.find('ConvTranspose') != -1:\r\n        init.orthogonal(m.weight.data, gain=1)\r\n        if m.bias is not None:\r\n            m.bias.data.zero_()\r\n    elif classname.find('Linear') != -1:\r\n        init.orthogonal(m.weight.data, gain=1)\r\n        if m.bias is not None:\r\n            m.bias.data.zero_()\r\n    elif classname.find('BatchNorm2d') != -1:\r\n        init.normal_(m.weight.data, 1.0, 0.02)\r\n        init.constant_(m.bias.data, 0.0)\r\n"
  },
  {
    "path": "work_dirs/.gitkeep",
    "content": ""
  }
]