[
  {
    "path": "ACON/ResNet_ACON/resnet_acon.py",
    "content": "import torch\nfrom torch import Tensor\nimport torch.nn as nn\nfrom typing import Type, Any, Callable, Union, List, Optional\n\nimport sys\nsys.path.insert(0,'../..')\nfrom acon import AconC\n\n\n__all__ = ['ResNet', 'resnet50_acon', 'resnet101_acon', 'resnet152_acon']\n\n\nmodel_urls = {}\n\n\ndef conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1) -> nn.Conv2d:\n    \"\"\"3x3 convolution with padding\"\"\"\n    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,\n                     padding=dilation, groups=groups, bias=True, dilation=dilation)\n\n\ndef conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d:\n    \"\"\"1x1 convolution\"\"\"\n    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=True)\n\nclass BasicBlock_ACON(nn.Module):\n    # We change the ReLU activation functions to ACON-C\n    # according to \"Activate or Not: Learning Customized Activation\" <https://arxiv.org/pdf/2009.04759.pdf>.\n    expansion: int = 1\n\n    def __init__(\n        self,\n        inplanes: int,\n        planes: int,\n        stride: int = 1,\n        downsample: Optional[nn.Module] = None,\n        groups: int = 1,\n        base_width: int = 64,\n        dilation: int = 1,\n        norm_layer: Optional[Callable[..., nn.Module]] = None\n    ) -> None:\n        super(BasicBlock_ACON, self).__init__()\n        if norm_layer is None:\n            norm_layer = nn.BatchNorm2d\n        if groups != 1 or base_width != 64:\n            raise ValueError('BasicBlock only supports groups=1 and base_width=64')\n        if dilation > 1:\n            raise NotImplementedError(\"Dilation > 1 not supported in BasicBlock\")\n        # Both self.conv1 and self.downsample layers downsample the input when stride != 1\n        self.conv1 = conv3x3(inplanes, planes, stride)\n        self.bn1 = norm_layer(planes)\n        self.acon1 = AconC(planes)\n        self.relu = nn.ReLU(inplace=True)\n        self.conv2 = conv3x3(planes, planes)\n        self.bn2 = norm_layer(planes)\n        self.acon2 = AconC(planes)\n        self.downsample = downsample\n        self.stride = stride\n\n    def forward(self, x: Tensor) -> Tensor:\n        identity = x\n\n        out = self.conv1(x)\n        out = self.bn1(out)\n        out = self.acon1(out)\n\n        out = self.conv2(out)\n        out = self.bn2(out)\n\n        if self.downsample is not None:\n            identity = self.downsample(x)\n\n        out += identity\n        out = self.acon2(out)\n\n        return out\n\nclass Bottleneck_ACON(nn.Module):\n    # We change the ReLU activation function after the 3x3 convolution(self.conv2) to ACON-C\n    # according to \"Activate or Not: Learning Customized Activation\" <https://arxiv.org/pdf/2009.04759.pdf>.\n\n    # We use the original implementation which places the stride at the first 1x1 convolution(self.conv1)\n    # according to \"Deep residual learning for image recognition\"https://arxiv.org/abs/1512.03385.\n    # This variant is also known as ResNet V1.5 and improves accuracy according to\n    # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.\n\n    expansion: int = 4\n\n    def __init__(\n        self,\n        inplanes: int,\n        planes: int,\n        stride: int = 1,\n        downsample: Optional[nn.Module] = None,\n        groups: int = 1,\n        base_width: int = 64,\n        dilation: int = 1,\n        norm_layer: Optional[Callable[..., nn.Module]] = None\n    ) -> None:\n        super(Bottleneck_ACON, self).__init__()\n        if norm_layer is None:\n            norm_layer = nn.BatchNorm2d\n        width = int(planes * (base_width / 64.)) * groups\n        # Both self.conv2 and self.downsample layers downsample the input when stride != 1\n        self.conv1 = conv1x1(inplanes, width, stride)\n        self.bn1 = norm_layer(width)\n        self.conv2 = conv3x3(width, width, 1, groups, dilation)\n        self.bn2 = norm_layer(width)\n        self.acon = AconC(width)\n        self.conv3 = conv1x1(width, planes * self.expansion)\n        self.bn3 = norm_layer(planes * self.expansion)\n        self.relu = nn.ReLU(inplace=True)\n        self.downsample = downsample\n        self.stride = stride\n\n    def forward(self, x: Tensor) -> Tensor:\n        identity = x\n\n        out = self.conv1(x)\n        out = self.bn1(out)\n        out = self.relu(out)\n\n        out = self.conv2(out)\n        out = self.bn2(out)\n        out = self.acon(out)\n\n        out = self.conv3(out)\n        out = self.bn3(out)\n\n        if self.downsample is not None:\n            identity = self.downsample(x)\n\n        out += identity\n        out = self.relu(out)\n\n        return out\n\n\nclass ResNet(nn.Module):\n\n    def __init__(\n        self,\n        block: Type[Union[BasicBlock_ACON, Bottleneck_ACON]],\n        layers: List[int],\n        num_classes: int = 1000,\n        zero_init_residual: bool = False,\n        groups: int = 1,\n        width_per_group: int = 64,\n        replace_stride_with_dilation: Optional[List[bool]] = None,\n        norm_layer: Optional[Callable[..., nn.Module]] = None\n    ) -> None:\n        super(ResNet, self).__init__()\n        if norm_layer is None:\n            norm_layer = nn.BatchNorm2d\n        self._norm_layer = norm_layer\n\n        self.inplanes = 64\n        self.dilation = 1\n        if replace_stride_with_dilation is None:\n            # each element in the tuple indicates if we should replace\n            # the 2x2 stride with a dilated convolution instead\n            replace_stride_with_dilation = [False, False, False]\n        if len(replace_stride_with_dilation) != 3:\n            raise ValueError(\"replace_stride_with_dilation should be None \"\n                             \"or a 3-element tuple, got {}\".format(replace_stride_with_dilation))\n        self.groups = groups\n        self.base_width = width_per_group\n        self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,\n                               bias=True)\n        self.bn1 = norm_layer(self.inplanes)\n        self.relu = nn.ReLU(inplace=True)\n        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)\n        self.layer1 = self._make_layer(block, 64, layers[0])\n        self.layer2 = self._make_layer(block, 128, layers[1], stride=2,\n                                       dilate=replace_stride_with_dilation[0])\n        self.layer3 = self._make_layer(block, 256, layers[2], stride=2,\n                                       dilate=replace_stride_with_dilation[1])\n        self.layer4 = self._make_layer(block, 512, layers[3], stride=2,\n                                       dilate=replace_stride_with_dilation[2])\n        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))\n        self.fc = nn.Linear(512 * block.expansion, num_classes)\n\n        for m in self.modules():\n            if isinstance(m, nn.Conv2d):\n                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')\n            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):\n                nn.init.constant_(m.weight, 1)\n                nn.init.constant_(m.bias, 0)\n\n        # Zero-initialize the last BN in each residual branch,\n        # so that the residual branch starts with zeros, and each residual block behaves like an identity.\n        # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677\n        if zero_init_residual:\n            for m in self.modules():\n                if isinstance(m, Bottleneck_ACON):\n                    nn.init.constant_(m.bn3.weight, 0)  # type: ignore[arg-type]\n                elif isinstance(m, BasicBlock_ACON):\n                    nn.init.constant_(m.bn2.weight, 0)  # type: ignore[arg-type]\n\n    def _make_layer(self, block: Type[Union[BasicBlock_ACON, Bottleneck_ACON]], planes: int, blocks: int,\n                    stride: int = 1, dilate: bool = False) -> nn.Sequential:\n        norm_layer = self._norm_layer\n        downsample = None\n        previous_dilation = self.dilation\n        if dilate:\n            self.dilation *= stride\n            stride = 1\n        if stride != 1 or self.inplanes != planes * block.expansion:\n            downsample = nn.Sequential(\n                conv1x1(self.inplanes, planes * block.expansion, stride),\n                norm_layer(planes * block.expansion),\n            )\n\n        layers = []\n        layers.append(block(self.inplanes, planes, stride, downsample, self.groups,\n                            self.base_width, previous_dilation, norm_layer))\n        self.inplanes = planes * block.expansion\n        for _ in range(1, blocks):\n            layers.append(block(self.inplanes, planes, groups=self.groups,\n                                base_width=self.base_width, dilation=self.dilation,\n                                norm_layer=norm_layer))\n\n        return nn.Sequential(*layers)\n\n    def _forward_impl(self, x: Tensor) -> Tensor:\n        # See note [TorchScript super()]\n        x = self.conv1(x)\n        x = self.bn1(x)\n        x = self.relu(x)\n        x = self.maxpool(x)\n\n        x = self.layer1(x)\n        x = self.layer2(x)\n        x = self.layer3(x)\n        x = self.layer4(x)\n\n        x = self.avgpool(x)\n        x = torch.flatten(x, 1)\n        x = self.fc(x)\n\n        return x\n\n    def forward(self, x: Tensor) -> Tensor:\n        return self._forward_impl(x)\n\n\ndef _resnet(\n    arch: str,\n    block: Type[Union[BasicBlock_ACON, Bottleneck_ACON]],\n    layers: List[int],\n    pretrained: bool,\n    progress: bool,\n    **kwargs: Any\n) -> ResNet:\n    model = ResNet(block, layers, **kwargs)\n    return model\n\n\ndef resnet50_acon(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:\n    r\"\"\"ResNet-50-acon model from\n    `\"Activate or Not: Learning Customized Activation\" <https://arxiv.org/pdf/2009.04759.pdf>`_.\n    Args:\n        pretrained (bool): If True, returns a model pre-trained on ImageNet\n        progress (bool): If True, displays a progress bar of the download to stderr\n    \"\"\"\n    return _resnet('resnet50_acon', Bottleneck_ACON, [3, 4, 6, 3], pretrained, progress,\n                   **kwargs)\n\ndef resnet101_acon(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:\n    r\"\"\"ResNet-101-acon model from\n    `\"Activate or Not: Learning Customized Activation\" <https://arxiv.org/pdf/2009.04759.pdf>`_.\n    Args:\n        pretrained (bool): If True, returns a model pre-trained on ImageNet\n        progress (bool): If True, displays a progress bar of the download to stderr\n    \"\"\"\n    return _resnet('resnet101_acon', Bottleneck_ACON, [3, 4, 23, 3], pretrained, progress,\n                   **kwargs)\n\ndef resnet152_acon(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:\n    r\"\"\"ResNet-152-acon model from\n    `\"Activate or Not: Learning Customized Activation\" <https://arxiv.org/pdf/2009.04759.pdf>`_.\n    Args:\n        pretrained (bool): If True, returns a model pre-trained on ImageNet\n        progress (bool): If True, displays a progress bar of the download to stderr\n    \"\"\"\n    return _resnet('resnet152_acon', Bottleneck_ACON, [3, 8, 36, 3], pretrained, progress,\n                   **kwargs)\n\n"
  },
  {
    "path": "ACON/ResNet_ACON/train.py",
    "content": "import os\nimport sys\nimport torch\nimport argparse\nimport torch.nn as nn\nimport torchvision.transforms as transforms\nimport torchvision.datasets as datasets\nimport cv2\nimport numpy as np\nimport PIL\nfrom PIL import Image\nimport time\nimport logging\nimport argparse\nfrom resnet_acon import resnet50_acon\nfrom utils import accuracy, AvgrageMeter, CrossEntropyLabelSmooth, save_checkpoint, get_lastest_model, get_parameters\n\nclass OpencvResize(object):\n\n    def __init__(self, size=256):\n        self.size = size\n\n    def __call__(self, img):\n        assert isinstance(img, PIL.Image.Image)\n        img = np.asarray(img) # (H,W,3) RGB\n        img = img[:,:,::-1] # 2 BGR\n        img = np.ascontiguousarray(img)\n        H, W, _ = img.shape\n        target_size = (int(self.size/H * W + 0.5), self.size) if H < W else (self.size, int(self.size/W * H + 0.5))\n        img = cv2.resize(img, target_size, interpolation=cv2.INTER_LINEAR)\n        img = img[:,:,::-1] # 2 RGB\n        img = np.ascontiguousarray(img)\n        img = Image.fromarray(img)\n        return img\n\nclass ToBGRTensor(object):\n\n    def __call__(self, img):\n        assert isinstance(img, (np.ndarray, PIL.Image.Image))\n        if isinstance(img, PIL.Image.Image):\n            img = np.asarray(img)\n        img = img[:,:,::-1] # 2 BGR\n        img = np.transpose(img, [2, 0, 1]) # 2 (3, H, W)\n        img = np.ascontiguousarray(img)\n        img = torch.from_numpy(img).float()\n        return img\n\nclass DataIterator(object):\n\n    def __init__(self, dataloader):\n        self.dataloader = dataloader\n        self.iterator = enumerate(self.dataloader)\n\n    def next(self):\n        try:\n            _, data = next(self.iterator)\n        except Exception:\n            self.iterator = enumerate(self.dataloader)\n            _, data = next(self.iterator)\n        return data[0], data[1]\n\ndef get_args():\n    parser = argparse.ArgumentParser(\"ResNet\")\n    parser.add_argument('--eval', default=False, action='store_true')\n    parser.add_argument('--eval-resume', type=str, default='./res50.acon.pth', help='path for eval model')\n    parser.add_argument('--batch-size', type=int, default=256, help='batch size')\n    parser.add_argument('--total-iters', type=int, default=600000, help='total iters')\n    parser.add_argument('--learning-rate', type=float, default=0.1, help='init learning rate')\n    parser.add_argument('--momentum', type=float, default=0.9, help='momentum')\n    parser.add_argument('--weight-decay', type=float, default=1e-4, help='weight decay')\n    parser.add_argument('--save', type=str, default='./models', help='path for saving trained models')\n\n    parser.add_argument('--auto-continue', type=bool, default=True, help='auto continue')\n    parser.add_argument('--display-interval', type=int, default=20, help='display interval')\n    parser.add_argument('--val-interval', type=int, default=50000, help='val interval')\n    parser.add_argument('--save-interval', type=int, default=50000, help='save interval')\n\n\n\n    parser.add_argument('--train-dir', type=str, default='data/train', help='path to training dataset')\n    parser.add_argument('--val-dir', type=str, default='data/val', help='path to validation dataset')\n\n    args = parser.parse_args()\n    return args\n\ndef main():\n    args = get_args()\n\n    # Log\n    log_format = '[%(asctime)s] %(message)s'\n    logging.basicConfig(stream=sys.stdout, level=logging.INFO,\n        format=log_format, datefmt='%d %I:%M:%S')\n    t = time.time()\n    local_time = time.localtime(t)\n    if not os.path.exists('./log'):\n        os.mkdir('./log')\n    fh = logging.FileHandler(os.path.join('log/train-{}{:02}{}'.format(local_time.tm_year % 2000, local_time.tm_mon, t)))\n    fh.setFormatter(logging.Formatter(log_format))\n    logging.getLogger().addHandler(fh)\n\n    use_gpu = False\n    if torch.cuda.is_available():\n        use_gpu = True\n\n    assert os.path.exists(args.train_dir)\n    train_dataset = datasets.ImageFolder(\n        args.train_dir,\n        transforms.Compose([\n            transforms.RandomResizedCrop(224),\n            transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4),\n            transforms.RandomHorizontalFlip(0.5),\n            ToBGRTensor(),\n        ])\n    )\n    train_loader = torch.utils.data.DataLoader(\n        train_dataset, batch_size=args.batch_size, shuffle=True,\n        num_workers=1, pin_memory=use_gpu)\n    train_dataprovider = DataIterator(train_loader)\n\n    assert os.path.exists(args.val_dir)\n    val_loader = torch.utils.data.DataLoader(\n        datasets.ImageFolder(args.val_dir, transforms.Compose([\n            OpencvResize(256),\n            transforms.CenterCrop(224),\n            ToBGRTensor(),\n        ])),\n        batch_size=200, shuffle=False,\n        num_workers=1, pin_memory=use_gpu\n    )\n    val_dataprovider = DataIterator(val_loader)\n    print('load data successfully')\n\n    model = resnet50_acon()\n    optimizer = torch.optim.SGD(get_parameters(model),\n                                lr=args.learning_rate,\n                                momentum=args.momentum,\n                                weight_decay=args.weight_decay)\n    criterion_smooth = CrossEntropyLabelSmooth(1000, 0.0)\n\n    if use_gpu:\n        model = nn.DataParallel(model)\n        loss_function = criterion_smooth.cuda()\n        device = torch.device(\"cuda\")\n    else:\n        loss_function = criterion_smooth\n        device = torch.device(\"cpu\")\n\n    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer,\n                    lambda step : (1.0-step/args.total_iters) if step <= args.total_iters else 0, last_epoch=-1)\n\n    model = model.to(device)\n\n    all_iters = 0\n    if args.auto_continue:\n        lastest_model, iters = get_lastest_model()\n        if lastest_model is not None:\n            all_iters = iters\n            checkpoint = torch.load(lastest_model, map_location=None if use_gpu else 'cpu')\n            model.load_state_dict(checkpoint['state_dict'], strict=True)\n            print('load from checkpoint')\n            for i in range(iters):\n                scheduler.step()\n\n    args.optimizer = optimizer\n    args.loss_function = loss_function\n    args.scheduler = scheduler\n    args.train_dataprovider = train_dataprovider\n    args.val_dataprovider = val_dataprovider\n\n    if args.eval:\n        if args.eval_resume is not None:\n            checkpoint = torch.load(args.eval_resume, map_location=None if use_gpu else 'cpu')\n            load_checkpoint(model, checkpoint)\n            validate(model, device, args, all_iters=all_iters)\n        exit(0)\n\n    while all_iters < args.total_iters:\n        all_iters = train(model, device, args, val_interval=args.val_interval, bn_process=False, all_iters=all_iters)\n        validate(model, device, args, all_iters=all_iters)\n    validate(model, device, args, all_iters=all_iters)\n    save_checkpoint({'state_dict': model.state_dict(),}, args.total_iters, tag='bnps-')\n\ndef adjust_bn_momentum(model, iters):\n    for m in model.modules():\n        if isinstance(m, nn.BatchNorm2d):\n            m.momentum = 1 / iters\n\ndef train(model, device, args, *, val_interval, bn_process=False, all_iters=None):\n\n    optimizer = args.optimizer\n    loss_function = args.loss_function\n    scheduler = args.scheduler\n    train_dataprovider = args.train_dataprovider\n\n    t1 = time.time()\n    Top1_err, Top5_err = 0.0, 0.0\n    model.train()\n    for iters in range(1, val_interval + 1):\n        scheduler.step()\n        if bn_process:\n            adjust_bn_momentum(model, iters)\n\n        all_iters += 1\n        d_st = time.time()\n        data, target = train_dataprovider.next()\n        target = target.type(torch.LongTensor)\n        data, target = data.to(device), target.to(device)\n        data_time = time.time() - d_st\n\n        output = model(data)\n        loss = loss_function(output, target)\n        optimizer.zero_grad()\n        loss.backward()\n        optimizer.step()\n        prec1, prec5 = accuracy(output, target, topk=(1, 5))\n\n        Top1_err += 1 - prec1.item() / 100\n        Top5_err += 1 - prec5.item() / 100\n\n        if all_iters % args.display_interval == 0:\n            printInfo = 'TRAIN Iter {}: lr = {:.6f},\\tloss = {:.6f},\\t'.format(all_iters, scheduler.get_lr()[0], loss.item()) + \\\n                        'Top-1 err = {:.6f},\\t'.format(Top1_err / args.display_interval) + \\\n                        'Top-5 err = {:.6f},\\t'.format(Top5_err / args.display_interval) + \\\n                        'data_time = {:.6f},\\ttrain_time = {:.6f}'.format(data_time, (time.time() - t1) / args.display_interval)\n            logging.info(printInfo)\n            t1 = time.time()\n            Top1_err, Top5_err = 0.0, 0.0\n\n        if all_iters % args.save_interval == 0:\n            save_checkpoint({\n                'state_dict': model.state_dict(),\n                }, all_iters)\n\n    return all_iters\n\ndef validate(model, device, args, *, all_iters=None):\n    objs = AvgrageMeter()\n    top1 = AvgrageMeter()\n    top5 = AvgrageMeter()\n\n    loss_function = args.loss_function\n    val_dataprovider = args.val_dataprovider\n\n    model.eval()\n    max_val_iters = 250\n    t1  = time.time()\n    with torch.no_grad():\n        for _ in range(1, max_val_iters + 1):\n            data, target = val_dataprovider.next()\n            target = target.type(torch.LongTensor)\n            data, target = data.to(device), target.to(device)\n\n            output = model(data)\n            loss = loss_function(output, target)\n\n            prec1, prec5 = accuracy(output, target, topk=(1, 5))\n            n = data.size(0)\n            objs.update(loss.item(), n)\n            top1.update(prec1.item(), n)\n            top5.update(prec5.item(), n)\n\n    logInfo = 'TEST Iter {}: loss = {:.6f},\\t'.format(all_iters, objs.avg) + \\\n              'Top-1 err = {:.6f},\\t'.format(1 - top1.avg / 100) + \\\n              'Top-5 err = {:.6f},\\t'.format(1 - top5.avg / 100) + \\\n              'val_time = {:.6f}'.format(time.time() - t1)\n    logging.info(logInfo)\n\ndef load_checkpoint(net, checkpoint):\n    from collections import OrderedDict\n\n    temp = OrderedDict()\n    if 'state_dict' in checkpoint:\n        checkpoint = dict(checkpoint['state_dict'])\n    for k in checkpoint:\n        k2 = 'module.'+k if not k.startswith('module.') else k\n        temp[k2] = checkpoint[k]\n\n    net.load_state_dict(temp, strict=True)\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "ACON/ResNet_ACON/utils.py",
    "content": "import os\nimport re\nimport torch\nimport torch.nn as nn\n\nclass CrossEntropyLabelSmooth(nn.Module):\n\n\tdef __init__(self, num_classes, epsilon):\n\t\tsuper(CrossEntropyLabelSmooth, self).__init__()\n\t\tself.num_classes = num_classes\n\t\tself.epsilon = epsilon\n\t\tself.logsoftmax = nn.LogSoftmax(dim=1)\n\n\tdef forward(self, inputs, targets):\n\t\tlog_probs = self.logsoftmax(inputs)\n\t\ttargets = torch.zeros_like(log_probs).scatter_(1, targets.unsqueeze(1), 1)\n\t\ttargets = (1 - self.epsilon) * targets + self.epsilon / self.num_classes\n\t\tloss = (-targets * log_probs).mean(0).sum()\n\t\treturn loss\n\n\nclass AvgrageMeter(object):\n\n\tdef __init__(self):\n\t\tself.reset()\n\n\tdef reset(self):\n\t\tself.avg = 0\n\t\tself.sum = 0\n\t\tself.cnt = 0\n\t\tself.val = 0\n\n\tdef update(self, val, n=1):\n\t\tself.val = val\n\t\tself.sum += val * n\n\t\tself.cnt += n\n\t\tself.avg = self.sum / self.cnt\n\n\ndef accuracy(output, target, topk=(1,)):\n\tmaxk = max(topk)\n\tbatch_size = target.size(0)\n\n\t_, pred = output.topk(maxk, 1, True, True)\n\tpred = pred.t()\n\tcorrect = pred.eq(target.view(1, -1).expand_as(pred))\n\n\tres = []\n\tfor k in topk:\n\t\tcorrect_k = correct[:k].reshape(-1).float().sum(0)\n\t\tres.append(correct_k.mul_(100.0/batch_size))\n\treturn res\n\n\ndef save_checkpoint(state, iters, tag=''):\n\tif not os.path.exists(\"./models\"):\n\t\tos.makedirs(\"./models\")\n\tfilename = os.path.join(\"./models/{}checkpoint-{:06}.pth.tar\".format(tag, iters))\n\ttorch.save(state, filename)\n\ndef get_lastest_model():\n\tif not os.path.exists('./models'):\n\t\tos.mkdir('./models')\n\tmodel_list = os.listdir('./models/')\n\tif model_list == []:\n\t\treturn None, 0\n\tmodel_list.sort()\n\tlastest_model = model_list[-1]\n\titers = re.findall(r'\\d+', lastest_model)\n\treturn './models/' + lastest_model, int(iters[0])\n\n\ndef get_parameters(model):\n\tgroup_no_weight_decay = []\n\tgroup_weight_decay = []\n\tfor pname, p in model.named_parameters():\n\t\tif pname.find('weight') >= 0 and len(p.size()) > 1:\n\t\t\t# print('include ', pname, p.size())\n\t\t\tgroup_weight_decay.append(p)\n\t\telse:\n\t\t\t# print('not include ', pname, p.size())\n\t\t\tgroup_no_weight_decay.append(p)\n\tassert len(list(model.parameters())) == len(group_weight_decay) + len(group_no_weight_decay)\n\tgroups = [dict(params=group_weight_decay), dict(params=group_no_weight_decay, weight_decay=0.)]\n\treturn groups\n"
  },
  {
    "path": "ACON/ShuffleNetV2_ACON/network.py",
    "content": "import torch\nimport torch.nn as nn\n\nimport sys\nsys.path.insert(0,'../..')\nfrom acon import AconC\n\nclass ShuffleV2Block_ACON(nn.Module):\n    def __init__(self, inp, oup, mid_channels, *, ksize, stride):\n        super(ShuffleV2Block_ACON, self).__init__()\n        self.stride = stride\n        assert stride in [1, 2]\n\n        self.mid_channels = mid_channels\n        self.ksize = ksize\n        pad = ksize // 2\n        self.pad = pad\n        self.inp = inp\n\n        outputs = oup - inp\n\n        branch_main = [\n            # pw\n            nn.Conv2d(inp, mid_channels, 1, 1, 0, bias=True),\n            nn.BatchNorm2d(mid_channels),\n            AconC(mid_channels),\n            # dw\n            nn.Conv2d(mid_channels, mid_channels, ksize, stride, pad, groups=mid_channels, bias=True),\n            nn.BatchNorm2d(mid_channels),\n            # pw-linear\n            nn.Conv2d(mid_channels, outputs, 1, 1, 0, bias=True),\n            nn.BatchNorm2d(outputs),\n            AconC(outputs),\n        ]\n        self.branch_main = nn.Sequential(*branch_main)\n\n        if stride == 2:\n            branch_proj = [\n                # dw\n                nn.Conv2d(inp, inp, ksize, stride, pad, groups=inp, bias=True),\n                nn.BatchNorm2d(inp),\n                # pw-linear\n                nn.Conv2d(inp, inp, 1, 1, 0, bias=True),\n                nn.BatchNorm2d(inp),\n                AconC(inp),\n            ]\n            self.branch_proj = nn.Sequential(*branch_proj)\n        else:\n            self.branch_proj = None\n\n    def forward(self, old_x):\n        if self.stride==1:\n            x_proj, x = self.channel_shuffle(old_x)\n            return torch.cat((x_proj, self.branch_main(x)), 1)\n        elif self.stride==2:\n            x_proj = old_x\n            x = old_x\n            return torch.cat((self.branch_proj(x_proj), self.branch_main(x)), 1)\n\n    def channel_shuffle(self, x):\n        batchsize, num_channels, height, width = x.data.size()\n        assert (num_channels % 4 == 0)\n        x = x.reshape(batchsize * num_channels // 2, 2, height * width)\n        x = x.permute(1, 0, 2)\n        x = x.reshape(2, -1, num_channels // 2, height, width)\n        return x[0], x[1]\n\n\nclass ShuffleNetV2_ACON(nn.Module):\n    def __init__(self, input_size=224, n_class=1000, model_size='1.5x'):\n        super(ShuffleNetV2_ACON, self).__init__()\n        print('model size is ', model_size)\n\n        self.stage_repeats = [4, 8, 4]\n        self.model_size = model_size\n        if model_size == '0.5x':\n            self.stage_out_channels = [-1, 24, 48, 96, 192, 1024]\n        elif model_size == '1.0x':\n            self.stage_out_channels = [-1, 24, 116, 232, 464, 1024]\n        elif model_size == '1.5x':\n            self.stage_out_channels = [-1, 24, 176, 352, 704, 1024]\n        elif model_size == '2.0x':\n            self.stage_out_channels = [-1, 24, 244, 488, 976, 2048]\n        else:\n            raise NotImplementedError\n\n        # building first layer\n        input_channel = self.stage_out_channels[1]\n        self.first_conv = nn.Sequential(\n            nn.Conv2d(3, input_channel, 3, 2, 1, bias=True),\n            nn.BatchNorm2d(input_channel),\n            AconC(input_channel),\n        )\n\n        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)\n\n        self.features = []\n        for idxstage in range(len(self.stage_repeats)):\n            numrepeat = self.stage_repeats[idxstage]\n            output_channel = self.stage_out_channels[idxstage+2]\n\n            for i in range(numrepeat):\n                if i == 0:\n                    self.features.append(ShuffleV2Block_ACON(input_channel, output_channel,\n                                                mid_channels=output_channel // 2, ksize=3, stride=2))\n                else:\n                    self.features.append(ShuffleV2Block_ACON(input_channel // 2, output_channel,\n                                                mid_channels=output_channel // 2, ksize=3, stride=1))\n\n                input_channel = output_channel\n\n        self.features = nn.Sequential(*self.features)\n\n        self.conv_last = nn.Sequential(\n            nn.Conv2d(input_channel, self.stage_out_channels[-1], 1, 1, 0, bias=True),\n            nn.BatchNorm2d(self.stage_out_channels[-1]),\n            AconC(self.stage_out_channels[-1]),\n        )\n        self.globalpool = nn.AvgPool2d(7)\n        if self.model_size == '2.0x':\n            self.dropout = nn.Dropout(0.2)\n        self.classifier = nn.Sequential(nn.Linear(self.stage_out_channels[-1], n_class, bias=True))\n        self._initialize_weights()\n\n    def forward(self, x):\n        x = self.first_conv(x)\n        x = self.maxpool(x)\n        x = self.features(x)\n        x = self.conv_last(x)\n\n        x = self.globalpool(x)\n        if self.model_size == '2.0x':\n            x = self.dropout(x)\n        x = x.contiguous().view(-1, self.stage_out_channels[-1])\n        x = self.classifier(x)\n        return x\n\n    def _initialize_weights(self):\n        for name, m in self.named_modules():\n            if isinstance(m, nn.Conv2d):\n                if 'first' in name:\n                    nn.init.normal_(m.weight, 0, 0.01)\n                else:\n                    nn.init.normal_(m.weight, 0, 1.0 / m.weight.shape[1])\n                if m.bias is not None:\n                    nn.init.constant_(m.bias, 0)\n            elif isinstance(m, nn.BatchNorm2d):\n                nn.init.constant_(m.weight, 1)\n                if m.bias is not None:\n                    nn.init.constant_(m.bias, 0.0001)\n                nn.init.constant_(m.running_mean, 0)\n            elif isinstance(m, nn.BatchNorm1d):\n                nn.init.constant_(m.weight, 1)\n                if m.bias is not None:\n                    nn.init.constant_(m.bias, 0.0001)\n                nn.init.constant_(m.running_mean, 0)\n            elif isinstance(m, nn.Linear):\n                nn.init.normal_(m.weight, 0, 0.01)\n                if m.bias is not None:\n                    nn.init.constant_(m.bias, 0)\n\n"
  },
  {
    "path": "ACON/ShuffleNetV2_ACON/train.py",
    "content": "import os\nimport sys\nimport torch\nimport argparse\nimport torch.nn as nn\nimport torchvision.transforms as transforms\nimport torchvision.datasets as datasets\nimport cv2\nimport numpy as np\nimport PIL\nfrom PIL import Image\nimport time\nimport logging\nimport argparse\nfrom network import ShuffleNetV2_ACON\nfrom utils import accuracy, AvgrageMeter, CrossEntropyLabelSmooth, save_checkpoint, get_lastest_model, get_parameters\n\nclass OpencvResize(object):\n\n    def __init__(self, size=256):\n        self.size = size\n\n    def __call__(self, img):\n        assert isinstance(img, PIL.Image.Image)\n        img = np.asarray(img) # (H,W,3) RGB\n        img = img[:,:,::-1] # 2 BGR\n        img = np.ascontiguousarray(img)\n        H, W, _ = img.shape\n        target_size = (int(self.size/H * W + 0.5), self.size) if H < W else (self.size, int(self.size/W * H + 0.5))\n        img = cv2.resize(img, target_size, interpolation=cv2.INTER_LINEAR)\n        img = img[:,:,::-1] # 2 RGB\n        img = np.ascontiguousarray(img)\n        img = Image.fromarray(img)\n        return img\n\nclass ToBGRTensor(object):\n\n    def __call__(self, img):\n        assert isinstance(img, (np.ndarray, PIL.Image.Image))\n        if isinstance(img, PIL.Image.Image):\n            img = np.asarray(img)\n        img = img[:,:,::-1] # 2 BGR\n        img = np.transpose(img, [2, 0, 1]) # 2 (3, H, W)\n        img = np.ascontiguousarray(img)\n        img = torch.from_numpy(img).float()\n        return img\n\nclass DataIterator(object):\n\n    def __init__(self, dataloader):\n        self.dataloader = dataloader\n        self.iterator = enumerate(self.dataloader)\n\n    def next(self):\n        try:\n            _, data = next(self.iterator)\n        except Exception:\n            self.iterator = enumerate(self.dataloader)\n            _, data = next(self.iterator)\n        return data[0], data[1]\n\ndef get_args():\n    parser = argparse.ArgumentParser(\"ShuffleNetV2_ACON\")\n    parser.add_argument('--eval', default=False, action='store_true')\n    parser.add_argument('--eval-resume', type=str, default='./shufflenetv2.0.5.acon.pth', help='path for eval model')\n    parser.add_argument('--batch-size', type=int, default=1024, help='batch size')\n    parser.add_argument('--total-iters', type=int, default=300000, help='total iters')\n    parser.add_argument('--learning-rate', type=float, default=0.5, help='init learning rate')\n    parser.add_argument('--momentum', type=float, default=0.9, help='momentum')\n    parser.add_argument('--weight-decay', type=float, default=4e-5, help='weight decay')\n    parser.add_argument('--save', type=str, default='./models', help='path for saving trained models')\n    parser.add_argument('--label-smooth', type=float, default=0.1, help='label smoothing')\n\n    parser.add_argument('--auto-continue', type=bool, default=True, help='auto continue')\n    parser.add_argument('--display-interval', type=int, default=20, help='display interval')\n    parser.add_argument('--val-interval', type=int, default=10000, help='val interval')\n    parser.add_argument('--save-interval', type=int, default=10000, help='save interval')\n\n\n    parser.add_argument('--model-size', type=str, default='0.5x', choices=['0.5x', '1.0x', '1.5x', '2.0x'], help='size of the model')\n\n    parser.add_argument('--train-dir', type=str, default='data/train', help='path to training dataset')\n    parser.add_argument('--val-dir', type=str, default='data/val', help='path to validation dataset')\n\n    args = parser.parse_args()\n    return args\n\ndef main():\n    args = get_args()\n\n    # Log\n    log_format = '[%(asctime)s] %(message)s'\n    logging.basicConfig(stream=sys.stdout, level=logging.INFO,\n        format=log_format, datefmt='%d %I:%M:%S')\n    t = time.time()\n    local_time = time.localtime(t)\n    if not os.path.exists('./log'):\n        os.mkdir('./log')\n    fh = logging.FileHandler(os.path.join('log/train-{}{:02}{}'.format(local_time.tm_year % 2000, local_time.tm_mon, t)))\n    fh.setFormatter(logging.Formatter(log_format))\n    logging.getLogger().addHandler(fh)\n\n    use_gpu = False\n    if torch.cuda.is_available():\n        use_gpu = True\n\n    assert os.path.exists(args.train_dir)\n    train_dataset = datasets.ImageFolder(\n        args.train_dir,\n        transforms.Compose([\n            transforms.RandomResizedCrop(224),\n            transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4),\n            transforms.RandomHorizontalFlip(0.5),\n            ToBGRTensor(),\n        ])\n    )\n    train_loader = torch.utils.data.DataLoader(\n        train_dataset, batch_size=args.batch_size, shuffle=True,\n        num_workers=1, pin_memory=use_gpu)\n    train_dataprovider = DataIterator(train_loader)\n\n    assert os.path.exists(args.val_dir)\n    val_loader = torch.utils.data.DataLoader(\n        datasets.ImageFolder(args.val_dir, transforms.Compose([\n            OpencvResize(256),\n            transforms.CenterCrop(224),\n            ToBGRTensor(),\n        ])),\n        batch_size=200, shuffle=False,\n        num_workers=1, pin_memory=use_gpu\n    )\n    val_dataprovider = DataIterator(val_loader)\n    print('load data successfully')\n\n    model = ShuffleNetV2_ACON(model_size=args.model_size)\n\n    optimizer = torch.optim.SGD(get_parameters(model),\n                                lr=args.learning_rate,\n                                momentum=args.momentum,\n                                weight_decay=args.weight_decay)\n    criterion_smooth = CrossEntropyLabelSmooth(1000, 0.1)\n\n    if use_gpu:\n        model = nn.DataParallel(model)\n        loss_function = criterion_smooth.cuda()\n        device = torch.device(\"cuda\")\n    else:\n        loss_function = criterion_smooth\n        device = torch.device(\"cpu\")\n\n    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer,\n                    lambda step : (1.0-step/args.total_iters) if step <= args.total_iters else 0, last_epoch=-1)\n\n    model = model.to(device)\n\n    all_iters = 0\n    if args.auto_continue:\n        lastest_model, iters = get_lastest_model()\n        if lastest_model is not None:\n            all_iters = iters\n            checkpoint = torch.load(lastest_model, map_location=None if use_gpu else 'cpu')\n            model.load_state_dict(checkpoint['state_dict'], strict=True)\n            print('load from checkpoint')\n            for i in range(iters):\n                scheduler.step()\n\n    args.optimizer = optimizer\n    args.loss_function = loss_function\n    args.scheduler = scheduler\n    args.train_dataprovider = train_dataprovider\n    args.val_dataprovider = val_dataprovider\n\n    if args.eval:\n        if args.eval_resume is not None:\n            checkpoint = torch.load(args.eval_resume, map_location=None if use_gpu else 'cpu')\n            load_checkpoint(model, checkpoint)\n            validate(model, device, args, all_iters=all_iters)\n        exit(0)\n\n    while all_iters < args.total_iters:\n        all_iters = train(model, device, args, val_interval=args.val_interval, bn_process=False, all_iters=all_iters)\n        validate(model, device, args, all_iters=all_iters)\n    validate(model, device, args, all_iters=all_iters)\n    save_checkpoint({'state_dict': model.state_dict(),}, args.total_iters, tag='bnps-')\n\ndef adjust_bn_momentum(model, iters):\n    for m in model.modules():\n        if isinstance(m, nn.BatchNorm2d):\n            m.momentum = 1 / iters\n\ndef train(model, device, args, *, val_interval, bn_process=False, all_iters=None):\n\n    optimizer = args.optimizer\n    loss_function = args.loss_function\n    scheduler = args.scheduler\n    train_dataprovider = args.train_dataprovider\n\n    t1 = time.time()\n    Top1_err, Top5_err = 0.0, 0.0\n    model.train()\n    for iters in range(1, val_interval + 1):\n        scheduler.step()\n        if bn_process:\n            adjust_bn_momentum(model, iters)\n\n        all_iters += 1\n        d_st = time.time()\n        data, target = train_dataprovider.next()\n        target = target.type(torch.LongTensor)\n        data, target = data.to(device), target.to(device)\n        data_time = time.time() - d_st\n\n        output = model(data)\n        loss = loss_function(output, target)\n        optimizer.zero_grad()\n        loss.backward()\n        optimizer.step()\n        prec1, prec5 = accuracy(output, target, topk=(1, 5))\n\n        Top1_err += 1 - prec1.item() / 100\n        Top5_err += 1 - prec5.item() / 100\n\n        if all_iters % args.display_interval == 0:\n            printInfo = 'TRAIN Iter {}: lr = {:.6f},\\tloss = {:.6f},\\t'.format(all_iters, scheduler.get_lr()[0], loss.item()) + \\\n                        'Top-1 err = {:.6f},\\t'.format(Top1_err / args.display_interval) + \\\n                        'Top-5 err = {:.6f},\\t'.format(Top5_err / args.display_interval) + \\\n                        'data_time = {:.6f},\\ttrain_time = {:.6f}'.format(data_time, (time.time() - t1) / args.display_interval)\n            logging.info(printInfo)\n            t1 = time.time()\n            Top1_err, Top5_err = 0.0, 0.0\n\n        if all_iters % args.save_interval == 0:\n            save_checkpoint({\n                'state_dict': model.state_dict(),\n                }, all_iters)\n\n    return all_iters\n\ndef validate(model, device, args, *, all_iters=None):\n    objs = AvgrageMeter()\n    top1 = AvgrageMeter()\n    top5 = AvgrageMeter()\n\n    loss_function = args.loss_function\n    val_dataprovider = args.val_dataprovider\n\n    model.eval()\n    max_val_iters = 250\n    t1  = time.time()\n    with torch.no_grad():\n        for _ in range(1, max_val_iters + 1):\n            data, target = val_dataprovider.next()\n            target = target.type(torch.LongTensor)\n            data, target = data.to(device), target.to(device)\n\n            output = model(data)\n            loss = loss_function(output, target)\n\n            prec1, prec5 = accuracy(output, target, topk=(1, 5))\n            n = data.size(0)\n            objs.update(loss.item(), n)\n            top1.update(prec1.item(), n)\n            top5.update(prec5.item(), n)\n\n    logInfo = 'TEST Iter {}: loss = {:.6f},\\t'.format(all_iters, objs.avg) + \\\n              'Top-1 err = {:.6f},\\t'.format(1 - top1.avg / 100) + \\\n              'Top-5 err = {:.6f},\\t'.format(1 - top5.avg / 100) + \\\n              'val_time = {:.6f}'.format(time.time() - t1)\n    logging.info(logInfo)\n\ndef load_checkpoint(net, checkpoint):\n    from collections import OrderedDict\n\n    temp = OrderedDict()\n    if 'state_dict' in checkpoint:\n        checkpoint = dict(checkpoint['state_dict'])\n    for k in checkpoint:\n        k2 = 'module.'+k if not k.startswith('module.') else k\n        temp[k2] = checkpoint[k]\n\n    net.load_state_dict(temp, strict=True)\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "ACON/ShuffleNetV2_ACON/utils.py",
    "content": "import os\nimport re\nimport torch\nimport torch.nn as nn\n\nclass CrossEntropyLabelSmooth(nn.Module):\n\n\tdef __init__(self, num_classes, epsilon):\n\t\tsuper(CrossEntropyLabelSmooth, self).__init__()\n\t\tself.num_classes = num_classes\n\t\tself.epsilon = epsilon\n\t\tself.logsoftmax = nn.LogSoftmax(dim=1)\n\n\tdef forward(self, inputs, targets):\n\t\tlog_probs = self.logsoftmax(inputs)\n\t\ttargets = torch.zeros_like(log_probs).scatter_(1, targets.unsqueeze(1), 1)\n\t\ttargets = (1 - self.epsilon) * targets + self.epsilon / self.num_classes\n\t\tloss = (-targets * log_probs).mean(0).sum()\n\t\treturn loss\n\n\nclass AvgrageMeter(object):\n\n\tdef __init__(self):\n\t\tself.reset()\n\n\tdef reset(self):\n\t\tself.avg = 0\n\t\tself.sum = 0\n\t\tself.cnt = 0\n\t\tself.val = 0\n\n\tdef update(self, val, n=1):\n\t\tself.val = val\n\t\tself.sum += val * n\n\t\tself.cnt += n\n\t\tself.avg = self.sum / self.cnt\n\n\ndef accuracy(output, target, topk=(1,)):\n\tmaxk = max(topk)\n\tbatch_size = target.size(0)\n\n\t_, pred = output.topk(maxk, 1, True, True)\n\tpred = pred.t()\n\tcorrect = pred.eq(target.view(1, -1).expand_as(pred))\n\n\tres = []\n\tfor k in topk:\n\t\tcorrect_k = correct[:k].reshape(-1).float().sum(0)\n\t\tres.append(correct_k.mul_(100.0/batch_size))\n\treturn res\n\n\ndef save_checkpoint(state, iters, tag=''):\n\tif not os.path.exists(\"./models\"):\n\t\tos.makedirs(\"./models\")\n\tfilename = os.path.join(\"./models/{}checkpoint-{:06}.pth.tar\".format(tag, iters))\n\ttorch.save(state, filename)\n\ndef get_lastest_model():\n\tif not os.path.exists('./models'):\n\t\tos.mkdir('./models')\n\tmodel_list = os.listdir('./models/')\n\tif model_list == []:\n\t\treturn None, 0\n\tmodel_list.sort()\n\tlastest_model = model_list[-1]\n\titers = re.findall(r'\\d+', lastest_model)\n\treturn './models/' + lastest_model, int(iters[0])\n\n\ndef get_parameters(model):\n\tgroup_no_weight_decay = []\n\tgroup_weight_decay = []\n\tfor pname, p in model.named_parameters():\n\t\tif pname.find('weight') >= 0 and len(p.size()) > 1:\n\t\t\t# print('include ', pname, p.size())\n\t\t\tgroup_weight_decay.append(p)\n\t\telse:\n\t\t\t# print('not include ', pname, p.size())\n\t\t\tgroup_no_weight_decay.append(p)\n\tassert len(list(model.parameters())) == len(group_weight_decay) + len(group_no_weight_decay)\n\tgroups = [dict(params=group_weight_decay), dict(params=group_no_weight_decay, weight_decay=0.)]\n\treturn groups\n"
  },
  {
    "path": "LICENSE",
    "content": "MIT License\n\nCopyright (c) 2021 nmaac \n\nPermission is hereby granted, free of charge, to any person obtaining a copy\nof this software and associated documentation files (the \"Software\"), to deal\nin the Software without restriction, including without limitation the rights\nto use, copy, modify, merge, publish, distribute, sublicense, and/or sell\ncopies of the Software, and to permit persons to whom the Software is\nfurnished to do so, subject to the following conditions:\n\nThe above copyright notice and this permission notice shall be included in all\ncopies or substantial portions of the Software.\n\nTHE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\nIMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\nFITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\nAUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\nLIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\nOUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\nSOFTWARE.\n"
  },
  {
    "path": "MetaACON/ResNet_MetaACON/resnet_metaacon.py",
    "content": "import torch\nfrom torch import Tensor\nimport torch.nn as nn\nfrom typing import Type, Any, Callable, Union, List, Optional\n\nimport sys\nsys.path.insert(0,'../..')\nfrom acon import MetaAconC\n\n\n__all__ = ['ResNet', 'resnet50_metaacon', 'resnet101_metaacon', 'resnet152_metaacon']\n\n\nmodel_urls = {}\n\n\ndef conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1) -> nn.Conv2d:\n    \"\"\"3x3 convolution with padding\"\"\"\n    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,\n                     padding=dilation, groups=groups, bias=True, dilation=dilation)\n\n\ndef conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d:\n    \"\"\"1x1 convolution\"\"\"\n    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=True)\n\n\nclass Bottleneck_MetaACON(nn.Module):\n    # We change the ReLU activation function after the 3x3 convolution(self.conv2) to ACON-C\n    # according to \"Activate or Not: Learning Customized Activation\" <https://arxiv.org/pdf/2009.04759.pdf>.\n\n    # We use the original implementation which places the stride at the first 1x1 convolution(self.conv1)\n    # while original implementation places the stride at the first 1x1 convolution(self.conv1)\n    # according to \"Deep residual learning for image recognition\"https://arxiv.org/abs/1512.03385.\n    # This variant is also known as ResNet V1.5 and improves accuracy according to\n    # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.\n\n    expansion: int = 4\n\n    def __init__(\n        self,\n        inplanes: int,\n        planes: int,\n        stride: int = 1,\n        downsample: Optional[nn.Module] = None,\n        groups: int = 1,\n        base_width: int = 64,\n        dilation: int = 1,\n        norm_layer: Optional[Callable[..., nn.Module]] = None\n    ) -> None:\n        super(Bottleneck_MetaACON, self).__init__()\n        if norm_layer is None:\n            norm_layer = nn.BatchNorm2d\n        width = int(planes * (base_width / 64.)) * groups\n        # Both self.conv2 and self.downsample layers downsample the input when stride != 1\n        self.conv1 = conv1x1(inplanes, width, stride)\n        self.bn1 = norm_layer(width)\n        self.conv2 = conv3x3(width, width, 1, groups, dilation)\n        self.bn2 = norm_layer(width)\n        self.acon = MetaAconC(width)\n        self.conv3 = conv1x1(width, planes * self.expansion)\n        self.bn3 = norm_layer(planes * self.expansion)\n        self.relu = nn.ReLU(inplace=True)\n        self.downsample = downsample\n        self.stride = stride\n\n    def forward(self, x: Tensor) -> Tensor:\n        identity = x\n\n        out = self.conv1(x)\n        out = self.bn1(out)\n        out = self.relu(out)\n\n        out = self.conv2(out)\n        out = self.bn2(out)\n        out = self.acon(out)\n\n        out = self.conv3(out)\n        out = self.bn3(out)\n\n        if self.downsample is not None:\n            identity = self.downsample(x)\n\n        out += identity\n        out = self.relu(out)\n\n        return out\n\nclass ResNet(nn.Module):\n\n    def __init__(\n        self,\n        block: Type[Union[Bottleneck_MetaACON]],\n        layers: List[int],\n        num_classes: int = 1000,\n        zero_init_residual: bool = False,\n        groups: int = 1,\n        width_per_group: int = 64,\n        replace_stride_with_dilation: Optional[List[bool]] = None,\n        norm_layer: Optional[Callable[..., nn.Module]] = None\n    ) -> None:\n        super(ResNet, self).__init__()\n        if norm_layer is None:\n            norm_layer = nn.BatchNorm2d\n        self._norm_layer = norm_layer\n\n        self.inplanes = 64\n        self.dilation = 1\n        if replace_stride_with_dilation is None:\n            # each element in the tuple indicates if we should replace\n            # the 2x2 stride with a dilated convolution instead\n            replace_stride_with_dilation = [False, False, False]\n        if len(replace_stride_with_dilation) != 3:\n            raise ValueError(\"replace_stride_with_dilation should be None \"\n                             \"or a 3-element tuple, got {}\".format(replace_stride_with_dilation))\n        self.groups = groups\n        self.base_width = width_per_group\n        self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,\n                               bias=True)\n        self.bn1 = norm_layer(self.inplanes)\n        self.relu = nn.ReLU(inplace=True)\n        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)\n        self.layer1 = self._make_layer(block, 64, layers[0])\n        self.layer2 = self._make_layer(block, 128, layers[1], stride=2,\n                                       dilate=replace_stride_with_dilation[0])\n        self.layer3 = self._make_layer(block, 256, layers[2], stride=2,\n                                       dilate=replace_stride_with_dilation[1])\n        self.layer4 = self._make_layer(block, 512, layers[3], stride=2,\n                                       dilate=replace_stride_with_dilation[2])\n        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))\n        self.fc = nn.Linear(512 * block.expansion, num_classes)\n\n        for m in self.modules():\n            if isinstance(m, nn.Conv2d):\n                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')\n            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):\n                nn.init.constant_(m.weight, 1)\n                nn.init.constant_(m.bias, 0)\n\n        # Zero-initialize the last BN in each residual branch,\n        # so that the residual branch starts with zeros, and each residual block behaves like an identity.\n        # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677\n        if zero_init_residual:\n            for m in self.modules():\n                if isinstance(m, Bottleneck_MetaACON):\n                    nn.init.constant_(m.bn3.weight, 0)  # type: ignore[arg-type]\n\n    def _make_layer(self, block: Type[Union[Bottleneck_MetaACON]], planes: int, blocks: int,\n                    stride: int = 1, dilate: bool = False) -> nn.Sequential:\n        norm_layer = self._norm_layer\n        downsample = None\n        previous_dilation = self.dilation\n        if dilate:\n            self.dilation *= stride\n            stride = 1\n        if stride != 1 or self.inplanes != planes * block.expansion:\n            downsample = nn.Sequential(\n                conv1x1(self.inplanes, planes * block.expansion, stride),\n                norm_layer(planes * block.expansion),\n            )\n\n        layers = []\n        layers.append(block(self.inplanes, planes, stride, downsample, self.groups,\n                            self.base_width, previous_dilation, norm_layer))\n        self.inplanes = planes * block.expansion\n        for _ in range(1, blocks):\n            layers.append(block(self.inplanes, planes, groups=self.groups,\n                                base_width=self.base_width, dilation=self.dilation,\n                                norm_layer=norm_layer))\n\n        return nn.Sequential(*layers)\n\n    def _forward_impl(self, x: Tensor) -> Tensor:\n        # See note [TorchScript super()]\n        x = self.conv1(x)\n        x = self.bn1(x)\n        x = self.relu(x)\n        x = self.maxpool(x)\n\n        x = self.layer1(x)\n        x = self.layer2(x)\n        x = self.layer3(x)\n        x = self.layer4(x)\n\n        x = self.avgpool(x)\n        x = torch.flatten(x, 1)\n        x = self.fc(x)\n\n        return x\n\n    def forward(self, x: Tensor) -> Tensor:\n        return self._forward_impl(x)\n\n\ndef _resnet(\n    arch: str,\n    block: Type[Union[Bottleneck_MetaACON]],\n    layers: List[int],\n    pretrained: bool,\n    progress: bool,\n    **kwargs: Any\n) -> ResNet:\n    model = ResNet(block, layers, **kwargs)\n    return model\n\n\ndef resnet50_metaacon(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:\n    r\"\"\"ResNet-50-meta-acon model from\n    `\"Activate or Not: Learning Customized Activation\" <https://arxiv.org/pdf/2009.04759.pdf>`_.\n    Args:\n        pretrained (bool): If True, returns a model pre-trained on ImageNet\n        progress (bool): If True, displays a progress bar of the download to stderr\n    \"\"\"\n    return _resnet('resnet50_metaacon', Bottleneck_MetaACON, [3, 4, 6, 3], pretrained, progress,\n                   **kwargs)\n\ndef resnet101_metaacon(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:\n    r\"\"\"ResNet-101-meta-acon model from\n    `\"Activate or Not: Learning Customized Activation\" <https://arxiv.org/pdf/2009.04759.pdf>`_.\n    Args:\n        pretrained (bool): If True, returns a model pre-trained on ImageNet\n        progress (bool): If True, displays a progress bar of the download to stderr\n    \"\"\"\n    return _resnet('resnet101_metaacon', Bottleneck_MetaACON, [3, 4, 23, 3], pretrained, progress,\n                   **kwargs)\n\ndef resnet152_metaacon(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:\n    r\"\"\"ResNet-152-meta-acon model from\n    `\"Activate or Not: Learning Customized Activation\" <https://arxiv.org/pdf/2009.04759.pdf>`_.\n    Args:\n        pretrained (bool): If True, returns a model pre-trained on ImageNet\n        progress (bool): If True, displays a progress bar of the download to stderr\n    \"\"\"\n    return _resnet('resnet152_metaacon', Bottleneck_MetaACON, [3, 8, 36, 3], pretrained, progress,\n                   **kwargs)\n\n"
  },
  {
    "path": "MetaACON/ResNet_MetaACON/train.py",
    "content": "import os\nimport sys\nimport torch\nimport argparse\nimport torch.nn as nn\nimport torchvision.transforms as transforms\nimport torchvision.datasets as datasets\nimport cv2\nimport numpy as np\nimport PIL\nfrom PIL import Image\nimport time\nimport logging\nimport argparse\nfrom resnet_metaacon import resnet50_metaacon\nfrom utils import accuracy, AvgrageMeter, CrossEntropyLabelSmooth, save_checkpoint, get_lastest_model, get_parameters\n\nclass OpencvResize(object):\n\n    def __init__(self, size=256):\n        self.size = size\n\n    def __call__(self, img):\n        assert isinstance(img, PIL.Image.Image)\n        img = np.asarray(img) # (H,W,3) RGB\n        img = img[:,:,::-1] # 2 BGR\n        img = np.ascontiguousarray(img)\n        H, W, _ = img.shape\n        target_size = (int(self.size/H * W + 0.5), self.size) if H < W else (self.size, int(self.size/W * H + 0.5))\n        img = cv2.resize(img, target_size, interpolation=cv2.INTER_LINEAR)\n        img = img[:,:,::-1] # 2 RGB\n        img = np.ascontiguousarray(img)\n        img = Image.fromarray(img)\n        return img\n\nclass ToBGRTensor(object):\n\n    def __call__(self, img):\n        assert isinstance(img, (np.ndarray, PIL.Image.Image))\n        if isinstance(img, PIL.Image.Image):\n            img = np.asarray(img)\n        img = img[:,:,::-1] # 2 BGR\n        img = np.transpose(img, [2, 0, 1]) # 2 (3, H, W)\n        img = np.ascontiguousarray(img)\n        img = torch.from_numpy(img).float()\n        return img\n\nclass DataIterator(object):\n\n    def __init__(self, dataloader):\n        self.dataloader = dataloader\n        self.iterator = enumerate(self.dataloader)\n\n    def next(self):\n        try:\n            _, data = next(self.iterator)\n        except Exception:\n            self.iterator = enumerate(self.dataloader)\n            _, data = next(self.iterator)\n        return data[0], data[1]\n\ndef get_args():\n    parser = argparse.ArgumentParser(\"ResNet\")\n    parser.add_argument('--eval', default=False, action='store_true')\n    parser.add_argument('--eval-resume', type=str, default='./res50.metaacon.pth', help='path for eval model')\n    parser.add_argument('--batch-size', type=int, default=256, help='batch size')\n    parser.add_argument('--total-iters', type=int, default=600000, help='total iters')\n    parser.add_argument('--learning-rate', type=float, default=0.1, help='init learning rate')\n    parser.add_argument('--momentum', type=float, default=0.9, help='momentum')\n    parser.add_argument('--weight-decay', type=float, default=1e-4, help='weight decay')\n    parser.add_argument('--save', type=str, default='./models', help='path for saving trained models')\n\n    parser.add_argument('--auto-continue', type=bool, default=True, help='auto continue')\n    parser.add_argument('--display-interval', type=int, default=20, help='display interval')\n    parser.add_argument('--val-interval', type=int, default=50000, help='val interval')\n    parser.add_argument('--save-interval', type=int, default=50000, help='save interval')\n\n\n\n    parser.add_argument('--train-dir', type=str, default='data/train', help='path to training dataset')\n    parser.add_argument('--val-dir', type=str, default='data/val', help='path to validation dataset')\n\n    args = parser.parse_args()\n    return args\n\ndef main():\n    args = get_args()\n\n    # Log\n    log_format = '[%(asctime)s] %(message)s'\n    logging.basicConfig(stream=sys.stdout, level=logging.INFO,\n        format=log_format, datefmt='%d %I:%M:%S')\n    t = time.time()\n    local_time = time.localtime(t)\n    if not os.path.exists('./log'):\n        os.mkdir('./log')\n    fh = logging.FileHandler(os.path.join('log/train-{}{:02}{}'.format(local_time.tm_year % 2000, local_time.tm_mon, t)))\n    fh.setFormatter(logging.Formatter(log_format))\n    logging.getLogger().addHandler(fh)\n\n    use_gpu = False\n    if torch.cuda.is_available():\n        use_gpu = True\n\n    assert os.path.exists(args.train_dir)\n    train_dataset = datasets.ImageFolder(\n        args.train_dir,\n        transforms.Compose([\n            transforms.RandomResizedCrop(224),\n            transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4),\n            transforms.RandomHorizontalFlip(0.5),\n            ToBGRTensor(),\n        ])\n    )\n    train_loader = torch.utils.data.DataLoader(\n        train_dataset, batch_size=args.batch_size, shuffle=True,\n        num_workers=1, pin_memory=use_gpu)\n    train_dataprovider = DataIterator(train_loader)\n\n    assert os.path.exists(args.val_dir)\n    val_loader = torch.utils.data.DataLoader(\n        datasets.ImageFolder(args.val_dir, transforms.Compose([\n            OpencvResize(256),\n            transforms.CenterCrop(224),\n            ToBGRTensor(),\n        ])),\n        batch_size=200, shuffle=False,\n        num_workers=1, pin_memory=use_gpu\n    )\n    val_dataprovider = DataIterator(val_loader)\n    print('load data successfully')\n\n    model = resnet50_metaacon()\n    optimizer = torch.optim.SGD(get_parameters(model),\n                                lr=args.learning_rate,\n                                momentum=args.momentum,\n                                weight_decay=args.weight_decay)\n    criterion_smooth = CrossEntropyLabelSmooth(1000, 0.0)\n\n    if use_gpu:\n        model = nn.DataParallel(model)\n        loss_function = criterion_smooth.cuda()\n        device = torch.device(\"cuda\")\n    else:\n        loss_function = criterion_smooth\n        device = torch.device(\"cpu\")\n\n    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer,\n                    lambda step : (1.0-step/args.total_iters) if step <= args.total_iters else 0, last_epoch=-1)\n\n    model = model.to(device)\n\n    all_iters = 0\n    if args.auto_continue:\n        lastest_model, iters = get_lastest_model()\n        if lastest_model is not None:\n            all_iters = iters\n            checkpoint = torch.load(lastest_model, map_location=None if use_gpu else 'cpu')\n            model.load_state_dict(checkpoint['state_dict'], strict=True)\n            print('load from checkpoint')\n            for i in range(iters):\n                scheduler.step()\n\n    args.optimizer = optimizer\n    args.loss_function = loss_function\n    args.scheduler = scheduler\n    args.train_dataprovider = train_dataprovider\n    args.val_dataprovider = val_dataprovider\n\n    if args.eval:\n        if args.eval_resume is not None:\n            checkpoint = torch.load(args.eval_resume, map_location=None if use_gpu else 'cpu')\n            load_checkpoint(model, checkpoint)\n            validate(model, device, args, all_iters=all_iters)\n        exit(0)\n\n    while all_iters < args.total_iters:\n        all_iters = train(model, device, args, val_interval=args.val_interval, bn_process=False, all_iters=all_iters)\n        validate(model, device, args, all_iters=all_iters)\n    validate(model, device, args, all_iters=all_iters)\n    save_checkpoint({'state_dict': model.state_dict(),}, args.total_iters, tag='bnps-')\n\ndef adjust_bn_momentum(model, iters):\n    for m in model.modules():\n        if isinstance(m, nn.BatchNorm2d):\n            m.momentum = 1 / iters\n\ndef train(model, device, args, *, val_interval, bn_process=False, all_iters=None):\n\n    optimizer = args.optimizer\n    loss_function = args.loss_function\n    scheduler = args.scheduler\n    train_dataprovider = args.train_dataprovider\n\n    t1 = time.time()\n    Top1_err, Top5_err = 0.0, 0.0\n    model.train()\n    for iters in range(1, val_interval + 1):\n        scheduler.step()\n        if bn_process:\n            adjust_bn_momentum(model, iters)\n\n        all_iters += 1\n        d_st = time.time()\n        data, target = train_dataprovider.next()\n        target = target.type(torch.LongTensor)\n        data, target = data.to(device), target.to(device)\n        data_time = time.time() - d_st\n\n        output = model(data)\n        loss = loss_function(output, target)\n        optimizer.zero_grad()\n        loss.backward()\n        optimizer.step()\n        prec1, prec5 = accuracy(output, target, topk=(1, 5))\n\n        Top1_err += 1 - prec1.item() / 100\n        Top5_err += 1 - prec5.item() / 100\n\n        if all_iters % args.display_interval == 0:\n            printInfo = 'TRAIN Iter {}: lr = {:.6f},\\tloss = {:.6f},\\t'.format(all_iters, scheduler.get_lr()[0], loss.item()) + \\\n                        'Top-1 err = {:.6f},\\t'.format(Top1_err / args.display_interval) + \\\n                        'Top-5 err = {:.6f},\\t'.format(Top5_err / args.display_interval) + \\\n                        'data_time = {:.6f},\\ttrain_time = {:.6f}'.format(data_time, (time.time() - t1) / args.display_interval)\n            logging.info(printInfo)\n            t1 = time.time()\n            Top1_err, Top5_err = 0.0, 0.0\n\n        if all_iters % args.save_interval == 0:\n            save_checkpoint({\n                'state_dict': model.state_dict(),\n                }, all_iters)\n\n    return all_iters\n\ndef validate(model, device, args, *, all_iters=None):\n    objs = AvgrageMeter()\n    top1 = AvgrageMeter()\n    top5 = AvgrageMeter()\n\n    loss_function = args.loss_function\n    val_dataprovider = args.val_dataprovider\n\n    model.eval()\n    max_val_iters = 250\n    t1  = time.time()\n    with torch.no_grad():\n        for _ in range(1, max_val_iters + 1):\n            data, target = val_dataprovider.next()\n            target = target.type(torch.LongTensor)\n            data, target = data.to(device), target.to(device)\n\n            output = model(data)\n            loss = loss_function(output, target)\n\n            prec1, prec5 = accuracy(output, target, topk=(1, 5))\n            n = data.size(0)\n            objs.update(loss.item(), n)\n            top1.update(prec1.item(), n)\n            top5.update(prec5.item(), n)\n\n    logInfo = 'TEST Iter {}: loss = {:.6f},\\t'.format(all_iters, objs.avg) + \\\n              'Top-1 err = {:.6f},\\t'.format(1 - top1.avg / 100) + \\\n              'Top-5 err = {:.6f},\\t'.format(1 - top5.avg / 100) + \\\n              'val_time = {:.6f}'.format(time.time() - t1)\n    logging.info(logInfo)\n\ndef load_checkpoint(net, checkpoint):\n    from collections import OrderedDict\n\n    temp = OrderedDict()\n    if 'state_dict' in checkpoint:\n        checkpoint = dict(checkpoint['state_dict'])\n    for k in checkpoint:\n        k2 = 'module.'+k if not k.startswith('module.') else k\n        temp[k2] = checkpoint[k]\n\n    net.load_state_dict(temp, strict=True)\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "MetaACON/ResNet_MetaACON/utils.py",
    "content": "import os\nimport re\nimport torch\nimport torch.nn as nn\n\nclass CrossEntropyLabelSmooth(nn.Module):\n\n\tdef __init__(self, num_classes, epsilon):\n\t\tsuper(CrossEntropyLabelSmooth, self).__init__()\n\t\tself.num_classes = num_classes\n\t\tself.epsilon = epsilon\n\t\tself.logsoftmax = nn.LogSoftmax(dim=1)\n\n\tdef forward(self, inputs, targets):\n\t\tlog_probs = self.logsoftmax(inputs)\n\t\ttargets = torch.zeros_like(log_probs).scatter_(1, targets.unsqueeze(1), 1)\n\t\ttargets = (1 - self.epsilon) * targets + self.epsilon / self.num_classes\n\t\tloss = (-targets * log_probs).mean(0).sum()\n\t\treturn loss\n\n\nclass AvgrageMeter(object):\n\n\tdef __init__(self):\n\t\tself.reset()\n\n\tdef reset(self):\n\t\tself.avg = 0\n\t\tself.sum = 0\n\t\tself.cnt = 0\n\t\tself.val = 0\n\n\tdef update(self, val, n=1):\n\t\tself.val = val\n\t\tself.sum += val * n\n\t\tself.cnt += n\n\t\tself.avg = self.sum / self.cnt\n\n\ndef accuracy(output, target, topk=(1,)):\n\tmaxk = max(topk)\n\tbatch_size = target.size(0)\n\n\t_, pred = output.topk(maxk, 1, True, True)\n\tpred = pred.t()\n\tcorrect = pred.eq(target.view(1, -1).expand_as(pred))\n\n\tres = []\n\tfor k in topk:\n\t\tcorrect_k = correct[:k].reshape(-1).float().sum(0)\n\t\tres.append(correct_k.mul_(100.0/batch_size))\n\treturn res\n\n\ndef save_checkpoint(state, iters, tag=''):\n\tif not os.path.exists(\"./models\"):\n\t\tos.makedirs(\"./models\")\n\tfilename = os.path.join(\"./models/{}checkpoint-{:06}.pth.tar\".format(tag, iters))\n\ttorch.save(state, filename)\n\ndef get_lastest_model():\n\tif not os.path.exists('./models'):\n\t\tos.mkdir('./models')\n\tmodel_list = os.listdir('./models/')\n\tif model_list == []:\n\t\treturn None, 0\n\tmodel_list.sort()\n\tlastest_model = model_list[-1]\n\titers = re.findall(r'\\d+', lastest_model)\n\treturn './models/' + lastest_model, int(iters[0])\n\n\ndef get_parameters(model):\n\tgroup_no_weight_decay = []\n\tgroup_weight_decay = []\n\tfor pname, p in model.named_parameters():\n\t\tif pname.find('weight') >= 0 and len(p.size()) > 1:\n\t\t\t# print('include ', pname, p.size())\n\t\t\tgroup_weight_decay.append(p)\n\t\telse:\n\t\t\t# print('not include ', pname, p.size())\n\t\t\tgroup_no_weight_decay.append(p)\n\tassert len(list(model.parameters())) == len(group_weight_decay) + len(group_no_weight_decay)\n\tgroups = [dict(params=group_weight_decay), dict(params=group_no_weight_decay, weight_decay=0.)]\n\treturn groups\n"
  },
  {
    "path": "MetaACON/ShuffleNet_MetaACON/network.py",
    "content": "import torch\nimport torch.nn as nn\n\nimport sys\nsys.path.insert(0,'../..')\nfrom acon import MetaAconC\n\nclass ShuffleV2Block_MetaACON(nn.Module):\n    def __init__(self, inp, oup, mid_channels, *, ksize, stride, r=16):\n        super(ShuffleV2Block_MetaACON, self).__init__()\n        self.stride = stride\n        assert stride in [1, 2]\n\n        self.mid_channels = mid_channels\n        self.ksize = ksize\n        pad = ksize // 2\n        self.pad = pad\n        self.inp = inp\n\n        outputs = oup - inp\n\n        branch_main = [\n            # pw\n            nn.Conv2d(inp, mid_channels, 1, 1, 0, bias=True),\n            nn.BatchNorm2d(mid_channels),\n            MetaAconC(mid_channels, r=r),\n            # dw\n            nn.Conv2d(mid_channels, mid_channels, ksize, stride, pad, groups=mid_channels, bias=True),\n            nn.BatchNorm2d(mid_channels),\n            # pw-linear\n            nn.Conv2d(mid_channels, outputs, 1, 1, 0, bias=True),\n            nn.BatchNorm2d(outputs),\n            MetaAconC(outputs, r=r),\n        ]\n        self.branch_main = nn.Sequential(*branch_main)\n\n        if stride == 2:\n            branch_proj = [\n                # dw\n                nn.Conv2d(inp, inp, ksize, stride, pad, groups=inp, bias=True),\n                nn.BatchNorm2d(inp),\n                # pw-linear\n                nn.Conv2d(inp, inp, 1, 1, 0, bias=True),\n                nn.BatchNorm2d(inp),\n                MetaAconC(inp, r=r),\n            ]\n            self.branch_proj = nn.Sequential(*branch_proj)\n        else:\n            self.branch_proj = None\n\n    def forward(self, old_x):\n        if self.stride==1:\n            x_proj, x = self.channel_shuffle(old_x)\n            return torch.cat((x_proj, self.branch_main(x)), 1)\n        elif self.stride==2:\n            x_proj = old_x\n            x = old_x\n            return torch.cat((self.branch_proj(x_proj), self.branch_main(x)), 1)\n\n    def channel_shuffle(self, x):\n        batchsize, num_channels, height, width = x.data.size()\n        assert (num_channels % 4 == 0)\n        x = x.reshape(batchsize * num_channels // 2, 2, height * width)\n        x = x.permute(1, 0, 2)\n        x = x.reshape(2, -1, num_channels // 2, height, width)\n        return x[0], x[1]\n\n\nclass ShuffleNetV2_MetaACON(nn.Module):\n    def __init__(self, input_size=224, n_class=1000, model_size='1.5x'):\n        super(ShuffleNetV2_MetaACON, self).__init__()\n        print('model size is ', model_size)\n\n        self.stage_repeats = [4, 8, 4]\n        self.model_size = model_size\n        self.r = 16\n        if model_size == '0.5x':\n            self.stage_out_channels = [-1, 24, 48, 96, 192, 1024]\n            self.r = 8\n        elif model_size == '1.0x':\n            self.stage_out_channels = [-1, 24, 116, 232, 464, 1024]\n        elif model_size == '1.5x':\n            self.stage_out_channels = [-1, 24, 176, 352, 704, 1024]\n        elif model_size == '2.0x':\n            self.stage_out_channels = [-1, 24, 244, 488, 976, 2048]\n        else:\n            raise NotImplementedError\n\n        # building first layer\n        input_channel = self.stage_out_channels[1]\n        self.first_conv = nn.Sequential(\n            nn.Conv2d(3, input_channel, 3, 2, 1, bias=True),\n            nn.BatchNorm2d(input_channel),\n            MetaAconC(input_channel, r=self.r),\n        )\n\n        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)\n\n        self.features = []\n        for idxstage in range(len(self.stage_repeats)):\n            numrepeat = self.stage_repeats[idxstage]\n            output_channel = self.stage_out_channels[idxstage+2]\n\n            for i in range(numrepeat):\n                if i == 0:\n                    self.features.append(ShuffleV2Block_MetaACON(input_channel, output_channel,\n                                                mid_channels=output_channel // 2, ksize=3, stride=2, r=self.r))\n                else:\n                    self.features.append(ShuffleV2Block_MetaACON(input_channel // 2, output_channel,\n                                                mid_channels=output_channel // 2, ksize=3, stride=1, r=self.r))\n\n                input_channel = output_channel\n\n        self.features = nn.Sequential(*self.features)\n\n        self.conv_last = nn.Sequential(\n            nn.Conv2d(input_channel, self.stage_out_channels[-1], 1, 1, 0, bias=True),\n            nn.BatchNorm2d(self.stage_out_channels[-1]),\n            MetaAconC(self.stage_out_channels[-1], r=self.r),\n        )\n        self.globalpool = nn.AvgPool2d(7)\n        if self.model_size == '2.0x':\n            self.dropout = nn.Dropout(0.2)\n        self.classifier = nn.Sequential(nn.Linear(self.stage_out_channels[-1], n_class, bias=True))\n        self._initialize_weights()\n\n    def forward(self, x):\n        x = self.first_conv(x)\n        x = self.maxpool(x)\n        x = self.features(x)\n        x = self.conv_last(x)\n\n        x = self.globalpool(x)\n        if self.model_size == '2.0x':\n            x = self.dropout(x)\n        x = x.contiguous().view(-1, self.stage_out_channels[-1])\n        x = self.classifier(x)\n        return x\n\n    def _initialize_weights(self):\n        for name, m in self.named_modules():\n            if isinstance(m, nn.Conv2d):\n                if 'first' in name:\n                    nn.init.normal_(m.weight, 0, 0.01)\n                else:\n                    nn.init.normal_(m.weight, 0, 1.0 / m.weight.shape[1])\n                if m.bias is not None:\n                    nn.init.constant_(m.bias, 0)\n            elif isinstance(m, nn.BatchNorm2d):\n                nn.init.constant_(m.weight, 1)\n                if m.bias is not None:\n                    nn.init.constant_(m.bias, 0.0001)\n                nn.init.constant_(m.running_mean, 0)\n            elif isinstance(m, nn.BatchNorm1d):\n                nn.init.constant_(m.weight, 1)\n                if m.bias is not None:\n                    nn.init.constant_(m.bias, 0.0001)\n                nn.init.constant_(m.running_mean, 0)\n            elif isinstance(m, nn.Linear):\n                nn.init.normal_(m.weight, 0, 0.01)\n                if m.bias is not None:\n                    nn.init.constant_(m.bias, 0)\n\n"
  },
  {
    "path": "MetaACON/ShuffleNet_MetaACON/train.py",
    "content": "import os\nimport sys\nimport torch\nimport argparse\nimport torch.nn as nn\nimport torchvision.transforms as transforms\nimport torchvision.datasets as datasets\nimport cv2\nimport numpy as np\nimport PIL\nfrom PIL import Image\nimport time\nimport logging\nimport argparse\nfrom network import ShuffleNetV2_MetaACON\nfrom utils import accuracy, AvgrageMeter, CrossEntropyLabelSmooth, save_checkpoint, get_lastest_model, get_parameters\n\nclass OpencvResize(object):\n\n    def __init__(self, size=256):\n        self.size = size\n\n    def __call__(self, img):\n        assert isinstance(img, PIL.Image.Image)\n        img = np.asarray(img) # (H,W,3) RGB\n        img = img[:,:,::-1] # 2 BGR\n        img = np.ascontiguousarray(img)\n        H, W, _ = img.shape\n        target_size = (int(self.size/H * W + 0.5), self.size) if H < W else (self.size, int(self.size/W * H + 0.5))\n        img = cv2.resize(img, target_size, interpolation=cv2.INTER_LINEAR)\n        img = img[:,:,::-1] # 2 RGB\n        img = np.ascontiguousarray(img)\n        img = Image.fromarray(img)\n        return img\n\nclass ToBGRTensor(object):\n\n    def __call__(self, img):\n        assert isinstance(img, (np.ndarray, PIL.Image.Image))\n        if isinstance(img, PIL.Image.Image):\n            img = np.asarray(img)\n        img = img[:,:,::-1] # 2 BGR\n        img = np.transpose(img, [2, 0, 1]) # 2 (3, H, W)\n        img = np.ascontiguousarray(img)\n        img = torch.from_numpy(img).float()\n        return img\n\nclass DataIterator(object):\n\n    def __init__(self, dataloader):\n        self.dataloader = dataloader\n        self.iterator = enumerate(self.dataloader)\n\n    def next(self):\n        try:\n            _, data = next(self.iterator)\n        except Exception:\n            self.iterator = enumerate(self.dataloader)\n            _, data = next(self.iterator)\n        return data[0], data[1]\n\ndef get_args():\n    parser = argparse.ArgumentParser(\"ShuffleNetV2_MetaACON\")\n    parser.add_argument('--eval', default=False, action='store_true')\n    parser.add_argument('--eval-resume', type=str, default='./shufflenetv2.0.5.metaacon.pth', help='path for eval model')\n    parser.add_argument('--batch-size', type=int, default=1024, help='batch size')\n    parser.add_argument('--total-iters', type=int, default=300000, help='total iters')\n    parser.add_argument('--learning-rate', type=float, default=0.5, help='init learning rate')\n    parser.add_argument('--momentum', type=float, default=0.9, help='momentum')\n    parser.add_argument('--weight-decay', type=float, default=4e-5, help='weight decay')\n    parser.add_argument('--save', type=str, default='./models', help='path for saving trained models')\n    parser.add_argument('--label-smooth', type=float, default=0.1, help='label smoothing')\n\n    parser.add_argument('--auto-continue', type=bool, default=True, help='auto continue')\n    parser.add_argument('--display-interval', type=int, default=20, help='display interval')\n    parser.add_argument('--val-interval', type=int, default=10000, help='val interval')\n    parser.add_argument('--save-interval', type=int, default=10000, help='save interval')\n\n\n    parser.add_argument('--model-size', type=str, default='0.5x', choices=['0.5x', '1.0x', '1.5x', '2.0x'], help='size of the model')\n\n    parser.add_argument('--train-dir', type=str, default='data/train', help='path to training dataset')\n    parser.add_argument('--val-dir', type=str, default='data/val', help='path to validation dataset')\n\n    args = parser.parse_args()\n    return args\n\ndef main():\n    args = get_args()\n\n    # Log\n    log_format = '[%(asctime)s] %(message)s'\n    logging.basicConfig(stream=sys.stdout, level=logging.INFO,\n        format=log_format, datefmt='%d %I:%M:%S')\n    t = time.time()\n    local_time = time.localtime(t)\n    if not os.path.exists('./log'):\n        os.mkdir('./log')\n    fh = logging.FileHandler(os.path.join('log/train-{}{:02}{}'.format(local_time.tm_year % 2000, local_time.tm_mon, t)))\n    fh.setFormatter(logging.Formatter(log_format))\n    logging.getLogger().addHandler(fh)\n\n    use_gpu = False\n    if torch.cuda.is_available():\n        use_gpu = True\n\n    assert os.path.exists(args.train_dir)\n    train_dataset = datasets.ImageFolder(\n        args.train_dir,\n        transforms.Compose([\n            transforms.RandomResizedCrop(224),\n            transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4),\n            transforms.RandomHorizontalFlip(0.5),\n            ToBGRTensor(),\n        ])\n    )\n    train_loader = torch.utils.data.DataLoader(\n        train_dataset, batch_size=args.batch_size, shuffle=True,\n        num_workers=1, pin_memory=use_gpu)\n    train_dataprovider = DataIterator(train_loader)\n\n    assert os.path.exists(args.val_dir)\n    val_loader = torch.utils.data.DataLoader(\n        datasets.ImageFolder(args.val_dir, transforms.Compose([\n            OpencvResize(256),\n            transforms.CenterCrop(224),\n            ToBGRTensor(),\n        ])),\n        batch_size=200, shuffle=False,\n        num_workers=1, pin_memory=use_gpu\n    )\n    val_dataprovider = DataIterator(val_loader)\n    print('load data successfully')\n\n    model = ShuffleNetV2_MetaACON(model_size=args.model_size)\n\n    optimizer = torch.optim.SGD(get_parameters(model),\n                                lr=args.learning_rate,\n                                momentum=args.momentum,\n                                weight_decay=args.weight_decay)\n    criterion_smooth = CrossEntropyLabelSmooth(1000, 0.1)\n\n    if use_gpu:\n        model = nn.DataParallel(model)\n        loss_function = criterion_smooth.cuda()\n        device = torch.device(\"cuda\")\n    else:\n        loss_function = criterion_smooth\n        device = torch.device(\"cpu\")\n\n    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer,\n                    lambda step : (1.0-step/args.total_iters) if step <= args.total_iters else 0, last_epoch=-1)\n\n    model = model.to(device)\n\n    all_iters = 0\n    if args.auto_continue:\n        lastest_model, iters = get_lastest_model()\n        if lastest_model is not None:\n            all_iters = iters\n            checkpoint = torch.load(lastest_model, map_location=None if use_gpu else 'cpu')\n            model.load_state_dict(checkpoint['state_dict'], strict=True)\n            print('load from checkpoint')\n            for i in range(iters):\n                scheduler.step()\n\n    args.optimizer = optimizer\n    args.loss_function = loss_function\n    args.scheduler = scheduler\n    args.train_dataprovider = train_dataprovider\n    args.val_dataprovider = val_dataprovider\n\n    if args.eval:\n        if args.eval_resume is not None:\n            checkpoint = torch.load(args.eval_resume, map_location=None if use_gpu else 'cpu')\n            load_checkpoint(model, checkpoint)\n            validate(model, device, args, all_iters=all_iters)\n        exit(0)\n\n    while all_iters < args.total_iters:\n        all_iters = train(model, device, args, val_interval=args.val_interval, bn_process=False, all_iters=all_iters)\n        validate(model, device, args, all_iters=all_iters)\n    validate(model, device, args, all_iters=all_iters)\n    save_checkpoint({'state_dict': model.state_dict(),}, args.total_iters, tag='bnps-')\n\ndef adjust_bn_momentum(model, iters):\n    for m in model.modules():\n        if isinstance(m, nn.BatchNorm2d):\n            m.momentum = 1 / iters\n\ndef train(model, device, args, *, val_interval, bn_process=False, all_iters=None):\n\n    optimizer = args.optimizer\n    loss_function = args.loss_function\n    scheduler = args.scheduler\n    train_dataprovider = args.train_dataprovider\n\n    t1 = time.time()\n    Top1_err, Top5_err = 0.0, 0.0\n    model.train()\n    for iters in range(1, val_interval + 1):\n        scheduler.step()\n        if bn_process:\n            adjust_bn_momentum(model, iters)\n\n        all_iters += 1\n        d_st = time.time()\n        data, target = train_dataprovider.next()\n        target = target.type(torch.LongTensor)\n        data, target = data.to(device), target.to(device)\n        data_time = time.time() - d_st\n\n        output = model(data)\n        loss = loss_function(output, target)\n        optimizer.zero_grad()\n        loss.backward()\n        optimizer.step()\n        prec1, prec5 = accuracy(output, target, topk=(1, 5))\n\n        Top1_err += 1 - prec1.item() / 100\n        Top5_err += 1 - prec5.item() / 100\n\n        if all_iters % args.display_interval == 0:\n            printInfo = 'TRAIN Iter {}: lr = {:.6f},\\tloss = {:.6f},\\t'.format(all_iters, scheduler.get_lr()[0], loss.item()) + \\\n                        'Top-1 err = {:.6f},\\t'.format(Top1_err / args.display_interval) + \\\n                        'Top-5 err = {:.6f},\\t'.format(Top5_err / args.display_interval) + \\\n                        'data_time = {:.6f},\\ttrain_time = {:.6f}'.format(data_time, (time.time() - t1) / args.display_interval)\n            logging.info(printInfo)\n            t1 = time.time()\n            Top1_err, Top5_err = 0.0, 0.0\n\n        if all_iters % args.save_interval == 0:\n            save_checkpoint({\n                'state_dict': model.state_dict(),\n                }, all_iters)\n\n    return all_iters\n\ndef validate(model, device, args, *, all_iters=None):\n    objs = AvgrageMeter()\n    top1 = AvgrageMeter()\n    top5 = AvgrageMeter()\n\n    loss_function = args.loss_function\n    val_dataprovider = args.val_dataprovider\n\n    model.eval()\n    max_val_iters = 250\n    t1  = time.time()\n    with torch.no_grad():\n        for _ in range(1, max_val_iters + 1):\n            data, target = val_dataprovider.next()\n            target = target.type(torch.LongTensor)\n            data, target = data.to(device), target.to(device)\n\n            output = model(data)\n            loss = loss_function(output, target)\n\n            prec1, prec5 = accuracy(output, target, topk=(1, 5))\n            n = data.size(0)\n            objs.update(loss.item(), n)\n            top1.update(prec1.item(), n)\n            top5.update(prec5.item(), n)\n\n    logInfo = 'TEST Iter {}: loss = {:.6f},\\t'.format(all_iters, objs.avg) + \\\n              'Top-1 err = {:.6f},\\t'.format(1 - top1.avg / 100) + \\\n              'Top-5 err = {:.6f},\\t'.format(1 - top5.avg / 100) + \\\n              'val_time = {:.6f}'.format(time.time() - t1)\n    logging.info(logInfo)\n\ndef load_checkpoint(net, checkpoint):\n    from collections import OrderedDict\n\n    temp = OrderedDict()\n    if 'state_dict' in checkpoint:\n        checkpoint = dict(checkpoint['state_dict'])\n    for k in checkpoint:\n        k2 = 'module.'+k if not k.startswith('module.') else k\n        temp[k2] = checkpoint[k]\n\n    net.load_state_dict(temp, strict=True)\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "MetaACON/ShuffleNet_MetaACON/utils.py",
    "content": "import os\nimport re\nimport torch\nimport torch.nn as nn\n\nclass CrossEntropyLabelSmooth(nn.Module):\n\n\tdef __init__(self, num_classes, epsilon):\n\t\tsuper(CrossEntropyLabelSmooth, self).__init__()\n\t\tself.num_classes = num_classes\n\t\tself.epsilon = epsilon\n\t\tself.logsoftmax = nn.LogSoftmax(dim=1)\n\n\tdef forward(self, inputs, targets):\n\t\tlog_probs = self.logsoftmax(inputs)\n\t\ttargets = torch.zeros_like(log_probs).scatter_(1, targets.unsqueeze(1), 1)\n\t\ttargets = (1 - self.epsilon) * targets + self.epsilon / self.num_classes\n\t\tloss = (-targets * log_probs).mean(0).sum()\n\t\treturn loss\n\n\nclass AvgrageMeter(object):\n\n\tdef __init__(self):\n\t\tself.reset()\n\n\tdef reset(self):\n\t\tself.avg = 0\n\t\tself.sum = 0\n\t\tself.cnt = 0\n\t\tself.val = 0\n\n\tdef update(self, val, n=1):\n\t\tself.val = val\n\t\tself.sum += val * n\n\t\tself.cnt += n\n\t\tself.avg = self.sum / self.cnt\n\n\ndef accuracy(output, target, topk=(1,)):\n\tmaxk = max(topk)\n\tbatch_size = target.size(0)\n\n\t_, pred = output.topk(maxk, 1, True, True)\n\tpred = pred.t()\n\tcorrect = pred.eq(target.view(1, -1).expand_as(pred))\n\n\tres = []\n\tfor k in topk:\n\t\tcorrect_k = correct[:k].reshape(-1).float().sum(0)\n\t\tres.append(correct_k.mul_(100.0/batch_size))\n\treturn res\n\n\ndef save_checkpoint(state, iters, tag=''):\n\tif not os.path.exists(\"./models\"):\n\t\tos.makedirs(\"./models\")\n\tfilename = os.path.join(\"./models/{}checkpoint-{:06}.pth.tar\".format(tag, iters))\n\ttorch.save(state, filename)\n\ndef get_lastest_model():\n\tif not os.path.exists('./models'):\n\t\tos.mkdir('./models')\n\tmodel_list = os.listdir('./models/')\n\tif model_list == []:\n\t\treturn None, 0\n\tmodel_list.sort()\n\tlastest_model = model_list[-1]\n\titers = re.findall(r'\\d+', lastest_model)\n\treturn './models/' + lastest_model, int(iters[0])\n\n\ndef get_parameters(model):\n\tgroup_no_weight_decay = []\n\tgroup_weight_decay = []\n\tfor pname, p in model.named_parameters():\n\t\tif pname.find('weight') >= 0 and len(p.size()) > 1:\n\t\t\t# print('include ', pname, p.size())\n\t\t\tgroup_weight_decay.append(p)\n\t\telse:\n\t\t\t# print('not include ', pname, p.size())\n\t\t\tgroup_no_weight_decay.append(p)\n\tassert len(list(model.parameters())) == len(group_weight_decay) + len(group_no_weight_decay)\n\tgroups = [dict(params=group_weight_decay), dict(params=group_no_weight_decay, weight_decay=0.)]\n\treturn groups\n"
  },
  {
    "path": "README.md",
    "content": "\n## CVPR 2021 | Activate or Not: Learning Customized Activation.\n\nThis repository contains the official Pytorch implementation of the paper [Activate or Not: Learning Customized Activation, CVPR 2021](https://arxiv.org/pdf/2009.04759.pdf).\n\n### ACON\n\nWe propose a novel activation function we term the ACON that explicitly learns to activate the neurons or not. \nBelow we show the ACON activation function and its first derivatives. β controls how fast the first derivative asymptotes to the upper/lower bounds, which are determined by p1 and p2.\n\n\n<img src=\"https://user-images.githubusercontent.com/5032208/113257297-fc76f380-92fc-11eb-9559-39d033baea4c.png\" width=90%>\n\n<img src=\"https://user-images.githubusercontent.com/5032208/113257194-cfc2dc00-92fc-11eb-94a0-f81569bed15e.png\" width=90%>\n\n### Training curves\nWe show the training curves of different activations here.\n\n<img src=\"https://user-images.githubusercontent.com/5032208/113260052-65ac3600-9300-11eb-8d2f-ef968be1c3a2.png\"  width=60%>\n\n\n### TFNet\nTo show the effectiveness of the proposed acon family, we also provide an extreme simple toy funnel network (TFNet) made only by pointwise convolution and ACON-FReLU operators.\n\n<img src=\"https://user-images.githubusercontent.com/5032208/113963614-7a3a8200-985c-11eb-8946-65c0bcef0a80.png\"  width=60%>\n\n\n\n\n## Main results\n\nThe following results are the ImageNet top-1 accuracy relative improvements compared with the ReLU baselines. The relative improvements of Meta-ACON are about twice as much as SENet.\n\n<img src=\"https://user-images.githubusercontent.com/5032208/113256618-fcc2bf00-92fb-11eb-9b1d-8f0589009a9b.png\" width=60%>\n\nThe comparison between ReLU, Swish and ACON-C. We show improvements without additional amount of FLOPs and parameters:\n| Model             | FLOPs | #Params. | top-1 err. (ReLU) | top-1 err. (Swish) |   top-1 err. (ACON)   |\n|-------------------|:-----:|:--------:|:-----------------:|:------------------:|:---------------------:|\n| ShuffleNetV2 0.5x |  41M  |   1.4M   |        39.4       |     38.3 (+1.1)    |    **37.0 (+2.4)**    |\n| ShuffleNetV2 1.5x |  299M |   3.5M   |        27.4       |     26.8 (+0.6)    |    **26.5 (+0.9)**    |\n| ResNet 50         |  3.9G |   25.5M  |        24.0       |     23.5 (+0.5)    |    **23.2 (+0.8)**    |\n| ResNet 101        |  7.6G |   44.4M  |        22.8       |     22.7 (+0.1)    |    **21.8 (+1.0)**    |\n| ResNet 152        | 11.3G |   60.0M  |        22.3       |     22.2 (+0.1)    |    **21.2 (+1.1)**    |\n\n\nNext, by adding a negligible amount of FLOPs and parameters, meta-ACON shows sigificant improvements:\n| Model                         | FLOPs | #Params. |        top-1 err.      | \n|-------------------------------|:-----:|:--------:|:----------------------:|\n| ShuffleNetV2 0.5x (meta-acon) | 41M   | 1.7M     |   **34.8 (+4.6)**      | \n| ShuffleNetV2 1.5x (meta-acon) | 299M  | 3.9M     |   **24.7 (+2.7)**      | \n| ResNet 50 (meta-acon)         | 3.9G  | 25.7M    |   **22.0 (+2.0)**      | \n| ResNet 101 (meta-acon)        | 7.6G  | 44.8M    |   **21.0 (+1.8)**      | \n| ResNet 152 (meta-acon)        | 11.3G | 60.5M    |   **20.5 (+1.8)**      | \n\n\n\n\n\nThe simple TFNet without the SE modules can outperform the state-of-the art light-weight networks without the SE modules.\n\n|                   | FLOPs | #Params. |   top-1 err.   |\n|-----------------  |:-----:|:--------:|:--------------:|\n|  MobileNetV2 0.17 |  42M  |   1.4M   |    52.6    |\n| ShuffleNetV2 0.5x |  41M  |   1.4M   |    39.4    |\n|     TFNet 0.5     |  43M  |   1.3M   |  **36.6 (+2.8)**  |\n|  MobileNetV2 0.6  |  141M |   2.2M   |    33.3    |\n| ShuffleNetV2 1.0x |  146M |   2.3M   |    30.6    |\n|     TFNet 1.0     |  135M |   1.9M   |  **29.7 (+0.9)**  |\n|  MobileNetV2 1.0  |  300M |   3.4M   |    28.0    |\n| ShuffleNetV2 1.5x |  299M |   3.5M   |    27.4    |\n|     TFNet 1.5     |  279M |   2.7M   |  **26.0 (+1.4)**  |\n|  MobileNetV2 1.4  |  585M |   5.5M   |    25.3    |\n| ShuffleNetV2 2.0x |  591M |   7.4M   |    25.0    |\n| TFNet 2.0         |  474M |   3.8M   |  **24.3 (+0.7)**  |\n\n\n\n\n## Trained Models\n- OneDrive download: [Link](https://1drv.ms/u/s!AgaP37NGYuEXhWbwpi4SX1IX6gOs?e=wIQYs1)\n- BaiduYun download: [Link](https://pan.baidu.com/s/18uDVWe-rh4b7qI_NBvWUCw) (extract code: 13fu)\n\n\n## Usage\n\n### Requirements\nDownload the ImageNet dataset and move validation images to labeled subfolders. To do this, you can use the following script:\nhttps://raw.githubusercontent.com/soumith/imagenetloader.torch/master/valprep.sh\n\n\nTrain:\n```shell\npython train.py  --train-dir YOUR_TRAINDATASET_PATH --val-dir YOUR_VALDATASET_PATH\n```\nEval:\n```shell\npython train.py --eval --eval-resume YOUR_WEIGHT_PATH --train-dir YOUR_TRAINDATASET_PATH --val-dir YOUR_VALDATASET_PATH\n```\n\n\n## Citation\nIf you use these models in your research, please cite:\n\n    @inproceedings{ma2021activate,\n      title={Activate or Not: Learning Customized Activation},\n      author={Ma, Ningning and Zhang, Xiangyu and Liu, Ming and Sun, Jian},\n      booktitle={Proceedings of the IEEE conference on computer vision and pattern recognition},\n      year={2021}\n    }\n\n"
  },
  {
    "path": "TFNet/README.md",
    "content": "# [TFNet](https://arxiv.org/pdf/2009.04759.pdf)\nThis repository contains TFNet implementation by Pytorch.\n\n\n### TFNet\nTo show the effectiveness of the proposed acon family, we provide an extreme simple toy funnel network (TFNet) made only by pointwise convolution and ACON-FReLU operators.\n\n<img src=\"https://user-images.githubusercontent.com/5032208/113963614-7a3a8200-985c-11eb-8946-65c0bcef0a80.png\"  width=60%>\n\n\n## Main results\n\n\n\nThe simple TFNet without the SE modules can outperform the state-of-the art light-weight networks without the SE modules.\n\n|                   | FLOPs | #Params. |   top-1 err.   |\n|-----------------  |:-----:|:--------:|:--------------:|\n|  MobileNetV2 0.17 |  42M  |   1.4M   |    52.6    |\n| ShuffleNetV2 0.5x |  41M  |   1.4M   |    39.4    |\n|     TFNet 0.5     |  43M  |   1.3M   |  **36.6 (+2.8)**  |\n|  MobileNetV2 0.6  |  141M |   2.2M   |    33.3    |\n| ShuffleNetV2 1.0x |  146M |   2.3M   |    30.6    |\n|     TFNet 1.0     |  135M |   1.9M   |  **29.7 (+0.9)**  |\n|  MobileNetV2 1.0  |  300M |   3.4M   |    28.0    |\n| ShuffleNetV2 1.5x |  299M |   3.5M   |    27.4    |\n|     TFNet 1.5     |  279M |   2.7M   |  **26.0 (+1.4)**  |\n|  MobileNetV2 1.4  |  585M |   5.5M   |    25.3    |\n| ShuffleNetV2 2.0x |  591M |   7.4M   |    25.0    |\n| TFNet 2.0         |  474M |   3.8M   |  **24.3 (+0.7)**  |\n\n\n\n\n## Trained Models\n- OneDrive download: [Link](https://1drv.ms/u/s!AgaP37NGYuEXhWbwpi4SX1IX6gOs?e=wIQYs1)\n- BaiduYun download: [Link](https://pan.baidu.com/s/18uDVWe-rh4b7qI_NBvWUCw) (extract code: 13fu)\n\n\n## Usage\n\n### Requirements\nDownload the ImageNet dataset and move validation images to labeled subfolders. To do this, you can use the following script:\nhttps://raw.githubusercontent.com/soumith/imagenetloader.torch/master/valprep.sh\n\n\nTrain:\n```shell\npython train.py  --train-dir YOUR_TRAINDATASET_PATH --val-dir YOUR_VALDATASET_PATH\n```\nEval:\n```shell\npython train.py --eval --eval-resume YOUR_WEIGHT_PATH --train-dir YOUR_TRAINDATASET_PATH --val-dir YOUR_VALDATASET_PATH\n```\n\n\n## Citation\nIf you use these models in your research, please cite:\n\n    @inproceedings{ma2021activate,\n      title={Activate or Not: Learning Customized Activation},\n      author={Ma, Ningning and Zhang, Xiangyu and Liu, Ming and Sun, Jian},\n      booktitle={Proceedings of the IEEE conference on computer vision and pattern recognition},\n      year={2021}\n    }\n"
  },
  {
    "path": "TFNet/network.py",
    "content": "import torch\nimport torch.nn as nn\n\nclass Acon_FReLU(nn.Module):\n    r\"\"\" ACON activation (activate or not) based on FReLU:\n    # eta_a(x) = x, eta_b(x) = dw_conv(x), according to\n    # \"Funnel Activation for Visual Recognition\" <https://arxiv.org/pdf/2007.11824.pdf>.\n    \"\"\"\n    def __init__(self, width, stride=1):\n        super().__init__()\n        self.stride = stride\n\n        # eta_b(x)\n        self.conv_frelu = nn.Conv2d(width, width, kernel_size=3, stride=stride, padding=1, groups=width, bias=True)\n        self.bn1 = nn.BatchNorm2d(width)\n\n        # eta_a(x)\n        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)\n        self.bn2 = nn.BatchNorm2d(width)\n        self.sigmoid = nn.Sigmoid()\n\n    def forward(self, x, **kwargs):\n        if self.stride == 2:\n            x1 = self.maxpool(x)\n        else:\n            x1 = x\n\n        x2 = self.bn1(self.conv_frelu(x))\n\n        return self.bn2( (x1 - x2) * self.sigmoid(x1 - x2) + x2 )\n\n\nclass TFBlock(nn.Module):\n    def __init__(self, inp, stride):\n        super(TFBlock, self).__init__()\n        self.oup = inp * stride\n        self.stride = stride\n\n        branch_main = [\n            # pw conv\n            nn.Conv2d(inp, inp, kernel_size=1, stride=1, bias=True),\n            nn.BatchNorm2d(inp),\n            Acon_FReLU(inp),\n            # pw conv\n            nn.Conv2d(inp, inp, kernel_size=1, stride=1, bias=True),\n            nn.BatchNorm2d(inp)\n        ]\n        self.branch_main = nn.Sequential(*branch_main)\n\n        self.acon = Acon_FReLU(self.oup, stride)\n\n    def forward(self, x):\n        x_proj = x\n        x = self.branch_main(x)\n\n        if self.stride==1:\n            return self.acon(x_proj + x)\n\n        elif self.stride==2:\n            return self.acon(torch.cat((x_proj, x), 1))\n\n\nclass TFNet(nn.Module):\n    def __init__(self, n_class=1000, model_size=0.5):\n        super(TFNet, self).__init__()\n        print('model size is ', model_size)\n\n        self.stages = [2, 3, 8, 3]\n        self.in_channel = int(16 * model_size)\n        self.out_channel = 1024\n        self.model_size = model_size\n\n        # building the first layer\n        self.first_conv = nn.Sequential(\n            nn.Conv2d(3, self.in_channel, 3, 2, 1, bias=True),\n            nn.BatchNorm2d(self.in_channel),\n            nn.ReLU(inplace=True),\n        )\n\n        # building the four stages' features\n        self.features = []\n        for stage in self.stages:\n            for i in range(stage):\n                self.features.append(\n                        TFBlock(self.in_channel, stride = 1 if i > 0 else 2))\n                self.in_channel = self.in_channel * 2 if i == 0 else self.in_channel\n        self.features = nn.Sequential(*self.features)\n\n        # building the last layer\n        self.conv_last = nn.Sequential(\n            nn.Conv2d(self.in_channel, self.out_channel, 1, 1, 0, bias=True),\n            nn.BatchNorm2d(self.out_channel),\n            Acon_FReLU(self.out_channel),\n        )\n        self.globalpool = nn.AvgPool2d(7)\n        if self.model_size > 0.5:\n            self.dropout = nn.Dropout(0.2)\n        self.classifier = nn.Sequential(nn.Linear(self.out_channel, n_class, bias=True))\n        self._initialize_weights()\n\n    def forward(self, x):\n        x = self.first_conv(x)\n        x = self.features(x)\n        x = self.conv_last(x)\n\n        x = self.globalpool(x)\n        if self.model_size > 0.5:\n            x = self.dropout(x)\n        x = x.contiguous().view(-1, self.out_channel)\n        x = self.classifier(x)\n        return x\n\n    def _initialize_weights(self):\n        for name, m in self.named_modules():\n            if isinstance(m, nn.Conv2d):\n                if 'first' in name or 'frelu' in name:\n                    nn.init.normal_(m.weight, 0, 0.01)\n                else:\n                    nn.init.normal_(m.weight, 0, 1.0 / m.weight.shape[1])\n                if m.bias is not None:\n                    nn.init.constant_(m.bias, 0)\n            elif isinstance(m, nn.BatchNorm2d):\n                nn.init.constant_(m.weight, 1)\n                if m.bias is not None:\n                    nn.init.constant_(m.bias, 0.0001)\n                nn.init.constant_(m.running_mean, 0)\n            elif isinstance(m, nn.BatchNorm1d):\n                nn.init.constant_(m.weight, 1)\n                if m.bias is not None:\n                    nn.init.constant_(m.bias, 0.0001)\n                nn.init.constant_(m.running_mean, 0)\n            elif isinstance(m, nn.Linear):\n                nn.init.normal_(m.weight, 0, 0.01)\n                if m.bias is not None:\n                    nn.init.constant_(m.bias, 0)\n"
  },
  {
    "path": "TFNet/train.py",
    "content": "import os\nimport sys\nimport torch\nimport argparse\nimport torch.nn as nn\nimport torchvision.transforms as transforms\nimport torchvision.datasets as datasets\nimport cv2\nimport numpy as np\nimport PIL\nfrom PIL import Image\nimport time\nimport logging\nimport argparse\nfrom network import TFNet\nfrom utils import accuracy, AvgrageMeter, CrossEntropyLabelSmooth, save_checkpoint, get_lastest_model, get_parameters\n\nclass OpencvResize(object):\n\n    def __init__(self, size=256):\n        self.size = size\n\n    def __call__(self, img):\n        assert isinstance(img, PIL.Image.Image)\n        img = np.asarray(img) # (H,W,3) RGB\n        img = img[:,:,::-1] # 2 BGR\n        img = np.ascontiguousarray(img)\n        H, W, _ = img.shape\n        target_size = (int(self.size/H * W + 0.5), self.size) if H < W else (self.size, int(self.size/W * H + 0.5))\n        img = cv2.resize(img, target_size, interpolation=cv2.INTER_LINEAR)\n        img = img[:,:,::-1] # 2 RGB\n        img = np.ascontiguousarray(img)\n        img = Image.fromarray(img)\n        return img\n\nclass ToBGRTensor(object):\n\n    def __call__(self, img):\n        assert isinstance(img, (np.ndarray, PIL.Image.Image))\n        if isinstance(img, PIL.Image.Image):\n            img = np.asarray(img)\n        img = img[:,:,::-1] # 2 BGR\n        img = np.transpose(img, [2, 0, 1]) # 2 (3, H, W)\n        img = np.ascontiguousarray(img)\n        img = torch.from_numpy(img).float()\n        return img\n\nclass DataIterator(object):\n\n    def __init__(self, dataloader):\n        self.dataloader = dataloader\n        self.iterator = enumerate(self.dataloader)\n\n    def next(self):\n        try:\n            _, data = next(self.iterator)\n        except Exception:\n            self.iterator = enumerate(self.dataloader)\n            _, data = next(self.iterator)\n        return data[0], data[1]\n\ndef get_args():\n    parser = argparse.ArgumentParser(\"TFNet\")\n    parser.add_argument('--eval', default=False, action='store_true')\n    parser.add_argument('--eval-resume', type=str, default='./tfnet.0.5.pth', help='path for eval model')\n    parser.add_argument('--batch-size', type=int, default=1024, help='batch size')\n    parser.add_argument('--total-iters', type=int, default=300000, help='total iters')\n    parser.add_argument('--learning-rate', type=float, default=0.5, help='init learning rate')\n    parser.add_argument('--momentum', type=float, default=0.9, help='momentum')\n    parser.add_argument('--weight-decay', type=float, default=4e-5, help='weight decay')\n    parser.add_argument('--save', type=str, default='./models', help='path for saving trained models')\n    parser.add_argument('--label-smooth', type=float, default=0.1, help='label smoothing')\n\n    parser.add_argument('--auto-continue', type=bool, default=True, help='auto continue')\n    parser.add_argument('--display-interval', type=int, default=20, help='display interval')\n    parser.add_argument('--val-interval', type=int, default=10000, help='val interval')\n    parser.add_argument('--save-interval', type=int, default=10000, help='save interval')\n\n\n    parser.add_argument('--model-size', type=float, default=0.5, choices=[0.5, 1.0, 1.5, 2.0], help='size of the model')\n\n    parser.add_argument('--train-dir', type=str, default='data/train', help='path to training dataset')\n    parser.add_argument('--val-dir', type=str, default='data/val', help='path to validation dataset')\n\n    args = parser.parse_args()\n    return args\n\ndef main():\n    args = get_args()\n\n    # Log\n    log_format = '[%(asctime)s] %(message)s'\n    logging.basicConfig(stream=sys.stdout, level=logging.INFO,\n        format=log_format, datefmt='%d %I:%M:%S')\n    t = time.time()\n    local_time = time.localtime(t)\n    if not os.path.exists('./log'):\n        os.mkdir('./log')\n    fh = logging.FileHandler(os.path.join('log/train-{}{:02}{}'.format(local_time.tm_year % 2000, local_time.tm_mon, t)))\n    fh.setFormatter(logging.Formatter(log_format))\n    logging.getLogger().addHandler(fh)\n\n    use_gpu = False\n    if torch.cuda.is_available():\n        use_gpu = True\n\n    assert os.path.exists(args.train_dir)\n    train_dataset = datasets.ImageFolder(\n        args.train_dir,\n        transforms.Compose([\n            transforms.RandomResizedCrop(224),\n            transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4),\n            transforms.RandomHorizontalFlip(0.5),\n            ToBGRTensor(),\n        ])\n    )\n    train_loader = torch.utils.data.DataLoader(\n        train_dataset, batch_size=args.batch_size, shuffle=True,\n        num_workers=1, pin_memory=use_gpu)\n    train_dataprovider = DataIterator(train_loader)\n\n    assert os.path.exists(args.val_dir)\n    val_loader = torch.utils.data.DataLoader(\n        datasets.ImageFolder(args.val_dir, transforms.Compose([\n            OpencvResize(256),\n            transforms.CenterCrop(224),\n            ToBGRTensor(),\n        ])),\n        batch_size=200, shuffle=False,\n        num_workers=1, pin_memory=use_gpu\n    )\n    val_dataprovider = DataIterator(val_loader)\n    print('load data successfully')\n\n    model = TFNet(model_size=args.model_size)\n\n    optimizer = torch.optim.SGD(get_parameters(model),\n                                lr=args.learning_rate,\n                                momentum=args.momentum,\n                                weight_decay=args.weight_decay)\n    criterion_smooth = CrossEntropyLabelSmooth(1000, 0.1)\n\n    if use_gpu:\n        model = nn.DataParallel(model)\n        loss_function = criterion_smooth.cuda()\n        device = torch.device(\"cuda\")\n    else:\n        loss_function = criterion_smooth\n        device = torch.device(\"cpu\")\n\n    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer,\n                    lambda step : (1.0-step/args.total_iters) if step <= args.total_iters else 0, last_epoch=-1)\n\n    model = model.to(device)\n\n    all_iters = 0\n    if args.auto_continue:\n        lastest_model, iters = get_lastest_model()\n        if lastest_model is not None:\n            all_iters = iters\n            checkpoint = torch.load(lastest_model, map_location=None if use_gpu else 'cpu')\n            model.load_state_dict(checkpoint['state_dict'], strict=True)\n            print('load from checkpoint')\n            for i in range(iters):\n                scheduler.step()\n\n    args.optimizer = optimizer\n    args.loss_function = loss_function\n    args.scheduler = scheduler\n    args.train_dataprovider = train_dataprovider\n    args.val_dataprovider = val_dataprovider\n\n    if args.eval:\n        if args.eval_resume is not None:\n            checkpoint = torch.load(args.eval_resume, map_location=None if use_gpu else 'cpu')\n            load_checkpoint(model, checkpoint)\n            validate(model, device, args, all_iters=all_iters)\n        exit(0)\n\n    while all_iters < args.total_iters:\n        all_iters = train(model, device, args, val_interval=args.val_interval, bn_process=False, all_iters=all_iters)\n        validate(model, device, args, all_iters=all_iters)\n    validate(model, device, args, all_iters=all_iters)\n    save_checkpoint({'state_dict': model.state_dict(),}, args.total_iters, tag='bnps-')\n\ndef adjust_bn_momentum(model, iters):\n    for m in model.modules():\n        if isinstance(m, nn.BatchNorm2d):\n            m.momentum = 1 / iters\n\ndef train(model, device, args, *, val_interval, bn_process=False, all_iters=None):\n\n    optimizer = args.optimizer\n    loss_function = args.loss_function\n    scheduler = args.scheduler\n    train_dataprovider = args.train_dataprovider\n\n    t1 = time.time()\n    Top1_err, Top5_err = 0.0, 0.0\n    model.train()\n    for iters in range(1, val_interval + 1):\n        scheduler.step()\n        if bn_process:\n            adjust_bn_momentum(model, iters)\n\n        all_iters += 1\n        d_st = time.time()\n        data, target = train_dataprovider.next()\n        target = target.type(torch.LongTensor)\n        data, target = data.to(device), target.to(device)\n        data_time = time.time() - d_st\n\n        output = model(data)\n        loss = loss_function(output, target)\n        optimizer.zero_grad()\n        loss.backward()\n        optimizer.step()\n        prec1, prec5 = accuracy(output, target, topk=(1, 5))\n\n        Top1_err += 1 - prec1.item() / 100\n        Top5_err += 1 - prec5.item() / 100\n\n        if all_iters % args.display_interval == 0:\n            printInfo = 'TRAIN Iter {}: lr = {:.6f},\\tloss = {:.6f},\\t'.format(all_iters, scheduler.get_lr()[0], loss.item()) + \\\n                        'Top-1 err = {:.6f},\\t'.format(Top1_err / args.display_interval) + \\\n                        'Top-5 err = {:.6f},\\t'.format(Top5_err / args.display_interval) + \\\n                        'data_time = {:.6f},\\ttrain_time = {:.6f}'.format(data_time, (time.time() - t1) / args.display_interval)\n            logging.info(printInfo)\n            t1 = time.time()\n            Top1_err, Top5_err = 0.0, 0.0\n\n        if all_iters % args.save_interval == 0:\n            save_checkpoint({\n                'state_dict': model.state_dict(),\n                }, all_iters)\n\n    return all_iters\n\ndef validate(model, device, args, *, all_iters=None):\n    objs = AvgrageMeter()\n    top1 = AvgrageMeter()\n    top5 = AvgrageMeter()\n\n    loss_function = args.loss_function\n    val_dataprovider = args.val_dataprovider\n\n    model.eval()\n    max_val_iters = 250\n    t1  = time.time()\n    with torch.no_grad():\n        for _ in range(1, max_val_iters + 1):\n            data, target = val_dataprovider.next()\n            target = target.type(torch.LongTensor)\n            data, target = data.to(device), target.to(device)\n\n            output = model(data)\n            loss = loss_function(output, target)\n\n            prec1, prec5 = accuracy(output, target, topk=(1, 5))\n            n = data.size(0)\n            objs.update(loss.item(), n)\n            top1.update(prec1.item(), n)\n            top5.update(prec5.item(), n)\n\n    logInfo = 'TEST Iter {}: loss = {:.6f},\\t'.format(all_iters, objs.avg) + \\\n              'Top-1 err = {:.6f},\\t'.format(1 - top1.avg / 100) + \\\n              'Top-5 err = {:.6f},\\t'.format(1 - top5.avg / 100) + \\\n              'val_time = {:.6f}'.format(time.time() - t1)\n    logging.info(logInfo)\n\ndef load_checkpoint(net, checkpoint):\n    from collections import OrderedDict\n\n    temp = OrderedDict()\n    if 'state_dict' in checkpoint:\n        checkpoint = dict(checkpoint['state_dict'])\n    for k in checkpoint:\n        k2 = 'module.'+k if not k.startswith('module.') else k\n        temp[k2] = checkpoint[k]\n\n    net.load_state_dict(temp, strict=True)\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "TFNet/utils.py",
    "content": "import os\nimport re\nimport torch\nimport torch.nn as nn\n\nclass CrossEntropyLabelSmooth(nn.Module):\n\n\tdef __init__(self, num_classes, epsilon):\n\t\tsuper(CrossEntropyLabelSmooth, self).__init__()\n\t\tself.num_classes = num_classes\n\t\tself.epsilon = epsilon\n\t\tself.logsoftmax = nn.LogSoftmax(dim=1)\n\n\tdef forward(self, inputs, targets):\n\t\tlog_probs = self.logsoftmax(inputs)\n\t\ttargets = torch.zeros_like(log_probs).scatter_(1, targets.unsqueeze(1), 1)\n\t\ttargets = (1 - self.epsilon) * targets + self.epsilon / self.num_classes\n\t\tloss = (-targets * log_probs).mean(0).sum()\n\t\treturn loss\n\n\nclass AvgrageMeter(object):\n\n\tdef __init__(self):\n\t\tself.reset()\n\n\tdef reset(self):\n\t\tself.avg = 0\n\t\tself.sum = 0\n\t\tself.cnt = 0\n\t\tself.val = 0\n\n\tdef update(self, val, n=1):\n\t\tself.val = val\n\t\tself.sum += val * n\n\t\tself.cnt += n\n\t\tself.avg = self.sum / self.cnt\n\n\ndef accuracy(output, target, topk=(1,)):\n\tmaxk = max(topk)\n\tbatch_size = target.size(0)\n\n\t_, pred = output.topk(maxk, 1, True, True)\n\tpred = pred.t()\n\tcorrect = pred.eq(target.view(1, -1).expand_as(pred))\n\n\tres = []\n\tfor k in topk:\n\t\tcorrect_k = correct[:k].reshape(-1).float().sum(0)\n\t\tres.append(correct_k.mul_(100.0/batch_size))\n\treturn res\n\n\ndef save_checkpoint(state, iters, tag=''):\n\tif not os.path.exists(\"./models\"):\n\t\tos.makedirs(\"./models\")\n\tfilename = os.path.join(\"./models/{}checkpoint-{:06}.pth.tar\".format(tag, iters))\n\ttorch.save(state, filename)\n\ndef get_lastest_model():\n\tif not os.path.exists('./models'):\n\t\tos.mkdir('./models')\n\tmodel_list = os.listdir('./models/')\n\tif model_list == []:\n\t\treturn None, 0\n\tmodel_list.sort()\n\tlastest_model = model_list[-1]\n\titers = re.findall(r'\\d+', lastest_model)\n\treturn './models/' + lastest_model, int(iters[0])\n\n\ndef get_parameters(model):\n\tgroup_no_weight_decay = []\n\tgroup_weight_decay = []\n\tfor pname, p in model.named_parameters():\n\t\tif pname.find('weight') >= 0 and len(p.size()) > 1:\n\t\t\t# print('include ', pname, p.size())\n\t\t\tgroup_weight_decay.append(p)\n\t\telse:\n\t\t\t# print('not include ', pname, p.size())\n\t\t\tgroup_no_weight_decay.append(p)\n\tassert len(list(model.parameters())) == len(group_weight_decay) + len(group_no_weight_decay)\n\tgroups = [dict(params=group_weight_decay), dict(params=group_no_weight_decay, weight_decay=0.)]\n\treturn groups\n"
  },
  {
    "path": "acon.py",
    "content": "import torch\nfrom torch import nn\n\n\nclass AconC(nn.Module):\n    r\"\"\" ACON activation (activate or not).\n    # AconC: (p1*x-p2*x) * sigmoid(beta*(p1*x-p2*x)) + p2*x, beta is a learnable parameter\n    # according to \"Activate or Not: Learning Customized Activation\" <https://arxiv.org/pdf/2009.04759.pdf>.\n    \"\"\"\n\n    def __init__(self, width):\n        super().__init__()\n        self.p1 = nn.Parameter(torch.randn(1, width, 1, 1))\n        self.p2 = nn.Parameter(torch.randn(1, width, 1, 1))\n        self.beta = nn.Parameter(torch.ones(1, width, 1, 1))\n\n    def forward(self, x):\n        return (self.p1 * x - self.p2 * x) * torch.sigmoid(self.beta * (self.p1 * x - self.p2 * x)) + self.p2 * x\n\n\nclass MetaAconC(nn.Module):\n    r\"\"\" ACON activation (activate or not).\n    # MetaAconC: (p1*x-p2*x) * sigmoid(beta*(p1*x-p2*x)) + p2*x, beta is generated by a small network\n    # according to \"Activate or Not: Learning Customized Activation\" <https://arxiv.org/pdf/2009.04759.pdf>.\n    \"\"\"\n\n    def __init__(self, width, r=16):\n        super().__init__()\n        self.fc1 = nn.Conv2d(width, max(r, width // r), kernel_size=1, stride=1, bias=True)\n        self.bn1 = nn.BatchNorm2d(max(r, width // r))\n        self.fc2 = nn.Conv2d(max(r, width // r), width, kernel_size=1, stride=1, bias=True)\n        self.bn2 = nn.BatchNorm2d(width)\n\n        self.p1 = nn.Parameter(torch.randn(1, width, 1, 1))\n        self.p2 = nn.Parameter(torch.randn(1, width, 1, 1))\n\n    def forward(self, x):\n        beta = torch.sigmoid(\n            self.bn2(self.fc2(self.bn1(self.fc1(x.mean(dim=2, keepdims=True).mean(dim=3, keepdims=True))))))\n        return (self.p1 * x - self.p2 * x) * torch.sigmoid(beta * (self.p1 * x - self.p2 * x)) + self.p2 * x\n"
  }
]