Repository: liaoxy169/Aggregation-Cross-Entropy Branch: master Commit: a31ccec38d6e Files: 12 Total size: 13.9 KB Directory structure: gitextract_44y1fogf/ ├── .gitignore ├── README.md ├── log/ │ ├── log/ │ │ └── .gitkeep │ └── snapshot/ │ └── .gitkeep └── source/ ├── main.py ├── models/ │ ├── __init__.py │ ├── seq_module.py │ └── solver.py ├── train.sh └── utils/ ├── __init__.py ├── basic.py └── data_loader.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: .gitignore ================================================ data/ *.txt *.pkl *.pyc ================================================ FILE: README.md ================================================ # Aggregation Cross-Entropy for Sequence Recognition This repository contains the code for the paper **Aggregation Cross-Entropy for Sequence Recognition**. Zecheng Xie, Yaoxiong Huang, Yuanzhi Zhu, Lianwen Jin, Yuliang Liu and Lele Xie. CVPR. 2019. [\[Paper\]](https://arxiv.org/abs/1904.08364) Connectionist temporal classification (CTC) and attention mechanism are the most popular methods for sequence-learning problem. However, CTC relies on a sophisticated forward-backward algorithm for transcription, which prevents it from addressing two-dimensional (2D) prediction problem, whereas the attention mechanism leans on a complex attention module to fulfill its functionality, resulting in additional network parameters and runtime. In this paper, we propose a novel method, aggregation cross-entropy (ACE), for sequence recognition from a brand new perspective. The ACE loss function exhibits competitive performance to CTC and the attention mechanism, with much quicker implementation (as it involves only four fundamental formulas), faster inference\back-propagation (approximately *O(1)* in parallel), less storage requirement (no parameter and negligible runtime memory), and convenient employment (by replacing CTC with ACE). Furthermore, the proposed ACE loss function exhibits two noteworthy properties: (1) it can be directly applied for 2D prediction by flattening the 2D prediction into 1D prediction as the input and (2) it requires only characters and their numbers in the sequence annotation for supervision, which allows it to advance beyond sequence recognition, e.g., counting problem. ![](./image/1.jpg) Figure 1: Illustration of proposed ACE loss function. Generally, the 1D and 2D predictions are generated by integrated CNN-LSTM and FCN model, respectively. For the ACE loss function, the 2D prediction is further flattened to 1D prediction. During aggregation, the 1D predictions at all time-steps are accumulated for each class independently. After normalization, the prediction, together with the ground-truth, is utilized for loss estimation based on cross-entropy. ![](./image/2.jpg) Figure 2: Toy example to show the advantage of ACE loss function. Resnet-50 trained with ACE loss function is able to recognize shuffled characters in the images. For each sub-image, the right column shows the 2D prediction of the recognition model for the text images. It is noteworthy that they have similar character distributions in the 2D space. ## Requirements - [Python 2.7](https://www.python.org/) - [PyTorch >= 0.4.1](https://pytorch.org/) - [TorchVision](https://pypi.org/project/torchvision/) - [OpenCV](https://opencv.org/) ## Data Preparation tar -xzvf data.tar.gz ## Training and Testing Start training: (in 'source/' folder) ```bash sh train.sh ``` - The training process should take **about 10s** for 100 iterations on a 1080Ti. ## Citation ``` @inproceedings{xie2019ace, title = {Aggregation Cross-Entropy for Sequence Recognition}, author = {Zecheng Xie, Yaoxiong Huang, Yuanzhi Zhu, Lianwen Jin, Yuliang Liu and Lele Xie}, booktitle = {CVPR}, year = {2019}, } ``` ## Attention The project is only free for academic research purposes. ================================================ FILE: log/log/.gitkeep ================================================ # Ignore everything in this directory * # Except this file !.gitkeep ================================================ FILE: log/snapshot/.gitkeep ================================================ # Ignore everything in this directory * # Except this file !.gitkeep ================================================ FILE: source/main.py ================================================ # -*- coding: utf-8 -*- from __future__ import print_function, division import torch import argparse import numpy as np import torch.nn as nn from torch import optim import torch.nn.functional as F from models.seq_module import ACE from torch.autograd import Variable from models.solver import seq_solver from utils.basic import timeSince from torch.utils.data import DataLoader from utils.data_loader import ImageDataset parser = argparse.ArgumentParser() parser.add_argument('--model_path', type=str, default='../log/snapshot/model-{:0>2d}.pkl') parser.add_argument('--total_epoch', type=int, default=50, help='total epoch number') parser.add_argument('--train_path', type=str, default='../data/train.txt') parser.add_argument('--test_path', type=str, default='../data/test.txt') parser.add_argument('--train_batch_size', type=int, default=50, help='training batch size') parser.add_argument('--test_batch_size', type=int, default=50, help='testing batch size') parser.add_argument('--last_epoch', type=int, default=0, help='last epoch') parser.add_argument('--class_num', type=int, default=26, help='class number') parser.add_argument('--dict', type=str, default='_abcdefghijklmnopqrstuvwxyz') opt = parser.parse_args() print(opt) import torchvision.models as models class ResnetEncoderDecoder(nn.Module): def __init__(self, loss_layer): super(ResnetEncoderDecoder, self).__init__() self.bn = nn.BatchNorm2d(64) resnet = models.resnet18(pretrained=True) self.conv = nn.Conv2d(1, 64, kernel_size=(3, 3), padding=(1, 1), stride=(1, 1)) self.cnn = nn.Sequential(*list(resnet.children())[4:-2]) self.out = nn.Linear(512, opt.class_num+1) self.loss_layer = loss_layer(opt.dict) def forward(self, input, labels): input = F.relu(self.bn(self.conv(input)), True) input = F.max_pool2d(input, kernel_size=(2, 2), stride=(2, 2)) input = self.cnn(input) input = input.permute(0,2,3,1) input = F.softmax(self.out(input),dim=-1) labels = labels.cuda() return self.loss_layer(input,labels) if __name__ == "__main__": model = ResnetEncoderDecoder(ACE).cuda() print(model) optimizer = optim.Adadelta(model.parameters()) if opt.last_epoch != 0: check_point = torch.load(opt.model_path.format(opt.last_epoch)) model.load_state_dict(check_point['state_dict']) optimizer.load_state_dict(check_point['optimizer']) scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones = [opt.total_epoch], gamma = 0.1, last_epoch = opt.last_epoch) else: scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones = [opt.total_epoch], gamma = 0.1) train_set = ImageDataset(file_name = opt.train_path, length = 5000, class_num = opt.class_num) lmdb_train = DataLoader(train_set, batch_size=opt.train_batch_size, shuffle=True, num_workers=0) test_set = ImageDataset(file_name = opt.test_path, length = 1000, class_num = opt.class_num) lmdb_test = DataLoader(test_set, batch_size=opt.test_batch_size, shuffle=False, num_workers=0) the_solver = seq_solver(model = model, lmdb = [lmdb_train, lmdb_test], optimizer = optimizer, scheduler = scheduler, total_epoch = opt.total_epoch, model_path = opt.model_path, last_epoch = opt.last_epoch) the_solver.forward() ================================================ FILE: source/models/__init__.py ================================================ ================================================ FILE: source/models/seq_module.py ================================================ # -*- coding: utf-8 -*- import math import torch import random import itertools import numpy as np import torch.nn as nn import torch.nn.functional as F from torch.autograd import Variable class Sequence(nn.Module): def __init__(self): super(Sequence, self).__init__() def result_analysis(self, iteration): pass; class ACE(Sequence): def __init__(self, dictionary): super(ACE, self).__init__() self.softmax = None; self.label = None; self.dict=dictionary def forward(self, input, label): self.bs,self.h,self.w,_ = input.size() T_ = self.h*self.w input = input.view(self.bs,T_,-1) input = input + 1e-10 self.softmax = input label[:,0] = T_ - label[:,0] self.label = label # ACE Implementation (four fundamental formulas) input = torch.sum(input,1) input = input/T_ label = label/T_ loss = (-torch.sum(torch.log(input)*label))/self.bs return loss def decode_batch(self): out_best = torch.max(self.softmax, 2)[1].data.cpu().numpy() pre_result = [0]*self.bs for j in range(self.bs): pre_result[j] = out_best[j][out_best[j]!=0] return pre_result def vis(self,iteration): sn = random.randint(0,self.bs-1) print('Test image %4d:' % (iteration*50+sn)) pred = torch.max(self.softmax, 2)[1].data.cpu().numpy() pred = pred[sn].tolist() # sample #0 pred_string = ''.join(['%2s' % self.dict[pn] for pn in pred]) pred_string_set = [pred_string[i:i+self.w*2] for i in xrange(0, len(pred_string), self.w*2)] print('Prediction: ') for pre_str in pred_string_set: print(pre_str) label = ''.join(['%2s:%2d'%(self.dict[idx],pn) for idx, pn in enumerate(self.label[sn]) if idx != 0 and pn != 0]) label = 'Label: ' + label print(label) def result_analysis(self, iteration): prediction = self.decode_batch() correct_count = 0 pre_total = 0 len_total = self.label[:,1:].sum() label_data = self.label.data.cpu().numpy() for idx, pre_list in enumerate(prediction): for pw in pre_list: if label_data[idx][pw] > 0: correct_count = correct_count + 1 label_data[idx][pw] -= 1 pre_total += len(pre_list) if not self.training and random.random() < 0.05: self.vis(iteration) return correct_count, len_total, pre_total ================================================ FILE: source/models/solver.py ================================================ import time import torch import numpy as np from torch.autograd import Variable from utils.basic import timeSince class solver(): def __init__(self, model, lmdb, optimizer, scheduler, total_epoch, model_path, last_epoch): self.model = model print(self.model) self.lmdb_train, self.lmdb_test = lmdb self.optimizer = optimizer self.scheduler = scheduler self.total_epoch = total_epoch self.model_path = model_path self.last_epoch = last_epoch self.start = time.time() def train_one_epoch(self, ep): pass def test_one_epoch(self, ep): pass def forward(self): for ep in range(self.total_epoch-self.last_epoch): ep = ep+self.last_epoch self.train_one_epoch(ep) self.test_one_epoch(ep) import pdb class seq_solver(solver): def train_one_epoch(self, ep): self.model.train() loss_aver = 0 if self.scheduler is not None: self.scheduler.step() print('learning_rate: ', self.scheduler.get_lr()) for it, sample_batched in enumerate(self.lmdb_train): inputs = sample_batched['image'].squeeze(0) labels = sample_batched['label'].squeeze(0) inputs = Variable(inputs.cuda()) loss = self.model(inputs, labels) self.optimizer.zero_grad() loss.backward() loss = loss.data.item() l2_norm = torch.nn.utils.clip_grad_norm_(self.model.parameters(),10) if not np.isnan(l2_norm): self.optimizer.step() else: print('l2_norm: ', l2_norm) l2_norm = 0 if it == 0: loss_aver = loss loss_aver = 0.9*loss_aver+0.1*loss if it == len(self.lmdb_train)-1: correct_count, len_total, pre_total = self.model.loss_layer.result_analysis(it) recall = float(correct_count) / len_total precision = correct_count / (pre_total+0.000001) print('Train: %10s Epoch: %3d it: %6d, loss: %.4f, l2_norm: %.4f, recall: %.4f, precision: %.4f' % (timeSince(self.start), ep, it, loss_aver, l2_norm, recall, precision)) torch.save({ 'epoch': ep, 'state_dict': self.model.state_dict(), 'optimizer' : self.optimizer.state_dict(), }, self.model_path.format(ep)) def test_one_epoch(self, ep): self.model.eval() loss_aver = 0 for it, sample_batched in enumerate(self.lmdb_test): inputs = sample_batched['image'].squeeze(0) labels = sample_batched['label'].squeeze(0) inputs = Variable(inputs.cuda()) loss = self.model(inputs, labels) correct_count, len_total, pre_total = self.model.loss_layer.result_analysis(it) loss = loss.data.item() if it == 0: loss_aver = loss loss_aver = 0.9*loss_aver+0.1*loss if it == len(self.lmdb_test) -1: recall = float(correct_count) / len_total precision = correct_count / (pre_total+0.000001) print('Test : %10s Epoch: %3d it: %6d, loss: %.4f, len : %4d, recall: %.4f, precision: %.4f' % (timeSince(self.start), ep, it, loss_aver, len_total, recall, precision)) ================================================ FILE: source/train.sh ================================================ #!/usr/bin/env bash filename="../log/log/log_`date +%y_%m_%d_%H_%M_%S`.txt" CUDA_VISIBLE_DEVICES=0 python -u main.py \ 2>&1 | tee $filename ================================================ FILE: source/utils/__init__.py ================================================ ================================================ FILE: source/utils/basic.py ================================================ import time import math def asMinutes(s): m = math.floor(s / 60) s -= m * 60 return '%dm %ds' % (m, s) def timeSince(since): now = time.time() s = now - since return '%s' % (asMinutes(s)) ================================================ FILE: source/utils/data_loader.py ================================================ import cv2 import torch import numpy as np from torch.utils.data import Dataset, DataLoader class ImageDataset(Dataset): """Face Landmarks dataset.""" def __init__(self, file_name, length, class_num, transform=None): """ Args: file_name (string): Path to the files with images and their annotations. length (string): image number. class_num (int): class number. """ with open(file_name) as fh: self.img_and_label = fh.readlines() self.length = length self.transform = transform self.class_num = class_num def __len__(self): return self.length def __getitem__(self, idx): img_and_label = self.img_and_label[idx].strip() pth, word = img_and_label.split(' ') # image path and its annotation image = cv2.imread(pth,0) image = cv2.pyrDown(image).astype('float32') # 100*100 word = [ord(var)-97 for var in word] # a->0 label = np.zeros((self.class_num+1)).astype('float32') for ln in word: label[int(ln+1)] += 1 # label construction for ACE label[0] = len(word) sample = {'image': image, 'label': label} sample = {'image': torch.from_numpy(image).unsqueeze(0), 'label': torch.from_numpy(label)} if self.transform: sample = self.transform(sample) return sample