[
  {
    "path": ".gitattributes",
    "content": "crfasrnn/permutohedral.cpp linguist-vendored\ncrfasrnn/permutohedral.h linguist-vendored\n\n"
  },
  {
    "path": ".gitignore",
    "content": ".idea\n__pycache__\n.pyc\n\n"
  },
  {
    "path": "LICENSE",
    "content": "MIT License\n\nCopyright (c) 2017 Sadeep Jayasumana\n\nPermission is hereby granted, free of charge, to any person obtaining a copy\nof this software and associated documentation files (the \"Software\"), to deal\nin the Software without restriction, including without limitation the rights\nto use, copy, modify, merge, publish, distribute, sublicense, and/or sell\ncopies of the Software, and to permit persons to whom the Software is\nfurnished to do so, subject to the following conditions:\n\nThe above copyright notice and this permission notice shall be included in all\ncopies or substantial portions of the Software.\n\nTHE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\nIMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\nFITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\nAUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\nLIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\nOUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\nSOFTWARE.\n"
  },
  {
    "path": "README.md",
    "content": "# CRF-RNN for Semantic Image Segmentation - PyTorch version\n![sample](sample.png)\n\n<b>Live demo:</b> &nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp; [http://crfasrnn.torr.vision](http://crfasrnn.torr.vision) <br/>\n<b>Caffe version:</b> &nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;[http://github.com/torrvision/crfasrnn](http://github.com/torrvision/crfasrnn)<br/>\n<b>Tensorflow/Keras version:</b> [http://github.com/sadeepj/crfasrnn_keras](http://github.com/sadeepj/crfasrnn_keras)<br/>\n\nThis repository contains the official PyTorch implementation of the \"CRF-RNN\" semantic image segmentation method, published in the ICCV 2015 paper [Conditional Random Fields as Recurrent Neural Networks](http://www.robots.ox.ac.uk/~szheng/papers/CRFasRNN.pdf). The [online demo](http://crfasrnn.torr.vision) of this project won the Best Demo Prize at ICCV 2015. Results of this PyTorch code are identical to that of the Caffe and Tensorflow/Keras based versions above.\n\nIf you use this code/model for your research, please cite the following paper:\n```\n@inproceedings{crfasrnn_ICCV2015,\n    author = {Shuai Zheng and Sadeep Jayasumana and Bernardino Romera-Paredes and Vibhav Vineet and\n    Zhizhong Su and Dalong Du and Chang Huang and Philip H. S. Torr},\n    title  = {Conditional Random Fields as Recurrent Neural Networks},\n    booktitle = {International Conference on Computer Vision (ICCV)},\n    year   = {2015}\n}\n```\n\n## Installation Guide\n\n_Note_: If you are using a Python virtualenv, make sure it is activated before running each command in this guide.\n\n### Step 1: Clone the repository\n```\n$ git clone https://github.com/sadeepj/crfasrnn_pytorch.git\n```\nThe root directory of the clone will be referred to as `crfasrnn_pytorch` hereafter.\n\n### Step 2: Install dependencies\n\n\nUse the `requirements.txt` file in this repository to install all the dependencies via `pip`:\n```\n$ cd crfasrnn_pytorch\n$ pip install -r requirements.txt\n```\n\nAfter installing the dependencies, run the following commands to make sure they are properly installed:\n```\n$ python\n>>> import torch \n```\nYou should not see any errors while importing `torch` above.\n\n### Step 3: Build CRF-RNN custom op\n\nRun `setup.py` inside the `crfasrnn_pytorch/crfasrnn` directory:\n```\n$ cd crfasrnn_pytorch/crfasrnn\n$ python setup.py install \n``` \nNote that the `python` command in the console should refer to the Python interpreter associated with your PyTorch installation. \n\n### Step 4: Download the pre-trained model weights\n\nDownload the model weights from [here](https://github.com/sadeepj/crfasrnn_pytorch/releases/download/0.0.1/crfasrnn_weights.pth) and place it in the `crfasrnn_pytorch` directory with the file name `crfasrnn_weights.pth`.\n\n### Step 5: Run the demo\n```\n$ cd crfasrnn_pytorch\n$ python run_demo.py\n```\nIf all goes well, you will see the segmentation results in a file named \"labels.png\".\n\n## Contributors\n* Sadeep Jayasumana ([sadeepj](https://github.com/sadeepj))\n* Harsha Ranasinghe ([HarshaPrabhath](https://github.com/HarshaPrabhath))\n\n"
  },
  {
    "path": "crfasrnn/__init__.py",
    "content": ""
  },
  {
    "path": "crfasrnn/crfasrnn_model.py",
    "content": "\"\"\"\nMIT License\n\nCopyright (c) 2019 Sadeep Jayasumana\n\nPermission is hereby granted, free of charge, to any person obtaining a copy\nof this software and associated documentation files (the \"Software\"), to deal\nin the Software without restriction, including without limitation the rights\nto use, copy, modify, merge, publish, distribute, sublicense, and/or sell\ncopies of the Software, and to permit persons to whom the Software is\nfurnished to do so, subject to the following conditions:\n\nThe above copyright notice and this permission notice shall be included in all\ncopies or substantial portions of the Software.\n\nTHE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\nIMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\nFITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\nAUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\nLIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\nOUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\nSOFTWARE.\n\"\"\"\n\nfrom crfasrnn.crfrnn import CrfRnn\nfrom crfasrnn.fcn8s import Fcn8s\n\n\nclass CrfRnnNet(Fcn8s):\n    \"\"\"\n    The full CRF-RNN network with the FCN-8s backbone as described in the paper:\n\n    Conditional Random Fields as Recurrent Neural Networks,\n    S. Zheng, S. Jayasumana, B. Romera-Paredes, V. Vineet, Z. Su, D. Du, C. Huang and P. Torr,\n    ICCV 2015 (https://arxiv.org/abs/1502.03240).\n    \"\"\"\n\n    def __init__(self):\n        super(CrfRnnNet, self).__init__()\n        self.crfrnn = CrfRnn(num_labels=21, num_iterations=10)\n\n    def forward(self, image):\n        out = super(CrfRnnNet, self).forward(image)\n        # Plug the CRF-RNN module at the end\n        return self.crfrnn(image, out)\n"
  },
  {
    "path": "crfasrnn/crfrnn.py",
    "content": "\"\"\"\nMIT License\n\nCopyright (c) 2019 Sadeep Jayasumana\n\nPermission is hereby granted, free of charge, to any person obtaining a copy\nof this software and associated documentation files (the \"Software\"), to deal\nin the Software without restriction, including without limitation the rights\nto use, copy, modify, merge, publish, distribute, sublicense, and/or sell\ncopies of the Software, and to permit persons to whom the Software is\nfurnished to do so, subject to the following conditions:\n\nThe above copyright notice and this permission notice shall be included in all\ncopies or substantial portions of the Software.\n\nTHE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\nIMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\nFITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\nAUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\nLIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\nOUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\nSOFTWARE.\n\"\"\"\n\nimport torch\nimport torch.nn as nn\n\nfrom crfasrnn.filters import SpatialFilter, BilateralFilter\nfrom crfasrnn.params import DenseCRFParams\n\n\nclass CrfRnn(nn.Module):\n    \"\"\"\n    PyTorch implementation of the CRF-RNN module described in the paper:\n\n    Conditional Random Fields as Recurrent Neural Networks,\n    S. Zheng, S. Jayasumana, B. Romera-Paredes, V. Vineet, Z. Su, D. Du, C. Huang and P. Torr,\n    ICCV 2015 (https://arxiv.org/abs/1502.03240).\n    \"\"\"\n\n    def __init__(self, num_labels, num_iterations=5, crf_init_params=None):\n        \"\"\"\n        Create a new instance of the CRF-RNN layer.\n\n        Args:\n            num_labels:         Number of semantic labels in the dataset\n            num_iterations:     Number of mean-field iterations to perform\n            crf_init_params:    CRF initialization parameters\n        \"\"\"\n        super(CrfRnn, self).__init__()\n\n        if crf_init_params is None:\n            crf_init_params = DenseCRFParams()\n\n        self.params = crf_init_params\n        self.num_iterations = num_iterations\n\n        self._softmax = torch.nn.Softmax(dim=0)\n\n        self.num_labels = num_labels\n\n        # --------------------------------------------------------------------------------------------\n        # --------------------------------- Trainable Parameters -------------------------------------\n        # --------------------------------------------------------------------------------------------\n\n        # Spatial kernel weights\n        self.spatial_ker_weights = nn.Parameter(\n            crf_init_params.spatial_ker_weight\n            * torch.eye(num_labels, dtype=torch.float32)\n        )\n\n        # Bilateral kernel weights\n        self.bilateral_ker_weights = nn.Parameter(\n            crf_init_params.bilateral_ker_weight\n            * torch.eye(num_labels, dtype=torch.float32)\n        )\n\n        # Compatibility transform matrix\n        self.compatibility_matrix = nn.Parameter(\n            torch.eye(num_labels, dtype=torch.float32)\n        )\n\n    def forward(self, image, logits):\n        \"\"\"\n        Perform CRF inference.\n\n        Args:\n            image:  Tensor of shape (3, h, w) containing the RGB image\n            logits: Tensor of shape (num_classes, h, w) containing the unary logits\n        Returns:\n            log-Q distributions (logits) after CRF inference\n        \"\"\"\n        if logits.shape[0] != 1:\n            raise ValueError(\"Only batch size 1 is currently supported!\")\n\n        image = image[0]\n        logits = logits[0]\n\n        spatial_filter = SpatialFilter(image, gamma=self.params.gamma)\n        bilateral_filter = BilateralFilter(\n            image, alpha=self.params.alpha, beta=self.params.beta\n        )\n        _, h, w = image.shape\n        cur_logits = logits\n\n        for _ in range(self.num_iterations):\n            # Normalization\n            q_values = self._softmax(cur_logits)\n\n            # Spatial filtering\n            spatial_out = torch.mm(\n                self.spatial_ker_weights,\n                spatial_filter.apply(q_values).view(self.num_labels, -1),\n            )\n\n            # Bilateral filtering\n            bilateral_out = torch.mm(\n                self.bilateral_ker_weights,\n                bilateral_filter.apply(q_values).view(self.num_labels, -1),\n            )\n\n            # Compatibility transform\n            msg_passing_out = (\n                spatial_out + bilateral_out\n            )  # Shape: (self.num_labels, -1)\n            msg_passing_out = torch.mm(self.compatibility_matrix, msg_passing_out).view(\n                self.num_labels, h, w\n            )\n\n            # Adding unary potentials\n            cur_logits = msg_passing_out + logits\n\n        return torch.unsqueeze(cur_logits, 0)\n"
  },
  {
    "path": "crfasrnn/fcn8s.py",
    "content": "\"\"\"\nThis file contains a modified version of the FCN-8s code available in https://github.com/wkentaro/pytorch-fcn\nThe original copyright notice from that repository is included below:\n\nCopyright (c) 2017 - 2019 Kentaro Wada.\n\nPermission is hereby granted, free of charge, to any person obtaining a copy\nof this software and associated documentation files (the \"Software\"), to deal\nin the Software without restriction, including without limitation the rights\nto use, copy, modify, merge, publish, distribute, sublicense, and/or sell\ncopies of the Software, and to permit persons to whom the Software is\nfurnished to do so, subject to the following conditions:\n\nThe above copyright notice and this permission notice shall be included in\nall copies or substantial portions of the Software.\n\nTHE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\nIMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\nFITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\nAUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\nLIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\nOUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN\nTHE SOFTWARE.\n\"\"\"\n\nimport numpy as np\nimport torch\nimport torch.nn as nn\n\n\ndef _upsampling_weights(in_channels, out_channels, kernel_size):\n    factor = (kernel_size + 1) // 2\n    if kernel_size % 2 == 1:\n        center = factor - 1\n    else:\n        center = factor - 0.5\n    og = np.ogrid[:kernel_size, :kernel_size]\n    filt = (1 - abs(og[0] - center) / factor) * (1 - abs(og[1] - center) / factor)\n    weight = np.zeros(\n        (in_channels, out_channels, kernel_size, kernel_size), dtype=np.float64\n    )\n    weight[range(in_channels), range(out_channels), :, :] = filt\n    return torch.from_numpy(weight).float()\n\n\nclass Fcn8s(nn.Module):\n    def __init__(self, n_class=21):\n        \"\"\"\n        Create the FCN-8s network the the given number of classes.\n\n        Args:\n            n_class:    The number of semantic classes.\n        \"\"\"\n\n        super(Fcn8s, self).__init__()\n\n        # conv1\n        self.conv1_1 = nn.Conv2d(3, 64, 3, padding=100)\n        self.relu1_1 = nn.ReLU(inplace=True)\n        self.conv1_2 = nn.Conv2d(64, 64, 3, padding=1)\n        self.relu1_2 = nn.ReLU(inplace=True)\n        self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)\n\n        # conv2\n        self.conv2_1 = nn.Conv2d(64, 128, 3, padding=1)\n        self.relu2_1 = nn.ReLU(inplace=True)\n        self.conv2_2 = nn.Conv2d(128, 128, 3, padding=1)\n        self.relu2_2 = nn.ReLU(inplace=True)\n        self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)\n\n        # conv3\n        self.conv3_1 = nn.Conv2d(128, 256, 3, padding=1)\n        self.relu3_1 = nn.ReLU(inplace=True)\n        self.conv3_2 = nn.Conv2d(256, 256, 3, padding=1)\n        self.relu3_2 = nn.ReLU(inplace=True)\n        self.conv3_3 = nn.Conv2d(256, 256, 3, padding=1)\n        self.relu3_3 = nn.ReLU(inplace=True)\n        self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)\n\n        # conv4\n        self.conv4_1 = nn.Conv2d(256, 512, 3, padding=1)\n        self.relu4_1 = nn.ReLU(inplace=True)\n        self.conv4_2 = nn.Conv2d(512, 512, 3, padding=1)\n        self.relu4_2 = nn.ReLU(inplace=True)\n        self.conv4_3 = nn.Conv2d(512, 512, 3, padding=1)\n        self.relu4_3 = nn.ReLU(inplace=True)\n        self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)\n\n        # conv5\n        self.conv5_1 = nn.Conv2d(512, 512, 3, padding=1)\n        self.relu5_1 = nn.ReLU(inplace=True)\n        self.conv5_2 = nn.Conv2d(512, 512, 3, padding=1)\n        self.relu5_2 = nn.ReLU(inplace=True)\n        self.conv5_3 = nn.Conv2d(512, 512, 3, padding=1)\n        self.relu5_3 = nn.ReLU(inplace=True)\n        self.pool5 = nn.MaxPool2d(2, stride=2, ceil_mode=True)\n\n        # fc6\n        self.fc6 = nn.Conv2d(512, 4096, 7)\n        self.relu6 = nn.ReLU(inplace=True)\n        self.drop6 = nn.Dropout2d()\n\n        # fc7\n        self.fc7 = nn.Conv2d(4096, 4096, 1)\n        self.relu7 = nn.ReLU(inplace=True)\n        self.drop7 = nn.Dropout2d()\n\n        self.score_fr = nn.Conv2d(4096, n_class, 1)\n        self.score_pool3 = nn.Conv2d(256, n_class, 1)\n        self.score_pool4 = nn.Conv2d(512, n_class, 1)\n\n        self.upscore2 = nn.ConvTranspose2d(n_class, n_class, 4, stride=2, bias=True)\n        self.upscore8 = nn.ConvTranspose2d(n_class, n_class, 16, stride=8, bias=False)\n        self.upscore_pool4 = nn.ConvTranspose2d(\n            n_class, n_class, 4, stride=2, bias=False\n        )\n\n        self._initialize_weights()\n\n    def _initialize_weights(self):\n        for m in self.modules():\n            if isinstance(m, nn.Conv2d):\n                m.weight.data.zero_()\n                if m.bias is not None:\n                    m.bias.data.zero_()\n            if isinstance(m, nn.ConvTranspose2d):\n                assert m.kernel_size[0] == m.kernel_size[1]\n                initial_weight = _upsampling_weights(\n                    m.in_channels, m.out_channels, m.kernel_size[0]\n                )\n                m.weight.data.copy_(initial_weight)\n\n    def forward(self, image):\n        h = self.relu1_1(self.conv1_1(image))\n        h = self.relu1_2(self.conv1_2(h))\n        h = self.pool1(h)\n\n        h = self.relu2_1(self.conv2_1(h))\n        h = self.relu2_2(self.conv2_2(h))\n        h = self.pool2(h)\n\n        h = self.relu3_1(self.conv3_1(h))\n        h = self.relu3_2(self.conv3_2(h))\n        h = self.relu3_3(self.conv3_3(h))\n        h = self.pool3(h)\n        pool3 = h  # 1/8\n\n        h = self.relu4_1(self.conv4_1(h))\n        h = self.relu4_2(self.conv4_2(h))\n        h = self.relu4_3(self.conv4_3(h))\n        h = self.pool4(h)\n        pool4 = h  # 1/16\n\n        h = self.relu5_1(self.conv5_1(h))\n        h = self.relu5_2(self.conv5_2(h))\n        h = self.relu5_3(self.conv5_3(h))\n        h = self.pool5(h)\n\n        h = self.relu6(self.fc6(h))\n        h = self.drop6(h)\n\n        h = self.relu7(self.fc7(h))\n        h = self.drop7(h)\n\n        h = self.score_fr(h)\n        h = self.upscore2(h)\n        upscore2 = h  # 1/16\n\n        h = self.score_pool4(pool4)\n        h = h[:, :, 5:5 + upscore2.size()[2], 5:5 + upscore2.size()[3]]\n        score_pool4c = h  # 1/16\n\n        h = upscore2 + score_pool4c  # 1/16\n        h = self.upscore_pool4(h)\n        upscore_pool4 = h  # 1/8\n\n        h = self.score_pool3(pool3)\n        h = h[:, :, 9:9 + upscore_pool4.size()[2], 9:9 + upscore_pool4.size()[3]]\n        score_pool3c = h  # 1/8\n\n        h = upscore_pool4 + score_pool3c  # 1/8\n\n        h = self.upscore8(h)\n        h = h[:, :, 31:31 + image.size()[2], 31:31 + image.size()[3]].contiguous()\n\n        return h\n"
  },
  {
    "path": "crfasrnn/filters.py",
    "content": "\"\"\"\nMIT License\n\nCopyright (c) 2019 Sadeep Jayasumana\n\nPermission is hereby granted, free of charge, to any person obtaining a copy\nof this software and associated documentation files (the \"Software\"), to deal\nin the Software without restriction, including without limitation the rights\nto use, copy, modify, merge, publish, distribute, sublicense, and/or sell\ncopies of the Software, and to permit persons to whom the Software is\nfurnished to do so, subject to the following conditions:\n\nThe above copyright notice and this permission notice shall be included in all\ncopies or substantial portions of the Software.\n\nTHE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\nIMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\nFITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\nAUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\nLIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\nOUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\nSOFTWARE.\n\"\"\"\n\nfrom abc import ABC, abstractmethod\n\nimport numpy as np\nimport torch\n\ntry:\n    import permuto_cpp\nexcept ImportError as e:\n    raise (e, \"Did you import `torch` first?\")\n\n_CPU = torch.device(\"cpu\")\n_EPS = np.finfo(\"float\").eps\n\n\nclass PermutoFunction(torch.autograd.Function):\n\n    @staticmethod\n    def forward(ctx, q_in, features):\n        q_out = permuto_cpp.forward(q_in, features)[0]\n        ctx.save_for_backward(features)\n        return q_out\n\n    @staticmethod\n    def backward(ctx, grad_q_out):\n        feature_saved = ctx.saved_tensors[0]\n        grad_q_back = permuto_cpp.backward(\n            grad_q_out.contiguous(), feature_saved.contiguous()\n        )[0]\n        return grad_q_back, None  # No need of grads w.r.t. features\n\n\ndef _spatial_features(image, sigma):\n    \"\"\"\n    Return the spatial features as a Tensor\n\n    Args:\n        image:  Image as a Tensor of shape (channels, height, wight)\n        sigma:  Bandwidth parameter\n\n    Returns:\n        Tensor of shape [h, w, 2] with spatial features\n    \"\"\"\n    sigma = float(sigma)\n    _, h, w = image.size()\n    x = torch.arange(start=0, end=w, dtype=torch.float32, device=_CPU)\n    xx = x.repeat([h, 1]) / sigma\n\n    y = torch.arange(\n        start=0, end=h, dtype=torch.float32, device=torch.device(\"cpu\")\n    ).view(-1, 1)\n    yy = y.repeat([1, w]) / sigma\n\n    return torch.stack([xx, yy], dim=2)\n\n\nclass AbstractFilter(ABC):\n    \"\"\"\n    Super-class for permutohedral-based Gaussian filters\n    \"\"\"\n\n    def __init__(self, image):\n        self.features = self._calc_features(image)\n        self.norm = self._calc_norm(image)\n\n    def apply(self, input_):\n        output = PermutoFunction.apply(input_, self.features)\n        return output * self.norm\n\n    @abstractmethod\n    def _calc_features(self, image):\n        pass\n\n    def _calc_norm(self, image):\n        _, h, w = image.size()\n        all_ones = torch.ones((1, h, w), dtype=torch.float32, device=_CPU)\n        norm = PermutoFunction.apply(all_ones, self.features)\n        return 1.0 / (norm + _EPS)\n\n\nclass SpatialFilter(AbstractFilter):\n    \"\"\"\n    Gaussian filter in the spatial ([x, y]) domain\n    \"\"\"\n\n    def __init__(self, image, gamma):\n        \"\"\"\n        Create new instance\n\n        Args:\n            image:  Image tensor of shape (3, height, width)\n            gamma:  Standard deviation\n        \"\"\"\n        self.gamma = gamma\n        super(SpatialFilter, self).__init__(image)\n\n    def _calc_features(self, image):\n        return _spatial_features(image, self.gamma)\n\n\nclass BilateralFilter(AbstractFilter):\n    \"\"\"\n    Gaussian filter in the bilateral ([r, g, b, x, y]) domain\n    \"\"\"\n\n    def __init__(self, image, alpha, beta):\n        \"\"\"\n        Create new instance\n\n        Args:\n            image:  Image tensor of shape (3, height, width)\n            alpha:  Smoothness (spatial) sigma\n            beta:   Appearance (color) sigma\n        \"\"\"\n        self.alpha = alpha\n        self.beta = beta\n        super(BilateralFilter, self).__init__(image)\n\n    def _calc_features(self, image):\n        xy = _spatial_features(\n            image, self.alpha\n        )  # TODO Possible optimisation, was calculated in the spatial kernel\n        rgb = (image / float(self.beta)).permute(1, 2, 0)  # Channel last order\n        return torch.cat([xy, rgb], dim=2)\n"
  },
  {
    "path": "crfasrnn/params.py",
    "content": "\"\"\"\nMIT License\n\nCopyright (c) 2019 Sadeep Jayasumana\n\nPermission is hereby granted, free of charge, to any person obtaining a copy\nof this software and associated documentation files (the \"Software\"), to deal\nin the Software without restriction, including without limitation the rights\nto use, copy, modify, merge, publish, distribute, sublicense, and/or sell\ncopies of the Software, and to permit persons to whom the Software is\nfurnished to do so, subject to the following conditions:\n\nThe above copyright notice and this permission notice shall be included in all\ncopies or substantial portions of the Software.\n\nTHE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\nIMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\nFITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\nAUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\nLIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\nOUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\nSOFTWARE.\n\"\"\"\n\n\nclass DenseCRFParams(object):\n    \"\"\"\n    Parameters for the DenseCRF model\n    \"\"\"\n\n    def __init__(\n        self,\n        alpha=160.0,\n        beta=3.0,\n        gamma=3.0,\n        spatial_ker_weight=3.0,\n        bilateral_ker_weight=5.0,\n    ):\n        \"\"\"\n        Default values were taken from https://github.com/sadeepj/crfasrnn_keras. More details about these parameters\n        can be found in https://arxiv.org/pdf/1210.5644.pdf\n\n        Args:\n            alpha:                  Bandwidth for the spatial component of the bilateral filter\n            beta:                   Bandwidth for the color component of the bilateral filter\n            gamma:                  Bandwidth for the spatial filter\n            spatial_ker_weight:     Spatial kernel weight\n            bilateral_ker_weight:   Bilateral kernel weight\n        \"\"\"\n        self.alpha = alpha\n        self.beta = beta\n        self.gamma = gamma\n        self.spatial_ker_weight = spatial_ker_weight\n        self.bilateral_ker_weight = bilateral_ker_weight\n"
  },
  {
    "path": "crfasrnn/permuto.cpp",
    "content": "#include <torch/extension.h>\n#include <vector>\n#include <iostream>\n#include <stdexcept>\n#include \"permutohedral.h\"\n\n/**\n *\n * @param input_values  Input values to filter (e.g. Q distributions). Has shape (channels, height, width)\n * @param features      Features for the permutohedral lattice. Has shape (height, width, feature_channels). Note that\n *                      channels are at the end!\n * @return Filtered values with shape (channels, height, width)\n */\nstd::vector<at::Tensor> permuto_forward(torch::Tensor input_values, torch::Tensor features) {\n\n    auto input_sizes = input_values.sizes();  // (channels, height, width)\n    auto feature_sizes = features.sizes();  // (height, width, num_features)\n\n    auto h = feature_sizes[0];\n    auto w = feature_sizes[1];\n    auto n_feature_dims = static_cast<int>(feature_sizes[2]);\n    auto n_pixels = static_cast<int>(h * w);\n    auto n_channels = static_cast<int>(input_sizes[0]);\n\n    // Validate the arguments\n    if (input_sizes[1] != h || input_sizes[2] != w) {\n        throw std::runtime_error(\"Sizes of `input_values` and `features` do not match!\");\n    }\n\n    if (!(input_values.dtype() == torch::kFloat32)) {\n        throw std::runtime_error(\"`input_values` must have float32 type.\");\n    }\n\n    if (!(features.dtype() == torch::kFloat32)) {\n        throw std::runtime_error(\"`features` must have float32 type.\");\n    }\n\n    // Create the output tensor\n    auto options = torch::TensorOptions()\n            .dtype(torch::kFloat32)\n            .layout(torch::kStrided)\n            .device(torch::kCPU)\n            .requires_grad(false);\n\n    auto output_values = torch::empty(input_sizes, options);\n    output_values = output_values.contiguous();\n\n    Permutohedral p;\n    p.init(features.contiguous().data<float>(), n_feature_dims, n_pixels);\n    p.compute(output_values.data<float>(), input_values.contiguous().data<float>(), n_channels);\n\n    return {output_values};\n}\n\n\nstd::vector<at::Tensor> permuto_backward(torch::Tensor grads, torch::Tensor features) {\n\n    auto grad_sizes = grads.sizes();  // (channels, height, width)\n    auto feature_sizes = features.sizes();  // (height, width, num_features)\n\n    auto h = feature_sizes[0];\n    auto w = feature_sizes[1];\n    auto n_feature_dims = static_cast<int>(feature_sizes[2]);\n    auto n_pixels = static_cast<int>(h * w);\n    auto n_channels = static_cast<int>(grad_sizes[0]);\n\n    // Validate the arguments\n    if (grad_sizes[1] != h || grad_sizes[2] != w) {\n        throw std::runtime_error(\"Sizes of `grad_values` and `features` do not match!\");\n    }\n\n    if (!(grads.dtype() == torch::kFloat32)) {\n        throw std::runtime_error(\"`input_values` must have float32 type.\");\n    }\n\n    if (!(features.dtype() == torch::kFloat32)) {\n        throw std::runtime_error(\"`features` must have float32 type.\");\n    }\n\n    // Create the output tensor\n    auto options = torch::TensorOptions()\n            .dtype(torch::kFloat32)\n            .layout(torch::kStrided)\n            .device(torch::kCPU)\n            .requires_grad(false);\n\n    auto grads_back = torch::empty(grad_sizes, options);\n    grads_back = grads_back.contiguous();\n\n    Permutohedral p;\n    p.init(features.contiguous().data<float>(), n_feature_dims, n_pixels);\n    p.compute(grads_back.data<float>(), grads.contiguous().data<float>(), n_channels, true);\n\n    return {grads_back};\n}\n\nPYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {\n    m.def(\"forward\", &permuto_forward, \"PERMUTO forward\");\n    m.def(\"backward\", &permuto_backward, \"PERMUTO backward\");\n}\n"
  },
  {
    "path": "crfasrnn/permutohedral.cpp",
    "content": "/*\n   This file contains a modified version of the \"permutohedral.cpp\" code\n   available at http://graphics.stanford.edu/projects/drf/. Copyright notice of\n   the original file is included below:\n\n    Copyright (c) 2013, Philipp Krähenbühl\n    All rights reserved.\n\n    Redistribution and use in source and binary forms, with or without\n    modification, are permitted provided that the following conditions are met:\n        * Redistributions of source code must retain the above copyright\n        notice, this list of conditions and the following disclaimer.\n        * Redistributions in binary form must reproduce the above copyright\n        notice, this list of conditions and the following disclaimer in the\n        documentation and/or other materials provided with the distribution.\n        * Neither the name of the Stanford University nor the\n        names of its contributors may be used to endorse or promote products\n        derived from this software without specific prior written permission.\n\n    THIS SOFTWARE IS PROVIDED BY Philipp Krähenbühl ''AS IS'' AND ANY\n    EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED\n    WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\n    DISCLAIMED. IN NO EVENT SHALL Philipp Krähenbühl BE LIABLE FOR ANY\n    DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES\n    (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;\n    LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND\n    ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT\n    (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS\n    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n*/\n\n//#include \"stdafx.h\"\n#include \"permutohedral.h\"\n\n#ifdef __SSE__\n// SSE Permutoheral lattice\n# define SSE_PERMUTOHEDRAL\n#endif\n\n#if defined(SSE_PERMUTOHEDRAL)\n# include <emmintrin.h>\n# include <xmmintrin.h>\n# ifdef __SSE4_1__\n#  include <smmintrin.h>\n# endif\n#endif\n\n\n/************************************************/\n/***                Hash Table                ***/\n/************************************************/\n\nclass HashTable{\nprotected:\n\tsize_t key_size_, filled_, capacity_;\n\tstd::vector< short > keys_;\n\tstd::vector< int > table_;\n\tvoid grow(){\n\t\t// Create the new memory and copy the values in\n\t\tint old_capacity = capacity_;\n\t\tcapacity_ *= 2;\n\t\tstd::vector<short> old_keys( (old_capacity+10)*key_size_ );\n\t\tstd::copy( keys_.begin(), keys_.end(), old_keys.begin() );\n\t\tstd::vector<int> old_table( capacity_, -1 );\n\n\t\t// Swap the memory\n\t\ttable_.swap( old_table );\n\t\tkeys_.swap( old_keys );\n\n\t\t// Reinsert each element\n\t\tfor( int i=0; i<old_capacity; i++ )\n\t\t\tif (old_table[i] >= 0){\n\t\t\t\tint e = old_table[i];\n\t\t\t\tsize_t h = hash( getKey(e) ) % capacity_;\n\t\t\t\tfor(; table_[h] >= 0; h = h<capacity_-1 ? h+1 : 0);\n\t\t\t\ttable_[h] = e;\n\t\t\t}\n\t}\n\tsize_t hash( const short * k ) {\n\t\tsize_t r = 0;\n\t\tfor( size_t i=0; i<key_size_; i++ ){\n\t\t\tr += k[i];\n\t\t\tr *= 1664525;\n\t\t}\n\t\treturn r;\n\t}\npublic:\n\texplicit HashTable( int key_size, int n_elements ) : key_size_ ( key_size ), filled_(0), capacity_(2*n_elements), keys_((capacity_/2+10)*key_size_), table_(2*n_elements,-1) {\n\t}\n\tint size() const {\n\t\treturn filled_;\n\t}\n\tvoid reset() {\n\t\tfilled_ = 0;\n\t\tstd::fill( table_.begin(), table_.end(), -1 );\n\t}\n\tint find( const short * k, bool create = false ){\n\t\tif (2*filled_ >= capacity_) grow();\n\t\t// Get the hash value\n\t\tsize_t h = hash( k ) % capacity_;\n\t\t// Find the element with he right key, using linear probing\n\t\twhile(1){\n\t\t\tint e = table_[h];\n\t\t\tif (e==-1){\n\t\t\t\tif (create){\n\t\t\t\t\t// Insert a new key and return the new id\n\t\t\t\t\tfor( size_t i=0; i<key_size_; i++ )\n\t\t\t\t\t\tkeys_[ filled_*key_size_+i ] = k[i];\n\t\t\t\t\treturn table_[h] = filled_++;\n\t\t\t\t}\n\t\t\t\telse\n\t\t\t\t\treturn -1;\n\t\t\t}\n\t\t\t// Check if the current key is The One\n\t\t\tbool good = true;\n\t\t\tfor( size_t i=0; i<key_size_ && good; i++ )\n\t\t\t\tif (keys_[ e*key_size_+i ] != k[i])\n\t\t\t\t\tgood = false;\n\t\t\tif (good)\n\t\t\t\treturn e;\n\t\t\t// Continue searching\n\t\t\th++;\n\t\t\tif (h==capacity_) h = 0;\n\t\t}\n\t}\n\tconst short * getKey( int i ) const{\n\t\treturn &keys_[i*key_size_];\n\t}\n\n};\n\n/************************************************/\n/***          Permutohedral Lattice           ***/\n/************************************************/\n\nPermutohedral::Permutohedral():N_( 0 ), M_( 0 ), d_( 0 ) {\n}\n#ifdef SSE_PERMUTOHEDRAL\nvoid Permutohedral::init(const float* features, int num_dimensions, int num_points)\n{\n\t// Compute the lattice coordinates for each feature [there is going to be a lot of magic here\n\tN_ = num_points;\n\td_ = num_dimensions;\n\tHashTable hash_table( d_, N_/**(d_+1)*/ );\n\n\tconst int blocksize = sizeof(__m128) / sizeof(float);\n\tconst __m128 invdplus1   = _mm_set1_ps( 1.0f / (d_+1) );\n\tconst __m128 dplus1      = _mm_set1_ps( d_+1 );\n\tconst __m128 Zero        = _mm_set1_ps( 0 );\n\tconst __m128 One         = _mm_set1_ps( 1 );\n\n\t// Allocate the class memory\n\toffset_.resize( (d_+1)*(N_+16) );\n\tstd::fill( offset_.begin(), offset_.end(), 0 );\n\tbarycentric_.resize( (d_+1)*(N_+16) );\n\tstd::fill( barycentric_.begin(), barycentric_.end(), 0 );\n\trank_.resize( (d_+1)*(N_+16) );\n\n\t// Allocate the local memory\n\t__m128 * scale_factor = (__m128*) _mm_malloc( (d_  )*sizeof(__m128) , 16 );\n\t__m128 * f            = (__m128*) _mm_malloc( (d_  )*sizeof(__m128) , 16 );\n\t__m128 * elevated     = (__m128*) _mm_malloc( (d_+1)*sizeof(__m128) , 16 );\n\t__m128 * rem0         = (__m128*) _mm_malloc( (d_+1)*sizeof(__m128) , 16 );\n\t__m128 * rank         = (__m128*) _mm_malloc( (d_+1)*sizeof(__m128), 16 );\n\tfloat * barycentric = new float[(d_+2)*blocksize];\n\tshort * canonical = new short[(d_+1)*(d_+1)];\n\tshort * key = new short[d_+1];\n\n\t// Compute the canonical simplex\n\tfor( int i=0; i<=d_; i++ ){\n\t\tfor( int j=0; j<=d_-i; j++ )\n\t\t\tcanonical[i*(d_+1)+j] = i;\n\t\tfor( int j=d_-i+1; j<=d_; j++ )\n\t\t\tcanonical[i*(d_+1)+j] = i - (d_+1);\n\t}\n\n\t// Expected standard deviation of our filter (p.6 in [Adams etal 2010])\n\tfloat inv_std_dev = sqrt(2.0 / 3.0)*(d_+1);\n\t// Compute the diagonal part of E (p.5 in [Adams etal 2010])\n\tfor( int i=0; i<d_; i++ )\n\t\tscale_factor[i] = _mm_set1_ps( 1.0 / sqrt( (i+2)*(i+1) ) * inv_std_dev );\n\n\t// Setup the SSE rounding\n#ifndef __SSE4_1__\n\tconst unsigned int old_rounding = _mm_getcsr();\n\t_mm_setcsr( (old_rounding&~_MM_ROUND_MASK) | _MM_ROUND_NEAREST );\n#endif\n\n\t// Compute the simplex each feature lies in\n\tfor( int k=0; k<N_; k+=blocksize ){\n\t\t// Load the feature from memory\n\t\tfloat * ff = (float*)f;\n\t\tfor( int j=0; j<d_; j++ )\n\t\t\tfor( int i=0; i<blocksize; i++ )\n\t\t\t\tff[ j*blocksize + i ] = k+i < N_ ? *(features + (k + i) * num_dimensions + j) : 0.0;\n\n\t\t// Elevate the feature ( y = Ep, see p.5 in [Adams etal 2010])\n\n\t\t// sm contains the sum of 1..n of our faeture vector\n\t\t__m128 sm = Zero;\n\t\tfor( int j=d_; j>0; j-- ){\n\t\t\t__m128 cf = f[j-1]*scale_factor[j-1];\n\t\t\televated[j] = sm - _mm_set1_ps(j)*cf;\n\t\t\tsm += cf;\n\t\t}\n\t\televated[0] = sm;\n\n\t\t// Find the closest 0-colored simplex through rounding\n\t\t__m128 sum = Zero;\n\t\tfor( int i=0; i<=d_; i++ ){\n\t\t\t__m128 v = invdplus1 * elevated[i];\n#ifdef __SSE4_1__\n\t\t\tv = _mm_round_ps( v, _MM_FROUND_TO_NEAREST_INT );\n#else\n\t\t\tv = _mm_cvtepi32_ps( _mm_cvtps_epi32( v ) );\n#endif\n\t\t\trem0[i] = v*dplus1;\n\t\t\tsum += v;\n\t\t}\n\n\t\t// Find the simplex we are in and store it in rank (where rank describes what position coorinate i has in the sorted order of the features values)\n\t\tfor( int i=0; i<=d_; i++ )\n\t\t\trank[i] = Zero;\n\t\tfor( int i=0; i<d_; i++ ){\n\t\t\t__m128 di = elevated[i] - rem0[i];\n\t\t\tfor( int j=i+1; j<=d_; j++ ){\n\t\t\t\t__m128 dj = elevated[j] - rem0[j];\n\t\t\t\t__m128 c = _mm_and_ps( One, _mm_cmplt_ps( di, dj ) );\n\t\t\t\trank[i] += c;\n\t\t\t\trank[j] += One-c;\n\t\t\t}\n\t\t}\n\n\t\t// If the point doesn't lie on the plane (sum != 0) bring it back\n\t\tfor( int i=0; i<=d_; i++ ){\n\t\t\trank[i] += sum;\n\t\t\t__m128 add = _mm_and_ps( dplus1, _mm_cmplt_ps( rank[i], Zero ) );\n\t\t\t__m128 sub = _mm_and_ps( dplus1, _mm_cmpge_ps( rank[i], dplus1 ) );\n\t\t\trank[i] += add-sub;\n\t\t\trem0[i] += add-sub;\n\t\t}\n\n\t\t// Compute the barycentric coordinates (p.10 in [Adams etal 2010])\n\t\tfor( int i=0; i<(d_+2)*blocksize; i++ )\n\t\t\tbarycentric[ i ] = 0;\n\t\tfor( int i=0; i<=d_; i++ ){\n\t\t\t__m128 v = (elevated[i] - rem0[i])*invdplus1;\n\n\t\t\t// Didn't figure out how to SSE this\n\t\t\tfloat * fv = (float*)&v;\n\t\t\tfloat * frank = (float*)&rank[i];\n\t\t\tfor( int j=0; j<blocksize; j++ ){\n\t\t\t\tint p = d_-frank[j];\n\t\t\t\tbarycentric[j*(d_+2)+p  ] += fv[j];\n\t\t\t\tbarycentric[j*(d_+2)+p+1] -= fv[j];\n\t\t\t}\n\t\t}\n\n\t\t// The rest is not SSE'd\n\t\tfor( int j=0; j<blocksize; j++ ){\n\t\t\t// Wrap around\n\t\t\tbarycentric[j*(d_+2)+0]+= 1 + barycentric[j*(d_+2)+d_+1];\n\n\t\t\tfloat * frank = (float*)rank;\n\t\t\tfloat * frem0 = (float*)rem0;\n\t\t\t// Compute all vertices and their offset\n\t\t\tfor( int remainder=0; remainder<=d_; remainder++ ){\n\t\t\t\tfor( int i=0; i<d_; i++ ){\n\t\t\t\t\tkey[i] = frem0[i*blocksize+j] + canonical[ remainder*(d_+1) + (int)frank[i*blocksize+j] ];\n\t\t\t\t}\n\t\t\t\toffset_[ (j+k)*(d_+1)+remainder ] = hash_table.find( key, true );\n\t\t\t\trank_[ (j+k)*(d_+1)+remainder ] = frank[remainder*blocksize+j];\n\t\t\t\tbarycentric_[ (j+k)*(d_+1)+remainder ] = barycentric[ j*(d_+2)+remainder ];\n\t\t\t}\n\t\t}\n\t}\n\t_mm_free( scale_factor );\n\t_mm_free( f );\n\t_mm_free( elevated );\n\t_mm_free( rem0 );\n\t_mm_free( rank );\n\tdelete [] barycentric;\n\tdelete [] canonical;\n\tdelete [] key;\n\n\t// Reset the SSE rounding\n#ifndef __SSE4_1__\n\t_mm_setcsr( old_rounding );\n#endif\n\n\t// This is normally fast enough so no SSE needed here\n\t// Find the Neighbors of each lattice point\n\n\t// Get the number of vertices in the lattice\n\tM_ = hash_table.size();\n\n\t// Create the neighborhood structure\n\tblur_neighbors_.resize( (d_+1)*M_ );\n\n\tshort * n1 = new short[d_+1];\n\tshort * n2 = new short[d_+1];\n\n\t// For each of d+1 axes,\n\tfor( int j = 0; j <= d_; j++ ){\n\t\tfor( int i=0; i<M_; i++ ){\n\t\t\tconst short * key = hash_table.getKey( i );\n\t\t\tfor( int k=0; k<d_; k++ ){\n\t\t\t\tn1[k] = key[k] - 1;\n\t\t\t\tn2[k] = key[k] + 1;\n\t\t\t}\n\t\t\tn1[j] = key[j] + d_;\n\t\t\tn2[j] = key[j] - d_;\n\n\t\t\tblur_neighbors_[j*M_+i].n1 = hash_table.find( n1 );\n\t\t\tblur_neighbors_[j*M_+i].n2 = hash_table.find( n2 );\n\t\t}\n\t}\n\tdelete[] n1;\n\tdelete[] n2;\n}\n#else\nvoid Permutohedral::init (const float* features, int num_dimensions, int num_points)\n{\n\t// Compute the lattice coordinates for each feature [there is going to be a lot of magic here\n\tN_ = num_points;\n\td_ = num_dimensions;\n\tHashTableCopy hash_table( d_, N_*(d_+1) );\n\n\t// Allocate the class memory\n\toffset_.resize( (d_+1)*N_ );\n\trank_.resize( (d_+1)*N_ );\n\tbarycentric_.resize( (d_+1)*N_ );\n\n\t// Allocate the local memory\n\tfloat * scale_factor = new float[d_];\n\tfloat * elevated = new float[d_+1];\n\tfloat * rem0 = new float[d_+1];\n\tfloat * barycentric = new float[d_+2];\n\tshort * rank = new short[d_+1];\n\tshort * canonical = new short[(d_+1)*(d_+1)];\n\tshort * key = new short[d_+1];\n\n\t// Compute the canonical simplex\n\tfor( int i=0; i<=d_; i++ ){\n\t\tfor( int j=0; j<=d_-i; j++ )\n\t\t\tcanonical[i*(d_+1)+j] = i;\n\t\tfor( int j=d_-i+1; j<=d_; j++ )\n\t\t\tcanonical[i*(d_+1)+j] = i - (d_+1);\n\t}\n\n\t// Expected standard deviation of our filter (p.6 in [Adams etal 2010])\n\tfloat inv_std_dev = sqrt(2.0 / 3.0)*(d_+1);\n\t// Compute the diagonal part of E (p.5 in [Adams etal 2010])\n\tfor( int i=0; i<d_; i++ )\n\t\tscale_factor[i] = 1.0 / sqrt( double((i+2)*(i+1)) ) * inv_std_dev;\n\n\t// Compute the simplex each feature lies in\n\tfor( int k=0; k<N_; k++ ){\n\t\t// Elevate the feature ( y = Ep, see p.5 in [Adams etal 2010])\n        assert false;  # Shouldn't reach here\n\t\tconst float * f = (feature + k * num_dimensions);\n\n\t\t// sm contains the sum of 1..n of our faeture vector\n\t\tfloat sm = 0;\n\t\tfor( int j=d_; j>0; j-- ){\n\t\t\tfloat cf = f[j-1]*scale_factor[j-1];\n\t\t\televated[j] = sm - j*cf;\n\t\t\tsm += cf;\n\t\t}\n\t\televated[0] = sm;\n\n\t\t// Find the closest 0-colored simplex through rounding\n\t\tfloat down_factor = 1.0f / (d_+1);\n\t\tfloat up_factor = (d_+1);\n\t\tint sum = 0;\n\t\tfor( int i=0; i<=d_; i++ ){\n\t\t\t//int rd1 = round( down_factor * elevated[i]);\n\t\t\tint rd2;\n\t\t\tfloat v = down_factor * elevated[i];\n\t\t\tfloat up = ceilf(v)*up_factor;\n\t\t\tfloat down = floorf(v)*up_factor;\n\t\t\tif (up - elevated[i] < elevated[i] - down) rd2 = (short)up;\n\t\t\telse rd2 = (short)down;\n\n\t\t\t//if(rd1!=rd2)\n\t\t\t//\tbreak;\n\n\t\t\trem0[i] = rd2;\n\t\t\tsum += rd2*down_factor;\n\t\t}\n\n\t\t// Find the simplex we are in and store it in rank (where rank describes what position coorinate i has in the sorted order of the features values)\n\t\tfor( int i=0; i<=d_; i++ )\n\t\t\trank[i] = 0;\n\t\tfor( int i=0; i<d_; i++ ){\n\t\t\tdouble di = elevated[i] - rem0[i];\n\t\t\tfor( int j=i+1; j<=d_; j++ )\n\t\t\t\tif ( di < elevated[j] - rem0[j])\n\t\t\t\t\trank[i]++;\n\t\t\t\telse\n\t\t\t\t\trank[j]++;\n\t\t}\n\n\t\t// If the point doesn't lie on the plane (sum != 0) bring it back\n\t\tfor( int i=0; i<=d_; i++ ){\n\t\t\trank[i] += sum;\n\t\t\tif ( rank[i] < 0 ){\n\t\t\t\trank[i] += d_+1;\n\t\t\t\trem0[i] += d_+1;\n\t\t\t}\n\t\t\telse if ( rank[i] > d_ ){\n\t\t\t\trank[i] -= d_+1;\n\t\t\t\trem0[i] -= d_+1;\n\t\t\t}\n\t\t}\n\n\t\t// Compute the barycentric coordinates (p.10 in [Adams etal 2010])\n\t\tfor( int i=0; i<=d_+1; i++ )\n\t\t\tbarycentric[i] = 0;\n\t\tfor( int i=0; i<=d_; i++ ){\n\t\t\tfloat v = (elevated[i] - rem0[i])*down_factor;\n\t\t\tbarycentric[d_-rank[i]  ] += v;\n\t\t\tbarycentric[d_-rank[i]+1] -= v;\n\t\t}\n\t\t// Wrap around\n\t\tbarycentric[0] += 1.0 + barycentric[d_+1];\n\n\t\t// Compute all vertices and their offset\n\t\tfor( int remainder=0; remainder<=d_; remainder++ ){\n\t\t\tfor( int i=0; i<d_; i++ )\n\t\t\t\tkey[i] = rem0[i] + canonical[ remainder*(d_+1) + rank[i] ];\n\t\t\toffset_[ k*(d_+1)+remainder ] = hash_table.find( key, true );\n\t\t\trank_[ k*(d_+1)+remainder ] = rank[remainder];\n\t\t\tbarycentric_[ k*(d_+1)+remainder ] = barycentric[ remainder ];\n\t\t}\n\t}\n\tdelete [] scale_factor;\n\tdelete [] elevated;\n\tdelete [] rem0;\n\tdelete [] barycentric;\n\tdelete [] rank;\n\tdelete [] canonical;\n\tdelete [] key;\n\n\n\t// Find the Neighbors of each lattice point\n\n\t// Get the number of vertices in the lattice\n\tM_ = hash_table.size();\n\n\t// Create the neighborhood structure\n\tblur_neighbors_.resize( (d_+1)*M_ );\n\n\tshort * n1 = new short[d_+1];\n\tshort * n2 = new short[d_+1];\n\n\t// For each of d+1 axes,\n\tfor( int j = 0; j <= d_; j++ ){\n\t\tfor( int i=0; i<M_; i++ ){\n\t\t\tconst short * key = hash_table.getKey( i );\n\t\t\tfor( int k=0; k<d_; k++ ){\n\t\t\t\tn1[k] = key[k] - 1;\n\t\t\t\tn2[k] = key[k] + 1;\n\t\t\t}\n\t\t\tn1[j] = key[j] + d_;\n\t\t\tn2[j] = key[j] - d_;\n\n\t\t\tblur_neighbors_[j*M_+i].n1 = hash_table.find( n1 );\n\t\t\tblur_neighbors_[j*M_+i].n2 = hash_table.find( n2 );\n\t\t}\n\t}\n\tdelete[] n1;\n\tdelete[] n2;\n}\n#endif\nvoid Permutohedral::seqCompute(float* out, const float* in, int value_size, bool reverse, bool add) const\n{\n\t// Shift all values by 1 such that -1 -> 0 (used for blurring)\n\tfloat * values = new float[ (M_+2)*value_size ];\n\tfloat * new_values = new float[ (M_+2)*value_size ];\n\n\tfor( int i=0; i<(M_+2)*value_size; i++ )\n\t\tvalues[i] = new_values[i] = 0;\n\n\t// Splatting\n\tfor( int i=0;  i<N_; i++ ){\n\t\tfor( int j=0; j<=d_; j++ ){\n\t\t\tint o = offset_[i*(d_+1)+j]+1;\n\t\t\tfloat w = barycentric_[i*(d_+1)+j];\n\t\t\tfor( int k=0; k<value_size; k++ )\n\t\t\t\tvalues[ o*value_size+k ] += w * in[k*N_ + i];\n\t\t}\n\t}\n\n\tfor( int j=reverse?d_:0; j<=d_ && j>=0; reverse?j--:j++ ){\n\t\tfor( int i=0; i<M_; i++ ){\n\t\t\tfloat * old_val = values + (i+1)*value_size;\n\t\t\tfloat * new_val = new_values + (i+1)*value_size;\n\n\t\t\tint n1 = blur_neighbors_[j*M_+i].n1+1;\n\t\t\tint n2 = blur_neighbors_[j*M_+i].n2+1;\n\t\t\tfloat * n1_val = values + n1*value_size;\n\t\t\tfloat * n2_val = values + n2*value_size;\n\t\t\tfor( int k=0; k<value_size; k++ )\n\t\t\t\tnew_val[k] = old_val[k]+0.5*(n1_val[k] + n2_val[k]);\n\t\t}\n\t\tstd::swap( values, new_values );\n\t}\n\t// Alpha is a magic scaling constant (write Andrew if you really wanna understand this)\n\tfloat alpha = 1.0f / (1+powf(2, -d_));\n\n\t// Slicing\n\tfor( int i=0; i<N_; i++ ){\n        if (!add) {\n          for( int k=0; k<value_size; k++ )\n            out[i + k*N_] = 0; //out[i*value_size+k] = 0;\n        }\n\t\tfor( int j=0; j<=d_; j++ ){\n\t\t\tint o = offset_[i*(d_+1)+j]+1;\n\t\t\tfloat w = barycentric_[i*(d_+1)+j];\n\t\t\tfor( int k=0; k<value_size; k++ )\n\t\t\t  out[ i + k*N_ ] += w * values[ o*value_size+k ] * alpha;\n\t\t}\n\t}\n\n\n\tdelete[] values;\n\tdelete[] new_values;\n}\n\n#ifdef SSE_PERMUTOHEDRAL\nvoid Permutohedral::sseCompute( float* out, const float* in, int value_size, const bool reverse, const bool add) const\n{\n\tconst int sse_value_size = (value_size-1)*sizeof(float) / sizeof(__m128) + 1;\n\t// Shift all values by 1 such that -1 -> 0 (used for blurring)\n\t__m128 * sse_val    = (__m128*) _mm_malloc( sse_value_size*sizeof(__m128), 16 );\n\t__m128 * values     = (__m128*) _mm_malloc( (M_+2)*sse_value_size*sizeof(__m128), 16 );\n\t__m128 * new_values = (__m128*) _mm_malloc( (M_+2)*sse_value_size*sizeof(__m128), 16 );\n\n\t__m128 Zero = _mm_set1_ps( 0 );\n\n\tfor( int i=0; i<(M_+2)*sse_value_size; i++ )\n\t\tvalues[i] = new_values[i] = Zero;\n\tfor( int i=0; i<sse_value_size; i++ )\n\t\tsse_val[i] = Zero;\n\n\tfloat* sdp_temp = new float[value_size];\n\n\t// Splatting\n\tfor( int i=0;  i<N_; i++ ){\n\n\n\t\tfor (int s = 0; s < value_size; s++) {\n\t\t  sdp_temp[s] = in[s*N_ + i];\n\t\t}\n\t\tmemcpy(sse_val, sdp_temp, value_size*sizeof(float));\n\n\t\tfor( int j=0; j<=d_; j++ ){\n\t\t\tint o = offset_[i*(d_+1)+j]+1;\n\t\t\t__m128 w = _mm_set1_ps( barycentric_[i*(d_+1)+j] );\n\t\t\tfor( int k=0; k<sse_value_size; k++ )\n\t\t\t\tvalues[ o*sse_value_size+k ] += w * sse_val[k];\n\t\t}\n\t}\n\t// Blurring\n\t__m128 half = _mm_set1_ps(0.5);\n\tfor( int j=reverse?d_:0; j<=d_ && j>=0; reverse?j--:j++ ){\n\t\tfor( int i=0; i<M_; i++ ){\n\t\t\t__m128 * old_val = values + (i+1)*sse_value_size;\n\t\t\t__m128 * new_val = new_values + (i+1)*sse_value_size;\n\n\t\t\tint n1 = blur_neighbors_[j*M_+i].n1+1;\n\t\t\tint n2 = blur_neighbors_[j*M_+i].n2+1;\n\t\t\t__m128 * n1_val = values + n1*sse_value_size;\n\t\t\t__m128 * n2_val = values + n2*sse_value_size;\n\t\t\tfor( int k=0; k<sse_value_size; k++ )\n\t\t\t\tnew_val[k] = old_val[k]+half*(n1_val[k] + n2_val[k]);\n\t\t}\n\t\tstd::swap( values, new_values );\n\t}\n\t// Alpha is a magic scaling constant (write Andrew if you really wanna understand this)\n\tfloat alpha = 1.0f / (1+powf(2, -d_));\n\n\t// Slicing\n\tfor( int i=0; i<N_; i++ ){\n\t\tfor( int k=0; k<sse_value_size; k++ )\n\t\t\tsse_val[ k ] = Zero;\n\t\tfor( int j=0; j<=d_; j++ ){\n\t\t\tint o = offset_[i*(d_+1)+j]+1;\n\t\t\t__m128 w = _mm_set1_ps( barycentric_[i*(d_+1)+j] * alpha );\n\t\t\tfor( int k=0; k<sse_value_size; k++ )\n\t\t\t\tsse_val[ k ] += w * values[ o*sse_value_size+k ];\n\t\t}\n\n\t\tmemcpy(sdp_temp, sse_val, value_size*sizeof(float) );\n        if (!add) {\n          for (int s = 0; s < value_size; s++) {\n            out[i + s*N_] = sdp_temp[s];\n          }\n        } else {\n          for (int s = 0; s < value_size; s++) {\n            out[i + s*N_] += sdp_temp[s];\n          }\n        }\n\t}\n\n\t_mm_free( sse_val );\n\t_mm_free( values );\n\t_mm_free( new_values );\n\tdelete[] sdp_temp;\n}\n#else\nvoid Permutohedral::sseCompute( float* out, const float* in, int value_size, bool reverse, bool add) const\n{\n\tseqCompute( out, in, value_size, reverse, add);\n}\n#endif\n\n\nvoid Permutohedral::compute(float * out, const float * in, int value_size, bool reverse, bool add) const\n{\n\tif (value_size <= 2)\n\t\tseqCompute(out, in, value_size, reverse, add);\n\telse\n\t\tsseCompute(out, in, value_size, reverse, add);\n}\n"
  },
  {
    "path": "crfasrnn/permutohedral.h",
    "content": "/*\n   This file contains a modified version of the \"permutohedral.h\" code\n   available at http://graphics.stanford.edu/projects/drf/. Copyright notice of\n   the original file is included below:\n\n    Copyright (c) 2013, Philipp Krähenbühl\n    All rights reserved.\n\n    Redistribution and use in source and binary forms, with or without\n    modification, are permitted provided that the following conditions are met:\n        * Redistributions of source code must retain the above copyright\n        notice, this list of conditions and the following disclaimer.\n        * Redistributions in binary form must reproduce the above copyright\n        notice, this list of conditions and the following disclaimer in the\n        documentation and/or other materials provided with the distribution.\n        * Neither the name of the Stanford University nor the\n        names of its contributors may be used to endorse or promote products\n        derived from this software without specific prior written permission.\n\n    THIS SOFTWARE IS PROVIDED BY Philipp Krähenbühl ''AS IS'' AND ANY\n    EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED\n    WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\n    DISCLAIMED. IN NO EVENT SHALL Philipp Krähenbühl BE LIABLE FOR ANY\n    DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES\n    (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;\n    LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND\n    ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT\n    (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS\n    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n*/\n#pragma once\n#include <cstdlib>\n#include <vector>\n#include <cstring>\n#include <cassert>\n#include <cstdio>\n#include <cmath>\n\n\n/************************************************/\n/***          Permutohedral Lattice   ***/\n/************************************************/\nclass Permutohedral {\nprotected:\n  struct Neighbors {\n    int n1, n2;\n\n    Neighbors(int n1 = 0, int n2 = 0) : n1(n1), n2(n2) {\n    }\n  };\n\n  std::vector<int> offset_, rank_;\n  std::vector<float> barycentric_;\n  std::vector<Neighbors> blur_neighbors_;\n  // Number of elements, size of sparse discretized space, dimension of features\n  int N_, M_, d_;\n\n  void sseCompute(float *out, const float *in, int value_size, bool reverse = false, bool add = false) const;\n\n  void seqCompute(float *out, const float *in, int value_size, bool reverse = false, bool add = false) const;\n\npublic:\n  Permutohedral();\n\n  void init(const float *features, int num_dimensions, int num_points);\n\n  void compute(float *out, const float *in, int value_size, bool reverse = false, bool add = false) const;\n};\n"
  },
  {
    "path": "crfasrnn/setup.py",
    "content": "from setuptools import setup, Extension\nfrom torch.utils import cpp_extension\n\nsetup(name='permuto_cpp',\n      ext_modules=[cpp_extension.CppExtension('permuto_cpp', ['permuto.cpp', 'permutohedral.cpp'])],\n      cmdclass={'build_ext': cpp_extension.BuildExtension})\n"
  },
  {
    "path": "crfasrnn/util.py",
    "content": "\"\"\"\nMIT License\n\nCopyright (c) 2019 Sadeep Jayasumana\n\nPermission is hereby granted, free of charge, to any person obtaining a copy\nof this software and associated documentation files (the \"Software\"), to deal\nin the Software without restriction, including without limitation the rights\nto use, copy, modify, merge, publish, distribute, sublicense, and/or sell\ncopies of the Software, and to permit persons to whom the Software is\nfurnished to do so, subject to the following conditions:\n\nThe above copyright notice and this permission notice shall be included in all\ncopies or substantial portions of the Software.\n\nTHE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\nIMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\nFITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\nAUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\nLIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\nOUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\nSOFTWARE.\n\"\"\"\n\nimport numpy as np\nfrom PIL import Image\n\n# Pascal VOC color palette for labels\n_PALETTE = [0, 0, 0,\n            128, 0, 0,\n            0, 128, 0,\n            128, 128, 0,\n            0, 0, 128,\n            128, 0, 128,\n            0, 128, 128,\n            128, 128, 128,\n            64, 0, 0,\n            192, 0, 0,\n            64, 128, 0,\n            192, 128, 0,\n            64, 0, 128,\n            192, 0, 128,\n            64, 128, 128,\n            192, 128, 128,\n            0, 64, 0,\n            128, 64, 0,\n            0, 192, 0,\n            128, 192, 0,\n            0, 64, 128,\n            128, 64, 128,\n            0, 192, 128,\n            128, 192, 128,\n            64, 64, 0,\n            192, 64, 0,\n            64, 192, 0,\n            192, 192, 0]\n\n_IMAGENET_MEANS = np.array([123.68, 116.779, 103.939], dtype=np.float32)  # RGB mean values\n\n\ndef get_preprocessed_image(file_name):\n    \"\"\"\n    Reads an image from the disk, pre-processes it by subtracting mean etc. and\n    returns a numpy array that's ready to be fed into the PyTorch model.\n\n    Args:\n        file_name:  File to read the image from\n\n    Returns:\n        A tuple containing:\n\n        (preprocessed image, img_h, img_w, original width & height)\n    \"\"\"\n\n    image = Image.open(file_name)\n    original_size = image.size\n    w, h = original_size\n    ratio = min(500.0 / w, 500.0 / h)\n    image = image.resize((int(w * ratio), int(h * ratio)), resample=Image.BILINEAR)\n    im = np.array(image).astype(np.float32)\n    assert im.ndim == 3, 'Only RGB images are supported.'\n    im = im[:, :, :3]\n    im = im - _IMAGENET_MEANS\n    im = im[:, :, ::-1]  # Convert to BGR\n    img_h, img_w, _ = im.shape\n\n    pad_h = 500 - img_h\n    pad_w = 500 - img_w\n    im = np.pad(im, pad_width=((0, pad_h), (0, pad_w), (0, 0)), mode='constant', constant_values=0)\n    return np.expand_dims(im.transpose([2, 0, 1]), 0), img_h, img_w, original_size\n\n\ndef get_label_image(probs, img_h, img_w, original_size):\n    \"\"\"\n    Returns the label image (PNG with Pascal VOC colormap) given the probabilities.\n\n    Args:\n        probs:  Probability output of shape (num_labels, height, width)\n        img_h:  Image height\n        img_w:  Image width\n        original_size: Original image size (width, height)\n\n    Returns:\n        Label image as a PIL Image\n    \"\"\"\n\n    labels = probs.argmax(axis=0).astype('uint8')[:img_h, :img_w]\n    label_im = Image.fromarray(labels, 'P')\n    label_im.putpalette(_PALETTE)\n    label_im = label_im.resize(original_size)\n    return label_im\n"
  },
  {
    "path": "quick_run.py",
    "content": "\"\"\"\nMIT License\n\nCopyright (c) 2019 Sadeep Jayasumana\n\nPermission is hereby granted, free of charge, to any person obtaining a copy\nof this software and associated documentation files (the \"Software\"), to deal\nin the Software without restriction, including without limitation the rights\nto use, copy, modify, merge, publish, distribute, sublicense, and/or sell\ncopies of the Software, and to permit persons to whom the Software is\nfurnished to do so, subject to the following conditions:\n\nThe above copyright notice and this permission notice shall be included in all\ncopies or substantial portions of the Software.\n\nTHE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\nIMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\nFITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\nAUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\nLIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\nOUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\nSOFTWARE.\n\"\"\"\n\nimport argparse\n\nimport torch\n\nfrom crfasrnn import util\nfrom crfasrnn.crfasrnn_model import CrfRnnNet\n\n\ndef main():\n    parser = argparse.ArgumentParser()\n\n    parser.add_argument(\n        \"--weights\",\n        help=\"Path to the .pth file (download from https://tinyurl.com/crfasrnn-weights-pth)\",\n        required=True,\n    )\n    parser.add_argument(\"--image\", help=\"Path to the input image\", required=True)\n    parser.add_argument(\"--output\", help=\"Path to the output label image\", default=None)\n    args = parser.parse_args()\n\n    img_data, img_h, img_w, size = util.get_preprocessed_image(args.image)\n\n    output_file = args.output or args.imaage + \"_labels.png\"\n\n    model = CrfRnnNet()\n    model.load_state_dict(torch.load(args.weights))\n    model.eval()\n    out = model.forward(torch.from_numpy(img_data))\n\n    probs = out.detach().numpy()[0]\n    label_im = util.get_label_image(probs, img_h, img_w, size)\n    label_im.save(output_file)\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "requirements.txt",
    "content": "torch\ntorchvision\nPillow\n\n"
  },
  {
    "path": "run_demo.py",
    "content": "\"\"\"\nMIT License\n\nCopyright (c) 2019 Sadeep Jayasumana\n\nPermission is hereby granted, free of charge, to any person obtaining a copy\nof this software and associated documentation files (the \"Software\"), to deal\nin the Software without restriction, including without limitation the rights\nto use, copy, modify, merge, publish, distribute, sublicense, and/or sell\ncopies of the Software, and to permit persons to whom the Software is\nfurnished to do so, subject to the following conditions:\n\nThe above copyright notice and this permission notice shall be included in all\ncopies or substantial portions of the Software.\n\nTHE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\nIMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\nFITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\nAUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\nLIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\nOUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\nSOFTWARE.\n\"\"\"\nimport torch\n\nfrom crfasrnn import util\nfrom crfasrnn.crfasrnn_model import CrfRnnNet\n\n\ndef main():\n    input_file = \"image.jpg\"\n    output_file = \"labels.png\"\n\n    # Read the image\n    img_data, img_h, img_w, size = util.get_preprocessed_image(input_file)\n\n    # Download the model from https://tinyurl.com/crfasrnn-weights-pth\n    saved_weights_path = \"crfasrnn_weights.pth\"\n\n    model = CrfRnnNet()\n    model.load_state_dict(torch.load(saved_weights_path))\n    model.eval()\n    out = model.forward(torch.from_numpy(img_data))\n\n    probs = out.detach().numpy()[0]\n    label_im = util.get_label_image(probs, img_h, img_w, size)\n    label_im.save(output_file)\n\n\nif __name__ == \"__main__\":\n    main()\n"
  }
]