Repository: lcy0604/EraseNet
Branch: master
Commit: 673bfad49ec5
Files: 12
Total size: 40.4 KB
Directory structure:
gitextract_sn7w10qi/
├── LICENSE
├── README.md
├── data/
│ └── dataloader.py
├── evaluatuion.py
├── gauss.py
├── loss/
│ └── Loss.py
├── models/
│ ├── Model.py
│ ├── discriminator.py
│ ├── networks.py
│ └── sa_gan.py
├── test_image_STE.py
└── train_STE.py
================================================
FILE CONTENTS
================================================
================================================
FILE: LICENSE
================================================
MIT License
Copyright (c) 2020 Chongyu-Liu
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
================================================
FILE: README.md
================================================
# EraseNet
This repository is the implementation of EraseNet, a neural network for end-to-end scene text removal.
## Data preparation
The data preparation can be refer to ./examples/. You can download our datatset at [SCUT-EnsText](https://github.com/HCIILAB/SCUT-EnsText) or synthetic dataset [SCUT-Syn](https://github.com/HCIILAB/Scene-Text-Removal) for training and testing.
SCUT-EnsText needs decompression password, you can send me at [liuchongyu1996@gmail.com](mailto:liuchongyu1996@gmail.com) for it.
## Environment
Anaconda is recommended to establish a virtual environment to run our code. My environment can be refered as follows:
```
python = 3.7
pytorch = 1.3.1
torchvision = 0.4.2
```
## Demo
We provide our retrain model for quick inference for SCUT-EnsText. [Model Link](https://drive.google.com/file/d/1scrtQ2GFvKjjoGEqbKxpOMn37mJmXsFd/view)
## Training
Once the data is well prepared, you can begin training:
```
python train_STE.py --batchSize 4 \
--dataRoot 'your path' \
--modelsSavePath 'your path' \
--logPath 'your path' \
```
## Testing and evaluation
If you want to predict the results, run:
```
python test_image_STE.py --dataRoot 'your path' \
--batchSize 1 \
--pretrain 'your path' \
--savePath 'your path'
```
To evaluate the results:
```
python evaluatuion.py --target_path 'results_path' --gt_path 'labels_path'
```
## Acknowledge
The repository is benefit a lot from [LBAM](https://github.com/Vious/LBAM_Pytorch) and [GatedConv](https://github.com/avalonstrel/GatedConvolution_pytorch). Thanks a lot for their excellent work.
## Citation
If you find our method or dataset useful for your reserach, please cite:
```
@ARTICLE{Erase2020Liu,
author ={Liu, Chongyu and Liu, Yuliang and Jin, lianwen and Zhang, Shuaitao and Luo, Canjie and Wang, Yongpan},
journal ={IEEE Transactions on Image Processing},
title ={EraseNet: End-to-End Text Removal in the Wild},
year ={2020},
volume ={29},
pages ={8760-8775},}
@article{zhang2019EnsNet,
title = {EnsNet: Ensconce Text in the Wild},
author = {Shuaitao Zhang∗, Yuliang Liu∗, Lianwen Jin†, Yaoxiong Huang, Songxuan Lai
joural = {AAAI}
year = {2019}
}
```
## Feedback
Suggestions and opinions of our work (both positive and negative) are greatly welcome. Please contact the authors by sending email to Chongyu Liu([liuchongyu1996@gmail.com](mailto:liuchongyu1996@gmail.com)). For commercial usage, please contact Prof. Lianwen Jin via ([eelwjin@scut.edu.cn](mailto:eelwjin@scut.edu.cn)).
================================================
FILE: data/dataloader.py
================================================
import torch
from torch.utils.data import Dataset
from PIL import Image
import numpy as np
import cv2
from os import listdir, walk
from os.path import join
from random import randint
import random
from PIL import Image
from torchvision.transforms import Compose, RandomCrop, ToTensor, ToPILImage, Resize, RandomHorizontalFlip
def random_horizontal_flip(imgs):
if random.random() < 0.3:
for i in range(len(imgs)):
imgs[i] = imgs[i].transpose(Image.FLIP_LEFT_RIGHT)
return imgs
def random_rotate(imgs):
if random.random() < 0.3:
max_angle = 10
angle = random.random() * 2 * max_angle - max_angle
# print(angle)
for i in range(len(imgs)):
img = np.array(imgs[i])
w, h = img.shape[:2]
rotation_matrix = cv2.getRotationMatrix2D((h / 2, w / 2), angle, 1)
img_rotation = cv2.warpAffine(img, rotation_matrix, (h, w))
imgs[i] =Image.fromarray(img_rotation)
return imgs
def CheckImageFile(filename):
return any(filename.endswith(extention) for extention in ['.png', '.PNG', '.jpg', '.JPG', '.jpeg', '.JPEG', '.bmp', '.BMP'])
def ImageTransform(loadSize):
return Compose([
Resize(size=loadSize, interpolation=Image.BICUBIC),
ToTensor(),
])
class ErasingData(Dataset):
def __init__(self, dataRoot, loadSize, training=True):
super(ErasingData, self).__init__()
self.imageFiles = [join (dataRootK, files) for dataRootK, dn, filenames in walk(dataRoot) \
for files in filenames if CheckImageFile(files)]
self.loadSize = loadSize
self.ImgTrans = ImageTransform(loadSize)
self.training = training
def __getitem__(self, index):
img = Image.open(self.imageFiles[index])
mask = Image.open(self.imageFiles[index].replace('all_images','mask'))
gt = Image.open(self.imageFiles[index].replace('all_images','all_labels'))
# import pdb;pdb.set_trace()
if self.training:
# ### for data augmentation
all_input = [img, mask, gt]
all_input = random_horizontal_flip(all_input)
all_input = random_rotate(all_input)
img = all_input[0]
mask = all_input[1]
gt = all_input[2]
### for data augmentation
inputImage = self.ImgTrans(img.convert('RGB'))
mask = self.ImgTrans(mask.convert('RGB'))
groundTruth = self.ImgTrans(gt.convert('RGB'))
path = self.imageFiles[index].split('/')[-1]
# import pdb;pdb.set_trace()
return inputImage, groundTruth, mask, path
def __len__(self):
return len(self.imageFiles)
class devdata(Dataset):
def __init__(self, dataRoot, gtRoot, loadSize=512):
super(devdata, self).__init__()
self.imageFiles = [join (dataRootK, files) for dataRootK, dn, filenames in walk(dataRoot) \
for files in filenames if CheckImageFile(files)]
self.gtFiles = [join (gtRootK, files) for gtRootK, dn, filenames in walk(gtRoot) \
for files in filenames if CheckImageFile(files)]
self.loadSize = loadSize
self.ImgTrans = ImageTransform(loadSize)
def __getitem__(self, index):
img = Image.open(self.imageFiles[index])
gt = Image.open(self.gtFiles[index])
#import pdb;pdb.set_trace()
inputImage = self.ImgTrans(img.convert('RGB'))
groundTruth = self.ImgTrans(gt.convert('RGB'))
path = self.imageFiles[index].split('/')[-1]
return inputImage, groundTruth,path
def __len__(self):
return len(self.imageFiles)
================================================
FILE: evaluatuion.py
================================================
import os
import math
import argparse
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
from PIL import Image
import numpy as np
from torch.autograd import Variable
from torchvision.utils import save_image
from torch.utils.data import DataLoader
from data.dataloader import devdata
from scipy import signal, ndimage
import gauss
parser = argparse.ArgumentParser()
parser.add_argument('--target_path', type=str, default='',
help='results')
parser.add_argument('--gt_path', type=str, default='',
help='labels')
args = parser.parse_args()
sum_psnr = 0
sum_ssim = 0
sum_AGE = 0
sum_pCEPS = 0
sum_pEPS = 0
sum_mse = 0
count = 0
sum_time = 0.0
l1_loss = 0
img_path = args.target_path
gt_path = args.gt_path
def ssim(img1, img2, cs_map=False):
"""Return the Structural Similarity Map corresponding to input images img1
and img2 (images are assumed to be uint8)
This function attempts to mimic precisely the functionality of ssim.m a
MATLAB provided by the author's of SSIM
https://ece.uwaterloo.ca/~z70wang/research/ssim/ssim_index.m
"""
img1 = img1.astype(float)
img2 = img2.astype(float)
size = min(img1.shape[0], 11)
sigma = 1.5
window = gauss.fspecial_gauss(size, sigma)
K1 = 0.01
K2 = 0.03
L = 255 #bitdepth of image
C1 = (K1 * L) ** 2
C2 = (K2 * L) ** 2
# import pdb;pdb.set_trace()
mu1 = signal.fftconvolve(img1, window, mode = 'valid')
mu2 = signal.fftconvolve(img2, window, mode = 'valid')
mu1_sq = mu1 * mu1
mu2_sq = mu2 * mu2
mu1_mu2 = mu1 * mu2
sigma1_sq = signal.fftconvolve(img1 * img1, window, mode = 'valid') - mu1_sq
sigma2_sq = signal.fftconvolve(img2 * img2, window, mode = 'valid') - mu2_sq
sigma12 = signal.fftconvolve(img1 * img2, window, mode = 'valid') - mu1_mu2
if cs_map:
return (((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)),
(2.0 * sigma12 + C2) / (sigma1_sq + sigma2_sq + C2))
else:
return ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) *
(sigma1_sq + sigma2_sq + C2))
def msssim(img1, img2):
"""This function implements Multi-Scale Structural Similarity (MSSSIM) Image
Quality Assessment according to Z. Wang's "Multi-scale structural similarity
for image quality assessment" Invited Paper, IEEE Asilomar Conference on
Signals, Systems and Computers, Nov. 2003
Author's MATLAB implementation:-
http://www.cns.nyu.edu/~lcv/ssim/msssim.zip
"""
level = 5
weight = np.array([0.0448, 0.2856, 0.3001, 0.2363, 0.1333])
downsample_filter = np.ones((2, 2)) / 4.0
# im1 = img1.astype(np.float64)
# im2 = img2.astype(np.float64)
mssim = np.array([])
mcs = np.array([])
for l in range(level):
ssim_map, cs_map = ssim(img1, img2, cs_map = True)
mssim = np.append(mssim, ssim_map.mean())
mcs = np.append(mcs, cs_map.mean())
filtered_im1 = ndimage.filters.convolve(img1, downsample_filter,
mode = 'reflect')
filtered_im2 = ndimage.filters.convolve(img2, downsample_filter,
mode = 'reflect')
im1 = filtered_im1[: : 2, : : 2]
im2 = filtered_im2[: : 2, : : 2]
# Note: Remove the negative and add it later to avoid NaN in exponential.
sign_mcs = np.sign(mcs[0 : level - 1])
sign_mssim = np.sign(mssim[level - 1])
mcs_power = np.power(np.abs(mcs[0 : level - 1]), weight[0 : level - 1])
mssim_power = np.power(np.abs(mssim[level - 1]), weight[level - 1])
return np.prod(sign_mcs * mcs_power) * sign_mssim * mssim_power
def ImageTransform(loadSize, cropSize):
return Compose([
Resize(size=loadSize, interpolation=Image.BICUBIC),
# RandomCrop(size=cropSize),
#RandomHorizontalFlip(p=0.5),
ToTensor(),
])
def visual(image):
im =(image).transpose(1,2).transpose(2,3).detach().cpu().numpy()
Image.fromarray(im[0].astype(np.uint8)).show()
imgData = devdata(dataRoot=img_path, gtRoot=gt_path)
data_loader = DataLoader(imgData, batch_size=1, shuffle=True, num_workers=0, drop_last=False)
for k, (img,lbl,path) in enumerate(data_loader):
##import pdb;pdb.set_trace()
mse = ((lbl - img)**2).mean()
sum_mse += mse
print(path,count, 'mse: ', mse)
if mse == 0:
continue
count += 1
psnr = 10 * math.log10(1/mse)
sum_psnr += psnr
print(path,count, ' psnr: ', psnr)
#l1_loss += nn.L1Loss()(img, lbl)
R = lbl[0,0,:, :]
G = lbl[0,1,:, :]
B = lbl[0,2,:, :]
YGT = .299 * R + .587 * G + .114 * B
R = img[0,0,:, :]
G = img[0,1,:, :]
B = img[0,2,:, :]
YBC = .299 * R + .587 * G + .114 * B
Diff = abs(np.array(YBC*255) - np.array(YGT*255)).round().astype(np.uint8)
AGE = np.mean(Diff)
print(' AGE: ', AGE)
mssim = msssim(np.array(YGT*255), np.array(YBC*255))
sum_ssim += mssim
print(count, ' ssim:', mssim)
threshold = 20
Errors = Diff > threshold
EPs = sum(sum(Errors)).astype(float)
pEPs = EPs / float(512*512)
print(' pEPS: ' , pEPs)
sum_pEPS += pEPs
########################## CEPs and pCEPs ################################
structure = np.array([[0, 1, 0], [1, 1, 1], [0, 1, 0]])
sum_AGE+=AGE
erodedErrors = ndimage.binary_erosion(Errors, structure).astype(Errors.dtype)
CEPs = sum(sum(erodedErrors))
pCEPs = CEPs / float(512*512)
print(' pCEPS: ' , pCEPs)
sum_pCEPS += pCEPs
print(sum_psnr)
print('avg mse:', sum_mse / count)
print('average psnr:', sum_psnr / count)
print('average ssim:', sum_ssim / count)
print('average AGE:', sum_AGE / count)
print('average pEPS:', sum_pEPS / count)
print('average pCEPS:', sum_pCEPS / count)
================================================
FILE: gauss.py
================================================
#!/usr/bin/env python
"""Module providing functionality surrounding gaussian function.
"""
SVN_REVISION = '$LastChangedRevision: 16541 $'
import sys
import numpy
def gaussian2(size, sigma):
"""Returns a normalized circularly symmetric 2D gauss kernel array
f(x,y) = A.e^{-(x^2/2*sigma^2 + y^2/2*sigma^2)} where
A = 1/(2*pi*sigma^2)
as define by Wolfram Mathworld
http://mathworld.wolfram.com/GaussianFunction.html
"""
A = 1/(2.0*numpy.pi*sigma**2)
x, y = numpy.mgrid[-size//2 + 1:size//2 + 1, -size//2 + 1:size//2 + 1]
g = A*numpy.exp(-((x**2/(2.0*sigma**2))+(y**2/(2.0*sigma**2))))
return g
def fspecial_gauss(size, sigma):
"""Function to mimic the 'fspecial' gaussian MATLAB function
"""
x, y = numpy.mgrid[-size//2 + 1:size//2 + 1, -size//2 + 1:size//2 + 1]
g = numpy.exp(-((x**2 + y**2)/(2.0*sigma**2)))
return g/g.sum()
def main():
"""Show simple use cases for functionality provided by this module."""
from mpl_toolkits.mplot3d.axes3d import Axes3D
import pylab
argv = sys.argv
if len(argv) != 3:
print >>sys.stderr, 'usage: python -m pim.sp.gauss size sigma'
sys.exit(2)
size = int(argv[1])
sigma = float(argv[2])
x, y = numpy.mgrid[-size//2 + 1:size//2 + 1, -size//2 + 1:size//2 + 1]
fig = pylab.figure()
fig.suptitle('Some 2-D Gauss Functions')
ax = fig.add_subplot(2, 1, 1, projection='3d')
ax.plot_surface(x, y, fspecial_gauss(size, sigma), rstride=1, cstride=1,
linewidth=0, antialiased=False, cmap=pylab.jet())
ax = fig.add_subplot(2, 1, 2, projection='3d')
ax.plot_surface(x, y, gaussian2(size, sigma), rstride=1, cstride=1,
linewidth=0, antialiased=False, cmap=pylab.jet())
pylab.show()
return 0
if __name__ == '__main__':
sys.exit(main())
================================================
FILE: loss/Loss.py
================================================
import torch
from torch import nn
from torch import autograd
import torch.nn.functional as F
from tensorboardX import SummaryWriter
from models.discriminator import Discriminator_STE
from PIL import Image
import numpy as np
def gram_matrix(feat):
# https://github.com/pytorch/examples/blob/master/fast_neural_style/neural_style/utils.py
(b, ch, h, w) = feat.size()
feat = feat.view(b, ch, h * w)
feat_t = feat.transpose(1, 2)
gram = torch.bmm(feat, feat_t) / (ch * h * w)
return gram
def visual(image):
im = image.transpose(1,2).transpose(2,3).detach().cpu().numpy()
Image.fromarray(im[0].astype(np.uint8)).show()
def dice_loss(input, target):
input = torch.sigmoid(input)
input = input.contiguous().view(input.size()[0], -1)
target = target.contiguous().view(target.size()[0], -1)
input = input
target = target
a = torch.sum(input * target, 1)
b = torch.sum(input * input, 1) + 0.001
c = torch.sum(target * target, 1) + 0.001
d = (2 * a) / (b + c)
dice_loss = torch.mean(d)
return 1 - dice_loss
class LossWithGAN_STE(nn.Module):
def __init__(self, logPath, extractor, Lamda, lr, betasInit=(0.5, 0.9)):
super(LossWithGAN_STE, self).__init__()
self.l1 = nn.L1Loss()
self.extractor = extractor
self.discriminator = Discriminator_STE(3) ## local_global sn patch gan
self.D_optimizer = torch.optim.Adam(self.discriminator.parameters(), lr=lr, betas=betasInit)
self.cudaAvailable = torch.cuda.is_available()
self.numOfGPUs = torch.cuda.device_count()
self.lamda = Lamda
self.writer = SummaryWriter(logPath)
def forward(self, input, mask, x_o1,x_o2,x_o3,output,mm, gt, count, epoch):
self.discriminator.zero_grad()
D_real = self.discriminator(gt, mask)
D_real = D_real.mean().sum() * -1
D_fake = self.discriminator(output, mask)
D_fake = D_fake.mean().sum() * 1
D_loss = torch.mean(F.relu(1.+D_real)) + torch.mean(F.relu(1.+D_fake)) #SN-patch-GAN loss
D_fake = -torch.mean(D_fake) # SN-Patch-GAN loss
self.D_optimizer.zero_grad()
D_loss.backward(retain_graph=True)
self.D_optimizer.step()
self.writer.add_scalar('LossD/Discrinimator loss', D_loss.item(), count)
output_comp = mask * input + (1 - mask) * output
# import pdb;pdb.set_trace()
holeLoss = 10 * self.l1((1 - mask) * output, (1 - mask) * gt)
validAreaLoss = 2*self.l1(mask * output, mask * gt)
mask_loss = dice_loss(mm, 1-mask)
### MSR loss ###
masks_a = F.interpolate(mask, scale_factor=0.25)
masks_b = F.interpolate(mask, scale_factor=0.5)
imgs1 = F.interpolate(gt, scale_factor=0.25)
imgs2 = F.interpolate(gt, scale_factor=0.5)
msrloss = 8 * self.l1((1-mask)*x_o3,(1-mask)*gt) + 0.8*self.l1(mask*x_o3, mask*gt)+\
6 * self.l1((1-masks_b)*x_o2,(1-masks_b)*imgs2)+1*self.l1(masks_b*x_o2,masks_b*imgs2)+\
5 * self.l1((1-masks_a)*x_o1,(1-masks_a)*imgs1)+0.8*self.l1(masks_a*x_o1,masks_a*imgs1)
feat_output_comp = self.extractor(output_comp)
feat_output = self.extractor(output)
feat_gt = self.extractor(gt)
prcLoss = 0.0
for i in range(3):
prcLoss += 0.01 * self.l1(feat_output[i], feat_gt[i])
prcLoss += 0.01 * self.l1(feat_output_comp[i], feat_gt[i])
styleLoss = 0.0
for i in range(3):
styleLoss += 120 * self.l1(gram_matrix(feat_output[i]),
gram_matrix(feat_gt[i]))
styleLoss += 120 * self.l1(gram_matrix(feat_output_comp[i]),
gram_matrix(feat_gt[i]))
""" if self.numOfGPUs > 1:
holeLoss = holeLoss.sum() / self.numOfGPUs
validAreaLoss = validAreaLoss.sum() / self.numOfGPUs
prcLoss = prcLoss.sum() / self.numOfGPUs
styleLoss = styleLoss.sum() / self.numOfGPUs """
self.writer.add_scalar('LossG/Hole loss', holeLoss.item(), count)
self.writer.add_scalar('LossG/Valid loss', validAreaLoss.item(), count)
self.writer.add_scalar('LossG/msr loss', msrloss.item(), count)
self.writer.add_scalar('LossPrc/Perceptual loss', prcLoss.item(), count)
self.writer.add_scalar('LossStyle/style loss', styleLoss.item(), count)
GLoss = msrloss+ holeLoss + validAreaLoss+ prcLoss + styleLoss + 0.1 * D_fake + 1*mask_loss
self.writer.add_scalar('Generator/Joint loss', GLoss.item(), count)
return GLoss.sum()
================================================
FILE: models/Model.py
================================================
import torch
import torch.nn as nn
from torchvision import models
#VGG16 feature extract
class VGG16FeatureExtractor(nn.Module):
def __init__(self):
super(VGG16FeatureExtractor, self).__init__()
vgg16 = models.vgg16(pretrained=True)
# vgg16.load_state_dict(torch.load('./vgg16-397923af.pth'))
self.enc_1 = nn.Sequential(*vgg16.features[:5])
self.enc_2 = nn.Sequential(*vgg16.features[5:10])
self.enc_3 = nn.Sequential(*vgg16.features[10:17])
# fix the encoder
for i in range(3):
for param in getattr(self, 'enc_{:d}'.format(i + 1)).parameters():
param.requires_grad = False
def forward(self, image):
results = [image]
for i in range(3):
func = getattr(self, 'enc_{:d}'.format(i + 1))
results.append(func(results[-1]))
return results[1:]
================================================
FILE: models/discriminator.py
================================================
import torch
import torch.nn as nn
from .networks import ConvWithActivation, get_pad
##discriminator
class Discriminator_STE(nn.Module):
def __init__(self, inputChannels):
super(Discriminator_STE, self).__init__()
cnum =32
self.globalDis = nn.Sequential(
ConvWithActivation(3, 2*cnum, 4, 2, padding=get_pad(256, 5, 2)),
ConvWithActivation(2*cnum, 4*cnum, 4, 2, padding=get_pad(128, 5, 2)),
ConvWithActivation(4*cnum, 8*cnum, 4, 2, padding=get_pad(64, 5, 2)),
ConvWithActivation(8*cnum, 8*cnum, 4, 2, padding=get_pad(32, 5, 2)),
ConvWithActivation(8*cnum, 8*cnum, 4, 2, padding=get_pad(16, 5, 2)),
ConvWithActivation(8*cnum, 8*cnum, 4, 2, padding=get_pad(8, 5, 2)),
)
self.localDis = nn.Sequential(
ConvWithActivation(3, 2*cnum, 4, 2, padding=get_pad(256, 5, 2)),
ConvWithActivation(2*cnum, 4*cnum, 4, 2, padding=get_pad(128, 5, 2)),
ConvWithActivation(4*cnum, 8*cnum, 4, 2, padding=get_pad(64, 5, 2)),
ConvWithActivation(8*cnum, 8*cnum, 4, 2, padding=get_pad(32, 5, 2)),
ConvWithActivation(8*cnum, 8*cnum, 4, 2, padding=get_pad(16, 5, 2)),
ConvWithActivation(8*cnum, 8*cnum, 4, 2, padding=get_pad(8, 5, 2)),
)
self.fusion = nn.Sequential(
nn.Conv2d(512, 1, kernel_size=4),
nn.Sigmoid()
)
def forward(self, input, masks):
global_feat = self.globalDis(input)
local_feat = self.localDis(input * (1 - masks))
concat_feat = torch.cat((global_feat, local_feat), 1)
return self.fusion(concat_feat).view(input.size()[0], -1)
================================================
FILE: models/networks.py
================================================
import torch
import numpy as np
import torch.nn.functional as F
import torch.nn as nn
def get_pad(in_, ksize, stride, atrous=1):
out_ = np.ceil(float(in_)/stride)
return int(((out_ - 1) * stride + atrous*(ksize-1) + 1 - in_)/2)
class ConvWithActivation(torch.nn.Module):
"""
SN convolution for spetral normalization conv
"""
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, activation=torch.nn.LeakyReLU(0.2, inplace=True)):
super(ConvWithActivation, self).__init__()
self.conv2d = torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias)
self.conv2d = torch.nn.utils.spectral_norm(self.conv2d)
self.activation = activation
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight)
def forward(self, input):
x = self.conv2d(input)
if self.activation is not None:
return self.activation(x)
else:
return x
class DeConvWithActivation(torch.nn.Module):
"""
SN convolution for spetral normalization conv
"""
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, activation=torch.nn.LeakyReLU(0.2, inplace=True)):
super(DeConvWithActivation, self).__init__()
self.conv2d = torch.nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias)
self.conv2d = torch.nn.utils.spectral_norm(self.conv2d)
self.activation = activation
for m in self.modules():
if isinstance(m, nn.ConvTranspose2d):
nn.init.kaiming_normal_(m.weight)
def forward(self, input):
x = self.conv2d(input)
if self.activation is not None:
return self.activation(x)
else:
return x
================================================
FILE: models/sa_gan.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from PIL import Image
from torch.autograd import Variable
from .networks import get_pad, ConvWithActivation, DeConvWithActivation
def img2photo(imgs):
return ((imgs+1)*127.5).transpose(1,2).transpose(2,3).detach().cpu().numpy()
def visual(imgs):
im = img2photo(imgs)
Image.fromarray(im[0].astype(np.uint8)).show()
class Residual(nn.Module):
def __init__(self, in_channels, out_channels, same_shape=True, **kwargs):
super(Residual,self).__init__()
self.same_shape = same_shape
strides = 1 if same_shape else 2
self.conv1 = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1,stride=strides)
self.conv2 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
# self.conv2 = torch.nn.utils.spectral_norm(self.conv2)
if not same_shape:
self.conv3 = nn.Conv2d(in_channels, out_channels, kernel_size=1,
# self.conv3 = nn.Conv2D(channels, kernel_size=3, padding=1,
stride=strides)
# self.conv3 = torch.nn.utils.spectral_norm(self.conv3)
self.batch_norm2d = nn.BatchNorm2d(out_channels)
def forward(self,x):
out = F.relu(self.conv1(x))
out = self.conv2(out)
if not self.same_shape:
x = self.conv3(x)
out = self.batch_norm2d(out + x)
# out = out + x
return F.relu(out)
class ASPP(nn.Module):
def __init__(self, in_channel=512, depth=256):
super(ASPP,self).__init__()
self.mean = nn.AdaptiveAvgPool2d((1, 1))
self.conv = nn.Conv2d(in_channel, depth, 1, 1)
# k=1 s=1 no pad
self.atrous_block1 = nn.Conv2d(in_channel, depth, 1, 1)
self.atrous_block6 = nn.Conv2d(in_channel, depth, 3, 1, padding=6, dilation=6)
self.atrous_block12 = nn.Conv2d(in_channel, depth, 3, 1, padding=12, dilation=12)
self.atrous_block18 = nn.Conv2d(in_channel, depth, 3, 1, padding=18, dilation=18)
self.conv_1x1_output = nn.Conv2d(depth * 5, depth, 1, 1)
def forward(self, x):
size = x.shape[2:]
image_features = self.mean(x)
image_features = self.conv(image_features)
image_features = F.upsample(image_features, size=size, mode='bilinear')
atrous_block1 = self.atrous_block1(x)
atrous_block6 = self.atrous_block6(x)
atrous_block12 = self.atrous_block12(x)
atrous_block18 = self.atrous_block18(x)
net = self.conv_1x1_output(torch.cat([image_features, atrous_block1, atrous_block6,
atrous_block12, atrous_block18], dim=1))
return net
class STRnet2(nn.Module):
def __init__(self, n_in_channel=3):
super(STRnet2, self).__init__()
#### U-Net ####
#downsample
self.conv1 = ConvWithActivation(3,32,kernel_size=4,stride=2,padding=1)
self.conva = ConvWithActivation(32,32,kernel_size=3, stride=1, padding=1)
self.convb = ConvWithActivation(32,64, kernel_size=4, stride=2, padding=1)
self.res1 = Residual(64,64)
self.res2 = Residual(64,64)
self.res3 = Residual(64,128,same_shape=False)
self.res4 = Residual(128,128)
self.res5 = Residual(128,256,same_shape=False)
# self.nn = ConvWithActivation(256, 512, 3, 1, dilation=2, padding=get_pad(64, 3, 1, 2))
self.res6 = Residual(256,256)
self.res7 = Residual(256,512,same_shape=False)
self.res8 = Residual(512,512)
self.conv2 = ConvWithActivation(512,512,kernel_size=1)
#upsample
self.deconv1 = DeConvWithActivation(512,256,kernel_size=3,padding=1,stride=2)
self.deconv2 = DeConvWithActivation(256*2,128,kernel_size=3,padding=1,stride=2)
self.deconv3 = DeConvWithActivation(128*2,64,kernel_size=3,padding=1,stride=2)
self.deconv4 = DeConvWithActivation(64*2,32,kernel_size=3,padding=1,stride=2)
self.deconv5 = DeConvWithActivation(64,3,kernel_size=3,padding=1,stride=2)
#lateral connection
self.lateral_connection1 = nn.Sequential(
nn.Conv2d(256, 256, kernel_size=1, padding=0,stride=1),
nn.Conv2d(256, 512, kernel_size=3, padding=1,stride=1),
nn.Conv2d(512, 512, kernel_size=3, padding=1,stride=1),
nn.Conv2d(512, 256, kernel_size=1, padding=0,stride=1),)
self.lateral_connection2 = nn.Sequential(
nn.Conv2d(128, 128, kernel_size=1, padding=0,stride=1),
nn.Conv2d(128, 256, kernel_size=3, padding=1,stride=1),
nn.Conv2d(256, 256, kernel_size=3, padding=1,stride=1),
nn.Conv2d(256, 128, kernel_size=1, padding=0,stride=1),)
self.lateral_connection3 = nn.Sequential(
nn.Conv2d(64, 64, kernel_size=1, padding=0,stride=1),
nn.Conv2d(64, 128, kernel_size=3, padding=1,stride=1),
nn.Conv2d(128, 128, kernel_size=3, padding=1,stride=1),
nn.Conv2d(128, 64, kernel_size=1, padding=0,stride=1),)
self.lateral_connection4 = nn.Sequential(
nn.Conv2d(32, 32, kernel_size=1, padding=0,stride=1),
nn.Conv2d(32, 64, kernel_size=3, padding=1,stride=1),
nn.Conv2d(64, 64, kernel_size=3, padding=1,stride=1),
nn.Conv2d(64, 32, kernel_size=1, padding=0,stride=1),)
#self.relu = nn.elu(alpha=1.0)
self.conv_o1 = nn.Conv2d(64,3,kernel_size=1)
self.conv_o2 = nn.Conv2d(32,3,kernel_size=1)
##### U-Net #####
### ASPP ###
# self.aspp = ASPP(512, 256)
### ASPP ###
### mask branch decoder ###
self.mask_deconv_a = DeConvWithActivation(512,256,kernel_size=3,padding=1,stride=2)
self.mask_conv_a = ConvWithActivation(256,128,kernel_size=3,padding=1,stride=1)
self.mask_deconv_b = DeConvWithActivation(256,128,kernel_size=3,padding=1,stride=2)
self.mask_conv_b = ConvWithActivation(128,64,kernel_size=3,padding=1,stride=1)
self.mask_deconv_c = DeConvWithActivation(128,64,kernel_size=3,padding=1,stride=2)
self.mask_conv_c = ConvWithActivation(64,32,kernel_size=3,padding=1,stride=1)
self.mask_deconv_d = DeConvWithActivation(64,32,kernel_size=3,padding=1,stride=2)
self.mask_conv_d = nn.Conv2d(32,3,kernel_size=1)
### mask branch ###
##### Refine sub-network ######
n_in_channel = 3
cnum = 32
####downsapmle
self.coarse_conva = ConvWithActivation(n_in_channel, cnum, kernel_size=5, stride=1, padding=2)
self.coarse_convb = ConvWithActivation(cnum, 2*cnum, kernel_size=4, stride=2, padding=1)
self.coarse_convc = ConvWithActivation(2*cnum, 2*cnum, kernel_size=3, stride=1, padding=1)
self.coarse_convd = ConvWithActivation(2*cnum, 4*cnum, kernel_size=4, stride=2, padding=1)
self.coarse_conve = ConvWithActivation(4*cnum, 4*cnum, kernel_size=3, stride=1, padding=1)
self.coarse_convf = ConvWithActivation(4*cnum, 4*cnum, kernel_size=3, stride=1, padding=1)
### astrous
self.astrous_net = nn.Sequential(
ConvWithActivation(4*cnum, 4*cnum, 3, 1, dilation=2, padding=get_pad(64, 3, 1, 2)),
ConvWithActivation(4*cnum, 4*cnum, 3, 1, dilation=4, padding=get_pad(64, 3, 1, 4)),
ConvWithActivation(4*cnum, 4*cnum, 3, 1, dilation=8, padding=get_pad(64, 3, 1, 8)),
ConvWithActivation(4*cnum, 4*cnum, 3, 1, dilation=16, padding=get_pad(64, 3, 1, 16)),
)
###astrous
### upsample
self.coarse_convk = ConvWithActivation(4*cnum, 4*cnum, kernel_size=3, stride=1, padding=1)
self.coarse_convl = ConvWithActivation(4*cnum, 4*cnum, kernel_size=3, stride=1, padding=1)
self.coarse_deconva = DeConvWithActivation(4*cnum*3, 2*cnum, kernel_size=3,padding=1,stride=2)
self.coarse_convm = ConvWithActivation(2*cnum, 2*cnum, kernel_size=3, stride=1, padding=1)
self.coarse_deconvb = DeConvWithActivation(2*cnum*3, cnum, kernel_size=3,padding=1,stride=2)
self.coarse_convn = nn.Sequential(
ConvWithActivation(cnum, cnum//2, kernel_size=3, stride=1, padding=1),
#Self_Attn(cnum//2, 'relu'),
ConvWithActivation(cnum//2, 3, kernel_size=3, stride=1, padding=1, activation=None),
)
self.c1 = nn.Conv2d(32,64,kernel_size=1)
self.c2 = nn.Conv2d(64,128,kernel_size=1)
##### Refine network ######
def forward(self, x):
#downsample
x = self.conv1(x)
x = self.conva(x)
con_x1 = x
# import pdb;pdb.set_trace()
x = self.convb(x)
x = self.res1(x)
con_x2 = x
x = self.res2(x)
x = self.res3(x)
con_x3 = x
x = self.res4(x)
x = self.res5(x)
con_x4 = x
x = self.res6(x)
# x_mask = self.nn(con_x4) ### for mask branch aspp
# x_mask = self.aspp(x_mask) ### for mask branch aspp
x_mask=x ### no aspp
# import pdb;pdb.set_trace()
x = self.res7(x)
x = self.res8(x)
x = self.conv2(x)
#upsample
x = self.deconv1(x)
x = torch.cat([self.lateral_connection1(con_x4), x], dim=1)
x = self.deconv2(x)
x = torch.cat([self.lateral_connection2(con_x3), x], dim=1)
x = self.deconv3(x)
xo1 = x
x = torch.cat([self.lateral_connection3(con_x2), x], dim=1)
x = self.deconv4(x)
xo2 = x
x = torch.cat([self.lateral_connection4(con_x1), x], dim=1)
#import pdb;pdb.set_trace()
x = self.deconv5(x)
x_o1 = self.conv_o1(xo1)
x_o2 = self.conv_o2(xo2)
x_o_unet = x
### mask branch ###
mm = self.mask_deconv_a(torch.cat([x_mask,con_x4],dim=1))
mm = self.mask_conv_a(mm)
mm = self.mask_deconv_b(torch.cat([mm,con_x3],dim=1))
mm = self.mask_conv_b(mm)
mm = self.mask_deconv_c(torch.cat([mm,con_x2],dim=1))
mm = self.mask_conv_c(mm)
mm = self.mask_deconv_d(torch.cat([mm,con_x1],dim=1))
mm = self.mask_conv_d(mm)
### mask branch ###
###refine sub-network
x = self.coarse_conva(x_o_unet)
x = self.coarse_convb(x)
x = self.coarse_convc(x)
x_c1 = x ###concate feature1
x = self.coarse_convd(x)
x = self.coarse_conve(x)
x = self.coarse_convf(x)
x_c2 = x ###concate feature2
x = self.astrous_net(x)
x = self.coarse_convk(x)
x = self.coarse_convl(x)
x = self.coarse_deconva(torch.cat([x, x_c2,self.c2(con_x2)],dim=1))
x = self.coarse_convm(x)
x = self.coarse_deconvb(torch.cat([x,x_c1,self.c1(con_x1)],dim=1))
x = self.coarse_convn(x)
return x_o1, x_o2, x_o_unet, x, mm
================================================
FILE: test_image_STE.py
================================================
import os
import math
import argparse
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
from PIL import Image
import numpy as np
from torch.autograd import Variable
from torchvision.utils import save_image
from torch.utils.data import DataLoader
from data.dataloader import ErasingData
from models.sa_gan import STRnet2
parser = argparse.ArgumentParser()
parser.add_argument('--numOfWorkers', type=int, default=0,
help='workers for dataloader')
parser.add_argument('--modelsSavePath', type=str, default='',
help='path for saving models')
parser.add_argument('--logPath', type=str,
default='')
parser.add_argument('--batchSize', type=int, default=16)
parser.add_argument('--loadSize', type=int, default=512,
help='image loading size')
parser.add_argument('--dataRoot', type=str,
default='')
parser.add_argument('--pretrained',type=str, default='', help='pretrained models for finetuning')
parser.add_argument('--savePath', type=str, default='./results/sn_tv/')
args = parser.parse_args()
cuda = torch.cuda.is_available()
if cuda:
print('Cuda is available!')
cudnn.benchmark = True
def visual(image):
im =(image).transpose(1,2).transpose(2,3).detach().cpu().numpy()
Image.fromarray(im[0].astype(np.uint8)).show()
batchSize = args.batchSize
loadSize = (args.loadSize, args.loadSize)
dataRoot = args.dataRoot
savePath = args.savePath
result_with_mask = savePath + 'WithMaskOutput/'
result_straight = savePath + 'StrOuput/'
#import pdb;pdb.set_trace()
if not os.path.exists(savePath):
os.makedirs(savePath)
os.makedirs(result_with_mask)
os.makedirs(result_straight)
Erase_data = ErasingData(dataRoot, loadSize, training=False)
Erase_data = DataLoader(Erase_data, batch_size=batchSize, shuffle=True, num_workers=args.numOfWorkers, drop_last=False)
netG = STRnet2(3)
netG.load_state_dict(torch.load(args.pretrained))
#
if cuda:
netG = netG.cuda()
for param in netG.parameters():
param.requires_grad = False
print('OK!')
import time
start = time.time()
netG.eval()
for imgs, gt, masks, path in (Erase_data):
if cuda:
imgs = imgs.cuda()
gt = gt.cuda()
masks = masks.cuda()
out1, out2, out3, g_images,mm = netG(imgs)
g_image = g_images.data.cpu()
gt = gt.data.cpu()
mask = masks.data.cpu()
g_image_with_mask = gt * (mask) + g_image * (1- mask)
save_image(g_image_with_mask, result_with_mask+path[0])
save_image(g_image, result_straight+path[0])
================================================
FILE: train_STE.py
================================================
import os
import math
import argparse
import torch
import torch.nn as nn
import torch.optim as optim
import torch.backends.cudnn as cudnn
from PIL import Image
import numpy as np
from torch.autograd import Variable
from torchvision.utils import save_image
from torchvision import datasets
from torch.utils.data import DataLoader
from torchvision import utils
from data.dataloader import ErasingData
from loss.Loss import LossWithGAN_STE
from models.Model import VGG16FeatureExtractor
from models.sa_gan import STRnet2
torch.set_num_threads(5)
os.environ["CUDA_VISIBLE_DEVICES"] = "3" ### set the gpu as No....
parser = argparse.ArgumentParser()
parser.add_argument('--numOfWorkers', type=int, default=0,
help='workers for dataloader')
parser.add_argument('--modelsSavePath', type=str, default='',
help='path for saving models')
parser.add_argument('--logPath', type=str,
default='')
parser.add_argument('--batchSize', type=int, default=16)
parser.add_argument('--loadSize', type=int, default=512,
help='image loading size')
parser.add_argument('--dataRoot', type=str,
default='')
parser.add_argument('--pretrained',type=str, default='', help='pretrained models for finetuning')
parser.add_argument('--num_epochs', type=int, default=500, help='epochs')
args = parser.parse_args()
def visual(image):
im = image.transpose(1,2).transpose(2,3).detach().cpu().numpy()
Image.fromarray(im[0].astype(np.uint8)).show()
cuda = torch.cuda.is_available()
if cuda:
print('Cuda is available!')
cudnn.enable = True
cudnn.benchmark = True
batchSize = args.batchSize
loadSize = (args.loadSize, args.loadSize)
if not os.path.exists(args.modelsSavePath):
os.makedirs(args.modelsSavePath)
dataRoot = args.dataRoot
# import pdb;pdb.set_trace()
Erase_data = ErasingData(dataRoot, loadSize, training=True)
Erase_data = DataLoader(Erase_data, batch_size=batchSize,
shuffle=True, num_workers=args.numOfWorkers, drop_last=False, pin_memory=True)
netG = STRnet2(3)
if args.pretrained != '':
print('loaded ')
netG.load_state_dict(torch.load(args.pretrained))
numOfGPUs = torch.cuda.device_count()
if cuda:
netG = netG.cuda()
if numOfGPUs > 1:
netG = nn.DataParallel(netG, device_ids=range(numOfGPUs))
count = 1
G_optimizer = optim.Adam(netG.parameters(), lr=0.0001, betas=(0.5, 0.9))
criterion = LossWithGAN_STE(args.logPath, VGG16FeatureExtractor(), lr=0.00001, betasInit=(0.0, 0.9), Lamda=10.0)
if cuda:
criterion = criterion.cuda()
if numOfGPUs > 1:
criterion = nn.DataParallel(criterion, device_ids=range(numOfGPUs))
print('OK!')
num_epochs = args.num_epochs
for i in range(1, num_epochs + 1):
netG.train()
for k,(imgs, gt, masks, path) in enumerate(Erase_data):
if cuda:
imgs = imgs.cuda()
gt = gt.cuda()
masks = masks.cuda()
netG.zero_grad()
x_o1,x_o2,x_o3,fake_images,mm = netG(imgs)
G_loss = criterion(imgs, masks, x_o1, x_o2, x_o3, fake_images, mm, gt, count, i)
G_loss = G_loss.sum()
G_optimizer.zero_grad()
G_loss.backward()
G_optimizer.step()
print('[{}/{}] Generator Loss of epoch{} is {}'.format(k,len(Erase_data),i, G_loss.item()))
count += 1
if ( i % 10 == 0):
if numOfGPUs > 1 :
torch.save(netG.module.state_dict(), args.modelsSavePath +
'/STE_{}.pth'.format(i))
else:
torch.save(netG.state_dict(), args.modelsSavePath +
'/STE_{}.pth'.format(i))
gitextract_sn7w10qi/ ├── LICENSE ├── README.md ├── data/ │ └── dataloader.py ├── evaluatuion.py ├── gauss.py ├── loss/ │ └── Loss.py ├── models/ │ ├── Model.py │ ├── discriminator.py │ ├── networks.py │ └── sa_gan.py ├── test_image_STE.py └── train_STE.py
SYMBOL INDEX (51 symbols across 10 files)
FILE: data/dataloader.py
function random_horizontal_flip (line 14) | def random_horizontal_flip(imgs):
function random_rotate (line 20) | def random_rotate(imgs):
function CheckImageFile (line 33) | def CheckImageFile(filename):
function ImageTransform (line 36) | def ImageTransform(loadSize):
class ErasingData (line 42) | class ErasingData(Dataset):
method __init__ (line 43) | def __init__(self, dataRoot, loadSize, training=True):
method __getitem__ (line 51) | def __getitem__(self, index):
method __len__ (line 73) | def __len__(self):
class devdata (line 76) | class devdata(Dataset):
method __init__ (line 77) | def __init__(self, dataRoot, gtRoot, loadSize=512):
method __getitem__ (line 86) | def __getitem__(self, index):
method __len__ (line 97) | def __len__(self):
FILE: evaluatuion.py
function ssim (line 39) | def ssim(img1, img2, cs_map=False):
function msssim (line 75) | def msssim(img1, img2):
function ImageTransform (line 109) | def ImageTransform(loadSize, cropSize):
function visual (line 117) | def visual(image):
FILE: gauss.py
function gaussian2 (line 9) | def gaussian2(size, sigma):
function fspecial_gauss (line 24) | def fspecial_gauss(size, sigma):
function main (line 31) | def main():
FILE: loss/Loss.py
function gram_matrix (line 10) | def gram_matrix(feat):
function visual (line 18) | def visual(image):
function dice_loss (line 22) | def dice_loss(input, target):
class LossWithGAN_STE (line 38) | class LossWithGAN_STE(nn.Module):
method __init__ (line 39) | def __init__(self, logPath, extractor, Lamda, lr, betasInit=(0.5, 0.9)):
method forward (line 50) | def forward(self, input, mask, x_o1,x_o2,x_o3,output,mm, gt, count, ep...
FILE: models/Model.py
class VGG16FeatureExtractor (line 6) | class VGG16FeatureExtractor(nn.Module):
method __init__ (line 7) | def __init__(self):
method forward (line 20) | def forward(self, image):
FILE: models/discriminator.py
class Discriminator_STE (line 6) | class Discriminator_STE(nn.Module):
method __init__ (line 7) | def __init__(self, inputChannels):
method forward (line 33) | def forward(self, input, masks):
FILE: models/networks.py
function get_pad (line 6) | def get_pad(in_, ksize, stride, atrous=1):
class ConvWithActivation (line 10) | class ConvWithActivation(torch.nn.Module):
method __init__ (line 14) | def __init__(self, in_channels, out_channels, kernel_size, stride=1, p...
method forward (line 22) | def forward(self, input):
class DeConvWithActivation (line 29) | class DeConvWithActivation(torch.nn.Module):
method __init__ (line 33) | def __init__(self, in_channels, out_channels, kernel_size, stride=1, p...
method forward (line 41) | def forward(self, input):
FILE: models/sa_gan.py
function img2photo (line 9) | def img2photo(imgs):
function visual (line 12) | def visual(imgs):
class Residual (line 16) | class Residual(nn.Module):
method __init__ (line 17) | def __init__(self, in_channels, out_channels, same_shape=True, **kwargs):
method forward (line 31) | def forward(self,x):
class ASPP (line 40) | class ASPP(nn.Module):
method __init__ (line 41) | def __init__(self, in_channel=512, depth=256):
method forward (line 53) | def forward(self, x):
class STRnet2 (line 72) | class STRnet2(nn.Module):
method __init__ (line 73) | def __init__(self, n_in_channel=3):
method forward (line 173) | def forward(self, x):
FILE: test_image_STE.py
function visual (line 38) | def visual(image):
FILE: train_STE.py
function visual (line 41) | def visual(image):
Condensed preview — 12 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (43K chars).
[
{
"path": "LICENSE",
"chars": 1068,
"preview": "MIT License\n\nCopyright (c) 2020 Chongyu-Liu\n\nPermission is hereby granted, free of charge, to any person obtaining a cop"
},
{
"path": "README.md",
"chars": 2599,
"preview": "# EraseNet\n\nThis repository is the implementation of EraseNet, a neural network for end-to-end scene text removal.\n\n\n## "
},
{
"path": "data/dataloader.py",
"chars": 3647,
"preview": "import torch\nfrom torch.utils.data import Dataset\nfrom PIL import Image\nimport numpy as np\nimport cv2\nfrom os import lis"
},
{
"path": "evaluatuion.py",
"chars": 5777,
"preview": "import os\nimport math\nimport argparse\nimport torch\nimport torch.nn as nn\nimport torch.backends.cudnn as cudnn\nfrom PIL i"
},
{
"path": "gauss.py",
"chars": 1916,
"preview": "#!/usr/bin/env python\r\n\"\"\"Module providing functionality surrounding gaussian function.\r\n\"\"\"\r\nSVN_REVISION = '$LastChang"
},
{
"path": "loss/Loss.py",
"chars": 4686,
"preview": "import torch\nfrom torch import nn\nfrom torch import autograd\nimport torch.nn.functional as F\nfrom tensorboardX import Su"
},
{
"path": "models/Model.py",
"chars": 888,
"preview": "import torch\nimport torch.nn as nn\nfrom torchvision import models\n\n#VGG16 feature extract\nclass VGG16FeatureExtractor(nn"
},
{
"path": "models/discriminator.py",
"chars": 1715,
"preview": "import torch\nimport torch.nn as nn\nfrom .networks import ConvWithActivation, get_pad\n\n##discriminator\nclass Discriminato"
},
{
"path": "models/networks.py",
"chars": 1953,
"preview": "import torch\nimport numpy as np\nimport torch.nn.functional as F\nimport torch.nn as nn\n\ndef get_pad(in_, ksize, stride, "
},
{
"path": "models/sa_gan.py",
"chars": 10913,
"preview": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport numpy as np\nfrom PIL import Image\nfrom torch.a"
},
{
"path": "test_image_STE.py",
"chars": 2570,
"preview": "import os\nimport math\nimport argparse\nimport torch\nimport torch.nn as nn\nimport torch.backends.cudnn as cudnn\nfrom PIL i"
},
{
"path": "train_STE.py",
"chars": 3674,
"preview": "import os\nimport math\nimport argparse\nimport torch\nimport torch.nn as nn\nimport torch.optim as optim\nimport torch.backen"
}
]
About this extraction
This page contains the full source code of the lcy0604/EraseNet GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 12 files (40.4 KB), approximately 12.1k tokens, and a symbol index with 51 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.
Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.