[
  {
    "path": ".gitignore",
    "content": "data/\r\n*.txt\r\n*.pkl\r\n*.pyc"
  },
  {
    "path": "README.md",
    "content": "# Aggregation Cross-Entropy for Sequence Recognition\r\nThis repository contains the code for the paper **Aggregation Cross-Entropy for Sequence Recognition**. Zecheng Xie, Yaoxiong Huang, Yuanzhi Zhu, Lianwen Jin, Yuliang Liu and Lele Xie. CVPR. 2019. [\\[Paper\\]](https://arxiv.org/abs/1904.08364)\r\n\r\nConnectionist temporal classification (CTC) and attention mechanism are the most popular methods for sequence-learning problem. However, CTC relies on a sophisticated forward-backward algorithm for transcription, which prevents it from addressing two-dimensional (2D) prediction problem, whereas the attention mechanism leans on a complex attention module to fulfill its functionality, resulting in additional network parameters and runtime. \r\n\r\nIn this paper, we propose a novel method, aggregation cross-entropy (ACE), for sequence recognition from a brand new perspective. The ACE loss function exhibits competitive performance to CTC and the attention mechanism, with much quicker implementation (as it involves only four fundamental formulas), faster inference\\back-propagation (approximately *O(1)* in parallel), less storage requirement (no parameter and negligible runtime memory), and convenient employment (by replacing CTC with ACE). Furthermore, the proposed ACE loss function exhibits two noteworthy properties: (1) it can be directly applied for 2D prediction by flattening the 2D prediction into 1D prediction as the input and (2) it requires only characters and their numbers in the sequence annotation for supervision, which allows it to advance beyond sequence recognition, e.g., counting problem.\r\n\r\n![](./image/1.jpg)\r\nFigure 1: Illustration of proposed ACE loss function. Generally, the 1D and 2D predictions are generated by integrated CNN-LSTM and FCN model, respectively. For the ACE loss function, the 2D prediction is further flattened to 1D prediction. During aggregation, the 1D predictions at all time-steps are accumulated for each class independently. After normalization, the prediction, together with the ground-truth, is utilized for loss estimation based on cross-entropy.\r\n\r\n![](./image/2.jpg)\r\nFigure 2: Toy example to show the advantage of ACE loss function. Resnet-50 trained with ACE loss function is able to recognize shuffled characters in the images. For each sub-image, the right column shows the 2D prediction of the recognition model for the text images. It is noteworthy that they have similar character distributions in the 2D space.\r\n\r\n## Requirements\r\n- [Python 2.7](https://www.python.org/) \r\n- [PyTorch >= 0.4.1](https://pytorch.org/) \r\n- [TorchVision](https://pypi.org/project/torchvision/)\r\n- [OpenCV](https://opencv.org/)\r\n\r\n## Data Preparation\r\ntar -xzvf data.tar.gz\r\n\r\n## Training and Testing\r\nStart training: (in 'source/' folder)\r\n```bash\r\n  sh train.sh\r\n```\r\n- The training process should take **about 10s** for 100 iterations on a 1080Ti.\r\n\r\n## Citation\r\n```\r\n@inproceedings{xie2019ace,\r\n  title     = {Aggregation Cross-Entropy for Sequence Recognition},\r\n  author    = {Zecheng Xie, Yaoxiong Huang, Yuanzhi Zhu, Lianwen Jin, Yuliang Liu and Lele Xie},\r\n  booktitle = {CVPR}, \r\n  year      = {2019},\r\n}\r\n```\r\n\r\n## Attention\r\nThe project is only free for academic research purposes."
  },
  {
    "path": "log/log/.gitkeep",
    "content": "# Ignore everything in this directory \r\n* \r\n# Except this file !.gitkeep "
  },
  {
    "path": "log/snapshot/.gitkeep",
    "content": "# Ignore everything in this directory \r\n* \r\n# Except this file !.gitkeep "
  },
  {
    "path": "source/main.py",
    "content": "# -*- coding: utf-8 -*-\nfrom __future__ import print_function, division\nimport torch\nimport argparse\nimport numpy as np\nimport torch.nn as nn\nfrom torch import optim\nimport torch.nn.functional as F\nfrom models.seq_module import ACE\nfrom torch.autograd import Variable\nfrom models.solver import seq_solver\nfrom utils.basic import timeSince\nfrom torch.utils.data import DataLoader\nfrom utils.data_loader import ImageDataset\n\nparser = argparse.ArgumentParser()\nparser.add_argument('--model_path', type=str, default='../log/snapshot/model-{:0>2d}.pkl')\nparser.add_argument('--total_epoch', type=int, default=50, help='total epoch number')\nparser.add_argument('--train_path', type=str, default='../data/train.txt')\nparser.add_argument('--test_path', type=str, default='../data/test.txt')\nparser.add_argument('--train_batch_size', type=int, default=50, help='training batch size')\nparser.add_argument('--test_batch_size', type=int, default=50, help='testing batch size')\nparser.add_argument('--last_epoch', type=int, default=0, help='last epoch')\nparser.add_argument('--class_num', type=int, default=26, help='class number')\nparser.add_argument('--dict', type=str, default='_abcdefghijklmnopqrstuvwxyz')\nopt = parser.parse_args()\nprint(opt)\n\nimport torchvision.models as models\n\nclass ResnetEncoderDecoder(nn.Module):\n    def __init__(self, loss_layer):\n        super(ResnetEncoderDecoder, self).__init__()\n        self.bn = nn.BatchNorm2d(64)\n        resnet = models.resnet18(pretrained=True)\n        self.conv  = nn.Conv2d(1,   64,   kernel_size=(3, 3), padding=(1, 1), stride=(1, 1))\n        self.cnn = nn.Sequential(*list(resnet.children())[4:-2])\n        self.out = nn.Linear(512, opt.class_num+1)\n        self.loss_layer = loss_layer(opt.dict)\n\n    def forward(self, input, labels):\n        input = F.relu(self.bn(self.conv(input)), True)\n        input = F.max_pool2d(input, kernel_size=(2, 2), stride=(2, 2)) \n        input = self.cnn(input)\n\n        input = input.permute(0,2,3,1)\n        input = F.softmax(self.out(input),dim=-1)\n\n        labels = labels.cuda()\n\n        return  self.loss_layer(input,labels)\n\n\nif __name__ == \"__main__\":\n\n    model = ResnetEncoderDecoder(ACE).cuda()\n    print(model)\n\n    optimizer = optim.Adadelta(model.parameters())\n\n    if opt.last_epoch != 0:\n        check_point = torch.load(opt.model_path.format(opt.last_epoch))\n        model.load_state_dict(check_point['state_dict'])\n        optimizer.load_state_dict(check_point['optimizer'])\n        scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones = [opt.total_epoch], gamma = 0.1, last_epoch = opt.last_epoch)    \n    else:\n        scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones = [opt.total_epoch], gamma = 0.1)    \n\n\n    train_set = ImageDataset(file_name = opt.train_path, length = 5000, class_num = opt.class_num)\n    lmdb_train = DataLoader(train_set, batch_size=opt.train_batch_size, shuffle=True, num_workers=0) \n\n    test_set = ImageDataset(file_name = opt.test_path, length = 1000, class_num = opt.class_num)\n    lmdb_test = DataLoader(test_set, batch_size=opt.test_batch_size, shuffle=False, num_workers=0) \n\n\n    the_solver = seq_solver(model = model,\n                        lmdb = [lmdb_train, lmdb_test],\n                        optimizer = optimizer, \n                        scheduler = scheduler,\n                        total_epoch = opt.total_epoch,\n                        model_path = opt.model_path,\n                        last_epoch = opt.last_epoch)\n\n    the_solver.forward()\n\n"
  },
  {
    "path": "source/models/__init__.py",
    "content": ""
  },
  {
    "path": "source/models/seq_module.py",
    "content": "# -*- coding: utf-8 -*-\nimport math\nimport torch\nimport random\nimport itertools\nimport numpy as np\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom torch.autograd import Variable\n\nclass Sequence(nn.Module):\n    def __init__(self):\n        super(Sequence, self).__init__()\n\n    def result_analysis(self, iteration):\n        pass;\n\n\n\nclass ACE(Sequence):\n\n    def __init__(self, dictionary):\n        super(ACE, self).__init__()\n        self.softmax = None;\n        self.label = None;\n        self.dict=dictionary\n\n    def forward(self, input, label):\n\n        self.bs,self.h,self.w,_ = input.size()\n        T_ = self.h*self.w\n\n        input = input.view(self.bs,T_,-1)\n        input = input + 1e-10\n\n        self.softmax = input\n        label[:,0] = T_ - label[:,0]\n        self.label = label\n\n        # ACE Implementation (four fundamental formulas)\n        input = torch.sum(input,1)\n        input = input/T_\n        label = label/T_\n        loss = (-torch.sum(torch.log(input)*label))/self.bs\n\n        return loss\n\n\n    def decode_batch(self):\n        out_best = torch.max(self.softmax, 2)[1].data.cpu().numpy()\n        pre_result = [0]*self.bs\n        for j in range(self.bs):\n            pre_result[j] = out_best[j][out_best[j]!=0]\n        return pre_result\n\n\n    def vis(self,iteration):\n\n        sn = random.randint(0,self.bs-1)\n        print('Test image %4d:' % (iteration*50+sn))\n\n        pred = torch.max(self.softmax, 2)[1].data.cpu().numpy()\n        pred = pred[sn].tolist() # sample #0\n        pred_string = ''.join(['%2s' % self.dict[pn] for pn in pred])\n        pred_string_set = [pred_string[i:i+self.w*2] for i in xrange(0, len(pred_string), self.w*2)]\n        print('Prediction: ')\n        for pre_str in pred_string_set:\n            print(pre_str)\n        label = ''.join(['%2s:%2d'%(self.dict[idx],pn) for idx, pn in enumerate(self.label[sn]) if idx != 0 and pn != 0])\n        label = 'Label: ' + label\n        print(label)\n\n\n\n    def result_analysis(self, iteration):\n        prediction = self.decode_batch()\n        correct_count = 0\n        pre_total = 0\n        len_total = self.label[:,1:].sum()\n        label_data = self.label.data.cpu().numpy()\n        for idx, pre_list in enumerate(prediction):\n            for pw in pre_list:\n                if label_data[idx][pw] > 0:\n                    correct_count = correct_count + 1\n                    label_data[idx][pw] -= 1\n\n            pre_total += len(pre_list)  \n\n        if not self.training and random.random() < 0.05:\n            self.vis(iteration)\n\n        return correct_count, len_total, pre_total  \n\n"
  },
  {
    "path": "source/models/solver.py",
    "content": "import time\nimport torch\nimport numpy as np\nfrom torch.autograd import Variable\nfrom utils.basic import timeSince\n\nclass solver():\n\n\tdef __init__(self, model, lmdb, optimizer, scheduler, total_epoch, model_path, last_epoch):\n\n\t\tself.model = model\n\t\tprint(self.model)\n\n\t\tself.lmdb_train, self.lmdb_test = lmdb\n\t\tself.optimizer = optimizer\n\t\tself.scheduler = scheduler\n\t\tself.total_epoch = total_epoch\n\t\tself.model_path = model_path\n\t\tself.last_epoch = last_epoch\n\n\t\tself.start = time.time()\n\n\tdef train_one_epoch(self, ep):\n\t\tpass\n\tdef test_one_epoch(self, ep):\n\t\tpass\n\n\tdef forward(self):\n\t\tfor ep in range(self.total_epoch-self.last_epoch):\n\t\t\tep = ep+self.last_epoch\n\t\t\tself.train_one_epoch(ep)\n\t\t\tself.test_one_epoch(ep)\n\t\t\nimport pdb\nclass seq_solver(solver):\n\n\tdef train_one_epoch(self, ep):\n\t\tself.model.train()\n\t\tloss_aver = 0\n\t\tif self.scheduler is not None:\n\t\t\tself.scheduler.step()\n\t\t\tprint('learning_rate: ', self.scheduler.get_lr())\t\n\t\tfor it, sample_batched in enumerate(self.lmdb_train):\n\t\t\tinputs = sample_batched['image'].squeeze(0)\n\t\t\tlabels = sample_batched['label'].squeeze(0)\n\n\t\t\tinputs = Variable(inputs.cuda())\n\t\t\tloss = self.model(inputs, labels)\n\t\t\tself.optimizer.zero_grad()\n\t\t\tloss.backward()\n\t\t\tloss = loss.data.item()\n\t\t\tl2_norm = torch.nn.utils.clip_grad_norm_(self.model.parameters(),10)\n\n\t\t\tif not np.isnan(l2_norm):\n\t\t\t\tself.optimizer.step()\n\t\t\telse:\n\t\t\t\tprint('l2_norm: ', l2_norm)\n\t\t\t\tl2_norm = 0\n\n\t\t\tif it == 0:\n\t\t\t\tloss_aver = loss\n\t\t\tloss_aver = 0.9*loss_aver+0.1*loss\t\t\t  \n\t\t\tif it == len(self.lmdb_train)-1:\n\n\t\t\t\tcorrect_count, len_total, pre_total = self.model.loss_layer.result_analysis(it)\n\n\t\t\t\trecall = float(correct_count) / len_total\n\t\t\t\tprecision = correct_count / (pre_total+0.000001)\n\n\t\t\t\tprint('Train: %10s Epoch: %3d it: %6d, loss: %.4f, l2_norm: %.4f, recall: %.4f, precision: %.4f' % \n\t\t\t\t\t(timeSince(self.start), ep, it, loss_aver, l2_norm, recall, precision))\n\n\t\ttorch.save({\n\t\t\t'epoch': ep,\n\t\t\t'state_dict': self.model.state_dict(),\n\t\t\t'optimizer' : self.optimizer.state_dict(),\n\t\t\t}, self.model_path.format(ep)) \t\n\n\n\tdef test_one_epoch(self, ep):\n\t\tself.model.eval()\n\t\tloss_aver = 0\n\t\tfor it, sample_batched in enumerate(self.lmdb_test):\n\t\t\tinputs = sample_batched['image'].squeeze(0)\n\t\t\tlabels = sample_batched['label'].squeeze(0)\n\n\t\t\tinputs = Variable(inputs.cuda())\n\t\t\tloss = self.model(inputs, labels)\n\t\t\tcorrect_count, len_total, pre_total = self.model.loss_layer.result_analysis(it)\n\n\t\t\tloss = loss.data.item()\n\t\t\tif it == 0:\n\t\t\t\tloss_aver = loss\n\t\t\tloss_aver = 0.9*loss_aver+0.1*loss\t\t\n\n\t\t\tif it == len(self.lmdb_test) -1:\n\t\t\t\trecall = float(correct_count) / len_total\n\t\t\t\tprecision = correct_count / (pre_total+0.000001)\t\n\t\t\t\tprint('Test : %10s Epoch: %3d it: %6d, loss: %.4f, len : %4d, recall: %.4f, precision: %.4f' % \n\t\t\t\t\t\t\t(timeSince(self.start), ep, it, loss_aver, len_total, recall, precision))\t\n\n\n"
  },
  {
    "path": "source/train.sh",
    "content": "#!/usr/bin/env bash\n\nfilename=\"../log/log/log_`date +%y_%m_%d_%H_%M_%S`.txt\"\nCUDA_VISIBLE_DEVICES=0 python -u main.py \\\n\t2>&1 | tee $filename\n\n\n\n\t"
  },
  {
    "path": "source/utils/__init__.py",
    "content": ""
  },
  {
    "path": "source/utils/basic.py",
    "content": "import time\r\nimport math\r\n\r\n\r\ndef asMinutes(s):\r\n    m = math.floor(s / 60)\r\n    s -= m * 60\r\n    return '%dm %ds' % (m, s)\r\n\r\n\r\ndef timeSince(since):\r\n    now = time.time()\r\n    s = now - since\r\n    return '%s' % (asMinutes(s))\r\n"
  },
  {
    "path": "source/utils/data_loader.py",
    "content": "import cv2\nimport torch\nimport numpy as np\nfrom torch.utils.data import Dataset, DataLoader\n\n\nclass ImageDataset(Dataset):\n    \"\"\"Face Landmarks dataset.\"\"\"\n\n    def __init__(self, file_name, length, class_num, transform=None):\n        \"\"\"\n        Args:\n            file_name (string): Path to the files with images and their annotations.\n            length (string): image number.\n            class_num (int): class number.\n        \"\"\"\n        with open(file_name) as fh:\n            self.img_and_label = fh.readlines()\n        self.length = length\n        self.transform = transform\n        self.class_num = class_num\n\n    def __len__(self):\n        return self.length\n\n    def __getitem__(self, idx):\n\n        img_and_label = self.img_and_label[idx].strip()\n        pth, word = img_and_label.split(' ') # image path and its annotation\n\n        image = cv2.imread(pth,0)\n        image = cv2.pyrDown(image).astype('float32') # 100*100\n\n        word = [ord(var)-97 for var in word] # a->0\n\n        label = np.zeros((self.class_num+1)).astype('float32')\n\n        for ln in word:\n            label[int(ln+1)] += 1 # label construction for ACE\n\n        label[0] = len(word)\n\n        sample = {'image': image, 'label': label}\n\n        sample = {'image': torch.from_numpy(image).unsqueeze(0), 'label': torch.from_numpy(label)}\n\n        if self.transform:\n            sample = self.transform(sample)\n\n        return sample    \n\n\n"
  }
]