[
  {
    "path": "README.md",
    "content": "This repo contains the code and results of the AAAI 2021 paper:\n\n<i><b> [Split then Refine: Stacked Attention-guided ResUNets for Blind Single Image Visible Watermark Removal](https://arxiv.org/abs/2012.07007)</b></i><br>\n[Xiaodong Cun](http://vinthony.github.io), [Chi-Man Pun<sup>*</sup>](http://www.cis.umac.mo/~cmpun/) <br>\n[University of Macau](http://um.edu.mo/)\n\n[Datasets](#Resources) | [Models](#Resources) | [Paper](https://arxiv.org/abs/2012.07007)  | [🔥Online Demo!](https://colab.research.google.com/drive/1pYY7byBjM-7aFIWk8HcF9nK_s6pqGwww?usp=sharing)(Google CoLab)\n\n<hr>\n\n<img width=\"726\" alt=\"nn\" src=\"https://user-images.githubusercontent.com/4397546/101241905-37915d80-3735-11eb-9fb9-2e1e46d63f15.png\">\n\n<i>The overview of the proposed two-stage framework. Firstly, we propose a multi-task network, SplitNet, for watermark detection, removal,  and recovery. Then, we propose the RefineNet to smooth the learned region with the predicted mask and the recovered background from the previous stage. As a consequence, our network can be trained in an end-to-end fashion without any manual intervention. Note that, for clarity, we do not show any skip-connections between all the encoders and decoders.</i>\n<hr>\n\n> The whole project will be released in the January of 2021 (almost).\n\n\n### Datasets\n\nWe synthesized four different datasets for training and testing, you can download the dataset via [huggingface](https://huggingface.co/datasets/vinthony/watermark-removal-logo/tree/main).\n\n![image](https://user-images.githubusercontent.com/4397546/104273158-74413900-54d9-11eb-95fa-c6bee94de0ea.png)\n\n\n### Pre-trained Models\n\n* [27kpng_model_best.pth.tar (google drive)](https://drive.google.com/file/d/1KpSJ6385CHN6WlAINqB3CYrJdleQTJBc/view?usp=sharing)\n\n> Other Pre-trained Models are still reorganizing and uploading, it will be released soon.\n\n\n### Demos\n\nAn easy-to-use online demo can be founded in [google colab](https://colab.research.google.com/drive/1pYY7byBjM-7aFIWk8HcF9nK_s6pqGwww?usp=sharing).\n\nThe local demo will be released soon.\n\n### Pre-requirements\n\n```\npip install -r requirements.txt\n```\n\n### Train\n\nBesides training our methods, here, we also give an example of how to train the [s2am](https://github.com/vinthony/s2am) under our framework. More details can be found in the shell scripts.\n\n\n```\nbash examples/evaluation.sh\n```\n\n### Test\n\n```\nbash examples/test.sh\n```\n\n## **Acknowledgements**\nThe author would like to thanks Nan Chen for her helpful discussion.\n\nPart of the code is based upon our previous work on image harmonization [s2am](https://github.com/vinthony/s2am) \n\n## **Citation**\n\nIf you find our work useful in your research, please consider citing:\n\n```\n@misc{cun2020split,\n      title={Split then Refine: Stacked Attention-guided ResUNets for Blind Single Image Visible Watermark Removal}, \n      author={Xiaodong Cun and Chi-Man Pun},\n      year={2020},\n      eprint={2012.07007},\n      archivePrefix={arXiv},\n      primaryClass={cs.CV}\n}\n```\n\n## **Contact**\nPlease contact me if there is any question (Xiaodong Cun yb87432@um.edu.mo)\n"
  },
  {
    "path": "examples/evaluate.sh",
    "content": "set -ex\n\n\n\n# example training scripts for AAAI-21\n# Split then Refine: Stacked Attention-guided ResUNets for Blind Single Image Visible Watermark Removal\n\n\nCUDA_VISIBLE_DEVICES=0 python /data/home/yb87432/s2am/main.py  --epochs 100\\\n --schedule 100\\\n --lr 1e-3\\\n -c eval/10kgray/1e3_bs4_256_hybrid_ssim_vgg\\\n --arch vvv4n\\\n --sltype vggx\\\n --style-loss 0.025\\\n --ssim-loss 0.15\\\n --masked True\\\n --loss-type hybrid\\\n --limited-dataset 1\\\n --machine vx\\\n --input-size 256\\\n --train-batch 4\\\n --test-batch 1\\\n --base-dir $HOME/watermark/10kgray/\\\n --data _images\n\n\n\n\n\n# example training scripts for TIP-20\n# Improving the Harmony of the Composite Image by Spatial-Separated Attention Module\n# * in the original version, the res = False\n# suitable for the iHarmony4 dataset.\n\npython /data/home/yb87432/mypaper/s2am/main.py  --epochs 200\\\n --schedule 150\\\n --lr 1e-3\\\n -c checkpoint/normal_rasc_HAdobe5k_res \\\n --arch rascv2\\\n --style-loss 0\\\n --ssim-loss 0\\\n --limited-dataset 0\\\n --res True\\\n --machine s2am\\\n --input-size 256\\\n --train-batch 16\\\n --test-batch 1\\\n --base-dir $HOME/Datasets/\\\n --data HAdobe5k"
  },
  {
    "path": "examples/test.sh",
    "content": "\nset -ex\n\nCUDA_VISIBLE_DEVICES=0 python /data/home/yb87432/s2am/test.py \\\n  -c test/10kgray_ssim\\\n  --resume /data/home/yb87432/s2am/eval/10kgray/1e3_bs6_256_hybrid_ssim_vgg_vx__images_vvv4n/model_best.pth.tar\\\n  --arch vvv4n\\\n  --machine vx\\\n  --input-size 256\\\n  --test-batch 1\\\n  --evaluate\\\n  --base-dir $HOME/watermark/10kgray/\\\n  --data _images"
  },
  {
    "path": "main.py",
    "content": "from __future__ import print_function, absolute_import\n\nimport argparse\nimport torch,time,os\n\ntorch.backends.cudnn.benchmark = True\n\nfrom scripts.utils.misc import save_checkpoint, adjust_learning_rate\n\nimport scripts.datasets as datasets\nimport scripts.machines as machines\nfrom options import Options\n\ndef main(args):\n    \n    if 'HFlickr' or 'HCOCO' or 'Hday2night' or 'HAdobe5k' in args.base_dir:\n        dataset_func = datasets.BIH\n    else:\n        dataset_func = datasets.COCO\n\n    train_loader = torch.utils.data.DataLoader(dataset_func('train',args),batch_size=args.train_batch, shuffle=True,\n        num_workers=args.workers, pin_memory=True)\n    \n    val_loader = torch.utils.data.DataLoader(dataset_func('val',args),batch_size=args.test_batch, shuffle=False,\n        num_workers=args.workers, pin_memory=True)\n\n    lr = args.lr\n    data_loaders = (train_loader,val_loader)\n\n    Machine = machines.__dict__[args.machine](datasets=data_loaders, args=args)\n    print('============================ Initization Finish && Training Start =============================================')\n\n    for epoch in range(Machine.args.start_epoch, Machine.args.epochs):\n\n        print('\\nEpoch: %d | LR: %.8f' % (epoch + 1, lr))\n        lr = adjust_learning_rate(data_loaders, Machine.optimizer, epoch, lr, args)\n\n        Machine.record('lr',lr, epoch)        \n        Machine.train(epoch)\n\n        if args.freq < 0:\n            Machine.validate(epoch)\n            Machine.flush()\n            Machine.save_checkpoint()\n\nif __name__ == '__main__':\n    parser=Options().init(argparse.ArgumentParser(description='WaterMark Removal'))\n    args = parser.parse_args()\n    print('==================================== WaterMark Removal =============================================')\n    print('==> {:50}: {:<}'.format(\"Start Time\",time.ctime(time.time())))\n    print('==> {:50}: {:<}'.format(\"USE GPU\",os.environ['CUDA_VISIBLE_DEVICES']))\n    print('==================================== Stable Parameters =============================================')\n    for arg in vars(args):\n        if type(getattr(args, arg)) == type([]):\n            if ','.join([ str(i) for i in getattr(args, arg)]) == ','.join([ str(i) for i in parser.get_default(arg)]):\n                print('==> {:50}: {:<}({:<})'.format(arg,','.join([ str(i) for i in getattr(args, arg)]),','.join([ str(i) for i in parser.get_default(arg)])))\n        else:\n            if getattr(args, arg) == parser.get_default(arg):\n                print('==> {:50}: {:<}({:<})'.format(arg,getattr(args, arg),parser.get_default(arg)))\n    print('==================================== Changed Parameters =============================================')\n    for arg in vars(args):\n        if type(getattr(args, arg)) == type([]):\n            if ','.join([ str(i) for i in getattr(args, arg)]) != ','.join([ str(i) for i in parser.get_default(arg)]):\n                print('==> {:50}: {:<}({:<})'.format(arg,','.join([ str(i) for i in getattr(args, arg)]),','.join([ str(i) for i in parser.get_default(arg)])))\n        else:\n            if getattr(args, arg) != parser.get_default(arg):\n                print('==> {:50}: {:<}({:<})'.format(arg,getattr(args, arg),parser.get_default(arg)))\n    print('==================================== Start Init Model  ===============================================')\n    main(args)\n    print('==================================== FINISH WITHOUT ERROR =============================================')\n"
  },
  {
    "path": "options.py",
    "content": "\nimport scripts.models as models\n\nmodel_names = sorted(name for name in models.__dict__\n    if name.islower() and not name.startswith(\"__\")\n    and callable(models.__dict__[name]))\n    \nclass Options():\n    \"\"\"docstring for Options\"\"\"\n    def __init__(self):\n        pass\n\n    def init(self, parser):        \n        # Model structure\n        parser.add_argument('--arch', '-a', metavar='ARCH', default='dhn',\n                            choices=model_names,\n                            help='model architecture: ' +\n                                ' | '.join(model_names) +\n                                ' (default: resnet18)')\n        parser.add_argument('--darch', metavar='ARCH', default='dhn',\n                            choices=model_names,\n                            help='model architecture: ' +\n                                ' | '.join(model_names) +\n                                ' (default: resnet18)')\n                                \n        parser.add_argument('--machine', '-m', metavar='NACHINE', default='basic')\n        # Training strategy\n        parser.add_argument('-j', '--workers', default=2, type=int, metavar='N',\n                            help='number of data loading workers (default: 4)')\n        parser.add_argument('--epochs', default=30, type=int, metavar='N',\n                            help='number of total epochs to run')\n        parser.add_argument('--start-epoch', default=0, type=int, metavar='N',\n                            help='manual epoch number (useful on restarts)')\n        parser.add_argument('--train-batch', default=64, type=int, metavar='N',\n                            help='train batchsize')\n        parser.add_argument('--test-batch', default=6, type=int, metavar='N',\n                            help='test batchsize')\n        parser.add_argument('--lr', '--learning-rate', default=1e-3, type=float,metavar='LR', help='initial learning rate')\n        parser.add_argument('--dlr', '--dlearning-rate', default=1e-3, type=float, help='initial learning rate')\n        parser.add_argument('--beta1', default=0.9, type=float, help='initial learning rate')\n        parser.add_argument('--beta2', default=0.999, type=float, help='initial learning rate')\n        parser.add_argument('--momentum', default=0, type=float, metavar='M',\n                            help='momentum')\n        parser.add_argument('--weight-decay', '--wd', default=0, type=float,\n                            metavar='W', help='weight decay (default: 0)')\n        parser.add_argument('--schedule', type=int, nargs='+', default=[5, 10],\n                            help='Decrease learning rate at these epochs.')\n        parser.add_argument('--gamma', type=float, default=0.1,\n                            help='LR is multiplied by gamma on schedule.')\n        # Data processing\n        parser.add_argument('-f', '--flip', dest='flip', action='store_true',\n                            help='flip the input during validation')\n        parser.add_argument('--lambdaL1', type=float, default=1, help='the weight of L1.')\n        parser.add_argument('--alpha', type=float, default=0.5,\n                            help='Groundtruth Gaussian sigma.')\n        parser.add_argument('--sigma-decay', type=float, default=0,\n                            help='Sigma decay rate for each epoch.')\n        # Miscs\n        parser.add_argument('--base-dir', default='/PATH_TO_DATA_FOLDER/', type=str, metavar='PATH')\n        parser.add_argument('--data', default='', type=str, metavar='PATH',\n                            help='path to save checkpoint (default: checkpoint)')\n        parser.add_argument('-c', '--checkpoint', default='checkpoint', type=str, metavar='PATH',\n                            help='path to save checkpoint (default: checkpoint)')\n        parser.add_argument('--resume', default='', type=str, metavar='PATH',\n                            help='path to latest checkpoint (default: none)')\n        parser.add_argument('--finetune', default='', type=str, metavar='PATH',\n                            help='path to latest checkpoint (default: none)')\n\n        parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',\n                            help='evaluate model on validation set')\n        parser.add_argument('--style-loss', default=0, type=float,\n                            help='preception loss')\n        parser.add_argument('--ssim-loss', default=0, type=float,help='msssim loss')\n        parser.add_argument('--att-loss', default=1, type=float,help='msssim loss')\n        parser.add_argument('--default-loss',default=False,type=bool)\n        parser.add_argument('--sltype', default='vggx', type=str)\n        parser.add_argument('-da', '--data-augumentation', default=False, type=bool,\n                            help='preception loss')\n        parser.add_argument('-d', '--debug', dest='debug', action='store_true',\n                            help='show intermediate results')\n        parser.add_argument('--input-size', default=256, type=int, metavar='N',\n                            help='train batchsize')\n        parser.add_argument('--freq', default=-1, type=int, metavar='N',\n                            help='evaluation frequence')\n        parser.add_argument('--normalized-input', default=False, type=bool,\n                            help='train batchsize')\n        parser.add_argument('--res', default=False, type=bool,help='residual learning for s2am')\n        parser.add_argument('--requires-grad', default=False, type=bool,\n                            help='train batchsize')\n        parser.add_argument('--limited-dataset', default=0, type=int, metavar='N')\n        parser.add_argument('--gpu',default=True,type=bool)\n        parser.add_argument('--masked',default=False,type=bool)\n        parser.add_argument('--gan-norm', default=False,type=bool, help='train batchsize')\n        parser.add_argument('--hl', default=False,type=bool, help='homogenious leanring')\n        parser.add_argument('--loss-type', default='l2',type=str, help='train batchsize')\n        return parser"
  },
  {
    "path": "requirements.txt",
    "content": "numpy==1.19.1\nopencv-python==3.4.8.29\nPillow\nscikit-image==0.14.5\nscikit-learn==0.23.1\nscipy==1.2.1\nsklearn==0.0\ntensorboardX\ntorch>=1.0.0\ntorchvision"
  },
  {
    "path": "scripts/__init__.py",
    "content": "from __future__ import absolute_import\n\nfrom . import datasets\nfrom . import models\nfrom . import utils\n\n# import os, sys\n# sys.path.append(os.path.join(os.path.dirname(__file__), \"progress\"))\n# from progress.bar import Bar as Bar\n\n# __version__ = '0.1.0'"
  },
  {
    "path": "scripts/datasets/BIH.py",
    "content": "from __future__ import print_function, absolute_import\n\nimport os\nimport csv\nimport numpy as np\nimport json\nimport random\nimport math\nimport matplotlib.pyplot as plt\nfrom collections import namedtuple\nfrom os import listdir\nfrom os.path import isfile, join\n\nimport torch\nimport torch.utils.data as data\n\nfrom scripts.utils.osutils import *\nfrom scripts.utils.imutils import *\nfrom scripts.utils.transforms import *\nimport torchvision.transforms as transforms\nfrom PIL import Image\nfrom PIL import ImageEnhance\nfrom PIL import ImageFilter\nfrom PIL import ImageFile\nImageFile.LOAD_TRUNCATED_IMAGES = True\n\nclass BIH(data.Dataset):\n    def __init__(self,train,config=None, sample=[],gan_norm=False):\n\n        self.train = []\n        self.anno = []\n        self.mask = []\n        self.wm = []\n        self.input_size = config.input_size\n        self.normalized_input = config.normalized_input\n        self.base_folder = config.base_dir +'/' + config.data\n        self.dataset = config.data\n\n        if config == None:\n            self.data_augumentation = False\n        else:\n            self.data_augumentation = config.data_augumentation\n\n        self.istrain = False if train.find('train') == -1 else True\n        self.sample = sample\n        self.gan_norm = gan_norm\n        mypath = join(self.base_folder,self.dataset+'_'+train+'.txt')\n\n        with open(mypath) as f:\n            # here we get the filenames \n            file_names = [ im.strip() for im in f.readlines() ]\n\n        if config.limited_dataset > 0:\n            xtrain = sorted(list(set([ file_name.split('-')[0] for file_name in file_names ])))\n            tmp = []\n            for x in xtrain:\n                tmp.append([y for y in file_names if x in y][0])\n\n            file_names = tmp\n        else:\n            file_names = file_names\n\n        for file_name in file_names:\n            self.train.append(os.path.join(self.base_folder,'images',file_name))\n            self.mask.append(os.path.join(self.base_folder,'masks','_'.join(file_name.split('_')[0:2])+'.png'))\n            self.anno.append(os.path.join(self.base_folder,'reals',file_name.split('_')[0]+'.jpg'))\n\n        if len(self.sample) > 0 :\n            self.train = [ self.train[i] for i in self.sample ] \n            self.mask = [ self.mask[i] for i in self.sample ] \n            self.anno = [ self.anno[i] for i in self.sample ] \n\n        self.trans = transforms.Compose([\n                transforms.Resize((self.input_size,self.input_size)),\n                transforms.ToTensor()\n            ])\n\n        print('total Dataset of '+self.dataset+' is : ', len(self.train))\n\n\n    def __getitem__(self, index):\n        img = Image.open(self.train[index]).convert('RGB')\n        mask = Image.open(self.mask[index]).convert('L')\n        anno = Image.open(self.anno[index]).convert('RGB')\n\n        # for shadow removal and blind image harmonization, here is no ground truth wm\n        # wm = Image.open(self.wm[index]).convert('RGB')\n\n        return {\"image\": self.trans(img),\n                \"target\": self.trans(anno), \n                \"mask\": self.trans(mask), \n                \"name\": self.train[index].split('/')[-1],\n                \"imgurl\":self.train[index],\n                \"maskurl\":self.mask[index],\n                \"targeturl\":self.anno[index],\n                }\n\n    def __len__(self):\n\n        return len(self.train)\n"
  },
  {
    "path": "scripts/datasets/COCO.py",
    "content": "from __future__ import print_function, absolute_import\n\nimport os\nimport csv\nimport numpy as np\nimport json\nimport random\nimport math\nimport matplotlib.pyplot as plt\nfrom collections import namedtuple\nfrom os import listdir\nfrom os.path import isfile, join\n\nimport torch\nimport torch.utils.data as data\n\nfrom scripts.utils.osutils import *\nfrom scripts.utils.imutils import *\nfrom scripts.utils.transforms import *\nimport torchvision.transforms as transforms\nfrom PIL import Image\nfrom PIL import ImageEnhance\nfrom PIL import ImageFilter\nfrom PIL import ImageFile\nImageFile.LOAD_TRUNCATED_IMAGES = True\n\nclass COCO(data.Dataset):\n    def __init__(self,train,config=None, sample=[],gan_norm=False):\n\n        self.train = []\n        self.anno = []\n        self.mask = []\n        self.wm = []\n        self.input_size = config.input_size\n        self.normalized_input = config.normalized_input\n        self.base_folder = config.base_dir\n        self.dataset = train+config.data\n\n        if config == None:\n            self.data_augumentation = False\n        else:\n            self.data_augumentation = config.data_augumentation\n\n        self.istrain = False if self.dataset.find('train') == -1 else True\n        self.sample = sample\n        self.gan_norm = gan_norm\n        mypath = join(self.base_folder,self.dataset)\n        file_names = sorted([f for f in listdir(join(mypath,'image')) if isfile(join(mypath,'image', f)) ])\n\n        if config.limited_dataset > 0:\n            xtrain = sorted(list(set([ file_name.split('-')[0] for file_name in file_names ])))\n            tmp = []\n            for x in xtrain:\n                # get the file_name by identifier\n                tmp.append([y for y in file_names if x in y][0])\n\n            file_names = tmp\n        else:\n            file_names = file_names\n\n        for file_name in file_names:\n            self.train.append(os.path.join(mypath,'image',file_name))\n            self.mask.append(os.path.join(mypath,'mask',file_name))\n            self.wm.append(os.path.join(mypath,'wm',file_name))\n            self.anno.append(os.path.join(self.base_folder,'natural',file_name.split('-')[0]+'.jpg'))\n\n        if len(self.sample) > 0 :\n            self.train = [ self.train[i] for i in self.sample ] \n            self.mask = [ self.mask[i] for i in self.sample ] \n            self.anno = [ self.anno[i] for i in self.sample ] \n\n        self.trans = transforms.Compose([\n                transforms.Resize((self.input_size,self.input_size)),\n                transforms.ToTensor()\n            ])\n\n        print('total Dataset of '+self.dataset+' is : ', len(self.train))\n\n\n    def __getitem__(self, index):\n        img = Image.open(self.train[index]).convert('RGB')\n        mask = Image.open(self.mask[index]).convert('L')\n        anno = Image.open(self.anno[index]).convert('RGB')\n        wm = Image.open(self.wm[index]).convert('RGB')\n\n        return {\"image\": self.trans(img),\n                \"target\": self.trans(anno), \n                \"mask\": self.trans(mask), \n                \"wm\": self.trans(wm),\n                \"name\": self.train[index].split('/')[-1],\n                \"imgurl\":self.train[index],\n                \"maskurl\":self.mask[index],\n                \"targeturl\":self.anno[index],\n                \"wmurl\":self.wm[index]\n                }\n\n    def __len__(self):\n\n        return len(self.train)\n"
  },
  {
    "path": "scripts/datasets/__init__.py",
    "content": "from .COCO import COCO\nfrom .BIH import BIH\n\n__all__ = ('COCO','BIH')"
  },
  {
    "path": "scripts/machines/BasicMachine.py",
    "content": "import torch\nimport torch.nn as nn\nimport torch.backends.cudnn as cudnn\nfrom progress.bar import Bar\nimport json\nimport numpy as np\nfrom tensorboardX import SummaryWriter\nfrom scripts.utils.evaluation import accuracy, AverageMeter, final_preds\nfrom scripts.utils.osutils import mkdir_p, isfile, isdir, join\nfrom scripts.utils.parallel import DataParallelModel, DataParallelCriterion\nimport pytorch_ssim as pytorch_ssim\nimport torch.optim\nimport sys,shutil,os\nimport time\nimport scripts.models as archs\nfrom math import log10\nfrom torch.autograd import Variable\nfrom scripts.utils.losses import VGGLoss\nfrom scripts.utils.imutils import im_to_numpy\n\nimport skimage.io\nfrom skimage.measure import compare_psnr,compare_ssim\n\n\nclass BasicMachine(object):\n    def __init__(self, datasets =(None,None), models = None, args = None, **kwargs):\n        super(BasicMachine, self).__init__()\n        \n        self.args = args\n        \n        # create model\n        print(\"==> creating model \")\n        self.model = archs.__dict__[self.args.arch]()\n        print(\"==> creating model [Finish]\")\n       \n        self.train_loader, self.val_loader = datasets\n        self.loss = torch.nn.MSELoss()\n        \n        self.title = '_'+args.machine + '_' + args.data + '_' + args.arch\n        self.args.checkpoint = args.checkpoint + self.title\n        self.device = torch.device('cuda')\n         # create checkpoint dir\n        if not isdir(self.args.checkpoint):\n            mkdir_p(self.args.checkpoint)\n\n        self.optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, self.model.parameters()), \n                            lr=args.lr,\n                            betas=(args.beta1,args.beta2),\n                            weight_decay=args.weight_decay)  \n        \n        if not self.args.evaluate:\n            self.writer = SummaryWriter(self.args.checkpoint+'/'+'ckpt')\n        \n        self.best_acc = 0\n        self.is_best = False\n        self.current_epoch = 0\n        self.metric = -100000\n        self.hl = 6 if self.args.hl else 1\n        self.count_gpu = len(range(torch.cuda.device_count()))\n\n        if self.args.style_loss > 0:\n            # init perception loss\n            self.vggloss = VGGLoss(self.args.sltype).to(self.device)\n\n        if self.count_gpu > 1 : # multiple\n            # self.model = DataParallelModel(self.model, device_ids=range(torch.cuda.device_count()))\n            # self.loss = DataParallelCriterion(self.loss, device_ids=range(torch.cuda.device_count()))\n            self.model = nn.DataParallel(self.model, device_ids=range(torch.cuda.device_count()))\n\n        self.model.to(self.device)\n        self.loss.to(self.device)\n\n        print('==> Total params: %.2fM' % (sum(p.numel() for p in self.model.parameters())/1000000.0))\n        print('==> Total devices: %d' % (torch.cuda.device_count()))\n        print('==> Current Checkpoint: %s' % (self.args.checkpoint))\n\n\n        if self.args.resume != '':\n            self.resume(self.args.resume)\n\n\n    def train(self,epoch):\n        batch_time = AverageMeter()\n        data_time = AverageMeter()\n        losses = AverageMeter()\n        lossvgg = AverageMeter()\n        \n        # switch to train mode\n        self.model.train()\n        end = time.time()\n\n        bar = Bar('Processing', max=len(self.train_loader)*self.hl)\n        for _ in range(self.hl):\n            for i, batches in enumerate(self.train_loader):\n                # measure data loading time\n                inputs = batches['image']\n                target = batches['target'].to(self.device)\n                mask =batches['mask'].to(self.device)\n                current_index = len(self.train_loader) * epoch + i\n\n                if self.args.hl:\n                    feeded = torch.cat([inputs,mask],dim=1)\n                else:\n                    feeded = inputs\n                feeded = feeded.to(self.device)\n\n                output = self.model(feeded)\n                L2_loss =  self.loss(output,target) \n                \n                if self.args.style_loss > 0:\n                    vgg_loss = self.vggloss(output,target,mask)\n                else:\n                    vgg_loss = 0\n\n                total_loss = L2_loss + self.args.style_loss * vgg_loss\n\n                # compute gradient and do SGD step\n                self.optimizer.zero_grad()\n                total_loss.backward()\n                self.optimizer.step()\n\n                # measure accuracy and record loss\n                losses.update(L2_loss.item(), inputs.size(0))\n                \n                if self.args.style_loss > 0 :\n                    lossvgg.update(vgg_loss.item(), inputs.size(0))\n                \n                # measure elapsed time\n                batch_time.update(time.time() - end)\n                end = time.time()\n\n                # plot progress\n                suffix  = '({batch}/{size}) Data: {data:.2f}s | Batch: {bt:.3f}s | Total: {total:} | ETA: {eta:} | Loss L2: {loss_label:.4f} | Loss VGG: {loss_vgg:.4f}'.format(\n                            batch=i + 1,\n                            size=len(self.train_loader),\n                            data=data_time.val,\n                            bt=batch_time.val,\n                            total=bar.elapsed_td,\n                            eta=bar.eta_td,\n                            loss_label=losses.avg,\n                            loss_vgg=lossvgg.avg\n                            )\n\n                if current_index % 1000 == 0:\n                    print(suffix)\n                \n                if self.args.freq > 0 and current_index % self.args.freq == 0:\n                    self.validate(current_index)\n                    self.flush()\n                    self.save_checkpoint()\n        \n        self.record('train/loss_L2', losses.avg, current_index)\n\n\n    def test(self, ):\n\n        # switch to evaluate mode\n        self.model.eval()\n\n        ssimes = AverageMeter()\n        psnres = AverageMeter()\n\n        with torch.no_grad():\n            for i, batches in enumerate(self.val_loader):\n\n                inputs = batches['image'].to(self.device)\n                target = batches['target'].to(self.device)\n                mask =batches['mask'].to(self.device)\n\n                outputs = self.model(inputs)\n\n                # select the outputs by the giving arch\n                if type(outputs) == type(inputs):\n                    output = outputs\n                elif type(outputs[0]) == type([]):\n                    output = outputs[0][0]\n                else:\n                    output = outputs[0]\n\n                # recover the image to 255\n                output = im_to_numpy(torch.clamp(output[0]*255,min=0.0,max=255.0)).astype(np.uint8)\n                target = im_to_numpy(torch.clamp(target[0]*255,min=0.0,max=255.0)).astype(np.uint8)\n\n                skimage.io.imsave('%s/%s'%(self.args.checkpoint,batches['name'][0]), output)\n\n                psnr = compare_psnr(target,output)\n                ssim = compare_ssim(target,output,multichannel=True)\n\n                psnres.update(psnr, inputs.size(0))\n                ssimes.update(ssim, inputs.size(0))\n\n        print(\"%s:PSNR:%s,SSIM:%s\"%(self.args.checkpoint,psnres.avg,ssimes.avg))\n        print(\"DONE.\\n\")\n              \n        \n    def validate(self, epoch):\n        batch_time = AverageMeter()\n        data_time = AverageMeter()\n        losses = AverageMeter()\n        ssimes = AverageMeter()\n        psnres = AverageMeter()\n        # switch to evaluate mode\n        self.model.eval()\n\n        end = time.time()\n        with torch.no_grad():\n            for i, batches in enumerate(self.val_loader):\n\n                inputs = batches['image'].to(self.device)\n                target = batches['target'].to(self.device)\n                mask =batches['mask'].to(self.device)\n                \n                if self.args.hl:\n                    feeded = torch.cat([inputs,torch.zeros((1,4,self.args.input_size,self.args.input_size)).to(self.device)],dim=1)\n                else:\n                    feeded = inputs\n\n                output = self.model(feeded)\n\n                L2_loss = self.loss(output, target)\n\n                psnr = 10 * log10(1 / L2_loss.item())   \n                ssim = pytorch_ssim.ssim(output, target)    \n\n                losses.update(L2_loss.item(), inputs.size(0))\n                psnres.update(psnr, inputs.size(0))\n                ssimes.update(ssim.item(), inputs.size(0))\n\n                # measure elapsed time\n                batch_time.update(time.time() - end)\n                end = time.time()\n\n        print(\"Epoches:%s,Losses:%.3f,PSNR:%.3f,SSIM:%.3f\"%(epoch+1, losses.avg,psnres.avg,ssimes.avg))\n        self.record('val/loss_L2', losses.avg, epoch)\n        self.record('val/PSNR', psnres.avg, epoch)\n        self.record('val/SSIM', ssimes.avg, epoch)\n        \n        self.metric = psnres.avg\n        \n    def resume(self,resume_path):\n        if isfile(resume_path):\n                print(\"=> loading checkpoint '{}'\".format(resume_path))\n                current_checkpoint = torch.load(resume_path)\n                if isinstance(current_checkpoint['state_dict'], torch.nn.DataParallel):\n                    current_checkpoint['state_dict'] = current_checkpoint['state_dict'].module\n\n                if isinstance(current_checkpoint['optimizer'], torch.nn.DataParallel):\n                    current_checkpoint['optimizer'] = current_checkpoint['optimizer'].module\n\n                self.args.start_epoch = current_checkpoint['epoch']\n                self.metric = current_checkpoint['best_acc']\n                self.model.load_state_dict(current_checkpoint['state_dict'])\n                # self.optimizer.load_state_dict(current_checkpoint['optimizer'])\n                print(\"=> loaded checkpoint '{}' (epoch {})\"\n                      .format(resume_path, current_checkpoint['epoch']))\n        else:\n            raise Exception(\"=> no checkpoint found at '{}'\".format(resume_path))\n\n    def save_checkpoint(self,filename='checkpoint.pth.tar', snapshot=None):\n        is_best = True if self.best_acc < self.metric else False\n\n        if is_best:\n            self.best_acc = self.metric\n\n        state = {\n                    'epoch': self.current_epoch + 1,\n                    'arch': self.args.arch,\n                    'state_dict': self.model.state_dict(),\n                    'best_acc': self.best_acc,\n                    'optimizer' : self.optimizer.state_dict() if self.optimizer else None,\n                }\n\n        filepath = os.path.join(self.args.checkpoint, filename)\n        torch.save(state, filepath)\n\n        if snapshot and state['epoch'] % snapshot == 0:\n            shutil.copyfile(filepath, os.path.join(self.args.checkpoint, 'checkpoint_{}.pth.tar'.format(state.epoch)))\n        \n        if is_best:\n            self.best_acc = self.metric\n            print('Saving Best Metric with PSNR:%s'%self.best_acc)\n            shutil.copyfile(filepath, os.path.join(self.args.checkpoint, 'model_best.pth.tar'))\n\n    def clean(self):\n        self.writer.close()\n\n    def record(self,k,v,epoch):\n        self.writer.add_scalar(k, v, epoch)\n\n    def flush(self):\n        self.writer.flush()\n        sys.stdout.flush()\n\n    def norm(self,x):\n        if self.args.gan_norm:\n            return x*2.0 - 1.0\n        else:\n            return x\n\n    def denorm(self,x):\n        if self.args.gan_norm:\n            return (x+1.0)/2.0\n        else:\n            return x\n\n"
  },
  {
    "path": "scripts/machines/S2AM.py",
    "content": "import torch\nimport torch.nn as nn\nimport torch.backends.cudnn as cudnn\nfrom progress.bar import Bar\nimport json\nimport numpy as np\nfrom tensorboardX import SummaryWriter\nfrom scripts.utils.evaluation import accuracy, AverageMeter, final_preds\nfrom scripts.utils.osutils import mkdir_p, isfile, isdir, join\nfrom scripts.utils.parallel import DataParallelModel, DataParallelCriterion\nimport pytorch_ssim as pytorch_ssim\nimport torch.optim\nimport sys,shutil,os\nimport time\nimport scripts.models as archs\nfrom math import log10\nfrom torch.autograd import Variable\nfrom scripts.utils.losses import VGGLoss\nfrom scripts.utils.imutils import im_to_numpy\n\nimport skimage.io\nfrom skimage.measure import compare_psnr,compare_ssim\n\n\nclass S2AM(object):\n    def __init__(self, datasets =(None,None), models = None, args = None, **kwargs):\n        super(S2AM, self).__init__()\n        \n        self.args = args\n        \n        # create model\n        print(\"==> creating model \")\n        self.model = archs.__dict__[self.args.arch]()\n        print(\"==> creating model [Finish]\")\n       \n        self.train_loader, self.val_loader = datasets\n        self.loss = torch.nn.MSELoss()\n        \n        self.title = '_'+args.machine + '_' + args.data + '_' + args.arch\n        self.args.checkpoint = args.checkpoint + self.title\n        self.device = torch.device('cuda')\n         # create checkpoint dir\n        if not isdir(self.args.checkpoint):\n            mkdir_p(self.args.checkpoint)\n\n        self.optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, self.model.parameters()), \n                            lr=args.lr,\n                            betas=(args.beta1,args.beta2),\n                            weight_decay=args.weight_decay)  \n        \n        if not self.args.evaluate:\n            self.writer = SummaryWriter(self.args.checkpoint+'/'+'ckpt')\n        \n        self.best_acc = 0\n        self.is_best = False\n        self.current_epoch = 0\n        self.hl = 1\n        self.metric = -100000\n        self.count_gpu = len(range(torch.cuda.device_count()))\n\n        if self.args.style_loss > 0:\n            # init perception loss\n            self.vggloss = VGGLoss(self.args.sltype).to(self.device)\n\n        if self.count_gpu > 1 : # multiple\n            # self.model = DataParallelModel(self.model, device_ids=range(torch.cuda.device_count()))\n            # self.loss = DataParallelCriterion(self.loss, device_ids=range(torch.cuda.device_count()))\n            self.model = nn.DataParallel(self.model, device_ids=range(torch.cuda.device_count()))\n\n        self.model.to(self.device)\n        self.loss.to(self.device)\n\n        print('==> Total params: %.2fM' % (sum(p.numel() for p in self.model.parameters())/1000000.0))\n        print('==> Total devices: %d' % (torch.cuda.device_count()))\n        print('==> Current Checkpoint: %s' % (self.args.checkpoint))\n\n\n        if self.args.resume != '':\n            self.resume(self.args.resume)\n\n\n    def train(self,epoch):\n        batch_time = AverageMeter()\n        data_time = AverageMeter()\n        losses = AverageMeter()\n        lossvgg = AverageMeter()\n        \n        # switch to train mode\n        self.model.train()\n        end = time.time()\n\n        bar = Bar('Processing', max=len(self.train_loader)*self.hl)\n        for _ in range(self.hl):\n            for i, batches in enumerate(self.train_loader):\n                # measure data loading time\n                inputs = batches['image'].to(self.device)\n                target = batches['target'].to(self.device)\n                mask =batches['mask'].to(self.device)\n                current_index = len(self.train_loader) * epoch + i\n\n                feeded = torch.cat([inputs,mask],dim=1)\n                feeded = feeded.to(self.device)\n\n                output = self.model(feeded)\n\n                if self.args.res:\n                    output = output + inputs\n\n                L2_loss =  self.loss(output,target) \n                \n                if self.args.style_loss > 0:\n                    vgg_loss = self.vggloss(output,target,mask)\n                else:\n                    vgg_loss = 0\n\n                total_loss = L2_loss + self.args.style_loss * vgg_loss\n\n                # compute gradient and do SGD step\n                self.optimizer.zero_grad()\n                total_loss.backward()\n                self.optimizer.step()\n\n                # measure accuracy and record loss\n                losses.update(L2_loss.item(), inputs.size(0))\n                \n                if self.args.style_loss > 0 :\n                    lossvgg.update(vgg_loss.item(), inputs.size(0))\n                \n                # measure elapsed time\n                batch_time.update(time.time() - end)\n                end = time.time()\n\n                # plot progress\n                suffix  = '({batch}/{size}) Data: {data:.2f}s | Batch: {bt:.3f}s | Total: {total:} | ETA: {eta:} | Loss L2: {loss_label:.4f} | Loss VGG: {loss_vgg:.4f}'.format(\n                            batch=i + 1,\n                            size=len(self.train_loader),\n                            data=data_time.val,\n                            bt=batch_time.val,\n                            total=bar.elapsed_td,\n                            eta=bar.eta_td,\n                            loss_label=losses.avg,\n                            loss_vgg=lossvgg.avg\n                            )\n\n                if current_index % 1000 == 0:\n                    print(suffix)\n                \n                if self.args.freq > 0 and current_index % self.args.freq == 0:\n                    self.validate(current_index)\n                    self.flush()\n                    self.save_checkpoint()\n        \n        self.record('train/loss_L2', losses.avg, current_index)\n\n\n    def test(self, ):\n\n        # switch to evaluate mode\n        self.model.eval()\n\n        ssimes = AverageMeter()\n        psnres = AverageMeter()\n\n        with torch.no_grad():\n            for i, batches in enumerate(self.val_loader):\n\n                inputs = batches['image'].to(self.device)\n                target = batches['target'].to(self.device)\n                mask =batches['mask'].to(self.device)\n\n                feeded = torch.cat([inputs,mask],dim=1)\n                feeded = feeded.to(self.device)\n\n                output = self.model(feeded)\n\n                if self.args.res:\n                    output = output + inputs\n\n                # recover the image to 255\n                output = im_to_numpy(torch.clamp(output[0]*255,min=0.0,max=255.0)).astype(np.uint8)\n                target = im_to_numpy(torch.clamp(target[0]*255,min=0.0,max=255.0)).astype(np.uint8)\n\n                skimage.io.imsave('%s/%s'%(self.args.checkpoint,batches['name'][0]), output)\n\n                psnr = compare_psnr(target,output)\n                ssim = compare_ssim(target,output,multichannel=True)\n\n                psnres.update(psnr, inputs.size(0))\n                ssimes.update(ssim, inputs.size(0))\n\n        print(\"%s:PSNR:%s,SSIM:%s\"%(self.args.checkpoint,psnres.avg,ssimes.avg))\n        print(\"DONE.\\n\")\n              \n        \n    def validate(self, epoch):\n        batch_time = AverageMeter()\n        data_time = AverageMeter()\n        losses = AverageMeter()\n        ssimes = AverageMeter()\n        psnres = AverageMeter()\n        # switch to evaluate mode\n        self.model.eval()\n\n        end = time.time()\n        with torch.no_grad():\n            for i, batches in enumerate(self.val_loader):\n\n                inputs = batches['image'].to(self.device)\n                target = batches['target'].to(self.device)\n                mask =batches['mask'].to(self.device)\n                \n                feeded = torch.cat([inputs,mask],dim=1)\n                feeded = feeded.to(self.device)\n\n                output = self.model(feeded)\n\n                if self.args.res:\n                    output = output + inputs\n\n                L2_loss = self.loss(output, target)\n\n                psnr = 10 * log10(1 / L2_loss.item())   \n                ssim = pytorch_ssim.ssim(output, target)    \n\n                losses.update(L2_loss.item(), inputs.size(0))\n                psnres.update(psnr, inputs.size(0))\n                ssimes.update(ssim.item(), inputs.size(0))\n\n                # measure elapsed time\n                batch_time.update(time.time() - end)\n                end = time.time()\n\n        print(\"Epoches:%s,Losses:%.3f,PSNR:%.3f,SSIM:%.3f\"%(epoch+1, losses.avg,psnres.avg,ssimes.avg))\n        self.record('val/loss_L2', losses.avg, epoch)\n        self.record('val/PSNR', psnres.avg, epoch)\n        self.record('val/SSIM', ssimes.avg, epoch)\n        \n        self.metric = psnres.avg\n        \n    def resume(self,resume_path):\n        if isfile(resume_path):\n                print(\"=> loading checkpoint '{}'\".format(resume_path))\n                current_checkpoint = torch.load(resume_path)\n                if isinstance(current_checkpoint['state_dict'], torch.nn.DataParallel):\n                    current_checkpoint['state_dict'] = current_checkpoint['state_dict'].module\n\n                if isinstance(current_checkpoint['optimizer'], torch.nn.DataParallel):\n                    current_checkpoint['optimizer'] = current_checkpoint['optimizer'].module\n\n                self.args.start_epoch = current_checkpoint['epoch']\n                self.metric = current_checkpoint['best_acc']\n                self.model.load_state_dict(current_checkpoint['state_dict'])\n                # self.optimizer.load_state_dict(current_checkpoint['optimizer'])\n                print(\"=> loaded checkpoint '{}' (epoch {})\"\n                      .format(resume_path, current_checkpoint['epoch']))\n        else:\n            raise Exception(\"=> no checkpoint found at '{}'\".format(resume_path))\n\n    def save_checkpoint(self,filename='checkpoint.pth.tar', snapshot=None):\n        is_best = True if self.best_acc < self.metric else False\n\n        if is_best:\n            self.best_acc = self.metric\n\n        state = {\n                    'epoch': self.current_epoch + 1,\n                    'arch': self.args.arch,\n                    'state_dict': self.model.state_dict(),\n                    'best_acc': self.best_acc,\n                    'optimizer' : self.optimizer.state_dict() if self.optimizer else None,\n                }\n\n        filepath = os.path.join(self.args.checkpoint, filename)\n        torch.save(state, filepath)\n\n        if snapshot and state['epoch'] % snapshot == 0:\n            shutil.copyfile(filepath, os.path.join(self.args.checkpoint, 'checkpoint_{}.pth.tar'.format(state.epoch)))\n        \n        if is_best:\n            self.best_acc = self.metric\n            print('Saving Best Metric with PSNR:%s'%self.best_acc)\n            shutil.copyfile(filepath, os.path.join(self.args.checkpoint, 'model_best.pth.tar'))\n\n    def clean(self):\n        self.writer.close()\n\n    def record(self,k,v,epoch):\n        self.writer.add_scalar(k, v, epoch)\n\n    def flush(self):\n        self.writer.flush()\n        sys.stdout.flush()\n\n    def norm(self,x):\n        if self.args.gan_norm:\n            return x*2.0 - 1.0\n        else:\n            return x\n\n    def denorm(self,x):\n        if self.args.gan_norm:\n            return (x+1.0)/2.0\n        else:\n            return x\n\n"
  },
  {
    "path": "scripts/machines/VX.py",
    "content": "import torch\nimport torch.nn as nn\nfrom progress.bar import Bar\nfrom tqdm import tqdm\nimport pytorch_ssim\nimport json\nimport sys,time,os\nimport torchvision\nfrom math import log10\nimport numpy as np\nfrom .BasicMachine import BasicMachine\nfrom scripts.utils.evaluation import accuracy, AverageMeter, final_preds\nfrom scripts.utils.misc import resize_to_match\nfrom torch.autograd import Variable\nimport torch.nn.functional as F\nfrom scripts.utils.parallel import DataParallelModel, DataParallelCriterion\nfrom scripts.utils.losses import VGGLoss, l1_relative,is_dic\nfrom scripts.utils.imutils import im_to_numpy\nimport skimage.io\nfrom skimage.measure import compare_psnr,compare_ssim\n\n\nclass Losses(nn.Module):\n    def __init__(self, argx, device, norm_func=None, denorm_func=None):\n        super(Losses, self).__init__()\n        self.args = argx\n\n        if self.args.loss_type == 'l1bl2':\n            self.outputLoss, self.attLoss, self.wrloss = nn.L1Loss(), nn.BCELoss(), nn.MSELoss()\n        elif self.args.loss_type == 'l2xbl2':\n            self.outputLoss, self.attLoss, self.wrloss = nn.MSELoss(), nn.BCEWithLogitsLoss(), nn.MSELoss()\n        elif self.args.loss_type == 'relative' or self.args.loss_type == 'hybrid':\n            self.outputLoss, self.attLoss, self.wrloss = l1_relative, nn.BCELoss(), l1_relative\n        else: # l2bl2\n            self.outputLoss, self.attLoss, self.wrloss = nn.MSELoss(), nn.BCELoss(), nn.MSELoss()\n\n        self.default = nn.L1Loss()\n\n        if self.args.style_loss > 0:\n            self.vggloss = VGGLoss(self.args.sltype).to(device)\n        \n        if self.args.ssim_loss > 0:\n            self.ssimloss =  pytorch_ssim.SSIM().to(device)\n        \n        self.norm = norm_func\n        self.denorm = denorm_func\n\n\n    def forward(self,pred_ims,target,pred_ms,mask,pred_wms,wm):\n        pixel_loss,att_loss,wm_loss,vgg_loss,ssim_loss = [0]*5\n        pred_ims = pred_ims if is_dic(pred_ims) else [pred_ims]\n\n        # try the loss in the masked region\n        if self.args.masked and 'hybrid' in self.args.loss_type: # masked loss\n            pixel_loss += sum([self.outputLoss(pred_im, target, mask) for pred_im in pred_ims])\n            pixel_loss += sum([self.default(pred_im*pred_ms,target*mask) for pred_im in pred_ims])\n            recov_imgs = [ self.denorm(pred_im*mask + (1-mask)*self.norm(target)) for pred_im in pred_ims ]\n            wm_loss += self.wrloss(pred_wms, wm, mask)\n            wm_loss += self.default(pred_wms*pred_ms, wm*mask)\n\n        elif self.args.masked and 'relative' in self.args.loss_type: # masked loss\n            pixel_loss += sum([self.outputLoss(pred_im, target, mask) for pred_im in pred_ims])\n            recov_imgs = [ self.denorm(pred_im*mask + (1-mask)*self.norm(target)) for pred_im in pred_ims ]\n            wm_loss = self.wrloss(pred_wms, wm, mask)\n        elif self.args.masked:\n            pixel_loss += sum([self.outputLoss(pred_im*mask, target*mask) for pred_im in pred_ims])\n            recov_imgs = [ self.denorm(pred_im*pred_ms + (1-pred_ms)*self.norm(target)) for pred_im in pred_ims ]\n            wm_loss = self.wrloss(pred_wms*mask, wm*mask)\n        else:\n            pixel_loss += sum([self.outputLoss(pred_im*pred_ms, target*mask) for pred_im in pred_ims])\n            recov_imgs = [ self.denorm(pred_im*pred_ms + (1-pred_ms)*self.norm(target)) for pred_im in pred_ims ]\n            wm_loss = self.wrloss(pred_wms*pred_ms,wm*mask)\n\n        pixel_loss += sum([self.default(im,target) for im in recov_imgs])\n\n        if self.args.style_loss > 0:\n            vgg_loss = sum([self.vggloss(im,target,mask) for im in recov_imgs])\n\n        if self.args.ssim_loss > 0:\n            ssim_loss = sum([ 1 - self.ssimloss(im,target) for im in recov_imgs])\n\n        att_loss =  self.attLoss(pred_ms, mask)\n\n        return pixel_loss,att_loss,wm_loss,vgg_loss,ssim_loss\n\n\nclass VX(BasicMachine):\n    def __init__(self,**kwargs):\n        BasicMachine.__init__(self,**kwargs)\n        self.loss = Losses(self.args, self.device, self.norm, self.denorm)\n        self.model.set_optimizers()\n        self.optimizer = None\n       \n    def train(self,epoch):\n\n        self.current_epoch = epoch\n\n        batch_time = AverageMeter()\n        data_time = AverageMeter()\n        losses = AverageMeter()\n        lossMask = AverageMeter()\n        lossWM = AverageMeter()\n        lossMX = AverageMeter()\n        lossvgg = AverageMeter()\n        lossssim = AverageMeter()\n\n        # switch to train mode\n        self.model.train()\n\n        end = time.time()\n        bar = Bar('Processing {} '.format(self.args.arch), max=len(self.train_loader))\n\n        for i, batches in enumerate(self.train_loader):\n\n            current_index = len(self.train_loader) * epoch + i\n\n            inputs = batches['image'].to(self.device)\n            target = batches['target'].to(self.device)\n            mask = batches['mask'].to(self.device)\n            wm =  batches['wm'].to(self.device)\n\n            outputs = self.model(self.norm(inputs))\n            \n            self.model.zero_grad_all()\n\n            l2_loss,att_loss,wm_loss,style_loss,ssim_loss = self.loss(outputs[0],self.norm(target),outputs[1],mask,outputs[2],self.norm(wm))\n            total_loss = 2*l2_loss + self.args.att_loss * att_loss + wm_loss + self.args.style_loss * style_loss + self.args.ssim_loss * ssim_loss\n\n            # compute gradient and do SGD step\n            total_loss.backward()\n            self.model.step_all()\n\n            # measure accuracy and record loss\n            losses.update(l2_loss.item(), inputs.size(0))\n            lossMask.update(att_loss.item(), inputs.size(0))\n            lossWM.update(wm_loss.item(), inputs.size(0))\n\n            if self.args.style_loss > 0 :\n                lossvgg.update(style_loss.item(), inputs.size(0))\n\n            if self.args.ssim_loss > 0 :\n                lossssim.update(ssim_loss.item(), inputs.size(0))\n\n\n            # measure elapsed time\n            batch_time.update(time.time() - end)\n            end = time.time()\n\n            # plot progress\n            suffix  = \"({batch}/{size}) Data: {data:.2f}s | Batch: {bt:.3f}s | Total: {total:} | ETA: {eta:} | Loss L2: {loss_label:.4f} | Loss Mask: {loss_mask:.4f} | loss WM: {loss_wm:.4f} | loss VGG: {loss_vgg:.4f} | loss SSIM: {loss_ssim:.4f}| loss MX: {loss_mx:.4f}\".format(\n                        batch=i + 1,\n                        size=len(self.train_loader),\n                        data=data_time.val,\n                        bt=batch_time.val,\n                        total=bar.elapsed_td,\n                        eta=bar.eta_td,\n                        loss_label=losses.avg,\n                        loss_mask=lossMask.avg,\n                        loss_wm=lossWM.avg,\n                        loss_vgg=lossvgg.avg,\n                        loss_ssim=lossssim.avg,\n                        loss_mx=lossMX.avg\n                        )\n            if current_index % 1000 == 0:\n                print(suffix)\n\n            if self.args.freq > 0 and current_index % self.args.freq == 0:\n                self.validate(current_index)\n                self.flush()\n                self.save_checkpoint()\n\n        self.record('train/loss_L2', losses.avg, epoch)\n        self.record('train/loss_Mask', lossMask.avg, epoch)\n        self.record('train/loss_WM', lossWM.avg, epoch)\n        self.record('train/loss_VGG', lossvgg.avg, epoch)\n        self.record('train/loss_SSIM', lossssim.avg, epoch)\n        self.record('train/loss_MX', lossMX.avg, epoch)\n\n\n\n\n    def validate(self, epoch):\n\n        self.current_epoch = epoch\n        \n        batch_time = AverageMeter()\n        data_time = AverageMeter()\n        losses = AverageMeter()\n        lossMask = AverageMeter()\n        psnres = AverageMeter()\n        ssimes = AverageMeter()\n\n        # switch to evaluate mode\n        self.model.eval()\n\n        end = time.time()\n        bar = Bar('Processing {} '.format(self.args.arch), max=len(self.val_loader))\n        with torch.no_grad():\n            for i, batches in enumerate(self.val_loader):\n\n                current_index = len(self.val_loader) * epoch + i\n\n                inputs = batches['image'].to(self.device)\n                target = batches['target'].to(self.device)\n\n                outputs = self.model(self.norm(inputs))\n                imoutput,immask,imwatermark = outputs\n                imoutput = imoutput[0] if is_dic(imoutput) else imoutput\n\n                imfinal = self.denorm(imoutput*immask + self.norm(inputs)*(1-immask))\n\n                if i % 300 == 0:\n                    # save the sample images\n                    ims = torch.cat([inputs,target,imfinal,immask.repeat(1,3,1,1)],dim=3)\n                    torchvision.utils.save_image(ims,os.path.join(self.args.checkpoint,'%s_%s.jpg'%(i,epoch)))\n\n                # here two choice: mseLoss or NLLLoss\n                psnr = 10 * log10(1 / F.mse_loss(imfinal,target).item())       \n\n                ssim = pytorch_ssim.ssim(imfinal,target)\n\n                psnres.update(psnr, inputs.size(0))\n                ssimes.update(ssim, inputs.size(0))\n\n                # measure elapsed time\n                batch_time.update(time.time() - end)\n                end = time.time()\n\n                # plot progress\n                bar.suffix  = '({batch}/{size}) Data: {data:.2f}s | Batch: {bt:.3f}s | Total: {total:} | ETA: {eta:} | Loss_L2: {loss_label:.4f} | Loss_Mask: {loss_mask:.4f} | PSNR: {psnr:.4f} | SSIM: {ssim:.4f}'.format(\n                            batch=i + 1,\n                            size=len(self.val_loader),\n                            data=data_time.val,\n                            bt=batch_time.val,\n                            total=bar.elapsed_td,\n                            eta=bar.eta_td,\n                            loss_label=losses.avg,\n                            loss_mask=lossMask.avg,\n                            psnr=psnres.avg,\n                            ssim=ssimes.avg\n                            )\n                bar.next()\n        bar.finish()\n        \n        print(\"Iter:%s,Losses:%s,PSNR:%.4f,SSIM:%.4f\"%(epoch, losses.avg,psnres.avg,ssimes.avg))\n        self.record('val/loss_L2', losses.avg, epoch)\n        self.record('val/lossMask', lossMask.avg, epoch)\n        self.record('val/PSNR', psnres.avg, epoch)\n        self.record('val/SSIM', ssimes.avg, epoch)\n        self.metric = psnres.avg\n\n        self.model.train()\n\n    def test(self, ):\n\n        # switch to evaluate mode\n        self.model.eval()\n        print(\"==> testing VM model \")\n        ssimes = AverageMeter()\n        psnres = AverageMeter()\n        ssimesx = AverageMeter()\n        psnresx = AverageMeter()\n\n        with torch.no_grad():\n            for i, batches in enumerate(tqdm(self.val_loader)):\n\n                inputs = batches['image'].to(self.device)\n                target = batches['target'].to(self.device)\n                mask =batches['mask'].to(self.device)\n\n                # select the outputs by the giving arch\n                outputs = self.model(self.norm(inputs))\n                imoutput,immask,imwatermark = outputs\n                imoutput = imoutput[0] if is_dic(imoutput) else imoutput\n\n                imfinal = self.denorm(imoutput*immask + self.norm(inputs)*(1-immask))\n                psnrx = 10 * log10(1 / F.mse_loss(imfinal,target).item())       \n                ssimx = pytorch_ssim.ssim(imfinal,target)\n                # recover the image to 255\n                imfinal = im_to_numpy(torch.clamp(imfinal[0]*255,min=0.0,max=255.0)).astype(np.uint8)\n                target = im_to_numpy(torch.clamp(target[0]*255,min=0.0,max=255.0)).astype(np.uint8)\n\n                skimage.io.imsave('%s/%s'%(self.args.checkpoint,batches['name'][0]), imfinal)\n\n                psnr = compare_psnr(target,imfinal)\n                ssim = compare_ssim(target,imfinal,multichannel=True)\n\n                psnres.update(psnr, inputs.size(0))\n                ssimes.update(ssim, inputs.size(0))\n                psnresx.update(psnrx, inputs.size(0))\n                ssimesx.update(ssimx, inputs.size(0))\n\n        print(\"%s:PSNR:%.5f(%.5f),SSIM:%.5f(%.5f)\"%(self.args.checkpoint,psnres.avg,psnresx.avg,ssimes.avg,ssimesx.avg))\n        print(\"DONE.\\n\")"
  },
  {
    "path": "scripts/machines/__init__.py",
    "content": "\nfrom .BasicMachine import BasicMachine\nfrom .VX import VX\nfrom .S2AM import S2AM\n\ndef basic(**kwargs):\n\treturn BasicMachine(**kwargs)\n\ndef s2am(**kwargs):\n    return S2AM(**kwargs)\n\ndef vx(**kwargs):\n    return VX(**kwargs)\n"
  },
  {
    "path": "scripts/models/__init__.py",
    "content": "from .vgg import *\nfrom .backbone_unet import *\nfrom .discriminator import *\n\n"
  },
  {
    "path": "scripts/models/backbone_unet.py",
    "content": "\n\nimport torch\nimport torchvision\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport numpy as np\nimport functools\nimport math\n\nfrom scripts.utils.model_init import *\nfrom scripts.models.rasc import *\nfrom scripts.models.unet import UnetGenerator,MinimalUnetV2\nfrom scripts.models.vmu import UnetVM\nfrom scripts.models.sa_resunet import UnetVMS2AMv4\n\n\n# our method\ndef vvv4n(**kwargs):\n    return UnetVMS2AMv4(shared_depth=2, blocks=3, long_skip=True, use_vm_decoder=True,s2am='vms2am')\n\n\n# BVMR\ndef vm3(**kwargs):\n    return UnetVM(shared_depth=2, blocks=3, use_vm_decoder=True)\n\n\n# Blind version of S2AM\ndef urasc(**kwargs):\n    model = UnetGenerator(3,3,is_attention_layer=True,attention_model=URASC,basicblock=MinimalUnetV2)\n    model.apply(weights_init_kaiming)\n    return model\n\n\n# Improving the Harmony of the Composite Image by Spatial-Separated Attention Module\n# Xiaodong Cun and Chi-Man Pun\n# University of Macau\n# Trans. on Image Processing, vol. 29, pp. 4759-4771, 2020.\ndef rascv2(**kwargs):\n    model = UnetGenerator(4,3,is_attention_layer=True,attention_model=RASC,basicblock=MinimalUnetV2)\n    model.apply(weights_init_kaiming)\n    return model\n\n# just original unet\ndef unet(**kwargs):\n    model = UnetGenerator(3,3)\n    model.apply(weights_init_kaiming)\n    return model\n\n\n"
  },
  {
    "path": "scripts/models/blocks.py",
    "content": "import torch\nimport torchvision\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport numpy as np\nimport functools\nimport math\nimport numbers\n\nfrom scripts.utils.model_init import *\nfrom scripts.models.vgg import Vgg16\nfrom torch import nn, cuda\nfrom torch.autograd import Variable\n\nclass BasicLearningBlock(nn.Module):\n    \"\"\"docstring for BasicLearningBlock\"\"\"\n    def __init__(self,channel):\n        super(BasicLearningBlock, self).__init__()\n        self.rconv1 = nn.Conv2d(channel,channel*2,3,padding=1,bias=False)\n        self.rbn1 = nn.BatchNorm2d(channel*2)\n        self.rconv2 = nn.Conv2d(channel*2,channel,3,padding=1,bias=False)\n        self.rbn2 = nn.BatchNorm2d(channel)\n\n    def forward(self,feature):\n        return F.elu(self.rbn2(self.rconv2(F.elu(self.rbn1(self.rconv1(feature)))))) \n        \n\n\n# From https://discuss.pytorch.org/t/is-there-anyway-to-do-gaussian-filtering-for-an-image-2d-3d-in-pytorch/12351/3\nclass GaussianSmoothing(nn.Module):\n    \"\"\"\n    Apply gaussian smoothing on a\n    1d, 2d or 3d tensor. Filtering is performed seperately for each channel\n    in the input using a depthwise convolution.\n    Arguments:\n        channels (int, sequence): Number of channels of the input tensors. Output will\n            have this number of channels as well.\n        kernel_size (int, sequence): Size of the gaussian kernel.\n        sigma (float, sequence): Standard deviation of the gaussian kernel.\n        dim (int, optional): The number of dimensions of the data.\n            Default value is 2 (spatial).\n    \"\"\"\n    def __init__(self, channels, kernel_size, sigma, dim=2):\n        super(GaussianSmoothing, self).__init__()\n        if isinstance(kernel_size, numbers.Number):\n            kernel_size = [kernel_size] * dim\n        if isinstance(sigma, numbers.Number):\n            sigma = [sigma] * dim\n\n        # The gaussian kernel is the product of the\n        # gaussian function of each dimension.\n        kernel = 1\n        meshgrids = torch.meshgrid(\n            [\n                torch.arange(size, dtype=torch.float32)\n                for size in kernel_size\n            ]\n        )\n        for size, std, mgrid in zip(kernel_size, sigma, meshgrids):\n            mean = (size - 1) / 2\n            kernel *= 1 / (std * math.sqrt(2 * math.pi)) * \\\n                      torch.exp(-((mgrid - mean) / (2 * std)) ** 2)\n\n        # Make sure sum of values in gaussian kernel equals 1.\n        kernel = kernel / torch.sum(kernel)\n\n        # Reshape to depthwise convolutional weight\n        kernel = kernel.view(1, 1, *kernel.size())\n        kernel = kernel.repeat(channels, *[1] * (kernel.dim() - 1))\n\n        self.register_buffer('weight', kernel)\n        self.groups = channels\n\n        if dim == 1:\n            self.conv = F.conv1d\n        elif dim == 2:\n            self.conv = F.conv2d\n        elif dim == 3:\n            self.conv = F.conv3d\n        else:\n            raise RuntimeError(\n                'Only 1, 2 and 3 dimensions are supported. Received {}.'.format(dim)\n            )\n\n    def forward(self, input):\n        \"\"\"\n        Apply gaussian filter to input.\n        Arguments:\n            input (torch.Tensor): Input to apply gaussian filter on.\n        Returns:\n            filtered (torch.Tensor): Filtered output.\n        \"\"\"\n        return self.conv(input, weight=self.weight, groups=self.groups)\n\nclass ChannelPool(nn.Module):\n    def __init__(self,types):\n        super(ChannelPool, self).__init__()\n        if types == 'avg': \n            self.poolingx = nn.AdaptiveAvgPool1d(1)\n        elif types == 'max':\n            self.poolingx = nn.AdaptiveMaxPool1d(1)\n        else:\n            raise 'inner error'\n\n    def forward(self, input):\n        n, c, w, h = input.size()\n        input = input.view(n,c,w*h).permute(0,2,1) \n        pooled =  self.poolingx(input)# b,w*h,c ->  b,w*h,1\n        _, _, c = pooled.size()\n        return pooled.view(n,c,w,h)\n\n\n\nclass SEBlock(nn.Module):\n    \"\"\"docstring for SEBlock\"\"\"\n    def __init__(self, channel,reducation=16):\n        super(SEBlock, self).__init__()\n        self.avg_pool = nn.AdaptiveAvgPool2d(1)\n        self.fc = nn.Sequential(\n            nn.Linear(channel,channel//reducation),\n            nn.ReLU(inplace=True),\n            nn.Linear(channel//reducation,channel),\n            nn.Sigmoid())\n        \n    def forward(self,x):\n        b,c,w,h = x.size()\n        y1 = self.avg_pool(x).view(b,c)\n        y = self.fc(y1).view(b,c,1,1)\n        return x*y\n\n\n\nclass GlobalAttentionModule(nn.Module):\n    \"\"\"docstring for GlobalAttentionModule\"\"\"\n    def __init__(self, channel,reducation=16):\n        super(GlobalAttentionModule, self).__init__()\n        self.avg_pool = nn.AdaptiveAvgPool2d(1)\n        self.max_pool = nn.AdaptiveMaxPool2d(1)\n        self.fc = nn.Sequential(\n            nn.Linear(channel*2,channel//reducation),\n            nn.ReLU(inplace=True),\n            nn.Linear(channel//reducation,channel),\n            nn.Sigmoid())\n        \n    def forward(self,x):\n        b,c,w,h = x.size()\n        y1 = self.avg_pool(x).view(b,c)\n        y2 = self.max_pool(x).view(b,c)\n        y = self.fc(torch.cat([y1,y2],1)).view(b,c,1,1)\n        return x*y\n\nclass SpatialAttentionModule(nn.Module):\n    \"\"\"docstring for SpatialAttentionModule\"\"\"\n    def __init__(self, channel,reducation=16):\n        super(SpatialAttentionModule, self).__init__()\n        self.avg_pool = ChannelPool('avg')\n        self.max_pool = ChannelPool('max')\n        self.fc = nn.Sequential(\n            nn.Conv2d(2,reducation,7,stride=1,padding=3),\n            nn.ReLU(inplace=True),\n            nn.Conv2d(reducation,1,7,stride=1,padding=3),\n            nn.Sigmoid())\n        \n    def forward(self,x):\n        b,c,w,h = x.size()\n        y1 = self.avg_pool(x)\n        y2 = self.max_pool(x)\n        y = self.fc(torch.cat([y1,y2],1))\n        yr = 1-y\n        return y,yr\n\n\n\nclass GlobalAttentionModuleJustSigmoid(nn.Module):\n    \"\"\"docstring for GlobalAttentionModule\"\"\"\n    def __init__(self, channel,reducation=16):\n        super(GlobalAttentionModuleJustSigmoid, self).__init__()\n        self.avg_pool = nn.AdaptiveAvgPool2d(1)\n        self.max_pool = nn.AdaptiveMaxPool2d(1)\n        self.fc = nn.Sequential(\n            nn.Linear(channel*2,channel//reducation),\n            nn.ReLU(inplace=True),\n            nn.Linear(channel//reducation,channel),\n            nn.Sigmoid())\n        \n    def forward(self,x):\n        b,c,w,h = x.size()\n        y1 = self.avg_pool(x).view(b,c)\n        y2 = self.max_pool(x).view(b,c)\n        y = self.fc(torch.cat([y1,y2],1)).view(b,c,1,1)\n        return y\n\n\n\nclass BasicBlock(nn.Module):\n    def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1, groups=1, relu=True, bn=True, bias=False):\n        super(BasicBlock, self).__init__()\n        self.out_channels = out_planes\n        self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias)\n        self.bn = nn.BatchNorm2d(out_planes,eps=1e-5, momentum=0.01, affine=True) if bn else None\n        self.relu = nn.ReLU() if relu else None\n\n    def forward(self, x):\n        x = self.conv(x)\n        if self.bn is not None:\n            x = self.bn(x)\n        if self.relu is not None:\n            x = self.relu(x)\n        return x\n\nclass Flatten(nn.Module):\n    def forward(self, x):\n        return x.view(x.size(0), -1)\n\nclass ChannelGate(nn.Module):\n    def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg', 'max']):\n        super(ChannelGate, self).__init__()\n        self.gate_channels = gate_channels\n        self.mlp = nn.Sequential(\n            Flatten(),\n            nn.Linear(gate_channels, gate_channels // reduction_ratio),\n            nn.ReLU(),\n            nn.Linear(gate_channels // reduction_ratio, gate_channels)\n            )\n        self.pool_types = pool_types\n    def forward(self, x):\n        channel_att_sum = None\n        for pool_type in self.pool_types:\n            if pool_type=='avg':\n                avg_pool = F.avg_pool2d( x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))\n                channel_att_raw = self.mlp( avg_pool )\n            elif pool_type=='max':\n                max_pool = F.max_pool2d( x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))\n                channel_att_raw = self.mlp( max_pool )\n            elif pool_type=='lp':\n                lp_pool = F.lp_pool2d( x, 2, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))\n                channel_att_raw = self.mlp( lp_pool )\n            elif pool_type=='lse':\n                # LSE pool only\n                lse_pool = logsumexp_2d(x)\n                channel_att_raw = self.mlp( lse_pool )\n\n            if channel_att_sum is None:\n                channel_att_sum = channel_att_raw\n            else:\n                channel_att_sum = channel_att_sum + channel_att_raw\n\n        scale = F.sigmoid( channel_att_sum ).unsqueeze(2).unsqueeze(3).expand_as(x)\n        return x * scale\n\ndef logsumexp_2d(tensor):\n    tensor_flatten = tensor.view(tensor.size(0), tensor.size(1), -1)\n    s, _ = torch.max(tensor_flatten, dim=2, keepdim=True)\n    outputs = s + (tensor_flatten - s).exp().sum(dim=2, keepdim=True).log()\n    return outputs\n\nclass ChannelPoolX(nn.Module):\n    def forward(self, x):\n        return torch.cat( (torch.max(x,1)[0].unsqueeze(1), torch.mean(x,1).unsqueeze(1)), dim=1 )\n\nclass SpatialGate(nn.Module):\n    def __init__(self):\n        super(SpatialGate, self).__init__()\n        kernel_size = 7\n        self.compress = ChannelPoolX()\n        self.spatial = BasicBlock(2, 1, kernel_size, stride=1, padding=(kernel_size-1) // 2, relu=False)\n    def forward(self, x):\n        x_compress = self.compress(x)\n        x_out = self.spatial(x_compress)\n        scale = F.sigmoid(x_out) # broadcasting\n        return x * scale\n\nclass CBAM(nn.Module):\n    def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg', 'max'], no_spatial=False):\n        super(CBAM, self).__init__()\n        self.ChannelGate = ChannelGate(gate_channels, reduction_ratio, pool_types)\n        self.no_spatial=no_spatial\n        if not no_spatial:\n            self.SpatialGate = SpatialGate()\n    def forward(self, x):\n        x_out = self.ChannelGate(x)\n        if not self.no_spatial:\n            x_out = self.SpatialGate(x_out)\n        return x_out\n\n\n"
  },
  {
    "path": "scripts/models/discriminator.py",
    "content": "import numpy as np\nimport functools\nimport math\nimport torch\nfrom torch.autograd import Variable\nimport torch.nn.functional as F\nfrom torch import nn\nfrom torch import Tensor\nfrom torch.nn import Parameter\nfrom scripts.utils.model_init import *\nfrom torch.optim.optimizer import Optimizer, required\n\n\n__all__ = ['patchgan','sngan','maskedsngan']\n\n\nclass SNCoXvWithActivation(torch.nn.Module):\n    \"\"\"\n    SN convolution for spetral normalization conv\n    \"\"\"\n    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, activation=torch.nn.LeakyReLU(0.2, inplace=True)):\n        super(SNCoXvWithActivation, self).__init__()\n        self.conv2d = torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias)\n        self.conv2d = torch.nn.utils.spectral_norm(self.conv2d)\n        self.activation = activation\n        for m in self.modules():\n            if isinstance(m, nn.Conv2d):\n                nn.init.kaiming_normal_(m.weight)\n    def forward(self, input):\n        x = self.conv2d(input)\n        if self.activation is not None:\n            return self.activation(x)\n        else:\n            return x\n\ndef l2normalize(v, eps=1e-12):\n    return v / (v.norm() + eps)\n\n\nclass SpectralNorm(nn.Module):\n    def __init__(self, module, name='weight', power_iterations=1):\n        super(SpectralNorm, self).__init__()\n        self.module = module\n        self.name = name\n        self.power_iterations = power_iterations\n        if not self._made_params():\n            self._make_params()\n\n    def _update_u_v(self):\n        u = getattr(self.module, self.name + \"_u\")\n        v = getattr(self.module, self.name + \"_v\")\n        w = getattr(self.module, self.name + \"_bar\")\n\n        height = w.data.shape[0]\n        for _ in range(self.power_iterations):\n            v.data = l2normalize(torch.mv(torch.t(w.view(height,-1).data), u.data))\n            u.data = l2normalize(torch.mv(w.view(height,-1).data, v.data))\n\n        # sigma = torch.dot(u.data, torch.mv(w.view(height,-1).data, v.data))\n        sigma = u.dot(w.view(height, -1).mv(v))\n        setattr(self.module, self.name, w / sigma.expand_as(w))\n\n    def _made_params(self):\n        try:\n            u = getattr(self.module, self.name + \"_u\")\n            v = getattr(self.module, self.name + \"_v\")\n            w = getattr(self.module, self.name + \"_bar\")\n            return True\n        except AttributeError:\n            return False\n\n\n    def _make_params(self):\n        w = getattr(self.module, self.name)\n\n        height = w.data.shape[0]\n        width = w.view(height, -1).data.shape[1]\n\n        u = Parameter(w.data.new(height).normal_(0, 1), requires_grad=False)\n        v = Parameter(w.data.new(width).normal_(0, 1), requires_grad=False)\n        u.data = l2normalize(u.data)\n        v.data = l2normalize(v.data)\n        w_bar = Parameter(w.data)\n\n        del self.module._parameters[self.name]\n\n        self.module.register_parameter(self.name + \"_u\", u)\n        self.module.register_parameter(self.name + \"_v\", v)\n        self.module.register_parameter(self.name + \"_bar\", w_bar)\n\n\n    def forward(self, *args):\n        self._update_u_v()\n        return self.module.forward(*args)\n\n\ndef get_pad(in_,  ksize, stride, atrous=1):\n    out_ = np.ceil(float(in_)/stride)\n    return int(((out_ - 1) * stride + atrous*(ksize-1) + 1 - in_)/2)\n\nclass SNDiscriminator(nn.Module):\n    def __init__(self,channel=6):\n        super(SNDiscriminator, self).__init__()\n        cnum = 32\n        self.discriminator_net = nn.Sequential(\n            SNCoXvWithActivation(channel, 2*cnum, 4, 2, padding=get_pad(256, 5, 2)),\n            SNCoXvWithActivation(2*cnum, 4*cnum, 4, 2, padding=get_pad(128, 5, 2)), \n            SNCoXvWithActivation(4*cnum, 8*cnum, 4, 2, padding=get_pad(64, 5, 2)),\n            SNCoXvWithActivation(8*cnum, 8*cnum, 4, 2, padding=get_pad(32, 5, 2)),\n            SNCoXvWithActivation(8*cnum, 8*cnum, 4, 2, padding=get_pad(16, 5, 2)), # 8*8*256\n            # SNConvWithActivation(8*cnum, 8*cnum, 4, 2, padding=get_pad(8, 5, 2)), # 4*4*256\n            # SNConvWithActivation(8*cnum, 8*cnum, 4, 2, padding=get_pad(4, 5, 2)), # 2*2*256\n        )\n        # self.linear = nn.Linear(2*2*256,1)\n\n    def forward(self, img_A, img_B):\n        # Concatenate image and condition image by channels to produce input\n        img_input = torch.cat((img_A, img_B), 1)\n        x = self.discriminator_net(img_input)\n        # x = x.view((x.size(0),-1))\n        # x = self.linear(x)\n        return x\n\nclass Discriminator(nn.Module):\n    def __init__(self, in_channels=3):\n        super(Discriminator, self).__init__()\n\n        def discriminator_block(in_filters, out_filters, normalization=True):\n            \"\"\"Returns downsampling layers of each discriminator block\"\"\"\n            layers = [nn.Conv2d(in_filters, out_filters, 4, stride=2, padding=1)]\n            if normalization:\n                layers.append(nn.InstanceNorm2d(out_filters))\n            layers.append(nn.LeakyReLU(0.2, inplace=True))\n            return layers\n\n        self.model = nn.Sequential(\n            *discriminator_block(in_channels*2, 64, normalization=False),\n            *discriminator_block(64, 128),\n            *discriminator_block(128, 256),\n            *discriminator_block(256, 512),\n            nn.ZeroPad2d((1, 0, 1, 0)),\n            nn.Conv2d(512, 1, 4, padding=1, bias=False)\n        )\n\n    def forward(self, img_A, img_B):\n        # Concatenate image and condition image by channels to produce input\n        img_input = torch.cat((img_A, img_B), 1)\n        return self.model(img_input)\n\n\ndef patchgan():\n    model = Discriminator()\n    model.apply(weights_init_kaiming)\n    return model\n\ndef sngan():\n    model = SNDiscriminator()\n    model.apply(weights_init_kaiming)\n    return model\n\ndef maskedsngan():\n    model = SNDiscriminator(channel=7)\n    model.apply(weights_init_kaiming)\n    return model"
  },
  {
    "path": "scripts/models/rasc.py",
    "content": "\r\n\r\nimport torch\r\nimport torchvision\r\nimport torch.nn as nn\r\nimport torch.nn.functional as F\r\nimport numpy as np\r\nimport math\r\n\r\nfrom scripts.utils.model_init import *\r\nfrom scripts.models.vgg import Vgg16\r\nfrom scripts.models.blocks import *\r\n\r\n\r\nclass CAWapper(nn.Module):\r\n    \"\"\"docstring for SENet\"\"\"\r\n\r\n    def __init__(self, channel, type_of_connection=BasicLearningBlock):\r\n        super(CAWapper, self).__init__()\r\n        self.attention = ContextualAttention(ksize=3, stride=1, rate=2, fuse_k=3, softmax_scale=10, fuse=True, use_cuda=True)\r\n\r\n    def forward(self, feature, mask):\r\n        _, _, w, _ = feature.size()\r\n        _, _, mw, _ = mask.size()\r\n        # binaryfiy\r\n        # selected the feature from the background as the additional feature to masked splicing feature.\r\n        mask = torch.round(F.avg_pool2d(mask, 2, stride=mw//w))\r\n\r\n        result = self.attention(feature,mask)\r\n\r\n        return result\r\n\r\n\r\nclass NLWapper(nn.Module):\r\n    \"\"\"docstring for SENet\"\"\"\r\n\r\n    def __init__(self, channel, type_of_connection=BasicLearningBlock):\r\n        super(NLWapper, self).__init__()\r\n        self.attention = NONLocalBlock2D(channel)\r\n\r\n    def forward(self, feature, mask):\r\n        _, _, w, _ = feature.size()\r\n        _, _, mw, _ = mask.size()\r\n        # binaryfiy\r\n        # selected the feature from the background as the additional feature to masked splicing feature.\r\n        # mask = torch.round(F.avg_pool2d(mask, 2, stride=mw//w))\r\n\r\n        result = self.attention(feature)\r\n\r\n        return result\r\n\r\nclass SENet(nn.Module):\r\n    \"\"\"docstring for SENet\"\"\"\r\n    def __init__(self,channel,type_of_connection=BasicLearningBlock):\r\n        super(SENet, self).__init__()\r\n        self.attention = SEBlock(channel,16)\r\n\r\n    def forward(self,feature,mask):\r\n        _,_,w,_ = feature.size()\r\n        _,_,mw,_ = mask.size()\r\n        # binaryfiy\r\n        # selected the feature from the background as the additional feature to masked splicing feature.\r\n        mask = torch.round(F.avg_pool2d(mask,2,stride=mw//w))\r\n\r\n        result = self.attention(feature) \r\n        \r\n        return result\r\n\r\nclass CBAMConnect(nn.Module):\r\n    def __init__(self,channel):\r\n        super(CBAMConnect, self).__init__()\r\n        self.attention = CBAM(channel)\r\n\r\n    def forward(self,feature,mask):\r\n        results = self.attention(feature)\r\n        return results\r\n\r\n\r\n\r\nclass RASC(nn.Module):\r\n    def __init__(self,channel,type_of_connection=BasicLearningBlock):\r\n        super(RASC, self).__init__()\r\n        self.connection = type_of_connection(channel)\r\n        self.background_attention = GlobalAttentionModule(channel,16)\r\n        self.mixed_attention = GlobalAttentionModule(channel,16)\r\n        self.spliced_attention = GlobalAttentionModule(channel,16)\r\n        self.gaussianMask = GaussianSmoothing(1,5,1)\r\n\r\n    def forward(self,feature,mask):\r\n        _,_,w,_ = feature.size()\r\n        _,_,mw,_ = mask.size()\r\n        # binaryfiy\r\n        # selected the feature from the background as the additional feature to masked splicing feature.\r\n        if w != mw:\r\n            mask = torch.round(F.avg_pool2d(mask,2,stride=mw//w))\r\n        reverse_mask = -1*(mask-1)\r\n        # here we add gaussin filter to mask and reverse_mask for better harimoization of edges.\r\n\r\n        mask = self.gaussianMask(F.pad(mask,(2,2,2,2),mode='reflect'))\r\n        reverse_mask = self.gaussianMask(F.pad(reverse_mask,(2,2,2,2),mode='reflect'))\r\n\r\n\r\n        background = self.background_attention(feature) * reverse_mask\r\n        selected_feature = self.mixed_attention(feature)\r\n        spliced_feature = self.spliced_attention(feature) \r\n        spliced = ( self.connection(spliced_feature) + selected_feature ) * mask\r\n        return background + spliced    \r\n\r\n\r\nclass UNO(nn.Module):\r\n    def __init__(self,channel):\r\n        super(UNO, self).__init__()\r\n\r\n    def forward(self,feature,_m):\r\n        return feature \r\n\r\n\r\nclass URASC(nn.Module):\r\n    def __init__(self,channel,type_of_connection=BasicLearningBlock):\r\n        super(URASC, self).__init__()\r\n        self.connection = type_of_connection(channel)\r\n        self.background_attention = GlobalAttentionModule(channel,16)\r\n        self.mixed_attention = GlobalAttentionModule(channel,16)\r\n        self.spliced_attention = GlobalAttentionModule(channel,16)\r\n        self.mask_attention = SpatialAttentionModule(channel,16)\r\n\r\n    def forward(self,feature, m=None):\r\n        _,_,w,_ = feature.size()\r\n      \r\n        mask, reverse_mask = self.mask_attention(feature)\r\n\r\n        background = self.background_attention(feature) * reverse_mask\r\n        selected_feature = self.mixed_attention(feature)\r\n        spliced_feature = self.spliced_attention(feature) \r\n        spliced = ( self.connection(spliced_feature) + selected_feature ) * mask\r\n        return background + spliced  \r\n\r\n\r\nclass MaskedURASC(nn.Module):\r\n    def __init__(self,channel,type_of_connection=BasicLearningBlock):\r\n        super(MaskedURASC, self).__init__()\r\n        self.connection = type_of_connection(channel)\r\n        self.background_attention = GlobalAttentionModule(channel,16)\r\n        self.mixed_attention = GlobalAttentionModule(channel,16)\r\n        self.spliced_attention = GlobalAttentionModule(channel,16)\r\n        self.mask_attention = SpatialAttentionModule(channel,16)\r\n\r\n    def forward(self,feature):\r\n        _,_,w,_ = feature.size()\r\n      \r\n        mask, reverse_mask = self.mask_attention(feature)\r\n\r\n        background = self.background_attention(feature) * reverse_mask\r\n        selected_feature = self.mixed_attention(feature)\r\n        spliced_feature = self.spliced_attention(feature) \r\n        spliced = ( self.connection(spliced_feature) + selected_feature ) * mask\r\n        return background + spliced, mask\r\n\r\n"
  },
  {
    "path": "scripts/models/sa_resunet.py",
    "content": "\r\nimport torch\r\nimport torch.nn as nn\r\nimport torch.nn.functional as F\r\nfrom scripts.models.blocks import SEBlock\r\nfrom scripts.models.rasc import *\r\nfrom scripts.models.unet import UnetGenerator,MinimalUnetV2\r\n\r\ndef weight_init(m):\r\n    if isinstance(m, nn.Conv2d):\r\n        nn.init.xavier_normal_(m.weight)\r\n        nn.init.constant_(m.bias, 0)\r\n\r\ndef reset_params(model):\r\n    for i, m in enumerate(model.modules()):\r\n        weight_init(m)\r\n\r\n\r\ndef conv3x3(in_channels, out_channels, stride=1,\r\n            padding=1, bias=True, groups=1):\r\n    return nn.Conv2d(\r\n        in_channels,\r\n        out_channels,\r\n        kernel_size=3,\r\n        stride=stride,\r\n        padding=padding,\r\n        bias=bias,\r\n        groups=groups)\r\n\r\n\r\ndef up_conv2x2(in_channels, out_channels, transpose=True):\r\n    if transpose:\r\n        return nn.ConvTranspose2d(\r\n            in_channels,\r\n            out_channels,\r\n            kernel_size=2,\r\n            stride=2)\r\n    else:\r\n        return nn.Sequential(\r\n            nn.Upsample(mode='bilinear', scale_factor=2),\r\n            conv1x1(in_channels, out_channels))\r\n\r\n\r\ndef conv1x1(in_channels, out_channels, groups=1):\r\n    return nn.Conv2d(\r\n        in_channels,\r\n        out_channels,\r\n        kernel_size=1,\r\n        groups=groups,\r\n        stride=1)\r\n\r\n\r\nclass UpCoXvD(nn.Module):\r\n\r\n    def __init__(self, in_channels, out_channels, blocks, residual=True,norm=nn.BatchNorm2d, act=F.relu,batch_norm=True, transpose=True,concat=True,use_att=False):\r\n        super(UpCoXvD, self).__init__()\r\n        self.concat = concat\r\n        self.residual = residual\r\n        self.batch_norm = batch_norm\r\n        self.bn = None\r\n        self.conv2 = []\r\n        self.use_att = use_att\r\n        self.up_conv = up_conv2x2(in_channels, out_channels, transpose=transpose)\r\n        self.norm0 = norm(out_channels)\r\n        \r\n        if self.use_att:\r\n            self.s2am = RASC(2 * out_channels)\r\n        else:\r\n            self.s2am = None\r\n\r\n        if self.concat:\r\n            self.conv1 = conv3x3(2 * out_channels, out_channels)\r\n            self.norm1 = norm(out_channels , out_channels)\r\n        else:\r\n            self.conv1 = conv3x3(out_channels, out_channels)\r\n            self.norm1 = norm(out_channels , out_channels)\r\n\r\n        for _ in range(blocks):\r\n            self.conv2.append(conv3x3(out_channels, out_channels))\r\n        if self.batch_norm:\r\n            self.bn = []\r\n            for _ in range(blocks):\r\n                self.bn.append(norm(out_channels))\r\n            self.bn = nn.ModuleList(self.bn)\r\n        self.conv2 = nn.ModuleList(self.conv2)\r\n        self.act = act\r\n\r\n    def forward(self, from_up, from_down, mask=None,se=None):\r\n        from_up = self.act(self.norm0(self.up_conv(from_up)))\r\n        if self.concat:\r\n            x1 = torch.cat((from_up, from_down), 1)\r\n        else:\r\n            if from_down is not None:\r\n                x1 = from_up + from_down\r\n            else:\r\n                x1 = from_up\r\n\r\n        if self.use_att:\r\n            x1 = self.s2am(x1,mask)\r\n        \r\n        x1 = self.act(self.norm1(self.conv1(x1)))\r\n        x2 = None\r\n        for idx, conv in enumerate(self.conv2):\r\n            x2 = conv(x1)\r\n            if self.batch_norm:\r\n                x2 = self.bn[idx](x2)\r\n            \r\n            if (se is not None) and (idx == len(self.conv2) - 1): # last \r\n                x2 = se(x2)\r\n\r\n            if self.residual:\r\n                x2 = x2 + x1\r\n            x2 = self.act(x2)\r\n            x1 = x2\r\n        return x2\r\n\r\n\r\nclass DownCoXvD(nn.Module):\r\n\r\n    def __init__(self, in_channels, out_channels, blocks, pooling=True, norm=nn.BatchNorm2d,act=F.relu,residual=True, batch_norm=True):\r\n        super(DownCoXvD, self).__init__()\r\n        self.pooling = pooling\r\n        self.residual = residual\r\n        self.batch_norm = batch_norm\r\n        self.bn = None\r\n        self.pool = None\r\n        self.conv1 = conv3x3(in_channels, out_channels)\r\n        self.norm1 = norm(out_channels)\r\n\r\n        self.conv2 = []\r\n        for _ in range(blocks):\r\n            self.conv2.append(conv3x3(out_channels, out_channels))\r\n        if self.batch_norm:\r\n            self.bn = []\r\n            for _ in range(blocks):\r\n                self.bn.append(norm(out_channels))\r\n            self.bn = nn.ModuleList(self.bn)\r\n        if self.pooling:\r\n            self.pool = nn.MaxPool2d(kernel_size=2, stride=2)\r\n        self.conv2 = nn.ModuleList(self.conv2)\r\n        self.act = act\r\n\r\n    def __call__(self, x):\r\n        return self.forward(x)\r\n\r\n    def forward(self, x):\r\n        x1 = self.act(self.norm1(self.conv1(x)))\r\n        x2 = None\r\n        for idx, conv in enumerate(self.conv2):\r\n            x2 = conv(x1)\r\n            if self.batch_norm:\r\n                x2 = self.bn[idx](x2)\r\n            if self.residual:\r\n                x2 = x2 + x1\r\n            x2 = self.act(x2)\r\n            x1 = x2\r\n        before_pool = x2\r\n        if self.pooling:\r\n            x2 = self.pool(x2)\r\n        return x2, before_pool\r\n\r\nclass UnetDecoderD(nn.Module):\r\n    def __init__(self, in_channels=512, out_channels=3, norm=nn.BatchNorm2d,act=F.relu, depth=5, blocks=1, residual=True, batch_norm=True,\r\n                 transpose=True, concat=True, is_final=True, use_att=False):\r\n        super(UnetDecoderD, self).__init__()\r\n        self.conv_final = None\r\n        self.up_convs = []\r\n        self.atts = []\r\n        self.use_att = use_att\r\n\r\n        outs = in_channels\r\n        for i in range(depth-1): # depth = 1\r\n            ins = outs\r\n            outs = ins // 2\r\n            # 512,256\r\n            # 256,128\r\n            # 128,64\r\n            # 64,32\r\n            up_conv = UpCoXvD(ins, outs, blocks, residual=residual, batch_norm=batch_norm, transpose=transpose,\r\n                              concat=concat, norm=norm, act=act)\r\n            if self.use_att:\r\n                self.atts.append(SEBlock(outs))\r\n            \r\n            self.up_convs.append(up_conv)\r\n\r\n        if is_final:\r\n            self.conv_final = conv1x1(outs, out_channels)\r\n        else:\r\n            up_conv = UpCoXvD(outs, out_channels, blocks, residual=residual, batch_norm=batch_norm, transpose=transpose,\r\n                              concat=concat,norm=norm, act=act)\r\n            if self.use_att:\r\n                self.atts.append(SEBlock(out_channels))\r\n\r\n            self.up_convs.append(up_conv)\r\n        self.up_convs = nn.ModuleList(self.up_convs)\r\n        self.atts = nn.ModuleList(self.atts)\r\n\r\n        reset_params(self)\r\n\r\n    def __call__(self, x, encoder_outs=None):\r\n        return self.forward(x, encoder_outs)\r\n\r\n    def forward(self, x, encoder_outs=None):\r\n        for i, up_conv in enumerate(self.up_convs):\r\n            before_pool = None\r\n            if encoder_outs is not None:\r\n                before_pool = encoder_outs[-(i+2)]\r\n            x = up_conv(x, before_pool)\r\n            if self.use_att:\r\n                x = self.atts[i](x)\r\n\r\n        if self.conv_final is not None:\r\n            x = self.conv_final(x)\r\n        return x\r\n\r\n\r\nclass UnetDecoderDatt(nn.Module):\r\n    def __init__(self, in_channels=512, out_channels=3, depth=5, blocks=1, residual=True, batch_norm=True,\r\n                 transpose=True, concat=True, is_final=True, norm=nn.BatchNorm2d,act=F.relu):\r\n        super(UnetDecoderDatt, self).__init__()\r\n        self.conv_final = None\r\n        self.up_convs = []\r\n        self.im_atts = []\r\n        self.vm_atts = []\r\n        self.mask_atts = []\r\n\r\n        outs = in_channels\r\n        for i in range(depth-1): # depth = 5 [0,1,2,3]\r\n            ins = outs\r\n            outs = ins // 2\r\n            # 512,256\r\n            # 256,128\r\n            # 128,64\r\n            # 64,32\r\n            up_conv = UpCoXvD(ins, outs, blocks, residual=residual, batch_norm=batch_norm, transpose=transpose,\r\n                              concat=concat, norm=nn.BatchNorm2d,act=F.relu)\r\n            self.up_convs.append(up_conv)\r\n            self.im_atts.append(SEBlock(outs))\r\n            self.vm_atts.append(SEBlock(outs))\r\n            self.mask_atts.append(SEBlock(outs))\r\n        if is_final:\r\n            self.conv_final = conv1x1(outs, out_channels)\r\n        else:\r\n            up_conv = UpCoXvD(outs, out_channels, blocks, residual=residual, batch_norm=batch_norm, transpose=transpose,\r\n                              concat=concat, norm=nn.BatchNorm2d,act=F.relu)\r\n            self.up_convs.append(up_conv)\r\n            self.im_atts.append(SEBlock(out_channels))\r\n            self.vm_atts.append(SEBlock(out_channels))\r\n            self.mask_atts.append(SEBlock(out_channels))\r\n\r\n        self.up_convs = nn.ModuleList(self.up_convs)\r\n        self.im_atts = nn.ModuleList(self.im_atts)\r\n        self.vm_atts = nn.ModuleList(self.vm_atts)\r\n        self.mask_atts = nn.ModuleList(self.mask_atts)\r\n\r\n        reset_params(self)\r\n\r\n    def forward(self, input, encoder_outs=None):\r\n        # im branch\r\n        x = input\r\n        for i, up_conv in enumerate(self.up_convs):\r\n            before_pool = None\r\n            if encoder_outs is not None:\r\n                before_pool = encoder_outs[-(i+2)]\r\n            x = up_conv(x, before_pool,se=self.im_atts[i])\r\n        x_im = x\r\n\r\n        x = input        \r\n        for i, up_conv in enumerate(self.up_convs):\r\n            before_pool = None\r\n            if encoder_outs is not None:\r\n                before_pool = encoder_outs[-(i+2)]\r\n            x = up_conv(x, before_pool, se = self.mask_atts[i])\r\n        x_mask = x\r\n\r\n        x = input\r\n        for i, up_conv in enumerate(self.up_convs):\r\n            before_pool = None\r\n            if encoder_outs is not None:\r\n                before_pool = encoder_outs[-(i+2)]\r\n            x = up_conv(x, before_pool, se=self.vm_atts[i])\r\n        x_vm = x\r\n\r\n        return x_im,x_mask,x_vm\r\n\r\nclass UnetEncoderD(nn.Module):\r\n\r\n    def __init__(self, in_channels=3, depth=5, blocks=1, start_filters=32, residual=True, batch_norm=True, norm=nn.BatchNorm2d, act=F.relu):\r\n        super(UnetEncoderD, self).__init__()\r\n        self.down_convs = []\r\n        outs = None\r\n        if type(blocks) is tuple:\r\n            blocks = blocks[0]\r\n        for i in range(depth):\r\n            ins = in_channels if i == 0 else outs\r\n            outs = start_filters*(2**i)\r\n            pooling = True if i < depth-1 else False\r\n            down_conv = DownCoXvD(ins, outs, blocks, pooling=pooling, residual=residual, batch_norm=batch_norm, norm=nn.BatchNorm2d, act=F.relu)\r\n            self.down_convs.append(down_conv)\r\n        self.down_convs = nn.ModuleList(self.down_convs)\r\n        reset_params(self)\r\n\r\n    def __call__(self, x):\r\n        return self.forward(x)\r\n\r\n    def forward(self, x):\r\n        encoder_outs = []\r\n        for d_conv in self.down_convs:\r\n            x, before_pool = d_conv(x)\r\n            encoder_outs.append(before_pool)\r\n        return x, encoder_outs\r\n\r\nclass ResDown(nn.Module):\r\n    def __init__(self, in_size, out_size, pooling=True, use_att=False):\r\n        super(ResDown, self).__init__()\r\n        self.model = DownCoXvD(in_size, out_size, 3, pooling=pooling)\r\n\r\n    def forward(self, x):\r\n        return self.model(x)\r\n\r\nclass ResUp(nn.Module):\r\n    def __init__(self, in_size, out_size, use_att=False):\r\n        super(ResUp, self).__init__()\r\n        self.model = UpCoXvD(in_size, out_size, 3, use_att=use_att)\r\n\r\n    def forward(self, x, skip_input, mask=None):\r\n        return self.model(x,skip_input,mask)\r\n\r\nclass ResDownNew(nn.Module):\r\n    def __init__(self, in_size, out_size, pooling=True, use_att=False):\r\n        super(ResDownNew, self).__init__()\r\n        self.model = DownCoXvD(in_size, out_size, 3, pooling=pooling, norm=nn.InstanceNorm2d, act=F.leaky_relu)\r\n\r\n    def forward(self, x):\r\n        return self.model(x)\r\n\r\nclass ResUpNew(nn.Module):\r\n    def __init__(self, in_size, out_size, use_att=False):\r\n        super(ResUpNew, self).__init__()\r\n        self.model = UpCoXvD(in_size, out_size, 3, use_att=use_att, norm=nn.InstanceNorm2d)\r\n\r\n    def forward(self, x, skip_input, mask=None):\r\n        return self.model(x,skip_input,mask)\r\n\r\n\r\n\r\nclass VMSingle(nn.Module):\r\n    def __init__(self, in_channels=3, out_channels=3, down=ResDown, up=ResUp, ngf=32, res=True,use_att=False):\r\n        super(VMSingle, self).__init__()\r\n\r\n        self.down1 = down(in_channels, ngf)\r\n        self.down2 = down(ngf, ngf*2)\r\n        self.down3 = down(ngf*2, ngf*4)\r\n        self.down4 = down(ngf*4, ngf*8)\r\n        self.down5 = down(ngf*8, ngf*16, pooling=False)\r\n\r\n        self.up1 = up(ngf*16, ngf*8)\r\n        self.up2 = up(ngf*8, ngf*4, use_att=use_att)\r\n        self.up3 = up(ngf*4, ngf*2, use_att=use_att)\r\n        self.up4 = up(ngf*2, ngf*1, use_att=use_att)\r\n\r\n        self.im = nn.Conv2d(ngf, 3, 1)\r\n        self.res = res\r\n\r\n\r\n    def forward(self, input):\r\n        img, mask = input[:,0:3,:,:],input[:,3:4,:,:]\r\n        # U-Net generator with skip connections from encoder to decoder\r\n        x,d1 = self.down1(input) # 128,256\r\n        x,d2 = self.down2(x) # 64,128\r\n        x,d3 = self.down3(x) # 32,64\r\n        x,d4 = self.down4(x) # 16,32\r\n        x,_ = self.down5(x) # 8,16\r\n\r\n        x = self.up1(x, d4) # 16\r\n        x = self.up2(x, d3, mask) # 32\r\n        x = self.up3(x, d2, mask) # 64\r\n        x = self.up4(x, d1, mask) # 128\r\n        im = self.im(x)\r\n\r\n        return im\r\n\r\n\r\n\r\nclass VMSingleS2AM(nn.Module):\r\n    def __init__(self, in_channels=3, out_channels=3, down=ResDown, up=ResUp, ngf=32):\r\n        super(VMSingleS2AM, self).__init__()\r\n\r\n        self.down1 = down(in_channels, ngf)\r\n        self.down2 = down(ngf, ngf*2)\r\n        self.down3 = down(ngf*2, ngf*4)\r\n        self.down4 = down(ngf*4, ngf*8)\r\n        self.down5 = down(ngf*8, ngf*16, pooling=False)\r\n\r\n        self.up1 = up(ngf*16, ngf*8)\r\n        self.up2 = up(ngf*8, ngf*4)\r\n        self.s2am2 = RASC(ngf*4)\r\n        \r\n        self.up3 = up(ngf*4, ngf*2)\r\n        self.s2am3 = RASC(ngf*2)\r\n\r\n        self.up4 = up(ngf*2, ngf*1)\r\n        self.s2am4 = RASC(ngf)\r\n\r\n        self.im = nn.Conv2d(ngf, 3, 1)\r\n\r\n\r\n    def forward(self, input):\r\n        img, mask = input[:,0:3,:,:],input[:,3:4,:,:]\r\n        # U-Net generator with skip connections from encoder to decoder\r\n        x,d1 = self.down1(input) # 128,256\r\n        x,d2 = self.down2(x) # 64,128\r\n        x,d3 = self.down3(x) # 32,64\r\n        x,d4 = self.down4(x) # 16,32\r\n        x,_ = self.down5(x) # 8,16\r\n\r\n        x = self.up1(x, d4) # 16\r\n        x = self.up2(x, d3) # 32\r\n        x = self.s2am2(x, mask)\r\n\r\n        x = self.up3(x, d2) # 64\r\n        x = self.s2am3(x, mask)\r\n\r\n        x = self.up4(x, d1) # 128\r\n        x = self.s2am4(x, mask)\r\n        im = self.im(x)\r\n        return im\r\n\r\n\r\nclass UnetVMS2AMv4(nn.Module):\r\n\r\n    def __init__(self, in_channels=3, depth=5, shared_depth=0, use_vm_decoder=False, blocks=1,\r\n                 out_channels_image=3, out_channels_mask=1, start_filters=32, residual=True, batch_norm=True,\r\n                 transpose=True, concat=True, transfer_data=True, long_skip=False, s2am='unet', use_coarser=True,no_stage2=False):\r\n        super(UnetVMS2AMv4, self).__init__()\r\n        self.transfer_data = transfer_data\r\n        self.shared = shared_depth\r\n        self.optimizer_encoder,  self.optimizer_image, self.optimizer_vm = None, None, None\r\n        self.optimizer_mask, self.optimizer_shared = None, None\r\n        if type(blocks) is not tuple:\r\n            blocks = (blocks, blocks, blocks, blocks, blocks)\r\n        if not transfer_data:\r\n            concat = False\r\n        self.encoder = UnetEncoderD(in_channels=in_channels, depth=depth, blocks=blocks[0],\r\n                                    start_filters=start_filters, residual=residual, batch_norm=batch_norm,norm=nn.InstanceNorm2d,act=F.leaky_relu)\r\n        self.image_decoder = UnetDecoderD(in_channels=start_filters * 2 ** (depth - shared_depth - 1),\r\n                                          out_channels=out_channels_image, depth=depth - shared_depth,\r\n                                          blocks=blocks[1], residual=residual, batch_norm=batch_norm,\r\n                                          transpose=transpose, concat=concat,norm=nn.InstanceNorm2d)\r\n        self.mask_decoder = UnetDecoderD(in_channels=start_filters * 2 ** (depth - shared_depth - 1),\r\n                                         out_channels=out_channels_mask, depth=depth - shared_depth,\r\n                                         blocks=blocks[2], residual=residual, batch_norm=batch_norm,\r\n                                         transpose=transpose, concat=concat,norm=nn.InstanceNorm2d)\r\n        self.vm_decoder = None\r\n        if use_vm_decoder:\r\n            self.vm_decoder = UnetDecoderD(in_channels=start_filters * 2 ** (depth - shared_depth - 1),\r\n                                           out_channels=out_channels_image, depth=depth - shared_depth,\r\n                                           blocks=blocks[3], residual=residual, batch_norm=batch_norm,\r\n                                           transpose=transpose, concat=concat,norm=nn.InstanceNorm2d)\r\n        self.shared_decoder = None\r\n        self.use_coarser = use_coarser\r\n        self.long_skip = long_skip\r\n        self.no_stage2 = no_stage2\r\n        self._forward = self.unshared_forward\r\n        if self.shared != 0:\r\n            self._forward = self.shared_forward\r\n            self.shared_decoder = UnetDecoderDatt(in_channels=start_filters * 2 ** (depth - 1),\r\n                                               out_channels=start_filters * 2 ** (depth - shared_depth - 1),\r\n                                               depth=shared_depth, blocks=blocks[4], residual=residual,\r\n                                               batch_norm=batch_norm, transpose=transpose, concat=concat,\r\n                                               is_final=False,norm=nn.InstanceNorm2d)\r\n\r\n        if s2am == 'unet':\r\n            self.s2am = UnetGenerator(4,3,is_attention_layer=True,attention_model=RASC,basicblock=MinimalUnetV2)\r\n        elif s2am == 'vm':\r\n            self.s2am = VMSingle(4)\r\n        elif s2am == 'vms2am':\r\n            self.s2am = VMSingleS2AM(4,down=ResDownNew,up=ResUpNew)\r\n\r\n    def set_optimizers(self):\r\n        self.optimizer_encoder = torch.optim.Adam(self.encoder.parameters(), lr=0.001)\r\n        self.optimizer_image = torch.optim.Adam(self.image_decoder.parameters(), lr=0.001)\r\n        self.optimizer_mask = torch.optim.Adam(self.mask_decoder.parameters(), lr=0.001)\r\n        self.optimizer_s2am = torch.optim.Adam(self.s2am.parameters(), lr=0.001)\r\n\r\n        if self.vm_decoder is not None:\r\n            self.optimizer_vm = torch.optim.Adam(self.vm_decoder.parameters(), lr=0.001)\r\n        if self.shared != 0:\r\n            self.optimizer_shared = torch.optim.Adam(self.shared_decoder.parameters(), lr=0.001)\r\n\r\n    def zero_grad_all(self):\r\n        self.optimizer_encoder.zero_grad()\r\n        self.optimizer_image.zero_grad()\r\n        self.optimizer_mask.zero_grad()\r\n        self.optimizer_s2am.zero_grad()\r\n        if self.vm_decoder is not None:\r\n            self.optimizer_vm.zero_grad()\r\n        if self.shared != 0:\r\n            self.optimizer_shared.zero_grad()\r\n\r\n    def step_all(self):\r\n        self.optimizer_encoder.step()\r\n        self.optimizer_image.step()\r\n        self.optimizer_mask.step()\r\n        self.optimizer_s2am.step()\r\n        if self.vm_decoder is not None:\r\n            self.optimizer_vm.step()\r\n        if self.shared != 0:\r\n            self.optimizer_shared.step()\r\n\r\n    def step_optimizer_image(self):\r\n        self.optimizer_image.step()\r\n\r\n    def __call__(self, synthesized):\r\n        return self._forward(synthesized)\r\n\r\n    def forward(self, synthesized):\r\n        return self._forward(synthesized)\r\n\r\n    def unshared_forward(self, synthesized):\r\n        image_code, before_pool = self.encoder(synthesized)\r\n        if not self.transfer_data:\r\n            before_pool = None\r\n        reconstructed_image = torch.tanh(self.image_decoder(image_code, before_pool))\r\n        reconstructed_mask = torch.sigmoid(self.mask_decoder(image_code, before_pool))\r\n        if self.vm_decoder is not None:\r\n            reconstructed_vm = torch.tanh(self.vm_decoder(image_code, before_pool))\r\n            return reconstructed_image, reconstructed_mask, reconstructed_vm\r\n        return reconstructed_image, reconstructed_mask\r\n\r\n    def shared_forward(self, synthesized):\r\n        image_code, before_pool = self.encoder(synthesized)\r\n        if self.transfer_data:\r\n            shared_before_pool = before_pool[- self.shared - 1:]\r\n            unshared_before_pool = before_pool[: - self.shared]\r\n        else:\r\n            before_pool = None\r\n            shared_before_pool = None\r\n            unshared_before_pool = None\r\n        im,mask,vm = self.shared_decoder(image_code, shared_before_pool)\r\n        reconstructed_image = torch.tanh(self.image_decoder(im, unshared_before_pool))\r\n        if self.long_skip:\r\n            reconstructed_image = reconstructed_image + synthesized\r\n\r\n        reconstructed_mask = torch.sigmoid(self.mask_decoder(mask, unshared_before_pool))\r\n        if self.vm_decoder is not None:\r\n            reconstructed_vm = torch.tanh(self.vm_decoder(vm, unshared_before_pool))\r\n            if self.long_skip:\r\n                reconstructed_vm = reconstructed_vm + synthesized\r\n\r\n        coarser = reconstructed_image * reconstructed_mask + (1-reconstructed_mask)* synthesized\r\n        \r\n        if self.use_coarser:\r\n            refine =  torch.tanh(self.s2am(torch.cat([coarser,reconstructed_mask],dim=1))) + coarser\r\n        elif self.no_stage2:\r\n            refine =  torch.tanh(self.s2am(torch.cat([coarser,reconstructed_mask],dim=1)))\r\n        else:\r\n            refine =  torch.tanh(self.s2am(torch.cat([coarser,reconstructed_mask],dim=1))) + synthesized\r\n\r\n        # final = refine * reconstructed_mask + (1-reconstructed_mask)* synthesized\r\n        if self.vm_decoder is not None:\r\n            return [refine, reconstructed_image], reconstructed_mask, reconstructed_vm\r\n        else:\r\n            return [refine, reconstructed_image], reconstructed_mask\r\n\r\n\r\n"
  },
  {
    "path": "scripts/models/unet.py",
    "content": "import torch\r\nimport torch.nn as nn\r\nfrom torch.nn import init\r\nimport functools\r\nfrom scripts.models.blocks import *\r\nfrom scripts.models.rasc import *\r\n\r\n\r\nclass MinimalUnetV2(nn.Module):\r\n    \"\"\"docstring for MinimalUnet\"\"\"\r\n    def __init__(self, down=None,up=None,submodule=None,attention=None,withoutskip=False,**kwags):\r\n        super(MinimalUnetV2, self).__init__()\r\n        \r\n        self.down = nn.Sequential(*down)\r\n        self.up = nn.Sequential(*up) \r\n        self.sub = submodule\r\n        self.attention = attention\r\n        self.withoutskip = withoutskip\r\n        self.is_attention = not self.attention == None \r\n        self.is_sub = not submodule == None \r\n    \r\n    def forward(self,x,mask=None):\r\n        if self.is_sub: \r\n            x_up,_ = self.sub(self.down(x),mask)\r\n        else:\r\n            x_up = self.down(x)\r\n\r\n        if self.withoutskip: #outer or inner.\r\n            x_out = self.up(x_up)\r\n        else:\r\n            if self.is_attention:\r\n                x_out = (self.attention(torch.cat([x,self.up(x_up)],1),mask),mask)\r\n            else:\r\n                x_out = (torch.cat([x,self.up(x_up)],1),mask)\r\n\r\n        return x_out\r\n\r\n\r\nclass MinimalUnet(nn.Module):\r\n    \"\"\"docstring for MinimalUnet\"\"\"\r\n    def __init__(self, down=None,up=None,submodule=None,attention=None,withoutskip=False,**kwags):\r\n        super(MinimalUnet, self).__init__()\r\n        \r\n        self.down = nn.Sequential(*down)\r\n        self.up = nn.Sequential(*up) \r\n        self.sub = submodule\r\n        self.attention = attention\r\n        self.withoutskip = withoutskip\r\n        self.is_attention = not self.attention == None \r\n        self.is_sub = not submodule == None \r\n    \r\n    def forward(self,x,mask=None):\r\n        if self.is_sub: \r\n            x_up,_ = self.sub(self.down(x),mask)\r\n        else:\r\n            x_up = self.down(x)\r\n\r\n        if self.is_attention:\r\n            x = self.attention(x,mask)\r\n        \r\n        if self.withoutskip: #outer or inner.\r\n            x_out = self.up(x_up)\r\n        else:\r\n            x_out = (torch.cat([x,self.up(x_up)],1),mask)\r\n\r\n        return x_out\r\n\r\n\r\n# Defines the submodule with skip connection.\r\n# X -------------------identity---------------------- X\r\n#   |-- downsampling -- |submodule| -- upsampling --|\r\nclass UnetSkipConnectionBlock(nn.Module):\r\n    def __init__(self, outer_nc, inner_nc, input_nc=None,\r\n                 submodule=None, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d, use_dropout=False,is_attention_layer=False,\r\n                 attention_model=RASC,basicblock=MinimalUnet,outermostattention=False):\r\n        super(UnetSkipConnectionBlock, self).__init__()\r\n        self.outermost = outermost\r\n        if type(norm_layer) == functools.partial:\r\n            use_bias = norm_layer.func == nn.InstanceNorm2d\r\n        else:\r\n            use_bias = norm_layer == nn.InstanceNorm2d\r\n        if input_nc is None:\r\n            input_nc = outer_nc\r\n        downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4,\r\n                             stride=2, padding=1, bias=use_bias)\r\n        downrelu = nn.LeakyReLU(0.2, True)\r\n        downnorm = norm_layer(inner_nc)\r\n        uprelu = nn.ReLU(True)\r\n        upnorm = norm_layer(outer_nc)\r\n\r\n\r\n        if outermost:\r\n            upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,\r\n                                        kernel_size=4, stride=2,\r\n                                        padding=1)\r\n            down = [downconv]\r\n            up = [uprelu, upconv]\r\n            model = basicblock(down,up,submodule,withoutskip=outermost)\r\n        elif innermost:\r\n            upconv = nn.ConvTranspose2d(inner_nc, outer_nc,\r\n                                        kernel_size=4, stride=2,\r\n                                        padding=1, bias=use_bias)\r\n            down = [downrelu, downconv]\r\n            up = [uprelu, upconv, upnorm]\r\n            model = basicblock(down,up)\r\n        else:\r\n            upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,\r\n                                        kernel_size=4, stride=2,\r\n                                        padding=1, bias=use_bias)\r\n            down = [downrelu, downconv, downnorm]\r\n            up = [uprelu, upconv, upnorm]\r\n\r\n            if is_attention_layer:\r\n                if MinimalUnetV2.__qualname__ in basicblock.__qualname__  :\r\n                    attention_model = attention_model(input_nc*2)\r\n                else:\r\n                    attention_model = attention_model(input_nc)     \r\n            else:\r\n                attention_model = None\r\n                \r\n            if use_dropout:\r\n                model = basicblock(down,up.append(nn.Dropout(0.5)),submodule,attention_model,outermostattention=outermostattention)\r\n            else:\r\n                model = basicblock(down,up,submodule,attention_model,outermostattention=outermostattention)\r\n\r\n        self.model = model\r\n\r\n\r\n    def forward(self, x,mask=None):\r\n        # build the mask for attention use\r\n        return self.model(x,mask)\r\n            \r\nclass UnetGenerator(nn.Module):\r\n    def __init__(self, input_nc, output_nc, num_downs=8, ngf=64,norm_layer=nn.BatchNorm2d, use_dropout=False,\r\n                 is_attention_layer=False,attention_model=RASC,use_inner_attention=False,basicblock=MinimalUnet):\r\n        super(UnetGenerator, self).__init__()\r\n\r\n        # 8 for 256x256\r\n        # 9 for 512x512\r\n        # construct unet structure\r\n        self.need_mask = not input_nc == output_nc\r\n\r\n        unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True,basicblock=basicblock) # 1\r\n        for i in range(num_downs - 5): #3 times\r\n            unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer, use_dropout=use_dropout,is_attention_layer=use_inner_attention,attention_model=attention_model,basicblock=basicblock) # 8,4,2\r\n        unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer,is_attention_layer=is_attention_layer,attention_model=attention_model,basicblock=basicblock) #16\r\n        unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, norm_layer=norm_layer,is_attention_layer=is_attention_layer,attention_model=attention_model,basicblock=basicblock) #32\r\n        unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer,is_attention_layer=is_attention_layer,attention_model=attention_model,basicblock=basicblock, outermostattention=True) #64 \r\n        unet_block = UnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, basicblock=basicblock, norm_layer=norm_layer) # 128\r\n\r\n        self.model = unet_block\r\n\r\n    def forward(self, input):\r\n        if self.need_mask:\r\n            return self.model(input,input[:,3:4,:,:])\r\n        else:\r\n            return self.model(input[:,0:3,:,:],input[:,3:4,:,:])\r\n\r\n\r\n\r\n"
  },
  {
    "path": "scripts/models/vgg.py",
    "content": "from collections import namedtuple\n\nimport torch\nfrom torchvision import models\n\n\nclass Vgg16(torch.nn.Module):\n    def __init__(self, requires_grad=False):\n        super(Vgg16, self).__init__()\n        vgg_pretrained_features = models.vgg16(pretrained=True).features\n        self.slice1 = torch.nn.Sequential()\n        self.slice2 = torch.nn.Sequential()\n        self.slice3 = torch.nn.Sequential()\n        self.slice4 = torch.nn.Sequential()\n        self.slice5 = torch.nn.Sequential()\n        for x in range(4):\n            self.slice1.add_module(str(x), vgg_pretrained_features[x])\n        for x in range(4, 9):\n            self.slice2.add_module(str(x), vgg_pretrained_features[x])\n        for x in range(9, 16):\n            self.slice3.add_module(str(x), vgg_pretrained_features[x])\n        for x in range(16, 23):\n            self.slice4.add_module(str(x), vgg_pretrained_features[x])\n        for x in range(23,30):\n            self.slice5.add_module(str(x), vgg_pretrained_features[x])\n                \n        if not requires_grad:\n            for param in self.parameters():\n                param.requires_grad = False\n\n    def forward(self, X):\n        h = self.slice1(X)\n        h_relu1_2 = h\n        h = self.slice2(h)\n        h_relu2_2 = h\n        h = self.slice3(h)\n        h_relu3_3 = h\n        h = self.slice4(h)\n        h_relu4_3 = h\n        h = self.slice5(h)\n        h_relu5_3 = h\n        # vgg_outputs = namedtuple(\"VggOutputs\", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3','relu5_3'])\n        # out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3)\n        return (h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3)\n\n\nclass Vgg19(torch.nn.Module):\n    def __init__(self, requires_grad=False):\n        super(Vgg19, self).__init__()\n        # vgg_pretrained_features = models.vgg19(pretrained=True).features\n        self.vgg_pretrained_features = models.vgg19(pretrained=True).features\n        # self.slice1 = torch.nn.Sequential()\n        # self.slice2 = torch.nn.Sequential()\n        # self.slice3 = torch.nn.Sequential()\n        # self.slice4 = torch.nn.Sequential()\n        # self.slice5 = torch.nn.Sequential()\n        # for x in range(2):\n        #     self.slice1.add_module(str(x), vgg_pretrained_features[x])\n        # for x in range(2, 7):\n        #     self.slice2.add_module(str(x), vgg_pretrained_features[x])\n        # for x in range(7, 12):\n        #     self.slice3.add_module(str(x), vgg_pretrained_features[x])\n        # for x in range(12, 21):\n        #     self.slice4.add_module(str(x), vgg_pretrained_features[x])\n        # for x in range(21, 30):\n        #     self.slice5.add_module(str(x), vgg_pretrained_features[x])\n        if not requires_grad:\n            for param in self.parameters():\n                param.requires_grad = False\n\n    def forward(self, X, indices=None):\n        if indices is None:\n            indices = [2, 7, 12, 21, 30]\n        out = []\n        #indices = sorted(indices)\n        for i in range(indices[-1]):\n            X = self.vgg_pretrained_features[i](X)\n            if (i+1) in indices:\n                out.append(X)\n        \n        return out\n"
  },
  {
    "path": "scripts/models/vmu.py",
    "content": "\r\nimport torch\r\nimport torch.nn as nn\r\nimport torch.nn.functional as F\r\nfrom scripts.models.blocks import SEBlock\r\nfrom scripts.models.rasc import *\r\nfrom scripts.models.unet import UnetGenerator,MinimalUnetV2\r\n\r\ndef weight_init(m):\r\n    if isinstance(m, nn.Conv2d):\r\n        nn.init.xavier_normal_(m.weight)\r\n        nn.init.constant_(m.bias, 0)\r\n\r\ndef reset_params(model):\r\n    for i, m in enumerate(model.modules()):\r\n        weight_init(m)\r\n\r\n\r\ndef conv3x3(in_channels, out_channels, stride=1,\r\n            padding=1, bias=True, groups=1):\r\n    return nn.Conv2d(\r\n        in_channels,\r\n        out_channels,\r\n        kernel_size=3,\r\n        stride=stride,\r\n        padding=padding,\r\n        bias=bias,\r\n        groups=groups)\r\n\r\n\r\ndef up_conv2x2(in_channels, out_channels, transpose=True):\r\n    if transpose:\r\n        return nn.ConvTranspose2d(\r\n            in_channels,\r\n            out_channels,\r\n            kernel_size=2,\r\n            stride=2)\r\n    else:\r\n        return nn.Sequential(\r\n            nn.Upsample(mode='bilinear', scale_factor=2),\r\n            conv1x1(in_channels, out_channels))\r\n\r\n\r\ndef conv1x1(in_channels, out_channels, groups=1):\r\n    return nn.Conv2d(\r\n        in_channels,\r\n        out_channels,\r\n        kernel_size=1,\r\n        groups=groups,\r\n        stride=1)\r\n\r\n\r\n\r\n\r\nclass UpCoXvD(nn.Module):\r\n\r\n    def __init__(self, in_channels, out_channels, blocks, residual=True, batch_norm=True, transpose=True,concat=True,use_att=False):\r\n        super(UpCoXvD, self).__init__()\r\n        self.concat = concat\r\n        self.residual = residual\r\n        self.batch_norm = batch_norm\r\n        self.bn = None\r\n        self.conv2 = []\r\n        self.use_att = use_att\r\n        self.up_conv = up_conv2x2(in_channels, out_channels, transpose=transpose)\r\n        \r\n        if self.use_att:\r\n            self.s2am = RASC(2 * out_channels)\r\n        else:\r\n            self.s2am = None\r\n\r\n        if self.concat:\r\n            self.conv1 = conv3x3(2 * out_channels, out_channels)\r\n        else:\r\n            self.conv1 = conv3x3(out_channels, out_channels)\r\n        for _ in range(blocks):\r\n            self.conv2.append(conv3x3(out_channels, out_channels))\r\n        if self.batch_norm:\r\n            self.bn = []\r\n            for _ in range(blocks):\r\n                self.bn.append(nn.BatchNorm2d(out_channels))\r\n            self.bn = nn.ModuleList(self.bn)\r\n        self.conv2 = nn.ModuleList(self.conv2)\r\n\r\n    def forward(self, from_up, from_down, mask=None):\r\n        from_up = self.up_conv(from_up)\r\n        if self.concat:\r\n            x1 = torch.cat((from_up, from_down), 1)\r\n        else:\r\n            if from_down is not None:\r\n                x1 = from_up + from_down\r\n            else:\r\n                x1 = from_up\r\n\r\n        if self.use_att:\r\n            x1 = self.s2am(x1,mask)\r\n\r\n        x1 = F.relu(self.conv1(x1))\r\n        x2 = None\r\n        for idx, conv in enumerate(self.conv2):\r\n            x2 = conv(x1)\r\n            if self.batch_norm:\r\n                x2 = self.bn[idx](x2)\r\n            if self.residual:\r\n                x2 = x2 + x1\r\n            x2 = F.relu(x2)\r\n            x1 = x2\r\n        return x2\r\n\r\n\r\nclass DownCoXvD(nn.Module):\r\n\r\n    def __init__(self, in_channels, out_channels, blocks, pooling=True, residual=True, batch_norm=True):\r\n        super(DownCoXvD, self).__init__()\r\n        self.pooling = pooling\r\n        self.residual = residual\r\n        self.batch_norm = batch_norm\r\n        self.bn = None\r\n        self.pool = None\r\n        self.conv1 = conv3x3(in_channels, out_channels)\r\n        self.conv2 = []\r\n        for _ in range(blocks):\r\n            self.conv2.append(conv3x3(out_channels, out_channels))\r\n        if self.batch_norm:\r\n            self.bn = []\r\n            for _ in range(blocks):\r\n                self.bn.append(nn.BatchNorm2d(out_channels))\r\n            self.bn = nn.ModuleList(self.bn)\r\n        if self.pooling:\r\n            self.pool = nn.MaxPool2d(kernel_size=2, stride=2)\r\n        self.conv2 = nn.ModuleList(self.conv2)\r\n\r\n    def __call__(self, x):\r\n        return self.forward(x)\r\n\r\n    def forward(self, x):\r\n        x1 = F.relu(self.conv1(x))\r\n        x2 = None\r\n        for idx, conv in enumerate(self.conv2):\r\n            x2 = conv(x1)\r\n            if self.batch_norm:\r\n                x2 = self.bn[idx](x2)\r\n            if self.residual:\r\n                x2 = x2 + x1\r\n            x2 = F.relu(x2)\r\n            x1 = x2\r\n        before_pool = x2\r\n        if self.pooling:\r\n            x2 = self.pool(x2)\r\n        return x2, before_pool\r\n\r\nclass UnetDecoderD(nn.Module):\r\n    def __init__(self, in_channels=512, out_channels=3, depth=5, blocks=1, residual=True, batch_norm=True,\r\n                 transpose=True, concat=True, is_final=True):\r\n        super(UnetDecoderD, self).__init__()\r\n        self.conv_final = None\r\n        self.up_convs = []\r\n        outs = in_channels\r\n        for i in range(depth-1):\r\n            ins = outs\r\n            outs = ins // 2\r\n            # 512,256\r\n            # 256,128\r\n            # 128,64\r\n            # 64,32\r\n            up_conv = UpCoXvD(ins, outs, blocks, residual=residual, batch_norm=batch_norm, transpose=transpose,\r\n                              concat=concat)\r\n            self.up_convs.append(up_conv)\r\n        if is_final:\r\n            self.conv_final = conv1x1(outs, out_channels)\r\n        else:\r\n            up_conv = UpCoXvD(outs, out_channels, blocks, residual=residual, batch_norm=batch_norm, transpose=transpose,\r\n                              concat=concat)\r\n            self.up_convs.append(up_conv)\r\n        self.up_convs = nn.ModuleList(self.up_convs)\r\n        reset_params(self)\r\n\r\n    def __call__(self, x, encoder_outs=None):\r\n        return self.forward(x, encoder_outs)\r\n\r\n    def forward(self, x, encoder_outs=None):\r\n        for i, up_conv in enumerate(self.up_convs):\r\n            before_pool = None\r\n            if encoder_outs is not None:\r\n                before_pool = encoder_outs[-(i+2)]\r\n            x = up_conv(x, before_pool)\r\n        if self.conv_final is not None:\r\n            x = self.conv_final(x)\r\n        return x\r\n\r\n\r\nclass UnetEncoderD(nn.Module):\r\n\r\n    def __init__(self, in_channels=3, depth=5, blocks=1, start_filters=32, residual=True, batch_norm=True):\r\n        super(UnetEncoderD, self).__init__()\r\n        self.down_convs = []\r\n        outs = None\r\n        if type(blocks) is tuple:\r\n            blocks = blocks[0]\r\n        for i in range(depth):\r\n            ins = in_channels if i == 0 else outs\r\n            outs = start_filters*(2**i)\r\n            pooling = True if i < depth-1 else False\r\n            down_conv = DownCoXvD(ins, outs, blocks, pooling=pooling, residual=residual, batch_norm=batch_norm)\r\n            self.down_convs.append(down_conv)\r\n        self.down_convs = nn.ModuleList(self.down_convs)\r\n        reset_params(self)\r\n\r\n    def __call__(self, x):\r\n        return self.forward(x)\r\n\r\n    def forward(self, x):\r\n        encoder_outs = []\r\n        for d_conv in self.down_convs:\r\n            x, before_pool = d_conv(x)\r\n            encoder_outs.append(before_pool)\r\n        return x, encoder_outs\r\n\r\n\r\n\r\nclass UnetVM(nn.Module):\r\n\r\n    def __init__(self, in_channels=3, depth=5, shared_depth=0, use_vm_decoder=False, blocks=1,\r\n                 out_channels_image=3, out_channels_mask=1, start_filters=32, residual=True, batch_norm=True,\r\n                 transpose=True, concat=True, transfer_data=True, long_skip=False):\r\n        super(UnetVM, self).__init__()\r\n        self.transfer_data = transfer_data\r\n        self.shared = shared_depth\r\n        self.optimizer_encoder,  self.optimizer_image, self.optimizer_vm = None, None, None\r\n        self.optimizer_mask, self.optimizer_shared = None, None\r\n        if type(blocks) is not tuple:\r\n            blocks = (blocks, blocks, blocks, blocks, blocks)\r\n        if not transfer_data:\r\n            concat = False\r\n        self.encoder = UnetEncoderD(in_channels=in_channels, depth=depth, blocks=blocks[0],\r\n                                    start_filters=start_filters, residual=residual, batch_norm=batch_norm)\r\n        self.image_decoder = UnetDecoderD(in_channels=start_filters * 2 ** (depth - shared_depth - 1),\r\n                                          out_channels=out_channels_image, depth=depth - shared_depth,\r\n                                          blocks=blocks[1], residual=residual, batch_norm=batch_norm,\r\n                                          transpose=transpose, concat=concat)\r\n        self.mask_decoder = UnetDecoderD(in_channels=start_filters * 2 ** (depth - 1),\r\n                                         out_channels=out_channels_mask, depth=depth,\r\n                                         blocks=blocks[2], residual=residual, batch_norm=batch_norm,\r\n                                         transpose=transpose, concat=concat)\r\n        self.vm_decoder = None\r\n        if use_vm_decoder:\r\n            self.vm_decoder = UnetDecoderD(in_channels=start_filters * 2 ** (depth - shared_depth - 1),\r\n                                           out_channels=out_channels_image, depth=depth - shared_depth,\r\n                                           blocks=blocks[3], residual=residual, batch_norm=batch_norm,\r\n                                           transpose=transpose, concat=concat)\r\n        self.shared_decoder = None\r\n        self.long_skip = long_skip\r\n        self._forward = self.unshared_forward\r\n        if self.shared != 0:\r\n            self._forward = self.shared_forward\r\n            self.shared_decoder = UnetDecoderD(in_channels=start_filters * 2 ** (depth - 1),\r\n                                               out_channels=start_filters * 2 ** (depth - shared_depth - 1),\r\n                                               depth=shared_depth, blocks=blocks[4], residual=residual,\r\n                                               batch_norm=batch_norm, transpose=transpose, concat=concat,\r\n                                               is_final=False)\r\n\r\n    def set_optimizers(self):\r\n        self.optimizer_encoder = torch.optim.Adam(self.encoder.parameters(), lr=0.001)\r\n        self.optimizer_image = torch.optim.Adam(self.image_decoder.parameters(), lr=0.001)\r\n        self.optimizer_mask = torch.optim.Adam(self.mask_decoder.parameters(), lr=0.001)\r\n        if self.vm_decoder is not None:\r\n            self.optimizer_vm = torch.optim.Adam(self.vm_decoder.parameters(), lr=0.001)\r\n        if self.shared != 0:\r\n            self.optimizer_shared = torch.optim.Adam(self.shared_decoder.parameters(), lr=0.001)\r\n\r\n    def zero_grad_all(self):\r\n        self.optimizer_encoder.zero_grad()\r\n        self.optimizer_image.zero_grad()\r\n        self.optimizer_mask.zero_grad()\r\n        if self.vm_decoder is not None:\r\n            self.optimizer_vm.zero_grad()\r\n        if self.shared != 0:\r\n            self.optimizer_shared.zero_grad()\r\n\r\n    def step_all(self):\r\n        self.optimizer_encoder.step()\r\n        self.optimizer_image.step()\r\n        self.optimizer_mask.step()\r\n        if self.vm_decoder is not None:\r\n            self.optimizer_vm.step()\r\n        if self.shared != 0:\r\n            self.optimizer_shared.step()\r\n\r\n    def step_optimizer_image(self):\r\n        self.optimizer_image.step()\r\n\r\n    def __call__(self, synthesized):\r\n        return self._forward(synthesized)\r\n\r\n    def forward(self, synthesized):\r\n        return self._forward(synthesized)\r\n\r\n    def unshared_forward(self, synthesized):\r\n        image_code, before_pool = self.encoder(synthesized)\r\n        if not self.transfer_data:\r\n            before_pool = None\r\n        reconstructed_image = torch.tanh(self.image_decoder(image_code, before_pool))\r\n        reconstructed_mask = torch.sigmoid(self.mask_decoder(image_code, before_pool))\r\n        if self.vm_decoder is not None:\r\n            reconstructed_vm = torch.tanh(self.vm_decoder(image_code, before_pool))\r\n            return reconstructed_image, reconstructed_mask, reconstructed_vm\r\n        return reconstructed_image, reconstructed_mask\r\n\r\n    def shared_forward(self, synthesized):\r\n        image_code, before_pool = self.encoder(synthesized)\r\n        if self.transfer_data:\r\n            shared_before_pool = before_pool[- self.shared - 1:]\r\n            unshared_before_pool = before_pool[: - self.shared]\r\n        else:\r\n            before_pool = None\r\n            shared_before_pool = None\r\n            unshared_before_pool = None\r\n        x = self.shared_decoder(image_code, shared_before_pool)\r\n        reconstructed_image = torch.tanh(self.image_decoder(x, unshared_before_pool))\r\n        if self.long_skip:\r\n            reconstructed_image = reconstructed_image + synthesized\r\n\r\n        reconstructed_mask = torch.sigmoid(self.mask_decoder(image_code, before_pool))\r\n        if self.vm_decoder is not None:\r\n            reconstructed_vm = torch.tanh(self.vm_decoder(x, unshared_before_pool))\r\n            if self.long_skip:\r\n                reconstructed_vm = reconstructed_vm + synthesized\r\n            return reconstructed_image, reconstructed_mask, reconstructed_vm\r\n        return reconstructed_image, reconstructed_mask\r\n"
  },
  {
    "path": "scripts/utils/__init__.py",
    "content": "from __future__ import absolute_import\n\nfrom .evaluation import *\nfrom .imutils import *\nfrom .logger import *\nfrom .misc import *\nfrom .osutils import *\nfrom .transforms import *\n"
  },
  {
    "path": "scripts/utils/evaluation.py",
    "content": "from __future__ import absolute_import\n\nimport math\nimport numpy as np\nimport matplotlib.pyplot as plt\nfrom random import randint\n\nfrom .misc import *\nfrom .transforms import transform, transform_preds\n\n__all__ = ['accuracy', 'AverageMeter']\n\ndef get_preds(scores):\n    ''' get predictions from score maps in torch Tensor\n        return type: torch.LongTensor\n    '''\n    assert scores.dim() == 4, 'Score maps should be 4-dim'\n    maxval, idx = torch.max(scores.view(scores.size(0), scores.size(1), -1), 2)\n\n    maxval = maxval.view(scores.size(0), scores.size(1), 1)\n    idx = idx.view(scores.size(0), scores.size(1), 1) + 1\n\n    preds = idx.repeat(1, 1, 2).float()\n\n    preds[:,:,0] = (preds[:,:,0] - 1) % scores.size(3) + 1\n    preds[:,:,1] = torch.floor((preds[:,:,1] - 1) / scores.size(2)) + 1\n\n    pred_mask = maxval.gt(0).repeat(1, 1, 2).float()\n    preds *= pred_mask\n    return preds\n\ndef calc_dists(preds, target, normalize):\n    preds = preds.float()\n    target = target.float()\n    dists = torch.zeros(preds.size(1), preds.size(0))\n    for n in range(preds.size(0)):\n        for c in range(preds.size(1)):\n            if target[n,c,0] > 1 and target[n, c, 1] > 1:\n                dists[c, n] = torch.dist(preds[n,c,:], target[n,c,:])/normalize[n]\n            else:\n                dists[c, n] = -1\n    return dists\n\ndef dist_acc(dists, thr=0.5):\n    ''' Return percentage below threshold while ignoring values with a -1 '''\n    if dists.ne(-1).sum() > 0:\n        return dists.le(thr).eq(dists.ne(-1)).sum()*1.0 / dists.ne(-1).sum()\n    else:\n        return -1\n\n\n\ndef accuracy(output, target, thr=0.5):\n    ''' Calculate accuracy according to PCK, but uses ground truth heatmap rather than x,y locations\n        First value to be returned is average accuracy across 'idxs', followed by individual accuracies\n    '''\n    # output_mask = torch.gt(output,thr);\n    # target_mask = torch.gt(target,thr);\n    # equal_mask = torch.eq(output_mask,target_mask);\n    # fp_equal_mask = torch.lt(output_mask,target_mask);\n    # fn_equal_mask = torch.gt(output_mask,target_mask);\n\n\n    # tp = torch.sum(equal_mask);\n    # fn = torch.sum(fn_equal_mask);\n    # fp = torch.sum(fp_equal_mask);\n\n    # return 2*tp / (2*tp+fn+fp)\n\n\n    if output.dim() > 2:\n        v,i = torch.max(output,1);\n    else:\n        v,i = torch.max(output,1);\n    return torch.sum(target.long() == i).float()/target.numel()\n\ndef final_preds(output, center, scale, res):\n    coords = get_preds(output) # float type\n\n    # pose-processing\n    for n in range(coords.size(0)):\n        for p in range(coords.size(1)):\n            hm = output[n][p]\n            px = int(math.floor(coords[n][p][0]))\n            py = int(math.floor(coords[n][p][1]))\n            if px > 1 and px < res[0] and py > 1 and py < res[1]:\n                diff = torch.Tensor([hm[py - 1][px] - hm[py - 1][px - 2], hm[py][px - 1]-hm[py - 2][px - 1]])\n                coords[n][p] += diff.sign() * .25\n    coords += 0.5\n    preds = coords.clone()\n\n    # Transform back\n    for i in range(coords.size(0)):\n        preds[i] = transform_preds(coords[i], center[i], scale[i], res)\n\n    if preds.dim() < 3:\n        preds = preds.view(1, preds.size())\n\n    return preds\n\n    \nclass AverageMeter(object):\n    \"\"\"Computes and stores the average and current value\"\"\"\n    def __init__(self):\n        self.reset()\n\n    def reset(self):\n        self.val = 0\n        self.avg = 0\n        self.sum = 0\n        self.count = 0\n\n    def update(self, val, n=1):\n        self.val = val\n        self.sum += val * n\n        self.count += n\n        self.avg = self.sum / self.count\n"
  },
  {
    "path": "scripts/utils/imutils.py",
    "content": "from __future__ import absolute_import\n\nimport torch\nimport torch.nn as nn\nimport numpy as np\nimport scipy.misc\n\nfrom .misc import *\n\ndef im_to_numpy(img):\n    img = to_numpy(img)\n    img = np.transpose(img, (1, 2, 0)) # H*W*C\n    return img\n\ndef im_to_torch(img):\n    img = np.transpose(img, (2, 0, 1)) # C*H*W\n    img = to_torch(img).float()\n    if img.max() > 1:\n        img /= 255\n    return img\n\ndef load_image(img_path):\n    # H x W x C => C x H x W\n    return im_to_torch(scipy.misc.imread(img_path, mode='RGB'))\n\ndef imread_all(img_path):\n    return scipy.misc.imread(img_path, mode='RGB')\n\ndef load_image_gray(img_path):\n    # H x W x C => C x H x W\n    x = scipy.misc.imread(img_path, mode='L')\n    x = x[:,:,np.newaxis]\n    return im_to_torch(x)\n\ndef resize(img, owidth, oheight):\n    img = im_to_numpy(img)\n\n    if img.shape[2] == 1:\n        img = scipy.misc.imresize(img.squeeze(),(oheight,owidth))\n        img = img[:,:,np.newaxis]\n    else:\n        img = scipy.misc.imresize(\n                img,\n                (oheight, owidth)\n            )\n    img = im_to_torch(img)\n    # print('%f %f' % (img.min(), img.max()))\n    return img\n\n# =============================================================================\n# Helpful functions generating groundtruth labelmap \n# =============================================================================\n\ndef gaussian(shape=(7,7),sigma=1):\n    \"\"\"\n    2D gaussian mask - should give the same result as MATLAB's\n    fspecial('gaussian',[shape],[sigma])\n    \"\"\"\n    m,n = [(ss-1.)/2. for ss in shape]\n    y,x = np.ogrid[-m:m+1,-n:n+1]\n    h = np.exp( -(x*x + y*y) / (2.*sigma*sigma) )\n    h[ h < np.finfo(h.dtype).eps*h.max() ] = 0\n    return to_torch(h).float()\n\ndef draw_labelmap(img, pt, sigma, type='Gaussian'):\n    # Draw a 2D gaussian \n    # Adopted from https://github.com/anewell/pose-hg-train/blob/master/src/pypose/draw.py\n    img = to_numpy(img)\n\n    # Check that any part of the gaussian is in-bounds\n    ul = [int(pt[0] - 3 * sigma), int(pt[1] - 3 * sigma)]\n    br = [int(pt[0] + 3 * sigma + 1), int(pt[1] + 3 * sigma + 1)]\n    if (ul[0] >= img.shape[1] or ul[1] >= img.shape[0] or\n            br[0] < 0 or br[1] < 0):\n        # If not, just return the image as is\n        return to_torch(img)\n\n    # Generate gaussian\n    size = 6 * sigma + 1\n    x = np.arange(0, size, 1, float)\n    y = x[:, np.newaxis]\n    x0 = y0 = size // 2\n    # The gaussian is not normalized, we want the center value to equal 1\n    if type == 'Gaussian':\n        g = np.exp(- ((x - x0) ** 2 + (y - y0) ** 2) / (2 * sigma ** 2))\n    elif type == 'Cauchy':\n        g = sigma / (((x - x0) ** 2 + (y - y0) ** 2 + sigma ** 2) ** 1.5)\n\n\n    # Usable gaussian range\n    g_x = max(0, -ul[0]), min(br[0], img.shape[1]) - ul[0]\n    g_y = max(0, -ul[1]), min(br[1], img.shape[0]) - ul[1]\n    # Image range\n    img_x = max(0, ul[0]), min(br[0], img.shape[1])\n    img_y = max(0, ul[1]), min(br[1], img.shape[0])\n\n    img[img_y[0]:img_y[1], img_x[0]:img_x[1]] = g[g_y[0]:g_y[1], g_x[0]:g_x[1]]\n    return to_torch(img)\n\n# =============================================================================\n# Helpful display functions\n# =============================================================================\n\ndef gauss(x, a, b, c, d=0):\n    return a * np.exp(-(x - b)**2 / (2 * c**2)) + d\n\ndef color_heatmap(x):\n    x = to_numpy(x)\n    color = np.zeros((x.shape[0],x.shape[1],3))\n    color[:,:,0] = gauss(x, .5, .6, .2) + gauss(x, 1, .8, .3)\n    color[:,:,1] = gauss(x, 1, .5, .3)\n    color[:,:,2] = gauss(x, 1, .2, .3)\n    color[color > 1] = 1\n    color = (color * 255).astype(np.uint8)\n    return color\n\ndef imshow(img):\n    npimg = im_to_numpy(img*255).astype(np.uint8)\n    plt.imshow(npimg)\n    plt.axis('off')\n\ndef show_joints(img, pts):\n    imshow(img)\n    \n    for i in range(pts.size(0)):\n        if pts[i, 2] > 0:\n            plt.plot(pts[i, 0], pts[i, 1], 'yo')\n    plt.axis('off')\n\ndef show_sample(inputs, target):\n    num_sample = inputs.size(0)\n    num_joints = target.size(1)\n    height = target.size(2)\n    width = target.size(3)\n\n    for n in range(num_sample):\n        inp = resize(inputs[n], width, height)\n        out = inp\n        for p in range(num_joints):\n            tgt = inp*0.5 + color_heatmap(target[n,p,:,:])*0.5\n            out = torch.cat((out, tgt), 2)\n        \n        imshow(out)\n        plt.show()\n\ndef sample_with_heatmap(inp, out, num_rows=2, parts_to_show=None):\n    inp = to_numpy(inp * 255)\n    out = to_numpy(out)\n\n    img = np.zeros((inp.shape[1], inp.shape[2], inp.shape[0]))\n    for i in range(3):\n        img[:, :, i] = inp[i, :, :]\n\n    if parts_to_show is None:\n        parts_to_show = np.arange(out.shape[0])\n\n    # Generate a single image to display input/output pair\n    num_cols = int(np.ceil(float(len(parts_to_show)) / num_rows))\n    size = img.shape[0] // num_rows\n\n    full_img = np.zeros((img.shape[0], size * (num_cols + num_rows), 3), np.uint8)\n    full_img[:img.shape[0], :img.shape[1]] = img\n\n    inp_small = scipy.misc.imresize(img, [size, size])\n\n    # Set up heatmap display for each part\n    for i, part in enumerate(parts_to_show):\n        part_idx = part\n        out_resized = scipy.misc.imresize(out[part_idx], [size, size])\n        out_resized = out_resized.astype(float)/255\n        out_img = inp_small.copy() * .3\n        color_hm = color_heatmap(out_resized)\n        out_img += color_hm * .7\n\n        col_offset = (i % num_cols + num_rows) * size\n        row_offset = (i // num_cols) * size\n        full_img[row_offset:row_offset + size, col_offset:col_offset + size] = out_img\n\n    return full_img\n\ndef batch_with_heatmap(inputs, outputs, mean=torch.Tensor([0.5, 0.5, 0.5]), num_rows=2, parts_to_show=None):\n    batch_img = []\n    for n in range(min(inputs.size(0), 4)):\n        inp = inputs[n] + mean.view(3, 1, 1).expand_as(inputs[n])\n        batch_img.append(\n            sample_with_heatmap(inp.clamp(0, 1), outputs[n], num_rows=num_rows, parts_to_show=parts_to_show)\n        )\n    return np.concatenate(batch_img)\n\n\ndef normalize_batch(batch):\n    # normalize using imagenet mean and std\n    mean = batch.new_tensor([0.485, 0.456, 0.406]).view(-1, 1, 1)\n    std = batch.new_tensor([0.229, 0.224, 0.225]).view(-1, 1, 1)\n    batch = batch/255.0\n    return (batch - mean) / std\n\ndef show_image_tensor(tensor):\n    re = []\n    for i in range(tensor.size(0)):\n        inp = tensor[i].data.cpu() #w,h,c\n        inp = inp.numpy().transpose((1, 2, 0))\n        mean = np.array([0.485, 0.456, 0.406])\n        std = np.array([0.229, 0.224, 0.225])\n        inp = std * inp + mean\n        inp = np.clip(inp, 0, 1).transpose((2,0,1))\n        re.append(torch.from_numpy(inp).unsqueeze(0))\n    return torch.cat(re,0)\n\n\ndef get_jet():\n    colormap_int = np.zeros((256, 3), np.uint8)\n \n    for i in range(0, 256, 1):\n        colormap_int[i, 0] = np.int_(np.round(cm.jet(i)[0] * 255.0))\n        colormap_int[i, 1] = np.int_(np.round(cm.jet(i)[1] * 255.0))\n        colormap_int[i, 2] = np.int_(np.round(cm.jet(i)[2] * 255.0))\n\n    return colormap_int\n\ndef clamp(num, min_value, max_value):\n    return max(min(num, max_value), min_value)\n\ndef gray2color(gray_array, color_map):\n    \n    rows, cols = gray_array.shape\n    color_array = np.zeros((rows, cols, 3), np.uint8)\n \n    for i in range(0, rows):\n        for j in range(0, cols):\n#             log(256,2) = 8 , log(1,2) = 0 * 8\n            color_array[i, j] = color_map[clamp(int(abs(gray_array[i, j])*10),0,255)]\n    \n    return color_array\n\nclass objectview(object):\n    def __init__(self, *args, **kwargs):\n        d = dict(*args, **kwargs)\n        self.__dict__ = d"
  },
  {
    "path": "scripts/utils/logger.py",
    "content": "# A simple torch style logger\n# (C) Wei YANG 2017\nfrom __future__ import absolute_import\n\nimport os\nimport sys\nimport numpy as np\nimport matplotlib.pyplot as plt\n\n__all__ = ['Logger', 'LoggerMonitor', 'savefig']\n\ndef savefig(fname, dpi=None):\n    dpi = 150 if dpi == None else dpi\n    plt.savefig(fname, dpi=dpi)\n    \ndef plot_overlap(logger, names=None):\n    names = logger.names if names == None else names\n    numbers = logger.numbers\n    for _, name in enumerate(names):\n        x = np.arange(len(numbers[name]))\n        plt.plot(x, np.asarray(numbers[name]))\n    return [logger.title + '(' + name + ')' for name in names]\n\nclass Logger(object):\n    '''Save training process to log file with simple plot function.'''\n    def __init__(self, fpath, title=None, resume=False): \n        self.file = None\n        self.resume = resume\n        self.title = '' if title == None else title\n        if fpath is not None:\n            if resume: \n                self.file = open(fpath, 'r') \n                name = self.file.readline()\n                self.names = name.rstrip().split('\\t')\n                self.numbers = {}\n                for _, name in enumerate(self.names):\n                    self.numbers[name] = []\n\n                for numbers in self.file:\n                    numbers = numbers.rstrip().split('\\t')\n                    for i in range(0, len(numbers)):\n                        self.numbers[self.names[i]].append(numbers[i])\n                self.file.close()\n                self.file = open(fpath, 'a')  \n            else:\n                self.file = open(fpath, 'w')\n\n    def set_names(self, names):\n        if self.resume: \n            pass\n        # initialize numbers as empty list\n        self.numbers = {}\n        self.names = names\n        for _, name in enumerate(self.names):\n            self.file.write(name)\n            self.file.write('\\t')\n            self.numbers[name] = []\n        self.file.write('\\n')\n        self.file.flush()\n\n\n    def append(self, numbers):\n        assert len(self.names) == len(numbers), 'Numbers do not match names'\n        for index, num in enumerate(numbers):\n            self.file.write(\"{0:.6f}\".format(num))\n            self.file.write('\\t')\n            self.numbers[self.names[index]].append(num)\n        self.file.write('\\n')\n        self.file.flush()\n\n    def plot(self, names=None):   \n        names = self.names if names == None else names\n        numbers = self.numbers\n        for _, name in enumerate(names):\n            x = np.arange(len(numbers[name]))\n            plt.plot(x, np.asarray(numbers[name]))\n        plt.legend([self.title + '(' + name + ')' for name in names])\n        plt.grid(True)\n\n    def close(self):\n        if self.file is not None:\n            self.file.close()\n\nclass LoggerMonitor(object):\n    '''Load and visualize multiple logs.'''\n    def __init__ (self, paths):\n        '''paths is a distionary with {name:filepath} pair'''\n        self.loggers = []\n        for title, path in paths.items():\n            logger = Logger(path, title=title, resume=True)\n            self.loggers.append(logger)\n\n    def plot(self, names=None):\n        plt.figure()\n        plt.subplot(121)\n        legend_text = []\n        for logger in self.loggers:\n            legend_text += plot_overlap(logger, names)\n        plt.legend(legend_text, bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.)\n        plt.grid(True)\n                    \nif __name__ == '__main__':\n    # # Example\n    # logger = Logger('test.txt')\n    # logger.set_names(['Train loss', 'Valid loss','Test loss'])\n\n    # length = 100\n    # t = np.arange(length)\n    # train_loss = np.exp(-t / 10.0) + np.random.rand(length) * 0.1\n    # valid_loss = np.exp(-t / 10.0) + np.random.rand(length) * 0.1\n    # test_loss = np.exp(-t / 10.0) + np.random.rand(length) * 0.1\n\n    # for i in range(0, length):\n    #     logger.append([train_loss[i], valid_loss[i], test_loss[i]])\n    # logger.plot()\n\n    # Example: logger monitor\n    paths = {\n    'resadvnet20':'/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet20/log.txt', \n    'resadvnet32':'/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet32/log.txt',\n    'resadvnet44':'/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet44/log.txt',\n    }\n\n    field = ['Valid Acc.']\n\n    monitor = LoggerMonitor(paths)\n    monitor.plot(names=field)\n    savefig('test.eps')"
  },
  {
    "path": "scripts/utils/losses.py",
    "content": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom scripts.models.vgg import Vgg19\nfrom torchvision import models\nfrom scripts.utils.misc import resize_to_match\n# from pytorch_msssim import SSIM, MS_SSIM\nimport pytorch_ssim\n\nclass WeightedBCE(nn.Module):\n    def __init__(self):\n        super(WeightedBCE, self).__init__()\n\n    def forward(self, pred, gt):\n        eposion = 1e-10\n        sigmoid_pred = torch.sigmoid(pred)\n        count_pos = torch.sum(gt)*1.0+eposion\n        count_neg = torch.sum(1.-gt)*1.0\n        beta = count_neg/count_pos\n        beta_back = count_pos / (count_pos + count_neg)\n\n        bce1 = nn.BCEWithLogitsLoss(pos_weight=beta)\n        loss = beta_back*bce1(pred, gt)\n\n        return loss\n\n\ndef l1_relative(reconstructed, real, mask):\n    batch = real.size(0)\n    area = torch.sum(mask.view(batch,-1),dim=1)\n    reconstructed = reconstructed * mask\n    real = real * mask\n    \n    loss_l1 = torch.abs(reconstructed - real).view(batch, -1)\n    loss_l1 = torch.sum(loss_l1, dim=1) / area\n    loss_l1 = torch.sum(loss_l1) / batch\n    return loss_l1\n\n\ndef is_dic(x):\n    return type(x) == type([])\n\nclass Losses(nn.Module):\n    def __init__(self, argx, device):\n        super(Losses, self).__init__()\n        self.args = argx\n\n        if self.args.loss_type == 'l1bl2':\n            self.outputLoss, self.attLoss, self.wrloss = nn.L1Loss(), nn.BCELoss(), nn.MSELoss()\n        elif self.args.loss_type == 'l1wbl2':\n            self.outputLoss, self.attLoss, self.wrloss = nn.L1Loss(), WeightedBCE(), nn.MSELoss() \n        elif self.args.loss_type == 'l2wbl2':\n            self.outputLoss, self.attLoss, self.wrloss = nn.MSELoss(), WeightedBCE(), nn.MSELoss()\n        elif self.args.loss_type == 'l2xbl2':\n            self.outputLoss, self.attLoss, self.wrloss = nn.MSELoss(), nn.BCEWithLogitsLoss(), nn.MSELoss()\n        else: # l2bl2\n            self.outputLoss, self.attLoss, self.wrloss = nn.MSELoss(), nn.BCELoss(), nn.MSELoss()\n\n        if self.args.style_loss > 0:\n            self.vggloss = VGGLoss(self.args.sltype).to(device)\n        \n        if self.args.ssim_loss > 0:\n            self.ssimloss =  pytorch_ssim.SSIM().to(device)\n\n        self.outputLoss = self.outputLoss.to(device)\n        self.attLoss = self.attLoss.to(device)\n        self.wrloss = self.wrloss.to(device)\n\n\n    def forward(self,imgx,target,attx,mask,wmx,wm):\n        pixel_loss,att_loss,wm_loss,vgg_loss,ssim_loss = 0,0,0,0,0\n\n        if is_dic(imgx):\n\n            if self.args.masked:\n            # calculate the overall loss and side output\n                pixel_loss = self.outputLoss(imgx[0],target) + sum([self.outputLoss(im,resize_to_match(mask,im)*resize_to_match(target,im)) for im in imgx[1:]])\n            else:\n                pixel_loss =  sum([self.outputLoss(im,resize_to_match(target,im)) for im in imgx])\n\n            if self.args.style_loss > 0:\n                vgg_loss = sum([self.vggloss(im,resize_to_match(target,im),resize_to_match(mask,im)) for im in imgx])\n\n            if self.args.ssim_loss > 0:\n                ssim_loss = sum([ 1 - self.ssimloss(im,resize_to_match(target,im)) for im in imgx])\n        else:\n\n            if self.args.masked:\n                pixel_loss = self.outputLoss(imgx,mask*target)\n            else:\n                pixel_loss =  self.outputLoss(imgx,target)\n\n            if self.args.style_loss > 0:\n                vgg_loss = self.vggloss(imgx,target,mask)\n\n            if self.args.ssim_loss > 0:\n                ssim_loss = 1 - self.ssimloss(imgx,target)\n\n        if is_dic(attx):\n            att_loss =  sum([self.attLoss(at,resize_to_match(mask,at)) for at in attx])\n        else:\n            att_loss =  self.attLoss(attx, mask)\n\n        if is_dic(wmx):\n            wm_loss = sum([self.wrloss(w,resize_to_match(wm,w)) for w in wmx])\n        else:\n            if self.args.masked:\n                wm_loss = self.wrloss(wmx,mask*wm)\n            else:\n                wm_loss = self.wrloss(wmx, wm)\n\n        return pixel_loss,att_loss,wm_loss,vgg_loss,ssim_loss\n\n\n\ndef gram_matrix(feat):\n    # https://github.com/pytorch/examples/blob/master/fast_neural_style/neural_style/utils.py\n    (b, ch, h, w) = feat.size()\n    feat = feat.view(b, ch, h * w)\n    feat_t = feat.transpose(1, 2)\n    gram = torch.bmm(feat, feat_t) / (ch * h * w)\n    return gram\n    \nclass MeanShift(nn.Conv2d):\n    def __init__(self, data_mean, data_std, data_range=1, norm=True):\n        \"\"\"norm (bool): normalize/denormalize the stats\"\"\"\n        c = len(data_mean)\n        super(MeanShift, self).__init__(c, c, kernel_size=1)\n        std = torch.Tensor(data_std)\n        self.weight.data = torch.eye(c).view(c, c, 1, 1)\n        if norm:\n            self.weight.data.div_(std.view(c, 1, 1, 1))\n            self.bias.data = -1 * data_range * torch.Tensor(data_mean)\n            self.bias.data.div_(std)\n        else:\n            self.weight.data.mul_(std.view(c, 1, 1, 1))\n            self.bias.data = data_range * torch.Tensor(data_mean)\n        self.requires_grad = False\n\n\n\ndef VGGLoss(losstype):\n    if losstype == 'vgg':\n        return VGGLossA()\n    elif losstype == 'vggx':\n        return VGGLossX(mask=False)\n    elif losstype == 'mvggx':\n        return VGGLossX(mask=True)\n    elif losstype == 'rvggx':\n        return VGGLossX(mask=True,relative=True)\n    else:\n        raise Exception(\"error in %s\"%losstype)\n\n        \n\nclass VGGLossA(nn.Module):\n    def __init__(self, vgg=None, weights=None, indices=None, normalize=True):\n        super(VGGLossA, self).__init__()        \n        if vgg is None:\n            self.vgg = Vgg19().cuda()\n        else:\n            self.vgg = vgg\n        self.criterion = nn.L1Loss()\n        self.weights = weights or [1.0/2.6, 1.0/4.8, 1.0/3.7, 1.0/5.6, 10/1.5]\n        self.indices = indices or [2, 7, 12, 21, 30]\n        if normalize:\n            self.normalize = MeanShift([0.485, 0.456, 0.406], [0.229, 0.224, 0.225], norm=True).cuda()\n        else:\n            self.normalize = None\n\n    def forward(self, x, y):\n        if self.normalize is not None:\n            x = self.normalize(x)\n            y = self.normalize(y)\n        x_vgg, y_vgg = self.vgg(x, self.indices), self.vgg(y, self.indices)\n        loss = 0\n        for i in range(len(x_vgg)):\n            loss += self.weights[i] * self.criterion(x_vgg[i], y_vgg[i].detach())\n        return loss\n\n\nclass VGG16FeatureExtractor(nn.Module):\n    def __init__(self):\n        super().__init__()\n        vgg16 = models.vgg16(pretrained=True)\n        self.enc_1 = nn.Sequential(*vgg16.features[:5])\n        self.enc_2 = nn.Sequential(*vgg16.features[5:10])\n        self.enc_3 = nn.Sequential(*vgg16.features[10:17])\n\n        # fix the encoder\n        for i in range(3):\n            for param in getattr(self, 'enc_{:d}'.format(i + 1)).parameters():\n                param.requires_grad = False\n\n    def forward(self, image):\n        results = [image]\n        for i in range(3):\n            func = getattr(self, 'enc_{:d}'.format(i + 1))\n            results.append(func(results[-1]))\n        return results[1:]\n\nclass VGGLossX(nn.Module):\n    def __init__(self, normalize=True, mask=False, relative=False):\n        super(VGGLossX, self).__init__()\n        \n        self.vgg = VGG16FeatureExtractor().cuda()\n        self.criterion = nn.L1Loss().cuda() if not relative else l1_relative\n        self.use_mask= mask\n        self.relative = relative\n\n        if normalize:\n            self.normalize = MeanShift([0.485, 0.456, 0.406], [0.229, 0.224, 0.225], norm=True).cuda()\n        else:\n            self.normalize = None\n\n    def forward(self, x, y, Xmask=None):\n        if not self.use_mask:\n            mask = torch.ones_like(x)[:,0:1,:,:]\n        else:\n            mask = Xmask\n\n        if self.normalize is not None:\n            x = self.normalize(x)\n            y = self.normalize(y)\n\n        x_vgg = self.vgg(x)\n        y_vgg = self.vgg(y)\n        \n        loss = 0\n        for i in range(3):\n            if self.relative:\n                loss += self.criterion(x_vgg[i],y_vgg[i].detach(),resize_to_match(mask,x_vgg[i]))\n            else:\n                loss += self.criterion(resize_to_match(mask,x_vgg[i])*x_vgg[i],resize_to_match(mask,y_vgg[i])*y_vgg[i].detach())\n\n        return loss\n\n\nclass GANLosses(object):\n    \"\"\"docstring for Loss\"\"\"\n    def __init__(self, gantype):\n        super(GANLosses, self).__init__()        \n        self.generator_loss = gen_gan(gantype)\n        self.discriminator_loss = dis_gan(gantype)\n        self.gantype = gantype\n\n    def g_loss(self,dis_fake):\n        if 'hinge' in self.gantype:\n            return gen_hinge(dis_fake)\n        else:\n            return self.generator_loss(dis_fake)\n\n    def d_loss(self,dis_fake,dis_real):\n        if 'hinge' in self.gantype:\n            return dis_hinge(dis_fake,dis_real)\n        else:\n            return self.discriminator_loss(dis_fake,dis_real)\n\n\nclass gen_gan(nn.Module):\n    def __init__(self,gantype):\n        super(gen_gan,self).__init__()\n        if gantype == 'lsgan':\n            self.criterion = nn.MSELoss()\n        elif gantype == 'naive':\n            self.criterion = nn.BCEWithLogitsLoss()\n        else:\n            raise Exception(\"error gan type\")\n    \n    def forward(self,dis_fake):\n        return self.criterion(dis_fake, torch.ones_like(dis_fake))\n\nclass dis_gan(nn.Module):\n    def __init__(self,gantype):\n        super(dis_gan,self).__init__()\n        if gantype == 'lsgan':\n            self.criterion = nn.MSELoss()\n        elif gantype == 'naive':\n            self.criterion = nn.BCEWithLogitsLoss()\n        else:\n            raise Exception(\"error gan type\")\n    \n    def forward(self,dis_fake,dis_real):\n        loss_fake = self.criterion(dis_fake, torch.zeros_like(dis_fake))\n        loss_real = self.criterion(dis_real, torch.ones_like(dis_real))\n        return loss_fake, loss_real\n\n# def gen_gan(dis_fake):\n#     # fake -> 1\n#     return F.binary_cross_entropy_with_logits(dis_fake,torch.ones_like(dis_fake))\n\n# def dis_gan(dis_fake,dis_real):\n#     # fake -> 0 , real ->1\n#     loss_fake = F.binary_cross_entropy_with_logits(dis_fake, torch.zeros_like(dis_real))\n#     loss_real = F.binary_cross_entropy_with_logits(dis_real, torch.ones_like(dis_fake))\n#     return loss_fake,loss_real \n\n# def gen_lsgan(dis_fake):\n#     loss = F.mse_loss(dis_fake,torch.ones_like(dis_fake)) # \n#     return loss\n\n# def dis_lsgan(dis_fake, dis_real):\n#     loss_fake = F.mse_loss(dis_fake, torch.zeros_like(dis_real))\n#     loss_real = F.mse_loss(dis_real, torch.ones_like(dis_real))\n#     return loss_fake,loss_real\n\ndef gen_hinge(dis_fake, dis_real=None):\n    return -torch.mean(dis_fake)\n\ndef dis_hinge(dis_fake, dis_real):\n    loss_fake = torch.mean(torch.relu(1. + dis_fake))\n    loss_real = torch.mean(torch.relu(1. - dis_real))\n    return loss_fake,loss_real\n\n"
  },
  {
    "path": "scripts/utils/misc.py",
    "content": "from __future__ import absolute_import\n\nimport os\nimport shutil\nimport torch \nimport math\nimport numpy as np\nimport scipy.io\nimport matplotlib.pyplot as plt\nimport torch.nn.functional as F\n\ndef to_numpy(tensor):\n    if torch.is_tensor(tensor):\n        return tensor.cpu().numpy()\n    elif type(tensor).__module__ != 'numpy':\n        raise ValueError(\"Cannot convert {} to numpy array\"\n                         .format(type(tensor)))\n    return tensor\n\ndef resize_to_match(fm,to):\n    # just use interpolate\n    # [1,3] = (h,w)\n    return F.interpolate(fm,to.size()[-2:],mode='bilinear',align_corners=False)\n\ndef to_torch(ndarray):\n    if type(ndarray).__module__ == 'numpy':\n        return torch.from_numpy(ndarray)\n    elif not torch.is_tensor(ndarray):\n        raise ValueError(\"Cannot convert {} to torch tensor\"\n                         .format(type(ndarray)))\n    return ndarray\n\n\ndef save_checkpoint(machine,filename='checkpoint.pth.tar', snapshot=None):\n    is_best = True if machine.best_acc < machine.metric else False\n\n    if is_best:\n        machine.best_acc = machine.metric\n\n    state = {\n                'epoch': machine.current_epoch + 1,\n                'arch': machine.args.arch,\n                'state_dict': machine.model.state_dict(),\n                'best_acc': machine.best_acc,\n                'optimizer' : machine.optimizer.state_dict(),\n            }\n\n    filepath = os.path.join(machine.args.checkpoint, filename)\n    torch.save(state, filepath)\n\n    if snapshot and state['epoch'] % snapshot == 0:\n        shutil.copyfile(filepath, os.path.join(machine.args.checkpoint, 'checkpoint_{}.pth.tar'.format(state.epoch)))\n       \n    if is_best:\n        machine.best_acc = machine.metric\n        print('Saving Best Metric with PSNR:%s'%machine.best_acc)\n        shutil.copyfile(filepath, os.path.join(machine.args.checkpoint, 'model_best.pth.tar'))\n        \n\n\ndef save_pred(preds, checkpoint='checkpoint', filename='preds_valid.mat'):\n    preds = to_numpy(preds)\n    filepath = os.path.join(checkpoint, filename)\n    scipy.io.savemat(filepath, mdict={'preds' : preds})\n\n\ndef adjust_learning_rate(datasets,optimizer, epoch, lr,args):\n    \"\"\"Sets the learning rate to the initial LR decayed by schedule\"\"\"\n    if epoch in args.schedule:\n        lr *= args.gamma\n        for param_group in optimizer.param_groups:\n            param_group['lr'] = lr\n    \n    # decay sigma\n    for dset in datasets:\n        if args.sigma_decay > 0:\n            dset.dataset.sigma *=  args.sigma_decay\n            dset.dataset.sigma *=  args.sigma_decay\n\n    return lr\n\n\n\n\n"
  },
  {
    "path": "scripts/utils/model_init.py",
    "content": "\n\nfrom torch.nn import init\n\n\ndef weights_init_normal(m):\n    classname = m.__class__.__name__\n    # print(classname)\n    if classname.find('Conv') != -1:\n        init.normal_(m.weight.data, 0.0, 0.02)\n    elif classname.find('Linear') != -1:\n        init.normal_(m.weight.data, 0.0, 0.02)\n    elif classname.find('BatchNorm2d') != -1:\n        init.normal_(m.weight.data, 1.0, 0.02)\n        init.constant_(m.bias.data, 0.0)\n\n\ndef weights_init_xavier(m):\n    classname = m.__class__.__name__\n    # print(classname)\n    if classname.find('Conv') != -1:\n        init.xavier_normal(m.weight.data, gain=0.02)\n    elif classname.find('Linear') != -1:\n        init.xavier_normal(m.weight.data, gain=0.02)\n    # elif classname.find('BatchNorm2d') != -1:\n    #     init.normal(m.weight.data, 1.0, 0.02)\n    #     init.constant(m.bias.data, 0.0)\n\n\ndef weights_init_kaiming(m):\n    classname = m.__class__.__name__\n    # print(classname)\n    if classname.find('Conv') != -1 and m.weight.requires_grad == True:\n        init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')\n    elif classname.find('Linear') != -1 and m.weight.requires_grad == True:\n        init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')\n    elif classname.find('BatchNorm2d') != -1 and m.weight.requires_grad == True:\n        init.normal_(m.weight.data, 1.0, 0.02)\n        init.constant_(m.bias.data, 0.0)\n\n\ndef weights_init_orthogonal(m):\n    classname = m.__class__.__name__\n    if classname.find('Conv') != -1:\n        init.orthogonal(m.weight.data, gain=1)\n    elif classname.find('Linear') != -1:\n        init.orthogonal(m.weight.data, gain=1)\n    # elif classname.find('BatchNorm2d') != -1:\n    #     init.normal(m.weight.data, 1.0, 0.02)\n    #     init.constant(m.bias.data, 0.0)"
  },
  {
    "path": "scripts/utils/osutils.py",
    "content": "from __future__ import absolute_import\n\nimport os\nimport errno\n\ndef mkdir_p(dir_path):\n    try:\n        os.makedirs(dir_path)\n    except OSError as e:\n        if e.errno != errno.EEXIST:\n            raise\n\ndef isfile(fname):\n    return os.path.isfile(fname) \n\ndef isdir(dirname):\n    return os.path.isdir(dirname)\n\ndef join(path, *paths):\n    return os.path.join(path, *paths)\n"
  },
  {
    "path": "scripts/utils/parallel.py",
    "content": "##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n## Created by: Hang Zhang, Rutgers University, Email: zhang.hang@rutgers.edu\n## Modified by Thomas Wolf, HuggingFace Inc., Email: thomas@huggingface.co\n## Copyright (c) 2017-2018\n##\n## This source code is licensed under the MIT-style license found in the\n## LICENSE file in the root directory of this source tree\n##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n\n\"\"\"Encoding Data Parallel\"\"\"\nimport threading\nimport functools\nimport torch\nfrom torch.autograd import Variable, Function\nimport torch.cuda.comm as comm\nfrom torch.nn.parallel import DistributedDataParallel\nfrom torch.nn.parallel.data_parallel import DataParallel\nfrom torch.nn.parallel.parallel_apply import get_a_var\nfrom torch.nn.parallel.scatter_gather import gather\nfrom torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast\n\ntorch_ver = torch.__version__[:3]\n\n__all__ = ['allreduce', 'DataParallelModel', 'DataParallelCriterion',\n           'patch_replication_callback']\n\ndef allreduce(*inputs):\n    \"\"\"Cross GPU all reduce autograd operation for calculate mean and\n    variance in SyncBN.\n    \"\"\"\n    return AllReduce.apply(*inputs)\n\nclass AllReduce(Function):\n    @staticmethod\n    def forward(ctx, num_inputs, *inputs):\n        ctx.num_inputs = num_inputs\n        ctx.target_gpus = [inputs[i].get_device() for i in range(0, len(inputs), num_inputs)]\n        inputs = [inputs[i:i + num_inputs]\n                 for i in range(0, len(inputs), num_inputs)]\n        # sort before reduce sum\n        inputs = sorted(inputs, key=lambda i: i[0].get_device())\n        results = comm.reduce_add_coalesced(inputs, ctx.target_gpus[0])\n        outputs = comm.broadcast_coalesced(results, ctx.target_gpus)\n        return tuple([t for tensors in outputs for t in tensors])\n\n    @staticmethod\n    def backward(ctx, *inputs):\n        inputs = [i.data for i in inputs]\n        inputs = [inputs[i:i + ctx.num_inputs]\n                 for i in range(0, len(inputs), ctx.num_inputs)]\n        results = comm.reduce_add_coalesced(inputs, ctx.target_gpus[0])\n        outputs = comm.broadcast_coalesced(results, ctx.target_gpus)\n        return (None,) + tuple([Variable(t) for tensors in outputs for t in tensors])\n\n\nclass Reduce(Function):\n    @staticmethod\n    def forward(ctx, *inputs):\n        ctx.target_gpus = [inputs[i].get_device() for i in range(len(inputs))]\n        inputs = sorted(inputs, key=lambda i: i.get_device())\n        return comm.reduce_add(inputs)\n\n    @staticmethod\n    def backward(ctx, gradOutput):\n        return Broadcast.apply(ctx.target_gpus, gradOutput)\n\nclass DistributedDataParallelModel(DistributedDataParallel):\n    \"\"\"Implements data parallelism at the module level for the DistributedDataParallel module.\n    This container parallelizes the application of the given module by\n    splitting the input across the specified devices by chunking in the\n    batch dimension.\n    In the forward pass, the module is replicated on each device,\n    and each replica handles a portion of the input. During the backwards pass,\n    gradients from each replica are summed into the original module.\n    Note that the outputs are not gathered, please use compatible\n    :class:`encoding.parallel.DataParallelCriterion`.\n    The batch size should be larger than the number of GPUs used. It should\n    also be an integer multiple of the number of GPUs so that each chunk is\n    the same size (so that each GPU processes the same number of samples).\n    Args:\n        module: module to be parallelized\n        device_ids: CUDA devices (default: all devices)\n    Reference:\n        Hang Zhang, Kristin Dana, Jianping Shi, Zhongyue Zhang, Xiaogang Wang, Ambrish Tyagi,\n        Amit Agrawal. “Context Encoding for Semantic Segmentation.\n        *The IEEE Conference on Computer Vision and Pattern Recognition (CVPR) 2018*\n    Example::\n        >>> net = encoding.nn.DistributedDataParallelModel(model, device_ids=[0, 1, 2])\n        >>> y = net(x)\n    \"\"\"\n    def gather(self, outputs, output_device):\n        return outputs\n\nclass DataParallelModel(DataParallel):\n    \"\"\"Implements data parallelism at the module level.\n\n    This container parallelizes the application of the given module by\n    splitting the input across the specified devices by chunking in the\n    batch dimension.\n    In the forward pass, the module is replicated on each device,\n    and each replica handles a portion of the input. During the backwards pass,\n    gradients from each replica are summed into the original module.\n    Note that the outputs are not gathered, please use compatible\n    :class:`encoding.parallel.DataParallelCriterion`.\n\n    The batch size should be larger than the number of GPUs used. It should\n    also be an integer multiple of the number of GPUs so that each chunk is\n    the same size (so that each GPU processes the same number of samples).\n\n    Args:\n        module: module to be parallelized\n        device_ids: CUDA devices (default: all devices)\n\n    Reference:\n        Hang Zhang, Kristin Dana, Jianping Shi, Zhongyue Zhang, Xiaogang Wang, Ambrish Tyagi,\n        Amit Agrawal. “Context Encoding for Semantic Segmentation.\n        *The IEEE Conference on Computer Vision and Pattern Recognition (CVPR) 2018*\n\n    Example::\n\n        >>> net = encoding.nn.DataParallelModel(model, device_ids=[0, 1, 2])\n        >>> y = net(x)\n    \"\"\"\n    def gather(self, outputs, output_device):\n        return outputs\n\n    def replicate(self, module, device_ids):\n        modules = super(DataParallelModel, self).replicate(module, device_ids)\n        execute_replication_callbacks(modules)\n        return modules\n\n\nclass DataParallelCriterion(DataParallel):\n    \"\"\"\n    Calculate loss in multiple-GPUs, which balance the memory usage.\n    The targets are splitted across the specified devices by chunking in\n    the batch dimension. Please use together with :class:`encoding.parallel.DataParallelModel`.\n\n    Reference:\n        Hang Zhang, Kristin Dana, Jianping Shi, Zhongyue Zhang, Xiaogang Wang, Ambrish Tyagi,\n        Amit Agrawal. “Context Encoding for Semantic Segmentation.\n        *The IEEE Conference on Computer Vision and Pattern Recognition (CVPR) 2018*\n\n    Example::\n\n        >>> net = encoding.nn.DataParallelModel(model, device_ids=[0, 1, 2])\n        >>> criterion = encoding.nn.DataParallelCriterion(criterion, device_ids=[0, 1, 2])\n        >>> y = net(x)\n        >>> loss = criterion(y, target)\n    \"\"\"\n    def forward(self, inputs, *targets, **kwargs):\n        # input should be already scatterd\n        # scattering the targets instead\n        if not self.device_ids:\n            return self.module(inputs, *targets, **kwargs)\n        targets, kwargs = self.scatter(targets, kwargs, self.device_ids)\n        if len(self.device_ids) == 1:\n            return self.module(inputs, *targets[0], **kwargs[0])\n        replicas = self.replicate(self.module, self.device_ids[:len(inputs)])\n        outputs = _criterion_parallel_apply(replicas, inputs, targets, kwargs)\n        #return Reduce.apply(*outputs) / len(outputs)\n        #return self.gather(outputs, self.output_device).mean()\n        return self.gather(outputs, self.output_device)\n\n\ndef _criterion_parallel_apply(modules, inputs, targets, kwargs_tup=None, devices=None):\n    assert len(modules) == len(inputs)\n    assert len(targets) == len(inputs)\n    if kwargs_tup:\n        assert len(modules) == len(kwargs_tup)\n    else:\n        kwargs_tup = ({},) * len(modules)\n    if devices is not None:\n        assert len(modules) == len(devices)\n    else:\n        devices = [None] * len(modules)\n\n    lock = threading.Lock()\n    results = {}\n    if torch_ver != \"0.3\":\n        grad_enabled = torch.is_grad_enabled()\n\n    def _worker(i, module, input, target, kwargs, device=None):\n        if torch_ver != \"0.3\":\n            torch.set_grad_enabled(grad_enabled)\n        if device is None:\n            device = get_a_var(input).get_device()\n        try:\n            with torch.cuda.device(device):\n                # this also avoids accidental slicing of `input` if it is a Tensor\n                if not isinstance(input, (list, tuple)):\n                    input = (input,)\n                if not isinstance(target, (list, tuple)):\n                    target = (target,)\n                output = module(*(input + target), **kwargs)\n            with lock:\n                results[i] = output\n        except Exception as e:\n            with lock:\n                results[i] = e\n\n    if len(modules) > 1:\n        threads = [threading.Thread(target=_worker,\n                                    args=(i, module, input, target,\n                                          kwargs, device),)\n                   for i, (module, input, target, kwargs, device) in\n                   enumerate(zip(modules, inputs, targets, kwargs_tup, devices))]\n\n        for thread in threads:\n            thread.start()\n        for thread in threads:\n            thread.join()\n    else:\n        _worker(0, modules[0], inputs[0], kwargs_tup[0], devices[0])\n\n    outputs = []\n    for i in range(len(inputs)):\n        output = results[i]\n        if isinstance(output, Exception):\n            raise output\n        outputs.append(output)\n    return outputs\n\n\n###########################################################################\n# Adapted from Synchronized-BatchNorm-PyTorch.\n# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch\n#\nclass CallbackContext(object):\n    pass\n\n\ndef execute_replication_callbacks(modules):\n    \"\"\"\n    Execute an replication callback `__data_parallel_replicate__` on each module created\n    by original replication.\n\n    The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)`\n\n    Note that, as all modules are isomorphism, we assign each sub-module with a context\n    (shared among multiple copies of this module on different devices).\n    Through this context, different copies can share some information.\n\n    We guarantee that the callback on the master copy (the first copy) will be called ahead\n    of calling the callback of any slave copies.\n    \"\"\"\n    master_copy = modules[0]\n    nr_modules = len(list(master_copy.modules()))\n    ctxs = [CallbackContext() for _ in range(nr_modules)]\n\n    for i, module in enumerate(modules):\n        for j, m in enumerate(module.modules()):\n            if hasattr(m, '__data_parallel_replicate__'):\n                m.__data_parallel_replicate__(ctxs[j], i)\n\n\ndef patch_replication_callback(data_parallel):\n    \"\"\"\n    Monkey-patch an existing `DataParallel` object. Add the replication callback.\n    Useful when you have customized `DataParallel` implementation.\n\n    Examples:\n        > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)\n        > sync_bn = DataParallel(sync_bn, device_ids=[0, 1])\n        > patch_replication_callback(sync_bn)\n        # this is equivalent to\n        > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)\n        > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])\n    \"\"\"\n\n    assert isinstance(data_parallel, DataParallel)\n\n    old_replicate = data_parallel.replicate\n\n    @functools.wraps(old_replicate)\n    def new_replicate(module, device_ids):\n        modules = old_replicate(module, device_ids)\n        execute_replication_callbacks(modules)\n        return modules\n\n    data_parallel.replicate = new_replicate"
  },
  {
    "path": "scripts/utils/transforms.py",
    "content": "from __future__ import absolute_import\n\nimport os\nimport numpy as np\nimport scipy.misc\nimport matplotlib.pyplot as plt\nimport torch\nimport torchvision\n\nfrom .misc import *\nfrom .imutils import *\n\n\ndef color_normalize(x, mean, std):\n    if x.size(0) == 1:\n        x = x.repeat(3, x.size(1), x.size(2))\n\n    for t, m, s in zip(x, mean, std):\n        t.sub_(m)\n    return x\n\n\ndef flip_back(flip_output, dataset='mpii'):\n    \"\"\"\n    flip output map\n    \"\"\"\n    if dataset ==  'mpii':\n        matchedParts = (\n            [0,5],   [1,4],   [2,3],\n            [10,15], [11,14], [12,13]\n        )\n    else:\n        print('Not supported dataset: ' + dataset)\n\n    # flip output horizontally\n    flip_output = fliplr(flip_output.numpy())\n\n    # Change left-right parts\n    for pair in matchedParts:\n        tmp = np.copy(flip_output[:, pair[0], :, :])\n        flip_output[:, pair[0], :, :] = flip_output[:, pair[1], :, :]\n        flip_output[:, pair[1], :, :] = tmp\n\n    return torch.from_numpy(flip_output).float()\n\n\ndef shufflelr(x, width, dataset='mpii'):\n    \"\"\"\n    flip coords\n    \"\"\"\n    if dataset ==  'mpii':\n        matchedParts = (\n            [0,5],   [1,4],   [2,3],\n            [10,15], [11,14], [12,13]\n        )\n    else:\n        print('Not supported dataset: ' + dataset)\n\n    # Flip horizontal\n    x[:, 0] = width - x[:, 0]\n\n    # Change left-right parts\n    for pair in matchedParts:\n        tmp = x[pair[0], :].clone()\n        x[pair[0], :] = x[pair[1], :]\n        x[pair[1], :] = tmp\n\n    return x\n\n\ndef fliplr(x):\n    if x.ndim == 3:\n        x = np.transpose(np.fliplr(np.transpose(x, (0, 2, 1))), (0, 2, 1))\n    elif x.ndim == 4:\n        for i in range(x.shape[0]):\n            x[i] = np.transpose(np.fliplr(np.transpose(x[i], (0, 2, 1))), (0, 2, 1))\n    return x.astype(float)\n\n\ndef get_transform(center, scale, res, rot=0):\n    \"\"\"\n    General image processing functions\n    \"\"\"\n    # Generate transformation matrix\n    h = 200 * scale\n    t = np.zeros((3, 3))\n    t[0, 0] = float(res[1]) / h\n    t[1, 1] = float(res[0]) / h\n    t[0, 2] = res[1] * (-float(center[0]) / h + .5)\n    t[1, 2] = res[0] * (-float(center[1]) / h + .5)\n    t[2, 2] = 1\n    if not rot == 0:\n        rot = -rot # To match direction of rotation from cropping\n        rot_mat = np.zeros((3,3))\n        rot_rad = rot * np.pi / 180\n        sn,cs = np.sin(rot_rad), np.cos(rot_rad)\n        rot_mat[0,:2] = [cs, -sn]\n        rot_mat[1,:2] = [sn, cs]\n        rot_mat[2,2] = 1\n        # Need to rotate around center\n        t_mat = np.eye(3)\n        t_mat[0,2] = -res[1]/2\n        t_mat[1,2] = -res[0]/2\n        t_inv = t_mat.copy()\n        t_inv[:2,2] *= -1\n        t = np.dot(t_inv,np.dot(rot_mat,np.dot(t_mat,t)))\n    return t\n\n\ndef transform(pt, center, scale, res, invert=0, rot=0):\n    # Transform pixel location to different reference\n    t = get_transform(center, scale, res, rot=rot)\n    if invert:\n        t = np.linalg.inv(t)\n    new_pt = np.array([pt[0] - 1, pt[1] - 1, 1.]).T\n    new_pt = np.dot(t, new_pt)\n    return new_pt[:2].astype(int) + 1\n\n\ndef transform_preds(coords, center, scale, res):\n    # size = coords.size()\n    # coords = coords.view(-1, coords.size(-1))\n    # print(coords.size())\n    for p in range(coords.size(0)):\n        coords[p, 0:2] = to_torch(transform(coords[p, 0:2], center, scale, res, 1, 0))\n    return coords\n\n\ndef crop(img, center, scale, res, rot=0):\n    img = im_to_numpy(img)\n\n    # Upper left point\n    ul = np.array(transform([0, 0], center, scale, res, invert=1))\n    # Bottom right point\n    br = np.array(transform(res, center, scale, res, invert=1))\n\n    # Padding so that when rotated proper amount of context is included\n    pad = int(np.linalg.norm(br - ul) / 2 - float(br[1] - ul[1]) / 2)\n    if not rot == 0:\n        ul -= pad\n        br += pad\n\n    new_shape = [br[1] - ul[1], br[0] - ul[0]]\n    if len(img.shape) > 2:\n        new_shape += [img.shape[2]]\n    new_img = np.zeros(new_shape)\n\n    # Range to fill new array\n    new_x = max(0, -ul[0]), min(br[0], len(img[0])) - ul[0]\n    new_y = max(0, -ul[1]), min(br[1], len(img)) - ul[1]\n    # Range to sample from original image\n    old_x = max(0, ul[0]), min(len(img[0]), br[0])\n    old_y = max(0, ul[1]), min(len(img), br[1])\n    new_img[new_y[0]:new_y[1], new_x[0]:new_x[1]] = img[old_y[0]:old_y[1], old_x[0]:old_x[1]]\n\n    if not rot == 0:\n        # Remove padding\n        new_img = scipy.misc.imrotate(new_img, rot)\n        new_img = new_img[pad:-pad, pad:-pad]\n\n    new_img = im_to_torch(scipy.misc.imresize(new_img, res))\n    return new_img\n\n\ndef get_right(img,gray=False):\n    img = im_to_numpy(img) #H*W*C\n\n    new_img = img[:,0:256,:]\n\n   \n    new_img = im_to_torch(new_img)\n    if gray == True:\n        new_img = new_img[1,:,:];\n\n    return new_img\n\nclass NormalizeInverse(torchvision.transforms.Normalize):\n    \"\"\"\n    Undoes the normalization and returns the reconstructed images in the input domain.\n    \"\"\"\n\n    def __init__(self, mean, std):\n        mean = torch.as_tensor(mean)\n        std = torch.as_tensor(std)\n        std_inv = 1 / (std + 1e-7)\n        mean_inv = -mean * std_inv\n        super().__init__(mean=mean_inv, std=std_inv)\n\n    def __call__(self, tensor):\n        return super().__call__(tensor.clone())\n"
  },
  {
    "path": "test.py",
    "content": "from __future__ import print_function, absolute_import\n\nimport argparse\nimport torch\n\ntorch.backends.cudnn.benchmark = True\n\nfrom scripts.utils.misc import save_checkpoint, adjust_learning_rate\n\nimport scripts.datasets as datasets\nimport scripts.machines as machines\nfrom options import Options\n\ndef main(args):\n    \n    val_loader = torch.utils.data.DataLoader(datasets.COCO('val',args),batch_size=args.test_batch, shuffle=False,\n        num_workers=args.workers, pin_memory=True)\n\n    data_loaders = (None,val_loader)\n\n    Machine = machines.__dict__[args.machine](datasets=data_loaders, args=args)\n\n    Machine.test()\n\nif __name__ == '__main__':\n    parser=Options().init(argparse.ArgumentParser(description='WaterMark Removal'))\n    main(parser.parse_args())\n"
  },
  {
    "path": "watermark_synthesis.ipynb",
    "content": "{\n \"cells\": [\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"SAVE ALL THE SETTING\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"# watermark synthesis\\n\",\n    \"import os \\n\",\n    \"import random\\n\",\n    \"import shutil\\n\",\n    \"from PIL import Image\\n\",\n    \"import numpy as np\\n\",\n    \"\\n\",\n    \"def trans_paste(bg_img,fg_img,mask,box=(0,0)):\\n\",\n    \"    fg_img_trans = Image.new(\\\"RGBA\\\",bg_img.size)\\n\",\n    \"    fg_img_trans.paste(fg_img,box,mask=mask)\\n\",\n    \"    new_img = Image.alpha_composite(bg_img,fg_img_trans)\\n\",\n    \"    return new_img,fg_img_trans\\n\",\n    \"\\n\",\n    \"if os.path.isdir('dataset'):\\n\",\n    \"    shutil.rmtree('dataset')\\n\",\n    \"\\n\",\n    \"os.mkdir('dataset')\\n\",\n    \"BASE_IMG_DIR = '/Users/oishii/Downloads/val2014/'\\n\",\n    \"WATERMARK_DIR = 'logos' #1080 \\n\",\n    \"images = sorted([os.path.join(BASE_IMG_DIR,x) for x in os.listdir(BASE_IMG_DIR) if '.jpg' in x])\\n\",\n    \"watermarks = sorted([os.path.join(WATERMARK_DIR,x).replace(' ','_') for x in os.listdir(WATERMARK_DIR) if '.png' in x])\\n\",\n    \"# rename all the watermark from replace ' ' to '_'\\n\",\n    \"\\n\",\n    \"random.shuffle(images)\\n\",\n    \"random.shuffle(watermarks)\\n\",\n    \"\\n\",\n    \"train_images = images[:int(len(images)*0.7)]\\n\",\n    \"val_images = images[int(len(images)*0.7):int(len(images)*0.8)]\\n\",\n    \"test_images = images[int(len(images)*0.8):]\\n\",\n    \"\\n\",\n    \"train_wms = watermarks[:int(len(watermarks)*0.7)]\\n\",\n    \"val_wms = watermarks[int(len(watermarks)*0.7):int(len(watermarks)*0.8)]\\n\",\n    \"test_wms = watermarks[int(len(watermarks)*0.8):]\\n\",\n    \"\\n\",\n    \"# save all the settings to file\\n\",\n    \"names = ['train_images','val_images','test_images','train_wms','val_wms','test_wms']\\n\",\n    \"lists = [train_images,val_images,test_images,train_wms,val_wms,test_wms]\\n\",\n    \"dataset = dict(zip(names, lists))\\n\",\n    \"\\n\",\n    \"for name,content in dataset.items():\\n\",\n    \"    with open('dataset/%s.txt'%name,'w') as f:\\n\",\n    \"        f.write(\\\"\\\\n\\\".join(content))\\n\",\n    \"\\n\",\n    \"print('SAVE ALL THE SETTING')\\n\",\n    \"\\n\",\n    \"for name, images in dataset.items():\\n\",\n    \"    if 'images' not in name:\\n\",\n    \"        continue\\n\",\n    \"    # for each setting, synthesis the watermark\\n\",\n    \"    # for each image, add X(X=6) watermark in differnet position, alpha,\\n\",\n    \"    # save the synthesized image, watermark mask, reshaped mask,\\n\",\n    \"    save_path = 'dataset/%s/'%name\\n\",\n    \"    os.makedirs('%s/image'%(save_path))\\n\",\n    \"    os.makedirs('%s/mask'%(save_path))\\n\",\n    \"    os.makedirs('%s/wm'%(save_path))\\n\",\n    \"    \\n\",\n    \"    for img in images:\\n\",\n    \"        im = Image.open(img).convert('RGBA')\\n\",\n    \"        imw,imh = im.size\\n\",\n    \"        \\n\",\n    \"        for wmg in random.choices(dataset[name.replace('images','wms')],k=6):\\n\",\n    \"            wm = Image.open(wmg.replace('_',' ')).convert(\\\"RGBA\\\") # RGBA\\n\",\n    \"            # get the mask of wm\\n\",\n    \"            # data agumentation of wm\\n\",\n    \"            wm = wm.rotate(angle=random.randint(0,360),expand=True) # rotate\\n\",\n    \"            \\n\",\n    \"            # make sure the \\n\",\n    \"            imrw = random.randrange(int(0.4*imw),int(0.8*imw))\\n\",\n    \"            imrh = random.randrange(int(0.4*imh),int(0.8*imh))\\n\",\n    \"            wmsize = imrh if imrw > imrh else imrw\\n\",\n    \"            wm = wm.resize((wmsize,wmsize),Image.BILINEAR)\\n\",\n    \"            w,h = wm.size # new size \\n\",\n    \"            \\n\",\n    \"            box_left = random.randint(0,imw-w)\\n\",\n    \"            box_upper = random.randint(0,imh-h)\\n\",\n    \"            wmm = wm.copy()\\n\",\n    \"            wm.putalpha(random.randint(int(255*0.4),int(255*0.8))) # alpha\\n\",\n    \"            \\n\",\n    \"            ims,wmc = trans_paste(im,wm,wmm,(box_left,box_upper))\\n\",\n    \"            \\n\",\n    \"            wmnp = np.array(wmc) # h,w,3\\n\",\n    \"            mask = np.sum(wmnp,axis=2)>0\\n\",\n    \"            mm = Image.fromarray(np.uint8(mask*255),mode='L')\\n\",\n    \"            \\n\",\n    \"            identifier = os.path.basename(img).split('.')[0] +'-'+os.path.basename(wmg).split('.')[0] + '.png'\\n\",\n    \"            # save \\n\",\n    \"            wmc.save('%s/wm/%s'%(save_path,identifier))\\n\",\n    \"            ims.save('%s/image/%s'%(save_path,identifier))\\n\",\n    \"            mm.save('%s/mask/%s'%(save_path,identifier))\\n\",\n    \"            \\n\",\n    \"            \"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": []\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": []\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": []\n  }\n ],\n \"metadata\": {\n  \"kernelspec\": {\n   \"display_name\": \"Python 3\",\n   \"language\": \"python\",\n   \"name\": \"python3\"\n  },\n  \"language_info\": {\n   \"codemirror_mode\": {\n    \"name\": \"ipython\",\n    \"version\": 3\n   },\n   \"file_extension\": \".py\",\n   \"mimetype\": \"text/x-python\",\n   \"name\": \"python\",\n   \"nbconvert_exporter\": \"python\",\n   \"pygments_lexer\": \"ipython3\",\n   \"version\": \"3.7.4\"\n  }\n },\n \"nbformat\": 4,\n \"nbformat_minor\": 2\n}\n"
  }
]