[
  {
    "path": "LICENSE",
    "content": "MIT License\n\nCopyright (c) 2019 Jie Tang\n\nPermission is hereby granted, free of charge, to any person obtaining a copy\nof this software and associated documentation files (the \"Software\"), to deal\nin the Software without restriction, including without limitation the rights\nto use, copy, modify, merge, publish, distribute, sublicense, and/or sell\ncopies of the Software, and to permit persons to whom the Software is\nfurnished to do so, subject to the following conditions:\n\nThe above copyright notice and this permission notice shall be included in all\ncopies or substantial portions of the Software.\n\nTHE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\nIMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\nFITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\nAUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\nLIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\nOUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\nSOFTWARE.\n"
  },
  {
    "path": "README.md",
    "content": "# [Learning Guided Convolutional Network for Depth Completion](https://arxiv.org/pdf/1908.01238).\n\n\n## Introduction\n\nThis is the pytorch implementation of our paper.\n\n## Dependency\n```\nPyTorch 1.4\nPyTorch-Encoding v1.4.0\n```\n\n## Setup\nCompile the C++ and CUDA code:\n```\ncd exts\npython setup.py install\n```\n\n## Dataset\nPlease download KITTI [depth completion](http://www.cvlibs.net/datasets/kitti/eval_depth.php?benchmark=depth_completion)\ndataset.\nThe structure of data directory:\n```\n└── datas\n    └── kitti\n        ├── data_depth_annotated\n        │   ├── train\n        │   └── val\n        ├── data_depth_velodyne\n        │   ├── train\n        │   └── val\n        ├── raw\n        │   ├── 2011_09_26\n        │   ├── 2011_09_28\n        │   ├── 2011_09_29\n        │   ├── 2011_09_30\n        │   └── 2011_10_03\n        ├── test_depth_completion_anonymous\n        │   ├── image\n        │   ├── intrinsics\n        │   └── velodyne_raw\n        └── val_selection_cropped\n            ├── groundtruth_depth\n            ├── image\n            ├── intrinsics\n            └── velodyne_raw\n```\n\n## Configs\nThe config of different settings:\n- GN.yaml\n- GNS.yaml\n\n*Compared to **GN**, **GNS** uses fewer parameters to generate the guided kernels, \nbut achieves slightly better results.*\n\n\n## Trained Models\nYou can directly download the trained model and put it in *checkpoints*:\n- [GN](https://drive.google.com/file/d/1-sa2pnMMjSv2dV2bRwuyLxPr1onmVykj/view?usp=sharing)\n- [GNS](https://drive.google.com/file/d/16tVrZQEDBucgjZmTjZl4iFkklkjfeDcs/view?usp=sharing)\n\n## Train \nYou can also train by yourself:\n```\npython train.py\n```\n*Pay attention to the settings in the config file (e.g. gpu id).*\n\n## Test\nWith the trained model, \nyou can test and save depth images.\n```\npython test.py\n```\n\n## Citation\nIf you find this work useful in your research, please consider citing:\n```\n@article{guidenet,\n  title={Learning guided convolutional network for depth completion},\n  author={Tang, Jie and Tian, Fei-Peng and Feng, Wei and Li, Jian and Tan, Ping},\n  journal={IEEE Transactions on Image Processing},\n  volume={30},\n  pages={1116--1129},\n  year={2020},\n  publisher={IEEE}\n}\n```"
  },
  {
    "path": "augs.py",
    "content": "#!/usr/bin/env python\n# -*- coding:utf-8 -*-\n# @Filename:    augs.py\n# @Project:     GuideNet\n# @Author:      jie\n# @Time:        2021/3/14 8:27 PM\n\nimport numpy as np\n\n__all__ = [\n    'Compose',\n    'Norm',\n    'Jitter',\n    'Flip',\n]\n\n\nclass Compose(object):\n    \"\"\"\n    Sequential operations on input images, (i.e. rgb, lidar and depth).\n    \"\"\"\n\n    def __init__(self, transforms):\n        self.transforms = transforms\n\n    def __call__(self, rgb, lidar, depth):\n        for t in self.transforms:\n            rgb, lidar, depth = t(rgb, lidar, depth)\n        return rgb, lidar, depth\n\n\nclass Norm(object):\n    \"\"\"\n    normalize rgb image.\n    \"\"\"\n\n    def __init__(self, mean, std):\n        self.mean = np.array(mean)\n        self.std = np.array(std)\n\n    def __call__(self, rgb, lidar, depth):\n        rgb = (rgb - self.mean) / self.std\n        return rgb, lidar, depth\n\n\nclass Jitter(object):\n    \"\"\"\n    borrow from https://github.com/kujason/avod/blob/master/avod/datasets/kitti/kitti_aug.py\n    \"\"\"\n\n    def __call__(self, rgb, lidar, depth):\n        pca = compute_pca(rgb)\n        rgb = add_pca_jitter(rgb, pca)\n        return rgb, lidar, depth\n\n\nclass Flip(object):\n    \"\"\"\n    random horizontal flip of images.\n    \"\"\"\n\n    def __call__(self, rgb, lidar, depth):\n        flip = bool(np.random.randint(2))\n        if flip:\n            rgb = rgb[:, ::-1, :]\n            lidar = lidar[:, ::-1, :]\n            depth = depth[:, ::-1, :]\n        return rgb, lidar, depth\n\n\ndef compute_pca(image):\n    \"\"\"\n    calculate PCA of image\n    \"\"\"\n\n    reshaped_data = image.reshape(-1, 3)\n    reshaped_data = (reshaped_data / 255.0).astype(np.float32)\n    covariance = np.cov(reshaped_data.T)\n    e_vals, e_vecs = np.linalg.eigh(covariance)\n    pca = np.sqrt(e_vals) * e_vecs\n    return pca\n\n\ndef add_pca_jitter(img_data, pca):\n    \"\"\"\n    add a multiple of principle components with Gaussian noise\n    \"\"\"\n    new_img_data = np.copy(img_data).astype(np.float32) / 255.0\n    magnitude = np.random.randn(3) * 0.1\n    noise = (pca * magnitude).sum(axis=1)\n\n    new_img_data = new_img_data + noise\n    np.clip(new_img_data, 0.0, 1.0, out=new_img_data)\n    new_img_data = (new_img_data * 255).astype(np.uint8)\n\n    return new_img_data\n"
  },
  {
    "path": "checkpoints/.gitignore",
    "content": "# Ignore everything in this directory\n*\n# Except this file\n!.gitignore"
  },
  {
    "path": "configs/GN.yaml",
    "content": "batch_size: 8\ndata_config:\n  kitti:\n    path: datas/kitti\ngpu_ids:\n- 4\n- 5\nloss: MSE\nlr_config:\n  MultiStepLR:\n    gamma: 0.5\n    last_epoch: -1\n    milestones:\n    - 5\n    - 10\n    - 15\nmanual_seed: 0\nmetric: RMSE\nmodel: GN\nname: GN\nnepoch: 20\nnum_workers: 4\noptim_config:\n  AdamW:\n    lr: 0.001\n    weight_decay: 0.05\nresume_seed: 6288\nstart_epoch: 0\ntest_aug_configs:\n- Norm:\n    mean:\n    - 90.995\n    - 96.2278\n    - 94.3213\n    std:\n    - 79.2382\n    - 80.5267\n    - 82.1483\ntest_epoch: 15\ntest_iters: 500\ntrain_aug_configs:\n- Jitter\n- Flip\n- Norm:\n    mean:\n    - 90.995\n    - 96.2278\n    - 94.3213\n    std:\n    - 79.2382\n    - 80.5267\n    - 82.1483\ntta: true\nvis: true\nvis_iters: 100\n"
  },
  {
    "path": "configs/GNS.yaml",
    "content": "batch_size: 8\ndata_config:\n  kitti:\n    path: datas/kitti\ngpu_ids:\n- 6\n- 7\nloss: MSE\nlr_config:\n  MultiStepLR:\n    gamma: 0.5\n    last_epoch: -1\n    milestones:\n    - 5\n    - 10\n    - 15\nmanual_seed: 0\nmetric: RMSE\nmodel: GNS\nname: GNS\nnepoch: 20\nnum_workers: 4\noptim_config:\n  AdamW:\n    lr: 0.001\n    weight_decay: 0.05\nresume_seed: 1600\nstart_epoch: 0\ntest_aug_configs:\n- Norm:\n    mean:\n    - 90.995\n    - 96.2278\n    - 94.3213\n    std:\n    - 79.2382\n    - 80.5267\n    - 82.1483\ntest_epoch: 15\ntest_iters: 500\ntrain_aug_configs:\n- Jitter\n- Flip\n- Norm:\n    mean:\n    - 90.995\n    - 96.2278\n    - 94.3213\n    std:\n    - 79.2382\n    - 80.5267\n    - 82.1483\ntta: true\nvis: true\nvis_iters: 100\n"
  },
  {
    "path": "criteria.py",
    "content": "#!/usr/bin/env python\n# -*- coding:utf-8 -*-\n# @Filename:    criteria.py\n# @Project:     GuideNet\n# @Author:      jie\n# @Time:        2021/3/14 7:51 PM\n\nimport torch\nimport torch.nn as nn\n\n__all__ = [\n    'RMSE',\n    'MSE',\n]\n\n\nclass RMSE(nn.Module):\n\n    def __init__(self):\n        super().__init__()\n\n    def forward(self, outputs, target, *args):\n        val_pixels = (target > 1e-3).float().cuda()\n        err = (target * val_pixels - outputs * val_pixels) ** 2\n        loss = torch.sum(err.view(err.size(0), 1, -1), -1, keepdim=True)\n        cnt = torch.sum(val_pixels.view(val_pixels.size(0), 1, -1), -1, keepdim=True)\n        return torch.sqrt(loss / cnt)\n\n\nclass MSE(nn.Module):\n\n    def __init__(self):\n        super().__init__()\n\n    def forward(self, outputs, target, *args):\n        val_pixels = (target > 1e-3).float().cuda()\n        loss = target * val_pixels - outputs * val_pixels\n        return loss ** 2\n"
  },
  {
    "path": "datas/.gitignore",
    "content": "# Ignore everything in this directory\n*\n# Except this file\n!.gitignore"
  },
  {
    "path": "datasets.py",
    "content": "#!/usr/bin/env python\n# -*- coding:utf-8 -*-\n# @Filename:    datasets.py\n# @Project:     GuideNet\n# @Author:      jie\n# @Time:        2021/3/14 8:08 PM\n\nimport os\nimport numpy as np\nimport glob\nfrom PIL import Image\nimport torch.utils.data as data\n\n__all__ = [\n    'kitti',\n]\n\n\nclass kitti(data.Dataset):\n    \"\"\"\n    kitti depth completion dataset: http://www.cvlibs.net/datasets/kitti/eval_depth.php?benchmark=depth_completion\n    \"\"\"\n\n    def __init__(self, path='../datas/kitti', mode='train', height=256, width=1216, return_idx=False, return_size=False,\n                 transform=None):\n        self.base_dir = path\n        self.height = height\n        self.width = width\n        self.mode = mode\n        self.return_idx = return_idx\n        self.return_size = return_size\n        self.transform = transform\n        if mode in ['train', 'val']:\n            self.depth_path = os.path.join(self.base_dir, 'data_depth_annotated', mode)\n            self.lidar_path = os.path.join(self.base_dir, 'data_depth_velodyne', mode)\n            self.depths = list(sorted(glob.iglob(self.depth_path + \"/**/*.png\", recursive=True)))\n            self.lidars = list(sorted(glob.iglob(self.lidar_path + \"/**/*.png\", recursive=True)))\n        elif mode == 'selval':\n            self.depth_path = os.path.join(self.base_dir, 'val_selection_cropped', 'groundtruth_depth')\n            self.lidar_path = os.path.join(self.base_dir, 'val_selection_cropped', 'velodyne_raw')\n            self.image_path = os.path.join(self.base_dir, 'val_selection_cropped', 'image')\n            self.depths = list(sorted(glob.iglob(self.depth_path + \"/*.png\", recursive=True)))\n            self.lidars = list(sorted(glob.iglob(self.lidar_path + \"/*.png\", recursive=True)))\n            self.images = list(sorted(glob.iglob(self.image_path + \"/*.png\", recursive=True)))\n        elif mode == 'test':\n            self.lidar_path = os.path.join(self.base_dir, 'test_depth_completion_anonymous', 'velodyne_raw')\n            self.image_path = os.path.join(self.base_dir, 'test_depth_completion_anonymous', 'image')\n            self.lidars = list(sorted(glob.iglob(self.lidar_path + \"/*.png\", recursive=True)))\n            self.images = list(sorted(glob.iglob(self.image_path + \"/*.png\", recursive=True)))\n            self.depths = self.lidars\n        else:\n            raise ValueError(\"Unknown mode: {}\".format(mode))\n        assert (len(self.depths) == len(self.lidars))\n        self.names = [os.path.split(path)[-1] for path in self.depths]\n\n    def __len__(self):\n        return len(self.depths)\n\n    def __getitem__(self, index):\n\n        depth = self.pull_DEPTH(self.depths[index])\n        depth = np.expand_dims(depth, axis=2)\n        lidar = self.pull_DEPTH(self.lidars[index])\n        lidar = np.expand_dims(lidar, axis=2)\n        file_names = self.depths[index].split('/')\n        if self.mode in ['train', 'val']:\n            rgb_path = os.path.join(*file_names[:-7], 'raw', file_names[-5].split('_drive')[0], file_names[-5],\n                                    file_names[-2], 'data', file_names[-1])\n        elif self.mode in ['selval', 'test']:\n            rgb_path = self.images[index]\n        else:\n            ValueError(\"Unknown mode: {}\".format(self.mode))\n        rgb = self.pull_RGB(rgb_path)\n        rgb = rgb.astype(np.float32)\n        lidar = lidar.astype(np.float32)\n        depth = depth.astype(np.float32)\n        shape = lidar.shape\n        if self.transform:\n            rgb, lidar, depth = self.transform(rgb, lidar, depth)\n        rgb = rgb.transpose(2, 0, 1).astype(np.float32)\n        lidar = lidar.transpose(2, 0, 1).astype(np.float32)\n        depth = depth.transpose(2, 0, 1).astype(np.float32)\n        lp = (rgb.shape[2] - self.width) // 2\n        rgb = rgb[:, -self.height:, lp:lp + self.width]\n        lidar = lidar[:, -self.height:, lp:lp + self.width]\n        depth = depth[:, -self.height:, lp:lp + self.width]\n        output = [rgb, lidar, depth]\n        if self.return_idx:\n            output.append(np.array([index], dtype=int))\n        if self.return_size:\n            output.append(np.array(shape[:2], dtype=int))\n        return output\n\n    def pull_RGB(self, path):\n        img = np.array(Image.open(path).convert('RGB'), dtype=np.uint8)\n        return img\n\n    def pull_DEPTH(self, path):\n        depth_png = np.array(Image.open(path), dtype=int)\n        assert (np.max(depth_png) > 255)\n        depth_image = (depth_png / 256.).astype(np.float32)\n        return depth_image\n"
  },
  {
    "path": "exts/guideconv.cpp",
    "content": "//\n// Created by jie on 09/02/19.\n//\n\n#include <torch/extension.h>\n#include <ATen/ATen.h>\n#include <vector>\n\n\nvoid Conv2d_LF_Cuda(at::Tensor x, at::Tensor y, at::Tensor z, size_t N1, size_t N2, size_t Ci, size_t Co, size_t B,\n                    size_t K);\n\nvoid\nConv2d_LB_Cuda(at::Tensor x, at::Tensor y, at::Tensor gx, at::Tensor gy, at::Tensor gz, size_t N1, size_t N2, size_t Ci,\n                    size_t Co, size_t B, size_t K);\n\n\nat::Tensor Conv2dLocal_F(\n        at::Tensor a, // BCHW\n        at::Tensor b // BCKKHW\n) {\n    int N1, N2, Ci, Co, K, B;\n    B = a.size(0);\n    Ci = a.size(1);\n    N1 = a.size(2);\n    N2 = a.size(3);\n    Co = Ci;\n    K = sqrt(b.size(1) / Co);\n    auto c = at::zeros_like(a);\n    Conv2d_LF_Cuda(a, b, c, N1, N2, Ci, Co, B, K);\n    return c;\n}\n\n\nstd::tuple <at::Tensor, at::Tensor> Conv2dLocal_B(\n        at::Tensor a,\n        at::Tensor b,\n        at::Tensor gc\n) {\n    int N1, N2, Ci, Co, K, B;\n    B = a.size(0);\n    Ci = a.size(1);\n    N1 = a.size(2);\n    N2 = a.size(3);\n    Co = Ci;\n    K = sqrt(b.size(1) / Co);\n    auto ga = at::zeros_like(a);\n    auto gb = at::zeros_like(b);\n    Conv2d_LB_Cuda(a, b, ga, gb, gc, N1, N2, Ci, Co, B, K);\n    return std::make_tuple(ga, gb);\n}\n\n\nPYBIND11_MODULE(TORCH_EXTENSION_NAME, m\n) {\nm.def(\"Conv2dLocal_F\", &Conv2dLocal_F, \"Conv2dLocal Forward (CUDA)\");\nm.def(\"Conv2dLocal_B\", &Conv2dLocal_B, \"Conv2dLocal Backward (CUDA)\");\n}"
  },
  {
    "path": "exts/guideconv_kernel.cu",
    "content": "//\n// Created by jie on 09/02/19.\n//\n\n#include <cuda.h>\n#include <cuda_runtime.h>\n#include <ATen/ATen.h>\n\nnamespace {\n\n    template<typename scalar_t>\n    __global__ void\n    conv2d_kernel_lf(scalar_t *__restrict__ x, scalar_t *__restrict__ y, scalar_t *__restrict__ z, size_t N1,\n                        size_t N2, size_t Ci, size_t Co, size_t B,\n                        size_t K) {\n        int col_index = threadIdx.x + blockIdx.x * blockDim.x;\n        int row_index = threadIdx.y + blockIdx.y * blockDim.y;\n        int cha_index = threadIdx.z + blockIdx.z * blockDim.z;\n        if ((row_index < N1) && (col_index < N2) && (cha_index < Co)) {\n            for (int b = 0; b < B; b++) {\n                scalar_t result = 0;\n                for (int i = -int((K - 1) / 2.); i < (K + 1) / 2.; i++) {\n                    for (int j = -int((K - 1) / 2.); j < (K + 1) / 2.; j++) {\n\n                        if ((row_index + i < 0) || (row_index + i >= N1) || (col_index + j < 0) ||\n                            (col_index + j >= N2)) {\n                            continue;\n                        }\n\n                        result += x[b * N1 * N2 * Ci + cha_index * N1 * N2 + (row_index + i) * N2 + col_index + j] *\n                                  y[b * N1 * N2 * Ci * K * K + cha_index * N1 * N2 * K * K +\n                                    (i + (K - 1) / 2) * K * N1 * N2 +\n                                    (j + (K - 1) / 2) * N1 * N2 + row_index * N2 + col_index];\n                    }\n                }\n                z[b * N1 * N2 * Co + cha_index * N1 * N2 + row_index * N2 + col_index] = result;\n            }\n        }\n    }\n\n\n    template<typename scalar_t>\n    __global__ void conv2d_kernel_lb(scalar_t *__restrict__ x, scalar_t *__restrict__ y, scalar_t *__restrict__ gx,\n                                         scalar_t *__restrict__ gy, scalar_t *__restrict__ gz, size_t N1, size_t N2,\n                                         size_t Ci, size_t Co, size_t B,\n                                         size_t K) {\n        int col_index = threadIdx.x + blockIdx.x * blockDim.x;\n        int row_index = threadIdx.y + blockIdx.y * blockDim.y;\n        int cha_index = threadIdx.z + blockIdx.z * blockDim.z;\n        if ((row_index < N1) && (col_index < N2) && (cha_index < Co)) {\n            for (int b = 0; b < B; b++) {\n                scalar_t result = 0;\n                for (int i = -int((K - 1) / 2.); i < (K + 1) / 2.; i++) {\n                    for (int j = -int((K - 1) / 2.); j < (K + 1) / 2.; j++) {\n\n                        if ((row_index - i < 0) || (row_index - i >= N1) || (col_index - j < 0) ||\n                            (col_index - j >= N2)) {\n                            continue;\n                        }\n                        result += gz[b * N1 * N2 * Ci + cha_index * N1 * N2 + (row_index - i) * N2 + col_index - j\n                                  ] *\n                                  y[b * N1 * N2 * Ci * K * K + cha_index * N1 * N2 * K * K +\n                                    (i + (K - 1) / 2) * K * N1 * N2 +\n                                    (j + (K - 1) / 2) * N1 * N2 + (row_index - i) * N2 + col_index - j];\n                        gy[b * N1 * N2 * Ci * K * K + cha_index * N1 * N2 * K * K + (i + (K - 1) / 2) * K * N1 * N2 +\n                           (j + (K - 1) / 2) * N1 * N2 + (row_index - i) * N2 + col_index - j] =\n                                gz[b * N1 * N2 * Ci + cha_index * N1 * N2 + (row_index - i) * N2 + col_index - j\n                                ] * x[b * N1 * N2 * Ci + cha_index * N1 * N2 + row_index * N2 + col_index];\n\n                    }\n                }\n                gx[b * N1 * N2 * Co + cha_index * N1 * N2 + row_index * N2 + col_index] = result;\n            }\n        }\n    }\n}\n\n\nvoid Conv2d_LF_Cuda(at::Tensor x, at::Tensor y, at::Tensor z, size_t N1, size_t N2, size_t Ci, size_t Co, size_t B,\n                    size_t K) {\n    dim3 blockSize(32, 32, 1);\n    dim3 gridSize((N2 + blockSize.x - 1) / blockSize.x, (N1 + blockSize.y - 1) / blockSize.y,\n                  (Co + blockSize.z - 1) / blockSize.z);\n    AT_DISPATCH_FLOATING_TYPES(x.type(), \"Conv2d_LF\", ([&] {\n        conv2d_kernel_lf<scalar_t> << < gridSize, blockSize >> > (\n                x.data<scalar_t>(), y.data<scalar_t>(), z.data<scalar_t>(),\n                        N1, N2, Ci, Co, B, K);\n    }));\n}\n\n\nvoid\nConv2d_LB_Cuda(at::Tensor x, at::Tensor y, at::Tensor gx, at::Tensor gy, at::Tensor gz, size_t N1, size_t N2, size_t Ci,\n               size_t Co, size_t B, size_t K) {\n    dim3 blockSize(32, 32, 1);\n    dim3 gridSize((N2 + blockSize.x - 1) / blockSize.x, (N1 + blockSize.y - 1) / blockSize.y,\n                  (Co + blockSize.z - 1) / blockSize.z);\n    AT_DISPATCH_FLOATING_TYPES(x.type(), \"Conv2d_LB\", ([&] {\n        conv2d_kernel_lb<scalar_t> << < gridSize, blockSize >> > (\n                x.data<scalar_t>(), y.data<scalar_t>(),\n                        gx.data<scalar_t>(), gy.data<scalar_t>(), gz.data<scalar_t>(),\n                        N1, N2, Ci, Co, B, K);\n    }));\n}\n"
  },
  {
    "path": "exts/setup.py",
    "content": "from setuptools import setup\nfrom torch.utils.cpp_extension import BuildExtension, CUDAExtension\n\nsetup(\n    name='GuideConv',\n    ext_modules=[\n        CUDAExtension('GuideConv', [\n            'guideconv.cpp',\n            'guideconv_kernel.cu',\n        ]),\n    ],\n    cmdclass={\n        'build_ext': BuildExtension\n    })"
  },
  {
    "path": "models.py",
    "content": "#!/usr/bin/env python\n# -*- coding:utf-8 -*-\n# @Filename:    model.py\n# @Project:     GuideNet\n# @Author:      jie\n# @Time:        2021/3/14 7:50 PM\n\nimport torch\nimport torch.nn as nn\nfrom scipy.stats import truncnorm\nimport math\nfrom torch.autograd import Function\nimport encoding\nimport GuideConv\n\n__all__ = [\n    'GN',\n    'GNS',\n]\n\n\ndef Conv1x1(in_planes, out_planes, stride=1):\n    \"\"\"1x1 convolution\"\"\"\n    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)\n\n\ndef Conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):\n    \"\"\"3x3 convolution with padding\"\"\"\n    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,\n                     padding=dilation, groups=groups, bias=False, dilation=dilation)\n\n\nclass Conv2dLocal_F(Function):\n    @staticmethod\n    def forward(ctx, input, weight):\n        ctx.save_for_backward(input, weight)\n        output = GuideConv.Conv2dLocal_F(input, weight)\n        return output\n\n    @staticmethod\n    def backward(ctx, grad_output):\n        input, weight = ctx.saved_tensors\n        grad_output = grad_output.contiguous()\n        grad_input, grad_weight = GuideConv.Conv2dLocal_B(input, weight, grad_output)\n        return grad_input, grad_weight\n\n\nclass Conv2dLocal(nn.Module):\n    def __init__(self, ):\n        super().__init__()\n\n    def forward(self, input, weight):\n        output = Conv2dLocal_F.apply(input, weight)\n        return output\n\n\nclass Basic2d(nn.Module):\n    def __init__(self, in_channels, out_channels, norm_layer=None, kernel_size=3, padding=1):\n        super().__init__()\n        if norm_layer:\n            conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size,\n                             stride=1, padding=padding, bias=False)\n        else:\n            conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size,\n                             stride=1, padding=padding, bias=True)\n        self.conv = nn.Sequential(conv, )\n        if norm_layer:\n            self.conv.add_module('bn', norm_layer(out_channels))\n        self.conv.add_module('relu', nn.ReLU(inplace=True))\n\n    def forward(self, x):\n        out = self.conv(x)\n        return out\n\n\nclass Basic2dTrans(nn.Module):\n    def __init__(self, in_channels, out_channels, norm_layer=None):\n        super().__init__()\n        if norm_layer is None:\n            norm_layer = nn.BatchNorm2d\n        self.conv = nn.ConvTranspose2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3,\n                                       stride=2, padding=1, output_padding=1, bias=False)\n        self.bn = norm_layer(out_channels)\n        self.relu = nn.ReLU(inplace=True)\n\n    def forward(self, x):\n        out = self.conv(x)\n        out = self.bn(out)\n        out = self.relu(out)\n        return out\n\n\nclass Basic2dLocal(nn.Module):\n    def __init__(self, out_channels, norm_layer=None):\n        super().__init__()\n        if norm_layer is None:\n            norm_layer = nn.BatchNorm2d\n\n        self.conv = Conv2dLocal()\n        self.bn = norm_layer(out_channels)\n        self.relu = nn.ReLU(inplace=True)\n\n    def forward(self, input, weight):\n        out = self.conv(input, weight)\n        out = self.bn(out)\n        out = self.relu(out)\n        return out\n\n\nclass Guide(nn.Module):\n\n    def __init__(self, input_planes, weight_planes, norm_layer=None, weight_ks=3):\n        super().__init__()\n        if norm_layer is None:\n            norm_layer = nn.BatchNorm2d\n        self.local = Basic2dLocal(input_planes, norm_layer)\n        self.pool = nn.AdaptiveAvgPool2d((1, 1))\n        self.conv11 = Basic2d(input_planes + weight_planes, input_planes, None)\n        self.conv12 = nn.Conv2d(input_planes, input_planes * 9, kernel_size=weight_ks, padding=weight_ks // 2)\n        self.conv21 = Basic2d(input_planes + weight_planes, input_planes, None)\n        self.conv22 = nn.Conv2d(input_planes, input_planes * input_planes, kernel_size=1, padding=0)\n        self.br = nn.Sequential(\n            norm_layer(num_features=input_planes),\n            nn.ReLU(inplace=True),\n        )\n        self.conv3 = Basic2d(input_planes, input_planes, norm_layer)\n\n    def forward(self, input, weight):\n        B, Ci, H, W = input.shape\n        weight = torch.cat([input, weight], 1)\n        weight11 = self.conv11(weight)\n        weight12 = self.conv12(weight11)\n        weight21 = self.conv21(weight)\n        weight21 = self.pool(weight21)\n        weight22 = self.conv22(weight21).view(B, -1, Ci)\n        out = self.local(input, weight12).view(B, Ci, -1)\n        out = torch.bmm(weight22, out).view(B, Ci, H, W)\n        out = self.br(out)\n        out = self.conv3(out)\n        return out\n\n\nclass BasicBlock(nn.Module):\n    expansion = 1\n    __constants__ = ['downsample']\n\n    def __init__(self, inplanes, planes, stride=1, downsample=None, norm_layer=None, act=True):\n        super().__init__()\n        if norm_layer is None:\n            norm_layer = nn.BatchNorm2d\n        self.conv1 = Conv3x3(inplanes, planes, stride)\n        self.bn1 = norm_layer(planes)\n        self.relu = nn.ReLU(inplace=True)\n        self.conv2 = Conv3x3(planes, planes)\n        self.bn2 = norm_layer(planes)\n        self.downsample = downsample\n        self.stride = stride\n        self.act = act\n\n    def forward(self, x):\n        identity = x\n        out = self.conv1(x)\n        out = self.bn1(out)\n        out = self.relu(out)\n        out = self.conv2(out)\n        out = self.bn2(out)\n        if self.downsample is not None:\n            identity = self.downsample(x)\n        out += identity\n        if self.act:\n            out = self.relu(out)\n        return out\n\n\nclass GuideNet(nn.Module):\n    \"\"\"\n    Not activate at the ref\n    Init change to trunctated norm\n    \"\"\"\n\n    def __init__(self, block=BasicBlock, bc=16, img_layers=[2, 2, 2, 2, 2],\n                 depth_layers=[2, 2, 2, 2, 2], norm_layer=nn.BatchNorm2d, guide=Guide, weight_ks=3):\n        super().__init__()\n        self._norm_layer = norm_layer\n\n        self.conv_img = Basic2d(3, bc * 2, norm_layer=norm_layer, kernel_size=5, padding=2)\n        in_channels = bc * 2\n        self.inplanes = in_channels\n        self.layer1_img = self._make_layer(block, in_channels * 2, img_layers[0], stride=2)\n\n        self.guide1 = guide(in_channels * 2, in_channels * 2, norm_layer, weight_ks)\n        self.inplanes = in_channels * 2 * block.expansion\n        self.layer2_img = self._make_layer(block, in_channels * 4, img_layers[1], stride=2)\n\n        self.guide2 = guide(in_channels * 4, in_channels * 4, norm_layer, weight_ks)\n        self.inplanes = in_channels * 4 * block.expansion\n        self.layer3_img = self._make_layer(block, in_channels * 8, img_layers[2], stride=2)\n\n        self.guide3 = guide(in_channels * 8, in_channels * 8, norm_layer, weight_ks)\n        self.inplanes = in_channels * 8 * block.expansion\n        self.layer4_img = self._make_layer(block, in_channels * 8, img_layers[3], stride=2)\n\n        self.guide4 = guide(in_channels * 8, in_channels * 8, norm_layer, weight_ks)\n        self.inplanes = in_channels * 8 * block.expansion\n        self.layer5_img = self._make_layer(block, in_channels * 8, img_layers[4], stride=2)\n\n        self.layer2d_img = Basic2dTrans(in_channels * 4, in_channels * 2, norm_layer)\n        self.layer3d_img = Basic2dTrans(in_channels * 8, in_channels * 4, norm_layer)\n        self.layer4d_img = Basic2dTrans(in_channels * 8, in_channels * 8, norm_layer)\n        self.layer5d_img = Basic2dTrans(in_channels * 8, in_channels * 8, norm_layer)\n\n        self.conv_lidar = Basic2d(1, bc * 2, norm_layer=None, kernel_size=5, padding=2)\n\n        self.inplanes = in_channels\n        self.layer1_lidar = self._make_layer(block, in_channels * 2, depth_layers[0], stride=2)\n        self.inplanes = in_channels * 2 * block.expansion\n        self.layer2_lidar = self._make_layer(block, in_channels * 4, depth_layers[1], stride=2)\n        self.inplanes = in_channels * 4 * block.expansion\n        self.layer3_lidar = self._make_layer(block, in_channels * 8, depth_layers[2], stride=2)\n        self.inplanes = in_channels * 8 * block.expansion\n        self.layer4_lidar = self._make_layer(block, in_channels * 8, depth_layers[3], stride=2)\n        self.inplanes = in_channels * 8 * block.expansion\n        self.layer5_lidar = self._make_layer(block, in_channels * 8, depth_layers[4], stride=2)\n\n        self.layer1d = Basic2dTrans(in_channels * 2, in_channels, norm_layer)\n        self.layer2d = Basic2dTrans(in_channels * 4, in_channels * 2, norm_layer)\n        self.layer3d = Basic2dTrans(in_channels * 8, in_channels * 4, norm_layer)\n        self.layer4d = Basic2dTrans(in_channels * 8, in_channels * 8, norm_layer)\n        self.layer5d = Basic2dTrans(in_channels * 8, in_channels * 8, norm_layer)\n\n        self.conv = nn.Conv2d(bc * 2, 1, kernel_size=3, stride=1, padding=1)\n        self.ref = block(bc * 2, bc * 2, norm_layer=norm_layer, act=False)\n\n        self._initialize_weights()\n\n    def forward(self, img, lidar):\n        c0_img = self.conv_img(img)\n        c1_img = self.layer1_img(c0_img)\n        c2_img = self.layer2_img(c1_img)\n        c3_img = self.layer3_img(c2_img)\n        c4_img = self.layer4_img(c3_img)\n        c5_img = self.layer5_img(c4_img)\n        dc5_img = self.layer5d_img(c5_img)\n        c4_mix = dc5_img + c4_img\n        dc4_img = self.layer4d_img(c4_mix)\n        c3_mix = dc4_img + c3_img\n        dc3_img = self.layer3d_img(c3_mix)\n        c2_mix = dc3_img + c2_img\n        dc2_img = self.layer2d_img(c2_mix)\n        c1_mix = dc2_img + c1_img\n\n        c0_lidar = self.conv_lidar(lidar)\n        c1_lidar = self.layer1_lidar(c0_lidar)\n        c1_lidar_dyn = self.guide1(c1_lidar, c1_mix)\n        c2_lidar = self.layer2_lidar(c1_lidar_dyn)\n        c2_lidar_dyn = self.guide2(c2_lidar, c2_mix)\n        c3_lidar = self.layer3_lidar(c2_lidar_dyn)\n        c3_lidar_dyn = self.guide3(c3_lidar, c3_mix)\n        c4_lidar = self.layer4_lidar(c3_lidar_dyn)\n        c4_lidar_dyn = self.guide4(c4_lidar, c4_mix)\n        c5_lidar = self.layer5_lidar(c4_lidar_dyn)\n        c5 = c5_img + c5_lidar\n        dc5 = self.layer5d(c5)\n        c4 = dc5 + c4_lidar_dyn\n        dc4 = self.layer4d(c4)\n        c3 = dc4 + c3_lidar_dyn\n        dc3 = self.layer3d(c3)\n        c2 = dc3 + c2_lidar_dyn\n        dc2 = self.layer2d(c2)\n        c1 = dc2 + c1_lidar_dyn\n        dc1 = self.layer1d(c1)\n        c0 = dc1 + c0_lidar\n        output = self.ref(c0)\n        output = self.conv(output)\n        return (output,)\n\n    def _make_layer(self, block, planes, blocks, stride=1):\n        norm_layer = self._norm_layer\n        downsample = None\n        if stride != 1 or self.inplanes != planes * block.expansion:\n            downsample = nn.Sequential(\n                Conv1x1(self.inplanes, planes * block.expansion, stride),\n                norm_layer(planes * block.expansion),\n            )\n\n        layers = []\n        layers.append(block(self.inplanes, planes, stride, downsample, norm_layer))\n        self.inplanes = planes * block.expansion\n        for _ in range(1, blocks):\n            layers.append(block(self.inplanes, planes, norm_layer=norm_layer))\n\n        return nn.Sequential(*layers)\n\n    def _initialize_weights(self):\n        def truncated_normal_(num, mean=0., std=1.):\n            lower = -2 * std\n            upper = 2 * std\n            X = truncnorm((lower - mean) / std, (upper - mean) / std, loc=mean, scale=std)\n            samples = X.rvs(num)\n            output = torch.from_numpy(samples)\n            return output\n\n        for m in self.modules():\n            if isinstance(m, nn.Conv2d):\n                n = m.kernel_size[0] * m.kernel_size[1] * m.in_channels\n                data = truncated_normal_(m.weight.nelement(), mean=0, std=math.sqrt(1.3 * 2. / n))\n                data = data.type_as(m.weight.data)\n                m.weight.data = data.view_as(m.weight.data)\n                if m.bias is not None:\n                    nn.init.zeros_(m.bias)\n\n\ndef GN():\n    return GuideNet(norm_layer=encoding.nn.SyncBatchNorm, guide=Guide)\n\n\ndef GNS():\n    return GuideNet(norm_layer=encoding.nn.SyncBatchNorm, guide=Guide, weight_ks=1)\n"
  },
  {
    "path": "optimizers.py",
    "content": "#!/usr/bin/env python\n# -*- coding:utf-8 -*-\n# @Filename:    optimizers.py\n# @Project:     GuideNet\n# @Author:      jie\n# @Time:        2021/3/15 4:59 PM\n\"\"\"\nThis is a fixup as pytorch 1.4.0 can not import AdamW directly from torch.optim\n\"\"\"\n\nfrom torch.optim import *\nfrom torch.optim.adamw import AdamW"
  },
  {
    "path": "test.py",
    "content": "#!/usr/bin/env python\n# -*- coding:utf-8 -*-\n# @Filename:    test.py\n# @Project:     GuideNet\n# @Author:      jie\n# @Time:        2021/3/16 4:47 PM\n\nimport os\n\nos.environ[\"CUDA_VISIBLE_DEVICES\"] = '0'\nimport torch\nimport yaml\nfrom easydict import EasyDict as edict\nimport datasets\nimport encoding\n\ndef test():\n    net.eval()\n    for batch_idx, (rgb, lidar, _, idx, ori_size) in enumerate(testloader):\n        with torch.no_grad():\n            if config.tta:\n                rgbf = torch.flip(rgb, [-1])\n                lidarf = torch.flip(lidar, [-1])\n                rgbs = torch.cat([rgb, rgbf], 0)\n                lidars = torch.cat([lidar, lidarf], 0)\n                rgbs, lidars = rgbs.cuda(), lidars.cuda()\n                depth_preds, = net(rgbs, lidars)\n                depth_pred, depth_predf = depth_preds.split(depth_preds.shape[0] // 2)\n                depth_predf = torch.flip(depth_predf, [-1])\n                depth_pred = (depth_pred + depth_predf) / 2.\n            else:\n                rgb, lidar = rgb.cuda(), lidar.cuda()\n                depth_pred, = net(rgb, lidar)\n            depth_pred[depth_pred < 0] = 0\n        depth_pred = depth_pred.cpu().squeeze(1).numpy()\n        idx = idx.cpu().squeeze(1).numpy()\n        ori_size = ori_size.cpu().numpy()\n        name = [testset.names[i] for i in idx]\n        save_result(config, depth_pred, name, ori_size)\n\n\nif __name__ == '__main__':\n    # config_name = 'GN.yaml'\n    config_name = 'GNS.yaml'\n    with open(os.path.join('configs', config_name), 'r') as file:\n        config_data = yaml.load(file, Loader=yaml.FullLoader)\n    config = edict(config_data)\n    from utils import *\n\n    transform = init_aug(config.test_aug_configs)\n    key, params = config.data_config.popitem()\n    dataset = getattr(datasets, key)\n    testset = dataset(**params, mode='test', transform=transform, return_idx=True, return_size=True)\n    testloader = torch.utils.data.DataLoader(testset, batch_size=config.batch_size, num_workers=config.num_workers,\n                                             shuffle=False, pin_memory=True)\n    print('num_test = {}'.format(len(testset)))\n    net = init_net(config)\n    torch.cuda.empty_cache()\n    torch.backends.cudnn.benchmark = True\n    net.cuda()\n    net = encoding.parallel.DataParallelModel(net)\n    net = resume_state(config, net)\n    test()\n"
  },
  {
    "path": "train.py",
    "content": "#!/usr/bin/env python\n# -*- coding:utf-8 -*-\n# @Filename:    train.py\n# @Project:     GuideNet\n# @Author:      jie\n# @Time:        2021/3/14 7:50 PM\n\nimport os\nimport torch\nimport yaml\nfrom easydict import EasyDict as edict\n\n\ndef train(epoch):\n    global iters\n    Avg = AverageMeter()\n    for batch_idx, (rgb, lidar, depth) in enumerate(trainloader):\n        if epoch >= config.test_epoch and iters % config.test_iters == 0:\n            test()\n        net.train()\n        rgb, lidar, depth = rgb.cuda(), lidar.cuda(), depth.cuda()\n        optimizer.zero_grad()\n        output = net(rgb, lidar)\n        loss = criterion(output, depth).mean()\n        loss.backward()\n        optimizer.step()\n        Avg.update(loss.item())\n        iters += 1\n        if config.vis and batch_idx % config.vis_iters == 0:\n            print('Epoch {} Idx {} Loss {:.4f}'.format(epoch, batch_idx, Avg.avg))\n\n\ndef test():\n    global best_metric\n    Avg = AverageMeter()\n    net.eval()\n    for batch_idx, (rgb, lidar, depth) in enumerate(testloader):\n        rgb, lidar, depth = rgb.cuda(), lidar.cuda(), depth.cuda()\n        with torch.no_grad():\n            output = net(rgb, lidar)\n            prec = metric(output, depth).mean()\n        Avg.update(prec.item(), rgb.size(0))\n    if Avg.avg < best_metric:\n        best_metric = Avg.avg\n        save_state(config, net)\n        print('Best Result: {:.4f}\\n'.format(best_metric))\n\n\nif __name__ == '__main__':\n    # config_name = 'GN.yaml'\n    config_name = 'GNS.yaml'\n    with open(os.path.join('configs', config_name), 'r') as file:\n        config_data = yaml.load(file, Loader=yaml.FullLoader)\n    config = edict(config_data)\n    print(config.name)\n    os.environ[\"CUDA_VISIBLE_DEVICES\"] = ','.join([str(gpu_id) for gpu_id in config.gpu_ids])\n    from utils import *\n\n    init_seed(config)\n    trainloader, testloader = init_dataset(config)\n    net = init_net(config)\n    criterion = init_loss(config)\n    metric = init_metric(config)\n    net, criterion, metric = init_cuda(net, criterion, metric)\n    optimizer = init_optim(config, net)\n    lr_scheduler = init_lr_scheduler(config, optimizer)\n    iters = 0\n    best_metric = 100\n    for epoch in range(config.start_epoch, config.nepoch):\n        train(epoch)\n        lr_scheduler.step()\n    print('Best Results: {:.4f}\\n'.format(best_metric))\n"
  },
  {
    "path": "utils.py",
    "content": "#!/usr/bin/env python\n# -*- coding:utf-8 -*-\n# @Filename:    utils.py\n# @Project:     GuideNet\n# @Author:      jie\n# @Time:        2021/3/15 5:25 PM\n\nimport os\nimport torch\nimport random\nimport numpy as np\nimport augs\nimport models\nimport datasets\nimport optimizers\nimport encoding\nimport criteria\nfrom PIL import Image\n\n__all__ = [\n    'AverageMeter',\n    'init_seed',\n    'init_aug',\n    'init_dataset',\n    'init_cuda',\n    'init_net',\n    'init_loss',\n    'init_metric',\n    'init_optim',\n    'init_lr_scheduler',\n    'save_state',\n    'resume_state',\n    'save_result',\n]\n\n\nclass AverageMeter(object):\n    def __init__(self):\n        self.reset()\n\n    def reset(self):\n        self.val = 0\n        self.avg = 0\n        self.sum = 0\n        self.count = 0\n\n    def update(self, val, n=1):\n        self.val = val\n        self.sum += val * n\n        self.count += n\n        self.avg = self.sum / self.count\n\n\ndef config_param(model):\n    param_groups = []\n    other_params = []\n    for name, param in model.named_parameters():\n        if len(param.shape) == 1:\n            g = {'params': [param], 'weight_decay': 0.0}\n            param_groups.append(g)\n        else:\n            other_params.append(param)\n    param_groups.append({'params': other_params})\n    return param_groups\n\n\ndef save_state(config, model):\n    print('==> Saving model ...')\n    env_name = config.name + '_' + str(config.manual_seed)\n    save_path = os.path.join('checkpoints', env_name)\n    os.makedirs(save_path, exist_ok=True)\n    model_state_dict = model.state_dict()\n    state_dict = {\n        'net': model_state_dict,\n    }\n    torch.save(state_dict, os.path.join(save_path, 'result.pth'))\n\n\ndef resume_state(config, model):\n    env_name = config.name + '_' + str(config.resume_seed)\n    cp_path = os.path.join('checkpoints', env_name, 'result.pth')\n    resume_model = torch.load(cp_path)['net']\n    model.load_state_dict(resume_model, strict=True)\n    return model\n\n\ndef pad_rep(image, ori_size):\n    h, w = image.shape\n    oh, ow = ori_size\n    pl = (ow - w) // 2\n    pr = ow - w - pl\n    pt = oh - h\n    image_pad = np.pad(image, pad_width=((pt, 0), (pl, pr)), mode='edge')\n    return image_pad\n\n\ndef save_result(config, depths, names, ori_sizes=None):\n    env_name = config.name + '_' + str(config.resume_seed)\n    save_path = os.path.join('results', env_name)\n    os.makedirs(save_path, exist_ok=True)\n    for i in range(depths.shape[0]):\n        depth, name = depths[i], names[i]\n        if ori_sizes is not None:\n            depth = pad_rep(depth, ori_sizes[i])\n        filename = os.path.join(save_path, name)\n        img = (depth * 256.0).astype('uint16')\n        Img = Image.fromarray(img)\n        Img.save(filename)\n\n\ndef init_seed(config):\n    if config.manual_seed == 0:\n        config.manual_seed = random.randint(1, 10000)\n    print(\"Random Seed: \", config.manual_seed)\n    torch.initial_seed()\n    random.seed(config.manual_seed)\n    np.random.seed(config.manual_seed)\n    torch.manual_seed(config.manual_seed)\n    torch.cuda.manual_seed_all(config.manual_seed)\n\n\ndef init_net(config):\n    return getattr(models, config.model)()\n\n\ndef init_loss(config):\n    return getattr(criteria, config.loss)()\n\n\ndef init_metric(config):\n    return getattr(criteria, config.metric)()\n\n\ndef init_aug(aug_config):\n    transform = []\n    for x in aug_config:\n        print(x)\n        if type(x) == str:\n            transform.append(getattr(augs, x)())\n        else:\n            key, params = x.popitem()\n            transform.append(getattr(augs, key)(**params))\n    return augs.Compose(transform)\n\n\ndef init_dataset(config):\n    train_transform = init_aug(config.train_aug_configs)\n    test_transform = init_aug(config.test_aug_configs)\n    key, params = config.data_config.popitem()\n    dataset = getattr(datasets, key)\n    trainset = dataset(**params, mode='train', transform=train_transform)\n    testset = dataset(**params, mode='selval', transform=test_transform)\n    trainloader = torch.utils.data.DataLoader(trainset, batch_size=config.batch_size,\n                                              num_workers=config.num_workers, shuffle=True, drop_last=True,\n                                              pin_memory=True)\n    testloader = torch.utils.data.DataLoader(testset, batch_size=config.batch_size,\n                                             num_workers=config.num_workers, shuffle=True, drop_last=True,\n                                             pin_memory=True)\n    print('num_train = {}, num_test = {}'.format(len(trainset), len(testset)))\n    return trainloader, testloader\n\n\ndef init_cuda(net, criterion, metric):\n    torch.cuda.empty_cache()\n    net.cuda()\n    criterion.cuda()\n    metric.cuda()\n    net = encoding.parallel.DataParallelModel(net)\n    criterion = encoding.parallel.DataParallelCriterion(criterion)\n    metric = encoding.parallel.DataParallelCriterion(metric)\n    torch.backends.cudnn.benchmark = True\n    return net, criterion, metric\n\n\ndef init_optim(config, net):\n    key, params = config.optim_config.popitem()\n    return getattr(optimizers, key)(config_param(net), **params)\n\n\ndef init_lr_scheduler(config, optimizer):\n    key, params = config.lr_config.popitem()\n    return getattr(torch.optim.lr_scheduler, key)(optimizer, **params)\n"
  }
]