Repository: sadeepj/crfasrnn_pytorch
Branch: master
Commit: 24899c528981
Files: 18
Total size: 56.1 KB
Directory structure:
gitextract_ryfdypmp/
├── .gitattributes
├── .gitignore
├── LICENSE
├── README.md
├── crfasrnn/
│ ├── __init__.py
│ ├── crfasrnn_model.py
│ ├── crfrnn.py
│ ├── fcn8s.py
│ ├── filters.py
│ ├── params.py
│ ├── permuto.cpp
│ ├── permutohedral.cpp
│ ├── permutohedral.h
│ ├── setup.py
│ └── util.py
├── quick_run.py
├── requirements.txt
└── run_demo.py
================================================
FILE CONTENTS
================================================
================================================
FILE: .gitattributes
================================================
crfasrnn/permutohedral.cpp linguist-vendored
crfasrnn/permutohedral.h linguist-vendored
================================================
FILE: .gitignore
================================================
.idea
__pycache__
.pyc
================================================
FILE: LICENSE
================================================
MIT License
Copyright (c) 2017 Sadeep Jayasumana
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
================================================
FILE: README.md
================================================
# CRF-RNN for Semantic Image Segmentation - PyTorch version

<b>Live demo:</b> [http://crfasrnn.torr.vision](http://crfasrnn.torr.vision) <br/>
<b>Caffe version:</b> [http://github.com/torrvision/crfasrnn](http://github.com/torrvision/crfasrnn)<br/>
<b>Tensorflow/Keras version:</b> [http://github.com/sadeepj/crfasrnn_keras](http://github.com/sadeepj/crfasrnn_keras)<br/>
This repository contains the official PyTorch implementation of the "CRF-RNN" semantic image segmentation method, published in the ICCV 2015 paper [Conditional Random Fields as Recurrent Neural Networks](http://www.robots.ox.ac.uk/~szheng/papers/CRFasRNN.pdf). The [online demo](http://crfasrnn.torr.vision) of this project won the Best Demo Prize at ICCV 2015. Results of this PyTorch code are identical to that of the Caffe and Tensorflow/Keras based versions above.
If you use this code/model for your research, please cite the following paper:
```
@inproceedings{crfasrnn_ICCV2015,
author = {Shuai Zheng and Sadeep Jayasumana and Bernardino Romera-Paredes and Vibhav Vineet and
Zhizhong Su and Dalong Du and Chang Huang and Philip H. S. Torr},
title = {Conditional Random Fields as Recurrent Neural Networks},
booktitle = {International Conference on Computer Vision (ICCV)},
year = {2015}
}
```
## Installation Guide
_Note_: If you are using a Python virtualenv, make sure it is activated before running each command in this guide.
### Step 1: Clone the repository
```
$ git clone https://github.com/sadeepj/crfasrnn_pytorch.git
```
The root directory of the clone will be referred to as `crfasrnn_pytorch` hereafter.
### Step 2: Install dependencies
Use the `requirements.txt` file in this repository to install all the dependencies via `pip`:
```
$ cd crfasrnn_pytorch
$ pip install -r requirements.txt
```
After installing the dependencies, run the following commands to make sure they are properly installed:
```
$ python
>>> import torch
```
You should not see any errors while importing `torch` above.
### Step 3: Build CRF-RNN custom op
Run `setup.py` inside the `crfasrnn_pytorch/crfasrnn` directory:
```
$ cd crfasrnn_pytorch/crfasrnn
$ python setup.py install
```
Note that the `python` command in the console should refer to the Python interpreter associated with your PyTorch installation.
### Step 4: Download the pre-trained model weights
Download the model weights from [here](https://github.com/sadeepj/crfasrnn_pytorch/releases/download/0.0.1/crfasrnn_weights.pth) and place it in the `crfasrnn_pytorch` directory with the file name `crfasrnn_weights.pth`.
### Step 5: Run the demo
```
$ cd crfasrnn_pytorch
$ python run_demo.py
```
If all goes well, you will see the segmentation results in a file named "labels.png".
## Contributors
* Sadeep Jayasumana ([sadeepj](https://github.com/sadeepj))
* Harsha Ranasinghe ([HarshaPrabhath](https://github.com/HarshaPrabhath))
================================================
FILE: crfasrnn/__init__.py
================================================
================================================
FILE: crfasrnn/crfasrnn_model.py
================================================
"""
MIT License
Copyright (c) 2019 Sadeep Jayasumana
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
"""
from crfasrnn.crfrnn import CrfRnn
from crfasrnn.fcn8s import Fcn8s
class CrfRnnNet(Fcn8s):
"""
The full CRF-RNN network with the FCN-8s backbone as described in the paper:
Conditional Random Fields as Recurrent Neural Networks,
S. Zheng, S. Jayasumana, B. Romera-Paredes, V. Vineet, Z. Su, D. Du, C. Huang and P. Torr,
ICCV 2015 (https://arxiv.org/abs/1502.03240).
"""
def __init__(self):
super(CrfRnnNet, self).__init__()
self.crfrnn = CrfRnn(num_labels=21, num_iterations=10)
def forward(self, image):
out = super(CrfRnnNet, self).forward(image)
# Plug the CRF-RNN module at the end
return self.crfrnn(image, out)
================================================
FILE: crfasrnn/crfrnn.py
================================================
"""
MIT License
Copyright (c) 2019 Sadeep Jayasumana
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
"""
import torch
import torch.nn as nn
from crfasrnn.filters import SpatialFilter, BilateralFilter
from crfasrnn.params import DenseCRFParams
class CrfRnn(nn.Module):
"""
PyTorch implementation of the CRF-RNN module described in the paper:
Conditional Random Fields as Recurrent Neural Networks,
S. Zheng, S. Jayasumana, B. Romera-Paredes, V. Vineet, Z. Su, D. Du, C. Huang and P. Torr,
ICCV 2015 (https://arxiv.org/abs/1502.03240).
"""
def __init__(self, num_labels, num_iterations=5, crf_init_params=None):
"""
Create a new instance of the CRF-RNN layer.
Args:
num_labels: Number of semantic labels in the dataset
num_iterations: Number of mean-field iterations to perform
crf_init_params: CRF initialization parameters
"""
super(CrfRnn, self).__init__()
if crf_init_params is None:
crf_init_params = DenseCRFParams()
self.params = crf_init_params
self.num_iterations = num_iterations
self._softmax = torch.nn.Softmax(dim=0)
self.num_labels = num_labels
# --------------------------------------------------------------------------------------------
# --------------------------------- Trainable Parameters -------------------------------------
# --------------------------------------------------------------------------------------------
# Spatial kernel weights
self.spatial_ker_weights = nn.Parameter(
crf_init_params.spatial_ker_weight
* torch.eye(num_labels, dtype=torch.float32)
)
# Bilateral kernel weights
self.bilateral_ker_weights = nn.Parameter(
crf_init_params.bilateral_ker_weight
* torch.eye(num_labels, dtype=torch.float32)
)
# Compatibility transform matrix
self.compatibility_matrix = nn.Parameter(
torch.eye(num_labels, dtype=torch.float32)
)
def forward(self, image, logits):
"""
Perform CRF inference.
Args:
image: Tensor of shape (3, h, w) containing the RGB image
logits: Tensor of shape (num_classes, h, w) containing the unary logits
Returns:
log-Q distributions (logits) after CRF inference
"""
if logits.shape[0] != 1:
raise ValueError("Only batch size 1 is currently supported!")
image = image[0]
logits = logits[0]
spatial_filter = SpatialFilter(image, gamma=self.params.gamma)
bilateral_filter = BilateralFilter(
image, alpha=self.params.alpha, beta=self.params.beta
)
_, h, w = image.shape
cur_logits = logits
for _ in range(self.num_iterations):
# Normalization
q_values = self._softmax(cur_logits)
# Spatial filtering
spatial_out = torch.mm(
self.spatial_ker_weights,
spatial_filter.apply(q_values).view(self.num_labels, -1),
)
# Bilateral filtering
bilateral_out = torch.mm(
self.bilateral_ker_weights,
bilateral_filter.apply(q_values).view(self.num_labels, -1),
)
# Compatibility transform
msg_passing_out = (
spatial_out + bilateral_out
) # Shape: (self.num_labels, -1)
msg_passing_out = torch.mm(self.compatibility_matrix, msg_passing_out).view(
self.num_labels, h, w
)
# Adding unary potentials
cur_logits = msg_passing_out + logits
return torch.unsqueeze(cur_logits, 0)
================================================
FILE: crfasrnn/fcn8s.py
================================================
"""
This file contains a modified version of the FCN-8s code available in https://github.com/wkentaro/pytorch-fcn
The original copyright notice from that repository is included below:
Copyright (c) 2017 - 2019 Kentaro Wada.
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.
"""
import numpy as np
import torch
import torch.nn as nn
def _upsampling_weights(in_channels, out_channels, kernel_size):
factor = (kernel_size + 1) // 2
if kernel_size % 2 == 1:
center = factor - 1
else:
center = factor - 0.5
og = np.ogrid[:kernel_size, :kernel_size]
filt = (1 - abs(og[0] - center) / factor) * (1 - abs(og[1] - center) / factor)
weight = np.zeros(
(in_channels, out_channels, kernel_size, kernel_size), dtype=np.float64
)
weight[range(in_channels), range(out_channels), :, :] = filt
return torch.from_numpy(weight).float()
class Fcn8s(nn.Module):
def __init__(self, n_class=21):
"""
Create the FCN-8s network the the given number of classes.
Args:
n_class: The number of semantic classes.
"""
super(Fcn8s, self).__init__()
# conv1
self.conv1_1 = nn.Conv2d(3, 64, 3, padding=100)
self.relu1_1 = nn.ReLU(inplace=True)
self.conv1_2 = nn.Conv2d(64, 64, 3, padding=1)
self.relu1_2 = nn.ReLU(inplace=True)
self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
# conv2
self.conv2_1 = nn.Conv2d(64, 128, 3, padding=1)
self.relu2_1 = nn.ReLU(inplace=True)
self.conv2_2 = nn.Conv2d(128, 128, 3, padding=1)
self.relu2_2 = nn.ReLU(inplace=True)
self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
# conv3
self.conv3_1 = nn.Conv2d(128, 256, 3, padding=1)
self.relu3_1 = nn.ReLU(inplace=True)
self.conv3_2 = nn.Conv2d(256, 256, 3, padding=1)
self.relu3_2 = nn.ReLU(inplace=True)
self.conv3_3 = nn.Conv2d(256, 256, 3, padding=1)
self.relu3_3 = nn.ReLU(inplace=True)
self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
# conv4
self.conv4_1 = nn.Conv2d(256, 512, 3, padding=1)
self.relu4_1 = nn.ReLU(inplace=True)
self.conv4_2 = nn.Conv2d(512, 512, 3, padding=1)
self.relu4_2 = nn.ReLU(inplace=True)
self.conv4_3 = nn.Conv2d(512, 512, 3, padding=1)
self.relu4_3 = nn.ReLU(inplace=True)
self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
# conv5
self.conv5_1 = nn.Conv2d(512, 512, 3, padding=1)
self.relu5_1 = nn.ReLU(inplace=True)
self.conv5_2 = nn.Conv2d(512, 512, 3, padding=1)
self.relu5_2 = nn.ReLU(inplace=True)
self.conv5_3 = nn.Conv2d(512, 512, 3, padding=1)
self.relu5_3 = nn.ReLU(inplace=True)
self.pool5 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
# fc6
self.fc6 = nn.Conv2d(512, 4096, 7)
self.relu6 = nn.ReLU(inplace=True)
self.drop6 = nn.Dropout2d()
# fc7
self.fc7 = nn.Conv2d(4096, 4096, 1)
self.relu7 = nn.ReLU(inplace=True)
self.drop7 = nn.Dropout2d()
self.score_fr = nn.Conv2d(4096, n_class, 1)
self.score_pool3 = nn.Conv2d(256, n_class, 1)
self.score_pool4 = nn.Conv2d(512, n_class, 1)
self.upscore2 = nn.ConvTranspose2d(n_class, n_class, 4, stride=2, bias=True)
self.upscore8 = nn.ConvTranspose2d(n_class, n_class, 16, stride=8, bias=False)
self.upscore_pool4 = nn.ConvTranspose2d(
n_class, n_class, 4, stride=2, bias=False
)
self._initialize_weights()
def _initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
m.weight.data.zero_()
if m.bias is not None:
m.bias.data.zero_()
if isinstance(m, nn.ConvTranspose2d):
assert m.kernel_size[0] == m.kernel_size[1]
initial_weight = _upsampling_weights(
m.in_channels, m.out_channels, m.kernel_size[0]
)
m.weight.data.copy_(initial_weight)
def forward(self, image):
h = self.relu1_1(self.conv1_1(image))
h = self.relu1_2(self.conv1_2(h))
h = self.pool1(h)
h = self.relu2_1(self.conv2_1(h))
h = self.relu2_2(self.conv2_2(h))
h = self.pool2(h)
h = self.relu3_1(self.conv3_1(h))
h = self.relu3_2(self.conv3_2(h))
h = self.relu3_3(self.conv3_3(h))
h = self.pool3(h)
pool3 = h # 1/8
h = self.relu4_1(self.conv4_1(h))
h = self.relu4_2(self.conv4_2(h))
h = self.relu4_3(self.conv4_3(h))
h = self.pool4(h)
pool4 = h # 1/16
h = self.relu5_1(self.conv5_1(h))
h = self.relu5_2(self.conv5_2(h))
h = self.relu5_3(self.conv5_3(h))
h = self.pool5(h)
h = self.relu6(self.fc6(h))
h = self.drop6(h)
h = self.relu7(self.fc7(h))
h = self.drop7(h)
h = self.score_fr(h)
h = self.upscore2(h)
upscore2 = h # 1/16
h = self.score_pool4(pool4)
h = h[:, :, 5:5 + upscore2.size()[2], 5:5 + upscore2.size()[3]]
score_pool4c = h # 1/16
h = upscore2 + score_pool4c # 1/16
h = self.upscore_pool4(h)
upscore_pool4 = h # 1/8
h = self.score_pool3(pool3)
h = h[:, :, 9:9 + upscore_pool4.size()[2], 9:9 + upscore_pool4.size()[3]]
score_pool3c = h # 1/8
h = upscore_pool4 + score_pool3c # 1/8
h = self.upscore8(h)
h = h[:, :, 31:31 + image.size()[2], 31:31 + image.size()[3]].contiguous()
return h
================================================
FILE: crfasrnn/filters.py
================================================
"""
MIT License
Copyright (c) 2019 Sadeep Jayasumana
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
"""
from abc import ABC, abstractmethod
import numpy as np
import torch
try:
import permuto_cpp
except ImportError as e:
raise (e, "Did you import `torch` first?")
_CPU = torch.device("cpu")
_EPS = np.finfo("float").eps
class PermutoFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, q_in, features):
q_out = permuto_cpp.forward(q_in, features)[0]
ctx.save_for_backward(features)
return q_out
@staticmethod
def backward(ctx, grad_q_out):
feature_saved = ctx.saved_tensors[0]
grad_q_back = permuto_cpp.backward(
grad_q_out.contiguous(), feature_saved.contiguous()
)[0]
return grad_q_back, None # No need of grads w.r.t. features
def _spatial_features(image, sigma):
"""
Return the spatial features as a Tensor
Args:
image: Image as a Tensor of shape (channels, height, wight)
sigma: Bandwidth parameter
Returns:
Tensor of shape [h, w, 2] with spatial features
"""
sigma = float(sigma)
_, h, w = image.size()
x = torch.arange(start=0, end=w, dtype=torch.float32, device=_CPU)
xx = x.repeat([h, 1]) / sigma
y = torch.arange(
start=0, end=h, dtype=torch.float32, device=torch.device("cpu")
).view(-1, 1)
yy = y.repeat([1, w]) / sigma
return torch.stack([xx, yy], dim=2)
class AbstractFilter(ABC):
"""
Super-class for permutohedral-based Gaussian filters
"""
def __init__(self, image):
self.features = self._calc_features(image)
self.norm = self._calc_norm(image)
def apply(self, input_):
output = PermutoFunction.apply(input_, self.features)
return output * self.norm
@abstractmethod
def _calc_features(self, image):
pass
def _calc_norm(self, image):
_, h, w = image.size()
all_ones = torch.ones((1, h, w), dtype=torch.float32, device=_CPU)
norm = PermutoFunction.apply(all_ones, self.features)
return 1.0 / (norm + _EPS)
class SpatialFilter(AbstractFilter):
"""
Gaussian filter in the spatial ([x, y]) domain
"""
def __init__(self, image, gamma):
"""
Create new instance
Args:
image: Image tensor of shape (3, height, width)
gamma: Standard deviation
"""
self.gamma = gamma
super(SpatialFilter, self).__init__(image)
def _calc_features(self, image):
return _spatial_features(image, self.gamma)
class BilateralFilter(AbstractFilter):
"""
Gaussian filter in the bilateral ([r, g, b, x, y]) domain
"""
def __init__(self, image, alpha, beta):
"""
Create new instance
Args:
image: Image tensor of shape (3, height, width)
alpha: Smoothness (spatial) sigma
beta: Appearance (color) sigma
"""
self.alpha = alpha
self.beta = beta
super(BilateralFilter, self).__init__(image)
def _calc_features(self, image):
xy = _spatial_features(
image, self.alpha
) # TODO Possible optimisation, was calculated in the spatial kernel
rgb = (image / float(self.beta)).permute(1, 2, 0) # Channel last order
return torch.cat([xy, rgb], dim=2)
================================================
FILE: crfasrnn/params.py
================================================
"""
MIT License
Copyright (c) 2019 Sadeep Jayasumana
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
"""
class DenseCRFParams(object):
"""
Parameters for the DenseCRF model
"""
def __init__(
self,
alpha=160.0,
beta=3.0,
gamma=3.0,
spatial_ker_weight=3.0,
bilateral_ker_weight=5.0,
):
"""
Default values were taken from https://github.com/sadeepj/crfasrnn_keras. More details about these parameters
can be found in https://arxiv.org/pdf/1210.5644.pdf
Args:
alpha: Bandwidth for the spatial component of the bilateral filter
beta: Bandwidth for the color component of the bilateral filter
gamma: Bandwidth for the spatial filter
spatial_ker_weight: Spatial kernel weight
bilateral_ker_weight: Bilateral kernel weight
"""
self.alpha = alpha
self.beta = beta
self.gamma = gamma
self.spatial_ker_weight = spatial_ker_weight
self.bilateral_ker_weight = bilateral_ker_weight
================================================
FILE: crfasrnn/permuto.cpp
================================================
#include <torch/extension.h>
#include <vector>
#include <iostream>
#include <stdexcept>
#include "permutohedral.h"
/**
*
* @param input_values Input values to filter (e.g. Q distributions). Has shape (channels, height, width)
* @param features Features for the permutohedral lattice. Has shape (height, width, feature_channels). Note that
* channels are at the end!
* @return Filtered values with shape (channels, height, width)
*/
std::vector<at::Tensor> permuto_forward(torch::Tensor input_values, torch::Tensor features) {
auto input_sizes = input_values.sizes(); // (channels, height, width)
auto feature_sizes = features.sizes(); // (height, width, num_features)
auto h = feature_sizes[0];
auto w = feature_sizes[1];
auto n_feature_dims = static_cast<int>(feature_sizes[2]);
auto n_pixels = static_cast<int>(h * w);
auto n_channels = static_cast<int>(input_sizes[0]);
// Validate the arguments
if (input_sizes[1] != h || input_sizes[2] != w) {
throw std::runtime_error("Sizes of `input_values` and `features` do not match!");
}
if (!(input_values.dtype() == torch::kFloat32)) {
throw std::runtime_error("`input_values` must have float32 type.");
}
if (!(features.dtype() == torch::kFloat32)) {
throw std::runtime_error("`features` must have float32 type.");
}
// Create the output tensor
auto options = torch::TensorOptions()
.dtype(torch::kFloat32)
.layout(torch::kStrided)
.device(torch::kCPU)
.requires_grad(false);
auto output_values = torch::empty(input_sizes, options);
output_values = output_values.contiguous();
Permutohedral p;
p.init(features.contiguous().data<float>(), n_feature_dims, n_pixels);
p.compute(output_values.data<float>(), input_values.contiguous().data<float>(), n_channels);
return {output_values};
}
std::vector<at::Tensor> permuto_backward(torch::Tensor grads, torch::Tensor features) {
auto grad_sizes = grads.sizes(); // (channels, height, width)
auto feature_sizes = features.sizes(); // (height, width, num_features)
auto h = feature_sizes[0];
auto w = feature_sizes[1];
auto n_feature_dims = static_cast<int>(feature_sizes[2]);
auto n_pixels = static_cast<int>(h * w);
auto n_channels = static_cast<int>(grad_sizes[0]);
// Validate the arguments
if (grad_sizes[1] != h || grad_sizes[2] != w) {
throw std::runtime_error("Sizes of `grad_values` and `features` do not match!");
}
if (!(grads.dtype() == torch::kFloat32)) {
throw std::runtime_error("`input_values` must have float32 type.");
}
if (!(features.dtype() == torch::kFloat32)) {
throw std::runtime_error("`features` must have float32 type.");
}
// Create the output tensor
auto options = torch::TensorOptions()
.dtype(torch::kFloat32)
.layout(torch::kStrided)
.device(torch::kCPU)
.requires_grad(false);
auto grads_back = torch::empty(grad_sizes, options);
grads_back = grads_back.contiguous();
Permutohedral p;
p.init(features.contiguous().data<float>(), n_feature_dims, n_pixels);
p.compute(grads_back.data<float>(), grads.contiguous().data<float>(), n_channels, true);
return {grads_back};
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &permuto_forward, "PERMUTO forward");
m.def("backward", &permuto_backward, "PERMUTO backward");
}
================================================
FILE: crfasrnn/permutohedral.cpp
================================================
/*
This file contains a modified version of the "permutohedral.cpp" code
available at http://graphics.stanford.edu/projects/drf/. Copyright notice of
the original file is included below:
Copyright (c) 2013, Philipp Krähenbühl
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
* Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
* Neither the name of the Stanford University nor the
names of its contributors may be used to endorse or promote products
derived from this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY Philipp Krähenbühl ''AS IS'' AND ANY
EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL Philipp Krähenbühl BE LIABLE FOR ANY
DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
//#include "stdafx.h"
#include "permutohedral.h"
#ifdef __SSE__
// SSE Permutoheral lattice
# define SSE_PERMUTOHEDRAL
#endif
#if defined(SSE_PERMUTOHEDRAL)
# include <emmintrin.h>
# include <xmmintrin.h>
# ifdef __SSE4_1__
# include <smmintrin.h>
# endif
#endif
/************************************************/
/*** Hash Table ***/
/************************************************/
class HashTable{
protected:
size_t key_size_, filled_, capacity_;
std::vector< short > keys_;
std::vector< int > table_;
void grow(){
// Create the new memory and copy the values in
int old_capacity = capacity_;
capacity_ *= 2;
std::vector<short> old_keys( (old_capacity+10)*key_size_ );
std::copy( keys_.begin(), keys_.end(), old_keys.begin() );
std::vector<int> old_table( capacity_, -1 );
// Swap the memory
table_.swap( old_table );
keys_.swap( old_keys );
// Reinsert each element
for( int i=0; i<old_capacity; i++ )
if (old_table[i] >= 0){
int e = old_table[i];
size_t h = hash( getKey(e) ) % capacity_;
for(; table_[h] >= 0; h = h<capacity_-1 ? h+1 : 0);
table_[h] = e;
}
}
size_t hash( const short * k ) {
size_t r = 0;
for( size_t i=0; i<key_size_; i++ ){
r += k[i];
r *= 1664525;
}
return r;
}
public:
explicit HashTable( int key_size, int n_elements ) : key_size_ ( key_size ), filled_(0), capacity_(2*n_elements), keys_((capacity_/2+10)*key_size_), table_(2*n_elements,-1) {
}
int size() const {
return filled_;
}
void reset() {
filled_ = 0;
std::fill( table_.begin(), table_.end(), -1 );
}
int find( const short * k, bool create = false ){
if (2*filled_ >= capacity_) grow();
// Get the hash value
size_t h = hash( k ) % capacity_;
// Find the element with he right key, using linear probing
while(1){
int e = table_[h];
if (e==-1){
if (create){
// Insert a new key and return the new id
for( size_t i=0; i<key_size_; i++ )
keys_[ filled_*key_size_+i ] = k[i];
return table_[h] = filled_++;
}
else
return -1;
}
// Check if the current key is The One
bool good = true;
for( size_t i=0; i<key_size_ && good; i++ )
if (keys_[ e*key_size_+i ] != k[i])
good = false;
if (good)
return e;
// Continue searching
h++;
if (h==capacity_) h = 0;
}
}
const short * getKey( int i ) const{
return &keys_[i*key_size_];
}
};
/************************************************/
/*** Permutohedral Lattice ***/
/************************************************/
Permutohedral::Permutohedral():N_( 0 ), M_( 0 ), d_( 0 ) {
}
#ifdef SSE_PERMUTOHEDRAL
void Permutohedral::init(const float* features, int num_dimensions, int num_points)
{
// Compute the lattice coordinates for each feature [there is going to be a lot of magic here
N_ = num_points;
d_ = num_dimensions;
HashTable hash_table( d_, N_/**(d_+1)*/ );
const int blocksize = sizeof(__m128) / sizeof(float);
const __m128 invdplus1 = _mm_set1_ps( 1.0f / (d_+1) );
const __m128 dplus1 = _mm_set1_ps( d_+1 );
const __m128 Zero = _mm_set1_ps( 0 );
const __m128 One = _mm_set1_ps( 1 );
// Allocate the class memory
offset_.resize( (d_+1)*(N_+16) );
std::fill( offset_.begin(), offset_.end(), 0 );
barycentric_.resize( (d_+1)*(N_+16) );
std::fill( barycentric_.begin(), barycentric_.end(), 0 );
rank_.resize( (d_+1)*(N_+16) );
// Allocate the local memory
__m128 * scale_factor = (__m128*) _mm_malloc( (d_ )*sizeof(__m128) , 16 );
__m128 * f = (__m128*) _mm_malloc( (d_ )*sizeof(__m128) , 16 );
__m128 * elevated = (__m128*) _mm_malloc( (d_+1)*sizeof(__m128) , 16 );
__m128 * rem0 = (__m128*) _mm_malloc( (d_+1)*sizeof(__m128) , 16 );
__m128 * rank = (__m128*) _mm_malloc( (d_+1)*sizeof(__m128), 16 );
float * barycentric = new float[(d_+2)*blocksize];
short * canonical = new short[(d_+1)*(d_+1)];
short * key = new short[d_+1];
// Compute the canonical simplex
for( int i=0; i<=d_; i++ ){
for( int j=0; j<=d_-i; j++ )
canonical[i*(d_+1)+j] = i;
for( int j=d_-i+1; j<=d_; j++ )
canonical[i*(d_+1)+j] = i - (d_+1);
}
// Expected standard deviation of our filter (p.6 in [Adams etal 2010])
float inv_std_dev = sqrt(2.0 / 3.0)*(d_+1);
// Compute the diagonal part of E (p.5 in [Adams etal 2010])
for( int i=0; i<d_; i++ )
scale_factor[i] = _mm_set1_ps( 1.0 / sqrt( (i+2)*(i+1) ) * inv_std_dev );
// Setup the SSE rounding
#ifndef __SSE4_1__
const unsigned int old_rounding = _mm_getcsr();
_mm_setcsr( (old_rounding&~_MM_ROUND_MASK) | _MM_ROUND_NEAREST );
#endif
// Compute the simplex each feature lies in
for( int k=0; k<N_; k+=blocksize ){
// Load the feature from memory
float * ff = (float*)f;
for( int j=0; j<d_; j++ )
for( int i=0; i<blocksize; i++ )
ff[ j*blocksize + i ] = k+i < N_ ? *(features + (k + i) * num_dimensions + j) : 0.0;
// Elevate the feature ( y = Ep, see p.5 in [Adams etal 2010])
// sm contains the sum of 1..n of our faeture vector
__m128 sm = Zero;
for( int j=d_; j>0; j-- ){
__m128 cf = f[j-1]*scale_factor[j-1];
elevated[j] = sm - _mm_set1_ps(j)*cf;
sm += cf;
}
elevated[0] = sm;
// Find the closest 0-colored simplex through rounding
__m128 sum = Zero;
for( int i=0; i<=d_; i++ ){
__m128 v = invdplus1 * elevated[i];
#ifdef __SSE4_1__
v = _mm_round_ps( v, _MM_FROUND_TO_NEAREST_INT );
#else
v = _mm_cvtepi32_ps( _mm_cvtps_epi32( v ) );
#endif
rem0[i] = v*dplus1;
sum += v;
}
// Find the simplex we are in and store it in rank (where rank describes what position coorinate i has in the sorted order of the features values)
for( int i=0; i<=d_; i++ )
rank[i] = Zero;
for( int i=0; i<d_; i++ ){
__m128 di = elevated[i] - rem0[i];
for( int j=i+1; j<=d_; j++ ){
__m128 dj = elevated[j] - rem0[j];
__m128 c = _mm_and_ps( One, _mm_cmplt_ps( di, dj ) );
rank[i] += c;
rank[j] += One-c;
}
}
// If the point doesn't lie on the plane (sum != 0) bring it back
for( int i=0; i<=d_; i++ ){
rank[i] += sum;
__m128 add = _mm_and_ps( dplus1, _mm_cmplt_ps( rank[i], Zero ) );
__m128 sub = _mm_and_ps( dplus1, _mm_cmpge_ps( rank[i], dplus1 ) );
rank[i] += add-sub;
rem0[i] += add-sub;
}
// Compute the barycentric coordinates (p.10 in [Adams etal 2010])
for( int i=0; i<(d_+2)*blocksize; i++ )
barycentric[ i ] = 0;
for( int i=0; i<=d_; i++ ){
__m128 v = (elevated[i] - rem0[i])*invdplus1;
// Didn't figure out how to SSE this
float * fv = (float*)&v;
float * frank = (float*)&rank[i];
for( int j=0; j<blocksize; j++ ){
int p = d_-frank[j];
barycentric[j*(d_+2)+p ] += fv[j];
barycentric[j*(d_+2)+p+1] -= fv[j];
}
}
// The rest is not SSE'd
for( int j=0; j<blocksize; j++ ){
// Wrap around
barycentric[j*(d_+2)+0]+= 1 + barycentric[j*(d_+2)+d_+1];
float * frank = (float*)rank;
float * frem0 = (float*)rem0;
// Compute all vertices and their offset
for( int remainder=0; remainder<=d_; remainder++ ){
for( int i=0; i<d_; i++ ){
key[i] = frem0[i*blocksize+j] + canonical[ remainder*(d_+1) + (int)frank[i*blocksize+j] ];
}
offset_[ (j+k)*(d_+1)+remainder ] = hash_table.find( key, true );
rank_[ (j+k)*(d_+1)+remainder ] = frank[remainder*blocksize+j];
barycentric_[ (j+k)*(d_+1)+remainder ] = barycentric[ j*(d_+2)+remainder ];
}
}
}
_mm_free( scale_factor );
_mm_free( f );
_mm_free( elevated );
_mm_free( rem0 );
_mm_free( rank );
delete [] barycentric;
delete [] canonical;
delete [] key;
// Reset the SSE rounding
#ifndef __SSE4_1__
_mm_setcsr( old_rounding );
#endif
// This is normally fast enough so no SSE needed here
// Find the Neighbors of each lattice point
// Get the number of vertices in the lattice
M_ = hash_table.size();
// Create the neighborhood structure
blur_neighbors_.resize( (d_+1)*M_ );
short * n1 = new short[d_+1];
short * n2 = new short[d_+1];
// For each of d+1 axes,
for( int j = 0; j <= d_; j++ ){
for( int i=0; i<M_; i++ ){
const short * key = hash_table.getKey( i );
for( int k=0; k<d_; k++ ){
n1[k] = key[k] - 1;
n2[k] = key[k] + 1;
}
n1[j] = key[j] + d_;
n2[j] = key[j] - d_;
blur_neighbors_[j*M_+i].n1 = hash_table.find( n1 );
blur_neighbors_[j*M_+i].n2 = hash_table.find( n2 );
}
}
delete[] n1;
delete[] n2;
}
#else
void Permutohedral::init (const float* features, int num_dimensions, int num_points)
{
// Compute the lattice coordinates for each feature [there is going to be a lot of magic here
N_ = num_points;
d_ = num_dimensions;
HashTableCopy hash_table( d_, N_*(d_+1) );
// Allocate the class memory
offset_.resize( (d_+1)*N_ );
rank_.resize( (d_+1)*N_ );
barycentric_.resize( (d_+1)*N_ );
// Allocate the local memory
float * scale_factor = new float[d_];
float * elevated = new float[d_+1];
float * rem0 = new float[d_+1];
float * barycentric = new float[d_+2];
short * rank = new short[d_+1];
short * canonical = new short[(d_+1)*(d_+1)];
short * key = new short[d_+1];
// Compute the canonical simplex
for( int i=0; i<=d_; i++ ){
for( int j=0; j<=d_-i; j++ )
canonical[i*(d_+1)+j] = i;
for( int j=d_-i+1; j<=d_; j++ )
canonical[i*(d_+1)+j] = i - (d_+1);
}
// Expected standard deviation of our filter (p.6 in [Adams etal 2010])
float inv_std_dev = sqrt(2.0 / 3.0)*(d_+1);
// Compute the diagonal part of E (p.5 in [Adams etal 2010])
for( int i=0; i<d_; i++ )
scale_factor[i] = 1.0 / sqrt( double((i+2)*(i+1)) ) * inv_std_dev;
// Compute the simplex each feature lies in
for( int k=0; k<N_; k++ ){
// Elevate the feature ( y = Ep, see p.5 in [Adams etal 2010])
assert false; # Shouldn't reach here
const float * f = (feature + k * num_dimensions);
// sm contains the sum of 1..n of our faeture vector
float sm = 0;
for( int j=d_; j>0; j-- ){
float cf = f[j-1]*scale_factor[j-1];
elevated[j] = sm - j*cf;
sm += cf;
}
elevated[0] = sm;
// Find the closest 0-colored simplex through rounding
float down_factor = 1.0f / (d_+1);
float up_factor = (d_+1);
int sum = 0;
for( int i=0; i<=d_; i++ ){
//int rd1 = round( down_factor * elevated[i]);
int rd2;
float v = down_factor * elevated[i];
float up = ceilf(v)*up_factor;
float down = floorf(v)*up_factor;
if (up - elevated[i] < elevated[i] - down) rd2 = (short)up;
else rd2 = (short)down;
//if(rd1!=rd2)
// break;
rem0[i] = rd2;
sum += rd2*down_factor;
}
// Find the simplex we are in and store it in rank (where rank describes what position coorinate i has in the sorted order of the features values)
for( int i=0; i<=d_; i++ )
rank[i] = 0;
for( int i=0; i<d_; i++ ){
double di = elevated[i] - rem0[i];
for( int j=i+1; j<=d_; j++ )
if ( di < elevated[j] - rem0[j])
rank[i]++;
else
rank[j]++;
}
// If the point doesn't lie on the plane (sum != 0) bring it back
for( int i=0; i<=d_; i++ ){
rank[i] += sum;
if ( rank[i] < 0 ){
rank[i] += d_+1;
rem0[i] += d_+1;
}
else if ( rank[i] > d_ ){
rank[i] -= d_+1;
rem0[i] -= d_+1;
}
}
// Compute the barycentric coordinates (p.10 in [Adams etal 2010])
for( int i=0; i<=d_+1; i++ )
barycentric[i] = 0;
for( int i=0; i<=d_; i++ ){
float v = (elevated[i] - rem0[i])*down_factor;
barycentric[d_-rank[i] ] += v;
barycentric[d_-rank[i]+1] -= v;
}
// Wrap around
barycentric[0] += 1.0 + barycentric[d_+1];
// Compute all vertices and their offset
for( int remainder=0; remainder<=d_; remainder++ ){
for( int i=0; i<d_; i++ )
key[i] = rem0[i] + canonical[ remainder*(d_+1) + rank[i] ];
offset_[ k*(d_+1)+remainder ] = hash_table.find( key, true );
rank_[ k*(d_+1)+remainder ] = rank[remainder];
barycentric_[ k*(d_+1)+remainder ] = barycentric[ remainder ];
}
}
delete [] scale_factor;
delete [] elevated;
delete [] rem0;
delete [] barycentric;
delete [] rank;
delete [] canonical;
delete [] key;
// Find the Neighbors of each lattice point
// Get the number of vertices in the lattice
M_ = hash_table.size();
// Create the neighborhood structure
blur_neighbors_.resize( (d_+1)*M_ );
short * n1 = new short[d_+1];
short * n2 = new short[d_+1];
// For each of d+1 axes,
for( int j = 0; j <= d_; j++ ){
for( int i=0; i<M_; i++ ){
const short * key = hash_table.getKey( i );
for( int k=0; k<d_; k++ ){
n1[k] = key[k] - 1;
n2[k] = key[k] + 1;
}
n1[j] = key[j] + d_;
n2[j] = key[j] - d_;
blur_neighbors_[j*M_+i].n1 = hash_table.find( n1 );
blur_neighbors_[j*M_+i].n2 = hash_table.find( n2 );
}
}
delete[] n1;
delete[] n2;
}
#endif
void Permutohedral::seqCompute(float* out, const float* in, int value_size, bool reverse, bool add) const
{
// Shift all values by 1 such that -1 -> 0 (used for blurring)
float * values = new float[ (M_+2)*value_size ];
float * new_values = new float[ (M_+2)*value_size ];
for( int i=0; i<(M_+2)*value_size; i++ )
values[i] = new_values[i] = 0;
// Splatting
for( int i=0; i<N_; i++ ){
for( int j=0; j<=d_; j++ ){
int o = offset_[i*(d_+1)+j]+1;
float w = barycentric_[i*(d_+1)+j];
for( int k=0; k<value_size; k++ )
values[ o*value_size+k ] += w * in[k*N_ + i];
}
}
for( int j=reverse?d_:0; j<=d_ && j>=0; reverse?j--:j++ ){
for( int i=0; i<M_; i++ ){
float * old_val = values + (i+1)*value_size;
float * new_val = new_values + (i+1)*value_size;
int n1 = blur_neighbors_[j*M_+i].n1+1;
int n2 = blur_neighbors_[j*M_+i].n2+1;
float * n1_val = values + n1*value_size;
float * n2_val = values + n2*value_size;
for( int k=0; k<value_size; k++ )
new_val[k] = old_val[k]+0.5*(n1_val[k] + n2_val[k]);
}
std::swap( values, new_values );
}
// Alpha is a magic scaling constant (write Andrew if you really wanna understand this)
float alpha = 1.0f / (1+powf(2, -d_));
// Slicing
for( int i=0; i<N_; i++ ){
if (!add) {
for( int k=0; k<value_size; k++ )
out[i + k*N_] = 0; //out[i*value_size+k] = 0;
}
for( int j=0; j<=d_; j++ ){
int o = offset_[i*(d_+1)+j]+1;
float w = barycentric_[i*(d_+1)+j];
for( int k=0; k<value_size; k++ )
out[ i + k*N_ ] += w * values[ o*value_size+k ] * alpha;
}
}
delete[] values;
delete[] new_values;
}
#ifdef SSE_PERMUTOHEDRAL
void Permutohedral::sseCompute( float* out, const float* in, int value_size, const bool reverse, const bool add) const
{
const int sse_value_size = (value_size-1)*sizeof(float) / sizeof(__m128) + 1;
// Shift all values by 1 such that -1 -> 0 (used for blurring)
__m128 * sse_val = (__m128*) _mm_malloc( sse_value_size*sizeof(__m128), 16 );
__m128 * values = (__m128*) _mm_malloc( (M_+2)*sse_value_size*sizeof(__m128), 16 );
__m128 * new_values = (__m128*) _mm_malloc( (M_+2)*sse_value_size*sizeof(__m128), 16 );
__m128 Zero = _mm_set1_ps( 0 );
for( int i=0; i<(M_+2)*sse_value_size; i++ )
values[i] = new_values[i] = Zero;
for( int i=0; i<sse_value_size; i++ )
sse_val[i] = Zero;
float* sdp_temp = new float[value_size];
// Splatting
for( int i=0; i<N_; i++ ){
for (int s = 0; s < value_size; s++) {
sdp_temp[s] = in[s*N_ + i];
}
memcpy(sse_val, sdp_temp, value_size*sizeof(float));
for( int j=0; j<=d_; j++ ){
int o = offset_[i*(d_+1)+j]+1;
__m128 w = _mm_set1_ps( barycentric_[i*(d_+1)+j] );
for( int k=0; k<sse_value_size; k++ )
values[ o*sse_value_size+k ] += w * sse_val[k];
}
}
// Blurring
__m128 half = _mm_set1_ps(0.5);
for( int j=reverse?d_:0; j<=d_ && j>=0; reverse?j--:j++ ){
for( int i=0; i<M_; i++ ){
__m128 * old_val = values + (i+1)*sse_value_size;
__m128 * new_val = new_values + (i+1)*sse_value_size;
int n1 = blur_neighbors_[j*M_+i].n1+1;
int n2 = blur_neighbors_[j*M_+i].n2+1;
__m128 * n1_val = values + n1*sse_value_size;
__m128 * n2_val = values + n2*sse_value_size;
for( int k=0; k<sse_value_size; k++ )
new_val[k] = old_val[k]+half*(n1_val[k] + n2_val[k]);
}
std::swap( values, new_values );
}
// Alpha is a magic scaling constant (write Andrew if you really wanna understand this)
float alpha = 1.0f / (1+powf(2, -d_));
// Slicing
for( int i=0; i<N_; i++ ){
for( int k=0; k<sse_value_size; k++ )
sse_val[ k ] = Zero;
for( int j=0; j<=d_; j++ ){
int o = offset_[i*(d_+1)+j]+1;
__m128 w = _mm_set1_ps( barycentric_[i*(d_+1)+j] * alpha );
for( int k=0; k<sse_value_size; k++ )
sse_val[ k ] += w * values[ o*sse_value_size+k ];
}
memcpy(sdp_temp, sse_val, value_size*sizeof(float) );
if (!add) {
for (int s = 0; s < value_size; s++) {
out[i + s*N_] = sdp_temp[s];
}
} else {
for (int s = 0; s < value_size; s++) {
out[i + s*N_] += sdp_temp[s];
}
}
}
_mm_free( sse_val );
_mm_free( values );
_mm_free( new_values );
delete[] sdp_temp;
}
#else
void Permutohedral::sseCompute( float* out, const float* in, int value_size, bool reverse, bool add) const
{
seqCompute( out, in, value_size, reverse, add);
}
#endif
void Permutohedral::compute(float * out, const float * in, int value_size, bool reverse, bool add) const
{
if (value_size <= 2)
seqCompute(out, in, value_size, reverse, add);
else
sseCompute(out, in, value_size, reverse, add);
}
================================================
FILE: crfasrnn/permutohedral.h
================================================
/*
This file contains a modified version of the "permutohedral.h" code
available at http://graphics.stanford.edu/projects/drf/. Copyright notice of
the original file is included below:
Copyright (c) 2013, Philipp Krähenbühl
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
* Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
* Neither the name of the Stanford University nor the
names of its contributors may be used to endorse or promote products
derived from this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY Philipp Krähenbühl ''AS IS'' AND ANY
EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL Philipp Krähenbühl BE LIABLE FOR ANY
DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
#pragma once
#include <cstdlib>
#include <vector>
#include <cstring>
#include <cassert>
#include <cstdio>
#include <cmath>
/************************************************/
/*** Permutohedral Lattice ***/
/************************************************/
class Permutohedral {
protected:
struct Neighbors {
int n1, n2;
Neighbors(int n1 = 0, int n2 = 0) : n1(n1), n2(n2) {
}
};
std::vector<int> offset_, rank_;
std::vector<float> barycentric_;
std::vector<Neighbors> blur_neighbors_;
// Number of elements, size of sparse discretized space, dimension of features
int N_, M_, d_;
void sseCompute(float *out, const float *in, int value_size, bool reverse = false, bool add = false) const;
void seqCompute(float *out, const float *in, int value_size, bool reverse = false, bool add = false) const;
public:
Permutohedral();
void init(const float *features, int num_dimensions, int num_points);
void compute(float *out, const float *in, int value_size, bool reverse = false, bool add = false) const;
};
================================================
FILE: crfasrnn/setup.py
================================================
from setuptools import setup, Extension
from torch.utils import cpp_extension
setup(name='permuto_cpp',
ext_modules=[cpp_extension.CppExtension('permuto_cpp', ['permuto.cpp', 'permutohedral.cpp'])],
cmdclass={'build_ext': cpp_extension.BuildExtension})
================================================
FILE: crfasrnn/util.py
================================================
"""
MIT License
Copyright (c) 2019 Sadeep Jayasumana
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
"""
import numpy as np
from PIL import Image
# Pascal VOC color palette for labels
_PALETTE = [0, 0, 0,
128, 0, 0,
0, 128, 0,
128, 128, 0,
0, 0, 128,
128, 0, 128,
0, 128, 128,
128, 128, 128,
64, 0, 0,
192, 0, 0,
64, 128, 0,
192, 128, 0,
64, 0, 128,
192, 0, 128,
64, 128, 128,
192, 128, 128,
0, 64, 0,
128, 64, 0,
0, 192, 0,
128, 192, 0,
0, 64, 128,
128, 64, 128,
0, 192, 128,
128, 192, 128,
64, 64, 0,
192, 64, 0,
64, 192, 0,
192, 192, 0]
_IMAGENET_MEANS = np.array([123.68, 116.779, 103.939], dtype=np.float32) # RGB mean values
def get_preprocessed_image(file_name):
"""
Reads an image from the disk, pre-processes it by subtracting mean etc. and
returns a numpy array that's ready to be fed into the PyTorch model.
Args:
file_name: File to read the image from
Returns:
A tuple containing:
(preprocessed image, img_h, img_w, original width & height)
"""
image = Image.open(file_name)
original_size = image.size
w, h = original_size
ratio = min(500.0 / w, 500.0 / h)
image = image.resize((int(w * ratio), int(h * ratio)), resample=Image.BILINEAR)
im = np.array(image).astype(np.float32)
assert im.ndim == 3, 'Only RGB images are supported.'
im = im[:, :, :3]
im = im - _IMAGENET_MEANS
im = im[:, :, ::-1] # Convert to BGR
img_h, img_w, _ = im.shape
pad_h = 500 - img_h
pad_w = 500 - img_w
im = np.pad(im, pad_width=((0, pad_h), (0, pad_w), (0, 0)), mode='constant', constant_values=0)
return np.expand_dims(im.transpose([2, 0, 1]), 0), img_h, img_w, original_size
def get_label_image(probs, img_h, img_w, original_size):
"""
Returns the label image (PNG with Pascal VOC colormap) given the probabilities.
Args:
probs: Probability output of shape (num_labels, height, width)
img_h: Image height
img_w: Image width
original_size: Original image size (width, height)
Returns:
Label image as a PIL Image
"""
labels = probs.argmax(axis=0).astype('uint8')[:img_h, :img_w]
label_im = Image.fromarray(labels, 'P')
label_im.putpalette(_PALETTE)
label_im = label_im.resize(original_size)
return label_im
================================================
FILE: quick_run.py
================================================
"""
MIT License
Copyright (c) 2019 Sadeep Jayasumana
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
"""
import argparse
import torch
from crfasrnn import util
from crfasrnn.crfasrnn_model import CrfRnnNet
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--weights",
help="Path to the .pth file (download from https://tinyurl.com/crfasrnn-weights-pth)",
required=True,
)
parser.add_argument("--image", help="Path to the input image", required=True)
parser.add_argument("--output", help="Path to the output label image", default=None)
args = parser.parse_args()
img_data, img_h, img_w, size = util.get_preprocessed_image(args.image)
output_file = args.output or args.imaage + "_labels.png"
model = CrfRnnNet()
model.load_state_dict(torch.load(args.weights))
model.eval()
out = model.forward(torch.from_numpy(img_data))
probs = out.detach().numpy()[0]
label_im = util.get_label_image(probs, img_h, img_w, size)
label_im.save(output_file)
if __name__ == "__main__":
main()
================================================
FILE: requirements.txt
================================================
torch
torchvision
Pillow
================================================
FILE: run_demo.py
================================================
"""
MIT License
Copyright (c) 2019 Sadeep Jayasumana
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
"""
import torch
from crfasrnn import util
from crfasrnn.crfasrnn_model import CrfRnnNet
def main():
input_file = "image.jpg"
output_file = "labels.png"
# Read the image
img_data, img_h, img_w, size = util.get_preprocessed_image(input_file)
# Download the model from https://tinyurl.com/crfasrnn-weights-pth
saved_weights_path = "crfasrnn_weights.pth"
model = CrfRnnNet()
model.load_state_dict(torch.load(saved_weights_path))
model.eval()
out = model.forward(torch.from_numpy(img_data))
probs = out.detach().numpy()[0]
label_im = util.get_label_image(probs, img_h, img_w, size)
label_im.save(output_file)
if __name__ == "__main__":
main()
gitextract_ryfdypmp/ ├── .gitattributes ├── .gitignore ├── LICENSE ├── README.md ├── crfasrnn/ │ ├── __init__.py │ ├── crfasrnn_model.py │ ├── crfrnn.py │ ├── fcn8s.py │ ├── filters.py │ ├── params.py │ ├── permuto.cpp │ ├── permutohedral.cpp │ ├── permutohedral.h │ ├── setup.py │ └── util.py ├── quick_run.py ├── requirements.txt └── run_demo.py
SYMBOL INDEX (43 symbols across 11 files)
FILE: crfasrnn/crfasrnn_model.py
class CrfRnnNet (line 29) | class CrfRnnNet(Fcn8s):
method __init__ (line 38) | def __init__(self):
method forward (line 42) | def forward(self, image):
FILE: crfasrnn/crfrnn.py
class CrfRnn (line 32) | class CrfRnn(nn.Module):
method __init__ (line 41) | def __init__(self, num_labels, num_iterations=5, crf_init_params=None):
method forward (line 83) | def forward(self, image, logits):
FILE: crfasrnn/fcn8s.py
function _upsampling_weights (line 31) | def _upsampling_weights(in_channels, out_channels, kernel_size):
class Fcn8s (line 46) | class Fcn8s(nn.Module):
method __init__ (line 47) | def __init__(self, n_class=21):
method _initialize_weights (line 120) | def _initialize_weights(self):
method forward (line 133) | def forward(self, image):
FILE: crfasrnn/filters.py
class PermutoFunction (line 39) | class PermutoFunction(torch.autograd.Function):
method forward (line 42) | def forward(ctx, q_in, features):
method backward (line 48) | def backward(ctx, grad_q_out):
function _spatial_features (line 56) | def _spatial_features(image, sigma):
class AbstractFilter (line 80) | class AbstractFilter(ABC):
method __init__ (line 85) | def __init__(self, image):
method apply (line 89) | def apply(self, input_):
method _calc_features (line 94) | def _calc_features(self, image):
method _calc_norm (line 97) | def _calc_norm(self, image):
class SpatialFilter (line 104) | class SpatialFilter(AbstractFilter):
method __init__ (line 109) | def __init__(self, image, gamma):
method _calc_features (line 120) | def _calc_features(self, image):
class BilateralFilter (line 124) | class BilateralFilter(AbstractFilter):
method __init__ (line 129) | def __init__(self, image, alpha, beta):
method _calc_features (line 142) | def _calc_features(self, image):
FILE: crfasrnn/params.py
class DenseCRFParams (line 26) | class DenseCRFParams(object):
method __init__ (line 31) | def __init__(
FILE: crfasrnn/permuto.cpp
function permuto_forward (line 14) | std::vector<at::Tensor> permuto_forward(torch::Tensor input_values, torc...
function permuto_backward (line 56) | std::vector<at::Tensor> permuto_backward(torch::Tensor grads, torch::Ten...
function PYBIND11_MODULE (line 97) | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
FILE: crfasrnn/permutohedral.cpp
class HashTable (line 53) | class HashTable{
method grow (line 58) | void grow(){
method hash (line 79) | size_t hash( const short * k ) {
method HashTable (line 88) | explicit HashTable( int key_size, int n_elements ) : key_size_ ( key_s...
method size (line 90) | int size() const {
method reset (line 93) | void reset() {
method find (line 97) | int find( const short * k, bool create = false ){
FILE: crfasrnn/permutohedral.h
function class (line 43) | class Permutohedral {
FILE: crfasrnn/util.py
function get_preprocessed_image (line 61) | def get_preprocessed_image(file_name):
function get_label_image (line 93) | def get_label_image(probs, img_h, img_w, original_size):
FILE: quick_run.py
function main (line 33) | def main():
FILE: run_demo.py
function main (line 30) | def main():
Condensed preview — 18 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (61K chars).
[
{
"path": ".gitattributes",
"chars": 89,
"preview": "crfasrnn/permutohedral.cpp linguist-vendored\ncrfasrnn/permutohedral.h linguist-vendored\n\n"
},
{
"path": ".gitignore",
"chars": 24,
"preview": ".idea\n__pycache__\n.pyc\n\n"
},
{
"path": "LICENSE",
"chars": 1074,
"preview": "MIT License\n\nCopyright (c) 2017 Sadeep Jayasumana\n\nPermission is hereby granted, free of charge, to any person obtaining"
},
{
"path": "README.md",
"chars": 3206,
"preview": "# CRF-RNN for Semantic Image Segmentation - PyTorch version\n\n\n<b>Live demo:</b> &"
},
{
"path": "crfasrnn/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "crfasrnn/crfasrnn_model.py",
"chars": 1777,
"preview": "\"\"\"\nMIT License\n\nCopyright (c) 2019 Sadeep Jayasumana\n\nPermission is hereby granted, free of charge, to any person obtai"
},
{
"path": "crfasrnn/crfrnn.py",
"chars": 4802,
"preview": "\"\"\"\nMIT License\n\nCopyright (c) 2019 Sadeep Jayasumana\n\nPermission is hereby granted, free of charge, to any person obtai"
},
{
"path": "crfasrnn/fcn8s.py",
"chars": 6692,
"preview": "\"\"\"\nThis file contains a modified version of the FCN-8s code available in https://github.com/wkentaro/pytorch-fcn\nThe or"
},
{
"path": "crfasrnn/filters.py",
"chars": 4384,
"preview": "\"\"\"\nMIT License\n\nCopyright (c) 2019 Sadeep Jayasumana\n\nPermission is hereby granted, free of charge, to any person obtai"
},
{
"path": "crfasrnn/params.py",
"chars": 2115,
"preview": "\"\"\"\nMIT License\n\nCopyright (c) 2019 Sadeep Jayasumana\n\nPermission is hereby granted, free of charge, to any person obtai"
},
{
"path": "crfasrnn/permuto.cpp",
"chars": 3535,
"preview": "#include <torch/extension.h>\n#include <vector>\n#include <iostream>\n#include <stdexcept>\n#include \"permutohedral.h\"\n\n/**\n"
},
{
"path": "crfasrnn/permutohedral.cpp",
"chars": 19147,
"preview": "/*\n This file contains a modified version of the \"permutohedral.cpp\" code\n available at http://graphics.stanford.edu"
},
{
"path": "crfasrnn/permutohedral.h",
"chars": 2823,
"preview": "/*\n This file contains a modified version of the \"permutohedral.h\" code\n available at http://graphics.stanford.edu/p"
},
{
"path": "crfasrnn/setup.py",
"chars": 266,
"preview": "from setuptools import setup, Extension\nfrom torch.utils import cpp_extension\n\nsetup(name='permuto_cpp',\n ext_modul"
},
{
"path": "crfasrnn/util.py",
"chars": 3606,
"preview": "\"\"\"\nMIT License\n\nCopyright (c) 2019 Sadeep Jayasumana\n\nPermission is hereby granted, free of charge, to any person obtai"
},
{
"path": "quick_run.py",
"chars": 2067,
"preview": "\"\"\"\nMIT License\n\nCopyright (c) 2019 Sadeep Jayasumana\n\nPermission is hereby granted, free of charge, to any person obtai"
},
{
"path": "requirements.txt",
"chars": 26,
"preview": "torch\ntorchvision\nPillow\n\n"
},
{
"path": "run_demo.py",
"chars": 1782,
"preview": "\"\"\"\nMIT License\n\nCopyright (c) 2019 Sadeep Jayasumana\n\nPermission is hereby granted, free of charge, to any person obtai"
}
]
About this extraction
This page contains the full source code of the sadeepj/crfasrnn_pytorch GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 18 files (56.1 KB), approximately 16.6k tokens, and a symbol index with 43 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.
Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.