Full Code of blue-season/pywarm for AI

master 92b9d85bee2d cached
24 files
99.2 KB
29.3k tokens
109 symbols
1 requests
Download .txt
Repository: blue-season/pywarm
Branch: master
Commit: 92b9d85bee2d
Files: 24
Total size: 99.2 KB

Directory structure:
gitextract_uvlwuzum/

├── .gitignore
├── CONTRIBUTING.md
├── LICENSE.md
├── README.md
├── docs/
│   ├── example.md
│   ├── text.mako
│   └── tutorial.md
├── examples/
│   ├── efficientnet.py
│   ├── lstm.py
│   ├── mnist.py
│   ├── mobilenet.py
│   ├── resnet.py
│   └── transformer.py
├── pyproject.toml
├── tests/
│   ├── test_engine.py
│   ├── test_functional.py
│   ├── test_module.py
│   ├── test_util.py
│   └── test_warm.py
└── warm/
    ├── __init__.py
    ├── engine.py
    ├── functional.py
    ├── module.py
    └── util.py

================================================
FILE CONTENTS
================================================

================================================
FILE: .gitignore
================================================
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class

# C extensions
*.so

# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST

# PyInstaller
#  Usually these files are written by a python script from a template
#  before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec

# Installer logs
pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
htmlcov/
.tox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
.hypothesis/
.pytest_cache/

# Translations
*.mo
*.pot

# Django stuff:
*.log
local_settings.py
db.sqlite3

# Flask stuff:
instance/
.webassets-cache

# Scrapy stuff:
.scrapy

# Sphinx documentation
docs/_build/

# PyBuilder
target/

# Jupyter Notebook
.ipynb_checkpoints

# pyenv
.python-version

# celery beat schedule file
celerybeat-schedule

# SageMath parsed files
*.sage.py

# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/

# Spyder project settings
.spyderproject
.spyproject

# Rope project settings
.ropeproject

# mkdocs documentation
/site

# mypy
.mypy_cache/

# Auto-generated content above this. Manually added content below.
.vscode/
.cache/
*cache*
/.project
tmp/
data/
site/


================================================
FILE: CONTRIBUTING.md
================================================
# Contributing to PyWarm

PyWarm is developed on [GitHub](https://github.com/blue-season/pywarm). 

Please use GitHub to file Bug reports and submit pull requests. 

Please document and test before submissions.

PyWarm is developed with Python 3.7, but has been tested to work with Python 3.6+.

# Coding Style

For the rational behind the distinct coding style use in PyWarm, please check

[A Coding Style for Python](https://blue-season.github.io/a-coding-style-for-python/).


================================================
FILE: LICENSE.md
================================================
MIT License

Copyright (c) 2019 blue-season

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.


================================================
FILE: README.md
================================================

[![PyWarm - A cleaner way to build neural networks for PyTorch](https://github.com/blue-season/pywarm/raw/gh-pages/docs/pywarm-logo.png)](https://blue-season.github.io/pywarm/)

# PyWarm

A cleaner way to build neural networks for PyTorch.

[![PyPI Python Version](https://img.shields.io/pypi/pyversions/pywarm)](https://github.com/blue-season/pywarm)
[![PyPI Version](https://img.shields.io/pypi/v/pywarm)](https://pypi.org/project/pywarm/)
[![License](https://img.shields.io/github/license/blue-season/pywarm)](https://github.com/blue-season/pywarm/blob/master/LICENSE)

[Examples](https://blue-season.github.io/pywarm/docs/example/)  |  [Tutorial](https://blue-season.github.io/pywarm/docs/tutorial/)  |   [API reference](https://blue-season.github.io/pywarm/reference/warm/functional/)

----

## Introduction

PyWarm is a lightweight, high-level neural network construction API for PyTorch.
It enables defining all parts of NNs in the functional way.

With PyWarm, you can put *all* network data flow logic in the `forward()` method of
your model, without the need to define children modules in the `__init__()` method
and then call it again in the `forward()`.
This result in a much more readable model definition in fewer lines of code.

PyWarm only aims to simplify the network definition, and does not attempt to cover
model training, validation or data handling.

----

For example, a convnet for MNIST:
(If needed, click the tabs to switch between Warm and Torch versions)


``` Python tab="Warm" linenums="1"
# powered by PyWarm
import torch.nn as nn
import torch.nn.functional as F
import warm
import warm.functional as W


class ConvNet(nn.Module):

    def __init__(self):
        super().__init__()
        warm.up(self, [2, 1, 28, 28])

    def forward(self, x):
        x = W.conv(x, 20, 5, activation='relu')
        x = F.max_pool2d(x, 2)
        x = W.conv(x, 50, 5, activation='relu')
        x = F.max_pool2d(x, 2)
        x = x.view(-1, 800)
        x = W.linear(x, 500, activation='relu')
        x = W.linear(x, 10)
        return F.log_softmax(x, dim=1)
```

``` Python tab="Torch" linenums="1"
# vanilla PyTorch version, taken from
# pytorch tutorials/beginner_source/blitz/neural_networks_tutorial.py 
import torch.nn as nn
import torch.nn.functional as F


class ConvNet(nn.Module):

    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5, 1)
        self.conv2 = nn.Conv2d(20, 50, 5, 1)
        self.fc1 = nn.Linear(4*4*50, 500)
        self.fc2 = nn.Linear(500, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2, 2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2, 2)
        x = x.view(-1, 4*4*50)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)
```

----

A couple of things you may have noticed:

-   First of all, in the PyWarm version, the entire network definition and
    data flow logic resides in the `forward()` method. You don't have to look
    up and down repeatedly to understand what `self.conv1`, `self.fc1` etc.
    is doing.

-   You do not need to track and specify `in_channels` (or `in_features`, etc.)
    for network layers. PyWarm can infer the information for you. e.g.

```Python
# Warm
x = W.conv(x, 20, 5, activation='relu')
x = W.conv(x, 50, 5, activation='relu')


# Torch
self.conv1 = nn.Conv2d(1, 20, 5, 1)
self.conv2 = nn.Conv2d(20, 50, 5, 1)
```

-   One unified `W.conv` for all 1D, 2D, and 3D cases. Fewer things to keep track of!

-   `activation='relu'`. All `warm.functional` APIs accept an optional `activation` keyword,
    which is basically equivalent to `F.relu(W.conv(...))`. The keyword `activation` can also 
    take in a callable, for example `activation=torch.nn.ReLU(inplace=True)` or `activation=swish`.

For deeper neural networks, see additional [examples](https://blue-season.github.io/pywarm/docs/example/).

----
## Installation

    pip3 install pywarm

----
## Quick start: 30 seconds to PyWarm

If you already have experinces with PyTorch, using PyWarm is very straightforward:

-   First, import PyWarm in you model file:
```Python
import warm
import warm.functional as W
```

-   Second, remove child module definitions in the model's `__init__()` method.
    In stead, use `W.conv`, `W.linear` ... etc. in the model's `forward()` method,
    just like how you would use torch nn functional `F.max_pool2d`, `F.relu` ... etc.

    For example, instead of writing:

```Python
# Torch
class MyModule(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size)
        # other child module definitions
    def forward(self, x):
        x = self.conv1(x)
        # more forward steps
```

-   You can now write in the warm way:

```Python
# Warm
class MyWarmModule(nn.Module):
    def __init__(self):
        super().__init__()
        warm.up(self, input_shape_or_data)
    def forward(self, x):
        x = W.conv(x, out_channels, kernel_size) # no in_channels needed
        # more forward steps
```

-   Finally, don't forget to warmify the model by adding
    
    `warm.up(self, input_shape_or_data)`

    at the end of the model's `__init__()` method. You need to supply
    `input_shape_or_data`, which is either a tensor of input data, 
    or just its shape, e.g. `[2, 1, 28, 28]` for MNIST inputs.
    
    The model is now ready to use, just like any other PyTorch models.

Check out the [tutorial](https://blue-season.github.io/pywarm/docs/tutorial/) 
and [examples](https://blue-season.github.io/pywarm/docs/example/) if you want to learn more!

----
## Testing

Clone the repository first, then

    cd pywarm
    pytest -v

----
## Documentation

Documentations are generated using the excellent [Portray](https://timothycrosley.github.io/portray/) package.

-   [Examples](https://blue-season.github.io/pywarm/docs/example/)

-   [Tutorial](https://blue-season.github.io/pywarm/docs/tutorial/) 

-   [API reference](https://blue-season.github.io/pywarm/reference/warm/functional/)


================================================
FILE: docs/example.md
================================================

# PyWarm Examples

## ResNet

A more detailed example, the ResNet18 network defined in PyWarm and vanilla PyTorch:

``` Python tab="Warm" linenums="1"
import torch
import torch.nn as nn
import torch.nn.functional as F
import warm
import warm.functional as W


def basic(x, size, stride):
    y = W.conv(x, size, 3, stride=stride, padding=1, bias=False)
    y = W.batch_norm(y, activation='relu')
    y = W.conv(y, size, 3, stride=1, padding=1, bias=False)
    y = W.batch_norm(y)
    if y.shape[1] != x.shape[1]: # channel size mismatch, needs projection
        x = W.conv(x, y.shape[1], 1, stride=stride, bias=False)
        x = W.batch_norm(x)
    y = y+x # residual shortcut connection
    return F.relu(y)


def stack(x, num_block, size, stride, block=basic):
    for s in [stride]+[1]*(num_block-1):
        x = block(x, size, s)
    return x


class ResNet(nn.Module):

    def __init__(self, block=basic,
            stack_spec=((2, 64, 1), (2, 128, 2), (2, 256, 2), (2, 512, 2))):
        super().__init__()
        self.block = block
        self.stack_spec = stack_spec
        warm.up(self, [2, 3, 32, 32])

    def forward(self, x):
        y = W.conv(x, 64, 7, stride=2, padding=3, bias=False)
        y = W.batch_norm(y, activation='relu')
        y = F.max_pool2d(y, 3, stride=2, padding=1)
        for spec in self.stack_spec:
            y = stack(y, *spec, block=self.block)
        y = F.adaptive_avg_pool2d(y, 1)
        y = torch.flatten(y, 1)
        y = W.linear(y, 1000)
        return y


resnet18 = ResNet()
```

``` Python tab="Torch" linenums="1"
# code based on torchvision/models/resnet.py
import torch
import torch.nn as nn
import torch.nn.functional as F


def conv3x3(size_in, size_out, stride=1):
    return nn.Conv2d(size_in, size_out, kernel_size=3, stride=stride,
        padding=1, groups=1, bias=False, dilation=1, )


def conv1x1(size_in, size_out, stride=1):
    return nn.Conv2d(size_in, size_out, kernel_size=1, stride=stride,
        padding=0, groups=1, bias=False, dilation=1, )


class BasicBlock(nn.Module):

    expansion = 1

    def __init__(self, size_in, size_out, stride=1, downsample=None):
        super().__init__()
        self.conv1 = conv3x3(size_in, size_out, stride)
        self.bn1 = nn.BatchNorm2d(size_out)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(size_out, size_out)
        self.bn2 = nn.BatchNorm2d(size_out)
        self.downsample = downsample

    def forward(self, x):
        identity = x
        y = self.conv1(x)
        y = self.bn1(y)
        y = self.relu(y)
        y = self.conv2(y)
        y = self.bn2(y)
        if self.downsample is not None:
            identity = self.downsample(x)
        y += identity
        y = self.relu(y)
        return y


class ResNet(nn.Module):

    def __init__(self,
            block=BasicBlock, num_block=[2, 2, 2, 2]):
        super().__init__()
        self.size_in = 64
        self.conv1 = nn.Conv2d(3, self.size_in, kernel_size=7, stride=2,
            padding=3, bias=False)
        self.bn1 = nn.BatchNorm2d(self.size_in)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.stack1 = self._make_stack(block, 64, num_block[0], 1)
        self.stack2 = self._make_stack(block, 128, num_block[1], 2)
        self.stack3 = self._make_stack(block, 256, num_block[2], 2)
        self.stack4 = self._make_stack(block, 512, num_block[3], 2)
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Linear(512, 1000)

    def _make_stack(self, block, size_out, num_blocks, stride):
        downsample = None
        if stride != 1 or self.size_in != size_out:
            downsample = nn.Sequential(
                conv1x1(self.size_in, size_out, stride),
                nn.BatchNorm2d(size_out), )
        stacks = []
        for stride in strides:
            stacks.append(
                block(self.size_in, size_out, stride, downsample))
            self.size_in = size_out
        return nn.Sequential(*stacks)

    def forward(self, x):
        y = self.conv1(x)
        y = self.bn1(y)
        y = self.relu(y)
        y = self.maxpool(y)
        y = self.stack1(y)
        y = self.stack2(y)
        y = self.stack3(y)
        y = self.stack4(y)
        y = self.avg_pool(y)
        y = torch.flatten(y, 1)
        y = self.fc(y)
        return y


resnet18 = ResNet()
```

-   The PyWarm version significantly reduces self-repititions of code as in the vanilla PyTorch version.

-   Note that when warming the model via `warm.up(self, [2, 3, 32, 32])`
    We set the first `Batch` dimension to 2 because the model uses `batch_norm`,
    which will not work when `Batch` is 1.

----

## MobileNet

``` Python tab="Warm" linenums="1"
import torch
import torch.nn as nn
import torch.nn.functional as F
import warm
import warm.functional as W


def conv_bn_relu(x, size, stride=1, expand=1, kernel=3, groups=1):
    x = W.conv(x, size, kernel, padding=(kernel-1)//2,
        stride=stride, groups=groups, bias=False, )
    return W.batch_norm(x, activation='relu6')


def bottleneck(x, size_out, stride, expand):
    size_in = x.shape[1]
    size_mid = size_in*expand
    y = conv_bn_relu(x, size_mid, kernel=1) if expand > 1 else x
    y = conv_bn_relu(y, size_mid, stride, kernel=3, groups=size_mid)
    y = W.conv(y, size_out, kernel=1, bias=False)
    y = W.batch_norm(y)
    if stride == 1 and size_in == size_out:
        y += x # residual shortcut
    return y


def conv1x1(x, *arg):
    return conv_bn_relu(x, *arg, kernel=1)


def pool(x, *arg):
    return x.mean([2, 3])


def classify(x, size, *arg):
    x = W.dropout(x, rate=0.2)
    return W.linear(x, size)


default_spec = (
    (None, 32, 1, 2, conv_bn_relu),  # t, c, n, s, operator
    (1, 16, 1, 1, bottleneck),
    (6, 24, 2, 2, bottleneck),
    (6, 32, 3, 2, bottleneck),
    (6, 64, 4, 2, bottleneck),
    (6, 96, 3, 1, bottleneck),
    (6, 160, 3, 2, bottleneck),
    (6, 320, 1, 1, bottleneck),
    (None, 1280, 1, 1, conv1x1),
    (None, None, 1, None, pool),
    (None, 1000, 1, None, classify), )


class MobileNetV2(nn.Module):

    def __init__(self):
        super().__init__()
        warm.up(self, [2, 3, 224, 224])
        
    def forward(self, x):
        for t, c, n, s, op in default_spec:
            for i in range(n):
                stride = s if i == 0 else 1
                x = op(x, c, stride, t)
        return x


net = MobileNetV2()
```

``` Python tab="Torch" linenums="1"
# code based on torchvision/models/mobilenet.py
import torch
import torch.nn as nn
import torch.nn.functional as F


class ConvBNReLU(nn.Sequential):

    def __init__(self, in_planes, out_planes, 
            kernel_size=3, stride=1, groups=1):
        padding = (kernel_size-1)//2
        super(ConvBNReLU, self).__init__(
            nn.Conv2d(in_planes, out_planes, kernel_size, 
                stride, padding, groups=groups, bias=False),
            nn.BatchNorm2d(out_planes),
            nn.ReLU6(inplace=True), )


class BottleNeck(nn.Module):

    def __init__(self, inp, oup, stride, expand_ratio):
        super().__init__()
        self.stride = stride
        assert stride in [1, 2]
        hidden_dim = int(round(inp * expand_ratio))
        self.use_res_connect = self.stride == 1 and inp == oup
        layers = []
        if expand_ratio != 1:
            layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1))
        layers.extend([
            ConvBNReLU(hidden_dim, hidden_dim, 
                stride=stride, groups=hidden_dim),
            nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
            nn.BatchNorm2d(oup), ])
        self.conv = nn.Sequential(*layers)

    def forward(self, x):
        if self.use_res_connect:
            return x + self.conv(x)
        else:
            return self.conv(x)


default_spec = [
    [1, 16, 1, 1], # t, c, n, s
    [6, 24, 2, 2],
    [6, 32, 3, 2],
    [6, 64, 4, 2],
    [6, 96, 3, 1],
    [6, 160, 3, 2],
    [6, 320, 1, 1], ]


class MobileNetV2(nn.Module):

    def __init__(self):
        super().__init__()
        input_channel = 32
        last_channel = 1280
        features = [ConvBNReLU(3, input_channel, stride=2)]
        for t, c, n, s in default_spec:
            output_channel = c
            for i in range(n):
                stride = s if i == 0 else 1
                features.append(
                    BottleNeck(
                        input_channel, output_channel,
                        stride, expand_ratio=t))
                input_channel = output_channel
        features.append(ConvBNReLU(input_channel, 
            last_channel, kernel_size=1))
        self.features = nn.Sequential(*features)
        self.classifier = nn.Sequential(
            nn.Dropout(0.2),
            nn.Linear(last_channel, 1000), )

    def forward(self, x):
        x = self.features(x)
        x = x.mean([2, 3])
        x = self.classifier(x)
        return x


net = MobileNetV2()
```

## Transformer

```Python
"""
The Transformer model from paper Attention is all you need.
The Transformer instance accepts two inputs:
x is Tensor with shape (Batch, Channel, LengthX).
    usually a source sequence from embedding (in such cases,
    Channel equals the embedding size).
y is Tensor with shape (Batch, Channel, lengthY).
    usually a target sequence, also from embedding.
**kw is passed down to inner components.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import warm
import warm.functional as W


def multi_head_attention(x, y=None, num_head=8, dropout=0.1, mask=None, **kw):
    def split_heads(t):
        return t.reshape(batch, num_head, size//num_head, t.shape[-1])
    def merge_heads(t):
        return t.reshape(batch, -1, t.shape[-1])
    if y is None:
        y = x # self attention
    batch, size = x.shape[:2]
    assert size%num_head == 0, 'num_head must be a divisor of size.'
    assert y.shape[:2] == x.shape[:2], 'The first 2 dims of x, y must match.'
    q = W.linear(x, size) # query
    k = W.linear(y, size) # key
    v = W.linear(y, size) # value
    q = split_heads(q)
    k = split_heads(k)
    v = split_heads(v)
    q *= (size//num_head)**(-0.5)
    a = q.transpose(2, 3).contiguous().matmul(k) # attention weights
    if mask is not None:
        a += mask
    a = F.softmax(a, dim=-1)
    a = W.dropout(a, dropout)
    x = v.matmul(a.transpose(2, 3).contiguous())
    x = merge_heads(x)
    return W.linear(x, size)


def feed_forward(x, size_ff=2048, dropout=0.1, **kw):
    y = W.linear(x, size_ff, activation='relu')
    y = W.dropout(y, dropout)
    return W.linear(y, x.shape[1])


def residual_add(x, layer, dropout=0.1, **kw):
    y = W.layer_norm(x)
    y = layer(y, **kw)
    y = W.dropout(y, dropout)
    return x+y


def encoder(x, num_encoder=6, **kw):
    for i in range(num_encoder):
        x = residual_add(x, multi_head_attention, **kw)
        x = residual_add(x, feed_forward, **kw)
    return W.layer_norm(x)


def decoder(x, y, num_decoder=6, mask_x=None, mask_y=None, **kw):
    for i in range(num_decoder):
        y = residual_add(y, multi_head_attention, mask=mask_y, **kw)
        y = residual_add(x, multi_head_attention, y=y, mask=mask_x, **kw)
        y = residual_add(y, feed_forward, **kw)
    return W.layer_norm(y)


def transformer(x, y, **kw):
    x = encoder(x, **kw)
    x = decoder(x, y, **kw)
    return x


class Transformer(nn.Module):

    def __init__(self, *shape, **kw):
        super().__init__()
        self.kw = kw
        warm.up(self, *shape)
        
    def forward(self, x, y):
        return transformer(x, y, **self.kw)

```


## EfficientNet

For a brief overview, check the [blog post](https://blue-season.github.io/efficientnet-in-5-minutes/).

```python
"""
EfficientNet model from https://arxiv.org/abs/1905.11946
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import warm
import warm.functional as W


def swish(x):
    return x*torch.sigmoid(x)


def squeeze_excitation(x, size_se):
    if size_se == 0:
        return x
    size_in = x.shape[1]
    x = F.adaptive_avg_pool2d(x, 1)
    x = W.conv(x, size_se, 1, activation=swish)
    return W.conv(x, size_in, 1, activation=swish)


def drop_connect(x, rate):
    if rate == 0:
        return x
    rate = 1.0-rate
    drop_mask = rate + torch.rand([x.shape[0], 1, 1, 1],
        device=x.device, requires_grad=False)
    return x/rate*drop_mask.floor()


def conv_pad_same(x, size, kernel=1, stride=1, **kw):
    """ Same padding so that out_size*stride == in_size. """
    pad = 0
    if kernel != 1 or stride != 1:
        in_size, s, k = [torch.as_tensor(v)
            for v in (x.shape[2:], stride, kernel)]
        pad = torch.max(((in_size+s-1)//s-1)*s+k-in_size, torch.tensor(0))
        left, right = pad//2, pad-pad//2
        if torch.all(left == right):
            pad = tuple(left.tolist())
        else:
            left, right = left.tolist(), right.tolist()
            pad = sum(zip(left[::-1], right[::-1]), ())
            x = F.pad(x, pad)
            pad = 0
    return W.conv(x, size, kernel, stride=stride, padding=pad, **kw)


def conv_bn_act(x, size, kernel=1, stride=1, groups=1, 
        bias=False, eps=1e-3, momentum=1e-2, act=swish):
    x = conv_pad_same(x, size, kernel, stride=stride, groups=groups, bias=bias)
    return W.batch_norm(x, eps=eps, momentum=momentum, activation=act)


def mb_block(x, size_out, expand=1, kernel=1, stride=1,
        se_ratio=0.25, dc_ratio=0.2):
    """ Mobilenet Bottleneck Block. """
    size_in = x.shape[1]
    size_mid = size_in*expand
    y = conv_bn_act(x, size_mid, 1) if expand > 1 else x
    y = conv_bn_act(y, size_mid, kernel, stride=stride, groups=size_mid)
    y = squeeze_excitation(y, int(size_in*se_ratio))
    y = conv_bn_act(y, size_out, 1, act=None)
    if stride == 1 and size_in == size_out:
        y = drop_connect(y, dc_ratio)
        y += x
    return y


spec_b0 = (
# size, expand, kernel, stride, repeat, squeeze_excitation, drop_connect
    (16, 1, 3, 1, 1, 0.25, 0.2),
    (24, 6, 3, 2, 2, 0.25, 0.2),
    (40, 6, 5, 2, 2, 0.25, 0.2),
    (80, 6, 3, 2, 3, 0.25, 0.2),
    (112, 6, 5, 1, 3, 0.25, 0.2),
    (192, 6, 5, 2, 4, 0.25, 0.2),
    (320, 6, 3, 1, 1, 0.25, 0.2), )


class WarmEfficientNet(nn.Module):
    def __init__(self):
        super().__init__()
        warm.up(self, [2, 3, 32, 32])
    def forward(self, x):
        x = conv_bn_act(x, 32, kernel=3, stride=2)
        for size, expand, kernel, stride, repeat, se, dc in spec_b0:
            for i in range(repeat):
                stride = stride if i == 0 else 1
                x = mb_block(x, size, expand, kernel, stride, se, dc)
        x = conv_bn_act(x, 1280)
        x = F.adaptive_avg_pool2d(x, 1)
        x = W.dropout(x, 0.2)
        x = x.view(x.shape[0], -1)
        x = W.linear(x, 1000)
        return x
```


================================================
FILE: docs/text.mako
================================================
## Define mini-templates for each portion of the doco.

<%!
  def indent(s, spaces=4):
      new = s.replace('\n', '\n' + ' ' * spaces)
      return ' ' * spaces + new.strip()
%>

<%def name="deflist(s)">:${indent(s)[1:]}</%def>

<%def name="h3(s)">### ${s}
</%def>

<%def name="function(func)" buffered="True">
    <%
        returns = show_type_annotations and func.return_annotation() or ''
        if returns:
            returns = ' -> ' + returns
    %>
${"---"}
${"### " + func.name}


```python3
def :
    ${",\n  ".join(func.params(annotate=show_type_annotations))} ${returns}
```
${func.docstring}

% if show_source_code and func.source and func.obj is not getattr(func.inherits, 'obj', None):

??? example "View Source"
        ${"\n        ".join(func.source.split("\n"))}

% endif
</%def>

<%def name="variable(var)" buffered="True">
```python3
${var.name}
```
${var.docstring | deflist}
</%def>

<%def name="class_(cls)" buffered="True">
${"---"}
${"### " + cls.name}

```python3
def :
    ${",\n  ".join(cls.params(annotate=show_type_annotations))}
```

${cls.docstring}

% if show_source_code and cls.source:

??? example "View Source"
        ${"\n        ".join(cls.source.split("\n"))}

------

% endif

<%
  class_vars = cls.class_variables(show_inherited_members, sort=sort_identifiers)
  static_methods = cls.functions(show_inherited_members, sort=sort_identifiers)
  inst_vars = cls.instance_variables(show_inherited_members, sort=sort_identifiers)
  methods = cls.methods(show_inherited_members, sort=sort_identifiers)
  mro = cls.mro()
  subclasses = cls.subclasses()
%>
% if mro:
${h3('Ancestors (in MRO)')}
    % for c in mro:
* ${c.refname}
    % endfor
% endif

% if subclasses:
${h3('Descendants')}
    % for c in subclasses:
* ${c.refname}
    % endfor
% endif

% if class_vars:
${h3('Class variables')}
    % for v in class_vars:
${variable(v)}

    % endfor
% endif

% if static_methods:
${h3('Static methods')}
    % for f in static_methods:
${function(f)}

    % endfor
% endif

% if inst_vars:
${h3('Instance variables')}
% for v in inst_vars:
${variable(v)}

% endfor
% endif
% if methods:
${h3('Methods')}
    % for m in methods:
${function(m)}

% endfor
% endif

</%def>

## Start the output logic for an entire module.

<%
  variables = module.variables()
  classes = module.classes()
  functions = module.functions()
  submodules = module.submodules()
  heading = 'Namespace' if module.is_namespace else 'Module'
%>

${"# " + heading} ${module.name}

${module.docstring}

% if show_source_code:

??? example "View Source"
        ${"\n        ".join(module.source.split("\n"))}

% endif


% if submodules:
Sub-modules
-----------
    % for m in submodules:
* [${m.name}](${m.name.split(".")[-1]}/)
    % endfor
% endif

% if variables:
Variables
---------
    % for v in variables:
${variable(v)}

    % endfor
% endif

% if functions:
Functions
---------
    % for f in functions:
${function(f)}

    % endfor
% endif

% if classes:
Classes
-------
    % for c in classes:
${class_(c)}

    % endfor
% endif


================================================
FILE: docs/tutorial.md
================================================

# PyWarm Basic Tutorial

## Import

To get started, first import PyWarm in your project:

```Python
import warm
import warm.functional as W
```

## Rewrite

Now you can replace child module definitions with function calls. 
For example, instead of:

```Python
# Torch
class MyModule(nn.Module):

    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size)
        # other child module definitions

    def forward(self, x):
        x = self.conv1(x)
        # more forward steps
```

You now use the warm functions:

```Python
# Warm
class MyWarmModule(nn.Module):

    def __init__(self):
        super().__init__()
        warm.up(self, input_shape_or_data)

    def forward(self, x):
        x = W.conv(x, out_channels, kernel_size) # no in_channels needed
        # more forward steps
```

Notice the `warm.up(self, input_shape_or_data)` at the end of the `__init__()` method.
It is required so that PyWarm can infer all shapes of itermediate steps and set up trainable parameters.
The only argument `input_shape_or_data` can either be a tensor, e.g. `torch.randn(2, 1, 28, 28)`,
or just the shape, e.g. `[2, 1, 28, 28]` for the model inputs. If the model has multiple inputs,
you may supple them in a list or a dictionary.

Although it is recommended that you attach `warm.up()` to the end of the `__init__()` of your model, you can actually
use it on the class instances outside of the definition, like a normal function call:

```Python
class MyWarmModule(nn.Module):

    def __init__(self):
        super().__init__() # no warm.up here

    def forward(self, x):
        x = W.conv(x, 10, 3)
        # forward step, powered by PyWarm


model = MyWarmModule() # call warm.up outside of the module definition

warm.up(model, [2, 1, 28, 28])
```

**Note**: If the model contains `batch_norm` layers, you need to specify the `Batch` dimension to at least 2.

# Advanced Topics

## Default shapes

PyWarm has a unified functional interface, that by default all functions accept and return tensors with shape
`(Batch, Channel, *)`, where `*` is any number of additional dimensions. For example, for 2d images,
the `*` usually stands for `(Height, Width)`, and for 1d time series, the `*` means `(Time,)`.

This convention is optimized for the performance of Convolutional networks. It may become less efficient if your
model relies heavily on dense (Linear) or recurrent (LSTM, GRU) layers. You can use different input and
output shapes by specifying `in_shape`, `out_shape` keyword arguments in the function calls. These keywords
accept only letters `'B'`, `'C'` and `'D'`, which stand for `Batch`, `Channel`, and `*` (extra Dimensions)
respectively. So for example if for a 1d time series you want to have `(Time, Batch, Channel)` as the output shape,
you can specify `out_shape='DBC'`.

## Dimensional awareness

PyWarm functions can automatically identify 1d, 2d and 3d input data, so the same function can be used on different
dimensional cases. For example, the single `W.conv` is enough to replace `nn.Conv1d, nn.Conv2d, nn.Conv3d`.
Similarly, you don't need `nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d` for differnt inputs, a single `W.batch_norm`
can replace them all.

## Shape inference

Many neural network layers will perform a transformation of shapes. For example, after a convolution operation,
the shape is changed from `(Batch, ChannelIn, *)` to `(Batch, ChannelOut, *)`. PyTorch nn Modules require the user to 
keep track of both `in_channels` and `out_channels`. PyWarm relieves this pain by inferring the `in_channels` for you,
so you can focus more on the nature of your tasks, rather than chores.

## Argument passdown

If the signature of a PyWarm function does not specify all possible argument of its torch nn Module couterpart, it will pass down
additional keyword arguments to the underlying nn Module. For example, if you want to specify strides of 2 for a conv layer,
just use `W.conv(..., stride=2)`. The only thing to remember is that you have to specify the full keyword, instead of 
relying on the position of arguments.

## Parameter initialization per layer

Unlike PyTorch's approach, paramter initialization can be specified directly in PyWarm's functional interface.
For example:

```Python
x = W.conv(x, 20, 1, init_weight='kaiming_uniform_')
```
This makes it easier to create layer specific initialization in PyWarm. You no long need to go through
`self.modules()` and `self.parameters()` to create custom initializations.

By default, PyWarm will look into `torch.nn.init` for initialization function names.
Alternatively, you may just specify a callable, or a tuple `(fn, kwargs)` if the callable accepts more than 1 input.

If the initialization is not specified or `None` is used, the corresponding layer will get default initializations as used
in torch nn modules. 

## Apply activation nonlinearity to the output

PyWarm's functional interface supports adding an optional keyword argument `activation=name`, where
name is a callable or just its name, which represents an activation (nonlinearity) function
in `torch.nn.functional` or just `torch`. By default no activation is used.

## Mix and Match

You are not limited to only use PyWarm's functional interface. It is completely ok to mix and match the old
PyTorch way of child module definitions with PyWarm's function API. For example:

```Python
class MyModel(nn.Module):

    def __init__(self):
        super().__init__()
        # other stuff
        self.conv1 = nn.Conv2d(2, 30, 7, padding=3)
        # other stuff

    def forward(self, x):
        y = F.relu(self.conv1(x))
        y = W.conv(y, 40, 3, activation='relu')
```

## Custom layer names

Normally you do not have to specify layer names when using the functional API.
PyWarm will track and count usage for the layer type and automatically assign names for you. For example,
subsequent convolutional layer calls via `W.conv` will create `conv_1`, `conv_2`, ... etc. in the parent module.

Nevertheless, if you want to ensure certain layer have particular names, you can specify `name='my_name'`
keyword arguments in the call.

Alternatively, if you still want PyWarm to count usage and increment ordinal for you, but only want to customize 
the base type name, you can use `base_name='my_prefix'` keyword instead. The PyWarm modules will then have
names like `my_prefix_1`, `my_prefix_2` in the parent module.

See the PyWarm [resnet example in the examples folder](https://github.com/blue-season/pywarm/blob/master/examples/resnet.py)
on how to use these features to load pre-trained model parameters into PyWarm models.


================================================
FILE: examples/efficientnet.py
================================================

# 09-20-2019;
"""
EfficientNet
"""
from pathlib import Path
import sys
sys.path.append(str(Path(__file__).parent.parent))
sys.path.append('..')
import torch
import torch.nn as nn
import torch.nn.functional as F
import warm
import warm.util
import warm.functional as W
from warm.engine import namespace


def swish(x):
    return x*torch.sigmoid(x)


def conv_pad_same(x, size, kernel=1, stride=1, **kw):
    pad = 0
    if kernel != 1 or stride != 1:
        in_size, s, k = [torch.as_tensor(v) for v in (x.shape[2:], stride, kernel)]
        pad = torch.max(((in_size+s-1)//s-1)*s+k-in_size, torch.tensor(0))
        left, right = pad//2, pad-pad//2
        if torch.all(left == right):
            pad = tuple(left.tolist())
        else:
            left, right = left.tolist(), right.tolist()
            pad = sum(zip(left[::-1], right[::-1]), ())
            x = F.pad(x, pad)
            pad = 0
    return W.conv(x, size, kernel, stride=stride, padding=pad, **kw)


@namespace
def conv_bn_act(x, size, kernel=1, stride=1, groups=1, bias=False, eps=1e-3, momentum=1e-2, act=swish, name='', **kw):
    x = conv_pad_same(x, size, kernel, stride=stride, groups=groups, bias=bias, name=name+'-conv')
    return W.batch_norm(x, eps=eps, momentum=momentum, activation=act, name=name+'-bn')


@namespace
def mb_block(x, size_out, expand=1, kernel=1, stride=1, se_ratio=0.25, dc_ratio=0.2, **kw):
    """ MobileNet Bottleneck Block. """
    size_in = x.shape[1]
    size_mid = size_in*expand
    y = conv_bn_act(x, size_mid, 1, **kw) if expand > 1 else x
    y = conv_bn_act(y, size_mid, kernel, stride=stride, groups=size_mid, **kw)
    y = squeeze_excitation(y, int(size_in*se_ratio), **kw)
    y = conv_bn_act(y, size_out, 1, act=None, **kw)
    if stride == 1 and size_in == size_out:
        y = drop_connect(y, dc_ratio)
        y += x
    return y


@namespace
def squeeze_excitation(x, size_se, name='', **kw):
    if size_se == 0:
        return x
    size_in = x.shape[1]
    x = F.adaptive_avg_pool2d(x, 1)
    x = W.conv(x, size_se, 1, activation=swish, name=name+'-conv1')
    return W.conv(x, size_in, 1, activation=swish, name=name+'-conv2')


def drop_connect(x, rate):
    """ Randomly set entire batch to 0. """
    if rate == 0:
        return x
    rate = 1.0-rate
    drop_mask = torch.rand([x.shape[0], 1, 1, 1], device=x.device, requires_grad=False)+rate
    return x/rate*drop_mask.floor()


spec_b0 = (
    (16, 1, 3, 1, 1, 0.25, 0.2), # size, expand, kernel, stride, repeat, se_ratio, dc_ratio
    (24, 6, 3, 2, 2, 0.25, 0.2),
    (40, 6, 5, 2, 2, 0.25, 0.2),
    (80, 6, 3, 2, 3, 0.25, 0.2),
    (112, 6, 5, 1, 3, 0.25, 0.2),
    (192, 6, 5, 2, 4, 0.25, 0.2),
    (320, 6, 3, 1, 1, 0.25, 0.2), )


class WarmEfficientNet(nn.Module):
    def __init__(self):
        super().__init__()
        warm.up(self, [2, 3, 32, 32])
    def forward(self, x):
        x = conv_bn_act(x, 32, kernel=3, stride=2, name='head')
        for size, expand, kernel, stride, repeat, se_ratio, dc_ratio in spec_b0:
            for i in range(repeat):
                stride = stride if i == 0 else 1
                x = mb_block(x, size, expand, kernel, stride, se_ratio, dc_ratio)
        x = conv_bn_act(x, 1280, name='tail')
        x = F.adaptive_avg_pool2d(x, 1)
        x = W.dropout(x, 0.2)
        x = x.view(x.shape[0], -1)
        x = W.linear(x, 1000)
        return x


if __name__ == '__main__':
    m = WarmEfficientNet()
    warm.util.summary(m)


================================================
FILE: examples/lstm.py
================================================
# 09-07-2019;
"""
LSTM sequence model example, based on
https://pytorch.org/tutorials/beginner/nlp/sequence_models_tutorial.html
"""
import argparse
from pathlib import Path
import sys
sys.path.append(str(Path(__file__).parent.parent))
sys.path.append('..')
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import warm
import warm.functional as W


training_data = [
    ('The dog ate the apple'.split(), ['DET', 'NN', 'V', 'DET', 'NN']),
    ('Everybody read that book'.split(), ['NN', 'V', 'DET', 'NN']), ]
testing_data = [('The dog ate the book'.split(), ['DET', 'NN', 'V', 'DET', 'NN'])]
word_to_ix = {}
for sent, tags in training_data:
    for word in sent:
        if word not in word_to_ix:
            word_to_ix[word] = len(word_to_ix)
tag_to_ix = {'DET': 0, 'NN': 1, 'V': 2}
ix_to_tag = {v:k for k, v in tag_to_ix.items()}


class WarmTagger(nn.Module):
    def __init__(self, embedding_dim, hidden_dim, vocab_size, tagset_size):
        super().__init__()
        self.arg = (embedding_dim, hidden_dim, vocab_size, tagset_size)
        warm.up(self, torch.tensor([0, 1], dtype=torch.long))
    def forward(self, x): # D
        embedding_dim, hidden_dim, vocab_size, tagset_size = self.arg
        y = W.embedding(x, embedding_dim, vocab_size) # D->DC
        y = W.lstm(y.T[None, ...], hidden_dim) # DC->BCD
        y = W.linear(y, tagset_size) # BCD
        y = F.log_softmax(y, dim=1) # BCD
        return y[0].T # DC


class TorchTagger(nn.Module):
    def __init__(self, embedding_dim, hidden_dim, vocab_size, tagset_size):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.word_embeddings = nn.Embedding(vocab_size, embedding_dim)
        self.lstm = nn.LSTM(embedding_dim, hidden_dim)
        self.hidden2tag = nn.Linear(hidden_dim, tagset_size)
    def forward(self, sentence):
        embeds = self.word_embeddings(sentence)
        lstm_out, _ = self.lstm(embeds.view(len(sentence), 1, -1))
        tag_space = self.hidden2tag(lstm_out.view(len(sentence), -1))
        tag_scores = F.log_softmax(tag_space, dim=1)
        return tag_scores


def prepare_sequence(seq, to_ix):
    idxs = [to_ix[w] for w in seq]
    return torch.tensor(idxs, dtype=torch.long)


def main():
    parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
    parser.add_argument(
        '--warm', action='store_true', help='use warm instead of vanilla pytorch.')
    p = parser.parse_args()
    torch.manual_seed(1)
    #
    arg = (6, 6, len(word_to_ix), len(tag_to_ix))
    model = WarmTagger(*arg) if p.warm else TorchTagger(*arg)
    print(f'Using {model._get_name()}.')
    loss_function = nn.NLLLoss()
    optimizer = optim.SGD(model.parameters(), lr=0.1)
    #
    for epoch in range(300):
        for sentence, tags in training_data:
            model.zero_grad()
            sentence_in = prepare_sequence(sentence, word_to_ix)
            targets = prepare_sequence(tags, tag_to_ix)
            tag_scores = model(sentence_in)
            loss = loss_function(tag_scores, targets)
            loss.backward()
            optimizer.step()
    #
    with torch.no_grad():
        inputs = prepare_sequence(testing_data[0][0], word_to_ix)
        tag_scores = model(inputs)
        ix = torch.argmax(tag_scores, -1).numpy()
        print(testing_data[0][0])
        print('Network tags:\n', [ix_to_tag[i] for i in ix])
        print('True tags:\n', testing_data[0][1])


if __name__ == '__main__':
    main()


================================================
FILE: examples/mnist.py
================================================
# 08-27-2019;
"""
MNIST training example.
Use `python mnist.py` to run with PyTorch NN.
Use `python mnist.py --warm` to run with PyWarm NN.
Use `python mnist.py --help` to see a list of cli argument options.
"""
from pathlib import Path
import sys
sys.path.append(str(Path(__file__).parent.parent))
sys.path.append('..')
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
import warm
import warm.functional as W


class WarmNet(nn.Module):
    def __init__(self):
        super().__init__()
        warm.up(self, [1, 1, 28, 28])
    def forward(self, x):
        x = W.conv(x, 20, 5, activation='relu')
        x = F.max_pool2d(x, 2)
        x = W.conv(x, 50, 5, activation='relu')
        x = F.max_pool2d(x, 2)
        x = x.view(-1, 800)
        x = W.linear(x, 500, activation='relu')
        x = W.linear(x, 10)
        return F.log_softmax(x, dim=1)


class TorchNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5, 1)
        self.conv2 = nn.Conv2d(20, 50, 5, 1)
        self.fc1 = nn.Linear(4*4*50, 500)
        self.fc2 = nn.Linear(500, 10)
    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2, 2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2, 2)
        x = x.view(-1, 4*4*50)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)


def train(p, model, device, train_loader, optimizer, epoch):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = F.nll_loss(output, target)
        loss.backward()
        optimizer.step()
        if batch_idx%p.log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx*len(data), len(train_loader.dataset),
                100.*batch_idx/len(train_loader), loss.item()))


def test(p, model, device, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    size = len(test_loader.dataset)
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.nll_loss(output, target, reduction='sum').item()
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()
    test_loss /= size
    print(f'\nTest loss: {test_loss:.4f}, Accuracy: {correct}/{size} ({100*correct/size:.2f}%)\n')


def main():
    parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
    parser.add_argument(
        '--warm', action='store_true', help='use warm instead of vanilla pytorch.')
    parser.add_argument(
        '--batch-size', type=int, default=128, metavar='N', help='input batch size for training (default: 128)')
    parser.add_argument(
        '--test-batch-size', type=int, default=1000, metavar='N', help='input batch size for testing (default: 1000)')
    parser.add_argument(
        '--epochs', type=int, default=3, metavar='N', help='number of epochs to train (default: 3)')
    parser.add_argument(
        '--lr', type=float, default=0.02, metavar='LR', help='learning rate (default: 0.02)')
    parser.add_argument(
        '--momentum', type=float, default=0.5, metavar='M', help='SGD momentum (default: 0.5)')
    parser.add_argument(
        '--no-cuda', action='store_true', default=False, help='disables CUDA training')
    parser.add_argument(
        '--seed', type=int, default=1, metavar='S', help='random seed (default: 1)')
    parser.add_argument(
        '--log-interval', type=int, default=10, metavar='N', help='number of batchs between logging training status')
    parser.add_argument(
        '--save-model', action='store_true', default=False, help='For Saving the current Model')
    p = parser.parse_args()
    #
    torch.manual_seed(p.seed)
    use_cuda = not p.no_cuda and torch.cuda.is_available()
    device = 'cuda' if use_cuda else 'cpu'
    kw = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}
    data_transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,)), ])
    train_data = datasets.MNIST('../data', train=True, download=True, transform=data_transform)
    test_data = datasets.MNIST('../data', train=False, download=True, transform=data_transform)
    train_loader = torch.utils.data.DataLoader(train_data, batch_size=p.batch_size, shuffle=True, **kw)
    test_loader = torch.utils.data.DataLoader(test_data, batch_size=p.test_batch_size, shuffle=True, **kw)
    model = WarmNet() if p.warm else TorchNet()
    print(f'Using {model._get_name()}.')
    model = model.to(device)
    optimizer = optim.SGD(model.parameters(), lr=p.lr, momentum=p.momentum)
    print(f'Training with {p.epochs} epochs on {device} device.')
    #
    for i in range(p.epochs):
        train(p, model, device, train_loader, optimizer, i)
        test(p, model, device, test_loader)
    #
    if p.save_model:
        torch.save(model.state_dict(), 'mnist_cnn.pt')


if __name__ == '__main__':
    main()


================================================
FILE: examples/mobilenet.py
================================================
# 09-03-2019;
"""
Construct a WarmMobileNetV2() using PyWarm, then copy state dicts
from torchvision.models.mobilenet_v2() into WarmMobileNetV2(),
compare if it produce identical results as the official one.
"""
from pathlib import Path
import sys
sys.path.append(str(Path(__file__).parent.parent))
sys.path.append('..')
import torch
import torch.nn as nn
import torch.nn.functional as F
import warm
import warm.util
import warm.functional as W


def conv_bn_relu(x, size, stride=1, expand=1, kernel=3, groups=1, name=''):
    x = W.conv(x, size, kernel, padding=(kernel-1)//2, stride=stride, groups=groups, bias=False,
        name=f'{name}-0', )
    return W.batch_norm(x, activation='relu6', name=f'{name}-1')


def bottleneck(x, size_out, stride, expand, name=''):
    size_in = x.shape[1]
    size_mid = size_in*expand
    y = conv_bn_relu(x, size_mid, kernel=1, name=f'{name}-conv-0') if expand > 1 else x
    y = conv_bn_relu(y, size_mid, stride, kernel=3, groups=size_mid, name=f'{name}-conv-{1 if expand > 1 else 0}')
    y = W.conv(y, size_out, kernel=1, bias=False, name=f'{name}-conv-{2 if expand > 1 else 1}')
    y = W.batch_norm(y, name=f'{name}-conv-{3 if expand > 1 else 2}')
    if stride == 1 and size_in == size_out:
        y += x # residual shortcut
    return y


def conv1x1(x, *arg, **kw):
    return conv_bn_relu(x, *arg, kernel=1, **kw)


def pool(x, *arg, **kw):
    return x.mean([2, 3])


def classify(x, size, *arg, **kw):
    x = W.dropout(x, rate=0.2, name='classifier-0')
    return W.linear(x, size, name='classifier-1')


default_spec = (
    (None, 32, 1, 2, conv_bn_relu),  # t, c, n, s, operator
    (1, 16, 1, 1, bottleneck),
    (6, 24, 2, 2, bottleneck),
    (6, 32, 3, 2, bottleneck),
    (6, 64, 4, 2, bottleneck),
    (6, 96, 3, 1, bottleneck),
    (6, 160, 3, 2, bottleneck),
    (6, 320, 1, 1, bottleneck),
    (None, 1280, 1, 1, conv1x1),
    (None, None, 1, None, pool),
    (None, 1000, 1, None, classify), )


class WarmMobileNetV2(nn.Module):
    def __init__(self):
        super().__init__()
        warm.up(self, [2, 3, 224, 224])
    def forward(self, x):
        count = 0
        for t, c, n, s, op in default_spec:
            for i in range(n):
                stride = s if i == 0 else 1
                x = op(x, c, stride, t, name=f'features-{count}')
                count += 1
        return x


def test():
    """ Compare the classification result of WarmMobileNetV2 versus torchvision mobilenet_v2. """
    new = WarmMobileNetV2()
    from torchvision.models import mobilenet_v2
    old = mobilenet_v2()
    state = old.state_dict()
    for k in list(state.keys()): # Map parameters of old, e.g. layer2.0.conv1.weight
        s = k.split('.') # to parameters of new, e.g. layer2-0-conv1.weight
        s = '-'.join(s[:-1])+'.'+s[-1]
        state[s] = state.pop(k)
    new.load_state_dict(state)
    warm.util.summary(old)
    warm.util.summary(new)
    x = torch.randn(1, 3, 224, 224)
    with torch.no_grad():
        old.eval()
        y_old = old(x)
        new.eval()
        y_new = new(x)
        if torch.equal(y_old, y_new):
            print('Success! Same results from old and new.')
        else:
            print('Warning! New and old produce different results.')


if __name__ == '__main__':
    test()


================================================
FILE: examples/resnet.py
================================================
# 08-29-2019;
"""
Construct a WarmResNet() using PyWarm, then copy state dicts
from torchvision.models.resnet18() into WarmResNet(),
compare if it produce identical results as the official one.
"""
from pathlib import Path
import sys
sys.path.append(str(Path(__file__).parent.parent))
sys.path.append('..')
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
import warm
import warm.util
import warm.functional as W


def basic(x, size, stride, stack_index, block_index):
    """ The basic block. """
    prefix = f'layer{stack_index+1}-{block_index}-'
    y = W.conv(x, size, 3, stride=stride, padding=1, bias=False, name=prefix+'conv1')
    y = W.batch_norm(y, activation='relu', name=prefix+'bn1')
    y = W.conv(y, size, 3, stride=1, padding=1, bias=False, name=prefix+'conv2')
    y = W.batch_norm(y, name=prefix+'bn2')
    if y.shape[1] != x.shape[1]:
        x = W.conv(x, y.shape[1], 1, stride=stride, bias=False, name=prefix+'downsample-0')
        x = W.batch_norm(x, name=prefix+'downsample-1')
    return F.relu(y+x)


def stack(x, num_block, size, stride, stack_index, block=basic):
    """ A stack of num_block blocks. """
    for block_index, s in enumerate([stride]+[1]*(num_block-1)):
        x = block(x, size, s, stack_index, block_index)
    return x


class WarmResNet(nn.Module):
    def __init__(self, block=basic, stack_spec=((2, 64, 1), (2, 128, 2), (2, 256, 2), (2, 512, 2))):
        super().__init__()
        self.block = block
        self.stack_spec = stack_spec
        warm.up(self, [2, 3, 32, 32])
    def forward(self, x):
        y = W.conv(x, 64, 7, stride=2, padding=3, bias=False, name='conv1')
        y = W.batch_norm(y, activation='relu', name='bn1')
        y = F.max_pool2d(y, 3, stride=2, padding=1)
        for i, spec in enumerate(self.stack_spec):
            y = stack(y, *spec, i, block=self.block)
        y = F.adaptive_avg_pool2d(y, 1)
        y = torch.flatten(y, 1)
        y = W.linear(y, 1000, name='fc')
        return y


def test_time(fn, *arg, repeat=10, **kw):
    dur = 0.0
    for i in range(repeat):
        start = time.time()
        y = fn(*arg, **kw)
        dur += time.time()-start
    return dur


def test():
    """ Compare the classification result of WarmResNet versus torchvision resnet18. """
    new = WarmResNet()
    from torchvision.models import resnet18
    old = resnet18()
    state = old.state_dict()
    for k in list(state.keys()): # Map parameters of old, e.g. layer2.0.conv1.weight
        s = k.split('.') # to parameters of new, e.g. layer2-0-conv1.weight
        s = '-'.join(s[:-1])+'.'+s[-1]
        state[s] = state.pop(k)
    new.load_state_dict(state)
    warm.util.summary(old)
    warm.util.summary(new)
    x = torch.randn(2, 3, 224, 224)
    with torch.no_grad():
        old.eval()
        y_old = old(x)
        new.eval()
        y_new = new(x)
        if torch.equal(y_old, y_new):
            print('Success! Same results from old and new.')
        else:
            print('Warning! New and old produce different results.')
        t_old = test_time(old, x)
        t_new = test_time(new, x)
        print('Total forward time for old:', t_old, 'seconds.')
        print('Total forward time for new:', t_new, 'seconds.')


if __name__ == '__main__':
    test()


================================================
FILE: examples/transformer.py
================================================
# 09-05-2019;
"""
The Transformer model from paper *Attention is all you need*.
"""
from pathlib import Path
import sys
sys.path.append(str(Path(__file__).parent.parent))
sys.path.append('..')
import torch
import torch.nn as nn
import torch.nn.functional as F
import warm
import warm.util
import warm.functional as W


def multi_head_attention(x, y=None, num_head=8, dropout=0.1, mask=None, **kw):
    def split_heads(t): # (B, C, L) -> (B, N, H, L) where N*H == C
        return t.reshape(batch, num_head, size//num_head, t.shape[-1])
    def merge_heads(t): # (B, N, H, L) -> (B, C, L)
        return t.reshape(batch, -1, t.shape[-1]) # (B, C, L)
    if y is None:
        y = x # self attention
    batch, size = x.shape[:2] # B, C, Lx
    assert size%num_head == 0, 'num_head must be a divisor of size.'
    assert y.shape[:2] == x.shape[:2], 'The first 2 dims of x, y must match.'
    q = W.linear(x, size) # query
    k = W.linear(y, size) # key
    v = W.linear(y, size) # value
    q = split_heads(q) # (B, N, H, Lx)
    k = split_heads(k) # (B, N, H, Ly)
    v = split_heads(v) # (B, N, H, Ly)
    q *= (size//num_head)**(-0.5)
    a = q.transpose(2, 3).contiguous().matmul(k) # attention weights, (B, N, Lx, Ly)
    if mask is not None:
        a += mask
    a = F.softmax(a, dim=-1)
    a = W.dropout(a, dropout)
    x = v.matmul(a.transpose(2, 3).contiguous()) # (B, N, H, Lx)
    x = merge_heads(x) # (B, C, Lx)
    return W.linear(x, size)


def feed_forward(x, size_ff=2048, dropout=0.1, **kw):
    y = W.linear(x, size_ff, activation='relu')
    y = W.dropout(y, dropout)
    return W.linear(y, x.shape[1])


def residual_add(x, layer, dropout=0.1, **kw):
    y = W.layer_norm(x)
    y = layer(y, **kw)
    y = W.dropout(y, dropout)
    return x+y


def encoder(x, num_encoder=6, **kw):
    for i in range(num_encoder):
        x = residual_add(x, multi_head_attention, **kw)
        x = residual_add(x, feed_forward, **kw)
    return W.layer_norm(x)


def decoder(x, y, num_decoder=6, mask_x=None, mask_y=None, **kw):
    for i in range(num_decoder):
        y = residual_add(y, multi_head_attention, mask=mask_y, **kw)
        y = residual_add(x, multi_head_attention, y=y, mask=mask_x, **kw)
        y = residual_add(y, feed_forward, **kw)
    return W.layer_norm(y)


def transformer(x, y, **kw):
    x = encoder(x, **kw)
    x = decoder(x, y, **kw)
    return x


class Transformer(nn.Module):
    def __init__(self, *shape, **kw):
        super().__init__()
        self.kw = kw
        warm.up(self, *shape)
    def forward(self, x, y):
        return transformer(x, y, **self.kw)


================================================
FILE: pyproject.toml
================================================
[tool.poetry]
name = 'PyWarm'
version = '0.4.1'
description = 'A cleaner way to build neural networks for PyTorch.'
license = 'MIT'
authors = ['blue-season <very.blue.season@gmail.com>']
readme = 'README.md'
repository = 'https://github.com/blue-season/pywarm'
homepage = 'https://github.com/blue-season/pywarm'
keywords = ['pywarm', 'pytorch', 'neural network', 'deep learning']
packages = [ { include='warm' }, ]


[tool.poetry.dependencies]
python = '>=3.6'


[tool.poetry.dev-dependencies]
toml = '>=0.9'
pytest = '>=3.0'
torch = '>=1.0'
torchvision = '>=0.4'


[tool.portray]
modules = ['warm']


[tool.portray.mkdocs]
markdown_extensions = ['pymdownx.superfences']


[tool.portray.mkdocs.theme]
logo = 'docs/pywarm-logo-small-light.gif'
favicon = 'docs/pywarm-logo-small-dark.gif'
name = 'material'
palette = {primary='deep orange', accent='pink'}


[tool.portray.pdoc3]
config = ['show_source_code=False',
    'show_type_annotations=False',
    'sort_identifiers=True',
    'show_inherited_members=False']
template_dir = 'docs'


================================================
FILE: tests/test_engine.py
================================================
# 08-31-2019;
"""
Test cases for warm.engine.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import copy
from pathlib import Path
import sys
sys.path.append(str(Path(__file__).parent.parent))
from warm import engine


def test_set_get_default_parent():
    a = nn.Identity()
    b = nn.Identity()
    engine.set_default_parent(a)
    assert engine.get_default_parent() is a, 'get_default_parent result mismatchs set_default_parent.'
    engine.set_default_parent(b)
    assert engine.get_default_parent() is b, 'get_default_parent result mismatchs set_default_parent.'


def test_auto_name():
    a = nn.Identity()
    for i in range(10):
        assert engine._auto_name('test', a) == f'test_{i+1}', 'new calls to _auto_name failed to increment name count.'
    a(None) # test if forward pre hook is triggered to reset names
    assert engine._auto_name('test', a) == 'test_1', 'forward_pre_hook did not work.'


def test_initialize():
    a = nn.Parameter(torch.zeros(3, 4))
    b = nn.Parameter(torch.zeros(3, 4))
    c = nn.Parameter(torch.zeros(3, 4))
    torch.manual_seed(1)
    engine.initialize_(a, 'normal_')
    torch.manual_seed(1)
    nn.init.normal_(b)
    assert torch.equal(a, b), 'initialize_ with str spec did not work correctly.'
    assert not torch.equal(a, c), 'initialize_ with str spec did not work.'
    torch.manual_seed(1)
    engine.initialize_(c, nn.init.normal_)
    assert torch.equal(a, c), 'initialize_ with function spec did not work correctly.'


def test_activate():
    a = torch.randn(3, 4)
    b = copy.deepcopy(a)
    a = engine.activate(a, 'hardshrink')
    b = F.hardshrink(b)
    assert torch.equal(a, b), 'activate with str spec did not work correctly.'
    a = engine.activate(a, 'relu')
    b = F.relu(b)
    assert torch.equal(a, b), 'activate with str spec did not work correctly.'


def test_permute():
    x = torch.randn(1, 2, 3)
    y = engine.permute(x, 'BCD', 'DCB')
    assert list(y.shape) == [3, 2, 1], 'permute 3d tensor with str in_shape and str out_shape did not work correctly.'
    y = engine.permute(x, 'BCD', None)
    assert list(y.shape) == [1, 2, 3], 'permute tensor with None out_shape did not work corretly.'
    y = engine.permute(x, 'BCD', [1, 0, 2])
    assert list(y.shape) == [2, 1, 3], 'permute tensor with list out_shape did not work corretly.'
    x = torch.randn(1, 2, 3, 4)
    y = engine.permute(x, 'BCD', 'DCB')
    assert list(y.shape) == [3, 4, 2, 1], 'permute 4d tensor with str in_shape and str out_shape did not work correctly.'
    y = engine.permute(x, 'DBC', 'CDB')
    assert list(y.shape) == [4, 1, 2, 3], 'permute 4d tensor with str in_shape and str out_shape did not work correctly.'
    x = torch.randn(1, 2, 3, 4, 5)
    y = engine.permute(x, 'BDC', 'BCD')
    assert list(y.shape) == [1, 5, 2, 3, 4], 'permute 5d tensor with str in_shape and str out_shape did not work correctly.'
    x = torch.randn(1, 2)
    y = engine.permute(x, 'BDC', 'BCD')
    assert list(y.shape) == [1, 2], 'permute 2d tensor with str in_shape and str out_shape did not work correctly.'
    y = engine.permute(x, 'CBD', 'DBC')
    assert list(y.shape) == [2, 1], 'permute 2d tensor with str in_shape and str out_shape did not work correctly.'


def test_unused_kwargs():
    kw = {'unused1':0, 'unused2':0, 'base_class':0}
    unused = engine.unused_kwargs(kw)
    assert 'base_class' not in unused, 'unused_kwargs leaks used.'
    assert set(unused.keys()) == {'unused1', 'unused2'}, 'unused_kwargs did not filter kw correctly.'


def test_prepare_model_is_ready():
    class TestModel(nn.Module):
        def forward(self, x):
            x = engine.forward(x, nn.Linear, 'linear',
                base_arg=(x.shape[-1], 4, False), # in_features, out_features, bias
                in_shape=None, out_shape=None, base_shape=None,
                initialization={'weight':'ones_'}, activation=(F.dropout, {'p':1.0}), )
            return x
    x = torch.randn(1, 2, 3)
    m = TestModel()
    assert not engine.is_ready(m), 'is_ready did not work correctly.'
    engine.prepare_model_(m, x)
    assert engine.is_ready(m), 'prepare_model_ did not work correctly.'
    assert m.linear_1.bias is None, 'linear_1 should not have bias.'
    assert torch.allclose(m.linear_1.weight, torch.Tensor([1.0])), 'linear_1.weight should be initialized to all 1s.'
    y = m(x)
    assert torch.allclose(y, torch.Tensor([0.0])), 'y should be all 0s because we dropout everything.'
    assert list(y.shape) == [1, 2, 4], 'y should have shape [1, 2, 4] after linear projection.'


def test_forward():
    x = torch.randn(1, 2, 3)
    m = nn.Module()
    engine.set_default_parent(m)
    class TripleOut(nn.Module): # to test tuple_out
        def forward(self, x, b=1, c='2'):
            return x+b, x, c
    y = engine.forward(x, base_class=TripleOut, base_name='tri', tuple_out=False)
    assert isinstance(y, torch.Tensor), 'tuple_out did not work correctly.'
    y = engine.forward(x, base_class=TripleOut, base_name='tri', tuple_out=True)
    assert isinstance(y, tuple) and len(y) == 3 and y[-1] == '2', 'tuple_out did not work correctly.'
    y = engine.forward(x, base_class=TripleOut, base_name='tri', forward_kw={'c':3}, tuple_out=True)
    assert y[-1] == 3, 'forward_kw did not work correctly.'
    y = engine.forward(x, base_class=TripleOut, base_name='tri', forward_arg=(2.0,))
    assert torch.allclose(y-x, torch.Tensor([2.0])), 'forward_arg did not work correctly.'
    y = engine.forward(x, base_class=TripleOut, activation=(F.dropout, {'p':1.0}))
    assert torch.allclose(y, torch.Tensor([0.0])), 'activation did not work correctly.'
    y = engine.forward(
        x, base_class=nn.Linear, base_kw={'out_features':4}, infer_kw={'in_features':'C'}, base_shape='BDC')
    assert  y.shape[1] == 4, 'base_kw, infer_kw did not work correctly.'


def test_namespace():
    m = nn.Module()
    engine.set_default_parent(m)
    @engine.namespace
    def f1(name=''):
        return ';'.join([f2(name=name) for i in range(2)])
    @engine.namespace
    def f2(name=''):
        return name
    s0, s1, s2 = [f1() for i in range(3)]
    assert s0 == 'f1_1-f2_1;f1_1-f2_2'
    assert s1 == 'f1_2-f2_1;f1_2-f2_2'
    assert s2 == 'f1_3-f2_1;f1_3-f2_2'


================================================
FILE: tests/test_functional.py
================================================
# 08-31-2019;
"""
Test cases for warm.functional.
"""
import torch
import torch.nn as nn
from pathlib import Path
import sys
sys.path.append(str(Path(__file__).parent.parent))
import warm.module as mm
import warm.functional as W


def test_conv():
    m = nn.Module()
    x = torch.randn(1, 2, 8) # BCD
    torch.manual_seed(100)
    y0 = nn.Conv1d(2, 3, 3)(x)
    torch.manual_seed(100)
    y1 = W.conv(x, 3, 3, parent=m)
    assert torch.equal(y0, y1), 'conv incorrect output on 1d signal.'
    m = nn.Module()
    x = torch.randn(1, 2, 3, 4) # BCD
    torch.manual_seed(100)
    y0 = nn.Conv2d(2, 3, 3)(x)
    torch.manual_seed(100)
    y1 = W.conv(x, 3, 3, parent=m)
    assert torch.equal(y0, y1), 'conv incorrect output on 2d signal.'


def test_linear():
    m = nn.Module()
    x = torch.randn(1, 2, 3) # BDC
    torch.manual_seed(100)
    y0 = nn.Linear(3, 4)(x)
    torch.manual_seed(100)
    y1 = W.linear(x, 4, parent=m, in_shape='BDC', out_shape='BDC')
    assert torch.equal(y0, y1), 'linear incorrect output on 1d signal.'
    m = nn.Module()
    x = torch.randn(1, 2, 3, 4) # BDC
    torch.manual_seed(100)
    y0 = nn.Linear(4, 3)(x)
    torch.manual_seed(100)
    y1 = W.linear(x, 3, parent=m, in_shape='BDC', out_shape='BDC')
    assert torch.equal(y0, y1), 'batch_norm incorrect output on 2d signal.'


def test_batch_norm():
    m = nn.Module()
    x = torch.randn(1, 2, 3) # BCD
    torch.manual_seed(100)
    y0 = nn.BatchNorm1d(2)(x)
    torch.manual_seed(100)
    y1 = W.batch_norm(x, parent=m)
    m = nn.Module()
    assert torch.equal(y0, y1), 'batch_norm incorrect output on 1d signal.'
    x = torch.randn(1, 2, 3, 4) # BCD
    torch.manual_seed(100)
    y0 = nn.BatchNorm2d(2)(x)
    torch.manual_seed(100)
    y1 = W.batch_norm(x, parent=m)
    assert torch.equal(y0, y1), 'batch_norm incorrect output on 2d signal.'


def test_lstm():
    m = nn.Module()
    x = torch.randn(3, 2, 1) # DBC
    torch.manual_seed(100)
    y0, *_ = nn.LSTM(1, 2, num_layers=2)(x)
    torch.manual_seed(100)
    y1 = W.lstm(x, 2, num_layers=2, parent=m, init_weight_hh=None, in_shape='DBC', out_shape='DBC')
    assert torch.equal(y0, y1)
    y1, s1 = W.lstm(x, 2, parent=m, tuple_out=True) # test tuple out
    assert len(s1) == 2
    y2 = W.lstm((y1, s1), 2, parent=m) # test tuple in
    assert torch.is_tensor(y2)


def test_gru():
    m = nn.Module()
    x = torch.randn(3, 2, 1) # DBC
    torch.manual_seed(100)
    y0, *_ = nn.GRU(1, 2, num_layers=2)(x)
    torch.manual_seed(100)
    y1 = W.gru(x, 2, num_layers=2, parent=m, init_weight_hh=None, in_shape='DBC', out_shape='DBC')
    assert torch.equal(y0, y1)


def test_identity():
    x = torch.randn(1, 2, 3)
    assert torch.equal(W.identity(x, 7, 8, a='b'), x)


def test_dropout():
    m = nn.Module()
    x = torch.ones(2, 6, 6, 6)
    torch.manual_seed(100)
    y0 = nn.Dropout(0.3)(x)
    torch.manual_seed(100)
    y1 = W.dropout(x, 0.3, parent=m)
    assert torch.equal(y0, y1)
    torch.manual_seed(100)
    y0 = nn.Dropout2d(0.3)(x)
    torch.manual_seed(100)
    y1 = W.dropout(x, 0.3, by_channel=True, parent=m)
    assert torch.equal(y0, y1)


def test_transformer():
    m = nn.Module()
    x = torch.randn(10, 2, 4)
    y = torch.randn(6, 2, 4)
    torch.manual_seed(100)
    z0 = nn.Transformer(4, 2, 1, 1, dim_feedforward=8)(x, y)
    torch.manual_seed(100)
    z1 = W.transformer(x, y, 1, 1, 2, dim_feedforward=8, in_shape='DBC', out_shape='DBC', parent=m)
    assert torch.equal(z0, z1)
    torch.manual_seed(100)
    z1 = W.transformer(x, y, 1, 1, 2, dim_feedforward=8, in_shape='DBC', out_shape='DBC', parent=m, causal=True)
    assert not torch.equal(z0, z1)
    z1 = W.transformer(x, None, 2, 0, 2, dim_feedforward=8, in_shape='DBC', out_shape='DBC', parent=m)
    assert z1.shape == x.shape


def test_layer_norm():
    m = nn.Module()
    x = torch.randn(1, 2, 3, 4, 5)
    y0 = nn.LayerNorm([3, 4, 5])(x)
    y1 = W.layer_norm(x, [2, -2, -1], parent=m)
    assert torch.equal(y0, y1)
    y0 = nn.LayerNorm(5)(x)
    y1 = W.layer_norm(x, dim=-1, parent=m)
    assert torch.equal(y0, y1)
    x0 = x.permute(0, 4, 2, 1, 3)
    y0 = nn.LayerNorm([2, 4])(x0)
    y0 = y0.permute(0, 3, 2, 4, 1)
    y1 = W.layer_norm(x, dim=[1, -2], parent=m)
    assert torch.equal(y0, y1)


def test_embedding():
    m = nn.Module()
    x = torch.randint(0, 20, (1, 2, 3, 4, 5))
    torch.manual_seed(10)
    y0 = nn.Embedding(20, 8)(x)
    torch.manual_seed(10)
    y1 = W.embedding(x, 8, 20, parent=m)
    assert torch.equal(y0, y1)
    torch.manual_seed(10)
    y1 = W.embedding(x, 8, 20, in_shape='DCB', parent=m) # shapes should have no effect
    assert torch.equal(y0, y1)
    torch.manual_seed(10)
    y1 = W.embedding(x, 8, 20, out_shape='CBD', parent=m) # shapes should have no effect
    assert torch.equal(y0, y1)
    y1 = W.embedding(x, 8, parent=m) # should work without a explicit vocabulary size
    torch.manual_seed(10)
    y1 = W.embedding(x.double(), 8, parent=m) # should work with non integer tensors.
    assert torch.equal(y0, y1)


================================================
FILE: tests/test_module.py
================================================
# 08-31-2019;
"""
Test cases for warm.module.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from pathlib import Path
import sys
sys.path.append(str(Path(__file__).parent.parent))
import warm.module as mm
import warm.functional as W


def test_lambda():
    f = lambda x: x*2
    m = mm.Lambda(f)
    x = torch.randn(1, 2)
    assert torch.equal(f(x), m(x)), 'lambda did not work correctly.'
    def f(x, w, b=5):
        return x*w+b
    m = mm.Lambda(f, 2, b=1)
    assert torch.equal(f(x, 2, 1), m(x)), 'function with args and kwargs did not work correctly.'
    x = torch.randn(3, 2, 4)
    m = mm.Lambda(W.permute, 'BDC', 'BCD')
    assert list(m(x).shape) == [3, 4, 2], 'lambda permute did not work correctly.'


def test_sequential():
    s = mm.Sequential(
        nn.Linear(1, 2),
        nn.LSTM(2, 3, batch_first=True), # lstm and gru return multiple outputs
        nn.GRU(3, 4, batch_first=True),
        mm.Lambda(W.permute, 'BDC', 'BCD'),
        nn.Conv1d(4, 5, 1), )
    x = torch.randn(3, 2, 1)
    assert list(s(x).shape) == [3, 5, 2]


def test_shortcut():
    l = nn.Linear(1, 1, bias=False)
    nn.init.constant_(l.weight, 2.0)
    s = mm.Shortcut(l)
    x = torch.ones(1, 1)
    assert torch.allclose(s(x), torch.Tensor([3.0]))


================================================
FILE: tests/test_util.py
================================================
# 08-31-2019;
"""
Test cases for warm.util.
"""
import torch
import torch.nn as nn
import numpy as np
from pathlib import Path
import sys
sys.path.append(str(Path(__file__).parent.parent))
from warm import util


def test_camel_to_snake():
    assert util.camel_to_snake('CamelAndSnake') == 'camel_and_snake'
    assert util.camel_to_snake('camelAndSnake') == 'camel_and_snake'
    assert util.camel_to_snake('camelANDSnake') == 'camel_and_snake'
    assert util.camel_to_snake('CAMELAndSnake') == 'camel_and_snake'
    assert util.camel_to_snake('CAMELAndSNAKE') == 'camel_and_snake'
    assert util.camel_to_snake('CamelAndSnake_') == 'camel_and_snake_'
    assert util.camel_to_snake('_CamelAndSnake') == '__camel_and_snake'


def test_summary_str():
    from examples.resnet import WarmResNet
    m = WarmResNet()
    s = util.summary_str(m)
    assert len(s) > 0


def test_summary():
    from examples.resnet import WarmResNet
    m = WarmResNet()
    util.summary(m)


================================================
FILE: tests/test_warm.py
================================================
# 09-10-2019;
"""
Test cases for the warm module.
"""
import torch.nn as nn
from pathlib import Path
import sys
sys.path.append(str(Path(__file__).parent.parent))
import warm


def test_warm_up():
    m = nn.Identity()
    assert not warm.engine.is_ready(m), 'is_ready did not work correctly.'
    warm.up(m, [1, 2, 3])
    assert warm.engine.is_ready(m), 'warm.up did not work correctly.'


================================================
FILE: warm/__init__.py
================================================
# 09-10-2019;

""" `warm.up` is an alias of
[`warm.engine.prepare_model_`](https://blue-season.github.io/pywarm/reference/warm/engine/#prepare_model_). """
from warm.engine import prepare_model_ as up


================================================
FILE: warm/engine.py
================================================
# 08-26-2019;
"""
PyWarm engine to the functional interface.
"""
import torch
import torch.nn as nn
import numpy as np
from warm import util


_DEFAULT_PARENT_MODULE = None


def set_default_parent(parent):
    """ Set the default `parent` module. """
    global _DEFAULT_PARENT_MODULE
    _DEFAULT_PARENT_MODULE = parent


def get_default_parent():
    """ Get the default `parent` module. """
    global _DEFAULT_PARENT_MODULE
    return _DEFAULT_PARENT_MODULE


def _auto_name(name, parent):
    """ Track the count of reference to `name` from `parent`. """
    if not is_ready(parent):
        parent._pywarm_auto_name_dict = {}
        def _hook(model, x):
            model._pywarm_auto_name_dict = {}
        parent._pywarm_forward_pre_hook = parent.register_forward_pre_hook(_hook)
    track = parent._pywarm_auto_name_dict
    if name not in track:
        track[name] = 0
    track[name] += 1
    return f'{name}_{track[name]}'


def prepare_model_(model, *data, device='cpu'):
    """ Initialize all childen modules defined by `warm` in a parent `model`.\n
    -  `model: Module`; The parent model to be prepared.
    -  `data: Tensor, or list of int`; A batch of data with the correct shape and type to be forwarded by model.
        `data` can also be a list of `int`, in which case it is interpreted as the shape of the input data.
    -  `device: str, or torch.device`; Should be the same for `model` and `data`. Default: `'cpu'`.
    -  `return: Module`; The prepared model, with all children modules defined by `warm` initialized. """
    _auto_name('', model)
    set_default_parent(model)
    def _prep_data(d):
        if isinstance(d, (np.ndarray, torch.Tensor)):
            return torch.as_tensor(d).to(device)
        elif isinstance(d, (list, tuple)):
            if all(isinstance(x, int) for x in d):
                return torch.randn(*d, device=device)
            return [_prep_data(x) for x in d]
        elif isinstance(d, dict):
            return {k:_prep_data(v) for k, v in d.items()}
    with torch.no_grad():
        is_training = model.training
        data = [_prep_data(d) for d in data]
        model.eval()
        model.to(device)
        model(*data)
        model.train(is_training)
    return model


def is_ready(model):
    """ Check if a `model` is prepared. """
    return hasattr(model, '_pywarm_forward_pre_hook')


def activate(x, spec, lookup=None):
    """ Activate tensors with given nonlinearity `spec`ification.\n
    -  `x: Tensor or list of Tensor`; The tensors to be initialized.
    -  `spec: str or callable or 2-tuple`; If a `str`, should be one of the nonlinearity functions contained in
        `torch.nn.functional` or `torch`. If a `callable`, it will be applied to `x` directly, i.e. `spec(x)`.
        If a 2-`tuple`, it must be of format `(callable, kwargs)`, i.e. `callable(x, **kwargs)`.
    -  `lookup: None or list of module`; Parent modules to look for `spec`. If `None`, `[nn.functional, torch]` is used.
    -  `return: Tensor or list of Tensor`; Activation results. """
    if spec is None:
        return x
    lookup = lookup or [nn.functional, torch]
    if isinstance(spec, str):
        for look in lookup:
            try:
                spec = getattr(look, spec)
                break
            except:
                pass
        if isinstance(spec, str):
            raise ValueError(f'Unknown spec {spec}.')
    if callable(spec):
        spec = (spec, {})
    fn, kw = spec
    if isinstance(x, (list, tuple)):
        return [fn(y, **kw) for y in x]
    return fn(x, **kw)


def initialize_(x, spec):
    """ Initialize parameters with given nonlinearity `spec`ification.\n
    -  `x: Tensor or list of Tensor`; The tensors to be initialized.
    -  `spec: str or callable or 2-tuple`; If a `str`, should be one of the nonlinearity functions contained in
        `torch.nn.init`. If a `callable`, it will be applied to `x` directly, i.e. `spec(x)`. If a 2-`tuple`,
        it must be of format `(callable, kwargs)`, i.e. `callable(x, **kwargs)`. """
    activate(x, spec, lookup=[nn.init])


def permute(x, in_shape='BCD', out_shape='BCD', **kw):
    """ Permute the dimensions of a tensor.\n
    -  `x: Tensor`; The nd-tensor to be permuted.
    -  `in_shape: str`; The dimension shape of `x`. Can only have characters `'B'` or `'C'` or `'D'`,
        which stand for Batch, Channel, or extra Dimensions. The default value `'BCD'` means
        the input tensor `x` should be at lest 2-d with shape `(Batch, Channel, Dim0, Dim1, Dim2, ...)`,
        where `Dim0, Dim1, Dim2 ...` stand for any number of extra dimensions.
    -  `out_shape: str or tuple or None`; The dimension shape of returned tensor.  Default: `'BCD'`.
        If a `str`, it is restricted to the same three characters `'B'`, `'C'` or `'D'` as the `in_shape`.
        If a `tuple`, `in_shape` is ignored, and simply `x.permute(out_shape)` is returned.
        If `None`, no permution will be performed.
    -  `return: Tensor`; Permuted nd-tensor. """
    if (in_shape == out_shape) or (out_shape is None):
        return x
    if isinstance(out_shape, (list, tuple, torch.Size)):
        return x.permute(*out_shape)
    if isinstance(in_shape, str) and isinstance(out_shape, str) :
        assert set(in_shape) == set(out_shape) <= {'B', 'C', 'D'}, 'In and out shapes must have save set of chars among B, C, and D.'
        in_shape = in_shape.lower().replace('d', '...')
        out_shape = out_shape.lower().replace('d', '...')
        return torch.einsum(f'{in_shape}->{out_shape}', x)
    return x


def unused_kwargs(kw):
    """ Filter out entries used by `forward` and return the rest. """
    fn_kw = dict(base_class=None,
        base_name=None, name=None, base_arg=None, base_kw=None, parent=None,
        infer_kw=None, in_shape='BCD', base_shape=None, out_shape='BCD', tuple_out=False,
        forward_arg=None, forward_kw=None, initialization=None, activation=None, )
    return {k:v for k, v in kw.items() if k not in fn_kw}


def forward(x, base_class, 
        base_name=None, name=None, base_arg=None, base_kw=None, parent=None,
        infer_kw=None, in_shape='BCD', base_shape='BCD', out_shape='BCD', tuple_out=False,
        forward_arg=None, forward_kw=None, initialization=None, activation=None, **kw):
    """ A forward template that creates child instances at the first time it is called.\n
    -  `x: Tensor`; The nd-tensor to be forwarded.
    -  `base_class: Module`; A child `torch.nn.Module` that will be created at the first time this function is called.
    -  `base_name: str`; Name for the `base_class`. Default: base_class name.
    -  `name: str`; Name for the child module instance. Default: class name plus ordinal.
    -  `base_arg: tuple`; Positional args to be passed to create the child module instance. Default: None.
    -  `base_kw: dict`; KWargs to be passed to create the child module instance. Default: None.
    -  `parent: Module`; The parent of the child instance.  Default: None. If `None`, will use `get_default_parent`.
    -  `infer_kw: dict`; Key should be valid for the child instance. Value shoud be a character,
        one of `'B'`, `'C'`, or `'D'` (see `permute`), to substitute for a dimension of `x`. Default: None.
    -  `in_shape: str`; The dimension shape of `x`. See also `permute`. Default: `'BCD'`.
    -  `base_shape: str`; The dimension shape required by the child module. See also `permute`. Default: `'BCD'`.
    -  `out_shape: str or tuple or None`; The dimension shape of returned tensor. See also `permute`. Default: `'BCD'`.
    -  `tuple_out: bool`; Whether the child module will return more than 1 outputs (e.g. `nn.RNN`).
        If `True`, the returned value of the function will be a tuple containing all outputs. Default: False.
    -  `forward_arg: tuple`; positional args to be passed when calling the child module instance. Default: None.
    -  `forward_kw: dict`; KWargs to be passed when calling the child module instance. Default: None.
    -  `initialization: dict`; Keys are name of parameters to initialize. Values are init specs, which can be 
        a, `str`, a `callable`, or `2-tuple`; See the `spec` argument of `initialize_` for details. Default: None.
    -  `activation: str or callable or 2-tuple`; See the `spec` argument of `activate`. Default: None.
    -  `return: Tensor or tuple`; If `tuple_out` is `True`, the returned value will be a `tuple`. """
    parent = parent or get_default_parent()
    if name is None:
        base_name = base_name or util.camel_to_snake(base_class.__name__)
        name = _auto_name(base_name, parent)
    if name not in parent._modules:
        if infer_kw is not None:
            shape = in_shape
            if 'D' in shape:
                shape = list(shape)
                shape[shape.index('D')] = 'D'*(x.ndim-len(shape)+1)
                shape = ''.join(shape)
            infer_kw = {
                k:x.shape[shape.find(v) if isinstance(v, str) else v]
                for k, v in infer_kw.items()}
        base = base_class(*(base_arg or []), **(infer_kw or {}), **(base_kw or {}), )
        parent.add_module(name, base)
        if initialization is not None:
            s = parent.state_dict()
            for k, v in initialization.items():
                initialize_(s[name+'.'+k], v)
    x = permute(x, in_shape, base_shape)
    y = parent._modules[name](x, *(forward_arg or []), **(forward_kw or {}))
    r = []
    if isinstance(y, tuple):
        y, *r = y
    y = permute(y, base_shape, out_shape)
    y = activate(y, activation)
    if tuple_out:
        return (y, *r)
    return y


import functools
def namespace(f):
    """ After decoration, the function name and call count will be appended to the `name` kw. """
    @functools.wraps(f)
    def _wrapped(*arg, **kw):
        parent = kw.get('parent', get_default_parent())
        name = kw.get('name', '')
        name = '_warmns_' + name + ('-' if name else '') + f.__name__
        name = _auto_name(name, parent)
        kw['name'] = name.replace('_warmns_', '')
        return f(*arg, **kw)
    return _wrapped


================================================
FILE: warm/functional.py
================================================
# 08-27-2019;
"""
Wraps around various torch.nn Modules to fit into a functional interface.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from warm import engine
from warm import util


permute = engine.permute


def conv(x, size, kernel, init_weight=None, init_bias=None, bias=True, **kw):
    """ Convolution layer.\n
    -  `x: Tensor`; With shape `(Batch, Channel, *)` where `*` Can be 1d or 2d or 3d.
        If 3d, shapes are `(Batch, Channel, Length)`.
        If 4d, shapes are `(Batch, Channel, Height, Width)`.
        If 5d, shapes are `(Batch, Channel, Depth, Height, Width)`.
    -  `size: int`; Size of hidden filters, and size of the output channel.
    -  `kernel: int or tuple`; Size of the convolution kernel.
    -  `init_weight: None or str or callable`; Initialization specification for the weight tensor.
        If a `str`, should be one of the nonlinearity functions contained in `torch.nn.init`.
        If a `callable`, it will be applied to `x` directly, i.e. `spec(x)`. If a 2-`tuple`,
        it must be of format `(callable, kwargs)`, i.e. `callable(x, **kwargs)`.
        Default: `None`, and the weight tensor is initialized using `torch.nn.ConvNd`s default scheme.
    -  `init_bias: None or str or callable`; Same as `init_weight`, but for the bias tensor.
    -  `bias: bool`; If `True`, adds a learnable bias to the output. Default: `True`.
    -  `**kw:dict`; Any additional KWargs are passed down to `torch.nn.ConvNd`, where N can be 1, 2 or 3.
        as well as `warm.engine.forward`. Refer to their docs for details. Some of the additional ConvNd arguments:
        `stride, padding, dilation, groups`.
    -  `return: Tensor`; With shape `(Batch, Size, *)` where `*` can be 1d, 2d, 3d that depends on `x`. """
    d = x.ndim-3
    assert d in [0, 1, 2], 'Incompatible number of dims for input x.'
    inferred_kw = dict(
        base_name='conv',
        base_class=[nn.Conv1d, nn.Conv2d, nn.Conv3d][d],
        base_kw={
            'out_channels':size,
            'kernel_size':kernel,
            'bias':bias,
            **engine.unused_kwargs(kw), },
        infer_kw={'in_channels':'C'},
        initialization={'weight':init_weight, **({'bias':init_bias} if bias else {})}, )
    return engine.forward(x, **{**inferred_kw, **kw})


def linear(x, size, init_weight=None, init_bias=None, bias=True, **kw):
    """ Linear transformation layer.\n
    -  `x: Tensor`; 2d or more, with shapes `(Batch, Channel, *)` where `*` means any number of additional dimensions.
    -  `size: int`; Size of hidden features, and size of the output channel.
    -  `init_weight: None or str or callable`; Initialization specification for the weight tensor.
        If a `str`, should be one of the nonlinearity functions contained in `torch.nn.init`.
        If a `callable`, it will be applied to `x` directly, i.e. `spec(x)`. If a 2-`tuple`,
        it must be of format `(callable, kwargs)`, i.e. `callable(x, **kwargs)`.
        Default: `None`, and the weight tensor is initialized using `torch.nn.Linear`s default scheme.
    -  `init_bias: None or str or callable`; Same as `init_weight`, but for the bias tensor.
    -  `bias: bool`; If `True`, adds a learnable bias to the output. Default: `True`.
    -  `**kw:dict`; Any additional KWargs are passed down to `warm.engine.forward`. Refer to its docs for details.
    -  `return: Tensor`; With shape `(Batch, Size, *)` where `*` can be 1d, 2d, 3d that depends on `x`. """
    inferred_kw = dict(
        base_name='linear',
        base_class=nn.Linear,
        base_kw={'out_features':size, 'bias':bias},
        base_shape='BDC',
        infer_kw={'in_features':'C'},
        initialization={'weight':init_weight, **({'bias':init_bias} if bias else {})}, )
    return engine.forward(x, **{**inferred_kw, **kw})


def batch_norm(x, **kw):
    """ Batch Normalization layer.\n
    -  `x: Tensor`; 2d or more, with shapes `(Batch, Channel, *)` where `*` means any number of additional dimensions.
    -  `**kw: dict`; Any additional KWargs are passed down to `torch.nn.BatchNormNd`, where N can be 1, 2 or 3.
        as well as `warm.engine.forward`. Refer to their docs for details. Some of the additional BatchNorm arguments:
        `eps, momentum, affine, track_running_stats`.
    -  `return: Tensor`; Same shape as input  `x`. """
    d = x.ndim-3
    assert d in [0, 1, 2], 'Incompatible number of dims for input x.'
    inferred_kw = dict(
        base_name='batch_norm',
        base_class=[nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d][d],
        base_kw={'num_features':x.shape[1]}, )
    return engine.forward(x, **{**inferred_kw, **kw})


def lstm(x, size,
        init_weight_hh='orthogonal_', init_weight_ih=None, init_bias_hh=None, init_bias_ih=None,
        bias=True, num_layers=1, **kw):
    """ Long Short Term Memory layer.\n
    -  `x: Tensor or tuple`; If tuple, must be of format `(x, (h_0, c_0))`, where `x` is a 3d tensor,
        with shapes `(Batch, Channel, Length)`.
    -  `size: int`; Size of hidden features, and size of the output channel.
    -  `init_weight_hh: None or str or callable`; Initialization specification for the hidden-hidden weight tensor.
        If a `str`, should be one of the nonlinearity functions contained in `torch.nn.init`.
        If a `callable`, it will be applied to `x` directly, i.e. `spec(x)`. If a 2-`tuple`,
        it must be of format `(callable, kwargs)`, i.e. `callable(x, **kwargs)`.
        Default: `'orthogonal_'`.
    -  `init_weight_ih: None or str or callable`; Initialization specification for the input-hidden weight tensor.
        Default: `None`, and the weight tensor is initialized using `torch.nn.LSTM`s default scheme.
    -  `init_bias_hh: None or str or callable`; Initialization specification for the hidden-hidden bias tensor.
        Default: `None`, and the weight tensor is initialized using `torch.nn.LSTM`s default scheme.
    -  `init_bias_ih: None or str or callable`; Initialization specification for the input-hidden bias tensor.
        Default: `None`, and the weight tensor is initialized using `torch.nn.LSTM`s default scheme.
    -  `bias: bool`; If `False`, then the layer does not use `bias_ih` and `bias_hh`. Default: `True`.
    -  `num_layers: int`; Number of the recurrent layers. Default: 1.
    -  `tuple_out: bool`; If `True`, the returned value will be a tuple `(out, (h_n, c_n))`. Default: False.
    -  `**kw: dict`; Any additional KWargs are passed down to `torch.nn.LSTM`, as well as `warm.engine.forward`.
        Refer to their docs for details. Some of the additional LSTM arguments: `dropout, bidirectional, batch_first`.
    -  `return: Tensor or tuple`; If `tuple_out` set to true, will return `(out, (h_n, c_n)`, otherwise just `out`.
        `out` has shape `(Batch, Size, Length*Directions)`,
            where Directions = 2 if `bidirectional` else 1.
        `h_n` is the hidden states with shape `(num_layers*Directions, Batch, Size)`.
        `c_n` is the cell states with shape `(num_layers*Directions, Batch, Size)`. """
    states = None
    if isinstance(x, tuple):
        x, *states = x
    init = dict(
        weight_hh=init_weight_hh,
        weight_ih=init_weight_ih,
        bias_hh=init_bias_hh,
        bias_ih=init_bias_ih, )
    inferred_kw = dict(
        base_name='lstm',
        base_class=nn.LSTM,
        base_kw={
            'hidden_size':size,
            'num_layers':num_layers,
            **engine.unused_kwargs(kw), },
        base_shape='DBC',
        infer_kw={'input_size':'C'},
        forward_arg=states,
        initialization={
            f'{k}_l{l}':init[k] for k in ['weight_hh', 'weight_ih']+(['bias_hh', 'bias_ih'] if bias else [])
            for l in range(num_layers)}, )
    return engine.forward(x, **{**inferred_kw, **kw})


def gru(*arg, **kw):
    """ Gated Recurrent Unit layer.\n
    -  `x: Tensor or tuple`; If tuple, must be of format `(x, (h_0, c_0))`, where `x` is a 3d tensor,
        with shapes `(Batch, Channel, Length)`.
    -  `size: int`; Size of hidden features, and size of the output channel.
    -  `init_weight_hh: None or str or callable`; Initialization specification for the hidden-hidden weight tensor.
        If a `str`, should be one of the nonlinearity functions contained in `torch.nn.init`.
        If a `callable`, it will be applied to `x` directly, i.e. `spec(x)`. If a 2-`tuple`,
        it must be of format `(callable, kwargs)`, i.e. `callable(x, **kwargs)`.
        Default: `'orthogonal_'`.
    -  `init_weight_ih: None or str or callable`; Initialization specification for the input-hidden weight tensor.
        Default: `None`, and the weight tensor is initialized using `torch.nn.GRU`s default scheme.
    -  `init_bias_hh: None or str or callable`; Initialization specification for the hidden-hidden bias tensor.
        Default: `None`, and the weight tensor is initialized using `torch.nn.GRU`s default scheme.
    -  `init_bias_ih: None or str or callable`; Initialization specification for the input-hidden bias tensor.
        Default: `None`, and the weight tensor is initialized using `torch.nn.GRU`s default scheme.
    -  `bias: bool`; If `False`, then the layer does not use `bias_ih` and `bias_hh`. Default: `True`.
    -  `num_layers: int`; Number of the recurrent layers. Default: 1.
    -  `tuple_out: bool`; If `True`, the returned value will be a tuple `(out, (h_n, c_n))`. Default: False.
    -  `**kw: dict`; Any additional KWargs are passed down to `torch.nn.GRU`, as well as `warm.engine.forward`.
        Refer to their docs for details. Some of the additional GRU arguments: `dropout, bidirectional, batch_first`.
    -  `return: Tensor or tuple`; If `tuple_out` set to true, will return `(out, (h_n, c_n)`, otherwise just `out`.
        `out` has shape `(Batch, Size, Length*Directions)`,
            where Directions = 2 if `bidirectional` else 1.
        `h_n` is the hidden states with shape `(num_layers*Directions, Batch, Size)`.
        `c_n` is the cell states with shape `(num_layers*Directions, Batch, Size)`. """
    return lstm(*arg, base_name='gru', base_class=nn.GRU, **kw)


def identity(x, *arg, **kw):
    """ Identity layer that returns the first input, ignores the rest arguments. """
    return x


def dropout(x, rate=0.5, by_channel=False, **kw):
    """ Dropout layer.\n
    During training, randomly zeros part of input tensor `x`, at probability `rate`.\n
    -  `x: Tensor`; Can be of any shape if `by_channel` is false, or 2d and up if `by_channel` is true.
    -  `rate: float`; The probability of dropout. Default 0.5.
    -  `by_channel: bool`; If true, will dropout entire channels (all `'D'` dimensions will be 0 if x is `'BCD'`).
        `by_channel` true requires `x` to be 2d or more.
    -  `inplace: bool`; If true, the operation will be in-place and the input `x` will be altered.
    -  `return: Tensor`; Same shape as `x`. """
    inferred_kw = dict(
        base_name='dropout',
        base_class=[nn.Dropout, nn.Dropout2d][by_channel],
        base_kw={'p':rate},
        base_shape=[None, 'BCD'][by_channel], )
    return engine.forward(x, **{**inferred_kw, **kw})


def transformer(x, y=None, num_encoder=6, num_decoder=6, num_head=8,
        mask=None, causal=False, in_shape='BCD', **kw):
    """ Transformer layer.\n
    This layer covers functionality of `Transformer`, `TransformerEncoder`, and `TransformerDecoder`.
    See [`torch.nn.Transformer`](https://pytorch.org/docs/stable/nn.html#transformer) for more details.\n
    -  `x: Tensor`; The source sequence, with shape `(Batch, Channel, LengthX)`.
        `Channel` is usually from embedding.
    -  `y: None or Tensor`; The target sequence. Also with shape `(Batch, Channel, LengthY)`.
        If not present, default to equal `x`.
    -  `num_encoder: int`; Number of encoder layers. Set to 0 to disable encoder and use only decoder. Default 6.
    -  `num_decoder: int`; Number of decoder layers. Set to 0 to disable decoder and use only encoder. Default 6.
    -  `num_head: int`; Number of heads for multi-headed attention. Default 8.
    -  `mask: None or dict`; Keys are among: `src_mask`, `tgt_mask`, `memory_mask`,
        `src_key_padding_mask`, `tgt_key_padding_mask`, `memory_key_padding_mask`.
        See the `forward` method of `torch.nn.Transformer` for details.
    -  `causal: bool`; Default false. if true, will add causal masks to source and target, so that
        current value only depends on the past, not the future, in the sequences.
    -  `**kw: dict`; Any additional KWargs are passed down to `torch.nn.Transformer`, as well as `warm.engine.forward`.
    -  `return: Tensor`; Same shape as `y`, if `num_decoder` > 0. Otherwise same shape as `x`. """
    def _causal_mask(n):
        mask = (torch.triu(torch.ones(n, n)) == 1).transpose(0, 1)
        return mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
    if y is None:
        y = x
    y = permute(y, in_shape, 'DBC')
    mask = mask or {}
    if causal:
        i = in_shape.find('D')
        mx = _causal_mask(x.shape[i])
        mask['src_mask'] = mask.pop('src_mask', 0.0)+mx
        my = _causal_mask(y.shape[0])
        mask['tgt_mask'] = mask.pop('tgt_mask', 0.0)+my
    encoder = identity if num_encoder == 0 else None
    decoder = identity if num_decoder == 0 else None
    inferred_kw = dict(
        base_name='transformer',
        base_class=nn.Transformer,
        base_shape='DBC',
        base_kw=dict(
            d_model=x.shape[in_shape.find('C')],
            custom_encoder=encoder,
            custom_decoder=decoder,
            nhead=num_head,
            num_encoder_layers=num_encoder,
            num_decoder_layers=num_decoder, 
            **engine.unused_kwargs(kw), ),
        in_shape=in_shape,
        forward_kw=mask,
        forward_arg=(y, ), )
    return engine.forward(x, **{**inferred_kw, **kw})


def layer_norm(x, dim=1, **kw):
    """ Layer Normalization.\n
    -  `x: Tensor`; Can be of any shape.
    -  `dim: int or list of int`; Dimensions to be normalized. Default: 1.
    -  `**kw: dict`; Any additional KWargs are passed down to `torch.nn.LayerNorm`, as well as `warm.engine.forward`.
    -  `return: Tensor`; Same shape as `x`. """
    if dim != -1:
        if isinstance(dim, int):
            dim = [dim]
        dim_norm = [x.ndim+i if i < 0 else i for i in dim]
        order = [i for i in range(x.ndim) if i not in dim_norm]+dim_norm
        x = x.permute(order)
        norm_shape = x.shape[-len(dim_norm):]
    else:
        norm_shape = [x.shape[-1]]
    inferred_kw = dict(
        base_name='layer_norm',
        base_class=nn.LayerNorm,
        base_kw={'normalized_shape':norm_shape}, )
    x = engine.forward(x, **{**inferred_kw, **kw})
    if dim != -1:
        x = x.permute(np.argsort(order).tolist())
    return x


def embedding(x, size, vocabulary=None, **kw):
    """ Embedding layer.\n
    The input is usually a list of indices (integers), and the output is a dense matrix which
    maps indices to dense vectors. Thus the output will have 1 more dimension than the input.\n
    **Note**: The output of this function is always one more dimension than the input. For input with shape `(*)`,
    The output will be `(*, size)`. Any shape specifications in the KWargs are ignored. \n
    -  `x: Tensor`; Contains indices into the vocabulary. Will be converted to `LongTensor` of integers.
        Can be of any shape.
    -  `size: int`; The size of embedding vector.
    -  `vocabulary: int or None`; The size of vocabulary of embedding, or max number of unique indices in `x`.
        By default it is set to `max(x)-min(x)+1`.
    -  `**kw: dict`; Any additional KWargs are passed down to `torch.nn.LayerNorm`, as well as `warm.engine.forward`.
    -  `return: Tensor`; With the embedded dim appended to the shape of x.
        Thus with shape `(*, Size)`, where `*` is the shape of `x`. """
    x = x.type(torch.LongTensor)
    if vocabulary is None:
        vocabulary = x.max()-x.min()+1
    kw.pop('in_shape', None)
    kw.pop('out_shape', None)
    kw.pop('base_shape', None)
    inferred_kw = dict(
        base_name='embedding',
        base_class=nn.Embedding,
        base_kw=dict(
            num_embeddings=vocabulary,
            embedding_dim=size,
            **engine.unused_kwargs(kw), ),
        base_shape=None,
        in_shape=None,
        out_shape=None, )
    return engine.forward(x, **{**inferred_kw, **kw})


================================================
FILE: warm/module.py
================================================
# 08-27-2019;
"""
Custom modules to enhance the nn Sequential experience.

PyWarm's core concept is to use a functional interface to simplify network building.
However, if you still prefer the classical way of defining child modules in `__init__()`,
PyWarm provides some utilities to help organize child modules better.

- `Lambda` can be used to wrap one line data transformations, like `x.view()`, `x.permute()` etc, into modules.

- `Sequential` is an extension to `nn.Sequential` that better accomodates PyTorch RNNs.

- `Shortcut` is another extension to `nn.Sequential` that will also perform a shortcut addition (AKA residual connection)
for the input with output, so that residual blocks can be written in an entire sequential way.

For example, to define the basic block type for resnet:


```Python
import torch.nn as nn
import warm.module as wm


def basic_block(size_in, size_out, stride=1):
    block = wm.Shortcut(
        nn.Conv2d(size_in, size_out, 3, stride, 1, bias=False),
        nn.BatchNorm2d(size_out),
        nn.ReLU(),
        nn.Conv2d(size_out, size_out, 3, 1, 1, bias=False),
        nn.BatchNorm2d(size_out),
        projection=wm.Lambda(
            lambda x: x if x.shape[1] == size_out else nn.Sequential(
                nn.Conv2d(size_in, size_out, 1, stride, bias=False),
                nn.BatchNorm2d(size_out), )(x), ), )
    return block
```
"""


import torch.nn as nn


class Lambda(nn.Module):
    """ Wraps a callable and all its call arguments.\n
    -  `fn: callable`; The callable being wrapped.
    -  `*arg: list`; Arguments to be passed to `fn`.
    -  `**kw: dict`; KWargs to be passed to `fn`. """
    def __init__(self, fn, *arg, **kw):
        super().__init__()
        self.fn = fn
        self.arg = arg
        self.kw = kw
    def forward(self, x):
        """ forward. """
        return self.fn(x, *self.arg, **self.kw)


class Sequential(nn.Sequential):
    """ Similar to `nn.Sequential`, except that child modules can have multiple outputs (e.g. `nn.RNN`).\n
    -  `*arg: list of Modules`; Same as `nn.Sequential`. """
    def forward(self, x):
        """ forward. """
        for module in self._modules.values():
            if isinstance(x, tuple):
                try:
                    x = module(x)
                except Exception:
                    x = module(x[0])
            else:
                x = module(x)
        return x


class Shortcut(Sequential):
    """ Similar to `nn.Sequential`, except that it performs a shortcut addition for the input and output.\n
    -  `*arg: list of Modules`; Same as `nn.Sequential`.
    -  `projection: None or callable`; If `None`, input with be added directly to the output.
        otherwise input will be passed to the `projection` first, usually to make the shapes match. """
    def __init__(self, *arg, projection=None):
        super().__init__(*arg)
        self.projection = projection or nn.Identity()
    def forward(self, x):
        """ forward. """
        return super().forward(x)+self.projection(x)


================================================
FILE: warm/util.py
================================================
# 08-28-2019;
"""
Short utilities.
"""
import torch
import torch.nn as nn
import numpy as np
import re


""" Create a property for class torch.Tensor called ndim, for pytorch earlier than 1.2. """
if not hasattr(torch.Tensor, 'ndim'):
    torch.Tensor.ndim = property(lambda x: x.dim())


def camel_to_snake(name):
    """ Convert a camelCaseString to its snake_case_equivalent. """
    s1 = re.sub('(.)([A-Z][a-z]+)', r'\1_\2', name)
    return re.sub('([a-z0-9])([A-Z])', r'\1_\2', s1).lower()


def summary_str(model):
    """ Get a string representation of model building blocks and parameter counts. """
    indent_list, name_list, count_list = [], [], []
    def module_info(m, name, indent_level):
        count_list.append(sum([np.prod(list(p.size())) for p in m.parameters()]))
        indent_list.append(indent_level)
        name_list.append(name)
        for name, child in m.named_children():
            if name.isdigit():
                name = child._get_name()
            module_info(child, name, indent_level+1)
    module_info(model, model._get_name(), 0)
    max_indent = max(indent_list)*4
    max_name = max(len(x) for x in name_list)+max_indent+2
    max_param = len(str(count_list[0]))+max_name+2
    out = ['Blocks{:>{w}}'.format('Params', w=max_param-6)]
    out += ['-'*max_param]
    for indent, name, param in zip(indent_list, name_list, count_list):
        s0 = '    '*indent
        s1 = '{:{w}}'.format(name, w=max_name-len(s0))
        s2 = '{:>{w}}'.format(param, w=max_param-len(s1)-len(s0))
        out += [s0+s1+s2]
    return '\n'.join(out)


def summary(model):
    """ Print a summary about model building blocks and parameter counts. """
    print(summary_str(model))
Download .txt
gitextract_uvlwuzum/

├── .gitignore
├── CONTRIBUTING.md
├── LICENSE.md
├── README.md
├── docs/
│   ├── example.md
│   ├── text.mako
│   └── tutorial.md
├── examples/
│   ├── efficientnet.py
│   ├── lstm.py
│   ├── mnist.py
│   ├── mobilenet.py
│   ├── resnet.py
│   └── transformer.py
├── pyproject.toml
├── tests/
│   ├── test_engine.py
│   ├── test_functional.py
│   ├── test_module.py
│   ├── test_util.py
│   └── test_warm.py
└── warm/
    ├── __init__.py
    ├── engine.py
    ├── functional.py
    ├── module.py
    └── util.py
Download .txt
SYMBOL INDEX (109 symbols across 15 files)

FILE: examples/efficientnet.py
  function swish (line 19) | def swish(x):
  function conv_pad_same (line 23) | def conv_pad_same(x, size, kernel=1, stride=1, **kw):
  function conv_bn_act (line 40) | def conv_bn_act(x, size, kernel=1, stride=1, groups=1, bias=False, eps=1...
  function mb_block (line 46) | def mb_block(x, size_out, expand=1, kernel=1, stride=1, se_ratio=0.25, d...
  function squeeze_excitation (line 61) | def squeeze_excitation(x, size_se, name='', **kw):
  function drop_connect (line 70) | def drop_connect(x, rate):
  class WarmEfficientNet (line 89) | class WarmEfficientNet(nn.Module):
    method __init__ (line 90) | def __init__(self):
    method forward (line 93) | def forward(self, x):

FILE: examples/lstm.py
  class WarmTagger (line 32) | class WarmTagger(nn.Module):
    method __init__ (line 33) | def __init__(self, embedding_dim, hidden_dim, vocab_size, tagset_size):
    method forward (line 37) | def forward(self, x): # D
  class TorchTagger (line 46) | class TorchTagger(nn.Module):
    method __init__ (line 47) | def __init__(self, embedding_dim, hidden_dim, vocab_size, tagset_size):
    method forward (line 53) | def forward(self, sentence):
  function prepare_sequence (line 61) | def prepare_sequence(seq, to_ix):
  function main (line 66) | def main():

FILE: examples/mnist.py
  class WarmNet (line 22) | class WarmNet(nn.Module):
    method __init__ (line 23) | def __init__(self):
    method forward (line 26) | def forward(self, x):
  class TorchNet (line 37) | class TorchNet(nn.Module):
    method __init__ (line 38) | def __init__(self):
    method forward (line 44) | def forward(self, x):
  function train (line 55) | def train(p, model, device, train_loader, optimizer, epoch):
  function test (line 70) | def test(p, model, device, test_loader):
  function main (line 86) | def main():

FILE: examples/mobilenet.py
  function conv_bn_relu (line 19) | def conv_bn_relu(x, size, stride=1, expand=1, kernel=3, groups=1, name=''):
  function bottleneck (line 25) | def bottleneck(x, size_out, stride, expand, name=''):
  function conv1x1 (line 37) | def conv1x1(x, *arg, **kw):
  function pool (line 41) | def pool(x, *arg, **kw):
  function classify (line 45) | def classify(x, size, *arg, **kw):
  class WarmMobileNetV2 (line 64) | class WarmMobileNetV2(nn.Module):
    method __init__ (line 65) | def __init__(self):
    method forward (line 68) | def forward(self, x):
  function test (line 78) | def test():

FILE: examples/resnet.py
  function basic (line 20) | def basic(x, size, stride, stack_index, block_index):
  function stack (line 33) | def stack(x, num_block, size, stride, stack_index, block=basic):
  class WarmResNet (line 40) | class WarmResNet(nn.Module):
    method __init__ (line 41) | def __init__(self, block=basic, stack_spec=((2, 64, 1), (2, 128, 2), (...
    method forward (line 46) | def forward(self, x):
  function test_time (line 58) | def test_time(fn, *arg, repeat=10, **kw):
  function test (line 67) | def test():

FILE: examples/transformer.py
  function multi_head_attention (line 17) | def multi_head_attention(x, y=None, num_head=8, dropout=0.1, mask=None, ...
  function feed_forward (line 44) | def feed_forward(x, size_ff=2048, dropout=0.1, **kw):
  function residual_add (line 50) | def residual_add(x, layer, dropout=0.1, **kw):
  function encoder (line 57) | def encoder(x, num_encoder=6, **kw):
  function decoder (line 64) | def decoder(x, y, num_decoder=6, mask_x=None, mask_y=None, **kw):
  function transformer (line 72) | def transformer(x, y, **kw):
  class Transformer (line 78) | class Transformer(nn.Module):
    method __init__ (line 79) | def __init__(self, *shape, **kw):
    method forward (line 83) | def forward(self, x, y):

FILE: tests/test_engine.py
  function test_set_get_default_parent (line 15) | def test_set_get_default_parent():
  function test_auto_name (line 24) | def test_auto_name():
  function test_initialize (line 32) | def test_initialize():
  function test_activate (line 47) | def test_activate():
  function test_permute (line 58) | def test_permute():
  function test_unused_kwargs (line 81) | def test_unused_kwargs():
  function test_prepare_model_is_ready (line 88) | def test_prepare_model_is_ready():
  function test_forward (line 108) | def test_forward():
  function test_namespace (line 130) | def test_namespace():

FILE: tests/test_functional.py
  function test_conv (line 14) | def test_conv():
  function test_linear (line 31) | def test_linear():
  function test_batch_norm (line 48) | def test_batch_norm():
  function test_lstm (line 65) | def test_lstm():
  function test_gru (line 79) | def test_gru():
  function test_identity (line 89) | def test_identity():
  function test_dropout (line 94) | def test_dropout():
  function test_transformer (line 109) | def test_transformer():
  function test_layer_norm (line 125) | def test_layer_norm():
  function test_embedding (line 141) | def test_embedding():

FILE: tests/test_module.py
  function test_lambda (line 15) | def test_lambda():
  function test_sequential (line 29) | def test_sequential():
  function test_shortcut (line 40) | def test_shortcut():

FILE: tests/test_util.py
  function test_camel_to_snake (line 14) | def test_camel_to_snake():
  function test_summary_str (line 24) | def test_summary_str():
  function test_summary (line 31) | def test_summary():

FILE: tests/test_warm.py
  function test_warm_up (line 12) | def test_warm_up():

FILE: warm/engine.py
  function set_default_parent (line 14) | def set_default_parent(parent):
  function get_default_parent (line 20) | def get_default_parent():
  function _auto_name (line 26) | def _auto_name(name, parent):
  function prepare_model_ (line 40) | def prepare_model_(model, *data, device='cpu'):
  function is_ready (line 68) | def is_ready(model):
  function activate (line 73) | def activate(x, spec, lookup=None):
  function initialize_ (line 101) | def initialize_(x, spec):
  function permute (line 110) | def permute(x, in_shape='BCD', out_shape='BCD', **kw):
  function unused_kwargs (line 134) | def unused_kwargs(kw):
  function forward (line 143) | def forward(x, base_class,
  function namespace (line 201) | def namespace(f):

FILE: warm/functional.py
  function conv (line 16) | def conv(x, size, kernel, init_weight=None, init_bias=None, bias=True, *...
  function linear (line 50) | def linear(x, size, init_weight=None, init_bias=None, bias=True, **kw):
  function batch_norm (line 73) | def batch_norm(x, **kw):
  function lstm (line 89) | def lstm(x, size,
  function gru (line 141) | def gru(*arg, **kw):
  function identity (line 170) | def identity(x, *arg, **kw):
  function dropout (line 175) | def dropout(x, rate=0.5, by_channel=False, **kw):
  function transformer (line 192) | def transformer(x, y=None, num_encoder=6, num_decoder=6, num_head=8,
  function layer_norm (line 244) | def layer_norm(x, dim=1, **kw):
  function embedding (line 269) | def embedding(x, size, vocabulary=None, **kw):

FILE: warm/module.py
  class Lambda (line 43) | class Lambda(nn.Module):
    method __init__ (line 48) | def __init__(self, fn, *arg, **kw):
    method forward (line 53) | def forward(self, x):
  class Sequential (line 58) | class Sequential(nn.Sequential):
    method forward (line 61) | def forward(self, x):
  class Shortcut (line 74) | class Shortcut(Sequential):
    method __init__ (line 79) | def __init__(self, *arg, projection=None):
    method forward (line 82) | def forward(self, x):

FILE: warm/util.py
  function camel_to_snake (line 16) | def camel_to_snake(name):
  function summary_str (line 22) | def summary_str(model):
  function summary (line 47) | def summary(model):
Condensed preview — 24 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (106K chars).
[
  {
    "path": ".gitignore",
    "chars": 1323,
    "preview": "# Byte-compiled / optimized / DLL files\n__pycache__/\n*.py[cod]\n*$py.class\n\n# C extensions\n*.so\n\n# Distribution / packagi"
  },
  {
    "path": "CONTRIBUTING.md",
    "chars": 478,
    "preview": "# Contributing to PyWarm\n\nPyWarm is developed on [GitHub](https://github.com/blue-season/pywarm). \n\nPlease use GitHub to"
  },
  {
    "path": "LICENSE.md",
    "chars": 1068,
    "preview": "MIT License\n\nCopyright (c) 2019 blue-season\n\nPermission is hereby granted, free of charge, to any person obtaining a cop"
  },
  {
    "path": "README.md",
    "chars": 6105,
    "preview": "\n[![PyWarm - A cleaner way to build neural networks for PyTorch](https://github.com/blue-season/pywarm/raw/gh-pages/docs"
  },
  {
    "path": "docs/example.md",
    "chars": 14900,
    "preview": "\n# PyWarm Examples\n\n## ResNet\n\nA more detailed example, the ResNet18 network defined in PyWarm and vanilla PyTorch:\n\n```"
  },
  {
    "path": "docs/text.mako",
    "chars": 3050,
    "preview": "## Define mini-templates for each portion of the doco.\n\n<%!\n  def indent(s, spaces=4):\n      new = s.replace('\\n', '\\n' "
  },
  {
    "path": "docs/tutorial.md",
    "chars": 6663,
    "preview": "\n# PyWarm Basic Tutorial\n\n## Import\n\nTo get started, first import PyWarm in your project:\n\n```Python\nimport warm\nimport "
  },
  {
    "path": "examples/efficientnet.py",
    "chars": 3466,
    "preview": "\n# 09-20-2019;\n\"\"\"\nEfficientNet\n\"\"\"\nfrom pathlib import Path\nimport sys\nsys.path.append(str(Path(__file__).parent.parent"
  },
  {
    "path": "examples/lstm.py",
    "chars": 3503,
    "preview": "# 09-07-2019;\n\"\"\"\nLSTM sequence model example, based on\nhttps://pytorch.org/tutorials/beginner/nlp/sequence_models_tutor"
  },
  {
    "path": "examples/mnist.py",
    "chars": 5314,
    "preview": "# 08-27-2019;\n\"\"\"\nMNIST training example.\nUse `python mnist.py` to run with PyTorch NN.\nUse `python mnist.py --warm` to "
  },
  {
    "path": "examples/mobilenet.py",
    "chars": 3286,
    "preview": "# 09-03-2019;\n\"\"\"\nConstruct a WarmMobileNetV2() using PyWarm, then copy state dicts\nfrom torchvision.models.mobilenet_v2"
  },
  {
    "path": "examples/resnet.py",
    "chars": 3296,
    "preview": "# 08-29-2019;\n\"\"\"\nConstruct a WarmResNet() using PyWarm, then copy state dicts\nfrom torchvision.models.resnet18() into W"
  },
  {
    "path": "examples/transformer.py",
    "chars": 2604,
    "preview": "# 09-05-2019;\n\"\"\"\nThe Transformer model from paper *Attention is all you need*.\n\"\"\"\nfrom pathlib import Path\nimport sys\n"
  },
  {
    "path": "pyproject.toml",
    "chars": 1035,
    "preview": "[tool.poetry]\nname = 'PyWarm'\nversion = '0.4.1'\ndescription = 'A cleaner way to build neural networks for PyTorch.'\nlice"
  },
  {
    "path": "tests/test_engine.py",
    "chars": 6242,
    "preview": "# 08-31-2019;\n\"\"\"\nTest cases for warm.engine.\n\"\"\"\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimp"
  },
  {
    "path": "tests/test_functional.py",
    "chars": 5034,
    "preview": "# 08-31-2019;\n\"\"\"\nTest cases for warm.functional.\n\"\"\"\nimport torch\nimport torch.nn as nn\nfrom pathlib import Path\nimport"
  },
  {
    "path": "tests/test_module.py",
    "chars": 1275,
    "preview": "# 08-31-2019;\n\"\"\"\nTest cases for warm.module.\n\"\"\"\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfro"
  },
  {
    "path": "tests/test_util.py",
    "chars": 974,
    "preview": "# 08-31-2019;\n\"\"\"\nTest cases for warm.util.\n\"\"\"\nimport torch\nimport torch.nn as nn\nimport numpy as np\nfrom pathlib impor"
  },
  {
    "path": "tests/test_warm.py",
    "chars": 390,
    "preview": "# 09-10-2019;\n\"\"\"\nTest cases for the warm module.\n\"\"\"\nimport torch.nn as nn\nfrom pathlib import Path\nimport sys\nsys.path"
  },
  {
    "path": "warm/__init__.py",
    "chars": 201,
    "preview": "# 09-10-2019;\n\n\"\"\" `warm.up` is an alias of\n[`warm.engine.prepare_model_`](https://blue-season.github.io/pywarm/referenc"
  },
  {
    "path": "warm/engine.py",
    "chars": 10107,
    "preview": "# 08-26-2019;\n\"\"\"\nPyWarm engine to the functional interface.\n\"\"\"\nimport torch\nimport torch.nn as nn\nimport numpy as np\nf"
  },
  {
    "path": "warm/functional.py",
    "chars": 16515,
    "preview": "# 08-27-2019;\n\"\"\"\nWraps around various torch.nn Modules to fit into a functional interface.\n\"\"\"\nimport torch\nimport torc"
  },
  {
    "path": "warm/module.py",
    "chars": 3039,
    "preview": "# 08-27-2019;\n\"\"\"\nCustom modules to enhance the nn Sequential experience.\n\nPyWarm's core concept is to use a functional "
  },
  {
    "path": "warm/util.py",
    "chars": 1711,
    "preview": "# 08-28-2019;\n\"\"\"\nShort utilities.\n\"\"\"\nimport torch\nimport torch.nn as nn\nimport numpy as np\nimport re\n\n\n\"\"\" Create a pr"
  }
]

About this extraction

This page contains the full source code of the blue-season/pywarm GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 24 files (99.2 KB), approximately 29.3k tokens, and a symbol index with 109 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!