master 7003b3deaa51 cached
19 files
33.2 KB
9.3k tokens
43 symbols
1 requests
Download .txt
Repository: vene/sparse-structured-attention
Branch: master
Commit: 7003b3deaa51
Files: 19
Total size: 33.2 KB

Directory structure:
gitextract_gpxv25pq/

├── .gitignore
├── .travis.yml
├── LICENSE
├── README.md
└── pytorch/
    ├── MANIFEST.in
    ├── setup.py
    └── torchsparseattn/
        ├── __init__.py
        ├── _fused.pyx
        ├── _fused_jv.pyx
        ├── _isotonic.pyx
        ├── base.py
        ├── fused.py
        ├── isotonic.py
        ├── oscar.py
        ├── sparsemax.py
        ├── test_attention.py
        ├── test_fused.py
        ├── test_oscar.py
        └── test_sparsemax.py

================================================
FILE CONTENTS
================================================

================================================
FILE: .gitignore
================================================
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class

# C extensions
*.so

# Distribution / packaging
.Python
env/
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
*.egg-info/
.installed.cfg
*.egg

# PyInstaller
#  Usually these files are written by a python script from a template
#  before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec

# Installer logs
pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
htmlcov/
.tox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*,cover
.hypothesis/

# Translations
*.mo
*.pot

# Django stuff:
*.log
local_settings.py

# Flask stuff:
instance/
.webassets-cache

# Scrapy stuff:
.scrapy

# Sphinx documentation
docs/_build/

# PyBuilder
target/

# IPython Notebook
.ipynb_checkpoints

# pyenv
.python-version

# celery beat schedule file
celerybeat-schedule

# dotenv
.env

# virtualenv
venv/
ENV/

# Spyder project settings
.spyderproject

# Rope project settings
.ropeproject


================================================
FILE: .travis.yml
================================================
sudo: false
language: python
dist: xenial

python:
  - "2.7"
  - "3.7"

env:
  - TORCH_VERSION=1.0.1
  - TORCH_VERSION=0.4.1

cache:
  apt: true
  directories:
    - $HOME/.cache/pip

install:

  - wget http://repo.continuum.io/miniconda/Miniconda-3.6.0-Linux-x86_64.sh -O miniconda.sh
  - bash miniconda.sh -b -p $HOME/miniconda
  - export PATH="$HOME/miniconda/bin:$PATH"
  - conda config --set always_yes yes --set changeps1 no
  - conda update -q conda
  - hash -r
  - conda info -a
  - conda create -q -n testenv python=$TRAVIS_PYTHON_VERSION
  - source activate testenv
  - conda install -c pytorch pytorch-cpu=$TORCH_VERSION numpy scipy pytest cython

  # install package
  - cd pytorch
  - pip install .

script:
  - mkdir empty_dir
  - pytest pytest -vs --pyargs torchsparseattn
  - cd .. 


================================================
FILE: LICENSE
================================================
BSD 3-Clause License

Copyright (c) 2017, Vlad Niculae
All rights reserved.

Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:

* Redistributions of source code must retain the above copyright notice, this
  list of conditions and the following disclaimer.

* Redistributions in binary form must reproduce the above copyright notice,
  this list of conditions and the following disclaimer in the documentation
  and/or other materials provided with the distribution.

* Neither the name of the copyright holder nor the names of its
  contributors may be used to endorse or promote products derived from
  this software without specific prior written permission.

THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.


================================================
FILE: README.md
================================================
# Sparse and structured attention mechanisms
[![Build Status](https://travis-ci.org/vene/sparse-structured-attention.svg?branch=master)](https://travis-ci.org/vene/sparse-structured-attention)
[![PyPI version](https://badge.fury.io/py/torchsparseattn.svg)](https://badge.fury.io/py/torchsparseattn)

<p align="center"><img src="fusedmax.png" /></p>

--------------------------------------------------------------------------------

Efficient implementation of structured sparsity inducing
attention mechanisms: fusedmax, oscarmax and sparsemax.

**Note**: If you are just looking for sparsemax, I recommend the implementation in the [entmax](https://github.com/deep-spin/entmax).

Currently available for pytorch >= 0.4.1. (For older versions, use a previous
release of this package.) Requires python >= 2.7, cython, numpy, scipy.

Usage example:

```python

In [1]: import torch
In [2]: import torchsparseattn
In [3]: a = torch.tensor([1, 2.1, 1.9], dtype=torch.double)
In [4]: lengths = torch.tensor([3])
In [5]: fusedmax = torchsparseattn.Fusedmax(alpha=.1)
In [6]: fusedmax(a, lengths)
Out[6]: tensor([0.0000, 0.5000, 0.5000], dtype=torch.float64)
```

For details, check out our paper:

> Vlad Niculae and Mathieu Blondel
> A Regularized Framework for Sparse and Structured Neural Attention
> In: Proceedings of NIPS, 2017. 
> https://arxiv.org/abs/1705.07704 

See also:

> André F. T. Martins and Ramón Fernandez Astudillo
> From Softmax to Sparsemax: A Sparse Model of Attention and Multi-Label Classification
> In: Proceedings of ICML, 2016
> https://arxiv.org/abs/1602.02068

> X. Zeng and M. Figueiredo,
> The ordered weighted L1 norm: Atomic formulation, dual norm, and projections.
> eprint http://arxiv.org/abs/1409.4271



================================================
FILE: pytorch/MANIFEST.in
================================================
include MANIFEST.in
recursive-include torchsparseattn *.c *.h *.cpp *.pyx *.pxd


================================================
FILE: pytorch/setup.py
================================================
import numpy
from setuptools import setup, find_packages, Extension

from Cython.Build import cythonize

extensions = [
    Extension('torchsparseattn._isotonic',
              ["torchsparseattn/_isotonic.pyx"],
              include_dirs=[numpy.get_include()]),
    Extension('torchsparseattn._fused',
              ["torchsparseattn/_fused.pyx"],
              include_dirs=[numpy.get_include()]),
    Extension('torchsparseattn._fused_jv',
              ["torchsparseattn/_fused_jv.pyx"]),
]

extensions = cythonize(extensions)


setup(name="torchsparseattn",
      version="0.3.dev0",
      description="Sparse structured attention mechanisms for pytorch",
      author="Vlad Niculae",
      author_email="vlad@vene.ro",
      license="BSD 3-clause",
      packages=find_packages(),
      ext_modules=extensions,
      install_requires=['numpy'],
      zip_safe=False,
      classifiers=[
          'Intended Audience :: Science/Research',
          'Intended Audience :: Developers', 'License :: OSI Approved',
          'Programming Language :: C', 'Programming Language :: Python',
          'Topic :: Software Development',
          'Topic :: Scientific/Engineering',
          'Operating System :: Microsoft :: Windows',
          'Operating System :: POSIX', 'Operating System :: Unix',
          'Operating System :: MacOS']
)


================================================
FILE: pytorch/torchsparseattn/__init__.py
================================================
from .fused import Fusedmax, FusedProxFunction
from .oscar import Oscarmax, OscarProxFunction
from .sparsemax import Sparsemax, SparsemaxFunction

__version__ = __VERSION__ = '0.3.dev0'


================================================
FILE: pytorch/torchsparseattn/_fused.pyx
================================================
# encoding: utf-8
# cython: cdivision=True
# cython: boundscheck=False
# cython: wraparound=False
#
# Authors: Fabian Pedregosa
# Bundled file from lightning library

"""
These are some helper functions to compute the proximal operator of some common penalties
"""

cimport numpy as np
from cython cimport floating

cpdef prox_tv1d(np.ndarray[ndim=1, dtype=floating] w, floating stepsize):
    """
    Computes the proximal operator of the 1-dimensional total variation operator.

    This solves a problem of the form

         argmin_x TV(x) + (1/(2 stepsize)) ||x - w||^2

    where TV(x) is the one-dimensional total variation

    Parameters
    ----------
    w: array
        vector of coefficieents
    stepsize: float
        step size (sometimes denoted gamma) in proximal objective function

    References
    ----------
    Condat, Laurent. "A direct algorithm for 1D total variation denoising."
    IEEE Signal Processing Letters (2013)
    """
    cdef long width, k, k0, kplus, kminus
    cdef floating umin, umax, vmin, vmax, twolambda, minlambda
    width = w.size

    # /to avoid invalid memory access to input[0] and invalid lambda values
    if width > 0 and stepsize >= 0:
        k, k0 = 0, 0			# k: current sample location, k0: beginning of current segment
        umin = stepsize  # u is the dual variable
        umax = - stepsize
        vmin = w[0] - stepsize
        vmax = w[0] + stepsize	  # bounds for the segment's value
        kplus = 0
        kminus = 0 	# last positions where umax=-lambda, umin=lambda, respectively
        twolambda = 2.0 * stepsize  # auxiliary variable
        minlambda = -stepsize		# auxiliary variable
        while True:				# simple loop, the exit test is inside
            while k >= width-1: 	# we use the right boundary condition
                if umin < 0.0:			# vmin is too high -> negative jump necessary
                    while True:
                        w[k0] = vmin
                        k0 += 1
                        if k0 > kminus:
                            break
                    k = k0
                    kminus = k
                    vmin = w[kminus]
                    umin = stepsize
                    umax = vmin + umin - vmax
                elif umax > 0.0:    # vmax is too low -> positive jump necessary
                    while True:
                        w[k0] = vmax
                        k0 += 1
                        if k0 > kplus:
                            break
                    k = k0
                    kplus = k
                    vmax = w[kplus]
                    umax = minlambda
                    umin = vmax + umax -vmin
                else:
                    vmin += umin / (k-k0+1)
                    while True:
                        w[k0] = vmin
                        k0 += 1
                        if k0 > k:
                            break
                    return
            umin += w[k + 1] - vmin
            if umin < minlambda:       # negative jump necessary
                while True:
                    w[k0] = vmin
                    k0 += 1
                    if k0 > kminus:
                        break
                k = k0
                kminus = k
                kplus = kminus
                vmin = w[kplus]
                vmax = vmin + twolambda
                umin = stepsize
                umax = minlambda
            else:
                umax += w[k + 1] - vmax
                if umax > stepsize:
                    while True:
                        w[k0] = vmax
                        k0 += 1
                        if k0 > kplus:
                            break
                    k = k0
                    kminus = k
                    kplus = kminus
                    vmax = w[kplus]
                    vmin = vmax - twolambda
                    umin = stepsize
                    umax = minlambda
                else:                   # no jump necessary, we continue
                    k += 1
                    if umin >= stepsize:		# update of vmin
                        kminus = k
                        vmin += (umin - stepsize) / (kminus - k0 + 1)
                        umin = stepsize
                    if umax <= minlambda:	    # update of vmax
                        kplus = k
                        vmax += (umax + stepsize) / (kplus - k0 + 1)
                        umax = minlambda


================================================
FILE: pytorch/torchsparseattn/_fused_jv.pyx
================================================
cimport cython
from cython cimport floating


@cython.boundscheck(False)
@cython.wraparound(False)
@cython.cdivision(True)
def _inplace_fused_prox_jv(floating[::1] y_hat, floating[::1] dout):
    cdef Py_ssize_t n_features = dout.shape[0]
    cdef Py_ssize_t i, last_ix
    cdef unsigned int n
    cdef floating acc
    for i in range(n_features + 1):
        if i in (0, n_features) or y_hat[i] != y_hat[i - 1]:
            if i > 0:
                dout[last_ix:i] = acc / n

            if i < n_features:
                last_ix = i
                acc = dout[i]
                n = 1

        else:
            acc += dout[i]
            n += 1
    return dout





================================================
FILE: pytorch/torchsparseattn/_isotonic.pyx
================================================
# Author: Nelle Varoquaux, Andrew Tulloch, Antony Lee

# Uses the pool adjacent violators algorithm (PAVA), with the
# enhancement of searching for the longest decreasing subsequence to
# pool at each step.

import numpy as np
cimport numpy as np
cimport cython
from cython cimport floating


@cython.boundscheck(False)
@cython.wraparound(False)
@cython.cdivision(True)
def _inplace_contiguous_isotonic_regression(floating[::1] y, floating[::1] w):
    cdef:
        Py_ssize_t n = y.shape[0], i, k
        floating prev_y, sum_wy, sum_w
        Py_ssize_t[::1] target = np.arange(n, dtype=np.intp)

    # target describes a list of blocks.  At any time, if [i..j] (inclusive) is
    # an active block, then target[i] := j and target[j] := i.

    # For "active" indices (block starts):
    # w[i] := sum{w_orig[j], j=[i..target[i]]}
    # y[i] := sum{y_orig[j]*w_orig[j], j=[i..target[i]]} / w[i]

    with nogil:
        i = 0
        while i < n:
            k = target[i] + 1
            if k == n:
                break
            if y[i] < y[k]:
                i = k
                continue
            sum_wy = w[i] * y[i]
            sum_w = w[i]
            while True:
                # We are within a decreasing subsequence.
                prev_y = y[k]
                sum_wy += w[k] * y[k]
                sum_w += w[k]
                k = target[k] + 1
                if k == n or prev_y < y[k]:
                    # Non-singleton decreasing subsequence is finished,
                    # update first entry.
                    y[i] = sum_wy / sum_w
                    w[i] = sum_w
                    target[i] = k - 1
                    target[k - 1] = i
                    if i > 0:
                        # Backtrack if we can.  This makes the algorithm
                        # single-pass and ensures O(n) complexity.
                        i = target[i - 1]
                    # Otherwise, restart from the same point.
                    break
        # Reconstruct the solution.
        i = 0
        while i < n:
            k = target[i] + 1
            y[i + 1 : k] = y[i]
            i = k


@cython.boundscheck(False)
@cython.wraparound(False)
@cython.cdivision(True)
def _make_unique(np.ndarray[dtype=floating] X,
                 np.ndarray[dtype=floating] y,
                 np.ndarray[dtype=floating] sample_weights):
    """Average targets for duplicate X, drop duplicates.

    Aggregates duplicate X values into a single X value where
    the target y is a (sample_weighted) average of the individual
    targets.

    Assumes that X is ordered, so that all duplicates follow each other.
    """
    unique_values = len(np.unique(X))
    if unique_values == len(X):
        return X, y, sample_weights
    cdef np.ndarray[dtype=floating] y_out = np.empty(unique_values)
    cdef np.ndarray[dtype=floating] x_out = np.empty(unique_values)
    cdef np.ndarray[dtype=floating] weights_out = np.empty(unique_values)

    cdef floating current_x = X[0]
    cdef floating current_y = 0
    cdef floating current_weight = 0
    cdef floating y_old = 0
    cdef int i = 0
    cdef int current_count = 0
    cdef int j
    cdef floating x
    cdef int n_samples = len(X)
    for j in range(n_samples):
        x = X[j]
        if x != current_x:
            # next unique value
            x_out[i] = current_x
            weights_out[i] = current_weight / current_count
            y_out[i] = current_y / current_weight
            i += 1
            current_x = x
            current_weight = sample_weights[j]
            current_y = y[j] * sample_weights[j]
            current_count = 1
        else:
            current_weight += sample_weights[j]
            current_y += y[j] * sample_weights[j]
            current_count += 1

    x_out[i] = current_x
    weights_out[i] = current_weight / current_count
    y_out[i] = current_y / current_weight
    return x_out, y_out, weights_out


================================================
FILE: pytorch/torchsparseattn/base.py
================================================
from torch import nn
from torch import autograd as ta

class _BaseBatchProjection(ta.Function):
    """Applies a sample-wise normalizing projection over a batch."""

    def forward(self, x, lengths=None):

        requires_squeeze = False
        if x.dim() == 1:
            x = x.unsqueeze(0)
            requires_squeeze = True

        n_samples, max_dim = x.size()

        has_lengths = True
        if lengths is None:
            has_lengths = False
            lengths = [max_dim] * n_samples

        y_star = x.new()
        y_star.resize_as_(x)
        y_star.zero_()

        for i in range(n_samples):
            y_star[i, :lengths[i]] = self.project(x[i, :lengths[i]])

        if requires_squeeze:
            y_star = y_star.squeeze()

        self.mark_non_differentiable(y_star)
        if has_lengths:
            self.mark_non_differentiable(lengths)
            self.save_for_backward(y_star, lengths)
        else:
            self.save_for_backward(y_star)

        return y_star

    def backward(self, dout):

        if not self.needs_input_grad[0]:
            return None

        if len(self.needs_input_grad) > 1 and self.needs_input_grad[1]:
            raise ValueError("Cannot differentiate {} w.r.t. the "
                             "sequence lengths".format(self.__name__))

        saved = self.saved_tensors
        if len(saved) == 2:
            y_star, lengths = saved
        else:
            y_star, = saved
            lengths = None

        requires_squeeze = False
        if y_star.dim() == 1:
            y_star = y_star.unsqueeze(0)
            dout = dout.unsqueeze(0)
            requires_squeeze = True

        n_samples, max_dim = y_star.size()
        din = dout.new()
        din.resize_as_(y_star)
        din.zero_()

        if lengths is None:
            lengths = [max_dim] * n_samples

        for i in range(n_samples):
            din[i, :lengths[i]] = self.project_jv(dout[i, :lengths[i]],
                                                  y_star[i, :lengths[i]])

        if requires_squeeze:
            din = din.squeeze()

        return din, None


================================================
FILE: pytorch/torchsparseattn/fused.py
================================================
"""Fusedmax attention

Clusters neighboring attention weights into groups with equal weight.

A Regularized Framework for Sparse and Structured Neural Attention
Vlad Niculae, Mathieu Blondel
https://arxiv.org/abs/1705.07704
"""

from __future__ import division

import torch
from torch import nn
from torch import autograd as ta
import warnings

from .base import _BaseBatchProjection
from .sparsemax import SparsemaxFunction
from ._fused import prox_tv1d


def _inplace_fused_prox_jv_slow(y_hat, dout):
    """not efficient in python for long seqs, but template for a cython impl"""

    n_features = len(dout)

    for i in range(n_features + 1):
        if i in (0, n_features) or y_hat[i] != y_hat[i - 1]:
            if i > 0:
                dout[last_ix:i] = acc / n

            if i < n_features:
                last_ix = i
                acc = dout[i]
                n = 1
        else:
            acc += dout[i]
            n += 1
    return dout


try:
    from ._fused_jv import _inplace_fused_prox_jv
except ImportError:
    warnings.warn("Could not import cython implementation of fused backward "
                  "pass. Slow implementation used instead.")
    _inplace_fused_prox_jv = _inplace_fused_prox_jv_slow


def fused_prox_jv_slow(y_hat, dout):
    dout = dout.clone()
    _inplace_fused_prox_jv_slow(y_hat, dout)
    return dout


def fused_prox_jv_fast(y_hat, dout):
    dout = dout.clone()
    _inplace_fused_prox_jv(y_hat.detach().numpy(), dout.numpy())
    return dout


class FusedProxFunction(_BaseBatchProjection):

    def __init__(self, alpha=1):
        self.alpha = alpha

    def project(self, x):
        x_np = x.detach().numpy().copy()
        prox_tv1d(x_np, self.alpha)
        y_hat = torch.from_numpy(x_np)
        return y_hat

    def project_jv(self, dout, y_hat):
        dout = dout.clone()
        _inplace_fused_prox_jv(y_hat.detach().numpy(), dout.numpy())
        return dout


class Fusedmax(nn.Module):
    def __init__(self, alpha=1):
        self.alpha = alpha
        super(Fusedmax, self).__init__()

    def forward(self, x, lengths=None):
        fused_prox = FusedProxFunction(self.alpha)
        sparsemax = SparsemaxFunction()
        return sparsemax(fused_prox(x, lengths), lengths)


if __name__ == '__main__':
    from timeit import timeit
    torch.manual_seed(1)

    for dim in (5, 10, 50, 100, 500, 1000):

        x = torch.randn(dim)
        x_var = ta.Variable(x, requires_grad=True)
        y_hat = FusedProxFunction()(x_var).data
        dout = torch.arange(0, dim)
        print("dimension={}".format(dim))
        print("slow", timeit("fused_prox_jv_slow(y_hat, dout)",
                             globals=globals(),
                             number=10000))
        print("fast", timeit("fused_prox_jv_fast(y_hat, dout)",
                             globals=globals(),
                             number=10000))


================================================
FILE: pytorch/torchsparseattn/isotonic.py
================================================
"""
Isotonic Regression that preserves 32bit inputs.

backported from scikit-learn pull request
https://github.com/scikit-learn/scikit-learn/pull/9106"""

import numpy as np

from ._isotonic import _inplace_contiguous_isotonic_regression


def isotonic_regression(y, sample_weight=None, y_min=None, y_max=None,
                        increasing=True):
    """Solve the isotonic regression model::

        min sum w[i] (y[i] - y_[i]) ** 2

        subject to y_min = y_[1] <= y_[2] ... <= y_[n] = y_max

    where:
        - y[i] are inputs (real numbers)
        - y_[i] are fitted
        - w[i] are optional strictly positive weights (default to 1.0)

    Read more in the :ref:`User Guide <isotonic>`.

    Parameters
    ----------
    y : iterable of floating-point values
        The data.

    sample_weight : iterable of floating-point values, optional, default: None
        Weights on each point of the regression.
        If None, weight is set to 1 (equal weights).

    y_min : optional, default: None
        If not None, set the lowest value of the fit to y_min.

    y_max : optional, default: None
        If not None, set the highest value of the fit to y_max.

    increasing : boolean, optional, default: True
        Whether to compute ``y_`` is increasing (if set to True) or decreasing
        (if set to False)

    Returns
    -------
    y_ : list of floating-point values
        Isotonic fit of y.

    References
    ----------
    "Active set algorithms for isotonic regression; A unifying framework"
    by Michael J. Best and Nilotpal Chakravarti, section 3.
    """
    order = np.s_[:] if increasing else np.s_[::-1]
    # y = as_float_array(y)  # avoid sklearn dependency; we always pass arrays
    y = np.array(y[order], dtype=y.dtype)
    if sample_weight is None:
        sample_weight = np.ones(len(y), dtype=y.dtype)
    else:
        sample_weight = np.array(sample_weight[order], dtype=y.dtype)

    _inplace_contiguous_isotonic_regression(y, sample_weight)
    if y_min is not None or y_max is not None:
        # Older versions of np.clip don't accept None as a bound, so use np.inf
        if y_min is None:
            y_min = -np.inf
        if y_max is None:
            y_max = np.inf
        np.clip(y, y_min, y_max, y)
    return y[order]


================================================
FILE: pytorch/torchsparseattn/oscar.py
================================================
"""Oscarmax attention

Clusters attention weights into groups with equal weight, regardless of index.

A Regularized Framework for Sparse and Structured Neural Attention
Vlad Niculae, Mathieu Blondel
https://arxiv.org/abs/1705.07704
"""

import numpy as np
import torch
from torch import nn
from torch import autograd as ta

from .isotonic import isotonic_regression
from .base import _BaseBatchProjection
from .sparsemax import SparsemaxFunction


def oscar_prox_jv(y_hat, dout):
    y_hat = y_hat.detach().numpy()
    din = dout.clone().zero_()
    dout = dout.numpy()
    din_np = din.numpy()

    sign = np.sign(y_hat)
    y_hat = np.abs(y_hat)

    uniq, inv, counts = np.unique(y_hat, return_inverse=True,
                                  return_counts=True)
    n_unique = len(uniq)
    tmp = np.zeros((n_unique,), dtype=y_hat.dtype)
    np.add.at(tmp, inv, dout * sign)
    tmp /= counts
    tmp.take(inv, mode='clip', out=din_np)
    din_np *= sign
    return din


def prox_owl(v, w):
    """Proximal operator of the OWL norm dot(w, reversed(sort(v)))

    Follows description and notation from:
    X. Zeng, M. Figueiredo,
    The ordered weighted L1 norm: Atomic formulation, dual norm,
    and projections.
    eprint http://arxiv.org/abs/1409.4271
    """

    # wlog operate on absolute values
    v_abs = np.abs(v)
    ix = np.argsort(v_abs)[::-1]
    v_abs = v_abs[ix]
    # project to K+ (monotone non-negative decreasing cone)
    v_abs = isotonic_regression(v_abs - w, y_min=0, increasing=False)

    # undo the sorting
    inv_ix = np.zeros_like(ix)
    inv_ix[ix] = np.arange(len(v))
    v_abs = v_abs[inv_ix]

    return np.sign(v) * v_abs


def _oscar_weights(alpha, beta, size):
    w = np.arange(size - 1, -1, -1, dtype=np.float32)
    w *= beta
    w += alpha
    return w


class OscarProxFunction(_BaseBatchProjection):
    """Proximal operator of the OSCAR regularizer.

    ||w||_oscar = alpha ||w||_1 + beta * sum_i<j max { |w_i|, |w_j| }

    Implemented via the OWL norm with appropriate choice of weights, as
    described in:

    X. Zeng, M. Figueiredo,
    The ordered weighted L1 norm: Atomic formulation, dual norm,
    and projections.
    eprint http://arxiv.org/abs/1409.4271

    Backward pass is described in:
    V. Niculae, M. Blondel,
    A Regularized Framework for Sparse and Structured Neural Attention.
    eprint https://arxiv.org/abs/1705.07704
    """

    def __init__(self, alpha=0, beta=1):
        self.alpha = alpha
        self.beta = beta

    def project(self, x):
        x_np = x.detach().numpy().copy()
        weights = _oscar_weights(self.alpha, self.beta, x_np.shape[0])
        y_hat_np = prox_owl(x_np, weights)
        y_hat = torch.from_numpy(y_hat_np)
        return y_hat

    def project_jv(self, dout, y_hat):
        return oscar_prox_jv(y_hat, dout)


class Oscarmax(nn.Module):
    def __init__(self, beta=1):
        self.beta = beta
        super(Oscarmax, self).__init__()

    def forward(self, x, lengths=None):
        oscar_prox = OscarProxFunction(beta=self.beta)
        sparsemax = SparsemaxFunction()
        return sparsemax(oscar_prox(x, lengths), lengths)


if __name__ == '__main__':
    from timeit import timeit
    torch.manual_seed(1)

    for dim in (5, 10, 50, 100, 500, 1000):

        x = torch.randn(dim)
        x_var = ta.Variable(x, requires_grad=True)

        def _run_backward(x):
            y_hat = OscarProxFunction(beta=0.1)(x)
            val = y_hat.mean()
            val.backward()

        print("dimension={}".format(dim))
        print("la", timeit("_run_backward(x_var)",
                           globals=globals(),
                           number=10000))


================================================
FILE: pytorch/torchsparseattn/sparsemax.py
================================================
# encoding: utf8

"""
From Softmax to Sparsemax: A Sparse Model of Attention and Multi-Label
Classification. André F. T. Martins, Ramón Fernandez Astudillo
In: Proc. of ICML 2016, https://arxiv.org/abs/1602.02068
"""

from __future__ import division

import numpy as np
import torch
from torch import nn
from .base import _BaseBatchProjection


def project_simplex(v, z=1):
    v_sorted, _ = torch.sort(v, dim=0, descending=True)
    cssv = torch.cumsum(v_sorted, dim=0) - z
    ind = torch.arange(1, 1 + len(v)).to(dtype=v.dtype)
    cond = v_sorted - cssv / ind > 0
    rho = ind.masked_select(cond)[-1]
    tau = cssv.masked_select(cond)[-1] / rho
    w = torch.clamp(v - tau, min=0)
    return w


def sparsemax_grad(dout, w_star):
    supp = w_star > 0
    masked = dout.masked_select(supp)
    nnz = supp.to(dtype=dout.dtype).sum()
    masked -= masked.sum() / nnz
    out = dout.new(dout.size()).zero_()
    out[supp] = masked
    return(out)


class SparsemaxFunction(_BaseBatchProjection):

    def project(self, x):
        return project_simplex(x)

    def project_jv(self, dout, y_star):
        return sparsemax_grad(dout, y_star)


class Sparsemax(nn.Module):

    def forward(self, x, lengths=None):
        sparsemax = SparsemaxFunction()
        return sparsemax(x, lengths)



================================================
FILE: pytorch/torchsparseattn/test_attention.py
================================================
import pytest

import torch
from torch import nn
from torch.autograd import Variable

from . import Sparsemax, Fusedmax, Oscarmax


class AttentionRegressor(nn.Module):

    def __init__(self, projection, n_features=100):
        super(AttentionRegressor, self).__init__()
        self.projection = projection
        self.attn_template = nn.Parameter(torch.Tensor(n_features))
        self.attn_template.data.uniform_(-0.1, 0.1)

    def forward(self, X, lengths):

        # compute scores for each input word
        scores = torch.matmul(X, self.attn_template)
        weights = self.projection(scores, lengths)
        weighted_avg = torch.bmm(X.transpose(1, 2),
                                 weights.unsqueeze(-1)).squeeze(-1)
        pred = weighted_avg.sum(dim=1)  # very simple prediction rule
        return pred


@pytest.mark.parametrize('projection', [Sparsemax(),
                                        Fusedmax(0.1),
                                        Oscarmax(0.01)])
def test_attention(projection):
    n_samples = 20
    max_len = 10
    torch.manual_seed(1)
    n_features = 50

    X = torch.zeros(n_samples, max_len, n_features)

    # generate lengths in [1, max_len]
    lengths = 1 + (torch.rand(n_samples) * max_len).long()

    for i in range(n_samples):
        X[i, :lengths[i], :] = torch.randn(lengths[i], n_features)

    X = Variable(X)
    lengths = Variable(lengths)
    targets = Variable(torch.randn(n_samples))

    regr = AttentionRegressor(projection, n_features=n_features)
    loss_func = nn.MSELoss()
    optim = torch.optim.SGD(regr.parameters(), lr=0.0001)

    pred = regr(X, lengths)

    init_obj = loss_func(pred, targets)

    for it in range(50):
        optim.zero_grad()
        pred = regr(X, lengths)
        obj = loss_func(pred, targets)
        obj.backward()
        optim.step()

    final_obj = obj
    assert final_obj < init_obj
    assert regr.attn_template.grad.size() == (n_features,)


================================================
FILE: pytorch/torchsparseattn/test_fused.py
================================================
from __future__ import division

import pytest
from numpy.testing import assert_allclose
import torch
from torch.autograd import gradcheck, Variable

from .fused import fused_prox_jv_slow, fused_prox_jv_fast
from .fused import FusedProxFunction


def _fused_prox_jacobian(y_hat, dout=None):
    """reference naive implementation: construct the jacobian"""
    dim = y_hat.shape[0]
    groups = torch.zeros(dim)
    J = torch.zeros(dim, dim)
    current_group = 0

    for i in range(1, dim):
        if y_hat[i] == y_hat[i - 1]:
            groups[i] = groups[i - 1]
        else:
            current_group += 1
            groups[i] = current_group

    for i in range(dim):
        for j in range(dim):
            if groups[i] == groups[j]:
                n_fused = (groups == groups[i]).sum()
                J[i, j] = 1 / n_fused.to(y_hat.dtype)

    if dout is not None:
        return torch.mv(J, dout)
    else:
        return J


@pytest.mark.parametrize('alpha', [0.001, 0.01, 0.1, 1])
def test_jv(alpha):

    torch.manual_seed(1)
    torch.set_default_tensor_type('torch.DoubleTensor')

    for _ in range(30):
        x = Variable(torch.randn(15))
        dout = torch.randn(15)

        y_hat = FusedProxFunction(alpha=alpha)(x).data


        ref = _fused_prox_jacobian(y_hat, dout)
        din_slow = fused_prox_jv_slow(y_hat, dout)
        din_fast = fused_prox_jv_fast(y_hat, dout)
        assert_allclose(ref.numpy(), din_slow.numpy(), atol=1e-5)
        assert_allclose(ref.numpy(), din_fast.numpy(), atol=1e-5)


@pytest.mark.parametrize('alpha', [0.001, 0.01, 0.1, 1])
def test_finite_diff(alpha):
    torch.manual_seed(1)
    torch.set_default_tensor_type('torch.DoubleTensor')

    for _ in range(30):
        x = Variable(torch.randn(20), requires_grad=True)
        func = FusedProxFunction(alpha=alpha)
        assert gradcheck(func, (x,), eps=1e-4, atol=1e-3)


================================================
FILE: pytorch/torchsparseattn/test_oscar.py
================================================
from __future__ import division

import pytest
from numpy.testing import assert_allclose
import numpy as np
import torch
from torch.autograd import gradcheck, Variable

from .oscar import OscarProxFunction, oscar_prox_jv


def _oscar_prox_jacobian(y_star, dout=None):
    y_star = y_star.numpy()
    dim = y_star.shape[0]
    J = torch.zeros(dim, dim)

    _, inv, counts = np.unique(np.abs(y_star),
                               return_inverse=True,
                               return_counts=True)

    for i in range(dim):
        for j in range(dim):
            if (inv[i] == inv[j] and
                    y_star[i] != 0):
                J[i, j] = (np.sign(y_star[i]) * np.sign(y_star[j])
                           / counts[inv[i]])
    if dout is not None:
        return torch.mv(J, dout)
    else:
        return J


@pytest.mark.parametrize('alpha', [0.001, 0.01, 0.1, 1])
@pytest.mark.parametrize('beta', [0.001, 0.01, 0.1, 1])
def test_jv(alpha, beta):

    torch.manual_seed(1)
    torch.set_default_tensor_type('torch.DoubleTensor')

    for _ in range(30):
        x = Variable(torch.randn(15))
        dout = torch.randn(15)
        y_hat = OscarProxFunction(alpha=alpha, beta=beta)(x).data

        ref = _oscar_prox_jacobian(y_hat, dout)
        din = oscar_prox_jv(y_hat, dout)
        assert_allclose(ref.numpy(), din.numpy(), atol=1e-5)


@pytest.mark.parametrize('alpha', [0.001, 0.01, 0.1, 1])
@pytest.mark.parametrize('beta', [0.001, 0.01, 0.1, 1])
def test_finite_diff(alpha, beta):
    torch.manual_seed(1)
    torch.set_default_tensor_type('torch.DoubleTensor')

    for _ in range(30):
        x = Variable(torch.randn(20), requires_grad=True)
        func = OscarProxFunction(alpha, beta=beta)
        assert gradcheck(func, (x,), eps=1e-5, atol=1e-3)


================================================
FILE: pytorch/torchsparseattn/test_sparsemax.py
================================================
import torch
from torch.autograd import gradcheck, Variable
from .sparsemax import SparsemaxFunction


def test_sparsemax():

    torch.manual_seed(1)
    torch.set_default_tensor_type('torch.DoubleTensor')

    for _ in range(30):
        func = SparsemaxFunction()
        x = Variable(torch.randn(20), requires_grad=True)
        assert gradcheck(func, (x,), eps=1e-4, atol=1e-3)
Download .txt
gitextract_gpxv25pq/

├── .gitignore
├── .travis.yml
├── LICENSE
├── README.md
└── pytorch/
    ├── MANIFEST.in
    ├── setup.py
    └── torchsparseattn/
        ├── __init__.py
        ├── _fused.pyx
        ├── _fused_jv.pyx
        ├── _isotonic.pyx
        ├── base.py
        ├── fused.py
        ├── isotonic.py
        ├── oscar.py
        ├── sparsemax.py
        ├── test_attention.py
        ├── test_fused.py
        ├── test_oscar.py
        └── test_sparsemax.py
Download .txt
SYMBOL INDEX (43 symbols across 9 files)

FILE: pytorch/torchsparseattn/base.py
  class _BaseBatchProjection (line 4) | class _BaseBatchProjection(ta.Function):
    method forward (line 7) | def forward(self, x, lengths=None):
    method backward (line 40) | def backward(self, dout):

FILE: pytorch/torchsparseattn/fused.py
  function _inplace_fused_prox_jv_slow (line 22) | def _inplace_fused_prox_jv_slow(y_hat, dout):
  function fused_prox_jv_slow (line 50) | def fused_prox_jv_slow(y_hat, dout):
  function fused_prox_jv_fast (line 56) | def fused_prox_jv_fast(y_hat, dout):
  class FusedProxFunction (line 62) | class FusedProxFunction(_BaseBatchProjection):
    method __init__ (line 64) | def __init__(self, alpha=1):
    method project (line 67) | def project(self, x):
    method project_jv (line 73) | def project_jv(self, dout, y_hat):
  class Fusedmax (line 79) | class Fusedmax(nn.Module):
    method __init__ (line 80) | def __init__(self, alpha=1):
    method forward (line 84) | def forward(self, x, lengths=None):

FILE: pytorch/torchsparseattn/isotonic.py
  function isotonic_regression (line 12) | def isotonic_regression(y, sample_weight=None, y_min=None, y_max=None,

FILE: pytorch/torchsparseattn/oscar.py
  function oscar_prox_jv (line 20) | def oscar_prox_jv(y_hat, dout):
  function prox_owl (line 40) | def prox_owl(v, w):
  function _oscar_weights (line 65) | def _oscar_weights(alpha, beta, size):
  class OscarProxFunction (line 72) | class OscarProxFunction(_BaseBatchProjection):
    method __init__ (line 91) | def __init__(self, alpha=0, beta=1):
    method project (line 95) | def project(self, x):
    method project_jv (line 102) | def project_jv(self, dout, y_hat):
  class Oscarmax (line 106) | class Oscarmax(nn.Module):
    method __init__ (line 107) | def __init__(self, beta=1):
    method forward (line 111) | def forward(self, x, lengths=None):
  function _run_backward (line 126) | def _run_backward(x):

FILE: pytorch/torchsparseattn/sparsemax.py
  function project_simplex (line 17) | def project_simplex(v, z=1):
  function sparsemax_grad (line 28) | def sparsemax_grad(dout, w_star):
  class SparsemaxFunction (line 38) | class SparsemaxFunction(_BaseBatchProjection):
    method project (line 40) | def project(self, x):
    method project_jv (line 43) | def project_jv(self, dout, y_star):
  class Sparsemax (line 47) | class Sparsemax(nn.Module):
    method forward (line 49) | def forward(self, x, lengths=None):

FILE: pytorch/torchsparseattn/test_attention.py
  class AttentionRegressor (line 10) | class AttentionRegressor(nn.Module):
    method __init__ (line 12) | def __init__(self, projection, n_features=100):
    method forward (line 18) | def forward(self, X, lengths):
  function test_attention (line 32) | def test_attention(projection):

FILE: pytorch/torchsparseattn/test_fused.py
  function _fused_prox_jacobian (line 12) | def _fused_prox_jacobian(y_hat, dout=None):
  function test_jv (line 39) | def test_jv(alpha):
  function test_finite_diff (line 59) | def test_finite_diff(alpha):

FILE: pytorch/torchsparseattn/test_oscar.py
  function _oscar_prox_jacobian (line 12) | def _oscar_prox_jacobian(y_star, dout=None):
  function test_jv (line 35) | def test_jv(alpha, beta):
  function test_finite_diff (line 52) | def test_finite_diff(alpha, beta):

FILE: pytorch/torchsparseattn/test_sparsemax.py
  function test_sparsemax (line 6) | def test_sparsemax():
Condensed preview — 19 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (36K chars).
[
  {
    "path": ".gitignore",
    "chars": 1045,
    "preview": "# Byte-compiled / optimized / DLL files\n__pycache__/\n*.py[cod]\n*$py.class\n\n# C extensions\n*.so\n\n# Distribution / packagi"
  },
  {
    "path": ".travis.yml",
    "chars": 799,
    "preview": "sudo: false\nlanguage: python\ndist: xenial\n\npython:\n  - \"2.7\"\n  - \"3.7\"\n\nenv:\n  - TORCH_VERSION=1.0.1\n  - TORCH_VERSION=0"
  },
  {
    "path": "LICENSE",
    "chars": 1512,
    "preview": "BSD 3-Clause License\n\nCopyright (c) 2017, Vlad Niculae\nAll rights reserved.\n\nRedistribution and use in source and binary"
  },
  {
    "path": "README.md",
    "chars": 1736,
    "preview": "# Sparse and structured attention mechanisms\n[![Build Status](https://travis-ci.org/vene/sparse-structured-attention.svg"
  },
  {
    "path": "pytorch/MANIFEST.in",
    "chars": 80,
    "preview": "include MANIFEST.in\nrecursive-include torchsparseattn *.c *.h *.cpp *.pyx *.pxd\n"
  },
  {
    "path": "pytorch/setup.py",
    "chars": 1339,
    "preview": "import numpy\nfrom setuptools import setup, find_packages, Extension\n\nfrom Cython.Build import cythonize\n\nextensions = [\n"
  },
  {
    "path": "pytorch/torchsparseattn/__init__.py",
    "chars": 186,
    "preview": "from .fused import Fusedmax, FusedProxFunction\nfrom .oscar import Oscarmax, OscarProxFunction\nfrom .sparsemax import Spa"
  },
  {
    "path": "pytorch/torchsparseattn/_fused.pyx",
    "chars": 4421,
    "preview": "# encoding: utf-8\n# cython: cdivision=True\n# cython: boundscheck=False\n# cython: wraparound=False\n#\n# Authors: Fabian Pe"
  },
  {
    "path": "pytorch/torchsparseattn/_fused_jv.pyx",
    "chars": 669,
    "preview": "cimport cython\nfrom cython cimport floating\n\n\n@cython.boundscheck(False)\n@cython.wraparound(False)\n@cython.cdivision(Tru"
  },
  {
    "path": "pytorch/torchsparseattn/_isotonic.pyx",
    "chars": 3934,
    "preview": "# Author: Nelle Varoquaux, Andrew Tulloch, Antony Lee\n\n# Uses the pool adjacent violators algorithm (PAVA), with the\n# e"
  },
  {
    "path": "pytorch/torchsparseattn/base.py",
    "chars": 2124,
    "preview": "from torch import nn\nfrom torch import autograd as ta\n\nclass _BaseBatchProjection(ta.Function):\n    \"\"\"Applies a sample-"
  },
  {
    "path": "pytorch/torchsparseattn/fused.py",
    "chars": 2902,
    "preview": "\"\"\"Fusedmax attention\n\nClusters neighboring attention weights into groups with equal weight.\n\nA Regularized Framework fo"
  },
  {
    "path": "pytorch/torchsparseattn/isotonic.py",
    "chars": 2292,
    "preview": "\"\"\"\nIsotonic Regression that preserves 32bit inputs.\n\nbackported from scikit-learn pull request\nhttps://github.com/sciki"
  },
  {
    "path": "pytorch/torchsparseattn/oscar.py",
    "chars": 3683,
    "preview": "\"\"\"Oscarmax attention\n\nClusters attention weights into groups with equal weight, regardless of index.\n\nA Regularized Fra"
  },
  {
    "path": "pytorch/torchsparseattn/sparsemax.py",
    "chars": 1294,
    "preview": "# encoding: utf8\n\n\"\"\"\nFrom Softmax to Sparsemax: A Sparse Model of Attention and Multi-Label\nClassification. André F. T."
  },
  {
    "path": "pytorch/torchsparseattn/test_attention.py",
    "chars": 1959,
    "preview": "import pytest\n\nimport torch\nfrom torch import nn\nfrom torch.autograd import Variable\n\nfrom . import Sparsemax, Fusedmax,"
  },
  {
    "path": "pytorch/torchsparseattn/test_fused.py",
    "chars": 1889,
    "preview": "from __future__ import division\n\nimport pytest\nfrom numpy.testing import assert_allclose\nimport torch\nfrom torch.autogra"
  },
  {
    "path": "pytorch/torchsparseattn/test_oscar.py",
    "chars": 1786,
    "preview": "from __future__ import division\n\nimport pytest\nfrom numpy.testing import assert_allclose\nimport numpy as np\nimport torch"
  },
  {
    "path": "pytorch/torchsparseattn/test_sparsemax.py",
    "chars": 383,
    "preview": "import torch\nfrom torch.autograd import gradcheck, Variable\nfrom .sparsemax import SparsemaxFunction\n\n\ndef test_sparsema"
  }
]

About this extraction

This page contains the full source code of the vene/sparse-structured-attention GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 19 files (33.2 KB), approximately 9.3k tokens, and a symbol index with 43 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!