Full Code of nmaac/acon for AI

main 99fd67928a6f cached
19 files
108.5 KB
29.5k tokens
188 symbols
1 requests
Download .txt
Repository: nmaac/acon
Branch: main
Commit: 99fd67928a6f
Files: 19
Total size: 108.5 KB

Directory structure:
gitextract_vf0sezxr/

├── ACON/
│   ├── ResNet_ACON/
│   │   ├── resnet_acon.py
│   │   ├── train.py
│   │   └── utils.py
│   └── ShuffleNetV2_ACON/
│       ├── network.py
│       ├── train.py
│       └── utils.py
├── LICENSE
├── MetaACON/
│   ├── ResNet_MetaACON/
│   │   ├── resnet_metaacon.py
│   │   ├── train.py
│   │   └── utils.py
│   └── ShuffleNet_MetaACON/
│       ├── network.py
│       ├── train.py
│       └── utils.py
├── README.md
├── TFNet/
│   ├── README.md
│   ├── network.py
│   ├── train.py
│   └── utils.py
└── acon.py

================================================
FILE CONTENTS
================================================

================================================
FILE: ACON/ResNet_ACON/resnet_acon.py
================================================
import torch
from torch import Tensor
import torch.nn as nn
from typing import Type, Any, Callable, Union, List, Optional

import sys
sys.path.insert(0,'../..')
from acon import AconC


__all__ = ['ResNet', 'resnet50_acon', 'resnet101_acon', 'resnet152_acon']


model_urls = {}


def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1) -> nn.Conv2d:
    """3x3 convolution with padding"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                     padding=dilation, groups=groups, bias=True, dilation=dilation)


def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d:
    """1x1 convolution"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=True)

class BasicBlock_ACON(nn.Module):
    # We change the ReLU activation functions to ACON-C
    # according to "Activate or Not: Learning Customized Activation" <https://arxiv.org/pdf/2009.04759.pdf>.
    expansion: int = 1

    def __init__(
        self,
        inplanes: int,
        planes: int,
        stride: int = 1,
        downsample: Optional[nn.Module] = None,
        groups: int = 1,
        base_width: int = 64,
        dilation: int = 1,
        norm_layer: Optional[Callable[..., nn.Module]] = None
    ) -> None:
        super(BasicBlock_ACON, self).__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        if groups != 1 or base_width != 64:
            raise ValueError('BasicBlock only supports groups=1 and base_width=64')
        if dilation > 1:
            raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
        # Both self.conv1 and self.downsample layers downsample the input when stride != 1
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.bn1 = norm_layer(planes)
        self.acon1 = AconC(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = norm_layer(planes)
        self.acon2 = AconC(planes)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x: Tensor) -> Tensor:
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.acon1(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.acon2(out)

        return out

class Bottleneck_ACON(nn.Module):
    # We change the ReLU activation function after the 3x3 convolution(self.conv2) to ACON-C
    # according to "Activate or Not: Learning Customized Activation" <https://arxiv.org/pdf/2009.04759.pdf>.

    # We use the original implementation which places the stride at the first 1x1 convolution(self.conv1)
    # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385.
    # This variant is also known as ResNet V1.5 and improves accuracy according to
    # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.

    expansion: int = 4

    def __init__(
        self,
        inplanes: int,
        planes: int,
        stride: int = 1,
        downsample: Optional[nn.Module] = None,
        groups: int = 1,
        base_width: int = 64,
        dilation: int = 1,
        norm_layer: Optional[Callable[..., nn.Module]] = None
    ) -> None:
        super(Bottleneck_ACON, self).__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        width = int(planes * (base_width / 64.)) * groups
        # Both self.conv2 and self.downsample layers downsample the input when stride != 1
        self.conv1 = conv1x1(inplanes, width, stride)
        self.bn1 = norm_layer(width)
        self.conv2 = conv3x3(width, width, 1, groups, dilation)
        self.bn2 = norm_layer(width)
        self.acon = AconC(width)
        self.conv3 = conv1x1(width, planes * self.expansion)
        self.bn3 = norm_layer(planes * self.expansion)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x: Tensor) -> Tensor:
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.acon(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu(out)

        return out


class ResNet(nn.Module):

    def __init__(
        self,
        block: Type[Union[BasicBlock_ACON, Bottleneck_ACON]],
        layers: List[int],
        num_classes: int = 1000,
        zero_init_residual: bool = False,
        groups: int = 1,
        width_per_group: int = 64,
        replace_stride_with_dilation: Optional[List[bool]] = None,
        norm_layer: Optional[Callable[..., nn.Module]] = None
    ) -> None:
        super(ResNet, self).__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        self._norm_layer = norm_layer

        self.inplanes = 64
        self.dilation = 1
        if replace_stride_with_dilation is None:
            # each element in the tuple indicates if we should replace
            # the 2x2 stride with a dilated convolution instead
            replace_stride_with_dilation = [False, False, False]
        if len(replace_stride_with_dilation) != 3:
            raise ValueError("replace_stride_with_dilation should be None "
                             "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
        self.groups = groups
        self.base_width = width_per_group
        self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,
                               bias=True)
        self.bn1 = norm_layer(self.inplanes)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, 64, layers[0])
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
                                       dilate=replace_stride_with_dilation[0])
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
                                       dilate=replace_stride_with_dilation[1])
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
                                       dilate=replace_stride_with_dilation[2])
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512 * block.expansion, num_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

        # Zero-initialize the last BN in each residual branch,
        # so that the residual branch starts with zeros, and each residual block behaves like an identity.
        # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
        if zero_init_residual:
            for m in self.modules():
                if isinstance(m, Bottleneck_ACON):
                    nn.init.constant_(m.bn3.weight, 0)  # type: ignore[arg-type]
                elif isinstance(m, BasicBlock_ACON):
                    nn.init.constant_(m.bn2.weight, 0)  # type: ignore[arg-type]

    def _make_layer(self, block: Type[Union[BasicBlock_ACON, Bottleneck_ACON]], planes: int, blocks: int,
                    stride: int = 1, dilate: bool = False) -> nn.Sequential:
        norm_layer = self._norm_layer
        downsample = None
        previous_dilation = self.dilation
        if dilate:
            self.dilation *= stride
            stride = 1
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                conv1x1(self.inplanes, planes * block.expansion, stride),
                norm_layer(planes * block.expansion),
            )

        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
                            self.base_width, previous_dilation, norm_layer))
        self.inplanes = planes * block.expansion
        for _ in range(1, blocks):
            layers.append(block(self.inplanes, planes, groups=self.groups,
                                base_width=self.base_width, dilation=self.dilation,
                                norm_layer=norm_layer))

        return nn.Sequential(*layers)

    def _forward_impl(self, x: Tensor) -> Tensor:
        # See note [TorchScript super()]
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)

        return x

    def forward(self, x: Tensor) -> Tensor:
        return self._forward_impl(x)


def _resnet(
    arch: str,
    block: Type[Union[BasicBlock_ACON, Bottleneck_ACON]],
    layers: List[int],
    pretrained: bool,
    progress: bool,
    **kwargs: Any
) -> ResNet:
    model = ResNet(block, layers, **kwargs)
    return model


def resnet50_acon(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
    r"""ResNet-50-acon model from
    `"Activate or Not: Learning Customized Activation" <https://arxiv.org/pdf/2009.04759.pdf>`_.
    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
        progress (bool): If True, displays a progress bar of the download to stderr
    """
    return _resnet('resnet50_acon', Bottleneck_ACON, [3, 4, 6, 3], pretrained, progress,
                   **kwargs)

def resnet101_acon(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
    r"""ResNet-101-acon model from
    `"Activate or Not: Learning Customized Activation" <https://arxiv.org/pdf/2009.04759.pdf>`_.
    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
        progress (bool): If True, displays a progress bar of the download to stderr
    """
    return _resnet('resnet101_acon', Bottleneck_ACON, [3, 4, 23, 3], pretrained, progress,
                   **kwargs)

def resnet152_acon(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
    r"""ResNet-152-acon model from
    `"Activate or Not: Learning Customized Activation" <https://arxiv.org/pdf/2009.04759.pdf>`_.
    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
        progress (bool): If True, displays a progress bar of the download to stderr
    """
    return _resnet('resnet152_acon', Bottleneck_ACON, [3, 8, 36, 3], pretrained, progress,
                   **kwargs)



================================================
FILE: ACON/ResNet_ACON/train.py
================================================
import os
import sys
import torch
import argparse
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import cv2
import numpy as np
import PIL
from PIL import Image
import time
import logging
import argparse
from resnet_acon import resnet50_acon
from utils import accuracy, AvgrageMeter, CrossEntropyLabelSmooth, save_checkpoint, get_lastest_model, get_parameters

class OpencvResize(object):

    def __init__(self, size=256):
        self.size = size

    def __call__(self, img):
        assert isinstance(img, PIL.Image.Image)
        img = np.asarray(img) # (H,W,3) RGB
        img = img[:,:,::-1] # 2 BGR
        img = np.ascontiguousarray(img)
        H, W, _ = img.shape
        target_size = (int(self.size/H * W + 0.5), self.size) if H < W else (self.size, int(self.size/W * H + 0.5))
        img = cv2.resize(img, target_size, interpolation=cv2.INTER_LINEAR)
        img = img[:,:,::-1] # 2 RGB
        img = np.ascontiguousarray(img)
        img = Image.fromarray(img)
        return img

class ToBGRTensor(object):

    def __call__(self, img):
        assert isinstance(img, (np.ndarray, PIL.Image.Image))
        if isinstance(img, PIL.Image.Image):
            img = np.asarray(img)
        img = img[:,:,::-1] # 2 BGR
        img = np.transpose(img, [2, 0, 1]) # 2 (3, H, W)
        img = np.ascontiguousarray(img)
        img = torch.from_numpy(img).float()
        return img

class DataIterator(object):

    def __init__(self, dataloader):
        self.dataloader = dataloader
        self.iterator = enumerate(self.dataloader)

    def next(self):
        try:
            _, data = next(self.iterator)
        except Exception:
            self.iterator = enumerate(self.dataloader)
            _, data = next(self.iterator)
        return data[0], data[1]

def get_args():
    parser = argparse.ArgumentParser("ResNet")
    parser.add_argument('--eval', default=False, action='store_true')
    parser.add_argument('--eval-resume', type=str, default='./res50.acon.pth', help='path for eval model')
    parser.add_argument('--batch-size', type=int, default=256, help='batch size')
    parser.add_argument('--total-iters', type=int, default=600000, help='total iters')
    parser.add_argument('--learning-rate', type=float, default=0.1, help='init learning rate')
    parser.add_argument('--momentum', type=float, default=0.9, help='momentum')
    parser.add_argument('--weight-decay', type=float, default=1e-4, help='weight decay')
    parser.add_argument('--save', type=str, default='./models', help='path for saving trained models')

    parser.add_argument('--auto-continue', type=bool, default=True, help='auto continue')
    parser.add_argument('--display-interval', type=int, default=20, help='display interval')
    parser.add_argument('--val-interval', type=int, default=50000, help='val interval')
    parser.add_argument('--save-interval', type=int, default=50000, help='save interval')



    parser.add_argument('--train-dir', type=str, default='data/train', help='path to training dataset')
    parser.add_argument('--val-dir', type=str, default='data/val', help='path to validation dataset')

    args = parser.parse_args()
    return args

def main():
    args = get_args()

    # Log
    log_format = '[%(asctime)s] %(message)s'
    logging.basicConfig(stream=sys.stdout, level=logging.INFO,
        format=log_format, datefmt='%d %I:%M:%S')
    t = time.time()
    local_time = time.localtime(t)
    if not os.path.exists('./log'):
        os.mkdir('./log')
    fh = logging.FileHandler(os.path.join('log/train-{}{:02}{}'.format(local_time.tm_year % 2000, local_time.tm_mon, t)))
    fh.setFormatter(logging.Formatter(log_format))
    logging.getLogger().addHandler(fh)

    use_gpu = False
    if torch.cuda.is_available():
        use_gpu = True

    assert os.path.exists(args.train_dir)
    train_dataset = datasets.ImageFolder(
        args.train_dir,
        transforms.Compose([
            transforms.RandomResizedCrop(224),
            transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4),
            transforms.RandomHorizontalFlip(0.5),
            ToBGRTensor(),
        ])
    )
    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=args.batch_size, shuffle=True,
        num_workers=1, pin_memory=use_gpu)
    train_dataprovider = DataIterator(train_loader)

    assert os.path.exists(args.val_dir)
    val_loader = torch.utils.data.DataLoader(
        datasets.ImageFolder(args.val_dir, transforms.Compose([
            OpencvResize(256),
            transforms.CenterCrop(224),
            ToBGRTensor(),
        ])),
        batch_size=200, shuffle=False,
        num_workers=1, pin_memory=use_gpu
    )
    val_dataprovider = DataIterator(val_loader)
    print('load data successfully')

    model = resnet50_acon()
    optimizer = torch.optim.SGD(get_parameters(model),
                                lr=args.learning_rate,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)
    criterion_smooth = CrossEntropyLabelSmooth(1000, 0.0)

    if use_gpu:
        model = nn.DataParallel(model)
        loss_function = criterion_smooth.cuda()
        device = torch.device("cuda")
    else:
        loss_function = criterion_smooth
        device = torch.device("cpu")

    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer,
                    lambda step : (1.0-step/args.total_iters) if step <= args.total_iters else 0, last_epoch=-1)

    model = model.to(device)

    all_iters = 0
    if args.auto_continue:
        lastest_model, iters = get_lastest_model()
        if lastest_model is not None:
            all_iters = iters
            checkpoint = torch.load(lastest_model, map_location=None if use_gpu else 'cpu')
            model.load_state_dict(checkpoint['state_dict'], strict=True)
            print('load from checkpoint')
            for i in range(iters):
                scheduler.step()

    args.optimizer = optimizer
    args.loss_function = loss_function
    args.scheduler = scheduler
    args.train_dataprovider = train_dataprovider
    args.val_dataprovider = val_dataprovider

    if args.eval:
        if args.eval_resume is not None:
            checkpoint = torch.load(args.eval_resume, map_location=None if use_gpu else 'cpu')
            load_checkpoint(model, checkpoint)
            validate(model, device, args, all_iters=all_iters)
        exit(0)

    while all_iters < args.total_iters:
        all_iters = train(model, device, args, val_interval=args.val_interval, bn_process=False, all_iters=all_iters)
        validate(model, device, args, all_iters=all_iters)
    validate(model, device, args, all_iters=all_iters)
    save_checkpoint({'state_dict': model.state_dict(),}, args.total_iters, tag='bnps-')

def adjust_bn_momentum(model, iters):
    for m in model.modules():
        if isinstance(m, nn.BatchNorm2d):
            m.momentum = 1 / iters

def train(model, device, args, *, val_interval, bn_process=False, all_iters=None):

    optimizer = args.optimizer
    loss_function = args.loss_function
    scheduler = args.scheduler
    train_dataprovider = args.train_dataprovider

    t1 = time.time()
    Top1_err, Top5_err = 0.0, 0.0
    model.train()
    for iters in range(1, val_interval + 1):
        scheduler.step()
        if bn_process:
            adjust_bn_momentum(model, iters)

        all_iters += 1
        d_st = time.time()
        data, target = train_dataprovider.next()
        target = target.type(torch.LongTensor)
        data, target = data.to(device), target.to(device)
        data_time = time.time() - d_st

        output = model(data)
        loss = loss_function(output, target)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        prec1, prec5 = accuracy(output, target, topk=(1, 5))

        Top1_err += 1 - prec1.item() / 100
        Top5_err += 1 - prec5.item() / 100

        if all_iters % args.display_interval == 0:
            printInfo = 'TRAIN Iter {}: lr = {:.6f},\tloss = {:.6f},\t'.format(all_iters, scheduler.get_lr()[0], loss.item()) + \
                        'Top-1 err = {:.6f},\t'.format(Top1_err / args.display_interval) + \
                        'Top-5 err = {:.6f},\t'.format(Top5_err / args.display_interval) + \
                        'data_time = {:.6f},\ttrain_time = {:.6f}'.format(data_time, (time.time() - t1) / args.display_interval)
            logging.info(printInfo)
            t1 = time.time()
            Top1_err, Top5_err = 0.0, 0.0

        if all_iters % args.save_interval == 0:
            save_checkpoint({
                'state_dict': model.state_dict(),
                }, all_iters)

    return all_iters

def validate(model, device, args, *, all_iters=None):
    objs = AvgrageMeter()
    top1 = AvgrageMeter()
    top5 = AvgrageMeter()

    loss_function = args.loss_function
    val_dataprovider = args.val_dataprovider

    model.eval()
    max_val_iters = 250
    t1  = time.time()
    with torch.no_grad():
        for _ in range(1, max_val_iters + 1):
            data, target = val_dataprovider.next()
            target = target.type(torch.LongTensor)
            data, target = data.to(device), target.to(device)

            output = model(data)
            loss = loss_function(output, target)

            prec1, prec5 = accuracy(output, target, topk=(1, 5))
            n = data.size(0)
            objs.update(loss.item(), n)
            top1.update(prec1.item(), n)
            top5.update(prec5.item(), n)

    logInfo = 'TEST Iter {}: loss = {:.6f},\t'.format(all_iters, objs.avg) + \
              'Top-1 err = {:.6f},\t'.format(1 - top1.avg / 100) + \
              'Top-5 err = {:.6f},\t'.format(1 - top5.avg / 100) + \
              'val_time = {:.6f}'.format(time.time() - t1)
    logging.info(logInfo)

def load_checkpoint(net, checkpoint):
    from collections import OrderedDict

    temp = OrderedDict()
    if 'state_dict' in checkpoint:
        checkpoint = dict(checkpoint['state_dict'])
    for k in checkpoint:
        k2 = 'module.'+k if not k.startswith('module.') else k
        temp[k2] = checkpoint[k]

    net.load_state_dict(temp, strict=True)

if __name__ == "__main__":
    main()


================================================
FILE: ACON/ResNet_ACON/utils.py
================================================
import os
import re
import torch
import torch.nn as nn

class CrossEntropyLabelSmooth(nn.Module):

	def __init__(self, num_classes, epsilon):
		super(CrossEntropyLabelSmooth, self).__init__()
		self.num_classes = num_classes
		self.epsilon = epsilon
		self.logsoftmax = nn.LogSoftmax(dim=1)

	def forward(self, inputs, targets):
		log_probs = self.logsoftmax(inputs)
		targets = torch.zeros_like(log_probs).scatter_(1, targets.unsqueeze(1), 1)
		targets = (1 - self.epsilon) * targets + self.epsilon / self.num_classes
		loss = (-targets * log_probs).mean(0).sum()
		return loss


class AvgrageMeter(object):

	def __init__(self):
		self.reset()

	def reset(self):
		self.avg = 0
		self.sum = 0
		self.cnt = 0
		self.val = 0

	def update(self, val, n=1):
		self.val = val
		self.sum += val * n
		self.cnt += n
		self.avg = self.sum / self.cnt


def accuracy(output, target, topk=(1,)):
	maxk = max(topk)
	batch_size = target.size(0)

	_, pred = output.topk(maxk, 1, True, True)
	pred = pred.t()
	correct = pred.eq(target.view(1, -1).expand_as(pred))

	res = []
	for k in topk:
		correct_k = correct[:k].reshape(-1).float().sum(0)
		res.append(correct_k.mul_(100.0/batch_size))
	return res


def save_checkpoint(state, iters, tag=''):
	if not os.path.exists("./models"):
		os.makedirs("./models")
	filename = os.path.join("./models/{}checkpoint-{:06}.pth.tar".format(tag, iters))
	torch.save(state, filename)

def get_lastest_model():
	if not os.path.exists('./models'):
		os.mkdir('./models')
	model_list = os.listdir('./models/')
	if model_list == []:
		return None, 0
	model_list.sort()
	lastest_model = model_list[-1]
	iters = re.findall(r'\d+', lastest_model)
	return './models/' + lastest_model, int(iters[0])


def get_parameters(model):
	group_no_weight_decay = []
	group_weight_decay = []
	for pname, p in model.named_parameters():
		if pname.find('weight') >= 0 and len(p.size()) > 1:
			# print('include ', pname, p.size())
			group_weight_decay.append(p)
		else:
			# print('not include ', pname, p.size())
			group_no_weight_decay.append(p)
	assert len(list(model.parameters())) == len(group_weight_decay) + len(group_no_weight_decay)
	groups = [dict(params=group_weight_decay), dict(params=group_no_weight_decay, weight_decay=0.)]
	return groups


================================================
FILE: ACON/ShuffleNetV2_ACON/network.py
================================================
import torch
import torch.nn as nn

import sys
sys.path.insert(0,'../..')
from acon import AconC

class ShuffleV2Block_ACON(nn.Module):
    def __init__(self, inp, oup, mid_channels, *, ksize, stride):
        super(ShuffleV2Block_ACON, self).__init__()
        self.stride = stride
        assert stride in [1, 2]

        self.mid_channels = mid_channels
        self.ksize = ksize
        pad = ksize // 2
        self.pad = pad
        self.inp = inp

        outputs = oup - inp

        branch_main = [
            # pw
            nn.Conv2d(inp, mid_channels, 1, 1, 0, bias=True),
            nn.BatchNorm2d(mid_channels),
            AconC(mid_channels),
            # dw
            nn.Conv2d(mid_channels, mid_channels, ksize, stride, pad, groups=mid_channels, bias=True),
            nn.BatchNorm2d(mid_channels),
            # pw-linear
            nn.Conv2d(mid_channels, outputs, 1, 1, 0, bias=True),
            nn.BatchNorm2d(outputs),
            AconC(outputs),
        ]
        self.branch_main = nn.Sequential(*branch_main)

        if stride == 2:
            branch_proj = [
                # dw
                nn.Conv2d(inp, inp, ksize, stride, pad, groups=inp, bias=True),
                nn.BatchNorm2d(inp),
                # pw-linear
                nn.Conv2d(inp, inp, 1, 1, 0, bias=True),
                nn.BatchNorm2d(inp),
                AconC(inp),
            ]
            self.branch_proj = nn.Sequential(*branch_proj)
        else:
            self.branch_proj = None

    def forward(self, old_x):
        if self.stride==1:
            x_proj, x = self.channel_shuffle(old_x)
            return torch.cat((x_proj, self.branch_main(x)), 1)
        elif self.stride==2:
            x_proj = old_x
            x = old_x
            return torch.cat((self.branch_proj(x_proj), self.branch_main(x)), 1)

    def channel_shuffle(self, x):
        batchsize, num_channels, height, width = x.data.size()
        assert (num_channels % 4 == 0)
        x = x.reshape(batchsize * num_channels // 2, 2, height * width)
        x = x.permute(1, 0, 2)
        x = x.reshape(2, -1, num_channels // 2, height, width)
        return x[0], x[1]


class ShuffleNetV2_ACON(nn.Module):
    def __init__(self, input_size=224, n_class=1000, model_size='1.5x'):
        super(ShuffleNetV2_ACON, self).__init__()
        print('model size is ', model_size)

        self.stage_repeats = [4, 8, 4]
        self.model_size = model_size
        if model_size == '0.5x':
            self.stage_out_channels = [-1, 24, 48, 96, 192, 1024]
        elif model_size == '1.0x':
            self.stage_out_channels = [-1, 24, 116, 232, 464, 1024]
        elif model_size == '1.5x':
            self.stage_out_channels = [-1, 24, 176, 352, 704, 1024]
        elif model_size == '2.0x':
            self.stage_out_channels = [-1, 24, 244, 488, 976, 2048]
        else:
            raise NotImplementedError

        # building first layer
        input_channel = self.stage_out_channels[1]
        self.first_conv = nn.Sequential(
            nn.Conv2d(3, input_channel, 3, 2, 1, bias=True),
            nn.BatchNorm2d(input_channel),
            AconC(input_channel),
        )

        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

        self.features = []
        for idxstage in range(len(self.stage_repeats)):
            numrepeat = self.stage_repeats[idxstage]
            output_channel = self.stage_out_channels[idxstage+2]

            for i in range(numrepeat):
                if i == 0:
                    self.features.append(ShuffleV2Block_ACON(input_channel, output_channel,
                                                mid_channels=output_channel // 2, ksize=3, stride=2))
                else:
                    self.features.append(ShuffleV2Block_ACON(input_channel // 2, output_channel,
                                                mid_channels=output_channel // 2, ksize=3, stride=1))

                input_channel = output_channel

        self.features = nn.Sequential(*self.features)

        self.conv_last = nn.Sequential(
            nn.Conv2d(input_channel, self.stage_out_channels[-1], 1, 1, 0, bias=True),
            nn.BatchNorm2d(self.stage_out_channels[-1]),
            AconC(self.stage_out_channels[-1]),
        )
        self.globalpool = nn.AvgPool2d(7)
        if self.model_size == '2.0x':
            self.dropout = nn.Dropout(0.2)
        self.classifier = nn.Sequential(nn.Linear(self.stage_out_channels[-1], n_class, bias=True))
        self._initialize_weights()

    def forward(self, x):
        x = self.first_conv(x)
        x = self.maxpool(x)
        x = self.features(x)
        x = self.conv_last(x)

        x = self.globalpool(x)
        if self.model_size == '2.0x':
            x = self.dropout(x)
        x = x.contiguous().view(-1, self.stage_out_channels[-1])
        x = self.classifier(x)
        return x

    def _initialize_weights(self):
        for name, m in self.named_modules():
            if isinstance(m, nn.Conv2d):
                if 'first' in name:
                    nn.init.normal_(m.weight, 0, 0.01)
                else:
                    nn.init.normal_(m.weight, 0, 1.0 / m.weight.shape[1])
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0.0001)
                nn.init.constant_(m.running_mean, 0)
            elif isinstance(m, nn.BatchNorm1d):
                nn.init.constant_(m.weight, 1)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0.0001)
                nn.init.constant_(m.running_mean, 0)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)



================================================
FILE: ACON/ShuffleNetV2_ACON/train.py
================================================
import os
import sys
import torch
import argparse
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import cv2
import numpy as np
import PIL
from PIL import Image
import time
import logging
import argparse
from network import ShuffleNetV2_ACON
from utils import accuracy, AvgrageMeter, CrossEntropyLabelSmooth, save_checkpoint, get_lastest_model, get_parameters

class OpencvResize(object):

    def __init__(self, size=256):
        self.size = size

    def __call__(self, img):
        assert isinstance(img, PIL.Image.Image)
        img = np.asarray(img) # (H,W,3) RGB
        img = img[:,:,::-1] # 2 BGR
        img = np.ascontiguousarray(img)
        H, W, _ = img.shape
        target_size = (int(self.size/H * W + 0.5), self.size) if H < W else (self.size, int(self.size/W * H + 0.5))
        img = cv2.resize(img, target_size, interpolation=cv2.INTER_LINEAR)
        img = img[:,:,::-1] # 2 RGB
        img = np.ascontiguousarray(img)
        img = Image.fromarray(img)
        return img

class ToBGRTensor(object):

    def __call__(self, img):
        assert isinstance(img, (np.ndarray, PIL.Image.Image))
        if isinstance(img, PIL.Image.Image):
            img = np.asarray(img)
        img = img[:,:,::-1] # 2 BGR
        img = np.transpose(img, [2, 0, 1]) # 2 (3, H, W)
        img = np.ascontiguousarray(img)
        img = torch.from_numpy(img).float()
        return img

class DataIterator(object):

    def __init__(self, dataloader):
        self.dataloader = dataloader
        self.iterator = enumerate(self.dataloader)

    def next(self):
        try:
            _, data = next(self.iterator)
        except Exception:
            self.iterator = enumerate(self.dataloader)
            _, data = next(self.iterator)
        return data[0], data[1]

def get_args():
    parser = argparse.ArgumentParser("ShuffleNetV2_ACON")
    parser.add_argument('--eval', default=False, action='store_true')
    parser.add_argument('--eval-resume', type=str, default='./shufflenetv2.0.5.acon.pth', help='path for eval model')
    parser.add_argument('--batch-size', type=int, default=1024, help='batch size')
    parser.add_argument('--total-iters', type=int, default=300000, help='total iters')
    parser.add_argument('--learning-rate', type=float, default=0.5, help='init learning rate')
    parser.add_argument('--momentum', type=float, default=0.9, help='momentum')
    parser.add_argument('--weight-decay', type=float, default=4e-5, help='weight decay')
    parser.add_argument('--save', type=str, default='./models', help='path for saving trained models')
    parser.add_argument('--label-smooth', type=float, default=0.1, help='label smoothing')

    parser.add_argument('--auto-continue', type=bool, default=True, help='auto continue')
    parser.add_argument('--display-interval', type=int, default=20, help='display interval')
    parser.add_argument('--val-interval', type=int, default=10000, help='val interval')
    parser.add_argument('--save-interval', type=int, default=10000, help='save interval')


    parser.add_argument('--model-size', type=str, default='0.5x', choices=['0.5x', '1.0x', '1.5x', '2.0x'], help='size of the model')

    parser.add_argument('--train-dir', type=str, default='data/train', help='path to training dataset')
    parser.add_argument('--val-dir', type=str, default='data/val', help='path to validation dataset')

    args = parser.parse_args()
    return args

def main():
    args = get_args()

    # Log
    log_format = '[%(asctime)s] %(message)s'
    logging.basicConfig(stream=sys.stdout, level=logging.INFO,
        format=log_format, datefmt='%d %I:%M:%S')
    t = time.time()
    local_time = time.localtime(t)
    if not os.path.exists('./log'):
        os.mkdir('./log')
    fh = logging.FileHandler(os.path.join('log/train-{}{:02}{}'.format(local_time.tm_year % 2000, local_time.tm_mon, t)))
    fh.setFormatter(logging.Formatter(log_format))
    logging.getLogger().addHandler(fh)

    use_gpu = False
    if torch.cuda.is_available():
        use_gpu = True

    assert os.path.exists(args.train_dir)
    train_dataset = datasets.ImageFolder(
        args.train_dir,
        transforms.Compose([
            transforms.RandomResizedCrop(224),
            transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4),
            transforms.RandomHorizontalFlip(0.5),
            ToBGRTensor(),
        ])
    )
    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=args.batch_size, shuffle=True,
        num_workers=1, pin_memory=use_gpu)
    train_dataprovider = DataIterator(train_loader)

    assert os.path.exists(args.val_dir)
    val_loader = torch.utils.data.DataLoader(
        datasets.ImageFolder(args.val_dir, transforms.Compose([
            OpencvResize(256),
            transforms.CenterCrop(224),
            ToBGRTensor(),
        ])),
        batch_size=200, shuffle=False,
        num_workers=1, pin_memory=use_gpu
    )
    val_dataprovider = DataIterator(val_loader)
    print('load data successfully')

    model = ShuffleNetV2_ACON(model_size=args.model_size)

    optimizer = torch.optim.SGD(get_parameters(model),
                                lr=args.learning_rate,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)
    criterion_smooth = CrossEntropyLabelSmooth(1000, 0.1)

    if use_gpu:
        model = nn.DataParallel(model)
        loss_function = criterion_smooth.cuda()
        device = torch.device("cuda")
    else:
        loss_function = criterion_smooth
        device = torch.device("cpu")

    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer,
                    lambda step : (1.0-step/args.total_iters) if step <= args.total_iters else 0, last_epoch=-1)

    model = model.to(device)

    all_iters = 0
    if args.auto_continue:
        lastest_model, iters = get_lastest_model()
        if lastest_model is not None:
            all_iters = iters
            checkpoint = torch.load(lastest_model, map_location=None if use_gpu else 'cpu')
            model.load_state_dict(checkpoint['state_dict'], strict=True)
            print('load from checkpoint')
            for i in range(iters):
                scheduler.step()

    args.optimizer = optimizer
    args.loss_function = loss_function
    args.scheduler = scheduler
    args.train_dataprovider = train_dataprovider
    args.val_dataprovider = val_dataprovider

    if args.eval:
        if args.eval_resume is not None:
            checkpoint = torch.load(args.eval_resume, map_location=None if use_gpu else 'cpu')
            load_checkpoint(model, checkpoint)
            validate(model, device, args, all_iters=all_iters)
        exit(0)

    while all_iters < args.total_iters:
        all_iters = train(model, device, args, val_interval=args.val_interval, bn_process=False, all_iters=all_iters)
        validate(model, device, args, all_iters=all_iters)
    validate(model, device, args, all_iters=all_iters)
    save_checkpoint({'state_dict': model.state_dict(),}, args.total_iters, tag='bnps-')

def adjust_bn_momentum(model, iters):
    for m in model.modules():
        if isinstance(m, nn.BatchNorm2d):
            m.momentum = 1 / iters

def train(model, device, args, *, val_interval, bn_process=False, all_iters=None):

    optimizer = args.optimizer
    loss_function = args.loss_function
    scheduler = args.scheduler
    train_dataprovider = args.train_dataprovider

    t1 = time.time()
    Top1_err, Top5_err = 0.0, 0.0
    model.train()
    for iters in range(1, val_interval + 1):
        scheduler.step()
        if bn_process:
            adjust_bn_momentum(model, iters)

        all_iters += 1
        d_st = time.time()
        data, target = train_dataprovider.next()
        target = target.type(torch.LongTensor)
        data, target = data.to(device), target.to(device)
        data_time = time.time() - d_st

        output = model(data)
        loss = loss_function(output, target)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        prec1, prec5 = accuracy(output, target, topk=(1, 5))

        Top1_err += 1 - prec1.item() / 100
        Top5_err += 1 - prec5.item() / 100

        if all_iters % args.display_interval == 0:
            printInfo = 'TRAIN Iter {}: lr = {:.6f},\tloss = {:.6f},\t'.format(all_iters, scheduler.get_lr()[0], loss.item()) + \
                        'Top-1 err = {:.6f},\t'.format(Top1_err / args.display_interval) + \
                        'Top-5 err = {:.6f},\t'.format(Top5_err / args.display_interval) + \
                        'data_time = {:.6f},\ttrain_time = {:.6f}'.format(data_time, (time.time() - t1) / args.display_interval)
            logging.info(printInfo)
            t1 = time.time()
            Top1_err, Top5_err = 0.0, 0.0

        if all_iters % args.save_interval == 0:
            save_checkpoint({
                'state_dict': model.state_dict(),
                }, all_iters)

    return all_iters

def validate(model, device, args, *, all_iters=None):
    objs = AvgrageMeter()
    top1 = AvgrageMeter()
    top5 = AvgrageMeter()

    loss_function = args.loss_function
    val_dataprovider = args.val_dataprovider

    model.eval()
    max_val_iters = 250
    t1  = time.time()
    with torch.no_grad():
        for _ in range(1, max_val_iters + 1):
            data, target = val_dataprovider.next()
            target = target.type(torch.LongTensor)
            data, target = data.to(device), target.to(device)

            output = model(data)
            loss = loss_function(output, target)

            prec1, prec5 = accuracy(output, target, topk=(1, 5))
            n = data.size(0)
            objs.update(loss.item(), n)
            top1.update(prec1.item(), n)
            top5.update(prec5.item(), n)

    logInfo = 'TEST Iter {}: loss = {:.6f},\t'.format(all_iters, objs.avg) + \
              'Top-1 err = {:.6f},\t'.format(1 - top1.avg / 100) + \
              'Top-5 err = {:.6f},\t'.format(1 - top5.avg / 100) + \
              'val_time = {:.6f}'.format(time.time() - t1)
    logging.info(logInfo)

def load_checkpoint(net, checkpoint):
    from collections import OrderedDict

    temp = OrderedDict()
    if 'state_dict' in checkpoint:
        checkpoint = dict(checkpoint['state_dict'])
    for k in checkpoint:
        k2 = 'module.'+k if not k.startswith('module.') else k
        temp[k2] = checkpoint[k]

    net.load_state_dict(temp, strict=True)

if __name__ == "__main__":
    main()


================================================
FILE: ACON/ShuffleNetV2_ACON/utils.py
================================================
import os
import re
import torch
import torch.nn as nn

class CrossEntropyLabelSmooth(nn.Module):

	def __init__(self, num_classes, epsilon):
		super(CrossEntropyLabelSmooth, self).__init__()
		self.num_classes = num_classes
		self.epsilon = epsilon
		self.logsoftmax = nn.LogSoftmax(dim=1)

	def forward(self, inputs, targets):
		log_probs = self.logsoftmax(inputs)
		targets = torch.zeros_like(log_probs).scatter_(1, targets.unsqueeze(1), 1)
		targets = (1 - self.epsilon) * targets + self.epsilon / self.num_classes
		loss = (-targets * log_probs).mean(0).sum()
		return loss


class AvgrageMeter(object):

	def __init__(self):
		self.reset()

	def reset(self):
		self.avg = 0
		self.sum = 0
		self.cnt = 0
		self.val = 0

	def update(self, val, n=1):
		self.val = val
		self.sum += val * n
		self.cnt += n
		self.avg = self.sum / self.cnt


def accuracy(output, target, topk=(1,)):
	maxk = max(topk)
	batch_size = target.size(0)

	_, pred = output.topk(maxk, 1, True, True)
	pred = pred.t()
	correct = pred.eq(target.view(1, -1).expand_as(pred))

	res = []
	for k in topk:
		correct_k = correct[:k].reshape(-1).float().sum(0)
		res.append(correct_k.mul_(100.0/batch_size))
	return res


def save_checkpoint(state, iters, tag=''):
	if not os.path.exists("./models"):
		os.makedirs("./models")
	filename = os.path.join("./models/{}checkpoint-{:06}.pth.tar".format(tag, iters))
	torch.save(state, filename)

def get_lastest_model():
	if not os.path.exists('./models'):
		os.mkdir('./models')
	model_list = os.listdir('./models/')
	if model_list == []:
		return None, 0
	model_list.sort()
	lastest_model = model_list[-1]
	iters = re.findall(r'\d+', lastest_model)
	return './models/' + lastest_model, int(iters[0])


def get_parameters(model):
	group_no_weight_decay = []
	group_weight_decay = []
	for pname, p in model.named_parameters():
		if pname.find('weight') >= 0 and len(p.size()) > 1:
			# print('include ', pname, p.size())
			group_weight_decay.append(p)
		else:
			# print('not include ', pname, p.size())
			group_no_weight_decay.append(p)
	assert len(list(model.parameters())) == len(group_weight_decay) + len(group_no_weight_decay)
	groups = [dict(params=group_weight_decay), dict(params=group_no_weight_decay, weight_decay=0.)]
	return groups


================================================
FILE: LICENSE
================================================
MIT License

Copyright (c) 2021 nmaac 

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.


================================================
FILE: MetaACON/ResNet_MetaACON/resnet_metaacon.py
================================================
import torch
from torch import Tensor
import torch.nn as nn
from typing import Type, Any, Callable, Union, List, Optional

import sys
sys.path.insert(0,'../..')
from acon import MetaAconC


__all__ = ['ResNet', 'resnet50_metaacon', 'resnet101_metaacon', 'resnet152_metaacon']


model_urls = {}


def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1) -> nn.Conv2d:
    """3x3 convolution with padding"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                     padding=dilation, groups=groups, bias=True, dilation=dilation)


def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d:
    """1x1 convolution"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=True)


class Bottleneck_MetaACON(nn.Module):
    # We change the ReLU activation function after the 3x3 convolution(self.conv2) to ACON-C
    # according to "Activate or Not: Learning Customized Activation" <https://arxiv.org/pdf/2009.04759.pdf>.

    # We use the original implementation which places the stride at the first 1x1 convolution(self.conv1)
    # while original implementation places the stride at the first 1x1 convolution(self.conv1)
    # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385.
    # This variant is also known as ResNet V1.5 and improves accuracy according to
    # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.

    expansion: int = 4

    def __init__(
        self,
        inplanes: int,
        planes: int,
        stride: int = 1,
        downsample: Optional[nn.Module] = None,
        groups: int = 1,
        base_width: int = 64,
        dilation: int = 1,
        norm_layer: Optional[Callable[..., nn.Module]] = None
    ) -> None:
        super(Bottleneck_MetaACON, self).__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        width = int(planes * (base_width / 64.)) * groups
        # Both self.conv2 and self.downsample layers downsample the input when stride != 1
        self.conv1 = conv1x1(inplanes, width, stride)
        self.bn1 = norm_layer(width)
        self.conv2 = conv3x3(width, width, 1, groups, dilation)
        self.bn2 = norm_layer(width)
        self.acon = MetaAconC(width)
        self.conv3 = conv1x1(width, planes * self.expansion)
        self.bn3 = norm_layer(planes * self.expansion)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x: Tensor) -> Tensor:
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.acon(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu(out)

        return out

class ResNet(nn.Module):

    def __init__(
        self,
        block: Type[Union[Bottleneck_MetaACON]],
        layers: List[int],
        num_classes: int = 1000,
        zero_init_residual: bool = False,
        groups: int = 1,
        width_per_group: int = 64,
        replace_stride_with_dilation: Optional[List[bool]] = None,
        norm_layer: Optional[Callable[..., nn.Module]] = None
    ) -> None:
        super(ResNet, self).__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        self._norm_layer = norm_layer

        self.inplanes = 64
        self.dilation = 1
        if replace_stride_with_dilation is None:
            # each element in the tuple indicates if we should replace
            # the 2x2 stride with a dilated convolution instead
            replace_stride_with_dilation = [False, False, False]
        if len(replace_stride_with_dilation) != 3:
            raise ValueError("replace_stride_with_dilation should be None "
                             "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
        self.groups = groups
        self.base_width = width_per_group
        self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,
                               bias=True)
        self.bn1 = norm_layer(self.inplanes)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, 64, layers[0])
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
                                       dilate=replace_stride_with_dilation[0])
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
                                       dilate=replace_stride_with_dilation[1])
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
                                       dilate=replace_stride_with_dilation[2])
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512 * block.expansion, num_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

        # Zero-initialize the last BN in each residual branch,
        # so that the residual branch starts with zeros, and each residual block behaves like an identity.
        # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
        if zero_init_residual:
            for m in self.modules():
                if isinstance(m, Bottleneck_MetaACON):
                    nn.init.constant_(m.bn3.weight, 0)  # type: ignore[arg-type]

    def _make_layer(self, block: Type[Union[Bottleneck_MetaACON]], planes: int, blocks: int,
                    stride: int = 1, dilate: bool = False) -> nn.Sequential:
        norm_layer = self._norm_layer
        downsample = None
        previous_dilation = self.dilation
        if dilate:
            self.dilation *= stride
            stride = 1
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                conv1x1(self.inplanes, planes * block.expansion, stride),
                norm_layer(planes * block.expansion),
            )

        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
                            self.base_width, previous_dilation, norm_layer))
        self.inplanes = planes * block.expansion
        for _ in range(1, blocks):
            layers.append(block(self.inplanes, planes, groups=self.groups,
                                base_width=self.base_width, dilation=self.dilation,
                                norm_layer=norm_layer))

        return nn.Sequential(*layers)

    def _forward_impl(self, x: Tensor) -> Tensor:
        # See note [TorchScript super()]
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)

        return x

    def forward(self, x: Tensor) -> Tensor:
        return self._forward_impl(x)


def _resnet(
    arch: str,
    block: Type[Union[Bottleneck_MetaACON]],
    layers: List[int],
    pretrained: bool,
    progress: bool,
    **kwargs: Any
) -> ResNet:
    model = ResNet(block, layers, **kwargs)
    return model


def resnet50_metaacon(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
    r"""ResNet-50-meta-acon model from
    `"Activate or Not: Learning Customized Activation" <https://arxiv.org/pdf/2009.04759.pdf>`_.
    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
        progress (bool): If True, displays a progress bar of the download to stderr
    """
    return _resnet('resnet50_metaacon', Bottleneck_MetaACON, [3, 4, 6, 3], pretrained, progress,
                   **kwargs)

def resnet101_metaacon(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
    r"""ResNet-101-meta-acon model from
    `"Activate or Not: Learning Customized Activation" <https://arxiv.org/pdf/2009.04759.pdf>`_.
    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
        progress (bool): If True, displays a progress bar of the download to stderr
    """
    return _resnet('resnet101_metaacon', Bottleneck_MetaACON, [3, 4, 23, 3], pretrained, progress,
                   **kwargs)

def resnet152_metaacon(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
    r"""ResNet-152-meta-acon model from
    `"Activate or Not: Learning Customized Activation" <https://arxiv.org/pdf/2009.04759.pdf>`_.
    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
        progress (bool): If True, displays a progress bar of the download to stderr
    """
    return _resnet('resnet152_metaacon', Bottleneck_MetaACON, [3, 8, 36, 3], pretrained, progress,
                   **kwargs)



================================================
FILE: MetaACON/ResNet_MetaACON/train.py
================================================
import os
import sys
import torch
import argparse
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import cv2
import numpy as np
import PIL
from PIL import Image
import time
import logging
import argparse
from resnet_metaacon import resnet50_metaacon
from utils import accuracy, AvgrageMeter, CrossEntropyLabelSmooth, save_checkpoint, get_lastest_model, get_parameters

class OpencvResize(object):

    def __init__(self, size=256):
        self.size = size

    def __call__(self, img):
        assert isinstance(img, PIL.Image.Image)
        img = np.asarray(img) # (H,W,3) RGB
        img = img[:,:,::-1] # 2 BGR
        img = np.ascontiguousarray(img)
        H, W, _ = img.shape
        target_size = (int(self.size/H * W + 0.5), self.size) if H < W else (self.size, int(self.size/W * H + 0.5))
        img = cv2.resize(img, target_size, interpolation=cv2.INTER_LINEAR)
        img = img[:,:,::-1] # 2 RGB
        img = np.ascontiguousarray(img)
        img = Image.fromarray(img)
        return img

class ToBGRTensor(object):

    def __call__(self, img):
        assert isinstance(img, (np.ndarray, PIL.Image.Image))
        if isinstance(img, PIL.Image.Image):
            img = np.asarray(img)
        img = img[:,:,::-1] # 2 BGR
        img = np.transpose(img, [2, 0, 1]) # 2 (3, H, W)
        img = np.ascontiguousarray(img)
        img = torch.from_numpy(img).float()
        return img

class DataIterator(object):

    def __init__(self, dataloader):
        self.dataloader = dataloader
        self.iterator = enumerate(self.dataloader)

    def next(self):
        try:
            _, data = next(self.iterator)
        except Exception:
            self.iterator = enumerate(self.dataloader)
            _, data = next(self.iterator)
        return data[0], data[1]

def get_args():
    parser = argparse.ArgumentParser("ResNet")
    parser.add_argument('--eval', default=False, action='store_true')
    parser.add_argument('--eval-resume', type=str, default='./res50.metaacon.pth', help='path for eval model')
    parser.add_argument('--batch-size', type=int, default=256, help='batch size')
    parser.add_argument('--total-iters', type=int, default=600000, help='total iters')
    parser.add_argument('--learning-rate', type=float, default=0.1, help='init learning rate')
    parser.add_argument('--momentum', type=float, default=0.9, help='momentum')
    parser.add_argument('--weight-decay', type=float, default=1e-4, help='weight decay')
    parser.add_argument('--save', type=str, default='./models', help='path for saving trained models')

    parser.add_argument('--auto-continue', type=bool, default=True, help='auto continue')
    parser.add_argument('--display-interval', type=int, default=20, help='display interval')
    parser.add_argument('--val-interval', type=int, default=50000, help='val interval')
    parser.add_argument('--save-interval', type=int, default=50000, help='save interval')



    parser.add_argument('--train-dir', type=str, default='data/train', help='path to training dataset')
    parser.add_argument('--val-dir', type=str, default='data/val', help='path to validation dataset')

    args = parser.parse_args()
    return args

def main():
    args = get_args()

    # Log
    log_format = '[%(asctime)s] %(message)s'
    logging.basicConfig(stream=sys.stdout, level=logging.INFO,
        format=log_format, datefmt='%d %I:%M:%S')
    t = time.time()
    local_time = time.localtime(t)
    if not os.path.exists('./log'):
        os.mkdir('./log')
    fh = logging.FileHandler(os.path.join('log/train-{}{:02}{}'.format(local_time.tm_year % 2000, local_time.tm_mon, t)))
    fh.setFormatter(logging.Formatter(log_format))
    logging.getLogger().addHandler(fh)

    use_gpu = False
    if torch.cuda.is_available():
        use_gpu = True

    assert os.path.exists(args.train_dir)
    train_dataset = datasets.ImageFolder(
        args.train_dir,
        transforms.Compose([
            transforms.RandomResizedCrop(224),
            transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4),
            transforms.RandomHorizontalFlip(0.5),
            ToBGRTensor(),
        ])
    )
    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=args.batch_size, shuffle=True,
        num_workers=1, pin_memory=use_gpu)
    train_dataprovider = DataIterator(train_loader)

    assert os.path.exists(args.val_dir)
    val_loader = torch.utils.data.DataLoader(
        datasets.ImageFolder(args.val_dir, transforms.Compose([
            OpencvResize(256),
            transforms.CenterCrop(224),
            ToBGRTensor(),
        ])),
        batch_size=200, shuffle=False,
        num_workers=1, pin_memory=use_gpu
    )
    val_dataprovider = DataIterator(val_loader)
    print('load data successfully')

    model = resnet50_metaacon()
    optimizer = torch.optim.SGD(get_parameters(model),
                                lr=args.learning_rate,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)
    criterion_smooth = CrossEntropyLabelSmooth(1000, 0.0)

    if use_gpu:
        model = nn.DataParallel(model)
        loss_function = criterion_smooth.cuda()
        device = torch.device("cuda")
    else:
        loss_function = criterion_smooth
        device = torch.device("cpu")

    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer,
                    lambda step : (1.0-step/args.total_iters) if step <= args.total_iters else 0, last_epoch=-1)

    model = model.to(device)

    all_iters = 0
    if args.auto_continue:
        lastest_model, iters = get_lastest_model()
        if lastest_model is not None:
            all_iters = iters
            checkpoint = torch.load(lastest_model, map_location=None if use_gpu else 'cpu')
            model.load_state_dict(checkpoint['state_dict'], strict=True)
            print('load from checkpoint')
            for i in range(iters):
                scheduler.step()

    args.optimizer = optimizer
    args.loss_function = loss_function
    args.scheduler = scheduler
    args.train_dataprovider = train_dataprovider
    args.val_dataprovider = val_dataprovider

    if args.eval:
        if args.eval_resume is not None:
            checkpoint = torch.load(args.eval_resume, map_location=None if use_gpu else 'cpu')
            load_checkpoint(model, checkpoint)
            validate(model, device, args, all_iters=all_iters)
        exit(0)

    while all_iters < args.total_iters:
        all_iters = train(model, device, args, val_interval=args.val_interval, bn_process=False, all_iters=all_iters)
        validate(model, device, args, all_iters=all_iters)
    validate(model, device, args, all_iters=all_iters)
    save_checkpoint({'state_dict': model.state_dict(),}, args.total_iters, tag='bnps-')

def adjust_bn_momentum(model, iters):
    for m in model.modules():
        if isinstance(m, nn.BatchNorm2d):
            m.momentum = 1 / iters

def train(model, device, args, *, val_interval, bn_process=False, all_iters=None):

    optimizer = args.optimizer
    loss_function = args.loss_function
    scheduler = args.scheduler
    train_dataprovider = args.train_dataprovider

    t1 = time.time()
    Top1_err, Top5_err = 0.0, 0.0
    model.train()
    for iters in range(1, val_interval + 1):
        scheduler.step()
        if bn_process:
            adjust_bn_momentum(model, iters)

        all_iters += 1
        d_st = time.time()
        data, target = train_dataprovider.next()
        target = target.type(torch.LongTensor)
        data, target = data.to(device), target.to(device)
        data_time = time.time() - d_st

        output = model(data)
        loss = loss_function(output, target)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        prec1, prec5 = accuracy(output, target, topk=(1, 5))

        Top1_err += 1 - prec1.item() / 100
        Top5_err += 1 - prec5.item() / 100

        if all_iters % args.display_interval == 0:
            printInfo = 'TRAIN Iter {}: lr = {:.6f},\tloss = {:.6f},\t'.format(all_iters, scheduler.get_lr()[0], loss.item()) + \
                        'Top-1 err = {:.6f},\t'.format(Top1_err / args.display_interval) + \
                        'Top-5 err = {:.6f},\t'.format(Top5_err / args.display_interval) + \
                        'data_time = {:.6f},\ttrain_time = {:.6f}'.format(data_time, (time.time() - t1) / args.display_interval)
            logging.info(printInfo)
            t1 = time.time()
            Top1_err, Top5_err = 0.0, 0.0

        if all_iters % args.save_interval == 0:
            save_checkpoint({
                'state_dict': model.state_dict(),
                }, all_iters)

    return all_iters

def validate(model, device, args, *, all_iters=None):
    objs = AvgrageMeter()
    top1 = AvgrageMeter()
    top5 = AvgrageMeter()

    loss_function = args.loss_function
    val_dataprovider = args.val_dataprovider

    model.eval()
    max_val_iters = 250
    t1  = time.time()
    with torch.no_grad():
        for _ in range(1, max_val_iters + 1):
            data, target = val_dataprovider.next()
            target = target.type(torch.LongTensor)
            data, target = data.to(device), target.to(device)

            output = model(data)
            loss = loss_function(output, target)

            prec1, prec5 = accuracy(output, target, topk=(1, 5))
            n = data.size(0)
            objs.update(loss.item(), n)
            top1.update(prec1.item(), n)
            top5.update(prec5.item(), n)

    logInfo = 'TEST Iter {}: loss = {:.6f},\t'.format(all_iters, objs.avg) + \
              'Top-1 err = {:.6f},\t'.format(1 - top1.avg / 100) + \
              'Top-5 err = {:.6f},\t'.format(1 - top5.avg / 100) + \
              'val_time = {:.6f}'.format(time.time() - t1)
    logging.info(logInfo)

def load_checkpoint(net, checkpoint):
    from collections import OrderedDict

    temp = OrderedDict()
    if 'state_dict' in checkpoint:
        checkpoint = dict(checkpoint['state_dict'])
    for k in checkpoint:
        k2 = 'module.'+k if not k.startswith('module.') else k
        temp[k2] = checkpoint[k]

    net.load_state_dict(temp, strict=True)

if __name__ == "__main__":
    main()


================================================
FILE: MetaACON/ResNet_MetaACON/utils.py
================================================
import os
import re
import torch
import torch.nn as nn

class CrossEntropyLabelSmooth(nn.Module):

	def __init__(self, num_classes, epsilon):
		super(CrossEntropyLabelSmooth, self).__init__()
		self.num_classes = num_classes
		self.epsilon = epsilon
		self.logsoftmax = nn.LogSoftmax(dim=1)

	def forward(self, inputs, targets):
		log_probs = self.logsoftmax(inputs)
		targets = torch.zeros_like(log_probs).scatter_(1, targets.unsqueeze(1), 1)
		targets = (1 - self.epsilon) * targets + self.epsilon / self.num_classes
		loss = (-targets * log_probs).mean(0).sum()
		return loss


class AvgrageMeter(object):

	def __init__(self):
		self.reset()

	def reset(self):
		self.avg = 0
		self.sum = 0
		self.cnt = 0
		self.val = 0

	def update(self, val, n=1):
		self.val = val
		self.sum += val * n
		self.cnt += n
		self.avg = self.sum / self.cnt


def accuracy(output, target, topk=(1,)):
	maxk = max(topk)
	batch_size = target.size(0)

	_, pred = output.topk(maxk, 1, True, True)
	pred = pred.t()
	correct = pred.eq(target.view(1, -1).expand_as(pred))

	res = []
	for k in topk:
		correct_k = correct[:k].reshape(-1).float().sum(0)
		res.append(correct_k.mul_(100.0/batch_size))
	return res


def save_checkpoint(state, iters, tag=''):
	if not os.path.exists("./models"):
		os.makedirs("./models")
	filename = os.path.join("./models/{}checkpoint-{:06}.pth.tar".format(tag, iters))
	torch.save(state, filename)

def get_lastest_model():
	if not os.path.exists('./models'):
		os.mkdir('./models')
	model_list = os.listdir('./models/')
	if model_list == []:
		return None, 0
	model_list.sort()
	lastest_model = model_list[-1]
	iters = re.findall(r'\d+', lastest_model)
	return './models/' + lastest_model, int(iters[0])


def get_parameters(model):
	group_no_weight_decay = []
	group_weight_decay = []
	for pname, p in model.named_parameters():
		if pname.find('weight') >= 0 and len(p.size()) > 1:
			# print('include ', pname, p.size())
			group_weight_decay.append(p)
		else:
			# print('not include ', pname, p.size())
			group_no_weight_decay.append(p)
	assert len(list(model.parameters())) == len(group_weight_decay) + len(group_no_weight_decay)
	groups = [dict(params=group_weight_decay), dict(params=group_no_weight_decay, weight_decay=0.)]
	return groups


================================================
FILE: MetaACON/ShuffleNet_MetaACON/network.py
================================================
import torch
import torch.nn as nn

import sys
sys.path.insert(0,'../..')
from acon import MetaAconC

class ShuffleV2Block_MetaACON(nn.Module):
    def __init__(self, inp, oup, mid_channels, *, ksize, stride, r=16):
        super(ShuffleV2Block_MetaACON, self).__init__()
        self.stride = stride
        assert stride in [1, 2]

        self.mid_channels = mid_channels
        self.ksize = ksize
        pad = ksize // 2
        self.pad = pad
        self.inp = inp

        outputs = oup - inp

        branch_main = [
            # pw
            nn.Conv2d(inp, mid_channels, 1, 1, 0, bias=True),
            nn.BatchNorm2d(mid_channels),
            MetaAconC(mid_channels, r=r),
            # dw
            nn.Conv2d(mid_channels, mid_channels, ksize, stride, pad, groups=mid_channels, bias=True),
            nn.BatchNorm2d(mid_channels),
            # pw-linear
            nn.Conv2d(mid_channels, outputs, 1, 1, 0, bias=True),
            nn.BatchNorm2d(outputs),
            MetaAconC(outputs, r=r),
        ]
        self.branch_main = nn.Sequential(*branch_main)

        if stride == 2:
            branch_proj = [
                # dw
                nn.Conv2d(inp, inp, ksize, stride, pad, groups=inp, bias=True),
                nn.BatchNorm2d(inp),
                # pw-linear
                nn.Conv2d(inp, inp, 1, 1, 0, bias=True),
                nn.BatchNorm2d(inp),
                MetaAconC(inp, r=r),
            ]
            self.branch_proj = nn.Sequential(*branch_proj)
        else:
            self.branch_proj = None

    def forward(self, old_x):
        if self.stride==1:
            x_proj, x = self.channel_shuffle(old_x)
            return torch.cat((x_proj, self.branch_main(x)), 1)
        elif self.stride==2:
            x_proj = old_x
            x = old_x
            return torch.cat((self.branch_proj(x_proj), self.branch_main(x)), 1)

    def channel_shuffle(self, x):
        batchsize, num_channels, height, width = x.data.size()
        assert (num_channels % 4 == 0)
        x = x.reshape(batchsize * num_channels // 2, 2, height * width)
        x = x.permute(1, 0, 2)
        x = x.reshape(2, -1, num_channels // 2, height, width)
        return x[0], x[1]


class ShuffleNetV2_MetaACON(nn.Module):
    def __init__(self, input_size=224, n_class=1000, model_size='1.5x'):
        super(ShuffleNetV2_MetaACON, self).__init__()
        print('model size is ', model_size)

        self.stage_repeats = [4, 8, 4]
        self.model_size = model_size
        self.r = 16
        if model_size == '0.5x':
            self.stage_out_channels = [-1, 24, 48, 96, 192, 1024]
            self.r = 8
        elif model_size == '1.0x':
            self.stage_out_channels = [-1, 24, 116, 232, 464, 1024]
        elif model_size == '1.5x':
            self.stage_out_channels = [-1, 24, 176, 352, 704, 1024]
        elif model_size == '2.0x':
            self.stage_out_channels = [-1, 24, 244, 488, 976, 2048]
        else:
            raise NotImplementedError

        # building first layer
        input_channel = self.stage_out_channels[1]
        self.first_conv = nn.Sequential(
            nn.Conv2d(3, input_channel, 3, 2, 1, bias=True),
            nn.BatchNorm2d(input_channel),
            MetaAconC(input_channel, r=self.r),
        )

        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

        self.features = []
        for idxstage in range(len(self.stage_repeats)):
            numrepeat = self.stage_repeats[idxstage]
            output_channel = self.stage_out_channels[idxstage+2]

            for i in range(numrepeat):
                if i == 0:
                    self.features.append(ShuffleV2Block_MetaACON(input_channel, output_channel,
                                                mid_channels=output_channel // 2, ksize=3, stride=2, r=self.r))
                else:
                    self.features.append(ShuffleV2Block_MetaACON(input_channel // 2, output_channel,
                                                mid_channels=output_channel // 2, ksize=3, stride=1, r=self.r))

                input_channel = output_channel

        self.features = nn.Sequential(*self.features)

        self.conv_last = nn.Sequential(
            nn.Conv2d(input_channel, self.stage_out_channels[-1], 1, 1, 0, bias=True),
            nn.BatchNorm2d(self.stage_out_channels[-1]),
            MetaAconC(self.stage_out_channels[-1], r=self.r),
        )
        self.globalpool = nn.AvgPool2d(7)
        if self.model_size == '2.0x':
            self.dropout = nn.Dropout(0.2)
        self.classifier = nn.Sequential(nn.Linear(self.stage_out_channels[-1], n_class, bias=True))
        self._initialize_weights()

    def forward(self, x):
        x = self.first_conv(x)
        x = self.maxpool(x)
        x = self.features(x)
        x = self.conv_last(x)

        x = self.globalpool(x)
        if self.model_size == '2.0x':
            x = self.dropout(x)
        x = x.contiguous().view(-1, self.stage_out_channels[-1])
        x = self.classifier(x)
        return x

    def _initialize_weights(self):
        for name, m in self.named_modules():
            if isinstance(m, nn.Conv2d):
                if 'first' in name:
                    nn.init.normal_(m.weight, 0, 0.01)
                else:
                    nn.init.normal_(m.weight, 0, 1.0 / m.weight.shape[1])
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0.0001)
                nn.init.constant_(m.running_mean, 0)
            elif isinstance(m, nn.BatchNorm1d):
                nn.init.constant_(m.weight, 1)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0.0001)
                nn.init.constant_(m.running_mean, 0)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)



================================================
FILE: MetaACON/ShuffleNet_MetaACON/train.py
================================================
import os
import sys
import torch
import argparse
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import cv2
import numpy as np
import PIL
from PIL import Image
import time
import logging
import argparse
from network import ShuffleNetV2_MetaACON
from utils import accuracy, AvgrageMeter, CrossEntropyLabelSmooth, save_checkpoint, get_lastest_model, get_parameters

class OpencvResize(object):

    def __init__(self, size=256):
        self.size = size

    def __call__(self, img):
        assert isinstance(img, PIL.Image.Image)
        img = np.asarray(img) # (H,W,3) RGB
        img = img[:,:,::-1] # 2 BGR
        img = np.ascontiguousarray(img)
        H, W, _ = img.shape
        target_size = (int(self.size/H * W + 0.5), self.size) if H < W else (self.size, int(self.size/W * H + 0.5))
        img = cv2.resize(img, target_size, interpolation=cv2.INTER_LINEAR)
        img = img[:,:,::-1] # 2 RGB
        img = np.ascontiguousarray(img)
        img = Image.fromarray(img)
        return img

class ToBGRTensor(object):

    def __call__(self, img):
        assert isinstance(img, (np.ndarray, PIL.Image.Image))
        if isinstance(img, PIL.Image.Image):
            img = np.asarray(img)
        img = img[:,:,::-1] # 2 BGR
        img = np.transpose(img, [2, 0, 1]) # 2 (3, H, W)
        img = np.ascontiguousarray(img)
        img = torch.from_numpy(img).float()
        return img

class DataIterator(object):

    def __init__(self, dataloader):
        self.dataloader = dataloader
        self.iterator = enumerate(self.dataloader)

    def next(self):
        try:
            _, data = next(self.iterator)
        except Exception:
            self.iterator = enumerate(self.dataloader)
            _, data = next(self.iterator)
        return data[0], data[1]

def get_args():
    parser = argparse.ArgumentParser("ShuffleNetV2_MetaACON")
    parser.add_argument('--eval', default=False, action='store_true')
    parser.add_argument('--eval-resume', type=str, default='./shufflenetv2.0.5.metaacon.pth', help='path for eval model')
    parser.add_argument('--batch-size', type=int, default=1024, help='batch size')
    parser.add_argument('--total-iters', type=int, default=300000, help='total iters')
    parser.add_argument('--learning-rate', type=float, default=0.5, help='init learning rate')
    parser.add_argument('--momentum', type=float, default=0.9, help='momentum')
    parser.add_argument('--weight-decay', type=float, default=4e-5, help='weight decay')
    parser.add_argument('--save', type=str, default='./models', help='path for saving trained models')
    parser.add_argument('--label-smooth', type=float, default=0.1, help='label smoothing')

    parser.add_argument('--auto-continue', type=bool, default=True, help='auto continue')
    parser.add_argument('--display-interval', type=int, default=20, help='display interval')
    parser.add_argument('--val-interval', type=int, default=10000, help='val interval')
    parser.add_argument('--save-interval', type=int, default=10000, help='save interval')


    parser.add_argument('--model-size', type=str, default='0.5x', choices=['0.5x', '1.0x', '1.5x', '2.0x'], help='size of the model')

    parser.add_argument('--train-dir', type=str, default='data/train', help='path to training dataset')
    parser.add_argument('--val-dir', type=str, default='data/val', help='path to validation dataset')

    args = parser.parse_args()
    return args

def main():
    args = get_args()

    # Log
    log_format = '[%(asctime)s] %(message)s'
    logging.basicConfig(stream=sys.stdout, level=logging.INFO,
        format=log_format, datefmt='%d %I:%M:%S')
    t = time.time()
    local_time = time.localtime(t)
    if not os.path.exists('./log'):
        os.mkdir('./log')
    fh = logging.FileHandler(os.path.join('log/train-{}{:02}{}'.format(local_time.tm_year % 2000, local_time.tm_mon, t)))
    fh.setFormatter(logging.Formatter(log_format))
    logging.getLogger().addHandler(fh)

    use_gpu = False
    if torch.cuda.is_available():
        use_gpu = True

    assert os.path.exists(args.train_dir)
    train_dataset = datasets.ImageFolder(
        args.train_dir,
        transforms.Compose([
            transforms.RandomResizedCrop(224),
            transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4),
            transforms.RandomHorizontalFlip(0.5),
            ToBGRTensor(),
        ])
    )
    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=args.batch_size, shuffle=True,
        num_workers=1, pin_memory=use_gpu)
    train_dataprovider = DataIterator(train_loader)

    assert os.path.exists(args.val_dir)
    val_loader = torch.utils.data.DataLoader(
        datasets.ImageFolder(args.val_dir, transforms.Compose([
            OpencvResize(256),
            transforms.CenterCrop(224),
            ToBGRTensor(),
        ])),
        batch_size=200, shuffle=False,
        num_workers=1, pin_memory=use_gpu
    )
    val_dataprovider = DataIterator(val_loader)
    print('load data successfully')

    model = ShuffleNetV2_MetaACON(model_size=args.model_size)

    optimizer = torch.optim.SGD(get_parameters(model),
                                lr=args.learning_rate,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)
    criterion_smooth = CrossEntropyLabelSmooth(1000, 0.1)

    if use_gpu:
        model = nn.DataParallel(model)
        loss_function = criterion_smooth.cuda()
        device = torch.device("cuda")
    else:
        loss_function = criterion_smooth
        device = torch.device("cpu")

    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer,
                    lambda step : (1.0-step/args.total_iters) if step <= args.total_iters else 0, last_epoch=-1)

    model = model.to(device)

    all_iters = 0
    if args.auto_continue:
        lastest_model, iters = get_lastest_model()
        if lastest_model is not None:
            all_iters = iters
            checkpoint = torch.load(lastest_model, map_location=None if use_gpu else 'cpu')
            model.load_state_dict(checkpoint['state_dict'], strict=True)
            print('load from checkpoint')
            for i in range(iters):
                scheduler.step()

    args.optimizer = optimizer
    args.loss_function = loss_function
    args.scheduler = scheduler
    args.train_dataprovider = train_dataprovider
    args.val_dataprovider = val_dataprovider

    if args.eval:
        if args.eval_resume is not None:
            checkpoint = torch.load(args.eval_resume, map_location=None if use_gpu else 'cpu')
            load_checkpoint(model, checkpoint)
            validate(model, device, args, all_iters=all_iters)
        exit(0)

    while all_iters < args.total_iters:
        all_iters = train(model, device, args, val_interval=args.val_interval, bn_process=False, all_iters=all_iters)
        validate(model, device, args, all_iters=all_iters)
    validate(model, device, args, all_iters=all_iters)
    save_checkpoint({'state_dict': model.state_dict(),}, args.total_iters, tag='bnps-')

def adjust_bn_momentum(model, iters):
    for m in model.modules():
        if isinstance(m, nn.BatchNorm2d):
            m.momentum = 1 / iters

def train(model, device, args, *, val_interval, bn_process=False, all_iters=None):

    optimizer = args.optimizer
    loss_function = args.loss_function
    scheduler = args.scheduler
    train_dataprovider = args.train_dataprovider

    t1 = time.time()
    Top1_err, Top5_err = 0.0, 0.0
    model.train()
    for iters in range(1, val_interval + 1):
        scheduler.step()
        if bn_process:
            adjust_bn_momentum(model, iters)

        all_iters += 1
        d_st = time.time()
        data, target = train_dataprovider.next()
        target = target.type(torch.LongTensor)
        data, target = data.to(device), target.to(device)
        data_time = time.time() - d_st

        output = model(data)
        loss = loss_function(output, target)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        prec1, prec5 = accuracy(output, target, topk=(1, 5))

        Top1_err += 1 - prec1.item() / 100
        Top5_err += 1 - prec5.item() / 100

        if all_iters % args.display_interval == 0:
            printInfo = 'TRAIN Iter {}: lr = {:.6f},\tloss = {:.6f},\t'.format(all_iters, scheduler.get_lr()[0], loss.item()) + \
                        'Top-1 err = {:.6f},\t'.format(Top1_err / args.display_interval) + \
                        'Top-5 err = {:.6f},\t'.format(Top5_err / args.display_interval) + \
                        'data_time = {:.6f},\ttrain_time = {:.6f}'.format(data_time, (time.time() - t1) / args.display_interval)
            logging.info(printInfo)
            t1 = time.time()
            Top1_err, Top5_err = 0.0, 0.0

        if all_iters % args.save_interval == 0:
            save_checkpoint({
                'state_dict': model.state_dict(),
                }, all_iters)

    return all_iters

def validate(model, device, args, *, all_iters=None):
    objs = AvgrageMeter()
    top1 = AvgrageMeter()
    top5 = AvgrageMeter()

    loss_function = args.loss_function
    val_dataprovider = args.val_dataprovider

    model.eval()
    max_val_iters = 250
    t1  = time.time()
    with torch.no_grad():
        for _ in range(1, max_val_iters + 1):
            data, target = val_dataprovider.next()
            target = target.type(torch.LongTensor)
            data, target = data.to(device), target.to(device)

            output = model(data)
            loss = loss_function(output, target)

            prec1, prec5 = accuracy(output, target, topk=(1, 5))
            n = data.size(0)
            objs.update(loss.item(), n)
            top1.update(prec1.item(), n)
            top5.update(prec5.item(), n)

    logInfo = 'TEST Iter {}: loss = {:.6f},\t'.format(all_iters, objs.avg) + \
              'Top-1 err = {:.6f},\t'.format(1 - top1.avg / 100) + \
              'Top-5 err = {:.6f},\t'.format(1 - top5.avg / 100) + \
              'val_time = {:.6f}'.format(time.time() - t1)
    logging.info(logInfo)

def load_checkpoint(net, checkpoint):
    from collections import OrderedDict

    temp = OrderedDict()
    if 'state_dict' in checkpoint:
        checkpoint = dict(checkpoint['state_dict'])
    for k in checkpoint:
        k2 = 'module.'+k if not k.startswith('module.') else k
        temp[k2] = checkpoint[k]

    net.load_state_dict(temp, strict=True)

if __name__ == "__main__":
    main()


================================================
FILE: MetaACON/ShuffleNet_MetaACON/utils.py
================================================
import os
import re
import torch
import torch.nn as nn

class CrossEntropyLabelSmooth(nn.Module):

	def __init__(self, num_classes, epsilon):
		super(CrossEntropyLabelSmooth, self).__init__()
		self.num_classes = num_classes
		self.epsilon = epsilon
		self.logsoftmax = nn.LogSoftmax(dim=1)

	def forward(self, inputs, targets):
		log_probs = self.logsoftmax(inputs)
		targets = torch.zeros_like(log_probs).scatter_(1, targets.unsqueeze(1), 1)
		targets = (1 - self.epsilon) * targets + self.epsilon / self.num_classes
		loss = (-targets * log_probs).mean(0).sum()
		return loss


class AvgrageMeter(object):

	def __init__(self):
		self.reset()

	def reset(self):
		self.avg = 0
		self.sum = 0
		self.cnt = 0
		self.val = 0

	def update(self, val, n=1):
		self.val = val
		self.sum += val * n
		self.cnt += n
		self.avg = self.sum / self.cnt


def accuracy(output, target, topk=(1,)):
	maxk = max(topk)
	batch_size = target.size(0)

	_, pred = output.topk(maxk, 1, True, True)
	pred = pred.t()
	correct = pred.eq(target.view(1, -1).expand_as(pred))

	res = []
	for k in topk:
		correct_k = correct[:k].reshape(-1).float().sum(0)
		res.append(correct_k.mul_(100.0/batch_size))
	return res


def save_checkpoint(state, iters, tag=''):
	if not os.path.exists("./models"):
		os.makedirs("./models")
	filename = os.path.join("./models/{}checkpoint-{:06}.pth.tar".format(tag, iters))
	torch.save(state, filename)

def get_lastest_model():
	if not os.path.exists('./models'):
		os.mkdir('./models')
	model_list = os.listdir('./models/')
	if model_list == []:
		return None, 0
	model_list.sort()
	lastest_model = model_list[-1]
	iters = re.findall(r'\d+', lastest_model)
	return './models/' + lastest_model, int(iters[0])


def get_parameters(model):
	group_no_weight_decay = []
	group_weight_decay = []
	for pname, p in model.named_parameters():
		if pname.find('weight') >= 0 and len(p.size()) > 1:
			# print('include ', pname, p.size())
			group_weight_decay.append(p)
		else:
			# print('not include ', pname, p.size())
			group_no_weight_decay.append(p)
	assert len(list(model.parameters())) == len(group_weight_decay) + len(group_no_weight_decay)
	groups = [dict(params=group_weight_decay), dict(params=group_no_weight_decay, weight_decay=0.)]
	return groups


================================================
FILE: README.md
================================================

## CVPR 2021 | Activate or Not: Learning Customized Activation.

This repository contains the official Pytorch implementation of the paper [Activate or Not: Learning Customized Activation, CVPR 2021](https://arxiv.org/pdf/2009.04759.pdf).

### ACON

We propose a novel activation function we term the ACON that explicitly learns to activate the neurons or not. 
Below we show the ACON activation function and its first derivatives. β controls how fast the first derivative asymptotes to the upper/lower bounds, which are determined by p1 and p2.


<img src="https://user-images.githubusercontent.com/5032208/113257297-fc76f380-92fc-11eb-9559-39d033baea4c.png" width=90%>

<img src="https://user-images.githubusercontent.com/5032208/113257194-cfc2dc00-92fc-11eb-94a0-f81569bed15e.png" width=90%>

### Training curves
We show the training curves of different activations here.

<img src="https://user-images.githubusercontent.com/5032208/113260052-65ac3600-9300-11eb-8d2f-ef968be1c3a2.png"  width=60%>


### TFNet
To show the effectiveness of the proposed acon family, we also provide an extreme simple toy funnel network (TFNet) made only by pointwise convolution and ACON-FReLU operators.

<img src="https://user-images.githubusercontent.com/5032208/113963614-7a3a8200-985c-11eb-8946-65c0bcef0a80.png"  width=60%>




## Main results

The following results are the ImageNet top-1 accuracy relative improvements compared with the ReLU baselines. The relative improvements of Meta-ACON are about twice as much as SENet.

<img src="https://user-images.githubusercontent.com/5032208/113256618-fcc2bf00-92fb-11eb-9b1d-8f0589009a9b.png" width=60%>

The comparison between ReLU, Swish and ACON-C. We show improvements without additional amount of FLOPs and parameters:
| Model             | FLOPs | #Params. | top-1 err. (ReLU) | top-1 err. (Swish) |   top-1 err. (ACON)   |
|-------------------|:-----:|:--------:|:-----------------:|:------------------:|:---------------------:|
| ShuffleNetV2 0.5x |  41M  |   1.4M   |        39.4       |     38.3 (+1.1)    |    **37.0 (+2.4)**    |
| ShuffleNetV2 1.5x |  299M |   3.5M   |        27.4       |     26.8 (+0.6)    |    **26.5 (+0.9)**    |
| ResNet 50         |  3.9G |   25.5M  |        24.0       |     23.5 (+0.5)    |    **23.2 (+0.8)**    |
| ResNet 101        |  7.6G |   44.4M  |        22.8       |     22.7 (+0.1)    |    **21.8 (+1.0)**    |
| ResNet 152        | 11.3G |   60.0M  |        22.3       |     22.2 (+0.1)    |    **21.2 (+1.1)**    |


Next, by adding a negligible amount of FLOPs and parameters, meta-ACON shows sigificant improvements:
| Model                         | FLOPs | #Params. |        top-1 err.      | 
|-------------------------------|:-----:|:--------:|:----------------------:|
| ShuffleNetV2 0.5x (meta-acon) | 41M   | 1.7M     |   **34.8 (+4.6)**      | 
| ShuffleNetV2 1.5x (meta-acon) | 299M  | 3.9M     |   **24.7 (+2.7)**      | 
| ResNet 50 (meta-acon)         | 3.9G  | 25.7M    |   **22.0 (+2.0)**      | 
| ResNet 101 (meta-acon)        | 7.6G  | 44.8M    |   **21.0 (+1.8)**      | 
| ResNet 152 (meta-acon)        | 11.3G | 60.5M    |   **20.5 (+1.8)**      | 





The simple TFNet without the SE modules can outperform the state-of-the art light-weight networks without the SE modules.

|                   | FLOPs | #Params. |   top-1 err.   |
|-----------------  |:-----:|:--------:|:--------------:|
|  MobileNetV2 0.17 |  42M  |   1.4M   |    52.6    |
| ShuffleNetV2 0.5x |  41M  |   1.4M   |    39.4    |
|     TFNet 0.5     |  43M  |   1.3M   |  **36.6 (+2.8)**  |
|  MobileNetV2 0.6  |  141M |   2.2M   |    33.3    |
| ShuffleNetV2 1.0x |  146M |   2.3M   |    30.6    |
|     TFNet 1.0     |  135M |   1.9M   |  **29.7 (+0.9)**  |
|  MobileNetV2 1.0  |  300M |   3.4M   |    28.0    |
| ShuffleNetV2 1.5x |  299M |   3.5M   |    27.4    |
|     TFNet 1.5     |  279M |   2.7M   |  **26.0 (+1.4)**  |
|  MobileNetV2 1.4  |  585M |   5.5M   |    25.3    |
| ShuffleNetV2 2.0x |  591M |   7.4M   |    25.0    |
| TFNet 2.0         |  474M |   3.8M   |  **24.3 (+0.7)**  |




## Trained Models
- OneDrive download: [Link](https://1drv.ms/u/s!AgaP37NGYuEXhWbwpi4SX1IX6gOs?e=wIQYs1)
- BaiduYun download: [Link](https://pan.baidu.com/s/18uDVWe-rh4b7qI_NBvWUCw) (extract code: 13fu)


## Usage

### Requirements
Download the ImageNet dataset and move validation images to labeled subfolders. To do this, you can use the following script:
https://raw.githubusercontent.com/soumith/imagenetloader.torch/master/valprep.sh


Train:
```shell
python train.py  --train-dir YOUR_TRAINDATASET_PATH --val-dir YOUR_VALDATASET_PATH
```
Eval:
```shell
python train.py --eval --eval-resume YOUR_WEIGHT_PATH --train-dir YOUR_TRAINDATASET_PATH --val-dir YOUR_VALDATASET_PATH
```


## Citation
If you use these models in your research, please cite:

    @inproceedings{ma2021activate,
      title={Activate or Not: Learning Customized Activation},
      author={Ma, Ningning and Zhang, Xiangyu and Liu, Ming and Sun, Jian},
      booktitle={Proceedings of the IEEE conference on computer vision and pattern recognition},
      year={2021}
    }



================================================
FILE: TFNet/README.md
================================================
# [TFNet](https://arxiv.org/pdf/2009.04759.pdf)
This repository contains TFNet implementation by Pytorch.


### TFNet
To show the effectiveness of the proposed acon family, we provide an extreme simple toy funnel network (TFNet) made only by pointwise convolution and ACON-FReLU operators.

<img src="https://user-images.githubusercontent.com/5032208/113963614-7a3a8200-985c-11eb-8946-65c0bcef0a80.png"  width=60%>


## Main results



The simple TFNet without the SE modules can outperform the state-of-the art light-weight networks without the SE modules.

|                   | FLOPs | #Params. |   top-1 err.   |
|-----------------  |:-----:|:--------:|:--------------:|
|  MobileNetV2 0.17 |  42M  |   1.4M   |    52.6    |
| ShuffleNetV2 0.5x |  41M  |   1.4M   |    39.4    |
|     TFNet 0.5     |  43M  |   1.3M   |  **36.6 (+2.8)**  |
|  MobileNetV2 0.6  |  141M |   2.2M   |    33.3    |
| ShuffleNetV2 1.0x |  146M |   2.3M   |    30.6    |
|     TFNet 1.0     |  135M |   1.9M   |  **29.7 (+0.9)**  |
|  MobileNetV2 1.0  |  300M |   3.4M   |    28.0    |
| ShuffleNetV2 1.5x |  299M |   3.5M   |    27.4    |
|     TFNet 1.5     |  279M |   2.7M   |  **26.0 (+1.4)**  |
|  MobileNetV2 1.4  |  585M |   5.5M   |    25.3    |
| ShuffleNetV2 2.0x |  591M |   7.4M   |    25.0    |
| TFNet 2.0         |  474M |   3.8M   |  **24.3 (+0.7)**  |




## Trained Models
- OneDrive download: [Link](https://1drv.ms/u/s!AgaP37NGYuEXhWbwpi4SX1IX6gOs?e=wIQYs1)
- BaiduYun download: [Link](https://pan.baidu.com/s/18uDVWe-rh4b7qI_NBvWUCw) (extract code: 13fu)


## Usage

### Requirements
Download the ImageNet dataset and move validation images to labeled subfolders. To do this, you can use the following script:
https://raw.githubusercontent.com/soumith/imagenetloader.torch/master/valprep.sh


Train:
```shell
python train.py  --train-dir YOUR_TRAINDATASET_PATH --val-dir YOUR_VALDATASET_PATH
```
Eval:
```shell
python train.py --eval --eval-resume YOUR_WEIGHT_PATH --train-dir YOUR_TRAINDATASET_PATH --val-dir YOUR_VALDATASET_PATH
```


## Citation
If you use these models in your research, please cite:

    @inproceedings{ma2021activate,
      title={Activate or Not: Learning Customized Activation},
      author={Ma, Ningning and Zhang, Xiangyu and Liu, Ming and Sun, Jian},
      booktitle={Proceedings of the IEEE conference on computer vision and pattern recognition},
      year={2021}
    }


================================================
FILE: TFNet/network.py
================================================
import torch
import torch.nn as nn

class Acon_FReLU(nn.Module):
    r""" ACON activation (activate or not) based on FReLU:
    # eta_a(x) = x, eta_b(x) = dw_conv(x), according to
    # "Funnel Activation for Visual Recognition" <https://arxiv.org/pdf/2007.11824.pdf>.
    """
    def __init__(self, width, stride=1):
        super().__init__()
        self.stride = stride

        # eta_b(x)
        self.conv_frelu = nn.Conv2d(width, width, kernel_size=3, stride=stride, padding=1, groups=width, bias=True)
        self.bn1 = nn.BatchNorm2d(width)

        # eta_a(x)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.bn2 = nn.BatchNorm2d(width)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x, **kwargs):
        if self.stride == 2:
            x1 = self.maxpool(x)
        else:
            x1 = x

        x2 = self.bn1(self.conv_frelu(x))

        return self.bn2( (x1 - x2) * self.sigmoid(x1 - x2) + x2 )


class TFBlock(nn.Module):
    def __init__(self, inp, stride):
        super(TFBlock, self).__init__()
        self.oup = inp * stride
        self.stride = stride

        branch_main = [
            # pw conv
            nn.Conv2d(inp, inp, kernel_size=1, stride=1, bias=True),
            nn.BatchNorm2d(inp),
            Acon_FReLU(inp),
            # pw conv
            nn.Conv2d(inp, inp, kernel_size=1, stride=1, bias=True),
            nn.BatchNorm2d(inp)
        ]
        self.branch_main = nn.Sequential(*branch_main)

        self.acon = Acon_FReLU(self.oup, stride)

    def forward(self, x):
        x_proj = x
        x = self.branch_main(x)

        if self.stride==1:
            return self.acon(x_proj + x)

        elif self.stride==2:
            return self.acon(torch.cat((x_proj, x), 1))


class TFNet(nn.Module):
    def __init__(self, n_class=1000, model_size=0.5):
        super(TFNet, self).__init__()
        print('model size is ', model_size)

        self.stages = [2, 3, 8, 3]
        self.in_channel = int(16 * model_size)
        self.out_channel = 1024
        self.model_size = model_size

        # building the first layer
        self.first_conv = nn.Sequential(
            nn.Conv2d(3, self.in_channel, 3, 2, 1, bias=True),
            nn.BatchNorm2d(self.in_channel),
            nn.ReLU(inplace=True),
        )

        # building the four stages' features
        self.features = []
        for stage in self.stages:
            for i in range(stage):
                self.features.append(
                        TFBlock(self.in_channel, stride = 1 if i > 0 else 2))
                self.in_channel = self.in_channel * 2 if i == 0 else self.in_channel
        self.features = nn.Sequential(*self.features)

        # building the last layer
        self.conv_last = nn.Sequential(
            nn.Conv2d(self.in_channel, self.out_channel, 1, 1, 0, bias=True),
            nn.BatchNorm2d(self.out_channel),
            Acon_FReLU(self.out_channel),
        )
        self.globalpool = nn.AvgPool2d(7)
        if self.model_size > 0.5:
            self.dropout = nn.Dropout(0.2)
        self.classifier = nn.Sequential(nn.Linear(self.out_channel, n_class, bias=True))
        self._initialize_weights()

    def forward(self, x):
        x = self.first_conv(x)
        x = self.features(x)
        x = self.conv_last(x)

        x = self.globalpool(x)
        if self.model_size > 0.5:
            x = self.dropout(x)
        x = x.contiguous().view(-1, self.out_channel)
        x = self.classifier(x)
        return x

    def _initialize_weights(self):
        for name, m in self.named_modules():
            if isinstance(m, nn.Conv2d):
                if 'first' in name or 'frelu' in name:
                    nn.init.normal_(m.weight, 0, 0.01)
                else:
                    nn.init.normal_(m.weight, 0, 1.0 / m.weight.shape[1])
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0.0001)
                nn.init.constant_(m.running_mean, 0)
            elif isinstance(m, nn.BatchNorm1d):
                nn.init.constant_(m.weight, 1)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0.0001)
                nn.init.constant_(m.running_mean, 0)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)


================================================
FILE: TFNet/train.py
================================================
import os
import sys
import torch
import argparse
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import cv2
import numpy as np
import PIL
from PIL import Image
import time
import logging
import argparse
from network import TFNet
from utils import accuracy, AvgrageMeter, CrossEntropyLabelSmooth, save_checkpoint, get_lastest_model, get_parameters

class OpencvResize(object):

    def __init__(self, size=256):
        self.size = size

    def __call__(self, img):
        assert isinstance(img, PIL.Image.Image)
        img = np.asarray(img) # (H,W,3) RGB
        img = img[:,:,::-1] # 2 BGR
        img = np.ascontiguousarray(img)
        H, W, _ = img.shape
        target_size = (int(self.size/H * W + 0.5), self.size) if H < W else (self.size, int(self.size/W * H + 0.5))
        img = cv2.resize(img, target_size, interpolation=cv2.INTER_LINEAR)
        img = img[:,:,::-1] # 2 RGB
        img = np.ascontiguousarray(img)
        img = Image.fromarray(img)
        return img

class ToBGRTensor(object):

    def __call__(self, img):
        assert isinstance(img, (np.ndarray, PIL.Image.Image))
        if isinstance(img, PIL.Image.Image):
            img = np.asarray(img)
        img = img[:,:,::-1] # 2 BGR
        img = np.transpose(img, [2, 0, 1]) # 2 (3, H, W)
        img = np.ascontiguousarray(img)
        img = torch.from_numpy(img).float()
        return img

class DataIterator(object):

    def __init__(self, dataloader):
        self.dataloader = dataloader
        self.iterator = enumerate(self.dataloader)

    def next(self):
        try:
            _, data = next(self.iterator)
        except Exception:
            self.iterator = enumerate(self.dataloader)
            _, data = next(self.iterator)
        return data[0], data[1]

def get_args():
    parser = argparse.ArgumentParser("TFNet")
    parser.add_argument('--eval', default=False, action='store_true')
    parser.add_argument('--eval-resume', type=str, default='./tfnet.0.5.pth', help='path for eval model')
    parser.add_argument('--batch-size', type=int, default=1024, help='batch size')
    parser.add_argument('--total-iters', type=int, default=300000, help='total iters')
    parser.add_argument('--learning-rate', type=float, default=0.5, help='init learning rate')
    parser.add_argument('--momentum', type=float, default=0.9, help='momentum')
    parser.add_argument('--weight-decay', type=float, default=4e-5, help='weight decay')
    parser.add_argument('--save', type=str, default='./models', help='path for saving trained models')
    parser.add_argument('--label-smooth', type=float, default=0.1, help='label smoothing')

    parser.add_argument('--auto-continue', type=bool, default=True, help='auto continue')
    parser.add_argument('--display-interval', type=int, default=20, help='display interval')
    parser.add_argument('--val-interval', type=int, default=10000, help='val interval')
    parser.add_argument('--save-interval', type=int, default=10000, help='save interval')


    parser.add_argument('--model-size', type=float, default=0.5, choices=[0.5, 1.0, 1.5, 2.0], help='size of the model')

    parser.add_argument('--train-dir', type=str, default='data/train', help='path to training dataset')
    parser.add_argument('--val-dir', type=str, default='data/val', help='path to validation dataset')

    args = parser.parse_args()
    return args

def main():
    args = get_args()

    # Log
    log_format = '[%(asctime)s] %(message)s'
    logging.basicConfig(stream=sys.stdout, level=logging.INFO,
        format=log_format, datefmt='%d %I:%M:%S')
    t = time.time()
    local_time = time.localtime(t)
    if not os.path.exists('./log'):
        os.mkdir('./log')
    fh = logging.FileHandler(os.path.join('log/train-{}{:02}{}'.format(local_time.tm_year % 2000, local_time.tm_mon, t)))
    fh.setFormatter(logging.Formatter(log_format))
    logging.getLogger().addHandler(fh)

    use_gpu = False
    if torch.cuda.is_available():
        use_gpu = True

    assert os.path.exists(args.train_dir)
    train_dataset = datasets.ImageFolder(
        args.train_dir,
        transforms.Compose([
            transforms.RandomResizedCrop(224),
            transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4),
            transforms.RandomHorizontalFlip(0.5),
            ToBGRTensor(),
        ])
    )
    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=args.batch_size, shuffle=True,
        num_workers=1, pin_memory=use_gpu)
    train_dataprovider = DataIterator(train_loader)

    assert os.path.exists(args.val_dir)
    val_loader = torch.utils.data.DataLoader(
        datasets.ImageFolder(args.val_dir, transforms.Compose([
            OpencvResize(256),
            transforms.CenterCrop(224),
            ToBGRTensor(),
        ])),
        batch_size=200, shuffle=False,
        num_workers=1, pin_memory=use_gpu
    )
    val_dataprovider = DataIterator(val_loader)
    print('load data successfully')

    model = TFNet(model_size=args.model_size)

    optimizer = torch.optim.SGD(get_parameters(model),
                                lr=args.learning_rate,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)
    criterion_smooth = CrossEntropyLabelSmooth(1000, 0.1)

    if use_gpu:
        model = nn.DataParallel(model)
        loss_function = criterion_smooth.cuda()
        device = torch.device("cuda")
    else:
        loss_function = criterion_smooth
        device = torch.device("cpu")

    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer,
                    lambda step : (1.0-step/args.total_iters) if step <= args.total_iters else 0, last_epoch=-1)

    model = model.to(device)

    all_iters = 0
    if args.auto_continue:
        lastest_model, iters = get_lastest_model()
        if lastest_model is not None:
            all_iters = iters
            checkpoint = torch.load(lastest_model, map_location=None if use_gpu else 'cpu')
            model.load_state_dict(checkpoint['state_dict'], strict=True)
            print('load from checkpoint')
            for i in range(iters):
                scheduler.step()

    args.optimizer = optimizer
    args.loss_function = loss_function
    args.scheduler = scheduler
    args.train_dataprovider = train_dataprovider
    args.val_dataprovider = val_dataprovider

    if args.eval:
        if args.eval_resume is not None:
            checkpoint = torch.load(args.eval_resume, map_location=None if use_gpu else 'cpu')
            load_checkpoint(model, checkpoint)
            validate(model, device, args, all_iters=all_iters)
        exit(0)

    while all_iters < args.total_iters:
        all_iters = train(model, device, args, val_interval=args.val_interval, bn_process=False, all_iters=all_iters)
        validate(model, device, args, all_iters=all_iters)
    validate(model, device, args, all_iters=all_iters)
    save_checkpoint({'state_dict': model.state_dict(),}, args.total_iters, tag='bnps-')

def adjust_bn_momentum(model, iters):
    for m in model.modules():
        if isinstance(m, nn.BatchNorm2d):
            m.momentum = 1 / iters

def train(model, device, args, *, val_interval, bn_process=False, all_iters=None):

    optimizer = args.optimizer
    loss_function = args.loss_function
    scheduler = args.scheduler
    train_dataprovider = args.train_dataprovider

    t1 = time.time()
    Top1_err, Top5_err = 0.0, 0.0
    model.train()
    for iters in range(1, val_interval + 1):
        scheduler.step()
        if bn_process:
            adjust_bn_momentum(model, iters)

        all_iters += 1
        d_st = time.time()
        data, target = train_dataprovider.next()
        target = target.type(torch.LongTensor)
        data, target = data.to(device), target.to(device)
        data_time = time.time() - d_st

        output = model(data)
        loss = loss_function(output, target)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        prec1, prec5 = accuracy(output, target, topk=(1, 5))

        Top1_err += 1 - prec1.item() / 100
        Top5_err += 1 - prec5.item() / 100

        if all_iters % args.display_interval == 0:
            printInfo = 'TRAIN Iter {}: lr = {:.6f},\tloss = {:.6f},\t'.format(all_iters, scheduler.get_lr()[0], loss.item()) + \
                        'Top-1 err = {:.6f},\t'.format(Top1_err / args.display_interval) + \
                        'Top-5 err = {:.6f},\t'.format(Top5_err / args.display_interval) + \
                        'data_time = {:.6f},\ttrain_time = {:.6f}'.format(data_time, (time.time() - t1) / args.display_interval)
            logging.info(printInfo)
            t1 = time.time()
            Top1_err, Top5_err = 0.0, 0.0

        if all_iters % args.save_interval == 0:
            save_checkpoint({
                'state_dict': model.state_dict(),
                }, all_iters)

    return all_iters

def validate(model, device, args, *, all_iters=None):
    objs = AvgrageMeter()
    top1 = AvgrageMeter()
    top5 = AvgrageMeter()

    loss_function = args.loss_function
    val_dataprovider = args.val_dataprovider

    model.eval()
    max_val_iters = 250
    t1  = time.time()
    with torch.no_grad():
        for _ in range(1, max_val_iters + 1):
            data, target = val_dataprovider.next()
            target = target.type(torch.LongTensor)
            data, target = data.to(device), target.to(device)

            output = model(data)
            loss = loss_function(output, target)

            prec1, prec5 = accuracy(output, target, topk=(1, 5))
            n = data.size(0)
            objs.update(loss.item(), n)
            top1.update(prec1.item(), n)
            top5.update(prec5.item(), n)

    logInfo = 'TEST Iter {}: loss = {:.6f},\t'.format(all_iters, objs.avg) + \
              'Top-1 err = {:.6f},\t'.format(1 - top1.avg / 100) + \
              'Top-5 err = {:.6f},\t'.format(1 - top5.avg / 100) + \
              'val_time = {:.6f}'.format(time.time() - t1)
    logging.info(logInfo)

def load_checkpoint(net, checkpoint):
    from collections import OrderedDict

    temp = OrderedDict()
    if 'state_dict' in checkpoint:
        checkpoint = dict(checkpoint['state_dict'])
    for k in checkpoint:
        k2 = 'module.'+k if not k.startswith('module.') else k
        temp[k2] = checkpoint[k]

    net.load_state_dict(temp, strict=True)

if __name__ == "__main__":
    main()


================================================
FILE: TFNet/utils.py
================================================
import os
import re
import torch
import torch.nn as nn

class CrossEntropyLabelSmooth(nn.Module):

	def __init__(self, num_classes, epsilon):
		super(CrossEntropyLabelSmooth, self).__init__()
		self.num_classes = num_classes
		self.epsilon = epsilon
		self.logsoftmax = nn.LogSoftmax(dim=1)

	def forward(self, inputs, targets):
		log_probs = self.logsoftmax(inputs)
		targets = torch.zeros_like(log_probs).scatter_(1, targets.unsqueeze(1), 1)
		targets = (1 - self.epsilon) * targets + self.epsilon / self.num_classes
		loss = (-targets * log_probs).mean(0).sum()
		return loss


class AvgrageMeter(object):

	def __init__(self):
		self.reset()

	def reset(self):
		self.avg = 0
		self.sum = 0
		self.cnt = 0
		self.val = 0

	def update(self, val, n=1):
		self.val = val
		self.sum += val * n
		self.cnt += n
		self.avg = self.sum / self.cnt


def accuracy(output, target, topk=(1,)):
	maxk = max(topk)
	batch_size = target.size(0)

	_, pred = output.topk(maxk, 1, True, True)
	pred = pred.t()
	correct = pred.eq(target.view(1, -1).expand_as(pred))

	res = []
	for k in topk:
		correct_k = correct[:k].reshape(-1).float().sum(0)
		res.append(correct_k.mul_(100.0/batch_size))
	return res


def save_checkpoint(state, iters, tag=''):
	if not os.path.exists("./models"):
		os.makedirs("./models")
	filename = os.path.join("./models/{}checkpoint-{:06}.pth.tar".format(tag, iters))
	torch.save(state, filename)

def get_lastest_model():
	if not os.path.exists('./models'):
		os.mkdir('./models')
	model_list = os.listdir('./models/')
	if model_list == []:
		return None, 0
	model_list.sort()
	lastest_model = model_list[-1]
	iters = re.findall(r'\d+', lastest_model)
	return './models/' + lastest_model, int(iters[0])


def get_parameters(model):
	group_no_weight_decay = []
	group_weight_decay = []
	for pname, p in model.named_parameters():
		if pname.find('weight') >= 0 and len(p.size()) > 1:
			# print('include ', pname, p.size())
			group_weight_decay.append(p)
		else:
			# print('not include ', pname, p.size())
			group_no_weight_decay.append(p)
	assert len(list(model.parameters())) == len(group_weight_decay) + len(group_no_weight_decay)
	groups = [dict(params=group_weight_decay), dict(params=group_no_weight_decay, weight_decay=0.)]
	return groups


================================================
FILE: acon.py
================================================
import torch
from torch import nn


class AconC(nn.Module):
    r""" ACON activation (activate or not).
    # AconC: (p1*x-p2*x) * sigmoid(beta*(p1*x-p2*x)) + p2*x, beta is a learnable parameter
    # according to "Activate or Not: Learning Customized Activation" <https://arxiv.org/pdf/2009.04759.pdf>.
    """

    def __init__(self, width):
        super().__init__()
        self.p1 = nn.Parameter(torch.randn(1, width, 1, 1))
        self.p2 = nn.Parameter(torch.randn(1, width, 1, 1))
        self.beta = nn.Parameter(torch.ones(1, width, 1, 1))

    def forward(self, x):
        return (self.p1 * x - self.p2 * x) * torch.sigmoid(self.beta * (self.p1 * x - self.p2 * x)) + self.p2 * x


class MetaAconC(nn.Module):
    r""" ACON activation (activate or not).
    # MetaAconC: (p1*x-p2*x) * sigmoid(beta*(p1*x-p2*x)) + p2*x, beta is generated by a small network
    # according to "Activate or Not: Learning Customized Activation" <https://arxiv.org/pdf/2009.04759.pdf>.
    """

    def __init__(self, width, r=16):
        super().__init__()
        self.fc1 = nn.Conv2d(width, max(r, width // r), kernel_size=1, stride=1, bias=True)
        self.bn1 = nn.BatchNorm2d(max(r, width // r))
        self.fc2 = nn.Conv2d(max(r, width // r), width, kernel_size=1, stride=1, bias=True)
        self.bn2 = nn.BatchNorm2d(width)

        self.p1 = nn.Parameter(torch.randn(1, width, 1, 1))
        self.p2 = nn.Parameter(torch.randn(1, width, 1, 1))

    def forward(self, x):
        beta = torch.sigmoid(
            self.bn2(self.fc2(self.bn1(self.fc1(x.mean(dim=2, keepdims=True).mean(dim=3, keepdims=True))))))
        return (self.p1 * x - self.p2 * x) * torch.sigmoid(beta * (self.p1 * x - self.p2 * x)) + self.p2 * x
Download .txt
gitextract_vf0sezxr/

├── ACON/
│   ├── ResNet_ACON/
│   │   ├── resnet_acon.py
│   │   ├── train.py
│   │   └── utils.py
│   └── ShuffleNetV2_ACON/
│       ├── network.py
│       ├── train.py
│       └── utils.py
├── LICENSE
├── MetaACON/
│   ├── ResNet_MetaACON/
│   │   ├── resnet_metaacon.py
│   │   ├── train.py
│   │   └── utils.py
│   └── ShuffleNet_MetaACON/
│       ├── network.py
│       ├── train.py
│       └── utils.py
├── README.md
├── TFNet/
│   ├── README.md
│   ├── network.py
│   ├── train.py
│   └── utils.py
└── acon.py
Download .txt
SYMBOL INDEX (188 symbols across 16 files)

FILE: ACON/ResNet_ACON/resnet_acon.py
  function conv3x3 (line 17) | def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: in...
  function conv1x1 (line 23) | def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d:
  class BasicBlock_ACON (line 27) | class BasicBlock_ACON(nn.Module):
    method __init__ (line 32) | def __init__(
    method forward (line 61) | def forward(self, x: Tensor) -> Tensor:
  class Bottleneck_ACON (line 79) | class Bottleneck_ACON(nn.Module):
    method __init__ (line 90) | def __init__(
    method forward (line 117) | def forward(self, x: Tensor) -> Tensor:
  class ResNet (line 140) | class ResNet(nn.Module):
    method __init__ (line 142) | def __init__(
    method _make_layer (line 201) | def _make_layer(self, block: Type[Union[BasicBlock_ACON, Bottleneck_AC...
    method _forward_impl (line 226) | def _forward_impl(self, x: Tensor) -> Tensor:
    method forward (line 244) | def forward(self, x: Tensor) -> Tensor:
  function _resnet (line 248) | def _resnet(
  function resnet50_acon (line 260) | def resnet50_acon(pretrained: bool = False, progress: bool = True, **kwa...
  function resnet101_acon (line 270) | def resnet101_acon(pretrained: bool = False, progress: bool = True, **kw...
  function resnet152_acon (line 280) | def resnet152_acon(pretrained: bool = False, progress: bool = True, **kw...

FILE: ACON/ResNet_ACON/train.py
  class OpencvResize (line 18) | class OpencvResize(object):
    method __init__ (line 20) | def __init__(self, size=256):
    method __call__ (line 23) | def __call__(self, img):
  class ToBGRTensor (line 36) | class ToBGRTensor(object):
    method __call__ (line 38) | def __call__(self, img):
  class DataIterator (line 48) | class DataIterator(object):
    method __init__ (line 50) | def __init__(self, dataloader):
    method next (line 54) | def next(self):
  function get_args (line 62) | def get_args():
  function main (line 86) | def main():
  function adjust_bn_momentum (line 183) | def adjust_bn_momentum(model, iters):
  function train (line 188) | def train(model, device, args, *, val_interval, bn_process=False, all_it...
  function validate (line 236) | def validate(model, device, args, *, all_iters=None):
  function load_checkpoint (line 268) | def load_checkpoint(net, checkpoint):

FILE: ACON/ResNet_ACON/utils.py
  class CrossEntropyLabelSmooth (line 6) | class CrossEntropyLabelSmooth(nn.Module):
    method __init__ (line 8) | def __init__(self, num_classes, epsilon):
    method forward (line 14) | def forward(self, inputs, targets):
  class AvgrageMeter (line 22) | class AvgrageMeter(object):
    method __init__ (line 24) | def __init__(self):
    method reset (line 27) | def reset(self):
    method update (line 33) | def update(self, val, n=1):
  function accuracy (line 40) | def accuracy(output, target, topk=(1,)):
  function save_checkpoint (line 55) | def save_checkpoint(state, iters, tag=''):
  function get_lastest_model (line 61) | def get_lastest_model():
  function get_parameters (line 73) | def get_parameters(model):

FILE: ACON/ShuffleNetV2_ACON/network.py
  class ShuffleV2Block_ACON (line 8) | class ShuffleV2Block_ACON(nn.Module):
    method __init__ (line 9) | def __init__(self, inp, oup, mid_channels, *, ksize, stride):
    method forward (line 51) | def forward(self, old_x):
    method channel_shuffle (line 60) | def channel_shuffle(self, x):
  class ShuffleNetV2_ACON (line 69) | class ShuffleNetV2_ACON(nn.Module):
    method __init__ (line 70) | def __init__(self, input_size=224, n_class=1000, model_size='1.5x'):
    method forward (line 125) | def forward(self, x):
    method _initialize_weights (line 138) | def _initialize_weights(self):

FILE: ACON/ShuffleNetV2_ACON/train.py
  class OpencvResize (line 18) | class OpencvResize(object):
    method __init__ (line 20) | def __init__(self, size=256):
    method __call__ (line 23) | def __call__(self, img):
  class ToBGRTensor (line 36) | class ToBGRTensor(object):
    method __call__ (line 38) | def __call__(self, img):
  class DataIterator (line 48) | class DataIterator(object):
    method __init__ (line 50) | def __init__(self, dataloader):
    method next (line 54) | def next(self):
  function get_args (line 62) | def get_args():
  function main (line 88) | def main():
  function adjust_bn_momentum (line 186) | def adjust_bn_momentum(model, iters):
  function train (line 191) | def train(model, device, args, *, val_interval, bn_process=False, all_it...
  function validate (line 239) | def validate(model, device, args, *, all_iters=None):
  function load_checkpoint (line 271) | def load_checkpoint(net, checkpoint):

FILE: ACON/ShuffleNetV2_ACON/utils.py
  class CrossEntropyLabelSmooth (line 6) | class CrossEntropyLabelSmooth(nn.Module):
    method __init__ (line 8) | def __init__(self, num_classes, epsilon):
    method forward (line 14) | def forward(self, inputs, targets):
  class AvgrageMeter (line 22) | class AvgrageMeter(object):
    method __init__ (line 24) | def __init__(self):
    method reset (line 27) | def reset(self):
    method update (line 33) | def update(self, val, n=1):
  function accuracy (line 40) | def accuracy(output, target, topk=(1,)):
  function save_checkpoint (line 55) | def save_checkpoint(state, iters, tag=''):
  function get_lastest_model (line 61) | def get_lastest_model():
  function get_parameters (line 73) | def get_parameters(model):

FILE: MetaACON/ResNet_MetaACON/resnet_metaacon.py
  function conv3x3 (line 17) | def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: in...
  function conv1x1 (line 23) | def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d:
  class Bottleneck_MetaACON (line 28) | class Bottleneck_MetaACON(nn.Module):
    method __init__ (line 40) | def __init__(
    method forward (line 67) | def forward(self, x: Tensor) -> Tensor:
  class ResNet (line 89) | class ResNet(nn.Module):
    method __init__ (line 91) | def __init__(
    method _make_layer (line 148) | def _make_layer(self, block: Type[Union[Bottleneck_MetaACON]], planes:...
    method _forward_impl (line 173) | def _forward_impl(self, x: Tensor) -> Tensor:
    method forward (line 191) | def forward(self, x: Tensor) -> Tensor:
  function _resnet (line 195) | def _resnet(
  function resnet50_metaacon (line 207) | def resnet50_metaacon(pretrained: bool = False, progress: bool = True, *...
  function resnet101_metaacon (line 217) | def resnet101_metaacon(pretrained: bool = False, progress: bool = True, ...
  function resnet152_metaacon (line 227) | def resnet152_metaacon(pretrained: bool = False, progress: bool = True, ...

FILE: MetaACON/ResNet_MetaACON/train.py
  class OpencvResize (line 18) | class OpencvResize(object):
    method __init__ (line 20) | def __init__(self, size=256):
    method __call__ (line 23) | def __call__(self, img):
  class ToBGRTensor (line 36) | class ToBGRTensor(object):
    method __call__ (line 38) | def __call__(self, img):
  class DataIterator (line 48) | class DataIterator(object):
    method __init__ (line 50) | def __init__(self, dataloader):
    method next (line 54) | def next(self):
  function get_args (line 62) | def get_args():
  function main (line 86) | def main():
  function adjust_bn_momentum (line 183) | def adjust_bn_momentum(model, iters):
  function train (line 188) | def train(model, device, args, *, val_interval, bn_process=False, all_it...
  function validate (line 236) | def validate(model, device, args, *, all_iters=None):
  function load_checkpoint (line 268) | def load_checkpoint(net, checkpoint):

FILE: MetaACON/ResNet_MetaACON/utils.py
  class CrossEntropyLabelSmooth (line 6) | class CrossEntropyLabelSmooth(nn.Module):
    method __init__ (line 8) | def __init__(self, num_classes, epsilon):
    method forward (line 14) | def forward(self, inputs, targets):
  class AvgrageMeter (line 22) | class AvgrageMeter(object):
    method __init__ (line 24) | def __init__(self):
    method reset (line 27) | def reset(self):
    method update (line 33) | def update(self, val, n=1):
  function accuracy (line 40) | def accuracy(output, target, topk=(1,)):
  function save_checkpoint (line 55) | def save_checkpoint(state, iters, tag=''):
  function get_lastest_model (line 61) | def get_lastest_model():
  function get_parameters (line 73) | def get_parameters(model):

FILE: MetaACON/ShuffleNet_MetaACON/network.py
  class ShuffleV2Block_MetaACON (line 8) | class ShuffleV2Block_MetaACON(nn.Module):
    method __init__ (line 9) | def __init__(self, inp, oup, mid_channels, *, ksize, stride, r=16):
    method forward (line 51) | def forward(self, old_x):
    method channel_shuffle (line 60) | def channel_shuffle(self, x):
  class ShuffleNetV2_MetaACON (line 69) | class ShuffleNetV2_MetaACON(nn.Module):
    method __init__ (line 70) | def __init__(self, input_size=224, n_class=1000, model_size='1.5x'):
    method forward (line 127) | def forward(self, x):
    method _initialize_weights (line 140) | def _initialize_weights(self):

FILE: MetaACON/ShuffleNet_MetaACON/train.py
  class OpencvResize (line 18) | class OpencvResize(object):
    method __init__ (line 20) | def __init__(self, size=256):
    method __call__ (line 23) | def __call__(self, img):
  class ToBGRTensor (line 36) | class ToBGRTensor(object):
    method __call__ (line 38) | def __call__(self, img):
  class DataIterator (line 48) | class DataIterator(object):
    method __init__ (line 50) | def __init__(self, dataloader):
    method next (line 54) | def next(self):
  function get_args (line 62) | def get_args():
  function main (line 88) | def main():
  function adjust_bn_momentum (line 186) | def adjust_bn_momentum(model, iters):
  function train (line 191) | def train(model, device, args, *, val_interval, bn_process=False, all_it...
  function validate (line 239) | def validate(model, device, args, *, all_iters=None):
  function load_checkpoint (line 271) | def load_checkpoint(net, checkpoint):

FILE: MetaACON/ShuffleNet_MetaACON/utils.py
  class CrossEntropyLabelSmooth (line 6) | class CrossEntropyLabelSmooth(nn.Module):
    method __init__ (line 8) | def __init__(self, num_classes, epsilon):
    method forward (line 14) | def forward(self, inputs, targets):
  class AvgrageMeter (line 22) | class AvgrageMeter(object):
    method __init__ (line 24) | def __init__(self):
    method reset (line 27) | def reset(self):
    method update (line 33) | def update(self, val, n=1):
  function accuracy (line 40) | def accuracy(output, target, topk=(1,)):
  function save_checkpoint (line 55) | def save_checkpoint(state, iters, tag=''):
  function get_lastest_model (line 61) | def get_lastest_model():
  function get_parameters (line 73) | def get_parameters(model):

FILE: TFNet/network.py
  class Acon_FReLU (line 4) | class Acon_FReLU(nn.Module):
    method __init__ (line 9) | def __init__(self, width, stride=1):
    method forward (line 22) | def forward(self, x, **kwargs):
  class TFBlock (line 33) | class TFBlock(nn.Module):
    method __init__ (line 34) | def __init__(self, inp, stride):
    method forward (line 52) | def forward(self, x):
  class TFNet (line 63) | class TFNet(nn.Module):
    method __init__ (line 64) | def __init__(self, n_class=1000, model_size=0.5):
    method forward (line 101) | def forward(self, x):
    method _initialize_weights (line 113) | def _initialize_weights(self):

FILE: TFNet/train.py
  class OpencvResize (line 18) | class OpencvResize(object):
    method __init__ (line 20) | def __init__(self, size=256):
    method __call__ (line 23) | def __call__(self, img):
  class ToBGRTensor (line 36) | class ToBGRTensor(object):
    method __call__ (line 38) | def __call__(self, img):
  class DataIterator (line 48) | class DataIterator(object):
    method __init__ (line 50) | def __init__(self, dataloader):
    method next (line 54) | def next(self):
  function get_args (line 62) | def get_args():
  function main (line 88) | def main():
  function adjust_bn_momentum (line 186) | def adjust_bn_momentum(model, iters):
  function train (line 191) | def train(model, device, args, *, val_interval, bn_process=False, all_it...
  function validate (line 239) | def validate(model, device, args, *, all_iters=None):
  function load_checkpoint (line 271) | def load_checkpoint(net, checkpoint):

FILE: TFNet/utils.py
  class CrossEntropyLabelSmooth (line 6) | class CrossEntropyLabelSmooth(nn.Module):
    method __init__ (line 8) | def __init__(self, num_classes, epsilon):
    method forward (line 14) | def forward(self, inputs, targets):
  class AvgrageMeter (line 22) | class AvgrageMeter(object):
    method __init__ (line 24) | def __init__(self):
    method reset (line 27) | def reset(self):
    method update (line 33) | def update(self, val, n=1):
  function accuracy (line 40) | def accuracy(output, target, topk=(1,)):
  function save_checkpoint (line 55) | def save_checkpoint(state, iters, tag=''):
  function get_lastest_model (line 61) | def get_lastest_model():
  function get_parameters (line 73) | def get_parameters(model):

FILE: acon.py
  class AconC (line 5) | class AconC(nn.Module):
    method __init__ (line 11) | def __init__(self, width):
    method forward (line 17) | def forward(self, x):
  class MetaAconC (line 21) | class MetaAconC(nn.Module):
    method __init__ (line 27) | def __init__(self, width, r=16):
    method forward (line 37) | def forward(self, x):
Condensed preview — 19 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (116K chars).
[
  {
    "path": "ACON/ResNet_ACON/resnet_acon.py",
    "chars": 11048,
    "preview": "import torch\nfrom torch import Tensor\nimport torch.nn as nn\nfrom typing import Type, Any, Callable, Union, List, Optiona"
  },
  {
    "path": "ACON/ResNet_ACON/train.py",
    "chars": 10314,
    "preview": "import os\nimport sys\nimport torch\nimport argparse\nimport torch.nn as nn\nimport torchvision.transforms as transforms\nimpo"
  },
  {
    "path": "ACON/ResNet_ACON/utils.py",
    "chars": 2259,
    "preview": "import os\nimport re\nimport torch\nimport torch.nn as nn\n\nclass CrossEntropyLabelSmooth(nn.Module):\n\n\tdef __init__(self, n"
  },
  {
    "path": "ACON/ShuffleNetV2_ACON/network.py",
    "chars": 5967,
    "preview": "import torch\nimport torch.nn as nn\n\nimport sys\nsys.path.insert(0,'../..')\nfrom acon import AconC\n\nclass ShuffleV2Block_A"
  },
  {
    "path": "ACON/ShuffleNetV2_ACON/train.py",
    "chars": 10593,
    "preview": "import os\nimport sys\nimport torch\nimport argparse\nimport torch.nn as nn\nimport torchvision.transforms as transforms\nimpo"
  },
  {
    "path": "ACON/ShuffleNetV2_ACON/utils.py",
    "chars": 2259,
    "preview": "import os\nimport re\nimport torch\nimport torch.nn as nn\n\nclass CrossEntropyLabelSmooth(nn.Module):\n\n\tdef __init__(self, n"
  },
  {
    "path": "LICENSE",
    "chars": 1063,
    "preview": "MIT License\n\nCopyright (c) 2021 nmaac \n\nPermission is hereby granted, free of charge, to any person obtaining a copy\nof "
  },
  {
    "path": "MetaACON/ResNet_MetaACON/resnet_metaacon.py",
    "chars": 9352,
    "preview": "import torch\nfrom torch import Tensor\nimport torch.nn as nn\nfrom typing import Type, Any, Callable, Union, List, Optiona"
  },
  {
    "path": "MetaACON/ResNet_MetaACON/train.py",
    "chars": 10330,
    "preview": "import os\nimport sys\nimport torch\nimport argparse\nimport torch.nn as nn\nimport torchvision.transforms as transforms\nimpo"
  },
  {
    "path": "MetaACON/ResNet_MetaACON/utils.py",
    "chars": 2259,
    "preview": "import os\nimport re\nimport torch\nimport torch.nn as nn\n\nclass CrossEntropyLabelSmooth(nn.Module):\n\n\tdef __init__(self, n"
  },
  {
    "path": "MetaACON/ShuffleNet_MetaACON/network.py",
    "chars": 6119,
    "preview": "import torch\nimport torch.nn as nn\n\nimport sys\nsys.path.insert(0,'../..')\nfrom acon import MetaAconC\n\nclass ShuffleV2Blo"
  },
  {
    "path": "MetaACON/ShuffleNet_MetaACON/train.py",
    "chars": 10609,
    "preview": "import os\nimport sys\nimport torch\nimport argparse\nimport torch.nn as nn\nimport torchvision.transforms as transforms\nimpo"
  },
  {
    "path": "MetaACON/ShuffleNet_MetaACON/utils.py",
    "chars": 2259,
    "preview": "import os\nimport re\nimport torch\nimport torch.nn as nn\n\nclass CrossEntropyLabelSmooth(nn.Module):\n\n\tdef __init__(self, n"
  },
  {
    "path": "README.md",
    "chars": 5134,
    "preview": "\n## CVPR 2021 | Activate or Not: Learning Customized Activation.\n\nThis repository contains the official Pytorch implemen"
  },
  {
    "path": "TFNet/README.md",
    "chars": 2403,
    "preview": "# [TFNet](https://arxiv.org/pdf/2009.04759.pdf)\nThis repository contains TFNet implementation by Pytorch.\n\n\n### TFNet\nTo"
  },
  {
    "path": "TFNet/network.py",
    "chars": 4609,
    "preview": "import torch\nimport torch.nn as nn\n\nclass Acon_FReLU(nn.Module):\n    r\"\"\" ACON activation (activate or not) based on FRe"
  },
  {
    "path": "TFNet/train.py",
    "chars": 10532,
    "preview": "import os\nimport sys\nimport torch\nimport argparse\nimport torch.nn as nn\nimport torchvision.transforms as transforms\nimpo"
  },
  {
    "path": "TFNet/utils.py",
    "chars": 2259,
    "preview": "import os\nimport re\nimport torch\nimport torch.nn as nn\n\nclass CrossEntropyLabelSmooth(nn.Module):\n\n\tdef __init__(self, n"
  },
  {
    "path": "acon.py",
    "chars": 1726,
    "preview": "import torch\nfrom torch import nn\n\n\nclass AconC(nn.Module):\n    r\"\"\" ACON activation (activate or not).\n    # AconC: (p1"
  }
]

About this extraction

This page contains the full source code of the nmaac/acon GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 19 files (108.5 KB), approximately 29.5k tokens, and a symbol index with 188 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!