[
  {
    "path": "LICENSE",
    "content": "MIT License\n\nCopyright (c) 2018 mishimori\n\nPermission is hereby granted, free of charge, to any person obtaining a copy\nof this software and associated documentation files (the \"Software\"), to deal\nin the Software without restriction, including without limitation the rights\nto use, copy, modify, merge, publish, distribute, sublicense, and/or sell\ncopies of the Software, and to permit persons to whom the Software is\nfurnished to do so, subject to the following conditions:\n\nThe above copyright notice and this permission notice shall be included in all\ncopies or substantial portions of the Software.\n\nTHE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\nIMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\nFITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\nAUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\nLIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\nOUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\nSOFTWARE.\n"
  },
  {
    "path": "README.md",
    "content": "# CondConv\n\nImplementation of [CondConv: Conditionally Parameterized Convolutions for Efficient Inference](https://arxiv.org/abs/1904.04971) \nin PyTorch.\n\n## Abstract\n\nConvolutional layers are one of the basic building blocks of modern deep neural networks. One fundamental assumption is that convolutional kernels should\nbe shared for all examples in a dataset. We propose conditionally parameterized convolutions (CondConv), which learn specialized convolutional kernels\nfor each example. Replacing normal convolutions with CondConv enables us\nto increase the size and capacity of a network, while maintaining efficient inference. We demonstrate that scaling networks with CondConv improves the\nperformance and inference cost trade-off of several existing convolutional neural\nnetwork architectures on both classification and detection tasks. On ImageNet\nclassification, our CondConv approach applied to EfficientNet-B0 achieves state-ofthe-art performance of 78.3% accuracy with only 413M multiply-adds. Code and\ncheckpoints for the CondConv Tensorflow layer and CondConv-EfficientNet models are available at: https://github.com/tensorflow/tpu/tree/master/\nmodels/official/efficientnet/condconv.\n\n\n## Installation\n\n    pip install git+https://github.com/nibuiro/CondConv-pytorch.git\n\n## Usage\n\n\nFor 2D inputs (CondConv2D):\n\n```python\nimport torch\nfrom condconv import CondConv2D\n\n\nclass Model(nn.Module):\n    def __init__(self, num_experts):\n        super(Model, self).__init__()\n        self.condconv2d = CondConv2D(10, 128, kernel_size=1, num_experts=num_experts, dropout_rate=dropout_rate)\n        \n    def forward(self, x):\n        x = self.condconv2d(x)\n```\n\n## Reference\n[Yang et al., 2019] CondConv: Conditionally Parameterized Convolutions for Efficient Inference\n"
  },
  {
    "path": "condconv/__init__.py",
    "content": "from .condconv import CondConv1D, CondConv2D\n\n__all__ = ['CondConv1D', 'CondConv2D']\n"
  },
  {
    "path": "condconv/condconv.py",
    "content": "import functools\n\nimport torch\nfrom torch import nn\nimport torch.nn.functional as F\nfrom torch.nn.modules.conv import _ConvNd\nfrom torch.nn.modules.utils import _pair\nfrom torch.nn.parameter import Parameter\n\n\nclass _routing(nn.Module):\n\n    def __init__(self, in_channels, num_experts, dropout_rate):\n        super(_routing, self).__init__()\n        \n        self.dropout = nn.Dropout(dropout_rate)\n        self.fc = nn.Linear(in_channels, num_experts)\n\n    def forward(self, x):\n        x = torch.flatten(x)\n        x = self.dropout(x)\n        x = self.fc(x)\n        return F.sigmoid(x)\n    \n\nclass CondConv2D(_ConvNd):\n    r\"\"\"Learn specialized convolutional kernels for each example.\n\n    As described in the paper\n    `CondConv: Conditionally Parameterized Convolutions for Efficient Inference`_ ,\n    conditionally parameterized convolutions (CondConv), \n    which challenge the paradigm of static convolutional kernels \n    by computing convolutional kernels as a function of the input.\n\n    Args:\n        in_channels (int): Number of channels in the input image\n        out_channels (int): Number of channels produced by the convolution\n        kernel_size (int or tuple): Size of the convolving kernel\n        stride (int or tuple, optional): Stride of the convolution. Default: 1\n        padding (int or tuple, optional): Zero-padding added to both sides of the input. Default: 0\n        padding_mode (string, optional): ``'zeros'``, ``'reflect'``, ``'replicate'`` or ``'circular'``. Default: ``'zeros'``\n        dilation (int or tuple, optional): Spacing between kernel elements. Default: 1\n        groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1\n        bias (bool, optional): If ``True``, adds a learnable bias to the output. Default: ``True``\n        num_experts (int): Number of experts per layer \n    Shape:\n        - Input: :math:`(N, C_{in}, H_{in}, W_{in})`\n        - Output: :math:`(N, C_{out}, H_{out}, W_{out})` where\n          .. math::\n              H_{out} = \\left\\lfloor\\frac{H_{in}  + 2 \\times \\text{padding}[0] - \\text{dilation}[0]\n                        \\times (\\text{kernel\\_size}[0] - 1) - 1}{\\text{stride}[0]} + 1\\right\\rfloor\n          .. math::\n              W_{out} = \\left\\lfloor\\frac{W_{in}  + 2 \\times \\text{padding}[1] - \\text{dilation}[1]\n                        \\times (\\text{kernel\\_size}[1] - 1) - 1}{\\text{stride}[1]} + 1\\right\\rfloor\n    Attributes:\n        weight (Tensor): the learnable weights of the module of shape\n                         :math:`(\\text{out\\_channels}, \\frac{\\text{in\\_channels}}{\\text{groups}},`\n                         :math:`\\text{kernel\\_size[0]}, \\text{kernel\\_size[1]})`.\n                         The values of these weights are sampled from\n                         :math:`\\mathcal{U}(-\\sqrt{k}, \\sqrt{k})` where\n                         :math:`k = \\frac{groups}{C_\\text{in} * \\prod_{i=0}^{1}\\text{kernel\\_size}[i]}`\n        bias (Tensor):   the learnable bias of the module of shape (out_channels). If :attr:`bias` is ``True``,\n                         then the values of these weights are\n                         sampled from :math:`\\mathcal{U}(-\\sqrt{k}, \\sqrt{k})` where\n                         :math:`k = \\frac{groups}{C_\\text{in} * \\prod_{i=0}^{1}\\text{kernel\\_size}[i]}`\n\n    .. _CondConv: Conditionally Parameterized Convolutions for Efficient Inference:\n       https://arxiv.org/abs/1904.04971\n\n    \"\"\"\n\n    def __init__(self, in_channels, out_channels, kernel_size, stride=1,\n                 padding=0, dilation=1, groups=1,\n                 bias=True, padding_mode='zeros', num_experts=3, dropout_rate=0.2):\n        kernel_size = _pair(kernel_size)\n        stride = _pair(stride)\n        padding = _pair(padding)\n        dilation = _pair(dilation)\n        super(CondConv2D, self).__init__(\n            in_channels, out_channels, kernel_size, stride, padding, dilation,\n            False, _pair(0), groups, bias, padding_mode)\n\n        self._avg_pooling = functools.partial(F.adaptive_avg_pool2d, output_size=(1, 1))\n        self._routing_fn = _routing(in_channels, num_experts, dropout_rate)\n        \n        self.weight = Parameter(torch.Tensor(\n            num_experts, out_channels, in_channels // groups, *kernel_size))\n        \n        self.reset_parameters()\n\n    def _conv_forward(self, input, weight):\n        if self.padding_mode != 'zeros':\n            return F.conv2d(F.pad(input, self._reversed_padding_repeated_twice, mode=self.padding_mode),\n                            weight, self.bias, self.stride,\n                            _pair(0), self.dilation, self.groups)\n        return F.conv2d(input, weight, self.bias, self.stride,\n                        self.padding, self.dilation, self.groups)\n    \n    def forward(self, inputs):\n        b, _, _, _ = inputs.size()\n        res = []\n        for input in inputs:\n            input = input.unsqueeze(0)\n            pooled_inputs = self._avg_pooling(input)\n            routing_weights = self._routing_fn(pooled_inputs)\n            kernels = torch.sum(routing_weights[: ,None, None, None, None] * self.weight, 0)\n            out = self._conv_forward(input, kernels)\n            res.append(out)\n        return torch.cat(res, dim=0)\n\nclass CondConv1D(_ConvNd):\n    r\"\"\"Learn specialized convolutional kernels for each example.\n\n    As described in the paper\n    `CondConv: Conditionally Parameterized Convolutions for Efficient Inference`_ ,\n    conditionally parameterized convolutions (CondConv), \n    which challenge the paradigm of static convolutional kernels \n    by computing convolutional kernels as a function of the input.\n\n    Args:\n        in_channels (int): Number of channels in the input image\n        out_channels (int): Number of channels produced by the convolution\n        kernel_size (int): Size of the convolving kernel\n        stride (int, optional): Stride of the convolution. Default: 1\n        padding (int, optional): Zero-padding added to both sides of the input. Default: 0\n        padding_mode (string, optional): ``'zeros'``, ``'reflect'``, ``'replicate'`` or ``'circular'``. Default: ``'zeros'``\n        dilation (int, optional): Spacing between kernel elements. Default: 1\n        groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1\n        bias (bool, optional): If ``True``, adds a learnable bias to the output. Default: ``True``\n        num_experts (int): Number of experts per layer \n    Shape:\n        - Input: :math:`(N, C_{in}, L_{in})`\n        - Output: :math:`(N, C_{out}, L_{out})` \n    Attributes:\n        weight (Tensor): the learnable weights of the module of shape\n            :math:`(\\text{out\\_channels},\n            \\frac{\\text{in\\_channels}}{\\text{groups}}, \\text{kernel\\_size})`.\n            The values of these weights are sampled from\n            :math:`\\mathcal{U}(-\\sqrt{k}, \\sqrt{k})` where\n            :math:`k = \\frac{groups}{C_\\text{in} * \\text{kernel\\_size}}`\n        bias (Tensor):   the learnable bias of the module of shape\n            (out_channels). If :attr:`bias` is ``True``, then the values of these weights are\n            sampled from :math:`\\mathcal{U}(-\\sqrt{k}, \\sqrt{k})` where\n            :math:`k = \\frac{groups}{C_\\text{in} * \\text{kernel\\_size}}`\n\n    .. _CondConv: Conditionally Parameterized Convolutions for Efficient Inference:\n       https://arxiv.org/abs/1904.04971\n\n    \"\"\"\n\n    def __init__(self, in_channels, out_channels, kernel_size, stride=1,\n                 padding=0, dilation=1, groups=1,\n                 bias=True, padding_mode='zeros', num_experts=3, dropout_rate=0.2):\n        kernel_size = (kernel_size,)\n        stride = (stride,)\n        padding = (padding,)\n        dilation = (dilation,)\n        super(CondConv1D, self).__init__(\n            in_channels, out_channels, kernel_size, stride, padding, dilation,\n            False, (0,), groups, bias, padding_mode)\n\n        self._avg_pooling = functools.partial(F.adaptive_avg_pool1d, output_size=1)\n        self._routing_fn = _routing(in_channels, num_experts, dropout_rate)\n        \n        self.weight = Parameter(torch.Tensor(\n            num_experts, out_channels, in_channels // groups, *kernel_size))\n        \n        self.reset_parameters()\n\n    def _conv_forward(self, input, weight):\n        if self.padding_mode != 'zeros':\n            return F.conv1d(F.pad(input, self._reversed_padding_repeated_twice, mode=self.padding_mode),\n                            weight, self.bias, self.stride,\n                            0, self.dilation, self.groups)\n        return F.conv1d(input, weight, self.bias, self.stride,\n                        self.padding, self.dilation, self.groups)\n    \n    def forward(self, inputs):\n        b, _, _ = inputs.size()\n        res = []\n        for input in inputs:\n            input = input.unsqueeze(0)\n            pooled_inputs = self._avg_pooling(input)\n            routing_weights = self._routing_fn(pooled_inputs)\n            kernels = torch.sum(routing_weights[: ,None, None, None] * self.weight, 0)\n            out = self._conv_forward(input, kernels)\n            res.append(out)\n        return torch.cat(res, dim=0)\n\n"
  },
  {
    "path": "requirements.txt",
    "content": "torch>=0.4.1"
  },
  {
    "path": "setup.py",
    "content": "from setuptools import setup, find_packages\n\n\nwith open('requirements.txt', encoding='utf-8') as f:\n    required = f.read().splitlines()\n\nwith open('README.md', encoding='utf-8') as f:\n    long_description = f.read()\n\nsetup(\n    name='condconv',\n    version='1.0.0',\n    packages=find_packages(),\n    long_description=long_description,\n    long_description_content_type='text/markdown',\n    install_requires=required,\n    url='https://github.com/nibuiro/CondConv-pytorch',\n    license='MIT',\n    author='nibuiro',\n    author_email='immay1999@gmail.com',\n    description='Implementation of condconv: Conditionally Parameterized Convolutions for Efficient Inference. '\n)\n"
  }
]