Repository: vene/sparse-structured-attention
Branch: master
Commit: 7003b3deaa51
Files: 19
Total size: 33.2 KB
Directory structure:
gitextract_gpxv25pq/
├── .gitignore
├── .travis.yml
├── LICENSE
├── README.md
└── pytorch/
├── MANIFEST.in
├── setup.py
└── torchsparseattn/
├── __init__.py
├── _fused.pyx
├── _fused_jv.pyx
├── _isotonic.pyx
├── base.py
├── fused.py
├── isotonic.py
├── oscar.py
├── sparsemax.py
├── test_attention.py
├── test_fused.py
├── test_oscar.py
└── test_sparsemax.py
================================================
FILE CONTENTS
================================================
================================================
FILE: .gitignore
================================================
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
env/
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
*.egg-info/
.installed.cfg
*.egg
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*,cover
.hypothesis/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
target/
# IPython Notebook
.ipynb_checkpoints
# pyenv
.python-version
# celery beat schedule file
celerybeat-schedule
# dotenv
.env
# virtualenv
venv/
ENV/
# Spyder project settings
.spyderproject
# Rope project settings
.ropeproject
================================================
FILE: .travis.yml
================================================
sudo: false
language: python
dist: xenial
python:
- "2.7"
- "3.7"
env:
- TORCH_VERSION=1.0.1
- TORCH_VERSION=0.4.1
cache:
apt: true
directories:
- $HOME/.cache/pip
install:
- wget http://repo.continuum.io/miniconda/Miniconda-3.6.0-Linux-x86_64.sh -O miniconda.sh
- bash miniconda.sh -b -p $HOME/miniconda
- export PATH="$HOME/miniconda/bin:$PATH"
- conda config --set always_yes yes --set changeps1 no
- conda update -q conda
- hash -r
- conda info -a
- conda create -q -n testenv python=$TRAVIS_PYTHON_VERSION
- source activate testenv
- conda install -c pytorch pytorch-cpu=$TORCH_VERSION numpy scipy pytest cython
# install package
- cd pytorch
- pip install .
script:
- mkdir empty_dir
- pytest pytest -vs --pyargs torchsparseattn
- cd ..
================================================
FILE: LICENSE
================================================
BSD 3-Clause License
Copyright (c) 2017, Vlad Niculae
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
* Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
* Neither the name of the copyright holder nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
================================================
FILE: README.md
================================================
# Sparse and structured attention mechanisms
[](https://travis-ci.org/vene/sparse-structured-attention)
[](https://badge.fury.io/py/torchsparseattn)

--------------------------------------------------------------------------------
Efficient implementation of structured sparsity inducing
attention mechanisms: fusedmax, oscarmax and sparsemax.
**Note**: If you are just looking for sparsemax, I recommend the implementation in the [entmax](https://github.com/deep-spin/entmax).
Currently available for pytorch >= 0.4.1. (For older versions, use a previous
release of this package.) Requires python >= 2.7, cython, numpy, scipy.
Usage example:
```python
In [1]: import torch
In [2]: import torchsparseattn
In [3]: a = torch.tensor([1, 2.1, 1.9], dtype=torch.double)
In [4]: lengths = torch.tensor([3])
In [5]: fusedmax = torchsparseattn.Fusedmax(alpha=.1)
In [6]: fusedmax(a, lengths)
Out[6]: tensor([0.0000, 0.5000, 0.5000], dtype=torch.float64)
```
For details, check out our paper:
> Vlad Niculae and Mathieu Blondel
> A Regularized Framework for Sparse and Structured Neural Attention
> In: Proceedings of NIPS, 2017.
> https://arxiv.org/abs/1705.07704
See also:
> André F. T. Martins and Ramón Fernandez Astudillo
> From Softmax to Sparsemax: A Sparse Model of Attention and Multi-Label Classification
> In: Proceedings of ICML, 2016
> https://arxiv.org/abs/1602.02068
> X. Zeng and M. Figueiredo,
> The ordered weighted L1 norm: Atomic formulation, dual norm, and projections.
> eprint http://arxiv.org/abs/1409.4271
================================================
FILE: pytorch/MANIFEST.in
================================================
include MANIFEST.in
recursive-include torchsparseattn *.c *.h *.cpp *.pyx *.pxd
================================================
FILE: pytorch/setup.py
================================================
import numpy
from setuptools import setup, find_packages, Extension
from Cython.Build import cythonize
extensions = [
Extension('torchsparseattn._isotonic',
["torchsparseattn/_isotonic.pyx"],
include_dirs=[numpy.get_include()]),
Extension('torchsparseattn._fused',
["torchsparseattn/_fused.pyx"],
include_dirs=[numpy.get_include()]),
Extension('torchsparseattn._fused_jv',
["torchsparseattn/_fused_jv.pyx"]),
]
extensions = cythonize(extensions)
setup(name="torchsparseattn",
version="0.3.dev0",
description="Sparse structured attention mechanisms for pytorch",
author="Vlad Niculae",
author_email="vlad@vene.ro",
license="BSD 3-clause",
packages=find_packages(),
ext_modules=extensions,
install_requires=['numpy'],
zip_safe=False,
classifiers=[
'Intended Audience :: Science/Research',
'Intended Audience :: Developers', 'License :: OSI Approved',
'Programming Language :: C', 'Programming Language :: Python',
'Topic :: Software Development',
'Topic :: Scientific/Engineering',
'Operating System :: Microsoft :: Windows',
'Operating System :: POSIX', 'Operating System :: Unix',
'Operating System :: MacOS']
)
================================================
FILE: pytorch/torchsparseattn/__init__.py
================================================
from .fused import Fusedmax, FusedProxFunction
from .oscar import Oscarmax, OscarProxFunction
from .sparsemax import Sparsemax, SparsemaxFunction
__version__ = __VERSION__ = '0.3.dev0'
================================================
FILE: pytorch/torchsparseattn/_fused.pyx
================================================
# encoding: utf-8
# cython: cdivision=True
# cython: boundscheck=False
# cython: wraparound=False
#
# Authors: Fabian Pedregosa
# Bundled file from lightning library
"""
These are some helper functions to compute the proximal operator of some common penalties
"""
cimport numpy as np
from cython cimport floating
cpdef prox_tv1d(np.ndarray[ndim=1, dtype=floating] w, floating stepsize):
"""
Computes the proximal operator of the 1-dimensional total variation operator.
This solves a problem of the form
argmin_x TV(x) + (1/(2 stepsize)) ||x - w||^2
where TV(x) is the one-dimensional total variation
Parameters
----------
w: array
vector of coefficieents
stepsize: float
step size (sometimes denoted gamma) in proximal objective function
References
----------
Condat, Laurent. "A direct algorithm for 1D total variation denoising."
IEEE Signal Processing Letters (2013)
"""
cdef long width, k, k0, kplus, kminus
cdef floating umin, umax, vmin, vmax, twolambda, minlambda
width = w.size
# /to avoid invalid memory access to input[0] and invalid lambda values
if width > 0 and stepsize >= 0:
k, k0 = 0, 0 # k: current sample location, k0: beginning of current segment
umin = stepsize # u is the dual variable
umax = - stepsize
vmin = w[0] - stepsize
vmax = w[0] + stepsize # bounds for the segment's value
kplus = 0
kminus = 0 # last positions where umax=-lambda, umin=lambda, respectively
twolambda = 2.0 * stepsize # auxiliary variable
minlambda = -stepsize # auxiliary variable
while True: # simple loop, the exit test is inside
while k >= width-1: # we use the right boundary condition
if umin < 0.0: # vmin is too high -> negative jump necessary
while True:
w[k0] = vmin
k0 += 1
if k0 > kminus:
break
k = k0
kminus = k
vmin = w[kminus]
umin = stepsize
umax = vmin + umin - vmax
elif umax > 0.0: # vmax is too low -> positive jump necessary
while True:
w[k0] = vmax
k0 += 1
if k0 > kplus:
break
k = k0
kplus = k
vmax = w[kplus]
umax = minlambda
umin = vmax + umax -vmin
else:
vmin += umin / (k-k0+1)
while True:
w[k0] = vmin
k0 += 1
if k0 > k:
break
return
umin += w[k + 1] - vmin
if umin < minlambda: # negative jump necessary
while True:
w[k0] = vmin
k0 += 1
if k0 > kminus:
break
k = k0
kminus = k
kplus = kminus
vmin = w[kplus]
vmax = vmin + twolambda
umin = stepsize
umax = minlambda
else:
umax += w[k + 1] - vmax
if umax > stepsize:
while True:
w[k0] = vmax
k0 += 1
if k0 > kplus:
break
k = k0
kminus = k
kplus = kminus
vmax = w[kplus]
vmin = vmax - twolambda
umin = stepsize
umax = minlambda
else: # no jump necessary, we continue
k += 1
if umin >= stepsize: # update of vmin
kminus = k
vmin += (umin - stepsize) / (kminus - k0 + 1)
umin = stepsize
if umax <= minlambda: # update of vmax
kplus = k
vmax += (umax + stepsize) / (kplus - k0 + 1)
umax = minlambda
================================================
FILE: pytorch/torchsparseattn/_fused_jv.pyx
================================================
cimport cython
from cython cimport floating
@cython.boundscheck(False)
@cython.wraparound(False)
@cython.cdivision(True)
def _inplace_fused_prox_jv(floating[::1] y_hat, floating[::1] dout):
cdef Py_ssize_t n_features = dout.shape[0]
cdef Py_ssize_t i, last_ix
cdef unsigned int n
cdef floating acc
for i in range(n_features + 1):
if i in (0, n_features) or y_hat[i] != y_hat[i - 1]:
if i > 0:
dout[last_ix:i] = acc / n
if i < n_features:
last_ix = i
acc = dout[i]
n = 1
else:
acc += dout[i]
n += 1
return dout
================================================
FILE: pytorch/torchsparseattn/_isotonic.pyx
================================================
# Author: Nelle Varoquaux, Andrew Tulloch, Antony Lee
# Uses the pool adjacent violators algorithm (PAVA), with the
# enhancement of searching for the longest decreasing subsequence to
# pool at each step.
import numpy as np
cimport numpy as np
cimport cython
from cython cimport floating
@cython.boundscheck(False)
@cython.wraparound(False)
@cython.cdivision(True)
def _inplace_contiguous_isotonic_regression(floating[::1] y, floating[::1] w):
cdef:
Py_ssize_t n = y.shape[0], i, k
floating prev_y, sum_wy, sum_w
Py_ssize_t[::1] target = np.arange(n, dtype=np.intp)
# target describes a list of blocks. At any time, if [i..j] (inclusive) is
# an active block, then target[i] := j and target[j] := i.
# For "active" indices (block starts):
# w[i] := sum{w_orig[j], j=[i..target[i]]}
# y[i] := sum{y_orig[j]*w_orig[j], j=[i..target[i]]} / w[i]
with nogil:
i = 0
while i < n:
k = target[i] + 1
if k == n:
break
if y[i] < y[k]:
i = k
continue
sum_wy = w[i] * y[i]
sum_w = w[i]
while True:
# We are within a decreasing subsequence.
prev_y = y[k]
sum_wy += w[k] * y[k]
sum_w += w[k]
k = target[k] + 1
if k == n or prev_y < y[k]:
# Non-singleton decreasing subsequence is finished,
# update first entry.
y[i] = sum_wy / sum_w
w[i] = sum_w
target[i] = k - 1
target[k - 1] = i
if i > 0:
# Backtrack if we can. This makes the algorithm
# single-pass and ensures O(n) complexity.
i = target[i - 1]
# Otherwise, restart from the same point.
break
# Reconstruct the solution.
i = 0
while i < n:
k = target[i] + 1
y[i + 1 : k] = y[i]
i = k
@cython.boundscheck(False)
@cython.wraparound(False)
@cython.cdivision(True)
def _make_unique(np.ndarray[dtype=floating] X,
np.ndarray[dtype=floating] y,
np.ndarray[dtype=floating] sample_weights):
"""Average targets for duplicate X, drop duplicates.
Aggregates duplicate X values into a single X value where
the target y is a (sample_weighted) average of the individual
targets.
Assumes that X is ordered, so that all duplicates follow each other.
"""
unique_values = len(np.unique(X))
if unique_values == len(X):
return X, y, sample_weights
cdef np.ndarray[dtype=floating] y_out = np.empty(unique_values)
cdef np.ndarray[dtype=floating] x_out = np.empty(unique_values)
cdef np.ndarray[dtype=floating] weights_out = np.empty(unique_values)
cdef floating current_x = X[0]
cdef floating current_y = 0
cdef floating current_weight = 0
cdef floating y_old = 0
cdef int i = 0
cdef int current_count = 0
cdef int j
cdef floating x
cdef int n_samples = len(X)
for j in range(n_samples):
x = X[j]
if x != current_x:
# next unique value
x_out[i] = current_x
weights_out[i] = current_weight / current_count
y_out[i] = current_y / current_weight
i += 1
current_x = x
current_weight = sample_weights[j]
current_y = y[j] * sample_weights[j]
current_count = 1
else:
current_weight += sample_weights[j]
current_y += y[j] * sample_weights[j]
current_count += 1
x_out[i] = current_x
weights_out[i] = current_weight / current_count
y_out[i] = current_y / current_weight
return x_out, y_out, weights_out
================================================
FILE: pytorch/torchsparseattn/base.py
================================================
from torch import nn
from torch import autograd as ta
class _BaseBatchProjection(ta.Function):
"""Applies a sample-wise normalizing projection over a batch."""
def forward(self, x, lengths=None):
requires_squeeze = False
if x.dim() == 1:
x = x.unsqueeze(0)
requires_squeeze = True
n_samples, max_dim = x.size()
has_lengths = True
if lengths is None:
has_lengths = False
lengths = [max_dim] * n_samples
y_star = x.new()
y_star.resize_as_(x)
y_star.zero_()
for i in range(n_samples):
y_star[i, :lengths[i]] = self.project(x[i, :lengths[i]])
if requires_squeeze:
y_star = y_star.squeeze()
self.mark_non_differentiable(y_star)
if has_lengths:
self.mark_non_differentiable(lengths)
self.save_for_backward(y_star, lengths)
else:
self.save_for_backward(y_star)
return y_star
def backward(self, dout):
if not self.needs_input_grad[0]:
return None
if len(self.needs_input_grad) > 1 and self.needs_input_grad[1]:
raise ValueError("Cannot differentiate {} w.r.t. the "
"sequence lengths".format(self.__name__))
saved = self.saved_tensors
if len(saved) == 2:
y_star, lengths = saved
else:
y_star, = saved
lengths = None
requires_squeeze = False
if y_star.dim() == 1:
y_star = y_star.unsqueeze(0)
dout = dout.unsqueeze(0)
requires_squeeze = True
n_samples, max_dim = y_star.size()
din = dout.new()
din.resize_as_(y_star)
din.zero_()
if lengths is None:
lengths = [max_dim] * n_samples
for i in range(n_samples):
din[i, :lengths[i]] = self.project_jv(dout[i, :lengths[i]],
y_star[i, :lengths[i]])
if requires_squeeze:
din = din.squeeze()
return din, None
================================================
FILE: pytorch/torchsparseattn/fused.py
================================================
"""Fusedmax attention
Clusters neighboring attention weights into groups with equal weight.
A Regularized Framework for Sparse and Structured Neural Attention
Vlad Niculae, Mathieu Blondel
https://arxiv.org/abs/1705.07704
"""
from __future__ import division
import torch
from torch import nn
from torch import autograd as ta
import warnings
from .base import _BaseBatchProjection
from .sparsemax import SparsemaxFunction
from ._fused import prox_tv1d
def _inplace_fused_prox_jv_slow(y_hat, dout):
"""not efficient in python for long seqs, but template for a cython impl"""
n_features = len(dout)
for i in range(n_features + 1):
if i in (0, n_features) or y_hat[i] != y_hat[i - 1]:
if i > 0:
dout[last_ix:i] = acc / n
if i < n_features:
last_ix = i
acc = dout[i]
n = 1
else:
acc += dout[i]
n += 1
return dout
try:
from ._fused_jv import _inplace_fused_prox_jv
except ImportError:
warnings.warn("Could not import cython implementation of fused backward "
"pass. Slow implementation used instead.")
_inplace_fused_prox_jv = _inplace_fused_prox_jv_slow
def fused_prox_jv_slow(y_hat, dout):
dout = dout.clone()
_inplace_fused_prox_jv_slow(y_hat, dout)
return dout
def fused_prox_jv_fast(y_hat, dout):
dout = dout.clone()
_inplace_fused_prox_jv(y_hat.detach().numpy(), dout.numpy())
return dout
class FusedProxFunction(_BaseBatchProjection):
def __init__(self, alpha=1):
self.alpha = alpha
def project(self, x):
x_np = x.detach().numpy().copy()
prox_tv1d(x_np, self.alpha)
y_hat = torch.from_numpy(x_np)
return y_hat
def project_jv(self, dout, y_hat):
dout = dout.clone()
_inplace_fused_prox_jv(y_hat.detach().numpy(), dout.numpy())
return dout
class Fusedmax(nn.Module):
def __init__(self, alpha=1):
self.alpha = alpha
super(Fusedmax, self).__init__()
def forward(self, x, lengths=None):
fused_prox = FusedProxFunction(self.alpha)
sparsemax = SparsemaxFunction()
return sparsemax(fused_prox(x, lengths), lengths)
if __name__ == '__main__':
from timeit import timeit
torch.manual_seed(1)
for dim in (5, 10, 50, 100, 500, 1000):
x = torch.randn(dim)
x_var = ta.Variable(x, requires_grad=True)
y_hat = FusedProxFunction()(x_var).data
dout = torch.arange(0, dim)
print("dimension={}".format(dim))
print("slow", timeit("fused_prox_jv_slow(y_hat, dout)",
globals=globals(),
number=10000))
print("fast", timeit("fused_prox_jv_fast(y_hat, dout)",
globals=globals(),
number=10000))
================================================
FILE: pytorch/torchsparseattn/isotonic.py
================================================
"""
Isotonic Regression that preserves 32bit inputs.
backported from scikit-learn pull request
https://github.com/scikit-learn/scikit-learn/pull/9106"""
import numpy as np
from ._isotonic import _inplace_contiguous_isotonic_regression
def isotonic_regression(y, sample_weight=None, y_min=None, y_max=None,
increasing=True):
"""Solve the isotonic regression model::
min sum w[i] (y[i] - y_[i]) ** 2
subject to y_min = y_[1] <= y_[2] ... <= y_[n] = y_max
where:
- y[i] are inputs (real numbers)
- y_[i] are fitted
- w[i] are optional strictly positive weights (default to 1.0)
Read more in the :ref:`User Guide `.
Parameters
----------
y : iterable of floating-point values
The data.
sample_weight : iterable of floating-point values, optional, default: None
Weights on each point of the regression.
If None, weight is set to 1 (equal weights).
y_min : optional, default: None
If not None, set the lowest value of the fit to y_min.
y_max : optional, default: None
If not None, set the highest value of the fit to y_max.
increasing : boolean, optional, default: True
Whether to compute ``y_`` is increasing (if set to True) or decreasing
(if set to False)
Returns
-------
y_ : list of floating-point values
Isotonic fit of y.
References
----------
"Active set algorithms for isotonic regression; A unifying framework"
by Michael J. Best and Nilotpal Chakravarti, section 3.
"""
order = np.s_[:] if increasing else np.s_[::-1]
# y = as_float_array(y) # avoid sklearn dependency; we always pass arrays
y = np.array(y[order], dtype=y.dtype)
if sample_weight is None:
sample_weight = np.ones(len(y), dtype=y.dtype)
else:
sample_weight = np.array(sample_weight[order], dtype=y.dtype)
_inplace_contiguous_isotonic_regression(y, sample_weight)
if y_min is not None or y_max is not None:
# Older versions of np.clip don't accept None as a bound, so use np.inf
if y_min is None:
y_min = -np.inf
if y_max is None:
y_max = np.inf
np.clip(y, y_min, y_max, y)
return y[order]
================================================
FILE: pytorch/torchsparseattn/oscar.py
================================================
"""Oscarmax attention
Clusters attention weights into groups with equal weight, regardless of index.
A Regularized Framework for Sparse and Structured Neural Attention
Vlad Niculae, Mathieu Blondel
https://arxiv.org/abs/1705.07704
"""
import numpy as np
import torch
from torch import nn
from torch import autograd as ta
from .isotonic import isotonic_regression
from .base import _BaseBatchProjection
from .sparsemax import SparsemaxFunction
def oscar_prox_jv(y_hat, dout):
y_hat = y_hat.detach().numpy()
din = dout.clone().zero_()
dout = dout.numpy()
din_np = din.numpy()
sign = np.sign(y_hat)
y_hat = np.abs(y_hat)
uniq, inv, counts = np.unique(y_hat, return_inverse=True,
return_counts=True)
n_unique = len(uniq)
tmp = np.zeros((n_unique,), dtype=y_hat.dtype)
np.add.at(tmp, inv, dout * sign)
tmp /= counts
tmp.take(inv, mode='clip', out=din_np)
din_np *= sign
return din
def prox_owl(v, w):
"""Proximal operator of the OWL norm dot(w, reversed(sort(v)))
Follows description and notation from:
X. Zeng, M. Figueiredo,
The ordered weighted L1 norm: Atomic formulation, dual norm,
and projections.
eprint http://arxiv.org/abs/1409.4271
"""
# wlog operate on absolute values
v_abs = np.abs(v)
ix = np.argsort(v_abs)[::-1]
v_abs = v_abs[ix]
# project to K+ (monotone non-negative decreasing cone)
v_abs = isotonic_regression(v_abs - w, y_min=0, increasing=False)
# undo the sorting
inv_ix = np.zeros_like(ix)
inv_ix[ix] = np.arange(len(v))
v_abs = v_abs[inv_ix]
return np.sign(v) * v_abs
def _oscar_weights(alpha, beta, size):
w = np.arange(size - 1, -1, -1, dtype=np.float32)
w *= beta
w += alpha
return w
class OscarProxFunction(_BaseBatchProjection):
"""Proximal operator of the OSCAR regularizer.
||w||_oscar = alpha ||w||_1 + beta * sum_i 0
rho = ind.masked_select(cond)[-1]
tau = cssv.masked_select(cond)[-1] / rho
w = torch.clamp(v - tau, min=0)
return w
def sparsemax_grad(dout, w_star):
supp = w_star > 0
masked = dout.masked_select(supp)
nnz = supp.to(dtype=dout.dtype).sum()
masked -= masked.sum() / nnz
out = dout.new(dout.size()).zero_()
out[supp] = masked
return(out)
class SparsemaxFunction(_BaseBatchProjection):
def project(self, x):
return project_simplex(x)
def project_jv(self, dout, y_star):
return sparsemax_grad(dout, y_star)
class Sparsemax(nn.Module):
def forward(self, x, lengths=None):
sparsemax = SparsemaxFunction()
return sparsemax(x, lengths)
================================================
FILE: pytorch/torchsparseattn/test_attention.py
================================================
import pytest
import torch
from torch import nn
from torch.autograd import Variable
from . import Sparsemax, Fusedmax, Oscarmax
class AttentionRegressor(nn.Module):
def __init__(self, projection, n_features=100):
super(AttentionRegressor, self).__init__()
self.projection = projection
self.attn_template = nn.Parameter(torch.Tensor(n_features))
self.attn_template.data.uniform_(-0.1, 0.1)
def forward(self, X, lengths):
# compute scores for each input word
scores = torch.matmul(X, self.attn_template)
weights = self.projection(scores, lengths)
weighted_avg = torch.bmm(X.transpose(1, 2),
weights.unsqueeze(-1)).squeeze(-1)
pred = weighted_avg.sum(dim=1) # very simple prediction rule
return pred
@pytest.mark.parametrize('projection', [Sparsemax(),
Fusedmax(0.1),
Oscarmax(0.01)])
def test_attention(projection):
n_samples = 20
max_len = 10
torch.manual_seed(1)
n_features = 50
X = torch.zeros(n_samples, max_len, n_features)
# generate lengths in [1, max_len]
lengths = 1 + (torch.rand(n_samples) * max_len).long()
for i in range(n_samples):
X[i, :lengths[i], :] = torch.randn(lengths[i], n_features)
X = Variable(X)
lengths = Variable(lengths)
targets = Variable(torch.randn(n_samples))
regr = AttentionRegressor(projection, n_features=n_features)
loss_func = nn.MSELoss()
optim = torch.optim.SGD(regr.parameters(), lr=0.0001)
pred = regr(X, lengths)
init_obj = loss_func(pred, targets)
for it in range(50):
optim.zero_grad()
pred = regr(X, lengths)
obj = loss_func(pred, targets)
obj.backward()
optim.step()
final_obj = obj
assert final_obj < init_obj
assert regr.attn_template.grad.size() == (n_features,)
================================================
FILE: pytorch/torchsparseattn/test_fused.py
================================================
from __future__ import division
import pytest
from numpy.testing import assert_allclose
import torch
from torch.autograd import gradcheck, Variable
from .fused import fused_prox_jv_slow, fused_prox_jv_fast
from .fused import FusedProxFunction
def _fused_prox_jacobian(y_hat, dout=None):
"""reference naive implementation: construct the jacobian"""
dim = y_hat.shape[0]
groups = torch.zeros(dim)
J = torch.zeros(dim, dim)
current_group = 0
for i in range(1, dim):
if y_hat[i] == y_hat[i - 1]:
groups[i] = groups[i - 1]
else:
current_group += 1
groups[i] = current_group
for i in range(dim):
for j in range(dim):
if groups[i] == groups[j]:
n_fused = (groups == groups[i]).sum()
J[i, j] = 1 / n_fused.to(y_hat.dtype)
if dout is not None:
return torch.mv(J, dout)
else:
return J
@pytest.mark.parametrize('alpha', [0.001, 0.01, 0.1, 1])
def test_jv(alpha):
torch.manual_seed(1)
torch.set_default_tensor_type('torch.DoubleTensor')
for _ in range(30):
x = Variable(torch.randn(15))
dout = torch.randn(15)
y_hat = FusedProxFunction(alpha=alpha)(x).data
ref = _fused_prox_jacobian(y_hat, dout)
din_slow = fused_prox_jv_slow(y_hat, dout)
din_fast = fused_prox_jv_fast(y_hat, dout)
assert_allclose(ref.numpy(), din_slow.numpy(), atol=1e-5)
assert_allclose(ref.numpy(), din_fast.numpy(), atol=1e-5)
@pytest.mark.parametrize('alpha', [0.001, 0.01, 0.1, 1])
def test_finite_diff(alpha):
torch.manual_seed(1)
torch.set_default_tensor_type('torch.DoubleTensor')
for _ in range(30):
x = Variable(torch.randn(20), requires_grad=True)
func = FusedProxFunction(alpha=alpha)
assert gradcheck(func, (x,), eps=1e-4, atol=1e-3)
================================================
FILE: pytorch/torchsparseattn/test_oscar.py
================================================
from __future__ import division
import pytest
from numpy.testing import assert_allclose
import numpy as np
import torch
from torch.autograd import gradcheck, Variable
from .oscar import OscarProxFunction, oscar_prox_jv
def _oscar_prox_jacobian(y_star, dout=None):
y_star = y_star.numpy()
dim = y_star.shape[0]
J = torch.zeros(dim, dim)
_, inv, counts = np.unique(np.abs(y_star),
return_inverse=True,
return_counts=True)
for i in range(dim):
for j in range(dim):
if (inv[i] == inv[j] and
y_star[i] != 0):
J[i, j] = (np.sign(y_star[i]) * np.sign(y_star[j])
/ counts[inv[i]])
if dout is not None:
return torch.mv(J, dout)
else:
return J
@pytest.mark.parametrize('alpha', [0.001, 0.01, 0.1, 1])
@pytest.mark.parametrize('beta', [0.001, 0.01, 0.1, 1])
def test_jv(alpha, beta):
torch.manual_seed(1)
torch.set_default_tensor_type('torch.DoubleTensor')
for _ in range(30):
x = Variable(torch.randn(15))
dout = torch.randn(15)
y_hat = OscarProxFunction(alpha=alpha, beta=beta)(x).data
ref = _oscar_prox_jacobian(y_hat, dout)
din = oscar_prox_jv(y_hat, dout)
assert_allclose(ref.numpy(), din.numpy(), atol=1e-5)
@pytest.mark.parametrize('alpha', [0.001, 0.01, 0.1, 1])
@pytest.mark.parametrize('beta', [0.001, 0.01, 0.1, 1])
def test_finite_diff(alpha, beta):
torch.manual_seed(1)
torch.set_default_tensor_type('torch.DoubleTensor')
for _ in range(30):
x = Variable(torch.randn(20), requires_grad=True)
func = OscarProxFunction(alpha, beta=beta)
assert gradcheck(func, (x,), eps=1e-5, atol=1e-3)
================================================
FILE: pytorch/torchsparseattn/test_sparsemax.py
================================================
import torch
from torch.autograd import gradcheck, Variable
from .sparsemax import SparsemaxFunction
def test_sparsemax():
torch.manual_seed(1)
torch.set_default_tensor_type('torch.DoubleTensor')
for _ in range(30):
func = SparsemaxFunction()
x = Variable(torch.randn(20), requires_grad=True)
assert gradcheck(func, (x,), eps=1e-4, atol=1e-3)