Repository: sadeepj/crfasrnn_pytorch Branch: master Commit: 24899c528981 Files: 18 Total size: 56.1 KB Directory structure: gitextract_ryfdypmp/ ├── .gitattributes ├── .gitignore ├── LICENSE ├── README.md ├── crfasrnn/ │ ├── __init__.py │ ├── crfasrnn_model.py │ ├── crfrnn.py │ ├── fcn8s.py │ ├── filters.py │ ├── params.py │ ├── permuto.cpp │ ├── permutohedral.cpp │ ├── permutohedral.h │ ├── setup.py │ └── util.py ├── quick_run.py ├── requirements.txt └── run_demo.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: .gitattributes ================================================ crfasrnn/permutohedral.cpp linguist-vendored crfasrnn/permutohedral.h linguist-vendored ================================================ FILE: .gitignore ================================================ .idea __pycache__ .pyc ================================================ FILE: LICENSE ================================================ MIT License Copyright (c) 2017 Sadeep Jayasumana Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ================================================ FILE: README.md ================================================ # CRF-RNN for Semantic Image Segmentation - PyTorch version ![sample](sample.png) Live demo:                           [http://crfasrnn.torr.vision](http://crfasrnn.torr.vision)
Caffe version:                      [http://github.com/torrvision/crfasrnn](http://github.com/torrvision/crfasrnn)
Tensorflow/Keras version: [http://github.com/sadeepj/crfasrnn_keras](http://github.com/sadeepj/crfasrnn_keras)
This repository contains the official PyTorch implementation of the "CRF-RNN" semantic image segmentation method, published in the ICCV 2015 paper [Conditional Random Fields as Recurrent Neural Networks](http://www.robots.ox.ac.uk/~szheng/papers/CRFasRNN.pdf). The [online demo](http://crfasrnn.torr.vision) of this project won the Best Demo Prize at ICCV 2015. Results of this PyTorch code are identical to that of the Caffe and Tensorflow/Keras based versions above. If you use this code/model for your research, please cite the following paper: ``` @inproceedings{crfasrnn_ICCV2015, author = {Shuai Zheng and Sadeep Jayasumana and Bernardino Romera-Paredes and Vibhav Vineet and Zhizhong Su and Dalong Du and Chang Huang and Philip H. S. Torr}, title = {Conditional Random Fields as Recurrent Neural Networks}, booktitle = {International Conference on Computer Vision (ICCV)}, year = {2015} } ``` ## Installation Guide _Note_: If you are using a Python virtualenv, make sure it is activated before running each command in this guide. ### Step 1: Clone the repository ``` $ git clone https://github.com/sadeepj/crfasrnn_pytorch.git ``` The root directory of the clone will be referred to as `crfasrnn_pytorch` hereafter. ### Step 2: Install dependencies Use the `requirements.txt` file in this repository to install all the dependencies via `pip`: ``` $ cd crfasrnn_pytorch $ pip install -r requirements.txt ``` After installing the dependencies, run the following commands to make sure they are properly installed: ``` $ python >>> import torch ``` You should not see any errors while importing `torch` above. ### Step 3: Build CRF-RNN custom op Run `setup.py` inside the `crfasrnn_pytorch/crfasrnn` directory: ``` $ cd crfasrnn_pytorch/crfasrnn $ python setup.py install ``` Note that the `python` command in the console should refer to the Python interpreter associated with your PyTorch installation. ### Step 4: Download the pre-trained model weights Download the model weights from [here](https://github.com/sadeepj/crfasrnn_pytorch/releases/download/0.0.1/crfasrnn_weights.pth) and place it in the `crfasrnn_pytorch` directory with the file name `crfasrnn_weights.pth`. ### Step 5: Run the demo ``` $ cd crfasrnn_pytorch $ python run_demo.py ``` If all goes well, you will see the segmentation results in a file named "labels.png". ## Contributors * Sadeep Jayasumana ([sadeepj](https://github.com/sadeepj)) * Harsha Ranasinghe ([HarshaPrabhath](https://github.com/HarshaPrabhath)) ================================================ FILE: crfasrnn/__init__.py ================================================ ================================================ FILE: crfasrnn/crfasrnn_model.py ================================================ """ MIT License Copyright (c) 2019 Sadeep Jayasumana Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ from crfasrnn.crfrnn import CrfRnn from crfasrnn.fcn8s import Fcn8s class CrfRnnNet(Fcn8s): """ The full CRF-RNN network with the FCN-8s backbone as described in the paper: Conditional Random Fields as Recurrent Neural Networks, S. Zheng, S. Jayasumana, B. Romera-Paredes, V. Vineet, Z. Su, D. Du, C. Huang and P. Torr, ICCV 2015 (https://arxiv.org/abs/1502.03240). """ def __init__(self): super(CrfRnnNet, self).__init__() self.crfrnn = CrfRnn(num_labels=21, num_iterations=10) def forward(self, image): out = super(CrfRnnNet, self).forward(image) # Plug the CRF-RNN module at the end return self.crfrnn(image, out) ================================================ FILE: crfasrnn/crfrnn.py ================================================ """ MIT License Copyright (c) 2019 Sadeep Jayasumana Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ import torch import torch.nn as nn from crfasrnn.filters import SpatialFilter, BilateralFilter from crfasrnn.params import DenseCRFParams class CrfRnn(nn.Module): """ PyTorch implementation of the CRF-RNN module described in the paper: Conditional Random Fields as Recurrent Neural Networks, S. Zheng, S. Jayasumana, B. Romera-Paredes, V. Vineet, Z. Su, D. Du, C. Huang and P. Torr, ICCV 2015 (https://arxiv.org/abs/1502.03240). """ def __init__(self, num_labels, num_iterations=5, crf_init_params=None): """ Create a new instance of the CRF-RNN layer. Args: num_labels: Number of semantic labels in the dataset num_iterations: Number of mean-field iterations to perform crf_init_params: CRF initialization parameters """ super(CrfRnn, self).__init__() if crf_init_params is None: crf_init_params = DenseCRFParams() self.params = crf_init_params self.num_iterations = num_iterations self._softmax = torch.nn.Softmax(dim=0) self.num_labels = num_labels # -------------------------------------------------------------------------------------------- # --------------------------------- Trainable Parameters ------------------------------------- # -------------------------------------------------------------------------------------------- # Spatial kernel weights self.spatial_ker_weights = nn.Parameter( crf_init_params.spatial_ker_weight * torch.eye(num_labels, dtype=torch.float32) ) # Bilateral kernel weights self.bilateral_ker_weights = nn.Parameter( crf_init_params.bilateral_ker_weight * torch.eye(num_labels, dtype=torch.float32) ) # Compatibility transform matrix self.compatibility_matrix = nn.Parameter( torch.eye(num_labels, dtype=torch.float32) ) def forward(self, image, logits): """ Perform CRF inference. Args: image: Tensor of shape (3, h, w) containing the RGB image logits: Tensor of shape (num_classes, h, w) containing the unary logits Returns: log-Q distributions (logits) after CRF inference """ if logits.shape[0] != 1: raise ValueError("Only batch size 1 is currently supported!") image = image[0] logits = logits[0] spatial_filter = SpatialFilter(image, gamma=self.params.gamma) bilateral_filter = BilateralFilter( image, alpha=self.params.alpha, beta=self.params.beta ) _, h, w = image.shape cur_logits = logits for _ in range(self.num_iterations): # Normalization q_values = self._softmax(cur_logits) # Spatial filtering spatial_out = torch.mm( self.spatial_ker_weights, spatial_filter.apply(q_values).view(self.num_labels, -1), ) # Bilateral filtering bilateral_out = torch.mm( self.bilateral_ker_weights, bilateral_filter.apply(q_values).view(self.num_labels, -1), ) # Compatibility transform msg_passing_out = ( spatial_out + bilateral_out ) # Shape: (self.num_labels, -1) msg_passing_out = torch.mm(self.compatibility_matrix, msg_passing_out).view( self.num_labels, h, w ) # Adding unary potentials cur_logits = msg_passing_out + logits return torch.unsqueeze(cur_logits, 0) ================================================ FILE: crfasrnn/fcn8s.py ================================================ """ This file contains a modified version of the FCN-8s code available in https://github.com/wkentaro/pytorch-fcn The original copyright notice from that repository is included below: Copyright (c) 2017 - 2019 Kentaro Wada. Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ import numpy as np import torch import torch.nn as nn def _upsampling_weights(in_channels, out_channels, kernel_size): factor = (kernel_size + 1) // 2 if kernel_size % 2 == 1: center = factor - 1 else: center = factor - 0.5 og = np.ogrid[:kernel_size, :kernel_size] filt = (1 - abs(og[0] - center) / factor) * (1 - abs(og[1] - center) / factor) weight = np.zeros( (in_channels, out_channels, kernel_size, kernel_size), dtype=np.float64 ) weight[range(in_channels), range(out_channels), :, :] = filt return torch.from_numpy(weight).float() class Fcn8s(nn.Module): def __init__(self, n_class=21): """ Create the FCN-8s network the the given number of classes. Args: n_class: The number of semantic classes. """ super(Fcn8s, self).__init__() # conv1 self.conv1_1 = nn.Conv2d(3, 64, 3, padding=100) self.relu1_1 = nn.ReLU(inplace=True) self.conv1_2 = nn.Conv2d(64, 64, 3, padding=1) self.relu1_2 = nn.ReLU(inplace=True) self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True) # conv2 self.conv2_1 = nn.Conv2d(64, 128, 3, padding=1) self.relu2_1 = nn.ReLU(inplace=True) self.conv2_2 = nn.Conv2d(128, 128, 3, padding=1) self.relu2_2 = nn.ReLU(inplace=True) self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True) # conv3 self.conv3_1 = nn.Conv2d(128, 256, 3, padding=1) self.relu3_1 = nn.ReLU(inplace=True) self.conv3_2 = nn.Conv2d(256, 256, 3, padding=1) self.relu3_2 = nn.ReLU(inplace=True) self.conv3_3 = nn.Conv2d(256, 256, 3, padding=1) self.relu3_3 = nn.ReLU(inplace=True) self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True) # conv4 self.conv4_1 = nn.Conv2d(256, 512, 3, padding=1) self.relu4_1 = nn.ReLU(inplace=True) self.conv4_2 = nn.Conv2d(512, 512, 3, padding=1) self.relu4_2 = nn.ReLU(inplace=True) self.conv4_3 = nn.Conv2d(512, 512, 3, padding=1) self.relu4_3 = nn.ReLU(inplace=True) self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True) # conv5 self.conv5_1 = nn.Conv2d(512, 512, 3, padding=1) self.relu5_1 = nn.ReLU(inplace=True) self.conv5_2 = nn.Conv2d(512, 512, 3, padding=1) self.relu5_2 = nn.ReLU(inplace=True) self.conv5_3 = nn.Conv2d(512, 512, 3, padding=1) self.relu5_3 = nn.ReLU(inplace=True) self.pool5 = nn.MaxPool2d(2, stride=2, ceil_mode=True) # fc6 self.fc6 = nn.Conv2d(512, 4096, 7) self.relu6 = nn.ReLU(inplace=True) self.drop6 = nn.Dropout2d() # fc7 self.fc7 = nn.Conv2d(4096, 4096, 1) self.relu7 = nn.ReLU(inplace=True) self.drop7 = nn.Dropout2d() self.score_fr = nn.Conv2d(4096, n_class, 1) self.score_pool3 = nn.Conv2d(256, n_class, 1) self.score_pool4 = nn.Conv2d(512, n_class, 1) self.upscore2 = nn.ConvTranspose2d(n_class, n_class, 4, stride=2, bias=True) self.upscore8 = nn.ConvTranspose2d(n_class, n_class, 16, stride=8, bias=False) self.upscore_pool4 = nn.ConvTranspose2d( n_class, n_class, 4, stride=2, bias=False ) self._initialize_weights() def _initialize_weights(self): for m in self.modules(): if isinstance(m, nn.Conv2d): m.weight.data.zero_() if m.bias is not None: m.bias.data.zero_() if isinstance(m, nn.ConvTranspose2d): assert m.kernel_size[0] == m.kernel_size[1] initial_weight = _upsampling_weights( m.in_channels, m.out_channels, m.kernel_size[0] ) m.weight.data.copy_(initial_weight) def forward(self, image): h = self.relu1_1(self.conv1_1(image)) h = self.relu1_2(self.conv1_2(h)) h = self.pool1(h) h = self.relu2_1(self.conv2_1(h)) h = self.relu2_2(self.conv2_2(h)) h = self.pool2(h) h = self.relu3_1(self.conv3_1(h)) h = self.relu3_2(self.conv3_2(h)) h = self.relu3_3(self.conv3_3(h)) h = self.pool3(h) pool3 = h # 1/8 h = self.relu4_1(self.conv4_1(h)) h = self.relu4_2(self.conv4_2(h)) h = self.relu4_3(self.conv4_3(h)) h = self.pool4(h) pool4 = h # 1/16 h = self.relu5_1(self.conv5_1(h)) h = self.relu5_2(self.conv5_2(h)) h = self.relu5_3(self.conv5_3(h)) h = self.pool5(h) h = self.relu6(self.fc6(h)) h = self.drop6(h) h = self.relu7(self.fc7(h)) h = self.drop7(h) h = self.score_fr(h) h = self.upscore2(h) upscore2 = h # 1/16 h = self.score_pool4(pool4) h = h[:, :, 5:5 + upscore2.size()[2], 5:5 + upscore2.size()[3]] score_pool4c = h # 1/16 h = upscore2 + score_pool4c # 1/16 h = self.upscore_pool4(h) upscore_pool4 = h # 1/8 h = self.score_pool3(pool3) h = h[:, :, 9:9 + upscore_pool4.size()[2], 9:9 + upscore_pool4.size()[3]] score_pool3c = h # 1/8 h = upscore_pool4 + score_pool3c # 1/8 h = self.upscore8(h) h = h[:, :, 31:31 + image.size()[2], 31:31 + image.size()[3]].contiguous() return h ================================================ FILE: crfasrnn/filters.py ================================================ """ MIT License Copyright (c) 2019 Sadeep Jayasumana Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ from abc import ABC, abstractmethod import numpy as np import torch try: import permuto_cpp except ImportError as e: raise (e, "Did you import `torch` first?") _CPU = torch.device("cpu") _EPS = np.finfo("float").eps class PermutoFunction(torch.autograd.Function): @staticmethod def forward(ctx, q_in, features): q_out = permuto_cpp.forward(q_in, features)[0] ctx.save_for_backward(features) return q_out @staticmethod def backward(ctx, grad_q_out): feature_saved = ctx.saved_tensors[0] grad_q_back = permuto_cpp.backward( grad_q_out.contiguous(), feature_saved.contiguous() )[0] return grad_q_back, None # No need of grads w.r.t. features def _spatial_features(image, sigma): """ Return the spatial features as a Tensor Args: image: Image as a Tensor of shape (channels, height, wight) sigma: Bandwidth parameter Returns: Tensor of shape [h, w, 2] with spatial features """ sigma = float(sigma) _, h, w = image.size() x = torch.arange(start=0, end=w, dtype=torch.float32, device=_CPU) xx = x.repeat([h, 1]) / sigma y = torch.arange( start=0, end=h, dtype=torch.float32, device=torch.device("cpu") ).view(-1, 1) yy = y.repeat([1, w]) / sigma return torch.stack([xx, yy], dim=2) class AbstractFilter(ABC): """ Super-class for permutohedral-based Gaussian filters """ def __init__(self, image): self.features = self._calc_features(image) self.norm = self._calc_norm(image) def apply(self, input_): output = PermutoFunction.apply(input_, self.features) return output * self.norm @abstractmethod def _calc_features(self, image): pass def _calc_norm(self, image): _, h, w = image.size() all_ones = torch.ones((1, h, w), dtype=torch.float32, device=_CPU) norm = PermutoFunction.apply(all_ones, self.features) return 1.0 / (norm + _EPS) class SpatialFilter(AbstractFilter): """ Gaussian filter in the spatial ([x, y]) domain """ def __init__(self, image, gamma): """ Create new instance Args: image: Image tensor of shape (3, height, width) gamma: Standard deviation """ self.gamma = gamma super(SpatialFilter, self).__init__(image) def _calc_features(self, image): return _spatial_features(image, self.gamma) class BilateralFilter(AbstractFilter): """ Gaussian filter in the bilateral ([r, g, b, x, y]) domain """ def __init__(self, image, alpha, beta): """ Create new instance Args: image: Image tensor of shape (3, height, width) alpha: Smoothness (spatial) sigma beta: Appearance (color) sigma """ self.alpha = alpha self.beta = beta super(BilateralFilter, self).__init__(image) def _calc_features(self, image): xy = _spatial_features( image, self.alpha ) # TODO Possible optimisation, was calculated in the spatial kernel rgb = (image / float(self.beta)).permute(1, 2, 0) # Channel last order return torch.cat([xy, rgb], dim=2) ================================================ FILE: crfasrnn/params.py ================================================ """ MIT License Copyright (c) 2019 Sadeep Jayasumana Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ class DenseCRFParams(object): """ Parameters for the DenseCRF model """ def __init__( self, alpha=160.0, beta=3.0, gamma=3.0, spatial_ker_weight=3.0, bilateral_ker_weight=5.0, ): """ Default values were taken from https://github.com/sadeepj/crfasrnn_keras. More details about these parameters can be found in https://arxiv.org/pdf/1210.5644.pdf Args: alpha: Bandwidth for the spatial component of the bilateral filter beta: Bandwidth for the color component of the bilateral filter gamma: Bandwidth for the spatial filter spatial_ker_weight: Spatial kernel weight bilateral_ker_weight: Bilateral kernel weight """ self.alpha = alpha self.beta = beta self.gamma = gamma self.spatial_ker_weight = spatial_ker_weight self.bilateral_ker_weight = bilateral_ker_weight ================================================ FILE: crfasrnn/permuto.cpp ================================================ #include #include #include #include #include "permutohedral.h" /** * * @param input_values Input values to filter (e.g. Q distributions). Has shape (channels, height, width) * @param features Features for the permutohedral lattice. Has shape (height, width, feature_channels). Note that * channels are at the end! * @return Filtered values with shape (channels, height, width) */ std::vector permuto_forward(torch::Tensor input_values, torch::Tensor features) { auto input_sizes = input_values.sizes(); // (channels, height, width) auto feature_sizes = features.sizes(); // (height, width, num_features) auto h = feature_sizes[0]; auto w = feature_sizes[1]; auto n_feature_dims = static_cast(feature_sizes[2]); auto n_pixels = static_cast(h * w); auto n_channels = static_cast(input_sizes[0]); // Validate the arguments if (input_sizes[1] != h || input_sizes[2] != w) { throw std::runtime_error("Sizes of `input_values` and `features` do not match!"); } if (!(input_values.dtype() == torch::kFloat32)) { throw std::runtime_error("`input_values` must have float32 type."); } if (!(features.dtype() == torch::kFloat32)) { throw std::runtime_error("`features` must have float32 type."); } // Create the output tensor auto options = torch::TensorOptions() .dtype(torch::kFloat32) .layout(torch::kStrided) .device(torch::kCPU) .requires_grad(false); auto output_values = torch::empty(input_sizes, options); output_values = output_values.contiguous(); Permutohedral p; p.init(features.contiguous().data(), n_feature_dims, n_pixels); p.compute(output_values.data(), input_values.contiguous().data(), n_channels); return {output_values}; } std::vector permuto_backward(torch::Tensor grads, torch::Tensor features) { auto grad_sizes = grads.sizes(); // (channels, height, width) auto feature_sizes = features.sizes(); // (height, width, num_features) auto h = feature_sizes[0]; auto w = feature_sizes[1]; auto n_feature_dims = static_cast(feature_sizes[2]); auto n_pixels = static_cast(h * w); auto n_channels = static_cast(grad_sizes[0]); // Validate the arguments if (grad_sizes[1] != h || grad_sizes[2] != w) { throw std::runtime_error("Sizes of `grad_values` and `features` do not match!"); } if (!(grads.dtype() == torch::kFloat32)) { throw std::runtime_error("`input_values` must have float32 type."); } if (!(features.dtype() == torch::kFloat32)) { throw std::runtime_error("`features` must have float32 type."); } // Create the output tensor auto options = torch::TensorOptions() .dtype(torch::kFloat32) .layout(torch::kStrided) .device(torch::kCPU) .requires_grad(false); auto grads_back = torch::empty(grad_sizes, options); grads_back = grads_back.contiguous(); Permutohedral p; p.init(features.contiguous().data(), n_feature_dims, n_pixels); p.compute(grads_back.data(), grads.contiguous().data(), n_channels, true); return {grads_back}; } PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("forward", &permuto_forward, "PERMUTO forward"); m.def("backward", &permuto_backward, "PERMUTO backward"); } ================================================ FILE: crfasrnn/permutohedral.cpp ================================================ /* This file contains a modified version of the "permutohedral.cpp" code available at http://graphics.stanford.edu/projects/drf/. Copyright notice of the original file is included below: Copyright (c) 2013, Philipp Krähenbühl All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: * Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. * Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. * Neither the name of the Stanford University nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY Philipp Krähenbühl ''AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL Philipp Krähenbühl BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. */ //#include "stdafx.h" #include "permutohedral.h" #ifdef __SSE__ // SSE Permutoheral lattice # define SSE_PERMUTOHEDRAL #endif #if defined(SSE_PERMUTOHEDRAL) # include # include # ifdef __SSE4_1__ # include # endif #endif /************************************************/ /*** Hash Table ***/ /************************************************/ class HashTable{ protected: size_t key_size_, filled_, capacity_; std::vector< short > keys_; std::vector< int > table_; void grow(){ // Create the new memory and copy the values in int old_capacity = capacity_; capacity_ *= 2; std::vector old_keys( (old_capacity+10)*key_size_ ); std::copy( keys_.begin(), keys_.end(), old_keys.begin() ); std::vector old_table( capacity_, -1 ); // Swap the memory table_.swap( old_table ); keys_.swap( old_keys ); // Reinsert each element for( int i=0; i= 0){ int e = old_table[i]; size_t h = hash( getKey(e) ) % capacity_; for(; table_[h] >= 0; h = h= capacity_) grow(); // Get the hash value size_t h = hash( k ) % capacity_; // Find the element with he right key, using linear probing while(1){ int e = table_[h]; if (e==-1){ if (create){ // Insert a new key and return the new id for( size_t i=0; i0; j-- ){ __m128 cf = f[j-1]*scale_factor[j-1]; elevated[j] = sm - _mm_set1_ps(j)*cf; sm += cf; } elevated[0] = sm; // Find the closest 0-colored simplex through rounding __m128 sum = Zero; for( int i=0; i<=d_; i++ ){ __m128 v = invdplus1 * elevated[i]; #ifdef __SSE4_1__ v = _mm_round_ps( v, _MM_FROUND_TO_NEAREST_INT ); #else v = _mm_cvtepi32_ps( _mm_cvtps_epi32( v ) ); #endif rem0[i] = v*dplus1; sum += v; } // Find the simplex we are in and store it in rank (where rank describes what position coorinate i has in the sorted order of the features values) for( int i=0; i<=d_; i++ ) rank[i] = Zero; for( int i=0; i0; j-- ){ float cf = f[j-1]*scale_factor[j-1]; elevated[j] = sm - j*cf; sm += cf; } elevated[0] = sm; // Find the closest 0-colored simplex through rounding float down_factor = 1.0f / (d_+1); float up_factor = (d_+1); int sum = 0; for( int i=0; i<=d_; i++ ){ //int rd1 = round( down_factor * elevated[i]); int rd2; float v = down_factor * elevated[i]; float up = ceilf(v)*up_factor; float down = floorf(v)*up_factor; if (up - elevated[i] < elevated[i] - down) rd2 = (short)up; else rd2 = (short)down; //if(rd1!=rd2) // break; rem0[i] = rd2; sum += rd2*down_factor; } // Find the simplex we are in and store it in rank (where rank describes what position coorinate i has in the sorted order of the features values) for( int i=0; i<=d_; i++ ) rank[i] = 0; for( int i=0; i d_ ){ rank[i] -= d_+1; rem0[i] -= d_+1; } } // Compute the barycentric coordinates (p.10 in [Adams etal 2010]) for( int i=0; i<=d_+1; i++ ) barycentric[i] = 0; for( int i=0; i<=d_; i++ ){ float v = (elevated[i] - rem0[i])*down_factor; barycentric[d_-rank[i] ] += v; barycentric[d_-rank[i]+1] -= v; } // Wrap around barycentric[0] += 1.0 + barycentric[d_+1]; // Compute all vertices and their offset for( int remainder=0; remainder<=d_; remainder++ ){ for( int i=0; i 0 (used for blurring) float * values = new float[ (M_+2)*value_size ]; float * new_values = new float[ (M_+2)*value_size ]; for( int i=0; i<(M_+2)*value_size; i++ ) values[i] = new_values[i] = 0; // Splatting for( int i=0; i=0; reverse?j--:j++ ){ for( int i=0; i 0 (used for blurring) __m128 * sse_val = (__m128*) _mm_malloc( sse_value_size*sizeof(__m128), 16 ); __m128 * values = (__m128*) _mm_malloc( (M_+2)*sse_value_size*sizeof(__m128), 16 ); __m128 * new_values = (__m128*) _mm_malloc( (M_+2)*sse_value_size*sizeof(__m128), 16 ); __m128 Zero = _mm_set1_ps( 0 ); for( int i=0; i<(M_+2)*sse_value_size; i++ ) values[i] = new_values[i] = Zero; for( int i=0; i=0; reverse?j--:j++ ){ for( int i=0; i #include #include #include #include #include /************************************************/ /*** Permutohedral Lattice ***/ /************************************************/ class Permutohedral { protected: struct Neighbors { int n1, n2; Neighbors(int n1 = 0, int n2 = 0) : n1(n1), n2(n2) { } }; std::vector offset_, rank_; std::vector barycentric_; std::vector blur_neighbors_; // Number of elements, size of sparse discretized space, dimension of features int N_, M_, d_; void sseCompute(float *out, const float *in, int value_size, bool reverse = false, bool add = false) const; void seqCompute(float *out, const float *in, int value_size, bool reverse = false, bool add = false) const; public: Permutohedral(); void init(const float *features, int num_dimensions, int num_points); void compute(float *out, const float *in, int value_size, bool reverse = false, bool add = false) const; }; ================================================ FILE: crfasrnn/setup.py ================================================ from setuptools import setup, Extension from torch.utils import cpp_extension setup(name='permuto_cpp', ext_modules=[cpp_extension.CppExtension('permuto_cpp', ['permuto.cpp', 'permutohedral.cpp'])], cmdclass={'build_ext': cpp_extension.BuildExtension}) ================================================ FILE: crfasrnn/util.py ================================================ """ MIT License Copyright (c) 2019 Sadeep Jayasumana Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ import numpy as np from PIL import Image # Pascal VOC color palette for labels _PALETTE = [0, 0, 0, 128, 0, 0, 0, 128, 0, 128, 128, 0, 0, 0, 128, 128, 0, 128, 0, 128, 128, 128, 128, 128, 64, 0, 0, 192, 0, 0, 64, 128, 0, 192, 128, 0, 64, 0, 128, 192, 0, 128, 64, 128, 128, 192, 128, 128, 0, 64, 0, 128, 64, 0, 0, 192, 0, 128, 192, 0, 0, 64, 128, 128, 64, 128, 0, 192, 128, 128, 192, 128, 64, 64, 0, 192, 64, 0, 64, 192, 0, 192, 192, 0] _IMAGENET_MEANS = np.array([123.68, 116.779, 103.939], dtype=np.float32) # RGB mean values def get_preprocessed_image(file_name): """ Reads an image from the disk, pre-processes it by subtracting mean etc. and returns a numpy array that's ready to be fed into the PyTorch model. Args: file_name: File to read the image from Returns: A tuple containing: (preprocessed image, img_h, img_w, original width & height) """ image = Image.open(file_name) original_size = image.size w, h = original_size ratio = min(500.0 / w, 500.0 / h) image = image.resize((int(w * ratio), int(h * ratio)), resample=Image.BILINEAR) im = np.array(image).astype(np.float32) assert im.ndim == 3, 'Only RGB images are supported.' im = im[:, :, :3] im = im - _IMAGENET_MEANS im = im[:, :, ::-1] # Convert to BGR img_h, img_w, _ = im.shape pad_h = 500 - img_h pad_w = 500 - img_w im = np.pad(im, pad_width=((0, pad_h), (0, pad_w), (0, 0)), mode='constant', constant_values=0) return np.expand_dims(im.transpose([2, 0, 1]), 0), img_h, img_w, original_size def get_label_image(probs, img_h, img_w, original_size): """ Returns the label image (PNG with Pascal VOC colormap) given the probabilities. Args: probs: Probability output of shape (num_labels, height, width) img_h: Image height img_w: Image width original_size: Original image size (width, height) Returns: Label image as a PIL Image """ labels = probs.argmax(axis=0).astype('uint8')[:img_h, :img_w] label_im = Image.fromarray(labels, 'P') label_im.putpalette(_PALETTE) label_im = label_im.resize(original_size) return label_im ================================================ FILE: quick_run.py ================================================ """ MIT License Copyright (c) 2019 Sadeep Jayasumana Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ import argparse import torch from crfasrnn import util from crfasrnn.crfasrnn_model import CrfRnnNet def main(): parser = argparse.ArgumentParser() parser.add_argument( "--weights", help="Path to the .pth file (download from https://tinyurl.com/crfasrnn-weights-pth)", required=True, ) parser.add_argument("--image", help="Path to the input image", required=True) parser.add_argument("--output", help="Path to the output label image", default=None) args = parser.parse_args() img_data, img_h, img_w, size = util.get_preprocessed_image(args.image) output_file = args.output or args.imaage + "_labels.png" model = CrfRnnNet() model.load_state_dict(torch.load(args.weights)) model.eval() out = model.forward(torch.from_numpy(img_data)) probs = out.detach().numpy()[0] label_im = util.get_label_image(probs, img_h, img_w, size) label_im.save(output_file) if __name__ == "__main__": main() ================================================ FILE: requirements.txt ================================================ torch torchvision Pillow ================================================ FILE: run_demo.py ================================================ """ MIT License Copyright (c) 2019 Sadeep Jayasumana Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ import torch from crfasrnn import util from crfasrnn.crfasrnn_model import CrfRnnNet def main(): input_file = "image.jpg" output_file = "labels.png" # Read the image img_data, img_h, img_w, size = util.get_preprocessed_image(input_file) # Download the model from https://tinyurl.com/crfasrnn-weights-pth saved_weights_path = "crfasrnn_weights.pth" model = CrfRnnNet() model.load_state_dict(torch.load(saved_weights_path)) model.eval() out = model.forward(torch.from_numpy(img_data)) probs = out.detach().numpy()[0] label_im = util.get_label_image(probs, img_h, img_w, size) label_im.save(output_file) if __name__ == "__main__": main()