[
  {
    "path": ".gitattributes",
    "content": "# Auto detect text files and perform LF normalization\n* text=auto\n"
  },
  {
    "path": ".gitignore",
    "content": "# Byte-compiled / optimized / DLL files\n__pycache__/\n*.py[cod]\n*$py.class\n\n# C extensions\n*.so\n\n# Distribution / packaging\n.Python\nbuild/\ndevelop-eggs/\ndist/\ndownloads/\neggs/\n.eggs/\nlib/\nlib64/\nparts/\nsdist/\nvar/\nwheels/\n*.egg-info/\n.installed.cfg\n*.egg\n\n# PyInstaller\n#  Usually these files are written by a python script from a template\n#  before PyInstaller builds the exe, so as to inject date/other infos into it.\n*.manifest\n*.spec\n\n# Installer logs\npip-log.txt\npip-delete-this-directory.txt\n\n# Unit test / coverage reports\nhtmlcov/\n.tox/\n.coverage\n.coverage.*\n.cache\nnosetests.xml\ncoverage.xml\n*.cover\n.hypothesis/\n\n# Translations\n*.mo\n*.pot\n\n# Django stuff:\n*.log\nlocal_settings.py\n\n# Flask stuff:\ninstance/\n.webassets-cache\n\n# Scrapy stuff:\n.scrapy\n\n# Sphinx documentation\ndocs/_build/\n\n# PyBuilder\ntarget/\n\n# Jupyter Notebook\n.ipynb_checkpoints\n\n# pyenv\n.python-version\n\n# celery beat schedule file\ncelerybeat-schedule\n\n# SageMath parsed files\n*.sage.py\n\n# Environments\n.env\n.venv\nenv/\nvenv/\nENV/\n\n# Spyder project settings\n.spyderproject\n.spyproject\n\n# Rope project settings\n.ropeproject\n\n# mkdocs documentation\n/site\n\n# mypy\n.mypy_cache/\n\n.DS_Store\nruns*\n*.json\nCIFAR/*"
  },
  {
    "path": "LICENSE",
    "content": "MIT License\n\nCopyright (c) 2018 jxgu1016\n\nPermission is hereby granted, free of charge, to any person obtaining a copy\nof this software and associated documentation files (the \"Software\"), to deal\nin the Software without restriction, including without limitation the rights\nto use, copy, modify, merge, publish, distribute, sublicense, and/or sell\ncopies of the Software, and to permit persons to whom the Software is\nfurnished to do so, subject to the following conditions:\n\nThe above copyright notice and this permission notice shall be included in all\ncopies or substantial portions of the Software.\n\nTHE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\nIMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\nFITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\nAUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\nLIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\nOUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\nSOFTWARE."
  },
  {
    "path": "README.md",
    "content": "# WACV2018/TIP: Gabor Convolutional Networks\n\nOfficial PyTorch implementation of Gabor CNN. \nBut all the results in the paper are based on [Torch 7](https://github.com/bczhangbczhang/Gabor-Convolutional-Networks).\nThese two implementations are sharing the same infrastructure level code.\n\n## Requirements\n- PyTorch 1.1.0 (earlier versions are not supported)\n- torchvision\n  \n## Install\n\n```\ngit clone https://github.com/jxgu1016/Gabor_CNN_PyTorch\ncd Gabor_CNN_PyTorch\nsh install.sh\n```\n\n## Run MNIST demo\n\n```\ncd demo\npython main.py --model gcn (--gpu 0)\n```\n\n## Please cite:\n@article{GaborCNNs, title={Gabor Convolutional Networks}, author={Luan, Shangzhen and chen, chen and Zhang, Baochang* and Han, jungong and Liu, Jianzhuang}, year={2018}, IEEE Trans. Image processing. }\n"
  },
  {
    "path": "demo/main.py",
    "content": "from __future__ import division\nimport os\nimport time\nimport argparse\nimport torch\nfrom torchvision import datasets, transforms\nimport torch.optim as optim\nimport torch.nn as nn\nfrom torch.optim.lr_scheduler import StepLR, MultiStepLR\nimport torch.nn.functional as F\nfrom utils import accuracy, AverageMeter, save_checkpoint, visualize_graph, get_parameters_size\nfrom torch.utils.tensorboard import SummaryWriter\nfrom net_factory import get_network_fn\n\n\nparser = argparse.ArgumentParser(description='PyTorch GCN MNIST Training')\nparser.add_argument('--epochs', default=50, type=int, metavar='N',\n                    help='number of total epochs to run')\nparser.add_argument('-j', '--workers', default=4, type=int, metavar='N',\n                    help='number of data loading workers (default: 4)')\nparser.add_argument('--start-epoch', default=0, type=int, metavar='N',\n                    help='manual epoch number (useful on restarts)')\nparser.add_argument('-b', '--batch-size', default=128, type=int,\n                    metavar='N', help='mini-batch size (default: 64)')\nparser.add_argument('--lr', '--learning-rate', default=0.01, type=float,\n                    metavar='LR', help='initial learning rate')\nparser.add_argument('--momentum', default=0.9, type=float, metavar='M',\n                    help='momentum')\nparser.add_argument('--print-freq', '-p', default=100, type=int,\n                    metavar='N', help='print frequency (default: 10)')\nparser.add_argument('--resume', default='', type=str, metavar='PATH',\n                    help='path to latest checkpoint (default: none)')\nparser.add_argument('--pretrained', default='', type=str, metavar='PATH',\n                    help='path to pretrained checkpoint (default: none)')\nparser.add_argument('--gpu', default=-1, type=int,\n                    metavar='N', help='GPU device ID (default: -1)')\nparser.add_argument('--dataset_dir', default='../../MNIST', type=str, metavar='PATH',\n                    help='path to dataset (default: ../MNIST)')\nparser.add_argument('--comment', default='', type=str, metavar='INFO',\n                    help='Extra description for tensorboard')\nparser.add_argument('--model', default='', type=str, metavar='NETWORK',\n                    help='Network to train')\nargs = parser.parse_args()\n\nuse_cuda = (args.gpu >= 0) and torch.cuda.is_available()\nbest_prec1 = 0\nwriter = SummaryWriter(comment='_'+args.model+'_'+args.comment)\niteration = 0\n\n# Prepare the MNIST dataset\nnormalize = transforms.Normalize((0.1307,), (0.3081,))\ntrain_transform = transforms.Compose([\n    transforms.ToTensor(),\n    normalize,\n    ])\ntest_transform = transforms.Compose([\n    transforms.ToTensor(), \n    normalize,\n    ])\n\n\ntrain_dataset = datasets.MNIST(root=args.dataset_dir, train=True, \n    \t\t\t\tdownload=True, transform=train_transform)\ntest_dataset = datasets.MNIST(root=args.dataset_dir, train=False, \n                    download=True,transform=test_transform)\n\ntrain_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size,\n                                num_workers=args.workers, pin_memory=True, shuffle=True)\ntest_loader = torch.utils.data.DataLoader(test_dataset, batch_size=args.batch_size,\n                                num_workers=args.workers, pin_memory=True, shuffle=True)\n\n# Load model\nmodel = get_network_fn(args.model)\nprint(model)\n\n# Try to visulize the model\ntry:\n\tvisualize_graph(model, writer, input_size=(1, 1, 28, 28))\nexcept:\n\tprint('\\nNetwork Visualization Failed! But the training procedure continue.')\n\n# optimizer = optim.Adadelta(model.parameters(), lr=args.lr, rho=0.9, eps=1e-06, weight_decay=3e-05)\n# optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=3e-05)\noptimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=3e-05)\nscheduler = StepLR(optimizer, step_size=10, gamma=0.5)\ncriterion = nn.CrossEntropyLoss()\n\ndevice = torch.device(\"cuda\" if use_cuda else \"cpu\")\nmodel = model.to(device)\ncriterion = criterion.to(device)\n\n# Calculate the total parameters of the model\nprint('Model size: {:0.2f} million float parameters'.format(get_parameters_size(model)/1e6))\n\nif args.pretrained:\n    if os.path.isfile(args.pretrained):\n        print(\"=> loading checkpoint '{}'\".format(args.pretrained))\n        checkpoint = torch.load(args.pretrained)\n        model.load_state_dict(checkpoint['state_dict'])\n    else:\n        print(\"=> no checkpoint found at '{}'\".format(args.pretrained))\n\ndef train(epoch):\n    model.train()\n    global iteration\n    st = time.time()\n    for batch_idx, (data, target) in enumerate(train_loader):\n        iteration += 1\n        data, target = data.to(device), target.to(device)\n        optimizer.zero_grad()\n        output = model(data)\n        prec1, = accuracy(output, target)\n        loss = criterion(output, target)\n        loss.backward()\n        optimizer.step()\n        if batch_idx % args.print_freq == 0:\n            print('Train Epoch: {} [{}/{} ({:.0f}%)]\\tLoss: {:.6f}, Accuracy: {:.2f}'.format(\n                epoch, batch_idx * len(data), len(train_loader.dataset),\n                100. * batch_idx / len(train_loader), loss.item(), prec1.item()))\n            writer.add_scalar('Loss/Train', loss.item(), iteration)\n            writer.add_scalar('Accuracy/Train', prec1, iteration)\n    epoch_time = time.time() - st\n    print('Epoch time:{:0.2f}s'.format(epoch_time))\n    scheduler.step()\n\ndef test(epoch):\n    model.eval()\n    test_loss = AverageMeter()\n    acc = AverageMeter()\n    with torch.no_grad():\n        for data, target in test_loader:\n            data, target = data.to(device), target.to(device)\n            output = model(data)\n            test_loss.update(F.cross_entropy(output, target, reduction='mean').item(), target.size(0))\n            prec1, = accuracy(output, target) # test precison in one batch\n            acc.update(prec1.item(), target.size(0))\n    print('\\nTest set: Average loss: {:.4f}, Accuracy: {:.2f}%\\n'.format(test_loss.avg, acc.avg))\n    writer.add_scalar('Loss/Test', test_loss.avg, epoch)\n    writer.add_scalar('Accuracy/Test', acc.avg, epoch)\n    return acc.avg\n\nfor epoch in range(args.start_epoch, args.epochs):\n    print('------------------------------------------------------------------------')\n    train(epoch+1)\n    prec1 = test(epoch+1)\n\n    # remember best prec@1 and save checkpoint\n    is_best = prec1 > best_prec1\n    best_prec1 = max(prec1, best_prec1)\n    save_checkpoint({\n        'epoch': epoch + 1,\n        'state_dict': model.state_dict(),\n        'best_prec1': best_prec1,\n        'optimizer' : optimizer.state_dict(),\n    }, is_best)\n\nprint('Finished!')\nprint('Best Test Precision@top1:{:.2f}'.format(best_prec1))\nwriter.add_scalar('Best TOP1', best_prec1, 0)\nwriter.close()"
  },
  {
    "path": "demo/net_factory.py",
    "content": "from __future__ import division\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom gcn.layers import GConv\n\nclass GCN(nn.Module):\n    def __init__(self, channel=4):\n        super(GCN, self).__init__()\n        self.channel = channel\n        self.model = nn.Sequential(\n            GConv(1, 10, 5, padding=2, stride=1, M=channel, nScale=1, bias=False, expand=True),\n            nn.BatchNorm2d(10*channel),\n            nn.ReLU(inplace=True),\n\n            GConv(10, 20, 5, padding=2, stride=1, M=channel, nScale=2, bias=False),\n            nn.BatchNorm2d(20*channel),\n            nn.ReLU(inplace=True),\n            nn.MaxPool2d(2,2),\n\n            GConv(20, 40, 5, padding=0, stride=1, M=channel, nScale=3, bias=False),\n            nn.BatchNorm2d(40*channel),\n            nn.ReLU(inplace=True),\n            nn.MaxPool2d(2,2),\n\n            GConv(40, 80, 5, padding=0, stride=1, M=channel, nScale=4, bias=False),\n            nn.BatchNorm2d(80*channel),\n            nn.ReLU(inplace=True),\n        )\n        self.fc1 = nn.Linear(80, 1024)\n        self.relu = nn.ReLU(inplace=True)\n        self.dropout = nn.Dropout(p=0.5)\n        self.fc2 = nn.Linear(1024, 10)\n\n    def forward(self, x):\n        x = self.model(x)\n        # x = x.view(-1, self.channel, 80)\n        # x = torch.max(x, 1)[0]\n        ## x = x.view(-1, 80 * self.channel)\n        x = x.view(-1, 80, self.channel)\n        x = torch.max(x, 2)[0]\n        x = self.fc1(x)\n        x = self.relu(x)\n        x = self.dropout(x)\n        x = self.fc2(x)\n        return x\n\n\ndef get_network_fn(name):\n    networks_zoo = {\n    'gcn': GCN(channel=4),\n    }\n    if name is '':\n        raise ValueError('Specify the network to train. All networks available:{}'.format(networks_zoo.keys()))\n    elif name not in networks_zoo:\n        raise ValueError('Name of network unknown {}. All networks available:{}'.format(name, networks_zoo.keys()))\n    return networks_zoo[name]"
  },
  {
    "path": "demo/utils.py",
    "content": "import shutil\nimport torch\nfrom torch.utils.tensorboard import SummaryWriter\n\n\nclass AverageMeter(object):\n    \"\"\"Computes and stores the average and current value\"\"\"\n    def __init__(self):\n        self.reset()\n\n    def reset(self):\n        self.val = 0\n        self.avg = 0\n        self.sum = 0\n        self.count = 0\n\n    def update(self, val, n=1):\n        self.val = val\n        self.sum += val * n\n        self.count += n\n        self.avg = self.sum / self.count\n\ndef accuracy(output, target, topk=(1,)):\n    \"\"\"Computes the precision@k for the specified values of k\"\"\"\n    with torch.no_grad():\n        maxk = max(topk)\n        batch_size = target.size(0)\n\n        _, pred = output.topk(maxk, 1, True, True)\n        pred = pred.t()\n        correct = pred.eq(target.view(1, -1).expand_as(pred))\n\n        res = []\n        for k in topk:\n            correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)\n            res.append(correct_k.mul_(100.0 / batch_size))\n        return res\n\ndef save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):\n    torch.save(state, filename)\n    if is_best:\n        shutil.copyfile(filename, 'model_best.pth.tar')\n\ndef visualize_graph(model, writer, input_size=(1, 3, 32, 32)):\n    dummy_input = torch.rand(input_size)\n    # with SummaryWriter(comment=name) as w:\n    writer.add_graph(model, (dummy_input, ))\n\ndef get_parameters_size(model):\n    total = 0\n    for p in model.parameters():\n        _size = 1\n        for i in range(len(p.size())):\n            _size *= p.size(i)\n        total += _size\n    return total"
  },
  {
    "path": "gcn/__init__.py",
    "content": ""
  },
  {
    "path": "gcn/csrc/GOF.h",
    "content": "#pragma once\n\n#include \"cpu/vision.h\"\n\n#ifdef WITH_CUDA\n#include \"cuda/vision.h\"\n#endif\n\n// Interface for Python\nat::Tensor GOF_forward(const at::Tensor& weight, \n                       const at::Tensor& gaborFilterBank) {\n  if (weight.type().is_cuda()) {\n#ifdef WITH_CUDA\n    return GOF_forward_cuda(weight, gaborFilterBank);\n#else\n    AT_ERROR(\"Not compiled with GPU support\");\n#endif\n  }\n  return GOF_forward_cpu(weight, gaborFilterBank);\n}\n\nat::Tensor GOF_backward(const at::Tensor& grad_output,\n                        const at::Tensor& gaborFilterBank) {\n  if (grad_output.type().is_cuda()) {\n#ifdef WITH_CUDA\n    return GOF_backward_cuda(grad_output, gaborFilterBank);\n#else\n    AT_ERROR(\"Not compiled with GPU support\");\n#endif\n  }\n  return GOF_backward_cpu(grad_output, gaborFilterBank);\n}\n\n"
  },
  {
    "path": "gcn/csrc/cpu/GOF_cpu.cpp",
    "content": "#include \"cpu/vision.h\"\n\n\ntemplate <typename T>\nvoid GOFForward_cpu_kernel(\n  const T* weight_data,\n  const T* gaborFilterBank_data,\n  const int nOutputPlane,\n  const int nInputPlane,\n  const int nChannel,\n  const int kH,\n  const int kW,\n  T* output_data) {\n  for (int i = 0; i < nOutputPlane; i++) {\n    for (int j = 0; j < nInputPlane; j++) {\n      for (int l = 0; l < nChannel * kH * kW; l++) {\n        T val = *(weight_data + i * (nInputPlane * nChannel * kH * kW)\n                              + j * (nChannel * kH * kW)\n                              + l);\n        for (int k = 0; k < nChannel; k++) {\n          T gabortmp = *(gaborFilterBank_data + k * (kW * kH) \n                                              + l % (kW * kH));\n          T *target = output_data + i * (nChannel * nInputPlane * nChannel * kH * kW)\n                                  + k * (nInputPlane * nChannel * kH * kW)\n                                  + j * (nChannel * kH * kW)\n                                  + l;\n          *target = val * gabortmp;\n        }\n      }\n    }\n  }\n}\n\ntemplate <typename T>\nvoid GOFBackward_cpu_kernel(\n  const T* grad_output_data,\n  const T* gaborFilterBank_data,\n  const int nOutputPlane,\n  const int nInputPlane,\n  const int nChannel,\n  const int kH,\n  const int kW,\n  T* grad_weight_data) {\n  const int nEntry = nChannel * kH * kW;\n\n  for (int i = 0; i < nOutputPlane; i++) {\n    for (int j = 0; j < nInputPlane; j++) {\n      for (int l = 0; l < nEntry; l++) {\n        T *val = grad_weight_data + i * (nInputPlane * nEntry)\n                                  + j * (nEntry) + l;\n        *val = 0;\n        for (int k = 0; k < nChannel; k++) {\n          T gabortmp = *(gaborFilterBank_data + k * (kW * kH)\n                                              + l % (kW * kH));\n          T target = *(grad_output_data + i * (nChannel * nInputPlane * nEntry)\n                                       + k * (nInputPlane * nEntry)\n                                       + j * (nEntry)\n                                       + l);\n          *val = *val + target * gabortmp;\n        }\n      }\n    }\n  }\n}\n\n\nat::Tensor GOF_forward_cpu(const at::Tensor& weight,\n                           const at::Tensor& gaborFilterBank) {\n  AT_ASSERTM(!weight.type().is_cuda(), \"weight must be a CPU tensor\");\n  AT_ASSERTM(!gaborFilterBank.type().is_cuda(), \"gaborFilterBank must be a CPU tensor\");\n\n  auto nOutputPlane = weight.size(0);\n  auto nInputPlane = weight.size(1);\n  auto nChannel = weight.size(2);\n  auto kH = weight.size(3);\n  auto kW = weight.size(4);\n\n  auto output = at::empty({nOutputPlane * nChannel, nInputPlane * nChannel, kH, kW}, weight.options());\n\n  if (output.numel() == 0) {\n    return output;\n  }\n\n  AT_DISPATCH_FLOATING_TYPES(weight.type(), \"GOF_forward\", [&] {\n    GOFForward_cpu_kernel<scalar_t>(\n         weight.data<scalar_t>(),\n         gaborFilterBank.data<scalar_t>(),\n         nOutputPlane,\n         nInputPlane,\n         nChannel,\n         kH,\n         kW,\n         output.data<scalar_t>());\n  });\n  return output;\n}\n\nat::Tensor GOF_backward_cpu(const at::Tensor& grad_output,\n                            const at::Tensor& gaborFilterBank) {\n  AT_ASSERTM(!grad_output.type().is_cuda(), \"grad_output must be a CPU tensor\");\n  AT_ASSERTM(!gaborFilterBank.type().is_cuda(), \"gaborFilterBank must be a CPU tensor\");\n\n  auto nChannel = gaborFilterBank.size(0);\n  auto nOutputPlane = grad_output.size(0) / nChannel;\n  auto nInputPlane = grad_output.size(1) / nChannel;\n  auto kH = grad_output.size(2);\n  auto kW = grad_output.size(3);\n\n  auto grad_weight = at::empty({nOutputPlane, nInputPlane, nChannel, kH, kW}, grad_output.options());\n\n  if (grad_weight.numel() == 0) {\n    return grad_weight;\n  }\n\n  AT_DISPATCH_FLOATING_TYPES(grad_output.type(), \"GOF_backward\", [&] {\n    GOFBackward_cpu_kernel<scalar_t>(\n         grad_output.data<scalar_t>(),\n         gaborFilterBank.data<scalar_t>(),\n         nOutputPlane,\n         nInputPlane,\n         nChannel,\n         kH,\n         kW,\n         grad_weight.data<scalar_t>());\n  });\n  return grad_weight;\n}"
  },
  {
    "path": "gcn/csrc/cpu/vision.h",
    "content": "#pragma once\n#include <torch/extension.h>\n\n\nat::Tensor GOF_forward_cpu(const at::Tensor& weight, \n                           const at::Tensor& gaborFilterBank);\n\nat::Tensor GOF_backward_cpu(const at::Tensor& grad_output,\n                            const at::Tensor& gaborFilterBank);\n"
  },
  {
    "path": "gcn/csrc/cuda/GOF_cuda.cu",
    "content": "#include <ATen/ATen.h>\n#include <ATen/cuda/CUDAContext.h>\n\n#include <THC/THC.h>\n#include <THC/THCAtomics.cuh>\n#include <THC/THCDeviceUtils.cuh>\n\n// TODO make it in a common file\n#define CUDA_1D_KERNEL_LOOP(i, n)                            \\\n  for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; \\\n       i += blockDim.x * gridDim.x)\n\n\ntemplate <typename T>\n__global__ void GOFForward_cuda_kernel(const int nthreads,\n                                       const T* weight_data,\n                                       const T* gaborFilterBank_data, \n                                       const int nOutputPlane,\n                                       const int nInputPlane,\n                                       const int nChannel,\n                                       const int kH,\n                                       const int kW,\n                                       T* output_data) {\n  CUDA_1D_KERNEL_LOOP(index, nthreads) {\n    auto w = index % kW;\n    auto h = (index / kW) % kH;\n    auto c = (index / kW / kH) % nChannel;\n    auto in = (index / kW / kH / nChannel) % nInputPlane;\n    auto ori = (index / kW / kH / nChannel / nInputPlane) % nChannel;\n    auto ou = index / kW / kH / nChannel / nInputPlane / nChannel;\n    T val = *(weight_data + (((ou * nInputPlane + in) * nChannel + c) * kH + h) * kW + w);\n    T *target = output_data + index;\n    T gabortmp = *(gaborFilterBank_data + ori * (kH * kW)\n                                        + h * kW\n                                        + w);\n    *target = val * gabortmp;\n  }\n}\n\ntemplate <typename T>\n__global__ void GOFBackward_cuda_kernel(const int nthreads,\n                                       const T* grad_output_data,\n                                       const T* gaborFilterBank_data, \n                                       const int nOutputPlane,\n                                       const int nInputPlane,\n                                       const int nChannel,\n                                       const int kH,\n                                       const int kW,\n                                       T* grad_weight_data) {\n  auto nEntry = nChannel * kH * kW;\n  CUDA_1D_KERNEL_LOOP(index, nthreads) {\n    auto l = index % nEntry;\n    auto j = (index / nEntry) % nInputPlane;\n    auto i = index / nEntry / nInputPlane;\n    T *val = grad_weight_data + index;\n    *val = 0;\n    for (int k = 0; k < nChannel; k++) {\n      T gabortmp = *(gaborFilterBank_data + k * (kW * kH)\n                                          + l % (kW * kH));\n      T target = *(grad_output_data + i * (nChannel * nInputPlane * nEntry)\n                                    + k * (nInputPlane * nEntry)\n                                    + j * (nEntry)\n                                    + l);     \n\t\t\t*val = *val + target * gabortmp;\n    }\n  }\n}\n\nat::Tensor GOF_forward_cuda(const at::Tensor& weight,\n                            const at::Tensor& gaborFilterBank) {\n  AT_ASSERTM(weight.type().is_cuda(), \"weight must be a CUDA tensor\");\n  AT_ASSERTM(gaborFilterBank.type().is_cuda(), \"gaborFilterBank must be a CUDA tensor\");\n\n  auto nOutputPlane = weight.size(0);\n  auto nInputPlane = weight.size(1);\n  auto nChannel = weight.size(2);\n  auto kH = weight.size(3);\n  auto kW = weight.size(4);\n\n  auto output = at::empty({nOutputPlane * nChannel, nInputPlane * nChannel, kH, kW}, weight.options());\n  // auto nEntry = nChannel * kH * kW;\n  auto output_size = nOutputPlane * nChannel* nInputPlane * nChannel * kH * kW;\n  cudaStream_t stream = at::cuda::getCurrentCUDAStream();\n\n  dim3 grid(std::min(THCCeilDiv(output_size, 512L), 4096L));\n  dim3 block(512);\n\n  if (output.numel() == 0) {\n    THCudaCheck(cudaGetLastError());\n    return output;\n  }\n\n  AT_DISPATCH_FLOATING_TYPES(weight.type(), \"GOF_forward\", [&] {\n    GOFForward_cuda_kernel<scalar_t><<<grid, block, 0, stream>>>(\n      output_size,\n      weight.data<scalar_t>(),\n      gaborFilterBank.data<scalar_t>(),\n      nOutputPlane,\n      nInputPlane,\n      nChannel,\n      kH,\n      kW,\n      output.data<scalar_t>());\n  });\n  THCudaCheck(cudaGetLastError());\n  return output;\n}\n\nat::Tensor GOF_backward_cuda(const at::Tensor& grad_output,\n                             const at::Tensor& gaborFilterBank) {\n  AT_ASSERTM(grad_output.type().is_cuda(), \"grad_output must be a CUDA tensor\");\n  AT_ASSERTM(gaborFilterBank.type().is_cuda(), \"gaborFilterBank must be a CUDA tensor\");\n\n  auto nChannel = gaborFilterBank.size(0);\n  auto nOutputPlane = grad_output.size(0) / nChannel;\n  auto nInputPlane = grad_output.size(1) / nChannel;\n  auto kH = grad_output.size(2);\n  auto kW = grad_output.size(3);\n\n  auto grad_weight = at::empty({nOutputPlane, nInputPlane, nChannel, kH, kW}, grad_output.options());\n  auto nEntry = nChannel * kH * kW;\n  auto grad_weight_size = nOutputPlane * nInputPlane * nEntry;\n  cudaStream_t stream = at::cuda::getCurrentCUDAStream();\n\n  dim3 grid(std::min(THCCeilDiv(grad_weight_size, 512L), 4096L));\n  dim3 block(512);\n\n  if (grad_weight.numel() == 0) {\n    THCudaCheck(cudaGetLastError());\n    return grad_weight;\n  }\n\n  AT_DISPATCH_FLOATING_TYPES(grad_output.type(), \"GOF_backward\", [&] {\n    GOFBackward_cuda_kernel<scalar_t><<<grid, block, 0, stream>>>(\n      grad_weight_size,\n      grad_output.data<scalar_t>(),\n      gaborFilterBank.data<scalar_t>(),\n      nOutputPlane,\n      nInputPlane,\n      nChannel,\n      kH,\n      kW,\n      grad_weight.data<scalar_t>());\n  });\n  THCudaCheck(cudaGetLastError());\n  return grad_weight;\n}"
  },
  {
    "path": "gcn/csrc/cuda/vision.h",
    "content": "#pragma once\n#include <torch/extension.h>\n\n\nat::Tensor GOF_forward_cuda(const at::Tensor& weight, \n                            const at::Tensor& gaborFilterBank);\n\nat::Tensor GOF_backward_cuda(const at::Tensor& grad_output,\n                             const at::Tensor& gaborFilterBank);"
  },
  {
    "path": "gcn/csrc/vision.cpp",
    "content": "#include \"GOF.h\"\n\nPYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {\n  m.def(\"gof_forward\", &GOF_forward, \"GOF forward\");\n  m.def(\"gof_backward\", &GOF_backward, \"GOF backward\");\n}\n"
  },
  {
    "path": "gcn/layers/GConv.py",
    "content": "from __future__ import division\nimport math\nimport torch\nfrom torch import nn\nimport torch.nn.functional as F\nfrom torch.autograd import Function\nfrom torch.nn.modules.utils import _pair\nfrom torch.nn.modules.conv import _ConvNd\nfrom torch.autograd.function import once_differentiable\n\nfrom gcn import _C\n\nclass GOF_Function(Function):\n    @staticmethod\n    def forward(ctx, weight, gaborFilterBank):\n        ctx.save_for_backward(weight, gaborFilterBank)\n        output = _C.gof_forward(weight, gaborFilterBank)\n        return output\n\n    @staticmethod\n    @once_differentiable\n    def backward(ctx, grad_output):\n        weight, gaborFilterBank = ctx.saved_tensors\n        grad_weight = _C.gof_backward(grad_output, gaborFilterBank)\n        return grad_weight, None \n\nclass MConv(_ConvNd):\n    '''\n    Baee layer class for modulated convolution\n    '''\n    def __init__(self, in_channels, out_channels, kernel_size, M=4, nScale=3, stride=1,\n                    padding=0, dilation=1, groups=1, bias=True, expand=False, padding_mode='zeros'):\n        if groups != 1:\n            raise ValueError('Group-conv not supported!')\n        kernel_size = (M,) + _pair(kernel_size)\n        stride = _pair(stride)\n        padding = _pair(padding)\n        dilation = _pair(dilation)\n        super(MConv, self).__init__(\n            in_channels, out_channels, kernel_size, stride, padding, dilation,\n            False, _pair(0), groups, bias, padding_mode)\n        self.expand = expand\n        self.M = M\n        self.need_bias = bias\n        self.generate_MFilters(nScale, kernel_size)\n        self.GOF_Function = GOF_Function.apply\n\n    def generate_MFilters(self, nScale, kernel_size):\n        raise NotImplementedError\n\n    def forward(self, x):\n        if self.expand:\n            x = self.do_expanding(x)\n        new_weight = self.GOF_Function(self.weight, self.MFilters)\n        new_bias = self.expand_bias(self.bias) if self.need_bias else self.bias\n        if self.padding_mode == 'circular':\n            expanded_padding = ((self.padding[1] + 1) // 2, self.padding[1] // 2,\n                                (self.padding[0] + 1) // 2, self.padding[0] // 2)\n            return F.conv2d(F.pad(input, expanded_padding, mode='circular'),\n                            self.weight, self.bias, self.stride,\n                            _pair(0), self.dilation, self.groups)\n        return F.conv2d(x, new_weight, new_bias, self.stride,\n                self.padding, self.dilation, self.groups)\n\n    def do_expanding(self, x):\n        index = []\n        for i in range(x.size(1)):\n            for _ in range(self.M):\n                index.append(i)\n        index = torch.LongTensor(index).cuda() if x.is_cuda else torch.LongTensor(index)\n        return x.index_select(1, index)\n    \n    def expand_bias(self, bias):\n        index = []\n        for i in range(bias.size()):\n            for _ in range(self.M):\n                index.append(i)\n        index = torch.LongTensor(index).cuda() if bias.is_cuda else torch.LongTensor(index)\n        return bias.index_select(0, index)\n\nclass GConv(MConv):\n    '''\n    Gabor Convolutional Operation Layer\n    '''\n    def __init__(self, in_channels, out_channels, kernel_size, M=4, nScale=3, stride=1,\n                    padding=0, dilation=1, groups=1, bias=True, expand=False, padding_mode='zeros'):\n        super(GConv, self).__init__(in_channels, out_channels, kernel_size, M, nScale, stride,\n                    padding, dilation, groups, bias, expand, padding_mode)\n\n    def generate_MFilters(self, nScale, kernel_size):\n        # To generate Gabor Filters\n        self.register_buffer('MFilters', getGaborFilterBank(nScale, *kernel_size))\n\ndef getGaborFilterBank(nScale, M, h, w):\n    Kmax = math.pi / 2\n    f = math.sqrt(2)\n    sigma = math.pi\n    sqsigma = sigma ** 2\n    postmean = math.exp(-sqsigma / 2)\n    if h != 1:\n        gfilter_real = torch.zeros(M, h, w)\n        for i in range(M):\n            theta = i / M * math.pi\n            k = Kmax / f ** (nScale - 1)\n            xymax = -1e309\n            xymin = 1e309\n            for y in range(h):\n                for x in range(w):\n                    y1 = y + 1 - ((h + 1) / 2)\n                    x1 = x + 1 - ((w + 1) / 2)\n                    tmp1 = math.exp(-(k * k * (x1 * x1 + y1 * y1) / (2 * sqsigma)))\n                    tmp2 = math.cos(k * math.cos(theta) * x1 + k * math.sin(theta) * y1) - postmean # For real part\n                    # tmp3 = math.sin(k*math.cos(theta)*x1+k*math.sin(theta)*y1) # For imaginary part\n                    gfilter_real[i][y][x] = k * k * tmp1 * tmp2 / sqsigma\t\t\t\n                    xymax = max(xymax, gfilter_real[i][y][x])\n                    xymin = min(xymin, gfilter_real[i][y][x])\n            gfilter_real[i] = (gfilter_real[i] - xymin) / (xymax - xymin)\n    else:\n        gfilter_real = torch.ones(M, h, w)\n    return gfilter_real"
  },
  {
    "path": "gcn/layers/__init__.py",
    "content": "from .GConv import GConv\n\n__all__=['GConv']"
  },
  {
    "path": "gcn/layers/gradtest.py",
    "content": "import torch\nfrom torch.autograd import gradcheck\nfrom gcn.layers.GConv import GOF_Function\n\ndef gradchecking(use_cuda=False):\n    print('-'*80)\n    GOF = GOF_Function.apply\n    device = torch.device(\"cuda\" if use_cuda else \"cpu\")\n\n    weight = torch.randn(8,8,4,3,3).to(device).double().requires_grad_()\n    gfb = torch.randn(4,3,3).to(device).double()\n\n    test = gradcheck(GOF, (weight, gfb), eps=1e-6, atol=1e-4, rtol=1e-3, raise_exception=True)\n    print(test)\n\n\nif __name__ == \"__main__\":\n    gradchecking()\n    if torch.cuda.is_available():\n        gradchecking(use_cuda=True)"
  },
  {
    "path": "install.sh",
    "content": "python setup.py build develop\n\n# or if you are on macOS\n# MACOSX_DEPLOYMENT_TARGET=10.9 CC=clang CXX=clang++ python setup.py build develop"
  },
  {
    "path": "setup.py",
    "content": "#!/usr/bin/env python\n\nimport glob\nimport os\n\nimport torch\nfrom setuptools import find_packages\nfrom setuptools import setup\nfrom torch.utils.cpp_extension import CUDA_HOME\nfrom torch.utils.cpp_extension import CppExtension\nfrom torch.utils.cpp_extension import CUDAExtension\n\nrequirements = [\"torch\", \"torchvision\"]\n\n\ndef get_extensions():\n    this_dir = os.path.dirname(os.path.abspath(__file__))\n    extensions_dir = os.path.join(this_dir, \"gcn\", \"csrc\")\n\n    main_file = glob.glob(os.path.join(extensions_dir, \"*.cpp\"))\n    source_cpu = glob.glob(os.path.join(extensions_dir, \"cpu\", \"*.cpp\"))\n    source_cuda = glob.glob(os.path.join(extensions_dir, \"cuda\", \"*.cu\"))\n\n    sources = main_file + source_cpu\n    extension = CppExtension\n\n    extra_compile_args = {\"cxx\": []}\n    define_macros = []\n\n    if torch.cuda.is_available() and CUDA_HOME is not None:\n        extension = CUDAExtension\n        sources += source_cuda\n        define_macros += [(\"WITH_CUDA\", None)]\n        extra_compile_args[\"nvcc\"] = [\n            \"-DCUDA_HAS_FP16=1\",\n            \"-D__CUDA_NO_HALF_OPERATORS__\",\n            \"-D__CUDA_NO_HALF_CONVERSIONS__\",\n            \"-D__CUDA_NO_HALF2_OPERATORS__\",\n        ]\n\n    sources = [os.path.join(extensions_dir, s) for s in sources]\n\n    include_dirs = [extensions_dir]\n\n    ext_modules = [\n        extension(\n            \"gcn._C\",\n            sources,\n            include_dirs=include_dirs,\n            define_macros=define_macros,\n            extra_compile_args=extra_compile_args,\n        )\n    ]\n\n    return ext_modules\n\n\nsetup(\n    name=\"gcn\",\n    version=\"0.1\",\n    author=\"gujiaxin\",\n    url=\"https://github.com/jxgu1016/Gabor_CNN_PyTorch\",\n    description=\"Gabor Convolutional Networks in pytorch\",\n    packages=find_packages(),\n    # install_requires=requirements,\n    ext_modules=get_extensions(),\n    cmdclass={\"build_ext\": torch.utils.cpp_extension.BuildExtension},\n)\n"
  }
]