Full Code of svip-lab/PPGNet for AI

master ae55d6729964 cached
27 files
139.4 KB
37.3k tokens
198 symbols
1 requests
Download .txt
Repository: svip-lab/PPGNet
Branch: master
Commit: ae55d6729964
Files: 27
Total size: 139.4 KB

Directory structure:
gitextract_j4mo8f9e/

├── .gitattributes
├── .gitignore
├── LICENSE
├── README.md
├── ckpt/
│   └── backbone/
│       ├── decoder_epoch_20.pth
│       └── encoder_epoch_20.pth
├── data/
│   ├── common.py
│   ├── line_graph.py
│   ├── sist_line.py
│   ├── transforms.py
│   ├── utils.py
│   └── york_urban.py
├── main.py
├── models/
│   ├── __init__.py
│   ├── backbone.py
│   ├── common.py
│   ├── graph.py
│   ├── lsd.py
│   ├── lsd_test.py
│   ├── pretrained/
│   │   └── resnet50-imagenet.pth
│   ├── test_common.py
│   └── test_graph.py
├── test.py
├── test.sh
├── tools/
│   └── rebuild_yorkurban.py
├── train.sh
└── utils.py

================================================
FILE CONTENTS
================================================

================================================
FILE: .gitattributes
================================================
*.pth filter=lfs diff=lfs merge=lfs -text


================================================
FILE: .gitignore
================================================
.git
.idea
log
__pycache__


================================================
FILE: LICENSE
================================================
MIT License

Copyright (c) 2019

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.


================================================
FILE: README.md
================================================
# PPGNet: Learning Point-Pair Graph for Line Segment Detection

PyTorch implementation of our CVPR 2019 paper:

[**PPGNet: Learning Point-Pair Graph for Line Segment Detection**](https://www.aiyoggle.me/publication/ppgnet-cvpr19/ppgnet-cvpr19.pdf)

Ziheng Zhang*, Zhengxin Li*, Ning Bi, Jia Zheng, Jinlei Wang, Kun Huang, Weixin Luo, Yanyu Xu, Shenghua Gao

(\* Equal Contribution)

The poster can be found [HERE](https://www.aiyoggle.me/publication/ppgnet-cvpr19).


![pipe-line](https://svip-lab.github.io/img/project/cvpr2019_zhangzh1.png)
**Demonstraton of juncton-line graph representaton G={V, E}. (a) an sample image patch with 10 junctons (V); (b) the graph which describes the connectvity of all junctons (G); (c) the adjacency matrix of all junctons (E, black means the junction pair is connected).** 

## Requirements
- Python >= 3.6
- fire >= 0.1.3
- numba >= 0.40.0
- numpy >= 1.14.5
- pytorch = 0.4.1
- scikit-learn = 0.19.2
- scipy = 1.1.0
- tensorboard >= 1.11.0
- tensorboardX >= 1.4
- torchvision >= 0.2.1
- OpenCV >= 3.4.3

## Usage

1. clone this repository (and make sure you fetch all .pth files right with [git-lfs](https://git-lfs.github.com/)): `git clone https://github.com/svip-lab/PPGNet.git`
2. download the preprocessed *SIST-Wireframe* dataset from [BaiduPan](https://pan.baidu.com/s/1Sbdi1lL492fhmPL1t1Ov0w) (code:lnfp) or [Google Drive](https://drive.google.com/file/d/1KggPcHCRu8BcOqCvVZCXiB64y9L2nQDf/view?usp=sharing).
3. specify the dataset path in the `train.sh` script. (modify the --data-root parameter)
4. run `train.sh`.

Please note that the code requires the GPU memory to be at least 24GB. For GPU with memory smaller than 24GB, you can use a smaller batch with `--batch-size` parameter and/or change the `--block-inference-size` parameter in `train.sh` to be a smaller integer to avoid the out-of-memory error.

## Citation

Please cite our paper for any purpose of usage.
```
@inproceedings{zhang2019ppgnet,
  title={PPGNet: Learning Point-Pair Graph for Line Segment Detection},
  author={Ziheng Zhang and Zhengxin Li and Ning Bi and Jia Zheng and Jinlei Wang and Kun Huang and Weixin Luo and Yanyu Xu and Shenghua Gao},
  booktitle={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition},
  year={2019}
}
```



================================================
FILE: ckpt/backbone/decoder_epoch_20.pth
================================================
version https://git-lfs.github.com/spec/v1
oid sha256:83be3696848929d3ed93deec1fbe31c94b8acbd40110bf1604e11fad024784fc
size 162474609


================================================
FILE: ckpt/backbone/encoder_epoch_20.pth
================================================
version https://git-lfs.github.com/spec/v1
oid sha256:bb24707e745689a005ca0c857d6c516afd2f532a233b2697f4832333c8af48d9
size 95013117


================================================
FILE: data/common.py
================================================
import numpy as np
from functools import lru_cache as cache


@cache(maxsize=None)
def _assert_valid_param(param):
    A, B, C = param
    assert not np.isclose(A ** 2 + B ** 2, 0), "invalid line param."
    return np.array(param) / np.sqrt(A ** 2 + B ** 2)


def assert_valid_param(param):
    return _assert_valid_param(tuple(param))


@cache(maxsize=None)
def _fit_line(pts):
    P = np.array(pts)
    P = np.hstack((P, np.ones((len(P), 1))))
    assert np.linalg.matrix_rank(P) >= 2, f"points to fit line are not valid: {P}"
    u, s, vt = np.linalg.svd(P)
    param = assert_valid_param(vt[-1])
    res = np.linalg.norm(P.dot(param)) / len(P)

    return param, res


def fit_line(pts):
    return _fit_line(tuple(tuple(pt) for pt in pts))


def dist_pts_to_line(pts, param):
    param = assert_valid_param(param)
    P = np.array([pt for pt in pts])
    P = np.hstack((P, np.ones((len(P), 1))))
    dists = np.abs(P.dot(param))

    return dists


def assert_pts_in_line(pts, param, atol=1.):
    dists = np.array(dist_pts_to_line(pts, param))
    assert np.all(dists < atol)


@cache(maxsize=None)
def _find_pt_in_line(param):
    A, B, C = assert_valid_param(param)
    if np.abs(A) > np.abs(B):
        x, y = -C / A, 0
    else:
        x, y = 0, -C / B

    return np.array([x, y])


def find_pt_in_line(param):
    return _find_pt_in_line(tuple(param))


def project_pts_on_line(pts, param):
    param = assert_valid_param(param)
    P = np.array([pt for pt in pts])
    pt0 = np.array(find_pt_in_line(param))
    e = np.array([-param[1], param[0]])
    alpha = (P - pt0).dot(e)
    P_proj = np.outer(alpha, e) + pt0

    assert P_proj.ndim == 2 and alpha.ndim == 1, f"internal error occored when project pts to line: {P_proj}, {alpha}"

    return P_proj, alpha


def find_lines_intersect(params):
    P = []
    for param in params:
        param = assert_valid_param(param)
        P.append(param)
    P = np.array(P)
    assert np.linalg.matrix_rank(P) >= 2, "lines do not intersect"
    u, s, vh = np.linalg.svd(P)
    x, y, _ = vh[-1] / vh[-1][-1]
    hpt = np.array([x, y, 1])
    dist = np.abs(P.dot(hpt)).mean()

    return np.array([x, y]), dist


@cache(maxsize=None)
def is_pt_in_line_seg(eps, pt, pt1, pt2):
    param, _ = fit_line([pt1, pt2])
    _, alphas = project_pts_on_line([pt, pt1, pt2], param)
    dist = dist_pts_to_line([pt], param)[0]
    return (dist < eps * 2) and (np.min(alphas[1:]) <= alphas[0] <= np.max(alphas[1:]))


================================================
FILE: data/line_graph.py
================================================
import os
from sklearn.neighbors import KDTree
from scipy.cluster.hierarchy import fclusterdata
import pickle
from itertools import combinations
import collections
from data.common import *
import cv2


class LineGraph(object):
    def __init__(
            self, eps_junction=3., eps_line_deg=np.pi / 20, verbose=False
    ):
        self._line_segs = []
        self._junctions = []
        self._end_points = []
        self._refine_junctions = None
        self._neighbor = {}
        self._junc2line = {}
        self._freeze_junction = False
        self._kdtree = None
        self._eps_junc = eps_junction
        self._eps_line_seg = eps_line_deg
        self.verbose = verbose

    def load(self, filename):
        with open(filename, "rb") as f:
            data = pickle.load(f)
            for mem in dir(self):
                if (
                        mem.startswith("_")
                        and not mem.startswith("__")
                        and not isinstance(getattr(self, mem), collections.Callable)
                ):
                    setattr(self, mem, data[mem])
        return self

    def save(self, filename):
        with open(filename, "wb+") as f:
            data = {}
            for mem in dir(self):
                if (
                        mem.startswith("_")
                        and not mem.startswith("__")
                        and not isinstance(getattr(self, mem), collections.Callable)
                ):
                    data[mem] = getattr(self, mem)
            pickle.dump(data, f)
        return self

    def _is_pt_in_line_seg(self, pt, pt1, pt2):

        return is_pt_in_line_seg(self._eps_junc, tuple(pt), tuple(pt1), tuple(pt2))
        # param, _ = fit_line([pt1, pt2])
        # _, alphas = project_pts_on_line([pt, pt1, pt2], param)
        # dist = dist_pts_to_line([pt], param)[0]
        # return (dist < self._eps_junc * 2) and (np.min(alphas[1:]) <= alphas[0] <= np.max(alphas[1:]))

    def freeze_junction(self, status=True):
        self._freeze_junction = status
        if status:
            clusters = fclusterdata(self._junctions, self._eps_junc, criterion="distance")
            junc_groups = {}
            for ind_junc, ind_group in enumerate(clusters):
                if ind_group not in junc_groups.keys():
                    junc_groups[ind_group] = []
                junc_groups[ind_group].append(self._junctions[ind_junc])
            if self.verbose:
                print(f"{len(self._junctions) - len(junc_groups)} junctions merged.")
            self._junctions = [np.mean(junc_group, axis=0) for junc_group in junc_groups.values()]

            self._kdtree = KDTree(self._junctions, leaf_size=30)
            dists, inds = self._kdtree.query(self._junctions, k=2)
            repl_inds = np.nonzero(dists.sum(axis=1) < self._eps_junc)[0].tolist()
            # assert len(repl_inds) == 0
        else:
            self._kdtree = None

    def _can_be_extented_by(self, line_seg1, line_seg2):
        pt1 = line_seg1["pt1"]
        pt2 = line_seg1["pt2"]
        _pt1 = line_seg2["pt1"]
        _pt2 = line_seg2["pt2"]
        if self._is_pt_in_line_seg(_pt1, pt1, pt2) or self._is_pt_in_line_seg(_pt2, pt1, pt2):
            arr1 = pt1 - pt2
            arr2 = _pt1 - _pt2
            arr1 /= np.linalg.norm(arr1)
            arr2 /= np.linalg.norm(arr2)
            if np.abs(arr1.dot(arr2)) > np.cos(self._eps_line_seg):
                return True

    def freeze_line_seg(self, status=True):
        assert self._freeze_junction, "junction should be freezed before."
        self._freeze_line_seg = status
        if status:
            # remove all junctions that are not in any line segment
            junc_remove = set(list(range(len(self._junctions))))
            for line_seg in self._line_segs:
                junc_remove -= line_seg["junctions"]

            junc_remain = [ind_junc for ind_junc in range(len(self._junctions)) if ind_junc not in junc_remove]
            self._junctions = [junc for ind_junc, junc in enumerate(self._junctions) if ind_junc not in junc_remove]
            map_old_to_new = {old: new for new, old in enumerate(junc_remain)}
            for line_seg in self._line_segs:
                line_seg["junctions"] = set([map_old_to_new[old] for old in line_seg["junctions"]])
            self._junc2line = {}

            # # extend all line segments
            # cnt = 0
            # finished = False
            # while (not finished):
            #     finished = True
            #     for ind_ls1 in range(len(self._line_segs)):
            #         for ind_ls2 in range(ind_ls1 + 1, len(self._line_segs)):
            #             # if ind_ls1 == ind_ls2:
            #             #     continue
            #             ls1 = self._line_segs[ind_ls1]
            #             ls2 = self._line_segs[ind_ls2]
            #             if ls2["junctions"].issubset(ls1["junctions"]):
            #                 continue
            #             if self._can_be_extented_by(ls1, ls2):
            #                 pt1 = ls1["pt1"]
            #                 pt2 = ls1["pt2"]
            #                 param = ls1["param"]
            #                 _pt1 = ls2["pt1"]
            #                 _pt2 = ls2["pt2"]
            #                 _param = ls2["param"]
            #                 P, alphas = project_pts_on_line([pt1, pt2, _pt1, _pt2], param)
            #                 _P, _alphas = project_pts_on_line([pt1, pt2, _pt1, _pt2], _param)
            #                 ind_min, ind_max = np.argmin(alphas), np.argmax(alphas)
            #                 _ind_min, _ind_max = np.argmin(_alphas), np.argmax(_alphas)
            #                 if np.abs(alphas[ind_min] - alphas[ind_max]) < self._eps_junc:  # this should not happen...
            #                     continue
            #                 ls1["pt1"] = P[ind_min]
            #                 ls1["pt2"] = P[ind_max]
            #                 ls2["pt1"] = _P[_ind_min]
            #                 ls2["pt2"] = _P[_ind_max]
            #                 ls1["junctions"] = ls1["junctions"].union(ls2["junctions"])
            #                 ls2["junctions"] = ls2["junctions"].union(ls1["junctions"])
            #                 cnt += 1
            #                 finished = False
            #
            # if self.verbose:
            #     print(f"line segments extend {cnt} times.", flush=True)
            #
            # # merge line segments of which junction set is subset of that of other line segments
            # finished = False
            # merged_inds = []
            # while (not finished):
            #     finished = True
            #     for ind_ls1 in range(len(self._line_segs)):
            #         for ind_ls2 in range(len(self._line_segs)):
            #             if (ind_ls1 == ind_ls2) or (ind_ls2 in merged_inds):
            #                 continue
            #             ls1 = self._line_segs[ind_ls1]
            #             ls2 = self._line_segs[ind_ls2]
            #             if ls2["junctions"].issubset(ls1["junctions"]):
            #                 merged_inds.append(ind_ls2)
            #                 finished = False
            # self._line_segs = [line_seg for ind, line_seg in enumerate(self._line_segs) if ind not in merged_inds]
            # if self.verbose:
            #     print(f"{len(merged_inds)} line segments merged.", flush=True)

            # # refine line segments w.r.t all associated junctions
            # ling_seg_remove = []
            # for ind_line_seg, line_seg in enumerate(self._line_segs):
            #     junc_set = line_seg["junctions"]
            #     if len(junc_set) == 2:  # no need to refine
            #         continue
            #     ind_innr_junctions = [ind_junc for ind_junc in junc_set]
            #     param, res = fit_line([self._junctions[ind_junc] for ind_junc in ind_innr_junctions])
            #     ind_outliers = np.nonzero(res > self._eps_junc * 2)[0]
            #     # if too many outliers, remove this line segment
            #     if (len(junc_set) - len(ind_outliers) < 2) or (len(ind_outliers) / len(junc_set) > 0.3):
            #         # ling_seg_remove.append(ind_line_seg)
            #         continue
            #     # else remove outliers
            #     for ind_outlier in ind_outliers:
            #         # junc_set.remove(ind_innr_junctions[ind_outlier])
            #         pass
            #     line_seg["param"] = param
            #     # find new endpoints
            #     ind_innr_junctions = [ind_junc for ind_junc in junc_set]
            #     P, alphas = project_pts_on_line([self._junctions[ind_junc] for ind_junc in ind_innr_junctions], param)
            #     ind_max, ind_min = np.argmax(alphas), np.argmin(alphas)
            #     line_seg["pt1"] = P[ind_min]
            #     line_seg["pt2"] = P[ind_max]
            #
            # self._line_segs = [line_seg for ind, line_seg in enumerate(self._line_segs) if ind not in ling_seg_remove]
            # if self.verbose:
            #     print(f"{len(ling_seg_remove)} line segments removed.", flush=True)

            # # remove all junctions that are not in any line segment
            # junc_remove = set(list(range(len(self._junctions))))
            # for line_seg in self._line_segs:
            #     junc_remove -= line_seg["junctions"]
            # junc_remain = [ind_junc for ind_junc in range(len(self._junctions)) if ind_junc not in junc_remove]
            # self._junctions = [junc for ind_junc, junc in enumerate(self._junctions) if ind_junc not in junc_remove]
            # map_old_to_new = {old: new for new, old in enumerate(junc_remain)}
            # for line_seg in self._line_segs:
            #     line_seg["junctions"] = set([map_old_to_new[old] for old in line_seg["junctions"]])
            # self._junc2line = {}
            #
            # # build hash table mapping from junction id to line segment
            # for line_seg in self._line_segs:
            #     for ind_junc in line_seg["junctions"]:
            #         if ind_junc not in self._junc2line.keys():
            #             self._junc2line[ind_junc] = []
            #         self._junc2line[ind_junc].append(line_seg)

            # # remove possibly noise line segments
            # line_seg_remove = []
            # for ind_line_seg, line_seg in enumerate(self._line_segs):
            #     junc_set = line_seg["junctions"]
            #     if len(junc_set) == 1: # not possible
            #         line_seg_remove.append(ind_line_seg)
            #     elif len(junc_set) == 2:
            #         ind_junc1, ind_junc2 = list(junc_set)
            #         if len(self._junc2line[ind_junc1]) >= 5 or len(self._junc2line[ind_junc2]) >= 5:
            #             is_delete = True
            #             for neigh_line_seg in self._junc2line[ind_junc1]:
            #                 if neigh_line_seg is line_seg:
            #                     continue
            #                 arr1 = neigh_line_seg["pt1"] - neigh_line_seg["pt2"]
            #                 arr2 = line_seg["pt1"] - line_seg["pt2"]
            #                 if np.abs(arr1.dot(arr2) / (np.linalg.norm(arr1) * np.linalg.norm(arr2))) < np.cos(
            #                         self._eps_line_seg):
            #                     for neigh_line_seg2 in self._junc2line[ind_junc2]:
            #                         if neigh_line_seg2 is line_seg:
            #                             continue
            #                         arr1 = neigh_line_seg2["pt1"] - neigh_line_seg2["pt2"]
            #                         arr2 = line_seg["pt1"] - line_seg["pt2"]
            #                         if np.abs(arr1.dot(arr2) / (np.linalg.norm(arr1) * np.linalg.norm(arr2))) < np.cos(
            #                                 self._eps_line_seg):
            #                             is_delete = False
            #                             break
            #             if is_delete:
            #                 line_seg_remove.append(ind_line_seg)
            #                 break
            #         if len(self._junc2line[ind_junc1]) == 1 or len(self._junc2line[ind_junc1]) == 1:
            #             for neigh_line_seg in self._junc2line[ind_junc1]:
            #                 if neigh_line_seg is line_seg:
            #                     continue
            #                 arr1 = neigh_line_seg["pt1"] - neigh_line_seg["pt2"]
            #                 arr2 = line_seg["pt1"] - line_seg["pt2"]
            #                 if np.abs(arr1.dot(arr2) / (np.linalg.norm(arr1) * np.linalg.norm(arr2))) < np.cos(self._eps_line_seg):
            #                     line_seg_remove.append(ind_line_seg)
            #                     break
            #             for neigh_line_seg in self._junc2line[ind_junc2]:
            #                 if neigh_line_seg is line_seg:
            #                     continue
            #                 arr1 = neigh_line_seg["pt1"] - neigh_line_seg["pt2"]
            #                 arr2 = line_seg["pt1"] - line_seg["pt2"]
            #                 if np.abs(arr1.dot(arr2) / (np.linalg.norm(arr1) * np.linalg.norm(arr2))) < np.cos(self._eps_line_seg):
            #                     line_seg_remove.append(ind_line_seg)
            #                     break
            # self._line_segs = [line_seg for ind, line_seg in enumerate(self._line_segs) if ind not in line_seg_remove]
            # if self.verbose:
            #     print(f"{len(line_seg_remove)} possibly noisy line segments removed.", flush=True)

            # # remove all junctions that are not in any line segment
            # junc_remove = set(list(range(len(self._junctions))))
            # for line_seg in self._line_segs:
            #     junc_remove -= line_seg["junctions"]
            # junc_remain = [ind_junc for ind_junc in range(len(self._junctions)) if ind_junc not in junc_remove]
            # self._junctions = [junc for ind_junc, junc in enumerate(self._junctions) if ind_junc not in junc_remove]
            # map_old_to_new = {old: new for new, old in enumerate(junc_remain)}
            # for line_seg in self._line_segs:
            #     line_seg["junctions"] = set([map_old_to_new[old] for old in line_seg["junctions"]])
            # self._junc2line = {}

            # build hash table mapping from junction id to line segment
            for line_seg in self._line_segs:
                for ind_junc in line_seg["junctions"]:
                    if ind_junc not in self._junc2line.keys():
                        self._junc2line[ind_junc] = []
                    self._junc2line[ind_junc].append(line_seg)
            assert len(self._junc2line) == len(self._junctions)

            # # refine all junctions w.r.t associated line segments
            # junctions_refined = []
            # for ind_junc, junc in enumerate(self._junctions):
            #     line_segs = self._junc2line[ind_junc]
            #     # assert len(line_segs) > 0, "line seg"
            #     if len(line_segs) == 1:
            #         param = line_segs[0]["param"]
            #         refined, _ = project_pts_on_line([junc], param)
            #         junctions_refined.append(refined[0])
            #     elif len(line_segs) >= 2:
            #         refined, _ = find_lines_intersect([line_seg["param"] for line_seg in line_segs])
            #         if np.linalg.norm(refined - junc) > 3 * self._eps_junc:  # maybe something wrong, do nothing...
            #             junctions_refined.append(junc)
            #         junctions_refined.append(refined)
            # self._junctions = junctions_refined
            self._kdtree = KDTree(self._junctions, leaf_size=30)

            # # remove all junctions that are not in any line segment
            # junc_remove = set(list(range(len(self._junctions))))
            # for line_seg in self._line_segs:
            #     junc_remove -= line_seg["junctions"]
            # junc_remain = [ind_junc for ind_junc in range(len(self._junctions)) if ind_junc not in junc_remove]
            # self._junctions = [junc for ind_junc, junc in enumerate(self._junctions) if ind_junc not in junc_remove]
            # map_old_to_new = {old: new for new, old in enumerate(junc_remain)}
            # for line_seg in self._line_segs:
            #     line_seg["junctions"] = set([map_old_to_new[old] for old in line_seg["junctions"]])
            # self._junc2line = {}

            # # build hash table mapping from junction id to line segment
            # for line_seg in self._line_segs:
            #     for ind_junc in line_seg["junctions"]:
            #         if ind_junc not in self._junc2line.keys():
            #             self._junc2line[ind_junc] = []
            #         self._junc2line[ind_junc].append(line_seg)

            # add line segment intersections
            cnt_new = 0
            for ind_ls1 in range(len(self._line_segs)):
                for ind_ls2 in range(ind_ls1 + 1, len(self._line_segs)):
                    ls1, ls2 = self._line_segs[ind_ls1], self._line_segs[ind_ls2]
                    pt11, pt12 = ls1["pt1"], ls1["pt2"]
                    pt21, pt22 = ls2["pt1"], ls2["pt2"]
                    p = np.array(pt11)
                    r = np.array(pt12) - np.array(pt11)
                    q = np.array(pt21)
                    s = np.array(pt22) - np.array(pt21)
                    alpha = np.cross(r, s)
                    if np.isclose(alpha, 0):
                        continue
                    # if np.abs(np.dot(r, s) / np.linalg.norm(r) / np.linalg.norm(s)) > self._eps_line_seg:
                    #     continue
                    beta_t = np.cross(q - p, s)
                    beta_u = np.cross(q - p, r)
                    t = np.mean(beta_t / alpha)
                    u = np.mean(beta_u / alpha)
                    # exact intersect
                    if 0 <= t <= 1 and 0 <= u <= 1:
                        # print("find exact intersect")
                        assert np.allclose(p + t * r, q + u * s, rtol=1.e-3), "intersecting math assertion (exact)"
                        intersect = p + t * r
                        dists, ind = self._kdtree.query([intersect], k=1)
                        if dists[0, 0] < self._eps_junc:
                            ls1["junctions"].add(ind[0, 0])
                            ls2["junctions"].add(ind[0, 0])
                            self._junc2line[ind[0, 0]].append(ls1)
                            self._junc2line[ind[0, 0]].append(ls2)
                        else:
                            ind = len(self._junctions)
                            self._junctions.append(intersect)
                            ls1["junctions"].add(ind)
                            ls2["junctions"].add(ind)
                            self._junc2line[ind] = [ls1, ls2]
                            self._kdtree = KDTree(self._junctions, leaf_size=30)
                            cnt_new += 1

                    # # close to intersect
                    # elif (min(abs(t), abs(t - 1)) * np.linalg.norm(r) < self._eps_junc * 5) and (
                    #         min(abs(u), abs(u - 1)) * np.linalg.norm(s) < self._eps_junc * 5):
                    #     assert np.allclose(p + t * r, q + u * s, rtol=1.e-3), "intersecting math assertion (close)"
                    #     intersect = p + t * r
                    #     dists, ind = self._kdtree.query([intersect], k=1)
                    #     if dists[0, 0] < self._eps_junc:
                    #         ls1["junctions"].add(ind[0, 0])
                    #         ls2["junctions"].add(ind[0, 0])
                    #         self._junc2line[ind[0, 0]].append(ls1)
                    #         self._junc2line[ind[0, 0]].append(ls2)
                    #     else:
                    #         ind = len(self._junctions)
                    #         self._junctions.append(intersect)
                    #         ls1["junctions"].add(ind)
                    #         ls2["junctions"].add(ind)
                    #         self._junc2line[ind] = [ls1, ls2]
                    #         cnt_new += 1
            if self.verbose:
                print(f"found {cnt_new} new intercept junctions", flush=True)

    def add_junction(self, junction):
        self._junctions.append(np.array(junction))

    def add_line_seg(self, junction1, junction2):
        assert self._freeze_junction
        junc1 = np.array(junction1)
        junc2 = np.array(junction2)
        dist1, ind1 = self._kdtree.query([junc1], k=1)
        dist2, ind2 = self._kdtree.query([junc2], k=1)
        if not (dist1[0, 0] < self._eps_junc and dist2[0, 0] < self._eps_junc):
            if self.verbose:
                print(f"warn: invalid line endpoints: {junc1} -> {junc2}, ignored.")
            return
        if ind1[0, 0] == ind2[0, 0]:
            if self.verbose:
                print(f"warn: zero length line segment found ({junc1} -> {junc2}), ignored.")
            return
        self._line_segs.append(dict(
            pt1=junc1,
            pt2=junc2,
            param=fit_line([junc1, junc2])[0],
            junctions=set([ind1[0, 0], ind2[0, 0]])
        ))

    def junctions(self):
        assert self.freeze_junction and self.freeze_line_seg
        for junc in self._junctions:
            yield junc

    def line_segs(self):
        assert self.freeze_junction and self.freeze_line_seg
        for line_seg in self._line_segs:
            for ind_junc1, ind_junc2 in combinations(line_seg["junctions"], 2):
                yield self._junctions[ind_junc1], self._junctions[ind_junc2]

    def longest_line_segs(self):
        for line_seg in self._line_segs:
            yield line_seg["pt1"], line_seg["pt2"]

    @property
    def adj_mtx(self):
        mtx = np.zeros((len(self._junctions), len(self._junctions)))
        for line_seg in self._line_segs:
            for ind_junc1, ind_junc2 in combinations(line_seg["junctions"], 2):
                mtx[ind_junc1, ind_junc2] = 1
                mtx[ind_junc2, ind_junc1] = 1

        return mtx

    def line_map(self, size, scale_x=1., scale_y=1., line_width=2.):
        if isinstance(size, tuple):
            lmap = np.zeros(size, dtype=np.uint8)
        else:
            lmap = np.zeros((size, size), dtype=np.uint8)
        for line_seg in self._line_segs:
            for ind_junc1, ind_junc2 in combinations(line_seg["junctions"], 2):
                x1, y1 = self._junctions[ind_junc1]
                x2, y2 = self._junctions[ind_junc2]
                x1, x2 = int(x1 * scale_x + 0.5), int(x2 * scale_x + 0.5)
                y1, y2 = int(y1 * scale_y + 0.5), int(y2 * scale_y + 0.5)
                lmap = cv2.line(lmap, (x1, y1), (x2, y2), 255, int(line_width), cv2.LINE_AA)
        # lmap = cv2.GaussianBlur(lmap, (int(line_width), int(line_width)), 1)
        # lmap[lmap > 1] = 1
        return lmap

    @property
    def num_junctions(self):
        return len(self._junctions)

    @property
    def num_line_segs(self):
        return np.sum(
            [
                len(line_seg["junctions"])
                * (len(line_seg["junctions"]) - 1)
                / 2
                for line_seg in self._line_segs
            ]
        )


if __name__ == "__main__":
    from glob import glob
    from tqdm import trange
    data_root = "/home/ziheng/indoorDist_new"
    img = [os.path.join("train", os.path.basename(f)) for f in glob(os.path.join(data_root, "train", "*.jpg"))]
    max_junc = 0
    for item in trange(len(img)):
        lg = LineGraph().load(os.path.join(data_root, img[item][:-4] + ".lg"))
        max_junc = max(lg.num_junctions, max_junc)

    print(max_junc)

================================================
FILE: data/sist_line.py
================================================
import os
import numpy as np
import torch as th
from torch.utils import data
from data.line_graph import LineGraph
from glob import glob
from PIL import Image
from data.utils import gen_gaussian_map


class SISTLine(data.Dataset):
    def __init__(self, data_root, transforms, phase="train", sigma_junction=3., max_junctions=512):
        self.data_root = data_root
        self.img = [os.path.join(phase, os.path.basename(f)) for f in glob(os.path.join(data_root, phase, "*.jpg"))]
        self.transforms = transforms
        self.phase = phase
        self.max_junctions = max_junctions
        self.sigma_junction = sigma_junction

    def __getitem__(self, item):
        img = Image.open(os.path.join(self.data_root, self.img[item]))
        ori_w, ori_h = img.size
        
        lg = LineGraph().load(os.path.join(self.data_root, self.img[item][:-4] + ".lg"))
        num_junc = lg.num_junctions
        # assert num_junc <= self.max_junctions, f"{(item, num_junc)}"
        if self.phase == "train" and num_junc > self.max_junctions:
            return self[(item + 1) % len(self)]
        elif num_junc > self.max_junctions:
            return self[(item + 1) % len(self)]
            # raise AssertionError()
        junc = np.zeros((self.max_junctions, 2))
        # tic = time()
        junc[:num_junc] = np.array([j if np.sum(j) > 0 else j + 1 for j in lg.junctions()])
        # print(f"junc time: {time() - tic:.4f}")

        assert np.sum(junc[:num_junc].sum(axis=1) <= 0) == 0, f"{item}"
        # tic = time()
        adj_mtx = np.zeros((self.max_junctions, self.max_junctions))
        # print(f"mtx time: {time() - tic:.4f}")
        adj_mtx[:num_junc, :num_junc] = lg.adj_mtx

        if self.transforms is not None:
            img, junc = self.transforms(img, junc)

        cur_w, cur_h = img.size

        junc[junc >= img.size[0]] = img.size[0] - 1
        junc[junc < 0] = 0
        # tic = time()
        heatmap = gen_gaussian_map(junc[:num_junc], img.size[:2], self.sigma_junction)
        assert cur_h == cur_w
        line_map = lg.line_map(cur_h, cur_w / ori_w, cur_h / ori_h, line_width=self.sigma_junction)
        # print(f"gaussian time: {time() - tic:.4f}")

        img = np.array(np.asarray(img)[:, :, ::-1])
        img = th.from_numpy(img).permute(2, 0, 1)
        adj_mtx = th.from_numpy(adj_mtx)
        junc = th.from_numpy(junc)
        heatmap = th.from_numpy(heatmap)
        line_map = th.from_numpy(line_map)

        batch = dict(
            image=img.float(),
            adj_mtx=adj_mtx.float(),
            heatmap = heatmap.float(),
            junctions = junc.float(),
            line_map = line_map.float()
        )

        return batch

    def __call__(self, item):
        return self.__getitem__(item)

    def __len__(self):
        return len(self.img)


if __name__ == "__main__":
    from tqdm import trange
    from multiprocessing.pool import Pool
    data = SISTLine("/home/ziheng/indoorDist_new", None, "train")
    # os.makedirs("/home/ziheng/heatmaps")
    pool = Pool(20)
    cnt = 0

    def readnsave(i):
        batch = data[i]
        hm = batch["heatmap"].numpy()
        np.save(f"/home/ziheng/heatmaps/{i}.npy", hm)

    def juncsave(i):
        batch = data[i]
        hm = batch["heatmap"].numpy()
        junc = batch["heatmap"].numpy()
        np.save(f"/home/ziheng/heatmaps/{i}.npy", hm)


    readnsave.cnt = 0
    # for i in trange(len(data)):
    pool.map_async(readnsave, range(len(data)))
    pool.close()
    pool.join()



================================================
FILE: data/transforms.py
================================================
from torchvision.transforms import functional as tf
import numpy as np
from PIL import Image
import random
from functools import partial


class Compose(object):
    def __init__(self, *transforms):
        self.transforms = transforms

    def __call__(self, img, pt):
        for t in self.transforms:
            img, pt = t(img, pt)

        return img, pt


class RandomCompose(object):
    def __init__(self, *transforms):
        self.transforms = transforms

    def __call__(self, img, pt):
        random.shuffle(self.transforms)
        for t in self.transforms:
            img, pt = t(img, pt)

        return img, pt        


class Resize(object):
    def __init__(self, size, interpolation=Image.BILINEAR):
        self.size = size
        self.interpolation = interpolation
    def __call__(self, img, pt):
        w, h = img.size
        y_scale = (self.size[0] - 1) / (h - 1)
        x_scale = (self.size[1] - 1) / (w - 1)

        pt_new = np.zeros_like(pt)
        pt_mask = pt.sum(axis=1) > 0
        pt_new[pt_mask] = np.vstack((pt[pt_mask][:, 0] * x_scale, pt[pt_mask][:, 1] * y_scale)).T

        assert not pt_new.sum() == 0

        return img.resize(self.size[::-1], self.interpolation), pt_new


class RandomHorizontalFlip(object):
    def __init__(self, p=0.5):
        self.p = p

    def __call__(self, img, pt):
        if self.p > np.random.rand():
            w, _ = img.size
            img = tf.hflip(img)
            pt_new = np.zeros_like(pt)
            pt_mask = pt.sum(axis=1) > 0
            pt_new[pt_mask] = np.vstack((w - 1 - pt[pt_mask][:, 0], pt[pt_mask][:, 1])).T
            return img, pt_new
        return img, pt


class RandomColorAug(object):
    def __init__(self, factor=0.2):
        self.factor = factor

    def __call__(self, img, pt):
        transforms = [
            tf.adjust_brightness,
            tf.adjust_contrast,
            tf.adjust_saturation
            ]
        random.shuffle(transforms)
        for t in transforms:
            img = t(img, (np.random.rand() - 0.5) * 2 * self.factor + 1)

        return img, pt

================================================
FILE: data/utils.py
================================================
from numba import jit, float32, int32
import numpy as np    
    

@jit(float32[:, :](float32[:, :], float32[:, :], int32[:, :], int32[:, :], float32), nopython=True, fastmath=True)
def apply_gaussian(accumulate_confid_map, centers, xx, yy, sigma):
    for i in range(len(centers)):
        center = centers[i]
        d2 = (xx - center[0]) ** 2 + (yy - center[1]) ** 2
        exponent = d2 / 2.0 / sigma / sigma
        mask = exponent <= 4.6052
        cofid_map = np.exp(-exponent)
        cofid_map = np.multiply(mask, cofid_map)
        accumulate_confid_map += cofid_map
    return accumulate_confid_map

def gen_gaussian_map(centers, shape, sigma):
    centers = np.float32(centers)
    sigma = np.float32(sigma)
    accumulate_confid_map = np.zeros(shape, dtype=np.float32)
    y_range = np.arange(accumulate_confid_map.shape[0], dtype=np.int32)
    x_range = np.arange(accumulate_confid_map.shape[1], dtype=np.int32)
    xx, yy = np.meshgrid(x_range, y_range)

    accumulate_confid_map = apply_gaussian(accumulate_confid_map, centers, xx, yy, sigma)
    accumulate_confid_map[accumulate_confid_map > 1.0] = 1.0
    
    return accumulate_confid_map


================================================
FILE: data/york_urban.py
================================================
import os
import numpy as np
import torch as th
from torch.utils import data
from data.line_graph import LineGraph
from glob import glob
from PIL import Image
from data.utils import gen_gaussian_map


class YorkUrban(data.Dataset):
    def __init__(self, data_root, transforms, phase="test", sigma_junction=3., max_junctions=800):
        print(f"{phase}")
        assert phase == "eval"
        self.data_root = data_root
        self.img = [os.path.basename(f) for f in glob(os.path.join(data_root, phase, "*.jpg"))]
        self.transforms = transforms
        self.phase = phase
        self.max_junctions = max_junctions
        self.sigma_junction = sigma_junction

    def __getitem__(self, item):
        img = Image.open(os.path.join(self.data_root, self.phase, self.img[item]))
        ori_w, ori_h = img.size

        lg = LineGraph().load(os.path.join(self.data_root, self.phase, self.img[item][:-4] + ".lg"))
        num_junc = lg.num_junctions
        assert num_junc <= self.max_junctions, f"{(item, num_junc)}"
        junc = np.zeros((self.max_junctions, 2))
        # tic = time()
        junc[:num_junc] = np.array([j if np.sum(j) > 0 else j + 1 for j in lg.junctions()])
        # print(f"junc time: {time() - tic:.4f}")

        assert np.sum(junc[:num_junc].sum(axis=1) <= 0) == 0, f"{item}"
        # tic = time()
        adj_mtx = np.zeros((self.max_junctions, self.max_junctions))
        # print(f"mtx time: {time() - tic:.4f}")
        adj_mtx[:num_junc, :num_junc] = lg.adj_mtx

        if self.transforms is not None:
            img, junc = self.transforms(img, junc)

        cur_w, cur_h = img.size

        junc[junc >= img.size[0]] = img.size[0] - 1
        junc[junc < 0] = 0
        # tic = time()
        heatmap = gen_gaussian_map(junc[:num_junc], img.size[:2], self.sigma_junction)
        assert cur_h == cur_w
        line_map = lg.line_map(cur_h, cur_w / ori_w, cur_h / ori_h, line_width=self.sigma_junction)
        # print(f"gaussian time: {time() - tic:.4f}")

        img = np.array(np.asarray(img)[:, :, ::-1])
        img = th.from_numpy(img).permute(2, 0, 1)
        adj_mtx = th.from_numpy(adj_mtx)
        junc = th.from_numpy(junc)
        heatmap = th.from_numpy(heatmap)
        line_map = th.from_numpy(line_map)

        batch = dict(
            image=img.float(),
            adj_mtx=adj_mtx.float(),
            heatmap=heatmap.float(),
            junctions=junc.float(),
            line_map=line_map.float()
        )

        return batch

    def __call__(self, item):
        return self.__getitem__(item)

    def __len__(self):
        return len(self.img)


# class YorkUrbanTrain(data.Dataset):
#     def __init__(self, data_root, transforms, phase="train", sigma_junction=3., max_junctions=512):
#         assert phase == "train"
#         self.data_root = data_root
#         self.img = [os.path.basename(f) for f in glob(os.path.join(data_root, "*.jpg"))]
#         self.transforms = transforms
#         self.phase = phase
#         self.max_junctions = max_junctions
#         self.sigma_junction = sigma_junction
#
#     def __getitem__(self, item):
#         img = Image.open(os.path.join(self.data_root, self.img[item]))
#         ori_w, ori_h = img.size
#
#         lg = LineGraph().load(os.path.join(self.data_root, self.img[item][:-4] + ".lg"))
#         num_junc = lg.num_junctions
#         # assert num_junc <= self.max_junctions, f"{(item, num_junc)}"
#         junc = np.zeros((max(num_junc, self.max_junctions), 2))
#         # tic = time()
#         junc[:num_junc] = np.array([j if np.sum(j) > 0 else j + 1 for j in lg.junctions()])
#         # print(f"junc time: {time() - tic:.4f}")
#
#         assert np.sum(junc[:num_junc].sum(axis=1) <= 0) == 0, f"{item}"
#         # tic = time()
#         adj_mtx = np.zeros((max(num_junc, self.max_junctions), max(num_junc, self.max_junctions)))
#         # print(f"mtx time: {time() - tic:.4f}")
#         adj_mtx[:num_junc, :num_junc] = lg.adj_mtx
#
#         if self.transforms is not None:
#             img, junc = self.transforms(img, junc)
#
#         cur_w, cur_h = img.size
#
#         junc[junc >= img.size[0]] = img.size[0] - 1
#         junc[junc < 0] = 0
#         # tic = time()
#         heatmap = gen_gaussian_map(junc[:num_junc], img.size[:2], self.sigma_junction)
#         assert cur_h == cur_w
#         line_map = lg.line_map(cur_h, cur_w / ori_w, cur_h / ori_h, line_width=self.sigma_junction)
#         # print(f"gaussian time: {time() - tic:.4f}")
#
#         if num_junc > self.max_junctions:
#             choice_junc = np.random.choice(num_junc, self.max_junctions, replace=False)
#             junc = np.array(junc[choice_junc])
#             adj_mtx = np.array(adj_mtx[choice_junc][:, choice_junc])
#
#         img = np.array(np.asarray(img)[:, :, ::-1])
#         img = th.from_numpy(img).permute(2, 0, 1)
#         adj_mtx = th.from_numpy(adj_mtx)
#         junc = th.from_numpy(junc)
#         heatmap = th.from_numpy(heatmap)
#         line_map = th.from_numpy(line_map)
#
#         batch = dict(
#             image=img.float(),
#             adj_mtx=adj_mtx.float(),
#             heatmap=heatmap.float(),
#             junctions=junc.float(),
#             line_map=line_map.float()
#         )
#
#         return batch
#
#     def __call__(self, item):
#         return self.__getitem__(item)
#
#     def __len__(self):
#         return len(self.img)

================================================
FILE: main.py
================================================
# System libs
import os
import time

# Numerical libs
import torch
import torch.nn as nn
from torch import optim
from torch.utils.data import DataLoader
import numpy as np

# Our libs
from data.sist_line import SISTLine
import data.transforms as tf
from models.lsd import LSDModule
from utils import AverageMeter, graph2line, draw_lines, draw_jucntions

# tensorboard
from tensorboardX import SummaryWriter
import torchvision.utils as vutils

import fire


def weight_fn(dist_map, max_dist, mid=0.1, scale=10):
    with torch.no_grad():
        dist_map = dist_map / max_dist
        weight = (torch.exp(scale * (dist_map - mid)) - torch.exp(scale * (-dist_map + mid))) / \
                   (torch.exp(scale * (dist_map - mid)) + torch.exp(scale * (-dist_map + mid))) / 2 + 0.5
        return weight


class LSDTrainer(object):
    def __init__(
            self,
            # exp params
            exp_name="u50_block",
            # arch params
            backbone="resnet50",
            backbone_kwargs={},
            dim_embedding=256,
            feature_spatial_scale=0.25,
            max_junctions=512,
            junction_pooling_threshold=0.2,
            junc_pooling_size=15,
            attention_sigma=1.,
            junction_heatmap_criterion="binary_cross_entropy",
            block_inference_size=64,
            adjacency_matrix_criterion="binary_cross_entropy",
            # data params
            data_root=r"/home/ziheng/indoorDist_new2",
            img_size=416,
            junc_sigma=3.,
            batch_size=2,
            # train params
            gpus=[0,],
            num_workers=5,
            resume_epoch="latest",
            is_train_junc=True,
            is_train_adj=True,
            # vis params
            vis_junc_th=0.3,
            vis_line_th=0.3
    ):
        os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(str(c) for c in gpus)

        self.is_cuda = bool(gpus)

        self.model = LSDModule(
            backbone=backbone,
            dim_embedding=dim_embedding,
            backbone_kwargs=backbone_kwargs,
            junction_pooling_threshold=junction_pooling_threshold,
            max_junctions=max_junctions,
            feature_spatial_scale=feature_spatial_scale,
            junction_heatmap_criterion=junction_heatmap_criterion,
            junction_pooling_size=junc_pooling_size,
            attention_sigma=attention_sigma,
            block_inference_size=block_inference_size,
            adjacency_matrix_criterion=adjacency_matrix_criterion,
            weight_fn=weight_fn,
            is_train_adj=is_train_adj,
            is_train_junc=is_train_junc
        )

        self.exp_name = exp_name
        os.makedirs(os.path.join("log", exp_name), exist_ok=True)
        os.makedirs(os.path.join("ckpt", exp_name), exist_ok=True)
        self.writer = SummaryWriter(log_dir=os.path.join("log", exp_name))

        # checkpoints
        self.states = dict(
            last_epoch=-1,
            elapsed_time=0,
            state_dict=None
        )

        if resume_epoch and os.path.isfile(os.path.join("ckpt", exp_name, f"train_states_{resume_epoch}.pth")):
            states = torch.load(
                os.path.join("ckpt", exp_name, f"train_states_{resume_epoch}.pth"))
            print(f"resume traning from epoch {states['last_epoch']}")
            self.model.load_state_dict(states["state_dict"])
            self.states.update(states)

        self.train_data = SISTLine(
            data_root=data_root,
            transforms=tf.Compose(
                tf.Resize((img_size, img_size)),
                tf.RandomHorizontalFlip(),
                tf.RandomColorAug()
            ),
            phase="train",
            sigma_junction=junc_sigma,
            max_junctions=max_junctions)
                  
        assert len(self.train_data) > 0, "Wow, there is nothing in your data folder. Please check the --data-root parameter in your train.sh."

        self.train_loader = DataLoader(
            self.train_data,
            batch_size=batch_size,
            shuffle=True,
            num_workers=num_workers,
            pin_memory=True
        )

        self.eval_data = SISTLine(
            data_root=data_root,
            transforms=tf.Compose(
                tf.Resize((img_size, img_size)),
            ),
            phase="val",
            sigma_junction=junc_sigma,
            max_junctions=max_junctions)

        self.eval_loader = DataLoader(
            self.eval_data,
            batch_size=batch_size,
            shuffle=False,
            num_workers=num_workers,
            pin_memory=True
        )

        self.vis_junc_th = vis_junc_th
        self.vis_line_th = vis_line_th
        self.block_size = block_inference_size
        self.max_junctions = max_junctions
        self.is_train_junc = is_train_junc
        self.is_train_adj = is_train_adj

    @staticmethod
    def _group_weight(module, lr):
        group_decay = []
        group_no_decay = []
        for m in module.modules():
            if isinstance(m, nn.Linear):
                group_decay.append(m.weight)
                if m.bias is not None:
                    group_no_decay.append(m.bias)
            elif isinstance(m, nn.modules.conv._ConvNd):
                group_decay.append(m.weight)
                if m.bias is not None:
                    group_no_decay.append(m.bias)
            elif isinstance(m, nn.modules.batchnorm._BatchNorm) or isinstance(m, nn.GroupNorm):
                if m.weight is not None:
                    group_no_decay.append(m.weight)
                if m.bias is not None:
                    group_no_decay.append(m.bias)

        assert len(list(
            module.parameters())) == len(group_decay) + len(group_no_decay)
        groups = [
            dict(params=group_decay, lr=lr),
            dict(params=group_no_decay, lr=lr, weight_decay=.0)
        ]
        return groups

    def end(self):
        self.writer.close()
        return "command queue finished."

    def _train_epoch(self):
        net_time = AverageMeter()
        data_time = AverageMeter()
        vis_time = AverageMeter()

        epoch = self.states["last_epoch"]
        data_loader = self.train_loader
        if self.is_cuda:
            self.model = self.model.cuda()
        params = self._group_weight(self.model.backbone, self.lr)
        if self.is_train_junc:
            params += self._group_weight(self.model.junc_infer, self.lr)
        if self.is_train_adj:
            params += self._group_weight(self.model.adj_infer, self.lr)
            params += self._group_weight(self.model.adj_embed, self.lr)
        if self.solver == "Adadelta":
            solver = optim.__dict__[self.solver](params, weight_decay=self.weight_decay)
        else:
            solver = optim.__dict__[self.solver](params, weight_decay=self.weight_decay, momentum=self.momentum)

        # main loop
        torch.set_grad_enabled(True)
        tic = time.time()
        print(f"start training epoch: {epoch}", flush=True)

        if self.is_cuda:
            model = nn.DataParallel(self.model).train()
        else:
            model = self.model.train()

        for i, batch in enumerate(data_loader):
            if self.is_cuda:
                img = batch["image"].cuda()
                heatmap_gt = batch["heatmap"].cuda()
                adj_mtx_gt = batch["adj_mtx"].cuda()
                junctions_gt = batch["junctions"].cuda()
            else:
                img = batch["image"]
                heatmap_gt = batch["heatmap"]
                adj_mtx_gt = batch["adj_mtx"]
                junctions_gt = batch["junctions"]

            # measure elapsed time
            data_time.update(time.time() - tic)
            tic = time.time()

            junc_pred, heatmap_pred, adj_mtx_pred, loss_hm, loss_adj = model(
                img, heatmap_gt, adj_mtx_gt, self.lambda_heatmap, self.lambda_adj, junctions_gt
            )

            model.zero_grad()
            loss_adj = loss_adj.mean()
            loss_hm = loss_hm.mean()
            loss = (loss_hm if self.is_train_junc else 0) + (loss_adj if self.is_train_adj else 0)
            loss.backward()
            solver.step()

            # measure elapsed time
            net_time.update(time.time() - tic)
            tic = time.time()

            # visualize result
            if i % self.vis_line_interval == 0:
                img = img.cpu().numpy()
                heatmap_pred = heatmap_pred.detach().cpu()
                adj_mtx_pred = adj_mtx_pred.detach().cpu().numpy()
                junctions_gt = junctions_gt.cpu().numpy()
                adj_mtx_gt = adj_mtx_gt.cpu().numpy()
                self._vis_train(epoch, i, len(data_loader), img, heatmap_pred, adj_mtx_pred, junctions_gt, adj_mtx_gt)

            vis_heatmap_gt = vutils.make_grid(
                heatmap_gt.view(heatmap_gt.size(0), 1, heatmap_gt.size(1), heatmap_gt.size(2)))
            vis_heatmap_pred = vutils.make_grid(
                heatmap_pred.view(heatmap_gt.size(0), 1, heatmap_gt.size(1), heatmap_gt.size(2)))

            self.writer.add_scalar(self.exp_name + "/" + "train/loss_total",
                                   loss.item(),
                                   epoch * len(data_loader) + i)
            self.writer.add_scalar(self.exp_name + "/" + "train/loss_heatmap",
                                   loss_hm.item() / self.lambda_heatmap if self.lambda_heatmap else 0,
                                   epoch * len(data_loader) + i)
            self.writer.add_scalar(self.exp_name + "/" + "train/loss_adj_mtx",
                                   loss_adj.item() / self.lambda_adj if self.lambda_adj else 0,
                                   epoch * len(data_loader) + i)
            self.writer.add_image(self.exp_name + "/" + "train/heatmap_gt",
                                  vis_heatmap_gt,
                                  epoch * len(data_loader) + i)
            self.writer.add_image(self.exp_name + "/" + "train/heatmap_pred",
                                  vis_heatmap_pred,
                                  epoch * len(data_loader) + i)

            vis_time.update(time.time() - tic)
            info = f"epoch: [{epoch}][{i}/{len(data_loader)}], lr: {self.lr}, " \
                   f"time_total: {net_time.average() + data_time.average() + vis_time.average():.2f}, " \
                   f"time_data: {data_time.average():.2f}, time_net: {net_time.average():.2f}, " \
                   f"time_vis: {vis_time.average():.2f}, " \
                   f"loss: {loss.item():.4f}, " \
                   f"loss_heatmap: {loss_hm.item() / self.lambda_heatmap if self.lambda_heatmap else 0:.4f}, " \
                   f"loss_adj_mtx: {loss_adj.item() / self.lambda_adj if self.lambda_adj else 0:.4f}"
            self.writer.add_text(self.exp_name + "/" + "train/info", info,
                                 epoch * len(data_loader) + i)
            print(info, flush=True)
            # measure elapsed time
            tic = time.time()

    def _vis_train(self, epoch, i, len_loader, img, heatmap, adj_mtx, junctions_gt, adj_mtx_gt):
        junctions_gt = np.int32(junctions_gt)
        lines_gt, scores_gt = graph2line(junctions_gt, adj_mtx_gt)
        vis_line_gt = vutils.make_grid(
            draw_lines(img, lines_gt, scores_gt))
        lines_pred, score_pred = graph2line(junctions_gt, adj_mtx, threshold=self.vis_line_th)
        vis_line_pred = vutils.make_grid(
            draw_lines(img, lines_pred, score_pred))
        junc_score = []
        line_score = []
        for m, juncs in zip(heatmap, junctions_gt):
            juncs = juncs[juncs.sum(axis=1) > 0]
            junc_score += m[juncs[:, 1], juncs[:, 0]].tolist()
        for s in score_pred:
            line_score += s.tolist()

        self.writer.add_image(self.exp_name + "/" + "train/lines_gt",
                              vis_line_gt,
                              epoch * len_loader + i)
        self.writer.add_image(self.exp_name + "/" + "train/lines_pred",
                              vis_line_pred,
                              epoch * len_loader + i)
        self.writer.add_scalar(
            self.exp_name + "/" + "train/mean_junc_score",
            np.mean(junc_score),
            epoch * len_loader + i)
        self.writer.add_scalar(
            self.exp_name + "/" + "train/mean_line_score",
            np.mean(line_score),
            epoch * len_loader + i)

    def _checkpoint(self):
        print('Saving checkpoints...')

        train_states = self.states

        train_states["state_dict"] = self.model.cpu().state_dict()

        torch.save(
            train_states,
            os.path.join("ckpt", self.exp_name,
                         "train_states_latest.pth"))
        torch.save(
            train_states,
            os.path.join("ckpt", self.exp_name,
                         f"train_states_{self.states['last_epoch']}.pth"))

        state = torch.load(os.path.join("ckpt", self.exp_name, "train_states_latest.pth"))
        self.model.load_state_dict(state["state_dict"])

    def train(
            self,
            end_epoch=20,
            solver="SGD",
            lr=1.,
            weight_decay=5e-4,
            momentum=0.9,
            lambda_heatmap=1.,
            lambda_adj=1.,
            vis_line_interval=20,
    ):
        self.vis_line_interval = vis_line_interval
        self.end_epoch = end_epoch
        self.lr = lr
        self.weight_decay = weight_decay
        self.momentum = momentum
        self.lambda_heatmap = lambda_heatmap
        self.lambda_adj = lambda_adj
        self.solver = solver

        start_epoch = self.states["last_epoch"] + 1

        for epoch in range(start_epoch, end_epoch):
            self.states["last_epoch"] = epoch
            self._train_epoch()
            self._checkpoint()

        return self

    def _vis_eval(self, epoch, i, len_loader, img, heatmap, adj_mtx, junctions_pred, junctions_gt, adj_mtx_gt):
        junctions_gt = np.int32(junctions_gt)
        lines_gt, scores_gt = graph2line(junctions_gt, adj_mtx_gt, threshold=self.vis_junc_th)
        vis_line_gt = vutils.make_grid(
            draw_lines(img, lines_gt, scores_gt))
        img_with_junc = draw_jucntions(img, junctions_pred)
        img_with_junc = torch.stack(img_with_junc, dim=0).numpy()[:, ::-1, :, :]
        lines_pred, score_pred = graph2line(junctions_pred, adj_mtx)
        vis_line_pred = vutils.make_grid(
            draw_lines(img_with_junc, lines_pred, score_pred))
        junc_score = []
        line_score = []
        for m, juncs in zip(heatmap, junctions_gt):
            juncs = juncs[juncs.sum(axis=1) > 0]
            junc_score += m[juncs[:, 1], juncs[:, 0]].tolist()
        for s in score_pred:
            line_score += s.tolist()

        junc_pooling = vutils.make_grid(draw_jucntions(heatmap, junctions_pred))

        self.writer.add_image(self.exp_name + "/" + "eval/junction_pooling",
                              junc_pooling,
                              epoch * len_loader + i)

        self.writer.add_image(self.exp_name + "/" + "eval/lines_gt",
                              vis_line_gt,
                              epoch * len_loader + i)
        self.writer.add_image(self.exp_name + "/" + "eval/lines_pred",
                              vis_line_pred,
                              epoch * len_loader + i)
        self.writer.add_scalar(
            self.exp_name + "/" + "eval/mean_junc_score",
            np.mean(junc_score),
            epoch * len_loader + i)
        self.writer.add_scalar(
            self.exp_name + "/" + "eval/mean_line_score",
            np.mean(line_score),
            epoch * len_loader + i)

    def eval(self,
             lambda_heatmap=1.,
             lambda_adj=1.,
             off_line=False,
             epoch=None
             ):

        if not off_line:
            if not (self.states["last_epoch"] == epoch - 1):
                return self
        else:
            self.lambda_heatmap = lambda_heatmap
            self.lambda_adj = lambda_adj

        net_time = AverageMeter()
        data_time = AverageMeter()
        vis_time = AverageMeter()
        ave_loss = AverageMeter()
        ave_loss_heatmap = AverageMeter()
        ave_loss_adj_mtx = AverageMeter()

        epoch = self.states["last_epoch"]
        data_loader = self.eval_loader

        # main loop
        torch.set_grad_enabled(False)
        tic = time.time()
        print(f"start evaluating epoch: {epoch}", flush=True)

        if self.is_cuda:
            model = nn.DataParallel(self.model.cuda()).train()
        else:
            model = self.model.train()

        for i, batch in enumerate(data_loader):
            if self.is_cuda:
                img = batch["image"].cuda()
                heatmap_gt = batch["heatmap"].cuda()
                adj_mtx_gt = batch["adj_mtx"].cuda()
                junctions_gt = batch["junctions"].cuda()
            else:
                img = batch["image"]
                heatmap_gt = batch["heatmap"]
                adj_mtx_gt = batch["adj_mtx"]
                junctions_gt = batch["junctions"]

            # measure elapsed time
            data_time.update(time.time() - tic)
            tic = time.time()

            junc_pred, heatmap_pred, adj_mtx_pred, loss_hm, loss_adj = model(
                img, heatmap_gt, adj_mtx_gt, self.lambda_heatmap, self.lambda_adj, junctions_gt
            )

            loss_adj = loss_adj.mean()
            loss_hm = loss_hm.mean()
            loss = loss_adj + loss_hm
            ave_loss_adj_mtx.update(loss_adj.item() / self.lambda_adj if self.lambda_adj else 0)
            ave_loss_heatmap.update(loss_hm.item() / self.lambda_heatmap if self.lambda_heatmap else 0)
            ave_loss.update(loss.item())

            # measure elapsed time
            net_time.update(time.time() - tic)
            tic = time.time()

            # visualize eval
            img = img.cpu().numpy()
            heatmap = heatmap_pred.detach().cpu().numpy()
            junctions_pred = junc_pred.detach().cpu().numpy()
            adj_mtx = adj_mtx_pred.detach().cpu().numpy()
            junctions_gt = junctions_gt.cpu().numpy()
            adj_mtx_gt = adj_mtx_gt.cpu().numpy()
            self._vis_eval(epoch, i, len(data_loader), img, heatmap, adj_mtx, junctions_pred, junctions_gt, adj_mtx_gt)

            vis_heatmap_gt = vutils.make_grid(
                heatmap_gt.view(heatmap_gt.size(0), 1, heatmap_gt.size(1), heatmap_gt.size(2)))
            vis_heatmap_pred = vutils.make_grid(
                heatmap.view(heatmap_gt.size(0), 1, heatmap_gt.size(1), heatmap_gt.size(2)))

            self.writer.add_scalar(self.exp_name + "/" + "eval/loss_total",
                                   loss.item(),
                                   epoch * len(data_loader) + i)
            self.writer.add_scalar(self.exp_name + "/" + "eval/loss_heatmap",
                                   loss_hm.item() / self.lambda_heatmap if self.lambda_heatmap else 0,
                                   epoch * len(data_loader) + i)
            self.writer.add_scalar(self.exp_name + "/" + "eval/loss_adj_mtx",
                                   loss_adj.item() / self.lambda_adj if self.lambda_adj else 0,
                                   epoch * len(data_loader) + i)
            self.writer.add_image(self.exp_name + "/" + "eval/heatmap_gt",
                                  vis_heatmap_gt,
                                  epoch * len(data_loader) + i)
            self.writer.add_image(self.exp_name + "/" + "eval/heatmap_pred",
                                  vis_heatmap_pred,
                                  epoch * len(data_loader) + i)

            vis_time.update(time.time() - tic)
            info = f"epoch: [{epoch}][{i}/{len(data_loader)}], " \
                   f"time_total: {net_time.average() + data_time.average() + vis_time.average():.2f}, " \
                   f"time_data: {data_time.average():.2f}, time_net: {net_time.average():.2f}, " \
                   f"time_vis: {vis_time.average():.2f}, " \
                   f"loss: {loss.item():.4f}, " \
                   f"loss_heatmap: {loss_hm.item() / self.lambda_heatmap if self.lambda_heatmap else 0:.4f}, " \
                   f"loss_adj_mtx: {loss_adj.item() / self.lambda_adj if self.lambda_adj else 0:.4f}"
            if i == len(data_loader) - 1:
                info += f"\n*[{epoch}] " \
                        f"ave_loss: {ave_loss.average():.4f}, " \
                        f"ave_loss_heatmap: {ave_loss_heatmap.average():.4f}, " \
                        f"ave_loss_adj_mtx: {ave_loss_adj_mtx.average():.4f}"

            self.writer.add_text(self.exp_name + "/" + "eval/info", info,
                                 epoch * len(data_loader) + i)
            print(info, flush=True)
            # measure elapsed time
            tic = time.time()

        return self


if __name__ == "__main__":
    fire.Fire(LSDTrainer)
    # trainer = LSDTrainer().train(lr=1.)


================================================
FILE: models/__init__.py
================================================
from . import lsd

================================================
FILE: models/backbone.py
================================================
import os
import sys
import torch
import torch.nn as nn
import math
from models.common import conv3x3, conv3x3_bn_relu, inconv, up, down, outconv, weights_init
from torch.nn import functional as F

try:
    from urllib import urlretrieve
except ImportError:
    from urllib.request import urlretrieve


__all__ = ['ResNetU50Backbone', 'UNetBackbone']


model_urls = {
    'resnet50': 'http://sceneparsing.csail.mit.edu/model/pretrained_resnet/resnet50-imagenet.pth',
}


class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(Bottleneck, self).__init__()
        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
                               padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(planes * 4)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out


class ResNet(nn.Module):

    def __init__(self, block, layers, num_classes=1000):
        self.inplanes = 128
        super(ResNet, self).__init__()
        self.conv1 = conv3x3(3, 64, stride=2)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu1 = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(64, 64)
        self.bn2 = nn.BatchNorm2d(64)
        self.relu2 = nn.ReLU(inplace=True)
        self.conv3 = conv3x3(64, 128)
        self.bn3 = nn.BatchNorm2d(128)
        self.relu3 = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

        self.layer1 = self._make_layer(block, 64, layers[0])
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
        self.avgpool = nn.AvgPool2d(7, stride=1)
        self.fc = nn.Linear(512 * block.expansion, num_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

    def _make_layer(self, block, planes, blocks, stride=1):
        downsample = None
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(self.inplanes, planes * block.expansion,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(planes * block.expansion),
            )

        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample))
        self.inplanes = planes * block.expansion
        for i in range(1, blocks):
            layers.append(block(self.inplanes, planes))

        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.relu1(self.bn1(self.conv1(x)))
        x = self.relu2(self.bn2(self.conv2(x)))
        x = self.relu3(self.bn3(self.conv3(x)))
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)

        return x


class ResnetDilated(nn.Module):
    def __init__(self, orig_resnet, dilate_scale=8):
        super(ResnetDilated, self).__init__()
        from functools import partial

        if dilate_scale == 8:
            orig_resnet.layer3.apply(
                partial(self._nostride_dilate, dilate=2))
            orig_resnet.layer4.apply(
                partial(self._nostride_dilate, dilate=4))
        elif dilate_scale == 16:
            orig_resnet.layer4.apply(
                partial(self._nostride_dilate, dilate=2))

        # take pretrained resnet, except AvgPool and FC
        self.conv1 = orig_resnet.conv1
        self.bn1 = orig_resnet.bn1
        self.relu1 = orig_resnet.relu1
        self.conv2 = orig_resnet.conv2
        self.bn2 = orig_resnet.bn2
        self.relu2 = orig_resnet.relu2
        self.conv3 = orig_resnet.conv3
        self.bn3 = orig_resnet.bn3
        self.relu3 = orig_resnet.relu3
        self.maxpool = orig_resnet.maxpool
        self.layer1 = orig_resnet.layer1
        self.layer2 = orig_resnet.layer2
        self.layer3 = orig_resnet.layer3
        self.layer4 = orig_resnet.layer4

    @staticmethod
    def _nostride_dilate(m, dilate):
        classname = m.__class__.__name__
        if classname.find('Conv') != -1:
            # the convolution with stride
            if m.stride == (2, 2):
                m.stride = (1, 1)
                if m.kernel_size == (3, 3):
                    m.dilation = (dilate // 2, dilate // 2)
                    m.padding = (dilate // 2, dilate // 2)
            # other convoluions
            else:
                if m.kernel_size == (3, 3):
                    m.dilation = (dilate, dilate)
                    m.padding = (dilate, dilate)

    def forward(self, x):
        conv_out = []

        x = self.relu1(self.bn1(self.conv1(x)))
        x = self.relu2(self.bn2(self.conv2(x)))
        x = self.relu3(self.bn3(self.conv3(x)))
        x = self.maxpool(x)

        x = self.layer1(x)
        conv_out.append(x)
        x = self.layer2(x)
        conv_out.append(x)
        x = self.layer3(x)
        conv_out.append(x)
        x = self.layer4(x)
        conv_out.append(x)

        return conv_out


def resnet50(pretrained=False, **kwargs):
    """Constructs a ResNet-50 model.

    Args:
        pretrained (bool): If True, returns a model pre-trained on Places
    """
    model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
    if pretrained:
        model.load_state_dict(load_url(model_urls['resnet50']), strict=False)
    return model


def load_url(url, model_dir='./pretrained', map_location=None):
    if not os.path.exists(model_dir):
        os.makedirs(model_dir)
    filename = url.split('/')[-1]
    cached_file = os.path.join(model_dir, filename)
    if not os.path.exists(cached_file):
        sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file))
        urlretrieve(url, cached_file)
    return torch.load(cached_file, map_location=map_location)


class UPerNet(nn.Module):
    def __init__(self, num_class=150, fc_dim=4096, pool_scales=(1, 2, 3, 6),
                 fpn_inplanes=(256, 512, 1024, 2048), fpn_dim=256):
        super(UPerNet, self).__init__()

        # PPM Module
        self.ppm_pooling = []
        self.ppm_conv = []

        for scale in pool_scales:
            self.ppm_pooling.append(nn.AdaptiveAvgPool2d(scale))
            self.ppm_conv.append(nn.Sequential(
                nn.Conv2d(fc_dim, 512, kernel_size=1, bias=False),
                nn.BatchNorm2d(512),
                nn.ReLU(inplace=True)
            ))
        self.ppm_pooling = nn.ModuleList(self.ppm_pooling)
        self.ppm_conv = nn.ModuleList(self.ppm_conv)
        self.ppm_last_conv = conv3x3_bn_relu(fc_dim + len(pool_scales) * 512, fpn_dim, 1)

        # FPN Module
        self.fpn_in = []
        for fpn_inplane in fpn_inplanes[:-1]:  # skip the top layer
            self.fpn_in.append(nn.Sequential(
                nn.Conv2d(fpn_inplane, fpn_dim, kernel_size=1, bias=False),
                nn.BatchNorm2d(fpn_dim),
                nn.ReLU(inplace=True)
            ))
        self.fpn_in = nn.ModuleList(self.fpn_in)

        self.fpn_out = []
        for i in range(len(fpn_inplanes) - 1):  # skip the top layer
            self.fpn_out.append(nn.Sequential(
                conv3x3_bn_relu(fpn_dim, fpn_dim, 1),
            ))
        self.fpn_out = nn.ModuleList(self.fpn_out)

        self.conv_last = nn.Sequential(
            conv3x3_bn_relu(len(fpn_inplanes) * fpn_dim, fpn_dim, 1),
            nn.Conv2d(fpn_dim, num_class, kernel_size=1)
        )

    def forward(self, conv_out, segSize=None):
        conv5 = conv_out[-1]

        input_size = conv5.size()
        ppm_out = [conv5]
        for pool_scale, pool_conv in zip(self.ppm_pooling, self.ppm_conv):
            ppm_out.append(pool_conv(F.interpolate(
                pool_scale(conv5),
                (input_size[2], input_size[3]),
                mode='bilinear', align_corners=False)))
        ppm_out = torch.cat(ppm_out, 1)
        f = self.ppm_last_conv(ppm_out)

        fpn_feature_list = [f]
        for i in reversed(range(len(conv_out) - 1)):
            conv_x = conv_out[i]
            conv_x = self.fpn_in[i](conv_x)  # lateral branch

            f = F.interpolate(
                f, size=conv_x.size()[2:], mode='bilinear', align_corners=False)  # top-down branch
            f = conv_x + f

            fpn_feature_list.append(self.fpn_out[i](f))

        fpn_feature_list.reverse()  # [P2 - P5]
        output_size = fpn_feature_list[0].size()[2:]
        fusion_list = [fpn_feature_list[0]]
        for i in range(1, len(fpn_feature_list)):
            fusion_list.append(F.interpolate(
                fpn_feature_list[i],
                output_size,
                mode='bilinear', align_corners=False))
        fusion_out = torch.cat(fusion_list, 1)
        x = self.conv_last(fusion_out)

        x = F.interpolate(
            x, size=segSize, mode='bilinear', align_corners=False)

        return x


class ResNetU50Backbone(nn.Module):
    def __init__(self, dim_embedding=256, encoder_weights="", decoder_weights=""):
        super(ResNetU50Backbone, self).__init__()
        self.encoder = self.build_encoder(encoder_weights)
        self.decoder = self.build_decoder(decoder_weights, dim_embedding)

    def forward(self, img):
        out = self.encoder(img)
        out = self.decoder(out, segSize=(img.size(2) // 4, img.size(3) // 4))

        return out

    @staticmethod
    def build_encoder(weights=''):
        pretrained = True if len(weights) == 0 else False
        orig_resnet = resnet50(pretrained=pretrained)
        net_encoder = ResnetDilated(orig_resnet,
                                    dilate_scale=8)
        if len(weights) > 0 and os.path.isfile(weights):
            print(f'Loading weights for net_encoder @ {weights}')
            net_encoder.load_state_dict(
                torch.load(weights, map_location=lambda storage, loc: storage),
                strict=False
            )
        return net_encoder

    @staticmethod
    def build_decoder(weights='', dim_embedding=64):
        net_decoder = UPerNet(
            num_class=150,
            fc_dim=2048,
            fpn_dim=512
        )

        net_decoder.conv_last[1] = conv3x3_bn_relu(net_decoder.conv_last[1].in_channels, dim_embedding, 1)
        net_decoder.apply(weights_init)
        if len(weights) > 0 and os.path.isfile(weights):
            print(f'Loading weights for net_decoder @ {weights}')
            net_decoder.load_state_dict(
                torch.load(weights, map_location=lambda storage, loc: storage),
                strict=False
            )

        return net_decoder


class UNetBackbone(nn.Module):
    def __init__(self, dim_embedding=256, n_downs=5, n_ups=3, weights=""):
        super(UNetBackbone, self).__init__()
        assert n_downs > 0 and 0 < n_ups <= n_downs
        self.n_downs = n_downs
        self.n_ups = n_ups
        self.inc = inconv(3, 64)
        down_channels = []
        for i in range(n_downs):
            down_channel = 64 * 2**min(n_downs - 1, i + 1)
            self.add_module(
                f"down{i+1}",
                down(64 * 2**i, down_channel)
            )
            down_channels.append(down_channel)
        for i in range(n_ups):
            down_channels.pop()
            self.add_module(
                f"up{i+1}",
                up(64 * 2**(n_downs - i), 64 * 2**max(0, n_downs - i - 2))
            )

        self.outc = outconv(64 * 2**max(0, n_downs - n_ups - 1) + sum(down_channels) // 2, dim_embedding)

        self.apply(weights_init)

        if len(weights) > 0 and os.path.isfile(weights):
            print(f'Loading weights for unet @ {weights}')
            self.load_state_dict(
                torch.load(weights, map_location=lambda storage, loc: storage),
                strict=False
            )

    def forward(self, x):
        downs = [self.inc(x)]
        for i in range(self.n_downs):
            downs.append(getattr(self, f"down{i+1}")(downs[i]))
        out = downs.pop()
        for i in range(self.n_ups):
            out = getattr(self, f"up{i+1}")(out, downs.pop())
        for i in range(len(downs)):
            downs[i] = F.interpolate(downs[i], size=out.shape[2:], mode="bilinear", align_corners=False)
        downs.append(out)
        out = self.outc(torch.cat(downs, dim=1))

        return out


================================================
FILE: models/common.py
================================================
from scipy.ndimage.filters import maximum_filter
from scipy.ndimage.morphology import generate_binary_structure, binary_erosion
from scipy.cluster.hierarchy import fclusterdata
import numpy as np
import torch.nn as nn
import torch
from torch.nn import functional as F
from torch.autograd import Function
from torch.autograd.function import once_differentiable
from collections import Sized, Iterable


class LMFPeakFinder(object):
    """
    shamelessly borrow from https://stackoverflow.com/a/3689710
    Takes an image and detect the peaks using the local maximum filter.
    Returns a boolean mask of the peaks (i.e. 1 when
    the pixel's value is the neighborhood maximum, 0 otherwise)
    """

    def __init__(self, min_dist=5., min_th=0.3):
        self.min_dist = min_dist
        self.min_th = min_th

    def detect(self, image):
        # define an 8-connected neighborhood
        neighborhood = generate_binary_structure(2, 2)

        # apply the local maximum filter; all pixel of maximal value
        # in their neighborhood are set to 1
        local_max = maximum_filter(image, footprint=neighborhood) == image
        # local_max is a mask that contains the peaks we are
        # looking for, but also the background.
        # In order to isolate the peaks we must remove the background from the mask.

        # we create the mask of the background
        background = (image < self.min_th)

        # a little technicality: we must erode the background in order to
        # successfully subtract it form local_max, otherwise a line will
        # appear along the background border (artifact of the local maximum filter)
        eroded_background = binary_erosion(background, structure=neighborhood, border_value=1)

        # we obtain the final mask, containing only peaks,
        # by removing the background from the local_max mask (xor operation)
        detected_peaks = local_max ^ eroded_background

        detected_peaks[image < self.min_th] = False
        peaks = np.array(np.nonzero(detected_peaks)).T

        if len(peaks) == 0:
            return peaks, np.array([])

        # nms
        if len(peaks) == 1:
            clusters = [0]
        else:
            clusters = fclusterdata(peaks, self.min_dist, criterion="distance")
        peak_groups = {}
        for ind_junc, ind_group in enumerate(clusters):
            if ind_group not in peak_groups.keys():
                peak_groups[ind_group] = []
                peak_groups[ind_group].append(peaks[ind_junc])
        peaks_nms = []
        peaks_score = []
        for peak_group in peak_groups.values():
            values = [image[y, x] for y, x in peak_group]
            ind_max = np.argmax(values)
            peaks_nms.append(peak_group[int(ind_max)])
            peaks_score.append(values[int(ind_max)])

        return np.float32(np.array(peaks_nms)), np.float32(np.array(peaks_score))


def conv3x3(in_planes, out_planes, stride=1):
    "3x3 convolution with padding"
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                     padding=1, bias=False)


def conv3x3_bn_relu(in_planes, out_planes, stride=1):
    return nn.Sequential(
        conv3x3(in_planes, out_planes, stride),
        nn.BatchNorm2d(out_planes),
        nn.ReLU(inplace=True),
    )


class double_conv(nn.Module):
    '''(conv => BN => ReLU) * 2'''

    def __init__(self, in_ch, out_ch):
        super(double_conv, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        x = self.conv(x)
        return x


class inconv(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(inconv, self).__init__()
        self.conv = double_conv(in_ch, out_ch)

    def forward(self, x):
        x = self.conv(x)
        return x


class down(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(down, self).__init__()
        self.mpconv = nn.Sequential(
            nn.MaxPool2d(2),
            double_conv(in_ch, out_ch)
        )

    def forward(self, x):
        x = self.mpconv(x)
        return x


class up(nn.Module):
    def __init__(self, in_ch, out_ch, bilinear=True):
        super(up, self).__init__()

        #  would be a nice idea if the upsampling could be learned too,
        #  but my machine do not have enough memory to handle all those weights
        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        else:
            self.up = nn.ConvTranspose2d(in_ch // 2, in_ch // 2, 2, stride=2)

        self.conv = double_conv(in_ch, out_ch)

    def forward(self, x1, x2):
        x1 = self.up(x1)
        diffX = x1.size()[2] - x2.size()[2]
        diffY = x1.size()[3] - x2.size()[3]
        x2 = F.pad(x2, (diffX // 2, int(diffX / 2),
                        diffY // 2, int(diffY / 2)))
        x = torch.cat([x2, x1], dim=1)
        x = self.conv(x)
        return x


class outconv(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(outconv, self).__init__()
        self.conv = nn.Conv2d(in_ch, out_ch, 1)

    def forward(self, x):
        x = self.conv(x)
        return x


def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.kaiming_normal_(m.weight.data)
    elif classname.find('BatchNorm') != -1:
        m.weight.data.fill_(1.)
        m.bias.data.fill_(1e-4)


def roi_pooling(input, rois, size=(7, 7), spatial_scale=1.0):
    assert (rois.dim() == 2)
    assert (rois.size(1) == 5)
    output = []
    rois = rois.data.float()
    num_rois = rois.size(0)

    rois[:, 1:].mul_(spatial_scale)
    rois = rois.long()
    for i in range(num_rois):
        roi = rois[i]
        im_idx = roi[0]
        if roi[1] >= input.size(3) or roi[2] >= input.size(2) or roi[1] < 0 or roi[2] < 0:
            # print(f"Runtime Warning: roi top left corner out of range: {roi}", file=sys.stderr)
            roi[1] = torch.clamp(roi[1], 0, input.size(3) - 1)
            roi[2] = torch.clamp(roi[2], 0, input.size(2) - 1)
        if roi[3] >= input.size(3) or roi[4] >= input.size(2) or roi[3] < 0 or roi[4] < 0:
            # print(f"Runtime Warning: roi bottom right corner out of range: {roi}", file=sys.stderr)
            roi[3] = torch.clamp(roi[3], 0, input.size(3) - 1)
            roi[4] = torch.clamp(roi[4], 0, input.size(2) - 1)
        if (roi[3:5] - roi[1:3] < 0).any():
            # print(f"Runtime Warning: invalid roi: {roi}", file=sys.stderr)
            im = input.new_full((1, input.size(1), 1, 1), 0)
        else:
            im = input.narrow(0, im_idx, 1)[..., roi[2]:(roi[4] + 1), roi[1]:(roi[3] + 1)]
        output.append(F.adaptive_max_pool2d(im, size))

    return torch.cat(output, 0)


class GradAccumulatorFunction(Function):
    @staticmethod
    def forward(ctx, input, accumulated_grad=None, mode="release"):
        ctx.accumulated_grad = accumulated_grad
        ctx.mode = mode
        return input

    @staticmethod
    @once_differentiable
    def backward(ctx, grad_output):
        accumulated_grad = ctx.accumulated_grad
        ctx.accumulated_grad = None
        if ctx.mode == "accumulate":
            accumulated_grad.add_(grad_output)
            return torch.zeros_like(grad_output), None, None
        elif ctx.mode == "release":
            if accumulated_grad is not None:
                accumulated_grad.add_(grad_output)
            else:
                accumulated_grad = grad_output
            grad_output = accumulated_grad
            return grad_output, None, None
        else:
            raise ValueError(f"invalid mode {ctx.mode}")


class GradAccumulator(nn.Module):
    """
    Helper module used to accumulate gradient of the given tensor w.r.t output of criterion.
    Typically used when we have a feature extractor followed by several modules that can be calculate independently, the
    module only retains the last executed submodule and accumulate the gradient produced by former submodules, so that
    GPU memory used to store the temporary variables in former submodules is saved. It can be also used to extend
    effective batch size at little expense of memory.
    """
    def __init__(self, criterion_fns, submodules, collect_fn=None, reduce_method="mean"):
        super(GradAccumulator, self).__init__()
        assert isinstance(submodules, (Sized, Iterable)), "invalid submodules"
        if isinstance(criterion_fns, (Sized, Iterable)):
            assert len(submodules) == len(criterion_fns)
            assert all([isinstance(submodule, nn.Module) for submodule in submodules])
            assert all([isinstance(criterion_fn, nn.Module) for criterion_fn in criterion_fns])
        elif isinstance(criterion_fns, nn.Module):
            criterion_fns = [criterion_fns for _ in range(len(submodules))]
        elif criterion_fns is None:
            criterion_fns = [criterion_fns for _ in range(len(submodules))]
        else:
            raise ValueError("invalid criterion function")
        assert reduce_method in ("mean", "sum", None)

        self.submodules = nn.ModuleList(submodules)
        self.criterion_fns = nn.ModuleList(criterion_fns)
        self.method = reduce_method
        self.grad_buffer = None
        self.func = GradAccumulatorFunction.apply
        self.collect_fn = collect_fn

    def forward(self, tensor):
        outputs = []
        losses = tensor.new_full((1,), 0)
        self.grad_buffer = None
        for i, (submodule, criterion) in enumerate(zip(self.submodules, self.criterion_fns)):
            mode = "accumulate" if i < len(self.submodules) - 1 else "release"
            if self.grad_buffer is None:
                self.grad_buffer = torch.zeros_like(tensor)
            if mode == "accumulate":
                output = tensor.detach()
                output.requires_grad = True
            else:
                output = tensor
            output = self.func(
                output,
                self.grad_buffer,
                mode,
            )
            if isinstance(output, tuple):
                output = submodule(*output)
            else:
                output = submodule(output)
            if criterion is not None:
                loss = criterion(output)
                if self.method == "mean":
                    loss = loss / len(self.submodules)

                if mode == "accumulate" and torch.is_grad_enabled():
                    loss.backward()
                    loss = loss.detach()

                output = output.detach()
                losses += loss
            else:
                assert not output.requires_grad, "criterion must be specified to calculate output gradient"

            outputs.append(output)

        if self.collect_fn is not None:
            with torch.no_grad():
                outputs = self.collect_fn(outputs)

        return outputs, losses


================================================
FILE: models/graph.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
from models.common import conv3x3_bn_relu, LMFPeakFinder, weights_init


class JunctionInference(nn.Module):
    def __init__(self, dim_embedding, pooling_threshold=0.2, max_junctions=512, spatial_scale=0.25, verbose=False):
        super(JunctionInference, self).__init__()
        self.dim_embedding = dim_embedding
        self.pool_th = pooling_threshold
        self.max_juncs = max_junctions
        self.map_infer = nn.Sequential(
            conv3x3_bn_relu(dim_embedding, dim_embedding // 2, 1),
            conv3x3_bn_relu(dim_embedding // 2, dim_embedding // 2, 1),
            nn.Conv2d(dim_embedding // 2, 1, 1),
            nn.Sigmoid()
        )
        self.verbose = verbose
        self.scale = spatial_scale
        self.map_infer.apply(weights_init)

    def forward(self, feat):
        bs, ch, h, w = feat.size()
        junc_map = self.map_infer(feat)
        junc_map = nn.functional.interpolate(
            junc_map,
            scale_factor=1. / self.scale,
            mode="bilinear",
            align_corners=False
        )
        junc_coord = []
        for b in range(bs):
            peak_finder = LMFPeakFinder(min_th=self.pool_th)
            coord, score = peak_finder.detect(junc_map[b, 0].data.cpu().numpy())
            if self.verbose:
                print(f"find {len(coord)} jucntions.", flush=True)
            if coord is None or len(coord) == 0:
                continue
            junc_score = torch.from_numpy(score).to(feat)
            _, ind = torch.sort(junc_score, descending=True)
            ind = ind.cpu() 
            coord = coord[ind[:self.max_juncs]]
            coord = coord.reshape((-1, 2))
            y, x = coord[:, 0], coord[:, 1]
            y = torch.from_numpy(y).to(feat)
            x = torch.from_numpy(x).to(feat)
            assert (x >= 0).all() and (x < junc_map.size(3)).all() and (y >= 0).all() and (y < junc_map.size(2)).all()
            junc_coord.append(
                torch.cat([feat.new_full((len(x), 1), b), x.view(-1, 1), y.view(-1, 1)], dim=1)
            )
        if len(junc_coord) > 0:
            junc_coord = torch.cat(junc_coord, dim=0)
        else:
            junc_coord = feat.new_full((1, 3), 0.)

        return junc_map.squeeze(1), junc_coord


class LinePooling(nn.Module):
    def __init__(self, align_size=256, spatial_scale=0.25):
        super(LinePooling, self).__init__()
        self.align_size = align_size
        assert isinstance(self.align_size, int)
        self.scale = spatial_scale

    def forward(self, feat, coord_st, coord_ed):
        _, ch, h, w = feat.size()
        num_st, num_ed = coord_st.size(0), coord_ed.size(0)
        assert coord_st.size(1) == 3 and coord_ed.size(1) == 3
        assert (coord_st[:, 0] == coord_st[0, 0]).all() and (coord_ed[:, 0] == coord_st[0, 0]).all()
        bs = coord_st[0, 0].item()
        # construct bounding boxes from junction points
        with torch.no_grad():
            coord_st = coord_st[:, 1:] * self.scale
            coord_ed = coord_ed[:, 1:] * self.scale
            coord_st = coord_st.unsqueeze(1).expand(num_st, num_ed, 2)
            coord_ed = coord_ed.unsqueeze(0).expand(num_st, num_ed, 2)
            arr_st2ed = coord_ed - coord_st
            sample_grid = torch.linspace(0, 1, steps=self.align_size).to(feat).view(1, 1, self.align_size).expand(num_st, num_ed, self.align_size)
            sample_grid = torch.einsum("ijd,ijs->ijsd", (arr_st2ed, sample_grid)) + coord_st.view(num_st, num_ed, 1, 2).expand(num_st, num_ed, self.align_size, 2)
            sample_grid = sample_grid.view(num_st, num_ed, self.align_size, 2)
            sample_grid[..., 0] = sample_grid[..., 0] / (w - 1) * 2 - 1
            sample_grid[..., 1] = sample_grid[..., 1] / (h - 1) * 2 - 1

        output = F.grid_sample(feat[int(bs)].view(1, ch, h, w).expand(num_st, ch, h, w), sample_grid)
        assert output.size() == (num_st, ch, num_ed, self.align_size)
        output = output.permute(0, 2, 1, 3).contiguous()

        return output


class AdjacencyMatrixInference(nn.Module):
    def __init__(self, dim_embedding=256, align_size=256):
        super(AdjacencyMatrixInference, self).__init__()
        self.dim_embedding = dim_embedding
        self.align_size = align_size
        self.dblock = nn.Sequential(
            nn.Conv1d(dim_embedding, dim_embedding, 8, 4, 2, bias=False),
            nn.GroupNorm(32, dim_embedding),
            nn.ReLU(inplace=True),
            nn.Conv1d(dim_embedding, dim_embedding, 8, 4, 2, bias=False),
            nn.GroupNorm(32, dim_embedding),
            nn.ReLU(inplace=True),
            nn.Conv1d(dim_embedding, dim_embedding, 8, 4, 2, bias=False),
            nn.GroupNorm(32, dim_embedding),
            nn.ReLU(inplace=True)
        )
        self.connectivity_inference = nn.Sequential(
            nn.Conv1d(dim_embedding, 1, 1, 1, 0),
            nn.Sigmoid()
        )

    def forward(self, line_feat):
        num_st, num_ed, c, s = line_feat.size()
        output_st2ed = line_feat.view(num_st * num_ed, c, s)
        output_ed2st = torch.flip(output_st2ed, (2, ))
        output_st2ed = self.dblock(output_st2ed)
        output_ed2st = self.dblock(output_ed2st)
        adjacency_matrix1 = self.connectivity_inference(output_st2ed).view(num_st, num_ed)
        adjacency_matrix2 = self.connectivity_inference(output_ed2st).view(num_st, num_ed)

        return torch.min(adjacency_matrix1, adjacency_matrix2)


================================================
FILE: models/lsd.py
================================================
import torch as th
import torch.nn as nn
import torch.nn.functional as F
import models.graph
import models.backbone
import models.common
import numpy as np


class LSDDataLayer(nn.Module):
    def __init__(self, mean=None, std=None):
        super(LSDDataLayer, self).__init__()
        self.std = [1., 1., 1.] if std is None else std
        self.mean = [102.9801, 115.9465, 122.7717] if mean is None else mean

    def forward(self, img):
        assert img.size(1) == 3
        for ch in range(3):
            img[:, ch, :, :] = (img[:, ch, :, :] - self.mean[ch]) / self.std[ch]

        return img


# noinspection PyTypeChecker
class BinaryFocalLoss(nn.Module):
    def __init__(self, gamma=2., alpha=0.25, size_average=True):
        super(BinaryFocalLoss, self).__init__()
        self.gamma = gamma
        self.alpha = alpha
        self.size_average = size_average

    def forward(self, input, target, weight=None):
        if weight is not None:
            assert weight.size() == input.size(), f"weight size: {weight.size()}, input size: {input.size()}"
            assert (weight >= 0).all() and (weight <= 1).all(), f"weight max: {weight.max()}, min: {weight.min()}"
        input = input.clamp(1.e-6, 1. - 1.e-6)
        if weight is None:
            loss = th.sum(
                - self.alpha * target * ((1 - input) ** self.gamma) * th.log(input)
                - (1 - self.alpha) * (1 - target) * (input ** self.gamma) * th.log(1 - input))
        else:
            loss = th.sum(
                (- self.alpha * target * ((1 - input) ** self.gamma) * th.log(input)
                 - (1 - self.alpha) * (1 - target) * (input ** self.gamma) * th.log(1 - input)) * weight
            )
        if self.size_average:
            loss /= input.nelement()
        return loss


class BlockAdjacencyMatrixInference(nn.Module):
    def __init__(self,
                 line_pool_module, adj_infer_module,
                 current_batch_id, junc_st_st, junc_st_len, junc_ed_st, junc_ed_len, junc_pred
                 ):
        super(BlockAdjacencyMatrixInference, self).__init__()
        self.line_pool = line_pool_module
        self.adj_infer = adj_infer_module
        self.b = current_batch_id
        self.st_st = junc_st_st
        self.st_len = junc_st_len
        self.ed_st = junc_ed_st
        self.ed_len = junc_ed_len
        self.juncs = junc_pred

    def forward(self, feat):
        junc_st = self.juncs.narrow(0, self.st_st, self.st_len)
        junc_ed = self.juncs.narrow(0, self.ed_st, self.ed_len)
        assert (junc_st[:, 0] == self.b).all() and (junc_ed[:, 0] == self.b).all(), f"{self.b}\n{junc_st[:, 0]}\n{junc_ed[:, 0]}"
        line_feat = self.line_pool(feat, junc_st, junc_ed)
        block_adj_matrix = self.adj_infer(line_feat)

        return block_adj_matrix


class BlockAdjacencyMatrixInferenceCriterion(nn.Module):
    def __init__(self, adj_matrix_crit, adj_matrix_gt, adj_matrix_loss_lambda,
                 current_batch_id, mtx_st_st, mtx_st_len, mtx_ed_st, mtx_ed_len,
                 junc_padded, img_size, line_seg_length_weight_fn
                 ):
        super(BlockAdjacencyMatrixInferenceCriterion, self).__init__()
        self.adj_crit = adj_matrix_crit
        self.adj_gt = adj_matrix_gt
        self.loss_lambda = adj_matrix_loss_lambda
        self.b = current_batch_id
        self.st_st = mtx_st_st
        self.st_len = mtx_st_len
        self.ed_st = mtx_ed_st
        self.ed_len = mtx_ed_len
        self.junc = junc_padded
        self.img_size = img_size
        self.weight = line_seg_length_weight_fn

    def forward(self, block_adj_matrix):
        block_adj_matrix_gt = self.adj_gt[self.b, self.st_st:self.st_st+self.st_len, self.ed_st:self.ed_st+self.ed_len]
        if self.junc is not None:
            junc_st = self.junc[self.b, self.st_st:self.st_st + self.st_len].view(self.st_len, 1, 2).expand(self.st_len,
                                                                                                            self.ed_len,
                                                                                                            2)
            junc_ed = self.junc[self.b, self.ed_st:self.ed_st + self.ed_len].view(1, self.ed_len, 2).expand(self.st_len,
                                                                                                            self.ed_len,
                                                                                                            2)
            line_len = (junc_ed - junc_st).norm(dim=2)
            return self.loss_lambda * self.adj_crit(block_adj_matrix, block_adj_matrix_gt, weight=self.weight(line_len, self.img_size * 1.4143))
        else:
            return self.loss_lambda * self.adj_crit(block_adj_matrix, block_adj_matrix_gt)


class LSDModule(nn.Module):
    def __init__(
            self,
            # backbone parameters
            backbone="unet",
            dim_embedding=256,
            backbone_kwargs={},
            # junction inference parameters
            junction_pooling_threshold=0.2,
            max_junctions=512,
            feature_spatial_scale=0.25,
            junction_heatmap_criterion="binary_cross_entropy",
            # junction pooling parameters
            junction_pooling_size=15.,
            # directional attention parameters
            attention_sigma=1.,
            # adjacency matrix inference parameters
            block_inference_size=64,
            adjacency_matrix_criterion="binary_cross_entropy",
            weight_fn=None,
            is_train_junc=True,
            is_train_adj=True,
            enable_junc_infer=True,
            enable_adj_infer=True,
            verbose=True,
            **kwargs
    ):
        super(LSDModule, self).__init__()
        if backbone == "unet":
            backbone_kwargs.update({
                "n_downs": 5,
                "n_ups": 3
            })
            self.backbone = models.backbone.UNetBackbone(
                dim_embedding=dim_embedding,
                **backbone_kwargs
            )
        elif backbone == "resnet50":
            self.backbone = models.backbone.ResNetU50Backbone(
                dim_embedding=dim_embedding,
                **backbone_kwargs
            )
        else:
            raise ValueError(f"invalid backbone: {backbone}")

        self.prep_data = LSDDataLayer()

        self.junc_infer = models.graph.JunctionInference(
            dim_embedding=dim_embedding,
            pooling_threshold=junction_pooling_threshold,
            max_junctions=max_junctions,
            spatial_scale=feature_spatial_scale,
            verbose=verbose
        )

        self.line_pool = models.graph.LinePooling(
            align_size=junction_pooling_size,
            spatial_scale=feature_spatial_scale
        )

        self.adj_infer = models.graph.AdjacencyMatrixInference(
            dim_embedding=dim_embedding,
            align_size=junction_pooling_size,
        )

        self.adj_embed = nn.Sequential(
            models.common.double_conv(dim_embedding, dim_embedding),
        )

        self.adj_embed.apply(models.common.weights_init)
        if junction_heatmap_criterion == "focal":
            self.hm_crit = BinaryFocalLoss()
        else:
            self.hm_crit = getattr(F, junction_heatmap_criterion)
        if adjacency_matrix_criterion == "focal":
            self.adj_crit = BinaryFocalLoss()
        else:
            self.adj_crit = getattr(F, adjacency_matrix_criterion)
        self.adj_block_size = block_inference_size
        self.max_junctions = max_junctions
        self.weight_fn = weight_fn
        self.is_train_junc = is_train_junc
        self.is_train_adj = is_train_adj
        self.enable_junc_infer = enable_junc_infer
        self.enable_adj_infer = enable_adj_infer

    def forward(self, img, junc_map_gt, adj_matrix_gt, junc_loss_lambda=1., adj_loss_lambda=1., junc_coord_gt=None):
        img = self.prep_data(img)
        feat = self.backbone(img)
        bs = img.size(0)

        if self.enable_junc_infer:
            if self.is_train_junc:
                junc_hm, junc_coords = self.junc_infer(feat)
            else:
                with th.no_grad():
                    junc_hm, junc_coords = self.junc_infer(feat)
            # junc_coords[junc_coords[:, 1:].sum(dim=1) == 0] += 0.1
            # padding junction prediction
            junc_cnt = []
            j0 = 0
            for b in range(bs):
                junc_cnt.append(0)
                for j in range(j0, len(junc_coords)):
                    if np.isclose(junc_coords[j, 0].item(), b, atol=.1):
                        junc_cnt[-1] += 1
                    else:
                        j0 = j
                        break
            junc_st = np.cumsum([0] + junc_cnt).tolist()
            junc_pred = junc_coords.new_full((bs, self.max_junctions, 2), 0.)
            for b in range(bs):
                junc_pred[b, :junc_cnt[b]] = junc_coords[junc_st[b]:junc_st[b + 1], 1:] + .1
            loss_hm = self.hm_crit(junc_hm, junc_map_gt) * junc_loss_lambda
        else:
            assert junc_coord_gt is not None
            junc_hm = junc_map_gt
            junc_pred = junc_coord_gt
            loss_hm = img.new_full((1, ), 0)

        if self.enable_adj_infer:
            # block-wise junction pooling and adjacency matrix inference
            # first count number of detected junctions of each image
            if junc_coord_gt is not None:
                junc_st = [0]
                junc_cnt = []
                for b in range(bs):
                    junc_cnt.append(th.sum(junc_coord_gt[b].sum(dim=1) != 0).item())
                    assert junc_cnt[-1] > 0
                    junc_st.append(junc_st[-1] + junc_cnt[-1])
                junc_coord_gt_ = img.new_full((sum(junc_cnt), 3), 0.)
                for b in range(bs):
                    junc_coord_gt_[junc_st[b]:junc_st[b+1], 0] = b
                    junc_coord_gt_[junc_st[b]:junc_st[b+1], 1:] = junc_coord_gt[b, :junc_cnt[b]]

            # then for each image, build list of subgraph that processes at most block_size junctions
            block_crit = []
            block_infer = []
            for b in range(bs):
                num_blocks = junc_cnt[b] // self.adj_block_size + (1 if junc_cnt[b] % self.adj_block_size else 0)
                for bst in range(num_blocks):
                    for bed in range(num_blocks):
                        st_st = junc_st[b] + bst * self.adj_block_size
                        st_len = min(self.adj_block_size, junc_cnt[b] - bst * self.adj_block_size)
                        ed_st = junc_st[b] + bed * self.adj_block_size
                        ed_len = min(self.adj_block_size, junc_cnt[b] - bed * self.adj_block_size)
                        block_crit.append(
                            BlockAdjacencyMatrixInferenceCriterion(
                                self.adj_crit, adj_matrix_gt, adj_loss_lambda, b,
                                bst * self.adj_block_size, min(self.adj_block_size, junc_cnt[b] - bst * self.adj_block_size),
                                bed * self.adj_block_size, min(self.adj_block_size, junc_cnt[b] - bed * self.adj_block_size),
                                None if junc_coord_gt is None else junc_coord_gt, img.size(2), self.weight_fn
                            )
                        )
                        block_infer.append(
                            BlockAdjacencyMatrixInference(
                                self.line_pool, self.adj_infer,
                                b, st_st, st_len, ed_st, ed_len, junc_coords if junc_coord_gt is None else junc_coord_gt_,
                            )
                        )

            def output_collect_fn(outputs):
                output = img.new_full((bs, self.max_junctions, self.max_junctions), 0.)
                current_block = 0
                for b in range(bs):
                    num_blocks = junc_cnt[b] // self.adj_block_size + (1 if junc_cnt[b] % self.adj_block_size else 0)
                    for bst in range(num_blocks):
                        for bed in range(num_blocks):
                            st_st = bst * self.adj_block_size
                            st_len = min(self.adj_block_size, junc_cnt[b] - bst * self.adj_block_size)
                            ed_st = bed * self.adj_block_size
                            ed_len = min(self.adj_block_size, junc_cnt[b] - bed * self.adj_block_size)
                            output[b, st_st:st_st + st_len, ed_st:ed_st + ed_len] = outputs[current_block]
                            current_block += 1

                return output

            block_adj_infer = models.common.GradAccumulator(
                block_crit, block_infer, output_collect_fn, reduce_method="mean"
            )
            if self.is_train_adj:
                feat_adj = self.adj_embed(feat)
                adj_matrix_pred, loss_adj = block_adj_infer(feat_adj)
            else:
                with th.no_grad():
                    feat_adj = self.adj_embed(feat)
                    adj_matrix_pred, loss_adj = block_adj_infer(feat_adj)
        else:
            adj_matrix_pred = adj_matrix_gt
            loss_adj = img.new_full((1, ), 0)

        return junc_pred, junc_hm, adj_matrix_pred, loss_hm, loss_adj


================================================
FILE: models/lsd_test.py
================================================
import torch as th
import torch.nn as nn
import models.graph
import models.backbone
import models.common
import numpy as np

from .lsd import LSDModule


class LSDDataLayer(nn.Module):
    def __init__(self, mean=None, std=None):
        super(LSDDataLayer, self).__init__()
        self.std = [1., 1., 1.] if std is None else std
        self.mean = [102.9801, 115.9465, 122.7717] if mean is None else mean

    def forward(self, img):
        assert img.size(1) == 3
        for ch in range(3):
            img[:, ch, :, :] = (img[:, ch, :, :] - self.mean[ch]) / self.std[ch]

        return img


# noinspection PyTypeChecker
class BinaryFocalLoss(nn.Module):
    def __init__(self, gamma=2., alpha=0.25, size_average=True):
        super(BinaryFocalLoss, self).__init__()
        self.gamma = gamma
        self.alpha = alpha
        self.size_average = size_average

    def forward(self, input, target, weight=None):
        if weight is not None:
            assert weight.size() == input.size(), f"weight size: {weight.size()}, input size: {input.size()}"
            assert (weight >= 0).all() and (weight <= 1).all(), f"weight max: {weight.max()}, min: {weight.min()}"
        input = input.clamp(1.e-6, 1. - 1.e-6)
        if weight is None:
            loss = th.sum(
                - self.alpha * target * ((1 - input) ** self.gamma) * th.log(input)
                - (1 - self.alpha) * (1 - target) * (input ** self.gamma) * th.log(1 - input))
        else:
            loss = th.sum(
                (- self.alpha * target * ((1 - input) ** self.gamma) * th.log(input)
                 - (1 - self.alpha) * (1 - target) * (input ** self.gamma) * th.log(1 - input)) * weight
            )
        if self.size_average:
            loss /= input.nelement()
        return loss


class BlockAdjacencyMatrixInference(nn.Module):
    def __init__(self,
                 line_pool_module, adj_infer_module,
                 current_batch_id, junc_st_st, junc_st_len, junc_ed_st, junc_ed_len, junc_pred
                 ):
        super(BlockAdjacencyMatrixInference, self).__init__()
        self.line_pool = line_pool_module
        self.adj_infer = adj_infer_module
        self.b = current_batch_id
        self.st_st = junc_st_st
        self.st_len = junc_st_len
        self.ed_st = junc_ed_st
        self.ed_len = junc_ed_len
        self.juncs = junc_pred

    def forward(self, feat):
        junc_st = self.juncs.narrow(0, self.st_st, self.st_len)
        junc_ed = self.juncs.narrow(0, self.ed_st, self.ed_len)
        assert (junc_st[:, 0] == self.b).all() and (junc_ed[:, 0] == self.b).all(), f"{self.b}\n{junc_st[:, 0]}\n{junc_ed[:, 0]}"
        line_feat = self.line_pool(feat, junc_st, junc_ed)
        block_adj_matrix = self.adj_infer(line_feat)

        return block_adj_matrix


class BlockAdjacencyMatrixInferenceCriterion(nn.Module):
    def __init__(self, adj_matrix_crit, adj_matrix_gt, adj_matrix_loss_lambda,
                 current_batch_id, mtx_st_st, mtx_st_len, mtx_ed_st, mtx_ed_len,
                 junc_padded, img_size, line_seg_length_weight_fn
                 ):
        super(BlockAdjacencyMatrixInferenceCriterion, self).__init__()
        self.adj_crit = adj_matrix_crit
        self.adj_gt = adj_matrix_gt
        self.loss_lambda = adj_matrix_loss_lambda
        self.b = current_batch_id
        self.st_st = mtx_st_st
        self.st_len = mtx_st_len
        self.ed_st = mtx_ed_st
        self.ed_len = mtx_ed_len
        self.junc = junc_padded
        self.img_size = img_size
        self.weight = line_seg_length_weight_fn

    def forward(self, block_adj_matrix):
        block_adj_matrix_gt = self.adj_gt[self.b, self.st_st:self.st_st+self.st_len, self.ed_st:self.ed_st+self.ed_len]
        if self.junc is not None:
            junc_st = self.junc[self.b, self.st_st:self.st_st + self.st_len].view(self.st_len, 1, 2).expand(self.st_len,
                                                                                                            self.ed_len,
                                                                                                            2)
            junc_ed = self.junc[self.b, self.ed_st:self.ed_st + self.ed_len].view(1, self.ed_len, 2).expand(self.st_len,
                                                                                                            self.ed_len,
                                                                                                            2)
            line_len = (junc_ed - junc_st).norm(dim=2)
            return self.loss_lambda * self.adj_crit(block_adj_matrix, block_adj_matrix_gt, weight=self.weight(line_len, self.img_size * 1.4143))
        else:
            return self.loss_lambda * self.adj_crit(block_adj_matrix, block_adj_matrix_gt)


class LSDTestModule(LSDModule):
    def __init__(
            self,
            # backbone parameters
            backbone="unet",
            dim_embedding=256,
            backbone_kwargs={},
            # junction inference parameters
            junction_pooling_threshold=0.2,
            max_junctions=512,
            feature_spatial_scale=0.25,
            # junction pooling parameters
            junction_pooling_size=15.,
    ):
        super(LSDTestModule, self).__init__(
            backbone=backbone,
            dim_embedding=dim_embedding,
            backbone_kwargs=backbone_kwargs,
            # junction inference parameters
            junction_pooling_threshold=junction_pooling_threshold,
            max_junctions=max_junctions,
            feature_spatial_scale=feature_spatial_scale,
            junction_heatmap_criterion="binary_cross_entropy",
            # junction pooling parameters
            junction_pooling_size=junction_pooling_size,
            # adjacency matrix inference parameters
            block_inference_size=64,
            adjacency_matrix_criterion="binary_cross_entropy",
            weight_fn=None,
            is_train_junc=True,
            is_train_adj=True,
            enable_junc_infer=True,
            enable_adj_infer=True,
            verbose=True,
        )

    def forward(self, img):
        img = self.prep_data(img)
        feat = self.backbone(img)
        bs = img.size(0)

        junc_hm, junc_coords = self.junc_infer(feat)

        # first count number of detected junctions of each image
        junc_cnt = []
        j0 = 0
        for b in range(bs):
            junc_cnt.append(0)
            for j in range(j0, len(junc_coords)):
                if np.isclose(junc_coords[j, 0].item(), b, atol=.1):
                    junc_cnt[-1] += 1
                else:
                    j0 = j
                    break
        junc_st = np.cumsum([0] + junc_cnt).tolist()
        junc_pred = junc_coords.new_full((bs, self.max_junctions, 2), 0.)
        for b in range(bs):
            junc_pred[b, :junc_cnt[b]] = junc_coords[junc_st[b]:junc_st[b + 1], 1:] + .1

        # then for each image, build list of subgraph that processes at most block_size junctions
        block_infer = []
        for b in range(bs):
            num_blocks = junc_cnt[b] // self.adj_block_size + (1 if junc_cnt[b] % self.adj_block_size else 0)
            for bst in range(num_blocks):
                for bed in range(num_blocks):
                    st_st = junc_st[b] + bst * self.adj_block_size
                    st_len = min(self.adj_block_size, junc_cnt[b] - bst * self.adj_block_size)
                    ed_st = junc_st[b] + bed * self.adj_block_size
                    ed_len = min(self.adj_block_size, junc_cnt[b] - bed * self.adj_block_size)
                    block_infer.append(
                        BlockAdjacencyMatrixInference(
                            self.line_pool, self.adj_infer,
                            b, st_st, st_len, ed_st, ed_len, junc_coords,
                        )
                    )

        def output_collect_fn(outputs):
            output = img.new_full((bs, self.max_junctions, self.max_junctions), 0.)
            current_block = 0
            for b in range(bs):
                num_blocks = junc_cnt[b] // self.adj_block_size + (1 if junc_cnt[b] % self.adj_block_size else 0)
                for bst in range(num_blocks):
                    for bed in range(num_blocks):
                        st_st = bst * self.adj_block_size
                        st_len = min(self.adj_block_size, junc_cnt[b] - bst * self.adj_block_size)
                        ed_st = bed * self.adj_block_size
                        ed_len = min(self.adj_block_size, junc_cnt[b] - bed * self.adj_block_size)
                        output[b, st_st:st_st + st_len, ed_st:ed_st + ed_len] = outputs[current_block]
                        current_block += 1

            return output

        block_adj_infer = models.common.GradAccumulator(None, block_infer, output_collect_fn)
        feat_adj = self.adj_embed(feat)
        adj_matrix_pred, _ = block_adj_infer(feat_adj)

        return junc_pred, junc_hm, adj_matrix_pred


================================================
FILE: models/pretrained/resnet50-imagenet.pth
================================================
version https://git-lfs.github.com/spec/v1
oid sha256:902bbfbc9b570be36e0f94e757b211b24f9216df14260c58b359ddb182c0723e
size 1122304


================================================
FILE: models/test_common.py
================================================
from unittest import TestCase
import numpy as np
import models.common as common
import torch as th
import torch.nn as nn
from itertools import chain
from libs.roi_align.modules.roi_align import RoIAlign


class TestRoiPooling(TestCase):
    def setUp(self):
        th.manual_seed(1234)
        self.test_data = dict(
            input=th.rand(2, 3, 9, 10),
            rois=th.tensor([
                [0, 0, 0, 9, 8], # whole feature map
                [1, 0, 0, 9, 8], # whole feature map
                [0, 0, 0, 0, 0], # top left pixel
                [0, 0, 0, 1, 1], # top left 2x2
                [0, 0, 0, -1, -1], # bottom right out of range
                [1, 9, 8, 9, 8], # bottom right pixel
                [1, -3, -5, 6, 7], # top left out of range
                [1, -3, -5, 10, 11],  # both corner out of range
                [0, 3, 2, 9, 8], # 7x7 roi
                [0, 3, 3, 8, 8], # 6x6 roi
                [1, 1, 1, 5, 5], # 5x5 roi
            ], dtype=th.float32)
        )

    def test_output_size_7x7(self):
        input, rois = self.test_data["input"], self.test_data["rois"]
        output7x7 = common.roi_pooling(
            input=input,
            rois=rois,
            size=(7, 7),
            spatial_scale=1.0
        )
        # output7x7 = RoIAlign(aligned_height=7, aligned_width=7, spatial_scale=1.)(input, rois)
        self.assertEqual(output7x7.size(0), rois.size(0), "output size dismatches rois size")
        self.assertEqual(input.size(1), output7x7.size(1), "output channel dismatch input channel")
        self.assertTupleEqual(output7x7.shape[2:], (7, 7), "output shape dismatch required shape")

    def test_output_size_5x5(self):
        input, rois = self.test_data["input"], self.test_data["rois"]
        output5x5 = common.roi_pooling(
            input=input,
            rois=rois,
            size=(5, 5),
            spatial_scale=1.0
        )
        # output5x5 = RoIAlign(aligned_height=5, aligned_width=5, spatial_scale=1.)(input, rois)
        self.assertTupleEqual(output5x5.shape[2:], (5, 5), "output shape dismatch required shape")

    def test_output_size_1x1(self):
        input, rois = self.test_data["input"], self.test_data["rois"]
        output1x1 = common.roi_pooling(
            input=input,
            rois=rois,
            size=(1, 1),
            spatial_scale=1.0
        )
        self.assertTupleEqual(output1x1.shape[2:], (1, 1), "output shape dismatch required shape")

    def test_output_value(self):
        input, rois = self.test_data["input"], self.test_data["rois"]
        # rois[2, 3:] -= 1
        output1x1 = common.roi_pooling(
            input=input,
            rois=rois,
            size=(1, 1),
            spatial_scale=1.0
        )
        # output1x1 = RoIAlign(aligned_height=1, aligned_width=1, spatial_scale=1.)(input, rois)
        output5x5 = common.roi_pooling(
            input=input,
            rois=rois,
            size=(5, 5),
            spatial_scale=1.0
        )
        # output7x7 = common.roi_pooling(
        #     input=input,
        #     rois=rois,
        #     size=(7, 7),
        #     spatial_scale=1.0
        # )
        rois[8, 3:] -= 1
        rois[10, 3:] -= 1
        output7x7 = RoIAlign(aligned_height=7, aligned_width=7, spatial_scale=1.)(input, rois)
        output5x5 = RoIAlign(aligned_height=5, aligned_width=5, spatial_scale=1.)(input, rois)
        self.assertTrue((output1x1[2] == input[0, :, :1, :1]).all())
        self.assertTrue((output1x1[5] == input[1, :, 8:, 9:]).all())
        self.assertTrue((output5x5[10] == input[1, :, 1:6, 1:6]).all())
        self.assertTrue((output7x7[8] == input[0, :, 2:9, 3:10]).all())


class TestGradAccumulator(TestCase):
    def setUp(self):
        th.manual_seed(789)
        self.net1 = nn.Sequential(
            nn.Conv2d(16, 32, 3, 1, 1, bias=True),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True)
        )
        self.net2 = nn.Sequential(
            nn.Conv2d(32, 16, 3, 1, 1, bias=True),
            nn.BatchNorm2d(16),
            nn.ReLU(inplace=True),
            nn.Conv2d(16, 1, 3, 1, 1, bias=True),
        )
        self.test_data = dict(
            X = th.rand(100, 16, 32, 32),
            Y = th.rand(100, 1, 32, 32),
        )
        self.net1.apply(common.weights_init)
        self.net2.apply(common.weights_init)
        self.crit = nn.SmoothL1Loss()

    def test_forward_pass(self):
        X, Y = self.test_data["X"], self.test_data["Y"]
        with th.no_grad():
            feat = self.net1(X)
            out = []
            for i in range(4):
                for j in range(4):
                    out.append(self.net2(feat[:, :, i*4:(i+1)*4, j*4:(j+1)*4]))
            loss = 0
            for i in range(4):
                for j in range(4):
                    loss += self.crit(out[i*4+j], Y[:, :, i*4:(i+1)*4, j*4:(j+1)*4])
            loss /= 16

            class Net2(nn.Module):
                def __init__(self, net2, st_st, st_ed, ed_st, ed_ed):
                    super(Net2, self).__init__()
                    self.net2 = net2
                    self.st_st = st_st
                    self.st_ed = st_ed
                    self.ed_st = ed_st
                    self.ed_ed = ed_ed

                def forward(self, input):
                    return self.net2(input[:, :, self.st_st:self.st_ed, self.ed_st:self.ed_ed])

            class Crit(nn.Module):
                def __init__(self, crit, target, st_st, st_ed, ed_st, ed_ed):
                    super(Crit, self).__init__()
                    self.crit = crit
                    self.st_st = st_st
                    self.st_ed = st_ed
                    self.ed_st = ed_st
                    self.ed_ed = ed_ed
                    self.register_buffer("target", target)

                def forward(self, x):
                    return self.crit(x, self.target[:, :, self.st_st:self.st_ed, self.ed_st:self.ed_ed])

            gradacc = common.GradAccumulator(
                [Crit(self.crit, Y, i*4, (i+1)*4, j*4, (j+1)*4) for i in range(4) for j in range(4)],
                [Net2(self.net2, i*4, (i+1)*4, j*4, (j+1)*4) for i in range(4) for j in range(4)],
                collect_fn=None
            )
            net_ = nn.Sequential(self.net1, gradacc)
            out_, loss_ = net_(X)

        for i in range(len(out)):
            self.assertTrue(th.allclose(out[i], out_[i]))
        self.assertTrue(th.allclose(loss, loss_), f"{loss}\n{loss_}")

    def test_backward_pass(self):
        X, Y = self.test_data["X"], self.test_data["Y"]

        self.net1.zero_grad()
        self.net2.zero_grad()

        feat = self.net1(X)
        out = []
        for i in range(4):
            for j in range(4):
                out.append(self.net2(feat[:, :, i * 4:(i + 1) * 4, j * 4:(j + 1) * 4]))
        loss = 0
        for i in range(4):
            for j in range(4):
                loss += self.crit(out[i * 4 + j], Y[:, :, i * 4:(i + 1) * 4, j * 4:(j + 1) * 4])
        loss /= 16
        loss.backward()

        grad = {}

        for k, v in chain(self.net1.named_parameters(prefix="net1"), self.net2.named_parameters(prefix="net2")):
            grad[k] = th.tensor(v.grad)

        self.net1.zero_grad()
        self.net2.zero_grad()

        class Net2(nn.Module):
            def __init__(self, net2, st_st, st_ed, ed_st, ed_ed):
                super(Net2, self).__init__()
                self.net2 = net2
                self.st_st = st_st
                self.st_ed = st_ed
                self.ed_st = ed_st
                self.ed_ed = ed_ed

            def forward(self, input):
                return self.net2(input[:, :, self.st_st:self.st_ed, self.ed_st:self.ed_ed])

        class Crit(nn.Module):
            def __init__(self, crit, target, st_st, st_ed, ed_st, ed_ed):
                super(Crit, self).__init__()
                self.crit = crit
                self.st_st = st_st
                self.st_ed = st_ed
                self.ed_st = ed_st
                self.ed_ed = ed_ed
                self.register_buffer("target", target)

            def forward(self, x):
                return self.crit(x, self.target[:, :, self.st_st:self.st_ed, self.ed_st:self.ed_ed])

        gradacc = common.GradAccumulator(
            [Crit(self.crit, Y, i * 4, (i + 1) * 4, j * 4, (j + 1) * 4) for i in range(4) for j in range(4)],
            [Net2(self.net2, i * 4, (i + 1) * 4, j * 4, (j + 1) * 4) for i in range(4) for j in range(4)],
            collect_fn=None
        )
        net_ = nn.Sequential(self.net1, gradacc)
        out_, loss_ = net_(X)
        loss_.backward()

        grad_ = {}
        for k, v in chain(self.net1.named_parameters(prefix="net1"), self.net2.named_parameters(prefix="net2")):
            grad_[k] = th.tensor(v.grad)

        for k in sorted(grad.keys()):
            self.assertTrue(th.allclose(grad[k], grad_[k]), f"{k}:\n{grad[k]}\n{grad_[k]}")


================================================
FILE: models/test_graph.py
================================================
from unittest import TestCase
import torch as th
import models.graph as graph
import matplotlib.pyplot as plt


class TestJunctionInference(TestCase):
    def test_junction_inference_forward(self):
        junc_infer = graph.JunctionInference(256, 0.1, 512, 0.25, False)
        with th.no_grad():
            feat_map = th.rand(5, 256, 32, 64)
            junc_map, junc_coord = junc_infer(feat_map)
            self.assertTrue(junc_map.size(0) == 5)
            self.assertTupleEqual(junc_map.shape[1:], (128, 256))
            self.assertTrue(junc_coord.size(1) == 3)
            self.assertTrue((junc_coord[:, 1:] >= 0).all() and (junc_coord[:, 1] < 256).all() and (junc_coord[:, 2] < 128).all())


class TestJunctionPooling(TestCase):
    def test_junc_pooling_forward(self):
        junc_infer = graph.JunctionInference(256, 0.1, 512, 0.25, False)
        junc_pool = graph.JunctionPooling(5, 5, 0.25)
        with th.no_grad():
            feat_map = th.rand(5, 256, 32, 64)
            junc_map, junc_coord = junc_infer(feat_map)
            out = junc_pool(feat_map, junc_coord)
            self.assertTrue(out.size(0) == junc_coord.size(0))
            self.assertTrue(out.size(1) == feat_map.size(1))
            self.assertTrue(out.shape[2:] == (5, 5))


class TestDirectionalAttention(TestCase):
    def test_attention_forward(self):
        attn = graph.DirectionalAttention(
            15,
            attn_sigma_dir=3.1415926/300,
            attn_sigma_pos=2
        )
        junc = th.tensor([
            [0., 0.],
            [2., 0.],
            [1., 1.],
            [0., 1.]
        ])
        map_st2ed, map_ed2st = attn(junc, junc)
        self.assertFalse(th.isnan(map_st2ed).any() or th.isnan(map_ed2st).any())
        self.assertTupleEqual(map_st2ed.size(), map_ed2st.size())
        self.assertTupleEqual(map_st2ed.shape[2:], map_ed2st.shape[2:])
        self.assertTupleEqual(map_st2ed.shape[2:], (15, 15))
        self.assertTupleEqual(map_st2ed.shape[:2], (4, 4))
        attn_map_st2ed = map_st2ed.numpy()
        attn_map_ed2st = map_ed2st.numpy()
        figs, axes = plt.subplots(4, 4)
        for i in range(4):
            for j in range(4):
                axes[i, j].imshow(attn_map_ed2st[i, j])
        plt.show()
        map_ct = attn(junc)
        self.assertTupleEqual(map_ct.size(), (15, 15))


class TestAdjacencyMatrixInference(TestCase):
    def test_adjacency_matrix_inference_forward(self):
        attn = graph.DirectionalAttention(15)
        junc_st = th.rand(30, 2)
        junc_ed = th.rand(60, 2)
        attn_st2ed, attn_ed2st = attn(junc_st, junc_ed)
        attn_center = attn()
        adj_with_center = graph.AdjacencyMatrixInference(256, junc_align_size=15, align_center=True)
        adj_wo_center = graph.AdjacencyMatrixInference(256, junc_align_size=15, align_center=False)
        feat_start = th.rand(30, 128, 15, 15)
        feat_end = th.rand(60, 128, 15, 15)
        feat_center = th.rand(30, 60, 128, 15, 15)
        adjacency_matrix_with_center = adj_with_center(feat_start, feat_end, attn_st2ed, attn_ed2st, feat_center, attn_center)
        adjacency_matrix_wo_center = adj_wo_center(feat_start, feat_end, attn_st2ed, attn_ed2st)
        self.assertTrue(adjacency_matrix_with_center.size() == adjacency_matrix_wo_center.size() == (30, 60))


================================================
FILE: test.py
================================================
# System libs
import os
import time

# Numerical libs
import torch
import torch.nn as nn
from torch import optim
from torch.utils.data import DataLoader
import numpy as np

# Our libs
from data.sist_line import SISTLine
import data.transforms as tf
from models.lsd_test import LSDTestModule
from utils import AverageMeter, graph2line, draw_lines, draw_jucntions

# tensorboard
from tensorboardX import SummaryWriter
import torchvision.utils as vutils

import fire
import cv2


class LSD(object):
    def __init__(
            self,
            # exp params
            exp_name="u50_block",
            # arch params
            backbone="resnet50",
            backbone_kwargs={},
            dim_embedding=256,
            feature_spatial_scale=0.25,
            max_junctions=512,
            junction_pooling_threshold=0.2,
            junc_pooling_size=15,
            block_inference_size=64,
            # data params
            img_size=416,
            gpus=[0,],
            resume_epoch="latest",
            # vis params
            vis_junc_th=0.3,
            vis_line_th=0.3
    ):
        os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(str(c) for c in gpus)

        self.is_cuda = bool(gpus)

        self.model = LSDTestModule(
            backbone=backbone,
            dim_embedding=dim_embedding,
            backbone_kwargs=backbone_kwargs,
            junction_pooling_threshold=junction_pooling_threshold,
            max_junctions=max_junctions,
            feature_spatial_scale=feature_spatial_scale,
            junction_pooling_size=junc_pooling_size,
        )

        self.exp_name = exp_name
        os.makedirs(os.path.join("log", exp_name), exist_ok=True)
        os.makedirs(os.path.join("ckpt", exp_name), exist_ok=True)
        self.writer = SummaryWriter(log_dir=os.path.join("log", exp_name))

        # checkpoints
        self.states = dict(
            last_epoch=-1,
            elapsed_time=0,
            state_dict=None
        )

        if resume_epoch and os.path.isfile(os.path.join("ckpt", exp_name, f"train_states_{resume_epoch}.pth")):
            states = torch.load(
                os.path.join("ckpt", exp_name, f"train_states_{resume_epoch}.pth"))
            print(f"resume traning from epoch {states['last_epoch']}")
            self.model.load_state_dict(states["state_dict"])
            self.states.update(states)

        self.vis_junc_th = vis_junc_th
        self.vis_line_th = vis_line_th
        self.block_size = block_inference_size
        self.max_junctions = max_junctions
        self.img_size = img_size

    def end(self):
        self.writer.close()
        return "command queue finished."

    def test(self, path_to_image):
        # main loop
        torch.set_grad_enabled(False)
        print(f"test for image: {path_to_image}", flush=True)

        if self.is_cuda:
            model = self.model.cuda().eval()
        else:
            model = self.model.eval()

        img = cv2.imread(path_to_image)
        img = cv2.resize(img, (self.img_size, self.img_size))
        img_reverse = img[..., [2, 1, 0]]
        img = torch.from_numpy(img_reverse).float().permute(2, 0, 1).unsqueeze(0)

        if self.is_cuda:
            img = img.cuda()

        # measure elapsed time
        junc_pred, heatmap_pred, adj_mtx_pred = model(img)

        # visualize eval
        img = img.cpu().numpy()
        junctions_pred = junc_pred.cpu().numpy()
        adj_mtx = adj_mtx_pred.cpu().numpy()

        img_with_junc = draw_jucntions(img, junctions_pred)
        img_with_junc = img_with_junc[0].numpy()[None]
        img_with_junc = img_with_junc[:, ::-1, :, :]
        lines_pred, score_pred = graph2line(junctions_pred, adj_mtx)
        vis_line_pred = draw_lines(img_with_junc, lines_pred, score_pred)[0]
        vis_line_pred = vis_line_pred.permute(1, 2, 0).numpy()

        cv2.imshow("result", vis_line_pred)


if __name__ == "__main__":
    fire.Fire(LSD)
    # trainer = LSDTrainer().train(lr=1.)


================================================
FILE: test.sh
================================================
python test.py \
--exp-name line_weighted_wo_focal_junc --backbone resnet50 \
--backbone-kwargs '{"encoder_weights": "ckpt/backbone/encoder_epoch_20.pth", "decoder_weights": "ckpt/backbone/decoder_epoch_20.pth"}' \
--dim-embedding 256 --junction-pooling-threshold 0.2 \
--junc-pooling-size 64 --block-inference-size 128 \
--gpus 0, --resume-epoch latest \
--vis-junc-th 0.25 --vis-line-th 0.25 \
    - test $1


================================================
FILE: tools/rebuild_yorkurban.py
================================================
from data.line_graph import LineGraph
import os
from scipy import io
import numpy as np
from shutil import copyfile
from tqdm import trange

data_root = "/home/ziheng/YorkUrbanDB"
out_root = "/home/ziheng/YorkUrbanDB_new/test"


list_file = io.loadmat(os.path.join(data_root, "Manhattan_Image_DB_Names.mat"))
name_list = [e[0][0].strip("\\") for e in list_file["Manhattan_Image_DB_Names"]]
test_set = io.loadmat(os.path.join(data_root, "ECCV_TrainingAndTestImageNumbers.mat"))
test_set_id = test_set["testSetIndex"].flatten().tolist()
imgs = [os.path.join(data_root, name_list[i - 1], name_list[i - 1] + ".jpg") for i in test_set_id]
labels = [io.loadmat(os.path.join(data_root, name_list[i - 1], name_list[i - 1] + "LinesAndVP.mat")) for i in test_set_id]
lines = [np.float32(lab["lines"]).reshape((-1, 4)) for lab in labels]
maps = [np.uint8(lab["finalImg"]) for lab in labels]

os.makedirs(out_root, exist_ok=True)
max_juncs = 512
for i in trange(len(imgs)):
    img, line = imgs[i], lines[i]
    fname = os.path.basename(img)[:-4]
    hm = maps[i]
    lg = LineGraph(eps_junction=1., eps_line_deg=np.pi / 30, verbose=False)
    for x1, y1, x2, y2 in line:
        lg.add_junction((x1, y1))
        lg.add_junction((x2, y2))
    lg.freeze_junction()
    for x1, y1, x2, y2 in line:
        lg.add_line_seg((x1, y1), (x2, y2))
    lg.freeze_line_seg()
    max_juncs = max(lg.num_junctions, max_juncs)
    lg.save(os.path.join(out_root, fname + ".lg"))
    copyfile(img, os.path.join(out_root, fname + ".jpg"))
    # img = cv2.imread(img)
    # print(fname, flush=True)
    # cv2.imshow("line_", lg.line_map(img.shape[:2]))
    # cv2.imshow("line", hm)
    # cv2.waitKey()

print(max_juncs)


================================================
FILE: train.sh
================================================
python main.py \
--exp-name line_weighted_wo_focal_junc --backbone resnet50 \
--backbone-kwargs '{"encoder_weights": "ckpt/backbone/encoder_epoch_20.pth", "decoder_weights": "ckpt/backbone/decoder_epoch_20.pth"}' \
--dim-embedding 256 --junction-pooling-threshold 0.2 \
--junc-pooling-size 64 --attention-sigma 1.5 --block-inference-size 128 \
--data-root /data/path --junc-sigma 3 \
--batch-size 16 --gpus 0,1,2,3 --num-workers 10 --resume-epoch latest \
--is-train-junc True --is-train-adj True \
--vis-junc-th 0.1 --vis-line-th 0.1 \
    - train --end-epoch 9 --solver SGD --lr 0.2 --weight-decay 5e-4 --lambda-heatmap 1. --lambda-adj 5. \
    - train --end-epoch 15 --solver SGD --lr 0.02 --weight-decay 5e-4 --lambda-heatmap 1. --lambda-adj 10. \
    - train --end-epoch 30 --solver SGD --lr 0.002 --weight-decay 5e-4 --lambda-heatmap 1. --lambda-adj 10. \
    - end


================================================
FILE: utils.py
================================================
import numpy as np
import re
import functools
import torch as th
import cv2
from numba import jit


def graph2line(junctions, adj_mtx, threshold=0.5):
    assert len(junctions) == len(adj_mtx)
    # assert np.allclose(adj_mtx, adj_mtx.transpose((0, 2, 1)), rtol=1e-2, atol=1e-2), f"{adj_mtx}"
    bs = len(junctions)
    lines = []
    scores = []
    for b in range(bs):
        junc = junctions[b]
        mtx = adj_mtx[b]
        num_junc = np.sum(junc.sum(axis=1) > 0)
        line = []
        score = []
        for i in range(num_junc):
            for j in range(i, num_junc):
                if mtx[i, j] > threshold:
                    line.append(np.hstack((junc[i], junc[j])))
                    score.append(mtx[i, j])
        scores.append(np.array(score))
        lines.append(np.array(line))

    return lines, scores


def draw_lines(imgs, lines, scores=None, width=2):
    assert len(imgs) == len(lines)
    imgs = np.uint8(imgs)
    bs = len(imgs)
    if scores is not None:
        assert len(scores) == bs
    res = []
    for b in range(bs):
        img = imgs[b].transpose((1, 2, 0))
        line = lines[b]
        if scores is None:
            score = np.zeros(len(line))
        else:
            score = scores[b]
        img = img.copy()
        for (x1, y1, x2, y2), c in zip(line, score):
            pt1, pt2 = (x1, y1), (x2, y2)
            c = tuple(cv2.applyColorMap(np.array(c * 255, dtype=np.uint8), cv2.COLORMAP_JET).flatten().tolist())
            img = cv2.line(img, pt1, pt2, c, width)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        res.append(th.from_numpy(img.transpose((2, 0, 1))))

    return res


def draw_jucntions(hms, junctions):
    assert len(hms) == len(junctions)
    if hms.ndim == 3:
        imgs = np.uint8(hms * 255)
    else:
        imgs = np.uint8(hms)
    bs = len(imgs)
    res = []
    for b in range(bs):
        if hms.ndim == 3:
            img = cv2.cvtColor(imgs[b], cv2.COLOR_GRAY2BGR)
        else:
            img = np.array(imgs[b].transpose((1, 2, 0)))
        junc = junctions[b]
        junc = junc[junc.sum(axis=1) > 0.1]
        if hms.ndim == 3:
            score = hms[b][np.int32(junc[:, 1]), np.int32(junc[:, 0])]
        else:
            score = [1.] * len(junc)
        img = img.copy()
        for (x, y), c in zip(junc, score):
            c = tuple(cv2.applyColorMap(np.array(c * 255, dtype=np.uint8), cv2.COLORMAP_JET).flatten().tolist())
            cv2.circle(img, (x, y), 5, c, thickness=2)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        res.append(th.from_numpy(img.transpose((2, 0, 1))))

    return res


class AverageMeter(object):
    """Computes and stores the average and current value"""

    def __init__(self):
        self.initialized = False
        self.val = None
        self.avg = None
        self.sum = None
        self.count = None

    def initialize(self, val, weight):
        self.val = val
        self.avg = val
        self.sum = val * weight
        self.count = weight
        self.initialized = True

    def update(self, val, weight=1):
        if not self.initialized:
            self.initialize(val, weight)
        else:
            self.add(val, weight)

    def add(self, val, weight):
        self.val = val
        self.sum += val * weight
        self.count += weight
        self.avg = self.sum / self.count

    def value(self):
        return self.val

    def average(self):
        return self.avg


def unique(ar, return_index=False, return_inverse=False, return_counts=False):
    ar = np.asanyarray(ar).flatten()

    optional_indices = return_index or return_inverse
    optional_returns = optional_indices or return_counts

    if ar.size == 0:
        if not optional_returns:
            ret = ar
        else:
            ret = (ar,)
            if return_index:
                ret += (np.empty(0, np.bool),)
            if return_inverse:
                ret += (np.empty(0, np.bool),)
            if return_counts:
                ret += (np.empty(0, np.intp),)
        return ret
    if optional_indices:
        perm = ar.argsort(kind='mergesort' if return_index else 'quicksort')
        aux = ar[perm]
    else:
        ar.sort()
        aux = ar
    flag = np.concatenate(([True], aux[1:] != aux[:-1]))

    if not optional_returns:
        ret = aux[flag]
    else:
        ret = (aux[flag],)
        if return_index:
            ret += (perm[flag],)
        if return_inverse:
            iflag = np.cumsum(flag) - 1
            inv_idx = np.empty(ar.shape, dtype=np.intp)
            inv_idx[perm] = iflag
            ret += (inv_idx,)
        if return_counts:
            idx = np.concatenate(np.nonzero(flag) + ([ar.size],))
            ret += (np.diff(idx),)
    return ret


def colorEncode(labelmap, colors, mode='BGR'):
    labelmap = labelmap.astype('int')
    labelmap_rgb = np.zeros((labelmap.shape[0], labelmap.shape[1], 3),
                            dtype=np.uint8)
    for label in unique(labelmap):
        if label < 0:
            continue
        labelmap_rgb += (labelmap == label)[:, :, np.newaxis] * \
                        np.tile(colors[label],
                                (labelmap.shape[0], labelmap.shape[1], 1))

    if mode == 'BGR':
        return labelmap_rgb[:, :, ::-1]
    else:
        return labelmap_rgb


def accuracy(preds, label):
    valid = (label >= 0)
    acc_sum = (valid * (preds == label)).sum()
    valid_sum = valid.sum()
    acc = float(acc_sum) / (valid_sum + 1e-10)
    return acc, valid_sum


def intersectionAndUnion(imPred, imLab, numClass):
    imPred = np.asarray(imPred).copy()
    imLab = np.asarray(imLab).copy()

    imPred += 1
    imLab += 1
    # Remove classes from unlabeled pixels in gt image.
    # We should not penalize detections in unlabeled portions of the image.
    imPred = imPred * (imLab > 0)

    # Compute area intersection:
    intersection = imPred * (imPred == imLab)
    (area_intersection, _) = np.histogram(
        intersection, bins=numClass, range=(1, numClass))

    # Compute area union:
    (area_pred, _) = np.histogram(imPred, bins=numClass, range=(1, numClass))
    (area_lab, _) = np.histogram(imLab, bins=numClass, range=(1, numClass))
    area_union = area_pred + area_lab - area_intersection

    return (area_intersection, area_union)


class NotSupportedCliException(Exception):
    pass


def process_range(xpu, inp):
    start, end = map(int, inp)
    if start > end:
        end, start = start, end
    return map(lambda x: '{}{}'.format(xpu, x), range(start, end + 1))


REGEX = [
    (re.compile(r'^gpu(\d+)$'), lambda x: ['gpu%s' % x[0]]),
    (re.compile(r'^(\d+)$'), lambda x: ['gpu%s' % x[0]]),
    (re.compile(r'^gpu(\d+)-(?:gpu)?(\d+)$'),
     functools.partial(process_range, 'gpu')),
    (re.compile(r'^(\d+)-(\d+)$'),
     functools.partial(process_range, 'gpu')),
]


def parse_devices(input_devices):
    """Parse user's devices input str to standard format.
    e.g. [gpu0, gpu1, ...]

    """
    ret = []
    for d in input_devices.split(','):
        for regex, func in REGEX:
            m = regex.match(d.lower().strip())
            if m:
                tmp = func(m.groups())
                # prevent duplicate
                for x in tmp:
                    if x not in ret:
                        ret.append(x)
                break
        else:
            raise NotSupportedCliException(
                'Can not recognize device: "%s"' % d)
    return ret
Download .txt
gitextract_j4mo8f9e/

├── .gitattributes
├── .gitignore
├── LICENSE
├── README.md
├── ckpt/
│   └── backbone/
│       ├── decoder_epoch_20.pth
│       └── encoder_epoch_20.pth
├── data/
│   ├── common.py
│   ├── line_graph.py
│   ├── sist_line.py
│   ├── transforms.py
│   ├── utils.py
│   └── york_urban.py
├── main.py
├── models/
│   ├── __init__.py
│   ├── backbone.py
│   ├── common.py
│   ├── graph.py
│   ├── lsd.py
│   ├── lsd_test.py
│   ├── pretrained/
│   │   └── resnet50-imagenet.pth
│   ├── test_common.py
│   └── test_graph.py
├── test.py
├── test.sh
├── tools/
│   └── rebuild_yorkurban.py
├── train.sh
└── utils.py
Download .txt
SYMBOL INDEX (198 symbols across 16 files)

FILE: data/common.py
  function _assert_valid_param (line 6) | def _assert_valid_param(param):
  function assert_valid_param (line 12) | def assert_valid_param(param):
  function _fit_line (line 17) | def _fit_line(pts):
  function fit_line (line 28) | def fit_line(pts):
  function dist_pts_to_line (line 32) | def dist_pts_to_line(pts, param):
  function assert_pts_in_line (line 41) | def assert_pts_in_line(pts, param, atol=1.):
  function _find_pt_in_line (line 47) | def _find_pt_in_line(param):
  function find_pt_in_line (line 57) | def find_pt_in_line(param):
  function project_pts_on_line (line 61) | def project_pts_on_line(pts, param):
  function find_lines_intersect (line 74) | def find_lines_intersect(params):
  function is_pt_in_line_seg (line 90) | def is_pt_in_line_seg(eps, pt, pt1, pt2):

FILE: data/line_graph.py
  class LineGraph (line 11) | class LineGraph(object):
    method __init__ (line 12) | def __init__(
    method load (line 27) | def load(self, filename):
    method save (line 39) | def save(self, filename):
    method _is_pt_in_line_seg (line 52) | def _is_pt_in_line_seg(self, pt, pt1, pt2):
    method freeze_junction (line 60) | def freeze_junction(self, status=True):
    method _can_be_extented_by (line 80) | def _can_be_extented_by(self, line_seg1, line_seg2):
    method freeze_line_seg (line 93) | def freeze_line_seg(self, status=True):
    method add_junction (line 377) | def add_junction(self, junction):
    method add_line_seg (line 380) | def add_line_seg(self, junction1, junction2):
    method junctions (line 401) | def junctions(self):
    method line_segs (line 406) | def line_segs(self):
    method longest_line_segs (line 412) | def longest_line_segs(self):
    method adj_mtx (line 417) | def adj_mtx(self):
    method line_map (line 426) | def line_map(self, size, scale_x=1., scale_y=1., line_width=2.):
    method num_junctions (line 443) | def num_junctions(self):
    method num_line_segs (line 447) | def num_line_segs(self):

FILE: data/sist_line.py
  class SISTLine (line 11) | class SISTLine(data.Dataset):
    method __init__ (line 12) | def __init__(self, data_root, transforms, phase="train", sigma_junctio...
    method __getitem__ (line 20) | def __getitem__(self, item):
    method __call__ (line 73) | def __call__(self, item):
    method __len__ (line 76) | def __len__(self):
  function readnsave (line 88) | def readnsave(i):
  function juncsave (line 93) | def juncsave(i):

FILE: data/transforms.py
  class Compose (line 8) | class Compose(object):
    method __init__ (line 9) | def __init__(self, *transforms):
    method __call__ (line 12) | def __call__(self, img, pt):
  class RandomCompose (line 19) | class RandomCompose(object):
    method __init__ (line 20) | def __init__(self, *transforms):
    method __call__ (line 23) | def __call__(self, img, pt):
  class Resize (line 31) | class Resize(object):
    method __init__ (line 32) | def __init__(self, size, interpolation=Image.BILINEAR):
    method __call__ (line 35) | def __call__(self, img, pt):
  class RandomHorizontalFlip (line 49) | class RandomHorizontalFlip(object):
    method __init__ (line 50) | def __init__(self, p=0.5):
    method __call__ (line 53) | def __call__(self, img, pt):
  class RandomColorAug (line 64) | class RandomColorAug(object):
    method __init__ (line 65) | def __init__(self, factor=0.2):
    method __call__ (line 68) | def __call__(self, img, pt):

FILE: data/utils.py
  function apply_gaussian (line 6) | def apply_gaussian(accumulate_confid_map, centers, xx, yy, sigma):
  function gen_gaussian_map (line 17) | def gen_gaussian_map(centers, shape, sigma):

FILE: data/york_urban.py
  class YorkUrban (line 11) | class YorkUrban(data.Dataset):
    method __init__ (line 12) | def __init__(self, data_root, transforms, phase="test", sigma_junction...
    method __getitem__ (line 22) | def __getitem__(self, item):
    method __call__ (line 70) | def __call__(self, item):
    method __len__ (line 73) | def __len__(self):

FILE: main.py
  function weight_fn (line 25) | def weight_fn(dist_map, max_dist, mid=0.1, scale=10):
  class LSDTrainer (line 33) | class LSDTrainer(object):
    method __init__ (line 34) | def __init__(
    method _group_weight (line 151) | def _group_weight(module, lr):
    method end (line 177) | def end(self):
    method _train_epoch (line 181) | def _train_epoch(self):
    method _vis_train (line 286) | def _vis_train(self, epoch, i, len_loader, img, heatmap, adj_mtx, junc...
    method _checkpoint (line 317) | def _checkpoint(self):
    method train (line 336) | def train(
    method _vis_eval (line 365) | def _vis_eval(self, epoch, i, len_loader, img, heatmap, adj_mtx, junct...
    method eval (line 404) | def eval(self,

FILE: models/backbone.py
  class Bottleneck (line 23) | class Bottleneck(nn.Module):
    method __init__ (line 26) | def __init__(self, inplanes, planes, stride=1, downsample=None):
    method forward (line 39) | def forward(self, x):
  class ResNet (line 62) | class ResNet(nn.Module):
    method __init__ (line 64) | def __init__(self, block, layers, num_classes=1000):
    method _make_layer (line 93) | def _make_layer(self, block, planes, blocks, stride=1):
    method forward (line 110) | def forward(self, x):
  class ResnetDilated (line 128) | class ResnetDilated(nn.Module):
    method __init__ (line 129) | def __init__(self, orig_resnet, dilate_scale=8):
    method _nostride_dilate (line 159) | def _nostride_dilate(m, dilate):
    method forward (line 174) | def forward(self, x):
  function resnet50 (line 194) | def resnet50(pretrained=False, **kwargs):
  function load_url (line 206) | def load_url(url, model_dir='./pretrained', map_location=None):
  class UPerNet (line 217) | class UPerNet(nn.Module):
    method __init__ (line 218) | def __init__(self, num_class=150, fc_dim=4096, pool_scales=(1, 2, 3, 6),
    method forward (line 259) | def forward(self, conv_out, segSize=None):
  class ResNetU50Backbone (line 300) | class ResNetU50Backbone(nn.Module):
    method __init__ (line 301) | def __init__(self, dim_embedding=256, encoder_weights="", decoder_weig...
    method forward (line 306) | def forward(self, img):
    method build_encoder (line 313) | def build_encoder(weights=''):
    method build_decoder (line 327) | def build_decoder(weights='', dim_embedding=64):
  class UNetBackbone (line 346) | class UNetBackbone(nn.Module):
    method __init__ (line 347) | def __init__(self, dim_embedding=256, n_downs=5, n_ups=3, weights=""):
    method forward (line 379) | def forward(self, x):

FILE: models/common.py
  class LMFPeakFinder (line 13) | class LMFPeakFinder(object):
    method __init__ (line 21) | def __init__(self, min_dist=5., min_th=0.3):
    method detect (line 25) | def detect(self, image):
  function conv3x3 (line 75) | def conv3x3(in_planes, out_planes, stride=1):
  function conv3x3_bn_relu (line 81) | def conv3x3_bn_relu(in_planes, out_planes, stride=1):
  class double_conv (line 89) | class double_conv(nn.Module):
    method __init__ (line 92) | def __init__(self, in_ch, out_ch):
    method forward (line 103) | def forward(self, x):
  class inconv (line 108) | class inconv(nn.Module):
    method __init__ (line 109) | def __init__(self, in_ch, out_ch):
    method forward (line 113) | def forward(self, x):
  class down (line 118) | class down(nn.Module):
    method __init__ (line 119) | def __init__(self, in_ch, out_ch):
    method forward (line 126) | def forward(self, x):
  class up (line 131) | class up(nn.Module):
    method __init__ (line 132) | def __init__(self, in_ch, out_ch, bilinear=True):
    method forward (line 144) | def forward(self, x1, x2):
  class outconv (line 155) | class outconv(nn.Module):
    method __init__ (line 156) | def __init__(self, in_ch, out_ch):
    method forward (line 160) | def forward(self, x):
  function weights_init (line 165) | def weights_init(m):
  function roi_pooling (line 174) | def roi_pooling(input, rois, size=(7, 7), spatial_scale=1.0):
  class GradAccumulatorFunction (line 204) | class GradAccumulatorFunction(Function):
    method forward (line 206) | def forward(ctx, input, accumulated_grad=None, mode="release"):
    method backward (line 213) | def backward(ctx, grad_output):
  class GradAccumulator (line 230) | class GradAccumulator(nn.Module):
    method __init__ (line 238) | def __init__(self, criterion_fns, submodules, collect_fn=None, reduce_...
    method forward (line 260) | def forward(self, tensor):

FILE: models/graph.py
  class JunctionInference (line 7) | class JunctionInference(nn.Module):
    method __init__ (line 8) | def __init__(self, dim_embedding, pooling_threshold=0.2, max_junctions...
    method forward (line 23) | def forward(self, feat):
  class LinePooling (line 60) | class LinePooling(nn.Module):
    method __init__ (line 61) | def __init__(self, align_size=256, spatial_scale=0.25):
    method forward (line 67) | def forward(self, feat, coord_st, coord_ed):
  class AdjacencyMatrixInference (line 93) | class AdjacencyMatrixInference(nn.Module):
    method __init__ (line 94) | def __init__(self, dim_embedding=256, align_size=256):
    method forward (line 114) | def forward(self, line_feat):

FILE: models/lsd.py
  class LSDDataLayer (line 10) | class LSDDataLayer(nn.Module):
    method __init__ (line 11) | def __init__(self, mean=None, std=None):
    method forward (line 16) | def forward(self, img):
  class BinaryFocalLoss (line 25) | class BinaryFocalLoss(nn.Module):
    method __init__ (line 26) | def __init__(self, gamma=2., alpha=0.25, size_average=True):
    method forward (line 32) | def forward(self, input, target, weight=None):
  class BlockAdjacencyMatrixInference (line 51) | class BlockAdjacencyMatrixInference(nn.Module):
    method __init__ (line 52) | def __init__(self,
    method forward (line 66) | def forward(self, feat):
  class BlockAdjacencyMatrixInferenceCriterion (line 76) | class BlockAdjacencyMatrixInferenceCriterion(nn.Module):
    method __init__ (line 77) | def __init__(self, adj_matrix_crit, adj_matrix_gt, adj_matrix_loss_lam...
    method forward (line 94) | def forward(self, block_adj_matrix):
  class LSDModule (line 109) | class LSDModule(nn.Module):
    method __init__ (line 110) | def __init__(
    method forward (line 195) | def forward(self, img, junc_map_gt, adj_matrix_gt, junc_loss_lambda=1....

FILE: models/lsd_test.py
  class LSDDataLayer (line 11) | class LSDDataLayer(nn.Module):
    method __init__ (line 12) | def __init__(self, mean=None, std=None):
    method forward (line 17) | def forward(self, img):
  class BinaryFocalLoss (line 26) | class BinaryFocalLoss(nn.Module):
    method __init__ (line 27) | def __init__(self, gamma=2., alpha=0.25, size_average=True):
    method forward (line 33) | def forward(self, input, target, weight=None):
  class BlockAdjacencyMatrixInference (line 52) | class BlockAdjacencyMatrixInference(nn.Module):
    method __init__ (line 53) | def __init__(self,
    method forward (line 67) | def forward(self, feat):
  class BlockAdjacencyMatrixInferenceCriterion (line 77) | class BlockAdjacencyMatrixInferenceCriterion(nn.Module):
    method __init__ (line 78) | def __init__(self, adj_matrix_crit, adj_matrix_gt, adj_matrix_loss_lam...
    method forward (line 95) | def forward(self, block_adj_matrix):
  class LSDTestModule (line 110) | class LSDTestModule(LSDModule):
    method __init__ (line 111) | def __init__(
    method forward (line 146) | def forward(self, img):

FILE: models/test_common.py
  class TestRoiPooling (line 10) | class TestRoiPooling(TestCase):
    method setUp (line 11) | def setUp(self):
    method test_output_size_7x7 (line 30) | def test_output_size_7x7(self):
    method test_output_size_5x5 (line 43) | def test_output_size_5x5(self):
    method test_output_size_1x1 (line 54) | def test_output_size_1x1(self):
    method test_output_value (line 64) | def test_output_value(self):
  class TestGradAccumulator (line 96) | class TestGradAccumulator(TestCase):
    method setUp (line 97) | def setUp(self):
    method test_forward_pass (line 118) | def test_forward_pass(self):
    method test_backward_pass (line 169) | def test_backward_pass(self):

FILE: models/test_graph.py
  class TestJunctionInference (line 7) | class TestJunctionInference(TestCase):
    method test_junction_inference_forward (line 8) | def test_junction_inference_forward(self):
  class TestJunctionPooling (line 19) | class TestJunctionPooling(TestCase):
    method test_junc_pooling_forward (line 20) | def test_junc_pooling_forward(self):
  class TestDirectionalAttention (line 32) | class TestDirectionalAttention(TestCase):
    method test_attention_forward (line 33) | def test_attention_forward(self):
  class TestAdjacencyMatrixInference (line 62) | class TestAdjacencyMatrixInference(TestCase):
    method test_adjacency_matrix_inference_forward (line 63) | def test_adjacency_matrix_inference_forward(self):

FILE: test.py
  class LSD (line 26) | class LSD(object):
    method __init__ (line 27) | def __init__(
    method end (line 87) | def end(self):
    method test (line 91) | def test(self, path_to_image):

FILE: utils.py
  function graph2line (line 9) | def graph2line(junctions, adj_mtx, threshold=0.5):
  function draw_lines (line 32) | def draw_lines(imgs, lines, scores=None, width=2):
  function draw_jucntions (line 57) | def draw_jucntions(hms, junctions):
  class AverageMeter (line 86) | class AverageMeter(object):
    method __init__ (line 89) | def __init__(self):
    method initialize (line 96) | def initialize(self, val, weight):
    method update (line 103) | def update(self, val, weight=1):
    method add (line 109) | def add(self, val, weight):
    method value (line 115) | def value(self):
    method average (line 118) | def average(self):
  function unique (line 122) | def unique(ar, return_index=False, return_inverse=False, return_counts=F...
  function colorEncode (line 165) | def colorEncode(labelmap, colors, mode='BGR'):
  function accuracy (line 182) | def accuracy(preds, label):
  function intersectionAndUnion (line 190) | def intersectionAndUnion(imPred, imLab, numClass):
  class NotSupportedCliException (line 213) | class NotSupportedCliException(Exception):
  function process_range (line 217) | def process_range(xpu, inp):
  function parse_devices (line 234) | def parse_devices(input_devices):
Condensed preview — 27 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (148K chars).
[
  {
    "path": ".gitattributes",
    "chars": 42,
    "preview": "*.pth filter=lfs diff=lfs merge=lfs -text\n"
  },
  {
    "path": ".gitignore",
    "chars": 27,
    "preview": ".git\n.idea\nlog\n__pycache__\n"
  },
  {
    "path": "LICENSE",
    "chars": 1056,
    "preview": "MIT License\n\nCopyright (c) 2019\n\nPermission is hereby granted, free of charge, to any person obtaining a copy\nof this so"
  },
  {
    "path": "README.md",
    "chars": 2283,
    "preview": "# PPGNet: Learning Point-Pair Graph for Line Segment Detection\n\nPyTorch implementation of our CVPR 2019 paper:\n\n[**PPGNe"
  },
  {
    "path": "ckpt/backbone/decoder_epoch_20.pth",
    "chars": 134,
    "preview": "version https://git-lfs.github.com/spec/v1\noid sha256:83be3696848929d3ed93deec1fbe31c94b8acbd40110bf1604e11fad024784fc\ns"
  },
  {
    "path": "ckpt/backbone/encoder_epoch_20.pth",
    "chars": 133,
    "preview": "version https://git-lfs.github.com/spec/v1\noid sha256:bb24707e745689a005ca0c857d6c516afd2f532a233b2697f4832333c8af48d9\ns"
  },
  {
    "path": "data/common.py",
    "chars": 2460,
    "preview": "import numpy as np\nfrom functools import lru_cache as cache\n\n\n@cache(maxsize=None)\ndef _assert_valid_param(param):\n    A"
  },
  {
    "path": "data/line_graph.py",
    "chars": 23917,
    "preview": "import os\nfrom sklearn.neighbors import KDTree\nfrom scipy.cluster.hierarchy import fclusterdata\nimport pickle\nfrom itert"
  },
  {
    "path": "data/sist_line.py",
    "chars": 3522,
    "preview": "import os\nimport numpy as np\nimport torch as th\nfrom torch.utils import data\nfrom data.line_graph import LineGraph\nfrom "
  },
  {
    "path": "data/transforms.py",
    "chars": 2094,
    "preview": "from torchvision.transforms import functional as tf\nimport numpy as np\nfrom PIL import Image\nimport random\nfrom functool"
  },
  {
    "path": "data/utils.py",
    "chars": 1160,
    "preview": "from numba import jit, float32, int32\nimport numpy as np    \n    \n\n@jit(float32[:, :](float32[:, :], float32[:, :], int3"
  },
  {
    "path": "data/york_urban.py",
    "chars": 5416,
    "preview": "import os\nimport numpy as np\nimport torch as th\nfrom torch.utils import data\nfrom data.line_graph import LineGraph\nfrom "
  },
  {
    "path": "main.py",
    "chars": 21120,
    "preview": "# System libs\nimport os\nimport time\n\n# Numerical libs\nimport torch\nimport torch.nn as nn\nfrom torch import optim\nfrom to"
  },
  {
    "path": "models/__init__.py",
    "chars": 17,
    "preview": "from . import lsd"
  },
  {
    "path": "models/backbone.py",
    "chars": 13451,
    "preview": "import os\nimport sys\nimport torch\nimport torch.nn as nn\nimport math\nfrom models.common import conv3x3, conv3x3_bn_relu, "
  },
  {
    "path": "models/common.py",
    "chars": 11132,
    "preview": "from scipy.ndimage.filters import maximum_filter\nfrom scipy.ndimage.morphology import generate_binary_structure, binary_"
  },
  {
    "path": "models/graph.py",
    "chars": 5510,
    "preview": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom models.common import conv3x3_bn_relu, LMFPeakFin"
  },
  {
    "path": "models/lsd.py",
    "chars": 13359,
    "preview": "import torch as th\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport models.graph\nimport models.backbone\nimpo"
  },
  {
    "path": "models/lsd_test.py",
    "chars": 9009,
    "preview": "import torch as th\nimport torch.nn as nn\nimport models.graph\nimport models.backbone\nimport models.common\nimport numpy as"
  },
  {
    "path": "models/pretrained/resnet50-imagenet.pth",
    "chars": 132,
    "preview": "version https://git-lfs.github.com/spec/v1\noid sha256:902bbfbc9b570be36e0f94e757b211b24f9216df14260c58b359ddb182c0723e\ns"
  },
  {
    "path": "models/test_common.py",
    "chars": 8969,
    "preview": "from unittest import TestCase\nimport numpy as np\nimport models.common as common\nimport torch as th\nimport torch.nn as nn"
  },
  {
    "path": "models/test_graph.py",
    "chars": 3317,
    "preview": "from unittest import TestCase\nimport torch as th\nimport models.graph as graph\nimport matplotlib.pyplot as plt\n\n\nclass Te"
  },
  {
    "path": "test.py",
    "chars": 3986,
    "preview": "# System libs\nimport os\nimport time\n\n# Numerical libs\nimport torch\nimport torch.nn as nn\nfrom torch import optim\nfrom to"
  },
  {
    "path": "test.sh",
    "chars": 410,
    "preview": "python test.py \\\n--exp-name line_weighted_wo_focal_junc --backbone resnet50 \\\n--backbone-kwargs '{\"encoder_weights\": \"ck"
  },
  {
    "path": "tools/rebuild_yorkurban.py",
    "chars": 1692,
    "preview": "from data.line_graph import LineGraph\nimport os\nfrom scipy import io\nimport numpy as np\nfrom shutil import copyfile\nfrom"
  },
  {
    "path": "train.sh",
    "chars": 872,
    "preview": "python main.py \\\n--exp-name line_weighted_wo_focal_junc --backbone resnet50 \\\n--backbone-kwargs '{\"encoder_weights\": \"ck"
  },
  {
    "path": "utils.py",
    "chars": 7480,
    "preview": "import numpy as np\nimport re\nimport functools\nimport torch as th\nimport cv2\nfrom numba import jit\n\n\ndef graph2line(junct"
  }
]

About this extraction

This page contains the full source code of the svip-lab/PPGNet GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 27 files (139.4 KB), approximately 37.3k tokens, and a symbol index with 198 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!