Repository: lcy0604/EraseNet Branch: master Commit: 673bfad49ec5 Files: 12 Total size: 40.4 KB Directory structure: gitextract_sn7w10qi/ ├── LICENSE ├── README.md ├── data/ │ └── dataloader.py ├── evaluatuion.py ├── gauss.py ├── loss/ │ └── Loss.py ├── models/ │ ├── Model.py │ ├── discriminator.py │ ├── networks.py │ └── sa_gan.py ├── test_image_STE.py └── train_STE.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: LICENSE ================================================ MIT License Copyright (c) 2020 Chongyu-Liu Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ================================================ FILE: README.md ================================================ # EraseNet This repository is the implementation of EraseNet, a neural network for end-to-end scene text removal. ## Data preparation The data preparation can be refer to ./examples/. You can download our datatset at [SCUT-EnsText](https://github.com/HCIILAB/SCUT-EnsText) or synthetic dataset [SCUT-Syn](https://github.com/HCIILAB/Scene-Text-Removal) for training and testing. SCUT-EnsText needs decompression password, you can send me at [liuchongyu1996@gmail.com](mailto:liuchongyu1996@gmail.com) for it. ## Environment Anaconda is recommended to establish a virtual environment to run our code. My environment can be refered as follows: ``` python = 3.7 pytorch = 1.3.1 torchvision = 0.4.2 ``` ## Demo We provide our retrain model for quick inference for SCUT-EnsText. [Model Link](https://drive.google.com/file/d/1scrtQ2GFvKjjoGEqbKxpOMn37mJmXsFd/view) ## Training Once the data is well prepared, you can begin training: ``` python train_STE.py --batchSize 4 \ --dataRoot 'your path' \ --modelsSavePath 'your path' \ --logPath 'your path' \ ``` ## Testing and evaluation If you want to predict the results, run: ``` python test_image_STE.py --dataRoot 'your path' \ --batchSize 1 \ --pretrain 'your path' \ --savePath 'your path' ``` To evaluate the results: ``` python evaluatuion.py --target_path 'results_path' --gt_path 'labels_path' ``` ## Acknowledge The repository is benefit a lot from [LBAM](https://github.com/Vious/LBAM_Pytorch) and [GatedConv](https://github.com/avalonstrel/GatedConvolution_pytorch). Thanks a lot for their excellent work. ## Citation If you find our method or dataset useful for your reserach, please cite: ``` @ARTICLE{Erase2020Liu, author ={Liu, Chongyu and Liu, Yuliang and Jin, lianwen and Zhang, Shuaitao and Luo, Canjie and Wang, Yongpan}, journal ={IEEE Transactions on Image Processing}, title ={EraseNet: End-to-End Text Removal in the Wild}, year ={2020}, volume ={29}, pages ={8760-8775},} @article{zhang2019EnsNet, title = {EnsNet: Ensconce Text in the Wild}, author = {Shuaitao Zhang∗, Yuliang Liu∗, Lianwen Jin†, Yaoxiong Huang, Songxuan Lai joural = {AAAI} year = {2019} } ``` ## Feedback Suggestions and opinions of our work (both positive and negative) are greatly welcome. Please contact the authors by sending email to Chongyu Liu([liuchongyu1996@gmail.com](mailto:liuchongyu1996@gmail.com)). For commercial usage, please contact Prof. Lianwen Jin via ([eelwjin@scut.edu.cn](mailto:eelwjin@scut.edu.cn)). ================================================ FILE: data/dataloader.py ================================================ import torch from torch.utils.data import Dataset from PIL import Image import numpy as np import cv2 from os import listdir, walk from os.path import join from random import randint import random from PIL import Image from torchvision.transforms import Compose, RandomCrop, ToTensor, ToPILImage, Resize, RandomHorizontalFlip def random_horizontal_flip(imgs): if random.random() < 0.3: for i in range(len(imgs)): imgs[i] = imgs[i].transpose(Image.FLIP_LEFT_RIGHT) return imgs def random_rotate(imgs): if random.random() < 0.3: max_angle = 10 angle = random.random() * 2 * max_angle - max_angle # print(angle) for i in range(len(imgs)): img = np.array(imgs[i]) w, h = img.shape[:2] rotation_matrix = cv2.getRotationMatrix2D((h / 2, w / 2), angle, 1) img_rotation = cv2.warpAffine(img, rotation_matrix, (h, w)) imgs[i] =Image.fromarray(img_rotation) return imgs def CheckImageFile(filename): return any(filename.endswith(extention) for extention in ['.png', '.PNG', '.jpg', '.JPG', '.jpeg', '.JPEG', '.bmp', '.BMP']) def ImageTransform(loadSize): return Compose([ Resize(size=loadSize, interpolation=Image.BICUBIC), ToTensor(), ]) class ErasingData(Dataset): def __init__(self, dataRoot, loadSize, training=True): super(ErasingData, self).__init__() self.imageFiles = [join (dataRootK, files) for dataRootK, dn, filenames in walk(dataRoot) \ for files in filenames if CheckImageFile(files)] self.loadSize = loadSize self.ImgTrans = ImageTransform(loadSize) self.training = training def __getitem__(self, index): img = Image.open(self.imageFiles[index]) mask = Image.open(self.imageFiles[index].replace('all_images','mask')) gt = Image.open(self.imageFiles[index].replace('all_images','all_labels')) # import pdb;pdb.set_trace() if self.training: # ### for data augmentation all_input = [img, mask, gt] all_input = random_horizontal_flip(all_input) all_input = random_rotate(all_input) img = all_input[0] mask = all_input[1] gt = all_input[2] ### for data augmentation inputImage = self.ImgTrans(img.convert('RGB')) mask = self.ImgTrans(mask.convert('RGB')) groundTruth = self.ImgTrans(gt.convert('RGB')) path = self.imageFiles[index].split('/')[-1] # import pdb;pdb.set_trace() return inputImage, groundTruth, mask, path def __len__(self): return len(self.imageFiles) class devdata(Dataset): def __init__(self, dataRoot, gtRoot, loadSize=512): super(devdata, self).__init__() self.imageFiles = [join (dataRootK, files) for dataRootK, dn, filenames in walk(dataRoot) \ for files in filenames if CheckImageFile(files)] self.gtFiles = [join (gtRootK, files) for gtRootK, dn, filenames in walk(gtRoot) \ for files in filenames if CheckImageFile(files)] self.loadSize = loadSize self.ImgTrans = ImageTransform(loadSize) def __getitem__(self, index): img = Image.open(self.imageFiles[index]) gt = Image.open(self.gtFiles[index]) #import pdb;pdb.set_trace() inputImage = self.ImgTrans(img.convert('RGB')) groundTruth = self.ImgTrans(gt.convert('RGB')) path = self.imageFiles[index].split('/')[-1] return inputImage, groundTruth,path def __len__(self): return len(self.imageFiles) ================================================ FILE: evaluatuion.py ================================================ import os import math import argparse import torch import torch.nn as nn import torch.backends.cudnn as cudnn from PIL import Image import numpy as np from torch.autograd import Variable from torchvision.utils import save_image from torch.utils.data import DataLoader from data.dataloader import devdata from scipy import signal, ndimage import gauss parser = argparse.ArgumentParser() parser.add_argument('--target_path', type=str, default='', help='results') parser.add_argument('--gt_path', type=str, default='', help='labels') args = parser.parse_args() sum_psnr = 0 sum_ssim = 0 sum_AGE = 0 sum_pCEPS = 0 sum_pEPS = 0 sum_mse = 0 count = 0 sum_time = 0.0 l1_loss = 0 img_path = args.target_path gt_path = args.gt_path def ssim(img1, img2, cs_map=False): """Return the Structural Similarity Map corresponding to input images img1 and img2 (images are assumed to be uint8) This function attempts to mimic precisely the functionality of ssim.m a MATLAB provided by the author's of SSIM https://ece.uwaterloo.ca/~z70wang/research/ssim/ssim_index.m """ img1 = img1.astype(float) img2 = img2.astype(float) size = min(img1.shape[0], 11) sigma = 1.5 window = gauss.fspecial_gauss(size, sigma) K1 = 0.01 K2 = 0.03 L = 255 #bitdepth of image C1 = (K1 * L) ** 2 C2 = (K2 * L) ** 2 # import pdb;pdb.set_trace() mu1 = signal.fftconvolve(img1, window, mode = 'valid') mu2 = signal.fftconvolve(img2, window, mode = 'valid') mu1_sq = mu1 * mu1 mu2_sq = mu2 * mu2 mu1_mu2 = mu1 * mu2 sigma1_sq = signal.fftconvolve(img1 * img1, window, mode = 'valid') - mu1_sq sigma2_sq = signal.fftconvolve(img2 * img2, window, mode = 'valid') - mu2_sq sigma12 = signal.fftconvolve(img1 * img2, window, mode = 'valid') - mu1_mu2 if cs_map: return (((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)), (2.0 * sigma12 + C2) / (sigma1_sq + sigma2_sq + C2)) else: return ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)) def msssim(img1, img2): """This function implements Multi-Scale Structural Similarity (MSSSIM) Image Quality Assessment according to Z. Wang's "Multi-scale structural similarity for image quality assessment" Invited Paper, IEEE Asilomar Conference on Signals, Systems and Computers, Nov. 2003 Author's MATLAB implementation:- http://www.cns.nyu.edu/~lcv/ssim/msssim.zip """ level = 5 weight = np.array([0.0448, 0.2856, 0.3001, 0.2363, 0.1333]) downsample_filter = np.ones((2, 2)) / 4.0 # im1 = img1.astype(np.float64) # im2 = img2.astype(np.float64) mssim = np.array([]) mcs = np.array([]) for l in range(level): ssim_map, cs_map = ssim(img1, img2, cs_map = True) mssim = np.append(mssim, ssim_map.mean()) mcs = np.append(mcs, cs_map.mean()) filtered_im1 = ndimage.filters.convolve(img1, downsample_filter, mode = 'reflect') filtered_im2 = ndimage.filters.convolve(img2, downsample_filter, mode = 'reflect') im1 = filtered_im1[: : 2, : : 2] im2 = filtered_im2[: : 2, : : 2] # Note: Remove the negative and add it later to avoid NaN in exponential. sign_mcs = np.sign(mcs[0 : level - 1]) sign_mssim = np.sign(mssim[level - 1]) mcs_power = np.power(np.abs(mcs[0 : level - 1]), weight[0 : level - 1]) mssim_power = np.power(np.abs(mssim[level - 1]), weight[level - 1]) return np.prod(sign_mcs * mcs_power) * sign_mssim * mssim_power def ImageTransform(loadSize, cropSize): return Compose([ Resize(size=loadSize, interpolation=Image.BICUBIC), # RandomCrop(size=cropSize), #RandomHorizontalFlip(p=0.5), ToTensor(), ]) def visual(image): im =(image).transpose(1,2).transpose(2,3).detach().cpu().numpy() Image.fromarray(im[0].astype(np.uint8)).show() imgData = devdata(dataRoot=img_path, gtRoot=gt_path) data_loader = DataLoader(imgData, batch_size=1, shuffle=True, num_workers=0, drop_last=False) for k, (img,lbl,path) in enumerate(data_loader): ##import pdb;pdb.set_trace() mse = ((lbl - img)**2).mean() sum_mse += mse print(path,count, 'mse: ', mse) if mse == 0: continue count += 1 psnr = 10 * math.log10(1/mse) sum_psnr += psnr print(path,count, ' psnr: ', psnr) #l1_loss += nn.L1Loss()(img, lbl) R = lbl[0,0,:, :] G = lbl[0,1,:, :] B = lbl[0,2,:, :] YGT = .299 * R + .587 * G + .114 * B R = img[0,0,:, :] G = img[0,1,:, :] B = img[0,2,:, :] YBC = .299 * R + .587 * G + .114 * B Diff = abs(np.array(YBC*255) - np.array(YGT*255)).round().astype(np.uint8) AGE = np.mean(Diff) print(' AGE: ', AGE) mssim = msssim(np.array(YGT*255), np.array(YBC*255)) sum_ssim += mssim print(count, ' ssim:', mssim) threshold = 20 Errors = Diff > threshold EPs = sum(sum(Errors)).astype(float) pEPs = EPs / float(512*512) print(' pEPS: ' , pEPs) sum_pEPS += pEPs ########################## CEPs and pCEPs ################################ structure = np.array([[0, 1, 0], [1, 1, 1], [0, 1, 0]]) sum_AGE+=AGE erodedErrors = ndimage.binary_erosion(Errors, structure).astype(Errors.dtype) CEPs = sum(sum(erodedErrors)) pCEPs = CEPs / float(512*512) print(' pCEPS: ' , pCEPs) sum_pCEPS += pCEPs print(sum_psnr) print('avg mse:', sum_mse / count) print('average psnr:', sum_psnr / count) print('average ssim:', sum_ssim / count) print('average AGE:', sum_AGE / count) print('average pEPS:', sum_pEPS / count) print('average pCEPS:', sum_pCEPS / count) ================================================ FILE: gauss.py ================================================ #!/usr/bin/env python """Module providing functionality surrounding gaussian function. """ SVN_REVISION = '$LastChangedRevision: 16541 $' import sys import numpy def gaussian2(size, sigma): """Returns a normalized circularly symmetric 2D gauss kernel array f(x,y) = A.e^{-(x^2/2*sigma^2 + y^2/2*sigma^2)} where A = 1/(2*pi*sigma^2) as define by Wolfram Mathworld http://mathworld.wolfram.com/GaussianFunction.html """ A = 1/(2.0*numpy.pi*sigma**2) x, y = numpy.mgrid[-size//2 + 1:size//2 + 1, -size//2 + 1:size//2 + 1] g = A*numpy.exp(-((x**2/(2.0*sigma**2))+(y**2/(2.0*sigma**2)))) return g def fspecial_gauss(size, sigma): """Function to mimic the 'fspecial' gaussian MATLAB function """ x, y = numpy.mgrid[-size//2 + 1:size//2 + 1, -size//2 + 1:size//2 + 1] g = numpy.exp(-((x**2 + y**2)/(2.0*sigma**2))) return g/g.sum() def main(): """Show simple use cases for functionality provided by this module.""" from mpl_toolkits.mplot3d.axes3d import Axes3D import pylab argv = sys.argv if len(argv) != 3: print >>sys.stderr, 'usage: python -m pim.sp.gauss size sigma' sys.exit(2) size = int(argv[1]) sigma = float(argv[2]) x, y = numpy.mgrid[-size//2 + 1:size//2 + 1, -size//2 + 1:size//2 + 1] fig = pylab.figure() fig.suptitle('Some 2-D Gauss Functions') ax = fig.add_subplot(2, 1, 1, projection='3d') ax.plot_surface(x, y, fspecial_gauss(size, sigma), rstride=1, cstride=1, linewidth=0, antialiased=False, cmap=pylab.jet()) ax = fig.add_subplot(2, 1, 2, projection='3d') ax.plot_surface(x, y, gaussian2(size, sigma), rstride=1, cstride=1, linewidth=0, antialiased=False, cmap=pylab.jet()) pylab.show() return 0 if __name__ == '__main__': sys.exit(main()) ================================================ FILE: loss/Loss.py ================================================ import torch from torch import nn from torch import autograd import torch.nn.functional as F from tensorboardX import SummaryWriter from models.discriminator import Discriminator_STE from PIL import Image import numpy as np def gram_matrix(feat): # https://github.com/pytorch/examples/blob/master/fast_neural_style/neural_style/utils.py (b, ch, h, w) = feat.size() feat = feat.view(b, ch, h * w) feat_t = feat.transpose(1, 2) gram = torch.bmm(feat, feat_t) / (ch * h * w) return gram def visual(image): im = image.transpose(1,2).transpose(2,3).detach().cpu().numpy() Image.fromarray(im[0].astype(np.uint8)).show() def dice_loss(input, target): input = torch.sigmoid(input) input = input.contiguous().view(input.size()[0], -1) target = target.contiguous().view(target.size()[0], -1) input = input target = target a = torch.sum(input * target, 1) b = torch.sum(input * input, 1) + 0.001 c = torch.sum(target * target, 1) + 0.001 d = (2 * a) / (b + c) dice_loss = torch.mean(d) return 1 - dice_loss class LossWithGAN_STE(nn.Module): def __init__(self, logPath, extractor, Lamda, lr, betasInit=(0.5, 0.9)): super(LossWithGAN_STE, self).__init__() self.l1 = nn.L1Loss() self.extractor = extractor self.discriminator = Discriminator_STE(3) ## local_global sn patch gan self.D_optimizer = torch.optim.Adam(self.discriminator.parameters(), lr=lr, betas=betasInit) self.cudaAvailable = torch.cuda.is_available() self.numOfGPUs = torch.cuda.device_count() self.lamda = Lamda self.writer = SummaryWriter(logPath) def forward(self, input, mask, x_o1,x_o2,x_o3,output,mm, gt, count, epoch): self.discriminator.zero_grad() D_real = self.discriminator(gt, mask) D_real = D_real.mean().sum() * -1 D_fake = self.discriminator(output, mask) D_fake = D_fake.mean().sum() * 1 D_loss = torch.mean(F.relu(1.+D_real)) + torch.mean(F.relu(1.+D_fake)) #SN-patch-GAN loss D_fake = -torch.mean(D_fake) # SN-Patch-GAN loss self.D_optimizer.zero_grad() D_loss.backward(retain_graph=True) self.D_optimizer.step() self.writer.add_scalar('LossD/Discrinimator loss', D_loss.item(), count) output_comp = mask * input + (1 - mask) * output # import pdb;pdb.set_trace() holeLoss = 10 * self.l1((1 - mask) * output, (1 - mask) * gt) validAreaLoss = 2*self.l1(mask * output, mask * gt) mask_loss = dice_loss(mm, 1-mask) ### MSR loss ### masks_a = F.interpolate(mask, scale_factor=0.25) masks_b = F.interpolate(mask, scale_factor=0.5) imgs1 = F.interpolate(gt, scale_factor=0.25) imgs2 = F.interpolate(gt, scale_factor=0.5) msrloss = 8 * self.l1((1-mask)*x_o3,(1-mask)*gt) + 0.8*self.l1(mask*x_o3, mask*gt)+\ 6 * self.l1((1-masks_b)*x_o2,(1-masks_b)*imgs2)+1*self.l1(masks_b*x_o2,masks_b*imgs2)+\ 5 * self.l1((1-masks_a)*x_o1,(1-masks_a)*imgs1)+0.8*self.l1(masks_a*x_o1,masks_a*imgs1) feat_output_comp = self.extractor(output_comp) feat_output = self.extractor(output) feat_gt = self.extractor(gt) prcLoss = 0.0 for i in range(3): prcLoss += 0.01 * self.l1(feat_output[i], feat_gt[i]) prcLoss += 0.01 * self.l1(feat_output_comp[i], feat_gt[i]) styleLoss = 0.0 for i in range(3): styleLoss += 120 * self.l1(gram_matrix(feat_output[i]), gram_matrix(feat_gt[i])) styleLoss += 120 * self.l1(gram_matrix(feat_output_comp[i]), gram_matrix(feat_gt[i])) """ if self.numOfGPUs > 1: holeLoss = holeLoss.sum() / self.numOfGPUs validAreaLoss = validAreaLoss.sum() / self.numOfGPUs prcLoss = prcLoss.sum() / self.numOfGPUs styleLoss = styleLoss.sum() / self.numOfGPUs """ self.writer.add_scalar('LossG/Hole loss', holeLoss.item(), count) self.writer.add_scalar('LossG/Valid loss', validAreaLoss.item(), count) self.writer.add_scalar('LossG/msr loss', msrloss.item(), count) self.writer.add_scalar('LossPrc/Perceptual loss', prcLoss.item(), count) self.writer.add_scalar('LossStyle/style loss', styleLoss.item(), count) GLoss = msrloss+ holeLoss + validAreaLoss+ prcLoss + styleLoss + 0.1 * D_fake + 1*mask_loss self.writer.add_scalar('Generator/Joint loss', GLoss.item(), count) return GLoss.sum() ================================================ FILE: models/Model.py ================================================ import torch import torch.nn as nn from torchvision import models #VGG16 feature extract class VGG16FeatureExtractor(nn.Module): def __init__(self): super(VGG16FeatureExtractor, self).__init__() vgg16 = models.vgg16(pretrained=True) # vgg16.load_state_dict(torch.load('./vgg16-397923af.pth')) self.enc_1 = nn.Sequential(*vgg16.features[:5]) self.enc_2 = nn.Sequential(*vgg16.features[5:10]) self.enc_3 = nn.Sequential(*vgg16.features[10:17]) # fix the encoder for i in range(3): for param in getattr(self, 'enc_{:d}'.format(i + 1)).parameters(): param.requires_grad = False def forward(self, image): results = [image] for i in range(3): func = getattr(self, 'enc_{:d}'.format(i + 1)) results.append(func(results[-1])) return results[1:] ================================================ FILE: models/discriminator.py ================================================ import torch import torch.nn as nn from .networks import ConvWithActivation, get_pad ##discriminator class Discriminator_STE(nn.Module): def __init__(self, inputChannels): super(Discriminator_STE, self).__init__() cnum =32 self.globalDis = nn.Sequential( ConvWithActivation(3, 2*cnum, 4, 2, padding=get_pad(256, 5, 2)), ConvWithActivation(2*cnum, 4*cnum, 4, 2, padding=get_pad(128, 5, 2)), ConvWithActivation(4*cnum, 8*cnum, 4, 2, padding=get_pad(64, 5, 2)), ConvWithActivation(8*cnum, 8*cnum, 4, 2, padding=get_pad(32, 5, 2)), ConvWithActivation(8*cnum, 8*cnum, 4, 2, padding=get_pad(16, 5, 2)), ConvWithActivation(8*cnum, 8*cnum, 4, 2, padding=get_pad(8, 5, 2)), ) self.localDis = nn.Sequential( ConvWithActivation(3, 2*cnum, 4, 2, padding=get_pad(256, 5, 2)), ConvWithActivation(2*cnum, 4*cnum, 4, 2, padding=get_pad(128, 5, 2)), ConvWithActivation(4*cnum, 8*cnum, 4, 2, padding=get_pad(64, 5, 2)), ConvWithActivation(8*cnum, 8*cnum, 4, 2, padding=get_pad(32, 5, 2)), ConvWithActivation(8*cnum, 8*cnum, 4, 2, padding=get_pad(16, 5, 2)), ConvWithActivation(8*cnum, 8*cnum, 4, 2, padding=get_pad(8, 5, 2)), ) self.fusion = nn.Sequential( nn.Conv2d(512, 1, kernel_size=4), nn.Sigmoid() ) def forward(self, input, masks): global_feat = self.globalDis(input) local_feat = self.localDis(input * (1 - masks)) concat_feat = torch.cat((global_feat, local_feat), 1) return self.fusion(concat_feat).view(input.size()[0], -1) ================================================ FILE: models/networks.py ================================================ import torch import numpy as np import torch.nn.functional as F import torch.nn as nn def get_pad(in_, ksize, stride, atrous=1): out_ = np.ceil(float(in_)/stride) return int(((out_ - 1) * stride + atrous*(ksize-1) + 1 - in_)/2) class ConvWithActivation(torch.nn.Module): """ SN convolution for spetral normalization conv """ def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, activation=torch.nn.LeakyReLU(0.2, inplace=True)): super(ConvWithActivation, self).__init__() self.conv2d = torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias) self.conv2d = torch.nn.utils.spectral_norm(self.conv2d) self.activation = activation for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight) def forward(self, input): x = self.conv2d(input) if self.activation is not None: return self.activation(x) else: return x class DeConvWithActivation(torch.nn.Module): """ SN convolution for spetral normalization conv """ def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, activation=torch.nn.LeakyReLU(0.2, inplace=True)): super(DeConvWithActivation, self).__init__() self.conv2d = torch.nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias) self.conv2d = torch.nn.utils.spectral_norm(self.conv2d) self.activation = activation for m in self.modules(): if isinstance(m, nn.ConvTranspose2d): nn.init.kaiming_normal_(m.weight) def forward(self, input): x = self.conv2d(input) if self.activation is not None: return self.activation(x) else: return x ================================================ FILE: models/sa_gan.py ================================================ import torch import torch.nn as nn import torch.nn.functional as F import numpy as np from PIL import Image from torch.autograd import Variable from .networks import get_pad, ConvWithActivation, DeConvWithActivation def img2photo(imgs): return ((imgs+1)*127.5).transpose(1,2).transpose(2,3).detach().cpu().numpy() def visual(imgs): im = img2photo(imgs) Image.fromarray(im[0].astype(np.uint8)).show() class Residual(nn.Module): def __init__(self, in_channels, out_channels, same_shape=True, **kwargs): super(Residual,self).__init__() self.same_shape = same_shape strides = 1 if same_shape else 2 self.conv1 = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1,stride=strides) self.conv2 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1) # self.conv2 = torch.nn.utils.spectral_norm(self.conv2) if not same_shape: self.conv3 = nn.Conv2d(in_channels, out_channels, kernel_size=1, # self.conv3 = nn.Conv2D(channels, kernel_size=3, padding=1, stride=strides) # self.conv3 = torch.nn.utils.spectral_norm(self.conv3) self.batch_norm2d = nn.BatchNorm2d(out_channels) def forward(self,x): out = F.relu(self.conv1(x)) out = self.conv2(out) if not self.same_shape: x = self.conv3(x) out = self.batch_norm2d(out + x) # out = out + x return F.relu(out) class ASPP(nn.Module): def __init__(self, in_channel=512, depth=256): super(ASPP,self).__init__() self.mean = nn.AdaptiveAvgPool2d((1, 1)) self.conv = nn.Conv2d(in_channel, depth, 1, 1) # k=1 s=1 no pad self.atrous_block1 = nn.Conv2d(in_channel, depth, 1, 1) self.atrous_block6 = nn.Conv2d(in_channel, depth, 3, 1, padding=6, dilation=6) self.atrous_block12 = nn.Conv2d(in_channel, depth, 3, 1, padding=12, dilation=12) self.atrous_block18 = nn.Conv2d(in_channel, depth, 3, 1, padding=18, dilation=18) self.conv_1x1_output = nn.Conv2d(depth * 5, depth, 1, 1) def forward(self, x): size = x.shape[2:] image_features = self.mean(x) image_features = self.conv(image_features) image_features = F.upsample(image_features, size=size, mode='bilinear') atrous_block1 = self.atrous_block1(x) atrous_block6 = self.atrous_block6(x) atrous_block12 = self.atrous_block12(x) atrous_block18 = self.atrous_block18(x) net = self.conv_1x1_output(torch.cat([image_features, atrous_block1, atrous_block6, atrous_block12, atrous_block18], dim=1)) return net class STRnet2(nn.Module): def __init__(self, n_in_channel=3): super(STRnet2, self).__init__() #### U-Net #### #downsample self.conv1 = ConvWithActivation(3,32,kernel_size=4,stride=2,padding=1) self.conva = ConvWithActivation(32,32,kernel_size=3, stride=1, padding=1) self.convb = ConvWithActivation(32,64, kernel_size=4, stride=2, padding=1) self.res1 = Residual(64,64) self.res2 = Residual(64,64) self.res3 = Residual(64,128,same_shape=False) self.res4 = Residual(128,128) self.res5 = Residual(128,256,same_shape=False) # self.nn = ConvWithActivation(256, 512, 3, 1, dilation=2, padding=get_pad(64, 3, 1, 2)) self.res6 = Residual(256,256) self.res7 = Residual(256,512,same_shape=False) self.res8 = Residual(512,512) self.conv2 = ConvWithActivation(512,512,kernel_size=1) #upsample self.deconv1 = DeConvWithActivation(512,256,kernel_size=3,padding=1,stride=2) self.deconv2 = DeConvWithActivation(256*2,128,kernel_size=3,padding=1,stride=2) self.deconv3 = DeConvWithActivation(128*2,64,kernel_size=3,padding=1,stride=2) self.deconv4 = DeConvWithActivation(64*2,32,kernel_size=3,padding=1,stride=2) self.deconv5 = DeConvWithActivation(64,3,kernel_size=3,padding=1,stride=2) #lateral connection self.lateral_connection1 = nn.Sequential( nn.Conv2d(256, 256, kernel_size=1, padding=0,stride=1), nn.Conv2d(256, 512, kernel_size=3, padding=1,stride=1), nn.Conv2d(512, 512, kernel_size=3, padding=1,stride=1), nn.Conv2d(512, 256, kernel_size=1, padding=0,stride=1),) self.lateral_connection2 = nn.Sequential( nn.Conv2d(128, 128, kernel_size=1, padding=0,stride=1), nn.Conv2d(128, 256, kernel_size=3, padding=1,stride=1), nn.Conv2d(256, 256, kernel_size=3, padding=1,stride=1), nn.Conv2d(256, 128, kernel_size=1, padding=0,stride=1),) self.lateral_connection3 = nn.Sequential( nn.Conv2d(64, 64, kernel_size=1, padding=0,stride=1), nn.Conv2d(64, 128, kernel_size=3, padding=1,stride=1), nn.Conv2d(128, 128, kernel_size=3, padding=1,stride=1), nn.Conv2d(128, 64, kernel_size=1, padding=0,stride=1),) self.lateral_connection4 = nn.Sequential( nn.Conv2d(32, 32, kernel_size=1, padding=0,stride=1), nn.Conv2d(32, 64, kernel_size=3, padding=1,stride=1), nn.Conv2d(64, 64, kernel_size=3, padding=1,stride=1), nn.Conv2d(64, 32, kernel_size=1, padding=0,stride=1),) #self.relu = nn.elu(alpha=1.0) self.conv_o1 = nn.Conv2d(64,3,kernel_size=1) self.conv_o2 = nn.Conv2d(32,3,kernel_size=1) ##### U-Net ##### ### ASPP ### # self.aspp = ASPP(512, 256) ### ASPP ### ### mask branch decoder ### self.mask_deconv_a = DeConvWithActivation(512,256,kernel_size=3,padding=1,stride=2) self.mask_conv_a = ConvWithActivation(256,128,kernel_size=3,padding=1,stride=1) self.mask_deconv_b = DeConvWithActivation(256,128,kernel_size=3,padding=1,stride=2) self.mask_conv_b = ConvWithActivation(128,64,kernel_size=3,padding=1,stride=1) self.mask_deconv_c = DeConvWithActivation(128,64,kernel_size=3,padding=1,stride=2) self.mask_conv_c = ConvWithActivation(64,32,kernel_size=3,padding=1,stride=1) self.mask_deconv_d = DeConvWithActivation(64,32,kernel_size=3,padding=1,stride=2) self.mask_conv_d = nn.Conv2d(32,3,kernel_size=1) ### mask branch ### ##### Refine sub-network ###### n_in_channel = 3 cnum = 32 ####downsapmle self.coarse_conva = ConvWithActivation(n_in_channel, cnum, kernel_size=5, stride=1, padding=2) self.coarse_convb = ConvWithActivation(cnum, 2*cnum, kernel_size=4, stride=2, padding=1) self.coarse_convc = ConvWithActivation(2*cnum, 2*cnum, kernel_size=3, stride=1, padding=1) self.coarse_convd = ConvWithActivation(2*cnum, 4*cnum, kernel_size=4, stride=2, padding=1) self.coarse_conve = ConvWithActivation(4*cnum, 4*cnum, kernel_size=3, stride=1, padding=1) self.coarse_convf = ConvWithActivation(4*cnum, 4*cnum, kernel_size=3, stride=1, padding=1) ### astrous self.astrous_net = nn.Sequential( ConvWithActivation(4*cnum, 4*cnum, 3, 1, dilation=2, padding=get_pad(64, 3, 1, 2)), ConvWithActivation(4*cnum, 4*cnum, 3, 1, dilation=4, padding=get_pad(64, 3, 1, 4)), ConvWithActivation(4*cnum, 4*cnum, 3, 1, dilation=8, padding=get_pad(64, 3, 1, 8)), ConvWithActivation(4*cnum, 4*cnum, 3, 1, dilation=16, padding=get_pad(64, 3, 1, 16)), ) ###astrous ### upsample self.coarse_convk = ConvWithActivation(4*cnum, 4*cnum, kernel_size=3, stride=1, padding=1) self.coarse_convl = ConvWithActivation(4*cnum, 4*cnum, kernel_size=3, stride=1, padding=1) self.coarse_deconva = DeConvWithActivation(4*cnum*3, 2*cnum, kernel_size=3,padding=1,stride=2) self.coarse_convm = ConvWithActivation(2*cnum, 2*cnum, kernel_size=3, stride=1, padding=1) self.coarse_deconvb = DeConvWithActivation(2*cnum*3, cnum, kernel_size=3,padding=1,stride=2) self.coarse_convn = nn.Sequential( ConvWithActivation(cnum, cnum//2, kernel_size=3, stride=1, padding=1), #Self_Attn(cnum//2, 'relu'), ConvWithActivation(cnum//2, 3, kernel_size=3, stride=1, padding=1, activation=None), ) self.c1 = nn.Conv2d(32,64,kernel_size=1) self.c2 = nn.Conv2d(64,128,kernel_size=1) ##### Refine network ###### def forward(self, x): #downsample x = self.conv1(x) x = self.conva(x) con_x1 = x # import pdb;pdb.set_trace() x = self.convb(x) x = self.res1(x) con_x2 = x x = self.res2(x) x = self.res3(x) con_x3 = x x = self.res4(x) x = self.res5(x) con_x4 = x x = self.res6(x) # x_mask = self.nn(con_x4) ### for mask branch aspp # x_mask = self.aspp(x_mask) ### for mask branch aspp x_mask=x ### no aspp # import pdb;pdb.set_trace() x = self.res7(x) x = self.res8(x) x = self.conv2(x) #upsample x = self.deconv1(x) x = torch.cat([self.lateral_connection1(con_x4), x], dim=1) x = self.deconv2(x) x = torch.cat([self.lateral_connection2(con_x3), x], dim=1) x = self.deconv3(x) xo1 = x x = torch.cat([self.lateral_connection3(con_x2), x], dim=1) x = self.deconv4(x) xo2 = x x = torch.cat([self.lateral_connection4(con_x1), x], dim=1) #import pdb;pdb.set_trace() x = self.deconv5(x) x_o1 = self.conv_o1(xo1) x_o2 = self.conv_o2(xo2) x_o_unet = x ### mask branch ### mm = self.mask_deconv_a(torch.cat([x_mask,con_x4],dim=1)) mm = self.mask_conv_a(mm) mm = self.mask_deconv_b(torch.cat([mm,con_x3],dim=1)) mm = self.mask_conv_b(mm) mm = self.mask_deconv_c(torch.cat([mm,con_x2],dim=1)) mm = self.mask_conv_c(mm) mm = self.mask_deconv_d(torch.cat([mm,con_x1],dim=1)) mm = self.mask_conv_d(mm) ### mask branch ### ###refine sub-network x = self.coarse_conva(x_o_unet) x = self.coarse_convb(x) x = self.coarse_convc(x) x_c1 = x ###concate feature1 x = self.coarse_convd(x) x = self.coarse_conve(x) x = self.coarse_convf(x) x_c2 = x ###concate feature2 x = self.astrous_net(x) x = self.coarse_convk(x) x = self.coarse_convl(x) x = self.coarse_deconva(torch.cat([x, x_c2,self.c2(con_x2)],dim=1)) x = self.coarse_convm(x) x = self.coarse_deconvb(torch.cat([x,x_c1,self.c1(con_x1)],dim=1)) x = self.coarse_convn(x) return x_o1, x_o2, x_o_unet, x, mm ================================================ FILE: test_image_STE.py ================================================ import os import math import argparse import torch import torch.nn as nn import torch.backends.cudnn as cudnn from PIL import Image import numpy as np from torch.autograd import Variable from torchvision.utils import save_image from torch.utils.data import DataLoader from data.dataloader import ErasingData from models.sa_gan import STRnet2 parser = argparse.ArgumentParser() parser.add_argument('--numOfWorkers', type=int, default=0, help='workers for dataloader') parser.add_argument('--modelsSavePath', type=str, default='', help='path for saving models') parser.add_argument('--logPath', type=str, default='') parser.add_argument('--batchSize', type=int, default=16) parser.add_argument('--loadSize', type=int, default=512, help='image loading size') parser.add_argument('--dataRoot', type=str, default='') parser.add_argument('--pretrained',type=str, default='', help='pretrained models for finetuning') parser.add_argument('--savePath', type=str, default='./results/sn_tv/') args = parser.parse_args() cuda = torch.cuda.is_available() if cuda: print('Cuda is available!') cudnn.benchmark = True def visual(image): im =(image).transpose(1,2).transpose(2,3).detach().cpu().numpy() Image.fromarray(im[0].astype(np.uint8)).show() batchSize = args.batchSize loadSize = (args.loadSize, args.loadSize) dataRoot = args.dataRoot savePath = args.savePath result_with_mask = savePath + 'WithMaskOutput/' result_straight = savePath + 'StrOuput/' #import pdb;pdb.set_trace() if not os.path.exists(savePath): os.makedirs(savePath) os.makedirs(result_with_mask) os.makedirs(result_straight) Erase_data = ErasingData(dataRoot, loadSize, training=False) Erase_data = DataLoader(Erase_data, batch_size=batchSize, shuffle=True, num_workers=args.numOfWorkers, drop_last=False) netG = STRnet2(3) netG.load_state_dict(torch.load(args.pretrained)) # if cuda: netG = netG.cuda() for param in netG.parameters(): param.requires_grad = False print('OK!') import time start = time.time() netG.eval() for imgs, gt, masks, path in (Erase_data): if cuda: imgs = imgs.cuda() gt = gt.cuda() masks = masks.cuda() out1, out2, out3, g_images,mm = netG(imgs) g_image = g_images.data.cpu() gt = gt.data.cpu() mask = masks.data.cpu() g_image_with_mask = gt * (mask) + g_image * (1- mask) save_image(g_image_with_mask, result_with_mask+path[0]) save_image(g_image, result_straight+path[0]) ================================================ FILE: train_STE.py ================================================ import os import math import argparse import torch import torch.nn as nn import torch.optim as optim import torch.backends.cudnn as cudnn from PIL import Image import numpy as np from torch.autograd import Variable from torchvision.utils import save_image from torchvision import datasets from torch.utils.data import DataLoader from torchvision import utils from data.dataloader import ErasingData from loss.Loss import LossWithGAN_STE from models.Model import VGG16FeatureExtractor from models.sa_gan import STRnet2 torch.set_num_threads(5) os.environ["CUDA_VISIBLE_DEVICES"] = "3" ### set the gpu as No.... parser = argparse.ArgumentParser() parser.add_argument('--numOfWorkers', type=int, default=0, help='workers for dataloader') parser.add_argument('--modelsSavePath', type=str, default='', help='path for saving models') parser.add_argument('--logPath', type=str, default='') parser.add_argument('--batchSize', type=int, default=16) parser.add_argument('--loadSize', type=int, default=512, help='image loading size') parser.add_argument('--dataRoot', type=str, default='') parser.add_argument('--pretrained',type=str, default='', help='pretrained models for finetuning') parser.add_argument('--num_epochs', type=int, default=500, help='epochs') args = parser.parse_args() def visual(image): im = image.transpose(1,2).transpose(2,3).detach().cpu().numpy() Image.fromarray(im[0].astype(np.uint8)).show() cuda = torch.cuda.is_available() if cuda: print('Cuda is available!') cudnn.enable = True cudnn.benchmark = True batchSize = args.batchSize loadSize = (args.loadSize, args.loadSize) if not os.path.exists(args.modelsSavePath): os.makedirs(args.modelsSavePath) dataRoot = args.dataRoot # import pdb;pdb.set_trace() Erase_data = ErasingData(dataRoot, loadSize, training=True) Erase_data = DataLoader(Erase_data, batch_size=batchSize, shuffle=True, num_workers=args.numOfWorkers, drop_last=False, pin_memory=True) netG = STRnet2(3) if args.pretrained != '': print('loaded ') netG.load_state_dict(torch.load(args.pretrained)) numOfGPUs = torch.cuda.device_count() if cuda: netG = netG.cuda() if numOfGPUs > 1: netG = nn.DataParallel(netG, device_ids=range(numOfGPUs)) count = 1 G_optimizer = optim.Adam(netG.parameters(), lr=0.0001, betas=(0.5, 0.9)) criterion = LossWithGAN_STE(args.logPath, VGG16FeatureExtractor(), lr=0.00001, betasInit=(0.0, 0.9), Lamda=10.0) if cuda: criterion = criterion.cuda() if numOfGPUs > 1: criterion = nn.DataParallel(criterion, device_ids=range(numOfGPUs)) print('OK!') num_epochs = args.num_epochs for i in range(1, num_epochs + 1): netG.train() for k,(imgs, gt, masks, path) in enumerate(Erase_data): if cuda: imgs = imgs.cuda() gt = gt.cuda() masks = masks.cuda() netG.zero_grad() x_o1,x_o2,x_o3,fake_images,mm = netG(imgs) G_loss = criterion(imgs, masks, x_o1, x_o2, x_o3, fake_images, mm, gt, count, i) G_loss = G_loss.sum() G_optimizer.zero_grad() G_loss.backward() G_optimizer.step() print('[{}/{}] Generator Loss of epoch{} is {}'.format(k,len(Erase_data),i, G_loss.item())) count += 1 if ( i % 10 == 0): if numOfGPUs > 1 : torch.save(netG.module.state_dict(), args.modelsSavePath + '/STE_{}.pth'.format(i)) else: torch.save(netG.state_dict(), args.modelsSavePath + '/STE_{}.pth'.format(i))