main 72f0e61b9f06 cached
35 files
174.8 KB
45.1k tokens
353 symbols
1 requests
Download .txt
Repository: vinthony/deep-blind-watermark-removal
Branch: main
Commit: 72f0e61b9f06
Files: 35
Total size: 174.8 KB

Directory structure:
gitextract_p5y2m_4m/

├── README.md
├── examples/
│   ├── evaluate.sh
│   └── test.sh
├── main.py
├── options.py
├── requirements.txt
├── scripts/
│   ├── __init__.py
│   ├── datasets/
│   │   ├── BIH.py
│   │   ├── COCO.py
│   │   └── __init__.py
│   ├── machines/
│   │   ├── BasicMachine.py
│   │   ├── S2AM.py
│   │   ├── VX.py
│   │   └── __init__.py
│   ├── models/
│   │   ├── __init__.py
│   │   ├── backbone_unet.py
│   │   ├── blocks.py
│   │   ├── discriminator.py
│   │   ├── rasc.py
│   │   ├── sa_resunet.py
│   │   ├── unet.py
│   │   ├── vgg.py
│   │   └── vmu.py
│   └── utils/
│       ├── __init__.py
│       ├── evaluation.py
│       ├── imutils.py
│       ├── logger.py
│       ├── losses.py
│       ├── misc.py
│       ├── model_init.py
│       ├── osutils.py
│       ├── parallel.py
│       └── transforms.py
├── test.py
└── watermark_synthesis.ipynb

================================================
FILE CONTENTS
================================================

================================================
FILE: README.md
================================================
This repo contains the code and results of the AAAI 2021 paper:

<i><b> [Split then Refine: Stacked Attention-guided ResUNets for Blind Single Image Visible Watermark Removal](https://arxiv.org/abs/2012.07007)</b></i><br>
[Xiaodong Cun](http://vinthony.github.io), [Chi-Man Pun<sup>*</sup>](http://www.cis.umac.mo/~cmpun/) <br>
[University of Macau](http://um.edu.mo/)

[Datasets](#Resources) | [Models](#Resources) | [Paper](https://arxiv.org/abs/2012.07007)  | [🔥Online Demo!](https://colab.research.google.com/drive/1pYY7byBjM-7aFIWk8HcF9nK_s6pqGwww?usp=sharing)(Google CoLab)

<hr>

<img width="726" alt="nn" src="https://user-images.githubusercontent.com/4397546/101241905-37915d80-3735-11eb-9fb9-2e1e46d63f15.png">

<i>The overview of the proposed two-stage framework. Firstly, we propose a multi-task network, SplitNet, for watermark detection, removal,  and recovery. Then, we propose the RefineNet to smooth the learned region with the predicted mask and the recovered background from the previous stage. As a consequence, our network can be trained in an end-to-end fashion without any manual intervention. Note that, for clarity, we do not show any skip-connections between all the encoders and decoders.</i>
<hr>

> The whole project will be released in the January of 2021 (almost).


### Datasets

We synthesized four different datasets for training and testing, you can download the dataset via [huggingface](https://huggingface.co/datasets/vinthony/watermark-removal-logo/tree/main).

![image](https://user-images.githubusercontent.com/4397546/104273158-74413900-54d9-11eb-95fa-c6bee94de0ea.png)


### Pre-trained Models

* [27kpng_model_best.pth.tar (google drive)](https://drive.google.com/file/d/1KpSJ6385CHN6WlAINqB3CYrJdleQTJBc/view?usp=sharing)

> Other Pre-trained Models are still reorganizing and uploading, it will be released soon.


### Demos

An easy-to-use online demo can be founded in [google colab](https://colab.research.google.com/drive/1pYY7byBjM-7aFIWk8HcF9nK_s6pqGwww?usp=sharing).

The local demo will be released soon.

### Pre-requirements

```
pip install -r requirements.txt
```

### Train

Besides training our methods, here, we also give an example of how to train the [s2am](https://github.com/vinthony/s2am) under our framework. More details can be found in the shell scripts.


```
bash examples/evaluation.sh
```

### Test

```
bash examples/test.sh
```

## **Acknowledgements**
The author would like to thanks Nan Chen for her helpful discussion.

Part of the code is based upon our previous work on image harmonization [s2am](https://github.com/vinthony/s2am) 

## **Citation**

If you find our work useful in your research, please consider citing:

```
@misc{cun2020split,
      title={Split then Refine: Stacked Attention-guided ResUNets for Blind Single Image Visible Watermark Removal}, 
      author={Xiaodong Cun and Chi-Man Pun},
      year={2020},
      eprint={2012.07007},
      archivePrefix={arXiv},
      primaryClass={cs.CV}
}
```

## **Contact**
Please contact me if there is any question (Xiaodong Cun yb87432@um.edu.mo)


================================================
FILE: examples/evaluate.sh
================================================
set -ex



# example training scripts for AAAI-21
# Split then Refine: Stacked Attention-guided ResUNets for Blind Single Image Visible Watermark Removal


CUDA_VISIBLE_DEVICES=0 python /data/home/yb87432/s2am/main.py  --epochs 100\
 --schedule 100\
 --lr 1e-3\
 -c eval/10kgray/1e3_bs4_256_hybrid_ssim_vgg\
 --arch vvv4n\
 --sltype vggx\
 --style-loss 0.025\
 --ssim-loss 0.15\
 --masked True\
 --loss-type hybrid\
 --limited-dataset 1\
 --machine vx\
 --input-size 256\
 --train-batch 4\
 --test-batch 1\
 --base-dir $HOME/watermark/10kgray/\
 --data _images





# example training scripts for TIP-20
# Improving the Harmony of the Composite Image by Spatial-Separated Attention Module
# * in the original version, the res = False
# suitable for the iHarmony4 dataset.

python /data/home/yb87432/mypaper/s2am/main.py  --epochs 200\
 --schedule 150\
 --lr 1e-3\
 -c checkpoint/normal_rasc_HAdobe5k_res \
 --arch rascv2\
 --style-loss 0\
 --ssim-loss 0\
 --limited-dataset 0\
 --res True\
 --machine s2am\
 --input-size 256\
 --train-batch 16\
 --test-batch 1\
 --base-dir $HOME/Datasets/\
 --data HAdobe5k

================================================
FILE: examples/test.sh
================================================

set -ex

CUDA_VISIBLE_DEVICES=0 python /data/home/yb87432/s2am/test.py \
  -c test/10kgray_ssim\
  --resume /data/home/yb87432/s2am/eval/10kgray/1e3_bs6_256_hybrid_ssim_vgg_vx__images_vvv4n/model_best.pth.tar\
  --arch vvv4n\
  --machine vx\
  --input-size 256\
  --test-batch 1\
  --evaluate\
  --base-dir $HOME/watermark/10kgray/\
  --data _images

================================================
FILE: main.py
================================================
from __future__ import print_function, absolute_import

import argparse
import torch,time,os

torch.backends.cudnn.benchmark = True

from scripts.utils.misc import save_checkpoint, adjust_learning_rate

import scripts.datasets as datasets
import scripts.machines as machines
from options import Options

def main(args):
    
    if 'HFlickr' or 'HCOCO' or 'Hday2night' or 'HAdobe5k' in args.base_dir:
        dataset_func = datasets.BIH
    else:
        dataset_func = datasets.COCO

    train_loader = torch.utils.data.DataLoader(dataset_func('train',args),batch_size=args.train_batch, shuffle=True,
        num_workers=args.workers, pin_memory=True)
    
    val_loader = torch.utils.data.DataLoader(dataset_func('val',args),batch_size=args.test_batch, shuffle=False,
        num_workers=args.workers, pin_memory=True)

    lr = args.lr
    data_loaders = (train_loader,val_loader)

    Machine = machines.__dict__[args.machine](datasets=data_loaders, args=args)
    print('============================ Initization Finish && Training Start =============================================')

    for epoch in range(Machine.args.start_epoch, Machine.args.epochs):

        print('\nEpoch: %d | LR: %.8f' % (epoch + 1, lr))
        lr = adjust_learning_rate(data_loaders, Machine.optimizer, epoch, lr, args)

        Machine.record('lr',lr, epoch)        
        Machine.train(epoch)

        if args.freq < 0:
            Machine.validate(epoch)
            Machine.flush()
            Machine.save_checkpoint()

if __name__ == '__main__':
    parser=Options().init(argparse.ArgumentParser(description='WaterMark Removal'))
    args = parser.parse_args()
    print('==================================== WaterMark Removal =============================================')
    print('==> {:50}: {:<}'.format("Start Time",time.ctime(time.time())))
    print('==> {:50}: {:<}'.format("USE GPU",os.environ['CUDA_VISIBLE_DEVICES']))
    print('==================================== Stable Parameters =============================================')
    for arg in vars(args):
        if type(getattr(args, arg)) == type([]):
            if ','.join([ str(i) for i in getattr(args, arg)]) == ','.join([ str(i) for i in parser.get_default(arg)]):
                print('==> {:50}: {:<}({:<})'.format(arg,','.join([ str(i) for i in getattr(args, arg)]),','.join([ str(i) for i in parser.get_default(arg)])))
        else:
            if getattr(args, arg) == parser.get_default(arg):
                print('==> {:50}: {:<}({:<})'.format(arg,getattr(args, arg),parser.get_default(arg)))
    print('==================================== Changed Parameters =============================================')
    for arg in vars(args):
        if type(getattr(args, arg)) == type([]):
            if ','.join([ str(i) for i in getattr(args, arg)]) != ','.join([ str(i) for i in parser.get_default(arg)]):
                print('==> {:50}: {:<}({:<})'.format(arg,','.join([ str(i) for i in getattr(args, arg)]),','.join([ str(i) for i in parser.get_default(arg)])))
        else:
            if getattr(args, arg) != parser.get_default(arg):
                print('==> {:50}: {:<}({:<})'.format(arg,getattr(args, arg),parser.get_default(arg)))
    print('==================================== Start Init Model  ===============================================')
    main(args)
    print('==================================== FINISH WITHOUT ERROR =============================================')


================================================
FILE: options.py
================================================

import scripts.models as models

model_names = sorted(name for name in models.__dict__
    if name.islower() and not name.startswith("__")
    and callable(models.__dict__[name]))
    
class Options():
    """docstring for Options"""
    def __init__(self):
        pass

    def init(self, parser):        
        # Model structure
        parser.add_argument('--arch', '-a', metavar='ARCH', default='dhn',
                            choices=model_names,
                            help='model architecture: ' +
                                ' | '.join(model_names) +
                                ' (default: resnet18)')
        parser.add_argument('--darch', metavar='ARCH', default='dhn',
                            choices=model_names,
                            help='model architecture: ' +
                                ' | '.join(model_names) +
                                ' (default: resnet18)')
                                
        parser.add_argument('--machine', '-m', metavar='NACHINE', default='basic')
        # Training strategy
        parser.add_argument('-j', '--workers', default=2, type=int, metavar='N',
                            help='number of data loading workers (default: 4)')
        parser.add_argument('--epochs', default=30, type=int, metavar='N',
                            help='number of total epochs to run')
        parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
                            help='manual epoch number (useful on restarts)')
        parser.add_argument('--train-batch', default=64, type=int, metavar='N',
                            help='train batchsize')
        parser.add_argument('--test-batch', default=6, type=int, metavar='N',
                            help='test batchsize')
        parser.add_argument('--lr', '--learning-rate', default=1e-3, type=float,metavar='LR', help='initial learning rate')
        parser.add_argument('--dlr', '--dlearning-rate', default=1e-3, type=float, help='initial learning rate')
        parser.add_argument('--beta1', default=0.9, type=float, help='initial learning rate')
        parser.add_argument('--beta2', default=0.999, type=float, help='initial learning rate')
        parser.add_argument('--momentum', default=0, type=float, metavar='M',
                            help='momentum')
        parser.add_argument('--weight-decay', '--wd', default=0, type=float,
                            metavar='W', help='weight decay (default: 0)')
        parser.add_argument('--schedule', type=int, nargs='+', default=[5, 10],
                            help='Decrease learning rate at these epochs.')
        parser.add_argument('--gamma', type=float, default=0.1,
                            help='LR is multiplied by gamma on schedule.')
        # Data processing
        parser.add_argument('-f', '--flip', dest='flip', action='store_true',
                            help='flip the input during validation')
        parser.add_argument('--lambdaL1', type=float, default=1, help='the weight of L1.')
        parser.add_argument('--alpha', type=float, default=0.5,
                            help='Groundtruth Gaussian sigma.')
        parser.add_argument('--sigma-decay', type=float, default=0,
                            help='Sigma decay rate for each epoch.')
        # Miscs
        parser.add_argument('--base-dir', default='/PATH_TO_DATA_FOLDER/', type=str, metavar='PATH')
        parser.add_argument('--data', default='', type=str, metavar='PATH',
                            help='path to save checkpoint (default: checkpoint)')
        parser.add_argument('-c', '--checkpoint', default='checkpoint', type=str, metavar='PATH',
                            help='path to save checkpoint (default: checkpoint)')
        parser.add_argument('--resume', default='', type=str, metavar='PATH',
                            help='path to latest checkpoint (default: none)')
        parser.add_argument('--finetune', default='', type=str, metavar='PATH',
                            help='path to latest checkpoint (default: none)')

        parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
                            help='evaluate model on validation set')
        parser.add_argument('--style-loss', default=0, type=float,
                            help='preception loss')
        parser.add_argument('--ssim-loss', default=0, type=float,help='msssim loss')
        parser.add_argument('--att-loss', default=1, type=float,help='msssim loss')
        parser.add_argument('--default-loss',default=False,type=bool)
        parser.add_argument('--sltype', default='vggx', type=str)
        parser.add_argument('-da', '--data-augumentation', default=False, type=bool,
                            help='preception loss')
        parser.add_argument('-d', '--debug', dest='debug', action='store_true',
                            help='show intermediate results')
        parser.add_argument('--input-size', default=256, type=int, metavar='N',
                            help='train batchsize')
        parser.add_argument('--freq', default=-1, type=int, metavar='N',
                            help='evaluation frequence')
        parser.add_argument('--normalized-input', default=False, type=bool,
                            help='train batchsize')
        parser.add_argument('--res', default=False, type=bool,help='residual learning for s2am')
        parser.add_argument('--requires-grad', default=False, type=bool,
                            help='train batchsize')
        parser.add_argument('--limited-dataset', default=0, type=int, metavar='N')
        parser.add_argument('--gpu',default=True,type=bool)
        parser.add_argument('--masked',default=False,type=bool)
        parser.add_argument('--gan-norm', default=False,type=bool, help='train batchsize')
        parser.add_argument('--hl', default=False,type=bool, help='homogenious leanring')
        parser.add_argument('--loss-type', default='l2',type=str, help='train batchsize')
        return parser

================================================
FILE: requirements.txt
================================================
numpy==1.19.1
opencv-python==3.4.8.29
Pillow
scikit-image==0.14.5
scikit-learn==0.23.1
scipy==1.2.1
sklearn==0.0
tensorboardX
torch>=1.0.0
torchvision

================================================
FILE: scripts/__init__.py
================================================
from __future__ import absolute_import

from . import datasets
from . import models
from . import utils

# import os, sys
# sys.path.append(os.path.join(os.path.dirname(__file__), "progress"))
# from progress.bar import Bar as Bar

# __version__ = '0.1.0'

================================================
FILE: scripts/datasets/BIH.py
================================================
from __future__ import print_function, absolute_import

import os
import csv
import numpy as np
import json
import random
import math
import matplotlib.pyplot as plt
from collections import namedtuple
from os import listdir
from os.path import isfile, join

import torch
import torch.utils.data as data

from scripts.utils.osutils import *
from scripts.utils.imutils import *
from scripts.utils.transforms import *
import torchvision.transforms as transforms
from PIL import Image
from PIL import ImageEnhance
from PIL import ImageFilter
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True

class BIH(data.Dataset):
    def __init__(self,train,config=None, sample=[],gan_norm=False):

        self.train = []
        self.anno = []
        self.mask = []
        self.wm = []
        self.input_size = config.input_size
        self.normalized_input = config.normalized_input
        self.base_folder = config.base_dir +'/' + config.data
        self.dataset = config.data

        if config == None:
            self.data_augumentation = False
        else:
            self.data_augumentation = config.data_augumentation

        self.istrain = False if train.find('train') == -1 else True
        self.sample = sample
        self.gan_norm = gan_norm
        mypath = join(self.base_folder,self.dataset+'_'+train+'.txt')

        with open(mypath) as f:
            # here we get the filenames 
            file_names = [ im.strip() for im in f.readlines() ]

        if config.limited_dataset > 0:
            xtrain = sorted(list(set([ file_name.split('-')[0] for file_name in file_names ])))
            tmp = []
            for x in xtrain:
                tmp.append([y for y in file_names if x in y][0])

            file_names = tmp
        else:
            file_names = file_names

        for file_name in file_names:
            self.train.append(os.path.join(self.base_folder,'images',file_name))
            self.mask.append(os.path.join(self.base_folder,'masks','_'.join(file_name.split('_')[0:2])+'.png'))
            self.anno.append(os.path.join(self.base_folder,'reals',file_name.split('_')[0]+'.jpg'))

        if len(self.sample) > 0 :
            self.train = [ self.train[i] for i in self.sample ] 
            self.mask = [ self.mask[i] for i in self.sample ] 
            self.anno = [ self.anno[i] for i in self.sample ] 

        self.trans = transforms.Compose([
                transforms.Resize((self.input_size,self.input_size)),
                transforms.ToTensor()
            ])

        print('total Dataset of '+self.dataset+' is : ', len(self.train))


    def __getitem__(self, index):
        img = Image.open(self.train[index]).convert('RGB')
        mask = Image.open(self.mask[index]).convert('L')
        anno = Image.open(self.anno[index]).convert('RGB')

        # for shadow removal and blind image harmonization, here is no ground truth wm
        # wm = Image.open(self.wm[index]).convert('RGB')

        return {"image": self.trans(img),
                "target": self.trans(anno), 
                "mask": self.trans(mask), 
                "name": self.train[index].split('/')[-1],
                "imgurl":self.train[index],
                "maskurl":self.mask[index],
                "targeturl":self.anno[index],
                }

    def __len__(self):

        return len(self.train)


================================================
FILE: scripts/datasets/COCO.py
================================================
from __future__ import print_function, absolute_import

import os
import csv
import numpy as np
import json
import random
import math
import matplotlib.pyplot as plt
from collections import namedtuple
from os import listdir
from os.path import isfile, join

import torch
import torch.utils.data as data

from scripts.utils.osutils import *
from scripts.utils.imutils import *
from scripts.utils.transforms import *
import torchvision.transforms as transforms
from PIL import Image
from PIL import ImageEnhance
from PIL import ImageFilter
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True

class COCO(data.Dataset):
    def __init__(self,train,config=None, sample=[],gan_norm=False):

        self.train = []
        self.anno = []
        self.mask = []
        self.wm = []
        self.input_size = config.input_size
        self.normalized_input = config.normalized_input
        self.base_folder = config.base_dir
        self.dataset = train+config.data

        if config == None:
            self.data_augumentation = False
        else:
            self.data_augumentation = config.data_augumentation

        self.istrain = False if self.dataset.find('train') == -1 else True
        self.sample = sample
        self.gan_norm = gan_norm
        mypath = join(self.base_folder,self.dataset)
        file_names = sorted([f for f in listdir(join(mypath,'image')) if isfile(join(mypath,'image', f)) ])

        if config.limited_dataset > 0:
            xtrain = sorted(list(set([ file_name.split('-')[0] for file_name in file_names ])))
            tmp = []
            for x in xtrain:
                # get the file_name by identifier
                tmp.append([y for y in file_names if x in y][0])

            file_names = tmp
        else:
            file_names = file_names

        for file_name in file_names:
            self.train.append(os.path.join(mypath,'image',file_name))
            self.mask.append(os.path.join(mypath,'mask',file_name))
            self.wm.append(os.path.join(mypath,'wm',file_name))
            self.anno.append(os.path.join(self.base_folder,'natural',file_name.split('-')[0]+'.jpg'))

        if len(self.sample) > 0 :
            self.train = [ self.train[i] for i in self.sample ] 
            self.mask = [ self.mask[i] for i in self.sample ] 
            self.anno = [ self.anno[i] for i in self.sample ] 

        self.trans = transforms.Compose([
                transforms.Resize((self.input_size,self.input_size)),
                transforms.ToTensor()
            ])

        print('total Dataset of '+self.dataset+' is : ', len(self.train))


    def __getitem__(self, index):
        img = Image.open(self.train[index]).convert('RGB')
        mask = Image.open(self.mask[index]).convert('L')
        anno = Image.open(self.anno[index]).convert('RGB')
        wm = Image.open(self.wm[index]).convert('RGB')

        return {"image": self.trans(img),
                "target": self.trans(anno), 
                "mask": self.trans(mask), 
                "wm": self.trans(wm),
                "name": self.train[index].split('/')[-1],
                "imgurl":self.train[index],
                "maskurl":self.mask[index],
                "targeturl":self.anno[index],
                "wmurl":self.wm[index]
                }

    def __len__(self):

        return len(self.train)


================================================
FILE: scripts/datasets/__init__.py
================================================
from .COCO import COCO
from .BIH import BIH

__all__ = ('COCO','BIH')

================================================
FILE: scripts/machines/BasicMachine.py
================================================
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
from progress.bar import Bar
import json
import numpy as np
from tensorboardX import SummaryWriter
from scripts.utils.evaluation import accuracy, AverageMeter, final_preds
from scripts.utils.osutils import mkdir_p, isfile, isdir, join
from scripts.utils.parallel import DataParallelModel, DataParallelCriterion
import pytorch_ssim as pytorch_ssim
import torch.optim
import sys,shutil,os
import time
import scripts.models as archs
from math import log10
from torch.autograd import Variable
from scripts.utils.losses import VGGLoss
from scripts.utils.imutils import im_to_numpy

import skimage.io
from skimage.measure import compare_psnr,compare_ssim


class BasicMachine(object):
    def __init__(self, datasets =(None,None), models = None, args = None, **kwargs):
        super(BasicMachine, self).__init__()
        
        self.args = args
        
        # create model
        print("==> creating model ")
        self.model = archs.__dict__[self.args.arch]()
        print("==> creating model [Finish]")
       
        self.train_loader, self.val_loader = datasets
        self.loss = torch.nn.MSELoss()
        
        self.title = '_'+args.machine + '_' + args.data + '_' + args.arch
        self.args.checkpoint = args.checkpoint + self.title
        self.device = torch.device('cuda')
         # create checkpoint dir
        if not isdir(self.args.checkpoint):
            mkdir_p(self.args.checkpoint)

        self.optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, self.model.parameters()), 
                            lr=args.lr,
                            betas=(args.beta1,args.beta2),
                            weight_decay=args.weight_decay)  
        
        if not self.args.evaluate:
            self.writer = SummaryWriter(self.args.checkpoint+'/'+'ckpt')
        
        self.best_acc = 0
        self.is_best = False
        self.current_epoch = 0
        self.metric = -100000
        self.hl = 6 if self.args.hl else 1
        self.count_gpu = len(range(torch.cuda.device_count()))

        if self.args.style_loss > 0:
            # init perception loss
            self.vggloss = VGGLoss(self.args.sltype).to(self.device)

        if self.count_gpu > 1 : # multiple
            # self.model = DataParallelModel(self.model, device_ids=range(torch.cuda.device_count()))
            # self.loss = DataParallelCriterion(self.loss, device_ids=range(torch.cuda.device_count()))
            self.model = nn.DataParallel(self.model, device_ids=range(torch.cuda.device_count()))

        self.model.to(self.device)
        self.loss.to(self.device)

        print('==> Total params: %.2fM' % (sum(p.numel() for p in self.model.parameters())/1000000.0))
        print('==> Total devices: %d' % (torch.cuda.device_count()))
        print('==> Current Checkpoint: %s' % (self.args.checkpoint))


        if self.args.resume != '':
            self.resume(self.args.resume)


    def train(self,epoch):
        batch_time = AverageMeter()
        data_time = AverageMeter()
        losses = AverageMeter()
        lossvgg = AverageMeter()
        
        # switch to train mode
        self.model.train()
        end = time.time()

        bar = Bar('Processing', max=len(self.train_loader)*self.hl)
        for _ in range(self.hl):
            for i, batches in enumerate(self.train_loader):
                # measure data loading time
                inputs = batches['image']
                target = batches['target'].to(self.device)
                mask =batches['mask'].to(self.device)
                current_index = len(self.train_loader) * epoch + i

                if self.args.hl:
                    feeded = torch.cat([inputs,mask],dim=1)
                else:
                    feeded = inputs
                feeded = feeded.to(self.device)

                output = self.model(feeded)
                L2_loss =  self.loss(output,target) 
                
                if self.args.style_loss > 0:
                    vgg_loss = self.vggloss(output,target,mask)
                else:
                    vgg_loss = 0

                total_loss = L2_loss + self.args.style_loss * vgg_loss

                # compute gradient and do SGD step
                self.optimizer.zero_grad()
                total_loss.backward()
                self.optimizer.step()

                # measure accuracy and record loss
                losses.update(L2_loss.item(), inputs.size(0))
                
                if self.args.style_loss > 0 :
                    lossvgg.update(vgg_loss.item(), inputs.size(0))
                
                # measure elapsed time
                batch_time.update(time.time() - end)
                end = time.time()

                # plot progress
                suffix  = '({batch}/{size}) Data: {data:.2f}s | Batch: {bt:.3f}s | Total: {total:} | ETA: {eta:} | Loss L2: {loss_label:.4f} | Loss VGG: {loss_vgg:.4f}'.format(
                            batch=i + 1,
                            size=len(self.train_loader),
                            data=data_time.val,
                            bt=batch_time.val,
                            total=bar.elapsed_td,
                            eta=bar.eta_td,
                            loss_label=losses.avg,
                            loss_vgg=lossvgg.avg
                            )

                if current_index % 1000 == 0:
                    print(suffix)
                
                if self.args.freq > 0 and current_index % self.args.freq == 0:
                    self.validate(current_index)
                    self.flush()
                    self.save_checkpoint()
        
        self.record('train/loss_L2', losses.avg, current_index)


    def test(self, ):

        # switch to evaluate mode
        self.model.eval()

        ssimes = AverageMeter()
        psnres = AverageMeter()

        with torch.no_grad():
            for i, batches in enumerate(self.val_loader):

                inputs = batches['image'].to(self.device)
                target = batches['target'].to(self.device)
                mask =batches['mask'].to(self.device)

                outputs = self.model(inputs)

                # select the outputs by the giving arch
                if type(outputs) == type(inputs):
                    output = outputs
                elif type(outputs[0]) == type([]):
                    output = outputs[0][0]
                else:
                    output = outputs[0]

                # recover the image to 255
                output = im_to_numpy(torch.clamp(output[0]*255,min=0.0,max=255.0)).astype(np.uint8)
                target = im_to_numpy(torch.clamp(target[0]*255,min=0.0,max=255.0)).astype(np.uint8)

                skimage.io.imsave('%s/%s'%(self.args.checkpoint,batches['name'][0]), output)

                psnr = compare_psnr(target,output)
                ssim = compare_ssim(target,output,multichannel=True)

                psnres.update(psnr, inputs.size(0))
                ssimes.update(ssim, inputs.size(0))

        print("%s:PSNR:%s,SSIM:%s"%(self.args.checkpoint,psnres.avg,ssimes.avg))
        print("DONE.\n")
              
        
    def validate(self, epoch):
        batch_time = AverageMeter()
        data_time = AverageMeter()
        losses = AverageMeter()
        ssimes = AverageMeter()
        psnres = AverageMeter()
        # switch to evaluate mode
        self.model.eval()

        end = time.time()
        with torch.no_grad():
            for i, batches in enumerate(self.val_loader):

                inputs = batches['image'].to(self.device)
                target = batches['target'].to(self.device)
                mask =batches['mask'].to(self.device)
                
                if self.args.hl:
                    feeded = torch.cat([inputs,torch.zeros((1,4,self.args.input_size,self.args.input_size)).to(self.device)],dim=1)
                else:
                    feeded = inputs

                output = self.model(feeded)

                L2_loss = self.loss(output, target)

                psnr = 10 * log10(1 / L2_loss.item())   
                ssim = pytorch_ssim.ssim(output, target)    

                losses.update(L2_loss.item(), inputs.size(0))
                psnres.update(psnr, inputs.size(0))
                ssimes.update(ssim.item(), inputs.size(0))

                # measure elapsed time
                batch_time.update(time.time() - end)
                end = time.time()

        print("Epoches:%s,Losses:%.3f,PSNR:%.3f,SSIM:%.3f"%(epoch+1, losses.avg,psnres.avg,ssimes.avg))
        self.record('val/loss_L2', losses.avg, epoch)
        self.record('val/PSNR', psnres.avg, epoch)
        self.record('val/SSIM', ssimes.avg, epoch)
        
        self.metric = psnres.avg
        
    def resume(self,resume_path):
        if isfile(resume_path):
                print("=> loading checkpoint '{}'".format(resume_path))
                current_checkpoint = torch.load(resume_path)
                if isinstance(current_checkpoint['state_dict'], torch.nn.DataParallel):
                    current_checkpoint['state_dict'] = current_checkpoint['state_dict'].module

                if isinstance(current_checkpoint['optimizer'], torch.nn.DataParallel):
                    current_checkpoint['optimizer'] = current_checkpoint['optimizer'].module

                self.args.start_epoch = current_checkpoint['epoch']
                self.metric = current_checkpoint['best_acc']
                self.model.load_state_dict(current_checkpoint['state_dict'])
                # self.optimizer.load_state_dict(current_checkpoint['optimizer'])
                print("=> loaded checkpoint '{}' (epoch {})"
                      .format(resume_path, current_checkpoint['epoch']))
        else:
            raise Exception("=> no checkpoint found at '{}'".format(resume_path))

    def save_checkpoint(self,filename='checkpoint.pth.tar', snapshot=None):
        is_best = True if self.best_acc < self.metric else False

        if is_best:
            self.best_acc = self.metric

        state = {
                    'epoch': self.current_epoch + 1,
                    'arch': self.args.arch,
                    'state_dict': self.model.state_dict(),
                    'best_acc': self.best_acc,
                    'optimizer' : self.optimizer.state_dict() if self.optimizer else None,
                }

        filepath = os.path.join(self.args.checkpoint, filename)
        torch.save(state, filepath)

        if snapshot and state['epoch'] % snapshot == 0:
            shutil.copyfile(filepath, os.path.join(self.args.checkpoint, 'checkpoint_{}.pth.tar'.format(state.epoch)))
        
        if is_best:
            self.best_acc = self.metric
            print('Saving Best Metric with PSNR:%s'%self.best_acc)
            shutil.copyfile(filepath, os.path.join(self.args.checkpoint, 'model_best.pth.tar'))

    def clean(self):
        self.writer.close()

    def record(self,k,v,epoch):
        self.writer.add_scalar(k, v, epoch)

    def flush(self):
        self.writer.flush()
        sys.stdout.flush()

    def norm(self,x):
        if self.args.gan_norm:
            return x*2.0 - 1.0
        else:
            return x

    def denorm(self,x):
        if self.args.gan_norm:
            return (x+1.0)/2.0
        else:
            return x



================================================
FILE: scripts/machines/S2AM.py
================================================
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
from progress.bar import Bar
import json
import numpy as np
from tensorboardX import SummaryWriter
from scripts.utils.evaluation import accuracy, AverageMeter, final_preds
from scripts.utils.osutils import mkdir_p, isfile, isdir, join
from scripts.utils.parallel import DataParallelModel, DataParallelCriterion
import pytorch_ssim as pytorch_ssim
import torch.optim
import sys,shutil,os
import time
import scripts.models as archs
from math import log10
from torch.autograd import Variable
from scripts.utils.losses import VGGLoss
from scripts.utils.imutils import im_to_numpy

import skimage.io
from skimage.measure import compare_psnr,compare_ssim


class S2AM(object):
    def __init__(self, datasets =(None,None), models = None, args = None, **kwargs):
        super(S2AM, self).__init__()
        
        self.args = args
        
        # create model
        print("==> creating model ")
        self.model = archs.__dict__[self.args.arch]()
        print("==> creating model [Finish]")
       
        self.train_loader, self.val_loader = datasets
        self.loss = torch.nn.MSELoss()
        
        self.title = '_'+args.machine + '_' + args.data + '_' + args.arch
        self.args.checkpoint = args.checkpoint + self.title
        self.device = torch.device('cuda')
         # create checkpoint dir
        if not isdir(self.args.checkpoint):
            mkdir_p(self.args.checkpoint)

        self.optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, self.model.parameters()), 
                            lr=args.lr,
                            betas=(args.beta1,args.beta2),
                            weight_decay=args.weight_decay)  
        
        if not self.args.evaluate:
            self.writer = SummaryWriter(self.args.checkpoint+'/'+'ckpt')
        
        self.best_acc = 0
        self.is_best = False
        self.current_epoch = 0
        self.hl = 1
        self.metric = -100000
        self.count_gpu = len(range(torch.cuda.device_count()))

        if self.args.style_loss > 0:
            # init perception loss
            self.vggloss = VGGLoss(self.args.sltype).to(self.device)

        if self.count_gpu > 1 : # multiple
            # self.model = DataParallelModel(self.model, device_ids=range(torch.cuda.device_count()))
            # self.loss = DataParallelCriterion(self.loss, device_ids=range(torch.cuda.device_count()))
            self.model = nn.DataParallel(self.model, device_ids=range(torch.cuda.device_count()))

        self.model.to(self.device)
        self.loss.to(self.device)

        print('==> Total params: %.2fM' % (sum(p.numel() for p in self.model.parameters())/1000000.0))
        print('==> Total devices: %d' % (torch.cuda.device_count()))
        print('==> Current Checkpoint: %s' % (self.args.checkpoint))


        if self.args.resume != '':
            self.resume(self.args.resume)


    def train(self,epoch):
        batch_time = AverageMeter()
        data_time = AverageMeter()
        losses = AverageMeter()
        lossvgg = AverageMeter()
        
        # switch to train mode
        self.model.train()
        end = time.time()

        bar = Bar('Processing', max=len(self.train_loader)*self.hl)
        for _ in range(self.hl):
            for i, batches in enumerate(self.train_loader):
                # measure data loading time
                inputs = batches['image'].to(self.device)
                target = batches['target'].to(self.device)
                mask =batches['mask'].to(self.device)
                current_index = len(self.train_loader) * epoch + i

                feeded = torch.cat([inputs,mask],dim=1)
                feeded = feeded.to(self.device)

                output = self.model(feeded)

                if self.args.res:
                    output = output + inputs

                L2_loss =  self.loss(output,target) 
                
                if self.args.style_loss > 0:
                    vgg_loss = self.vggloss(output,target,mask)
                else:
                    vgg_loss = 0

                total_loss = L2_loss + self.args.style_loss * vgg_loss

                # compute gradient and do SGD step
                self.optimizer.zero_grad()
                total_loss.backward()
                self.optimizer.step()

                # measure accuracy and record loss
                losses.update(L2_loss.item(), inputs.size(0))
                
                if self.args.style_loss > 0 :
                    lossvgg.update(vgg_loss.item(), inputs.size(0))
                
                # measure elapsed time
                batch_time.update(time.time() - end)
                end = time.time()

                # plot progress
                suffix  = '({batch}/{size}) Data: {data:.2f}s | Batch: {bt:.3f}s | Total: {total:} | ETA: {eta:} | Loss L2: {loss_label:.4f} | Loss VGG: {loss_vgg:.4f}'.format(
                            batch=i + 1,
                            size=len(self.train_loader),
                            data=data_time.val,
                            bt=batch_time.val,
                            total=bar.elapsed_td,
                            eta=bar.eta_td,
                            loss_label=losses.avg,
                            loss_vgg=lossvgg.avg
                            )

                if current_index % 1000 == 0:
                    print(suffix)
                
                if self.args.freq > 0 and current_index % self.args.freq == 0:
                    self.validate(current_index)
                    self.flush()
                    self.save_checkpoint()
        
        self.record('train/loss_L2', losses.avg, current_index)


    def test(self, ):

        # switch to evaluate mode
        self.model.eval()

        ssimes = AverageMeter()
        psnres = AverageMeter()

        with torch.no_grad():
            for i, batches in enumerate(self.val_loader):

                inputs = batches['image'].to(self.device)
                target = batches['target'].to(self.device)
                mask =batches['mask'].to(self.device)

                feeded = torch.cat([inputs,mask],dim=1)
                feeded = feeded.to(self.device)

                output = self.model(feeded)

                if self.args.res:
                    output = output + inputs

                # recover the image to 255
                output = im_to_numpy(torch.clamp(output[0]*255,min=0.0,max=255.0)).astype(np.uint8)
                target = im_to_numpy(torch.clamp(target[0]*255,min=0.0,max=255.0)).astype(np.uint8)

                skimage.io.imsave('%s/%s'%(self.args.checkpoint,batches['name'][0]), output)

                psnr = compare_psnr(target,output)
                ssim = compare_ssim(target,output,multichannel=True)

                psnres.update(psnr, inputs.size(0))
                ssimes.update(ssim, inputs.size(0))

        print("%s:PSNR:%s,SSIM:%s"%(self.args.checkpoint,psnres.avg,ssimes.avg))
        print("DONE.\n")
              
        
    def validate(self, epoch):
        batch_time = AverageMeter()
        data_time = AverageMeter()
        losses = AverageMeter()
        ssimes = AverageMeter()
        psnres = AverageMeter()
        # switch to evaluate mode
        self.model.eval()

        end = time.time()
        with torch.no_grad():
            for i, batches in enumerate(self.val_loader):

                inputs = batches['image'].to(self.device)
                target = batches['target'].to(self.device)
                mask =batches['mask'].to(self.device)
                
                feeded = torch.cat([inputs,mask],dim=1)
                feeded = feeded.to(self.device)

                output = self.model(feeded)

                if self.args.res:
                    output = output + inputs

                L2_loss = self.loss(output, target)

                psnr = 10 * log10(1 / L2_loss.item())   
                ssim = pytorch_ssim.ssim(output, target)    

                losses.update(L2_loss.item(), inputs.size(0))
                psnres.update(psnr, inputs.size(0))
                ssimes.update(ssim.item(), inputs.size(0))

                # measure elapsed time
                batch_time.update(time.time() - end)
                end = time.time()

        print("Epoches:%s,Losses:%.3f,PSNR:%.3f,SSIM:%.3f"%(epoch+1, losses.avg,psnres.avg,ssimes.avg))
        self.record('val/loss_L2', losses.avg, epoch)
        self.record('val/PSNR', psnres.avg, epoch)
        self.record('val/SSIM', ssimes.avg, epoch)
        
        self.metric = psnres.avg
        
    def resume(self,resume_path):
        if isfile(resume_path):
                print("=> loading checkpoint '{}'".format(resume_path))
                current_checkpoint = torch.load(resume_path)
                if isinstance(current_checkpoint['state_dict'], torch.nn.DataParallel):
                    current_checkpoint['state_dict'] = current_checkpoint['state_dict'].module

                if isinstance(current_checkpoint['optimizer'], torch.nn.DataParallel):
                    current_checkpoint['optimizer'] = current_checkpoint['optimizer'].module

                self.args.start_epoch = current_checkpoint['epoch']
                self.metric = current_checkpoint['best_acc']
                self.model.load_state_dict(current_checkpoint['state_dict'])
                # self.optimizer.load_state_dict(current_checkpoint['optimizer'])
                print("=> loaded checkpoint '{}' (epoch {})"
                      .format(resume_path, current_checkpoint['epoch']))
        else:
            raise Exception("=> no checkpoint found at '{}'".format(resume_path))

    def save_checkpoint(self,filename='checkpoint.pth.tar', snapshot=None):
        is_best = True if self.best_acc < self.metric else False

        if is_best:
            self.best_acc = self.metric

        state = {
                    'epoch': self.current_epoch + 1,
                    'arch': self.args.arch,
                    'state_dict': self.model.state_dict(),
                    'best_acc': self.best_acc,
                    'optimizer' : self.optimizer.state_dict() if self.optimizer else None,
                }

        filepath = os.path.join(self.args.checkpoint, filename)
        torch.save(state, filepath)

        if snapshot and state['epoch'] % snapshot == 0:
            shutil.copyfile(filepath, os.path.join(self.args.checkpoint, 'checkpoint_{}.pth.tar'.format(state.epoch)))
        
        if is_best:
            self.best_acc = self.metric
            print('Saving Best Metric with PSNR:%s'%self.best_acc)
            shutil.copyfile(filepath, os.path.join(self.args.checkpoint, 'model_best.pth.tar'))

    def clean(self):
        self.writer.close()

    def record(self,k,v,epoch):
        self.writer.add_scalar(k, v, epoch)

    def flush(self):
        self.writer.flush()
        sys.stdout.flush()

    def norm(self,x):
        if self.args.gan_norm:
            return x*2.0 - 1.0
        else:
            return x

    def denorm(self,x):
        if self.args.gan_norm:
            return (x+1.0)/2.0
        else:
            return x



================================================
FILE: scripts/machines/VX.py
================================================
import torch
import torch.nn as nn
from progress.bar import Bar
from tqdm import tqdm
import pytorch_ssim
import json
import sys,time,os
import torchvision
from math import log10
import numpy as np
from .BasicMachine import BasicMachine
from scripts.utils.evaluation import accuracy, AverageMeter, final_preds
from scripts.utils.misc import resize_to_match
from torch.autograd import Variable
import torch.nn.functional as F
from scripts.utils.parallel import DataParallelModel, DataParallelCriterion
from scripts.utils.losses import VGGLoss, l1_relative,is_dic
from scripts.utils.imutils import im_to_numpy
import skimage.io
from skimage.measure import compare_psnr,compare_ssim


class Losses(nn.Module):
    def __init__(self, argx, device, norm_func=None, denorm_func=None):
        super(Losses, self).__init__()
        self.args = argx

        if self.args.loss_type == 'l1bl2':
            self.outputLoss, self.attLoss, self.wrloss = nn.L1Loss(), nn.BCELoss(), nn.MSELoss()
        elif self.args.loss_type == 'l2xbl2':
            self.outputLoss, self.attLoss, self.wrloss = nn.MSELoss(), nn.BCEWithLogitsLoss(), nn.MSELoss()
        elif self.args.loss_type == 'relative' or self.args.loss_type == 'hybrid':
            self.outputLoss, self.attLoss, self.wrloss = l1_relative, nn.BCELoss(), l1_relative
        else: # l2bl2
            self.outputLoss, self.attLoss, self.wrloss = nn.MSELoss(), nn.BCELoss(), nn.MSELoss()

        self.default = nn.L1Loss()

        if self.args.style_loss > 0:
            self.vggloss = VGGLoss(self.args.sltype).to(device)
        
        if self.args.ssim_loss > 0:
            self.ssimloss =  pytorch_ssim.SSIM().to(device)
        
        self.norm = norm_func
        self.denorm = denorm_func


    def forward(self,pred_ims,target,pred_ms,mask,pred_wms,wm):
        pixel_loss,att_loss,wm_loss,vgg_loss,ssim_loss = [0]*5
        pred_ims = pred_ims if is_dic(pred_ims) else [pred_ims]

        # try the loss in the masked region
        if self.args.masked and 'hybrid' in self.args.loss_type: # masked loss
            pixel_loss += sum([self.outputLoss(pred_im, target, mask) for pred_im in pred_ims])
            pixel_loss += sum([self.default(pred_im*pred_ms,target*mask) for pred_im in pred_ims])
            recov_imgs = [ self.denorm(pred_im*mask + (1-mask)*self.norm(target)) for pred_im in pred_ims ]
            wm_loss += self.wrloss(pred_wms, wm, mask)
            wm_loss += self.default(pred_wms*pred_ms, wm*mask)

        elif self.args.masked and 'relative' in self.args.loss_type: # masked loss
            pixel_loss += sum([self.outputLoss(pred_im, target, mask) for pred_im in pred_ims])
            recov_imgs = [ self.denorm(pred_im*mask + (1-mask)*self.norm(target)) for pred_im in pred_ims ]
            wm_loss = self.wrloss(pred_wms, wm, mask)
        elif self.args.masked:
            pixel_loss += sum([self.outputLoss(pred_im*mask, target*mask) for pred_im in pred_ims])
            recov_imgs = [ self.denorm(pred_im*pred_ms + (1-pred_ms)*self.norm(target)) for pred_im in pred_ims ]
            wm_loss = self.wrloss(pred_wms*mask, wm*mask)
        else:
            pixel_loss += sum([self.outputLoss(pred_im*pred_ms, target*mask) for pred_im in pred_ims])
            recov_imgs = [ self.denorm(pred_im*pred_ms + (1-pred_ms)*self.norm(target)) for pred_im in pred_ims ]
            wm_loss = self.wrloss(pred_wms*pred_ms,wm*mask)

        pixel_loss += sum([self.default(im,target) for im in recov_imgs])

        if self.args.style_loss > 0:
            vgg_loss = sum([self.vggloss(im,target,mask) for im in recov_imgs])

        if self.args.ssim_loss > 0:
            ssim_loss = sum([ 1 - self.ssimloss(im,target) for im in recov_imgs])

        att_loss =  self.attLoss(pred_ms, mask)

        return pixel_loss,att_loss,wm_loss,vgg_loss,ssim_loss


class VX(BasicMachine):
    def __init__(self,**kwargs):
        BasicMachine.__init__(self,**kwargs)
        self.loss = Losses(self.args, self.device, self.norm, self.denorm)
        self.model.set_optimizers()
        self.optimizer = None
       
    def train(self,epoch):

        self.current_epoch = epoch

        batch_time = AverageMeter()
        data_time = AverageMeter()
        losses = AverageMeter()
        lossMask = AverageMeter()
        lossWM = AverageMeter()
        lossMX = AverageMeter()
        lossvgg = AverageMeter()
        lossssim = AverageMeter()

        # switch to train mode
        self.model.train()

        end = time.time()
        bar = Bar('Processing {} '.format(self.args.arch), max=len(self.train_loader))

        for i, batches in enumerate(self.train_loader):

            current_index = len(self.train_loader) * epoch + i

            inputs = batches['image'].to(self.device)
            target = batches['target'].to(self.device)
            mask = batches['mask'].to(self.device)
            wm =  batches['wm'].to(self.device)

            outputs = self.model(self.norm(inputs))
            
            self.model.zero_grad_all()

            l2_loss,att_loss,wm_loss,style_loss,ssim_loss = self.loss(outputs[0],self.norm(target),outputs[1],mask,outputs[2],self.norm(wm))
            total_loss = 2*l2_loss + self.args.att_loss * att_loss + wm_loss + self.args.style_loss * style_loss + self.args.ssim_loss * ssim_loss

            # compute gradient and do SGD step
            total_loss.backward()
            self.model.step_all()

            # measure accuracy and record loss
            losses.update(l2_loss.item(), inputs.size(0))
            lossMask.update(att_loss.item(), inputs.size(0))
            lossWM.update(wm_loss.item(), inputs.size(0))

            if self.args.style_loss > 0 :
                lossvgg.update(style_loss.item(), inputs.size(0))

            if self.args.ssim_loss > 0 :
                lossssim.update(ssim_loss.item(), inputs.size(0))


            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            # plot progress
            suffix  = "({batch}/{size}) Data: {data:.2f}s | Batch: {bt:.3f}s | Total: {total:} | ETA: {eta:} | Loss L2: {loss_label:.4f} | Loss Mask: {loss_mask:.4f} | loss WM: {loss_wm:.4f} | loss VGG: {loss_vgg:.4f} | loss SSIM: {loss_ssim:.4f}| loss MX: {loss_mx:.4f}".format(
                        batch=i + 1,
                        size=len(self.train_loader),
                        data=data_time.val,
                        bt=batch_time.val,
                        total=bar.elapsed_td,
                        eta=bar.eta_td,
                        loss_label=losses.avg,
                        loss_mask=lossMask.avg,
                        loss_wm=lossWM.avg,
                        loss_vgg=lossvgg.avg,
                        loss_ssim=lossssim.avg,
                        loss_mx=lossMX.avg
                        )
            if current_index % 1000 == 0:
                print(suffix)

            if self.args.freq > 0 and current_index % self.args.freq == 0:
                self.validate(current_index)
                self.flush()
                self.save_checkpoint()

        self.record('train/loss_L2', losses.avg, epoch)
        self.record('train/loss_Mask', lossMask.avg, epoch)
        self.record('train/loss_WM', lossWM.avg, epoch)
        self.record('train/loss_VGG', lossvgg.avg, epoch)
        self.record('train/loss_SSIM', lossssim.avg, epoch)
        self.record('train/loss_MX', lossMX.avg, epoch)




    def validate(self, epoch):

        self.current_epoch = epoch
        
        batch_time = AverageMeter()
        data_time = AverageMeter()
        losses = AverageMeter()
        lossMask = AverageMeter()
        psnres = AverageMeter()
        ssimes = AverageMeter()

        # switch to evaluate mode
        self.model.eval()

        end = time.time()
        bar = Bar('Processing {} '.format(self.args.arch), max=len(self.val_loader))
        with torch.no_grad():
            for i, batches in enumerate(self.val_loader):

                current_index = len(self.val_loader) * epoch + i

                inputs = batches['image'].to(self.device)
                target = batches['target'].to(self.device)

                outputs = self.model(self.norm(inputs))
                imoutput,immask,imwatermark = outputs
                imoutput = imoutput[0] if is_dic(imoutput) else imoutput

                imfinal = self.denorm(imoutput*immask + self.norm(inputs)*(1-immask))

                if i % 300 == 0:
                    # save the sample images
                    ims = torch.cat([inputs,target,imfinal,immask.repeat(1,3,1,1)],dim=3)
                    torchvision.utils.save_image(ims,os.path.join(self.args.checkpoint,'%s_%s.jpg'%(i,epoch)))

                # here two choice: mseLoss or NLLLoss
                psnr = 10 * log10(1 / F.mse_loss(imfinal,target).item())       

                ssim = pytorch_ssim.ssim(imfinal,target)

                psnres.update(psnr, inputs.size(0))
                ssimes.update(ssim, inputs.size(0))

                # measure elapsed time
                batch_time.update(time.time() - end)
                end = time.time()

                # plot progress
                bar.suffix  = '({batch}/{size}) Data: {data:.2f}s | Batch: {bt:.3f}s | Total: {total:} | ETA: {eta:} | Loss_L2: {loss_label:.4f} | Loss_Mask: {loss_mask:.4f} | PSNR: {psnr:.4f} | SSIM: {ssim:.4f}'.format(
                            batch=i + 1,
                            size=len(self.val_loader),
                            data=data_time.val,
                            bt=batch_time.val,
                            total=bar.elapsed_td,
                            eta=bar.eta_td,
                            loss_label=losses.avg,
                            loss_mask=lossMask.avg,
                            psnr=psnres.avg,
                            ssim=ssimes.avg
                            )
                bar.next()
        bar.finish()
        
        print("Iter:%s,Losses:%s,PSNR:%.4f,SSIM:%.4f"%(epoch, losses.avg,psnres.avg,ssimes.avg))
        self.record('val/loss_L2', losses.avg, epoch)
        self.record('val/lossMask', lossMask.avg, epoch)
        self.record('val/PSNR', psnres.avg, epoch)
        self.record('val/SSIM', ssimes.avg, epoch)
        self.metric = psnres.avg

        self.model.train()

    def test(self, ):

        # switch to evaluate mode
        self.model.eval()
        print("==> testing VM model ")
        ssimes = AverageMeter()
        psnres = AverageMeter()
        ssimesx = AverageMeter()
        psnresx = AverageMeter()

        with torch.no_grad():
            for i, batches in enumerate(tqdm(self.val_loader)):

                inputs = batches['image'].to(self.device)
                target = batches['target'].to(self.device)
                mask =batches['mask'].to(self.device)

                # select the outputs by the giving arch
                outputs = self.model(self.norm(inputs))
                imoutput,immask,imwatermark = outputs
                imoutput = imoutput[0] if is_dic(imoutput) else imoutput

                imfinal = self.denorm(imoutput*immask + self.norm(inputs)*(1-immask))
                psnrx = 10 * log10(1 / F.mse_loss(imfinal,target).item())       
                ssimx = pytorch_ssim.ssim(imfinal,target)
                # recover the image to 255
                imfinal = im_to_numpy(torch.clamp(imfinal[0]*255,min=0.0,max=255.0)).astype(np.uint8)
                target = im_to_numpy(torch.clamp(target[0]*255,min=0.0,max=255.0)).astype(np.uint8)

                skimage.io.imsave('%s/%s'%(self.args.checkpoint,batches['name'][0]), imfinal)

                psnr = compare_psnr(target,imfinal)
                ssim = compare_ssim(target,imfinal,multichannel=True)

                psnres.update(psnr, inputs.size(0))
                ssimes.update(ssim, inputs.size(0))
                psnresx.update(psnrx, inputs.size(0))
                ssimesx.update(ssimx, inputs.size(0))

        print("%s:PSNR:%.5f(%.5f),SSIM:%.5f(%.5f)"%(self.args.checkpoint,psnres.avg,psnresx.avg,ssimes.avg,ssimesx.avg))
        print("DONE.\n")

================================================
FILE: scripts/machines/__init__.py
================================================

from .BasicMachine import BasicMachine
from .VX import VX
from .S2AM import S2AM

def basic(**kwargs):
	return BasicMachine(**kwargs)

def s2am(**kwargs):
    return S2AM(**kwargs)

def vx(**kwargs):
    return VX(**kwargs)


================================================
FILE: scripts/models/__init__.py
================================================
from .vgg import *
from .backbone_unet import *
from .discriminator import *



================================================
FILE: scripts/models/backbone_unet.py
================================================


import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import functools
import math

from scripts.utils.model_init import *
from scripts.models.rasc import *
from scripts.models.unet import UnetGenerator,MinimalUnetV2
from scripts.models.vmu import UnetVM
from scripts.models.sa_resunet import UnetVMS2AMv4


# our method
def vvv4n(**kwargs):
    return UnetVMS2AMv4(shared_depth=2, blocks=3, long_skip=True, use_vm_decoder=True,s2am='vms2am')


# BVMR
def vm3(**kwargs):
    return UnetVM(shared_depth=2, blocks=3, use_vm_decoder=True)


# Blind version of S2AM
def urasc(**kwargs):
    model = UnetGenerator(3,3,is_attention_layer=True,attention_model=URASC,basicblock=MinimalUnetV2)
    model.apply(weights_init_kaiming)
    return model


# Improving the Harmony of the Composite Image by Spatial-Separated Attention Module
# Xiaodong Cun and Chi-Man Pun
# University of Macau
# Trans. on Image Processing, vol. 29, pp. 4759-4771, 2020.
def rascv2(**kwargs):
    model = UnetGenerator(4,3,is_attention_layer=True,attention_model=RASC,basicblock=MinimalUnetV2)
    model.apply(weights_init_kaiming)
    return model

# just original unet
def unet(**kwargs):
    model = UnetGenerator(3,3)
    model.apply(weights_init_kaiming)
    return model




================================================
FILE: scripts/models/blocks.py
================================================
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import functools
import math
import numbers

from scripts.utils.model_init import *
from scripts.models.vgg import Vgg16
from torch import nn, cuda
from torch.autograd import Variable

class BasicLearningBlock(nn.Module):
    """docstring for BasicLearningBlock"""
    def __init__(self,channel):
        super(BasicLearningBlock, self).__init__()
        self.rconv1 = nn.Conv2d(channel,channel*2,3,padding=1,bias=False)
        self.rbn1 = nn.BatchNorm2d(channel*2)
        self.rconv2 = nn.Conv2d(channel*2,channel,3,padding=1,bias=False)
        self.rbn2 = nn.BatchNorm2d(channel)

    def forward(self,feature):
        return F.elu(self.rbn2(self.rconv2(F.elu(self.rbn1(self.rconv1(feature)))))) 
        


# From https://discuss.pytorch.org/t/is-there-anyway-to-do-gaussian-filtering-for-an-image-2d-3d-in-pytorch/12351/3
class GaussianSmoothing(nn.Module):
    """
    Apply gaussian smoothing on a
    1d, 2d or 3d tensor. Filtering is performed seperately for each channel
    in the input using a depthwise convolution.
    Arguments:
        channels (int, sequence): Number of channels of the input tensors. Output will
            have this number of channels as well.
        kernel_size (int, sequence): Size of the gaussian kernel.
        sigma (float, sequence): Standard deviation of the gaussian kernel.
        dim (int, optional): The number of dimensions of the data.
            Default value is 2 (spatial).
    """
    def __init__(self, channels, kernel_size, sigma, dim=2):
        super(GaussianSmoothing, self).__init__()
        if isinstance(kernel_size, numbers.Number):
            kernel_size = [kernel_size] * dim
        if isinstance(sigma, numbers.Number):
            sigma = [sigma] * dim

        # The gaussian kernel is the product of the
        # gaussian function of each dimension.
        kernel = 1
        meshgrids = torch.meshgrid(
            [
                torch.arange(size, dtype=torch.float32)
                for size in kernel_size
            ]
        )
        for size, std, mgrid in zip(kernel_size, sigma, meshgrids):
            mean = (size - 1) / 2
            kernel *= 1 / (std * math.sqrt(2 * math.pi)) * \
                      torch.exp(-((mgrid - mean) / (2 * std)) ** 2)

        # Make sure sum of values in gaussian kernel equals 1.
        kernel = kernel / torch.sum(kernel)

        # Reshape to depthwise convolutional weight
        kernel = kernel.view(1, 1, *kernel.size())
        kernel = kernel.repeat(channels, *[1] * (kernel.dim() - 1))

        self.register_buffer('weight', kernel)
        self.groups = channels

        if dim == 1:
            self.conv = F.conv1d
        elif dim == 2:
            self.conv = F.conv2d
        elif dim == 3:
            self.conv = F.conv3d
        else:
            raise RuntimeError(
                'Only 1, 2 and 3 dimensions are supported. Received {}.'.format(dim)
            )

    def forward(self, input):
        """
        Apply gaussian filter to input.
        Arguments:
            input (torch.Tensor): Input to apply gaussian filter on.
        Returns:
            filtered (torch.Tensor): Filtered output.
        """
        return self.conv(input, weight=self.weight, groups=self.groups)

class ChannelPool(nn.Module):
    def __init__(self,types):
        super(ChannelPool, self).__init__()
        if types == 'avg': 
            self.poolingx = nn.AdaptiveAvgPool1d(1)
        elif types == 'max':
            self.poolingx = nn.AdaptiveMaxPool1d(1)
        else:
            raise 'inner error'

    def forward(self, input):
        n, c, w, h = input.size()
        input = input.view(n,c,w*h).permute(0,2,1) 
        pooled =  self.poolingx(input)# b,w*h,c ->  b,w*h,1
        _, _, c = pooled.size()
        return pooled.view(n,c,w,h)



class SEBlock(nn.Module):
    """docstring for SEBlock"""
    def __init__(self, channel,reducation=16):
        super(SEBlock, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
            nn.Linear(channel,channel//reducation),
            nn.ReLU(inplace=True),
            nn.Linear(channel//reducation,channel),
            nn.Sigmoid())
        
    def forward(self,x):
        b,c,w,h = x.size()
        y1 = self.avg_pool(x).view(b,c)
        y = self.fc(y1).view(b,c,1,1)
        return x*y



class GlobalAttentionModule(nn.Module):
    """docstring for GlobalAttentionModule"""
    def __init__(self, channel,reducation=16):
        super(GlobalAttentionModule, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)
        self.fc = nn.Sequential(
            nn.Linear(channel*2,channel//reducation),
            nn.ReLU(inplace=True),
            nn.Linear(channel//reducation,channel),
            nn.Sigmoid())
        
    def forward(self,x):
        b,c,w,h = x.size()
        y1 = self.avg_pool(x).view(b,c)
        y2 = self.max_pool(x).view(b,c)
        y = self.fc(torch.cat([y1,y2],1)).view(b,c,1,1)
        return x*y

class SpatialAttentionModule(nn.Module):
    """docstring for SpatialAttentionModule"""
    def __init__(self, channel,reducation=16):
        super(SpatialAttentionModule, self).__init__()
        self.avg_pool = ChannelPool('avg')
        self.max_pool = ChannelPool('max')
        self.fc = nn.Sequential(
            nn.Conv2d(2,reducation,7,stride=1,padding=3),
            nn.ReLU(inplace=True),
            nn.Conv2d(reducation,1,7,stride=1,padding=3),
            nn.Sigmoid())
        
    def forward(self,x):
        b,c,w,h = x.size()
        y1 = self.avg_pool(x)
        y2 = self.max_pool(x)
        y = self.fc(torch.cat([y1,y2],1))
        yr = 1-y
        return y,yr



class GlobalAttentionModuleJustSigmoid(nn.Module):
    """docstring for GlobalAttentionModule"""
    def __init__(self, channel,reducation=16):
        super(GlobalAttentionModuleJustSigmoid, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)
        self.fc = nn.Sequential(
            nn.Linear(channel*2,channel//reducation),
            nn.ReLU(inplace=True),
            nn.Linear(channel//reducation,channel),
            nn.Sigmoid())
        
    def forward(self,x):
        b,c,w,h = x.size()
        y1 = self.avg_pool(x).view(b,c)
        y2 = self.max_pool(x).view(b,c)
        y = self.fc(torch.cat([y1,y2],1)).view(b,c,1,1)
        return y



class BasicBlock(nn.Module):
    def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1, groups=1, relu=True, bn=True, bias=False):
        super(BasicBlock, self).__init__()
        self.out_channels = out_planes
        self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias)
        self.bn = nn.BatchNorm2d(out_planes,eps=1e-5, momentum=0.01, affine=True) if bn else None
        self.relu = nn.ReLU() if relu else None

    def forward(self, x):
        x = self.conv(x)
        if self.bn is not None:
            x = self.bn(x)
        if self.relu is not None:
            x = self.relu(x)
        return x

class Flatten(nn.Module):
    def forward(self, x):
        return x.view(x.size(0), -1)

class ChannelGate(nn.Module):
    def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg', 'max']):
        super(ChannelGate, self).__init__()
        self.gate_channels = gate_channels
        self.mlp = nn.Sequential(
            Flatten(),
            nn.Linear(gate_channels, gate_channels // reduction_ratio),
            nn.ReLU(),
            nn.Linear(gate_channels // reduction_ratio, gate_channels)
            )
        self.pool_types = pool_types
    def forward(self, x):
        channel_att_sum = None
        for pool_type in self.pool_types:
            if pool_type=='avg':
                avg_pool = F.avg_pool2d( x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))
                channel_att_raw = self.mlp( avg_pool )
            elif pool_type=='max':
                max_pool = F.max_pool2d( x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))
                channel_att_raw = self.mlp( max_pool )
            elif pool_type=='lp':
                lp_pool = F.lp_pool2d( x, 2, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))
                channel_att_raw = self.mlp( lp_pool )
            elif pool_type=='lse':
                # LSE pool only
                lse_pool = logsumexp_2d(x)
                channel_att_raw = self.mlp( lse_pool )

            if channel_att_sum is None:
                channel_att_sum = channel_att_raw
            else:
                channel_att_sum = channel_att_sum + channel_att_raw

        scale = F.sigmoid( channel_att_sum ).unsqueeze(2).unsqueeze(3).expand_as(x)
        return x * scale

def logsumexp_2d(tensor):
    tensor_flatten = tensor.view(tensor.size(0), tensor.size(1), -1)
    s, _ = torch.max(tensor_flatten, dim=2, keepdim=True)
    outputs = s + (tensor_flatten - s).exp().sum(dim=2, keepdim=True).log()
    return outputs

class ChannelPoolX(nn.Module):
    def forward(self, x):
        return torch.cat( (torch.max(x,1)[0].unsqueeze(1), torch.mean(x,1).unsqueeze(1)), dim=1 )

class SpatialGate(nn.Module):
    def __init__(self):
        super(SpatialGate, self).__init__()
        kernel_size = 7
        self.compress = ChannelPoolX()
        self.spatial = BasicBlock(2, 1, kernel_size, stride=1, padding=(kernel_size-1) // 2, relu=False)
    def forward(self, x):
        x_compress = self.compress(x)
        x_out = self.spatial(x_compress)
        scale = F.sigmoid(x_out) # broadcasting
        return x * scale

class CBAM(nn.Module):
    def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg', 'max'], no_spatial=False):
        super(CBAM, self).__init__()
        self.ChannelGate = ChannelGate(gate_channels, reduction_ratio, pool_types)
        self.no_spatial=no_spatial
        if not no_spatial:
            self.SpatialGate = SpatialGate()
    def forward(self, x):
        x_out = self.ChannelGate(x)
        if not self.no_spatial:
            x_out = self.SpatialGate(x_out)
        return x_out




================================================
FILE: scripts/models/discriminator.py
================================================
import numpy as np
import functools
import math
import torch
from torch.autograd import Variable
import torch.nn.functional as F
from torch import nn
from torch import Tensor
from torch.nn import Parameter
from scripts.utils.model_init import *
from torch.optim.optimizer import Optimizer, required


__all__ = ['patchgan','sngan','maskedsngan']


class SNCoXvWithActivation(torch.nn.Module):
    """
    SN convolution for spetral normalization conv
    """
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, activation=torch.nn.LeakyReLU(0.2, inplace=True)):
        super(SNCoXvWithActivation, self).__init__()
        self.conv2d = torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias)
        self.conv2d = torch.nn.utils.spectral_norm(self.conv2d)
        self.activation = activation
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight)
    def forward(self, input):
        x = self.conv2d(input)
        if self.activation is not None:
            return self.activation(x)
        else:
            return x

def l2normalize(v, eps=1e-12):
    return v / (v.norm() + eps)


class SpectralNorm(nn.Module):
    def __init__(self, module, name='weight', power_iterations=1):
        super(SpectralNorm, self).__init__()
        self.module = module
        self.name = name
        self.power_iterations = power_iterations
        if not self._made_params():
            self._make_params()

    def _update_u_v(self):
        u = getattr(self.module, self.name + "_u")
        v = getattr(self.module, self.name + "_v")
        w = getattr(self.module, self.name + "_bar")

        height = w.data.shape[0]
        for _ in range(self.power_iterations):
            v.data = l2normalize(torch.mv(torch.t(w.view(height,-1).data), u.data))
            u.data = l2normalize(torch.mv(w.view(height,-1).data, v.data))

        # sigma = torch.dot(u.data, torch.mv(w.view(height,-1).data, v.data))
        sigma = u.dot(w.view(height, -1).mv(v))
        setattr(self.module, self.name, w / sigma.expand_as(w))

    def _made_params(self):
        try:
            u = getattr(self.module, self.name + "_u")
            v = getattr(self.module, self.name + "_v")
            w = getattr(self.module, self.name + "_bar")
            return True
        except AttributeError:
            return False


    def _make_params(self):
        w = getattr(self.module, self.name)

        height = w.data.shape[0]
        width = w.view(height, -1).data.shape[1]

        u = Parameter(w.data.new(height).normal_(0, 1), requires_grad=False)
        v = Parameter(w.data.new(width).normal_(0, 1), requires_grad=False)
        u.data = l2normalize(u.data)
        v.data = l2normalize(v.data)
        w_bar = Parameter(w.data)

        del self.module._parameters[self.name]

        self.module.register_parameter(self.name + "_u", u)
        self.module.register_parameter(self.name + "_v", v)
        self.module.register_parameter(self.name + "_bar", w_bar)


    def forward(self, *args):
        self._update_u_v()
        return self.module.forward(*args)


def get_pad(in_,  ksize, stride, atrous=1):
    out_ = np.ceil(float(in_)/stride)
    return int(((out_ - 1) * stride + atrous*(ksize-1) + 1 - in_)/2)

class SNDiscriminator(nn.Module):
    def __init__(self,channel=6):
        super(SNDiscriminator, self).__init__()
        cnum = 32
        self.discriminator_net = nn.Sequential(
            SNCoXvWithActivation(channel, 2*cnum, 4, 2, padding=get_pad(256, 5, 2)),
            SNCoXvWithActivation(2*cnum, 4*cnum, 4, 2, padding=get_pad(128, 5, 2)), 
            SNCoXvWithActivation(4*cnum, 8*cnum, 4, 2, padding=get_pad(64, 5, 2)),
            SNCoXvWithActivation(8*cnum, 8*cnum, 4, 2, padding=get_pad(32, 5, 2)),
            SNCoXvWithActivation(8*cnum, 8*cnum, 4, 2, padding=get_pad(16, 5, 2)), # 8*8*256
            # SNConvWithActivation(8*cnum, 8*cnum, 4, 2, padding=get_pad(8, 5, 2)), # 4*4*256
            # SNConvWithActivation(8*cnum, 8*cnum, 4, 2, padding=get_pad(4, 5, 2)), # 2*2*256
        )
        # self.linear = nn.Linear(2*2*256,1)

    def forward(self, img_A, img_B):
        # Concatenate image and condition image by channels to produce input
        img_input = torch.cat((img_A, img_B), 1)
        x = self.discriminator_net(img_input)
        # x = x.view((x.size(0),-1))
        # x = self.linear(x)
        return x

class Discriminator(nn.Module):
    def __init__(self, in_channels=3):
        super(Discriminator, self).__init__()

        def discriminator_block(in_filters, out_filters, normalization=True):
            """Returns downsampling layers of each discriminator block"""
            layers = [nn.Conv2d(in_filters, out_filters, 4, stride=2, padding=1)]
            if normalization:
                layers.append(nn.InstanceNorm2d(out_filters))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        self.model = nn.Sequential(
            *discriminator_block(in_channels*2, 64, normalization=False),
            *discriminator_block(64, 128),
            *discriminator_block(128, 256),
            *discriminator_block(256, 512),
            nn.ZeroPad2d((1, 0, 1, 0)),
            nn.Conv2d(512, 1, 4, padding=1, bias=False)
        )

    def forward(self, img_A, img_B):
        # Concatenate image and condition image by channels to produce input
        img_input = torch.cat((img_A, img_B), 1)
        return self.model(img_input)


def patchgan():
    model = Discriminator()
    model.apply(weights_init_kaiming)
    return model

def sngan():
    model = SNDiscriminator()
    model.apply(weights_init_kaiming)
    return model

def maskedsngan():
    model = SNDiscriminator(channel=7)
    model.apply(weights_init_kaiming)
    return model

================================================
FILE: scripts/models/rasc.py
================================================


import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import math

from scripts.utils.model_init import *
from scripts.models.vgg import Vgg16
from scripts.models.blocks import *


class CAWapper(nn.Module):
    """docstring for SENet"""

    def __init__(self, channel, type_of_connection=BasicLearningBlock):
        super(CAWapper, self).__init__()
        self.attention = ContextualAttention(ksize=3, stride=1, rate=2, fuse_k=3, softmax_scale=10, fuse=True, use_cuda=True)

    def forward(self, feature, mask):
        _, _, w, _ = feature.size()
        _, _, mw, _ = mask.size()
        # binaryfiy
        # selected the feature from the background as the additional feature to masked splicing feature.
        mask = torch.round(F.avg_pool2d(mask, 2, stride=mw//w))

        result = self.attention(feature,mask)

        return result


class NLWapper(nn.Module):
    """docstring for SENet"""

    def __init__(self, channel, type_of_connection=BasicLearningBlock):
        super(NLWapper, self).__init__()
        self.attention = NONLocalBlock2D(channel)

    def forward(self, feature, mask):
        _, _, w, _ = feature.size()
        _, _, mw, _ = mask.size()
        # binaryfiy
        # selected the feature from the background as the additional feature to masked splicing feature.
        # mask = torch.round(F.avg_pool2d(mask, 2, stride=mw//w))

        result = self.attention(feature)

        return result

class SENet(nn.Module):
    """docstring for SENet"""
    def __init__(self,channel,type_of_connection=BasicLearningBlock):
        super(SENet, self).__init__()
        self.attention = SEBlock(channel,16)

    def forward(self,feature,mask):
        _,_,w,_ = feature.size()
        _,_,mw,_ = mask.size()
        # binaryfiy
        # selected the feature from the background as the additional feature to masked splicing feature.
        mask = torch.round(F.avg_pool2d(mask,2,stride=mw//w))

        result = self.attention(feature) 
        
        return result

class CBAMConnect(nn.Module):
    def __init__(self,channel):
        super(CBAMConnect, self).__init__()
        self.attention = CBAM(channel)

    def forward(self,feature,mask):
        results = self.attention(feature)
        return results



class RASC(nn.Module):
    def __init__(self,channel,type_of_connection=BasicLearningBlock):
        super(RASC, self).__init__()
        self.connection = type_of_connection(channel)
        self.background_attention = GlobalAttentionModule(channel,16)
        self.mixed_attention = GlobalAttentionModule(channel,16)
        self.spliced_attention = GlobalAttentionModule(channel,16)
        self.gaussianMask = GaussianSmoothing(1,5,1)

    def forward(self,feature,mask):
        _,_,w,_ = feature.size()
        _,_,mw,_ = mask.size()
        # binaryfiy
        # selected the feature from the background as the additional feature to masked splicing feature.
        if w != mw:
            mask = torch.round(F.avg_pool2d(mask,2,stride=mw//w))
        reverse_mask = -1*(mask-1)
        # here we add gaussin filter to mask and reverse_mask for better harimoization of edges.

        mask = self.gaussianMask(F.pad(mask,(2,2,2,2),mode='reflect'))
        reverse_mask = self.gaussianMask(F.pad(reverse_mask,(2,2,2,2),mode='reflect'))


        background = self.background_attention(feature) * reverse_mask
        selected_feature = self.mixed_attention(feature)
        spliced_feature = self.spliced_attention(feature) 
        spliced = ( self.connection(spliced_feature) + selected_feature ) * mask
        return background + spliced    


class UNO(nn.Module):
    def __init__(self,channel):
        super(UNO, self).__init__()

    def forward(self,feature,_m):
        return feature 


class URASC(nn.Module):
    def __init__(self,channel,type_of_connection=BasicLearningBlock):
        super(URASC, self).__init__()
        self.connection = type_of_connection(channel)
        self.background_attention = GlobalAttentionModule(channel,16)
        self.mixed_attention = GlobalAttentionModule(channel,16)
        self.spliced_attention = GlobalAttentionModule(channel,16)
        self.mask_attention = SpatialAttentionModule(channel,16)

    def forward(self,feature, m=None):
        _,_,w,_ = feature.size()
      
        mask, reverse_mask = self.mask_attention(feature)

        background = self.background_attention(feature) * reverse_mask
        selected_feature = self.mixed_attention(feature)
        spliced_feature = self.spliced_attention(feature) 
        spliced = ( self.connection(spliced_feature) + selected_feature ) * mask
        return background + spliced  


class MaskedURASC(nn.Module):
    def __init__(self,channel,type_of_connection=BasicLearningBlock):
        super(MaskedURASC, self).__init__()
        self.connection = type_of_connection(channel)
        self.background_attention = GlobalAttentionModule(channel,16)
        self.mixed_attention = GlobalAttentionModule(channel,16)
        self.spliced_attention = GlobalAttentionModule(channel,16)
        self.mask_attention = SpatialAttentionModule(channel,16)

    def forward(self,feature):
        _,_,w,_ = feature.size()
      
        mask, reverse_mask = self.mask_attention(feature)

        background = self.background_attention(feature) * reverse_mask
        selected_feature = self.mixed_attention(feature)
        spliced_feature = self.spliced_attention(feature) 
        spliced = ( self.connection(spliced_feature) + selected_feature ) * mask
        return background + spliced, mask



================================================
FILE: scripts/models/sa_resunet.py
================================================

import torch
import torch.nn as nn
import torch.nn.functional as F
from scripts.models.blocks import SEBlock
from scripts.models.rasc import *
from scripts.models.unet import UnetGenerator,MinimalUnetV2

def weight_init(m):
    if isinstance(m, nn.Conv2d):
        nn.init.xavier_normal_(m.weight)
        nn.init.constant_(m.bias, 0)

def reset_params(model):
    for i, m in enumerate(model.modules()):
        weight_init(m)


def conv3x3(in_channels, out_channels, stride=1,
            padding=1, bias=True, groups=1):
    return nn.Conv2d(
        in_channels,
        out_channels,
        kernel_size=3,
        stride=stride,
        padding=padding,
        bias=bias,
        groups=groups)


def up_conv2x2(in_channels, out_channels, transpose=True):
    if transpose:
        return nn.ConvTranspose2d(
            in_channels,
            out_channels,
            kernel_size=2,
            stride=2)
    else:
        return nn.Sequential(
            nn.Upsample(mode='bilinear', scale_factor=2),
            conv1x1(in_channels, out_channels))


def conv1x1(in_channels, out_channels, groups=1):
    return nn.Conv2d(
        in_channels,
        out_channels,
        kernel_size=1,
        groups=groups,
        stride=1)


class UpCoXvD(nn.Module):

    def __init__(self, in_channels, out_channels, blocks, residual=True,norm=nn.BatchNorm2d, act=F.relu,batch_norm=True, transpose=True,concat=True,use_att=False):
        super(UpCoXvD, self).__init__()
        self.concat = concat
        self.residual = residual
        self.batch_norm = batch_norm
        self.bn = None
        self.conv2 = []
        self.use_att = use_att
        self.up_conv = up_conv2x2(in_channels, out_channels, transpose=transpose)
        self.norm0 = norm(out_channels)
        
        if self.use_att:
            self.s2am = RASC(2 * out_channels)
        else:
            self.s2am = None

        if self.concat:
            self.conv1 = conv3x3(2 * out_channels, out_channels)
            self.norm1 = norm(out_channels , out_channels)
        else:
            self.conv1 = conv3x3(out_channels, out_channels)
            self.norm1 = norm(out_channels , out_channels)

        for _ in range(blocks):
            self.conv2.append(conv3x3(out_channels, out_channels))
        if self.batch_norm:
            self.bn = []
            for _ in range(blocks):
                self.bn.append(norm(out_channels))
            self.bn = nn.ModuleList(self.bn)
        self.conv2 = nn.ModuleList(self.conv2)
        self.act = act

    def forward(self, from_up, from_down, mask=None,se=None):
        from_up = self.act(self.norm0(self.up_conv(from_up)))
        if self.concat:
            x1 = torch.cat((from_up, from_down), 1)
        else:
            if from_down is not None:
                x1 = from_up + from_down
            else:
                x1 = from_up

        if self.use_att:
            x1 = self.s2am(x1,mask)
        
        x1 = self.act(self.norm1(self.conv1(x1)))
        x2 = None
        for idx, conv in enumerate(self.conv2):
            x2 = conv(x1)
            if self.batch_norm:
                x2 = self.bn[idx](x2)
            
            if (se is not None) and (idx == len(self.conv2) - 1): # last 
                x2 = se(x2)

            if self.residual:
                x2 = x2 + x1
            x2 = self.act(x2)
            x1 = x2
        return x2


class DownCoXvD(nn.Module):

    def __init__(self, in_channels, out_channels, blocks, pooling=True, norm=nn.BatchNorm2d,act=F.relu,residual=True, batch_norm=True):
        super(DownCoXvD, self).__init__()
        self.pooling = pooling
        self.residual = residual
        self.batch_norm = batch_norm
        self.bn = None
        self.pool = None
        self.conv1 = conv3x3(in_channels, out_channels)
        self.norm1 = norm(out_channels)

        self.conv2 = []
        for _ in range(blocks):
            self.conv2.append(conv3x3(out_channels, out_channels))
        if self.batch_norm:
            self.bn = []
            for _ in range(blocks):
                self.bn.append(norm(out_channels))
            self.bn = nn.ModuleList(self.bn)
        if self.pooling:
            self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv2 = nn.ModuleList(self.conv2)
        self.act = act

    def __call__(self, x):
        return self.forward(x)

    def forward(self, x):
        x1 = self.act(self.norm1(self.conv1(x)))
        x2 = None
        for idx, conv in enumerate(self.conv2):
            x2 = conv(x1)
            if self.batch_norm:
                x2 = self.bn[idx](x2)
            if self.residual:
                x2 = x2 + x1
            x2 = self.act(x2)
            x1 = x2
        before_pool = x2
        if self.pooling:
            x2 = self.pool(x2)
        return x2, before_pool

class UnetDecoderD(nn.Module):
    def __init__(self, in_channels=512, out_channels=3, norm=nn.BatchNorm2d,act=F.relu, depth=5, blocks=1, residual=True, batch_norm=True,
                 transpose=True, concat=True, is_final=True, use_att=False):
        super(UnetDecoderD, self).__init__()
        self.conv_final = None
        self.up_convs = []
        self.atts = []
        self.use_att = use_att

        outs = in_channels
        for i in range(depth-1): # depth = 1
            ins = outs
            outs = ins // 2
            # 512,256
            # 256,128
            # 128,64
            # 64,32
            up_conv = UpCoXvD(ins, outs, blocks, residual=residual, batch_norm=batch_norm, transpose=transpose,
                              concat=concat, norm=norm, act=act)
            if self.use_att:
                self.atts.append(SEBlock(outs))
            
            self.up_convs.append(up_conv)

        if is_final:
            self.conv_final = conv1x1(outs, out_channels)
        else:
            up_conv = UpCoXvD(outs, out_channels, blocks, residual=residual, batch_norm=batch_norm, transpose=transpose,
                              concat=concat,norm=norm, act=act)
            if self.use_att:
                self.atts.append(SEBlock(out_channels))

            self.up_convs.append(up_conv)
        self.up_convs = nn.ModuleList(self.up_convs)
        self.atts = nn.ModuleList(self.atts)

        reset_params(self)

    def __call__(self, x, encoder_outs=None):
        return self.forward(x, encoder_outs)

    def forward(self, x, encoder_outs=None):
        for i, up_conv in enumerate(self.up_convs):
            before_pool = None
            if encoder_outs is not None:
                before_pool = encoder_outs[-(i+2)]
            x = up_conv(x, before_pool)
            if self.use_att:
                x = self.atts[i](x)

        if self.conv_final is not None:
            x = self.conv_final(x)
        return x


class UnetDecoderDatt(nn.Module):
    def __init__(self, in_channels=512, out_channels=3, depth=5, blocks=1, residual=True, batch_norm=True,
                 transpose=True, concat=True, is_final=True, norm=nn.BatchNorm2d,act=F.relu):
        super(UnetDecoderDatt, self).__init__()
        self.conv_final = None
        self.up_convs = []
        self.im_atts = []
        self.vm_atts = []
        self.mask_atts = []

        outs = in_channels
        for i in range(depth-1): # depth = 5 [0,1,2,3]
            ins = outs
            outs = ins // 2
            # 512,256
            # 256,128
            # 128,64
            # 64,32
            up_conv = UpCoXvD(ins, outs, blocks, residual=residual, batch_norm=batch_norm, transpose=transpose,
                              concat=concat, norm=nn.BatchNorm2d,act=F.relu)
            self.up_convs.append(up_conv)
            self.im_atts.append(SEBlock(outs))
            self.vm_atts.append(SEBlock(outs))
            self.mask_atts.append(SEBlock(outs))
        if is_final:
            self.conv_final = conv1x1(outs, out_channels)
        else:
            up_conv = UpCoXvD(outs, out_channels, blocks, residual=residual, batch_norm=batch_norm, transpose=transpose,
                              concat=concat, norm=nn.BatchNorm2d,act=F.relu)
            self.up_convs.append(up_conv)
            self.im_atts.append(SEBlock(out_channels))
            self.vm_atts.append(SEBlock(out_channels))
            self.mask_atts.append(SEBlock(out_channels))

        self.up_convs = nn.ModuleList(self.up_convs)
        self.im_atts = nn.ModuleList(self.im_atts)
        self.vm_atts = nn.ModuleList(self.vm_atts)
        self.mask_atts = nn.ModuleList(self.mask_atts)

        reset_params(self)

    def forward(self, input, encoder_outs=None):
        # im branch
        x = input
        for i, up_conv in enumerate(self.up_convs):
            before_pool = None
            if encoder_outs is not None:
                before_pool = encoder_outs[-(i+2)]
            x = up_conv(x, before_pool,se=self.im_atts[i])
        x_im = x

        x = input        
        for i, up_conv in enumerate(self.up_convs):
            before_pool = None
            if encoder_outs is not None:
                before_pool = encoder_outs[-(i+2)]
            x = up_conv(x, before_pool, se = self.mask_atts[i])
        x_mask = x

        x = input
        for i, up_conv in enumerate(self.up_convs):
            before_pool = None
            if encoder_outs is not None:
                before_pool = encoder_outs[-(i+2)]
            x = up_conv(x, before_pool, se=self.vm_atts[i])
        x_vm = x

        return x_im,x_mask,x_vm

class UnetEncoderD(nn.Module):

    def __init__(self, in_channels=3, depth=5, blocks=1, start_filters=32, residual=True, batch_norm=True, norm=nn.BatchNorm2d, act=F.relu):
        super(UnetEncoderD, self).__init__()
        self.down_convs = []
        outs = None
        if type(blocks) is tuple:
            blocks = blocks[0]
        for i in range(depth):
            ins = in_channels if i == 0 else outs
            outs = start_filters*(2**i)
            pooling = True if i < depth-1 else False
            down_conv = DownCoXvD(ins, outs, blocks, pooling=pooling, residual=residual, batch_norm=batch_norm, norm=nn.BatchNorm2d, act=F.relu)
            self.down_convs.append(down_conv)
        self.down_convs = nn.ModuleList(self.down_convs)
        reset_params(self)

    def __call__(self, x):
        return self.forward(x)

    def forward(self, x):
        encoder_outs = []
        for d_conv in self.down_convs:
            x, before_pool = d_conv(x)
            encoder_outs.append(before_pool)
        return x, encoder_outs

class ResDown(nn.Module):
    def __init__(self, in_size, out_size, pooling=True, use_att=False):
        super(ResDown, self).__init__()
        self.model = DownCoXvD(in_size, out_size, 3, pooling=pooling)

    def forward(self, x):
        return self.model(x)

class ResUp(nn.Module):
    def __init__(self, in_size, out_size, use_att=False):
        super(ResUp, self).__init__()
        self.model = UpCoXvD(in_size, out_size, 3, use_att=use_att)

    def forward(self, x, skip_input, mask=None):
        return self.model(x,skip_input,mask)

class ResDownNew(nn.Module):
    def __init__(self, in_size, out_size, pooling=True, use_att=False):
        super(ResDownNew, self).__init__()
        self.model = DownCoXvD(in_size, out_size, 3, pooling=pooling, norm=nn.InstanceNorm2d, act=F.leaky_relu)

    def forward(self, x):
        return self.model(x)

class ResUpNew(nn.Module):
    def __init__(self, in_size, out_size, use_att=False):
        super(ResUpNew, self).__init__()
        self.model = UpCoXvD(in_size, out_size, 3, use_att=use_att, norm=nn.InstanceNorm2d)

    def forward(self, x, skip_input, mask=None):
        return self.model(x,skip_input,mask)



class VMSingle(nn.Module):
    def __init__(self, in_channels=3, out_channels=3, down=ResDown, up=ResUp, ngf=32, res=True,use_att=False):
        super(VMSingle, self).__init__()

        self.down1 = down(in_channels, ngf)
        self.down2 = down(ngf, ngf*2)
        self.down3 = down(ngf*2, ngf*4)
        self.down4 = down(ngf*4, ngf*8)
        self.down5 = down(ngf*8, ngf*16, pooling=False)

        self.up1 = up(ngf*16, ngf*8)
        self.up2 = up(ngf*8, ngf*4, use_att=use_att)
        self.up3 = up(ngf*4, ngf*2, use_att=use_att)
        self.up4 = up(ngf*2, ngf*1, use_att=use_att)

        self.im = nn.Conv2d(ngf, 3, 1)
        self.res = res


    def forward(self, input):
        img, mask = input[:,0:3,:,:],input[:,3:4,:,:]
        # U-Net generator with skip connections from encoder to decoder
        x,d1 = self.down1(input) # 128,256
        x,d2 = self.down2(x) # 64,128
        x,d3 = self.down3(x) # 32,64
        x,d4 = self.down4(x) # 16,32
        x,_ = self.down5(x) # 8,16

        x = self.up1(x, d4) # 16
        x = self.up2(x, d3, mask) # 32
        x = self.up3(x, d2, mask) # 64
        x = self.up4(x, d1, mask) # 128
        im = self.im(x)

        return im



class VMSingleS2AM(nn.Module):
    def __init__(self, in_channels=3, out_channels=3, down=ResDown, up=ResUp, ngf=32):
        super(VMSingleS2AM, self).__init__()

        self.down1 = down(in_channels, ngf)
        self.down2 = down(ngf, ngf*2)
        self.down3 = down(ngf*2, ngf*4)
        self.down4 = down(ngf*4, ngf*8)
        self.down5 = down(ngf*8, ngf*16, pooling=False)

        self.up1 = up(ngf*16, ngf*8)
        self.up2 = up(ngf*8, ngf*4)
        self.s2am2 = RASC(ngf*4)
        
        self.up3 = up(ngf*4, ngf*2)
        self.s2am3 = RASC(ngf*2)

        self.up4 = up(ngf*2, ngf*1)
        self.s2am4 = RASC(ngf)

        self.im = nn.Conv2d(ngf, 3, 1)


    def forward(self, input):
        img, mask = input[:,0:3,:,:],input[:,3:4,:,:]
        # U-Net generator with skip connections from encoder to decoder
        x,d1 = self.down1(input) # 128,256
        x,d2 = self.down2(x) # 64,128
        x,d3 = self.down3(x) # 32,64
        x,d4 = self.down4(x) # 16,32
        x,_ = self.down5(x) # 8,16

        x = self.up1(x, d4) # 16
        x = self.up2(x, d3) # 32
        x = self.s2am2(x, mask)

        x = self.up3(x, d2) # 64
        x = self.s2am3(x, mask)

        x = self.up4(x, d1) # 128
        x = self.s2am4(x, mask)
        im = self.im(x)
        return im


class UnetVMS2AMv4(nn.Module):

    def __init__(self, in_channels=3, depth=5, shared_depth=0, use_vm_decoder=False, blocks=1,
                 out_channels_image=3, out_channels_mask=1, start_filters=32, residual=True, batch_norm=True,
                 transpose=True, concat=True, transfer_data=True, long_skip=False, s2am='unet', use_coarser=True,no_stage2=False):
        super(UnetVMS2AMv4, self).__init__()
        self.transfer_data = transfer_data
        self.shared = shared_depth
        self.optimizer_encoder,  self.optimizer_image, self.optimizer_vm = None, None, None
        self.optimizer_mask, self.optimizer_shared = None, None
        if type(blocks) is not tuple:
            blocks = (blocks, blocks, blocks, blocks, blocks)
        if not transfer_data:
            concat = False
        self.encoder = UnetEncoderD(in_channels=in_channels, depth=depth, blocks=blocks[0],
                                    start_filters=start_filters, residual=residual, batch_norm=batch_norm,norm=nn.InstanceNorm2d,act=F.leaky_relu)
        self.image_decoder = UnetDecoderD(in_channels=start_filters * 2 ** (depth - shared_depth - 1),
                                          out_channels=out_channels_image, depth=depth - shared_depth,
                                          blocks=blocks[1], residual=residual, batch_norm=batch_norm,
                                          transpose=transpose, concat=concat,norm=nn.InstanceNorm2d)
        self.mask_decoder = UnetDecoderD(in_channels=start_filters * 2 ** (depth - shared_depth - 1),
                                         out_channels=out_channels_mask, depth=depth - shared_depth,
                                         blocks=blocks[2], residual=residual, batch_norm=batch_norm,
                                         transpose=transpose, concat=concat,norm=nn.InstanceNorm2d)
        self.vm_decoder = None
        if use_vm_decoder:
            self.vm_decoder = UnetDecoderD(in_channels=start_filters * 2 ** (depth - shared_depth - 1),
                                           out_channels=out_channels_image, depth=depth - shared_depth,
                                           blocks=blocks[3], residual=residual, batch_norm=batch_norm,
                                           transpose=transpose, concat=concat,norm=nn.InstanceNorm2d)
        self.shared_decoder = None
        self.use_coarser = use_coarser
        self.long_skip = long_skip
        self.no_stage2 = no_stage2
        self._forward = self.unshared_forward
        if self.shared != 0:
            self._forward = self.shared_forward
            self.shared_decoder = UnetDecoderDatt(in_channels=start_filters * 2 ** (depth - 1),
                                               out_channels=start_filters * 2 ** (depth - shared_depth - 1),
                                               depth=shared_depth, blocks=blocks[4], residual=residual,
                                               batch_norm=batch_norm, transpose=transpose, concat=concat,
                                               is_final=False,norm=nn.InstanceNorm2d)

        if s2am == 'unet':
            self.s2am = UnetGenerator(4,3,is_attention_layer=True,attention_model=RASC,basicblock=MinimalUnetV2)
        elif s2am == 'vm':
            self.s2am = VMSingle(4)
        elif s2am == 'vms2am':
            self.s2am = VMSingleS2AM(4,down=ResDownNew,up=ResUpNew)

    def set_optimizers(self):
        self.optimizer_encoder = torch.optim.Adam(self.encoder.parameters(), lr=0.001)
        self.optimizer_image = torch.optim.Adam(self.image_decoder.parameters(), lr=0.001)
        self.optimizer_mask = torch.optim.Adam(self.mask_decoder.parameters(), lr=0.001)
        self.optimizer_s2am = torch.optim.Adam(self.s2am.parameters(), lr=0.001)

        if self.vm_decoder is not None:
            self.optimizer_vm = torch.optim.Adam(self.vm_decoder.parameters(), lr=0.001)
        if self.shared != 0:
            self.optimizer_shared = torch.optim.Adam(self.shared_decoder.parameters(), lr=0.001)

    def zero_grad_all(self):
        self.optimizer_encoder.zero_grad()
        self.optimizer_image.zero_grad()
        self.optimizer_mask.zero_grad()
        self.optimizer_s2am.zero_grad()
        if self.vm_decoder is not None:
            self.optimizer_vm.zero_grad()
        if self.shared != 0:
            self.optimizer_shared.zero_grad()

    def step_all(self):
        self.optimizer_encoder.step()
        self.optimizer_image.step()
        self.optimizer_mask.step()
        self.optimizer_s2am.step()
        if self.vm_decoder is not None:
            self.optimizer_vm.step()
        if self.shared != 0:
            self.optimizer_shared.step()

    def step_optimizer_image(self):
        self.optimizer_image.step()

    def __call__(self, synthesized):
        return self._forward(synthesized)

    def forward(self, synthesized):
        return self._forward(synthesized)

    def unshared_forward(self, synthesized):
        image_code, before_pool = self.encoder(synthesized)
        if not self.transfer_data:
            before_pool = None
        reconstructed_image = torch.tanh(self.image_decoder(image_code, before_pool))
        reconstructed_mask = torch.sigmoid(self.mask_decoder(image_code, before_pool))
        if self.vm_decoder is not None:
            reconstructed_vm = torch.tanh(self.vm_decoder(image_code, before_pool))
            return reconstructed_image, reconstructed_mask, reconstructed_vm
        return reconstructed_image, reconstructed_mask

    def shared_forward(self, synthesized):
        image_code, before_pool = self.encoder(synthesized)
        if self.transfer_data:
            shared_before_pool = before_pool[- self.shared - 1:]
            unshared_before_pool = before_pool[: - self.shared]
        else:
            before_pool = None
            shared_before_pool = None
            unshared_before_pool = None
        im,mask,vm = self.shared_decoder(image_code, shared_before_pool)
        reconstructed_image = torch.tanh(self.image_decoder(im, unshared_before_pool))
        if self.long_skip:
            reconstructed_image = reconstructed_image + synthesized

        reconstructed_mask = torch.sigmoid(self.mask_decoder(mask, unshared_before_pool))
        if self.vm_decoder is not None:
            reconstructed_vm = torch.tanh(self.vm_decoder(vm, unshared_before_pool))
            if self.long_skip:
                reconstructed_vm = reconstructed_vm + synthesized

        coarser = reconstructed_image * reconstructed_mask + (1-reconstructed_mask)* synthesized
        
        if self.use_coarser:
            refine =  torch.tanh(self.s2am(torch.cat([coarser,reconstructed_mask],dim=1))) + coarser
        elif self.no_stage2:
            refine =  torch.tanh(self.s2am(torch.cat([coarser,reconstructed_mask],dim=1)))
        else:
            refine =  torch.tanh(self.s2am(torch.cat([coarser,reconstructed_mask],dim=1))) + synthesized

        # final = refine * reconstructed_mask + (1-reconstructed_mask)* synthesized
        if self.vm_decoder is not None:
            return [refine, reconstructed_image], reconstructed_mask, reconstructed_vm
        else:
            return [refine, reconstructed_image], reconstructed_mask




================================================
FILE: scripts/models/unet.py
================================================
import torch
import torch.nn as nn
from torch.nn import init
import functools
from scripts.models.blocks import *
from scripts.models.rasc import *


class MinimalUnetV2(nn.Module):
    """docstring for MinimalUnet"""
    def __init__(self, down=None,up=None,submodule=None,attention=None,withoutskip=False,**kwags):
        super(MinimalUnetV2, self).__init__()
        
        self.down = nn.Sequential(*down)
        self.up = nn.Sequential(*up) 
        self.sub = submodule
        self.attention = attention
        self.withoutskip = withoutskip
        self.is_attention = not self.attention == None 
        self.is_sub = not submodule == None 
    
    def forward(self,x,mask=None):
        if self.is_sub: 
            x_up,_ = self.sub(self.down(x),mask)
        else:
            x_up = self.down(x)

        if self.withoutskip: #outer or inner.
            x_out = self.up(x_up)
        else:
            if self.is_attention:
                x_out = (self.attention(torch.cat([x,self.up(x_up)],1),mask),mask)
            else:
                x_out = (torch.cat([x,self.up(x_up)],1),mask)

        return x_out


class MinimalUnet(nn.Module):
    """docstring for MinimalUnet"""
    def __init__(self, down=None,up=None,submodule=None,attention=None,withoutskip=False,**kwags):
        super(MinimalUnet, self).__init__()
        
        self.down = nn.Sequential(*down)
        self.up = nn.Sequential(*up) 
        self.sub = submodule
        self.attention = attention
        self.withoutskip = withoutskip
        self.is_attention = not self.attention == None 
        self.is_sub = not submodule == None 
    
    def forward(self,x,mask=None):
        if self.is_sub: 
            x_up,_ = self.sub(self.down(x),mask)
        else:
            x_up = self.down(x)

        if self.is_attention:
            x = self.attention(x,mask)
        
        if self.withoutskip: #outer or inner.
            x_out = self.up(x_up)
        else:
            x_out = (torch.cat([x,self.up(x_up)],1),mask)

        return x_out


# Defines the submodule with skip connection.
# X -------------------identity---------------------- X
#   |-- downsampling -- |submodule| -- upsampling --|
class UnetSkipConnectionBlock(nn.Module):
    def __init__(self, outer_nc, inner_nc, input_nc=None,
                 submodule=None, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d, use_dropout=False,is_attention_layer=False,
                 attention_model=RASC,basicblock=MinimalUnet,outermostattention=False):
        super(UnetSkipConnectionBlock, self).__init__()
        self.outermost = outermost
        if type(norm_layer) == functools.partial:
            use_bias = norm_layer.func == nn.InstanceNorm2d
        else:
            use_bias = norm_layer == nn.InstanceNorm2d
        if input_nc is None:
            input_nc = outer_nc
        downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4,
                             stride=2, padding=1, bias=use_bias)
        downrelu = nn.LeakyReLU(0.2, True)
        downnorm = norm_layer(inner_nc)
        uprelu = nn.ReLU(True)
        upnorm = norm_layer(outer_nc)


        if outermost:
            upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
                                        kernel_size=4, stride=2,
                                        padding=1)
            down = [downconv]
            up = [uprelu, upconv]
            model = basicblock(down,up,submodule,withoutskip=outermost)
        elif innermost:
            upconv = nn.ConvTranspose2d(inner_nc, outer_nc,
                                        kernel_size=4, stride=2,
                                        padding=1, bias=use_bias)
            down = [downrelu, downconv]
            up = [uprelu, upconv, upnorm]
            model = basicblock(down,up)
        else:
            upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
                                        kernel_size=4, stride=2,
                                        padding=1, bias=use_bias)
            down = [downrelu, downconv, downnorm]
            up = [uprelu, upconv, upnorm]

            if is_attention_layer:
                if MinimalUnetV2.__qualname__ in basicblock.__qualname__  :
                    attention_model = attention_model(input_nc*2)
                else:
                    attention_model = attention_model(input_nc)     
            else:
                attention_model = None
                
            if use_dropout:
                model = basicblock(down,up.append(nn.Dropout(0.5)),submodule,attention_model,outermostattention=outermostattention)
            else:
                model = basicblock(down,up,submodule,attention_model,outermostattention=outermostattention)

        self.model = model


    def forward(self, x,mask=None):
        # build the mask for attention use
        return self.model(x,mask)
            
class UnetGenerator(nn.Module):
    def __init__(self, input_nc, output_nc, num_downs=8, ngf=64,norm_layer=nn.BatchNorm2d, use_dropout=False,
                 is_attention_layer=False,attention_model=RASC,use_inner_attention=False,basicblock=MinimalUnet):
        super(UnetGenerator, self).__init__()

        # 8 for 256x256
        # 9 for 512x512
        # construct unet structure
        self.need_mask = not input_nc == output_nc

        unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True,basicblock=basicblock) # 1
        for i in range(num_downs - 5): #3 times
            unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer, use_dropout=use_dropout,is_attention_layer=use_inner_attention,attention_model=attention_model,basicblock=basicblock) # 8,4,2
        unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer,is_attention_layer=is_attention_layer,attention_model=attention_model,basicblock=basicblock) #16
        unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, norm_layer=norm_layer,is_attention_layer=is_attention_layer,attention_model=attention_model,basicblock=basicblock) #32
        unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer,is_attention_layer=is_attention_layer,attention_model=attention_model,basicblock=basicblock, outermostattention=True) #64 
        unet_block = UnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, basicblock=basicblock, norm_layer=norm_layer) # 128

        self.model = unet_block

    def forward(self, input):
        if self.need_mask:
            return self.model(input,input[:,3:4,:,:])
        else:
            return self.model(input[:,0:3,:,:],input[:,3:4,:,:])





================================================
FILE: scripts/models/vgg.py
================================================
from collections import namedtuple

import torch
from torchvision import models


class Vgg16(torch.nn.Module):
    def __init__(self, requires_grad=False):
        super(Vgg16, self).__init__()
        vgg_pretrained_features = models.vgg16(pretrained=True).features
        self.slice1 = torch.nn.Sequential()
        self.slice2 = torch.nn.Sequential()
        self.slice3 = torch.nn.Sequential()
        self.slice4 = torch.nn.Sequential()
        self.slice5 = torch.nn.Sequential()
        for x in range(4):
            self.slice1.add_module(str(x), vgg_pretrained_features[x])
        for x in range(4, 9):
            self.slice2.add_module(str(x), vgg_pretrained_features[x])
        for x in range(9, 16):
            self.slice3.add_module(str(x), vgg_pretrained_features[x])
        for x in range(16, 23):
            self.slice4.add_module(str(x), vgg_pretrained_features[x])
        for x in range(23,30):
            self.slice5.add_module(str(x), vgg_pretrained_features[x])
                
        if not requires_grad:
            for param in self.parameters():
                param.requires_grad = False

    def forward(self, X):
        h = self.slice1(X)
        h_relu1_2 = h
        h = self.slice2(h)
        h_relu2_2 = h
        h = self.slice3(h)
        h_relu3_3 = h
        h = self.slice4(h)
        h_relu4_3 = h
        h = self.slice5(h)
        h_relu5_3 = h
        # vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3','relu5_3'])
        # out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3)
        return (h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3)


class Vgg19(torch.nn.Module):
    def __init__(self, requires_grad=False):
        super(Vgg19, self).__init__()
        # vgg_pretrained_features = models.vgg19(pretrained=True).features
        self.vgg_pretrained_features = models.vgg19(pretrained=True).features
        # self.slice1 = torch.nn.Sequential()
        # self.slice2 = torch.nn.Sequential()
        # self.slice3 = torch.nn.Sequential()
        # self.slice4 = torch.nn.Sequential()
        # self.slice5 = torch.nn.Sequential()
        # for x in range(2):
        #     self.slice1.add_module(str(x), vgg_pretrained_features[x])
        # for x in range(2, 7):
        #     self.slice2.add_module(str(x), vgg_pretrained_features[x])
        # for x in range(7, 12):
        #     self.slice3.add_module(str(x), vgg_pretrained_features[x])
        # for x in range(12, 21):
        #     self.slice4.add_module(str(x), vgg_pretrained_features[x])
        # for x in range(21, 30):
        #     self.slice5.add_module(str(x), vgg_pretrained_features[x])
        if not requires_grad:
            for param in self.parameters():
                param.requires_grad = False

    def forward(self, X, indices=None):
        if indices is None:
            indices = [2, 7, 12, 21, 30]
        out = []
        #indices = sorted(indices)
        for i in range(indices[-1]):
            X = self.vgg_pretrained_features[i](X)
            if (i+1) in indices:
                out.append(X)
        
        return out


================================================
FILE: scripts/models/vmu.py
================================================

import torch
import torch.nn as nn
import torch.nn.functional as F
from scripts.models.blocks import SEBlock
from scripts.models.rasc import *
from scripts.models.unet import UnetGenerator,MinimalUnetV2

def weight_init(m):
    if isinstance(m, nn.Conv2d):
        nn.init.xavier_normal_(m.weight)
        nn.init.constant_(m.bias, 0)

def reset_params(model):
    for i, m in enumerate(model.modules()):
        weight_init(m)


def conv3x3(in_channels, out_channels, stride=1,
            padding=1, bias=True, groups=1):
    return nn.Conv2d(
        in_channels,
        out_channels,
        kernel_size=3,
        stride=stride,
        padding=padding,
        bias=bias,
        groups=groups)


def up_conv2x2(in_channels, out_channels, transpose=True):
    if transpose:
        return nn.ConvTranspose2d(
            in_channels,
            out_channels,
            kernel_size=2,
            stride=2)
    else:
        return nn.Sequential(
            nn.Upsample(mode='bilinear', scale_factor=2),
            conv1x1(in_channels, out_channels))


def conv1x1(in_channels, out_channels, groups=1):
    return nn.Conv2d(
        in_channels,
        out_channels,
        kernel_size=1,
        groups=groups,
        stride=1)




class UpCoXvD(nn.Module):

    def __init__(self, in_channels, out_channels, blocks, residual=True, batch_norm=True, transpose=True,concat=True,use_att=False):
        super(UpCoXvD, self).__init__()
        self.concat = concat
        self.residual = residual
        self.batch_norm = batch_norm
        self.bn = None
        self.conv2 = []
        self.use_att = use_att
        self.up_conv = up_conv2x2(in_channels, out_channels, transpose=transpose)
        
        if self.use_att:
            self.s2am = RASC(2 * out_channels)
        else:
            self.s2am = None

        if self.concat:
            self.conv1 = conv3x3(2 * out_channels, out_channels)
        else:
            self.conv1 = conv3x3(out_channels, out_channels)
        for _ in range(blocks):
            self.conv2.append(conv3x3(out_channels, out_channels))
        if self.batch_norm:
            self.bn = []
            for _ in range(blocks):
                self.bn.append(nn.BatchNorm2d(out_channels))
            self.bn = nn.ModuleList(self.bn)
        self.conv2 = nn.ModuleList(self.conv2)

    def forward(self, from_up, from_down, mask=None):
        from_up = self.up_conv(from_up)
        if self.concat:
            x1 = torch.cat((from_up, from_down), 1)
        else:
            if from_down is not None:
                x1 = from_up + from_down
            else:
                x1 = from_up

        if self.use_att:
            x1 = self.s2am(x1,mask)

        x1 = F.relu(self.conv1(x1))
        x2 = None
        for idx, conv in enumerate(self.conv2):
            x2 = conv(x1)
            if self.batch_norm:
                x2 = self.bn[idx](x2)
            if self.residual:
                x2 = x2 + x1
            x2 = F.relu(x2)
            x1 = x2
        return x2


class DownCoXvD(nn.Module):

    def __init__(self, in_channels, out_channels, blocks, pooling=True, residual=True, batch_norm=True):
        super(DownCoXvD, self).__init__()
        self.pooling = pooling
        self.residual = residual
        self.batch_norm = batch_norm
        self.bn = None
        self.pool = None
        self.conv1 = conv3x3(in_channels, out_channels)
        self.conv2 = []
        for _ in range(blocks):
            self.conv2.append(conv3x3(out_channels, out_channels))
        if self.batch_norm:
            self.bn = []
            for _ in range(blocks):
                self.bn.append(nn.BatchNorm2d(out_channels))
            self.bn = nn.ModuleList(self.bn)
        if self.pooling:
            self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv2 = nn.ModuleList(self.conv2)

    def __call__(self, x):
        return self.forward(x)

    def forward(self, x):
        x1 = F.relu(self.conv1(x))
        x2 = None
        for idx, conv in enumerate(self.conv2):
            x2 = conv(x1)
            if self.batch_norm:
                x2 = self.bn[idx](x2)
            if self.residual:
                x2 = x2 + x1
            x2 = F.relu(x2)
            x1 = x2
        before_pool = x2
        if self.pooling:
            x2 = self.pool(x2)
        return x2, before_pool

class UnetDecoderD(nn.Module):
    def __init__(self, in_channels=512, out_channels=3, depth=5, blocks=1, residual=True, batch_norm=True,
                 transpose=True, concat=True, is_final=True):
        super(UnetDecoderD, self).__init__()
        self.conv_final = None
        self.up_convs = []
        outs = in_channels
        for i in range(depth-1):
            ins = outs
            outs = ins // 2
            # 512,256
            # 256,128
            # 128,64
            # 64,32
            up_conv = UpCoXvD(ins, outs, blocks, residual=residual, batch_norm=batch_norm, transpose=transpose,
                              concat=concat)
            self.up_convs.append(up_conv)
        if is_final:
            self.conv_final = conv1x1(outs, out_channels)
        else:
            up_conv = UpCoXvD(outs, out_channels, blocks, residual=residual, batch_norm=batch_norm, transpose=transpose,
                              concat=concat)
            self.up_convs.append(up_conv)
        self.up_convs = nn.ModuleList(self.up_convs)
        reset_params(self)

    def __call__(self, x, encoder_outs=None):
        return self.forward(x, encoder_outs)

    def forward(self, x, encoder_outs=None):
        for i, up_conv in enumerate(self.up_convs):
            before_pool = None
            if encoder_outs is not None:
                before_pool = encoder_outs[-(i+2)]
            x = up_conv(x, before_pool)
        if self.conv_final is not None:
            x = self.conv_final(x)
        return x


class UnetEncoderD(nn.Module):

    def __init__(self, in_channels=3, depth=5, blocks=1, start_filters=32, residual=True, batch_norm=True):
        super(UnetEncoderD, self).__init__()
        self.down_convs = []
        outs = None
        if type(blocks) is tuple:
            blocks = blocks[0]
        for i in range(depth):
            ins = in_channels if i == 0 else outs
            outs = start_filters*(2**i)
            pooling = True if i < depth-1 else False
            down_conv = DownCoXvD(ins, outs, blocks, pooling=pooling, residual=residual, batch_norm=batch_norm)
            self.down_convs.append(down_conv)
        self.down_convs = nn.ModuleList(self.down_convs)
        reset_params(self)

    def __call__(self, x):
        return self.forward(x)

    def forward(self, x):
        encoder_outs = []
        for d_conv in self.down_convs:
            x, before_pool = d_conv(x)
            encoder_outs.append(before_pool)
        return x, encoder_outs



class UnetVM(nn.Module):

    def __init__(self, in_channels=3, depth=5, shared_depth=0, use_vm_decoder=False, blocks=1,
                 out_channels_image=3, out_channels_mask=1, start_filters=32, residual=True, batch_norm=True,
                 transpose=True, concat=True, transfer_data=True, long_skip=False):
        super(UnetVM, self).__init__()
        self.transfer_data = transfer_data
        self.shared = shared_depth
        self.optimizer_encoder,  self.optimizer_image, self.optimizer_vm = None, None, None
        self.optimizer_mask, self.optimizer_shared = None, None
        if type(blocks) is not tuple:
            blocks = (blocks, blocks, blocks, blocks, blocks)
        if not transfer_data:
            concat = False
        self.encoder = UnetEncoderD(in_channels=in_channels, depth=depth, blocks=blocks[0],
                                    start_filters=start_filters, residual=residual, batch_norm=batch_norm)
        self.image_decoder = UnetDecoderD(in_channels=start_filters * 2 ** (depth - shared_depth - 1),
                                          out_channels=out_channels_image, depth=depth - shared_depth,
                                          blocks=blocks[1], residual=residual, batch_norm=batch_norm,
                                          transpose=transpose, concat=concat)
        self.mask_decoder = UnetDecoderD(in_channels=start_filters * 2 ** (depth - 1),
                                         out_channels=out_channels_mask, depth=depth,
                                         blocks=blocks[2], residual=residual, batch_norm=batch_norm,
                                         transpose=transpose, concat=concat)
        self.vm_decoder = None
        if use_vm_decoder:
            self.vm_decoder = UnetDecoderD(in_channels=start_filters * 2 ** (depth - shared_depth - 1),
                                           out_channels=out_channels_image, depth=depth - shared_depth,
                                           blocks=blocks[3], residual=residual, batch_norm=batch_norm,
                                           transpose=transpose, concat=concat)
        self.shared_decoder = None
        self.long_skip = long_skip
        self._forward = self.unshared_forward
        if self.shared != 0:
            self._forward = self.shared_forward
            self.shared_decoder = UnetDecoderD(in_channels=start_filters * 2 ** (depth - 1),
                                               out_channels=start_filters * 2 ** (depth - shared_depth - 1),
                                               depth=shared_depth, blocks=blocks[4], residual=residual,
                                               batch_norm=batch_norm, transpose=transpose, concat=concat,
                                               is_final=False)

    def set_optimizers(self):
        self.optimizer_encoder = torch.optim.Adam(self.encoder.parameters(), lr=0.001)
        self.optimizer_image = torch.optim.Adam(self.image_decoder.parameters(), lr=0.001)
        self.optimizer_mask = torch.optim.Adam(self.mask_decoder.parameters(), lr=0.001)
        if self.vm_decoder is not None:
            self.optimizer_vm = torch.optim.Adam(self.vm_decoder.parameters(), lr=0.001)
        if self.shared != 0:
            self.optimizer_shared = torch.optim.Adam(self.shared_decoder.parameters(), lr=0.001)

    def zero_grad_all(self):
        self.optimizer_encoder.zero_grad()
        self.optimizer_image.zero_grad()
        self.optimizer_mask.zero_grad()
        if self.vm_decoder is not None:
            self.optimizer_vm.zero_grad()
        if self.shared != 0:
            self.optimizer_shared.zero_grad()

    def step_all(self):
        self.optimizer_encoder.step()
        self.optimizer_image.step()
        self.optimizer_mask.step()
        if self.vm_decoder is not None:
            self.optimizer_vm.step()
        if self.shared != 0:
            self.optimizer_shared.step()

    def step_optimizer_image(self):
        self.optimizer_image.step()

    def __call__(self, synthesized):
        return self._forward(synthesized)

    def forward(self, synthesized):
        return self._forward(synthesized)

    def unshared_forward(self, synthesized):
        image_code, before_pool = self.encoder(synthesized)
        if not self.transfer_data:
            before_pool = None
        reconstructed_image = torch.tanh(self.image_decoder(image_code, before_pool))
        reconstructed_mask = torch.sigmoid(self.mask_decoder(image_code, before_pool))
        if self.vm_decoder is not None:
            reconstructed_vm = torch.tanh(self.vm_decoder(image_code, before_pool))
            return reconstructed_image, reconstructed_mask, reconstructed_vm
        return reconstructed_image, reconstructed_mask

    def shared_forward(self, synthesized):
        image_code, before_pool = self.encoder(synthesized)
        if self.transfer_data:
            shared_before_pool = before_pool[- self.shared - 1:]
            unshared_before_pool = before_pool[: - self.shared]
        else:
            before_pool = None
            shared_before_pool = None
            unshared_before_pool = None
        x = self.shared_decoder(image_code, shared_before_pool)
        reconstructed_image = torch.tanh(self.image_decoder(x, unshared_before_pool))
        if self.long_skip:
            reconstructed_image = reconstructed_image + synthesized

        reconstructed_mask = torch.sigmoid(self.mask_decoder(image_code, before_pool))
        if self.vm_decoder is not None:
            reconstructed_vm = torch.tanh(self.vm_decoder(x, unshared_before_pool))
            if self.long_skip:
                reconstructed_vm = reconstructed_vm + synthesized
            return reconstructed_image, reconstructed_mask, reconstructed_vm
        return reconstructed_image, reconstructed_mask


================================================
FILE: scripts/utils/__init__.py
================================================
from __future__ import absolute_import

from .evaluation import *
from .imutils import *
from .logger import *
from .misc import *
from .osutils import *
from .transforms import *


================================================
FILE: scripts/utils/evaluation.py
================================================
from __future__ import absolute_import

import math
import numpy as np
import matplotlib.pyplot as plt
from random import randint

from .misc import *
from .transforms import transform, transform_preds

__all__ = ['accuracy', 'AverageMeter']

def get_preds(scores):
    ''' get predictions from score maps in torch Tensor
        return type: torch.LongTensor
    '''
    assert scores.dim() == 4, 'Score maps should be 4-dim'
    maxval, idx = torch.max(scores.view(scores.size(0), scores.size(1), -1), 2)

    maxval = maxval.view(scores.size(0), scores.size(1), 1)
    idx = idx.view(scores.size(0), scores.size(1), 1) + 1

    preds = idx.repeat(1, 1, 2).float()

    preds[:,:,0] = (preds[:,:,0] - 1) % scores.size(3) + 1
    preds[:,:,1] = torch.floor((preds[:,:,1] - 1) / scores.size(2)) + 1

    pred_mask = maxval.gt(0).repeat(1, 1, 2).float()
    preds *= pred_mask
    return preds

def calc_dists(preds, target, normalize):
    preds = preds.float()
    target = target.float()
    dists = torch.zeros(preds.size(1), preds.size(0))
    for n in range(preds.size(0)):
        for c in range(preds.size(1)):
            if target[n,c,0] > 1 and target[n, c, 1] > 1:
                dists[c, n] = torch.dist(preds[n,c,:], target[n,c,:])/normalize[n]
            else:
                dists[c, n] = -1
    return dists

def dist_acc(dists, thr=0.5):
    ''' Return percentage below threshold while ignoring values with a -1 '''
    if dists.ne(-1).sum() > 0:
        return dists.le(thr).eq(dists.ne(-1)).sum()*1.0 / dists.ne(-1).sum()
    else:
        return -1



def accuracy(output, target, thr=0.5):
    ''' Calculate accuracy according to PCK, but uses ground truth heatmap rather than x,y locations
        First value to be returned is average accuracy across 'idxs', followed by individual accuracies
    '''
    # output_mask = torch.gt(output,thr);
    # target_mask = torch.gt(target,thr);
    # equal_mask = torch.eq(output_mask,target_mask);
    # fp_equal_mask = torch.lt(output_mask,target_mask);
    # fn_equal_mask = torch.gt(output_mask,target_mask);


    # tp = torch.sum(equal_mask);
    # fn = torch.sum(fn_equal_mask);
    # fp = torch.sum(fp_equal_mask);

    # return 2*tp / (2*tp+fn+fp)


    if output.dim() > 2:
        v,i = torch.max(output,1);
    else:
        v,i = torch.max(output,1);
    return torch.sum(target.long() == i).float()/target.numel()

def final_preds(output, center, scale, res):
    coords = get_preds(output) # float type

    # pose-processing
    for n in range(coords.size(0)):
        for p in range(coords.size(1)):
            hm = output[n][p]
            px = int(math.floor(coords[n][p][0]))
            py = int(math.floor(coords[n][p][1]))
            if px > 1 and px < res[0] and py > 1 and py < res[1]:
                diff = torch.Tensor([hm[py - 1][px] - hm[py - 1][px - 2], hm[py][px - 1]-hm[py - 2][px - 1]])
                coords[n][p] += diff.sign() * .25
    coords += 0.5
    preds = coords.clone()

    # Transform back
    for i in range(coords.size(0)):
        preds[i] = transform_preds(coords[i], center[i], scale[i], res)

    if preds.dim() < 3:
        preds = preds.view(1, preds.size())

    return preds

    
class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


================================================
FILE: scripts/utils/imutils.py
================================================
from __future__ import absolute_import

import torch
import torch.nn as nn
import numpy as np
import scipy.misc

from .misc import *

def im_to_numpy(img):
    img = to_numpy(img)
    img = np.transpose(img, (1, 2, 0)) # H*W*C
    return img

def im_to_torch(img):
    img = np.transpose(img, (2, 0, 1)) # C*H*W
    img = to_torch(img).float()
    if img.max() > 1:
        img /= 255
    return img

def load_image(img_path):
    # H x W x C => C x H x W
    return im_to_torch(scipy.misc.imread(img_path, mode='RGB'))

def imread_all(img_path):
    return scipy.misc.imread(img_path, mode='RGB')

def load_image_gray(img_path):
    # H x W x C => C x H x W
    x = scipy.misc.imread(img_path, mode='L')
    x = x[:,:,np.newaxis]
    return im_to_torch(x)

def resize(img, owidth, oheight):
    img = im_to_numpy(img)

    if img.shape[2] == 1:
        img = scipy.misc.imresize(img.squeeze(),(oheight,owidth))
        img = img[:,:,np.newaxis]
    else:
        img = scipy.misc.imresize(
                img,
                (oheight, owidth)
            )
    img = im_to_torch(img)
    # print('%f %f' % (img.min(), img.max()))
    return img

# =============================================================================
# Helpful functions generating groundtruth labelmap 
# =============================================================================

def gaussian(shape=(7,7),sigma=1):
    """
    2D gaussian mask - should give the same result as MATLAB's
    fspecial('gaussian',[shape],[sigma])
    """
    m,n = [(ss-1.)/2. for ss in shape]
    y,x = np.ogrid[-m:m+1,-n:n+1]
    h = np.exp( -(x*x + y*y) / (2.*sigma*sigma) )
    h[ h < np.finfo(h.dtype).eps*h.max() ] = 0
    return to_torch(h).float()

def draw_labelmap(img, pt, sigma, type='Gaussian'):
    # Draw a 2D gaussian 
    # Adopted from https://github.com/anewell/pose-hg-train/blob/master/src/pypose/draw.py
    img = to_numpy(img)

    # Check that any part of the gaussian is in-bounds
    ul = [int(pt[0] - 3 * sigma), int(pt[1] - 3 * sigma)]
    br = [int(pt[0] + 3 * sigma + 1), int(pt[1] + 3 * sigma + 1)]
    if (ul[0] >= img.shape[1] or ul[1] >= img.shape[0] or
            br[0] < 0 or br[1] < 0):
        # If not, just return the image as is
        return to_torch(img)

    # Generate gaussian
    size = 6 * sigma + 1
    x = np.arange(0, size, 1, float)
    y = x[:, np.newaxis]
    x0 = y0 = size // 2
    # The gaussian is not normalized, we want the center value to equal 1
    if type == 'Gaussian':
        g = np.exp(- ((x - x0) ** 2 + (y - y0) ** 2) / (2 * sigma ** 2))
    elif type == 'Cauchy':
        g = sigma / (((x - x0) ** 2 + (y - y0) ** 2 + sigma ** 2) ** 1.5)


    # Usable gaussian range
    g_x = max(0, -ul[0]), min(br[0], img.shape[1]) - ul[0]
    g_y = max(0, -ul[1]), min(br[1], img.shape[0]) - ul[1]
    # Image range
    img_x = max(0, ul[0]), min(br[0], img.shape[1])
    img_y = max(0, ul[1]), min(br[1], img.shape[0])

    img[img_y[0]:img_y[1], img_x[0]:img_x[1]] = g[g_y[0]:g_y[1], g_x[0]:g_x[1]]
    return to_torch(img)

# =============================================================================
# Helpful display functions
# =============================================================================

def gauss(x, a, b, c, d=0):
    return a * np.exp(-(x - b)**2 / (2 * c**2)) + d

def color_heatmap(x):
    x = to_numpy(x)
    color = np.zeros((x.shape[0],x.shape[1],3))
    color[:,:,0] = gauss(x, .5, .6, .2) + gauss(x, 1, .8, .3)
    color[:,:,1] = gauss(x, 1, .5, .3)
    color[:,:,2] = gauss(x, 1, .2, .3)
    color[color > 1] = 1
    color = (color * 255).astype(np.uint8)
    return color

def imshow(img):
    npimg = im_to_numpy(img*255).astype(np.uint8)
    plt.imshow(npimg)
    plt.axis('off')

def show_joints(img, pts):
    imshow(img)
    
    for i in range(pts.size(0)):
        if pts[i, 2] > 0:
            plt.plot(pts[i, 0], pts[i, 1], 'yo')
    plt.axis('off')

def show_sample(inputs, target):
    num_sample = inputs.size(0)
    num_joints = target.size(1)
    height = target.size(2)
    width = target.size(3)

    for n in range(num_sample):
        inp = resize(inputs[n], width, height)
        out = inp
        for p in range(num_joints):
            tgt = inp*0.5 + color_heatmap(target[n,p,:,:])*0.5
            out = torch.cat((out, tgt), 2)
        
        imshow(out)
        plt.show()

def sample_with_heatmap(inp, out, num_rows=2, parts_to_show=None):
    inp = to_numpy(inp * 255)
    out = to_numpy(out)

    img = np.zeros((inp.shape[1], inp.shape[2], inp.shape[0]))
    for i in range(3):
        img[:, :, i] = inp[i, :, :]

    if parts_to_show is None:
        parts_to_show = np.arange(out.shape[0])

    # Generate a single image to display input/output pair
    num_cols = int(np.ceil(float(len(parts_to_show)) / num_rows))
    size = img.shape[0] // num_rows

    full_img = np.zeros((img.shape[0], size * (num_cols + num_rows), 3), np.uint8)
    full_img[:img.shape[0], :img.shape[1]] = img

    inp_small = scipy.misc.imresize(img, [size, size])

    # Set up heatmap display for each part
    for i, part in enumerate(parts_to_show):
        part_idx = part
        out_resized = scipy.misc.imresize(out[part_idx], [size, size])
        out_resized = out_resized.astype(float)/255
        out_img = inp_small.copy() * .3
        color_hm = color_heatmap(out_resized)
        out_img += color_hm * .7

        col_offset = (i % num_cols + num_rows) * size
        row_offset = (i // num_cols) * size
        full_img[row_offset:row_offset + size, col_offset:col_offset + size] = out_img

    return full_img

def batch_with_heatmap(inputs, outputs, mean=torch.Tensor([0.5, 0.5, 0.5]), num_rows=2, parts_to_show=None):
    batch_img = []
    for n in range(min(inputs.size(0), 4)):
        inp = inputs[n] + mean.view(3, 1, 1).expand_as(inputs[n])
        batch_img.append(
            sample_with_heatmap(inp.clamp(0, 1), outputs[n], num_rows=num_rows, parts_to_show=parts_to_show)
        )
    return np.concatenate(batch_img)


def normalize_batch(batch):
    # normalize using imagenet mean and std
    mean = batch.new_tensor([0.485, 0.456, 0.406]).view(-1, 1, 1)
    std = batch.new_tensor([0.229, 0.224, 0.225]).view(-1, 1, 1)
    batch = batch/255.0
    return (batch - mean) / std

def show_image_tensor(tensor):
    re = []
    for i in range(tensor.size(0)):
        inp = tensor[i].data.cpu() #w,h,c
        inp = inp.numpy().transpose((1, 2, 0))
        mean = np.array([0.485, 0.456, 0.406])
        std = np.array([0.229, 0.224, 0.225])
        inp = std * inp + mean
        inp = np.clip(inp, 0, 1).transpose((2,0,1))
        re.append(torch.from_numpy(inp).unsqueeze(0))
    return torch.cat(re,0)


def get_jet():
    colormap_int = np.zeros((256, 3), np.uint8)
 
    for i in range(0, 256, 1):
        colormap_int[i, 0] = np.int_(np.round(cm.jet(i)[0] * 255.0))
        colormap_int[i, 1] = np.int_(np.round(cm.jet(i)[1] * 255.0))
        colormap_int[i, 2] = np.int_(np.round(cm.jet(i)[2] * 255.0))

    return colormap_int

def clamp(num, min_value, max_value):
    return max(min(num, max_value), min_value)

def gray2color(gray_array, color_map):
    
    rows, cols = gray_array.shape
    color_array = np.zeros((rows, cols, 3), np.uint8)
 
    for i in range(0, rows):
        for j in range(0, cols):
#             log(256,2) = 8 , log(1,2) = 0 * 8
            color_array[i, j] = color_map[clamp(int(abs(gray_array[i, j])*10),0,255)]
    
    return color_array

class objectview(object):
    def __init__(self, *args, **kwargs):
        d = dict(*args, **kwargs)
        self.__dict__ = d

================================================
FILE: scripts/utils/logger.py
================================================
# A simple torch style logger
# (C) Wei YANG 2017
from __future__ import absolute_import

import os
import sys
import numpy as np
import matplotlib.pyplot as plt

__all__ = ['Logger', 'LoggerMonitor', 'savefig']

def savefig(fname, dpi=None):
    dpi = 150 if dpi == None else dpi
    plt.savefig(fname, dpi=dpi)
    
def plot_overlap(logger, names=None):
    names = logger.names if names == None else names
    numbers = logger.numbers
    for _, name in enumerate(names):
        x = np.arange(len(numbers[name]))
        plt.plot(x, np.asarray(numbers[name]))
    return [logger.title + '(' + name + ')' for name in names]

class Logger(object):
    '''Save training process to log file with simple plot function.'''
    def __init__(self, fpath, title=None, resume=False): 
        self.file = None
        self.resume = resume
        self.title = '' if title == None else title
        if fpath is not None:
            if resume: 
                self.file = open(fpath, 'r') 
                name = self.file.readline()
                self.names = name.rstrip().split('\t')
                self.numbers = {}
                for _, name in enumerate(self.names):
                    self.numbers[name] = []

                for numbers in self.file:
                    numbers = numbers.rstrip().split('\t')
                    for i in range(0, len(numbers)):
                        self.numbers[self.names[i]].append(numbers[i])
                self.file.close()
                self.file = open(fpath, 'a')  
            else:
                self.file = open(fpath, 'w')

    def set_names(self, names):
        if self.resume: 
            pass
        # initialize numbers as empty list
        self.numbers = {}
        self.names = names
        for _, name in enumerate(self.names):
            self.file.write(name)
            self.file.write('\t')
            self.numbers[name] = []
        self.file.write('\n')
        self.file.flush()


    def append(self, numbers):
        assert len(self.names) == len(numbers), 'Numbers do not match names'
        for index, num in enumerate(numbers):
            self.file.write("{0:.6f}".format(num))
            self.file.write('\t')
            self.numbers[self.names[index]].append(num)
        self.file.write('\n')
        self.file.flush()

    def plot(self, names=None):   
        names = self.names if names == None else names
        numbers = self.numbers
        for _, name in enumerate(names):
            x = np.arange(len(numbers[name]))
            plt.plot(x, np.asarray(numbers[name]))
        plt.legend([self.title + '(' + name + ')' for name in names])
        plt.grid(True)

    def close(self):
        if self.file is not None:
            self.file.close()

class LoggerMonitor(object):
    '''Load and visualize multiple logs.'''
    def __init__ (self, paths):
        '''paths is a distionary with {name:filepath} pair'''
        self.loggers = []
        for title, path in paths.items():
            logger = Logger(path, title=title, resume=True)
            self.loggers.append(logger)

    def plot(self, names=None):
        plt.figure()
        plt.subplot(121)
        legend_text = []
        for logger in self.loggers:
            legend_text += plot_overlap(logger, names)
        plt.legend(legend_text, bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.)
        plt.grid(True)
                    
if __name__ == '__main__':
    # # Example
    # logger = Logger('test.txt')
    # logger.set_names(['Train loss', 'Valid loss','Test loss'])

    # length = 100
    # t = np.arange(length)
    # train_loss = np.exp(-t / 10.0) + np.random.rand(length) * 0.1
    # valid_loss = np.exp(-t / 10.0) + np.random.rand(length) * 0.1
    # test_loss = np.exp(-t / 10.0) + np.random.rand(length) * 0.1

    # for i in range(0, length):
    #     logger.append([train_loss[i], valid_loss[i], test_loss[i]])
    # logger.plot()

    # Example: logger monitor
    paths = {
    'resadvnet20':'/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet20/log.txt', 
    'resadvnet32':'/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet32/log.txt',
    'resadvnet44':'/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet44/log.txt',
    }

    field = ['Valid Acc.']

    monitor = LoggerMonitor(paths)
    monitor.plot(names=field)
    savefig('test.eps')

================================================
FILE: scripts/utils/losses.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
from scripts.models.vgg import Vgg19
from torchvision import models
from scripts.utils.misc import resize_to_match
# from pytorch_msssim import SSIM, MS_SSIM
import pytorch_ssim

class WeightedBCE(nn.Module):
    def __init__(self):
        super(WeightedBCE, self).__init__()

    def forward(self, pred, gt):
        eposion = 1e-10
        sigmoid_pred = torch.sigmoid(pred)
        count_pos = torch.sum(gt)*1.0+eposion
        count_neg = torch.sum(1.-gt)*1.0
        beta = count_neg/count_pos
        beta_back = count_pos / (count_pos + count_neg)

        bce1 = nn.BCEWithLogitsLoss(pos_weight=beta)
        loss = beta_back*bce1(pred, gt)

        return loss


def l1_relative(reconstructed, real, mask):
    batch = real.size(0)
    area = torch.sum(mask.view(batch,-1),dim=1)
    reconstructed = reconstructed * mask
    real = real * mask
    
    loss_l1 = torch.abs(reconstructed - real).view(batch, -1)
    loss_l1 = torch.sum(loss_l1, dim=1) / area
    loss_l1 = torch.sum(loss_l1) / batch
    return loss_l1


def is_dic(x):
    return type(x) == type([])

class Losses(nn.Module):
    def __init__(self, argx, device):
        super(Losses, self).__init__()
        self.args = argx

        if self.args.loss_type == 'l1bl2':
            self.outputLoss, self.attLoss, self.wrloss = nn.L1Loss(), nn.BCELoss(), nn.MSELoss()
        elif self.args.loss_type == 'l1wbl2':
            self.outputLoss, self.attLoss, self.wrloss = nn.L1Loss(), WeightedBCE(), nn.MSELoss() 
        elif self.args.loss_type == 'l2wbl2':
            self.outputLoss, self.attLoss, self.wrloss = nn.MSELoss(), WeightedBCE(), nn.MSELoss()
        elif self.args.loss_type == 'l2xbl2':
            self.outputLoss, self.attLoss, self.wrloss = nn.MSELoss(), nn.BCEWithLogitsLoss(), nn.MSELoss()
        else: # l2bl2
            self.outputLoss, self.attLoss, self.wrloss = nn.MSELoss(), nn.BCELoss(), nn.MSELoss()

        if self.args.style_loss > 0:
            self.vggloss = VGGLoss(self.args.sltype).to(device)
        
        if self.args.ssim_loss > 0:
            self.ssimloss =  pytorch_ssim.SSIM().to(device)

        self.outputLoss = self.outputLoss.to(device)
        self.attLoss = self.attLoss.to(device)
        self.wrloss = self.wrloss.to(device)


    def forward(self,imgx,target,attx,mask,wmx,wm):
        pixel_loss,att_loss,wm_loss,vgg_loss,ssim_loss = 0,0,0,0,0

        if is_dic(imgx):

            if self.args.masked:
            # calculate the overall loss and side output
                pixel_loss = self.outputLoss(imgx[0],target) + sum([self.outputLoss(im,resize_to_match(mask,im)*resize_to_match(target,im)) for im in imgx[1:]])
            else:
                pixel_loss =  sum([self.outputLoss(im,resize_to_match(target,im)) for im in imgx])

            if self.args.style_loss > 0:
                vgg_loss = sum([self.vggloss(im,resize_to_match(target,im),resize_to_match(mask,im)) for im in imgx])

            if self.args.ssim_loss > 0:
                ssim_loss = sum([ 1 - self.ssimloss(im,resize_to_match(target,im)) for im in imgx])
        else:

            if self.args.masked:
                pixel_loss = self.outputLoss(imgx,mask*target)
            else:
                pixel_loss =  self.outputLoss(imgx,target)

            if self.args.style_loss > 0:
                vgg_loss = self.vggloss(imgx,target,mask)

            if self.args.ssim_loss > 0:
                ssim_loss = 1 - self.ssimloss(imgx,target)

        if is_dic(attx):
            att_loss =  sum([self.attLoss(at,resize_to_match(mask,at)) for at in attx])
        else:
            att_loss =  self.attLoss(attx, mask)

        if is_dic(wmx):
            wm_loss = sum([self.wrloss(w,resize_to_match(wm,w)) for w in wmx])
        else:
            if self.args.masked:
                wm_loss = self.wrloss(wmx,mask*wm)
            else:
                wm_loss = self.wrloss(wmx, wm)

        return pixel_loss,att_loss,wm_loss,vgg_loss,ssim_loss



def gram_matrix(feat):
    # https://github.com/pytorch/examples/blob/master/fast_neural_style/neural_style/utils.py
    (b, ch, h, w) = feat.size()
    feat = feat.view(b, ch, h * w)
    feat_t = feat.transpose(1, 2)
    gram = torch.bmm(feat, feat_t) / (ch * h * w)
    return gram
    
class MeanShift(nn.Conv2d):
    def __init__(self, data_mean, data_std, data_range=1, norm=True):
        """norm (bool): normalize/denormalize the stats"""
        c = len(data_mean)
        super(MeanShift, self).__init__(c, c, kernel_size=1)
        std = torch.Tensor(data_std)
        self.weight.data = torch.eye(c).view(c, c, 1, 1)
        if norm:
            self.weight.data.div_(std.view(c, 1, 1, 1))
            self.bias.data = -1 * data_range * torch.Tensor(data_mean)
            self.bias.data.div_(std)
        else:
            self.weight.data.mul_(std.view(c, 1, 1, 1))
            self.bias.data = data_range * torch.Tensor(data_mean)
        self.requires_grad = False



def VGGLoss(losstype):
    if losstype == 'vgg':
        return VGGLossA()
    elif losstype == 'vggx':
        return VGGLossX(mask=False)
    elif losstype == 'mvggx':
        return VGGLossX(mask=True)
    elif losstype == 'rvggx':
        return VGGLossX(mask=True,relative=True)
    else:
        raise Exception("error in %s"%losstype)

        

class VGGLossA(nn.Module):
    def __init__(self, vgg=None, weights=None, indices=None, normalize=True):
        super(VGGLossA, self).__init__()        
        if vgg is None:
            self.vgg = Vgg19().cuda()
        else:
            self.vgg = vgg
        self.criterion = nn.L1Loss()
        self.weights = weights or [1.0/2.6, 1.0/4.8, 1.0/3.7, 1.0/5.6, 10/1.5]
        self.indices = indices or [2, 7, 12, 21, 30]
        if normalize:
            self.normalize = MeanShift([0.485, 0.456, 0.406], [0.229, 0.224, 0.225], norm=True).cuda()
        else:
            self.normalize = None

    def forward(self, x, y):
        if self.normalize is not None:
            x = self.normalize(x)
            y = self.normalize(y)
        x_vgg, y_vgg = self.vgg(x, self.indices), self.vgg(y, self.indices)
        loss = 0
        for i in range(len(x_vgg)):
            loss += self.weights[i] * self.criterion(x_vgg[i], y_vgg[i].detach())
        return loss


class VGG16FeatureExtractor(nn.Module):
    def __init__(self):
        super().__init__()
        vgg16 = models.vgg16(pretrained=True)
        self.enc_1 = nn.Sequential(*vgg16.features[:5])
        self.enc_2 = nn.Sequential(*vgg16.features[5:10])
        self.enc_3 = nn.Sequential(*vgg16.features[10:17])

        # fix the encoder
        for i in range(3):
            for param in getattr(self, 'enc_{:d}'.format(i + 1)).parameters():
                param.requires_grad = False

    def forward(self, image):
        results = [image]
        for i in range(3):
            func = getattr(self, 'enc_{:d}'.format(i + 1))
            results.append(func(results[-1]))
        return results[1:]

class VGGLossX(nn.Module):
    def __init__(self, normalize=True, mask=False, relative=False):
        super(VGGLossX, self).__init__()
        
        self.vgg = VGG16FeatureExtractor().cuda()
        self.criterion = nn.L1Loss().cuda() if not relative else l1_relative
        self.use_mask= mask
        self.relative = relative

        if normalize:
            self.normalize = MeanShift([0.485, 0.456, 0.406], [0.229, 0.224, 0.225], norm=True).cuda()
        else:
            self.normalize = None

    def forward(self, x, y, Xmask=None):
        if not self.use_mask:
            mask = torch.ones_like(x)[:,0:1,:,:]
        else:
            mask = Xmask

        if self.normalize is not None:
            x = self.normalize(x)
            y = self.normalize(y)

        x_vgg = self.vgg(x)
        y_vgg = self.vgg(y)
        
        loss = 0
        for i in range(3):
            if self.relative:
                loss += self.criterion(x_vgg[i],y_vgg[i].detach(),resize_to_match(mask,x_vgg[i]))
            else:
                loss += self.criterion(resize_to_match(mask,x_vgg[i])*x_vgg[i],resize_to_match(mask,y_vgg[i])*y_vgg[i].detach())

        return loss


class GANLosses(object):
    """docstring for Loss"""
    def __init__(self, gantype):
        super(GANLosses, self).__init__()        
        self.generator_loss = gen_gan(gantype)
        self.discriminator_loss = dis_gan(gantype)
        self.gantype = gantype

    def g_loss(self,dis_fake):
        if 'hinge' in self.gantype:
            return gen_hinge(dis_fake)
        else:
            return self.generator_loss(dis_fake)

    def d_loss(self,dis_fake,dis_real):
        if 'hinge' in self.gantype:
            return dis_hinge(dis_fake,dis_real)
        else:
            return self.discriminator_loss(dis_fake,dis_real)


class gen_gan(nn.Module):
    def __init__(self,gantype):
        super(gen_gan,self).__init__()
        if gantype == 'lsgan':
            self.criterion = nn.MSELoss()
        elif gantype == 'naive':
            self.criterion = nn.BCEWithLogitsLoss()
        else:
            raise Exception("error gan type")
    
    def forward(self,dis_fake):
        return self.criterion(dis_fake, torch.ones_like(dis_fake))

class dis_gan(nn.Module):
    def __init__(self,gantype):
        super(dis_gan,self).__init__()
        if gantype == 'lsgan':
            self.criterion = nn.MSELoss()
        elif gantype == 'naive':
            self.criterion = nn.BCEWithLogitsLoss()
        else:
            raise Exception("error gan type")
    
    def forward(self,dis_fake,dis_real):
        loss_fake = self.criterion(dis_fake, torch.zeros_like(dis_fake))
        loss_real = self.criterion(dis_real, torch.ones_like(dis_real))
        return loss_fake, loss_real

# def gen_gan(dis_fake):
#     # fake -> 1
#     return F.binary_cross_entropy_with_logits(dis_fake,torch.ones_like(dis_fake))

# def dis_gan(dis_fake,dis_real):
#     # fake -> 0 , real ->1
#     loss_fake = F.binary_cross_entropy_with_logits(dis_fake, torch.zeros_like(dis_real))
#     loss_real = F.binary_cross_entropy_with_logits(dis_real, torch.ones_like(dis_fake))
#     return loss_fake,loss_real 

# def gen_lsgan(dis_fake):
#     loss = F.mse_loss(dis_fake,torch.ones_like(dis_fake)) # 
#     return loss

# def dis_lsgan(dis_fake, dis_real):
#     loss_fake = F.mse_loss(dis_fake, torch.zeros_like(dis_real))
#     loss_real = F.mse_loss(dis_real, torch.ones_like(dis_real))
#     return loss_fake,loss_real

def gen_hinge(dis_fake, dis_real=None):
    return -torch.mean(dis_fake)

def dis_hinge(dis_fake, dis_real):
    loss_fake = torch.mean(torch.relu(1. + dis_fake))
    loss_real = torch.mean(torch.relu(1. - dis_real))
    return loss_fake,loss_real



================================================
FILE: scripts/utils/misc.py
================================================
from __future__ import absolute_import

import os
import shutil
import torch 
import math
import numpy as np
import scipy.io
import matplotlib.pyplot as plt
import torch.nn.functional as F

def to_numpy(tensor):
    if torch.is_tensor(tensor):
        return tensor.cpu().numpy()
    elif type(tensor).__module__ != 'numpy':
        raise ValueError("Cannot convert {} to numpy array"
                         .format(type(tensor)))
    return tensor

def resize_to_match(fm,to):
    # just use interpolate
    # [1,3] = (h,w)
    return F.interpolate(fm,to.size()[-2:],mode='bilinear',align_corners=False)

def to_torch(ndarray):
    if type(ndarray).__module__ == 'numpy':
        return torch.from_numpy(ndarray)
    elif not torch.is_tensor(ndarray):
        raise ValueError("Cannot convert {} to torch tensor"
                         .format(type(ndarray)))
    return ndarray


def save_checkpoint(machine,filename='checkpoint.pth.tar', snapshot=None):
    is_best = True if machine.best_acc < machine.metric else False

    if is_best:
        machine.best_acc = machine.metric

    state = {
                'epoch': machine.current_epoch + 1,
                'arch': machine.args.arch,
                'state_dict': machine.model.state_dict(),
                'best_acc': machine.best_acc,
                'optimizer' : machine.optimizer.state_dict(),
            }

    filepath = os.path.join(machine.args.checkpoint, filename)
    torch.save(state, filepath)

    if snapshot and state['epoch'] % snapshot == 0:
        shutil.copyfile(filepath, os.path.join(machine.args.checkpoint, 'checkpoint_{}.pth.tar'.format(state.epoch)))
       
    if is_best:
        machine.best_acc = machine.metric
        print('Saving Best Metric with PSNR:%s'%machine.best_acc)
        shutil.copyfile(filepath, os.path.join(machine.args.checkpoint, 'model_best.pth.tar'))
        


def save_pred(preds, checkpoint='checkpoint', filename='preds_valid.mat'):
    preds = to_numpy(preds)
    filepath = os.path.join(checkpoint, filename)
    scipy.io.savemat(filepath, mdict={'preds' : preds})


def adjust_learning_rate(datasets,optimizer, epoch, lr,args):
    """Sets the learning rate to the initial LR decayed by schedule"""
    if epoch in args.schedule:
        lr *= args.gamma
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr
    
    # decay sigma
    for dset in datasets:
        if args.sigma_decay > 0:
            dset.dataset.sigma *=  args.sigma_decay
            dset.dataset.sigma *=  args.sigma_decay

    return lr






================================================
FILE: scripts/utils/model_init.py
================================================


from torch.nn import init


def weights_init_normal(m):
    classname = m.__class__.__name__
    # print(classname)
    if classname.find('Conv') != -1:
        init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('Linear') != -1:
        init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm2d') != -1:
        init.normal_(m.weight.data, 1.0, 0.02)
        init.constant_(m.bias.data, 0.0)


def weights_init_xavier(m):
    classname = m.__class__.__name__
    # print(classname)
    if classname.find('Conv') != -1:
        init.xavier_normal(m.weight.data, gain=0.02)
    elif classname.find('Linear') != -1:
        init.xavier_normal(m.weight.data, gain=0.02)
    # elif classname.find('BatchNorm2d') != -1:
    #     init.normal(m.weight.data, 1.0, 0.02)
    #     init.constant(m.bias.data, 0.0)


def weights_init_kaiming(m):
    classname = m.__class__.__name__
    # print(classname)
    if classname.find('Conv') != -1 and m.weight.requires_grad == True:
        init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
    elif classname.find('Linear') != -1 and m.weight.requires_grad == True:
        init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
    elif classname.find('BatchNorm2d') != -1 and m.weight.requires_grad == True:
        init.normal_(m.weight.data, 1.0, 0.02)
        init.constant_(m.bias.data, 0.0)


def weights_init_orthogonal(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        init.orthogonal(m.weight.data, gain=1)
    elif classname.find('Linear') != -1:
        init.orthogonal(m.weight.data, gain=1)
    # elif classname.find('BatchNorm2d') != -1:
    #     init.normal(m.weight.data, 1.0, 0.02)
    #     init.constant(m.bias.data, 0.0)

================================================
FILE: scripts/utils/osutils.py
================================================
from __future__ import absolute_import

import os
import errno

def mkdir_p(dir_path):
    try:
        os.makedirs(dir_path)
    except OSError as e:
        if e.errno != errno.EEXIST:
            raise

def isfile(fname):
    return os.path.isfile(fname) 

def isdir(dirname):
    return os.path.isdir(dirname)

def join(path, *paths):
    return os.path.join(path, *paths)


================================================
FILE: scripts/utils/parallel.py
================================================
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
## Created by: Hang Zhang, Rutgers University, Email: zhang.hang@rutgers.edu
## Modified by Thomas Wolf, HuggingFace Inc., Email: thomas@huggingface.co
## Copyright (c) 2017-2018
##
## This source code is licensed under the MIT-style license found in the
## LICENSE file in the root directory of this source tree
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++

"""Encoding Data Parallel"""
import threading
import functools
import torch
from torch.autograd import Variable, Function
import torch.cuda.comm as comm
from torch.nn.parallel import DistributedDataParallel
from torch.nn.parallel.data_parallel import DataParallel
from torch.nn.parallel.parallel_apply import get_a_var
from torch.nn.parallel.scatter_gather import gather
from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast

torch_ver = torch.__version__[:3]

__all__ = ['allreduce', 'DataParallelModel', 'DataParallelCriterion',
           'patch_replication_callback']

def allreduce(*inputs):
    """Cross GPU all reduce autograd operation for calculate mean and
    variance in SyncBN.
    """
    return AllReduce.apply(*inputs)

class AllReduce(Function):
    @staticmethod
    def forward(ctx, num_inputs, *inputs):
        ctx.num_inputs = num_inputs
        ctx.target_gpus = [inputs[i].get_device() for i in range(0, len(inputs), num_inputs)]
        inputs = [inputs[i:i + num_inputs]
                 for i in range(0, len(inputs), num_inputs)]
        # sort before reduce sum
        inputs = sorted(inputs, key=lambda i: i[0].get_device())
        results = comm.reduce_add_coalesced(inputs, ctx.target_gpus[0])
        outputs = comm.broadcast_coalesced(results, ctx.target_gpus)
        return tuple([t for tensors in outputs for t in tensors])

    @staticmethod
    def backward(ctx, *inputs):
        inputs = [i.data for i in inputs]
        inputs = [inputs[i:i + ctx.num_inputs]
                 for i in range(0, len(inputs), ctx.num_inputs)]
        results = comm.reduce_add_coalesced(inputs, ctx.target_gpus[0])
        outputs = comm.broadcast_coalesced(results, ctx.target_gpus)
        return (None,) + tuple([Variable(t) for tensors in outputs for t in tensors])


class Reduce(Function):
    @staticmethod
    def forward(ctx, *inputs):
        ctx.target_gpus = [inputs[i].get_device() for i in range(len(inputs))]
        inputs = sorted(inputs, key=lambda i: i.get_device())
        return comm.reduce_add(inputs)

    @staticmethod
    def backward(ctx, gradOutput):
        return Broadcast.apply(ctx.target_gpus, gradOutput)

class DistributedDataParallelModel(DistributedDataParallel):
    """Implements data parallelism at the module level for the DistributedDataParallel module.
    This container parallelizes the application of the given module by
    splitting the input across the specified devices by chunking in the
    batch dimension.
    In the forward pass, the module is replicated on each device,
    and each replica handles a portion of the input. During the backwards pass,
    gradients from each replica are summed into the original module.
    Note that the outputs are not gathered, please use compatible
    :class:`encoding.parallel.DataParallelCriterion`.
    The batch size should be larger than the number of GPUs used. It should
    also be an integer multiple of the number of GPUs so that each chunk is
    the same size (so that each GPU processes the same number of samples).
    Args:
        module: module to be parallelized
        device_ids: CUDA devices (default: all devices)
    Reference:
        Hang Zhang, Kristin Dana, Jianping Shi, Zhongyue Zhang, Xiaogang Wang, Ambrish Tyagi,
        Amit Agrawal. “Context Encoding for Semantic Segmentation.
        *The IEEE Conference on Computer Vision and Pattern Recognition (CVPR) 2018*
    Example::
        >>> net = encoding.nn.DistributedDataParallelModel(model, device_ids=[0, 1, 2])
        >>> y = net(x)
    """
    def gather(self, outputs, output_device):
        return outputs

class DataParallelModel(DataParallel):
    """Implements data parallelism at the module level.

    This container parallelizes the application of the given module by
    splitting the input across the specified devices by chunking in the
    batch dimension.
    In the forward pass, the module is replicated on each device,
    and each replica handles a portion of the input. During the backwards pass,
    gradients from each replica are summed into the original module.
    Note that the outputs are not gathered, please use compatible
    :class:`encoding.parallel.DataParallelCriterion`.

    The batch size should be larger than the number of GPUs used. It should
    also be an integer multiple of the number of GPUs so that each chunk is
    the same size (so that each GPU processes the same number of samples).

    Args:
        module: module to be parallelized
        device_ids: CUDA devices (default: all devices)

    Reference:
        Hang Zhang, Kristin Dana, Jianping Shi, Zhongyue Zhang, Xiaogang Wang, Ambrish Tyagi,
        Amit Agrawal. “Context Encoding for Semantic Segmentation.
        *The IEEE Conference on Computer Vision and Pattern Recognition (CVPR) 2018*

    Example::

        >>> net = encoding.nn.DataParallelModel(model, device_ids=[0, 1, 2])
        >>> y = net(x)
    """
    def gather(self, outputs, output_device):
        return outputs

    def replicate(self, module, device_ids):
        modules = super(DataParallelModel, self).replicate(module, device_ids)
        execute_replication_callbacks(modules)
        return modules


class DataParallelCriterion(DataParallel):
    """
    Calculate loss in multiple-GPUs, which balance the memory usage.
    The targets are splitted across the specified devices by chunking in
    the batch dimension. Please use together with :class:`encoding.parallel.DataParallelModel`.

    Reference:
        Hang Zhang, Kristin Dana, Jianping Shi, Zhongyue Zhang, Xiaogang Wang, Ambrish Tyagi,
        Amit Agrawal. “Context Encoding for Semantic Segmentation.
        *The IEEE Conference on Computer Vision and Pattern Recognition (CVPR) 2018*

    Example::

        >>> net = encoding.nn.DataParallelModel(model, device_ids=[0, 1, 2])
        >>> criterion = encoding.nn.DataParallelCriterion(criterion, device_ids=[0, 1, 2])
        >>> y = net(x)
        >>> loss = criterion(y, target)
    """
    def forward(self, inputs, *targets, **kwargs):
        # input should be already scatterd
        # scattering the targets instead
        if not self.device_ids:
            return self.module(inputs, *targets, **kwargs)
        targets, kwargs = self.scatter(targets, kwargs, self.device_ids)
        if len(self.device_ids) == 1:
            return self.module(inputs, *targets[0], **kwargs[0])
        replicas = self.replicate(self.module, self.device_ids[:len(inputs)])
        outputs = _criterion_parallel_apply(replicas, inputs, targets, kwargs)
        #return Reduce.apply(*outputs) / len(outputs)
        #return self.gather(outputs, self.output_device).mean()
        return self.gather(outputs, self.output_device)


def _criterion_parallel_apply(modules, inputs, targets, kwargs_tup=None, devices=None):
    assert len(modules) == len(inputs)
    assert len(targets) == len(inputs)
    if kwargs_tup:
        assert len(modules) == len(kwargs_tup)
    else:
        kwargs_tup = ({},) * len(modules)
    if devices is not None:
        assert len(modules) == len(devices)
    else:
        devices = [None] * len(modules)

    lock = threading.Lock()
    results = {}
    if torch_ver != "0.3":
        grad_enabled = torch.is_grad_enabled()

    def _worker(i, module, input, target, kwargs, device=None):
        if torch_ver != "0.3":
            torch.set_grad_enabled(grad_enabled)
        if device is None:
            device = get_a_var(input).get_device()
        try:
            with torch.cuda.device(device):
                # this also avoids accidental slicing of `input` if it is a Tensor
                if not isinstance(input, (list, tuple)):
                    input = (input,)
                if not isinstance(target, (list, tuple)):
                    target = (target,)
                output = module(*(input + target), **kwargs)
            with lock:
                results[i] = output
        except Exception as e:
            with lock:
                results[i] = e

    if len(modules) > 1:
        threads = [threading.Thread(target=_worker,
                                    args=(i, module, input, target,
                                          kwargs, device),)
                   for i, (module, input, target, kwargs, device) in
                   enumerate(zip(modules, inputs, targets, kwargs_tup, devices))]

        for thread in threads:
            thread.start()
        for thread in threads:
            thread.join()
    else:
        _worker(0, modules[0], inputs[0], kwargs_tup[0], devices[0])

    outputs = []
    for i in range(len(inputs)):
        output = results[i]
        if isinstance(output, Exception):
            raise output
        outputs.append(output)
    return outputs


###########################################################################
# Adapted from Synchronized-BatchNorm-PyTorch.
# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
#
class CallbackContext(object):
    pass


def execute_replication_callbacks(modules):
    """
    Execute an replication callback `__data_parallel_replicate__` on each module created
    by original replication.

    The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)`

    Note that, as all modules are isomorphism, we assign each sub-module with a context
    (shared among multiple copies of this module on different devices).
    Through this context, different copies can share some information.

    We guarantee that the callback on the master copy (the first copy) will be called ahead
    of calling the callback of any slave copies.
    """
    master_copy = modules[0]
    nr_modules = len(list(master_copy.modules()))
    ctxs = [CallbackContext() for _ in range(nr_modules)]

    for i, module in enumerate(modules):
        for j, m in enumerate(module.modules()):
            if hasattr(m, '__data_parallel_replicate__'):
                m.__data_parallel_replicate__(ctxs[j], i)


def patch_replication_callback(data_parallel):
    """
    Monkey-patch an existing `DataParallel` object. Add the replication callback.
    Useful when you have customized `DataParallel` implementation.

    Examples:
        > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
        > sync_bn = DataParallel(sync_bn, device_ids=[0, 1])
        > patch_replication_callback(sync_bn)
        # this is equivalent to
        > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
        > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
    """

    assert isinstance(data_parallel, DataParallel)

    old_replicate = data_parallel.replicate

    @functools.wraps(old_replicate)
    def new_replicate(module, device_ids):
        modules = old_replicate(module, device_ids)
        execute_replication_callbacks(modules)
        return modules

    data_parallel.replicate = new_replicate

================================================
FILE: scripts/utils/transforms.py
================================================
from __future__ import absolute_import

import os
import numpy as np
import scipy.misc
import matplotlib.pyplot as plt
import torch
import torchvision

from .misc import *
from .imutils import *


def color_normalize(x, mean, std):
    if x.size(0) == 1:
        x = x.repeat(3, x.size(1), x.size(2))

    for t, m, s in zip(x, mean, std):
        t.sub_(m)
    return x


def flip_back(flip_output, dataset='mpii'):
    """
    flip output map
    """
    if dataset ==  'mpii':
        matchedParts = (
            [0,5],   [1,4],   [2,3],
            [10,15], [11,14], [12,13]
        )
    else:
        print('Not supported dataset: ' + dataset)

    # flip output horizontally
    flip_output = fliplr(flip_output.numpy())

    # Change left-right parts
    for pair in matchedParts:
        tmp = np.copy(flip_output[:, pair[0], :, :])
        flip_output[:, pair[0], :, :] = flip_output[:, pair[1], :, :]
        flip_output[:, pair[1], :, :] = tmp

    return torch.from_numpy(flip_output).float()


def shufflelr(x, width, dataset='mpii'):
    """
    flip coords
    """
    if dataset ==  'mpii':
        matchedParts = (
            [0,5],   [1,4],   [2,3],
            [10,15], [11,14], [12,13]
        )
    else:
        print('Not supported dataset: ' + dataset)

    # Flip horizontal
    x[:, 0] = width - x[:, 0]

    # Change left-right parts
    for pair in matchedParts:
        tmp = x[pair[0], :].clone()
        x[pair[0], :] = x[pair[1], :]
        x[pair[1], :] = tmp

    return x


def fliplr(x):
    if x.ndim == 3:
        x = np.transpose(np.fliplr(np.transpose(x, (0, 2, 1))), (0, 2, 1))
    elif x.ndim == 4:
        for i in range(x.shape[0]):
            x[i] = np.transpose(np.fliplr(np.transpose(x[i], (0, 2, 1))), (0, 2, 1))
    return x.astype(float)


def get_transform(center, scale, res, rot=0):
    """
    General image processing functions
    """
    # Generate transformation matrix
    h = 200 * scale
    t = np.zeros((3, 3))
    t[0, 0] = float(res[1]) / h
    t[1, 1] = float(res[0]) / h
    t[0, 2] = res[1] * (-float(center[0]) / h + .5)
    t[1, 2] = res[0] * (-float(center[1]) / h + .5)
    t[2, 2] = 1
    if not rot == 0:
        rot = -rot # To match direction of rotation from cropping
        rot_mat = np.zeros((3,3))
        rot_rad = rot * np.pi / 180
        sn,cs = np.sin(rot_rad), np.cos(rot_rad)
        rot_mat[0,:2] = [cs, -sn]
        rot_mat[1,:2] = [sn, cs]
        rot_mat[2,2] = 1
        # Need to rotate around center
        t_mat = np.eye(3)
        t_mat[0,2] = -res[1]/2
        t_mat[1,2] = -res[0]/2
        t_inv = t_mat.copy()
        t_inv[:2,2] *= -1
        t = np.dot(t_inv,np.dot(rot_mat,np.dot(t_mat,t)))
    return t


def transform(pt, center, scale, res, invert=0, rot=0):
    # Transform pixel location to different reference
    t = get_transform(center, scale, res, rot=rot)
    if invert:
        t = np.linalg.inv(t)
    new_pt = np.array([pt[0] - 1, pt[1] - 1, 1.]).T
    new_pt = np.dot(t, new_pt)
    return new_pt[:2].astype(int) + 1


def transform_preds(coords, center, scale, res):
    # size = coords.size()
    # coords = coords.view(-1, coords.size(-1))
    # print(coords.size())
    for p in range(coords.size(0)):
        coords[p, 0:2] = to_torch(transform(coords[p, 0:2], center, scale, res, 1, 0))
    return coords


def crop(img, center, scale, res, rot=0):
    img = im_to_numpy(img)

    # Upper left point
    ul = np.array(transform([0, 0], center, scale, res, invert=1))
    # Bottom right point
    br = np.array(transform(res, center, scale, res, invert=1))

    # Padding so that when rotated proper amount of context is included
    pad = int(np.linalg.norm(br - ul) / 2 - float(br[1] - ul[1]) / 2)
    if not rot == 0:
        ul -= pad
        br += pad

    new_shape = [br[1] - ul[1], br[0] - ul[0]]
    if len(img.shape) > 2:
        new_shape += [img.shape[2]]
    new_img = np.zeros(new_shape)

    # Range to fill new array
    new_x = max(0, -ul[0]), min(br[0], len(img[0])) - ul[0]
    new_y = max(0, -ul[1]), min(br[1], len(img)) - ul[1]
    # Range to sample from original image
    old_x = max(0, ul[0]), min(len(img[0]), br[0])
    old_y = max(0, ul[1]), min(len(img), br[1])
    new_img[new_y[0]:new_y[1], new_x[0]:new_x[1]] = img[old_y[0]:old_y[1], old_x[0]:old_x[1]]

    if not rot == 0:
        # Remove padding
        new_img = scipy.misc.imrotate(new_img, rot)
        new_img = new_img[pad:-pad, pad:-pad]

    new_img = im_to_torch(scipy.misc.imresize(new_img, res))
    return new_img


def get_right(img,gray=False):
    img = im_to_numpy(img) #H*W*C

    new_img = img[:,0:256,:]

   
    new_img = im_to_torch(new_img)
    if gray == True:
        new_img = new_img[1,:,:];

    return new_img

class NormalizeInverse(torchvision.transforms.Normalize):
    """
    Undoes the normalization and returns the reconstructed images in the input domain.
    """

    def __init__(self, mean, std):
        mean = torch.as_tensor(mean)
        std = torch.as_tensor(std)
        std_inv = 1 / (std + 1e-7)
        mean_inv = -mean * std_inv
        super().__init__(mean=mean_inv, std=std_inv)

    def __call__(self, tensor):
        return super().__call__(tensor.clone())


================================================
FILE: test.py
================================================
from __future__ import print_function, absolute_import

import argparse
import torch

torch.backends.cudnn.benchmark = True

from scripts.utils.misc import save_checkpoint, adjust_learning_rate

import scripts.datasets as datasets
import scripts.machines as machines
from options import Options

def main(args):
    
    val_loader = torch.utils.data.DataLoader(datasets.COCO('val',args),batch_size=args.test_batch, shuffle=False,
        num_workers=args.workers, pin_memory=True)

    data_loaders = (None,val_loader)

    Machine = machines.__dict__[args.machine](datasets=data_loaders, args=args)

    Machine.test()

if __name__ == '__main__':
    parser=Options().init(argparse.ArgumentParser(description='WaterMark Removal'))
    main(parser.parse_args())


================================================
FILE: watermark_synthesis.ipynb
================================================
{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "SAVE ALL THE SETTING\n"
     ]
    }
   ],
   "source": [
    "# watermark synthesis\n",
    "import os \n",
    "import random\n",
    "import shutil\n",
    "from PIL import Image\n",
    "import numpy as np\n",
    "\n",
    "def trans_paste(bg_img,fg_img,mask,box=(0,0)):\n",
    "    fg_img_trans = Image.new(\"RGBA\",bg_img.size)\n",
    "    fg_img_trans.paste(fg_img,box,mask=mask)\n",
    "    new_img = Image.alpha_composite(bg_img,fg_img_trans)\n",
    "    return new_img,fg_img_trans\n",
    "\n",
    "if os.path.isdir('dataset'):\n",
    "    shutil.rmtree('dataset')\n",
    "\n",
    "os.mkdir('dataset')\n",
    "BASE_IMG_DIR = '/Users/oishii/Downloads/val2014/'\n",
    "WATERMARK_DIR = 'logos' #1080 \n",
    "images = sorted([os.path.join(BASE_IMG_DIR,x) for x in os.listdir(BASE_IMG_DIR) if '.jpg' in x])\n",
    "watermarks = sorted([os.path.join(WATERMARK_DIR,x).replace(' ','_') for x in os.listdir(WATERMARK_DIR) if '.png' in x])\n",
    "# rename all the watermark from replace ' ' to '_'\n",
    "\n",
    "random.shuffle(images)\n",
    "random.shuffle(watermarks)\n",
    "\n",
    "train_images = images[:int(len(images)*0.7)]\n",
    "val_images = images[int(len(images)*0.7):int(len(images)*0.8)]\n",
    "test_images = images[int(len(images)*0.8):]\n",
    "\n",
    "train_wms = watermarks[:int(len(watermarks)*0.7)]\n",
    "val_wms = watermarks[int(len(watermarks)*0.7):int(len(watermarks)*0.8)]\n",
    "test_wms = watermarks[int(len(watermarks)*0.8):]\n",
    "\n",
    "# save all the settings to file\n",
    "names = ['train_images','val_images','test_images','train_wms','val_wms','test_wms']\n",
    "lists = [train_images,val_images,test_images,train_wms,val_wms,test_wms]\n",
    "dataset = dict(zip(names, lists))\n",
    "\n",
    "for name,content in dataset.items():\n",
    "    with open('dataset/%s.txt'%name,'w') as f:\n",
    "        f.write(\"\\n\".join(content))\n",
    "\n",
    "print('SAVE ALL THE SETTING')\n",
    "\n",
    "for name, images in dataset.items():\n",
    "    if 'images' not in name:\n",
    "        continue\n",
    "    # for each setting, synthesis the watermark\n",
    "    # for each image, add X(X=6) watermark in differnet position, alpha,\n",
    "    # save the synthesized image, watermark mask, reshaped mask,\n",
    "    save_path = 'dataset/%s/'%name\n",
    "    os.makedirs('%s/image'%(save_path))\n",
    "    os.makedirs('%s/mask'%(save_path))\n",
    "    os.makedirs('%s/wm'%(save_path))\n",
    "    \n",
    "    for img in images:\n",
    "        im = Image.open(img).convert('RGBA')\n",
    "        imw,imh = im.size\n",
    "        \n",
    "        for wmg in random.choices(dataset[name.replace('images','wms')],k=6):\n",
    "            wm = Image.open(wmg.replace('_',' ')).convert(\"RGBA\") # RGBA\n",
    "            # get the mask of wm\n",
    "            # data agumentation of wm\n",
    "            wm = wm.rotate(angle=random.randint(0,360),expand=True) # rotate\n",
    "            \n",
    "            # make sure the \n",
    "            imrw = random.randrange(int(0.4*imw),int(0.8*imw))\n",
    "            imrh = random.randrange(int(0.4*imh),int(0.8*imh))\n",
    "            wmsize = imrh if imrw > imrh else imrw\n",
    "            wm = wm.resize((wmsize,wmsize),Image.BILINEAR)\n",
    "            w,h = wm.size # new size \n",
    "            \n",
    "            box_left = random.randint(0,imw-w)\n",
    "            box_upper = random.randint(0,imh-h)\n",
    "            wmm = wm.copy()\n",
    "            wm.putalpha(random.randint(int(255*0.4),int(255*0.8))) # alpha\n",
    "            \n",
    "            ims,wmc = trans_paste(im,wm,wmm,(box_left,box_upper))\n",
    "            \n",
    "            wmnp = np.array(wmc) # h,w,3\n",
    "            mask = np.sum(wmnp,axis=2)>0\n",
    "            mm = Image.fromarray(np.uint8(mask*255),mode='L')\n",
    "            \n",
    "            identifier = os.path.basename(img).split('.')[0] +'-'+os.path.basename(wmg).split('.')[0] + '.png'\n",
    "            # save \n",
    "            wmc.save('%s/wm/%s'%(save_path,identifier))\n",
    "            ims.save('%s/image/%s'%(save_path,identifier))\n",
    "            mm.save('%s/mask/%s'%(save_path,identifier))\n",
    "            \n",
    "            "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
Download .txt
gitextract_p5y2m_4m/

├── README.md
├── examples/
│   ├── evaluate.sh
│   └── test.sh
├── main.py
├── options.py
├── requirements.txt
├── scripts/
│   ├── __init__.py
│   ├── datasets/
│   │   ├── BIH.py
│   │   ├── COCO.py
│   │   └── __init__.py
│   ├── machines/
│   │   ├── BasicMachine.py
│   │   ├── S2AM.py
│   │   ├── VX.py
│   │   └── __init__.py
│   ├── models/
│   │   ├── __init__.py
│   │   ├── backbone_unet.py
│   │   ├── blocks.py
│   │   ├── discriminator.py
│   │   ├── rasc.py
│   │   ├── sa_resunet.py
│   │   ├── unet.py
│   │   ├── vgg.py
│   │   └── vmu.py
│   └── utils/
│       ├── __init__.py
│       ├── evaluation.py
│       ├── imutils.py
│       ├── logger.py
│       ├── losses.py
│       ├── misc.py
│       ├── model_init.py
│       ├── osutils.py
│       ├── parallel.py
│       └── transforms.py
├── test.py
└── watermark_synthesis.ipynb
Download .txt
SYMBOL INDEX (353 symbols across 26 files)

FILE: main.py
  function main (line 14) | def main(args):

FILE: options.py
  class Options (line 8) | class Options():
    method __init__ (line 10) | def __init__(self):
    method init (line 13) | def init(self, parser):

FILE: scripts/datasets/BIH.py
  class BIH (line 27) | class BIH(data.Dataset):
    method __init__ (line 28) | def __init__(self,train,config=None, sample=[],gan_norm=False):
    method __getitem__ (line 81) | def __getitem__(self, index):
    method __len__ (line 98) | def __len__(self):

FILE: scripts/datasets/COCO.py
  class COCO (line 27) | class COCO(data.Dataset):
    method __init__ (line 28) | def __init__(self,train,config=None, sample=[],gan_norm=False):
    method __getitem__ (line 80) | def __getitem__(self, index):
    method __len__ (line 97) | def __len__(self):

FILE: scripts/machines/BasicMachine.py
  class BasicMachine (line 25) | class BasicMachine(object):
    method __init__ (line 26) | def __init__(self, datasets =(None,None), models = None, args = None, ...
    method train (line 82) | def train(self,epoch):
    method test (line 155) | def test(self, ):
    method validate (line 196) | def validate(self, epoch):
    method resume (line 240) | def resume(self,resume_path):
    method save_checkpoint (line 259) | def save_checkpoint(self,filename='checkpoint.pth.tar', snapshot=None):
    method clean (line 284) | def clean(self):
    method record (line 287) | def record(self,k,v,epoch):
    method flush (line 290) | def flush(self):
    method norm (line 294) | def norm(self,x):
    method denorm (line 300) | def denorm(self,x):

FILE: scripts/machines/S2AM.py
  class S2AM (line 25) | class S2AM(object):
    method __init__ (line 26) | def __init__(self, datasets =(None,None), models = None, args = None, ...
    method train (line 82) | def train(self,epoch):
    method test (line 156) | def test(self, ):
    method validate (line 195) | def validate(self, epoch):
    method resume (line 240) | def resume(self,resume_path):
    method save_checkpoint (line 259) | def save_checkpoint(self,filename='checkpoint.pth.tar', snapshot=None):
    method clean (line 284) | def clean(self):
    method record (line 287) | def record(self,k,v,epoch):
    method flush (line 290) | def flush(self):
    method norm (line 294) | def norm(self,x):
    method denorm (line 300) | def denorm(self,x):

FILE: scripts/machines/VX.py
  class Losses (line 23) | class Losses(nn.Module):
    method __init__ (line 24) | def __init__(self, argx, device, norm_func=None, denorm_func=None):
    method forward (line 49) | def forward(self,pred_ims,target,pred_ms,mask,pred_wms,wm):
  class VX (line 87) | class VX(BasicMachine):
    method __init__ (line 88) | def __init__(self,**kwargs):
    method train (line 94) | def train(self,epoch):
    method validate (line 182) | def validate(self, epoch):
    method test (line 254) | def test(self, ):

FILE: scripts/machines/__init__.py
  function basic (line 6) | def basic(**kwargs):
  function s2am (line 9) | def s2am(**kwargs):
  function vx (line 12) | def vx(**kwargs):

FILE: scripts/models/backbone_unet.py
  function vvv4n (line 19) | def vvv4n(**kwargs):
  function vm3 (line 24) | def vm3(**kwargs):
  function urasc (line 29) | def urasc(**kwargs):
  function rascv2 (line 39) | def rascv2(**kwargs):
  function unet (line 45) | def unet(**kwargs):

FILE: scripts/models/blocks.py
  class BasicLearningBlock (line 15) | class BasicLearningBlock(nn.Module):
    method __init__ (line 17) | def __init__(self,channel):
    method forward (line 24) | def forward(self,feature):
  class GaussianSmoothing (line 30) | class GaussianSmoothing(nn.Module):
    method __init__ (line 43) | def __init__(self, channels, kernel_size, sigma, dim=2):
    method forward (line 85) | def forward(self, input):
  class ChannelPool (line 95) | class ChannelPool(nn.Module):
    method __init__ (line 96) | def __init__(self,types):
    method forward (line 105) | def forward(self, input):
  class SEBlock (line 114) | class SEBlock(nn.Module):
    method __init__ (line 116) | def __init__(self, channel,reducation=16):
    method forward (line 125) | def forward(self,x):
  class GlobalAttentionModule (line 133) | class GlobalAttentionModule(nn.Module):
    method __init__ (line 135) | def __init__(self, channel,reducation=16):
    method forward (line 145) | def forward(self,x):
  class SpatialAttentionModule (line 152) | class SpatialAttentionModule(nn.Module):
    method __init__ (line 154) | def __init__(self, channel,reducation=16):
    method forward (line 164) | def forward(self,x):
  class GlobalAttentionModuleJustSigmoid (line 174) | class GlobalAttentionModuleJustSigmoid(nn.Module):
    method __init__ (line 176) | def __init__(self, channel,reducation=16):
    method forward (line 186) | def forward(self,x):
  class BasicBlock (line 195) | class BasicBlock(nn.Module):
    method __init__ (line 196) | def __init__(self, in_planes, out_planes, kernel_size, stride=1, paddi...
    method forward (line 203) | def forward(self, x):
  class Flatten (line 211) | class Flatten(nn.Module):
    method forward (line 212) | def forward(self, x):
  class ChannelGate (line 215) | class ChannelGate(nn.Module):
    method __init__ (line 216) | def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg...
    method forward (line 226) | def forward(self, x):
  function logsumexp_2d (line 251) | def logsumexp_2d(tensor):
  class ChannelPoolX (line 257) | class ChannelPoolX(nn.Module):
    method forward (line 258) | def forward(self, x):
  class SpatialGate (line 261) | class SpatialGate(nn.Module):
    method __init__ (line 262) | def __init__(self):
    method forward (line 267) | def forward(self, x):
  class CBAM (line 273) | class CBAM(nn.Module):
    method __init__ (line 274) | def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg...
    method forward (line 280) | def forward(self, x):

FILE: scripts/models/discriminator.py
  class SNCoXvWithActivation (line 17) | class SNCoXvWithActivation(torch.nn.Module):
    method __init__ (line 21) | def __init__(self, in_channels, out_channels, kernel_size, stride=1, p...
    method forward (line 29) | def forward(self, input):
  function l2normalize (line 36) | def l2normalize(v, eps=1e-12):
  class SpectralNorm (line 40) | class SpectralNorm(nn.Module):
    method __init__ (line 41) | def __init__(self, module, name='weight', power_iterations=1):
    method _update_u_v (line 49) | def _update_u_v(self):
    method _made_params (line 63) | def _made_params(self):
    method _make_params (line 73) | def _make_params(self):
    method forward (line 92) | def forward(self, *args):
  function get_pad (line 97) | def get_pad(in_,  ksize, stride, atrous=1):
  class SNDiscriminator (line 101) | class SNDiscriminator(nn.Module):
    method __init__ (line 102) | def __init__(self,channel=6):
    method forward (line 116) | def forward(self, img_A, img_B):
  class Discriminator (line 124) | class Discriminator(nn.Module):
    method __init__ (line 125) | def __init__(self, in_channels=3):
    method forward (line 145) | def forward(self, img_A, img_B):
  function patchgan (line 151) | def patchgan():
  function sngan (line 156) | def sngan():
  function maskedsngan (line 161) | def maskedsngan():

FILE: scripts/models/rasc.py
  class CAWapper (line 15) | class CAWapper(nn.Module):
    method __init__ (line 18) | def __init__(self, channel, type_of_connection=BasicLearningBlock):
    method forward (line 22) | def forward(self, feature, mask):
  class NLWapper (line 34) | class NLWapper(nn.Module):
    method __init__ (line 37) | def __init__(self, channel, type_of_connection=BasicLearningBlock):
    method forward (line 41) | def forward(self, feature, mask):
  class SENet (line 52) | class SENet(nn.Module):
    method __init__ (line 54) | def __init__(self,channel,type_of_connection=BasicLearningBlock):
    method forward (line 58) | def forward(self,feature,mask):
  class CBAMConnect (line 69) | class CBAMConnect(nn.Module):
    method __init__ (line 70) | def __init__(self,channel):
    method forward (line 74) | def forward(self,feature,mask):
  class RASC (line 80) | class RASC(nn.Module):
    method __init__ (line 81) | def __init__(self,channel,type_of_connection=BasicLearningBlock):
    method forward (line 89) | def forward(self,feature,mask):
  class UNO (line 110) | class UNO(nn.Module):
    method __init__ (line 111) | def __init__(self,channel):
    method forward (line 114) | def forward(self,feature,_m):
  class URASC (line 118) | class URASC(nn.Module):
    method __init__ (line 119) | def __init__(self,channel,type_of_connection=BasicLearningBlock):
    method forward (line 127) | def forward(self,feature, m=None):
  class MaskedURASC (line 139) | class MaskedURASC(nn.Module):
    method __init__ (line 140) | def __init__(self,channel,type_of_connection=BasicLearningBlock):
    method forward (line 148) | def forward(self,feature):

FILE: scripts/models/sa_resunet.py
  function weight_init (line 9) | def weight_init(m):
  function reset_params (line 14) | def reset_params(model):
  function conv3x3 (line 19) | def conv3x3(in_channels, out_channels, stride=1,
  function up_conv2x2 (line 31) | def up_conv2x2(in_channels, out_channels, transpose=True):
  function conv1x1 (line 44) | def conv1x1(in_channels, out_channels, groups=1):
  class UpCoXvD (line 53) | class UpCoXvD(nn.Module):
    method __init__ (line 55) | def __init__(self, in_channels, out_channels, blocks, residual=True,no...
    method forward (line 88) | def forward(self, from_up, from_down, mask=None,se=None):
  class DownCoXvD (line 118) | class DownCoXvD(nn.Module):
    method __init__ (line 120) | def __init__(self, in_channels, out_channels, blocks, pooling=True, no...
    method __call__ (line 143) | def __call__(self, x):
    method forward (line 146) | def forward(self, x):
  class UnetDecoderD (line 162) | class UnetDecoderD(nn.Module):
    method __init__ (line 163) | def __init__(self, in_channels=512, out_channels=3, norm=nn.BatchNorm2...
    method __call__ (line 200) | def __call__(self, x, encoder_outs=None):
    method forward (line 203) | def forward(self, x, encoder_outs=None):
  class UnetDecoderDatt (line 217) | class UnetDecoderDatt(nn.Module):
    method __init__ (line 218) | def __init__(self, in_channels=512, out_channels=3, depth=5, blocks=1,...
    method forward (line 258) | def forward(self, input, encoder_outs=None):
  class UnetEncoderD (line 286) | class UnetEncoderD(nn.Module):
    method __init__ (line 288) | def __init__(self, in_channels=3, depth=5, blocks=1, start_filters=32,...
    method __call__ (line 303) | def __call__(self, x):
    method forward (line 306) | def forward(self, x):
  class ResDown (line 313) | class ResDown(nn.Module):
    method __init__ (line 314) | def __init__(self, in_size, out_size, pooling=True, use_att=False):
    method forward (line 318) | def forward(self, x):
  class ResUp (line 321) | class ResUp(nn.Module):
    method __init__ (line 322) | def __init__(self, in_size, out_size, use_att=False):
    method forward (line 326) | def forward(self, x, skip_input, mask=None):
  class ResDownNew (line 329) | class ResDownNew(nn.Module):
    method __init__ (line 330) | def __init__(self, in_size, out_size, pooling=True, use_att=False):
    method forward (line 334) | def forward(self, x):
  class ResUpNew (line 337) | class ResUpNew(nn.Module):
    method __init__ (line 338) | def __init__(self, in_size, out_size, use_att=False):
    method forward (line 342) | def forward(self, x, skip_input, mask=None):
  class VMSingle (line 347) | class VMSingle(nn.Module):
    method __init__ (line 348) | def __init__(self, in_channels=3, out_channels=3, down=ResDown, up=Res...
    method forward (line 366) | def forward(self, input):
  class VMSingleS2AM (line 385) | class VMSingleS2AM(nn.Module):
    method __init__ (line 386) | def __init__(self, in_channels=3, out_channels=3, down=ResDown, up=Res...
    method forward (line 408) | def forward(self, input):
  class UnetVMS2AMv4 (line 430) | class UnetVMS2AMv4(nn.Module):
    method __init__ (line 432) | def __init__(self, in_channels=3, depth=5, shared_depth=0, use_vm_deco...
    method set_optimizers (line 480) | def set_optimizers(self):
    method zero_grad_all (line 491) | def zero_grad_all(self):
    method step_all (line 501) | def step_all(self):
    method step_optimizer_image (line 511) | def step_optimizer_image(self):
    method __call__ (line 514) | def __call__(self, synthesized):
    method forward (line 517) | def forward(self, synthesized):
    method unshared_forward (line 520) | def unshared_forward(self, synthesized):
    method shared_forward (line 531) | def shared_forward(self, synthesized):

FILE: scripts/models/unet.py
  class MinimalUnetV2 (line 9) | class MinimalUnetV2(nn.Module):
    method __init__ (line 11) | def __init__(self, down=None,up=None,submodule=None,attention=None,wit...
    method forward (line 22) | def forward(self,x,mask=None):
  class MinimalUnet (line 39) | class MinimalUnet(nn.Module):
    method __init__ (line 41) | def __init__(self, down=None,up=None,submodule=None,attention=None,wit...
    method forward (line 52) | def forward(self,x,mask=None):
  class UnetSkipConnectionBlock (line 72) | class UnetSkipConnectionBlock(nn.Module):
    method __init__ (line 73) | def __init__(self, outer_nc, inner_nc, input_nc=None,
    method forward (line 129) | def forward(self, x,mask=None):
  class UnetGenerator (line 133) | class UnetGenerator(nn.Module):
    method __init__ (line 134) | def __init__(self, input_nc, output_nc, num_downs=8, ngf=64,norm_layer...
    method forward (line 153) | def forward(self, input):

FILE: scripts/models/vgg.py
  class Vgg16 (line 7) | class Vgg16(torch.nn.Module):
    method __init__ (line 8) | def __init__(self, requires_grad=False):
    method forward (line 31) | def forward(self, X):
  class Vgg19 (line 47) | class Vgg19(torch.nn.Module):
    method __init__ (line 48) | def __init__(self, requires_grad=False):
    method forward (line 71) | def forward(self, X, indices=None):

FILE: scripts/models/vmu.py
  function weight_init (line 9) | def weight_init(m):
  function reset_params (line 14) | def reset_params(model):
  function conv3x3 (line 19) | def conv3x3(in_channels, out_channels, stride=1,
  function up_conv2x2 (line 31) | def up_conv2x2(in_channels, out_channels, transpose=True):
  function conv1x1 (line 44) | def conv1x1(in_channels, out_channels, groups=1):
  class UpCoXvD (line 55) | class UpCoXvD(nn.Module):
    method __init__ (line 57) | def __init__(self, in_channels, out_channels, blocks, residual=True, b...
    method forward (line 85) | def forward(self, from_up, from_down, mask=None):
  class DownCoXvD (line 111) | class DownCoXvD(nn.Module):
    method __init__ (line 113) | def __init__(self, in_channels, out_channels, blocks, pooling=True, re...
    method __call__ (line 133) | def __call__(self, x):
    method forward (line 136) | def forward(self, x):
  class UnetDecoderD (line 152) | class UnetDecoderD(nn.Module):
    method __init__ (line 153) | def __init__(self, in_channels=512, out_channels=3, depth=5, blocks=1,...
    method __call__ (line 178) | def __call__(self, x, encoder_outs=None):
    method forward (line 181) | def forward(self, x, encoder_outs=None):
  class UnetEncoderD (line 192) | class UnetEncoderD(nn.Module):
    method __init__ (line 194) | def __init__(self, in_channels=3, depth=5, blocks=1, start_filters=32,...
    method __call__ (line 209) | def __call__(self, x):
    method forward (line 212) | def forward(self, x):
  class UnetVM (line 221) | class UnetVM(nn.Module):
    method __init__ (line 223) | def __init__(self, in_channels=3, depth=5, shared_depth=0, use_vm_deco...
    method set_optimizers (line 262) | def set_optimizers(self):
    method zero_grad_all (line 271) | def zero_grad_all(self):
    method step_all (line 280) | def step_all(self):
    method step_optimizer_image (line 289) | def step_optimizer_image(self):
    method __call__ (line 292) | def __call__(self, synthesized):
    method forward (line 295) | def forward(self, synthesized):
    method unshared_forward (line 298) | def unshared_forward(self, synthesized):
    method shared_forward (line 309) | def shared_forward(self, synthesized):

FILE: scripts/utils/evaluation.py
  function get_preds (line 13) | def get_preds(scores):
  function calc_dists (line 32) | def calc_dists(preds, target, normalize):
  function dist_acc (line 44) | def dist_acc(dists, thr=0.5):
  function accuracy (line 53) | def accuracy(output, target, thr=0.5):
  function final_preds (line 77) | def final_preds(output, center, scale, res):
  class AverageMeter (line 102) | class AverageMeter(object):
    method __init__ (line 104) | def __init__(self):
    method reset (line 107) | def reset(self):
    method update (line 113) | def update(self, val, n=1):

FILE: scripts/utils/imutils.py
  function im_to_numpy (line 10) | def im_to_numpy(img):
  function im_to_torch (line 15) | def im_to_torch(img):
  function load_image (line 22) | def load_image(img_path):
  function imread_all (line 26) | def imread_all(img_path):
  function load_image_gray (line 29) | def load_image_gray(img_path):
  function resize (line 35) | def resize(img, owidth, oheight):
  function gaussian (line 54) | def gaussian(shape=(7,7),sigma=1):
  function draw_labelmap (line 65) | def draw_labelmap(img, pt, sigma, type='Gaussian'):
  function gauss (line 104) | def gauss(x, a, b, c, d=0):
  function color_heatmap (line 107) | def color_heatmap(x):
  function imshow (line 117) | def imshow(img):
  function show_joints (line 122) | def show_joints(img, pts):
  function show_sample (line 130) | def show_sample(inputs, target):
  function sample_with_heatmap (line 146) | def sample_with_heatmap(inp, out, num_rows=2, parts_to_show=None):
  function batch_with_heatmap (line 181) | def batch_with_heatmap(inputs, outputs, mean=torch.Tensor([0.5, 0.5, 0.5...
  function normalize_batch (line 191) | def normalize_batch(batch):
  function show_image_tensor (line 198) | def show_image_tensor(tensor):
  function get_jet (line 211) | def get_jet():
  function clamp (line 221) | def clamp(num, min_value, max_value):
  function gray2color (line 224) | def gray2color(gray_array, color_map):
  class objectview (line 236) | class objectview(object):
    method __init__ (line 237) | def __init__(self, *args, **kwargs):

FILE: scripts/utils/logger.py
  function savefig (line 12) | def savefig(fname, dpi=None):
  function plot_overlap (line 16) | def plot_overlap(logger, names=None):
  class Logger (line 24) | class Logger(object):
    method __init__ (line 26) | def __init__(self, fpath, title=None, resume=False):
    method set_names (line 48) | def set_names(self, names):
    method append (line 62) | def append(self, numbers):
    method plot (line 71) | def plot(self, names=None):
    method close (line 80) | def close(self):
  class LoggerMonitor (line 84) | class LoggerMonitor(object):
    method __init__ (line 86) | def __init__ (self, paths):
    method plot (line 93) | def plot(self, names=None):

FILE: scripts/utils/losses.py
  class WeightedBCE (line 10) | class WeightedBCE(nn.Module):
    method __init__ (line 11) | def __init__(self):
    method forward (line 14) | def forward(self, pred, gt):
  function l1_relative (line 28) | def l1_relative(reconstructed, real, mask):
  function is_dic (line 40) | def is_dic(x):
  class Losses (line 43) | class Losses(nn.Module):
    method __init__ (line 44) | def __init__(self, argx, device):
    method forward (line 70) | def forward(self,imgx,target,attx,mask,wmx,wm):
  function gram_matrix (line 116) | def gram_matrix(feat):
  class MeanShift (line 124) | class MeanShift(nn.Conv2d):
    method __init__ (line 125) | def __init__(self, data_mean, data_std, data_range=1, norm=True):
  function VGGLoss (line 142) | def VGGLoss(losstype):
  class VGGLossA (line 156) | class VGGLossA(nn.Module):
    method __init__ (line 157) | def __init__(self, vgg=None, weights=None, indices=None, normalize=True):
    method forward (line 171) | def forward(self, x, y):
  class VGG16FeatureExtractor (line 182) | class VGG16FeatureExtractor(nn.Module):
    method __init__ (line 183) | def __init__(self):
    method forward (line 195) | def forward(self, image):
  class VGGLossX (line 202) | class VGGLossX(nn.Module):
    method __init__ (line 203) | def __init__(self, normalize=True, mask=False, relative=False):
    method forward (line 216) | def forward(self, x, y, Xmask=None):
  class GANLosses (line 239) | class GANLosses(object):
    method __init__ (line 241) | def __init__(self, gantype):
    method g_loss (line 247) | def g_loss(self,dis_fake):
    method d_loss (line 253) | def d_loss(self,dis_fake,dis_real):
  class gen_gan (line 260) | class gen_gan(nn.Module):
    method __init__ (line 261) | def __init__(self,gantype):
    method forward (line 270) | def forward(self,dis_fake):
  class dis_gan (line 273) | class dis_gan(nn.Module):
    method __init__ (line 274) | def __init__(self,gantype):
    method forward (line 283) | def forward(self,dis_fake,dis_real):
  function gen_hinge (line 307) | def gen_hinge(dis_fake, dis_real=None):
  function dis_hinge (line 310) | def dis_hinge(dis_fake, dis_real):

FILE: scripts/utils/misc.py
  function to_numpy (line 12) | def to_numpy(tensor):
  function resize_to_match (line 20) | def resize_to_match(fm,to):
  function to_torch (line 25) | def to_torch(ndarray):
  function save_checkpoint (line 34) | def save_checkpoint(machine,filename='checkpoint.pth.tar', snapshot=None):
  function save_pred (line 61) | def save_pred(preds, checkpoint='checkpoint', filename='preds_valid.mat'):
  function adjust_learning_rate (line 67) | def adjust_learning_rate(datasets,optimizer, epoch, lr,args):

FILE: scripts/utils/model_init.py
  function weights_init_normal (line 6) | def weights_init_normal(m):
  function weights_init_xavier (line 18) | def weights_init_xavier(m):
  function weights_init_kaiming (line 30) | def weights_init_kaiming(m):
  function weights_init_orthogonal (line 42) | def weights_init_orthogonal(m):

FILE: scripts/utils/osutils.py
  function mkdir_p (line 6) | def mkdir_p(dir_path):
  function isfile (line 13) | def isfile(fname):
  function isdir (line 16) | def isdir(dirname):
  function join (line 19) | def join(path, *paths):

FILE: scripts/utils/parallel.py
  function allreduce (line 27) | def allreduce(*inputs):
  class AllReduce (line 33) | class AllReduce(Function):
    method forward (line 35) | def forward(ctx, num_inputs, *inputs):
    method backward (line 47) | def backward(ctx, *inputs):
  class Reduce (line 56) | class Reduce(Function):
    method forward (line 58) | def forward(ctx, *inputs):
    method backward (line 64) | def backward(ctx, gradOutput):
  class DistributedDataParallelModel (line 67) | class DistributedDataParallelModel(DistributedDataParallel):
    method gather (line 91) | def gather(self, outputs, output_device):
  class DataParallelModel (line 94) | class DataParallelModel(DataParallel):
    method gather (line 124) | def gather(self, outputs, output_device):
    method replicate (line 127) | def replicate(self, module, device_ids):
  class DataParallelCriterion (line 133) | class DataParallelCriterion(DataParallel):
    method forward (line 151) | def forward(self, inputs, *targets, **kwargs):
  function _criterion_parallel_apply (line 166) | def _criterion_parallel_apply(modules, inputs, targets, kwargs_tup=None,...
  class CallbackContext (line 229) | class CallbackContext(object):
  function execute_replication_callbacks (line 233) | def execute_replication_callbacks(modules):
  function patch_replication_callback (line 257) | def patch_replication_callback(data_parallel):

FILE: scripts/utils/transforms.py
  function color_normalize (line 14) | def color_normalize(x, mean, std):
  function flip_back (line 23) | def flip_back(flip_output, dataset='mpii'):
  function shufflelr (line 47) | def shufflelr(x, width, dataset='mpii'):
  function fliplr (line 71) | def fliplr(x):
  function get_transform (line 80) | def get_transform(center, scale, res, rot=0):
  function transform (line 110) | def transform(pt, center, scale, res, invert=0, rot=0):
  function transform_preds (line 120) | def transform_preds(coords, center, scale, res):
  function crop (line 129) | def crop(img, center, scale, res, rot=0):
  function get_right (line 165) | def get_right(img,gray=False):
  class NormalizeInverse (line 177) | class NormalizeInverse(torchvision.transforms.Normalize):
    method __init__ (line 182) | def __init__(self, mean, std):
    method __call__ (line 189) | def __call__(self, tensor):

FILE: test.py
  function main (line 14) | def main(args):
Condensed preview — 35 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (188K chars).
[
  {
    "path": "README.md",
    "chars": 3088,
    "preview": "This repo contains the code and results of the AAAI 2021 paper:\n\n<i><b> [Split then Refine: Stacked Attention-guided Res"
  },
  {
    "path": "examples/evaluate.sh",
    "chars": 1107,
    "preview": "set -ex\n\n\n\n# example training scripts for AAAI-21\n# Split then Refine: Stacked Attention-guided ResUNets for Blind Singl"
  },
  {
    "path": "examples/test.sh",
    "chars": 350,
    "preview": "\nset -ex\n\nCUDA_VISIBLE_DEVICES=0 python /data/home/yb87432/s2am/test.py \\\n  -c test/10kgray_ssim\\\n  --resume /data/home/"
  },
  {
    "path": "main.py",
    "chars": 3470,
    "preview": "from __future__ import print_function, absolute_import\n\nimport argparse\nimport torch,time,os\n\ntorch.backends.cudnn.bench"
  },
  {
    "path": "options.py",
    "chars": 6050,
    "preview": "\nimport scripts.models as models\n\nmodel_names = sorted(name for name in models.__dict__\n    if name.islower() and not na"
  },
  {
    "path": "requirements.txt",
    "chars": 150,
    "preview": "numpy==1.19.1\nopencv-python==3.4.8.29\nPillow\nscikit-image==0.14.5\nscikit-learn==0.23.1\nscipy==1.2.1\nsklearn==0.0\ntensorb"
  },
  {
    "path": "scripts/__init__.py",
    "chars": 255,
    "preview": "from __future__ import absolute_import\n\nfrom . import datasets\nfrom . import models\nfrom . import utils\n\n# import os, sy"
  },
  {
    "path": "scripts/datasets/BIH.py",
    "chars": 3357,
    "preview": "from __future__ import print_function, absolute_import\n\nimport os\nimport csv\nimport numpy as np\nimport json\nimport rando"
  },
  {
    "path": "scripts/datasets/COCO.py",
    "chars": 3353,
    "preview": "from __future__ import print_function, absolute_import\n\nimport os\nimport csv\nimport numpy as np\nimport json\nimport rando"
  },
  {
    "path": "scripts/datasets/__init__.py",
    "chars": 69,
    "preview": "from .COCO import COCO\nfrom .BIH import BIH\n\n__all__ = ('COCO','BIH')"
  },
  {
    "path": "scripts/machines/BasicMachine.py",
    "chars": 11420,
    "preview": "import torch\nimport torch.nn as nn\nimport torch.backends.cudnn as cudnn\nfrom progress.bar import Bar\nimport json\nimport "
  },
  {
    "path": "scripts/machines/S2AM.py",
    "chars": 11228,
    "preview": "import torch\nimport torch.nn as nn\nimport torch.backends.cudnn as cudnn\nfrom progress.bar import Bar\nimport json\nimport "
  },
  {
    "path": "scripts/machines/VX.py",
    "chars": 12188,
    "preview": "import torch\nimport torch.nn as nn\nfrom progress.bar import Bar\nfrom tqdm import tqdm\nimport pytorch_ssim\nimport json\nim"
  },
  {
    "path": "scripts/machines/__init__.py",
    "chars": 225,
    "preview": "\nfrom .BasicMachine import BasicMachine\nfrom .VX import VX\nfrom .S2AM import S2AM\n\ndef basic(**kwargs):\n\treturn BasicMac"
  },
  {
    "path": "scripts/models/__init__.py",
    "chars": 78,
    "preview": "from .vgg import *\nfrom .backbone_unet import *\nfrom .discriminator import *\n\n"
  },
  {
    "path": "scripts/models/backbone_unet.py",
    "chars": 1301,
    "preview": "\n\nimport torch\nimport torchvision\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport numpy as np\nimport functo"
  },
  {
    "path": "scripts/models/blocks.py",
    "chars": 10365,
    "preview": "import torch\nimport torchvision\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport numpy as np\nimport functool"
  },
  {
    "path": "scripts/models/discriminator.py",
    "chars": 5929,
    "preview": "import numpy as np\nimport functools\nimport math\nimport torch\nfrom torch.autograd import Variable\nimport torch.nn.functio"
  },
  {
    "path": "scripts/models/rasc.py",
    "chars": 5778,
    "preview": "\r\n\r\nimport torch\r\nimport torchvision\r\nimport torch.nn as nn\r\nimport torch.nn.functional as F\r\nimport numpy as np\r\nimport"
  },
  {
    "path": "scripts/models/sa_resunet.py",
    "chars": 22050,
    "preview": "\r\nimport torch\r\nimport torch.nn as nn\r\nimport torch.nn.functional as F\r\nfrom scripts.models.blocks import SEBlock\r\nfrom "
  },
  {
    "path": "scripts/models/unet.py",
    "chars": 7038,
    "preview": "import torch\r\nimport torch.nn as nn\r\nfrom torch.nn import init\r\nimport functools\r\nfrom scripts.models.blocks import *\r\nf"
  },
  {
    "path": "scripts/models/vgg.py",
    "chars": 3144,
    "preview": "from collections import namedtuple\n\nimport torch\nfrom torchvision import models\n\n\nclass Vgg16(torch.nn.Module):\n    def "
  },
  {
    "path": "scripts/models/vmu.py",
    "chars": 13059,
    "preview": "\r\nimport torch\r\nimport torch.nn as nn\r\nimport torch.nn.functional as F\r\nfrom scripts.models.blocks import SEBlock\r\nfrom "
  },
  {
    "path": "scripts/utils/__init__.py",
    "chars": 180,
    "preview": "from __future__ import absolute_import\n\nfrom .evaluation import *\nfrom .imutils import *\nfrom .logger import *\nfrom .mis"
  },
  {
    "path": "scripts/utils/evaluation.py",
    "chars": 3597,
    "preview": "from __future__ import absolute_import\n\nimport math\nimport numpy as np\nimport matplotlib.pyplot as plt\nfrom random impor"
  },
  {
    "path": "scripts/utils/imutils.py",
    "chars": 7611,
    "preview": "from __future__ import absolute_import\n\nimport torch\nimport torch.nn as nn\nimport numpy as np\nimport scipy.misc\n\nfrom .m"
  },
  {
    "path": "scripts/utils/logger.py",
    "chars": 4399,
    "preview": "# A simple torch style logger\n# (C) Wei YANG 2017\nfrom __future__ import absolute_import\n\nimport os\nimport sys\nimport nu"
  },
  {
    "path": "scripts/utils/losses.py",
    "chars": 10807,
    "preview": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom scripts.models.vgg import Vgg19\nfrom torchvision"
  },
  {
    "path": "scripts/utils/misc.py",
    "chars": 2573,
    "preview": "from __future__ import absolute_import\n\nimport os\nimport shutil\nimport torch \nimport math\nimport numpy as np\nimport scip"
  },
  {
    "path": "scripts/utils/model_init.py",
    "chars": 1752,
    "preview": "\n\nfrom torch.nn import init\n\n\ndef weights_init_normal(m):\n    classname = m.__class__.__name__\n    # print(classname)\n  "
  },
  {
    "path": "scripts/utils/osutils.py",
    "chars": 377,
    "preview": "from __future__ import absolute_import\n\nimport os\nimport errno\n\ndef mkdir_p(dir_path):\n    try:\n        os.makedirs(dir_"
  },
  {
    "path": "scripts/utils/parallel.py",
    "chars": 11406,
    "preview": "##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n## Created by: Hang Zhang, Rutgers Universit"
  },
  {
    "path": "scripts/utils/transforms.py",
    "chars": 5225,
    "preview": "from __future__ import absolute_import\n\nimport os\nimport numpy as np\nimport scipy.misc\nimport matplotlib.pyplot as plt\ni"
  },
  {
    "path": "test.py",
    "chars": 763,
    "preview": "from __future__ import print_function, absolute_import\n\nimport argparse\nimport torch\n\ntorch.backends.cudnn.benchmark = T"
  },
  {
    "path": "watermark_synthesis.ipynb",
    "chars": 5294,
    "preview": "{\n \"cells\": [\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"nam"
  }
]

About this extraction

This page contains the full source code of the vinthony/deep-blind-watermark-removal GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 35 files (174.8 KB), approximately 45.1k tokens, and a symbol index with 353 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!