[
  {
    "path": "LICENSE",
    "content": "MIT License\n\nCopyright (c) 2020 Chongyu-Liu\n\nPermission is hereby granted, free of charge, to any person obtaining a copy\nof this software and associated documentation files (the \"Software\"), to deal\nin the Software without restriction, including without limitation the rights\nto use, copy, modify, merge, publish, distribute, sublicense, and/or sell\ncopies of the Software, and to permit persons to whom the Software is\nfurnished to do so, subject to the following conditions:\n\nThe above copyright notice and this permission notice shall be included in all\ncopies or substantial portions of the Software.\n\nTHE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\nIMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\nFITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\nAUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\nLIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\nOUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\nSOFTWARE.\n"
  },
  {
    "path": "README.md",
    "content": "# EraseNet\n\nThis repository is the implementation of EraseNet, a neural network for end-to-end scene text removal.\n\n\n## Data preparation\n\nThe data preparation can be refer to ./examples/. You can download our datatset at [SCUT-EnsText](https://github.com/HCIILAB/SCUT-EnsText) or synthetic dataset [SCUT-Syn](https://github.com/HCIILAB/Scene-Text-Removal) for training and testing. \n\nSCUT-EnsText needs decompression password, you can send me at [liuchongyu1996@gmail.com](mailto:liuchongyu1996@gmail.com) for it.\n\n## Environment\n\nAnaconda is recommended to establish a virtual environment to run our code. My environment can be refered as follows:\n```\npython = 3.7\npytorch = 1.3.1\ntorchvision = 0.4.2\n```\n\n## Demo\n\nWe provide our retrain model for quick inference for SCUT-EnsText. [Model Link](https://drive.google.com/file/d/1scrtQ2GFvKjjoGEqbKxpOMn37mJmXsFd/view)\n\n## Training\n\nOnce the data is well prepared, you can begin training:\n```\npython train_STE.py --batchSize 4 \\\n  --dataRoot 'your path' \\\n  --modelsSavePath 'your path' \\\n  --logPath 'your path'  \\\n```\n\n## Testing and evaluation\n\nIf you want to predict the results, run:\n\n```\npython test_image_STE.py --dataRoot 'your path'  \\\n            --batchSize 1 \\\n            --pretrain 'your path' \\\n            --savePath 'your path'\n```\n\nTo evaluate the results:\n```\npython evaluatuion.py --target_path 'results_path' --gt_path 'labels_path'\n```\n\n\n\n## Acknowledge\n\nThe repository is benefit a lot from [LBAM](https://github.com/Vious/LBAM_Pytorch) and [GatedConv](https://github.com/avalonstrel/GatedConvolution_pytorch). Thanks a lot for their excellent work.\n\n## Citation\nIf you find our method or dataset useful for your reserach, please cite:\n```\n@ARTICLE{Erase2020Liu,\n  author     ={Liu, Chongyu and Liu, Yuliang and Jin, lianwen and Zhang, Shuaitao and Luo, Canjie and Wang, Yongpan},\n  journal    ={IEEE Transactions on Image Processing},\n  title      ={EraseNet: End-to-End Text Removal in the Wild},\n  year       ={2020},\n  volume     ={29},\n  pages      ={8760-8775},}\n\n@article{zhang2019EnsNet,\n    title     = {EnsNet: Ensconce Text in the Wild},\n    author    = {Shuaitao Zhang∗, Yuliang Liu∗, Lianwen Jin†, Yaoxiong Huang, Songxuan Lai\n    joural    = {AAAI}\n    year      = {2019}\n  }\n```\n\n## Feedback\nSuggestions and opinions of our work (both positive and negative) are greatly welcome. Please contact the authors by sending email to Chongyu Liu([liuchongyu1996@gmail.com](mailto:liuchongyu1996@gmail.com)). For commercial usage, please contact Prof. Lianwen Jin via ([eelwjin@scut.edu.cn](mailto:eelwjin@scut.edu.cn)).\n"
  },
  {
    "path": "data/dataloader.py",
    "content": "import torch\nfrom torch.utils.data import Dataset\nfrom PIL import Image\nimport numpy as np\nimport cv2\nfrom os import listdir, walk\nfrom os.path import join\nfrom random import randint\nimport random\nfrom PIL import Image\nfrom torchvision.transforms import Compose, RandomCrop, ToTensor, ToPILImage, Resize, RandomHorizontalFlip\n\n\ndef random_horizontal_flip(imgs):\n    if random.random() < 0.3:\n        for i in range(len(imgs)):\n            imgs[i] = imgs[i].transpose(Image.FLIP_LEFT_RIGHT)\n    return imgs\n\ndef random_rotate(imgs):\n    if random.random() < 0.3:\n        max_angle = 10\n        angle = random.random() * 2 * max_angle - max_angle\n        # print(angle)\n        for i in range(len(imgs)):\n            img = np.array(imgs[i])\n            w, h = img.shape[:2]\n            rotation_matrix = cv2.getRotationMatrix2D((h / 2, w / 2), angle, 1)\n            img_rotation = cv2.warpAffine(img, rotation_matrix, (h, w))\n            imgs[i] =Image.fromarray(img_rotation)\n    return imgs\n\ndef CheckImageFile(filename):\n    return any(filename.endswith(extention) for extention in ['.png', '.PNG', '.jpg', '.JPG', '.jpeg', '.JPEG', '.bmp', '.BMP'])\n\ndef ImageTransform(loadSize):\n    return Compose([\n        Resize(size=loadSize, interpolation=Image.BICUBIC),\n        ToTensor(),\n    ])\n\nclass ErasingData(Dataset):\n    def __init__(self, dataRoot, loadSize, training=True):\n        super(ErasingData, self).__init__()\n        self.imageFiles = [join (dataRootK, files) for dataRootK, dn, filenames in walk(dataRoot) \\\n            for files in filenames if CheckImageFile(files)]\n        self.loadSize = loadSize\n        self.ImgTrans = ImageTransform(loadSize)\n        self.training = training\n    \n    def __getitem__(self, index):\n        img = Image.open(self.imageFiles[index])\n        mask = Image.open(self.imageFiles[index].replace('all_images','mask'))\n        gt = Image.open(self.imageFiles[index].replace('all_images','all_labels'))\n        # import pdb;pdb.set_trace()\n        if self.training:\n        # ### for data augmentation\n            all_input = [img, mask, gt]\n            all_input = random_horizontal_flip(all_input)   \n            all_input = random_rotate(all_input)\n            img = all_input[0]\n            mask = all_input[1]\n            gt = all_input[2]\n        ### for data augmentation\n        inputImage = self.ImgTrans(img.convert('RGB'))\n        mask = self.ImgTrans(mask.convert('RGB'))\n        groundTruth = self.ImgTrans(gt.convert('RGB'))\n        path = self.imageFiles[index].split('/')[-1]\n       # import pdb;pdb.set_trace()\n\n        return inputImage, groundTruth, mask, path\n    \n    def __len__(self):\n        return len(self.imageFiles)\n\nclass devdata(Dataset):\n    def __init__(self, dataRoot, gtRoot, loadSize=512):\n        super(devdata, self).__init__()\n        self.imageFiles = [join (dataRootK, files) for dataRootK, dn, filenames in walk(dataRoot) \\\n            for files in filenames if CheckImageFile(files)]\n        self.gtFiles = [join (gtRootK, files) for gtRootK, dn, filenames in walk(gtRoot) \\\n            for files in filenames if CheckImageFile(files)]\n        self.loadSize = loadSize\n        self.ImgTrans = ImageTransform(loadSize)\n    \n    def __getitem__(self, index):\n        img = Image.open(self.imageFiles[index])\n        gt = Image.open(self.gtFiles[index])\n        #import pdb;pdb.set_trace()\n        inputImage = self.ImgTrans(img.convert('RGB'))\n\n        groundTruth = self.ImgTrans(gt.convert('RGB'))\n        path = self.imageFiles[index].split('/')[-1]\n\n        return inputImage, groundTruth,path\n    \n    def __len__(self):\n        return len(self.imageFiles)\n"
  },
  {
    "path": "evaluatuion.py",
    "content": "import os\nimport math\nimport argparse\nimport torch\nimport torch.nn as nn\nimport torch.backends.cudnn as cudnn\nfrom PIL import Image\nimport numpy as np\nfrom torch.autograd import Variable\nfrom torchvision.utils import save_image\nfrom torch.utils.data import DataLoader\nfrom data.dataloader import devdata\nfrom scipy import signal, ndimage\nimport gauss\n\n\nparser = argparse.ArgumentParser()\nparser.add_argument('--target_path', type=str, default='',\n                    help='results')\nparser.add_argument('--gt_path', type=str, default='',\n                    help='labels')\nargs = parser.parse_args()\n\nsum_psnr = 0\nsum_ssim = 0\nsum_AGE = 0 \nsum_pCEPS = 0\nsum_pEPS = 0\nsum_mse = 0\n\ncount = 0\nsum_time = 0.0\nl1_loss = 0\n\nimg_path = args.target_path\ngt_path = args.gt_path\n\n\ndef ssim(img1, img2, cs_map=False):\n    \"\"\"Return the Structural Similarity Map corresponding to input images img1 \n    and img2 (images are assumed to be uint8)\n    \n    This function attempts to mimic precisely the functionality of ssim.m a \n    MATLAB provided by the author's of SSIM\n    https://ece.uwaterloo.ca/~z70wang/research/ssim/ssim_index.m\n    \"\"\"\n    img1 = img1.astype(float)\n    img2 = img2.astype(float)\n\n    size = min(img1.shape[0], 11)\n    sigma = 1.5\n    window = gauss.fspecial_gauss(size, sigma)\n    K1 = 0.01\n    K2 = 0.03\n    L = 255 #bitdepth of image\n    C1 = (K1 * L) ** 2\n    C2 = (K2 * L) ** 2\n  #  import pdb;pdb.set_trace()\n    mu1 = signal.fftconvolve(img1, window, mode = 'valid')\n    mu2 = signal.fftconvolve(img2, window, mode = 'valid')\n    mu1_sq = mu1 * mu1\n    mu2_sq = mu2 * mu2\n    mu1_mu2 = mu1 * mu2\n    sigma1_sq = signal.fftconvolve(img1 * img1, window, mode = 'valid') - mu1_sq\n    sigma2_sq = signal.fftconvolve(img2 * img2, window, mode = 'valid') - mu2_sq\n    sigma12 = signal.fftconvolve(img1 * img2, window, mode = 'valid') - mu1_mu2\n    if cs_map:\n        return (((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)), \n                (2.0 * sigma12 + C2) / (sigma1_sq + sigma2_sq + C2))\n    else:\n        return ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) *\n                    (sigma1_sq + sigma2_sq + C2))\n\n\ndef msssim(img1, img2):\n    \"\"\"This function implements Multi-Scale Structural Similarity (MSSSIM) Image \n    Quality Assessment according to Z. Wang's \"Multi-scale structural similarity \n    for image quality assessment\" Invited Paper, IEEE Asilomar Conference on \n    Signals, Systems and Computers, Nov. 2003 \n    \n    Author's MATLAB implementation:-\n    http://www.cns.nyu.edu/~lcv/ssim/msssim.zip\n    \"\"\"\n    level = 5\n    weight = np.array([0.0448, 0.2856, 0.3001, 0.2363, 0.1333])\n    downsample_filter = np.ones((2, 2)) / 4.0\n    # im1 = img1.astype(np.float64)\n    # im2 = img2.astype(np.float64)\n    mssim = np.array([])\n    mcs = np.array([])\n    for l in range(level):\n        ssim_map, cs_map = ssim(img1, img2, cs_map = True)\n        mssim = np.append(mssim, ssim_map.mean())\n        mcs = np.append(mcs, cs_map.mean())\n        filtered_im1 = ndimage.filters.convolve(img1, downsample_filter, \n                                                mode = 'reflect')\n        filtered_im2 = ndimage.filters.convolve(img2, downsample_filter, \n                                                mode = 'reflect')\n        im1 = filtered_im1[: : 2, : : 2]\n        im2 = filtered_im2[: : 2, : : 2]\n\n    # Note: Remove the negative and add it later to avoid NaN in exponential.\n    sign_mcs = np.sign(mcs[0 : level - 1])\n    sign_mssim = np.sign(mssim[level - 1])\n    mcs_power = np.power(np.abs(mcs[0 : level - 1]), weight[0 : level - 1])\n    mssim_power = np.power(np.abs(mssim[level - 1]), weight[level - 1])\n    return np.prod(sign_mcs * mcs_power) * sign_mssim * mssim_power\n\ndef ImageTransform(loadSize, cropSize):\n    return Compose([\n        Resize(size=loadSize, interpolation=Image.BICUBIC),\n      #  RandomCrop(size=cropSize),\n        #RandomHorizontalFlip(p=0.5),\n        ToTensor(),\n    ])\n\ndef visual(image):\n    im =(image).transpose(1,2).transpose(2,3).detach().cpu().numpy()\n    Image.fromarray(im[0].astype(np.uint8)).show()\n\nimgData = devdata(dataRoot=img_path, gtRoot=gt_path)\ndata_loader = DataLoader(imgData, batch_size=1, shuffle=True, num_workers=0, drop_last=False)\n\nfor k, (img,lbl,path) in enumerate(data_loader):\n\t##import pdb;pdb.set_trace()\n\tmse = ((lbl - img)**2).mean()\n\tsum_mse += mse\n\tprint(path,count, 'mse: ', mse)\n\tif mse == 0:\n\t\tcontinue\n\tcount += 1\n\tpsnr = 10 * math.log10(1/mse)\n\tsum_psnr += psnr\n\tprint(path,count, ' psnr: ', psnr)\n\t#l1_loss += nn.L1Loss()(img, lbl)\n\n\n\tR = lbl[0,0,:, :]\n\tG = lbl[0,1,:, :]\n\tB = lbl[0,2,:, :]\n\n\tYGT = .299 * R + .587 * G + .114 * B\n\n\tR = img[0,0,:, :]\n\tG = img[0,1,:, :]\n\tB = img[0,2,:, :]\n\n\tYBC = .299 * R + .587 * G + .114 * B\n\tDiff = abs(np.array(YBC*255) - np.array(YGT*255)).round().astype(np.uint8)\n\tAGE = np.mean(Diff)\n\tprint(' AGE: ', AGE) \n\tmssim = msssim(np.array(YGT*255), np.array(YBC*255))\n\tsum_ssim += mssim\n\tprint(count, ' ssim:', mssim)\n\tthreshold = 20\n\n\tErrors = Diff > threshold\n\tEPs = sum(sum(Errors)).astype(float)\n\tpEPs = EPs / float(512*512)\n\tprint(' pEPS: ' , pEPs)\n\tsum_pEPS += pEPs\n\t########################## CEPs and pCEPs ################################\n\tstructure = np.array([[0, 1, 0], [1, 1, 1], [0, 1, 0]])\n\tsum_AGE+=AGE\n\terodedErrors = ndimage.binary_erosion(Errors, structure).astype(Errors.dtype)\n\tCEPs = sum(sum(erodedErrors))\n\tpCEPs = CEPs / float(512*512)\n\tprint(' pCEPS: ' , pCEPs)\n\tsum_pCEPS += pCEPs\n\nprint(sum_psnr)\nprint('avg mse:', sum_mse / count)\nprint('average psnr:', sum_psnr / count)\nprint('average ssim:', sum_ssim / count)\nprint('average AGE:', sum_AGE / count)\nprint('average pEPS:', sum_pEPS / count)\nprint('average pCEPS:', sum_pCEPS / count)\n"
  },
  {
    "path": "gauss.py",
    "content": "#!/usr/bin/env python\r\n\"\"\"Module providing functionality surrounding gaussian function.\r\n\"\"\"\r\nSVN_REVISION = '$LastChangedRevision: 16541 $'\r\n\r\nimport sys\r\nimport numpy\r\n\r\ndef gaussian2(size, sigma):\r\n    \"\"\"Returns a normalized circularly symmetric 2D gauss kernel array\r\n    \r\n    f(x,y) = A.e^{-(x^2/2*sigma^2 + y^2/2*sigma^2)} where\r\n    \r\n    A = 1/(2*pi*sigma^2)\r\n    \r\n    as define by Wolfram Mathworld \r\n    http://mathworld.wolfram.com/GaussianFunction.html\r\n    \"\"\"\r\n    A = 1/(2.0*numpy.pi*sigma**2)\r\n    x, y = numpy.mgrid[-size//2 + 1:size//2 + 1, -size//2 + 1:size//2 + 1]\r\n    g = A*numpy.exp(-((x**2/(2.0*sigma**2))+(y**2/(2.0*sigma**2))))\r\n    return g\r\n\r\ndef fspecial_gauss(size, sigma):\r\n    \"\"\"Function to mimic the 'fspecial' gaussian MATLAB function\r\n    \"\"\"\r\n    x, y = numpy.mgrid[-size//2 + 1:size//2 + 1, -size//2 + 1:size//2 + 1]\r\n    g = numpy.exp(-((x**2 + y**2)/(2.0*sigma**2)))\r\n    return g/g.sum()\r\n\r\ndef main():\r\n    \"\"\"Show simple use cases for functionality provided by this module.\"\"\"\r\n    from mpl_toolkits.mplot3d.axes3d import Axes3D\r\n    import pylab\r\n    argv = sys.argv\r\n    if len(argv) != 3:\r\n        print >>sys.stderr, 'usage: python -m pim.sp.gauss size sigma'\r\n        sys.exit(2)\r\n    size = int(argv[1])\r\n    sigma = float(argv[2])\r\n    x, y = numpy.mgrid[-size//2 + 1:size//2 + 1, -size//2 + 1:size//2 + 1]\r\n\r\n    fig = pylab.figure()\r\n    fig.suptitle('Some 2-D Gauss Functions')\r\n    ax = fig.add_subplot(2, 1, 1, projection='3d')\r\n    ax.plot_surface(x, y, fspecial_gauss(size, sigma), rstride=1, cstride=1, \r\n                    linewidth=0, antialiased=False, cmap=pylab.jet())\r\n    ax = fig.add_subplot(2, 1, 2, projection='3d')\r\n    ax.plot_surface(x, y, gaussian2(size, sigma), rstride=1, cstride=1, \r\n                    linewidth=0, antialiased=False, cmap=pylab.jet())\r\n    pylab.show()\r\n    return 0\r\n\r\nif __name__ == '__main__':\r\n    sys.exit(main())"
  },
  {
    "path": "loss/Loss.py",
    "content": "import torch\nfrom torch import nn\nfrom torch import autograd\nimport torch.nn.functional as F\nfrom tensorboardX import SummaryWriter\nfrom models.discriminator import Discriminator_STE\nfrom PIL import Image\nimport numpy as np\n\ndef gram_matrix(feat):\n    # https://github.com/pytorch/examples/blob/master/fast_neural_style/neural_style/utils.py\n    (b, ch, h, w) = feat.size()\n    feat = feat.view(b, ch, h * w)\n    feat_t = feat.transpose(1, 2)\n    gram = torch.bmm(feat, feat_t) / (ch * h * w)\n    return gram\n\ndef visual(image):\n    im = image.transpose(1,2).transpose(2,3).detach().cpu().numpy()\n    Image.fromarray(im[0].astype(np.uint8)).show()\n\ndef dice_loss(input, target):\n    input = torch.sigmoid(input)\n\n    input = input.contiguous().view(input.size()[0], -1)\n    target = target.contiguous().view(target.size()[0], -1)\n    \n    input = input \n    target = target\n\n    a = torch.sum(input * target, 1)\n    b = torch.sum(input * input, 1) + 0.001\n    c = torch.sum(target * target, 1) + 0.001\n    d = (2 * a) / (b + c)\n    dice_loss = torch.mean(d)\n    return 1 - dice_loss\n\nclass LossWithGAN_STE(nn.Module):\n    def __init__(self, logPath, extractor, Lamda, lr, betasInit=(0.5, 0.9)):\n        super(LossWithGAN_STE, self).__init__()\n        self.l1 = nn.L1Loss()\n        self.extractor = extractor\n        self.discriminator = Discriminator_STE(3)    ## local_global sn patch gan\n        self.D_optimizer = torch.optim.Adam(self.discriminator.parameters(), lr=lr, betas=betasInit)\n        self.cudaAvailable = torch.cuda.is_available()\n        self.numOfGPUs = torch.cuda.device_count()\n        self.lamda = Lamda\n        self.writer = SummaryWriter(logPath)\n\n    def forward(self, input, mask, x_o1,x_o2,x_o3,output,mm, gt, count, epoch):\n        self.discriminator.zero_grad()\n        D_real = self.discriminator(gt, mask)\n        D_real = D_real.mean().sum() * -1\n        D_fake = self.discriminator(output, mask)\n        D_fake = D_fake.mean().sum() * 1\n        D_loss = torch.mean(F.relu(1.+D_real)) + torch.mean(F.relu(1.+D_fake))  #SN-patch-GAN loss\n        D_fake = -torch.mean(D_fake)     #  SN-Patch-GAN loss\n\n        self.D_optimizer.zero_grad()\n        D_loss.backward(retain_graph=True)\n        self.D_optimizer.step()\n\n        self.writer.add_scalar('LossD/Discrinimator loss', D_loss.item(), count)\n        \n        output_comp = mask * input + (1 - mask) * output\n       # import pdb;pdb.set_trace()\n        holeLoss = 10 * self.l1((1 - mask) * output, (1 - mask) * gt)\n        validAreaLoss = 2*self.l1(mask * output, mask * gt)  \n\n        mask_loss = dice_loss(mm, 1-mask)\n        ### MSR loss ###\n        masks_a = F.interpolate(mask, scale_factor=0.25)\n        masks_b = F.interpolate(mask, scale_factor=0.5)\n        imgs1 = F.interpolate(gt, scale_factor=0.25)\n        imgs2 = F.interpolate(gt, scale_factor=0.5)\n        msrloss = 8 * self.l1((1-mask)*x_o3,(1-mask)*gt) + 0.8*self.l1(mask*x_o3, mask*gt)+\\\n                    6 * self.l1((1-masks_b)*x_o2,(1-masks_b)*imgs2)+1*self.l1(masks_b*x_o2,masks_b*imgs2)+\\\n                    5 * self.l1((1-masks_a)*x_o1,(1-masks_a)*imgs1)+0.8*self.l1(masks_a*x_o1,masks_a*imgs1)\n\n        feat_output_comp = self.extractor(output_comp)\n        feat_output = self.extractor(output)\n        feat_gt = self.extractor(gt)\n\n        prcLoss = 0.0\n        for i in range(3):\n            prcLoss += 0.01 * self.l1(feat_output[i], feat_gt[i])\n            prcLoss += 0.01 * self.l1(feat_output_comp[i], feat_gt[i])\n\n        styleLoss = 0.0\n        for i in range(3):\n            styleLoss += 120 * self.l1(gram_matrix(feat_output[i]),\n                                          gram_matrix(feat_gt[i]))\n            styleLoss += 120 * self.l1(gram_matrix(feat_output_comp[i]),\n                                          gram_matrix(feat_gt[i]))\n        \"\"\" if self.numOfGPUs > 1:\n            holeLoss = holeLoss.sum() / self.numOfGPUs\n            validAreaLoss = validAreaLoss.sum() / self.numOfGPUs\n            prcLoss = prcLoss.sum() / self.numOfGPUs\n            styleLoss = styleLoss.sum() / self.numOfGPUs \"\"\"\n        self.writer.add_scalar('LossG/Hole loss', holeLoss.item(), count)    \n        self.writer.add_scalar('LossG/Valid loss', validAreaLoss.item(), count) \n        self.writer.add_scalar('LossG/msr loss', msrloss.item(), count)   \n        self.writer.add_scalar('LossPrc/Perceptual loss', prcLoss.item(), count)    \n        self.writer.add_scalar('LossStyle/style loss', styleLoss.item(), count)\n\n        GLoss = msrloss+ holeLoss + validAreaLoss+ prcLoss + styleLoss + 0.1 * D_fake + 1*mask_loss\n        self.writer.add_scalar('Generator/Joint loss', GLoss.item(), count)    \n        return GLoss.sum()\n    \n"
  },
  {
    "path": "models/Model.py",
    "content": "import torch\nimport torch.nn as nn\nfrom torchvision import models\n\n#VGG16 feature extract\nclass VGG16FeatureExtractor(nn.Module):\n    def __init__(self):\n        super(VGG16FeatureExtractor, self).__init__()\n        vgg16 = models.vgg16(pretrained=True)\n      #  vgg16.load_state_dict(torch.load('./vgg16-397923af.pth'))\n        self.enc_1 = nn.Sequential(*vgg16.features[:5])\n        self.enc_2 = nn.Sequential(*vgg16.features[5:10])\n        self.enc_3 = nn.Sequential(*vgg16.features[10:17])\n\n        # fix the encoder\n        for i in range(3):\n            for param in getattr(self, 'enc_{:d}'.format(i + 1)).parameters():\n                param.requires_grad = False\n\n    def forward(self, image):\n        results = [image]\n        for i in range(3):\n            func = getattr(self, 'enc_{:d}'.format(i + 1))\n            results.append(func(results[-1]))\n        return results[1:]\n\n"
  },
  {
    "path": "models/discriminator.py",
    "content": "import torch\nimport torch.nn as nn\nfrom .networks import ConvWithActivation, get_pad\n\n##discriminator\nclass Discriminator_STE(nn.Module):\n    def __init__(self, inputChannels):\n        super(Discriminator_STE, self).__init__()\n        cnum =32\n        self.globalDis = nn.Sequential(\n            ConvWithActivation(3, 2*cnum, 4, 2, padding=get_pad(256, 5, 2)),\n            ConvWithActivation(2*cnum, 4*cnum, 4, 2, padding=get_pad(128, 5, 2)),\n            ConvWithActivation(4*cnum, 8*cnum, 4, 2, padding=get_pad(64, 5, 2)),\n            ConvWithActivation(8*cnum, 8*cnum, 4, 2, padding=get_pad(32, 5, 2)),\n            ConvWithActivation(8*cnum, 8*cnum, 4, 2, padding=get_pad(16, 5, 2)),\n            ConvWithActivation(8*cnum, 8*cnum, 4, 2, padding=get_pad(8, 5, 2)),            \n        )\n\n        self.localDis = nn.Sequential(\n            ConvWithActivation(3, 2*cnum, 4, 2, padding=get_pad(256, 5, 2)),\n            ConvWithActivation(2*cnum, 4*cnum, 4, 2, padding=get_pad(128, 5, 2)),\n            ConvWithActivation(4*cnum, 8*cnum, 4, 2, padding=get_pad(64, 5, 2)),\n            ConvWithActivation(8*cnum, 8*cnum, 4, 2, padding=get_pad(32, 5, 2)),\n            ConvWithActivation(8*cnum, 8*cnum, 4, 2, padding=get_pad(16, 5, 2)),\n            ConvWithActivation(8*cnum, 8*cnum, 4, 2, padding=get_pad(8, 5, 2)),\n        )\n        \n        self.fusion = nn.Sequential(\n            nn.Conv2d(512, 1, kernel_size=4),\n            nn.Sigmoid()\n        )\n\n    def forward(self, input, masks):\n        global_feat = self.globalDis(input)\n        local_feat = self.localDis(input * (1 - masks))\n\n        concat_feat = torch.cat((global_feat, local_feat), 1)\n\n        return self.fusion(concat_feat).view(input.size()[0], -1)\n"
  },
  {
    "path": "models/networks.py",
    "content": "import torch\nimport numpy as np\nimport torch.nn.functional as F\nimport torch.nn as nn\n\ndef get_pad(in_,  ksize, stride, atrous=1):\n    out_ = np.ceil(float(in_)/stride)\n    return int(((out_ - 1) * stride + atrous*(ksize-1) + 1 - in_)/2)\n\nclass ConvWithActivation(torch.nn.Module):\n    \"\"\"\n    SN convolution for spetral normalization conv\n    \"\"\"\n    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, activation=torch.nn.LeakyReLU(0.2, inplace=True)):\n        super(ConvWithActivation, self).__init__()\n        self.conv2d = torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias)\n        self.conv2d = torch.nn.utils.spectral_norm(self.conv2d)\n        self.activation = activation\n        for m in self.modules():\n            if isinstance(m, nn.Conv2d):\n                nn.init.kaiming_normal_(m.weight)\n    def forward(self, input):\n        x = self.conv2d(input)\n        if self.activation is not None:\n            return self.activation(x)\n        else:\n            return x\n\nclass DeConvWithActivation(torch.nn.Module):\n    \"\"\"\n    SN convolution for spetral normalization conv\n    \"\"\"\n    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, activation=torch.nn.LeakyReLU(0.2, inplace=True)):\n        super(DeConvWithActivation, self).__init__()\n        self.conv2d = torch.nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias)\n        self.conv2d = torch.nn.utils.spectral_norm(self.conv2d)\n        self.activation = activation\n        for m in self.modules():\n            if isinstance(m, nn.ConvTranspose2d):\n                nn.init.kaiming_normal_(m.weight)\n    def forward(self, input):\n        x = self.conv2d(input)\n        if self.activation is not None:\n            return self.activation(x)\n        else:\n            return x"
  },
  {
    "path": "models/sa_gan.py",
    "content": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport numpy as np\nfrom PIL import Image\nfrom torch.autograd import Variable\nfrom .networks import get_pad, ConvWithActivation, DeConvWithActivation\n\ndef img2photo(imgs):\n    return ((imgs+1)*127.5).transpose(1,2).transpose(2,3).detach().cpu().numpy()\n\ndef visual(imgs):\n    im = img2photo(imgs)\n    Image.fromarray(im[0].astype(np.uint8)).show()\n\nclass Residual(nn.Module):\n    def __init__(self, in_channels, out_channels, same_shape=True, **kwargs):\n        super(Residual,self).__init__()\n        self.same_shape = same_shape\n        strides = 1 if same_shape else 2\n        self.conv1 = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1,stride=strides)\n        self.conv2 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)\n        # self.conv2 = torch.nn.utils.spectral_norm(self.conv2)\n        if not same_shape:\n            self.conv3 = nn.Conv2d(in_channels, out_channels, kernel_size=1,\n            # self.conv3 = nn.Conv2D(channels, kernel_size=3, padding=1,\n                                 stride=strides)\n            # self.conv3 = torch.nn.utils.spectral_norm(self.conv3)\n        self.batch_norm2d = nn.BatchNorm2d(out_channels)\n\n    def forward(self,x):\n        out = F.relu(self.conv1(x))\n        out = self.conv2(out)\n        if not self.same_shape:\n            x = self.conv3(x)\n        out = self.batch_norm2d(out + x)\n        # out = out + x\n        return F.relu(out)\n\nclass ASPP(nn.Module):\n    def __init__(self, in_channel=512, depth=256):\n        super(ASPP,self).__init__()\n        self.mean = nn.AdaptiveAvgPool2d((1, 1))\n        self.conv = nn.Conv2d(in_channel, depth, 1, 1)\n        # k=1 s=1 no pad\n        self.atrous_block1 = nn.Conv2d(in_channel, depth, 1, 1)\n        self.atrous_block6 = nn.Conv2d(in_channel, depth, 3, 1, padding=6, dilation=6)\n        self.atrous_block12 = nn.Conv2d(in_channel, depth, 3, 1, padding=12, dilation=12)\n        self.atrous_block18 = nn.Conv2d(in_channel, depth, 3, 1, padding=18, dilation=18)\n \n        self.conv_1x1_output = nn.Conv2d(depth * 5, depth, 1, 1)\n \n    def forward(self, x):\n        size = x.shape[2:]\n \n        image_features = self.mean(x)\n        image_features = self.conv(image_features)\n        image_features = F.upsample(image_features, size=size, mode='bilinear')\n \n        atrous_block1 = self.atrous_block1(x)\n \n        atrous_block6 = self.atrous_block6(x)\n \n        atrous_block12 = self.atrous_block12(x)\n \n        atrous_block18 = self.atrous_block18(x)\n \n        net = self.conv_1x1_output(torch.cat([image_features, atrous_block1, atrous_block6,\n                                              atrous_block12, atrous_block18], dim=1))\n        return net\n\nclass STRnet2(nn.Module):\n    def __init__(self, n_in_channel=3):\n        super(STRnet2, self).__init__()\n        #### U-Net ####\n        #downsample\n        self.conv1 = ConvWithActivation(3,32,kernel_size=4,stride=2,padding=1)\n        self.conva = ConvWithActivation(32,32,kernel_size=3, stride=1, padding=1)\n        self.convb = ConvWithActivation(32,64, kernel_size=4, stride=2, padding=1)\n        self.res1 = Residual(64,64)\n        self.res2 = Residual(64,64)\n        self.res3 = Residual(64,128,same_shape=False)\n        self.res4 = Residual(128,128)\n        self.res5 = Residual(128,256,same_shape=False)\n       # self.nn = ConvWithActivation(256, 512, 3, 1, dilation=2, padding=get_pad(64, 3, 1, 2))\n        self.res6 = Residual(256,256)\n        self.res7 = Residual(256,512,same_shape=False)\n        self.res8 = Residual(512,512)\n        self.conv2 = ConvWithActivation(512,512,kernel_size=1)\n\n        #upsample\n        self.deconv1 = DeConvWithActivation(512,256,kernel_size=3,padding=1,stride=2)\n        self.deconv2 = DeConvWithActivation(256*2,128,kernel_size=3,padding=1,stride=2)\n        self.deconv3 = DeConvWithActivation(128*2,64,kernel_size=3,padding=1,stride=2)\n        self.deconv4 = DeConvWithActivation(64*2,32,kernel_size=3,padding=1,stride=2)\n        self.deconv5 = DeConvWithActivation(64,3,kernel_size=3,padding=1,stride=2)\n\n        #lateral connection \n        self.lateral_connection1 = nn.Sequential(\n            nn.Conv2d(256, 256, kernel_size=1, padding=0,stride=1),\n            nn.Conv2d(256, 512, kernel_size=3, padding=1,stride=1),\n            nn.Conv2d(512, 512, kernel_size=3, padding=1,stride=1),\n            nn.Conv2d(512, 256, kernel_size=1, padding=0,stride=1),)\n        self.lateral_connection2 = nn.Sequential(\n            nn.Conv2d(128, 128, kernel_size=1, padding=0,stride=1),\n            nn.Conv2d(128, 256, kernel_size=3, padding=1,stride=1),\n            nn.Conv2d(256, 256, kernel_size=3, padding=1,stride=1),\n            nn.Conv2d(256, 128, kernel_size=1, padding=0,stride=1),)\n        self.lateral_connection3 = nn.Sequential(\n            nn.Conv2d(64, 64, kernel_size=1, padding=0,stride=1),\n            nn.Conv2d(64, 128, kernel_size=3, padding=1,stride=1),\n            nn.Conv2d(128, 128, kernel_size=3, padding=1,stride=1),\n            nn.Conv2d(128, 64, kernel_size=1, padding=0,stride=1),)\n        self.lateral_connection4 = nn.Sequential(\n            nn.Conv2d(32, 32, kernel_size=1, padding=0,stride=1),\n            nn.Conv2d(32, 64, kernel_size=3, padding=1,stride=1),\n            nn.Conv2d(64, 64, kernel_size=3, padding=1,stride=1),\n            nn.Conv2d(64, 32, kernel_size=1, padding=0,stride=1),)   \n\n        #self.relu = nn.elu(alpha=1.0)\n        self.conv_o1 = nn.Conv2d(64,3,kernel_size=1)\n        self.conv_o2 = nn.Conv2d(32,3,kernel_size=1)\n        ##### U-Net #####\n\n        ### ASPP ###\n       # self.aspp = ASPP(512, 256)\n        ### ASPP ###\n\n        ### mask branch decoder ###\n        self.mask_deconv_a = DeConvWithActivation(512,256,kernel_size=3,padding=1,stride=2)\n        self.mask_conv_a = ConvWithActivation(256,128,kernel_size=3,padding=1,stride=1)\n        self.mask_deconv_b = DeConvWithActivation(256,128,kernel_size=3,padding=1,stride=2)\n        self.mask_conv_b = ConvWithActivation(128,64,kernel_size=3,padding=1,stride=1)\n        self.mask_deconv_c = DeConvWithActivation(128,64,kernel_size=3,padding=1,stride=2)\n        self.mask_conv_c = ConvWithActivation(64,32,kernel_size=3,padding=1,stride=1)\n        self.mask_deconv_d = DeConvWithActivation(64,32,kernel_size=3,padding=1,stride=2)\n        self.mask_conv_d = nn.Conv2d(32,3,kernel_size=1)\n        ### mask branch ###\n\n        ##### Refine sub-network ######\n        n_in_channel = 3\n        cnum = 32\n        ####downsapmle\n        self.coarse_conva = ConvWithActivation(n_in_channel, cnum, kernel_size=5, stride=1, padding=2)\n        self.coarse_convb = ConvWithActivation(cnum, 2*cnum, kernel_size=4, stride=2, padding=1)\n        self.coarse_convc = ConvWithActivation(2*cnum, 2*cnum, kernel_size=3, stride=1, padding=1)\n        self.coarse_convd = ConvWithActivation(2*cnum, 4*cnum, kernel_size=4, stride=2, padding=1)\n        self.coarse_conve = ConvWithActivation(4*cnum, 4*cnum, kernel_size=3, stride=1, padding=1)\n        self.coarse_convf = ConvWithActivation(4*cnum, 4*cnum, kernel_size=3, stride=1, padding=1)\n        ### astrous\n        self.astrous_net = nn.Sequential(\n            ConvWithActivation(4*cnum, 4*cnum, 3, 1, dilation=2, padding=get_pad(64, 3, 1, 2)),\n            ConvWithActivation(4*cnum, 4*cnum, 3, 1, dilation=4, padding=get_pad(64, 3, 1, 4)),\n            ConvWithActivation(4*cnum, 4*cnum, 3, 1, dilation=8, padding=get_pad(64, 3, 1, 8)),\n            ConvWithActivation(4*cnum, 4*cnum, 3, 1, dilation=16, padding=get_pad(64, 3, 1, 16)),\n        )\n        ###astrous\n        ### upsample\n        self.coarse_convk = ConvWithActivation(4*cnum, 4*cnum, kernel_size=3, stride=1, padding=1)\n        self.coarse_convl = ConvWithActivation(4*cnum, 4*cnum, kernel_size=3, stride=1, padding=1)\n        self.coarse_deconva = DeConvWithActivation(4*cnum*3, 2*cnum, kernel_size=3,padding=1,stride=2)\n        self.coarse_convm = ConvWithActivation(2*cnum, 2*cnum, kernel_size=3, stride=1, padding=1)\n        self.coarse_deconvb = DeConvWithActivation(2*cnum*3, cnum, kernel_size=3,padding=1,stride=2)\n        self.coarse_convn = nn.Sequential(\n            ConvWithActivation(cnum, cnum//2, kernel_size=3, stride=1, padding=1),\n            #Self_Attn(cnum//2, 'relu'),\n            ConvWithActivation(cnum//2, 3, kernel_size=3, stride=1, padding=1, activation=None),\n        )   \n        self.c1 = nn.Conv2d(32,64,kernel_size=1)    \n        self.c2 = nn.Conv2d(64,128,kernel_size=1)   \n        ##### Refine network ######\n\n    def forward(self, x):\n        #downsample\n        x = self.conv1(x)\n        x = self.conva(x)\n        con_x1 = x\n       # import pdb;pdb.set_trace()\n        x = self.convb(x)\n        x = self.res1(x)\n        con_x2 = x\n        x = self.res2(x)\n        x = self.res3(x)\n        con_x3 = x\n        x = self.res4(x)\n        x = self.res5(x)\n        con_x4 = x\n        x = self.res6(x)\n        # x_mask = self.nn(con_x4)    ### for mask branch  aspp \n        # x_mask = self.aspp(x_mask)     ###  for mask branch aspp\n        x_mask=x                      ### no aspp\n       # import pdb;pdb.set_trace()\n        x = self.res7(x)\n        x = self.res8(x)\n        x = self.conv2(x)\n        #upsample\n        x = self.deconv1(x)\n        x = torch.cat([self.lateral_connection1(con_x4), x], dim=1)\n        x = self.deconv2(x)\n        x = torch.cat([self.lateral_connection2(con_x3), x], dim=1)\n        x = self.deconv3(x)\n        xo1 = x\n        x = torch.cat([self.lateral_connection3(con_x2), x], dim=1)\n        x = self.deconv4(x)\n        xo2 = x\n        x = torch.cat([self.lateral_connection4(con_x1), x], dim=1)\n        #import pdb;pdb.set_trace()\n        x = self.deconv5(x)\n        x_o1 = self.conv_o1(xo1)\n        x_o2 = self.conv_o2(xo2)\n        x_o_unet = x\n\n        ### mask branch ###\n        mm = self.mask_deconv_a(torch.cat([x_mask,con_x4],dim=1))\n        mm = self.mask_conv_a(mm)\n        mm = self.mask_deconv_b(torch.cat([mm,con_x3],dim=1))\n        mm = self.mask_conv_b(mm)\n        mm = self.mask_deconv_c(torch.cat([mm,con_x2],dim=1))\n        mm = self.mask_conv_c(mm)\n        mm = self.mask_deconv_d(torch.cat([mm,con_x1],dim=1))\n        mm = self.mask_conv_d(mm)\n        ### mask branch ### \n\n        ###refine sub-network\n        x = self.coarse_conva(x_o_unet)\n        x = self.coarse_convb(x)\n        x = self.coarse_convc(x)\n        x_c1 = x     ###concate feature1\n        x = self.coarse_convd(x)\n        x = self.coarse_conve(x)\n        x = self.coarse_convf(x)\n        x_c2 = x    ###concate feature2\n        x = self.astrous_net(x)\n        x = self.coarse_convk(x)\n        x = self.coarse_convl(x)\n        x = self.coarse_deconva(torch.cat([x, x_c2,self.c2(con_x2)],dim=1))\n        x = self.coarse_convm(x)\n        x = self.coarse_deconvb(torch.cat([x,x_c1,self.c1(con_x1)],dim=1))\n        x = self.coarse_convn(x)\n        return x_o1, x_o2, x_o_unet, x, mm"
  },
  {
    "path": "test_image_STE.py",
    "content": "import os\nimport math\nimport argparse\nimport torch\nimport torch.nn as nn\nimport torch.backends.cudnn as cudnn\nfrom PIL import Image\nimport numpy as np\nfrom torch.autograd import Variable\nfrom torchvision.utils import save_image\nfrom torch.utils.data import DataLoader\nfrom data.dataloader import ErasingData\nfrom models.sa_gan import STRnet2\n\n\nparser = argparse.ArgumentParser()\nparser.add_argument('--numOfWorkers', type=int, default=0,\n                    help='workers for dataloader')\nparser.add_argument('--modelsSavePath', type=str, default='',\n                    help='path for saving models')\nparser.add_argument('--logPath', type=str,\n                    default='')\nparser.add_argument('--batchSize', type=int, default=16)\nparser.add_argument('--loadSize', type=int, default=512,\n                    help='image loading size')\nparser.add_argument('--dataRoot', type=str,\n                    default='')\nparser.add_argument('--pretrained',type=str, default='', help='pretrained models for finetuning')\nparser.add_argument('--savePath', type=str, default='./results/sn_tv/')\nargs = parser.parse_args()\n\ncuda = torch.cuda.is_available()\nif cuda:\n    print('Cuda is available!')\n    cudnn.benchmark = True\n\n\ndef visual(image):\n    im =(image).transpose(1,2).transpose(2,3).detach().cpu().numpy()\n    Image.fromarray(im[0].astype(np.uint8)).show()\n\nbatchSize = args.batchSize\nloadSize = (args.loadSize, args.loadSize)\ndataRoot = args.dataRoot\nsavePath = args.savePath\nresult_with_mask = savePath + 'WithMaskOutput/'\nresult_straight = savePath + 'StrOuput/'\n#import pdb;pdb.set_trace()\n\nif not os.path.exists(savePath):\n    os.makedirs(savePath)\n    os.makedirs(result_with_mask)\n    os.makedirs(result_straight)\n\n\nErase_data = ErasingData(dataRoot, loadSize, training=False)\nErase_data = DataLoader(Erase_data, batch_size=batchSize, shuffle=True, num_workers=args.numOfWorkers, drop_last=False)\n\n\nnetG = STRnet2(3)\n\nnetG.load_state_dict(torch.load(args.pretrained))\n\n#\nif cuda:\n    netG = netG.cuda()\n\nfor param in netG.parameters():\n    param.requires_grad = False\n\nprint('OK!')\n\nimport time\nstart = time.time()\nnetG.eval()\nfor imgs, gt, masks, path in (Erase_data):\n    if cuda:\n        imgs = imgs.cuda()\n        gt = gt.cuda()\n        masks = masks.cuda()\n    out1, out2, out3, g_images,mm = netG(imgs)\n    g_image = g_images.data.cpu()\n    gt = gt.data.cpu()\n    mask = masks.data.cpu()\n    g_image_with_mask = gt * (mask) + g_image * (1- mask)\n\n    save_image(g_image_with_mask, result_with_mask+path[0])\n    save_image(g_image, result_straight+path[0])\n\n\n\n\n"
  },
  {
    "path": "train_STE.py",
    "content": "import os\nimport math\nimport argparse\nimport torch\nimport torch.nn as nn\nimport torch.optim as optim\nimport torch.backends.cudnn as cudnn\nfrom PIL import Image\nimport numpy as np\nfrom torch.autograd import Variable\nfrom torchvision.utils import save_image\nfrom torchvision import datasets\nfrom torch.utils.data import DataLoader\nfrom torchvision import utils\nfrom data.dataloader import ErasingData\nfrom loss.Loss import LossWithGAN_STE\nfrom models.Model import VGG16FeatureExtractor\nfrom models.sa_gan import STRnet2\n\ntorch.set_num_threads(5)\n\nos.environ[\"CUDA_VISIBLE_DEVICES\"] = \"3\"    ### set the gpu as No....\n\nparser = argparse.ArgumentParser()\nparser.add_argument('--numOfWorkers', type=int, default=0,\n                    help='workers for dataloader')\nparser.add_argument('--modelsSavePath', type=str, default='',\n                    help='path for saving models')\nparser.add_argument('--logPath', type=str,\n                    default='')\nparser.add_argument('--batchSize', type=int, default=16)\nparser.add_argument('--loadSize', type=int, default=512,\n                    help='image loading size')\nparser.add_argument('--dataRoot', type=str,\n                    default='')\nparser.add_argument('--pretrained',type=str, default='', help='pretrained models for finetuning')\nparser.add_argument('--num_epochs', type=int, default=500, help='epochs')\nargs = parser.parse_args()\n\n\ndef visual(image):\n    im = image.transpose(1,2).transpose(2,3).detach().cpu().numpy()\n    Image.fromarray(im[0].astype(np.uint8)).show()\n\n\ncuda = torch.cuda.is_available()\nif cuda:\n    print('Cuda is available!')\n    cudnn.enable = True\n    cudnn.benchmark = True\n\nbatchSize = args.batchSize\nloadSize = (args.loadSize, args.loadSize)\n\nif not os.path.exists(args.modelsSavePath):\n    os.makedirs(args.modelsSavePath)\n\ndataRoot = args.dataRoot\n\n# import pdb;pdb.set_trace()\nErase_data = ErasingData(dataRoot, loadSize, training=True)\nErase_data = DataLoader(Erase_data, batch_size=batchSize, \n                         shuffle=True, num_workers=args.numOfWorkers, drop_last=False, pin_memory=True)\n\nnetG = STRnet2(3)\n\nif args.pretrained != '':\n    print('loaded ')\n    netG.load_state_dict(torch.load(args.pretrained))\n\nnumOfGPUs = torch.cuda.device_count()\n\nif cuda:\n    netG = netG.cuda()\n    if numOfGPUs > 1:\n        netG = nn.DataParallel(netG, device_ids=range(numOfGPUs))\n\ncount = 1\n\n\nG_optimizer = optim.Adam(netG.parameters(), lr=0.0001, betas=(0.5, 0.9))\n\n\ncriterion = LossWithGAN_STE(args.logPath, VGG16FeatureExtractor(), lr=0.00001, betasInit=(0.0, 0.9), Lamda=10.0)\n\nif cuda:\n    criterion = criterion.cuda()\n\n    if numOfGPUs > 1:\n        criterion = nn.DataParallel(criterion, device_ids=range(numOfGPUs))\n\nprint('OK!')\nnum_epochs = args.num_epochs\n\nfor i in range(1, num_epochs + 1):\n    netG.train()\n\n    for k,(imgs, gt, masks, path) in enumerate(Erase_data):\n        if cuda:\n            imgs = imgs.cuda()\n            gt = gt.cuda()\n            masks = masks.cuda()\n        netG.zero_grad()\n\n        x_o1,x_o2,x_o3,fake_images,mm = netG(imgs)\n        G_loss = criterion(imgs, masks, x_o1, x_o2, x_o3, fake_images, mm, gt, count, i)\n        G_loss = G_loss.sum()\n        G_optimizer.zero_grad()\n        G_loss.backward()\n        G_optimizer.step()       \n\n        print('[{}/{}] Generator Loss of epoch{} is {}'.format(k,len(Erase_data),i, G_loss.item()))\n\n        count += 1\n    \n    if ( i % 10 == 0):\n        if numOfGPUs > 1 :\n            torch.save(netG.module.state_dict(), args.modelsSavePath +\n                    '/STE_{}.pth'.format(i))\n        else:\n            torch.save(netG.state_dict(), args.modelsSavePath +\n                    '/STE_{}.pth'.format(i))"
  }
]