Full Code of ZJULearning/resa for AI

main 2d4e6312673e cached
48 files
152.3 KB
41.5k tokens
236 symbols
1 requests
Download .txt
Repository: ZJULearning/resa
Branch: main
Commit: 2d4e6312673e
Files: 48
Total size: 152.3 KB

Directory structure:
gitextract_wg61ge90/

├── .gitignore
├── INSTALL.md
├── LICENSE
├── README.md
├── configs/
│   ├── culane.py
│   └── tusimple.py
├── datasets/
│   ├── __init__.py
│   ├── base_dataset.py
│   ├── culane.py
│   ├── registry.py
│   └── tusimple.py
├── main.py
├── models/
│   ├── __init__.py
│   ├── decoder.py
│   ├── registry.py
│   ├── resa.py
│   └── resnet.py
├── requirement.txt
├── runner/
│   ├── __init__.py
│   ├── evaluator/
│   │   ├── __init__.py
│   │   ├── culane/
│   │   │   ├── culane.py
│   │   │   ├── lane_evaluation/
│   │   │   │   ├── .gitignore
│   │   │   │   ├── Makefile
│   │   │   │   ├── include/
│   │   │   │   │   ├── counter.hpp
│   │   │   │   │   ├── hungarianGraph.hpp
│   │   │   │   │   ├── lane_compare.hpp
│   │   │   │   │   └── spline.hpp
│   │   │   │   └── src/
│   │   │   │       ├── counter.cpp
│   │   │   │       ├── evaluate.cpp
│   │   │   │       ├── lane_compare.cpp
│   │   │   │       └── spline.cpp
│   │   │   └── prob2lines.py
│   │   └── tusimple/
│   │       ├── getLane.py
│   │       ├── lane.py
│   │       └── tusimple.py
│   ├── logger.py
│   ├── net_utils.py
│   ├── optimizer.py
│   ├── recorder.py
│   ├── registry.py
│   ├── resa_trainer.py
│   ├── runner.py
│   └── scheduler.py
├── tools/
│   └── generate_seg_tusimple.py
└── utils/
    ├── __init__.py
    ├── config.py
    ├── registry.py
    └── transforms.py

================================================
FILE CONTENTS
================================================

================================================
FILE: .gitignore
================================================
work_dirs/
predicts/
output/
data/
data

__pycache__/
*/*.un~
.*.swp



*.egg-info/
*.egg

output.txt
.vscode/*
.DS_Store
tmp.*
*.pt
*.pth
*.un~


================================================
FILE: INSTALL.md
================================================

# Install

1. Clone the RESA repository
    ```
    git clone https://github.com/zjulearning/resa.git
    ```
    We call this directory as `$RESA_ROOT`

2. Create a conda virtual environment and activate it (conda is optional)

    ```Shell
    conda create -n resa python=3.8 -y
    conda activate resa
    ```

3. Install dependencies

    ```Shell
    # Install pytorch firstly, the cudatoolkit version should be same in your system. (you can also use pip to install pytorch and torchvision)
    conda install pytorch torchvision cudatoolkit=10.1 -c pytorch

    # Or you can install via pip
    pip install torch torchvision

    # Install python packages
    pip install -r requirements.txt
    ```

4. Data preparation

    Download [CULane](https://xingangpan.github.io/projects/CULane.html) and [Tusimple](https://github.com/TuSimple/tusimple-benchmark/issues/3). Then extract them to `$CULANEROOT` and `$TUSIMPLEROOT`. Create link to `data` directory.
    
    ```Shell
    cd $RESA_ROOT
    ln -s $CULANEROOT data/CULane
    ln -s $TUSIMPLEROOT data/tusimple
    ```

    For Tusimple, the segmentation annotation is not provided, hence we need to generate segmentation from the json annotation. 

    ```Shell
    python scripts/convert_tusimple.py --root $TUSIMPLEROOT
    # this will generate segmentations and two list files: train_gt.txt and test.txt
    ```

    For CULane, you should have structure like this:
    ```
    $RESA_ROOT/data/CULane/driver_xx_xxframe    # data folders x6
    $RESA_ROOT/data/CULane/laneseg_label_w16    # lane segmentation labels
    $RESA_ROOT/data/CULane/list                 # data lists
    ```

    For Tusimple, you should have structure like this:
    ```
    $RESA_ROOT/data/tusimple/clips # data folders
    $RESA_ROOT/data/tusimple/lable_data_xxxx.json # label json file x4
    $RESA_ROOT/data/tusimple/test_tasks_0627.json # test tasks json file
    $RESA_ROOT/data/tusimple/test_label.json # test label json file
    ```

5. Install CULane evaluation tools. 

    This tools requires OpenCV C++. Please follow [here](https://docs.opencv.org/master/d7/d9f/tutorial_linux_install.html) to install OpenCV C++.  Or just install opencv with command `sudo apt-get install libopencv-dev`

    
    Then compile the evaluation tool of CULane.
    ```Shell
    cd $RESA_ROOT/runner/evaluator/culane/lane_evaluation
    make
    cd -
    ```
    
    Note that, the default `opencv` version is 3. If you use opencv2, please modify the `OPENCV_VERSION := 3` to `OPENCV_VERSION := 2` in the `Makefile`.

================================================
FILE: LICENSE
================================================
Apache License
                           Version 2.0, January 2004
                        http://www.apache.org/licenses/

   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION

   1. Definitions.

      "License" shall mean the terms and conditions for use, reproduction,
      and distribution as defined by Sections 1 through 9 of this document.

      "Licensor" shall mean the copyright owner or entity authorized by
      the copyright owner that is granting the License.

      "Legal Entity" shall mean the union of the acting entity and all
      other entities that control, are controlled by, or are under common
      control with that entity. For the purposes of this definition,
      "control" means (i) the power, direct or indirect, to cause the
      direction or management of such entity, whether by contract or
      otherwise, or (ii) ownership of fifty percent (50%) or more of the
      outstanding shares, or (iii) beneficial ownership of such entity.

      "You" (or "Your") shall mean an individual or Legal Entity
      exercising permissions granted by this License.

      "Source" form shall mean the preferred form for making modifications,
      including but not limited to software source code, documentation
      source, and configuration files.

      "Object" form shall mean any form resulting from mechanical
      transformation or translation of a Source form, including but
      not limited to compiled object code, generated documentation,
      and conversions to other media types.

      "Work" shall mean the work of authorship, whether in Source or
      Object form, made available under the License, as indicated by a
      copyright notice that is included in or attached to the work
      (an example is provided in the Appendix below).

      "Derivative Works" shall mean any work, whether in Source or Object
      form, that is based on (or derived from) the Work and for which the
      editorial revisions, annotations, elaborations, or other modifications
      represent, as a whole, an original work of authorship. For the purposes
      of this License, Derivative Works shall not include works that remain
      separable from, or merely link (or bind by name) to the interfaces of,
      the Work and Derivative Works thereof.

      "Contribution" shall mean any work of authorship, including
      the original version of the Work and any modifications or additions
      to that Work or Derivative Works thereof, that is intentionally
      submitted to Licensor for inclusion in the Work by the copyright owner
      or by an individual or Legal Entity authorized to submit on behalf of
      the copyright owner. For the purposes of this definition, "submitted"
      means any form of electronic, verbal, or written communication sent
      to the Licensor or its representatives, including but not limited to
      communication on electronic mailing lists, source code control systems,
      and issue tracking systems that are managed by, or on behalf of, the
      Licensor for the purpose of discussing and improving the Work, but
      excluding communication that is conspicuously marked or otherwise
      designated in writing by the copyright owner as "Not a Contribution."

      "Contributor" shall mean Licensor and any individual or Legal Entity
      on behalf of whom a Contribution has been received by Licensor and
      subsequently incorporated within the Work.

   2. Grant of Copyright License. Subject to the terms and conditions of
      this License, each Contributor hereby grants to You a perpetual,
      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
      copyright license to reproduce, prepare Derivative Works of,
      publicly display, publicly perform, sublicense, and distribute the
      Work and such Derivative Works in Source or Object form.

   3. Grant of Patent License. Subject to the terms and conditions of
      this License, each Contributor hereby grants to You a perpetual,
      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
      (except as stated in this section) patent license to make, have made,
      use, offer to sell, sell, import, and otherwise transfer the Work,
      where such license applies only to those patent claims licensable
      by such Contributor that are necessarily infringed by their
      Contribution(s) alone or by combination of their Contribution(s)
      with the Work to which such Contribution(s) was submitted. If You
      institute patent litigation against any entity (including a
      cross-claim or counterclaim in a lawsuit) alleging that the Work
      or a Contribution incorporated within the Work constitutes direct
      or contributory patent infringement, then any patent licenses
      granted to You under this License for that Work shall terminate
      as of the date such litigation is filed.

   4. Redistribution. You may reproduce and distribute copies of the
      Work or Derivative Works thereof in any medium, with or without
      modifications, and in Source or Object form, provided that You
      meet the following conditions:

      (a) You must give any other recipients of the Work or
          Derivative Works a copy of this License; and

      (b) You must cause any modified files to carry prominent notices
          stating that You changed the files; and

      (c) You must retain, in the Source form of any Derivative Works
          that You distribute, all copyright, patent, trademark, and
          attribution notices from the Source form of the Work,
          excluding those notices that do not pertain to any part of
          the Derivative Works; and

      (d) If the Work includes a "NOTICE" text file as part of its
          distribution, then any Derivative Works that You distribute must
          include a readable copy of the attribution notices contained
          within such NOTICE file, excluding those notices that do not
          pertain to any part of the Derivative Works, in at least one
          of the following places: within a NOTICE text file distributed
          as part of the Derivative Works; within the Source form or
          documentation, if provided along with the Derivative Works; or,
          within a display generated by the Derivative Works, if and
          wherever such third-party notices normally appear. The contents
          of the NOTICE file are for informational purposes only and
          do not modify the License. You may add Your own attribution
          notices within Derivative Works that You distribute, alongside
          or as an addendum to the NOTICE text from the Work, provided
          that such additional attribution notices cannot be construed
          as modifying the License.

      You may add Your own copyright statement to Your modifications and
      may provide additional or different license terms and conditions
      for use, reproduction, or distribution of Your modifications, or
      for any such Derivative Works as a whole, provided Your use,
      reproduction, and distribution of the Work otherwise complies with
      the conditions stated in this License.

   5. Submission of Contributions. Unless You explicitly state otherwise,
      any Contribution intentionally submitted for inclusion in the Work
      by You to the Licensor shall be under the terms and conditions of
      this License, without any additional terms or conditions.
      Notwithstanding the above, nothing herein shall supersede or modify
      the terms of any separate license agreement you may have executed
      with Licensor regarding such Contributions.

   6. Trademarks. This License does not grant permission to use the trade
      names, trademarks, service marks, or product names of the Licensor,
      except as required for reasonable and customary use in describing the
      origin of the Work and reproducing the content of the NOTICE file.

   7. Disclaimer of Warranty. Unless required by applicable law or
      agreed to in writing, Licensor provides the Work (and each
      Contributor provides its Contributions) on an "AS IS" BASIS,
      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
      implied, including, without limitation, any warranties or conditions
      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
      PARTICULAR PURPOSE. You are solely responsible for determining the
      appropriateness of using or redistributing the Work and assume any
      risks associated with Your exercise of permissions under this License.

   8. Limitation of Liability. In no event and under no legal theory,
      whether in tort (including negligence), contract, or otherwise,
      unless required by applicable law (such as deliberate and grossly
      negligent acts) or agreed to in writing, shall any Contributor be
      liable to You for damages, including any direct, indirect, special,
      incidental, or consequential damages of any character arising as a
      result of this License or out of the use or inability to use the
      Work (including but not limited to damages for loss of goodwill,
      work stoppage, computer failure or malfunction, or any and all
      other commercial damages or losses), even if such Contributor
      has been advised of the possibility of such damages.

   9. Accepting Warranty or Additional Liability. While redistributing
      the Work or Derivative Works thereof, You may choose to offer,
      and charge a fee for, acceptance of support, warranty, indemnity,
      or other liability obligations and/or rights consistent with this
      License. However, in accepting such obligations, You may act only
      on Your own behalf and on Your sole responsibility, not on behalf
      of any other Contributor, and only if You agree to indemnify,
      defend, and hold each Contributor harmless for any liability
      incurred by, or claims asserted against, such Contributor by reason
      of your accepting any such warranty or additional liability.

   END OF TERMS AND CONDITIONS

   APPENDIX: How to apply the Apache License to your work.

      To apply the Apache License to your work, attach the following
      boilerplate notice, with the fields enclosed by brackets "[]"
      replaced with your own identifying information. (Don't include
      the brackets!)  The text should be enclosed in the appropriate
      comment syntax for the file format. We also recommend that a
      file or class name and description of purpose be included on the
      same "printed page" as the copyright notice for easier
      identification within third-party archives.

   Copyright 2021 Tu Zheng

   Licensed under the Apache License, Version 2.0 (the "License");
   you may not use this file except in compliance with the License.
   You may obtain a copy of the License at

       http://www.apache.org/licenses/LICENSE-2.0

   Unless required by applicable law or agreed to in writing, software
   distributed under the License is distributed on an "AS IS" BASIS,
   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
   See the License for the specific language governing permissions and
   limitations under the License.

================================================
FILE: README.md
================================================
# RESA 
PyTorch implementation of the paper "[RESA: Recurrent Feature-Shift Aggregator for Lane Detection](https://arxiv.org/abs/2008.13719)".

Our paper has been accepted by AAAI2021.

**News**: We also release RESA on [LaneDet](https://github.com/Turoad/lanedet). It's also recommended for you to try LaneDet.

## Introduction
![intro](intro.png "intro")
- RESA shifts sliced
feature map recurrently in vertical and horizontal directions
and enables each pixel to gather global information.
- RESA achieves SOTA results on CULane and Tusimple Dataset.

## Get started
1. Clone the RESA repository
    ```
    git clone https://github.com/zjulearning/resa.git
    ```
    We call this directory as `$RESA_ROOT`

2. Create a conda virtual environment and activate it (conda is optional)

    ```Shell
    conda create -n resa python=3.8 -y
    conda activate resa
    ```

3. Install dependencies

    ```Shell
    # Install pytorch firstly, the cudatoolkit version should be same in your system. (you can also use pip to install pytorch and torchvision)
    conda install pytorch torchvision cudatoolkit=10.1 -c pytorch

    # Or you can install via pip
    pip install torch torchvision

    # Install python packages
    pip install -r requirements.txt
    ```

4. Data preparation

    Download [CULane](https://xingangpan.github.io/projects/CULane.html) and [Tusimple](https://github.com/TuSimple/tusimple-benchmark/issues/3). Then extract them to `$CULANEROOT` and `$TUSIMPLEROOT`. Create link to `data` directory.
    
    ```Shell
    cd $RESA_ROOT
    mkdir -p data
    ln -s $CULANEROOT data/CULane
    ln -s $TUSIMPLEROOT data/tusimple
    ```

    For CULane, you should have structure like this:
    ```
    $CULANEROOT/driver_xx_xxframe    # data folders x6
    $CULANEROOT/laneseg_label_w16    # lane segmentation labels
    $CULANEROOT/list                 # data lists
    ```

    For Tusimple, you should have structure like this:
    ```
    $TUSIMPLEROOT/clips # data folders
    $TUSIMPLEROOT/lable_data_xxxx.json # label json file x4
    $TUSIMPLEROOT/test_tasks_0627.json # test tasks json file
    $TUSIMPLEROOT/test_label.json # test label json file

    ```

    For Tusimple, the segmentation annotation is not provided, hence we need to generate segmentation from the json annotation. 

    ```Shell
    python tools/generate_seg_tusimple.py --root $TUSIMPLEROOT
    # this will generate seg_label directory
    ```

5. Install CULane evaluation tools. 

    This tools requires OpenCV C++. Please follow [here](https://docs.opencv.org/master/d7/d9f/tutorial_linux_install.html) to install OpenCV C++.  Or just install opencv with command `sudo apt-get install libopencv-dev`

    
    Then compile the evaluation tool of CULane.
    ```Shell
    cd $RESA_ROOT/runner/evaluator/culane/lane_evaluation
    make
    cd -
    ```
    
    Note that, the default `opencv` version is 3. If you use opencv2, please modify the `OPENCV_VERSION := 3` to `OPENCV_VERSION := 2` in the `Makefile`.


## Training

For training, run

```Shell
python main.py [configs/path_to_your_config] --gpus [gpu_ids]
```


For example, run
```Shell
python main.py configs/culane.py --gpus 0 1 2 3
```

## Testing
For testing, run
```Shell
python main.py c[configs/path_to_your_config] --validate --load_from [path_to_your_model] [gpu_num]
```

For example, run
```Shell
python main.py configs/culane.py --validate --load_from culane_resnet50.pth --gpus 0 1 2 3

python main.py configs/tusimple.py --validate --load_from tusimple_resnet34.pth --gpus 0 1 2 3
```


We provide two trained ResNet models on CULane and Tusimple, downloading our best performed model (Tusimple: [GoogleDrive](https://drive.google.com/file/d/1M1xi82y0RoWUwYYG9LmZHXWSD2D60o0D/view?usp=sharing)/[BaiduDrive(code:s5ii)](https://pan.baidu.com/s/1CgJFrt9OHe-RUNooPpHRGA),
CULane: [GoogleDrive](https://drive.google.com/file/d/1pcqq9lpJ4ixJgFVFndlPe42VgVsjgn0Q/view?usp=sharing)/[BaiduDrive(code:rlwj)](https://pan.baidu.com/s/1ODKAZxpKrZIPXyaNnxcV3g)
)

## Visualization
Just add `--view`.

For example:
```Shell
python main.py configs/culane.py --validate --load_from culane_resnet50.pth --gpus 0 1 2 3 --view
```
You will get the result in the directory: `work_dirs/[DATASET]/xxx/vis`.

## Citation
If you use our method, please consider citing:
```BibTeX
@inproceedings{zheng2021resa,
  title={RESA: Recurrent Feature-Shift Aggregator for Lane Detection},
  author={Zheng, Tu and Fang, Hao and Zhang, Yi and Tang, Wenjian and Yang, Zheng and Liu, Haifeng and Cai, Deng},
  booktitle={Proceedings of the AAAI Conference on Artificial Intelligence},
  volume={35},
  number={4},
  pages={3547--3554},
  year={2021}
}
```

<!-- ## Thanks

The evaluation code is modified from [SCNN](https://github.com/XingangPan/SCNN) and [Tusimple Benchmark](https://github.com/TuSimple/tusimple-benchmark). -->


================================================
FILE: configs/culane.py
================================================
net = dict(
    type='RESANet',
)

backbone = dict(
    type='ResNetWrapper',
    resnet='resnet50',
    pretrained=True,
    replace_stride_with_dilation=[False, True, True],
    out_conv=True,
    fea_stride=8,
)

resa = dict(
    type='RESA',
    alpha=2.0,
    iter=4,
    input_channel=128,
    conv_stride=9,
)

decoder = 'PlainDecoder'        

trainer = dict(
    type='RESA'
)

evaluator = dict(
    type='CULane',        
)

optimizer = dict(
  type='sgd',
  lr=0.025,
  weight_decay=1e-4,
  momentum=0.9
)

epochs = 12
batch_size = 8
total_iter = (88880 // batch_size) * epochs
import math
scheduler = dict(
    type = 'LambdaLR',
    lr_lambda = lambda _iter : math.pow(1 - _iter/total_iter, 0.9)
)

loss_type = 'dice_loss'
seg_loss_weight = 2.
eval_ep = 6
save_ep = epochs

bg_weight = 0.4

img_norm = dict(
    mean=[103.939, 116.779, 123.68],
    std=[1., 1., 1.]
)

img_height = 288
img_width = 800
cut_height = 240 

dataset_path = './data/CULane'
dataset = dict(
    train=dict(
        type='CULane',
        img_path=dataset_path,
        data_list='train_gt.txt',
    ),
    val=dict(
        type='CULane',
        img_path=dataset_path,
        data_list='test.txt',
    ),
    test=dict(
        type='CULane',
        img_path=dataset_path,
        data_list='test.txt',
    )
)


workers = 12
num_classes = 4 + 1
ignore_label = 255
log_interval = 500


================================================
FILE: configs/tusimple.py
================================================
net = dict(
    type='RESANet',
)

backbone = dict(
    type='ResNetWrapper',
    resnet='resnet34',
    pretrained=True,
    replace_stride_with_dilation=[False, True, True],
    out_conv=True,
    fea_stride=8,
)

resa = dict(
    type='RESA',
    alpha=2.0,
    iter=5,
    input_channel=128,
    conv_stride=9,
)

decoder = 'BUSD'        

trainer = dict(
    type='RESA'
)

evaluator = dict(
    type='Tusimple',        
    thresh = 0.60
)

optimizer = dict(
  type='sgd',
  lr=0.020,
  weight_decay=1e-4,
  momentum=0.9
)

total_iter = 80000
import math
scheduler = dict(
    type = 'LambdaLR',
    lr_lambda = lambda _iter : math.pow(1 - _iter/total_iter, 0.9)
)

bg_weight = 0.4

img_norm = dict(
    mean=[103.939, 116.779, 123.68],
    std=[1., 1., 1.]
)

img_height = 368
img_width = 640
cut_height = 160
seg_label = "seg_label"

dataset_path = './data/tusimple'
test_json_file = './data/tusimple/test_label.json'

dataset = dict(
    train=dict(
        type='TuSimple',
        img_path=dataset_path,
        data_list='train_val_gt.txt',
    ),
    val=dict(
        type='TuSimple',
        img_path=dataset_path,
        data_list='test_gt.txt'
    ),
    test=dict(
        type='TuSimple',
        img_path=dataset_path,
        data_list='test_gt.txt'
    )
)


loss_type = 'cross_entropy'
seg_loss_weight = 1.0


batch_size = 4
workers = 12
num_classes = 6 + 1
ignore_label = 255
epochs = 300
log_interval = 100
eval_ep = 1
save_ep = epochs
log_note = ''


================================================
FILE: datasets/__init__.py
================================================
from .registry import build_dataset, build_dataloader

from .tusimple import TuSimple
from .culane import CULane


================================================
FILE: datasets/base_dataset.py
================================================
import os.path as osp
import os
import numpy as np
import cv2
import torch
from torch.utils.data import Dataset
import torchvision
import utils.transforms as tf
from .registry import DATASETS


@DATASETS.register_module
class BaseDataset(Dataset):
    def __init__(self, img_path, data_list, list_path='list', cfg=None):
        self.cfg = cfg
        self.img_path = img_path
        self.list_path = osp.join(img_path, list_path)
        self.data_list = data_list
        self.is_training = ('train' in data_list)

        self.img_name_list = []
        self.full_img_path_list = []
        self.label_list = []
        self.exist_list = []

        self.transform = self.transform_train() if self.is_training else self.transform_val()

        self.init()

    def transform_train(self):
        raise NotImplementedError()

    def transform_val(self):
        val_transform = torchvision.transforms.Compose([
            tf.SampleResize((self.cfg.img_width, self.cfg.img_height)),
            tf.GroupNormalize(mean=(self.cfg.img_norm['mean'], (0, )), std=(
                self.cfg.img_norm['std'], (1, ))),
        ])
        return val_transform

    def view(self, img, coords, file_path=None):
        for coord in coords:
            for x, y in coord:
                if x <= 0 or y <= 0:
                    continue
                x, y = int(x), int(y)
                cv2.circle(img, (x, y), 4, (255, 0, 0), 2)

        if file_path is not None:
            if not os.path.exists(osp.dirname(file_path)):
                os.makedirs(osp.dirname(file_path))
            cv2.imwrite(file_path, img)


    def init(self):
        raise NotImplementedError()


    def __len__(self):
        return len(self.full_img_path_list)

    def __getitem__(self, idx):
        img = cv2.imread(self.full_img_path_list[idx]).astype(np.float32)
        img = img[self.cfg.cut_height:, :, :]

        if self.is_training:
            label = cv2.imread(self.label_list[idx], cv2.IMREAD_UNCHANGED)
            if len(label.shape) > 2:
                label = label[:, :, 0]
            label = label.squeeze()
            label = label[self.cfg.cut_height:, :]
            exist = self.exist_list[idx]
            if self.transform:
                img, label = self.transform((img, label))
            label = torch.from_numpy(label).contiguous().long()
        else:
            img, = self.transform((img,))

        img = torch.from_numpy(img).permute(2, 0, 1).contiguous().float()
        meta = {'full_img_path': self.full_img_path_list[idx],
                'img_name': self.img_name_list[idx]}

        data = {'img': img, 'meta': meta}
        if self.is_training:
            data.update({'label': label, 'exist': exist})
        return data


================================================
FILE: datasets/culane.py
================================================
import os
import os.path as osp
import numpy as np
import torchvision
import utils.transforms as tf
from .base_dataset import BaseDataset
from .registry import DATASETS
import cv2
import torch


@DATASETS.register_module
class CULane(BaseDataset):
    def __init__(self, img_path, data_list, cfg=None):
        super().__init__(img_path, data_list, cfg=cfg)
        self.ori_imgh = 590
        self.ori_imgw = 1640

    def init(self):
        with open(osp.join(self.list_path, self.data_list)) as f:
            for line in f:
                line_split = line.strip().split(" ")
                self.img_name_list.append(line_split[0])
                self.full_img_path_list.append(self.img_path + line_split[0])
                if not self.is_training:
                    continue
                self.label_list.append(self.img_path + line_split[1])
                self.exist_list.append(
                    np.array([int(line_split[2]), int(line_split[3]),
                              int(line_split[4]), int(line_split[5])]))

    def transform_train(self):
        train_transform = torchvision.transforms.Compose([
            tf.GroupRandomRotation(degree=(-2, 2)),
            tf.GroupRandomHorizontalFlip(),
            tf.SampleResize((self.cfg.img_width, self.cfg.img_height)),
            tf.GroupNormalize(mean=(self.cfg.img_norm['mean'], (0, )), std=(
                self.cfg.img_norm['std'], (1, ))),
        ])
        return train_transform

    def probmap2lane(self, probmaps, exists, pts=18):
        coords = []
        probmaps = probmaps[1:, ...]
        exists = exists > 0.5
        for probmap, exist in zip(probmaps, exists):
            if exist == 0:
                continue
            probmap = cv2.blur(probmap, (9, 9), borderType=cv2.BORDER_REPLICATE)
            thr = 0.3
            coordinate = np.zeros(pts)
            cut_height = self.cfg.cut_height
            for i in range(pts):
                line = probmap[round(
                    self.cfg.img_height-i*20/(self.ori_imgh-cut_height)*self.cfg.img_height)-1]

                if np.max(line) > thr:
                    coordinate[i] = np.argmax(line)+1
            if np.sum(coordinate > 0) < 2:
                continue
    
            img_coord = np.zeros((pts, 2))
            img_coord[:, :] = -1
            for idx, value in enumerate(coordinate):
                if value > 0:
                    img_coord[idx][0] = round(value*self.ori_imgw/self.cfg.img_width-1)
                    img_coord[idx][1] = round(self.ori_imgh-idx*20-1)
    
            img_coord = img_coord.astype(int)
            coords.append(img_coord)
    
        return coords


================================================
FILE: datasets/registry.py
================================================
from utils import Registry, build_from_cfg

import torch

DATASETS = Registry('datasets')

def build(cfg, registry, default_args=None):
    if isinstance(cfg, list):
        modules = [
            build_from_cfg(cfg_, registry, default_args) for cfg_ in cfg
        ]
        return nn.Sequential(*modules)
    else:
        return build_from_cfg(cfg, registry, default_args)


def build_dataset(split_cfg, cfg):
    args = split_cfg.copy()
    args.pop('type')
    args = args.to_dict()
    args['cfg'] = cfg
    return build(split_cfg, DATASETS, default_args=args)

def build_dataloader(split_cfg, cfg, is_train=True):
    if is_train:
        shuffle = True
    else:
        shuffle = False

    dataset = build_dataset(split_cfg, cfg)

    data_loader = torch.utils.data.DataLoader(
        dataset, batch_size = cfg.batch_size, shuffle = shuffle,
        num_workers = cfg.workers, pin_memory = False, drop_last = False)

    return data_loader


================================================
FILE: datasets/tusimple.py
================================================
import os.path as osp
import numpy as np
import cv2
import torchvision
import utils.transforms as tf
from .base_dataset import BaseDataset
from .registry import DATASETS


@DATASETS.register_module
class TuSimple(BaseDataset):
    def __init__(self, img_path, data_list, cfg=None):
        super().__init__(img_path, data_list, 'seg_label/list', cfg)

    def transform_train(self):
        input_mean = self.cfg.img_norm['mean']
        train_transform = torchvision.transforms.Compose([
            tf.GroupRandomRotation(),
            tf.GroupRandomHorizontalFlip(),
            tf.SampleResize((self.cfg.img_width, self.cfg.img_height)),
            tf.GroupNormalize(mean=(self.cfg.img_norm['mean'], (0, )), std=(
                self.cfg.img_norm['std'], (1, ))),
        ])
        return train_transform


    def init(self):
        with open(osp.join(self.list_path, self.data_list)) as f:
            for line in f:
                line_split = line.strip().split(" ")
                self.img_name_list.append(line_split[0])
                self.full_img_path_list.append(self.img_path + line_split[0])
                if not self.is_training:
                    continue
                self.label_list.append(self.img_path + line_split[1])
                self.exist_list.append(
                    np.array([int(line_split[2]), int(line_split[3]),
                              int(line_split[4]), int(line_split[5]),
                              int(line_split[6]), int(line_split[7])
                              ]))

    def fix_gap(self, coordinate):
        if any(x > 0 for x in coordinate):
            start = [i for i, x in enumerate(coordinate) if x > 0][0]
            end = [i for i, x in reversed(list(enumerate(coordinate))) if x > 0][0]
            lane = coordinate[start:end+1]
            if any(x < 0 for x in lane):
                gap_start = [i for i, x in enumerate(
                    lane[:-1]) if x > 0 and lane[i+1] < 0]
                gap_end = [i+1 for i,
                           x in enumerate(lane[:-1]) if x < 0 and lane[i+1] > 0]
                gap_id = [i for i, x in enumerate(lane) if x < 0]
                if len(gap_start) == 0 or len(gap_end) == 0:
                    return coordinate
                for id in gap_id:
                    for i in range(len(gap_start)):
                        if i >= len(gap_end):
                            return coordinate
                        if id > gap_start[i] and id < gap_end[i]:
                            gap_width = float(gap_end[i] - gap_start[i])
                            lane[id] = int((id - gap_start[i]) / gap_width * lane[gap_end[i]] + (
                                gap_end[i] - id) / gap_width * lane[gap_start[i]])
                if not all(x > 0 for x in lane):
                    print("Gaps still exist!")
                coordinate[start:end+1] = lane
        return coordinate

    def is_short(self, lane):
        start = [i for i, x in enumerate(lane) if x > 0]
        if not start:
            return 1
        else:
            return 0

    def get_lane(self, prob_map, y_px_gap, pts, thresh, resize_shape=None):
        """
        Arguments:
        ----------
        prob_map: prob map for single lane, np array size (h, w)
        resize_shape:  reshape size target, (H, W)
    
        Return:
        ----------
        coords: x coords bottom up every y_px_gap px, 0 for non-exist, in resized shape
        """
        if resize_shape is None:
            resize_shape = prob_map.shape
        h, w = prob_map.shape
        H, W = resize_shape
        H -= self.cfg.cut_height
    
        coords = np.zeros(pts)
        coords[:] = -1.0
        for i in range(pts):
            y = int((H - 10 - i * y_px_gap) * h / H)
            if y < 0:
                break
            line = prob_map[y, :]
            id = np.argmax(line)
            if line[id] > thresh:
                coords[i] = int(id / w * W)
        if (coords > 0).sum() < 2:
            coords = np.zeros(pts)
        self.fix_gap(coords)
        #print(coords.shape)

        return coords

    def probmap2lane(self, seg_pred, exist, resize_shape=(720, 1280), smooth=True, y_px_gap=10, pts=56, thresh=0.6):
        """
        Arguments:
        ----------
        seg_pred:      np.array size (5, h, w)
        resize_shape:  reshape size target, (H, W)
        exist:       list of existence, e.g. [0, 1, 1, 0]
        smooth:      whether to smooth the probability or not
        y_px_gap:    y pixel gap for sampling
        pts:     how many points for one lane
        thresh:  probability threshold
    
        Return:
        ----------
        coordinates: [x, y] list of lanes, e.g.: [ [[9, 569], [50, 549]] ,[[630, 569], [647, 549]] ]
        """
        if resize_shape is None:
            resize_shape = seg_pred.shape[1:]  # seg_pred (5, h, w)
        _, h, w = seg_pred.shape
        H, W = resize_shape
        coordinates = []
    
        for i in range(self.cfg.num_classes - 1):
            prob_map = seg_pred[i + 1]
            if smooth:
                prob_map = cv2.blur(prob_map, (9, 9), borderType=cv2.BORDER_REPLICATE)
            coords = self.get_lane(prob_map, y_px_gap, pts, thresh, resize_shape)
            if self.is_short(coords):
                continue
            coordinates.append(
                [[coords[j], H - 10 - j * y_px_gap] if coords[j] > 0 else [-1, H - 10 - j * y_px_gap] for j in
                 range(pts)])
    
    
        if len(coordinates) == 0:
            coords = np.zeros(pts)
            coordinates.append(
                [[coords[j], H - 10 - j * y_px_gap] if coords[j] > 0 else [-1, H - 10 - j * y_px_gap] for j in
                 range(pts)])
        #print(coordinates)
    
        return coordinates


================================================
FILE: main.py
================================================
import os
import os.path as osp
import time
import shutil
import torch
import torchvision
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.nn.functional as F
import torch.optim
import cv2
import numpy as np
import models
import argparse
from utils.config import Config
from runner.runner import Runner 
from datasets import build_dataloader


def main():
    args = parse_args()
    os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(str(gpu) for gpu in args.gpus)

    cfg = Config.fromfile(args.config)
    cfg.gpus = len(args.gpus)

    cfg.load_from = args.load_from
    cfg.finetune_from = args.finetune_from
    cfg.view = args.view

    cfg.work_dirs = args.work_dirs + '/' + cfg.dataset.train.type

    cudnn.benchmark = True
    cudnn.fastest = True

    runner = Runner(cfg)

    if args.validate:
        val_loader = build_dataloader(cfg.dataset.val, cfg, is_train=False)
        runner.validate(val_loader)
    else:
        runner.train()

def parse_args():
    parser = argparse.ArgumentParser(description='Train a detector')
    parser.add_argument('config', help='train config file path')
    parser.add_argument(
        '--work_dirs', type=str, default='work_dirs',
        help='work dirs')
    parser.add_argument(
        '--load_from', default=None,
        help='the checkpoint file to resume from')
    parser.add_argument(
        '--finetune_from', default=None,
        help='whether to finetune from the checkpoint')
    parser.add_argument(
        '--validate',
        action='store_true',
        help='whether to evaluate the checkpoint during training')
    parser.add_argument(
        '--view',
        action='store_true',
        help='whether to show visualization result')
    parser.add_argument('--gpus', nargs='+', type=int, default='0')
    parser.add_argument('--seed', type=int,
                        default=None, help='random seed')
    args = parser.parse_args()

    return args


if __name__ == '__main__':
    main()


================================================
FILE: models/__init__.py
================================================
from .resa import *


================================================
FILE: models/decoder.py
================================================
from torch import nn
import torch.nn.functional as F

class PlainDecoder(nn.Module):
    def __init__(self, cfg):
        super(PlainDecoder, self).__init__()
        self.cfg = cfg

        self.dropout = nn.Dropout2d(0.1)
        self.conv8 = nn.Conv2d(128, cfg.num_classes, 1)

    def forward(self, x):
        x = self.dropout(x)
        x = self.conv8(x)
        x = F.interpolate(x, size=[self.cfg.img_height,  self.cfg.img_width],
                           mode='bilinear', align_corners=False)

        return x


def conv1x1(in_planes, out_planes, stride=1):
    """1x1 convolution"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)


class non_bottleneck_1d(nn.Module):
    def __init__(self, chann, dropprob, dilated):
        super().__init__()

        self.conv3x1_1 = nn.Conv2d(
            chann, chann, (3, 1), stride=1, padding=(1, 0), bias=True)

        self.conv1x3_1 = nn.Conv2d(
            chann, chann, (1, 3), stride=1, padding=(0, 1), bias=True)

        self.bn1 = nn.BatchNorm2d(chann, eps=1e-03)

        self.conv3x1_2 = nn.Conv2d(chann, chann, (3, 1), stride=1, padding=(1 * dilated, 0), bias=True,
                                   dilation=(dilated, 1))

        self.conv1x3_2 = nn.Conv2d(chann, chann, (1, 3), stride=1, padding=(0, 1 * dilated), bias=True,
                                   dilation=(1, dilated))

        self.bn2 = nn.BatchNorm2d(chann, eps=1e-03)

        self.dropout = nn.Dropout2d(dropprob)

    def forward(self, input):
        output = self.conv3x1_1(input)
        output = F.relu(output)
        output = self.conv1x3_1(output)
        output = self.bn1(output)
        output = F.relu(output)

        output = self.conv3x1_2(output)
        output = F.relu(output)
        output = self.conv1x3_2(output)
        output = self.bn2(output)

        if (self.dropout.p != 0):
            output = self.dropout(output)

        # +input = identity (residual connection)
        return F.relu(output + input)


class UpsamplerBlock(nn.Module):
    def __init__(self, ninput, noutput, up_width, up_height):
        super().__init__()

        self.conv = nn.ConvTranspose2d(
            ninput, noutput, 3, stride=2, padding=1, output_padding=1, bias=True)

        self.bn = nn.BatchNorm2d(noutput, eps=1e-3, track_running_stats=True)

        self.follows = nn.ModuleList()
        self.follows.append(non_bottleneck_1d(noutput, 0, 1))
        self.follows.append(non_bottleneck_1d(noutput, 0, 1))

        # interpolate
        self.up_width = up_width
        self.up_height = up_height
        self.interpolate_conv = conv1x1(ninput, noutput)
        self.interpolate_bn = nn.BatchNorm2d(
            noutput, eps=1e-3, track_running_stats=True)

    def forward(self, input):
        output = self.conv(input)
        output = self.bn(output)
        out = F.relu(output)
        for follow in self.follows:
            out = follow(out)

        interpolate_output = self.interpolate_conv(input)
        interpolate_output = self.interpolate_bn(interpolate_output)
        interpolate_output = F.relu(interpolate_output)

        interpolate = F.interpolate(interpolate_output, size=[self.up_height,  self.up_width],
                                    mode='bilinear', align_corners=False)

        return out + interpolate

class BUSD(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        img_height = cfg.img_height
        img_width = cfg.img_width
        num_classes = cfg.num_classes

        self.layers = nn.ModuleList()

        self.layers.append(UpsamplerBlock(ninput=128, noutput=64,
                                          up_height=int(img_height)//4, up_width=int(img_width)//4))
        self.layers.append(UpsamplerBlock(ninput=64, noutput=32,
                                          up_height=int(img_height)//2, up_width=int(img_width)//2))
        self.layers.append(UpsamplerBlock(ninput=32, noutput=16,
                                          up_height=int(img_height)//1, up_width=int(img_width)//1))

        self.output_conv = conv1x1(16, num_classes)

    def forward(self, input):
        output = input

        for layer in self.layers:
            output = layer(output)

        output = self.output_conv(output)

        return output


================================================
FILE: models/registry.py
================================================
from utils import Registry, build_from_cfg

NET = Registry('net')

def build(cfg, registry, default_args=None):
    if isinstance(cfg, list):
        modules = [
            build_from_cfg(cfg_, registry, default_args) for cfg_ in cfg
        ]
        return nn.Sequential(*modules)
    else:
        return build_from_cfg(cfg, registry, default_args)


def build_net(cfg):
    return build(cfg.net, NET, default_args=dict(cfg=cfg))


================================================
FILE: models/resa.py
================================================
import torch.nn as nn
import torch
import torch.nn.functional as F

from models.registry import NET
from .resnet import ResNetWrapper 
from .decoder import BUSD, PlainDecoder 


class RESA(nn.Module):
    def __init__(self, cfg):
        super(RESA, self).__init__()
        self.iter = cfg.resa.iter
        chan = cfg.resa.input_channel
        fea_stride = cfg.backbone.fea_stride
        self.height = cfg.img_height // fea_stride
        self.width = cfg.img_width // fea_stride
        self.alpha = cfg.resa.alpha
        conv_stride = cfg.resa.conv_stride

        for i in range(self.iter):
            conv_vert1 = nn.Conv2d(
                chan, chan, (1, conv_stride),
                padding=(0, conv_stride//2), groups=1, bias=False)
            conv_vert2 = nn.Conv2d(
                chan, chan, (1, conv_stride),
                padding=(0, conv_stride//2), groups=1, bias=False)

            setattr(self, 'conv_d'+str(i), conv_vert1)
            setattr(self, 'conv_u'+str(i), conv_vert2)

            conv_hori1 = nn.Conv2d(
                chan, chan, (conv_stride, 1),
                padding=(conv_stride//2, 0), groups=1, bias=False)
            conv_hori2 = nn.Conv2d(
                chan, chan, (conv_stride, 1),
                padding=(conv_stride//2, 0), groups=1, bias=False)

            setattr(self, 'conv_r'+str(i), conv_hori1)
            setattr(self, 'conv_l'+str(i), conv_hori2)

            idx_d = (torch.arange(self.height) + self.height //
                     2**(self.iter - i)) % self.height
            setattr(self, 'idx_d'+str(i), idx_d)

            idx_u = (torch.arange(self.height) - self.height //
                     2**(self.iter - i)) % self.height
            setattr(self, 'idx_u'+str(i), idx_u)

            idx_r = (torch.arange(self.width) + self.width //
                     2**(self.iter - i)) % self.width
            setattr(self, 'idx_r'+str(i), idx_r)

            idx_l = (torch.arange(self.width) - self.width //
                     2**(self.iter - i)) % self.width
            setattr(self, 'idx_l'+str(i), idx_l)

    def forward(self, x):
        x = x.clone()

        for direction in ['d', 'u']:
            for i in range(self.iter):
                conv = getattr(self, 'conv_' + direction + str(i))
                idx = getattr(self, 'idx_' + direction + str(i))
                x.add_(self.alpha * F.relu(conv(x[..., idx, :])))

        for direction in ['r', 'l']:
            for i in range(self.iter):
                conv = getattr(self, 'conv_' + direction + str(i))
                idx = getattr(self, 'idx_' + direction + str(i))
                x.add_(self.alpha * F.relu(conv(x[..., idx])))

        return x



class ExistHead(nn.Module):
    def __init__(self, cfg=None):
        super(ExistHead, self).__init__()
        self.cfg = cfg

        self.dropout = nn.Dropout2d(0.1)  # ???
        self.conv8 = nn.Conv2d(128, cfg.num_classes, 1)

        stride = cfg.backbone.fea_stride * 2
        self.fc9 = nn.Linear(
            int(cfg.num_classes * cfg.img_width / stride * cfg.img_height / stride), 128)
        self.fc10 = nn.Linear(128, cfg.num_classes-1)

    def forward(self, x):
        x = self.dropout(x)
        x = self.conv8(x)

        x = F.softmax(x, dim=1)
        x = F.avg_pool2d(x, 2, stride=2, padding=0)
        x = x.view(-1, x.numel() // x.shape[0])
        x = self.fc9(x)
        x = F.relu(x)
        x = self.fc10(x)
        x = torch.sigmoid(x)

        return x


@NET.register_module
class RESANet(nn.Module):
    def __init__(self, cfg):
        super(RESANet, self).__init__()
        self.cfg = cfg
        self.backbone = ResNetWrapper(cfg)
        self.resa = RESA(cfg)
        self.decoder = eval(cfg.decoder)(cfg)
        self.heads = ExistHead(cfg) 

    def forward(self, batch):
        fea = self.backbone(batch)
        fea = self.resa(fea)
        seg = self.decoder(fea)
        exist = self.heads(fea)

        output = {'seg': seg, 'exist': exist}

        return output


================================================
FILE: models/resnet.py
================================================
import torch
from torch import nn
import torch.nn.functional as F
from torch.hub import load_state_dict_from_url


# This code is borrow from torchvision.

model_urls = {
    'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
    'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
    'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
    'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
    'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
    'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',
    'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth',
    'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth',
    'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth',
}


def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
    """3x3 convolution with padding"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                     padding=dilation, groups=groups, bias=False, dilation=dilation)


def conv1x1(in_planes, out_planes, stride=1):
    """1x1 convolution"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)


class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
                 base_width=64, dilation=1, norm_layer=None):
        super(BasicBlock, self).__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        if groups != 1 or base_width != 64:
            raise ValueError(
                'BasicBlock only supports groups=1 and base_width=64')
        # if dilation > 1:
        #     raise NotImplementedError(
        #         "Dilation > 1 not supported in BasicBlock")
        # Both self.conv1 and self.downsample layers downsample the input when stride != 1
        self.conv1 = conv3x3(inplanes, planes, stride, dilation=dilation)
        self.bn1 = norm_layer(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(planes, planes, dilation=dilation)
        self.bn2 = norm_layer(planes)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu(out)

        return out


class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
                 base_width=64, dilation=1, norm_layer=None):
        super(Bottleneck, self).__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        width = int(planes * (base_width / 64.)) * groups
        # Both self.conv2 and self.downsample layers downsample the input when stride != 1
        self.conv1 = conv1x1(inplanes, width)
        self.bn1 = norm_layer(width)
        self.conv2 = conv3x3(width, width, stride, groups, dilation)
        self.bn2 = norm_layer(width)
        self.conv3 = conv1x1(width, planes * self.expansion)
        self.bn3 = norm_layer(planes * self.expansion)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu(out)

        return out


class ResNetWrapper(nn.Module):

    def __init__(self, cfg):
        super(ResNetWrapper, self).__init__()
        self.cfg = cfg
        self.in_channels = [64, 128, 256, 512]
        if 'in_channels' in cfg.backbone:
            self.in_channels = cfg.backbone.in_channels
        self.model = eval(cfg.backbone.resnet)(
            pretrained=cfg.backbone.pretrained,
            replace_stride_with_dilation=cfg.backbone.replace_stride_with_dilation, in_channels=self.in_channels)
        self.out = None
        if cfg.backbone.out_conv:
            out_channel = 512
            for chan in reversed(self.in_channels):
                if chan < 0: continue
                out_channel = chan
                break
            self.out = conv1x1(
                out_channel * self.model.expansion, 128)

    def forward(self, x):
        x = self.model(x)
        if self.out:
            x = self.out(x)
        return x


class ResNet(nn.Module):

    def __init__(self, block, layers, zero_init_residual=False,
                 groups=1, width_per_group=64, replace_stride_with_dilation=None,
                 norm_layer=None, in_channels=None):
        super(ResNet, self).__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        self._norm_layer = norm_layer

        self.inplanes = 64
        self.dilation = 1
        if replace_stride_with_dilation is None:
            # each element in the tuple indicates if we should replace
            # the 2x2 stride with a dilated convolution instead
            replace_stride_with_dilation = [False, False, False]
        if len(replace_stride_with_dilation) != 3:
            raise ValueError("replace_stride_with_dilation should be None "
                             "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
        self.groups = groups
        self.base_width = width_per_group
        self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,
                               bias=False)
        self.bn1 = norm_layer(self.inplanes)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.in_channels = in_channels
        self.layer1 = self._make_layer(block, in_channels[0], layers[0])
        self.layer2 = self._make_layer(block, in_channels[1], layers[1], stride=2,
                                       dilate=replace_stride_with_dilation[0])
        self.layer3 = self._make_layer(block, in_channels[2], layers[2], stride=2,
                                       dilate=replace_stride_with_dilation[1])
        if in_channels[3] > 0:
            self.layer4 = self._make_layer(block, in_channels[3], layers[3], stride=2,
                                           dilate=replace_stride_with_dilation[2])
        self.expansion = block.expansion

        # self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        # self.fc = nn.Linear(512 * block.expansion, num_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(
                    m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

        # Zero-initialize the last BN in each residual branch,
        # so that the residual branch starts with zeros, and each residual block behaves like an identity.
        # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
        if zero_init_residual:
            for m in self.modules():
                if isinstance(m, Bottleneck):
                    nn.init.constant_(m.bn3.weight, 0)
                elif isinstance(m, BasicBlock):
                    nn.init.constant_(m.bn2.weight, 0)

    def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
        norm_layer = self._norm_layer
        downsample = None
        previous_dilation = self.dilation
        if dilate:
            self.dilation *= stride
            stride = 1
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                conv1x1(self.inplanes, planes * block.expansion, stride),
                norm_layer(planes * block.expansion),
            )

        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
                            self.base_width, previous_dilation, norm_layer))
        self.inplanes = planes * block.expansion
        for _ in range(1, blocks):
            layers.append(block(self.inplanes, planes, groups=self.groups,
                                base_width=self.base_width, dilation=self.dilation,
                                norm_layer=norm_layer))

        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        if self.in_channels[3] > 0:
            x = self.layer4(x)

        # x = self.avgpool(x)
        # x = torch.flatten(x, 1)
        # x = self.fc(x)

        return x


def _resnet(arch, block, layers, pretrained, progress, **kwargs):
    model = ResNet(block, layers, **kwargs)
    if pretrained:
        state_dict = load_state_dict_from_url(model_urls[arch],
                                              progress=progress)
        model.load_state_dict(state_dict, strict=False)
    return model


def resnet18(pretrained=False, progress=True, **kwargs):
    r"""ResNet-18 model from
    `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
        progress (bool): If True, displays a progress bar of the download to stderr
    """
    return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress,
                   **kwargs)


def resnet34(pretrained=False, progress=True, **kwargs):
    r"""ResNet-34 model from
    `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
        progress (bool): If True, displays a progress bar of the download to stderr
    """
    return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress,
                   **kwargs)


def resnet50(pretrained=False, progress=True, **kwargs):
    r"""ResNet-50 model from
    `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
        progress (bool): If True, displays a progress bar of the download to stderr
    """
    return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress,
                   **kwargs)


def resnet101(pretrained=False, progress=True, **kwargs):
    r"""ResNet-101 model from
    `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
        progress (bool): If True, displays a progress bar of the download to stderr
    """
    return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress,
                   **kwargs)


def resnet152(pretrained=False, progress=True, **kwargs):
    r"""ResNet-152 model from
    `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
        progress (bool): If True, displays a progress bar of the download to stderr
    """
    return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress,
                   **kwargs)


def resnext50_32x4d(pretrained=False, progress=True, **kwargs):
    r"""ResNeXt-50 32x4d model from
    `"Aggregated Residual Transformation for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
        progress (bool): If True, displays a progress bar of the download to stderr
    """
    kwargs['groups'] = 32
    kwargs['width_per_group'] = 4
    return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3],
                   pretrained, progress, **kwargs)


def resnext101_32x8d(pretrained=False, progress=True, **kwargs):
    r"""ResNeXt-101 32x8d model from
    `"Aggregated Residual Transformation for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
        progress (bool): If True, displays a progress bar of the download to stderr
    """
    kwargs['groups'] = 32
    kwargs['width_per_group'] = 8
    return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3],
                   pretrained, progress, **kwargs)


def wide_resnet50_2(pretrained=False, progress=True, **kwargs):
    r"""Wide ResNet-50-2 model from
    `"Wide Residual Networks" <https://arxiv.org/pdf/1605.07146.pdf>`_

    The model is the same as ResNet except for the bottleneck number of channels
    which is twice larger in every block. The number of channels in outer 1x1
    convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
    channels, and in Wide ResNet-50-2 has 2048-1024-2048.

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
        progress (bool): If True, displays a progress bar of the download to stderr
    """
    kwargs['width_per_group'] = 64 * 2
    return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3],
                   pretrained, progress, **kwargs)


def wide_resnet101_2(pretrained=False, progress=True, **kwargs):
    r"""Wide ResNet-101-2 model from
    `"Wide Residual Networks" <https://arxiv.org/pdf/1605.07146.pdf>`_

    The model is the same as ResNet except for the bottleneck number of channels
    which is twice larger in every block. The number of channels in outer 1x1
    convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
    channels, and in Wide ResNet-50-2 has 2048-1024-2048.

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
        progress (bool): If True, displays a progress bar of the download to stderr
    """
    kwargs['width_per_group'] = 64 * 2
    return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3],
                   pretrained, progress, **kwargs)


================================================
FILE: requirement.txt
================================================
pandas
addict
sklearn
opencv-python
pytorch_warmup
scikit-image
tqdm
termcolor


================================================
FILE: runner/__init__.py
================================================
from .evaluator import *
from .resa_trainer import *

from .registry import build_evaluator 


================================================
FILE: runner/evaluator/__init__.py
================================================
from .tusimple.tusimple import Tusimple
from .culane.culane import CULane


================================================
FILE: runner/evaluator/culane/culane.py
================================================
import torch.nn as nn
import torch
import torch.nn.functional as F
from runner.logger import get_logger

from runner.registry import EVALUATOR 
import json
import os
import subprocess
from shutil import rmtree
import cv2
import numpy as np

def check():
    import subprocess
    import sys
    FNULL = open(os.devnull, 'w')
    result = subprocess.call(
        './runner/evaluator/culane/lane_evaluation/evaluate', stdout=FNULL, stderr=FNULL)
    if result > 1:
        print('There is something wrong with evaluate tool, please compile it.')
        sys.exit()

def read_helper(path):
    lines = open(path, 'r').readlines()[1:]
    lines = ' '.join(lines)
    values = lines.split(' ')[1::2]
    keys = lines.split(' ')[0::2]
    keys = [key[:-1] for key in keys]
    res = {k : v for k,v in zip(keys,values)}
    return res

def call_culane_eval(data_dir, output_path='./output'):
    if data_dir[-1] != '/':
        data_dir = data_dir + '/'
    detect_dir=os.path.join(output_path, 'lines')+'/'

    w_lane=30
    iou=0.5;  # Set iou to 0.3 or 0.5
    im_w=1640
    im_h=590
    frame=1
    list0 = os.path.join(data_dir,'list/test_split/test0_normal.txt')
    list1 = os.path.join(data_dir,'list/test_split/test1_crowd.txt')
    list2 = os.path.join(data_dir,'list/test_split/test2_hlight.txt')
    list3 = os.path.join(data_dir,'list/test_split/test3_shadow.txt')
    list4 = os.path.join(data_dir,'list/test_split/test4_noline.txt')
    list5 = os.path.join(data_dir,'list/test_split/test5_arrow.txt')
    list6 = os.path.join(data_dir,'list/test_split/test6_curve.txt')
    list7 = os.path.join(data_dir,'list/test_split/test7_cross.txt')
    list8 = os.path.join(data_dir,'list/test_split/test8_night.txt')
    if not os.path.exists(os.path.join(output_path,'txt')):
        os.mkdir(os.path.join(output_path,'txt'))
    out0 = os.path.join(output_path,'txt','out0_normal.txt')
    out1 = os.path.join(output_path,'txt','out1_crowd.txt')
    out2 = os.path.join(output_path,'txt','out2_hlight.txt')
    out3 = os.path.join(output_path,'txt','out3_shadow.txt')
    out4 = os.path.join(output_path,'txt','out4_noline.txt')
    out5 = os.path.join(output_path,'txt','out5_arrow.txt')
    out6 = os.path.join(output_path,'txt','out6_curve.txt')
    out7 = os.path.join(output_path,'txt','out7_cross.txt')
    out8 = os.path.join(output_path,'txt','out8_night.txt')

    eval_cmd = './runner/evaluator/culane/lane_evaluation/evaluate'

    os.system('%s -a %s -d %s -i %s -l %s -w %s -t %s -c %s -r %s -f %s -o %s'%(eval_cmd,data_dir,detect_dir,data_dir,list0,w_lane,iou,im_w,im_h,frame,out0))
    os.system('%s -a %s -d %s -i %s -l %s -w %s -t %s -c %s -r %s -f %s -o %s'%(eval_cmd,data_dir,detect_dir,data_dir,list1,w_lane,iou,im_w,im_h,frame,out1))
    os.system('%s -a %s -d %s -i %s -l %s -w %s -t %s -c %s -r %s -f %s -o %s'%(eval_cmd,data_dir,detect_dir,data_dir,list2,w_lane,iou,im_w,im_h,frame,out2))
    os.system('%s -a %s -d %s -i %s -l %s -w %s -t %s -c %s -r %s -f %s -o %s'%(eval_cmd,data_dir,detect_dir,data_dir,list3,w_lane,iou,im_w,im_h,frame,out3))
    os.system('%s -a %s -d %s -i %s -l %s -w %s -t %s -c %s -r %s -f %s -o %s'%(eval_cmd,data_dir,detect_dir,data_dir,list4,w_lane,iou,im_w,im_h,frame,out4))
    os.system('%s -a %s -d %s -i %s -l %s -w %s -t %s -c %s -r %s -f %s -o %s'%(eval_cmd,data_dir,detect_dir,data_dir,list5,w_lane,iou,im_w,im_h,frame,out5))
    os.system('%s -a %s -d %s -i %s -l %s -w %s -t %s -c %s -r %s -f %s -o %s'%(eval_cmd,data_dir,detect_dir,data_dir,list6,w_lane,iou,im_w,im_h,frame,out6))
    os.system('%s -a %s -d %s -i %s -l %s -w %s -t %s -c %s -r %s -f %s -o %s'%(eval_cmd,data_dir,detect_dir,data_dir,list7,w_lane,iou,im_w,im_h,frame,out7))
    os.system('%s -a %s -d %s -i %s -l %s -w %s -t %s -c %s -r %s -f %s -o %s'%(eval_cmd,data_dir,detect_dir,data_dir,list8,w_lane,iou,im_w,im_h,frame,out8))
    res_all = {}
    res_all['normal'] = read_helper(out0)
    res_all['crowd']= read_helper(out1)
    res_all['night']= read_helper(out8)
    res_all['noline'] = read_helper(out4)
    res_all['shadow'] = read_helper(out3)
    res_all['arrow']= read_helper(out5)
    res_all['hlight'] = read_helper(out2)
    res_all['curve']= read_helper(out6)
    res_all['cross']= read_helper(out7)
    return res_all

@EVALUATOR.register_module
class CULane(nn.Module):
    def __init__(self, cfg):
        super(CULane, self).__init__()
        # Firstly, check the evaluation tool
        check()
        self.cfg = cfg 
        self.blur = torch.nn.Conv2d(
            5, 5, 9, padding=4, bias=False, groups=5).cuda()
        torch.nn.init.constant_(self.blur.weight, 1 / 81)
        self.logger = get_logger('resa')
        self.out_dir = os.path.join(self.cfg.work_dir, 'lines')
        if cfg.view:
            self.view_dir = os.path.join(self.cfg.work_dir, 'vis')

    def evaluate(self, dataset, output, batch):
        seg, exists = output['seg'], output['exist']
        predictmaps = F.softmax(seg, dim=1).cpu().numpy()
        exists = exists.cpu().numpy()
        batch_size = seg.size(0)
        img_name = batch['meta']['img_name']
        img_path = batch['meta']['full_img_path']
        for i in range(batch_size):
            coords = dataset.probmap2lane(predictmaps[i], exists[i])
            outname = self.out_dir + img_name[i][:-4] + '.lines.txt'
            outdir = os.path.dirname(outname)
            if not os.path.exists(outdir):
                os.makedirs(outdir)
            f = open(outname, 'w')
            for coord in coords:
                for x, y in coord:
                    if x < 0 and y < 0:
                        continue
                    f.write('%d %d ' % (x, y))
                f.write('\n')
            f.close()

            if self.cfg.view:
                img = cv2.imread(img_path[i]).astype(np.float32)
                dataset.view(img, coords, self.view_dir+img_name[i])


    def summarize(self):
        self.logger.info('summarize result...')
        eval_list_path = os.path.join(
            self.cfg.dataset_path, "list", self.cfg.dataset.val.data_list)
        #prob2lines(self.prob_dir, self.out_dir, eval_list_path, self.cfg)
        res = call_culane_eval(self.cfg.dataset_path, output_path=self.cfg.work_dir)
        TP,FP,FN = 0,0,0
        out_str = 'Copypaste: '
        for k, v in res.items():
            val = float(v['Fmeasure']) if 'nan' not in v['Fmeasure'] else 0
            val_tp, val_fp, val_fn = int(v['tp']), int(v['fp']), int(v['fn'])
            val_p, val_r, val_f1 = float(v['precision']), float(v['recall']), float(v['Fmeasure'])
            TP += val_tp
            FP += val_fp
            FN += val_fn
            self.logger.info(k + ': ' + str(v))
            out_str += k
            for metric, value in v.items():
                out_str += ' ' + str(value).rstrip('\n')
            out_str += ' '
        P = TP * 1.0 / (TP + FP + 1e-9)
        R = TP * 1.0 / (TP + FN + 1e-9)
        F = 2*P*R/(P + R + 1e-9)
        overall_result_str = ('Overall Precision: %f Recall: %f F1: %f' % (P, R, F))
        self.logger.info(overall_result_str)
        out_str = out_str + overall_result_str
        self.logger.info(out_str)

        # delete the tmp output
        rmtree(self.out_dir)


================================================
FILE: runner/evaluator/culane/lane_evaluation/.gitignore
================================================
build/
evaluate


================================================
FILE: runner/evaluator/culane/lane_evaluation/Makefile
================================================
PROJECT_NAME:= evaluate

# config ----------------------------------
OPENCV_VERSION := 3

INCLUDE_DIRS := include
LIBRARY_DIRS := lib /usr/local/lib

COMMON_FLAGS := -DCPU_ONLY
CXXFLAGS := -std=c++11 -fopenmp
LDFLAGS := -fopenmp -Wl,-rpath,./lib
BUILD_DIR := build


# make rules -------------------------------
CXX ?= g++
BUILD_DIR ?= ./build

LIBRARIES += opencv_core opencv_highgui opencv_imgproc 
ifeq ($(OPENCV_VERSION), 3)
		LIBRARIES += opencv_imgcodecs
endif

CXXFLAGS += $(COMMON_FLAGS) $(foreach includedir,$(INCLUDE_DIRS),-I$(includedir))
LDFLAGS +=  $(COMMON_FLAGS) $(foreach includedir,$(LIBRARY_DIRS),-L$(includedir)) $(foreach library,$(LIBRARIES),-l$(library))
SRC_DIRS += $(shell find * -type d -exec bash -c "find {} -maxdepth 1 \( -name '*.cpp' -o -name '*.proto' \) | grep -q ." \; -print)
CXX_SRCS += $(shell find src/ -name "*.cpp")
CXX_TARGETS:=$(patsubst %.cpp, $(BUILD_DIR)/%.o, $(CXX_SRCS))
ALL_BUILD_DIRS := $(sort $(BUILD_DIR) $(addprefix $(BUILD_DIR)/, $(SRC_DIRS)))

.PHONY: all
all: $(PROJECT_NAME)

.PHONY: $(ALL_BUILD_DIRS)
$(ALL_BUILD_DIRS):
	@mkdir -p $@

$(BUILD_DIR)/%.o: %.cpp | $(ALL_BUILD_DIRS)
	@echo "CXX" $<
	@$(CXX) $(CXXFLAGS) -c -o $@ $<

$(PROJECT_NAME): $(CXX_TARGETS)
	@echo "CXX/LD" $@
	@$(CXX) -o $@ $^ $(LDFLAGS)

.PHONY: clean
clean:
	@rm -rf $(CXX_TARGETS)
	@rm -rf $(PROJECT_NAME)
	@rm -rf $(BUILD_DIR)


================================================
FILE: runner/evaluator/culane/lane_evaluation/include/counter.hpp
================================================
#ifndef COUNTER_HPP
#define COUNTER_HPP

#include "lane_compare.hpp"
#include "hungarianGraph.hpp"
#include <iostream>
#include <algorithm>
#include <tuple>
#include <vector>
#include <opencv2/core/core.hpp>

using namespace std;
using namespace cv;

// before coming to use functions of this class, the lanes should resize to im_width and im_height using resize_lane() in lane_compare.hpp
class Counter
{
	public:
		Counter(int _im_width, int _im_height, double _iou_threshold=0.4, int _lane_width=10):tp(0),fp(0),fn(0){
			im_width = _im_width;
			im_height = _im_height;
			sim_threshold = _iou_threshold;
			lane_compare = new LaneCompare(_im_width, _im_height,  _lane_width, LaneCompare::IOU);
		};
		double get_precision(void);
		double get_recall(void);
		long getTP(void);
		long getFP(void);
		long getFN(void);
		void setTP(long);
		void setFP(long);
		void setFN(long);
		// direct add tp, fp, tn and fn
		// first match with hungarian
		tuple<vector<int>, long, long, long, long> count_im_pair(const vector<vector<Point2f> > &anno_lanes, const vector<vector<Point2f> > &detect_lanes);
		void makeMatch(const vector<vector<double> > &similarity, vector<int> &match1, vector<int> &match2);

	private:
		double sim_threshold;
		int im_width;
		int im_height;
		long tp;
		long fp;
		long fn;
		LaneCompare *lane_compare;
};
#endif


================================================
FILE: runner/evaluator/culane/lane_evaluation/include/hungarianGraph.hpp
================================================
#ifndef HUNGARIAN_GRAPH_HPP
#define HUNGARIAN_GRAPH_HPP
#include <vector>
using namespace std;

struct pipartiteGraph {
    vector<vector<double> > mat;
    vector<bool> leftUsed, rightUsed;
    vector<double> leftWeight, rightWeight;
    vector<int>rightMatch, leftMatch;
    int leftNum, rightNum;
    bool matchDfs(int u) {
        leftUsed[u] = true;
        for (int v = 0; v < rightNum; v++) {
            if (!rightUsed[v] && fabs(leftWeight[u] + rightWeight[v] - mat[u][v]) < 1e-2) {
                rightUsed[v] = true;
                if (rightMatch[v] == -1 || matchDfs(rightMatch[v])) {
                    rightMatch[v] = u;
                    leftMatch[u] = v;
                    return true;
                }
            }
        }
        return false;
    }
    void resize(int leftNum, int rightNum) {
        this->leftNum = leftNum;
        this->rightNum = rightNum;
        leftMatch.resize(leftNum);
        rightMatch.resize(rightNum);
        leftUsed.resize(leftNum);
        rightUsed.resize(rightNum);
        leftWeight.resize(leftNum);
        rightWeight.resize(rightNum);
        mat.resize(leftNum);
        for (int i = 0; i < leftNum; i++) mat[i].resize(rightNum);
    }
    void match() {
        for (int i = 0; i < leftNum; i++) leftMatch[i] = -1;
        for (int i = 0; i < rightNum; i++) rightMatch[i] = -1;
        for (int i = 0; i < rightNum; i++) rightWeight[i] = 0;
        for (int i = 0; i < leftNum; i++) {
            leftWeight[i] = -1e5;
            for (int j = 0; j < rightNum; j++) {
                if (leftWeight[i] < mat[i][j]) leftWeight[i] = mat[i][j];
            }
        }

        for (int u = 0; u < leftNum; u++) {
            while (1) {
                for (int i = 0; i < leftNum; i++) leftUsed[i] = false;
                for (int i = 0; i < rightNum; i++) rightUsed[i] = false;
                if (matchDfs(u)) break;
                double d = 1e10;
                for (int i = 0; i < leftNum; i++) {
                    if (leftUsed[i] ) {
                        for (int j = 0; j < rightNum; j++) {
                            if (!rightUsed[j]) d = min(d, leftWeight[i] + rightWeight[j] - mat[i][j]);
                        }
                    }
                }
                if (d == 1e10) return ;
                for (int i = 0; i < leftNum; i++) if (leftUsed[i]) leftWeight[i] -= d;
                for (int i = 0; i < rightNum; i++) if (rightUsed[i]) rightWeight[i] += d;
            }
        }
    }
};


#endif // HUNGARIAN_GRAPH_HPP


================================================
FILE: runner/evaluator/culane/lane_evaluation/include/lane_compare.hpp
================================================
#ifndef LANE_COMPARE_HPP
#define LANE_COMPARE_HPP

#include "spline.hpp"
#include <vector>
#include <iostream>
#include <opencv2/core/version.hpp>
#include <opencv2/core/core.hpp>

#if CV_VERSION_EPOCH == 2
#define OPENCV2
#elif CV_VERSION_MAJOR == 3
#define  OPENCV3
#else
#error Not support this OpenCV version
#endif

#ifdef OPENCV3
#include <opencv2/imgproc.hpp>
#elif defined(OPENCV2)
#include <opencv2/imgproc/imgproc.hpp>
#endif

using namespace std;
using namespace cv;

class LaneCompare{
	public:
		enum CompareMode{
			IOU,
			Caltech
		};

		LaneCompare(int _im_width, int _im_height, int _lane_width = 10, CompareMode _compare_mode = IOU){
			im_width = _im_width;
			im_height = _im_height;
			compare_mode = _compare_mode;
			lane_width = _lane_width;
		}

		double get_lane_similarity(const vector<Point2f> &lane1, const vector<Point2f> &lane2);
		void resize_lane(vector<Point2f> &curr_lane, int curr_width, int curr_height);
	private:
		CompareMode compare_mode;
		int im_width;
		int im_height;
		int lane_width;
		Spline splineSolver;
};

#endif


================================================
FILE: runner/evaluator/culane/lane_evaluation/include/spline.hpp
================================================
#ifndef SPLINE_HPP
#define SPLINE_HPP
#include <vector>
#include <cstdio>
#include <math.h>
#include <opencv2/core/core.hpp>

using namespace cv;
using namespace std;

struct Func {
    double a_x;
    double b_x;
    double c_x;
    double d_x;
    double a_y;
    double b_y;
    double c_y;
    double d_y;
    double h;
};
class Spline {
public:
	vector<Point2f> splineInterpTimes(const vector<Point2f> &tmp_line, int times);
    vector<Point2f> splineInterpStep(vector<Point2f> tmp_line, double step);
	vector<Func> cal_fun(const vector<Point2f> &point_v);
};
#endif


================================================
FILE: runner/evaluator/culane/lane_evaluation/src/counter.cpp
================================================
/*************************************************************************
	> File Name: counter.cpp
	> Author: Xingang Pan, Jun Li
	> Mail: px117@ie.cuhk.edu.hk
	> Created Time: Thu Jul 14 20:23:08 2016
 ************************************************************************/

#include "counter.hpp"

double Counter::get_precision(void)
{
	cerr<<"tp: "<<tp<<" fp: "<<fp<<" fn: "<<fn<<endl;
	if(tp+fp == 0)
	{
		cerr<<"no positive detection"<<endl;
		return -1;
	}
	return tp/double(tp + fp);
}

double Counter::get_recall(void)
{
	if(tp+fn == 0)
	{
		cerr<<"no ground truth positive"<<endl;
		return -1;
	}
	return tp/double(tp + fn);
}

long Counter::getTP(void)
{
	return tp;
}

long Counter::getFP(void)
{
	return fp;
}

long Counter::getFN(void)
{
	return fn;
}

void Counter::setTP(long value) 
{
	tp = value;
}

void Counter::setFP(long value)
{
  fp = value;
}

void Counter::setFN(long value)
{
	fn = value;
}

tuple<vector<int>, long, long, long, long> Counter::count_im_pair(const vector<vector<Point2f> > &anno_lanes, const vector<vector<Point2f> > &detect_lanes)
{
	vector<int> anno_match(anno_lanes.size(), -1);
	vector<int> detect_match;
	if(anno_lanes.empty())
	{
		return make_tuple(anno_match, 0, detect_lanes.size(), 0, 0);
	}

	if(detect_lanes.empty())
	{
		return make_tuple(anno_match, 0, 0, 0, anno_lanes.size());
	}
	// hungarian match first
	
	// first calc similarity matrix
	vector<vector<double> > similarity(anno_lanes.size(), vector<double>(detect_lanes.size(), 0));
	for(int i=0; i<anno_lanes.size(); i++)
	{
		const vector<Point2f> &curr_anno_lane = anno_lanes[i];
		for(int j=0; j<detect_lanes.size(); j++)
		{
			const vector<Point2f> &curr_detect_lane = detect_lanes[j];
			similarity[i][j] = lane_compare->get_lane_similarity(curr_anno_lane, curr_detect_lane);
		}
	}



	makeMatch(similarity, anno_match, detect_match);

	
	int curr_tp = 0;
	// count and add
	for(int i=0; i<anno_lanes.size(); i++)
	{
		if(anno_match[i]>=0 && similarity[i][anno_match[i]] > sim_threshold)
		{
			curr_tp++;
		}
		else
		{
			anno_match[i] = -1;
		}
	}
	int curr_fn = anno_lanes.size() - curr_tp;
	int curr_fp = detect_lanes.size() - curr_tp;
	return make_tuple(anno_match, curr_tp, curr_fp, 0, curr_fn);
}


void Counter::makeMatch(const vector<vector<double> > &similarity, vector<int> &match1, vector<int> &match2) {
	int m = similarity.size();
	int n = similarity[0].size();
    pipartiteGraph gra;
    bool have_exchange = false;
    if (m > n) {
        have_exchange = true;
        swap(m, n);
    }
    gra.resize(m, n);
    for (int i = 0; i < gra.leftNum; i++) {
        for (int j = 0; j < gra.rightNum; j++) {
			if(have_exchange)
				gra.mat[i][j] = similarity[j][i];
			else
				gra.mat[i][j] = similarity[i][j];
        }
    }
    gra.match();
    match1 = gra.leftMatch;
    match2 = gra.rightMatch;
    if (have_exchange) swap(match1, match2);
}


================================================
FILE: runner/evaluator/culane/lane_evaluation/src/evaluate.cpp
================================================
/*************************************************************************
        > File Name: evaluate.cpp
        > Author: Xingang Pan, Jun Li
        > Mail: px117@ie.cuhk.edu.hk
        > Created Time: 2016年07月14日 星期四 18时28分45秒
 ************************************************************************/

#include "counter.hpp"
#include "spline.hpp"
#include <unistd.h>
#include <iostream>
#include <fstream>
#include <sstream>
#include <cstdlib>
#include <string>
#include <opencv2/core/core.hpp>
#include <opencv2/highgui/highgui.hpp>
using namespace std;
using namespace cv;

void help(void) {
  cout << "./evaluate [OPTIONS]" << endl;
  cout << "-h                  : print usage help" << endl;
  cout << "-a                  : directory for annotation files (default: "
          "/data/driving/eval_data/anno_label/)" << endl;
  cout << "-d                  : directory for detection files (default: "
          "/data/driving/eval_data/predict_label/)" << endl;
  cout << "-i                  : directory for image files (default: "
          "/data/driving/eval_data/img/)" << endl;
  cout << "-l                  : list of images used for evaluation (default: "
          "/data/driving/eval_data/img/all.txt)" << endl;
  cout << "-w                  : width of the lanes (default: 10)" << endl;
  cout << "-t                  : threshold of iou (default: 0.4)" << endl;
  cout << "-c                  : cols (max image width) (default: 1920)"
       << endl;
  cout << "-r                  : rows (max image height) (default: 1080)"
       << endl;
  cout << "-s                  : show visualization" << endl;
  cout << "-f                  : start frame in the test set (default: 1)"
       << endl;
}

void read_lane_file(const string &file_name, vector<vector<Point2f>> &lanes);
void visualize(string &full_im_name, vector<vector<Point2f>> &anno_lanes,
               vector<vector<Point2f>> &detect_lanes, vector<int> anno_match,
               int width_lane, string save_path = "");

int main(int argc, char **argv) {
  // process params
  string anno_dir = "/data/driving/eval_data/anno_label/";
  string detect_dir = "/data/driving/eval_data/predict_label/";
  string im_dir = "/data/driving/eval_data/img/";
  string list_im_file = "/data/driving/eval_data/img/all.txt";
  string output_file = "./output.txt";
  int width_lane = 10;
  double iou_threshold = 0.4;
  int im_width = 1920;
  int im_height = 1080;
  int oc;
  bool show = false;
  int frame = 1;
  string save_path = "";
  while ((oc = getopt(argc, argv, "ha:d:i:l:w:t:c:r:sf:o:p:")) != -1) {
    switch (oc) {
    case 'h':
      help();
      return 0;
    case 'a':
      anno_dir = optarg;
      break;
    case 'd':
      detect_dir = optarg;
      break;
    case 'i':
      im_dir = optarg;
      break;
    case 'l':
      list_im_file = optarg;
      break;
    case 'w':
      width_lane = atoi(optarg);
      break;
    case 't':
      iou_threshold = atof(optarg);
      break;
    case 'c':
      im_width = atoi(optarg);
      break;
    case 'r':
      im_height = atoi(optarg);
      break;
    case 's':
      show = true;
      break;
    case 'p':
      save_path = optarg;
      break;
    case 'f':
      frame = atoi(optarg);
      break;
    case 'o':
      output_file = optarg;
      break;
    }
  }

  cout << "------------Configuration---------" << endl;
  cout << "anno_dir: " << anno_dir << endl;
  cout << "detect_dir: " << detect_dir << endl;
  cout << "im_dir: " << im_dir << endl;
  cout << "list_im_file: " << list_im_file << endl;
  cout << "width_lane: " << width_lane << endl;
  cout << "iou_threshold: " << iou_threshold << endl;
  cout << "im_width: " << im_width << endl;
  cout << "im_height: " << im_height << endl;
  cout << "-----------------------------------" << endl;
  cout << "Evaluating the results..." << endl;
  // this is the max_width and max_height

  if (width_lane < 1) {
    cerr << "width_lane must be positive" << endl;
    help();
    return 1;
  }

  ifstream ifs_im_list(list_im_file, ios::in);
  if (ifs_im_list.fail()) {
    cerr << "Error: file " << list_im_file << " not exist!" << endl;
    return 1;
  }

  Counter counter(im_width, im_height, iou_threshold, width_lane);

  vector<int> anno_match;
  string sub_im_name;
  // pre-load filelist
  vector<string> filelists;
  while (getline(ifs_im_list, sub_im_name)) {
    filelists.push_back(sub_im_name);
  }
  ifs_im_list.close();

  vector<tuple<vector<int>, long, long, long, long>> tuple_lists;
  tuple_lists.resize(filelists.size());

#pragma omp parallel for
  for (size_t i = 0; i < filelists.size(); i++) {
    auto sub_im_name = filelists[i];
    string full_im_name = im_dir + sub_im_name;
    string sub_txt_name =
        sub_im_name.substr(0, sub_im_name.find_last_of(".")) + ".lines.txt";
    string anno_file_name = anno_dir + sub_txt_name;
    string detect_file_name = detect_dir + sub_txt_name;
    vector<vector<Point2f>> anno_lanes;
    vector<vector<Point2f>> detect_lanes;
    read_lane_file(anno_file_name, anno_lanes);
    read_lane_file(detect_file_name, detect_lanes);
    // cerr<<count<<": "<<full_im_name<<endl;
    tuple_lists[i] = counter.count_im_pair(anno_lanes, detect_lanes);
    if (show) {
      auto anno_match = get<0>(tuple_lists[i]);
      visualize(full_im_name, anno_lanes, detect_lanes, anno_match, width_lane);
      waitKey(0);
    }
    if (save_path != "") {
      auto anno_match = get<0>(tuple_lists[i]);
      visualize(full_im_name, anno_lanes, detect_lanes, anno_match, width_lane,
                save_path);
    }
  }

  long tp = 0, fp = 0, tn = 0, fn = 0;
  for (auto result : tuple_lists) {
    tp += get<1>(result);
    fp += get<2>(result);
    // tn = get<3>(result);
    fn += get<4>(result);
  }
  counter.setTP(tp);
  counter.setFP(fp);
  counter.setFN(fn);

  double precision = counter.get_precision();
  double recall = counter.get_recall();
  double F = 2 * precision * recall / (precision + recall);
  cerr << "finished process file" << endl;
  cout << "precision: " << precision << endl;
  cout << "recall: " << recall << endl;
  cout << "Fmeasure: " << F << endl;
  cout << "----------------------------------" << endl;

  ofstream ofs_out_file;
  ofs_out_file.open(output_file, ios::out);
  ofs_out_file << "file: " << output_file << endl;
  ofs_out_file << "tp: " << counter.getTP() << " fp: " << counter.getFP()
               << " fn: " << counter.getFN() << endl;
  ofs_out_file << "precision: " << precision << endl;
  ofs_out_file << "recall: " << recall << endl;
  ofs_out_file << "Fmeasure: " << F << endl << endl;
  ofs_out_file.close();
  return 0;
}

void read_lane_file(const string &file_name, vector<vector<Point2f>> &lanes) {
  lanes.clear();
  ifstream ifs_lane(file_name, ios::in);
  if (ifs_lane.fail()) {
    return;
  }

  string str_line;
  while (getline(ifs_lane, str_line)) {
    vector<Point2f> curr_lane;
    stringstream ss;
    ss << str_line;
    double x, y;
    while (ss >> x >> y) {
      curr_lane.push_back(Point2f(x, y));
    }
    lanes.push_back(curr_lane);
  }

  ifs_lane.close();
}

void visualize(string &full_im_name, vector<vector<Point2f>> &anno_lanes,
               vector<vector<Point2f>> &detect_lanes, vector<int> anno_match,
               int width_lane, string save_path) {
  Mat img = imread(full_im_name, 1);
  Mat img2 = imread(full_im_name, 1);
  vector<Point2f> curr_lane;
  vector<Point2f> p_interp;
  Spline splineSolver;
  Scalar color_B = Scalar(255, 0, 0);
  Scalar color_G = Scalar(0, 255, 0);
  Scalar color_R = Scalar(0, 0, 255);
  Scalar color_P = Scalar(255, 0, 255);
  Scalar color;
  for (int i = 0; i < anno_lanes.size(); i++) {
    curr_lane = anno_lanes[i];
    if (curr_lane.size() == 2) {
      p_interp = curr_lane;
    } else {
      p_interp = splineSolver.splineInterpTimes(curr_lane, 50);
    }
    if (anno_match[i] >= 0) {
      color = color_G;
    } else {
      color = color_G;
    }
    for (int n = 0; n < p_interp.size() - 1; n++) {
      line(img, p_interp[n], p_interp[n + 1], color, width_lane);
      line(img2, p_interp[n], p_interp[n + 1], color, 2);
    }
  }
  bool detected;
  for (int i = 0; i < detect_lanes.size(); i++) {
    detected = false;
    curr_lane = detect_lanes[i];
    if (curr_lane.size() == 2) {
      p_interp = curr_lane;
    } else {
      p_interp = splineSolver.splineInterpTimes(curr_lane, 50);
    }
    for (int n = 0; n < anno_lanes.size(); n++) {
      if (anno_match[n] == i) {
        detected = true;
        break;
      }
    }
    if (detected == true) {
      color = color_B;
    } else {
      color = color_R;
    }
    for (int n = 0; n < p_interp.size() - 1; n++) {
      line(img, p_interp[n], p_interp[n + 1], color, width_lane);
      line(img2, p_interp[n], p_interp[n + 1], color, 2);
    }
  }
  if (save_path != "") {
    size_t pos = 0;
    string s = full_im_name;
    std::string token;
    std::string delimiter = "/";
    vector<string> names;
    while ((pos = s.find(delimiter)) != std::string::npos) {
      token = s.substr(0, pos);
      names.emplace_back(token);
      s.erase(0, pos + delimiter.length());
    }
    names.emplace_back(s);
    string file_name = names[3] + '_' + names[4] + '_' + names[5];
    // cout << file_name << endl;
    imwrite(save_path + '/' + file_name, img);
  } else {
    namedWindow("visualize", 1);
    imshow("visualize", img);
    namedWindow("visualize2", 1);
    imshow("visualize2", img2);
  }
}


================================================
FILE: runner/evaluator/culane/lane_evaluation/src/lane_compare.cpp
================================================
/*************************************************************************
	> File Name: lane_compare.cpp
	> Author: Xingang Pan, Jun Li
	> Mail: px117@ie.cuhk.edu.hk
	> Created Time: Fri Jul 15 10:26:32 2016
 ************************************************************************/

#include "lane_compare.hpp"

double LaneCompare::get_lane_similarity(const vector<Point2f> &lane1, const vector<Point2f> &lane2)
{
	if(lane1.size()<2 || lane2.size()<2)
	{
		cerr<<"lane size must be greater or equal to 2"<<endl;
		return 0;
	}
	Mat im1 = Mat::zeros(im_height, im_width, CV_8UC1);
	Mat im2 = Mat::zeros(im_height, im_width, CV_8UC1);
	// draw lines on im1 and im2
	vector<Point2f> p_interp1;
	vector<Point2f> p_interp2;
	if(lane1.size() == 2)
	{
		p_interp1 = lane1;
	}
	else
	{
		p_interp1 = splineSolver.splineInterpTimes(lane1, 50);
	}

	if(lane2.size() == 2)
	{
		p_interp2 = lane2;
	}
	else
	{
		p_interp2 = splineSolver.splineInterpTimes(lane2, 50);
	}
	
	Scalar color_white = Scalar(1);
	for(int n=0; n<p_interp1.size()-1; n++)
	{
		line(im1, p_interp1[n], p_interp1[n+1], color_white, lane_width);
	}
	for(int n=0; n<p_interp2.size()-1; n++)
	{
		line(im2, p_interp2[n], p_interp2[n+1], color_white, lane_width);
	}

	double sum_1 = cv::sum(im1).val[0];
	double sum_2 = cv::sum(im2).val[0];
	double inter_sum = cv::sum(im1.mul(im2)).val[0];
	double union_sum = sum_1 + sum_2 - inter_sum; 
	double iou = inter_sum / union_sum;
	return iou;
}


// resize the lane from Size(curr_width, curr_height) to Size(im_width, im_height)
void LaneCompare::resize_lane(vector<Point2f> &curr_lane, int curr_width, int curr_height)
{
	if(curr_width == im_width && curr_height == im_height)
	{
		return;
	}
	double x_scale = im_width/(double)curr_width;
	double y_scale = im_height/(double)curr_height;
	for(int n=0; n<curr_lane.size(); n++)
	{
		curr_lane[n] = Point2f(curr_lane[n].x*x_scale, curr_lane[n].y*y_scale);
	}
}



================================================
FILE: runner/evaluator/culane/lane_evaluation/src/spline.cpp
================================================
#include <vector>
#include <iostream>
#include "spline.hpp"
using namespace std;
using namespace cv;

vector<Point2f> Spline::splineInterpTimes(const vector<Point2f>& tmp_line, int times) {
    vector<Point2f> res;

    if(tmp_line.size() == 2) {
        double x1 = tmp_line[0].x;
        double y1 = tmp_line[0].y;
        double x2 = tmp_line[1].x;
        double y2 = tmp_line[1].y;

        for (int k = 0; k <= times; k++) {
            double xi =  x1 + double((x2 - x1) * k) / times;
            double yi =  y1 + double((y2 - y1) * k) / times;
            res.push_back(Point2f(xi, yi));
        }
    }

    else if(tmp_line.size() > 2)
    {
        vector<Func> tmp_func;
        tmp_func = this->cal_fun(tmp_line);
        if (tmp_func.empty()) {
            cout << "in splineInterpTimes: cal_fun failed" << endl;
            return res;
        }
        for(int j = 0; j < tmp_func.size(); j++)
        {
            double delta = tmp_func[j].h / times;
            for(int k = 0; k < times; k++)
            {
                double t1 = delta*k;
                double x1 = tmp_func[j].a_x + tmp_func[j].b_x*t1 + tmp_func[j].c_x*pow(t1,2) + tmp_func[j].d_x*pow(t1,3);
                double y1 = tmp_func[j].a_y + tmp_func[j].b_y*t1 + tmp_func[j].c_y*pow(t1,2) + tmp_func[j].d_y*pow(t1,3);
                res.push_back(Point2f(x1, y1));
            }
        }
        res.push_back(tmp_line[tmp_line.size() - 1]);
    }
	else {
		cerr << "in splineInterpTimes: not enough points" << endl;
	}
    return res;
}
vector<Point2f> Spline::splineInterpStep(vector<Point2f> tmp_line, double step) {
	vector<Point2f> res;
	/*
	if (tmp_line.size() == 2) {
		double x1 = tmp_line[0].x;
		double y1 = tmp_line[0].y;
		double x2 = tmp_line[1].x;
		double y2 = tmp_line[1].y;

		for (double yi = std::min(y1, y2); yi < std::max(y1, y2); yi += step) {
            double xi;
			if (yi == y1) xi = x1;
			else xi = (x2 - x1) / (y2 - y1) * (yi - y1) + x1;
			res.push_back(Point2f(xi, yi));
		}
	}*/
	if (tmp_line.size() == 2) {
		double x1 = tmp_line[0].x;
		double y1 = tmp_line[0].y;
		double x2 = tmp_line[1].x;
		double y2 = tmp_line[1].y;
		tmp_line[1].x = (x1 + x2) / 2;
		tmp_line[1].y = (y1 + y2) / 2;
		tmp_line.push_back(Point2f(x2, y2));
	}
	if (tmp_line.size() > 2) {
		vector<Func> tmp_func;
		tmp_func = this->cal_fun(tmp_line);
		double ystart = tmp_line[0].y;
		double yend = tmp_line[tmp_line.size() - 1].y;
		bool down;
		if (ystart < yend) down = 1;
		else down = 0;
		if (tmp_func.empty()) {
			cerr << "in splineInterpStep: cal_fun failed" << endl;
		}

		for(int j = 0; j < tmp_func.size(); j++)
        {
            for(double t1 = 0; t1 < tmp_func[j].h; t1 += step)
            {
                double x1 = tmp_func[j].a_x + tmp_func[j].b_x*t1 + tmp_func[j].c_x*pow(t1,2) + tmp_func[j].d_x*pow(t1,3);
                double y1 = tmp_func[j].a_y + tmp_func[j].b_y*t1 + tmp_func[j].c_y*pow(t1,2) + tmp_func[j].d_y*pow(t1,3);
                res.push_back(Point2f(x1, y1));
            }
        }
        res.push_back(tmp_line[tmp_line.size() - 1]);
	}
    else {
        cerr << "in splineInterpStep: not enough points" << endl;
    }
    return res;
}

vector<Func> Spline::cal_fun(const vector<Point2f> &point_v)
{
    vector<Func> func_v;
    int n = point_v.size();
    if(n<=2) {
        cout << "in cal_fun: point number less than 3" << endl;
        return func_v;
    }

    func_v.resize(point_v.size()-1);

    vector<double> Mx(n);
    vector<double> My(n);
    vector<double> A(n-2);
    vector<double> B(n-2);
    vector<double> C(n-2);
    vector<double> Dx(n-2);
    vector<double> Dy(n-2);
    vector<double> h(n-1);
    //vector<func> func_v(n-1);

    for(int i = 0; i < n-1; i++)
    {
        h[i] = sqrt(pow(point_v[i+1].x - point_v[i].x, 2) + pow(point_v[i+1].y - point_v[i].y, 2));
    }

    for(int i = 0; i < n-2; i++)
    {
        A[i] = h[i];
        B[i] = 2*(h[i]+h[i+1]);
        C[i] = h[i+1];

        Dx[i] =  6*( (point_v[i+2].x - point_v[i+1].x)/h[i+1] - (point_v[i+1].x - point_v[i].x)/h[i] );
        Dy[i] =  6*( (point_v[i+2].y - point_v[i+1].y)/h[i+1] - (point_v[i+1].y - point_v[i].y)/h[i] );
    }

    //TDMA
    C[0] = C[0] / B[0];
    Dx[0] = Dx[0] / B[0];
    Dy[0] = Dy[0] / B[0];
    for(int i = 1; i < n-2; i++)
    {
        double tmp = B[i] - A[i]*C[i-1];
        C[i] = C[i] / tmp;
        Dx[i] = (Dx[i] - A[i]*Dx[i-1]) / tmp;
        Dy[i] = (Dy[i] - A[i]*Dy[i-1]) / tmp;
    }
    Mx[n-2] = Dx[n-3];
    My[n-2] = Dy[n-3];
    for(int i = n-4; i >= 0; i--)
    {
        Mx[i+1] = Dx[i] - C[i]*Mx[i+2];
        My[i+1] = Dy[i] - C[i]*My[i+2];
    }

    Mx[0] = 0;
    Mx[n-1] = 0;
    My[0] = 0;
    My[n-1] = 0;

    for(int i = 0; i < n-1; i++)
    {
        func_v[i].a_x = point_v[i].x;
        func_v[i].b_x = (point_v[i+1].x - point_v[i].x)/h[i] - (2*h[i]*Mx[i] + h[i]*Mx[i+1]) / 6;
        func_v[i].c_x = Mx[i]/2;
        func_v[i].d_x = (Mx[i+1] - Mx[i]) / (6*h[i]);

        func_v[i].a_y = point_v[i].y;
        func_v[i].b_y = (point_v[i+1].y - point_v[i].y)/h[i] - (2*h[i]*My[i] + h[i]*My[i+1]) / 6;
        func_v[i].c_y = My[i]/2;
        func_v[i].d_y = (My[i+1] - My[i]) / (6*h[i]);

        func_v[i].h = h[i];
    }
    return func_v;
}


================================================
FILE: runner/evaluator/culane/prob2lines.py
================================================
import os
import argparse
import numpy as np
import pandas as pd
from PIL import Image
import tqdm


def getLane(probmap, pts, cfg = None):
    thr = 0.3
    coordinate = np.zeros(pts)
    cut_height = 0
    if cfg.cut_height:
        cut_height = cfg.cut_height
    for i in range(pts):
        line = probmap[round(cfg.img_height-i*20/(590-cut_height)*cfg.img_height)-1]
        if np.max(line)/255 > thr:
            coordinate[i] = np.argmax(line)+1
    if np.sum(coordinate > 0) < 2:
        coordinate = np.zeros(pts)
    return coordinate


def prob2lines(prob_dir, out_dir, list_file, cfg = None):
    lists = pd.read_csv(list_file, sep=' ', header=None,
                        names=('img', 'probmap', 'label1', 'label2', 'label3', 'label4'))
    pts = 18

    for k, im in enumerate(lists['img'], 1):
        existPath = prob_dir + im[:-4] + '.exist.txt'
        outname = out_dir + im[:-4] + '.lines.txt'
        prefix = '/'.join(outname.split('/')[:-1])
        if not os.path.exists(prefix):
            os.makedirs(prefix)
        f = open(outname, 'w')

        labels = list(pd.read_csv(existPath, sep=' ', header=None).iloc[0])
        coordinates = np.zeros((4, pts))
        for i in range(4):
            if labels[i] == 1:
                probfile = prob_dir + im[:-4] + '_{0}_avg.png'.format(i+1)
                probmap = np.array(Image.open(probfile))
                coordinates[i] = getLane(probmap, pts, cfg)

                if np.sum(coordinates[i] > 0) > 1:
                    for idx, value in enumerate(coordinates[i]):
                        if value > 0:
                            f.write('%d %d ' % (
                                round(value*1640/cfg.img_width)-1, round(590-idx*20)-1))
                    f.write('\n')
        f.close()


================================================
FILE: runner/evaluator/tusimple/getLane.py
================================================
import cv2
import numpy as np

def isShort(lane):
    start = [i for i, x in enumerate(lane) if x > 0]
    if not start:
        return 1
    else:
        return 0

def fixGap(coordinate):
    if any(x > 0 for x in coordinate):
        start = [i for i, x in enumerate(coordinate) if x > 0][0]
        end = [i for i, x in reversed(list(enumerate(coordinate))) if x > 0][0]
        lane = coordinate[start:end+1]
        if any(x < 0 for x in lane):
            gap_start = [i for i, x in enumerate(
                lane[:-1]) if x > 0 and lane[i+1] < 0]
            gap_end = [i+1 for i,
                       x in enumerate(lane[:-1]) if x < 0 and lane[i+1] > 0]
            gap_id = [i for i, x in enumerate(lane) if x < 0]
            if len(gap_start) == 0 or len(gap_end) == 0:
                return coordinate
            for id in gap_id:
                for i in range(len(gap_start)):
                    if i >= len(gap_end):
                        return coordinate
                    if id > gap_start[i] and id < gap_end[i]:
                        gap_width = float(gap_end[i] - gap_start[i])
                        lane[id] = int((id - gap_start[i]) / gap_width * lane[gap_end[i]] + (
                            gap_end[i] - id) / gap_width * lane[gap_start[i]])
            if not all(x > 0 for x in lane):
                print("Gaps still exist!")
            coordinate[start:end+1] = lane
    return coordinate

def getLane_tusimple(prob_map, y_px_gap, pts, thresh, resize_shape=None, cfg=None):
    """
    Arguments:
    ----------
    prob_map: prob map for single lane, np array size (h, w)
    resize_shape:  reshape size target, (H, W)

    Return:
    ----------
    coords: x coords bottom up every y_px_gap px, 0 for non-exist, in resized shape
    """
    if resize_shape is None:
        resize_shape = prob_map.shape
    h, w = prob_map.shape
    H, W = resize_shape
    H -= cfg.cut_height

    coords = np.zeros(pts)
    coords[:] = -1.0
    for i in range(pts):
        y = int((H - 10 - i * y_px_gap) * h / H)
        if y < 0:
            break
        line = prob_map[y, :]
        id = np.argmax(line)
        if line[id] > thresh:
            coords[i] = int(id / w * W)
    if (coords > 0).sum() < 2:
        coords = np.zeros(pts)
    fixGap(coords)
    return coords


def prob2lines_tusimple(seg_pred, exist, resize_shape=None, smooth=True, y_px_gap=10, pts=None, thresh=0.3, cfg=None):
    """
    Arguments:
    ----------
    seg_pred:      np.array size (5, h, w)
    resize_shape:  reshape size target, (H, W)
    exist:       list of existence, e.g. [0, 1, 1, 0]
    smooth:      whether to smooth the probability or not
    y_px_gap:    y pixel gap for sampling
    pts:     how many points for one lane
    thresh:  probability threshold

    Return:
    ----------
    coordinates: [x, y] list of lanes, e.g.: [ [[9, 569], [50, 549]] ,[[630, 569], [647, 549]] ]
    """
    if resize_shape is None:
        resize_shape = seg_pred.shape[1:]  # seg_pred (5, h, w)
    _, h, w = seg_pred.shape
    H, W = resize_shape
    coordinates = []

    if pts is None:
        pts = round(H / 2 / y_px_gap)

    seg_pred = np.ascontiguousarray(np.transpose(seg_pred, (1, 2, 0)))
    for i in range(cfg.num_classes - 1):
        prob_map = seg_pred[..., i + 1]
        if smooth:
            prob_map = cv2.blur(prob_map, (9, 9), borderType=cv2.BORDER_REPLICATE)
        coords = getLane_tusimple(prob_map, y_px_gap, pts, thresh, resize_shape, cfg)
        if isShort(coords):
            continue
        coordinates.append(
            [[coords[j], H - 10 - j * y_px_gap] if coords[j] > 0 else [-1, H - 10 - j * y_px_gap] for j in
             range(pts)])


    if len(coordinates) == 0:
        coords = np.zeros(pts)
        coordinates.append(
            [[coords[j], H - 10 - j * y_px_gap] if coords[j] > 0 else [-1, H - 10 - j * y_px_gap] for j in
             range(pts)])


    return coordinates


================================================
FILE: runner/evaluator/tusimple/lane.py
================================================
import numpy as np
from sklearn.linear_model import LinearRegression
import json as json


class LaneEval(object):
    lr = LinearRegression()
    pixel_thresh = 20
    pt_thresh = 0.85

    @staticmethod
    def get_angle(xs, y_samples):
        xs, ys = xs[xs >= 0], y_samples[xs >= 0]
        if len(xs) > 1:
            LaneEval.lr.fit(ys[:, None], xs)
            k = LaneEval.lr.coef_[0]
            theta = np.arctan(k)
        else:
            theta = 0
        return theta

    @staticmethod
    def line_accuracy(pred, gt, thresh):
        pred = np.array([p if p >= 0 else -100 for p in pred])
        gt = np.array([g if g >= 0 else -100 for g in gt])
        return np.sum(np.where(np.abs(pred - gt) < thresh, 1., 0.)) / len(gt)

    @staticmethod
    def bench(pred, gt, y_samples, running_time):
        if any(len(p) != len(y_samples) for p in pred):
            raise Exception('Format of lanes error.')
        if running_time > 200 or len(gt) + 2 < len(pred):
            return 0., 0., 1.
        angles = [LaneEval.get_angle(
            np.array(x_gts), np.array(y_samples)) for x_gts in gt]
        threshs = [LaneEval.pixel_thresh / np.cos(angle) for angle in angles]
        line_accs = []
        fp, fn = 0., 0.
        matched = 0.
        for x_gts, thresh in zip(gt, threshs):
            accs = [LaneEval.line_accuracy(
                np.array(x_preds), np.array(x_gts), thresh) for x_preds in pred]
            max_acc = np.max(accs) if len(accs) > 0 else 0.
            if max_acc < LaneEval.pt_thresh:
                fn += 1
            else:
                matched += 1
            line_accs.append(max_acc)
        fp = len(pred) - matched
        if len(gt) > 4 and fn > 0:
            fn -= 1
        s = sum(line_accs)
        if len(gt) > 4:
            s -= min(line_accs)
        return s / max(min(4.0, len(gt)), 1.), fp / len(pred) if len(pred) > 0 else 0., fn / max(min(len(gt), 4.), 1.)

    @staticmethod
    def bench_one_submit(pred_file, gt_file):
        try:
            json_pred = [json.loads(line)
                         for line in open(pred_file).readlines()]
        except BaseException as e:
            raise Exception('Fail to load json file of the prediction.')
        json_gt = [json.loads(line) for line in open(gt_file).readlines()]
        if len(json_gt) != len(json_pred):
            raise Exception(
                'We do not get the predictions of all the test tasks')
        gts = {l['raw_file']: l for l in json_gt}
        accuracy, fp, fn = 0., 0., 0.
        for pred in json_pred:
            if 'raw_file' not in pred or 'lanes' not in pred or 'run_time' not in pred:
                raise Exception(
                    'raw_file or lanes or run_time not in some predictions.')
            raw_file = pred['raw_file']
            pred_lanes = pred['lanes']
            run_time = pred['run_time']
            if raw_file not in gts:
                raise Exception(
                    'Some raw_file from your predictions do not exist in the test tasks.')
            gt = gts[raw_file]
            gt_lanes = gt['lanes']
            y_samples = gt['h_samples']
            try:
                a, p, n = LaneEval.bench(
                    pred_lanes, gt_lanes, y_samples, run_time)
            except BaseException as e:
                raise Exception('Format of lanes error.')
            accuracy += a
            fp += p
            fn += n
        num = len(gts)
        # the first return parameter is the default ranking parameter
        return json.dumps([
            {'name': 'Accuracy', 'value': accuracy / num, 'order': 'desc'},
            {'name': 'FP', 'value': fp / num, 'order': 'asc'},
            {'name': 'FN', 'value': fn / num, 'order': 'asc'}
        ]), accuracy / num


if __name__ == '__main__':
    import sys
    try:
        if len(sys.argv) != 3:
            raise Exception('Invalid input arguments')
        print(LaneEval.bench_one_submit(sys.argv[1], sys.argv[2]))
    except Exception as e:
        print(e.message)
        sys.exit(e.message)


================================================
FILE: runner/evaluator/tusimple/tusimple.py
================================================
import torch.nn as nn
import torch
import torch.nn.functional as F
from runner.logger import get_logger

from runner.registry import EVALUATOR 
import json
import os
import cv2

from .lane import LaneEval

def split_path(path):
    """split path tree into list"""
    folders = []
    while True:
        path, folder = os.path.split(path)
        if folder != "":
            folders.insert(0, folder)
        else:
            if path != "":
                folders.insert(0, path)
            break
    return folders


@EVALUATOR.register_module
class Tusimple(nn.Module):
    def __init__(self, cfg):
        super(Tusimple, self).__init__()
        self.cfg = cfg 
        exp_dir = os.path.join(self.cfg.work_dir, "output")
        if not os.path.exists(exp_dir):
            os.mkdir(exp_dir)
        self.out_path = os.path.join(exp_dir, "coord_output")
        if not os.path.exists(self.out_path):
            os.mkdir(self.out_path)
        self.dump_to_json = [] 
        self.thresh = cfg.evaluator.thresh
        self.logger = get_logger('resa')
        if cfg.view:
            self.view_dir = os.path.join(self.cfg.work_dir, 'vis')

    def evaluate_pred(self, dataset, seg_pred, exist_pred, batch):
        img_name = batch['meta']['img_name']
        img_path = batch['meta']['full_img_path']
        for b in range(len(seg_pred)):
            seg = seg_pred[b]
            exist = [1 if exist_pred[b, i] >
                     0.5 else 0 for i in range(self.cfg.num_classes-1)]
            lane_coords = dataset.probmap2lane(seg, exist, thresh = self.thresh)
            for i in range(len(lane_coords)):
                lane_coords[i] = sorted(
                    lane_coords[i], key=lambda pair: pair[1])

            path_tree = split_path(img_name[b])
            save_dir, save_name = path_tree[-3:-1], path_tree[-1]
            save_dir = os.path.join(self.out_path, *save_dir)
            save_name = save_name[:-3] + "lines.txt"
            save_name = os.path.join(save_dir, save_name)
            if not os.path.exists(save_dir):
                os.makedirs(save_dir, exist_ok=True)

            with open(save_name, "w") as f:
                for l in lane_coords:
                    for (x, y) in l:
                        print("{} {}".format(x, y), end=" ", file=f)
                    print(file=f)

            json_dict = {}
            json_dict['lanes'] = []
            json_dict['h_sample'] = []
            json_dict['raw_file'] = os.path.join(*path_tree[-4:])
            json_dict['run_time'] = 0
            for l in lane_coords:
                if len(l) == 0:
                    continue
                json_dict['lanes'].append([])
                for (x, y) in l:
                    json_dict['lanes'][-1].append(int(x))
            for (x, y) in lane_coords[0]:
                json_dict['h_sample'].append(y)
            self.dump_to_json.append(json.dumps(json_dict))
            if self.cfg.view:
                img = cv2.imread(img_path[b])
                new_img_name = img_name[b].replace('/', '_')
                save_dir = os.path.join(self.view_dir, new_img_name)
                dataset.view(img, lane_coords, save_dir)


    def evaluate(self, dataset, output, batch):
        seg_pred, exist_pred = output['seg'], output['exist']
        seg_pred = F.softmax(seg_pred, dim=1)
        seg_pred = seg_pred.detach().cpu().numpy()
        exist_pred = exist_pred.detach().cpu().numpy()
        self.evaluate_pred(dataset, seg_pred, exist_pred, batch)

    def summarize(self):
        best_acc = 0
        output_file = os.path.join(self.out_path, 'predict_test.json')
        with open(output_file, "w+") as f:
            for line in self.dump_to_json:
                print(line, end="\n", file=f)

        eval_result, acc = LaneEval.bench_one_submit(output_file,
                            self.cfg.test_json_file)

        self.logger.info(eval_result)
        self.dump_to_json = []
        best_acc = max(acc, best_acc)
        return best_acc


================================================
FILE: runner/logger.py
================================================
import logging

logger_initialized = {}

def get_logger(name, log_file=None, log_level=logging.INFO):
    """Initialize and get a logger by name.
    If the logger has not been initialized, this method will initialize the
    logger by adding one or two handlers, otherwise the initialized logger will
    be directly returned. During initialization, a StreamHandler will always be
    added. If `log_file` is specified and the process rank is 0, a FileHandler
    will also be added.
    Args:
        name (str): Logger name.
        log_file (str | None): The log filename. If specified, a FileHandler
            will be added to the logger.
        log_level (int): The logger level. Note that only the process of
            rank 0 is affected, and other processes will set the level to
            "Error" thus be silent most of the time.
    Returns:
        logging.Logger: The expected logger.
    """
    logger = logging.getLogger(name)
    if name in logger_initialized:
        return logger
    # handle hierarchical names
    # e.g., logger "a" is initialized, then logger "a.b" will skip the
    # initialization since it is a child of "a".
    for logger_name in logger_initialized:
        if name.startswith(logger_name):
            return logger

    stream_handler = logging.StreamHandler()
    handlers = [stream_handler]

    if log_file is not None:
        file_handler = logging.FileHandler(log_file, 'w')
        handlers.append(file_handler)

    formatter = logging.Formatter(
        '%(asctime)s - %(name)s - %(levelname)s - %(message)s')
    for handler in handlers:
        handler.setFormatter(formatter)
        handler.setLevel(log_level)
        logger.addHandler(handler)

    logger.setLevel(log_level)

    logger_initialized[name] = True

    return logger


================================================
FILE: runner/net_utils.py
================================================
import torch
import os
from torch import nn
import numpy as np
import torch.nn.functional
from termcolor import colored
from .logger import get_logger

def save_model(net, optim, scheduler, recorder, is_best=False):
    model_dir = os.path.join(recorder.work_dir, 'ckpt')
    os.system('mkdir -p {}'.format(model_dir))
    epoch = recorder.epoch
    ckpt_name = 'best' if is_best else epoch
    torch.save({
        'net': net.state_dict(),
        'optim': optim.state_dict(),
        'scheduler': scheduler.state_dict(),
        'recorder': recorder.state_dict(),
        'epoch': epoch
    }, os.path.join(model_dir, '{}.pth'.format(ckpt_name)))


def load_network_specified(net, model_dir, logger=None):
    pretrained_net = torch.load(model_dir)['net']
    net_state = net.state_dict()
    state = {}
    for k, v in pretrained_net.items():
        if k not in net_state.keys() or v.size() != net_state[k].size():
            if logger:
                logger.info('skip weights: ' + k)
            continue
        state[k] = v
    net.load_state_dict(state, strict=False)


def load_network(net, model_dir, finetune_from=None, logger=None):
    if finetune_from:
        if logger:
            logger.info('Finetune model from: ' + finetune_from)
        load_network_specified(net, finetune_from, logger)
        return
    pretrained_model = torch.load(model_dir)
    net.load_state_dict(pretrained_model['net'], strict=True)


================================================
FILE: runner/optimizer.py
================================================
import torch


_optimizer_factory = {
    'adam': torch.optim.Adam,
    'sgd': torch.optim.SGD
}


def build_optimizer(cfg, net):
    params = []
    lr = cfg.optimizer.lr
    weight_decay = cfg.optimizer.weight_decay

    for key, value in net.named_parameters():
        if not value.requires_grad:
            continue
        params += [{"params": [value], "lr": lr, "weight_decay": weight_decay}]

    if 'adam' in cfg.optimizer.type:
        optimizer = _optimizer_factory[cfg.optimizer.type](params, lr, weight_decay=weight_decay)
    else:
        optimizer = _optimizer_factory[cfg.optimizer.type](
                params, lr, weight_decay=weight_decay, momentum=cfg.optimizer.momentum)

    return optimizer


================================================
FILE: runner/recorder.py
================================================
from collections import deque, defaultdict
import torch
import os
import datetime
from .logger import get_logger


class SmoothedValue(object):
    """Track a series of values and provide access to smoothed values over a
    window or the global series average.
    """

    def __init__(self, window_size=20):
        self.deque = deque(maxlen=window_size)
        self.total = 0.0
        self.count = 0

    def update(self, value):
        self.deque.append(value)
        self.count += 1
        self.total += value

    @property
    def median(self):
        d = torch.tensor(list(self.deque))
        return d.median().item()

    @property
    def avg(self):
        d = torch.tensor(list(self.deque))
        return d.mean().item()

    @property
    def global_avg(self):
        return self.total / self.count


class Recorder(object):
    def __init__(self, cfg):
        self.cfg = cfg
        self.work_dir = self.get_work_dir()
        cfg.work_dir = self.work_dir
        self.log_path = os.path.join(self.work_dir, 'log.txt')

        self.logger = get_logger('resa', self.log_path)
        self.logger.info('Config: \n' + cfg.text)

        # scalars
        self.epoch = 0
        self.step = 0
        self.loss_stats = defaultdict(SmoothedValue)
        self.batch_time = SmoothedValue()
        self.data_time = SmoothedValue()
        self.max_iter = self.cfg.total_iter 
        self.lr = 0.

    def get_work_dir(self):
        now = datetime.datetime.now().strftime('%Y%m%d_%H%M%S')
        hyper_param_str = '_lr_%1.0e_b_%d' % (self.cfg.optimizer.lr, self.cfg.batch_size)
        work_dir = os.path.join(self.cfg.work_dirs, now + hyper_param_str)
        if not os.path.exists(work_dir):
            os.makedirs(work_dir)
        return work_dir

    def update_loss_stats(self, loss_dict):
        for k, v in loss_dict.items():
            self.loss_stats[k].update(v.detach().cpu())

    def record(self, prefix, step=-1, loss_stats=None, image_stats=None):
        self.logger.info(self)
        # self.write(str(self))

    def write(self, content):
        with open(self.log_path, 'a+') as f:
            f.write(content)
            f.write('\n')

    def state_dict(self):
        scalar_dict = {}
        scalar_dict['step'] = self.step
        return scalar_dict

    def load_state_dict(self, scalar_dict):
        self.step = scalar_dict['step']

    def __str__(self):
        loss_state = []
        for k, v in self.loss_stats.items():
            loss_state.append('{}: {:.4f}'.format(k, v.avg))
        loss_state = '  '.join(loss_state)

        recording_state = '  '.join(['epoch: {}', 'step: {}', 'lr: {:.4f}', '{}', 'data: {:.4f}', 'batch: {:.4f}', 'eta: {}'])
        eta_seconds = self.batch_time.global_avg * (self.max_iter - self.step)
        eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
        return recording_state.format(self.epoch, self.step, self.lr, loss_state, self.data_time.avg, self.batch_time.avg, eta_string)


def build_recorder(cfg):
    return Recorder(cfg)



================================================
FILE: runner/registry.py
================================================
from utils import Registry, build_from_cfg

TRAINER = Registry('trainer')
EVALUATOR = Registry('evaluator')

def build(cfg, registry, default_args=None):
    if isinstance(cfg, list):
        modules = [
            build_from_cfg(cfg_, registry, default_args) for cfg_ in cfg
        ]
        return nn.Sequential(*modules)
    else:
        return build_from_cfg(cfg, registry, default_args)

def build_trainer(cfg):
    return build(cfg.trainer, TRAINER, default_args=dict(cfg=cfg))

def build_evaluator(cfg):
    return build(cfg.evaluator, EVALUATOR, default_args=dict(cfg=cfg))


================================================
FILE: runner/resa_trainer.py
================================================
import torch.nn as nn
import torch
import torch.nn.functional as F

from runner.registry import TRAINER

def dice_loss(input, target):
    input = input.contiguous().view(input.size()[0], -1)
    target = target.contiguous().view(target.size()[0], -1).float()

    a = torch.sum(input * target, 1)
    b = torch.sum(input * input, 1) + 0.001
    c = torch.sum(target * target, 1) + 0.001
    d = (2 * a) / (b + c)
    return (1-d).mean()

@TRAINER.register_module
class RESA(nn.Module):
    def __init__(self, cfg):
        super(RESA, self).__init__()
        self.cfg = cfg
        self.loss_type = cfg.loss_type
        if self.loss_type == 'cross_entropy':
            weights = torch.ones(cfg.num_classes)
            weights[0] = cfg.bg_weight
            weights = weights.cuda()
            self.criterion = torch.nn.NLLLoss(ignore_index=self.cfg.ignore_label,
                                              weight=weights).cuda()

        self.criterion_exist = torch.nn.BCEWithLogitsLoss().cuda()

    def forward(self, net, batch):
        output = net(batch['img'])

        loss_stats = {}
        loss = 0.

        if self.loss_type == 'dice_loss':
            target = F.one_hot(batch['label'], num_classes=self.cfg.num_classes).permute(0, 3, 1, 2)
            seg_loss = dice_loss(F.softmax(
                output['seg'], dim=1)[:, 1:], target[:, 1:])
        else:
            seg_loss = self.criterion(F.log_softmax(
                output['seg'], dim=1), batch['label'].long())

        loss += seg_loss * self.cfg.seg_loss_weight

        loss_stats.update({'seg_loss': seg_loss})

        if 'exist' in output:
            exist_loss = 0.1 * \
                self.criterion_exist(output['exist'], batch['exist'].float())
            loss += exist_loss
            loss_stats.update({'exist_loss': exist_loss})

        ret = {'loss': loss, 'loss_stats': loss_stats}

        return ret


================================================
FILE: runner/runner.py
================================================
import time
import torch
import numpy as np
from tqdm import tqdm
import pytorch_warmup as warmup

from models.registry import build_net
from .registry import build_trainer, build_evaluator
from .optimizer import build_optimizer
from .scheduler import build_scheduler
from datasets import build_dataloader
from .recorder import build_recorder
from .net_utils import save_model, load_network


class Runner(object):
    def __init__(self, cfg):
        self.cfg = cfg
        self.recorder = build_recorder(self.cfg)
        self.net = build_net(self.cfg)
        self.net = torch.nn.parallel.DataParallel(
                self.net, device_ids = range(self.cfg.gpus)).cuda()
        self.recorder.logger.info('Network: \n' + str(self.net))
        self.resume()
        self.optimizer = build_optimizer(self.cfg, self.net)
        self.scheduler = build_scheduler(self.cfg, self.optimizer)
        self.evaluator = build_evaluator(self.cfg)
        self.warmup_scheduler = warmup.LinearWarmup(
            self.optimizer, warmup_period=5000)
        self.metric = 0.

    def resume(self):
        if not self.cfg.load_from and not self.cfg.finetune_from:
            return
        load_network(self.net, self.cfg.load_from,
                finetune_from=self.cfg.finetune_from, logger=self.recorder.logger)

    def to_cuda(self, batch):
        for k in batch:
            if k == 'meta':
                continue
            batch[k] = batch[k].cuda()
        return batch
    
    def train_epoch(self, epoch, train_loader):
        self.net.train()
        end = time.time()
        max_iter = len(train_loader)
        for i, data in enumerate(train_loader):
            if self.recorder.step >= self.cfg.total_iter:
                break
            date_time = time.time() - end
            self.recorder.step += 1
            data = self.to_cuda(data)
            output = self.trainer.forward(self.net, data)
            self.optimizer.zero_grad()
            loss = output['loss']
            loss.backward()
            self.optimizer.step()
            self.scheduler.step()
            self.warmup_scheduler.dampen()
            batch_time = time.time() - end
            end = time.time()
            self.recorder.update_loss_stats(output['loss_stats'])
            self.recorder.batch_time.update(batch_time)
            self.recorder.data_time.update(date_time)

            if i % self.cfg.log_interval == 0 or i == max_iter - 1:
                lr = self.optimizer.param_groups[0]['lr']
                self.recorder.lr = lr
                self.recorder.record('train')

    def train(self):
        self.recorder.logger.info('start training...')
        self.trainer = build_trainer(self.cfg)
        train_loader = build_dataloader(self.cfg.dataset.train, self.cfg, is_train=True)
        val_loader = build_dataloader(self.cfg.dataset.val, self.cfg, is_train=False)

        for epoch in range(self.cfg.epochs):
            self.recorder.epoch = epoch
            self.train_epoch(epoch, train_loader)
            if (epoch + 1) % self.cfg.save_ep == 0 or epoch == self.cfg.epochs - 1:
                self.save_ckpt()
            if (epoch + 1) % self.cfg.eval_ep == 0 or epoch == self.cfg.epochs - 1:
                self.validate(val_loader)
            if self.recorder.step >= self.cfg.total_iter:
                break

    def validate(self, val_loader):
        self.net.eval()
        for i, data in enumerate(tqdm(val_loader, desc=f'Validate')):
            data = self.to_cuda(data)
            with torch.no_grad():
                output = self.net(data['img'])
                self.evaluator.evaluate(val_loader.dataset, output, data)

        metric = self.evaluator.summarize()
        if not metric:
            return
        if metric > self.metric:
            self.metric = metric
            self.save_ckpt(is_best=True)
        self.recorder.logger.info('Best metric: ' + str(self.metric))

    def save_ckpt(self, is_best=False):
        save_model(self.net, self.optimizer, self.scheduler,
                self.recorder, is_best)


================================================
FILE: runner/scheduler.py
================================================
import torch
import math


_scheduler_factory = {
    'LambdaLR': torch.optim.lr_scheduler.LambdaLR,
}


def build_scheduler(cfg, optimizer):

    assert cfg.scheduler.type in _scheduler_factory

    cfg_cp = cfg.scheduler.copy()
    cfg_cp.pop('type')

    scheduler = _scheduler_factory[cfg.scheduler.type](optimizer, **cfg_cp)


    return scheduler 


================================================
FILE: tools/generate_seg_tusimple.py
================================================
import json
import numpy as np
import cv2
import os
import argparse

TRAIN_SET = ['label_data_0313.json', 'label_data_0601.json']
VAL_SET = ['label_data_0531.json']
TRAIN_VAL_SET = TRAIN_SET + VAL_SET
TEST_SET = ['test_label.json']

def gen_label_for_json(args, image_set):
    H, W = 720, 1280
    SEG_WIDTH = 30
    save_dir = args.savedir

    os.makedirs(os.path.join(args.root, args.savedir, "list"), exist_ok=True)
    list_f = open(os.path.join(args.root, args.savedir, "list", "{}_gt.txt".format(image_set)), "w")

    json_path = os.path.join(args.root, args.savedir, "{}.json".format(image_set))
    with open(json_path) as f:
        for line in f:
            label = json.loads(line)
            # ---------- clean and sort lanes -------------
            lanes = []
            _lanes = []
            slope = [] # identify 0th, 1st, 2nd, 3rd, 4th, 5th lane through slope
            for i in range(len(label['lanes'])):
                l = [(x, y) for x, y in zip(label['lanes'][i], label['h_samples']) if x >= 0]
                if (len(l)>1):
                    _lanes.append(l)
                    slope.append(np.arctan2(l[-1][1]-l[0][1], l[0][0]-l[-1][0]) / np.pi * 180)
            _lanes = [_lanes[i] for i in np.argsort(slope)]
            slope = [slope[i] for i in np.argsort(slope)]

            idx = [None for i in range(6)]
            for i in range(len(slope)):
                if slope[i] <= 90:
                    idx[2] = i
                    idx[1] = i-1 if i > 0 else None
                    idx[0] = i-2 if i > 1 else None
                else:
                    idx[3] = i
                    idx[4] = i+1 if i+1 < len(slope) else None
                    idx[5] = i+2 if i+2 < len(slope) else None
                    break
            for i in range(6):
                lanes.append([] if idx[i] is None else _lanes[idx[i]])

            # ---------------------------------------------

            img_path = label['raw_file']
            seg_img = np.zeros((H, W, 3))
            list_str = []  # str to be written to list.txt
            for i in range(len(lanes)):
                coords = lanes[i]
                if len(coords) < 4:
                    list_str.append('0')
                    continue
                for j in range(len(coords)-1):
                    cv2.line(seg_img, coords[j], coords[j+1], (i+1, i+1, i+1), SEG_WIDTH//2)
                list_str.append('1')

            seg_path = img_path.split("/")
            seg_path, img_name = os.path.join(args.root, args.savedir, seg_path[1], seg_path[2]), seg_path[3]
            os.makedirs(seg_path, exist_ok=True)
            seg_path = os.path.join(seg_path, img_name[:-3]+"png")
            cv2.imwrite(seg_path, seg_img)

            seg_path = "/".join([args.savedir, *img_path.split("/")[1:3], img_name[:-3]+"png"])
            if seg_path[0] != '/':
                seg_path = '/' + seg_path
            if img_path[0] != '/':
                img_path = '/' + img_path
            list_str.insert(0, seg_path)
            list_str.insert(0, img_path)
            list_str = " ".join(list_str) + "\n"
            list_f.write(list_str)


def generate_json_file(save_dir, json_file, image_set):
    with open(os.path.join(save_dir, json_file), "w") as outfile:
        for json_name in (image_set):
            with open(os.path.join(args.root, json_name)) as infile:
                for line in infile:
                    outfile.write(line)

def generate_label(args):
    save_dir = os.path.join(args.root, args.savedir)
    os.makedirs(save_dir, exist_ok=True)
    generate_json_file(save_dir, "train_val.json", TRAIN_VAL_SET)
    generate_json_file(save_dir, "test.json", TEST_SET)

    print("generating train_val set...")
    gen_label_for_json(args, 'train_val')
    print("generating test set...")
    gen_label_for_json(args, 'test')

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--root', required=True, help='The root of the Tusimple dataset')
    parser.add_argument('--savedir', type=str, default='seg_label', help='The root of the Tusimple dataset')
    args = parser.parse_args()

    generate_label(args)


================================================
FILE: utils/__init__.py
================================================
from .config import Config
from .registry import Registry, build_from_cfg


================================================
FILE: utils/config.py
================================================
# Copyright (c) Open-MMLab. All rights reserved.
import ast
import os.path as osp
import shutil
import sys
import tempfile
from argparse import Action, ArgumentParser
from collections import abc
from importlib import import_module

from addict import Dict
from yapf.yapflib.yapf_api import FormatCode


BASE_KEY = '_base_'
DELETE_KEY = '_delete_'
RESERVED_KEYS = ['filename', 'text', 'pretty_text']

def check_file_exist(filename, msg_tmpl='file "{}" does not exist'):
    if not osp.isfile(filename):
        raise FileNotFoundError(msg_tmpl.format(filename))



class ConfigDict(Dict):

    def __missing__(self, name):
        raise KeyError(name)

    def __getattr__(self, name):
        try:
            value = super(ConfigDict, self).__getattr__(name)
        except KeyError:
            ex = AttributeError(f"'{self.__class__.__name__}' object has no "
                                f"attribute '{name}'")
        except Exception as e:
            ex = e
        else:
            return value
        raise ex


def add_args(parser, cfg, prefix=''):
    for k, v in cfg.items():
        if isinstance(v, str):
            parser.add_argument('--' + prefix + k)
        elif isinstance(v, int):
            parser.add_argument('--' + prefix + k, type=int)
        elif isinstance(v, float):
            parser.add_argument('--' + prefix + k, type=float)
        elif isinstance(v, bool):
            parser.add_argument('--' + prefix + k, action='store_true')
        elif isinstance(v, dict):
            add_args(parser, v, prefix + k + '.')
        elif isinstance(v, abc.Iterable):
            parser.add_argument('--' + prefix + k, type=type(v[0]), nargs='+')
        else:
            print(f'cannot parse key {prefix + k} of type {type(v)}')
    return parser


class Config:
    """A facility for config and config files.
    It supports common file formats as configs: python/json/yaml. The interface
    is the same as a dict object and also allows access config values as
    attributes.
    Example:
        >>> cfg = Config(dict(a=1, b=dict(b1=[0, 1])))
        >>> cfg.a
        1
        >>> cfg.b
        {'b1': [0, 1]}
        >>> cfg.b.b1
        [0, 1]
        >>> cfg = Config.fromfile('tests/data/config/a.py')
        >>> cfg.filename
        "/home/kchen/projects/mmcv/tests/data/config/a.py"
        >>> cfg.item4
        'test'
        >>> cfg
        "Config [path: /home/kchen/projects/mmcv/tests/data/config/a.py]: "
        "{'item1': [1, 2], 'item2': {'a': 0}, 'item3': True, 'item4': 'test'}"
    """

    @staticmethod
    def _validate_py_syntax(filename):
        with open(filename) as f:
            content = f.read()
        try:
            ast.parse(content)
        except SyntaxError:
            raise SyntaxError('There are syntax errors in config '
                              f'file {filename}')

    @staticmethod
    def _file2dict(filename):
        filename = osp.abspath(osp.expanduser(filename))
        check_file_exist(filename)
        if filename.endswith('.py'):
            with tempfile.TemporaryDirectory() as temp_config_dir:
                temp_config_file = tempfile.NamedTemporaryFile(
                    dir=temp_config_dir, suffix='.py')
                temp_config_name = osp.basename(temp_config_file.name)
                shutil.copyfile(filename,
                                osp.join(temp_config_dir, temp_config_name))
                temp_module_name = osp.splitext(temp_config_name)[0]
                sys.path.insert(0, temp_config_dir)
                Config._validate_py_syntax(filename)
                mod = import_module(temp_module_name)
                sys.path.pop(0)
                cfg_dict = {
                    name: value
                    for name, value in mod.__dict__.items()
                    if not name.startswith('__')
                }
                # delete imported module
                del sys.modules[temp_module_name]
                # close temp file
                temp_config_file.close()
        elif filename.endswith(('.yml', '.yaml', '.json')):
            import mmcv
            cfg_dict = mmcv.load(filename)
        else:
            raise IOError('Only py/yml/yaml/json type are supported now!')

        cfg_text = filename + '\n'
        with open(filename, 'r') as f:
            cfg_text += f.read()

        if BASE_KEY in cfg_dict:
            cfg_dir = osp.dirname(filename)
            base_filename = cfg_dict.pop(BASE_KEY)
            base_filename = base_filename if isinstance(
                base_filename, list) else [base_filename]

            cfg_dict_list = list()
            cfg_text_list = list()
            for f in base_filename:
                _cfg_dict, _cfg_text = Config._file2dict(osp.join(cfg_dir, f))
                cfg_dict_list.append(_cfg_dict)
                cfg_text_list.append(_cfg_text)

            base_cfg_dict = dict()
            for c in cfg_dict_list:
                if len(base_cfg_dict.keys() & c.keys()) > 0:
                    raise KeyError('Duplicate key is not allowed among bases')
                base_cfg_dict.update(c)

            base_cfg_dict = Config._merge_a_into_b(cfg_dict, base_cfg_dict)
            cfg_dict = base_cfg_dict

            # merge cfg_text
            cfg_text_list.append(cfg_text)
            cfg_text = '\n'.join(cfg_text_list)

        return cfg_dict, cfg_text

    @staticmethod
    def _merge_a_into_b(a, b):
        # merge dict `a` into dict `b` (non-inplace). values in `a` will
        # overwrite `b`.
        # copy first to avoid inplace modification
        b = b.copy()
        for k, v in a.items():
            if isinstance(v, dict) and k in b and not v.pop(DELETE_KEY, False):
                if not isinstance(b[k], dict):
                    raise TypeError(
                        f'{k}={v} in child config cannot inherit from base '
                        f'because {k} is a dict in the child config but is of '
                        f'type {type(b[k])} in base config. You may set '
                        f'`{DELETE_KEY}=True` to ignore the base config')
                b[k] = Config._merge_a_into_b(v, b[k])
            else:
                b[k] = v
        return b

    @staticmethod
    def fromfile(filename):
        cfg_dict, cfg_text = Config._file2dict(filename)
        return Config(cfg_dict, cfg_text=cfg_text, filename=filename)

    @staticmethod
    def auto_argparser(description=None):
        """Generate argparser from config file automatically (experimental)
        """
        partial_parser = ArgumentParser(description=description)
        partial_parser.add_argument('config', help='config file path')
        cfg_file = partial_parser.parse_known_args()[0].config
        cfg = Config.fromfile(cfg_file)
        parser = ArgumentParser(description=description)
        parser.add_argument('config', help='config file path')
        add_args(parser, cfg)
        return parser, cfg

    def __init__(self, cfg_dict=None, cfg_text=None, filename=None):
        if cfg_dict is None:
            cfg_dict = dict()
        elif not isinstance(cfg_dict, dict):
            raise TypeError('cfg_dict must be a dict, but '
                            f'got {type(cfg_dict)}')
        for key in cfg_dict:
            if key in RESERVED_KEYS:
                raise KeyError(f'{key} is reserved for config file')

        super(Config, self).__setattr__('_cfg_dict', ConfigDict(cfg_dict))
        super(Config, self).__setattr__('_filename', filename)
        if cfg_text:
            text = cfg_text
        elif filename:
            with open(filename, 'r') as f:
                text = f.read()
        else:
            text = ''
        super(Config, self).__setattr__('_text', text)

    @property
    def filename(self):
        return self._filename

    @property
    def text(self):
        return self._text

    @property
    def pretty_text(self):

        indent = 4

        def _indent(s_, num_spaces):
            s = s_.split('\n')
            if len(s) == 1:
                return s_
            first = s.pop(0)
            s = [(num_spaces * ' ') + line for line in s]
            s = '\n'.join(s)
            s = first + '\n' + s
            return s

        def _format_basic_types(k, v, use_mapping=False):
            if isinstance(v, str):
                v_str = f"'{v}'"
            else:
                v_str = str(v)

            if use_mapping:
                k_str = f"'{k}'" if isinstance(k, str) else str(k)
                attr_str = f'{k_str}: {v_str}'
            else:
                attr_str = f'{str(k)}={v_str}'
            attr_str = _indent(attr_str, indent)

            return attr_str

        def _format_list(k, v, use_mapping=False):
            # check if all items in the list are dict
            if all(isinstance(_, dict) for _ in v):
                v_str = '[\n'
                v_str += '\n'.join(
                    f'dict({_indent(_format_dict(v_), indent)}),'
                    for v_ in v).rstrip(',')
                if use_mapping:
                    k_str = f"'{k}'" if isinstance(k, str) else str(k)
                    attr_str = f'{k_str}: {v_str}'
                else:
                    attr_str = f'{str(k)}={v_str}'
                attr_str = _indent(attr_str, indent) + ']'
            else:
                attr_str = _format_basic_types(k, v, use_mapping)
            return attr_str

        def _contain_invalid_identifier(dict_str):
            contain_invalid_identifier = False
            for key_name in dict_str:
                contain_invalid_identifier |= \
                    (not str(key_name).isidentifier())
            return contain_invalid_identifier

        def _format_dict(input_dict, outest_level=False):
            r = ''
            s = []

            use_mapping = _contain_invalid_identifier(input_dict)
            if use_mapping:
                r += '{'
            for idx, (k, v) in enumerate(input_dict.items()):
                is_last = idx >= len(input_dict) - 1
                end = '' if outest_level or is_last else ','
                if isinstance(v, dict):
                    v_str = '\n' + _format_dict(v)
                    if use_mapping:
                        k_str = f"'{k}'" if isinstance(k, str) else str(k)
                        attr_str = f'{k_str}: dict({v_str}'
                    else:
                        attr_str = f'{str(k)}=dict({v_str}'
                    attr_str = _indent(attr_str, indent) + ')' + end
                elif isinstance(v, list):
                    attr_str = _format_list(k, v, use_mapping) + end
                else:
                    attr_str = _format_basic_types(k, v, use_mapping) + end

                s.append(attr_str)
            r += '\n'.join(s)
            if use_mapping:
                r += '}'
            return r

        cfg_dict = self._cfg_dict.to_dict()
        text = _format_dict(cfg_dict, outest_level=True)
        # copied from setup.cfg
        yapf_style = dict(
            based_on_style='pep8',
            blank_line_before_nested_class_or_def=True,
            split_before_expression_after_opening_paren=True)
        text, _ = FormatCode(text, style_config=yapf_style, verify=True)

        return text

    def __repr__(self):
        return f'Config (path: {self.filename}): {self._cfg_dict.__repr__()}'

    def __len__(self):
        return len(self._cfg_dict)

    def __getattr__(self, name):
        return getattr(self._cfg_dict, name)

    def __getitem__(self, name):
        return self._cfg_dict.__getitem__(name)

    def __setattr__(self, name, value):
        if isinstance(value, dict):
            value = ConfigDict(value)
        self._cfg_dict.__setattr__(name, value)

    def __setitem__(self, name, value):
        if isinstance(value, dict):
            value = ConfigDict(value)
        self._cfg_dict.__setitem__(name, value)

    def __iter__(self):
        return iter(self._cfg_dict)

    def dump(self, file=None):
        cfg_dict = super(Config, self).__getattribute__('_cfg_dict').to_dict()
        if self.filename.endswith('.py'):
            if file is None:
                return self.pretty_text
            else:
                with open(file, 'w') as f:
                    f.write(self.pretty_text)
        else:
            import mmcv
            if file is None:
                file_format = self.filename.split('.')[-1]
                return mmcv.dump(cfg_dict, file_format=file_format)
            else:
                mmcv.dump(cfg_dict, file)

    def merge_from_dict(self, options):
        """Merge list into cfg_dict
        Merge the dict parsed by MultipleKVAction into this cfg.
        Examples:
            >>> options = {'model.backbone.depth': 50,
            ...            'model.backbone.with_cp':True}
            >>> cfg = Config(dict(model=dict(backbone=dict(type='ResNet'))))
            >>> cfg.merge_from_dict(options)
            >>> cfg_dict = super(Config, self).__getattribute__('_cfg_dict')
            >>> assert cfg_dict == dict(
            ...     model=dict(backbone=dict(depth=50, with_cp=True)))
        Args:
            options (dict): dict of configs to merge from.
        """
        option_cfg_dict = {}
        for full_key, v in options.items():
            d = option_cfg_dict
            key_list = full_key.split('.')
            for subkey in key_list[:-1]:
                d.setdefault(subkey, ConfigDict())
                d = d[subkey]
            subkey = key_list[-1]
            d[subkey] = v

        cfg_dict = super(Config, self).__getattribute__('_cfg_dict')
        super(Config, self).__setattr__(
            '_cfg_dict', Config._merge_a_into_b(option_cfg_dict, cfg_dict))


class DictAction(Action):
    """
    argparse action to split an argument into KEY=VALUE form
    on the first = and append to a dictionary. List options should
    be passed as comma separated values, i.e KEY=V1,V2,V3
    """

    @staticmethod
    def _parse_int_float_bool(val):
        try:
            return int(val)
        except ValueError:
            pass
        try:
            return float(val)
        except ValueError:
            pass
        if val.lower() in ['true', 'false']:
            return True if val.lower() == 'true' else False
        return val

    def __call__(self, parser, namespace, values, option_string=None):
        options = {}
        for kv in values:
            key, val = kv.split('=', maxsplit=1)
            val = [self._parse_int_float_bool(v) for v in val.split(',')]
            if len(val) == 1:
                val = val[0]
            options[key] = val
        setattr(namespace, self.dest, options)


================================================
FILE: utils/registry.py
================================================
import inspect

import six

# borrow from mmdetection

def is_str(x):
    """Whether the input is an string instance."""
    return isinstance(x, six.string_types)

class Registry(object):

    def __init__(self, name):
        self._name = name
        self._module_dict = dict()

    def __repr__(self):
        format_str = self.__class__.__name__ + '(name={}, items={})'.format(
            self._name, list(self._module_dict.keys()))
        return format_str

    @property
    def name(self):
        return self._name

    @property
    def module_dict(self):
        return self._module_dict

    def get(self, key):
        return self._module_dict.get(key, None)

    def _register_module(self, module_class):
        """Register a module.

        Args:
            module (:obj:`nn.Module`): Module to be registered.
        """
        if not inspect.isclass(module_class):
            raise TypeError('module must be a class, but got {}'.format(
                type(module_class)))
        module_name = module_class.__name__
        if module_name in self._module_dict:
            raise KeyError('{} is already registered in {}'.format(
                module_name, self.name))
        self._module_dict[module_name] = module_class

    def register_module(self, cls):
        self._register_module(cls)
        return cls


def build_from_cfg(cfg, registry, default_args=None):
    """Build a module from config dict.

    Args:
        cfg (dict): Config dict. It should at least contain the key "type".
        registry (:obj:`Registry`): The registry to search the type from.
        default_args (dict, optional): Default initialization arguments.

    Returns:
        obj: The constructed object.
    """
    assert isinstance(cfg, dict) and 'type' in cfg
    assert isinstance(default_args, dict) or default_args is None
    args = {}
    obj_type = cfg.type 
    if is_str(obj_type):
        obj_cls = registry.get(obj_type)
        if obj_cls is None:
            raise KeyError('{} is not in the {} registry'.format(
                obj_type, registry.name))
    elif inspect.isclass(obj_type):
        obj_cls = obj_type
    else:
        raise TypeError('type must be a str or valid type, but got {}'.format(
            type(obj_type)))
    if default_args is not None:
        for name, value in default_args.items():
            args.setdefault(name, value)
    return obj_cls(**args)


================================================
FILE: utils/transforms.py
================================================
import random
import cv2
import numpy as np
import numbers
import collections

# copy from: https://github.com/cardwing/Codes-for-Lane-Detection/blob/master/ERFNet-CULane-PyTorch/utils/transforms.py

__all__ = ['GroupRandomCrop', 'GroupCenterCrop', 'GroupRandomPad', 'GroupCenterPad',
           'GroupRandomScale', 'GroupRandomHorizontalFlip', 'GroupNormalize']


class SampleResize(object):
    def __init__(self, size):
        assert (isinstance(size, collections.Iterable) and len(size) == 2)
        self.size = size

    def __call__(self, sample):
        out = list()
        out.append(cv2.resize(sample[0], self.size,
                              interpolation=cv2.INTER_CUBIC))
        if len(sample) > 1:
            out.append(cv2.resize(sample[1], self.size,
                                  interpolation=cv2.INTER_NEAREST))
        return out


class GroupRandomCrop(object):
    def __init__(self, size):
        if isinstance(size, numbers.Number):
            self.size = (int(size), int(size))
        else:
            self.size = size

    def __call__(self, img_group):
        h, w = img_group[0].shape[0:2]
        th, tw = self.size

        out_images = list()
        h1 = random.randint(0, max(0, h - th))
        w1 = random.randint(0, max(0, w - tw))
        h2 = min(h1 + th, h)
        w2 = min(w1 + tw, w)

        for img in img_group:
            assert (img.shape[0] == h and img.shape[1] == w)
            out_images.append(img[h1:h2, w1:w2, ...])
        return out_images


class GroupRandomCropRatio(object):
    def __init__(self, size):
        if isinstance(size, numbers.Number):
            self.size = (int(size), int(size))
        else:
            self.size = size

    def __call__(self, img_group):
        h, w = img_group[0].shape[0:2]
        tw, th = self.size

        out_images = list()
        h1 = random.randint(0, max(0, h - th))
        w1 = random.randint(0, max(0, w - tw))
        h2 = min(h1 + th, h)
        w2 = min(w1 + tw, w)

        for img in img_group:
            assert (img.shape[0] == h and img.shape[1] == w)
            out_images.append(img[h1:h2, w1:w2, ...])
        return out_images


class GroupCenterCrop(object):
    def __init__(self, size):
        if isinstance(size, numbers.Number):
            self.size = (int(size), int(size))
        else:
            self.size = size

    def __call__(self, img_group):
        h, w = img_group[0].shape[0:2]
        th, tw = self.size

        out_images = list()
        h1 = max(0, int((h - th) / 2))
        w1 = max(0, int((w - tw) / 2))
        h2 = min(h1 + th, h)
        w2 = min(w1 + tw, w)

        for img in img_group:
            assert (img.shape[0] == h and img.shape[1] == w)
            out_images.append(img[h1:h2, w1:w2, ...])
        return out_images


class GroupRandomPad(object):
    def __init__(self, size, padding):
        if isinstance(size, numbers.Number):
            self.size = (int(size), int(size))
        else:
            self.size = size
        self.padding = padding

    def __call__(self, img_group):
        assert (len(self.padding) == len(img_group))
        h, w = img_group[0].shape[0:2]
        th, tw = self.size

        out_images = list()
        h1 = random.randint(0, max(0, th - h))
        w1 = random.randint(0, max(0, tw - w))
        h2 = max(th - h - h1, 0)
        w2 = max(tw - w - w1, 0)

        for img, padding in zip(img_group, self.padding):
            assert (img.shape[0] == h and img.shape[1] == w)
            out_images.append(cv2.copyMakeBorder(
                img, h1, h2, w1, w2, cv2.BORDER_CONSTANT, value=padding))
            if len(img.shape) > len(out_images[-1].shape):
                out_images[-1] = out_images[-1][...,
                                                np.newaxis]  # single channel image
        return out_images


class GroupCenterPad(object):
    def __init__(self, size, padding):
        if isinstance(size, numbers.Number):
            self.size = (int(size), int(size))
        else:
            self.size = size
        self.padding = padding

    def __call__(self, img_group):
        assert (len(self.padding) == len(img_group))
        h, w = img_group[0].shape[0:2]
        th, tw = self.size

        out_images = list()
        h1 = max(0, int((th - h) / 2))
        w1 = max(0, int((tw - w) / 2))
        h2 = max(th - h - h1, 0)
        w2 = max(tw - w - w1, 0)

        for img, padding in zip(img_group, self.padding):
            assert (img.shape[0] == h and img.shape[1] == w)
            out_images.append(cv2.copyMakeBorder(
                img, h1, h2, w1, w2, cv2.BORDER_CONSTANT, value=padding))
            if len(img.shape) > len(out_images[-1].shape):
                out_images[-1] = out_images[-1][...,
                                                np.newaxis]  # single channel image
        return out_images


class GroupConcerPad(object):
    def __init__(self, size, padding):
        if isinstance(size, numbers.Number):
            self.size = (int(size), int(size))
        else:
            self.size = size
        self.padding = padding

    def __call__(self, img_group):
        assert (len(self.padding) == len(img_group))
        h, w = img_group[0].shape[0:2]
        th, tw = self.size

        out_images = list()
        h1 = 0
        w1 = 0
        h2 = max(th - h - h1, 0)
        w2 = max(tw - w - w1, 0)

        for img, padding in zip(img_group, self.padding):
            assert (img.shape[0] == h and img.shape[1] == w)
            out_images.append(cv2.copyMakeBorder(
                img, h1, h2, w1, w2, cv2.BORDER_CONSTANT, value=padding))
            if len(img.shape) > len(out_images[-1].shape):
                out_images[-1] = out_images[-1][...,
                                                np.newaxis]  # single channel image
        return out_images


class GroupRandomScaleNew(object):
    def __init__(self, size=(976, 208), interpolation=(cv2.INTER_LINEAR, cv2.INTER_NEAREST)):
        self.size = size
        self.interpolation = interpolation

    def __call__(self, img_group):
        assert (len(self.interpolation) == len(img_group))
        scale_w, scale_h = self.size[0] * 1.0 / 1640, self.size[1] * 1.0 / 590
        out_images = list()
        for img, interpolation in zip(img_group, self.interpolation):
            out_images.append(cv2.resize(img, None, fx=scale_w,
                                         fy=scale_h, interpolation=interpolation))
            if len(img.shape) > len(out_images[-1].shape):
                out_images[-1] = out_images[-1][...,
                                                np.newaxis]  # single channel image
        return out_images


class GroupRandomScale(object):
    def __init__(self, size=(0.5, 1.5), interpolation=(cv2.INTER_LINEAR, cv2.INTER_NEAREST)):
        self.size = size
        self.interpolation = interpolation

    def __call__(self, img_group):
        assert (len(self.interpolation) == len(img_group))
        scale = random.uniform(self.size[0], self.size[1])
        out_images = list()
        for img, interpolation in zip(img_group, self.interpolation):
            out_images.append(cv2.resize(img, None, fx=scale,
                                         fy=scale, interpolation=interpolation))
            if len(img.shape) > len(out_images[-1].shape):
                out_images[-1] = out_images[-1][...,
                                                np.newaxis]  # single channel image
        return out_images


class GroupRandomMultiScale(object):
    def __init__(self, size=(0.5, 1.5), interpolation=(cv2.INTER_LINEAR, cv2.INTER_NEAREST)):
        self.size = size
        self.interpolation = interpolation

    def __call__(self, img_group):
        assert (len(self.interpolation) == len(img_group))
        scales = [0.5, 1.0, 1.5]  # random.uniform(self.size[0], self.size[1])
        out_images = list()
        for scale in scales:
            for img, interpolation in zip(img_group, self.interpolation):
                out_images.append(cv2.resize(
                    img, None, fx=scale, fy=scale, interpolation=interpolation))
                if len(img.shape) > len(out_images[-1].shape):
                    out_images[-1] = out_images[-1][...,
                                                    np.newaxis]  # single channel image
        return out_images


class GroupRandomScaleRatio(object):
    def __init__(self, size=(680, 762, 562, 592), interpolation=(cv2.INTER_LINEAR, cv2.INTER_NEAREST)):
        self.size = size
        self.interpolation = interpolation
        self.origin_id = [0, 1360, 580, 768, 255, 300, 680, 710, 312, 1509, 800, 1377, 880, 910, 1188, 128, 960, 1784,
                          1414, 1150, 512, 1162, 950, 750, 1575, 708, 2111, 1848, 1071, 1204, 892, 639, 2040, 1524, 832, 1122, 1224, 2295]

    def __call__(self, img_group):
        assert (len(self.interpolation) == len(img_group))
        w_scale = random.randint(self.size[0], self.size[1])
        h_scale = random.randint(self.size[2], self.size[3])
        h, w, _ = img_group[0].shape
        out_images = list()
        out_images.append(cv2.resize(img_group[0], None, fx=w_scale*1.0/w, fy=h_scale*1.0/h,
                                     interpolation=self.interpolation[0]))  # fx=w_scale*1.0/w, fy=h_scale*1.0/h
        ### process label map ###
        origin_label = cv2.resize(
            img_group[1], None, fx=w_scale*1.0/w, fy=h_scale*1.0/h, interpolation=self.interpolation[1])
        origin_label = origin_label.astype(int)
        label = origin_label[:, :, 0] * 5 + \
            origin_label[:, :, 1] * 3 + origin_label[:, :, 2]
        new_label = np.ones(label.shape) * 100
        new_label = new_label.astype(int)
        for cnt in range(37):
            new_label = (
                label == self.origin_id[cnt]) * (cnt - 100) + new_label
        new_label = (label == self.origin_id[37]) * (36 - 100) + new_label
        assert(100 not in np.unique(new_label))
        out_images.append(new_label)
        return out_images


class GroupRandomRotation(object):
    def __init__(self, degree=(-10, 10), interpolation=(cv2.INTER_LINEAR, cv2.INTER_NEAREST), padding=None):
        self.degree = degree
        self.interpolation = interpolation
        self.padding = padding
        if self.padding is None:
            self.padding = [0, 0]

    def __call__(self, img_group):
        assert (len(self.interpolation) == len(img_group))
        v = random.random()
        if v < 0.5:
            degree = random.uniform(self.degree[0], self.degree[1])
            h, w = img_group[0].shape[0:2]
            center = (w / 2, h / 2)
            map_matrix = cv2.getRotationMatrix2D(center, degree, 1.0)
            out_images = list()
            for img, interpolation, padding in zip(img_group, self.interpolation, self.padding):
                out_images.append(cv2.warpAffine(
                    img, map_matrix, (w, h), flags=interpolation, borderMode=cv2.BORDER_CONSTANT, borderValue=padding))
                if len(img.shape) > len(out_images[-1].shape):
                    out_images[-1] = out_images[-1][...,
                                                    np.newaxis]  # single channel image
            return out_images
        else:
            return img_group


class GroupRandomBlur(object):
    def __init__(self, applied):
        self.applied = applied

    def __call__(self, img_group):
        assert (len(self.applied) == len(img_group))
        v = random.random()
        if v < 0.5:
            out_images = []
            for img, a in zip(img_group, self.applied):
                if a:
                    img = cv2.GaussianBlur(
                        img, (5, 5), random.uniform(1e-6, 0.6))
                out_images.append(img)
                if len(img.shape) > len(out_images[-1].shape):
                    out_images[-1] = out_images[-1][...,
                                                    np.newaxis]  # single channel image
            return out_images
        else:
            return img_group


class GroupRandomHorizontalFlip(object):
    """Randomly horizontally flips the given numpy Image with a probability of 0.5
    """

    def __init__(self, is_flow=False):
        self.is_flow = is_flow

    def __call__(self, img_group, is_flow=False):
        v = random.random()
        if v < 0.5:
            out_images = [np.fliplr(img) for img in img_group]
            if self.is_flow:
                for i in range(0, len(out_images), 2):
                    # invert flow pixel values when flipping
                    out_images[i] = -out_images[i]
            return out_images
        else:
            return img_group


class GroupNormalize(object):
    def __init__(self, mean, std):
        self.mean = mean
        self.std = std

    def __call__(self, img_group):
        out_images = list()
        for img, m, s in zip(img_group, self.mean, self.std):
            if len(m) == 1:
                img = img - np.array(m)  # single channel image
                img = img / np.array(s)
            else:
                img = img - np.array(m)[np.newaxis, np.newaxis, ...]
                img = img / np.array(s)[np.newaxis, np.newaxis, ...]
            out_images.append(img)

        return out_images
Download .txt
gitextract_wg61ge90/

├── .gitignore
├── INSTALL.md
├── LICENSE
├── README.md
├── configs/
│   ├── culane.py
│   └── tusimple.py
├── datasets/
│   ├── __init__.py
│   ├── base_dataset.py
│   ├── culane.py
│   ├── registry.py
│   └── tusimple.py
├── main.py
├── models/
│   ├── __init__.py
│   ├── decoder.py
│   ├── registry.py
│   ├── resa.py
│   └── resnet.py
├── requirement.txt
├── runner/
│   ├── __init__.py
│   ├── evaluator/
│   │   ├── __init__.py
│   │   ├── culane/
│   │   │   ├── culane.py
│   │   │   ├── lane_evaluation/
│   │   │   │   ├── .gitignore
│   │   │   │   ├── Makefile
│   │   │   │   ├── include/
│   │   │   │   │   ├── counter.hpp
│   │   │   │   │   ├── hungarianGraph.hpp
│   │   │   │   │   ├── lane_compare.hpp
│   │   │   │   │   └── spline.hpp
│   │   │   │   └── src/
│   │   │   │       ├── counter.cpp
│   │   │   │       ├── evaluate.cpp
│   │   │   │       ├── lane_compare.cpp
│   │   │   │       └── spline.cpp
│   │   │   └── prob2lines.py
│   │   └── tusimple/
│   │       ├── getLane.py
│   │       ├── lane.py
│   │       └── tusimple.py
│   ├── logger.py
│   ├── net_utils.py
│   ├── optimizer.py
│   ├── recorder.py
│   ├── registry.py
│   ├── resa_trainer.py
│   ├── runner.py
│   └── scheduler.py
├── tools/
│   └── generate_seg_tusimple.py
└── utils/
    ├── __init__.py
    ├── config.py
    ├── registry.py
    └── transforms.py
Download .txt
SYMBOL INDEX (236 symbols across 31 files)

FILE: datasets/base_dataset.py
  class BaseDataset (line 13) | class BaseDataset(Dataset):
    method __init__ (line 14) | def __init__(self, img_path, data_list, list_path='list', cfg=None):
    method transform_train (line 30) | def transform_train(self):
    method transform_val (line 33) | def transform_val(self):
    method view (line 41) | def view(self, img, coords, file_path=None):
    method init (line 55) | def init(self):
    method __len__ (line 59) | def __len__(self):
    method __getitem__ (line 62) | def __getitem__(self, idx):

FILE: datasets/culane.py
  class CULane (line 13) | class CULane(BaseDataset):
    method __init__ (line 14) | def __init__(self, img_path, data_list, cfg=None):
    method init (line 19) | def init(self):
    method transform_train (line 32) | def transform_train(self):
    method probmap2lane (line 42) | def probmap2lane(self, probmaps, exists, pts=18):

FILE: datasets/registry.py
  function build (line 7) | def build(cfg, registry, default_args=None):
  function build_dataset (line 17) | def build_dataset(split_cfg, cfg):
  function build_dataloader (line 24) | def build_dataloader(split_cfg, cfg, is_train=True):

FILE: datasets/tusimple.py
  class TuSimple (line 11) | class TuSimple(BaseDataset):
    method __init__ (line 12) | def __init__(self, img_path, data_list, cfg=None):
    method transform_train (line 15) | def transform_train(self):
    method init (line 27) | def init(self):
    method fix_gap (line 42) | def fix_gap(self, coordinate):
    method is_short (line 68) | def is_short(self, lane):
    method get_lane (line 75) | def get_lane(self, prob_map, y_px_gap, pts, thresh, resize_shape=None):
    method probmap2lane (line 109) | def probmap2lane(self, seg_pred, exist, resize_shape=(720, 1280), smoo...

FILE: main.py
  function main (line 20) | def main():
  function parse_args (line 44) | def parse_args():

FILE: models/decoder.py
  class PlainDecoder (line 4) | class PlainDecoder(nn.Module):
    method __init__ (line 5) | def __init__(self, cfg):
    method forward (line 12) | def forward(self, x):
  function conv1x1 (line 21) | def conv1x1(in_planes, out_planes, stride=1):
  class non_bottleneck_1d (line 26) | class non_bottleneck_1d(nn.Module):
    method __init__ (line 27) | def __init__(self, chann, dropprob, dilated):
    method forward (line 48) | def forward(self, input):
  class UpsamplerBlock (line 67) | class UpsamplerBlock(nn.Module):
    method __init__ (line 68) | def __init__(self, ninput, noutput, up_width, up_height):
    method forward (line 87) | def forward(self, input):
  class BUSD (line 103) | class BUSD(nn.Module):
    method __init__ (line 104) | def __init__(self, cfg):
    method forward (line 121) | def forward(self, input):

FILE: models/registry.py
  function build (line 5) | def build(cfg, registry, default_args=None):
  function build_net (line 15) | def build_net(cfg):

FILE: models/resa.py
  class RESA (line 10) | class RESA(nn.Module):
    method __init__ (line 11) | def __init__(self, cfg):
    method forward (line 58) | def forward(self, x):
  class ExistHead (line 77) | class ExistHead(nn.Module):
    method __init__ (line 78) | def __init__(self, cfg=None):
    method forward (line 90) | def forward(self, x):
  class RESANet (line 106) | class RESANet(nn.Module):
    method __init__ (line 107) | def __init__(self, cfg):
    method forward (line 115) | def forward(self, batch):

FILE: models/resnet.py
  function conv3x3 (line 22) | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
  function conv1x1 (line 28) | def conv1x1(in_planes, out_planes, stride=1):
  class BasicBlock (line 33) | class BasicBlock(nn.Module):
    method __init__ (line 36) | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
    method forward (line 56) | def forward(self, x):
  class Bottleneck (line 75) | class Bottleneck(nn.Module):
    method __init__ (line 78) | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
    method forward (line 95) | def forward(self, x):
  class ResNetWrapper (line 118) | class ResNetWrapper(nn.Module):
    method __init__ (line 120) | def __init__(self, cfg):
    method forward (line 139) | def forward(self, x):
  class ResNet (line 146) | class ResNet(nn.Module):
    method __init__ (line 148) | def __init__(self, block, layers, zero_init_residual=False,
    method _make_layer (line 204) | def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
    method forward (line 228) | def forward(self, x):
  function _resnet (line 247) | def _resnet(arch, block, layers, pretrained, progress, **kwargs):
  function resnet18 (line 256) | def resnet18(pretrained=False, progress=True, **kwargs):
  function resnet34 (line 268) | def resnet34(pretrained=False, progress=True, **kwargs):
  function resnet50 (line 280) | def resnet50(pretrained=False, progress=True, **kwargs):
  function resnet101 (line 292) | def resnet101(pretrained=False, progress=True, **kwargs):
  function resnet152 (line 304) | def resnet152(pretrained=False, progress=True, **kwargs):
  function resnext50_32x4d (line 316) | def resnext50_32x4d(pretrained=False, progress=True, **kwargs):
  function resnext101_32x8d (line 330) | def resnext101_32x8d(pretrained=False, progress=True, **kwargs):
  function wide_resnet50_2 (line 344) | def wide_resnet50_2(pretrained=False, progress=True, **kwargs):
  function wide_resnet101_2 (line 362) | def wide_resnet101_2(pretrained=False, progress=True, **kwargs):

FILE: runner/evaluator/culane/culane.py
  function check (line 14) | def check():
  function read_helper (line 24) | def read_helper(path):
  function call_culane_eval (line 33) | def call_culane_eval(data_dir, output_path='./output'):
  class CULane (line 88) | class CULane(nn.Module):
    method __init__ (line 89) | def __init__(self, cfg):
    method evaluate (line 102) | def evaluate(self, dataset, output, batch):
    method summarize (line 129) | def summarize(self):

FILE: runner/evaluator/culane/lane_evaluation/include/counter.hpp
  class Counter (line 16) | class Counter
    method Counter (line 19) | Counter(int _im_width, int _im_height, double _iou_threshold=0.4, int ...

FILE: runner/evaluator/culane/lane_evaluation/include/hungarianGraph.hpp
  type pipartiteGraph (line 6) | struct pipartiteGraph {
    method matchDfs (line 12) | bool matchDfs(int u) {
    method resize (line 26) | void resize(int leftNum, int rightNum) {
    method match (line 38) | void match() {

FILE: runner/evaluator/culane/lane_evaluation/include/lane_compare.hpp
  class LaneCompare (line 27) | class LaneCompare{
    type CompareMode (line 29) | enum CompareMode{
    method LaneCompare (line 34) | LaneCompare(int _im_width, int _im_height, int _lane_width = 10, Compa...

FILE: runner/evaluator/culane/lane_evaluation/include/spline.hpp
  type Func (line 11) | struct Func {
  class Spline (line 22) | class Spline {

FILE: runner/evaluator/culane/lane_evaluation/src/evaluate.cpp
  function help (line 21) | void help(void) {
  function main (line 48) | int main(int argc, char **argv) {
  function read_lane_file (line 204) | void read_lane_file(const string &file_name, vector<vector<Point2f>> &la...
  function visualize (line 226) | void visualize(string &full_im_name, vector<vector<Point2f>> &anno_lanes,

FILE: runner/evaluator/culane/prob2lines.py
  function getLane (line 9) | def getLane(probmap, pts, cfg = None):
  function prob2lines (line 24) | def prob2lines(prob_dir, out_dir, list_file, cfg = None):

FILE: runner/evaluator/tusimple/getLane.py
  function isShort (line 4) | def isShort(lane):
  function fixGap (line 11) | def fixGap(coordinate):
  function getLane_tusimple (line 37) | def getLane_tusimple(prob_map, y_px_gap, pts, thresh, resize_shape=None,...
  function prob2lines_tusimple (line 70) | def prob2lines_tusimple(seg_pred, exist, resize_shape=None, smooth=True,...

FILE: runner/evaluator/tusimple/lane.py
  class LaneEval (line 6) | class LaneEval(object):
    method get_angle (line 12) | def get_angle(xs, y_samples):
    method line_accuracy (line 23) | def line_accuracy(pred, gt, thresh):
    method bench (line 29) | def bench(pred, gt, y_samples, running_time):
    method bench_one_submit (line 58) | def bench_one_submit(pred_file, gt_file):

FILE: runner/evaluator/tusimple/tusimple.py
  function split_path (line 13) | def split_path(path):
  class Tusimple (line 28) | class Tusimple(nn.Module):
    method __init__ (line 29) | def __init__(self, cfg):
    method evaluate_pred (line 44) | def evaluate_pred(self, dataset, seg_pred, exist_pred, batch):
    method evaluate (line 91) | def evaluate(self, dataset, output, batch):
    method summarize (line 98) | def summarize(self):

FILE: runner/logger.py
  function get_logger (line 5) | def get_logger(name, log_file=None, log_level=logging.INFO):

FILE: runner/net_utils.py
  function save_model (line 9) | def save_model(net, optim, scheduler, recorder, is_best=False):
  function load_network_specified (line 23) | def load_network_specified(net, model_dir, logger=None):
  function load_network (line 36) | def load_network(net, model_dir, finetune_from=None, logger=None):

FILE: runner/optimizer.py
  function build_optimizer (line 10) | def build_optimizer(cfg, net):

FILE: runner/recorder.py
  class SmoothedValue (line 8) | class SmoothedValue(object):
    method __init__ (line 13) | def __init__(self, window_size=20):
    method update (line 18) | def update(self, value):
    method median (line 24) | def median(self):
    method avg (line 29) | def avg(self):
    method global_avg (line 34) | def global_avg(self):
  class Recorder (line 38) | class Recorder(object):
    method __init__ (line 39) | def __init__(self, cfg):
    method get_work_dir (line 57) | def get_work_dir(self):
    method update_loss_stats (line 65) | def update_loss_stats(self, loss_dict):
    method record (line 69) | def record(self, prefix, step=-1, loss_stats=None, image_stats=None):
    method write (line 73) | def write(self, content):
    method state_dict (line 78) | def state_dict(self):
    method load_state_dict (line 83) | def load_state_dict(self, scalar_dict):
    method __str__ (line 86) | def __str__(self):
  function build_recorder (line 98) | def build_recorder(cfg):

FILE: runner/registry.py
  function build (line 6) | def build(cfg, registry, default_args=None):
  function build_trainer (line 15) | def build_trainer(cfg):
  function build_evaluator (line 18) | def build_evaluator(cfg):

FILE: runner/resa_trainer.py
  function dice_loss (line 7) | def dice_loss(input, target):
  class RESA (line 18) | class RESA(nn.Module):
    method __init__ (line 19) | def __init__(self, cfg):
    method forward (line 32) | def forward(self, net, batch):

FILE: runner/runner.py
  class Runner (line 16) | class Runner(object):
    method __init__ (line 17) | def __init__(self, cfg):
    method resume (line 32) | def resume(self):
    method to_cuda (line 38) | def to_cuda(self, batch):
    method train_epoch (line 45) | def train_epoch(self, epoch, train_loader):
    method train (line 73) | def train(self):
    method validate (line 89) | def validate(self, val_loader):
    method save_ckpt (line 105) | def save_ckpt(self, is_best=False):

FILE: runner/scheduler.py
  function build_scheduler (line 10) | def build_scheduler(cfg, optimizer):

FILE: tools/generate_seg_tusimple.py
  function gen_label_for_json (line 12) | def gen_label_for_json(args, image_set):
  function generate_json_file (line 81) | def generate_json_file(save_dir, json_file, image_set):
  function generate_label (line 88) | def generate_label(args):

FILE: utils/config.py
  function check_file_exist (line 19) | def check_file_exist(filename, msg_tmpl='file "{}" does not exist'):
  class ConfigDict (line 25) | class ConfigDict(Dict):
    method __missing__ (line 27) | def __missing__(self, name):
    method __getattr__ (line 30) | def __getattr__(self, name):
  function add_args (line 43) | def add_args(parser, cfg, prefix=''):
  class Config (line 62) | class Config:
    method _validate_py_syntax (line 86) | def _validate_py_syntax(filename):
    method _file2dict (line 96) | def _file2dict(filename):
    method _merge_a_into_b (line 159) | def _merge_a_into_b(a, b):
    method fromfile (line 178) | def fromfile(filename):
    method auto_argparser (line 183) | def auto_argparser(description=None):
    method __init__ (line 195) | def __init__(self, cfg_dict=None, cfg_text=None, filename=None):
    method filename (line 217) | def filename(self):
    method text (line 221) | def text(self):
    method pretty_text (line 225) | def pretty_text(self):
    method __repr__ (line 318) | def __repr__(self):
    method __len__ (line 321) | def __len__(self):
    method __getattr__ (line 324) | def __getattr__(self, name):
    method __getitem__ (line 327) | def __getitem__(self, name):
    method __setattr__ (line 330) | def __setattr__(self, name, value):
    method __setitem__ (line 335) | def __setitem__(self, name, value):
    method __iter__ (line 340) | def __iter__(self):
    method dump (line 343) | def dump(self, file=None):
    method merge_from_dict (line 359) | def merge_from_dict(self, options):
  class DictAction (line 388) | class DictAction(Action):
    method _parse_int_float_bool (line 396) | def _parse_int_float_bool(val):
    method __call__ (line 409) | def __call__(self, parser, namespace, values, option_string=None):

FILE: utils/registry.py
  function is_str (line 7) | def is_str(x):
  class Registry (line 11) | class Registry(object):
    method __init__ (line 13) | def __init__(self, name):
    method __repr__ (line 17) | def __repr__(self):
    method name (line 23) | def name(self):
    method module_dict (line 27) | def module_dict(self):
    method get (line 30) | def get(self, key):
    method _register_module (line 33) | def _register_module(self, module_class):
    method register_module (line 48) | def register_module(self, cls):
  function build_from_cfg (line 53) | def build_from_cfg(cfg, registry, default_args=None):

FILE: utils/transforms.py
  class SampleResize (line 13) | class SampleResize(object):
    method __init__ (line 14) | def __init__(self, size):
    method __call__ (line 18) | def __call__(self, sample):
  class GroupRandomCrop (line 28) | class GroupRandomCrop(object):
    method __init__ (line 29) | def __init__(self, size):
    method __call__ (line 35) | def __call__(self, img_group):
  class GroupRandomCropRatio (line 51) | class GroupRandomCropRatio(object):
    method __init__ (line 52) | def __init__(self, size):
    method __call__ (line 58) | def __call__(self, img_group):
  class GroupCenterCrop (line 74) | class GroupCenterCrop(object):
    method __init__ (line 75) | def __init__(self, size):
    method __call__ (line 81) | def __call__(self, img_group):
  class GroupRandomPad (line 97) | class GroupRandomPad(object):
    method __init__ (line 98) | def __init__(self, size, padding):
    method __call__ (line 105) | def __call__(self, img_group):
  class GroupCenterPad (line 126) | class GroupCenterPad(object):
    method __init__ (line 127) | def __init__(self, size, padding):
    method __call__ (line 134) | def __call__(self, img_group):
  class GroupConcerPad (line 155) | class GroupConcerPad(object):
    method __init__ (line 156) | def __init__(self, size, padding):
    method __call__ (line 163) | def __call__(self, img_group):
  class GroupRandomScaleNew (line 184) | class GroupRandomScaleNew(object):
    method __init__ (line 185) | def __init__(self, size=(976, 208), interpolation=(cv2.INTER_LINEAR, c...
    method __call__ (line 189) | def __call__(self, img_group):
  class GroupRandomScale (line 202) | class GroupRandomScale(object):
    method __init__ (line 203) | def __init__(self, size=(0.5, 1.5), interpolation=(cv2.INTER_LINEAR, c...
    method __call__ (line 207) | def __call__(self, img_group):
  class GroupRandomMultiScale (line 220) | class GroupRandomMultiScale(object):
    method __init__ (line 221) | def __init__(self, size=(0.5, 1.5), interpolation=(cv2.INTER_LINEAR, c...
    method __call__ (line 225) | def __call__(self, img_group):
  class GroupRandomScaleRatio (line 239) | class GroupRandomScaleRatio(object):
    method __init__ (line 240) | def __init__(self, size=(680, 762, 562, 592), interpolation=(cv2.INTER...
    method __call__ (line 246) | def __call__(self, img_group):
  class GroupRandomRotation (line 271) | class GroupRandomRotation(object):
    method __init__ (line 272) | def __init__(self, degree=(-10, 10), interpolation=(cv2.INTER_LINEAR, ...
    method __call__ (line 279) | def __call__(self, img_group):
  class GroupRandomBlur (line 299) | class GroupRandomBlur(object):
    method __init__ (line 300) | def __init__(self, applied):
    method __call__ (line 303) | def __call__(self, img_group):
  class GroupRandomHorizontalFlip (line 321) | class GroupRandomHorizontalFlip(object):
    method __init__ (line 325) | def __init__(self, is_flow=False):
    method __call__ (line 328) | def __call__(self, img_group, is_flow=False):
  class GroupNormalize (line 341) | class GroupNormalize(object):
    method __init__ (line 342) | def __init__(self, mean, std):
    method __call__ (line 346) | def __call__(self, img_group):
Condensed preview — 48 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (164K chars).
[
  {
    "path": ".gitignore",
    "chars": 145,
    "preview": "work_dirs/\npredicts/\noutput/\ndata/\ndata\n\n__pycache__/\n*/*.un~\n.*.swp\n\n\n\n*.egg-info/\n*.egg\n\noutput.txt\n.vscode/*\n.DS_Stor"
  },
  {
    "path": "INSTALL.md",
    "chars": 2551,
    "preview": "\n# Install\n\n1. Clone the RESA repository\n    ```\n    git clone https://github.com/zjulearning/resa.git\n    ```\n    We ca"
  },
  {
    "path": "LICENSE",
    "chars": 11304,
    "preview": "Apache License\n                           Version 2.0, January 2004\n                        http://www.apache.org/licens"
  },
  {
    "path": "README.md",
    "chars": 4872,
    "preview": "# RESA \nPyTorch implementation of the paper \"[RESA: Recurrent Feature-Shift Aggregator for Lane Detection](https://arxiv"
  },
  {
    "path": "configs/culane.py",
    "chars": 1377,
    "preview": "net = dict(\n    type='RESANet',\n)\n\nbackbone = dict(\n    type='ResNetWrapper',\n    resnet='resnet50',\n    pretrained=True"
  },
  {
    "path": "configs/tusimple.py",
    "chars": 1476,
    "preview": "net = dict(\n    type='RESANet',\n)\n\nbackbone = dict(\n    type='ResNetWrapper',\n    resnet='resnet34',\n    pretrained=True"
  },
  {
    "path": "datasets/__init__.py",
    "chars": 113,
    "preview": "from .registry import build_dataset, build_dataloader\n\nfrom .tusimple import TuSimple\nfrom .culane import CULane\n"
  },
  {
    "path": "datasets/base_dataset.py",
    "chars": 2754,
    "preview": "import os.path as osp\nimport os\nimport numpy as np\nimport cv2\nimport torch\nfrom torch.utils.data import Dataset\nimport t"
  },
  {
    "path": "datasets/culane.py",
    "chars": 2668,
    "preview": "import os\nimport os.path as osp\nimport numpy as np\nimport torchvision\nimport utils.transforms as tf\nfrom .base_dataset i"
  },
  {
    "path": "datasets/registry.py",
    "chars": 952,
    "preview": "from utils import Registry, build_from_cfg\n\nimport torch\n\nDATASETS = Registry('datasets')\n\ndef build(cfg, registry, defa"
  },
  {
    "path": "datasets/tusimple.py",
    "chars": 5809,
    "preview": "import os.path as osp\nimport numpy as np\nimport cv2\nimport torchvision\nimport utils.transforms as tf\nfrom .base_dataset "
  },
  {
    "path": "main.py",
    "chars": 1993,
    "preview": "import os\nimport os.path as osp\nimport time\nimport shutil\nimport torch\nimport torchvision\nimport torch.nn.parallel\nimpor"
  },
  {
    "path": "models/__init__.py",
    "chars": 20,
    "preview": "from .resa import *\n"
  },
  {
    "path": "models/decoder.py",
    "chars": 4305,
    "preview": "from torch import nn\nimport torch.nn.functional as F\n\nclass PlainDecoder(nn.Module):\n    def __init__(self, cfg):\n      "
  },
  {
    "path": "models/registry.py",
    "chars": 434,
    "preview": "from utils import Registry, build_from_cfg\n\nNET = Registry('net')\n\ndef build(cfg, registry, default_args=None):\n    if i"
  },
  {
    "path": "models/resa.py",
    "chars": 4015,
    "preview": "import torch.nn as nn\nimport torch\nimport torch.nn.functional as F\n\nfrom models.registry import NET\nfrom .resnet import "
  },
  {
    "path": "models/resnet.py",
    "chars": 14705,
    "preview": "import torch\nfrom torch import nn\nimport torch.nn.functional as F\nfrom torch.hub import load_state_dict_from_url\n\n\n# Thi"
  },
  {
    "path": "requirement.txt",
    "chars": 79,
    "preview": "pandas\naddict\nsklearn\nopencv-python\npytorch_warmup\nscikit-image\ntqdm\ntermcolor\n"
  },
  {
    "path": "runner/__init__.py",
    "chars": 93,
    "preview": "from .evaluator import *\nfrom .resa_trainer import *\n\nfrom .registry import build_evaluator \n"
  },
  {
    "path": "runner/evaluator/__init__.py",
    "chars": 74,
    "preview": "from .tusimple.tusimple import Tusimple\nfrom .culane.culane import CULane\n"
  },
  {
    "path": "runner/evaluator/culane/culane.py",
    "chars": 7243,
    "preview": "import torch.nn as nn\nimport torch\nimport torch.nn.functional as F\nfrom runner.logger import get_logger\n\nfrom runner.reg"
  },
  {
    "path": "runner/evaluator/culane/lane_evaluation/.gitignore",
    "chars": 16,
    "preview": "build/\nevaluate\n"
  },
  {
    "path": "runner/evaluator/culane/lane_evaluation/Makefile",
    "chars": 1358,
    "preview": "PROJECT_NAME:= evaluate\n\n# config ----------------------------------\nOPENCV_VERSION := 3\n\nINCLUDE_DIRS := include\nLIBRAR"
  },
  {
    "path": "runner/evaluator/culane/lane_evaluation/include/counter.hpp",
    "chars": 1340,
    "preview": "#ifndef COUNTER_HPP\n#define COUNTER_HPP\n\n#include \"lane_compare.hpp\"\n#include \"hungarianGraph.hpp\"\n#include <iostream>\n#"
  },
  {
    "path": "runner/evaluator/culane/lane_evaluation/include/hungarianGraph.hpp",
    "chars": 2532,
    "preview": "#ifndef HUNGARIAN_GRAPH_HPP\n#define HUNGARIAN_GRAPH_HPP\n#include <vector>\nusing namespace std;\n\nstruct pipartiteGraph {"
  },
  {
    "path": "runner/evaluator/culane/lane_evaluation/include/lane_compare.hpp",
    "chars": 1066,
    "preview": "#ifndef LANE_COMPARE_HPP\n#define LANE_COMPARE_HPP\n\n#include \"spline.hpp\"\n#include <vector>\n#include <iostream>\n#include "
  },
  {
    "path": "runner/evaluator/culane/lane_evaluation/include/spline.hpp",
    "chars": 572,
    "preview": "#ifndef SPLINE_HPP\n#define SPLINE_HPP\n#include <vector>\n#include <cstdio>\n#include <math.h>\n#include <opencv2/core/core."
  },
  {
    "path": "runner/evaluator/culane/lane_evaluation/src/counter.cpp",
    "chars": 2887,
    "preview": "/*************************************************************************\n\t> File Name: counter.cpp\n\t> Author: Xingang "
  },
  {
    "path": "runner/evaluator/culane/lane_evaluation/src/evaluate.cpp",
    "chars": 9470,
    "preview": "/*************************************************************************\n        > File Name: evaluate.cpp\n        > A"
  },
  {
    "path": "runner/evaluator/culane/lane_evaluation/src/lane_compare.cpp",
    "chars": 1918,
    "preview": "/*************************************************************************\n\t> File Name: lane_compare.cpp\n\t> Author: Xin"
  },
  {
    "path": "runner/evaluator/culane/lane_evaluation/src/spline.cpp",
    "chars": 5247,
    "preview": "#include <vector>\n#include <iostream>\n#include \"spline.hpp\"\nusing namespace std;\nusing namespace cv;\n\nvector<Point2f> Sp"
  },
  {
    "path": "runner/evaluator/culane/prob2lines.py",
    "chars": 1783,
    "preview": "import os\nimport argparse\nimport numpy as np\nimport pandas as pd\nfrom PIL import Image\nimport tqdm\n\n\ndef getLane(probmap"
  },
  {
    "path": "runner/evaluator/tusimple/getLane.py",
    "chars": 3957,
    "preview": "import cv2\nimport numpy as np\n\ndef isShort(lane):\n    start = [i for i, x in enumerate(lane) if x > 0]\n    if not start:"
  },
  {
    "path": "runner/evaluator/tusimple/lane.py",
    "chars": 4073,
    "preview": "import numpy as np\nfrom sklearn.linear_model import LinearRegression\nimport json as json\n\n\nclass LaneEval(object):\n    l"
  },
  {
    "path": "runner/evaluator/tusimple/tusimple.py",
    "chars": 4020,
    "preview": "import torch.nn as nn\nimport torch\nimport torch.nn.functional as F\nfrom runner.logger import get_logger\n\nfrom runner.reg"
  },
  {
    "path": "runner/logger.py",
    "chars": 1800,
    "preview": "import logging\n\nlogger_initialized = {}\n\ndef get_logger(name, log_file=None, log_level=logging.INFO):\n    \"\"\"Initialize "
  },
  {
    "path": "runner/net_utils.py",
    "chars": 1435,
    "preview": "import torch\nimport os\nfrom torch import nn\nimport numpy as np\nimport torch.nn.functional\nfrom termcolor import colored\n"
  },
  {
    "path": "runner/optimizer.py",
    "chars": 718,
    "preview": "import torch\n\n\n_optimizer_factory = {\n    'adam': torch.optim.Adam,\n    'sgd': torch.optim.SGD\n}\n\n\ndef build_optimizer(c"
  },
  {
    "path": "runner/recorder.py",
    "chars": 3050,
    "preview": "from collections import deque, defaultdict\nimport torch\nimport os\nimport datetime\nfrom .logger import get_logger\n\n\nclass"
  },
  {
    "path": "runner/registry.py",
    "chars": 585,
    "preview": "from utils import Registry, build_from_cfg\n\nTRAINER = Registry('trainer')\nEVALUATOR = Registry('evaluator')\n\ndef build(c"
  },
  {
    "path": "runner/resa_trainer.py",
    "chars": 1909,
    "preview": "import torch.nn as nn\nimport torch\nimport torch.nn.functional as F\n\nfrom runner.registry import TRAINER\n\ndef dice_loss(i"
  },
  {
    "path": "runner/runner.py",
    "chars": 4079,
    "preview": "import time\nimport torch\nimport numpy as np\nfrom tqdm import tqdm\nimport pytorch_warmup as warmup\n\nfrom models.registry "
  },
  {
    "path": "runner/scheduler.py",
    "chars": 354,
    "preview": "import torch\nimport math\n\n\n_scheduler_factory = {\n    'LambdaLR': torch.optim.lr_scheduler.LambdaLR,\n}\n\n\ndef build_sched"
  },
  {
    "path": "tools/generate_seg_tusimple.py",
    "chars": 4192,
    "preview": "import json\nimport numpy as np\nimport cv2\nimport os\nimport argparse\n\nTRAIN_SET = ['label_data_0313.json', 'label_data_06"
  },
  {
    "path": "utils/__init__.py",
    "chars": 74,
    "preview": "from .config import Config\nfrom .registry import Registry, build_from_cfg\n"
  },
  {
    "path": "utils/config.py",
    "chars": 14742,
    "preview": "# Copyright (c) Open-MMLab. All rights reserved.\nimport ast\nimport os.path as osp\nimport shutil\nimport sys\nimport tempfi"
  },
  {
    "path": "utils/registry.py",
    "chars": 2419,
    "preview": "import inspect\n\nimport six\n\n# borrow from mmdetection\n\ndef is_str(x):\n    \"\"\"Whether the input is an string instance.\"\"\""
  },
  {
    "path": "utils/transforms.py",
    "chars": 13332,
    "preview": "import random\nimport cv2\nimport numpy as np\nimport numbers\nimport collections\n\n# copy from: https://github.com/cardwing/"
  }
]

About this extraction

This page contains the full source code of the ZJULearning/resa GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 48 files (152.3 KB), approximately 41.5k tokens, and a symbol index with 236 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!