[
  {
    "path": ".gitattributes",
    "content": "*.pth filter=lfs diff=lfs merge=lfs -text\n"
  },
  {
    "path": ".gitignore",
    "content": ".git\n.idea\nlog\n__pycache__\n"
  },
  {
    "path": "LICENSE",
    "content": "MIT License\n\nCopyright (c) 2019\n\nPermission is hereby granted, free of charge, to any person obtaining a copy\nof this software and associated documentation files (the \"Software\"), to deal\nin the Software without restriction, including without limitation the rights\nto use, copy, modify, merge, publish, distribute, sublicense, and/or sell\ncopies of the Software, and to permit persons to whom the Software is\nfurnished to do so, subject to the following conditions:\n\nThe above copyright notice and this permission notice shall be included in all\ncopies or substantial portions of the Software.\n\nTHE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\nIMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\nFITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\nAUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\nLIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\nOUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\nSOFTWARE.\n"
  },
  {
    "path": "README.md",
    "content": "# PPGNet: Learning Point-Pair Graph for Line Segment Detection\n\nPyTorch implementation of our CVPR 2019 paper:\n\n[**PPGNet: Learning Point-Pair Graph for Line Segment Detection**](https://www.aiyoggle.me/publication/ppgnet-cvpr19/ppgnet-cvpr19.pdf)\n\nZiheng Zhang*, Zhengxin Li*, Ning Bi, Jia Zheng, Jinlei Wang, Kun Huang, Weixin Luo, Yanyu Xu, Shenghua Gao\n\n(\\* Equal Contribution)\n\nThe poster can be found [HERE](https://www.aiyoggle.me/publication/ppgnet-cvpr19).\n\n\n![pipe-line](https://svip-lab.github.io/img/project/cvpr2019_zhangzh1.png)\n**Demonstraton of juncton-line graph representaton G={V, E}. (a) an sample image patch with 10 junctons (V); (b) the graph which describes the connectvity of all junctons (G); (c) the adjacency matrix of all junctons (E, black means the junction pair is connected).** \n\n## Requirements\n- Python >= 3.6\n- fire >= 0.1.3\n- numba >= 0.40.0\n- numpy >= 1.14.5\n- pytorch = 0.4.1\n- scikit-learn = 0.19.2\n- scipy = 1.1.0\n- tensorboard >= 1.11.0\n- tensorboardX >= 1.4\n- torchvision >= 0.2.1\n- OpenCV >= 3.4.3\n\n## Usage\n\n1. clone this repository (and make sure you fetch all .pth files right with [git-lfs](https://git-lfs.github.com/)): `git clone https://github.com/svip-lab/PPGNet.git`\n2. download the preprocessed *SIST-Wireframe* dataset from [BaiduPan](https://pan.baidu.com/s/1Sbdi1lL492fhmPL1t1Ov0w) (code:lnfp) or [Google Drive](https://drive.google.com/file/d/1KggPcHCRu8BcOqCvVZCXiB64y9L2nQDf/view?usp=sharing).\n3. specify the dataset path in the `train.sh` script. (modify the --data-root parameter)\n4. run `train.sh`.\n\nPlease note that the code requires the GPU memory to be at least 24GB. For GPU with memory smaller than 24GB, you can use a smaller batch with `--batch-size` parameter and/or change the `--block-inference-size` parameter in `train.sh` to be a smaller integer to avoid the out-of-memory error.\n\n## Citation\n\nPlease cite our paper for any purpose of usage.\n```\n@inproceedings{zhang2019ppgnet,\n  title={PPGNet: Learning Point-Pair Graph for Line Segment Detection},\n  author={Ziheng Zhang and Zhengxin Li and Ning Bi and Jia Zheng and Jinlei Wang and Kun Huang and Weixin Luo and Yanyu Xu and Shenghua Gao},\n  booktitle={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition},\n  year={2019}\n}\n```\n\n"
  },
  {
    "path": "ckpt/backbone/decoder_epoch_20.pth",
    "content": "version https://git-lfs.github.com/spec/v1\noid sha256:83be3696848929d3ed93deec1fbe31c94b8acbd40110bf1604e11fad024784fc\nsize 162474609\n"
  },
  {
    "path": "ckpt/backbone/encoder_epoch_20.pth",
    "content": "version https://git-lfs.github.com/spec/v1\noid sha256:bb24707e745689a005ca0c857d6c516afd2f532a233b2697f4832333c8af48d9\nsize 95013117\n"
  },
  {
    "path": "data/common.py",
    "content": "import numpy as np\nfrom functools import lru_cache as cache\n\n\n@cache(maxsize=None)\ndef _assert_valid_param(param):\n    A, B, C = param\n    assert not np.isclose(A ** 2 + B ** 2, 0), \"invalid line param.\"\n    return np.array(param) / np.sqrt(A ** 2 + B ** 2)\n\n\ndef assert_valid_param(param):\n    return _assert_valid_param(tuple(param))\n\n\n@cache(maxsize=None)\ndef _fit_line(pts):\n    P = np.array(pts)\n    P = np.hstack((P, np.ones((len(P), 1))))\n    assert np.linalg.matrix_rank(P) >= 2, f\"points to fit line are not valid: {P}\"\n    u, s, vt = np.linalg.svd(P)\n    param = assert_valid_param(vt[-1])\n    res = np.linalg.norm(P.dot(param)) / len(P)\n\n    return param, res\n\n\ndef fit_line(pts):\n    return _fit_line(tuple(tuple(pt) for pt in pts))\n\n\ndef dist_pts_to_line(pts, param):\n    param = assert_valid_param(param)\n    P = np.array([pt for pt in pts])\n    P = np.hstack((P, np.ones((len(P), 1))))\n    dists = np.abs(P.dot(param))\n\n    return dists\n\n\ndef assert_pts_in_line(pts, param, atol=1.):\n    dists = np.array(dist_pts_to_line(pts, param))\n    assert np.all(dists < atol)\n\n\n@cache(maxsize=None)\ndef _find_pt_in_line(param):\n    A, B, C = assert_valid_param(param)\n    if np.abs(A) > np.abs(B):\n        x, y = -C / A, 0\n    else:\n        x, y = 0, -C / B\n\n    return np.array([x, y])\n\n\ndef find_pt_in_line(param):\n    return _find_pt_in_line(tuple(param))\n\n\ndef project_pts_on_line(pts, param):\n    param = assert_valid_param(param)\n    P = np.array([pt for pt in pts])\n    pt0 = np.array(find_pt_in_line(param))\n    e = np.array([-param[1], param[0]])\n    alpha = (P - pt0).dot(e)\n    P_proj = np.outer(alpha, e) + pt0\n\n    assert P_proj.ndim == 2 and alpha.ndim == 1, f\"internal error occored when project pts to line: {P_proj}, {alpha}\"\n\n    return P_proj, alpha\n\n\ndef find_lines_intersect(params):\n    P = []\n    for param in params:\n        param = assert_valid_param(param)\n        P.append(param)\n    P = np.array(P)\n    assert np.linalg.matrix_rank(P) >= 2, \"lines do not intersect\"\n    u, s, vh = np.linalg.svd(P)\n    x, y, _ = vh[-1] / vh[-1][-1]\n    hpt = np.array([x, y, 1])\n    dist = np.abs(P.dot(hpt)).mean()\n\n    return np.array([x, y]), dist\n\n\n@cache(maxsize=None)\ndef is_pt_in_line_seg(eps, pt, pt1, pt2):\n    param, _ = fit_line([pt1, pt2])\n    _, alphas = project_pts_on_line([pt, pt1, pt2], param)\n    dist = dist_pts_to_line([pt], param)[0]\n    return (dist < eps * 2) and (np.min(alphas[1:]) <= alphas[0] <= np.max(alphas[1:]))\n"
  },
  {
    "path": "data/line_graph.py",
    "content": "import os\nfrom sklearn.neighbors import KDTree\nfrom scipy.cluster.hierarchy import fclusterdata\nimport pickle\nfrom itertools import combinations\nimport collections\nfrom data.common import *\nimport cv2\n\n\nclass LineGraph(object):\n    def __init__(\n            self, eps_junction=3., eps_line_deg=np.pi / 20, verbose=False\n    ):\n        self._line_segs = []\n        self._junctions = []\n        self._end_points = []\n        self._refine_junctions = None\n        self._neighbor = {}\n        self._junc2line = {}\n        self._freeze_junction = False\n        self._kdtree = None\n        self._eps_junc = eps_junction\n        self._eps_line_seg = eps_line_deg\n        self.verbose = verbose\n\n    def load(self, filename):\n        with open(filename, \"rb\") as f:\n            data = pickle.load(f)\n            for mem in dir(self):\n                if (\n                        mem.startswith(\"_\")\n                        and not mem.startswith(\"__\")\n                        and not isinstance(getattr(self, mem), collections.Callable)\n                ):\n                    setattr(self, mem, data[mem])\n        return self\n\n    def save(self, filename):\n        with open(filename, \"wb+\") as f:\n            data = {}\n            for mem in dir(self):\n                if (\n                        mem.startswith(\"_\")\n                        and not mem.startswith(\"__\")\n                        and not isinstance(getattr(self, mem), collections.Callable)\n                ):\n                    data[mem] = getattr(self, mem)\n            pickle.dump(data, f)\n        return self\n\n    def _is_pt_in_line_seg(self, pt, pt1, pt2):\n\n        return is_pt_in_line_seg(self._eps_junc, tuple(pt), tuple(pt1), tuple(pt2))\n        # param, _ = fit_line([pt1, pt2])\n        # _, alphas = project_pts_on_line([pt, pt1, pt2], param)\n        # dist = dist_pts_to_line([pt], param)[0]\n        # return (dist < self._eps_junc * 2) and (np.min(alphas[1:]) <= alphas[0] <= np.max(alphas[1:]))\n\n    def freeze_junction(self, status=True):\n        self._freeze_junction = status\n        if status:\n            clusters = fclusterdata(self._junctions, self._eps_junc, criterion=\"distance\")\n            junc_groups = {}\n            for ind_junc, ind_group in enumerate(clusters):\n                if ind_group not in junc_groups.keys():\n                    junc_groups[ind_group] = []\n                junc_groups[ind_group].append(self._junctions[ind_junc])\n            if self.verbose:\n                print(f\"{len(self._junctions) - len(junc_groups)} junctions merged.\")\n            self._junctions = [np.mean(junc_group, axis=0) for junc_group in junc_groups.values()]\n\n            self._kdtree = KDTree(self._junctions, leaf_size=30)\n            dists, inds = self._kdtree.query(self._junctions, k=2)\n            repl_inds = np.nonzero(dists.sum(axis=1) < self._eps_junc)[0].tolist()\n            # assert len(repl_inds) == 0\n        else:\n            self._kdtree = None\n\n    def _can_be_extented_by(self, line_seg1, line_seg2):\n        pt1 = line_seg1[\"pt1\"]\n        pt2 = line_seg1[\"pt2\"]\n        _pt1 = line_seg2[\"pt1\"]\n        _pt2 = line_seg2[\"pt2\"]\n        if self._is_pt_in_line_seg(_pt1, pt1, pt2) or self._is_pt_in_line_seg(_pt2, pt1, pt2):\n            arr1 = pt1 - pt2\n            arr2 = _pt1 - _pt2\n            arr1 /= np.linalg.norm(arr1)\n            arr2 /= np.linalg.norm(arr2)\n            if np.abs(arr1.dot(arr2)) > np.cos(self._eps_line_seg):\n                return True\n\n    def freeze_line_seg(self, status=True):\n        assert self._freeze_junction, \"junction should be freezed before.\"\n        self._freeze_line_seg = status\n        if status:\n            # remove all junctions that are not in any line segment\n            junc_remove = set(list(range(len(self._junctions))))\n            for line_seg in self._line_segs:\n                junc_remove -= line_seg[\"junctions\"]\n\n            junc_remain = [ind_junc for ind_junc in range(len(self._junctions)) if ind_junc not in junc_remove]\n            self._junctions = [junc for ind_junc, junc in enumerate(self._junctions) if ind_junc not in junc_remove]\n            map_old_to_new = {old: new for new, old in enumerate(junc_remain)}\n            for line_seg in self._line_segs:\n                line_seg[\"junctions\"] = set([map_old_to_new[old] for old in line_seg[\"junctions\"]])\n            self._junc2line = {}\n\n            # # extend all line segments\n            # cnt = 0\n            # finished = False\n            # while (not finished):\n            #     finished = True\n            #     for ind_ls1 in range(len(self._line_segs)):\n            #         for ind_ls2 in range(ind_ls1 + 1, len(self._line_segs)):\n            #             # if ind_ls1 == ind_ls2:\n            #             #     continue\n            #             ls1 = self._line_segs[ind_ls1]\n            #             ls2 = self._line_segs[ind_ls2]\n            #             if ls2[\"junctions\"].issubset(ls1[\"junctions\"]):\n            #                 continue\n            #             if self._can_be_extented_by(ls1, ls2):\n            #                 pt1 = ls1[\"pt1\"]\n            #                 pt2 = ls1[\"pt2\"]\n            #                 param = ls1[\"param\"]\n            #                 _pt1 = ls2[\"pt1\"]\n            #                 _pt2 = ls2[\"pt2\"]\n            #                 _param = ls2[\"param\"]\n            #                 P, alphas = project_pts_on_line([pt1, pt2, _pt1, _pt2], param)\n            #                 _P, _alphas = project_pts_on_line([pt1, pt2, _pt1, _pt2], _param)\n            #                 ind_min, ind_max = np.argmin(alphas), np.argmax(alphas)\n            #                 _ind_min, _ind_max = np.argmin(_alphas), np.argmax(_alphas)\n            #                 if np.abs(alphas[ind_min] - alphas[ind_max]) < self._eps_junc:  # this should not happen...\n            #                     continue\n            #                 ls1[\"pt1\"] = P[ind_min]\n            #                 ls1[\"pt2\"] = P[ind_max]\n            #                 ls2[\"pt1\"] = _P[_ind_min]\n            #                 ls2[\"pt2\"] = _P[_ind_max]\n            #                 ls1[\"junctions\"] = ls1[\"junctions\"].union(ls2[\"junctions\"])\n            #                 ls2[\"junctions\"] = ls2[\"junctions\"].union(ls1[\"junctions\"])\n            #                 cnt += 1\n            #                 finished = False\n            #\n            # if self.verbose:\n            #     print(f\"line segments extend {cnt} times.\", flush=True)\n            #\n            # # merge line segments of which junction set is subset of that of other line segments\n            # finished = False\n            # merged_inds = []\n            # while (not finished):\n            #     finished = True\n            #     for ind_ls1 in range(len(self._line_segs)):\n            #         for ind_ls2 in range(len(self._line_segs)):\n            #             if (ind_ls1 == ind_ls2) or (ind_ls2 in merged_inds):\n            #                 continue\n            #             ls1 = self._line_segs[ind_ls1]\n            #             ls2 = self._line_segs[ind_ls2]\n            #             if ls2[\"junctions\"].issubset(ls1[\"junctions\"]):\n            #                 merged_inds.append(ind_ls2)\n            #                 finished = False\n            # self._line_segs = [line_seg for ind, line_seg in enumerate(self._line_segs) if ind not in merged_inds]\n            # if self.verbose:\n            #     print(f\"{len(merged_inds)} line segments merged.\", flush=True)\n\n            # # refine line segments w.r.t all associated junctions\n            # ling_seg_remove = []\n            # for ind_line_seg, line_seg in enumerate(self._line_segs):\n            #     junc_set = line_seg[\"junctions\"]\n            #     if len(junc_set) == 2:  # no need to refine\n            #         continue\n            #     ind_innr_junctions = [ind_junc for ind_junc in junc_set]\n            #     param, res = fit_line([self._junctions[ind_junc] for ind_junc in ind_innr_junctions])\n            #     ind_outliers = np.nonzero(res > self._eps_junc * 2)[0]\n            #     # if too many outliers, remove this line segment\n            #     if (len(junc_set) - len(ind_outliers) < 2) or (len(ind_outliers) / len(junc_set) > 0.3):\n            #         # ling_seg_remove.append(ind_line_seg)\n            #         continue\n            #     # else remove outliers\n            #     for ind_outlier in ind_outliers:\n            #         # junc_set.remove(ind_innr_junctions[ind_outlier])\n            #         pass\n            #     line_seg[\"param\"] = param\n            #     # find new endpoints\n            #     ind_innr_junctions = [ind_junc for ind_junc in junc_set]\n            #     P, alphas = project_pts_on_line([self._junctions[ind_junc] for ind_junc in ind_innr_junctions], param)\n            #     ind_max, ind_min = np.argmax(alphas), np.argmin(alphas)\n            #     line_seg[\"pt1\"] = P[ind_min]\n            #     line_seg[\"pt2\"] = P[ind_max]\n            #\n            # self._line_segs = [line_seg for ind, line_seg in enumerate(self._line_segs) if ind not in ling_seg_remove]\n            # if self.verbose:\n            #     print(f\"{len(ling_seg_remove)} line segments removed.\", flush=True)\n\n            # # remove all junctions that are not in any line segment\n            # junc_remove = set(list(range(len(self._junctions))))\n            # for line_seg in self._line_segs:\n            #     junc_remove -= line_seg[\"junctions\"]\n            # junc_remain = [ind_junc for ind_junc in range(len(self._junctions)) if ind_junc not in junc_remove]\n            # self._junctions = [junc for ind_junc, junc in enumerate(self._junctions) if ind_junc not in junc_remove]\n            # map_old_to_new = {old: new for new, old in enumerate(junc_remain)}\n            # for line_seg in self._line_segs:\n            #     line_seg[\"junctions\"] = set([map_old_to_new[old] for old in line_seg[\"junctions\"]])\n            # self._junc2line = {}\n            #\n            # # build hash table mapping from junction id to line segment\n            # for line_seg in self._line_segs:\n            #     for ind_junc in line_seg[\"junctions\"]:\n            #         if ind_junc not in self._junc2line.keys():\n            #             self._junc2line[ind_junc] = []\n            #         self._junc2line[ind_junc].append(line_seg)\n\n            # # remove possibly noise line segments\n            # line_seg_remove = []\n            # for ind_line_seg, line_seg in enumerate(self._line_segs):\n            #     junc_set = line_seg[\"junctions\"]\n            #     if len(junc_set) == 1: # not possible\n            #         line_seg_remove.append(ind_line_seg)\n            #     elif len(junc_set) == 2:\n            #         ind_junc1, ind_junc2 = list(junc_set)\n            #         if len(self._junc2line[ind_junc1]) >= 5 or len(self._junc2line[ind_junc2]) >= 5:\n            #             is_delete = True\n            #             for neigh_line_seg in self._junc2line[ind_junc1]:\n            #                 if neigh_line_seg is line_seg:\n            #                     continue\n            #                 arr1 = neigh_line_seg[\"pt1\"] - neigh_line_seg[\"pt2\"]\n            #                 arr2 = line_seg[\"pt1\"] - line_seg[\"pt2\"]\n            #                 if np.abs(arr1.dot(arr2) / (np.linalg.norm(arr1) * np.linalg.norm(arr2))) < np.cos(\n            #                         self._eps_line_seg):\n            #                     for neigh_line_seg2 in self._junc2line[ind_junc2]:\n            #                         if neigh_line_seg2 is line_seg:\n            #                             continue\n            #                         arr1 = neigh_line_seg2[\"pt1\"] - neigh_line_seg2[\"pt2\"]\n            #                         arr2 = line_seg[\"pt1\"] - line_seg[\"pt2\"]\n            #                         if np.abs(arr1.dot(arr2) / (np.linalg.norm(arr1) * np.linalg.norm(arr2))) < np.cos(\n            #                                 self._eps_line_seg):\n            #                             is_delete = False\n            #                             break\n            #             if is_delete:\n            #                 line_seg_remove.append(ind_line_seg)\n            #                 break\n            #         if len(self._junc2line[ind_junc1]) == 1 or len(self._junc2line[ind_junc1]) == 1:\n            #             for neigh_line_seg in self._junc2line[ind_junc1]:\n            #                 if neigh_line_seg is line_seg:\n            #                     continue\n            #                 arr1 = neigh_line_seg[\"pt1\"] - neigh_line_seg[\"pt2\"]\n            #                 arr2 = line_seg[\"pt1\"] - line_seg[\"pt2\"]\n            #                 if np.abs(arr1.dot(arr2) / (np.linalg.norm(arr1) * np.linalg.norm(arr2))) < np.cos(self._eps_line_seg):\n            #                     line_seg_remove.append(ind_line_seg)\n            #                     break\n            #             for neigh_line_seg in self._junc2line[ind_junc2]:\n            #                 if neigh_line_seg is line_seg:\n            #                     continue\n            #                 arr1 = neigh_line_seg[\"pt1\"] - neigh_line_seg[\"pt2\"]\n            #                 arr2 = line_seg[\"pt1\"] - line_seg[\"pt2\"]\n            #                 if np.abs(arr1.dot(arr2) / (np.linalg.norm(arr1) * np.linalg.norm(arr2))) < np.cos(self._eps_line_seg):\n            #                     line_seg_remove.append(ind_line_seg)\n            #                     break\n            # self._line_segs = [line_seg for ind, line_seg in enumerate(self._line_segs) if ind not in line_seg_remove]\n            # if self.verbose:\n            #     print(f\"{len(line_seg_remove)} possibly noisy line segments removed.\", flush=True)\n\n            # # remove all junctions that are not in any line segment\n            # junc_remove = set(list(range(len(self._junctions))))\n            # for line_seg in self._line_segs:\n            #     junc_remove -= line_seg[\"junctions\"]\n            # junc_remain = [ind_junc for ind_junc in range(len(self._junctions)) if ind_junc not in junc_remove]\n            # self._junctions = [junc for ind_junc, junc in enumerate(self._junctions) if ind_junc not in junc_remove]\n            # map_old_to_new = {old: new for new, old in enumerate(junc_remain)}\n            # for line_seg in self._line_segs:\n            #     line_seg[\"junctions\"] = set([map_old_to_new[old] for old in line_seg[\"junctions\"]])\n            # self._junc2line = {}\n\n            # build hash table mapping from junction id to line segment\n            for line_seg in self._line_segs:\n                for ind_junc in line_seg[\"junctions\"]:\n                    if ind_junc not in self._junc2line.keys():\n                        self._junc2line[ind_junc] = []\n                    self._junc2line[ind_junc].append(line_seg)\n            assert len(self._junc2line) == len(self._junctions)\n\n            # # refine all junctions w.r.t associated line segments\n            # junctions_refined = []\n            # for ind_junc, junc in enumerate(self._junctions):\n            #     line_segs = self._junc2line[ind_junc]\n            #     # assert len(line_segs) > 0, \"line seg\"\n            #     if len(line_segs) == 1:\n            #         param = line_segs[0][\"param\"]\n            #         refined, _ = project_pts_on_line([junc], param)\n            #         junctions_refined.append(refined[0])\n            #     elif len(line_segs) >= 2:\n            #         refined, _ = find_lines_intersect([line_seg[\"param\"] for line_seg in line_segs])\n            #         if np.linalg.norm(refined - junc) > 3 * self._eps_junc:  # maybe something wrong, do nothing...\n            #             junctions_refined.append(junc)\n            #         junctions_refined.append(refined)\n            # self._junctions = junctions_refined\n            self._kdtree = KDTree(self._junctions, leaf_size=30)\n\n            # # remove all junctions that are not in any line segment\n            # junc_remove = set(list(range(len(self._junctions))))\n            # for line_seg in self._line_segs:\n            #     junc_remove -= line_seg[\"junctions\"]\n            # junc_remain = [ind_junc for ind_junc in range(len(self._junctions)) if ind_junc not in junc_remove]\n            # self._junctions = [junc for ind_junc, junc in enumerate(self._junctions) if ind_junc not in junc_remove]\n            # map_old_to_new = {old: new for new, old in enumerate(junc_remain)}\n            # for line_seg in self._line_segs:\n            #     line_seg[\"junctions\"] = set([map_old_to_new[old] for old in line_seg[\"junctions\"]])\n            # self._junc2line = {}\n\n            # # build hash table mapping from junction id to line segment\n            # for line_seg in self._line_segs:\n            #     for ind_junc in line_seg[\"junctions\"]:\n            #         if ind_junc not in self._junc2line.keys():\n            #             self._junc2line[ind_junc] = []\n            #         self._junc2line[ind_junc].append(line_seg)\n\n            # add line segment intersections\n            cnt_new = 0\n            for ind_ls1 in range(len(self._line_segs)):\n                for ind_ls2 in range(ind_ls1 + 1, len(self._line_segs)):\n                    ls1, ls2 = self._line_segs[ind_ls1], self._line_segs[ind_ls2]\n                    pt11, pt12 = ls1[\"pt1\"], ls1[\"pt2\"]\n                    pt21, pt22 = ls2[\"pt1\"], ls2[\"pt2\"]\n                    p = np.array(pt11)\n                    r = np.array(pt12) - np.array(pt11)\n                    q = np.array(pt21)\n                    s = np.array(pt22) - np.array(pt21)\n                    alpha = np.cross(r, s)\n                    if np.isclose(alpha, 0):\n                        continue\n                    # if np.abs(np.dot(r, s) / np.linalg.norm(r) / np.linalg.norm(s)) > self._eps_line_seg:\n                    #     continue\n                    beta_t = np.cross(q - p, s)\n                    beta_u = np.cross(q - p, r)\n                    t = np.mean(beta_t / alpha)\n                    u = np.mean(beta_u / alpha)\n                    # exact intersect\n                    if 0 <= t <= 1 and 0 <= u <= 1:\n                        # print(\"find exact intersect\")\n                        assert np.allclose(p + t * r, q + u * s, rtol=1.e-3), \"intersecting math assertion (exact)\"\n                        intersect = p + t * r\n                        dists, ind = self._kdtree.query([intersect], k=1)\n                        if dists[0, 0] < self._eps_junc:\n                            ls1[\"junctions\"].add(ind[0, 0])\n                            ls2[\"junctions\"].add(ind[0, 0])\n                            self._junc2line[ind[0, 0]].append(ls1)\n                            self._junc2line[ind[0, 0]].append(ls2)\n                        else:\n                            ind = len(self._junctions)\n                            self._junctions.append(intersect)\n                            ls1[\"junctions\"].add(ind)\n                            ls2[\"junctions\"].add(ind)\n                            self._junc2line[ind] = [ls1, ls2]\n                            self._kdtree = KDTree(self._junctions, leaf_size=30)\n                            cnt_new += 1\n\n                    # # close to intersect\n                    # elif (min(abs(t), abs(t - 1)) * np.linalg.norm(r) < self._eps_junc * 5) and (\n                    #         min(abs(u), abs(u - 1)) * np.linalg.norm(s) < self._eps_junc * 5):\n                    #     assert np.allclose(p + t * r, q + u * s, rtol=1.e-3), \"intersecting math assertion (close)\"\n                    #     intersect = p + t * r\n                    #     dists, ind = self._kdtree.query([intersect], k=1)\n                    #     if dists[0, 0] < self._eps_junc:\n                    #         ls1[\"junctions\"].add(ind[0, 0])\n                    #         ls2[\"junctions\"].add(ind[0, 0])\n                    #         self._junc2line[ind[0, 0]].append(ls1)\n                    #         self._junc2line[ind[0, 0]].append(ls2)\n                    #     else:\n                    #         ind = len(self._junctions)\n                    #         self._junctions.append(intersect)\n                    #         ls1[\"junctions\"].add(ind)\n                    #         ls2[\"junctions\"].add(ind)\n                    #         self._junc2line[ind] = [ls1, ls2]\n                    #         cnt_new += 1\n            if self.verbose:\n                print(f\"found {cnt_new} new intercept junctions\", flush=True)\n\n    def add_junction(self, junction):\n        self._junctions.append(np.array(junction))\n\n    def add_line_seg(self, junction1, junction2):\n        assert self._freeze_junction\n        junc1 = np.array(junction1)\n        junc2 = np.array(junction2)\n        dist1, ind1 = self._kdtree.query([junc1], k=1)\n        dist2, ind2 = self._kdtree.query([junc2], k=1)\n        if not (dist1[0, 0] < self._eps_junc and dist2[0, 0] < self._eps_junc):\n            if self.verbose:\n                print(f\"warn: invalid line endpoints: {junc1} -> {junc2}, ignored.\")\n            return\n        if ind1[0, 0] == ind2[0, 0]:\n            if self.verbose:\n                print(f\"warn: zero length line segment found ({junc1} -> {junc2}), ignored.\")\n            return\n        self._line_segs.append(dict(\n            pt1=junc1,\n            pt2=junc2,\n            param=fit_line([junc1, junc2])[0],\n            junctions=set([ind1[0, 0], ind2[0, 0]])\n        ))\n\n    def junctions(self):\n        assert self.freeze_junction and self.freeze_line_seg\n        for junc in self._junctions:\n            yield junc\n\n    def line_segs(self):\n        assert self.freeze_junction and self.freeze_line_seg\n        for line_seg in self._line_segs:\n            for ind_junc1, ind_junc2 in combinations(line_seg[\"junctions\"], 2):\n                yield self._junctions[ind_junc1], self._junctions[ind_junc2]\n\n    def longest_line_segs(self):\n        for line_seg in self._line_segs:\n            yield line_seg[\"pt1\"], line_seg[\"pt2\"]\n\n    @property\n    def adj_mtx(self):\n        mtx = np.zeros((len(self._junctions), len(self._junctions)))\n        for line_seg in self._line_segs:\n            for ind_junc1, ind_junc2 in combinations(line_seg[\"junctions\"], 2):\n                mtx[ind_junc1, ind_junc2] = 1\n                mtx[ind_junc2, ind_junc1] = 1\n\n        return mtx\n\n    def line_map(self, size, scale_x=1., scale_y=1., line_width=2.):\n        if isinstance(size, tuple):\n            lmap = np.zeros(size, dtype=np.uint8)\n        else:\n            lmap = np.zeros((size, size), dtype=np.uint8)\n        for line_seg in self._line_segs:\n            for ind_junc1, ind_junc2 in combinations(line_seg[\"junctions\"], 2):\n                x1, y1 = self._junctions[ind_junc1]\n                x2, y2 = self._junctions[ind_junc2]\n                x1, x2 = int(x1 * scale_x + 0.5), int(x2 * scale_x + 0.5)\n                y1, y2 = int(y1 * scale_y + 0.5), int(y2 * scale_y + 0.5)\n                lmap = cv2.line(lmap, (x1, y1), (x2, y2), 255, int(line_width), cv2.LINE_AA)\n        # lmap = cv2.GaussianBlur(lmap, (int(line_width), int(line_width)), 1)\n        # lmap[lmap > 1] = 1\n        return lmap\n\n    @property\n    def num_junctions(self):\n        return len(self._junctions)\n\n    @property\n    def num_line_segs(self):\n        return np.sum(\n            [\n                len(line_seg[\"junctions\"])\n                * (len(line_seg[\"junctions\"]) - 1)\n                / 2\n                for line_seg in self._line_segs\n            ]\n        )\n\n\nif __name__ == \"__main__\":\n    from glob import glob\n    from tqdm import trange\n    data_root = \"/home/ziheng/indoorDist_new\"\n    img = [os.path.join(\"train\", os.path.basename(f)) for f in glob(os.path.join(data_root, \"train\", \"*.jpg\"))]\n    max_junc = 0\n    for item in trange(len(img)):\n        lg = LineGraph().load(os.path.join(data_root, img[item][:-4] + \".lg\"))\n        max_junc = max(lg.num_junctions, max_junc)\n\n    print(max_junc)"
  },
  {
    "path": "data/sist_line.py",
    "content": "import os\nimport numpy as np\nimport torch as th\nfrom torch.utils import data\nfrom data.line_graph import LineGraph\nfrom glob import glob\nfrom PIL import Image\nfrom data.utils import gen_gaussian_map\n\n\nclass SISTLine(data.Dataset):\n    def __init__(self, data_root, transforms, phase=\"train\", sigma_junction=3., max_junctions=512):\n        self.data_root = data_root\n        self.img = [os.path.join(phase, os.path.basename(f)) for f in glob(os.path.join(data_root, phase, \"*.jpg\"))]\n        self.transforms = transforms\n        self.phase = phase\n        self.max_junctions = max_junctions\n        self.sigma_junction = sigma_junction\n\n    def __getitem__(self, item):\n        img = Image.open(os.path.join(self.data_root, self.img[item]))\n        ori_w, ori_h = img.size\n        \n        lg = LineGraph().load(os.path.join(self.data_root, self.img[item][:-4] + \".lg\"))\n        num_junc = lg.num_junctions\n        # assert num_junc <= self.max_junctions, f\"{(item, num_junc)}\"\n        if self.phase == \"train\" and num_junc > self.max_junctions:\n            return self[(item + 1) % len(self)]\n        elif num_junc > self.max_junctions:\n            return self[(item + 1) % len(self)]\n            # raise AssertionError()\n        junc = np.zeros((self.max_junctions, 2))\n        # tic = time()\n        junc[:num_junc] = np.array([j if np.sum(j) > 0 else j + 1 for j in lg.junctions()])\n        # print(f\"junc time: {time() - tic:.4f}\")\n\n        assert np.sum(junc[:num_junc].sum(axis=1) <= 0) == 0, f\"{item}\"\n        # tic = time()\n        adj_mtx = np.zeros((self.max_junctions, self.max_junctions))\n        # print(f\"mtx time: {time() - tic:.4f}\")\n        adj_mtx[:num_junc, :num_junc] = lg.adj_mtx\n\n        if self.transforms is not None:\n            img, junc = self.transforms(img, junc)\n\n        cur_w, cur_h = img.size\n\n        junc[junc >= img.size[0]] = img.size[0] - 1\n        junc[junc < 0] = 0\n        # tic = time()\n        heatmap = gen_gaussian_map(junc[:num_junc], img.size[:2], self.sigma_junction)\n        assert cur_h == cur_w\n        line_map = lg.line_map(cur_h, cur_w / ori_w, cur_h / ori_h, line_width=self.sigma_junction)\n        # print(f\"gaussian time: {time() - tic:.4f}\")\n\n        img = np.array(np.asarray(img)[:, :, ::-1])\n        img = th.from_numpy(img).permute(2, 0, 1)\n        adj_mtx = th.from_numpy(adj_mtx)\n        junc = th.from_numpy(junc)\n        heatmap = th.from_numpy(heatmap)\n        line_map = th.from_numpy(line_map)\n\n        batch = dict(\n            image=img.float(),\n            adj_mtx=adj_mtx.float(),\n            heatmap = heatmap.float(),\n            junctions = junc.float(),\n            line_map = line_map.float()\n        )\n\n        return batch\n\n    def __call__(self, item):\n        return self.__getitem__(item)\n\n    def __len__(self):\n        return len(self.img)\n\n\nif __name__ == \"__main__\":\n    from tqdm import trange\n    from multiprocessing.pool import Pool\n    data = SISTLine(\"/home/ziheng/indoorDist_new\", None, \"train\")\n    # os.makedirs(\"/home/ziheng/heatmaps\")\n    pool = Pool(20)\n    cnt = 0\n\n    def readnsave(i):\n        batch = data[i]\n        hm = batch[\"heatmap\"].numpy()\n        np.save(f\"/home/ziheng/heatmaps/{i}.npy\", hm)\n\n    def juncsave(i):\n        batch = data[i]\n        hm = batch[\"heatmap\"].numpy()\n        junc = batch[\"heatmap\"].numpy()\n        np.save(f\"/home/ziheng/heatmaps/{i}.npy\", hm)\n\n\n    readnsave.cnt = 0\n    # for i in trange(len(data)):\n    pool.map_async(readnsave, range(len(data)))\n    pool.close()\n    pool.join()\n\n"
  },
  {
    "path": "data/transforms.py",
    "content": "from torchvision.transforms import functional as tf\nimport numpy as np\nfrom PIL import Image\nimport random\nfrom functools import partial\n\n\nclass Compose(object):\n    def __init__(self, *transforms):\n        self.transforms = transforms\n\n    def __call__(self, img, pt):\n        for t in self.transforms:\n            img, pt = t(img, pt)\n\n        return img, pt\n\n\nclass RandomCompose(object):\n    def __init__(self, *transforms):\n        self.transforms = transforms\n\n    def __call__(self, img, pt):\n        random.shuffle(self.transforms)\n        for t in self.transforms:\n            img, pt = t(img, pt)\n\n        return img, pt        \n\n\nclass Resize(object):\n    def __init__(self, size, interpolation=Image.BILINEAR):\n        self.size = size\n        self.interpolation = interpolation\n    def __call__(self, img, pt):\n        w, h = img.size\n        y_scale = (self.size[0] - 1) / (h - 1)\n        x_scale = (self.size[1] - 1) / (w - 1)\n\n        pt_new = np.zeros_like(pt)\n        pt_mask = pt.sum(axis=1) > 0\n        pt_new[pt_mask] = np.vstack((pt[pt_mask][:, 0] * x_scale, pt[pt_mask][:, 1] * y_scale)).T\n\n        assert not pt_new.sum() == 0\n\n        return img.resize(self.size[::-1], self.interpolation), pt_new\n\n\nclass RandomHorizontalFlip(object):\n    def __init__(self, p=0.5):\n        self.p = p\n\n    def __call__(self, img, pt):\n        if self.p > np.random.rand():\n            w, _ = img.size\n            img = tf.hflip(img)\n            pt_new = np.zeros_like(pt)\n            pt_mask = pt.sum(axis=1) > 0\n            pt_new[pt_mask] = np.vstack((w - 1 - pt[pt_mask][:, 0], pt[pt_mask][:, 1])).T\n            return img, pt_new\n        return img, pt\n\n\nclass RandomColorAug(object):\n    def __init__(self, factor=0.2):\n        self.factor = factor\n\n    def __call__(self, img, pt):\n        transforms = [\n            tf.adjust_brightness,\n            tf.adjust_contrast,\n            tf.adjust_saturation\n            ]\n        random.shuffle(transforms)\n        for t in transforms:\n            img = t(img, (np.random.rand() - 0.5) * 2 * self.factor + 1)\n\n        return img, pt"
  },
  {
    "path": "data/utils.py",
    "content": "from numba import jit, float32, int32\nimport numpy as np    \n    \n\n@jit(float32[:, :](float32[:, :], float32[:, :], int32[:, :], int32[:, :], float32), nopython=True, fastmath=True)\ndef apply_gaussian(accumulate_confid_map, centers, xx, yy, sigma):\n    for i in range(len(centers)):\n        center = centers[i]\n        d2 = (xx - center[0]) ** 2 + (yy - center[1]) ** 2\n        exponent = d2 / 2.0 / sigma / sigma\n        mask = exponent <= 4.6052\n        cofid_map = np.exp(-exponent)\n        cofid_map = np.multiply(mask, cofid_map)\n        accumulate_confid_map += cofid_map\n    return accumulate_confid_map\n\ndef gen_gaussian_map(centers, shape, sigma):\n    centers = np.float32(centers)\n    sigma = np.float32(sigma)\n    accumulate_confid_map = np.zeros(shape, dtype=np.float32)\n    y_range = np.arange(accumulate_confid_map.shape[0], dtype=np.int32)\n    x_range = np.arange(accumulate_confid_map.shape[1], dtype=np.int32)\n    xx, yy = np.meshgrid(x_range, y_range)\n\n    accumulate_confid_map = apply_gaussian(accumulate_confid_map, centers, xx, yy, sigma)\n    accumulate_confid_map[accumulate_confid_map > 1.0] = 1.0\n    \n    return accumulate_confid_map\n"
  },
  {
    "path": "data/york_urban.py",
    "content": "import os\nimport numpy as np\nimport torch as th\nfrom torch.utils import data\nfrom data.line_graph import LineGraph\nfrom glob import glob\nfrom PIL import Image\nfrom data.utils import gen_gaussian_map\n\n\nclass YorkUrban(data.Dataset):\n    def __init__(self, data_root, transforms, phase=\"test\", sigma_junction=3., max_junctions=800):\n        print(f\"{phase}\")\n        assert phase == \"eval\"\n        self.data_root = data_root\n        self.img = [os.path.basename(f) for f in glob(os.path.join(data_root, phase, \"*.jpg\"))]\n        self.transforms = transforms\n        self.phase = phase\n        self.max_junctions = max_junctions\n        self.sigma_junction = sigma_junction\n\n    def __getitem__(self, item):\n        img = Image.open(os.path.join(self.data_root, self.phase, self.img[item]))\n        ori_w, ori_h = img.size\n\n        lg = LineGraph().load(os.path.join(self.data_root, self.phase, self.img[item][:-4] + \".lg\"))\n        num_junc = lg.num_junctions\n        assert num_junc <= self.max_junctions, f\"{(item, num_junc)}\"\n        junc = np.zeros((self.max_junctions, 2))\n        # tic = time()\n        junc[:num_junc] = np.array([j if np.sum(j) > 0 else j + 1 for j in lg.junctions()])\n        # print(f\"junc time: {time() - tic:.4f}\")\n\n        assert np.sum(junc[:num_junc].sum(axis=1) <= 0) == 0, f\"{item}\"\n        # tic = time()\n        adj_mtx = np.zeros((self.max_junctions, self.max_junctions))\n        # print(f\"mtx time: {time() - tic:.4f}\")\n        adj_mtx[:num_junc, :num_junc] = lg.adj_mtx\n\n        if self.transforms is not None:\n            img, junc = self.transforms(img, junc)\n\n        cur_w, cur_h = img.size\n\n        junc[junc >= img.size[0]] = img.size[0] - 1\n        junc[junc < 0] = 0\n        # tic = time()\n        heatmap = gen_gaussian_map(junc[:num_junc], img.size[:2], self.sigma_junction)\n        assert cur_h == cur_w\n        line_map = lg.line_map(cur_h, cur_w / ori_w, cur_h / ori_h, line_width=self.sigma_junction)\n        # print(f\"gaussian time: {time() - tic:.4f}\")\n\n        img = np.array(np.asarray(img)[:, :, ::-1])\n        img = th.from_numpy(img).permute(2, 0, 1)\n        adj_mtx = th.from_numpy(adj_mtx)\n        junc = th.from_numpy(junc)\n        heatmap = th.from_numpy(heatmap)\n        line_map = th.from_numpy(line_map)\n\n        batch = dict(\n            image=img.float(),\n            adj_mtx=adj_mtx.float(),\n            heatmap=heatmap.float(),\n            junctions=junc.float(),\n            line_map=line_map.float()\n        )\n\n        return batch\n\n    def __call__(self, item):\n        return self.__getitem__(item)\n\n    def __len__(self):\n        return len(self.img)\n\n\n# class YorkUrbanTrain(data.Dataset):\n#     def __init__(self, data_root, transforms, phase=\"train\", sigma_junction=3., max_junctions=512):\n#         assert phase == \"train\"\n#         self.data_root = data_root\n#         self.img = [os.path.basename(f) for f in glob(os.path.join(data_root, \"*.jpg\"))]\n#         self.transforms = transforms\n#         self.phase = phase\n#         self.max_junctions = max_junctions\n#         self.sigma_junction = sigma_junction\n#\n#     def __getitem__(self, item):\n#         img = Image.open(os.path.join(self.data_root, self.img[item]))\n#         ori_w, ori_h = img.size\n#\n#         lg = LineGraph().load(os.path.join(self.data_root, self.img[item][:-4] + \".lg\"))\n#         num_junc = lg.num_junctions\n#         # assert num_junc <= self.max_junctions, f\"{(item, num_junc)}\"\n#         junc = np.zeros((max(num_junc, self.max_junctions), 2))\n#         # tic = time()\n#         junc[:num_junc] = np.array([j if np.sum(j) > 0 else j + 1 for j in lg.junctions()])\n#         # print(f\"junc time: {time() - tic:.4f}\")\n#\n#         assert np.sum(junc[:num_junc].sum(axis=1) <= 0) == 0, f\"{item}\"\n#         # tic = time()\n#         adj_mtx = np.zeros((max(num_junc, self.max_junctions), max(num_junc, self.max_junctions)))\n#         # print(f\"mtx time: {time() - tic:.4f}\")\n#         adj_mtx[:num_junc, :num_junc] = lg.adj_mtx\n#\n#         if self.transforms is not None:\n#             img, junc = self.transforms(img, junc)\n#\n#         cur_w, cur_h = img.size\n#\n#         junc[junc >= img.size[0]] = img.size[0] - 1\n#         junc[junc < 0] = 0\n#         # tic = time()\n#         heatmap = gen_gaussian_map(junc[:num_junc], img.size[:2], self.sigma_junction)\n#         assert cur_h == cur_w\n#         line_map = lg.line_map(cur_h, cur_w / ori_w, cur_h / ori_h, line_width=self.sigma_junction)\n#         # print(f\"gaussian time: {time() - tic:.4f}\")\n#\n#         if num_junc > self.max_junctions:\n#             choice_junc = np.random.choice(num_junc, self.max_junctions, replace=False)\n#             junc = np.array(junc[choice_junc])\n#             adj_mtx = np.array(adj_mtx[choice_junc][:, choice_junc])\n#\n#         img = np.array(np.asarray(img)[:, :, ::-1])\n#         img = th.from_numpy(img).permute(2, 0, 1)\n#         adj_mtx = th.from_numpy(adj_mtx)\n#         junc = th.from_numpy(junc)\n#         heatmap = th.from_numpy(heatmap)\n#         line_map = th.from_numpy(line_map)\n#\n#         batch = dict(\n#             image=img.float(),\n#             adj_mtx=adj_mtx.float(),\n#             heatmap=heatmap.float(),\n#             junctions=junc.float(),\n#             line_map=line_map.float()\n#         )\n#\n#         return batch\n#\n#     def __call__(self, item):\n#         return self.__getitem__(item)\n#\n#     def __len__(self):\n#         return len(self.img)"
  },
  {
    "path": "main.py",
    "content": "# System libs\nimport os\nimport time\n\n# Numerical libs\nimport torch\nimport torch.nn as nn\nfrom torch import optim\nfrom torch.utils.data import DataLoader\nimport numpy as np\n\n# Our libs\nfrom data.sist_line import SISTLine\nimport data.transforms as tf\nfrom models.lsd import LSDModule\nfrom utils import AverageMeter, graph2line, draw_lines, draw_jucntions\n\n# tensorboard\nfrom tensorboardX import SummaryWriter\nimport torchvision.utils as vutils\n\nimport fire\n\n\ndef weight_fn(dist_map, max_dist, mid=0.1, scale=10):\n    with torch.no_grad():\n        dist_map = dist_map / max_dist\n        weight = (torch.exp(scale * (dist_map - mid)) - torch.exp(scale * (-dist_map + mid))) / \\\n                   (torch.exp(scale * (dist_map - mid)) + torch.exp(scale * (-dist_map + mid))) / 2 + 0.5\n        return weight\n\n\nclass LSDTrainer(object):\n    def __init__(\n            self,\n            # exp params\n            exp_name=\"u50_block\",\n            # arch params\n            backbone=\"resnet50\",\n            backbone_kwargs={},\n            dim_embedding=256,\n            feature_spatial_scale=0.25,\n            max_junctions=512,\n            junction_pooling_threshold=0.2,\n            junc_pooling_size=15,\n            attention_sigma=1.,\n            junction_heatmap_criterion=\"binary_cross_entropy\",\n            block_inference_size=64,\n            adjacency_matrix_criterion=\"binary_cross_entropy\",\n            # data params\n            data_root=r\"/home/ziheng/indoorDist_new2\",\n            img_size=416,\n            junc_sigma=3.,\n            batch_size=2,\n            # train params\n            gpus=[0,],\n            num_workers=5,\n            resume_epoch=\"latest\",\n            is_train_junc=True,\n            is_train_adj=True,\n            # vis params\n            vis_junc_th=0.3,\n            vis_line_th=0.3\n    ):\n        os.environ[\"CUDA_VISIBLE_DEVICES\"] = \",\".join(str(c) for c in gpus)\n\n        self.is_cuda = bool(gpus)\n\n        self.model = LSDModule(\n            backbone=backbone,\n            dim_embedding=dim_embedding,\n            backbone_kwargs=backbone_kwargs,\n            junction_pooling_threshold=junction_pooling_threshold,\n            max_junctions=max_junctions,\n            feature_spatial_scale=feature_spatial_scale,\n            junction_heatmap_criterion=junction_heatmap_criterion,\n            junction_pooling_size=junc_pooling_size,\n            attention_sigma=attention_sigma,\n            block_inference_size=block_inference_size,\n            adjacency_matrix_criterion=adjacency_matrix_criterion,\n            weight_fn=weight_fn,\n            is_train_adj=is_train_adj,\n            is_train_junc=is_train_junc\n        )\n\n        self.exp_name = exp_name\n        os.makedirs(os.path.join(\"log\", exp_name), exist_ok=True)\n        os.makedirs(os.path.join(\"ckpt\", exp_name), exist_ok=True)\n        self.writer = SummaryWriter(log_dir=os.path.join(\"log\", exp_name))\n\n        # checkpoints\n        self.states = dict(\n            last_epoch=-1,\n            elapsed_time=0,\n            state_dict=None\n        )\n\n        if resume_epoch and os.path.isfile(os.path.join(\"ckpt\", exp_name, f\"train_states_{resume_epoch}.pth\")):\n            states = torch.load(\n                os.path.join(\"ckpt\", exp_name, f\"train_states_{resume_epoch}.pth\"))\n            print(f\"resume traning from epoch {states['last_epoch']}\")\n            self.model.load_state_dict(states[\"state_dict\"])\n            self.states.update(states)\n\n        self.train_data = SISTLine(\n            data_root=data_root,\n            transforms=tf.Compose(\n                tf.Resize((img_size, img_size)),\n                tf.RandomHorizontalFlip(),\n                tf.RandomColorAug()\n            ),\n            phase=\"train\",\n            sigma_junction=junc_sigma,\n            max_junctions=max_junctions)\n                  \n        assert len(self.train_data) > 0, \"Wow, there is nothing in your data folder. Please check the --data-root parameter in your train.sh.\"\n\n        self.train_loader = DataLoader(\n            self.train_data,\n            batch_size=batch_size,\n            shuffle=True,\n            num_workers=num_workers,\n            pin_memory=True\n        )\n\n        self.eval_data = SISTLine(\n            data_root=data_root,\n            transforms=tf.Compose(\n                tf.Resize((img_size, img_size)),\n            ),\n            phase=\"val\",\n            sigma_junction=junc_sigma,\n            max_junctions=max_junctions)\n\n        self.eval_loader = DataLoader(\n            self.eval_data,\n            batch_size=batch_size,\n            shuffle=False,\n            num_workers=num_workers,\n            pin_memory=True\n        )\n\n        self.vis_junc_th = vis_junc_th\n        self.vis_line_th = vis_line_th\n        self.block_size = block_inference_size\n        self.max_junctions = max_junctions\n        self.is_train_junc = is_train_junc\n        self.is_train_adj = is_train_adj\n\n    @staticmethod\n    def _group_weight(module, lr):\n        group_decay = []\n        group_no_decay = []\n        for m in module.modules():\n            if isinstance(m, nn.Linear):\n                group_decay.append(m.weight)\n                if m.bias is not None:\n                    group_no_decay.append(m.bias)\n            elif isinstance(m, nn.modules.conv._ConvNd):\n                group_decay.append(m.weight)\n                if m.bias is not None:\n                    group_no_decay.append(m.bias)\n            elif isinstance(m, nn.modules.batchnorm._BatchNorm) or isinstance(m, nn.GroupNorm):\n                if m.weight is not None:\n                    group_no_decay.append(m.weight)\n                if m.bias is not None:\n                    group_no_decay.append(m.bias)\n\n        assert len(list(\n            module.parameters())) == len(group_decay) + len(group_no_decay)\n        groups = [\n            dict(params=group_decay, lr=lr),\n            dict(params=group_no_decay, lr=lr, weight_decay=.0)\n        ]\n        return groups\n\n    def end(self):\n        self.writer.close()\n        return \"command queue finished.\"\n\n    def _train_epoch(self):\n        net_time = AverageMeter()\n        data_time = AverageMeter()\n        vis_time = AverageMeter()\n\n        epoch = self.states[\"last_epoch\"]\n        data_loader = self.train_loader\n        if self.is_cuda:\n            self.model = self.model.cuda()\n        params = self._group_weight(self.model.backbone, self.lr)\n        if self.is_train_junc:\n            params += self._group_weight(self.model.junc_infer, self.lr)\n        if self.is_train_adj:\n            params += self._group_weight(self.model.adj_infer, self.lr)\n            params += self._group_weight(self.model.adj_embed, self.lr)\n        if self.solver == \"Adadelta\":\n            solver = optim.__dict__[self.solver](params, weight_decay=self.weight_decay)\n        else:\n            solver = optim.__dict__[self.solver](params, weight_decay=self.weight_decay, momentum=self.momentum)\n\n        # main loop\n        torch.set_grad_enabled(True)\n        tic = time.time()\n        print(f\"start training epoch: {epoch}\", flush=True)\n\n        if self.is_cuda:\n            model = nn.DataParallel(self.model).train()\n        else:\n            model = self.model.train()\n\n        for i, batch in enumerate(data_loader):\n            if self.is_cuda:\n                img = batch[\"image\"].cuda()\n                heatmap_gt = batch[\"heatmap\"].cuda()\n                adj_mtx_gt = batch[\"adj_mtx\"].cuda()\n                junctions_gt = batch[\"junctions\"].cuda()\n            else:\n                img = batch[\"image\"]\n                heatmap_gt = batch[\"heatmap\"]\n                adj_mtx_gt = batch[\"adj_mtx\"]\n                junctions_gt = batch[\"junctions\"]\n\n            # measure elapsed time\n            data_time.update(time.time() - tic)\n            tic = time.time()\n\n            junc_pred, heatmap_pred, adj_mtx_pred, loss_hm, loss_adj = model(\n                img, heatmap_gt, adj_mtx_gt, self.lambda_heatmap, self.lambda_adj, junctions_gt\n            )\n\n            model.zero_grad()\n            loss_adj = loss_adj.mean()\n            loss_hm = loss_hm.mean()\n            loss = (loss_hm if self.is_train_junc else 0) + (loss_adj if self.is_train_adj else 0)\n            loss.backward()\n            solver.step()\n\n            # measure elapsed time\n            net_time.update(time.time() - tic)\n            tic = time.time()\n\n            # visualize result\n            if i % self.vis_line_interval == 0:\n                img = img.cpu().numpy()\n                heatmap_pred = heatmap_pred.detach().cpu()\n                adj_mtx_pred = adj_mtx_pred.detach().cpu().numpy()\n                junctions_gt = junctions_gt.cpu().numpy()\n                adj_mtx_gt = adj_mtx_gt.cpu().numpy()\n                self._vis_train(epoch, i, len(data_loader), img, heatmap_pred, adj_mtx_pred, junctions_gt, adj_mtx_gt)\n\n            vis_heatmap_gt = vutils.make_grid(\n                heatmap_gt.view(heatmap_gt.size(0), 1, heatmap_gt.size(1), heatmap_gt.size(2)))\n            vis_heatmap_pred = vutils.make_grid(\n                heatmap_pred.view(heatmap_gt.size(0), 1, heatmap_gt.size(1), heatmap_gt.size(2)))\n\n            self.writer.add_scalar(self.exp_name + \"/\" + \"train/loss_total\",\n                                   loss.item(),\n                                   epoch * len(data_loader) + i)\n            self.writer.add_scalar(self.exp_name + \"/\" + \"train/loss_heatmap\",\n                                   loss_hm.item() / self.lambda_heatmap if self.lambda_heatmap else 0,\n                                   epoch * len(data_loader) + i)\n            self.writer.add_scalar(self.exp_name + \"/\" + \"train/loss_adj_mtx\",\n                                   loss_adj.item() / self.lambda_adj if self.lambda_adj else 0,\n                                   epoch * len(data_loader) + i)\n            self.writer.add_image(self.exp_name + \"/\" + \"train/heatmap_gt\",\n                                  vis_heatmap_gt,\n                                  epoch * len(data_loader) + i)\n            self.writer.add_image(self.exp_name + \"/\" + \"train/heatmap_pred\",\n                                  vis_heatmap_pred,\n                                  epoch * len(data_loader) + i)\n\n            vis_time.update(time.time() - tic)\n            info = f\"epoch: [{epoch}][{i}/{len(data_loader)}], lr: {self.lr}, \" \\\n                   f\"time_total: {net_time.average() + data_time.average() + vis_time.average():.2f}, \" \\\n                   f\"time_data: {data_time.average():.2f}, time_net: {net_time.average():.2f}, \" \\\n                   f\"time_vis: {vis_time.average():.2f}, \" \\\n                   f\"loss: {loss.item():.4f}, \" \\\n                   f\"loss_heatmap: {loss_hm.item() / self.lambda_heatmap if self.lambda_heatmap else 0:.4f}, \" \\\n                   f\"loss_adj_mtx: {loss_adj.item() / self.lambda_adj if self.lambda_adj else 0:.4f}\"\n            self.writer.add_text(self.exp_name + \"/\" + \"train/info\", info,\n                                 epoch * len(data_loader) + i)\n            print(info, flush=True)\n            # measure elapsed time\n            tic = time.time()\n\n    def _vis_train(self, epoch, i, len_loader, img, heatmap, adj_mtx, junctions_gt, adj_mtx_gt):\n        junctions_gt = np.int32(junctions_gt)\n        lines_gt, scores_gt = graph2line(junctions_gt, adj_mtx_gt)\n        vis_line_gt = vutils.make_grid(\n            draw_lines(img, lines_gt, scores_gt))\n        lines_pred, score_pred = graph2line(junctions_gt, adj_mtx, threshold=self.vis_line_th)\n        vis_line_pred = vutils.make_grid(\n            draw_lines(img, lines_pred, score_pred))\n        junc_score = []\n        line_score = []\n        for m, juncs in zip(heatmap, junctions_gt):\n            juncs = juncs[juncs.sum(axis=1) > 0]\n            junc_score += m[juncs[:, 1], juncs[:, 0]].tolist()\n        for s in score_pred:\n            line_score += s.tolist()\n\n        self.writer.add_image(self.exp_name + \"/\" + \"train/lines_gt\",\n                              vis_line_gt,\n                              epoch * len_loader + i)\n        self.writer.add_image(self.exp_name + \"/\" + \"train/lines_pred\",\n                              vis_line_pred,\n                              epoch * len_loader + i)\n        self.writer.add_scalar(\n            self.exp_name + \"/\" + \"train/mean_junc_score\",\n            np.mean(junc_score),\n            epoch * len_loader + i)\n        self.writer.add_scalar(\n            self.exp_name + \"/\" + \"train/mean_line_score\",\n            np.mean(line_score),\n            epoch * len_loader + i)\n\n    def _checkpoint(self):\n        print('Saving checkpoints...')\n\n        train_states = self.states\n\n        train_states[\"state_dict\"] = self.model.cpu().state_dict()\n\n        torch.save(\n            train_states,\n            os.path.join(\"ckpt\", self.exp_name,\n                         \"train_states_latest.pth\"))\n        torch.save(\n            train_states,\n            os.path.join(\"ckpt\", self.exp_name,\n                         f\"train_states_{self.states['last_epoch']}.pth\"))\n\n        state = torch.load(os.path.join(\"ckpt\", self.exp_name, \"train_states_latest.pth\"))\n        self.model.load_state_dict(state[\"state_dict\"])\n\n    def train(\n            self,\n            end_epoch=20,\n            solver=\"SGD\",\n            lr=1.,\n            weight_decay=5e-4,\n            momentum=0.9,\n            lambda_heatmap=1.,\n            lambda_adj=1.,\n            vis_line_interval=20,\n    ):\n        self.vis_line_interval = vis_line_interval\n        self.end_epoch = end_epoch\n        self.lr = lr\n        self.weight_decay = weight_decay\n        self.momentum = momentum\n        self.lambda_heatmap = lambda_heatmap\n        self.lambda_adj = lambda_adj\n        self.solver = solver\n\n        start_epoch = self.states[\"last_epoch\"] + 1\n\n        for epoch in range(start_epoch, end_epoch):\n            self.states[\"last_epoch\"] = epoch\n            self._train_epoch()\n            self._checkpoint()\n\n        return self\n\n    def _vis_eval(self, epoch, i, len_loader, img, heatmap, adj_mtx, junctions_pred, junctions_gt, adj_mtx_gt):\n        junctions_gt = np.int32(junctions_gt)\n        lines_gt, scores_gt = graph2line(junctions_gt, adj_mtx_gt, threshold=self.vis_junc_th)\n        vis_line_gt = vutils.make_grid(\n            draw_lines(img, lines_gt, scores_gt))\n        img_with_junc = draw_jucntions(img, junctions_pred)\n        img_with_junc = torch.stack(img_with_junc, dim=0).numpy()[:, ::-1, :, :]\n        lines_pred, score_pred = graph2line(junctions_pred, adj_mtx)\n        vis_line_pred = vutils.make_grid(\n            draw_lines(img_with_junc, lines_pred, score_pred))\n        junc_score = []\n        line_score = []\n        for m, juncs in zip(heatmap, junctions_gt):\n            juncs = juncs[juncs.sum(axis=1) > 0]\n            junc_score += m[juncs[:, 1], juncs[:, 0]].tolist()\n        for s in score_pred:\n            line_score += s.tolist()\n\n        junc_pooling = vutils.make_grid(draw_jucntions(heatmap, junctions_pred))\n\n        self.writer.add_image(self.exp_name + \"/\" + \"eval/junction_pooling\",\n                              junc_pooling,\n                              epoch * len_loader + i)\n\n        self.writer.add_image(self.exp_name + \"/\" + \"eval/lines_gt\",\n                              vis_line_gt,\n                              epoch * len_loader + i)\n        self.writer.add_image(self.exp_name + \"/\" + \"eval/lines_pred\",\n                              vis_line_pred,\n                              epoch * len_loader + i)\n        self.writer.add_scalar(\n            self.exp_name + \"/\" + \"eval/mean_junc_score\",\n            np.mean(junc_score),\n            epoch * len_loader + i)\n        self.writer.add_scalar(\n            self.exp_name + \"/\" + \"eval/mean_line_score\",\n            np.mean(line_score),\n            epoch * len_loader + i)\n\n    def eval(self,\n             lambda_heatmap=1.,\n             lambda_adj=1.,\n             off_line=False,\n             epoch=None\n             ):\n\n        if not off_line:\n            if not (self.states[\"last_epoch\"] == epoch - 1):\n                return self\n        else:\n            self.lambda_heatmap = lambda_heatmap\n            self.lambda_adj = lambda_adj\n\n        net_time = AverageMeter()\n        data_time = AverageMeter()\n        vis_time = AverageMeter()\n        ave_loss = AverageMeter()\n        ave_loss_heatmap = AverageMeter()\n        ave_loss_adj_mtx = AverageMeter()\n\n        epoch = self.states[\"last_epoch\"]\n        data_loader = self.eval_loader\n\n        # main loop\n        torch.set_grad_enabled(False)\n        tic = time.time()\n        print(f\"start evaluating epoch: {epoch}\", flush=True)\n\n        if self.is_cuda:\n            model = nn.DataParallel(self.model.cuda()).train()\n        else:\n            model = self.model.train()\n\n        for i, batch in enumerate(data_loader):\n            if self.is_cuda:\n                img = batch[\"image\"].cuda()\n                heatmap_gt = batch[\"heatmap\"].cuda()\n                adj_mtx_gt = batch[\"adj_mtx\"].cuda()\n                junctions_gt = batch[\"junctions\"].cuda()\n            else:\n                img = batch[\"image\"]\n                heatmap_gt = batch[\"heatmap\"]\n                adj_mtx_gt = batch[\"adj_mtx\"]\n                junctions_gt = batch[\"junctions\"]\n\n            # measure elapsed time\n            data_time.update(time.time() - tic)\n            tic = time.time()\n\n            junc_pred, heatmap_pred, adj_mtx_pred, loss_hm, loss_adj = model(\n                img, heatmap_gt, adj_mtx_gt, self.lambda_heatmap, self.lambda_adj, junctions_gt\n            )\n\n            loss_adj = loss_adj.mean()\n            loss_hm = loss_hm.mean()\n            loss = loss_adj + loss_hm\n            ave_loss_adj_mtx.update(loss_adj.item() / self.lambda_adj if self.lambda_adj else 0)\n            ave_loss_heatmap.update(loss_hm.item() / self.lambda_heatmap if self.lambda_heatmap else 0)\n            ave_loss.update(loss.item())\n\n            # measure elapsed time\n            net_time.update(time.time() - tic)\n            tic = time.time()\n\n            # visualize eval\n            img = img.cpu().numpy()\n            heatmap = heatmap_pred.detach().cpu().numpy()\n            junctions_pred = junc_pred.detach().cpu().numpy()\n            adj_mtx = adj_mtx_pred.detach().cpu().numpy()\n            junctions_gt = junctions_gt.cpu().numpy()\n            adj_mtx_gt = adj_mtx_gt.cpu().numpy()\n            self._vis_eval(epoch, i, len(data_loader), img, heatmap, adj_mtx, junctions_pred, junctions_gt, adj_mtx_gt)\n\n            vis_heatmap_gt = vutils.make_grid(\n                heatmap_gt.view(heatmap_gt.size(0), 1, heatmap_gt.size(1), heatmap_gt.size(2)))\n            vis_heatmap_pred = vutils.make_grid(\n                heatmap.view(heatmap_gt.size(0), 1, heatmap_gt.size(1), heatmap_gt.size(2)))\n\n            self.writer.add_scalar(self.exp_name + \"/\" + \"eval/loss_total\",\n                                   loss.item(),\n                                   epoch * len(data_loader) + i)\n            self.writer.add_scalar(self.exp_name + \"/\" + \"eval/loss_heatmap\",\n                                   loss_hm.item() / self.lambda_heatmap if self.lambda_heatmap else 0,\n                                   epoch * len(data_loader) + i)\n            self.writer.add_scalar(self.exp_name + \"/\" + \"eval/loss_adj_mtx\",\n                                   loss_adj.item() / self.lambda_adj if self.lambda_adj else 0,\n                                   epoch * len(data_loader) + i)\n            self.writer.add_image(self.exp_name + \"/\" + \"eval/heatmap_gt\",\n                                  vis_heatmap_gt,\n                                  epoch * len(data_loader) + i)\n            self.writer.add_image(self.exp_name + \"/\" + \"eval/heatmap_pred\",\n                                  vis_heatmap_pred,\n                                  epoch * len(data_loader) + i)\n\n            vis_time.update(time.time() - tic)\n            info = f\"epoch: [{epoch}][{i}/{len(data_loader)}], \" \\\n                   f\"time_total: {net_time.average() + data_time.average() + vis_time.average():.2f}, \" \\\n                   f\"time_data: {data_time.average():.2f}, time_net: {net_time.average():.2f}, \" \\\n                   f\"time_vis: {vis_time.average():.2f}, \" \\\n                   f\"loss: {loss.item():.4f}, \" \\\n                   f\"loss_heatmap: {loss_hm.item() / self.lambda_heatmap if self.lambda_heatmap else 0:.4f}, \" \\\n                   f\"loss_adj_mtx: {loss_adj.item() / self.lambda_adj if self.lambda_adj else 0:.4f}\"\n            if i == len(data_loader) - 1:\n                info += f\"\\n*[{epoch}] \" \\\n                        f\"ave_loss: {ave_loss.average():.4f}, \" \\\n                        f\"ave_loss_heatmap: {ave_loss_heatmap.average():.4f}, \" \\\n                        f\"ave_loss_adj_mtx: {ave_loss_adj_mtx.average():.4f}\"\n\n            self.writer.add_text(self.exp_name + \"/\" + \"eval/info\", info,\n                                 epoch * len(data_loader) + i)\n            print(info, flush=True)\n            # measure elapsed time\n            tic = time.time()\n\n        return self\n\n\nif __name__ == \"__main__\":\n    fire.Fire(LSDTrainer)\n    # trainer = LSDTrainer().train(lr=1.)\n"
  },
  {
    "path": "models/__init__.py",
    "content": "from . import lsd"
  },
  {
    "path": "models/backbone.py",
    "content": "import os\nimport sys\nimport torch\nimport torch.nn as nn\nimport math\nfrom models.common import conv3x3, conv3x3_bn_relu, inconv, up, down, outconv, weights_init\nfrom torch.nn import functional as F\n\ntry:\n    from urllib import urlretrieve\nexcept ImportError:\n    from urllib.request import urlretrieve\n\n\n__all__ = ['ResNetU50Backbone', 'UNetBackbone']\n\n\nmodel_urls = {\n    'resnet50': 'http://sceneparsing.csail.mit.edu/model/pretrained_resnet/resnet50-imagenet.pth',\n}\n\n\nclass Bottleneck(nn.Module):\n    expansion = 4\n\n    def __init__(self, inplanes, planes, stride=1, downsample=None):\n        super(Bottleneck, self).__init__()\n        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)\n        self.bn1 = nn.BatchNorm2d(planes)\n        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,\n                               padding=1, bias=False)\n        self.bn2 = nn.BatchNorm2d(planes)\n        self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)\n        self.bn3 = nn.BatchNorm2d(planes * 4)\n        self.relu = nn.ReLU(inplace=True)\n        self.downsample = downsample\n        self.stride = stride\n\n    def forward(self, x):\n        residual = x\n\n        out = self.conv1(x)\n        out = self.bn1(out)\n        out = self.relu(out)\n\n        out = self.conv2(out)\n        out = self.bn2(out)\n        out = self.relu(out)\n\n        out = self.conv3(out)\n        out = self.bn3(out)\n\n        if self.downsample is not None:\n            residual = self.downsample(x)\n\n        out += residual\n        out = self.relu(out)\n\n        return out\n\n\nclass ResNet(nn.Module):\n\n    def __init__(self, block, layers, num_classes=1000):\n        self.inplanes = 128\n        super(ResNet, self).__init__()\n        self.conv1 = conv3x3(3, 64, stride=2)\n        self.bn1 = nn.BatchNorm2d(64)\n        self.relu1 = nn.ReLU(inplace=True)\n        self.conv2 = conv3x3(64, 64)\n        self.bn2 = nn.BatchNorm2d(64)\n        self.relu2 = nn.ReLU(inplace=True)\n        self.conv3 = conv3x3(64, 128)\n        self.bn3 = nn.BatchNorm2d(128)\n        self.relu3 = nn.ReLU(inplace=True)\n        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)\n\n        self.layer1 = self._make_layer(block, 64, layers[0])\n        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)\n        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)\n        self.layer4 = self._make_layer(block, 512, layers[3], stride=2)\n        self.avgpool = nn.AvgPool2d(7, stride=1)\n        self.fc = nn.Linear(512 * block.expansion, num_classes)\n\n        for m in self.modules():\n            if isinstance(m, nn.Conv2d):\n                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels\n                m.weight.data.normal_(0, math.sqrt(2. / n))\n            elif isinstance(m, nn.BatchNorm2d):\n                m.weight.data.fill_(1)\n                m.bias.data.zero_()\n\n    def _make_layer(self, block, planes, blocks, stride=1):\n        downsample = None\n        if stride != 1 or self.inplanes != planes * block.expansion:\n            downsample = nn.Sequential(\n                nn.Conv2d(self.inplanes, planes * block.expansion,\n                          kernel_size=1, stride=stride, bias=False),\n                nn.BatchNorm2d(planes * block.expansion),\n            )\n\n        layers = []\n        layers.append(block(self.inplanes, planes, stride, downsample))\n        self.inplanes = planes * block.expansion\n        for i in range(1, blocks):\n            layers.append(block(self.inplanes, planes))\n\n        return nn.Sequential(*layers)\n\n    def forward(self, x):\n        x = self.relu1(self.bn1(self.conv1(x)))\n        x = self.relu2(self.bn2(self.conv2(x)))\n        x = self.relu3(self.bn3(self.conv3(x)))\n        x = self.maxpool(x)\n\n        x = self.layer1(x)\n        x = self.layer2(x)\n        x = self.layer3(x)\n        x = self.layer4(x)\n\n        x = self.avgpool(x)\n        x = x.view(x.size(0), -1)\n        x = self.fc(x)\n\n        return x\n\n\nclass ResnetDilated(nn.Module):\n    def __init__(self, orig_resnet, dilate_scale=8):\n        super(ResnetDilated, self).__init__()\n        from functools import partial\n\n        if dilate_scale == 8:\n            orig_resnet.layer3.apply(\n                partial(self._nostride_dilate, dilate=2))\n            orig_resnet.layer4.apply(\n                partial(self._nostride_dilate, dilate=4))\n        elif dilate_scale == 16:\n            orig_resnet.layer4.apply(\n                partial(self._nostride_dilate, dilate=2))\n\n        # take pretrained resnet, except AvgPool and FC\n        self.conv1 = orig_resnet.conv1\n        self.bn1 = orig_resnet.bn1\n        self.relu1 = orig_resnet.relu1\n        self.conv2 = orig_resnet.conv2\n        self.bn2 = orig_resnet.bn2\n        self.relu2 = orig_resnet.relu2\n        self.conv3 = orig_resnet.conv3\n        self.bn3 = orig_resnet.bn3\n        self.relu3 = orig_resnet.relu3\n        self.maxpool = orig_resnet.maxpool\n        self.layer1 = orig_resnet.layer1\n        self.layer2 = orig_resnet.layer2\n        self.layer3 = orig_resnet.layer3\n        self.layer4 = orig_resnet.layer4\n\n    @staticmethod\n    def _nostride_dilate(m, dilate):\n        classname = m.__class__.__name__\n        if classname.find('Conv') != -1:\n            # the convolution with stride\n            if m.stride == (2, 2):\n                m.stride = (1, 1)\n                if m.kernel_size == (3, 3):\n                    m.dilation = (dilate // 2, dilate // 2)\n                    m.padding = (dilate // 2, dilate // 2)\n            # other convoluions\n            else:\n                if m.kernel_size == (3, 3):\n                    m.dilation = (dilate, dilate)\n                    m.padding = (dilate, dilate)\n\n    def forward(self, x):\n        conv_out = []\n\n        x = self.relu1(self.bn1(self.conv1(x)))\n        x = self.relu2(self.bn2(self.conv2(x)))\n        x = self.relu3(self.bn3(self.conv3(x)))\n        x = self.maxpool(x)\n\n        x = self.layer1(x)\n        conv_out.append(x)\n        x = self.layer2(x)\n        conv_out.append(x)\n        x = self.layer3(x)\n        conv_out.append(x)\n        x = self.layer4(x)\n        conv_out.append(x)\n\n        return conv_out\n\n\ndef resnet50(pretrained=False, **kwargs):\n    \"\"\"Constructs a ResNet-50 model.\n\n    Args:\n        pretrained (bool): If True, returns a model pre-trained on Places\n    \"\"\"\n    model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)\n    if pretrained:\n        model.load_state_dict(load_url(model_urls['resnet50']), strict=False)\n    return model\n\n\ndef load_url(url, model_dir='./pretrained', map_location=None):\n    if not os.path.exists(model_dir):\n        os.makedirs(model_dir)\n    filename = url.split('/')[-1]\n    cached_file = os.path.join(model_dir, filename)\n    if not os.path.exists(cached_file):\n        sys.stderr.write('Downloading: \"{}\" to {}\\n'.format(url, cached_file))\n        urlretrieve(url, cached_file)\n    return torch.load(cached_file, map_location=map_location)\n\n\nclass UPerNet(nn.Module):\n    def __init__(self, num_class=150, fc_dim=4096, pool_scales=(1, 2, 3, 6),\n                 fpn_inplanes=(256, 512, 1024, 2048), fpn_dim=256):\n        super(UPerNet, self).__init__()\n\n        # PPM Module\n        self.ppm_pooling = []\n        self.ppm_conv = []\n\n        for scale in pool_scales:\n            self.ppm_pooling.append(nn.AdaptiveAvgPool2d(scale))\n            self.ppm_conv.append(nn.Sequential(\n                nn.Conv2d(fc_dim, 512, kernel_size=1, bias=False),\n                nn.BatchNorm2d(512),\n                nn.ReLU(inplace=True)\n            ))\n        self.ppm_pooling = nn.ModuleList(self.ppm_pooling)\n        self.ppm_conv = nn.ModuleList(self.ppm_conv)\n        self.ppm_last_conv = conv3x3_bn_relu(fc_dim + len(pool_scales) * 512, fpn_dim, 1)\n\n        # FPN Module\n        self.fpn_in = []\n        for fpn_inplane in fpn_inplanes[:-1]:  # skip the top layer\n            self.fpn_in.append(nn.Sequential(\n                nn.Conv2d(fpn_inplane, fpn_dim, kernel_size=1, bias=False),\n                nn.BatchNorm2d(fpn_dim),\n                nn.ReLU(inplace=True)\n            ))\n        self.fpn_in = nn.ModuleList(self.fpn_in)\n\n        self.fpn_out = []\n        for i in range(len(fpn_inplanes) - 1):  # skip the top layer\n            self.fpn_out.append(nn.Sequential(\n                conv3x3_bn_relu(fpn_dim, fpn_dim, 1),\n            ))\n        self.fpn_out = nn.ModuleList(self.fpn_out)\n\n        self.conv_last = nn.Sequential(\n            conv3x3_bn_relu(len(fpn_inplanes) * fpn_dim, fpn_dim, 1),\n            nn.Conv2d(fpn_dim, num_class, kernel_size=1)\n        )\n\n    def forward(self, conv_out, segSize=None):\n        conv5 = conv_out[-1]\n\n        input_size = conv5.size()\n        ppm_out = [conv5]\n        for pool_scale, pool_conv in zip(self.ppm_pooling, self.ppm_conv):\n            ppm_out.append(pool_conv(F.interpolate(\n                pool_scale(conv5),\n                (input_size[2], input_size[3]),\n                mode='bilinear', align_corners=False)))\n        ppm_out = torch.cat(ppm_out, 1)\n        f = self.ppm_last_conv(ppm_out)\n\n        fpn_feature_list = [f]\n        for i in reversed(range(len(conv_out) - 1)):\n            conv_x = conv_out[i]\n            conv_x = self.fpn_in[i](conv_x)  # lateral branch\n\n            f = F.interpolate(\n                f, size=conv_x.size()[2:], mode='bilinear', align_corners=False)  # top-down branch\n            f = conv_x + f\n\n            fpn_feature_list.append(self.fpn_out[i](f))\n\n        fpn_feature_list.reverse()  # [P2 - P5]\n        output_size = fpn_feature_list[0].size()[2:]\n        fusion_list = [fpn_feature_list[0]]\n        for i in range(1, len(fpn_feature_list)):\n            fusion_list.append(F.interpolate(\n                fpn_feature_list[i],\n                output_size,\n                mode='bilinear', align_corners=False))\n        fusion_out = torch.cat(fusion_list, 1)\n        x = self.conv_last(fusion_out)\n\n        x = F.interpolate(\n            x, size=segSize, mode='bilinear', align_corners=False)\n\n        return x\n\n\nclass ResNetU50Backbone(nn.Module):\n    def __init__(self, dim_embedding=256, encoder_weights=\"\", decoder_weights=\"\"):\n        super(ResNetU50Backbone, self).__init__()\n        self.encoder = self.build_encoder(encoder_weights)\n        self.decoder = self.build_decoder(decoder_weights, dim_embedding)\n\n    def forward(self, img):\n        out = self.encoder(img)\n        out = self.decoder(out, segSize=(img.size(2) // 4, img.size(3) // 4))\n\n        return out\n\n    @staticmethod\n    def build_encoder(weights=''):\n        pretrained = True if len(weights) == 0 else False\n        orig_resnet = resnet50(pretrained=pretrained)\n        net_encoder = ResnetDilated(orig_resnet,\n                                    dilate_scale=8)\n        if len(weights) > 0 and os.path.isfile(weights):\n            print(f'Loading weights for net_encoder @ {weights}')\n            net_encoder.load_state_dict(\n                torch.load(weights, map_location=lambda storage, loc: storage),\n                strict=False\n            )\n        return net_encoder\n\n    @staticmethod\n    def build_decoder(weights='', dim_embedding=64):\n        net_decoder = UPerNet(\n            num_class=150,\n            fc_dim=2048,\n            fpn_dim=512\n        )\n\n        net_decoder.conv_last[1] = conv3x3_bn_relu(net_decoder.conv_last[1].in_channels, dim_embedding, 1)\n        net_decoder.apply(weights_init)\n        if len(weights) > 0 and os.path.isfile(weights):\n            print(f'Loading weights for net_decoder @ {weights}')\n            net_decoder.load_state_dict(\n                torch.load(weights, map_location=lambda storage, loc: storage),\n                strict=False\n            )\n\n        return net_decoder\n\n\nclass UNetBackbone(nn.Module):\n    def __init__(self, dim_embedding=256, n_downs=5, n_ups=3, weights=\"\"):\n        super(UNetBackbone, self).__init__()\n        assert n_downs > 0 and 0 < n_ups <= n_downs\n        self.n_downs = n_downs\n        self.n_ups = n_ups\n        self.inc = inconv(3, 64)\n        down_channels = []\n        for i in range(n_downs):\n            down_channel = 64 * 2**min(n_downs - 1, i + 1)\n            self.add_module(\n                f\"down{i+1}\",\n                down(64 * 2**i, down_channel)\n            )\n            down_channels.append(down_channel)\n        for i in range(n_ups):\n            down_channels.pop()\n            self.add_module(\n                f\"up{i+1}\",\n                up(64 * 2**(n_downs - i), 64 * 2**max(0, n_downs - i - 2))\n            )\n\n        self.outc = outconv(64 * 2**max(0, n_downs - n_ups - 1) + sum(down_channels) // 2, dim_embedding)\n\n        self.apply(weights_init)\n\n        if len(weights) > 0 and os.path.isfile(weights):\n            print(f'Loading weights for unet @ {weights}')\n            self.load_state_dict(\n                torch.load(weights, map_location=lambda storage, loc: storage),\n                strict=False\n            )\n\n    def forward(self, x):\n        downs = [self.inc(x)]\n        for i in range(self.n_downs):\n            downs.append(getattr(self, f\"down{i+1}\")(downs[i]))\n        out = downs.pop()\n        for i in range(self.n_ups):\n            out = getattr(self, f\"up{i+1}\")(out, downs.pop())\n        for i in range(len(downs)):\n            downs[i] = F.interpolate(downs[i], size=out.shape[2:], mode=\"bilinear\", align_corners=False)\n        downs.append(out)\n        out = self.outc(torch.cat(downs, dim=1))\n\n        return out\n"
  },
  {
    "path": "models/common.py",
    "content": "from scipy.ndimage.filters import maximum_filter\nfrom scipy.ndimage.morphology import generate_binary_structure, binary_erosion\nfrom scipy.cluster.hierarchy import fclusterdata\nimport numpy as np\nimport torch.nn as nn\nimport torch\nfrom torch.nn import functional as F\nfrom torch.autograd import Function\nfrom torch.autograd.function import once_differentiable\nfrom collections import Sized, Iterable\n\n\nclass LMFPeakFinder(object):\n    \"\"\"\n    shamelessly borrow from https://stackoverflow.com/a/3689710\n    Takes an image and detect the peaks using the local maximum filter.\n    Returns a boolean mask of the peaks (i.e. 1 when\n    the pixel's value is the neighborhood maximum, 0 otherwise)\n    \"\"\"\n\n    def __init__(self, min_dist=5., min_th=0.3):\n        self.min_dist = min_dist\n        self.min_th = min_th\n\n    def detect(self, image):\n        # define an 8-connected neighborhood\n        neighborhood = generate_binary_structure(2, 2)\n\n        # apply the local maximum filter; all pixel of maximal value\n        # in their neighborhood are set to 1\n        local_max = maximum_filter(image, footprint=neighborhood) == image\n        # local_max is a mask that contains the peaks we are\n        # looking for, but also the background.\n        # In order to isolate the peaks we must remove the background from the mask.\n\n        # we create the mask of the background\n        background = (image < self.min_th)\n\n        # a little technicality: we must erode the background in order to\n        # successfully subtract it form local_max, otherwise a line will\n        # appear along the background border (artifact of the local maximum filter)\n        eroded_background = binary_erosion(background, structure=neighborhood, border_value=1)\n\n        # we obtain the final mask, containing only peaks,\n        # by removing the background from the local_max mask (xor operation)\n        detected_peaks = local_max ^ eroded_background\n\n        detected_peaks[image < self.min_th] = False\n        peaks = np.array(np.nonzero(detected_peaks)).T\n\n        if len(peaks) == 0:\n            return peaks, np.array([])\n\n        # nms\n        if len(peaks) == 1:\n            clusters = [0]\n        else:\n            clusters = fclusterdata(peaks, self.min_dist, criterion=\"distance\")\n        peak_groups = {}\n        for ind_junc, ind_group in enumerate(clusters):\n            if ind_group not in peak_groups.keys():\n                peak_groups[ind_group] = []\n                peak_groups[ind_group].append(peaks[ind_junc])\n        peaks_nms = []\n        peaks_score = []\n        for peak_group in peak_groups.values():\n            values = [image[y, x] for y, x in peak_group]\n            ind_max = np.argmax(values)\n            peaks_nms.append(peak_group[int(ind_max)])\n            peaks_score.append(values[int(ind_max)])\n\n        return np.float32(np.array(peaks_nms)), np.float32(np.array(peaks_score))\n\n\ndef conv3x3(in_planes, out_planes, stride=1):\n    \"3x3 convolution with padding\"\n    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,\n                     padding=1, bias=False)\n\n\ndef conv3x3_bn_relu(in_planes, out_planes, stride=1):\n    return nn.Sequential(\n        conv3x3(in_planes, out_planes, stride),\n        nn.BatchNorm2d(out_planes),\n        nn.ReLU(inplace=True),\n    )\n\n\nclass double_conv(nn.Module):\n    '''(conv => BN => ReLU) * 2'''\n\n    def __init__(self, in_ch, out_ch):\n        super(double_conv, self).__init__()\n        self.conv = nn.Sequential(\n            nn.Conv2d(in_ch, out_ch, 3, padding=1),\n            nn.BatchNorm2d(out_ch),\n            nn.ReLU(inplace=True),\n            nn.Conv2d(out_ch, out_ch, 3, padding=1),\n            nn.BatchNorm2d(out_ch),\n            nn.ReLU(inplace=True)\n        )\n\n    def forward(self, x):\n        x = self.conv(x)\n        return x\n\n\nclass inconv(nn.Module):\n    def __init__(self, in_ch, out_ch):\n        super(inconv, self).__init__()\n        self.conv = double_conv(in_ch, out_ch)\n\n    def forward(self, x):\n        x = self.conv(x)\n        return x\n\n\nclass down(nn.Module):\n    def __init__(self, in_ch, out_ch):\n        super(down, self).__init__()\n        self.mpconv = nn.Sequential(\n            nn.MaxPool2d(2),\n            double_conv(in_ch, out_ch)\n        )\n\n    def forward(self, x):\n        x = self.mpconv(x)\n        return x\n\n\nclass up(nn.Module):\n    def __init__(self, in_ch, out_ch, bilinear=True):\n        super(up, self).__init__()\n\n        #  would be a nice idea if the upsampling could be learned too,\n        #  but my machine do not have enough memory to handle all those weights\n        if bilinear:\n            self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)\n        else:\n            self.up = nn.ConvTranspose2d(in_ch // 2, in_ch // 2, 2, stride=2)\n\n        self.conv = double_conv(in_ch, out_ch)\n\n    def forward(self, x1, x2):\n        x1 = self.up(x1)\n        diffX = x1.size()[2] - x2.size()[2]\n        diffY = x1.size()[3] - x2.size()[3]\n        x2 = F.pad(x2, (diffX // 2, int(diffX / 2),\n                        diffY // 2, int(diffY / 2)))\n        x = torch.cat([x2, x1], dim=1)\n        x = self.conv(x)\n        return x\n\n\nclass outconv(nn.Module):\n    def __init__(self, in_ch, out_ch):\n        super(outconv, self).__init__()\n        self.conv = nn.Conv2d(in_ch, out_ch, 1)\n\n    def forward(self, x):\n        x = self.conv(x)\n        return x\n\n\ndef weights_init(m):\n    classname = m.__class__.__name__\n    if classname.find('Conv') != -1:\n        nn.init.kaiming_normal_(m.weight.data)\n    elif classname.find('BatchNorm') != -1:\n        m.weight.data.fill_(1.)\n        m.bias.data.fill_(1e-4)\n\n\ndef roi_pooling(input, rois, size=(7, 7), spatial_scale=1.0):\n    assert (rois.dim() == 2)\n    assert (rois.size(1) == 5)\n    output = []\n    rois = rois.data.float()\n    num_rois = rois.size(0)\n\n    rois[:, 1:].mul_(spatial_scale)\n    rois = rois.long()\n    for i in range(num_rois):\n        roi = rois[i]\n        im_idx = roi[0]\n        if roi[1] >= input.size(3) or roi[2] >= input.size(2) or roi[1] < 0 or roi[2] < 0:\n            # print(f\"Runtime Warning: roi top left corner out of range: {roi}\", file=sys.stderr)\n            roi[1] = torch.clamp(roi[1], 0, input.size(3) - 1)\n            roi[2] = torch.clamp(roi[2], 0, input.size(2) - 1)\n        if roi[3] >= input.size(3) or roi[4] >= input.size(2) or roi[3] < 0 or roi[4] < 0:\n            # print(f\"Runtime Warning: roi bottom right corner out of range: {roi}\", file=sys.stderr)\n            roi[3] = torch.clamp(roi[3], 0, input.size(3) - 1)\n            roi[4] = torch.clamp(roi[4], 0, input.size(2) - 1)\n        if (roi[3:5] - roi[1:3] < 0).any():\n            # print(f\"Runtime Warning: invalid roi: {roi}\", file=sys.stderr)\n            im = input.new_full((1, input.size(1), 1, 1), 0)\n        else:\n            im = input.narrow(0, im_idx, 1)[..., roi[2]:(roi[4] + 1), roi[1]:(roi[3] + 1)]\n        output.append(F.adaptive_max_pool2d(im, size))\n\n    return torch.cat(output, 0)\n\n\nclass GradAccumulatorFunction(Function):\n    @staticmethod\n    def forward(ctx, input, accumulated_grad=None, mode=\"release\"):\n        ctx.accumulated_grad = accumulated_grad\n        ctx.mode = mode\n        return input\n\n    @staticmethod\n    @once_differentiable\n    def backward(ctx, grad_output):\n        accumulated_grad = ctx.accumulated_grad\n        ctx.accumulated_grad = None\n        if ctx.mode == \"accumulate\":\n            accumulated_grad.add_(grad_output)\n            return torch.zeros_like(grad_output), None, None\n        elif ctx.mode == \"release\":\n            if accumulated_grad is not None:\n                accumulated_grad.add_(grad_output)\n            else:\n                accumulated_grad = grad_output\n            grad_output = accumulated_grad\n            return grad_output, None, None\n        else:\n            raise ValueError(f\"invalid mode {ctx.mode}\")\n\n\nclass GradAccumulator(nn.Module):\n    \"\"\"\n    Helper module used to accumulate gradient of the given tensor w.r.t output of criterion.\n    Typically used when we have a feature extractor followed by several modules that can be calculate independently, the\n    module only retains the last executed submodule and accumulate the gradient produced by former submodules, so that\n    GPU memory used to store the temporary variables in former submodules is saved. It can be also used to extend\n    effective batch size at little expense of memory.\n    \"\"\"\n    def __init__(self, criterion_fns, submodules, collect_fn=None, reduce_method=\"mean\"):\n        super(GradAccumulator, self).__init__()\n        assert isinstance(submodules, (Sized, Iterable)), \"invalid submodules\"\n        if isinstance(criterion_fns, (Sized, Iterable)):\n            assert len(submodules) == len(criterion_fns)\n            assert all([isinstance(submodule, nn.Module) for submodule in submodules])\n            assert all([isinstance(criterion_fn, nn.Module) for criterion_fn in criterion_fns])\n        elif isinstance(criterion_fns, nn.Module):\n            criterion_fns = [criterion_fns for _ in range(len(submodules))]\n        elif criterion_fns is None:\n            criterion_fns = [criterion_fns for _ in range(len(submodules))]\n        else:\n            raise ValueError(\"invalid criterion function\")\n        assert reduce_method in (\"mean\", \"sum\", None)\n\n        self.submodules = nn.ModuleList(submodules)\n        self.criterion_fns = nn.ModuleList(criterion_fns)\n        self.method = reduce_method\n        self.grad_buffer = None\n        self.func = GradAccumulatorFunction.apply\n        self.collect_fn = collect_fn\n\n    def forward(self, tensor):\n        outputs = []\n        losses = tensor.new_full((1,), 0)\n        self.grad_buffer = None\n        for i, (submodule, criterion) in enumerate(zip(self.submodules, self.criterion_fns)):\n            mode = \"accumulate\" if i < len(self.submodules) - 1 else \"release\"\n            if self.grad_buffer is None:\n                self.grad_buffer = torch.zeros_like(tensor)\n            if mode == \"accumulate\":\n                output = tensor.detach()\n                output.requires_grad = True\n            else:\n                output = tensor\n            output = self.func(\n                output,\n                self.grad_buffer,\n                mode,\n            )\n            if isinstance(output, tuple):\n                output = submodule(*output)\n            else:\n                output = submodule(output)\n            if criterion is not None:\n                loss = criterion(output)\n                if self.method == \"mean\":\n                    loss = loss / len(self.submodules)\n\n                if mode == \"accumulate\" and torch.is_grad_enabled():\n                    loss.backward()\n                    loss = loss.detach()\n\n                output = output.detach()\n                losses += loss\n            else:\n                assert not output.requires_grad, \"criterion must be specified to calculate output gradient\"\n\n            outputs.append(output)\n\n        if self.collect_fn is not None:\n            with torch.no_grad():\n                outputs = self.collect_fn(outputs)\n\n        return outputs, losses\n"
  },
  {
    "path": "models/graph.py",
    "content": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom models.common import conv3x3_bn_relu, LMFPeakFinder, weights_init\n\n\nclass JunctionInference(nn.Module):\n    def __init__(self, dim_embedding, pooling_threshold=0.2, max_junctions=512, spatial_scale=0.25, verbose=False):\n        super(JunctionInference, self).__init__()\n        self.dim_embedding = dim_embedding\n        self.pool_th = pooling_threshold\n        self.max_juncs = max_junctions\n        self.map_infer = nn.Sequential(\n            conv3x3_bn_relu(dim_embedding, dim_embedding // 2, 1),\n            conv3x3_bn_relu(dim_embedding // 2, dim_embedding // 2, 1),\n            nn.Conv2d(dim_embedding // 2, 1, 1),\n            nn.Sigmoid()\n        )\n        self.verbose = verbose\n        self.scale = spatial_scale\n        self.map_infer.apply(weights_init)\n\n    def forward(self, feat):\n        bs, ch, h, w = feat.size()\n        junc_map = self.map_infer(feat)\n        junc_map = nn.functional.interpolate(\n            junc_map,\n            scale_factor=1. / self.scale,\n            mode=\"bilinear\",\n            align_corners=False\n        )\n        junc_coord = []\n        for b in range(bs):\n            peak_finder = LMFPeakFinder(min_th=self.pool_th)\n            coord, score = peak_finder.detect(junc_map[b, 0].data.cpu().numpy())\n            if self.verbose:\n                print(f\"find {len(coord)} jucntions.\", flush=True)\n            if coord is None or len(coord) == 0:\n                continue\n            junc_score = torch.from_numpy(score).to(feat)\n            _, ind = torch.sort(junc_score, descending=True)\n            ind = ind.cpu() \n            coord = coord[ind[:self.max_juncs]]\n            coord = coord.reshape((-1, 2))\n            y, x = coord[:, 0], coord[:, 1]\n            y = torch.from_numpy(y).to(feat)\n            x = torch.from_numpy(x).to(feat)\n            assert (x >= 0).all() and (x < junc_map.size(3)).all() and (y >= 0).all() and (y < junc_map.size(2)).all()\n            junc_coord.append(\n                torch.cat([feat.new_full((len(x), 1), b), x.view(-1, 1), y.view(-1, 1)], dim=1)\n            )\n        if len(junc_coord) > 0:\n            junc_coord = torch.cat(junc_coord, dim=0)\n        else:\n            junc_coord = feat.new_full((1, 3), 0.)\n\n        return junc_map.squeeze(1), junc_coord\n\n\nclass LinePooling(nn.Module):\n    def __init__(self, align_size=256, spatial_scale=0.25):\n        super(LinePooling, self).__init__()\n        self.align_size = align_size\n        assert isinstance(self.align_size, int)\n        self.scale = spatial_scale\n\n    def forward(self, feat, coord_st, coord_ed):\n        _, ch, h, w = feat.size()\n        num_st, num_ed = coord_st.size(0), coord_ed.size(0)\n        assert coord_st.size(1) == 3 and coord_ed.size(1) == 3\n        assert (coord_st[:, 0] == coord_st[0, 0]).all() and (coord_ed[:, 0] == coord_st[0, 0]).all()\n        bs = coord_st[0, 0].item()\n        # construct bounding boxes from junction points\n        with torch.no_grad():\n            coord_st = coord_st[:, 1:] * self.scale\n            coord_ed = coord_ed[:, 1:] * self.scale\n            coord_st = coord_st.unsqueeze(1).expand(num_st, num_ed, 2)\n            coord_ed = coord_ed.unsqueeze(0).expand(num_st, num_ed, 2)\n            arr_st2ed = coord_ed - coord_st\n            sample_grid = torch.linspace(0, 1, steps=self.align_size).to(feat).view(1, 1, self.align_size).expand(num_st, num_ed, self.align_size)\n            sample_grid = torch.einsum(\"ijd,ijs->ijsd\", (arr_st2ed, sample_grid)) + coord_st.view(num_st, num_ed, 1, 2).expand(num_st, num_ed, self.align_size, 2)\n            sample_grid = sample_grid.view(num_st, num_ed, self.align_size, 2)\n            sample_grid[..., 0] = sample_grid[..., 0] / (w - 1) * 2 - 1\n            sample_grid[..., 1] = sample_grid[..., 1] / (h - 1) * 2 - 1\n\n        output = F.grid_sample(feat[int(bs)].view(1, ch, h, w).expand(num_st, ch, h, w), sample_grid)\n        assert output.size() == (num_st, ch, num_ed, self.align_size)\n        output = output.permute(0, 2, 1, 3).contiguous()\n\n        return output\n\n\nclass AdjacencyMatrixInference(nn.Module):\n    def __init__(self, dim_embedding=256, align_size=256):\n        super(AdjacencyMatrixInference, self).__init__()\n        self.dim_embedding = dim_embedding\n        self.align_size = align_size\n        self.dblock = nn.Sequential(\n            nn.Conv1d(dim_embedding, dim_embedding, 8, 4, 2, bias=False),\n            nn.GroupNorm(32, dim_embedding),\n            nn.ReLU(inplace=True),\n            nn.Conv1d(dim_embedding, dim_embedding, 8, 4, 2, bias=False),\n            nn.GroupNorm(32, dim_embedding),\n            nn.ReLU(inplace=True),\n            nn.Conv1d(dim_embedding, dim_embedding, 8, 4, 2, bias=False),\n            nn.GroupNorm(32, dim_embedding),\n            nn.ReLU(inplace=True)\n        )\n        self.connectivity_inference = nn.Sequential(\n            nn.Conv1d(dim_embedding, 1, 1, 1, 0),\n            nn.Sigmoid()\n        )\n\n    def forward(self, line_feat):\n        num_st, num_ed, c, s = line_feat.size()\n        output_st2ed = line_feat.view(num_st * num_ed, c, s)\n        output_ed2st = torch.flip(output_st2ed, (2, ))\n        output_st2ed = self.dblock(output_st2ed)\n        output_ed2st = self.dblock(output_ed2st)\n        adjacency_matrix1 = self.connectivity_inference(output_st2ed).view(num_st, num_ed)\n        adjacency_matrix2 = self.connectivity_inference(output_ed2st).view(num_st, num_ed)\n\n        return torch.min(adjacency_matrix1, adjacency_matrix2)\n"
  },
  {
    "path": "models/lsd.py",
    "content": "import torch as th\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport models.graph\nimport models.backbone\nimport models.common\nimport numpy as np\n\n\nclass LSDDataLayer(nn.Module):\n    def __init__(self, mean=None, std=None):\n        super(LSDDataLayer, self).__init__()\n        self.std = [1., 1., 1.] if std is None else std\n        self.mean = [102.9801, 115.9465, 122.7717] if mean is None else mean\n\n    def forward(self, img):\n        assert img.size(1) == 3\n        for ch in range(3):\n            img[:, ch, :, :] = (img[:, ch, :, :] - self.mean[ch]) / self.std[ch]\n\n        return img\n\n\n# noinspection PyTypeChecker\nclass BinaryFocalLoss(nn.Module):\n    def __init__(self, gamma=2., alpha=0.25, size_average=True):\n        super(BinaryFocalLoss, self).__init__()\n        self.gamma = gamma\n        self.alpha = alpha\n        self.size_average = size_average\n\n    def forward(self, input, target, weight=None):\n        if weight is not None:\n            assert weight.size() == input.size(), f\"weight size: {weight.size()}, input size: {input.size()}\"\n            assert (weight >= 0).all() and (weight <= 1).all(), f\"weight max: {weight.max()}, min: {weight.min()}\"\n        input = input.clamp(1.e-6, 1. - 1.e-6)\n        if weight is None:\n            loss = th.sum(\n                - self.alpha * target * ((1 - input) ** self.gamma) * th.log(input)\n                - (1 - self.alpha) * (1 - target) * (input ** self.gamma) * th.log(1 - input))\n        else:\n            loss = th.sum(\n                (- self.alpha * target * ((1 - input) ** self.gamma) * th.log(input)\n                 - (1 - self.alpha) * (1 - target) * (input ** self.gamma) * th.log(1 - input)) * weight\n            )\n        if self.size_average:\n            loss /= input.nelement()\n        return loss\n\n\nclass BlockAdjacencyMatrixInference(nn.Module):\n    def __init__(self,\n                 line_pool_module, adj_infer_module,\n                 current_batch_id, junc_st_st, junc_st_len, junc_ed_st, junc_ed_len, junc_pred\n                 ):\n        super(BlockAdjacencyMatrixInference, self).__init__()\n        self.line_pool = line_pool_module\n        self.adj_infer = adj_infer_module\n        self.b = current_batch_id\n        self.st_st = junc_st_st\n        self.st_len = junc_st_len\n        self.ed_st = junc_ed_st\n        self.ed_len = junc_ed_len\n        self.juncs = junc_pred\n\n    def forward(self, feat):\n        junc_st = self.juncs.narrow(0, self.st_st, self.st_len)\n        junc_ed = self.juncs.narrow(0, self.ed_st, self.ed_len)\n        assert (junc_st[:, 0] == self.b).all() and (junc_ed[:, 0] == self.b).all(), f\"{self.b}\\n{junc_st[:, 0]}\\n{junc_ed[:, 0]}\"\n        line_feat = self.line_pool(feat, junc_st, junc_ed)\n        block_adj_matrix = self.adj_infer(line_feat)\n\n        return block_adj_matrix\n\n\nclass BlockAdjacencyMatrixInferenceCriterion(nn.Module):\n    def __init__(self, adj_matrix_crit, adj_matrix_gt, adj_matrix_loss_lambda,\n                 current_batch_id, mtx_st_st, mtx_st_len, mtx_ed_st, mtx_ed_len,\n                 junc_padded, img_size, line_seg_length_weight_fn\n                 ):\n        super(BlockAdjacencyMatrixInferenceCriterion, self).__init__()\n        self.adj_crit = adj_matrix_crit\n        self.adj_gt = adj_matrix_gt\n        self.loss_lambda = adj_matrix_loss_lambda\n        self.b = current_batch_id\n        self.st_st = mtx_st_st\n        self.st_len = mtx_st_len\n        self.ed_st = mtx_ed_st\n        self.ed_len = mtx_ed_len\n        self.junc = junc_padded\n        self.img_size = img_size\n        self.weight = line_seg_length_weight_fn\n\n    def forward(self, block_adj_matrix):\n        block_adj_matrix_gt = self.adj_gt[self.b, self.st_st:self.st_st+self.st_len, self.ed_st:self.ed_st+self.ed_len]\n        if self.junc is not None:\n            junc_st = self.junc[self.b, self.st_st:self.st_st + self.st_len].view(self.st_len, 1, 2).expand(self.st_len,\n                                                                                                            self.ed_len,\n                                                                                                            2)\n            junc_ed = self.junc[self.b, self.ed_st:self.ed_st + self.ed_len].view(1, self.ed_len, 2).expand(self.st_len,\n                                                                                                            self.ed_len,\n                                                                                                            2)\n            line_len = (junc_ed - junc_st).norm(dim=2)\n            return self.loss_lambda * self.adj_crit(block_adj_matrix, block_adj_matrix_gt, weight=self.weight(line_len, self.img_size * 1.4143))\n        else:\n            return self.loss_lambda * self.adj_crit(block_adj_matrix, block_adj_matrix_gt)\n\n\nclass LSDModule(nn.Module):\n    def __init__(\n            self,\n            # backbone parameters\n            backbone=\"unet\",\n            dim_embedding=256,\n            backbone_kwargs={},\n            # junction inference parameters\n            junction_pooling_threshold=0.2,\n            max_junctions=512,\n            feature_spatial_scale=0.25,\n            junction_heatmap_criterion=\"binary_cross_entropy\",\n            # junction pooling parameters\n            junction_pooling_size=15.,\n            # directional attention parameters\n            attention_sigma=1.,\n            # adjacency matrix inference parameters\n            block_inference_size=64,\n            adjacency_matrix_criterion=\"binary_cross_entropy\",\n            weight_fn=None,\n            is_train_junc=True,\n            is_train_adj=True,\n            enable_junc_infer=True,\n            enable_adj_infer=True,\n            verbose=True,\n            **kwargs\n    ):\n        super(LSDModule, self).__init__()\n        if backbone == \"unet\":\n            backbone_kwargs.update({\n                \"n_downs\": 5,\n                \"n_ups\": 3\n            })\n            self.backbone = models.backbone.UNetBackbone(\n                dim_embedding=dim_embedding,\n                **backbone_kwargs\n            )\n        elif backbone == \"resnet50\":\n            self.backbone = models.backbone.ResNetU50Backbone(\n                dim_embedding=dim_embedding,\n                **backbone_kwargs\n            )\n        else:\n            raise ValueError(f\"invalid backbone: {backbone}\")\n\n        self.prep_data = LSDDataLayer()\n\n        self.junc_infer = models.graph.JunctionInference(\n            dim_embedding=dim_embedding,\n            pooling_threshold=junction_pooling_threshold,\n            max_junctions=max_junctions,\n            spatial_scale=feature_spatial_scale,\n            verbose=verbose\n        )\n\n        self.line_pool = models.graph.LinePooling(\n            align_size=junction_pooling_size,\n            spatial_scale=feature_spatial_scale\n        )\n\n        self.adj_infer = models.graph.AdjacencyMatrixInference(\n            dim_embedding=dim_embedding,\n            align_size=junction_pooling_size,\n        )\n\n        self.adj_embed = nn.Sequential(\n            models.common.double_conv(dim_embedding, dim_embedding),\n        )\n\n        self.adj_embed.apply(models.common.weights_init)\n        if junction_heatmap_criterion == \"focal\":\n            self.hm_crit = BinaryFocalLoss()\n        else:\n            self.hm_crit = getattr(F, junction_heatmap_criterion)\n        if adjacency_matrix_criterion == \"focal\":\n            self.adj_crit = BinaryFocalLoss()\n        else:\n            self.adj_crit = getattr(F, adjacency_matrix_criterion)\n        self.adj_block_size = block_inference_size\n        self.max_junctions = max_junctions\n        self.weight_fn = weight_fn\n        self.is_train_junc = is_train_junc\n        self.is_train_adj = is_train_adj\n        self.enable_junc_infer = enable_junc_infer\n        self.enable_adj_infer = enable_adj_infer\n\n    def forward(self, img, junc_map_gt, adj_matrix_gt, junc_loss_lambda=1., adj_loss_lambda=1., junc_coord_gt=None):\n        img = self.prep_data(img)\n        feat = self.backbone(img)\n        bs = img.size(0)\n\n        if self.enable_junc_infer:\n            if self.is_train_junc:\n                junc_hm, junc_coords = self.junc_infer(feat)\n            else:\n                with th.no_grad():\n                    junc_hm, junc_coords = self.junc_infer(feat)\n            # junc_coords[junc_coords[:, 1:].sum(dim=1) == 0] += 0.1\n            # padding junction prediction\n            junc_cnt = []\n            j0 = 0\n            for b in range(bs):\n                junc_cnt.append(0)\n                for j in range(j0, len(junc_coords)):\n                    if np.isclose(junc_coords[j, 0].item(), b, atol=.1):\n                        junc_cnt[-1] += 1\n                    else:\n                        j0 = j\n                        break\n            junc_st = np.cumsum([0] + junc_cnt).tolist()\n            junc_pred = junc_coords.new_full((bs, self.max_junctions, 2), 0.)\n            for b in range(bs):\n                junc_pred[b, :junc_cnt[b]] = junc_coords[junc_st[b]:junc_st[b + 1], 1:] + .1\n            loss_hm = self.hm_crit(junc_hm, junc_map_gt) * junc_loss_lambda\n        else:\n            assert junc_coord_gt is not None\n            junc_hm = junc_map_gt\n            junc_pred = junc_coord_gt\n            loss_hm = img.new_full((1, ), 0)\n\n        if self.enable_adj_infer:\n            # block-wise junction pooling and adjacency matrix inference\n            # first count number of detected junctions of each image\n            if junc_coord_gt is not None:\n                junc_st = [0]\n                junc_cnt = []\n                for b in range(bs):\n                    junc_cnt.append(th.sum(junc_coord_gt[b].sum(dim=1) != 0).item())\n                    assert junc_cnt[-1] > 0\n                    junc_st.append(junc_st[-1] + junc_cnt[-1])\n                junc_coord_gt_ = img.new_full((sum(junc_cnt), 3), 0.)\n                for b in range(bs):\n                    junc_coord_gt_[junc_st[b]:junc_st[b+1], 0] = b\n                    junc_coord_gt_[junc_st[b]:junc_st[b+1], 1:] = junc_coord_gt[b, :junc_cnt[b]]\n\n            # then for each image, build list of subgraph that processes at most block_size junctions\n            block_crit = []\n            block_infer = []\n            for b in range(bs):\n                num_blocks = junc_cnt[b] // self.adj_block_size + (1 if junc_cnt[b] % self.adj_block_size else 0)\n                for bst in range(num_blocks):\n                    for bed in range(num_blocks):\n                        st_st = junc_st[b] + bst * self.adj_block_size\n                        st_len = min(self.adj_block_size, junc_cnt[b] - bst * self.adj_block_size)\n                        ed_st = junc_st[b] + bed * self.adj_block_size\n                        ed_len = min(self.adj_block_size, junc_cnt[b] - bed * self.adj_block_size)\n                        block_crit.append(\n                            BlockAdjacencyMatrixInferenceCriterion(\n                                self.adj_crit, adj_matrix_gt, adj_loss_lambda, b,\n                                bst * self.adj_block_size, min(self.adj_block_size, junc_cnt[b] - bst * self.adj_block_size),\n                                bed * self.adj_block_size, min(self.adj_block_size, junc_cnt[b] - bed * self.adj_block_size),\n                                None if junc_coord_gt is None else junc_coord_gt, img.size(2), self.weight_fn\n                            )\n                        )\n                        block_infer.append(\n                            BlockAdjacencyMatrixInference(\n                                self.line_pool, self.adj_infer,\n                                b, st_st, st_len, ed_st, ed_len, junc_coords if junc_coord_gt is None else junc_coord_gt_,\n                            )\n                        )\n\n            def output_collect_fn(outputs):\n                output = img.new_full((bs, self.max_junctions, self.max_junctions), 0.)\n                current_block = 0\n                for b in range(bs):\n                    num_blocks = junc_cnt[b] // self.adj_block_size + (1 if junc_cnt[b] % self.adj_block_size else 0)\n                    for bst in range(num_blocks):\n                        for bed in range(num_blocks):\n                            st_st = bst * self.adj_block_size\n                            st_len = min(self.adj_block_size, junc_cnt[b] - bst * self.adj_block_size)\n                            ed_st = bed * self.adj_block_size\n                            ed_len = min(self.adj_block_size, junc_cnt[b] - bed * self.adj_block_size)\n                            output[b, st_st:st_st + st_len, ed_st:ed_st + ed_len] = outputs[current_block]\n                            current_block += 1\n\n                return output\n\n            block_adj_infer = models.common.GradAccumulator(\n                block_crit, block_infer, output_collect_fn, reduce_method=\"mean\"\n            )\n            if self.is_train_adj:\n                feat_adj = self.adj_embed(feat)\n                adj_matrix_pred, loss_adj = block_adj_infer(feat_adj)\n            else:\n                with th.no_grad():\n                    feat_adj = self.adj_embed(feat)\n                    adj_matrix_pred, loss_adj = block_adj_infer(feat_adj)\n        else:\n            adj_matrix_pred = adj_matrix_gt\n            loss_adj = img.new_full((1, ), 0)\n\n        return junc_pred, junc_hm, adj_matrix_pred, loss_hm, loss_adj\n"
  },
  {
    "path": "models/lsd_test.py",
    "content": "import torch as th\nimport torch.nn as nn\nimport models.graph\nimport models.backbone\nimport models.common\nimport numpy as np\n\nfrom .lsd import LSDModule\n\n\nclass LSDDataLayer(nn.Module):\n    def __init__(self, mean=None, std=None):\n        super(LSDDataLayer, self).__init__()\n        self.std = [1., 1., 1.] if std is None else std\n        self.mean = [102.9801, 115.9465, 122.7717] if mean is None else mean\n\n    def forward(self, img):\n        assert img.size(1) == 3\n        for ch in range(3):\n            img[:, ch, :, :] = (img[:, ch, :, :] - self.mean[ch]) / self.std[ch]\n\n        return img\n\n\n# noinspection PyTypeChecker\nclass BinaryFocalLoss(nn.Module):\n    def __init__(self, gamma=2., alpha=0.25, size_average=True):\n        super(BinaryFocalLoss, self).__init__()\n        self.gamma = gamma\n        self.alpha = alpha\n        self.size_average = size_average\n\n    def forward(self, input, target, weight=None):\n        if weight is not None:\n            assert weight.size() == input.size(), f\"weight size: {weight.size()}, input size: {input.size()}\"\n            assert (weight >= 0).all() and (weight <= 1).all(), f\"weight max: {weight.max()}, min: {weight.min()}\"\n        input = input.clamp(1.e-6, 1. - 1.e-6)\n        if weight is None:\n            loss = th.sum(\n                - self.alpha * target * ((1 - input) ** self.gamma) * th.log(input)\n                - (1 - self.alpha) * (1 - target) * (input ** self.gamma) * th.log(1 - input))\n        else:\n            loss = th.sum(\n                (- self.alpha * target * ((1 - input) ** self.gamma) * th.log(input)\n                 - (1 - self.alpha) * (1 - target) * (input ** self.gamma) * th.log(1 - input)) * weight\n            )\n        if self.size_average:\n            loss /= input.nelement()\n        return loss\n\n\nclass BlockAdjacencyMatrixInference(nn.Module):\n    def __init__(self,\n                 line_pool_module, adj_infer_module,\n                 current_batch_id, junc_st_st, junc_st_len, junc_ed_st, junc_ed_len, junc_pred\n                 ):\n        super(BlockAdjacencyMatrixInference, self).__init__()\n        self.line_pool = line_pool_module\n        self.adj_infer = adj_infer_module\n        self.b = current_batch_id\n        self.st_st = junc_st_st\n        self.st_len = junc_st_len\n        self.ed_st = junc_ed_st\n        self.ed_len = junc_ed_len\n        self.juncs = junc_pred\n\n    def forward(self, feat):\n        junc_st = self.juncs.narrow(0, self.st_st, self.st_len)\n        junc_ed = self.juncs.narrow(0, self.ed_st, self.ed_len)\n        assert (junc_st[:, 0] == self.b).all() and (junc_ed[:, 0] == self.b).all(), f\"{self.b}\\n{junc_st[:, 0]}\\n{junc_ed[:, 0]}\"\n        line_feat = self.line_pool(feat, junc_st, junc_ed)\n        block_adj_matrix = self.adj_infer(line_feat)\n\n        return block_adj_matrix\n\n\nclass BlockAdjacencyMatrixInferenceCriterion(nn.Module):\n    def __init__(self, adj_matrix_crit, adj_matrix_gt, adj_matrix_loss_lambda,\n                 current_batch_id, mtx_st_st, mtx_st_len, mtx_ed_st, mtx_ed_len,\n                 junc_padded, img_size, line_seg_length_weight_fn\n                 ):\n        super(BlockAdjacencyMatrixInferenceCriterion, self).__init__()\n        self.adj_crit = adj_matrix_crit\n        self.adj_gt = adj_matrix_gt\n        self.loss_lambda = adj_matrix_loss_lambda\n        self.b = current_batch_id\n        self.st_st = mtx_st_st\n        self.st_len = mtx_st_len\n        self.ed_st = mtx_ed_st\n        self.ed_len = mtx_ed_len\n        self.junc = junc_padded\n        self.img_size = img_size\n        self.weight = line_seg_length_weight_fn\n\n    def forward(self, block_adj_matrix):\n        block_adj_matrix_gt = self.adj_gt[self.b, self.st_st:self.st_st+self.st_len, self.ed_st:self.ed_st+self.ed_len]\n        if self.junc is not None:\n            junc_st = self.junc[self.b, self.st_st:self.st_st + self.st_len].view(self.st_len, 1, 2).expand(self.st_len,\n                                                                                                            self.ed_len,\n                                                                                                            2)\n            junc_ed = self.junc[self.b, self.ed_st:self.ed_st + self.ed_len].view(1, self.ed_len, 2).expand(self.st_len,\n                                                                                                            self.ed_len,\n                                                                                                            2)\n            line_len = (junc_ed - junc_st).norm(dim=2)\n            return self.loss_lambda * self.adj_crit(block_adj_matrix, block_adj_matrix_gt, weight=self.weight(line_len, self.img_size * 1.4143))\n        else:\n            return self.loss_lambda * self.adj_crit(block_adj_matrix, block_adj_matrix_gt)\n\n\nclass LSDTestModule(LSDModule):\n    def __init__(\n            self,\n            # backbone parameters\n            backbone=\"unet\",\n            dim_embedding=256,\n            backbone_kwargs={},\n            # junction inference parameters\n            junction_pooling_threshold=0.2,\n            max_junctions=512,\n            feature_spatial_scale=0.25,\n            # junction pooling parameters\n            junction_pooling_size=15.,\n    ):\n        super(LSDTestModule, self).__init__(\n            backbone=backbone,\n            dim_embedding=dim_embedding,\n            backbone_kwargs=backbone_kwargs,\n            # junction inference parameters\n            junction_pooling_threshold=junction_pooling_threshold,\n            max_junctions=max_junctions,\n            feature_spatial_scale=feature_spatial_scale,\n            junction_heatmap_criterion=\"binary_cross_entropy\",\n            # junction pooling parameters\n            junction_pooling_size=junction_pooling_size,\n            # adjacency matrix inference parameters\n            block_inference_size=64,\n            adjacency_matrix_criterion=\"binary_cross_entropy\",\n            weight_fn=None,\n            is_train_junc=True,\n            is_train_adj=True,\n            enable_junc_infer=True,\n            enable_adj_infer=True,\n            verbose=True,\n        )\n\n    def forward(self, img):\n        img = self.prep_data(img)\n        feat = self.backbone(img)\n        bs = img.size(0)\n\n        junc_hm, junc_coords = self.junc_infer(feat)\n\n        # first count number of detected junctions of each image\n        junc_cnt = []\n        j0 = 0\n        for b in range(bs):\n            junc_cnt.append(0)\n            for j in range(j0, len(junc_coords)):\n                if np.isclose(junc_coords[j, 0].item(), b, atol=.1):\n                    junc_cnt[-1] += 1\n                else:\n                    j0 = j\n                    break\n        junc_st = np.cumsum([0] + junc_cnt).tolist()\n        junc_pred = junc_coords.new_full((bs, self.max_junctions, 2), 0.)\n        for b in range(bs):\n            junc_pred[b, :junc_cnt[b]] = junc_coords[junc_st[b]:junc_st[b + 1], 1:] + .1\n\n        # then for each image, build list of subgraph that processes at most block_size junctions\n        block_infer = []\n        for b in range(bs):\n            num_blocks = junc_cnt[b] // self.adj_block_size + (1 if junc_cnt[b] % self.adj_block_size else 0)\n            for bst in range(num_blocks):\n                for bed in range(num_blocks):\n                    st_st = junc_st[b] + bst * self.adj_block_size\n                    st_len = min(self.adj_block_size, junc_cnt[b] - bst * self.adj_block_size)\n                    ed_st = junc_st[b] + bed * self.adj_block_size\n                    ed_len = min(self.adj_block_size, junc_cnt[b] - bed * self.adj_block_size)\n                    block_infer.append(\n                        BlockAdjacencyMatrixInference(\n                            self.line_pool, self.adj_infer,\n                            b, st_st, st_len, ed_st, ed_len, junc_coords,\n                        )\n                    )\n\n        def output_collect_fn(outputs):\n            output = img.new_full((bs, self.max_junctions, self.max_junctions), 0.)\n            current_block = 0\n            for b in range(bs):\n                num_blocks = junc_cnt[b] // self.adj_block_size + (1 if junc_cnt[b] % self.adj_block_size else 0)\n                for bst in range(num_blocks):\n                    for bed in range(num_blocks):\n                        st_st = bst * self.adj_block_size\n                        st_len = min(self.adj_block_size, junc_cnt[b] - bst * self.adj_block_size)\n                        ed_st = bed * self.adj_block_size\n                        ed_len = min(self.adj_block_size, junc_cnt[b] - bed * self.adj_block_size)\n                        output[b, st_st:st_st + st_len, ed_st:ed_st + ed_len] = outputs[current_block]\n                        current_block += 1\n\n            return output\n\n        block_adj_infer = models.common.GradAccumulator(None, block_infer, output_collect_fn)\n        feat_adj = self.adj_embed(feat)\n        adj_matrix_pred, _ = block_adj_infer(feat_adj)\n\n        return junc_pred, junc_hm, adj_matrix_pred\n"
  },
  {
    "path": "models/pretrained/resnet50-imagenet.pth",
    "content": "version https://git-lfs.github.com/spec/v1\noid sha256:902bbfbc9b570be36e0f94e757b211b24f9216df14260c58b359ddb182c0723e\nsize 1122304\n"
  },
  {
    "path": "models/test_common.py",
    "content": "from unittest import TestCase\nimport numpy as np\nimport models.common as common\nimport torch as th\nimport torch.nn as nn\nfrom itertools import chain\nfrom libs.roi_align.modules.roi_align import RoIAlign\n\n\nclass TestRoiPooling(TestCase):\n    def setUp(self):\n        th.manual_seed(1234)\n        self.test_data = dict(\n            input=th.rand(2, 3, 9, 10),\n            rois=th.tensor([\n                [0, 0, 0, 9, 8], # whole feature map\n                [1, 0, 0, 9, 8], # whole feature map\n                [0, 0, 0, 0, 0], # top left pixel\n                [0, 0, 0, 1, 1], # top left 2x2\n                [0, 0, 0, -1, -1], # bottom right out of range\n                [1, 9, 8, 9, 8], # bottom right pixel\n                [1, -3, -5, 6, 7], # top left out of range\n                [1, -3, -5, 10, 11],  # both corner out of range\n                [0, 3, 2, 9, 8], # 7x7 roi\n                [0, 3, 3, 8, 8], # 6x6 roi\n                [1, 1, 1, 5, 5], # 5x5 roi\n            ], dtype=th.float32)\n        )\n\n    def test_output_size_7x7(self):\n        input, rois = self.test_data[\"input\"], self.test_data[\"rois\"]\n        output7x7 = common.roi_pooling(\n            input=input,\n            rois=rois,\n            size=(7, 7),\n            spatial_scale=1.0\n        )\n        # output7x7 = RoIAlign(aligned_height=7, aligned_width=7, spatial_scale=1.)(input, rois)\n        self.assertEqual(output7x7.size(0), rois.size(0), \"output size dismatches rois size\")\n        self.assertEqual(input.size(1), output7x7.size(1), \"output channel dismatch input channel\")\n        self.assertTupleEqual(output7x7.shape[2:], (7, 7), \"output shape dismatch required shape\")\n\n    def test_output_size_5x5(self):\n        input, rois = self.test_data[\"input\"], self.test_data[\"rois\"]\n        output5x5 = common.roi_pooling(\n            input=input,\n            rois=rois,\n            size=(5, 5),\n            spatial_scale=1.0\n        )\n        # output5x5 = RoIAlign(aligned_height=5, aligned_width=5, spatial_scale=1.)(input, rois)\n        self.assertTupleEqual(output5x5.shape[2:], (5, 5), \"output shape dismatch required shape\")\n\n    def test_output_size_1x1(self):\n        input, rois = self.test_data[\"input\"], self.test_data[\"rois\"]\n        output1x1 = common.roi_pooling(\n            input=input,\n            rois=rois,\n            size=(1, 1),\n            spatial_scale=1.0\n        )\n        self.assertTupleEqual(output1x1.shape[2:], (1, 1), \"output shape dismatch required shape\")\n\n    def test_output_value(self):\n        input, rois = self.test_data[\"input\"], self.test_data[\"rois\"]\n        # rois[2, 3:] -= 1\n        output1x1 = common.roi_pooling(\n            input=input,\n            rois=rois,\n            size=(1, 1),\n            spatial_scale=1.0\n        )\n        # output1x1 = RoIAlign(aligned_height=1, aligned_width=1, spatial_scale=1.)(input, rois)\n        output5x5 = common.roi_pooling(\n            input=input,\n            rois=rois,\n            size=(5, 5),\n            spatial_scale=1.0\n        )\n        # output7x7 = common.roi_pooling(\n        #     input=input,\n        #     rois=rois,\n        #     size=(7, 7),\n        #     spatial_scale=1.0\n        # )\n        rois[8, 3:] -= 1\n        rois[10, 3:] -= 1\n        output7x7 = RoIAlign(aligned_height=7, aligned_width=7, spatial_scale=1.)(input, rois)\n        output5x5 = RoIAlign(aligned_height=5, aligned_width=5, spatial_scale=1.)(input, rois)\n        self.assertTrue((output1x1[2] == input[0, :, :1, :1]).all())\n        self.assertTrue((output1x1[5] == input[1, :, 8:, 9:]).all())\n        self.assertTrue((output5x5[10] == input[1, :, 1:6, 1:6]).all())\n        self.assertTrue((output7x7[8] == input[0, :, 2:9, 3:10]).all())\n\n\nclass TestGradAccumulator(TestCase):\n    def setUp(self):\n        th.manual_seed(789)\n        self.net1 = nn.Sequential(\n            nn.Conv2d(16, 32, 3, 1, 1, bias=True),\n            nn.BatchNorm2d(32),\n            nn.ReLU(inplace=True)\n        )\n        self.net2 = nn.Sequential(\n            nn.Conv2d(32, 16, 3, 1, 1, bias=True),\n            nn.BatchNorm2d(16),\n            nn.ReLU(inplace=True),\n            nn.Conv2d(16, 1, 3, 1, 1, bias=True),\n        )\n        self.test_data = dict(\n            X = th.rand(100, 16, 32, 32),\n            Y = th.rand(100, 1, 32, 32),\n        )\n        self.net1.apply(common.weights_init)\n        self.net2.apply(common.weights_init)\n        self.crit = nn.SmoothL1Loss()\n\n    def test_forward_pass(self):\n        X, Y = self.test_data[\"X\"], self.test_data[\"Y\"]\n        with th.no_grad():\n            feat = self.net1(X)\n            out = []\n            for i in range(4):\n                for j in range(4):\n                    out.append(self.net2(feat[:, :, i*4:(i+1)*4, j*4:(j+1)*4]))\n            loss = 0\n            for i in range(4):\n                for j in range(4):\n                    loss += self.crit(out[i*4+j], Y[:, :, i*4:(i+1)*4, j*4:(j+1)*4])\n            loss /= 16\n\n            class Net2(nn.Module):\n                def __init__(self, net2, st_st, st_ed, ed_st, ed_ed):\n                    super(Net2, self).__init__()\n                    self.net2 = net2\n                    self.st_st = st_st\n                    self.st_ed = st_ed\n                    self.ed_st = ed_st\n                    self.ed_ed = ed_ed\n\n                def forward(self, input):\n                    return self.net2(input[:, :, self.st_st:self.st_ed, self.ed_st:self.ed_ed])\n\n            class Crit(nn.Module):\n                def __init__(self, crit, target, st_st, st_ed, ed_st, ed_ed):\n                    super(Crit, self).__init__()\n                    self.crit = crit\n                    self.st_st = st_st\n                    self.st_ed = st_ed\n                    self.ed_st = ed_st\n                    self.ed_ed = ed_ed\n                    self.register_buffer(\"target\", target)\n\n                def forward(self, x):\n                    return self.crit(x, self.target[:, :, self.st_st:self.st_ed, self.ed_st:self.ed_ed])\n\n            gradacc = common.GradAccumulator(\n                [Crit(self.crit, Y, i*4, (i+1)*4, j*4, (j+1)*4) for i in range(4) for j in range(4)],\n                [Net2(self.net2, i*4, (i+1)*4, j*4, (j+1)*4) for i in range(4) for j in range(4)],\n                collect_fn=None\n            )\n            net_ = nn.Sequential(self.net1, gradacc)\n            out_, loss_ = net_(X)\n\n        for i in range(len(out)):\n            self.assertTrue(th.allclose(out[i], out_[i]))\n        self.assertTrue(th.allclose(loss, loss_), f\"{loss}\\n{loss_}\")\n\n    def test_backward_pass(self):\n        X, Y = self.test_data[\"X\"], self.test_data[\"Y\"]\n\n        self.net1.zero_grad()\n        self.net2.zero_grad()\n\n        feat = self.net1(X)\n        out = []\n        for i in range(4):\n            for j in range(4):\n                out.append(self.net2(feat[:, :, i * 4:(i + 1) * 4, j * 4:(j + 1) * 4]))\n        loss = 0\n        for i in range(4):\n            for j in range(4):\n                loss += self.crit(out[i * 4 + j], Y[:, :, i * 4:(i + 1) * 4, j * 4:(j + 1) * 4])\n        loss /= 16\n        loss.backward()\n\n        grad = {}\n\n        for k, v in chain(self.net1.named_parameters(prefix=\"net1\"), self.net2.named_parameters(prefix=\"net2\")):\n            grad[k] = th.tensor(v.grad)\n\n        self.net1.zero_grad()\n        self.net2.zero_grad()\n\n        class Net2(nn.Module):\n            def __init__(self, net2, st_st, st_ed, ed_st, ed_ed):\n                super(Net2, self).__init__()\n                self.net2 = net2\n                self.st_st = st_st\n                self.st_ed = st_ed\n                self.ed_st = ed_st\n                self.ed_ed = ed_ed\n\n            def forward(self, input):\n                return self.net2(input[:, :, self.st_st:self.st_ed, self.ed_st:self.ed_ed])\n\n        class Crit(nn.Module):\n            def __init__(self, crit, target, st_st, st_ed, ed_st, ed_ed):\n                super(Crit, self).__init__()\n                self.crit = crit\n                self.st_st = st_st\n                self.st_ed = st_ed\n                self.ed_st = ed_st\n                self.ed_ed = ed_ed\n                self.register_buffer(\"target\", target)\n\n            def forward(self, x):\n                return self.crit(x, self.target[:, :, self.st_st:self.st_ed, self.ed_st:self.ed_ed])\n\n        gradacc = common.GradAccumulator(\n            [Crit(self.crit, Y, i * 4, (i + 1) * 4, j * 4, (j + 1) * 4) for i in range(4) for j in range(4)],\n            [Net2(self.net2, i * 4, (i + 1) * 4, j * 4, (j + 1) * 4) for i in range(4) for j in range(4)],\n            collect_fn=None\n        )\n        net_ = nn.Sequential(self.net1, gradacc)\n        out_, loss_ = net_(X)\n        loss_.backward()\n\n        grad_ = {}\n        for k, v in chain(self.net1.named_parameters(prefix=\"net1\"), self.net2.named_parameters(prefix=\"net2\")):\n            grad_[k] = th.tensor(v.grad)\n\n        for k in sorted(grad.keys()):\n            self.assertTrue(th.allclose(grad[k], grad_[k]), f\"{k}:\\n{grad[k]}\\n{grad_[k]}\")\n"
  },
  {
    "path": "models/test_graph.py",
    "content": "from unittest import TestCase\nimport torch as th\nimport models.graph as graph\nimport matplotlib.pyplot as plt\n\n\nclass TestJunctionInference(TestCase):\n    def test_junction_inference_forward(self):\n        junc_infer = graph.JunctionInference(256, 0.1, 512, 0.25, False)\n        with th.no_grad():\n            feat_map = th.rand(5, 256, 32, 64)\n            junc_map, junc_coord = junc_infer(feat_map)\n            self.assertTrue(junc_map.size(0) == 5)\n            self.assertTupleEqual(junc_map.shape[1:], (128, 256))\n            self.assertTrue(junc_coord.size(1) == 3)\n            self.assertTrue((junc_coord[:, 1:] >= 0).all() and (junc_coord[:, 1] < 256).all() and (junc_coord[:, 2] < 128).all())\n\n\nclass TestJunctionPooling(TestCase):\n    def test_junc_pooling_forward(self):\n        junc_infer = graph.JunctionInference(256, 0.1, 512, 0.25, False)\n        junc_pool = graph.JunctionPooling(5, 5, 0.25)\n        with th.no_grad():\n            feat_map = th.rand(5, 256, 32, 64)\n            junc_map, junc_coord = junc_infer(feat_map)\n            out = junc_pool(feat_map, junc_coord)\n            self.assertTrue(out.size(0) == junc_coord.size(0))\n            self.assertTrue(out.size(1) == feat_map.size(1))\n            self.assertTrue(out.shape[2:] == (5, 5))\n\n\nclass TestDirectionalAttention(TestCase):\n    def test_attention_forward(self):\n        attn = graph.DirectionalAttention(\n            15,\n            attn_sigma_dir=3.1415926/300,\n            attn_sigma_pos=2\n        )\n        junc = th.tensor([\n            [0., 0.],\n            [2., 0.],\n            [1., 1.],\n            [0., 1.]\n        ])\n        map_st2ed, map_ed2st = attn(junc, junc)\n        self.assertFalse(th.isnan(map_st2ed).any() or th.isnan(map_ed2st).any())\n        self.assertTupleEqual(map_st2ed.size(), map_ed2st.size())\n        self.assertTupleEqual(map_st2ed.shape[2:], map_ed2st.shape[2:])\n        self.assertTupleEqual(map_st2ed.shape[2:], (15, 15))\n        self.assertTupleEqual(map_st2ed.shape[:2], (4, 4))\n        attn_map_st2ed = map_st2ed.numpy()\n        attn_map_ed2st = map_ed2st.numpy()\n        figs, axes = plt.subplots(4, 4)\n        for i in range(4):\n            for j in range(4):\n                axes[i, j].imshow(attn_map_ed2st[i, j])\n        plt.show()\n        map_ct = attn(junc)\n        self.assertTupleEqual(map_ct.size(), (15, 15))\n\n\nclass TestAdjacencyMatrixInference(TestCase):\n    def test_adjacency_matrix_inference_forward(self):\n        attn = graph.DirectionalAttention(15)\n        junc_st = th.rand(30, 2)\n        junc_ed = th.rand(60, 2)\n        attn_st2ed, attn_ed2st = attn(junc_st, junc_ed)\n        attn_center = attn()\n        adj_with_center = graph.AdjacencyMatrixInference(256, junc_align_size=15, align_center=True)\n        adj_wo_center = graph.AdjacencyMatrixInference(256, junc_align_size=15, align_center=False)\n        feat_start = th.rand(30, 128, 15, 15)\n        feat_end = th.rand(60, 128, 15, 15)\n        feat_center = th.rand(30, 60, 128, 15, 15)\n        adjacency_matrix_with_center = adj_with_center(feat_start, feat_end, attn_st2ed, attn_ed2st, feat_center, attn_center)\n        adjacency_matrix_wo_center = adj_wo_center(feat_start, feat_end, attn_st2ed, attn_ed2st)\n        self.assertTrue(adjacency_matrix_with_center.size() == adjacency_matrix_wo_center.size() == (30, 60))\n"
  },
  {
    "path": "test.py",
    "content": "# System libs\nimport os\nimport time\n\n# Numerical libs\nimport torch\nimport torch.nn as nn\nfrom torch import optim\nfrom torch.utils.data import DataLoader\nimport numpy as np\n\n# Our libs\nfrom data.sist_line import SISTLine\nimport data.transforms as tf\nfrom models.lsd_test import LSDTestModule\nfrom utils import AverageMeter, graph2line, draw_lines, draw_jucntions\n\n# tensorboard\nfrom tensorboardX import SummaryWriter\nimport torchvision.utils as vutils\n\nimport fire\nimport cv2\n\n\nclass LSD(object):\n    def __init__(\n            self,\n            # exp params\n            exp_name=\"u50_block\",\n            # arch params\n            backbone=\"resnet50\",\n            backbone_kwargs={},\n            dim_embedding=256,\n            feature_spatial_scale=0.25,\n            max_junctions=512,\n            junction_pooling_threshold=0.2,\n            junc_pooling_size=15,\n            block_inference_size=64,\n            # data params\n            img_size=416,\n            gpus=[0,],\n            resume_epoch=\"latest\",\n            # vis params\n            vis_junc_th=0.3,\n            vis_line_th=0.3\n    ):\n        os.environ[\"CUDA_VISIBLE_DEVICES\"] = \",\".join(str(c) for c in gpus)\n\n        self.is_cuda = bool(gpus)\n\n        self.model = LSDTestModule(\n            backbone=backbone,\n            dim_embedding=dim_embedding,\n            backbone_kwargs=backbone_kwargs,\n            junction_pooling_threshold=junction_pooling_threshold,\n            max_junctions=max_junctions,\n            feature_spatial_scale=feature_spatial_scale,\n            junction_pooling_size=junc_pooling_size,\n        )\n\n        self.exp_name = exp_name\n        os.makedirs(os.path.join(\"log\", exp_name), exist_ok=True)\n        os.makedirs(os.path.join(\"ckpt\", exp_name), exist_ok=True)\n        self.writer = SummaryWriter(log_dir=os.path.join(\"log\", exp_name))\n\n        # checkpoints\n        self.states = dict(\n            last_epoch=-1,\n            elapsed_time=0,\n            state_dict=None\n        )\n\n        if resume_epoch and os.path.isfile(os.path.join(\"ckpt\", exp_name, f\"train_states_{resume_epoch}.pth\")):\n            states = torch.load(\n                os.path.join(\"ckpt\", exp_name, f\"train_states_{resume_epoch}.pth\"))\n            print(f\"resume traning from epoch {states['last_epoch']}\")\n            self.model.load_state_dict(states[\"state_dict\"])\n            self.states.update(states)\n\n        self.vis_junc_th = vis_junc_th\n        self.vis_line_th = vis_line_th\n        self.block_size = block_inference_size\n        self.max_junctions = max_junctions\n        self.img_size = img_size\n\n    def end(self):\n        self.writer.close()\n        return \"command queue finished.\"\n\n    def test(self, path_to_image):\n        # main loop\n        torch.set_grad_enabled(False)\n        print(f\"test for image: {path_to_image}\", flush=True)\n\n        if self.is_cuda:\n            model = self.model.cuda().eval()\n        else:\n            model = self.model.eval()\n\n        img = cv2.imread(path_to_image)\n        img = cv2.resize(img, (self.img_size, self.img_size))\n        img_reverse = img[..., [2, 1, 0]]\n        img = torch.from_numpy(img_reverse).float().permute(2, 0, 1).unsqueeze(0)\n\n        if self.is_cuda:\n            img = img.cuda()\n\n        # measure elapsed time\n        junc_pred, heatmap_pred, adj_mtx_pred = model(img)\n\n        # visualize eval\n        img = img.cpu().numpy()\n        junctions_pred = junc_pred.cpu().numpy()\n        adj_mtx = adj_mtx_pred.cpu().numpy()\n\n        img_with_junc = draw_jucntions(img, junctions_pred)\n        img_with_junc = img_with_junc[0].numpy()[None]\n        img_with_junc = img_with_junc[:, ::-1, :, :]\n        lines_pred, score_pred = graph2line(junctions_pred, adj_mtx)\n        vis_line_pred = draw_lines(img_with_junc, lines_pred, score_pred)[0]\n        vis_line_pred = vis_line_pred.permute(1, 2, 0).numpy()\n\n        cv2.imshow(\"result\", vis_line_pred)\n\n\nif __name__ == \"__main__\":\n    fire.Fire(LSD)\n    # trainer = LSDTrainer().train(lr=1.)\n"
  },
  {
    "path": "test.sh",
    "content": "python test.py \\\n--exp-name line_weighted_wo_focal_junc --backbone resnet50 \\\n--backbone-kwargs '{\"encoder_weights\": \"ckpt/backbone/encoder_epoch_20.pth\", \"decoder_weights\": \"ckpt/backbone/decoder_epoch_20.pth\"}' \\\n--dim-embedding 256 --junction-pooling-threshold 0.2 \\\n--junc-pooling-size 64 --block-inference-size 128 \\\n--gpus 0, --resume-epoch latest \\\n--vis-junc-th 0.25 --vis-line-th 0.25 \\\n    - test $1\n"
  },
  {
    "path": "tools/rebuild_yorkurban.py",
    "content": "from data.line_graph import LineGraph\nimport os\nfrom scipy import io\nimport numpy as np\nfrom shutil import copyfile\nfrom tqdm import trange\n\ndata_root = \"/home/ziheng/YorkUrbanDB\"\nout_root = \"/home/ziheng/YorkUrbanDB_new/test\"\n\n\nlist_file = io.loadmat(os.path.join(data_root, \"Manhattan_Image_DB_Names.mat\"))\nname_list = [e[0][0].strip(\"\\\\\") for e in list_file[\"Manhattan_Image_DB_Names\"]]\ntest_set = io.loadmat(os.path.join(data_root, \"ECCV_TrainingAndTestImageNumbers.mat\"))\ntest_set_id = test_set[\"testSetIndex\"].flatten().tolist()\nimgs = [os.path.join(data_root, name_list[i - 1], name_list[i - 1] + \".jpg\") for i in test_set_id]\nlabels = [io.loadmat(os.path.join(data_root, name_list[i - 1], name_list[i - 1] + \"LinesAndVP.mat\")) for i in test_set_id]\nlines = [np.float32(lab[\"lines\"]).reshape((-1, 4)) for lab in labels]\nmaps = [np.uint8(lab[\"finalImg\"]) for lab in labels]\n\nos.makedirs(out_root, exist_ok=True)\nmax_juncs = 512\nfor i in trange(len(imgs)):\n    img, line = imgs[i], lines[i]\n    fname = os.path.basename(img)[:-4]\n    hm = maps[i]\n    lg = LineGraph(eps_junction=1., eps_line_deg=np.pi / 30, verbose=False)\n    for x1, y1, x2, y2 in line:\n        lg.add_junction((x1, y1))\n        lg.add_junction((x2, y2))\n    lg.freeze_junction()\n    for x1, y1, x2, y2 in line:\n        lg.add_line_seg((x1, y1), (x2, y2))\n    lg.freeze_line_seg()\n    max_juncs = max(lg.num_junctions, max_juncs)\n    lg.save(os.path.join(out_root, fname + \".lg\"))\n    copyfile(img, os.path.join(out_root, fname + \".jpg\"))\n    # img = cv2.imread(img)\n    # print(fname, flush=True)\n    # cv2.imshow(\"line_\", lg.line_map(img.shape[:2]))\n    # cv2.imshow(\"line\", hm)\n    # cv2.waitKey()\n\nprint(max_juncs)\n"
  },
  {
    "path": "train.sh",
    "content": "python main.py \\\n--exp-name line_weighted_wo_focal_junc --backbone resnet50 \\\n--backbone-kwargs '{\"encoder_weights\": \"ckpt/backbone/encoder_epoch_20.pth\", \"decoder_weights\": \"ckpt/backbone/decoder_epoch_20.pth\"}' \\\n--dim-embedding 256 --junction-pooling-threshold 0.2 \\\n--junc-pooling-size 64 --attention-sigma 1.5 --block-inference-size 128 \\\n--data-root /data/path --junc-sigma 3 \\\n--batch-size 16 --gpus 0,1,2,3 --num-workers 10 --resume-epoch latest \\\n--is-train-junc True --is-train-adj True \\\n--vis-junc-th 0.1 --vis-line-th 0.1 \\\n    - train --end-epoch 9 --solver SGD --lr 0.2 --weight-decay 5e-4 --lambda-heatmap 1. --lambda-adj 5. \\\n    - train --end-epoch 15 --solver SGD --lr 0.02 --weight-decay 5e-4 --lambda-heatmap 1. --lambda-adj 10. \\\n    - train --end-epoch 30 --solver SGD --lr 0.002 --weight-decay 5e-4 --lambda-heatmap 1. --lambda-adj 10. \\\n    - end\n"
  },
  {
    "path": "utils.py",
    "content": "import numpy as np\nimport re\nimport functools\nimport torch as th\nimport cv2\nfrom numba import jit\n\n\ndef graph2line(junctions, adj_mtx, threshold=0.5):\n    assert len(junctions) == len(adj_mtx)\n    # assert np.allclose(adj_mtx, adj_mtx.transpose((0, 2, 1)), rtol=1e-2, atol=1e-2), f\"{adj_mtx}\"\n    bs = len(junctions)\n    lines = []\n    scores = []\n    for b in range(bs):\n        junc = junctions[b]\n        mtx = adj_mtx[b]\n        num_junc = np.sum(junc.sum(axis=1) > 0)\n        line = []\n        score = []\n        for i in range(num_junc):\n            for j in range(i, num_junc):\n                if mtx[i, j] > threshold:\n                    line.append(np.hstack((junc[i], junc[j])))\n                    score.append(mtx[i, j])\n        scores.append(np.array(score))\n        lines.append(np.array(line))\n\n    return lines, scores\n\n\ndef draw_lines(imgs, lines, scores=None, width=2):\n    assert len(imgs) == len(lines)\n    imgs = np.uint8(imgs)\n    bs = len(imgs)\n    if scores is not None:\n        assert len(scores) == bs\n    res = []\n    for b in range(bs):\n        img = imgs[b].transpose((1, 2, 0))\n        line = lines[b]\n        if scores is None:\n            score = np.zeros(len(line))\n        else:\n            score = scores[b]\n        img = img.copy()\n        for (x1, y1, x2, y2), c in zip(line, score):\n            pt1, pt2 = (x1, y1), (x2, y2)\n            c = tuple(cv2.applyColorMap(np.array(c * 255, dtype=np.uint8), cv2.COLORMAP_JET).flatten().tolist())\n            img = cv2.line(img, pt1, pt2, c, width)\n        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)\n        res.append(th.from_numpy(img.transpose((2, 0, 1))))\n\n    return res\n\n\ndef draw_jucntions(hms, junctions):\n    assert len(hms) == len(junctions)\n    if hms.ndim == 3:\n        imgs = np.uint8(hms * 255)\n    else:\n        imgs = np.uint8(hms)\n    bs = len(imgs)\n    res = []\n    for b in range(bs):\n        if hms.ndim == 3:\n            img = cv2.cvtColor(imgs[b], cv2.COLOR_GRAY2BGR)\n        else:\n            img = np.array(imgs[b].transpose((1, 2, 0)))\n        junc = junctions[b]\n        junc = junc[junc.sum(axis=1) > 0.1]\n        if hms.ndim == 3:\n            score = hms[b][np.int32(junc[:, 1]), np.int32(junc[:, 0])]\n        else:\n            score = [1.] * len(junc)\n        img = img.copy()\n        for (x, y), c in zip(junc, score):\n            c = tuple(cv2.applyColorMap(np.array(c * 255, dtype=np.uint8), cv2.COLORMAP_JET).flatten().tolist())\n            cv2.circle(img, (x, y), 5, c, thickness=2)\n        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)\n        res.append(th.from_numpy(img.transpose((2, 0, 1))))\n\n    return res\n\n\nclass AverageMeter(object):\n    \"\"\"Computes and stores the average and current value\"\"\"\n\n    def __init__(self):\n        self.initialized = False\n        self.val = None\n        self.avg = None\n        self.sum = None\n        self.count = None\n\n    def initialize(self, val, weight):\n        self.val = val\n        self.avg = val\n        self.sum = val * weight\n        self.count = weight\n        self.initialized = True\n\n    def update(self, val, weight=1):\n        if not self.initialized:\n            self.initialize(val, weight)\n        else:\n            self.add(val, weight)\n\n    def add(self, val, weight):\n        self.val = val\n        self.sum += val * weight\n        self.count += weight\n        self.avg = self.sum / self.count\n\n    def value(self):\n        return self.val\n\n    def average(self):\n        return self.avg\n\n\ndef unique(ar, return_index=False, return_inverse=False, return_counts=False):\n    ar = np.asanyarray(ar).flatten()\n\n    optional_indices = return_index or return_inverse\n    optional_returns = optional_indices or return_counts\n\n    if ar.size == 0:\n        if not optional_returns:\n            ret = ar\n        else:\n            ret = (ar,)\n            if return_index:\n                ret += (np.empty(0, np.bool),)\n            if return_inverse:\n                ret += (np.empty(0, np.bool),)\n            if return_counts:\n                ret += (np.empty(0, np.intp),)\n        return ret\n    if optional_indices:\n        perm = ar.argsort(kind='mergesort' if return_index else 'quicksort')\n        aux = ar[perm]\n    else:\n        ar.sort()\n        aux = ar\n    flag = np.concatenate(([True], aux[1:] != aux[:-1]))\n\n    if not optional_returns:\n        ret = aux[flag]\n    else:\n        ret = (aux[flag],)\n        if return_index:\n            ret += (perm[flag],)\n        if return_inverse:\n            iflag = np.cumsum(flag) - 1\n            inv_idx = np.empty(ar.shape, dtype=np.intp)\n            inv_idx[perm] = iflag\n            ret += (inv_idx,)\n        if return_counts:\n            idx = np.concatenate(np.nonzero(flag) + ([ar.size],))\n            ret += (np.diff(idx),)\n    return ret\n\n\ndef colorEncode(labelmap, colors, mode='BGR'):\n    labelmap = labelmap.astype('int')\n    labelmap_rgb = np.zeros((labelmap.shape[0], labelmap.shape[1], 3),\n                            dtype=np.uint8)\n    for label in unique(labelmap):\n        if label < 0:\n            continue\n        labelmap_rgb += (labelmap == label)[:, :, np.newaxis] * \\\n                        np.tile(colors[label],\n                                (labelmap.shape[0], labelmap.shape[1], 1))\n\n    if mode == 'BGR':\n        return labelmap_rgb[:, :, ::-1]\n    else:\n        return labelmap_rgb\n\n\ndef accuracy(preds, label):\n    valid = (label >= 0)\n    acc_sum = (valid * (preds == label)).sum()\n    valid_sum = valid.sum()\n    acc = float(acc_sum) / (valid_sum + 1e-10)\n    return acc, valid_sum\n\n\ndef intersectionAndUnion(imPred, imLab, numClass):\n    imPred = np.asarray(imPred).copy()\n    imLab = np.asarray(imLab).copy()\n\n    imPred += 1\n    imLab += 1\n    # Remove classes from unlabeled pixels in gt image.\n    # We should not penalize detections in unlabeled portions of the image.\n    imPred = imPred * (imLab > 0)\n\n    # Compute area intersection:\n    intersection = imPred * (imPred == imLab)\n    (area_intersection, _) = np.histogram(\n        intersection, bins=numClass, range=(1, numClass))\n\n    # Compute area union:\n    (area_pred, _) = np.histogram(imPred, bins=numClass, range=(1, numClass))\n    (area_lab, _) = np.histogram(imLab, bins=numClass, range=(1, numClass))\n    area_union = area_pred + area_lab - area_intersection\n\n    return (area_intersection, area_union)\n\n\nclass NotSupportedCliException(Exception):\n    pass\n\n\ndef process_range(xpu, inp):\n    start, end = map(int, inp)\n    if start > end:\n        end, start = start, end\n    return map(lambda x: '{}{}'.format(xpu, x), range(start, end + 1))\n\n\nREGEX = [\n    (re.compile(r'^gpu(\\d+)$'), lambda x: ['gpu%s' % x[0]]),\n    (re.compile(r'^(\\d+)$'), lambda x: ['gpu%s' % x[0]]),\n    (re.compile(r'^gpu(\\d+)-(?:gpu)?(\\d+)$'),\n     functools.partial(process_range, 'gpu')),\n    (re.compile(r'^(\\d+)-(\\d+)$'),\n     functools.partial(process_range, 'gpu')),\n]\n\n\ndef parse_devices(input_devices):\n    \"\"\"Parse user's devices input str to standard format.\n    e.g. [gpu0, gpu1, ...]\n\n    \"\"\"\n    ret = []\n    for d in input_devices.split(','):\n        for regex, func in REGEX:\n            m = regex.match(d.lower().strip())\n            if m:\n                tmp = func(m.groups())\n                # prevent duplicate\n                for x in tmp:\n                    if x not in ret:\n                        ret.append(x)\n                break\n        else:\n            raise NotSupportedCliException(\n                'Can not recognize device: \"%s\"' % d)\n    return ret\n"
  }
]