Repository: vinthony/deep-blind-watermark-removal Branch: main Commit: 72f0e61b9f06 Files: 35 Total size: 174.8 KB Directory structure: gitextract_p5y2m_4m/ ├── README.md ├── examples/ │ ├── evaluate.sh │ └── test.sh ├── main.py ├── options.py ├── requirements.txt ├── scripts/ │ ├── __init__.py │ ├── datasets/ │ │ ├── BIH.py │ │ ├── COCO.py │ │ └── __init__.py │ ├── machines/ │ │ ├── BasicMachine.py │ │ ├── S2AM.py │ │ ├── VX.py │ │ └── __init__.py │ ├── models/ │ │ ├── __init__.py │ │ ├── backbone_unet.py │ │ ├── blocks.py │ │ ├── discriminator.py │ │ ├── rasc.py │ │ ├── sa_resunet.py │ │ ├── unet.py │ │ ├── vgg.py │ │ └── vmu.py │ └── utils/ │ ├── __init__.py │ ├── evaluation.py │ ├── imutils.py │ ├── logger.py │ ├── losses.py │ ├── misc.py │ ├── model_init.py │ ├── osutils.py │ ├── parallel.py │ └── transforms.py ├── test.py └── watermark_synthesis.ipynb ================================================ FILE CONTENTS ================================================ ================================================ FILE: README.md ================================================ This repo contains the code and results of the AAAI 2021 paper: [Split then Refine: Stacked Attention-guided ResUNets for Blind Single Image Visible Watermark Removal](https://arxiv.org/abs/2012.07007)
[Xiaodong Cun](http://vinthony.github.io), [Chi-Man Pun*](http://www.cis.umac.mo/~cmpun/)
[University of Macau](http://um.edu.mo/) [Datasets](#Resources) | [Models](#Resources) | [Paper](https://arxiv.org/abs/2012.07007) | [🔥Online Demo!](https://colab.research.google.com/drive/1pYY7byBjM-7aFIWk8HcF9nK_s6pqGwww?usp=sharing)(Google CoLab)
nn The overview of the proposed two-stage framework. Firstly, we propose a multi-task network, SplitNet, for watermark detection, removal, and recovery. Then, we propose the RefineNet to smooth the learned region with the predicted mask and the recovered background from the previous stage. As a consequence, our network can be trained in an end-to-end fashion without any manual intervention. Note that, for clarity, we do not show any skip-connections between all the encoders and decoders.
> The whole project will be released in the January of 2021 (almost). ### Datasets We synthesized four different datasets for training and testing, you can download the dataset via [huggingface](https://huggingface.co/datasets/vinthony/watermark-removal-logo/tree/main). ![image](https://user-images.githubusercontent.com/4397546/104273158-74413900-54d9-11eb-95fa-c6bee94de0ea.png) ### Pre-trained Models * [27kpng_model_best.pth.tar (google drive)](https://drive.google.com/file/d/1KpSJ6385CHN6WlAINqB3CYrJdleQTJBc/view?usp=sharing) > Other Pre-trained Models are still reorganizing and uploading, it will be released soon. ### Demos An easy-to-use online demo can be founded in [google colab](https://colab.research.google.com/drive/1pYY7byBjM-7aFIWk8HcF9nK_s6pqGwww?usp=sharing). The local demo will be released soon. ### Pre-requirements ``` pip install -r requirements.txt ``` ### Train Besides training our methods, here, we also give an example of how to train the [s2am](https://github.com/vinthony/s2am) under our framework. More details can be found in the shell scripts. ``` bash examples/evaluation.sh ``` ### Test ``` bash examples/test.sh ``` ## **Acknowledgements** The author would like to thanks Nan Chen for her helpful discussion. Part of the code is based upon our previous work on image harmonization [s2am](https://github.com/vinthony/s2am) ## **Citation** If you find our work useful in your research, please consider citing: ``` @misc{cun2020split, title={Split then Refine: Stacked Attention-guided ResUNets for Blind Single Image Visible Watermark Removal}, author={Xiaodong Cun and Chi-Man Pun}, year={2020}, eprint={2012.07007}, archivePrefix={arXiv}, primaryClass={cs.CV} } ``` ## **Contact** Please contact me if there is any question (Xiaodong Cun yb87432@um.edu.mo) ================================================ FILE: examples/evaluate.sh ================================================ set -ex # example training scripts for AAAI-21 # Split then Refine: Stacked Attention-guided ResUNets for Blind Single Image Visible Watermark Removal CUDA_VISIBLE_DEVICES=0 python /data/home/yb87432/s2am/main.py --epochs 100\ --schedule 100\ --lr 1e-3\ -c eval/10kgray/1e3_bs4_256_hybrid_ssim_vgg\ --arch vvv4n\ --sltype vggx\ --style-loss 0.025\ --ssim-loss 0.15\ --masked True\ --loss-type hybrid\ --limited-dataset 1\ --machine vx\ --input-size 256\ --train-batch 4\ --test-batch 1\ --base-dir $HOME/watermark/10kgray/\ --data _images # example training scripts for TIP-20 # Improving the Harmony of the Composite Image by Spatial-Separated Attention Module # * in the original version, the res = False # suitable for the iHarmony4 dataset. python /data/home/yb87432/mypaper/s2am/main.py --epochs 200\ --schedule 150\ --lr 1e-3\ -c checkpoint/normal_rasc_HAdobe5k_res \ --arch rascv2\ --style-loss 0\ --ssim-loss 0\ --limited-dataset 0\ --res True\ --machine s2am\ --input-size 256\ --train-batch 16\ --test-batch 1\ --base-dir $HOME/Datasets/\ --data HAdobe5k ================================================ FILE: examples/test.sh ================================================ set -ex CUDA_VISIBLE_DEVICES=0 python /data/home/yb87432/s2am/test.py \ -c test/10kgray_ssim\ --resume /data/home/yb87432/s2am/eval/10kgray/1e3_bs6_256_hybrid_ssim_vgg_vx__images_vvv4n/model_best.pth.tar\ --arch vvv4n\ --machine vx\ --input-size 256\ --test-batch 1\ --evaluate\ --base-dir $HOME/watermark/10kgray/\ --data _images ================================================ FILE: main.py ================================================ from __future__ import print_function, absolute_import import argparse import torch,time,os torch.backends.cudnn.benchmark = True from scripts.utils.misc import save_checkpoint, adjust_learning_rate import scripts.datasets as datasets import scripts.machines as machines from options import Options def main(args): if 'HFlickr' or 'HCOCO' or 'Hday2night' or 'HAdobe5k' in args.base_dir: dataset_func = datasets.BIH else: dataset_func = datasets.COCO train_loader = torch.utils.data.DataLoader(dataset_func('train',args),batch_size=args.train_batch, shuffle=True, num_workers=args.workers, pin_memory=True) val_loader = torch.utils.data.DataLoader(dataset_func('val',args),batch_size=args.test_batch, shuffle=False, num_workers=args.workers, pin_memory=True) lr = args.lr data_loaders = (train_loader,val_loader) Machine = machines.__dict__[args.machine](datasets=data_loaders, args=args) print('============================ Initization Finish && Training Start =============================================') for epoch in range(Machine.args.start_epoch, Machine.args.epochs): print('\nEpoch: %d | LR: %.8f' % (epoch + 1, lr)) lr = adjust_learning_rate(data_loaders, Machine.optimizer, epoch, lr, args) Machine.record('lr',lr, epoch) Machine.train(epoch) if args.freq < 0: Machine.validate(epoch) Machine.flush() Machine.save_checkpoint() if __name__ == '__main__': parser=Options().init(argparse.ArgumentParser(description='WaterMark Removal')) args = parser.parse_args() print('==================================== WaterMark Removal =============================================') print('==> {:50}: {:<}'.format("Start Time",time.ctime(time.time()))) print('==> {:50}: {:<}'.format("USE GPU",os.environ['CUDA_VISIBLE_DEVICES'])) print('==================================== Stable Parameters =============================================') for arg in vars(args): if type(getattr(args, arg)) == type([]): if ','.join([ str(i) for i in getattr(args, arg)]) == ','.join([ str(i) for i in parser.get_default(arg)]): print('==> {:50}: {:<}({:<})'.format(arg,','.join([ str(i) for i in getattr(args, arg)]),','.join([ str(i) for i in parser.get_default(arg)]))) else: if getattr(args, arg) == parser.get_default(arg): print('==> {:50}: {:<}({:<})'.format(arg,getattr(args, arg),parser.get_default(arg))) print('==================================== Changed Parameters =============================================') for arg in vars(args): if type(getattr(args, arg)) == type([]): if ','.join([ str(i) for i in getattr(args, arg)]) != ','.join([ str(i) for i in parser.get_default(arg)]): print('==> {:50}: {:<}({:<})'.format(arg,','.join([ str(i) for i in getattr(args, arg)]),','.join([ str(i) for i in parser.get_default(arg)]))) else: if getattr(args, arg) != parser.get_default(arg): print('==> {:50}: {:<}({:<})'.format(arg,getattr(args, arg),parser.get_default(arg))) print('==================================== Start Init Model ===============================================') main(args) print('==================================== FINISH WITHOUT ERROR =============================================') ================================================ FILE: options.py ================================================ import scripts.models as models model_names = sorted(name for name in models.__dict__ if name.islower() and not name.startswith("__") and callable(models.__dict__[name])) class Options(): """docstring for Options""" def __init__(self): pass def init(self, parser): # Model structure parser.add_argument('--arch', '-a', metavar='ARCH', default='dhn', choices=model_names, help='model architecture: ' + ' | '.join(model_names) + ' (default: resnet18)') parser.add_argument('--darch', metavar='ARCH', default='dhn', choices=model_names, help='model architecture: ' + ' | '.join(model_names) + ' (default: resnet18)') parser.add_argument('--machine', '-m', metavar='NACHINE', default='basic') # Training strategy parser.add_argument('-j', '--workers', default=2, type=int, metavar='N', help='number of data loading workers (default: 4)') parser.add_argument('--epochs', default=30, type=int, metavar='N', help='number of total epochs to run') parser.add_argument('--start-epoch', default=0, type=int, metavar='N', help='manual epoch number (useful on restarts)') parser.add_argument('--train-batch', default=64, type=int, metavar='N', help='train batchsize') parser.add_argument('--test-batch', default=6, type=int, metavar='N', help='test batchsize') parser.add_argument('--lr', '--learning-rate', default=1e-3, type=float,metavar='LR', help='initial learning rate') parser.add_argument('--dlr', '--dlearning-rate', default=1e-3, type=float, help='initial learning rate') parser.add_argument('--beta1', default=0.9, type=float, help='initial learning rate') parser.add_argument('--beta2', default=0.999, type=float, help='initial learning rate') parser.add_argument('--momentum', default=0, type=float, metavar='M', help='momentum') parser.add_argument('--weight-decay', '--wd', default=0, type=float, metavar='W', help='weight decay (default: 0)') parser.add_argument('--schedule', type=int, nargs='+', default=[5, 10], help='Decrease learning rate at these epochs.') parser.add_argument('--gamma', type=float, default=0.1, help='LR is multiplied by gamma on schedule.') # Data processing parser.add_argument('-f', '--flip', dest='flip', action='store_true', help='flip the input during validation') parser.add_argument('--lambdaL1', type=float, default=1, help='the weight of L1.') parser.add_argument('--alpha', type=float, default=0.5, help='Groundtruth Gaussian sigma.') parser.add_argument('--sigma-decay', type=float, default=0, help='Sigma decay rate for each epoch.') # Miscs parser.add_argument('--base-dir', default='/PATH_TO_DATA_FOLDER/', type=str, metavar='PATH') parser.add_argument('--data', default='', type=str, metavar='PATH', help='path to save checkpoint (default: checkpoint)') parser.add_argument('-c', '--checkpoint', default='checkpoint', type=str, metavar='PATH', help='path to save checkpoint (default: checkpoint)') parser.add_argument('--resume', default='', type=str, metavar='PATH', help='path to latest checkpoint (default: none)') parser.add_argument('--finetune', default='', type=str, metavar='PATH', help='path to latest checkpoint (default: none)') parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', help='evaluate model on validation set') parser.add_argument('--style-loss', default=0, type=float, help='preception loss') parser.add_argument('--ssim-loss', default=0, type=float,help='msssim loss') parser.add_argument('--att-loss', default=1, type=float,help='msssim loss') parser.add_argument('--default-loss',default=False,type=bool) parser.add_argument('--sltype', default='vggx', type=str) parser.add_argument('-da', '--data-augumentation', default=False, type=bool, help='preception loss') parser.add_argument('-d', '--debug', dest='debug', action='store_true', help='show intermediate results') parser.add_argument('--input-size', default=256, type=int, metavar='N', help='train batchsize') parser.add_argument('--freq', default=-1, type=int, metavar='N', help='evaluation frequence') parser.add_argument('--normalized-input', default=False, type=bool, help='train batchsize') parser.add_argument('--res', default=False, type=bool,help='residual learning for s2am') parser.add_argument('--requires-grad', default=False, type=bool, help='train batchsize') parser.add_argument('--limited-dataset', default=0, type=int, metavar='N') parser.add_argument('--gpu',default=True,type=bool) parser.add_argument('--masked',default=False,type=bool) parser.add_argument('--gan-norm', default=False,type=bool, help='train batchsize') parser.add_argument('--hl', default=False,type=bool, help='homogenious leanring') parser.add_argument('--loss-type', default='l2',type=str, help='train batchsize') return parser ================================================ FILE: requirements.txt ================================================ numpy==1.19.1 opencv-python==3.4.8.29 Pillow scikit-image==0.14.5 scikit-learn==0.23.1 scipy==1.2.1 sklearn==0.0 tensorboardX torch>=1.0.0 torchvision ================================================ FILE: scripts/__init__.py ================================================ from __future__ import absolute_import from . import datasets from . import models from . import utils # import os, sys # sys.path.append(os.path.join(os.path.dirname(__file__), "progress")) # from progress.bar import Bar as Bar # __version__ = '0.1.0' ================================================ FILE: scripts/datasets/BIH.py ================================================ from __future__ import print_function, absolute_import import os import csv import numpy as np import json import random import math import matplotlib.pyplot as plt from collections import namedtuple from os import listdir from os.path import isfile, join import torch import torch.utils.data as data from scripts.utils.osutils import * from scripts.utils.imutils import * from scripts.utils.transforms import * import torchvision.transforms as transforms from PIL import Image from PIL import ImageEnhance from PIL import ImageFilter from PIL import ImageFile ImageFile.LOAD_TRUNCATED_IMAGES = True class BIH(data.Dataset): def __init__(self,train,config=None, sample=[],gan_norm=False): self.train = [] self.anno = [] self.mask = [] self.wm = [] self.input_size = config.input_size self.normalized_input = config.normalized_input self.base_folder = config.base_dir +'/' + config.data self.dataset = config.data if config == None: self.data_augumentation = False else: self.data_augumentation = config.data_augumentation self.istrain = False if train.find('train') == -1 else True self.sample = sample self.gan_norm = gan_norm mypath = join(self.base_folder,self.dataset+'_'+train+'.txt') with open(mypath) as f: # here we get the filenames file_names = [ im.strip() for im in f.readlines() ] if config.limited_dataset > 0: xtrain = sorted(list(set([ file_name.split('-')[0] for file_name in file_names ]))) tmp = [] for x in xtrain: tmp.append([y for y in file_names if x in y][0]) file_names = tmp else: file_names = file_names for file_name in file_names: self.train.append(os.path.join(self.base_folder,'images',file_name)) self.mask.append(os.path.join(self.base_folder,'masks','_'.join(file_name.split('_')[0:2])+'.png')) self.anno.append(os.path.join(self.base_folder,'reals',file_name.split('_')[0]+'.jpg')) if len(self.sample) > 0 : self.train = [ self.train[i] for i in self.sample ] self.mask = [ self.mask[i] for i in self.sample ] self.anno = [ self.anno[i] for i in self.sample ] self.trans = transforms.Compose([ transforms.Resize((self.input_size,self.input_size)), transforms.ToTensor() ]) print('total Dataset of '+self.dataset+' is : ', len(self.train)) def __getitem__(self, index): img = Image.open(self.train[index]).convert('RGB') mask = Image.open(self.mask[index]).convert('L') anno = Image.open(self.anno[index]).convert('RGB') # for shadow removal and blind image harmonization, here is no ground truth wm # wm = Image.open(self.wm[index]).convert('RGB') return {"image": self.trans(img), "target": self.trans(anno), "mask": self.trans(mask), "name": self.train[index].split('/')[-1], "imgurl":self.train[index], "maskurl":self.mask[index], "targeturl":self.anno[index], } def __len__(self): return len(self.train) ================================================ FILE: scripts/datasets/COCO.py ================================================ from __future__ import print_function, absolute_import import os import csv import numpy as np import json import random import math import matplotlib.pyplot as plt from collections import namedtuple from os import listdir from os.path import isfile, join import torch import torch.utils.data as data from scripts.utils.osutils import * from scripts.utils.imutils import * from scripts.utils.transforms import * import torchvision.transforms as transforms from PIL import Image from PIL import ImageEnhance from PIL import ImageFilter from PIL import ImageFile ImageFile.LOAD_TRUNCATED_IMAGES = True class COCO(data.Dataset): def __init__(self,train,config=None, sample=[],gan_norm=False): self.train = [] self.anno = [] self.mask = [] self.wm = [] self.input_size = config.input_size self.normalized_input = config.normalized_input self.base_folder = config.base_dir self.dataset = train+config.data if config == None: self.data_augumentation = False else: self.data_augumentation = config.data_augumentation self.istrain = False if self.dataset.find('train') == -1 else True self.sample = sample self.gan_norm = gan_norm mypath = join(self.base_folder,self.dataset) file_names = sorted([f for f in listdir(join(mypath,'image')) if isfile(join(mypath,'image', f)) ]) if config.limited_dataset > 0: xtrain = sorted(list(set([ file_name.split('-')[0] for file_name in file_names ]))) tmp = [] for x in xtrain: # get the file_name by identifier tmp.append([y for y in file_names if x in y][0]) file_names = tmp else: file_names = file_names for file_name in file_names: self.train.append(os.path.join(mypath,'image',file_name)) self.mask.append(os.path.join(mypath,'mask',file_name)) self.wm.append(os.path.join(mypath,'wm',file_name)) self.anno.append(os.path.join(self.base_folder,'natural',file_name.split('-')[0]+'.jpg')) if len(self.sample) > 0 : self.train = [ self.train[i] for i in self.sample ] self.mask = [ self.mask[i] for i in self.sample ] self.anno = [ self.anno[i] for i in self.sample ] self.trans = transforms.Compose([ transforms.Resize((self.input_size,self.input_size)), transforms.ToTensor() ]) print('total Dataset of '+self.dataset+' is : ', len(self.train)) def __getitem__(self, index): img = Image.open(self.train[index]).convert('RGB') mask = Image.open(self.mask[index]).convert('L') anno = Image.open(self.anno[index]).convert('RGB') wm = Image.open(self.wm[index]).convert('RGB') return {"image": self.trans(img), "target": self.trans(anno), "mask": self.trans(mask), "wm": self.trans(wm), "name": self.train[index].split('/')[-1], "imgurl":self.train[index], "maskurl":self.mask[index], "targeturl":self.anno[index], "wmurl":self.wm[index] } def __len__(self): return len(self.train) ================================================ FILE: scripts/datasets/__init__.py ================================================ from .COCO import COCO from .BIH import BIH __all__ = ('COCO','BIH') ================================================ FILE: scripts/machines/BasicMachine.py ================================================ import torch import torch.nn as nn import torch.backends.cudnn as cudnn from progress.bar import Bar import json import numpy as np from tensorboardX import SummaryWriter from scripts.utils.evaluation import accuracy, AverageMeter, final_preds from scripts.utils.osutils import mkdir_p, isfile, isdir, join from scripts.utils.parallel import DataParallelModel, DataParallelCriterion import pytorch_ssim as pytorch_ssim import torch.optim import sys,shutil,os import time import scripts.models as archs from math import log10 from torch.autograd import Variable from scripts.utils.losses import VGGLoss from scripts.utils.imutils import im_to_numpy import skimage.io from skimage.measure import compare_psnr,compare_ssim class BasicMachine(object): def __init__(self, datasets =(None,None), models = None, args = None, **kwargs): super(BasicMachine, self).__init__() self.args = args # create model print("==> creating model ") self.model = archs.__dict__[self.args.arch]() print("==> creating model [Finish]") self.train_loader, self.val_loader = datasets self.loss = torch.nn.MSELoss() self.title = '_'+args.machine + '_' + args.data + '_' + args.arch self.args.checkpoint = args.checkpoint + self.title self.device = torch.device('cuda') # create checkpoint dir if not isdir(self.args.checkpoint): mkdir_p(self.args.checkpoint) self.optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, self.model.parameters()), lr=args.lr, betas=(args.beta1,args.beta2), weight_decay=args.weight_decay) if not self.args.evaluate: self.writer = SummaryWriter(self.args.checkpoint+'/'+'ckpt') self.best_acc = 0 self.is_best = False self.current_epoch = 0 self.metric = -100000 self.hl = 6 if self.args.hl else 1 self.count_gpu = len(range(torch.cuda.device_count())) if self.args.style_loss > 0: # init perception loss self.vggloss = VGGLoss(self.args.sltype).to(self.device) if self.count_gpu > 1 : # multiple # self.model = DataParallelModel(self.model, device_ids=range(torch.cuda.device_count())) # self.loss = DataParallelCriterion(self.loss, device_ids=range(torch.cuda.device_count())) self.model = nn.DataParallel(self.model, device_ids=range(torch.cuda.device_count())) self.model.to(self.device) self.loss.to(self.device) print('==> Total params: %.2fM' % (sum(p.numel() for p in self.model.parameters())/1000000.0)) print('==> Total devices: %d' % (torch.cuda.device_count())) print('==> Current Checkpoint: %s' % (self.args.checkpoint)) if self.args.resume != '': self.resume(self.args.resume) def train(self,epoch): batch_time = AverageMeter() data_time = AverageMeter() losses = AverageMeter() lossvgg = AverageMeter() # switch to train mode self.model.train() end = time.time() bar = Bar('Processing', max=len(self.train_loader)*self.hl) for _ in range(self.hl): for i, batches in enumerate(self.train_loader): # measure data loading time inputs = batches['image'] target = batches['target'].to(self.device) mask =batches['mask'].to(self.device) current_index = len(self.train_loader) * epoch + i if self.args.hl: feeded = torch.cat([inputs,mask],dim=1) else: feeded = inputs feeded = feeded.to(self.device) output = self.model(feeded) L2_loss = self.loss(output,target) if self.args.style_loss > 0: vgg_loss = self.vggloss(output,target,mask) else: vgg_loss = 0 total_loss = L2_loss + self.args.style_loss * vgg_loss # compute gradient and do SGD step self.optimizer.zero_grad() total_loss.backward() self.optimizer.step() # measure accuracy and record loss losses.update(L2_loss.item(), inputs.size(0)) if self.args.style_loss > 0 : lossvgg.update(vgg_loss.item(), inputs.size(0)) # measure elapsed time batch_time.update(time.time() - end) end = time.time() # plot progress suffix = '({batch}/{size}) Data: {data:.2f}s | Batch: {bt:.3f}s | Total: {total:} | ETA: {eta:} | Loss L2: {loss_label:.4f} | Loss VGG: {loss_vgg:.4f}'.format( batch=i + 1, size=len(self.train_loader), data=data_time.val, bt=batch_time.val, total=bar.elapsed_td, eta=bar.eta_td, loss_label=losses.avg, loss_vgg=lossvgg.avg ) if current_index % 1000 == 0: print(suffix) if self.args.freq > 0 and current_index % self.args.freq == 0: self.validate(current_index) self.flush() self.save_checkpoint() self.record('train/loss_L2', losses.avg, current_index) def test(self, ): # switch to evaluate mode self.model.eval() ssimes = AverageMeter() psnres = AverageMeter() with torch.no_grad(): for i, batches in enumerate(self.val_loader): inputs = batches['image'].to(self.device) target = batches['target'].to(self.device) mask =batches['mask'].to(self.device) outputs = self.model(inputs) # select the outputs by the giving arch if type(outputs) == type(inputs): output = outputs elif type(outputs[0]) == type([]): output = outputs[0][0] else: output = outputs[0] # recover the image to 255 output = im_to_numpy(torch.clamp(output[0]*255,min=0.0,max=255.0)).astype(np.uint8) target = im_to_numpy(torch.clamp(target[0]*255,min=0.0,max=255.0)).astype(np.uint8) skimage.io.imsave('%s/%s'%(self.args.checkpoint,batches['name'][0]), output) psnr = compare_psnr(target,output) ssim = compare_ssim(target,output,multichannel=True) psnres.update(psnr, inputs.size(0)) ssimes.update(ssim, inputs.size(0)) print("%s:PSNR:%s,SSIM:%s"%(self.args.checkpoint,psnres.avg,ssimes.avg)) print("DONE.\n") def validate(self, epoch): batch_time = AverageMeter() data_time = AverageMeter() losses = AverageMeter() ssimes = AverageMeter() psnres = AverageMeter() # switch to evaluate mode self.model.eval() end = time.time() with torch.no_grad(): for i, batches in enumerate(self.val_loader): inputs = batches['image'].to(self.device) target = batches['target'].to(self.device) mask =batches['mask'].to(self.device) if self.args.hl: feeded = torch.cat([inputs,torch.zeros((1,4,self.args.input_size,self.args.input_size)).to(self.device)],dim=1) else: feeded = inputs output = self.model(feeded) L2_loss = self.loss(output, target) psnr = 10 * log10(1 / L2_loss.item()) ssim = pytorch_ssim.ssim(output, target) losses.update(L2_loss.item(), inputs.size(0)) psnres.update(psnr, inputs.size(0)) ssimes.update(ssim.item(), inputs.size(0)) # measure elapsed time batch_time.update(time.time() - end) end = time.time() print("Epoches:%s,Losses:%.3f,PSNR:%.3f,SSIM:%.3f"%(epoch+1, losses.avg,psnres.avg,ssimes.avg)) self.record('val/loss_L2', losses.avg, epoch) self.record('val/PSNR', psnres.avg, epoch) self.record('val/SSIM', ssimes.avg, epoch) self.metric = psnres.avg def resume(self,resume_path): if isfile(resume_path): print("=> loading checkpoint '{}'".format(resume_path)) current_checkpoint = torch.load(resume_path) if isinstance(current_checkpoint['state_dict'], torch.nn.DataParallel): current_checkpoint['state_dict'] = current_checkpoint['state_dict'].module if isinstance(current_checkpoint['optimizer'], torch.nn.DataParallel): current_checkpoint['optimizer'] = current_checkpoint['optimizer'].module self.args.start_epoch = current_checkpoint['epoch'] self.metric = current_checkpoint['best_acc'] self.model.load_state_dict(current_checkpoint['state_dict']) # self.optimizer.load_state_dict(current_checkpoint['optimizer']) print("=> loaded checkpoint '{}' (epoch {})" .format(resume_path, current_checkpoint['epoch'])) else: raise Exception("=> no checkpoint found at '{}'".format(resume_path)) def save_checkpoint(self,filename='checkpoint.pth.tar', snapshot=None): is_best = True if self.best_acc < self.metric else False if is_best: self.best_acc = self.metric state = { 'epoch': self.current_epoch + 1, 'arch': self.args.arch, 'state_dict': self.model.state_dict(), 'best_acc': self.best_acc, 'optimizer' : self.optimizer.state_dict() if self.optimizer else None, } filepath = os.path.join(self.args.checkpoint, filename) torch.save(state, filepath) if snapshot and state['epoch'] % snapshot == 0: shutil.copyfile(filepath, os.path.join(self.args.checkpoint, 'checkpoint_{}.pth.tar'.format(state.epoch))) if is_best: self.best_acc = self.metric print('Saving Best Metric with PSNR:%s'%self.best_acc) shutil.copyfile(filepath, os.path.join(self.args.checkpoint, 'model_best.pth.tar')) def clean(self): self.writer.close() def record(self,k,v,epoch): self.writer.add_scalar(k, v, epoch) def flush(self): self.writer.flush() sys.stdout.flush() def norm(self,x): if self.args.gan_norm: return x*2.0 - 1.0 else: return x def denorm(self,x): if self.args.gan_norm: return (x+1.0)/2.0 else: return x ================================================ FILE: scripts/machines/S2AM.py ================================================ import torch import torch.nn as nn import torch.backends.cudnn as cudnn from progress.bar import Bar import json import numpy as np from tensorboardX import SummaryWriter from scripts.utils.evaluation import accuracy, AverageMeter, final_preds from scripts.utils.osutils import mkdir_p, isfile, isdir, join from scripts.utils.parallel import DataParallelModel, DataParallelCriterion import pytorch_ssim as pytorch_ssim import torch.optim import sys,shutil,os import time import scripts.models as archs from math import log10 from torch.autograd import Variable from scripts.utils.losses import VGGLoss from scripts.utils.imutils import im_to_numpy import skimage.io from skimage.measure import compare_psnr,compare_ssim class S2AM(object): def __init__(self, datasets =(None,None), models = None, args = None, **kwargs): super(S2AM, self).__init__() self.args = args # create model print("==> creating model ") self.model = archs.__dict__[self.args.arch]() print("==> creating model [Finish]") self.train_loader, self.val_loader = datasets self.loss = torch.nn.MSELoss() self.title = '_'+args.machine + '_' + args.data + '_' + args.arch self.args.checkpoint = args.checkpoint + self.title self.device = torch.device('cuda') # create checkpoint dir if not isdir(self.args.checkpoint): mkdir_p(self.args.checkpoint) self.optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, self.model.parameters()), lr=args.lr, betas=(args.beta1,args.beta2), weight_decay=args.weight_decay) if not self.args.evaluate: self.writer = SummaryWriter(self.args.checkpoint+'/'+'ckpt') self.best_acc = 0 self.is_best = False self.current_epoch = 0 self.hl = 1 self.metric = -100000 self.count_gpu = len(range(torch.cuda.device_count())) if self.args.style_loss > 0: # init perception loss self.vggloss = VGGLoss(self.args.sltype).to(self.device) if self.count_gpu > 1 : # multiple # self.model = DataParallelModel(self.model, device_ids=range(torch.cuda.device_count())) # self.loss = DataParallelCriterion(self.loss, device_ids=range(torch.cuda.device_count())) self.model = nn.DataParallel(self.model, device_ids=range(torch.cuda.device_count())) self.model.to(self.device) self.loss.to(self.device) print('==> Total params: %.2fM' % (sum(p.numel() for p in self.model.parameters())/1000000.0)) print('==> Total devices: %d' % (torch.cuda.device_count())) print('==> Current Checkpoint: %s' % (self.args.checkpoint)) if self.args.resume != '': self.resume(self.args.resume) def train(self,epoch): batch_time = AverageMeter() data_time = AverageMeter() losses = AverageMeter() lossvgg = AverageMeter() # switch to train mode self.model.train() end = time.time() bar = Bar('Processing', max=len(self.train_loader)*self.hl) for _ in range(self.hl): for i, batches in enumerate(self.train_loader): # measure data loading time inputs = batches['image'].to(self.device) target = batches['target'].to(self.device) mask =batches['mask'].to(self.device) current_index = len(self.train_loader) * epoch + i feeded = torch.cat([inputs,mask],dim=1) feeded = feeded.to(self.device) output = self.model(feeded) if self.args.res: output = output + inputs L2_loss = self.loss(output,target) if self.args.style_loss > 0: vgg_loss = self.vggloss(output,target,mask) else: vgg_loss = 0 total_loss = L2_loss + self.args.style_loss * vgg_loss # compute gradient and do SGD step self.optimizer.zero_grad() total_loss.backward() self.optimizer.step() # measure accuracy and record loss losses.update(L2_loss.item(), inputs.size(0)) if self.args.style_loss > 0 : lossvgg.update(vgg_loss.item(), inputs.size(0)) # measure elapsed time batch_time.update(time.time() - end) end = time.time() # plot progress suffix = '({batch}/{size}) Data: {data:.2f}s | Batch: {bt:.3f}s | Total: {total:} | ETA: {eta:} | Loss L2: {loss_label:.4f} | Loss VGG: {loss_vgg:.4f}'.format( batch=i + 1, size=len(self.train_loader), data=data_time.val, bt=batch_time.val, total=bar.elapsed_td, eta=bar.eta_td, loss_label=losses.avg, loss_vgg=lossvgg.avg ) if current_index % 1000 == 0: print(suffix) if self.args.freq > 0 and current_index % self.args.freq == 0: self.validate(current_index) self.flush() self.save_checkpoint() self.record('train/loss_L2', losses.avg, current_index) def test(self, ): # switch to evaluate mode self.model.eval() ssimes = AverageMeter() psnres = AverageMeter() with torch.no_grad(): for i, batches in enumerate(self.val_loader): inputs = batches['image'].to(self.device) target = batches['target'].to(self.device) mask =batches['mask'].to(self.device) feeded = torch.cat([inputs,mask],dim=1) feeded = feeded.to(self.device) output = self.model(feeded) if self.args.res: output = output + inputs # recover the image to 255 output = im_to_numpy(torch.clamp(output[0]*255,min=0.0,max=255.0)).astype(np.uint8) target = im_to_numpy(torch.clamp(target[0]*255,min=0.0,max=255.0)).astype(np.uint8) skimage.io.imsave('%s/%s'%(self.args.checkpoint,batches['name'][0]), output) psnr = compare_psnr(target,output) ssim = compare_ssim(target,output,multichannel=True) psnres.update(psnr, inputs.size(0)) ssimes.update(ssim, inputs.size(0)) print("%s:PSNR:%s,SSIM:%s"%(self.args.checkpoint,psnres.avg,ssimes.avg)) print("DONE.\n") def validate(self, epoch): batch_time = AverageMeter() data_time = AverageMeter() losses = AverageMeter() ssimes = AverageMeter() psnres = AverageMeter() # switch to evaluate mode self.model.eval() end = time.time() with torch.no_grad(): for i, batches in enumerate(self.val_loader): inputs = batches['image'].to(self.device) target = batches['target'].to(self.device) mask =batches['mask'].to(self.device) feeded = torch.cat([inputs,mask],dim=1) feeded = feeded.to(self.device) output = self.model(feeded) if self.args.res: output = output + inputs L2_loss = self.loss(output, target) psnr = 10 * log10(1 / L2_loss.item()) ssim = pytorch_ssim.ssim(output, target) losses.update(L2_loss.item(), inputs.size(0)) psnres.update(psnr, inputs.size(0)) ssimes.update(ssim.item(), inputs.size(0)) # measure elapsed time batch_time.update(time.time() - end) end = time.time() print("Epoches:%s,Losses:%.3f,PSNR:%.3f,SSIM:%.3f"%(epoch+1, losses.avg,psnres.avg,ssimes.avg)) self.record('val/loss_L2', losses.avg, epoch) self.record('val/PSNR', psnres.avg, epoch) self.record('val/SSIM', ssimes.avg, epoch) self.metric = psnres.avg def resume(self,resume_path): if isfile(resume_path): print("=> loading checkpoint '{}'".format(resume_path)) current_checkpoint = torch.load(resume_path) if isinstance(current_checkpoint['state_dict'], torch.nn.DataParallel): current_checkpoint['state_dict'] = current_checkpoint['state_dict'].module if isinstance(current_checkpoint['optimizer'], torch.nn.DataParallel): current_checkpoint['optimizer'] = current_checkpoint['optimizer'].module self.args.start_epoch = current_checkpoint['epoch'] self.metric = current_checkpoint['best_acc'] self.model.load_state_dict(current_checkpoint['state_dict']) # self.optimizer.load_state_dict(current_checkpoint['optimizer']) print("=> loaded checkpoint '{}' (epoch {})" .format(resume_path, current_checkpoint['epoch'])) else: raise Exception("=> no checkpoint found at '{}'".format(resume_path)) def save_checkpoint(self,filename='checkpoint.pth.tar', snapshot=None): is_best = True if self.best_acc < self.metric else False if is_best: self.best_acc = self.metric state = { 'epoch': self.current_epoch + 1, 'arch': self.args.arch, 'state_dict': self.model.state_dict(), 'best_acc': self.best_acc, 'optimizer' : self.optimizer.state_dict() if self.optimizer else None, } filepath = os.path.join(self.args.checkpoint, filename) torch.save(state, filepath) if snapshot and state['epoch'] % snapshot == 0: shutil.copyfile(filepath, os.path.join(self.args.checkpoint, 'checkpoint_{}.pth.tar'.format(state.epoch))) if is_best: self.best_acc = self.metric print('Saving Best Metric with PSNR:%s'%self.best_acc) shutil.copyfile(filepath, os.path.join(self.args.checkpoint, 'model_best.pth.tar')) def clean(self): self.writer.close() def record(self,k,v,epoch): self.writer.add_scalar(k, v, epoch) def flush(self): self.writer.flush() sys.stdout.flush() def norm(self,x): if self.args.gan_norm: return x*2.0 - 1.0 else: return x def denorm(self,x): if self.args.gan_norm: return (x+1.0)/2.0 else: return x ================================================ FILE: scripts/machines/VX.py ================================================ import torch import torch.nn as nn from progress.bar import Bar from tqdm import tqdm import pytorch_ssim import json import sys,time,os import torchvision from math import log10 import numpy as np from .BasicMachine import BasicMachine from scripts.utils.evaluation import accuracy, AverageMeter, final_preds from scripts.utils.misc import resize_to_match from torch.autograd import Variable import torch.nn.functional as F from scripts.utils.parallel import DataParallelModel, DataParallelCriterion from scripts.utils.losses import VGGLoss, l1_relative,is_dic from scripts.utils.imutils import im_to_numpy import skimage.io from skimage.measure import compare_psnr,compare_ssim class Losses(nn.Module): def __init__(self, argx, device, norm_func=None, denorm_func=None): super(Losses, self).__init__() self.args = argx if self.args.loss_type == 'l1bl2': self.outputLoss, self.attLoss, self.wrloss = nn.L1Loss(), nn.BCELoss(), nn.MSELoss() elif self.args.loss_type == 'l2xbl2': self.outputLoss, self.attLoss, self.wrloss = nn.MSELoss(), nn.BCEWithLogitsLoss(), nn.MSELoss() elif self.args.loss_type == 'relative' or self.args.loss_type == 'hybrid': self.outputLoss, self.attLoss, self.wrloss = l1_relative, nn.BCELoss(), l1_relative else: # l2bl2 self.outputLoss, self.attLoss, self.wrloss = nn.MSELoss(), nn.BCELoss(), nn.MSELoss() self.default = nn.L1Loss() if self.args.style_loss > 0: self.vggloss = VGGLoss(self.args.sltype).to(device) if self.args.ssim_loss > 0: self.ssimloss = pytorch_ssim.SSIM().to(device) self.norm = norm_func self.denorm = denorm_func def forward(self,pred_ims,target,pred_ms,mask,pred_wms,wm): pixel_loss,att_loss,wm_loss,vgg_loss,ssim_loss = [0]*5 pred_ims = pred_ims if is_dic(pred_ims) else [pred_ims] # try the loss in the masked region if self.args.masked and 'hybrid' in self.args.loss_type: # masked loss pixel_loss += sum([self.outputLoss(pred_im, target, mask) for pred_im in pred_ims]) pixel_loss += sum([self.default(pred_im*pred_ms,target*mask) for pred_im in pred_ims]) recov_imgs = [ self.denorm(pred_im*mask + (1-mask)*self.norm(target)) for pred_im in pred_ims ] wm_loss += self.wrloss(pred_wms, wm, mask) wm_loss += self.default(pred_wms*pred_ms, wm*mask) elif self.args.masked and 'relative' in self.args.loss_type: # masked loss pixel_loss += sum([self.outputLoss(pred_im, target, mask) for pred_im in pred_ims]) recov_imgs = [ self.denorm(pred_im*mask + (1-mask)*self.norm(target)) for pred_im in pred_ims ] wm_loss = self.wrloss(pred_wms, wm, mask) elif self.args.masked: pixel_loss += sum([self.outputLoss(pred_im*mask, target*mask) for pred_im in pred_ims]) recov_imgs = [ self.denorm(pred_im*pred_ms + (1-pred_ms)*self.norm(target)) for pred_im in pred_ims ] wm_loss = self.wrloss(pred_wms*mask, wm*mask) else: pixel_loss += sum([self.outputLoss(pred_im*pred_ms, target*mask) for pred_im in pred_ims]) recov_imgs = [ self.denorm(pred_im*pred_ms + (1-pred_ms)*self.norm(target)) for pred_im in pred_ims ] wm_loss = self.wrloss(pred_wms*pred_ms,wm*mask) pixel_loss += sum([self.default(im,target) for im in recov_imgs]) if self.args.style_loss > 0: vgg_loss = sum([self.vggloss(im,target,mask) for im in recov_imgs]) if self.args.ssim_loss > 0: ssim_loss = sum([ 1 - self.ssimloss(im,target) for im in recov_imgs]) att_loss = self.attLoss(pred_ms, mask) return pixel_loss,att_loss,wm_loss,vgg_loss,ssim_loss class VX(BasicMachine): def __init__(self,**kwargs): BasicMachine.__init__(self,**kwargs) self.loss = Losses(self.args, self.device, self.norm, self.denorm) self.model.set_optimizers() self.optimizer = None def train(self,epoch): self.current_epoch = epoch batch_time = AverageMeter() data_time = AverageMeter() losses = AverageMeter() lossMask = AverageMeter() lossWM = AverageMeter() lossMX = AverageMeter() lossvgg = AverageMeter() lossssim = AverageMeter() # switch to train mode self.model.train() end = time.time() bar = Bar('Processing {} '.format(self.args.arch), max=len(self.train_loader)) for i, batches in enumerate(self.train_loader): current_index = len(self.train_loader) * epoch + i inputs = batches['image'].to(self.device) target = batches['target'].to(self.device) mask = batches['mask'].to(self.device) wm = batches['wm'].to(self.device) outputs = self.model(self.norm(inputs)) self.model.zero_grad_all() l2_loss,att_loss,wm_loss,style_loss,ssim_loss = self.loss(outputs[0],self.norm(target),outputs[1],mask,outputs[2],self.norm(wm)) total_loss = 2*l2_loss + self.args.att_loss * att_loss + wm_loss + self.args.style_loss * style_loss + self.args.ssim_loss * ssim_loss # compute gradient and do SGD step total_loss.backward() self.model.step_all() # measure accuracy and record loss losses.update(l2_loss.item(), inputs.size(0)) lossMask.update(att_loss.item(), inputs.size(0)) lossWM.update(wm_loss.item(), inputs.size(0)) if self.args.style_loss > 0 : lossvgg.update(style_loss.item(), inputs.size(0)) if self.args.ssim_loss > 0 : lossssim.update(ssim_loss.item(), inputs.size(0)) # measure elapsed time batch_time.update(time.time() - end) end = time.time() # plot progress suffix = "({batch}/{size}) Data: {data:.2f}s | Batch: {bt:.3f}s | Total: {total:} | ETA: {eta:} | Loss L2: {loss_label:.4f} | Loss Mask: {loss_mask:.4f} | loss WM: {loss_wm:.4f} | loss VGG: {loss_vgg:.4f} | loss SSIM: {loss_ssim:.4f}| loss MX: {loss_mx:.4f}".format( batch=i + 1, size=len(self.train_loader), data=data_time.val, bt=batch_time.val, total=bar.elapsed_td, eta=bar.eta_td, loss_label=losses.avg, loss_mask=lossMask.avg, loss_wm=lossWM.avg, loss_vgg=lossvgg.avg, loss_ssim=lossssim.avg, loss_mx=lossMX.avg ) if current_index % 1000 == 0: print(suffix) if self.args.freq > 0 and current_index % self.args.freq == 0: self.validate(current_index) self.flush() self.save_checkpoint() self.record('train/loss_L2', losses.avg, epoch) self.record('train/loss_Mask', lossMask.avg, epoch) self.record('train/loss_WM', lossWM.avg, epoch) self.record('train/loss_VGG', lossvgg.avg, epoch) self.record('train/loss_SSIM', lossssim.avg, epoch) self.record('train/loss_MX', lossMX.avg, epoch) def validate(self, epoch): self.current_epoch = epoch batch_time = AverageMeter() data_time = AverageMeter() losses = AverageMeter() lossMask = AverageMeter() psnres = AverageMeter() ssimes = AverageMeter() # switch to evaluate mode self.model.eval() end = time.time() bar = Bar('Processing {} '.format(self.args.arch), max=len(self.val_loader)) with torch.no_grad(): for i, batches in enumerate(self.val_loader): current_index = len(self.val_loader) * epoch + i inputs = batches['image'].to(self.device) target = batches['target'].to(self.device) outputs = self.model(self.norm(inputs)) imoutput,immask,imwatermark = outputs imoutput = imoutput[0] if is_dic(imoutput) else imoutput imfinal = self.denorm(imoutput*immask + self.norm(inputs)*(1-immask)) if i % 300 == 0: # save the sample images ims = torch.cat([inputs,target,imfinal,immask.repeat(1,3,1,1)],dim=3) torchvision.utils.save_image(ims,os.path.join(self.args.checkpoint,'%s_%s.jpg'%(i,epoch))) # here two choice: mseLoss or NLLLoss psnr = 10 * log10(1 / F.mse_loss(imfinal,target).item()) ssim = pytorch_ssim.ssim(imfinal,target) psnres.update(psnr, inputs.size(0)) ssimes.update(ssim, inputs.size(0)) # measure elapsed time batch_time.update(time.time() - end) end = time.time() # plot progress bar.suffix = '({batch}/{size}) Data: {data:.2f}s | Batch: {bt:.3f}s | Total: {total:} | ETA: {eta:} | Loss_L2: {loss_label:.4f} | Loss_Mask: {loss_mask:.4f} | PSNR: {psnr:.4f} | SSIM: {ssim:.4f}'.format( batch=i + 1, size=len(self.val_loader), data=data_time.val, bt=batch_time.val, total=bar.elapsed_td, eta=bar.eta_td, loss_label=losses.avg, loss_mask=lossMask.avg, psnr=psnres.avg, ssim=ssimes.avg ) bar.next() bar.finish() print("Iter:%s,Losses:%s,PSNR:%.4f,SSIM:%.4f"%(epoch, losses.avg,psnres.avg,ssimes.avg)) self.record('val/loss_L2', losses.avg, epoch) self.record('val/lossMask', lossMask.avg, epoch) self.record('val/PSNR', psnres.avg, epoch) self.record('val/SSIM', ssimes.avg, epoch) self.metric = psnres.avg self.model.train() def test(self, ): # switch to evaluate mode self.model.eval() print("==> testing VM model ") ssimes = AverageMeter() psnres = AverageMeter() ssimesx = AverageMeter() psnresx = AverageMeter() with torch.no_grad(): for i, batches in enumerate(tqdm(self.val_loader)): inputs = batches['image'].to(self.device) target = batches['target'].to(self.device) mask =batches['mask'].to(self.device) # select the outputs by the giving arch outputs = self.model(self.norm(inputs)) imoutput,immask,imwatermark = outputs imoutput = imoutput[0] if is_dic(imoutput) else imoutput imfinal = self.denorm(imoutput*immask + self.norm(inputs)*(1-immask)) psnrx = 10 * log10(1 / F.mse_loss(imfinal,target).item()) ssimx = pytorch_ssim.ssim(imfinal,target) # recover the image to 255 imfinal = im_to_numpy(torch.clamp(imfinal[0]*255,min=0.0,max=255.0)).astype(np.uint8) target = im_to_numpy(torch.clamp(target[0]*255,min=0.0,max=255.0)).astype(np.uint8) skimage.io.imsave('%s/%s'%(self.args.checkpoint,batches['name'][0]), imfinal) psnr = compare_psnr(target,imfinal) ssim = compare_ssim(target,imfinal,multichannel=True) psnres.update(psnr, inputs.size(0)) ssimes.update(ssim, inputs.size(0)) psnresx.update(psnrx, inputs.size(0)) ssimesx.update(ssimx, inputs.size(0)) print("%s:PSNR:%.5f(%.5f),SSIM:%.5f(%.5f)"%(self.args.checkpoint,psnres.avg,psnresx.avg,ssimes.avg,ssimesx.avg)) print("DONE.\n") ================================================ FILE: scripts/machines/__init__.py ================================================ from .BasicMachine import BasicMachine from .VX import VX from .S2AM import S2AM def basic(**kwargs): return BasicMachine(**kwargs) def s2am(**kwargs): return S2AM(**kwargs) def vx(**kwargs): return VX(**kwargs) ================================================ FILE: scripts/models/__init__.py ================================================ from .vgg import * from .backbone_unet import * from .discriminator import * ================================================ FILE: scripts/models/backbone_unet.py ================================================ import torch import torchvision import torch.nn as nn import torch.nn.functional as F import numpy as np import functools import math from scripts.utils.model_init import * from scripts.models.rasc import * from scripts.models.unet import UnetGenerator,MinimalUnetV2 from scripts.models.vmu import UnetVM from scripts.models.sa_resunet import UnetVMS2AMv4 # our method def vvv4n(**kwargs): return UnetVMS2AMv4(shared_depth=2, blocks=3, long_skip=True, use_vm_decoder=True,s2am='vms2am') # BVMR def vm3(**kwargs): return UnetVM(shared_depth=2, blocks=3, use_vm_decoder=True) # Blind version of S2AM def urasc(**kwargs): model = UnetGenerator(3,3,is_attention_layer=True,attention_model=URASC,basicblock=MinimalUnetV2) model.apply(weights_init_kaiming) return model # Improving the Harmony of the Composite Image by Spatial-Separated Attention Module # Xiaodong Cun and Chi-Man Pun # University of Macau # Trans. on Image Processing, vol. 29, pp. 4759-4771, 2020. def rascv2(**kwargs): model = UnetGenerator(4,3,is_attention_layer=True,attention_model=RASC,basicblock=MinimalUnetV2) model.apply(weights_init_kaiming) return model # just original unet def unet(**kwargs): model = UnetGenerator(3,3) model.apply(weights_init_kaiming) return model ================================================ FILE: scripts/models/blocks.py ================================================ import torch import torchvision import torch.nn as nn import torch.nn.functional as F import numpy as np import functools import math import numbers from scripts.utils.model_init import * from scripts.models.vgg import Vgg16 from torch import nn, cuda from torch.autograd import Variable class BasicLearningBlock(nn.Module): """docstring for BasicLearningBlock""" def __init__(self,channel): super(BasicLearningBlock, self).__init__() self.rconv1 = nn.Conv2d(channel,channel*2,3,padding=1,bias=False) self.rbn1 = nn.BatchNorm2d(channel*2) self.rconv2 = nn.Conv2d(channel*2,channel,3,padding=1,bias=False) self.rbn2 = nn.BatchNorm2d(channel) def forward(self,feature): return F.elu(self.rbn2(self.rconv2(F.elu(self.rbn1(self.rconv1(feature)))))) # From https://discuss.pytorch.org/t/is-there-anyway-to-do-gaussian-filtering-for-an-image-2d-3d-in-pytorch/12351/3 class GaussianSmoothing(nn.Module): """ Apply gaussian smoothing on a 1d, 2d or 3d tensor. Filtering is performed seperately for each channel in the input using a depthwise convolution. Arguments: channels (int, sequence): Number of channels of the input tensors. Output will have this number of channels as well. kernel_size (int, sequence): Size of the gaussian kernel. sigma (float, sequence): Standard deviation of the gaussian kernel. dim (int, optional): The number of dimensions of the data. Default value is 2 (spatial). """ def __init__(self, channels, kernel_size, sigma, dim=2): super(GaussianSmoothing, self).__init__() if isinstance(kernel_size, numbers.Number): kernel_size = [kernel_size] * dim if isinstance(sigma, numbers.Number): sigma = [sigma] * dim # The gaussian kernel is the product of the # gaussian function of each dimension. kernel = 1 meshgrids = torch.meshgrid( [ torch.arange(size, dtype=torch.float32) for size in kernel_size ] ) for size, std, mgrid in zip(kernel_size, sigma, meshgrids): mean = (size - 1) / 2 kernel *= 1 / (std * math.sqrt(2 * math.pi)) * \ torch.exp(-((mgrid - mean) / (2 * std)) ** 2) # Make sure sum of values in gaussian kernel equals 1. kernel = kernel / torch.sum(kernel) # Reshape to depthwise convolutional weight kernel = kernel.view(1, 1, *kernel.size()) kernel = kernel.repeat(channels, *[1] * (kernel.dim() - 1)) self.register_buffer('weight', kernel) self.groups = channels if dim == 1: self.conv = F.conv1d elif dim == 2: self.conv = F.conv2d elif dim == 3: self.conv = F.conv3d else: raise RuntimeError( 'Only 1, 2 and 3 dimensions are supported. Received {}.'.format(dim) ) def forward(self, input): """ Apply gaussian filter to input. Arguments: input (torch.Tensor): Input to apply gaussian filter on. Returns: filtered (torch.Tensor): Filtered output. """ return self.conv(input, weight=self.weight, groups=self.groups) class ChannelPool(nn.Module): def __init__(self,types): super(ChannelPool, self).__init__() if types == 'avg': self.poolingx = nn.AdaptiveAvgPool1d(1) elif types == 'max': self.poolingx = nn.AdaptiveMaxPool1d(1) else: raise 'inner error' def forward(self, input): n, c, w, h = input.size() input = input.view(n,c,w*h).permute(0,2,1) pooled = self.poolingx(input)# b,w*h,c -> b,w*h,1 _, _, c = pooled.size() return pooled.view(n,c,w,h) class SEBlock(nn.Module): """docstring for SEBlock""" def __init__(self, channel,reducation=16): super(SEBlock, self).__init__() self.avg_pool = nn.AdaptiveAvgPool2d(1) self.fc = nn.Sequential( nn.Linear(channel,channel//reducation), nn.ReLU(inplace=True), nn.Linear(channel//reducation,channel), nn.Sigmoid()) def forward(self,x): b,c,w,h = x.size() y1 = self.avg_pool(x).view(b,c) y = self.fc(y1).view(b,c,1,1) return x*y class GlobalAttentionModule(nn.Module): """docstring for GlobalAttentionModule""" def __init__(self, channel,reducation=16): super(GlobalAttentionModule, self).__init__() self.avg_pool = nn.AdaptiveAvgPool2d(1) self.max_pool = nn.AdaptiveMaxPool2d(1) self.fc = nn.Sequential( nn.Linear(channel*2,channel//reducation), nn.ReLU(inplace=True), nn.Linear(channel//reducation,channel), nn.Sigmoid()) def forward(self,x): b,c,w,h = x.size() y1 = self.avg_pool(x).view(b,c) y2 = self.max_pool(x).view(b,c) y = self.fc(torch.cat([y1,y2],1)).view(b,c,1,1) return x*y class SpatialAttentionModule(nn.Module): """docstring for SpatialAttentionModule""" def __init__(self, channel,reducation=16): super(SpatialAttentionModule, self).__init__() self.avg_pool = ChannelPool('avg') self.max_pool = ChannelPool('max') self.fc = nn.Sequential( nn.Conv2d(2,reducation,7,stride=1,padding=3), nn.ReLU(inplace=True), nn.Conv2d(reducation,1,7,stride=1,padding=3), nn.Sigmoid()) def forward(self,x): b,c,w,h = x.size() y1 = self.avg_pool(x) y2 = self.max_pool(x) y = self.fc(torch.cat([y1,y2],1)) yr = 1-y return y,yr class GlobalAttentionModuleJustSigmoid(nn.Module): """docstring for GlobalAttentionModule""" def __init__(self, channel,reducation=16): super(GlobalAttentionModuleJustSigmoid, self).__init__() self.avg_pool = nn.AdaptiveAvgPool2d(1) self.max_pool = nn.AdaptiveMaxPool2d(1) self.fc = nn.Sequential( nn.Linear(channel*2,channel//reducation), nn.ReLU(inplace=True), nn.Linear(channel//reducation,channel), nn.Sigmoid()) def forward(self,x): b,c,w,h = x.size() y1 = self.avg_pool(x).view(b,c) y2 = self.max_pool(x).view(b,c) y = self.fc(torch.cat([y1,y2],1)).view(b,c,1,1) return y class BasicBlock(nn.Module): def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1, groups=1, relu=True, bn=True, bias=False): super(BasicBlock, self).__init__() self.out_channels = out_planes self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias) self.bn = nn.BatchNorm2d(out_planes,eps=1e-5, momentum=0.01, affine=True) if bn else None self.relu = nn.ReLU() if relu else None def forward(self, x): x = self.conv(x) if self.bn is not None: x = self.bn(x) if self.relu is not None: x = self.relu(x) return x class Flatten(nn.Module): def forward(self, x): return x.view(x.size(0), -1) class ChannelGate(nn.Module): def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg', 'max']): super(ChannelGate, self).__init__() self.gate_channels = gate_channels self.mlp = nn.Sequential( Flatten(), nn.Linear(gate_channels, gate_channels // reduction_ratio), nn.ReLU(), nn.Linear(gate_channels // reduction_ratio, gate_channels) ) self.pool_types = pool_types def forward(self, x): channel_att_sum = None for pool_type in self.pool_types: if pool_type=='avg': avg_pool = F.avg_pool2d( x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3))) channel_att_raw = self.mlp( avg_pool ) elif pool_type=='max': max_pool = F.max_pool2d( x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3))) channel_att_raw = self.mlp( max_pool ) elif pool_type=='lp': lp_pool = F.lp_pool2d( x, 2, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3))) channel_att_raw = self.mlp( lp_pool ) elif pool_type=='lse': # LSE pool only lse_pool = logsumexp_2d(x) channel_att_raw = self.mlp( lse_pool ) if channel_att_sum is None: channel_att_sum = channel_att_raw else: channel_att_sum = channel_att_sum + channel_att_raw scale = F.sigmoid( channel_att_sum ).unsqueeze(2).unsqueeze(3).expand_as(x) return x * scale def logsumexp_2d(tensor): tensor_flatten = tensor.view(tensor.size(0), tensor.size(1), -1) s, _ = torch.max(tensor_flatten, dim=2, keepdim=True) outputs = s + (tensor_flatten - s).exp().sum(dim=2, keepdim=True).log() return outputs class ChannelPoolX(nn.Module): def forward(self, x): return torch.cat( (torch.max(x,1)[0].unsqueeze(1), torch.mean(x,1).unsqueeze(1)), dim=1 ) class SpatialGate(nn.Module): def __init__(self): super(SpatialGate, self).__init__() kernel_size = 7 self.compress = ChannelPoolX() self.spatial = BasicBlock(2, 1, kernel_size, stride=1, padding=(kernel_size-1) // 2, relu=False) def forward(self, x): x_compress = self.compress(x) x_out = self.spatial(x_compress) scale = F.sigmoid(x_out) # broadcasting return x * scale class CBAM(nn.Module): def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg', 'max'], no_spatial=False): super(CBAM, self).__init__() self.ChannelGate = ChannelGate(gate_channels, reduction_ratio, pool_types) self.no_spatial=no_spatial if not no_spatial: self.SpatialGate = SpatialGate() def forward(self, x): x_out = self.ChannelGate(x) if not self.no_spatial: x_out = self.SpatialGate(x_out) return x_out ================================================ FILE: scripts/models/discriminator.py ================================================ import numpy as np import functools import math import torch from torch.autograd import Variable import torch.nn.functional as F from torch import nn from torch import Tensor from torch.nn import Parameter from scripts.utils.model_init import * from torch.optim.optimizer import Optimizer, required __all__ = ['patchgan','sngan','maskedsngan'] class SNCoXvWithActivation(torch.nn.Module): """ SN convolution for spetral normalization conv """ def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, activation=torch.nn.LeakyReLU(0.2, inplace=True)): super(SNCoXvWithActivation, self).__init__() self.conv2d = torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias) self.conv2d = torch.nn.utils.spectral_norm(self.conv2d) self.activation = activation for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight) def forward(self, input): x = self.conv2d(input) if self.activation is not None: return self.activation(x) else: return x def l2normalize(v, eps=1e-12): return v / (v.norm() + eps) class SpectralNorm(nn.Module): def __init__(self, module, name='weight', power_iterations=1): super(SpectralNorm, self).__init__() self.module = module self.name = name self.power_iterations = power_iterations if not self._made_params(): self._make_params() def _update_u_v(self): u = getattr(self.module, self.name + "_u") v = getattr(self.module, self.name + "_v") w = getattr(self.module, self.name + "_bar") height = w.data.shape[0] for _ in range(self.power_iterations): v.data = l2normalize(torch.mv(torch.t(w.view(height,-1).data), u.data)) u.data = l2normalize(torch.mv(w.view(height,-1).data, v.data)) # sigma = torch.dot(u.data, torch.mv(w.view(height,-1).data, v.data)) sigma = u.dot(w.view(height, -1).mv(v)) setattr(self.module, self.name, w / sigma.expand_as(w)) def _made_params(self): try: u = getattr(self.module, self.name + "_u") v = getattr(self.module, self.name + "_v") w = getattr(self.module, self.name + "_bar") return True except AttributeError: return False def _make_params(self): w = getattr(self.module, self.name) height = w.data.shape[0] width = w.view(height, -1).data.shape[1] u = Parameter(w.data.new(height).normal_(0, 1), requires_grad=False) v = Parameter(w.data.new(width).normal_(0, 1), requires_grad=False) u.data = l2normalize(u.data) v.data = l2normalize(v.data) w_bar = Parameter(w.data) del self.module._parameters[self.name] self.module.register_parameter(self.name + "_u", u) self.module.register_parameter(self.name + "_v", v) self.module.register_parameter(self.name + "_bar", w_bar) def forward(self, *args): self._update_u_v() return self.module.forward(*args) def get_pad(in_, ksize, stride, atrous=1): out_ = np.ceil(float(in_)/stride) return int(((out_ - 1) * stride + atrous*(ksize-1) + 1 - in_)/2) class SNDiscriminator(nn.Module): def __init__(self,channel=6): super(SNDiscriminator, self).__init__() cnum = 32 self.discriminator_net = nn.Sequential( SNCoXvWithActivation(channel, 2*cnum, 4, 2, padding=get_pad(256, 5, 2)), SNCoXvWithActivation(2*cnum, 4*cnum, 4, 2, padding=get_pad(128, 5, 2)), SNCoXvWithActivation(4*cnum, 8*cnum, 4, 2, padding=get_pad(64, 5, 2)), SNCoXvWithActivation(8*cnum, 8*cnum, 4, 2, padding=get_pad(32, 5, 2)), SNCoXvWithActivation(8*cnum, 8*cnum, 4, 2, padding=get_pad(16, 5, 2)), # 8*8*256 # SNConvWithActivation(8*cnum, 8*cnum, 4, 2, padding=get_pad(8, 5, 2)), # 4*4*256 # SNConvWithActivation(8*cnum, 8*cnum, 4, 2, padding=get_pad(4, 5, 2)), # 2*2*256 ) # self.linear = nn.Linear(2*2*256,1) def forward(self, img_A, img_B): # Concatenate image and condition image by channels to produce input img_input = torch.cat((img_A, img_B), 1) x = self.discriminator_net(img_input) # x = x.view((x.size(0),-1)) # x = self.linear(x) return x class Discriminator(nn.Module): def __init__(self, in_channels=3): super(Discriminator, self).__init__() def discriminator_block(in_filters, out_filters, normalization=True): """Returns downsampling layers of each discriminator block""" layers = [nn.Conv2d(in_filters, out_filters, 4, stride=2, padding=1)] if normalization: layers.append(nn.InstanceNorm2d(out_filters)) layers.append(nn.LeakyReLU(0.2, inplace=True)) return layers self.model = nn.Sequential( *discriminator_block(in_channels*2, 64, normalization=False), *discriminator_block(64, 128), *discriminator_block(128, 256), *discriminator_block(256, 512), nn.ZeroPad2d((1, 0, 1, 0)), nn.Conv2d(512, 1, 4, padding=1, bias=False) ) def forward(self, img_A, img_B): # Concatenate image and condition image by channels to produce input img_input = torch.cat((img_A, img_B), 1) return self.model(img_input) def patchgan(): model = Discriminator() model.apply(weights_init_kaiming) return model def sngan(): model = SNDiscriminator() model.apply(weights_init_kaiming) return model def maskedsngan(): model = SNDiscriminator(channel=7) model.apply(weights_init_kaiming) return model ================================================ FILE: scripts/models/rasc.py ================================================ import torch import torchvision import torch.nn as nn import torch.nn.functional as F import numpy as np import math from scripts.utils.model_init import * from scripts.models.vgg import Vgg16 from scripts.models.blocks import * class CAWapper(nn.Module): """docstring for SENet""" def __init__(self, channel, type_of_connection=BasicLearningBlock): super(CAWapper, self).__init__() self.attention = ContextualAttention(ksize=3, stride=1, rate=2, fuse_k=3, softmax_scale=10, fuse=True, use_cuda=True) def forward(self, feature, mask): _, _, w, _ = feature.size() _, _, mw, _ = mask.size() # binaryfiy # selected the feature from the background as the additional feature to masked splicing feature. mask = torch.round(F.avg_pool2d(mask, 2, stride=mw//w)) result = self.attention(feature,mask) return result class NLWapper(nn.Module): """docstring for SENet""" def __init__(self, channel, type_of_connection=BasicLearningBlock): super(NLWapper, self).__init__() self.attention = NONLocalBlock2D(channel) def forward(self, feature, mask): _, _, w, _ = feature.size() _, _, mw, _ = mask.size() # binaryfiy # selected the feature from the background as the additional feature to masked splicing feature. # mask = torch.round(F.avg_pool2d(mask, 2, stride=mw//w)) result = self.attention(feature) return result class SENet(nn.Module): """docstring for SENet""" def __init__(self,channel,type_of_connection=BasicLearningBlock): super(SENet, self).__init__() self.attention = SEBlock(channel,16) def forward(self,feature,mask): _,_,w,_ = feature.size() _,_,mw,_ = mask.size() # binaryfiy # selected the feature from the background as the additional feature to masked splicing feature. mask = torch.round(F.avg_pool2d(mask,2,stride=mw//w)) result = self.attention(feature) return result class CBAMConnect(nn.Module): def __init__(self,channel): super(CBAMConnect, self).__init__() self.attention = CBAM(channel) def forward(self,feature,mask): results = self.attention(feature) return results class RASC(nn.Module): def __init__(self,channel,type_of_connection=BasicLearningBlock): super(RASC, self).__init__() self.connection = type_of_connection(channel) self.background_attention = GlobalAttentionModule(channel,16) self.mixed_attention = GlobalAttentionModule(channel,16) self.spliced_attention = GlobalAttentionModule(channel,16) self.gaussianMask = GaussianSmoothing(1,5,1) def forward(self,feature,mask): _,_,w,_ = feature.size() _,_,mw,_ = mask.size() # binaryfiy # selected the feature from the background as the additional feature to masked splicing feature. if w != mw: mask = torch.round(F.avg_pool2d(mask,2,stride=mw//w)) reverse_mask = -1*(mask-1) # here we add gaussin filter to mask and reverse_mask for better harimoization of edges. mask = self.gaussianMask(F.pad(mask,(2,2,2,2),mode='reflect')) reverse_mask = self.gaussianMask(F.pad(reverse_mask,(2,2,2,2),mode='reflect')) background = self.background_attention(feature) * reverse_mask selected_feature = self.mixed_attention(feature) spliced_feature = self.spliced_attention(feature) spliced = ( self.connection(spliced_feature) + selected_feature ) * mask return background + spliced class UNO(nn.Module): def __init__(self,channel): super(UNO, self).__init__() def forward(self,feature,_m): return feature class URASC(nn.Module): def __init__(self,channel,type_of_connection=BasicLearningBlock): super(URASC, self).__init__() self.connection = type_of_connection(channel) self.background_attention = GlobalAttentionModule(channel,16) self.mixed_attention = GlobalAttentionModule(channel,16) self.spliced_attention = GlobalAttentionModule(channel,16) self.mask_attention = SpatialAttentionModule(channel,16) def forward(self,feature, m=None): _,_,w,_ = feature.size() mask, reverse_mask = self.mask_attention(feature) background = self.background_attention(feature) * reverse_mask selected_feature = self.mixed_attention(feature) spliced_feature = self.spliced_attention(feature) spliced = ( self.connection(spliced_feature) + selected_feature ) * mask return background + spliced class MaskedURASC(nn.Module): def __init__(self,channel,type_of_connection=BasicLearningBlock): super(MaskedURASC, self).__init__() self.connection = type_of_connection(channel) self.background_attention = GlobalAttentionModule(channel,16) self.mixed_attention = GlobalAttentionModule(channel,16) self.spliced_attention = GlobalAttentionModule(channel,16) self.mask_attention = SpatialAttentionModule(channel,16) def forward(self,feature): _,_,w,_ = feature.size() mask, reverse_mask = self.mask_attention(feature) background = self.background_attention(feature) * reverse_mask selected_feature = self.mixed_attention(feature) spliced_feature = self.spliced_attention(feature) spliced = ( self.connection(spliced_feature) + selected_feature ) * mask return background + spliced, mask ================================================ FILE: scripts/models/sa_resunet.py ================================================ import torch import torch.nn as nn import torch.nn.functional as F from scripts.models.blocks import SEBlock from scripts.models.rasc import * from scripts.models.unet import UnetGenerator,MinimalUnetV2 def weight_init(m): if isinstance(m, nn.Conv2d): nn.init.xavier_normal_(m.weight) nn.init.constant_(m.bias, 0) def reset_params(model): for i, m in enumerate(model.modules()): weight_init(m) def conv3x3(in_channels, out_channels, stride=1, padding=1, bias=True, groups=1): return nn.Conv2d( in_channels, out_channels, kernel_size=3, stride=stride, padding=padding, bias=bias, groups=groups) def up_conv2x2(in_channels, out_channels, transpose=True): if transpose: return nn.ConvTranspose2d( in_channels, out_channels, kernel_size=2, stride=2) else: return nn.Sequential( nn.Upsample(mode='bilinear', scale_factor=2), conv1x1(in_channels, out_channels)) def conv1x1(in_channels, out_channels, groups=1): return nn.Conv2d( in_channels, out_channels, kernel_size=1, groups=groups, stride=1) class UpCoXvD(nn.Module): def __init__(self, in_channels, out_channels, blocks, residual=True,norm=nn.BatchNorm2d, act=F.relu,batch_norm=True, transpose=True,concat=True,use_att=False): super(UpCoXvD, self).__init__() self.concat = concat self.residual = residual self.batch_norm = batch_norm self.bn = None self.conv2 = [] self.use_att = use_att self.up_conv = up_conv2x2(in_channels, out_channels, transpose=transpose) self.norm0 = norm(out_channels) if self.use_att: self.s2am = RASC(2 * out_channels) else: self.s2am = None if self.concat: self.conv1 = conv3x3(2 * out_channels, out_channels) self.norm1 = norm(out_channels , out_channels) else: self.conv1 = conv3x3(out_channels, out_channels) self.norm1 = norm(out_channels , out_channels) for _ in range(blocks): self.conv2.append(conv3x3(out_channels, out_channels)) if self.batch_norm: self.bn = [] for _ in range(blocks): self.bn.append(norm(out_channels)) self.bn = nn.ModuleList(self.bn) self.conv2 = nn.ModuleList(self.conv2) self.act = act def forward(self, from_up, from_down, mask=None,se=None): from_up = self.act(self.norm0(self.up_conv(from_up))) if self.concat: x1 = torch.cat((from_up, from_down), 1) else: if from_down is not None: x1 = from_up + from_down else: x1 = from_up if self.use_att: x1 = self.s2am(x1,mask) x1 = self.act(self.norm1(self.conv1(x1))) x2 = None for idx, conv in enumerate(self.conv2): x2 = conv(x1) if self.batch_norm: x2 = self.bn[idx](x2) if (se is not None) and (idx == len(self.conv2) - 1): # last x2 = se(x2) if self.residual: x2 = x2 + x1 x2 = self.act(x2) x1 = x2 return x2 class DownCoXvD(nn.Module): def __init__(self, in_channels, out_channels, blocks, pooling=True, norm=nn.BatchNorm2d,act=F.relu,residual=True, batch_norm=True): super(DownCoXvD, self).__init__() self.pooling = pooling self.residual = residual self.batch_norm = batch_norm self.bn = None self.pool = None self.conv1 = conv3x3(in_channels, out_channels) self.norm1 = norm(out_channels) self.conv2 = [] for _ in range(blocks): self.conv2.append(conv3x3(out_channels, out_channels)) if self.batch_norm: self.bn = [] for _ in range(blocks): self.bn.append(norm(out_channels)) self.bn = nn.ModuleList(self.bn) if self.pooling: self.pool = nn.MaxPool2d(kernel_size=2, stride=2) self.conv2 = nn.ModuleList(self.conv2) self.act = act def __call__(self, x): return self.forward(x) def forward(self, x): x1 = self.act(self.norm1(self.conv1(x))) x2 = None for idx, conv in enumerate(self.conv2): x2 = conv(x1) if self.batch_norm: x2 = self.bn[idx](x2) if self.residual: x2 = x2 + x1 x2 = self.act(x2) x1 = x2 before_pool = x2 if self.pooling: x2 = self.pool(x2) return x2, before_pool class UnetDecoderD(nn.Module): def __init__(self, in_channels=512, out_channels=3, norm=nn.BatchNorm2d,act=F.relu, depth=5, blocks=1, residual=True, batch_norm=True, transpose=True, concat=True, is_final=True, use_att=False): super(UnetDecoderD, self).__init__() self.conv_final = None self.up_convs = [] self.atts = [] self.use_att = use_att outs = in_channels for i in range(depth-1): # depth = 1 ins = outs outs = ins // 2 # 512,256 # 256,128 # 128,64 # 64,32 up_conv = UpCoXvD(ins, outs, blocks, residual=residual, batch_norm=batch_norm, transpose=transpose, concat=concat, norm=norm, act=act) if self.use_att: self.atts.append(SEBlock(outs)) self.up_convs.append(up_conv) if is_final: self.conv_final = conv1x1(outs, out_channels) else: up_conv = UpCoXvD(outs, out_channels, blocks, residual=residual, batch_norm=batch_norm, transpose=transpose, concat=concat,norm=norm, act=act) if self.use_att: self.atts.append(SEBlock(out_channels)) self.up_convs.append(up_conv) self.up_convs = nn.ModuleList(self.up_convs) self.atts = nn.ModuleList(self.atts) reset_params(self) def __call__(self, x, encoder_outs=None): return self.forward(x, encoder_outs) def forward(self, x, encoder_outs=None): for i, up_conv in enumerate(self.up_convs): before_pool = None if encoder_outs is not None: before_pool = encoder_outs[-(i+2)] x = up_conv(x, before_pool) if self.use_att: x = self.atts[i](x) if self.conv_final is not None: x = self.conv_final(x) return x class UnetDecoderDatt(nn.Module): def __init__(self, in_channels=512, out_channels=3, depth=5, blocks=1, residual=True, batch_norm=True, transpose=True, concat=True, is_final=True, norm=nn.BatchNorm2d,act=F.relu): super(UnetDecoderDatt, self).__init__() self.conv_final = None self.up_convs = [] self.im_atts = [] self.vm_atts = [] self.mask_atts = [] outs = in_channels for i in range(depth-1): # depth = 5 [0,1,2,3] ins = outs outs = ins // 2 # 512,256 # 256,128 # 128,64 # 64,32 up_conv = UpCoXvD(ins, outs, blocks, residual=residual, batch_norm=batch_norm, transpose=transpose, concat=concat, norm=nn.BatchNorm2d,act=F.relu) self.up_convs.append(up_conv) self.im_atts.append(SEBlock(outs)) self.vm_atts.append(SEBlock(outs)) self.mask_atts.append(SEBlock(outs)) if is_final: self.conv_final = conv1x1(outs, out_channels) else: up_conv = UpCoXvD(outs, out_channels, blocks, residual=residual, batch_norm=batch_norm, transpose=transpose, concat=concat, norm=nn.BatchNorm2d,act=F.relu) self.up_convs.append(up_conv) self.im_atts.append(SEBlock(out_channels)) self.vm_atts.append(SEBlock(out_channels)) self.mask_atts.append(SEBlock(out_channels)) self.up_convs = nn.ModuleList(self.up_convs) self.im_atts = nn.ModuleList(self.im_atts) self.vm_atts = nn.ModuleList(self.vm_atts) self.mask_atts = nn.ModuleList(self.mask_atts) reset_params(self) def forward(self, input, encoder_outs=None): # im branch x = input for i, up_conv in enumerate(self.up_convs): before_pool = None if encoder_outs is not None: before_pool = encoder_outs[-(i+2)] x = up_conv(x, before_pool,se=self.im_atts[i]) x_im = x x = input for i, up_conv in enumerate(self.up_convs): before_pool = None if encoder_outs is not None: before_pool = encoder_outs[-(i+2)] x = up_conv(x, before_pool, se = self.mask_atts[i]) x_mask = x x = input for i, up_conv in enumerate(self.up_convs): before_pool = None if encoder_outs is not None: before_pool = encoder_outs[-(i+2)] x = up_conv(x, before_pool, se=self.vm_atts[i]) x_vm = x return x_im,x_mask,x_vm class UnetEncoderD(nn.Module): def __init__(self, in_channels=3, depth=5, blocks=1, start_filters=32, residual=True, batch_norm=True, norm=nn.BatchNorm2d, act=F.relu): super(UnetEncoderD, self).__init__() self.down_convs = [] outs = None if type(blocks) is tuple: blocks = blocks[0] for i in range(depth): ins = in_channels if i == 0 else outs outs = start_filters*(2**i) pooling = True if i < depth-1 else False down_conv = DownCoXvD(ins, outs, blocks, pooling=pooling, residual=residual, batch_norm=batch_norm, norm=nn.BatchNorm2d, act=F.relu) self.down_convs.append(down_conv) self.down_convs = nn.ModuleList(self.down_convs) reset_params(self) def __call__(self, x): return self.forward(x) def forward(self, x): encoder_outs = [] for d_conv in self.down_convs: x, before_pool = d_conv(x) encoder_outs.append(before_pool) return x, encoder_outs class ResDown(nn.Module): def __init__(self, in_size, out_size, pooling=True, use_att=False): super(ResDown, self).__init__() self.model = DownCoXvD(in_size, out_size, 3, pooling=pooling) def forward(self, x): return self.model(x) class ResUp(nn.Module): def __init__(self, in_size, out_size, use_att=False): super(ResUp, self).__init__() self.model = UpCoXvD(in_size, out_size, 3, use_att=use_att) def forward(self, x, skip_input, mask=None): return self.model(x,skip_input,mask) class ResDownNew(nn.Module): def __init__(self, in_size, out_size, pooling=True, use_att=False): super(ResDownNew, self).__init__() self.model = DownCoXvD(in_size, out_size, 3, pooling=pooling, norm=nn.InstanceNorm2d, act=F.leaky_relu) def forward(self, x): return self.model(x) class ResUpNew(nn.Module): def __init__(self, in_size, out_size, use_att=False): super(ResUpNew, self).__init__() self.model = UpCoXvD(in_size, out_size, 3, use_att=use_att, norm=nn.InstanceNorm2d) def forward(self, x, skip_input, mask=None): return self.model(x,skip_input,mask) class VMSingle(nn.Module): def __init__(self, in_channels=3, out_channels=3, down=ResDown, up=ResUp, ngf=32, res=True,use_att=False): super(VMSingle, self).__init__() self.down1 = down(in_channels, ngf) self.down2 = down(ngf, ngf*2) self.down3 = down(ngf*2, ngf*4) self.down4 = down(ngf*4, ngf*8) self.down5 = down(ngf*8, ngf*16, pooling=False) self.up1 = up(ngf*16, ngf*8) self.up2 = up(ngf*8, ngf*4, use_att=use_att) self.up3 = up(ngf*4, ngf*2, use_att=use_att) self.up4 = up(ngf*2, ngf*1, use_att=use_att) self.im = nn.Conv2d(ngf, 3, 1) self.res = res def forward(self, input): img, mask = input[:,0:3,:,:],input[:,3:4,:,:] # U-Net generator with skip connections from encoder to decoder x,d1 = self.down1(input) # 128,256 x,d2 = self.down2(x) # 64,128 x,d3 = self.down3(x) # 32,64 x,d4 = self.down4(x) # 16,32 x,_ = self.down5(x) # 8,16 x = self.up1(x, d4) # 16 x = self.up2(x, d3, mask) # 32 x = self.up3(x, d2, mask) # 64 x = self.up4(x, d1, mask) # 128 im = self.im(x) return im class VMSingleS2AM(nn.Module): def __init__(self, in_channels=3, out_channels=3, down=ResDown, up=ResUp, ngf=32): super(VMSingleS2AM, self).__init__() self.down1 = down(in_channels, ngf) self.down2 = down(ngf, ngf*2) self.down3 = down(ngf*2, ngf*4) self.down4 = down(ngf*4, ngf*8) self.down5 = down(ngf*8, ngf*16, pooling=False) self.up1 = up(ngf*16, ngf*8) self.up2 = up(ngf*8, ngf*4) self.s2am2 = RASC(ngf*4) self.up3 = up(ngf*4, ngf*2) self.s2am3 = RASC(ngf*2) self.up4 = up(ngf*2, ngf*1) self.s2am4 = RASC(ngf) self.im = nn.Conv2d(ngf, 3, 1) def forward(self, input): img, mask = input[:,0:3,:,:],input[:,3:4,:,:] # U-Net generator with skip connections from encoder to decoder x,d1 = self.down1(input) # 128,256 x,d2 = self.down2(x) # 64,128 x,d3 = self.down3(x) # 32,64 x,d4 = self.down4(x) # 16,32 x,_ = self.down5(x) # 8,16 x = self.up1(x, d4) # 16 x = self.up2(x, d3) # 32 x = self.s2am2(x, mask) x = self.up3(x, d2) # 64 x = self.s2am3(x, mask) x = self.up4(x, d1) # 128 x = self.s2am4(x, mask) im = self.im(x) return im class UnetVMS2AMv4(nn.Module): def __init__(self, in_channels=3, depth=5, shared_depth=0, use_vm_decoder=False, blocks=1, out_channels_image=3, out_channels_mask=1, start_filters=32, residual=True, batch_norm=True, transpose=True, concat=True, transfer_data=True, long_skip=False, s2am='unet', use_coarser=True,no_stage2=False): super(UnetVMS2AMv4, self).__init__() self.transfer_data = transfer_data self.shared = shared_depth self.optimizer_encoder, self.optimizer_image, self.optimizer_vm = None, None, None self.optimizer_mask, self.optimizer_shared = None, None if type(blocks) is not tuple: blocks = (blocks, blocks, blocks, blocks, blocks) if not transfer_data: concat = False self.encoder = UnetEncoderD(in_channels=in_channels, depth=depth, blocks=blocks[0], start_filters=start_filters, residual=residual, batch_norm=batch_norm,norm=nn.InstanceNorm2d,act=F.leaky_relu) self.image_decoder = UnetDecoderD(in_channels=start_filters * 2 ** (depth - shared_depth - 1), out_channels=out_channels_image, depth=depth - shared_depth, blocks=blocks[1], residual=residual, batch_norm=batch_norm, transpose=transpose, concat=concat,norm=nn.InstanceNorm2d) self.mask_decoder = UnetDecoderD(in_channels=start_filters * 2 ** (depth - shared_depth - 1), out_channels=out_channels_mask, depth=depth - shared_depth, blocks=blocks[2], residual=residual, batch_norm=batch_norm, transpose=transpose, concat=concat,norm=nn.InstanceNorm2d) self.vm_decoder = None if use_vm_decoder: self.vm_decoder = UnetDecoderD(in_channels=start_filters * 2 ** (depth - shared_depth - 1), out_channels=out_channels_image, depth=depth - shared_depth, blocks=blocks[3], residual=residual, batch_norm=batch_norm, transpose=transpose, concat=concat,norm=nn.InstanceNorm2d) self.shared_decoder = None self.use_coarser = use_coarser self.long_skip = long_skip self.no_stage2 = no_stage2 self._forward = self.unshared_forward if self.shared != 0: self._forward = self.shared_forward self.shared_decoder = UnetDecoderDatt(in_channels=start_filters * 2 ** (depth - 1), out_channels=start_filters * 2 ** (depth - shared_depth - 1), depth=shared_depth, blocks=blocks[4], residual=residual, batch_norm=batch_norm, transpose=transpose, concat=concat, is_final=False,norm=nn.InstanceNorm2d) if s2am == 'unet': self.s2am = UnetGenerator(4,3,is_attention_layer=True,attention_model=RASC,basicblock=MinimalUnetV2) elif s2am == 'vm': self.s2am = VMSingle(4) elif s2am == 'vms2am': self.s2am = VMSingleS2AM(4,down=ResDownNew,up=ResUpNew) def set_optimizers(self): self.optimizer_encoder = torch.optim.Adam(self.encoder.parameters(), lr=0.001) self.optimizer_image = torch.optim.Adam(self.image_decoder.parameters(), lr=0.001) self.optimizer_mask = torch.optim.Adam(self.mask_decoder.parameters(), lr=0.001) self.optimizer_s2am = torch.optim.Adam(self.s2am.parameters(), lr=0.001) if self.vm_decoder is not None: self.optimizer_vm = torch.optim.Adam(self.vm_decoder.parameters(), lr=0.001) if self.shared != 0: self.optimizer_shared = torch.optim.Adam(self.shared_decoder.parameters(), lr=0.001) def zero_grad_all(self): self.optimizer_encoder.zero_grad() self.optimizer_image.zero_grad() self.optimizer_mask.zero_grad() self.optimizer_s2am.zero_grad() if self.vm_decoder is not None: self.optimizer_vm.zero_grad() if self.shared != 0: self.optimizer_shared.zero_grad() def step_all(self): self.optimizer_encoder.step() self.optimizer_image.step() self.optimizer_mask.step() self.optimizer_s2am.step() if self.vm_decoder is not None: self.optimizer_vm.step() if self.shared != 0: self.optimizer_shared.step() def step_optimizer_image(self): self.optimizer_image.step() def __call__(self, synthesized): return self._forward(synthesized) def forward(self, synthesized): return self._forward(synthesized) def unshared_forward(self, synthesized): image_code, before_pool = self.encoder(synthesized) if not self.transfer_data: before_pool = None reconstructed_image = torch.tanh(self.image_decoder(image_code, before_pool)) reconstructed_mask = torch.sigmoid(self.mask_decoder(image_code, before_pool)) if self.vm_decoder is not None: reconstructed_vm = torch.tanh(self.vm_decoder(image_code, before_pool)) return reconstructed_image, reconstructed_mask, reconstructed_vm return reconstructed_image, reconstructed_mask def shared_forward(self, synthesized): image_code, before_pool = self.encoder(synthesized) if self.transfer_data: shared_before_pool = before_pool[- self.shared - 1:] unshared_before_pool = before_pool[: - self.shared] else: before_pool = None shared_before_pool = None unshared_before_pool = None im,mask,vm = self.shared_decoder(image_code, shared_before_pool) reconstructed_image = torch.tanh(self.image_decoder(im, unshared_before_pool)) if self.long_skip: reconstructed_image = reconstructed_image + synthesized reconstructed_mask = torch.sigmoid(self.mask_decoder(mask, unshared_before_pool)) if self.vm_decoder is not None: reconstructed_vm = torch.tanh(self.vm_decoder(vm, unshared_before_pool)) if self.long_skip: reconstructed_vm = reconstructed_vm + synthesized coarser = reconstructed_image * reconstructed_mask + (1-reconstructed_mask)* synthesized if self.use_coarser: refine = torch.tanh(self.s2am(torch.cat([coarser,reconstructed_mask],dim=1))) + coarser elif self.no_stage2: refine = torch.tanh(self.s2am(torch.cat([coarser,reconstructed_mask],dim=1))) else: refine = torch.tanh(self.s2am(torch.cat([coarser,reconstructed_mask],dim=1))) + synthesized # final = refine * reconstructed_mask + (1-reconstructed_mask)* synthesized if self.vm_decoder is not None: return [refine, reconstructed_image], reconstructed_mask, reconstructed_vm else: return [refine, reconstructed_image], reconstructed_mask ================================================ FILE: scripts/models/unet.py ================================================ import torch import torch.nn as nn from torch.nn import init import functools from scripts.models.blocks import * from scripts.models.rasc import * class MinimalUnetV2(nn.Module): """docstring for MinimalUnet""" def __init__(self, down=None,up=None,submodule=None,attention=None,withoutskip=False,**kwags): super(MinimalUnetV2, self).__init__() self.down = nn.Sequential(*down) self.up = nn.Sequential(*up) self.sub = submodule self.attention = attention self.withoutskip = withoutskip self.is_attention = not self.attention == None self.is_sub = not submodule == None def forward(self,x,mask=None): if self.is_sub: x_up,_ = self.sub(self.down(x),mask) else: x_up = self.down(x) if self.withoutskip: #outer or inner. x_out = self.up(x_up) else: if self.is_attention: x_out = (self.attention(torch.cat([x,self.up(x_up)],1),mask),mask) else: x_out = (torch.cat([x,self.up(x_up)],1),mask) return x_out class MinimalUnet(nn.Module): """docstring for MinimalUnet""" def __init__(self, down=None,up=None,submodule=None,attention=None,withoutskip=False,**kwags): super(MinimalUnet, self).__init__() self.down = nn.Sequential(*down) self.up = nn.Sequential(*up) self.sub = submodule self.attention = attention self.withoutskip = withoutskip self.is_attention = not self.attention == None self.is_sub = not submodule == None def forward(self,x,mask=None): if self.is_sub: x_up,_ = self.sub(self.down(x),mask) else: x_up = self.down(x) if self.is_attention: x = self.attention(x,mask) if self.withoutskip: #outer or inner. x_out = self.up(x_up) else: x_out = (torch.cat([x,self.up(x_up)],1),mask) return x_out # Defines the submodule with skip connection. # X -------------------identity---------------------- X # |-- downsampling -- |submodule| -- upsampling --| class UnetSkipConnectionBlock(nn.Module): def __init__(self, outer_nc, inner_nc, input_nc=None, submodule=None, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d, use_dropout=False,is_attention_layer=False, attention_model=RASC,basicblock=MinimalUnet,outermostattention=False): super(UnetSkipConnectionBlock, self).__init__() self.outermost = outermost if type(norm_layer) == functools.partial: use_bias = norm_layer.func == nn.InstanceNorm2d else: use_bias = norm_layer == nn.InstanceNorm2d if input_nc is None: input_nc = outer_nc downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4, stride=2, padding=1, bias=use_bias) downrelu = nn.LeakyReLU(0.2, True) downnorm = norm_layer(inner_nc) uprelu = nn.ReLU(True) upnorm = norm_layer(outer_nc) if outermost: upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, kernel_size=4, stride=2, padding=1) down = [downconv] up = [uprelu, upconv] model = basicblock(down,up,submodule,withoutskip=outermost) elif innermost: upconv = nn.ConvTranspose2d(inner_nc, outer_nc, kernel_size=4, stride=2, padding=1, bias=use_bias) down = [downrelu, downconv] up = [uprelu, upconv, upnorm] model = basicblock(down,up) else: upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, kernel_size=4, stride=2, padding=1, bias=use_bias) down = [downrelu, downconv, downnorm] up = [uprelu, upconv, upnorm] if is_attention_layer: if MinimalUnetV2.__qualname__ in basicblock.__qualname__ : attention_model = attention_model(input_nc*2) else: attention_model = attention_model(input_nc) else: attention_model = None if use_dropout: model = basicblock(down,up.append(nn.Dropout(0.5)),submodule,attention_model,outermostattention=outermostattention) else: model = basicblock(down,up,submodule,attention_model,outermostattention=outermostattention) self.model = model def forward(self, x,mask=None): # build the mask for attention use return self.model(x,mask) class UnetGenerator(nn.Module): def __init__(self, input_nc, output_nc, num_downs=8, ngf=64,norm_layer=nn.BatchNorm2d, use_dropout=False, is_attention_layer=False,attention_model=RASC,use_inner_attention=False,basicblock=MinimalUnet): super(UnetGenerator, self).__init__() # 8 for 256x256 # 9 for 512x512 # construct unet structure self.need_mask = not input_nc == output_nc unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True,basicblock=basicblock) # 1 for i in range(num_downs - 5): #3 times unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer, use_dropout=use_dropout,is_attention_layer=use_inner_attention,attention_model=attention_model,basicblock=basicblock) # 8,4,2 unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer,is_attention_layer=is_attention_layer,attention_model=attention_model,basicblock=basicblock) #16 unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, norm_layer=norm_layer,is_attention_layer=is_attention_layer,attention_model=attention_model,basicblock=basicblock) #32 unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer,is_attention_layer=is_attention_layer,attention_model=attention_model,basicblock=basicblock, outermostattention=True) #64 unet_block = UnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, basicblock=basicblock, norm_layer=norm_layer) # 128 self.model = unet_block def forward(self, input): if self.need_mask: return self.model(input,input[:,3:4,:,:]) else: return self.model(input[:,0:3,:,:],input[:,3:4,:,:]) ================================================ FILE: scripts/models/vgg.py ================================================ from collections import namedtuple import torch from torchvision import models class Vgg16(torch.nn.Module): def __init__(self, requires_grad=False): super(Vgg16, self).__init__() vgg_pretrained_features = models.vgg16(pretrained=True).features self.slice1 = torch.nn.Sequential() self.slice2 = torch.nn.Sequential() self.slice3 = torch.nn.Sequential() self.slice4 = torch.nn.Sequential() self.slice5 = torch.nn.Sequential() for x in range(4): self.slice1.add_module(str(x), vgg_pretrained_features[x]) for x in range(4, 9): self.slice2.add_module(str(x), vgg_pretrained_features[x]) for x in range(9, 16): self.slice3.add_module(str(x), vgg_pretrained_features[x]) for x in range(16, 23): self.slice4.add_module(str(x), vgg_pretrained_features[x]) for x in range(23,30): self.slice5.add_module(str(x), vgg_pretrained_features[x]) if not requires_grad: for param in self.parameters(): param.requires_grad = False def forward(self, X): h = self.slice1(X) h_relu1_2 = h h = self.slice2(h) h_relu2_2 = h h = self.slice3(h) h_relu3_3 = h h = self.slice4(h) h_relu4_3 = h h = self.slice5(h) h_relu5_3 = h # vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3','relu5_3']) # out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3) return (h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3) class Vgg19(torch.nn.Module): def __init__(self, requires_grad=False): super(Vgg19, self).__init__() # vgg_pretrained_features = models.vgg19(pretrained=True).features self.vgg_pretrained_features = models.vgg19(pretrained=True).features # self.slice1 = torch.nn.Sequential() # self.slice2 = torch.nn.Sequential() # self.slice3 = torch.nn.Sequential() # self.slice4 = torch.nn.Sequential() # self.slice5 = torch.nn.Sequential() # for x in range(2): # self.slice1.add_module(str(x), vgg_pretrained_features[x]) # for x in range(2, 7): # self.slice2.add_module(str(x), vgg_pretrained_features[x]) # for x in range(7, 12): # self.slice3.add_module(str(x), vgg_pretrained_features[x]) # for x in range(12, 21): # self.slice4.add_module(str(x), vgg_pretrained_features[x]) # for x in range(21, 30): # self.slice5.add_module(str(x), vgg_pretrained_features[x]) if not requires_grad: for param in self.parameters(): param.requires_grad = False def forward(self, X, indices=None): if indices is None: indices = [2, 7, 12, 21, 30] out = [] #indices = sorted(indices) for i in range(indices[-1]): X = self.vgg_pretrained_features[i](X) if (i+1) in indices: out.append(X) return out ================================================ FILE: scripts/models/vmu.py ================================================ import torch import torch.nn as nn import torch.nn.functional as F from scripts.models.blocks import SEBlock from scripts.models.rasc import * from scripts.models.unet import UnetGenerator,MinimalUnetV2 def weight_init(m): if isinstance(m, nn.Conv2d): nn.init.xavier_normal_(m.weight) nn.init.constant_(m.bias, 0) def reset_params(model): for i, m in enumerate(model.modules()): weight_init(m) def conv3x3(in_channels, out_channels, stride=1, padding=1, bias=True, groups=1): return nn.Conv2d( in_channels, out_channels, kernel_size=3, stride=stride, padding=padding, bias=bias, groups=groups) def up_conv2x2(in_channels, out_channels, transpose=True): if transpose: return nn.ConvTranspose2d( in_channels, out_channels, kernel_size=2, stride=2) else: return nn.Sequential( nn.Upsample(mode='bilinear', scale_factor=2), conv1x1(in_channels, out_channels)) def conv1x1(in_channels, out_channels, groups=1): return nn.Conv2d( in_channels, out_channels, kernel_size=1, groups=groups, stride=1) class UpCoXvD(nn.Module): def __init__(self, in_channels, out_channels, blocks, residual=True, batch_norm=True, transpose=True,concat=True,use_att=False): super(UpCoXvD, self).__init__() self.concat = concat self.residual = residual self.batch_norm = batch_norm self.bn = None self.conv2 = [] self.use_att = use_att self.up_conv = up_conv2x2(in_channels, out_channels, transpose=transpose) if self.use_att: self.s2am = RASC(2 * out_channels) else: self.s2am = None if self.concat: self.conv1 = conv3x3(2 * out_channels, out_channels) else: self.conv1 = conv3x3(out_channels, out_channels) for _ in range(blocks): self.conv2.append(conv3x3(out_channels, out_channels)) if self.batch_norm: self.bn = [] for _ in range(blocks): self.bn.append(nn.BatchNorm2d(out_channels)) self.bn = nn.ModuleList(self.bn) self.conv2 = nn.ModuleList(self.conv2) def forward(self, from_up, from_down, mask=None): from_up = self.up_conv(from_up) if self.concat: x1 = torch.cat((from_up, from_down), 1) else: if from_down is not None: x1 = from_up + from_down else: x1 = from_up if self.use_att: x1 = self.s2am(x1,mask) x1 = F.relu(self.conv1(x1)) x2 = None for idx, conv in enumerate(self.conv2): x2 = conv(x1) if self.batch_norm: x2 = self.bn[idx](x2) if self.residual: x2 = x2 + x1 x2 = F.relu(x2) x1 = x2 return x2 class DownCoXvD(nn.Module): def __init__(self, in_channels, out_channels, blocks, pooling=True, residual=True, batch_norm=True): super(DownCoXvD, self).__init__() self.pooling = pooling self.residual = residual self.batch_norm = batch_norm self.bn = None self.pool = None self.conv1 = conv3x3(in_channels, out_channels) self.conv2 = [] for _ in range(blocks): self.conv2.append(conv3x3(out_channels, out_channels)) if self.batch_norm: self.bn = [] for _ in range(blocks): self.bn.append(nn.BatchNorm2d(out_channels)) self.bn = nn.ModuleList(self.bn) if self.pooling: self.pool = nn.MaxPool2d(kernel_size=2, stride=2) self.conv2 = nn.ModuleList(self.conv2) def __call__(self, x): return self.forward(x) def forward(self, x): x1 = F.relu(self.conv1(x)) x2 = None for idx, conv in enumerate(self.conv2): x2 = conv(x1) if self.batch_norm: x2 = self.bn[idx](x2) if self.residual: x2 = x2 + x1 x2 = F.relu(x2) x1 = x2 before_pool = x2 if self.pooling: x2 = self.pool(x2) return x2, before_pool class UnetDecoderD(nn.Module): def __init__(self, in_channels=512, out_channels=3, depth=5, blocks=1, residual=True, batch_norm=True, transpose=True, concat=True, is_final=True): super(UnetDecoderD, self).__init__() self.conv_final = None self.up_convs = [] outs = in_channels for i in range(depth-1): ins = outs outs = ins // 2 # 512,256 # 256,128 # 128,64 # 64,32 up_conv = UpCoXvD(ins, outs, blocks, residual=residual, batch_norm=batch_norm, transpose=transpose, concat=concat) self.up_convs.append(up_conv) if is_final: self.conv_final = conv1x1(outs, out_channels) else: up_conv = UpCoXvD(outs, out_channels, blocks, residual=residual, batch_norm=batch_norm, transpose=transpose, concat=concat) self.up_convs.append(up_conv) self.up_convs = nn.ModuleList(self.up_convs) reset_params(self) def __call__(self, x, encoder_outs=None): return self.forward(x, encoder_outs) def forward(self, x, encoder_outs=None): for i, up_conv in enumerate(self.up_convs): before_pool = None if encoder_outs is not None: before_pool = encoder_outs[-(i+2)] x = up_conv(x, before_pool) if self.conv_final is not None: x = self.conv_final(x) return x class UnetEncoderD(nn.Module): def __init__(self, in_channels=3, depth=5, blocks=1, start_filters=32, residual=True, batch_norm=True): super(UnetEncoderD, self).__init__() self.down_convs = [] outs = None if type(blocks) is tuple: blocks = blocks[0] for i in range(depth): ins = in_channels if i == 0 else outs outs = start_filters*(2**i) pooling = True if i < depth-1 else False down_conv = DownCoXvD(ins, outs, blocks, pooling=pooling, residual=residual, batch_norm=batch_norm) self.down_convs.append(down_conv) self.down_convs = nn.ModuleList(self.down_convs) reset_params(self) def __call__(self, x): return self.forward(x) def forward(self, x): encoder_outs = [] for d_conv in self.down_convs: x, before_pool = d_conv(x) encoder_outs.append(before_pool) return x, encoder_outs class UnetVM(nn.Module): def __init__(self, in_channels=3, depth=5, shared_depth=0, use_vm_decoder=False, blocks=1, out_channels_image=3, out_channels_mask=1, start_filters=32, residual=True, batch_norm=True, transpose=True, concat=True, transfer_data=True, long_skip=False): super(UnetVM, self).__init__() self.transfer_data = transfer_data self.shared = shared_depth self.optimizer_encoder, self.optimizer_image, self.optimizer_vm = None, None, None self.optimizer_mask, self.optimizer_shared = None, None if type(blocks) is not tuple: blocks = (blocks, blocks, blocks, blocks, blocks) if not transfer_data: concat = False self.encoder = UnetEncoderD(in_channels=in_channels, depth=depth, blocks=blocks[0], start_filters=start_filters, residual=residual, batch_norm=batch_norm) self.image_decoder = UnetDecoderD(in_channels=start_filters * 2 ** (depth - shared_depth - 1), out_channels=out_channels_image, depth=depth - shared_depth, blocks=blocks[1], residual=residual, batch_norm=batch_norm, transpose=transpose, concat=concat) self.mask_decoder = UnetDecoderD(in_channels=start_filters * 2 ** (depth - 1), out_channels=out_channels_mask, depth=depth, blocks=blocks[2], residual=residual, batch_norm=batch_norm, transpose=transpose, concat=concat) self.vm_decoder = None if use_vm_decoder: self.vm_decoder = UnetDecoderD(in_channels=start_filters * 2 ** (depth - shared_depth - 1), out_channels=out_channels_image, depth=depth - shared_depth, blocks=blocks[3], residual=residual, batch_norm=batch_norm, transpose=transpose, concat=concat) self.shared_decoder = None self.long_skip = long_skip self._forward = self.unshared_forward if self.shared != 0: self._forward = self.shared_forward self.shared_decoder = UnetDecoderD(in_channels=start_filters * 2 ** (depth - 1), out_channels=start_filters * 2 ** (depth - shared_depth - 1), depth=shared_depth, blocks=blocks[4], residual=residual, batch_norm=batch_norm, transpose=transpose, concat=concat, is_final=False) def set_optimizers(self): self.optimizer_encoder = torch.optim.Adam(self.encoder.parameters(), lr=0.001) self.optimizer_image = torch.optim.Adam(self.image_decoder.parameters(), lr=0.001) self.optimizer_mask = torch.optim.Adam(self.mask_decoder.parameters(), lr=0.001) if self.vm_decoder is not None: self.optimizer_vm = torch.optim.Adam(self.vm_decoder.parameters(), lr=0.001) if self.shared != 0: self.optimizer_shared = torch.optim.Adam(self.shared_decoder.parameters(), lr=0.001) def zero_grad_all(self): self.optimizer_encoder.zero_grad() self.optimizer_image.zero_grad() self.optimizer_mask.zero_grad() if self.vm_decoder is not None: self.optimizer_vm.zero_grad() if self.shared != 0: self.optimizer_shared.zero_grad() def step_all(self): self.optimizer_encoder.step() self.optimizer_image.step() self.optimizer_mask.step() if self.vm_decoder is not None: self.optimizer_vm.step() if self.shared != 0: self.optimizer_shared.step() def step_optimizer_image(self): self.optimizer_image.step() def __call__(self, synthesized): return self._forward(synthesized) def forward(self, synthesized): return self._forward(synthesized) def unshared_forward(self, synthesized): image_code, before_pool = self.encoder(synthesized) if not self.transfer_data: before_pool = None reconstructed_image = torch.tanh(self.image_decoder(image_code, before_pool)) reconstructed_mask = torch.sigmoid(self.mask_decoder(image_code, before_pool)) if self.vm_decoder is not None: reconstructed_vm = torch.tanh(self.vm_decoder(image_code, before_pool)) return reconstructed_image, reconstructed_mask, reconstructed_vm return reconstructed_image, reconstructed_mask def shared_forward(self, synthesized): image_code, before_pool = self.encoder(synthesized) if self.transfer_data: shared_before_pool = before_pool[- self.shared - 1:] unshared_before_pool = before_pool[: - self.shared] else: before_pool = None shared_before_pool = None unshared_before_pool = None x = self.shared_decoder(image_code, shared_before_pool) reconstructed_image = torch.tanh(self.image_decoder(x, unshared_before_pool)) if self.long_skip: reconstructed_image = reconstructed_image + synthesized reconstructed_mask = torch.sigmoid(self.mask_decoder(image_code, before_pool)) if self.vm_decoder is not None: reconstructed_vm = torch.tanh(self.vm_decoder(x, unshared_before_pool)) if self.long_skip: reconstructed_vm = reconstructed_vm + synthesized return reconstructed_image, reconstructed_mask, reconstructed_vm return reconstructed_image, reconstructed_mask ================================================ FILE: scripts/utils/__init__.py ================================================ from __future__ import absolute_import from .evaluation import * from .imutils import * from .logger import * from .misc import * from .osutils import * from .transforms import * ================================================ FILE: scripts/utils/evaluation.py ================================================ from __future__ import absolute_import import math import numpy as np import matplotlib.pyplot as plt from random import randint from .misc import * from .transforms import transform, transform_preds __all__ = ['accuracy', 'AverageMeter'] def get_preds(scores): ''' get predictions from score maps in torch Tensor return type: torch.LongTensor ''' assert scores.dim() == 4, 'Score maps should be 4-dim' maxval, idx = torch.max(scores.view(scores.size(0), scores.size(1), -1), 2) maxval = maxval.view(scores.size(0), scores.size(1), 1) idx = idx.view(scores.size(0), scores.size(1), 1) + 1 preds = idx.repeat(1, 1, 2).float() preds[:,:,0] = (preds[:,:,0] - 1) % scores.size(3) + 1 preds[:,:,1] = torch.floor((preds[:,:,1] - 1) / scores.size(2)) + 1 pred_mask = maxval.gt(0).repeat(1, 1, 2).float() preds *= pred_mask return preds def calc_dists(preds, target, normalize): preds = preds.float() target = target.float() dists = torch.zeros(preds.size(1), preds.size(0)) for n in range(preds.size(0)): for c in range(preds.size(1)): if target[n,c,0] > 1 and target[n, c, 1] > 1: dists[c, n] = torch.dist(preds[n,c,:], target[n,c,:])/normalize[n] else: dists[c, n] = -1 return dists def dist_acc(dists, thr=0.5): ''' Return percentage below threshold while ignoring values with a -1 ''' if dists.ne(-1).sum() > 0: return dists.le(thr).eq(dists.ne(-1)).sum()*1.0 / dists.ne(-1).sum() else: return -1 def accuracy(output, target, thr=0.5): ''' Calculate accuracy according to PCK, but uses ground truth heatmap rather than x,y locations First value to be returned is average accuracy across 'idxs', followed by individual accuracies ''' # output_mask = torch.gt(output,thr); # target_mask = torch.gt(target,thr); # equal_mask = torch.eq(output_mask,target_mask); # fp_equal_mask = torch.lt(output_mask,target_mask); # fn_equal_mask = torch.gt(output_mask,target_mask); # tp = torch.sum(equal_mask); # fn = torch.sum(fn_equal_mask); # fp = torch.sum(fp_equal_mask); # return 2*tp / (2*tp+fn+fp) if output.dim() > 2: v,i = torch.max(output,1); else: v,i = torch.max(output,1); return torch.sum(target.long() == i).float()/target.numel() def final_preds(output, center, scale, res): coords = get_preds(output) # float type # pose-processing for n in range(coords.size(0)): for p in range(coords.size(1)): hm = output[n][p] px = int(math.floor(coords[n][p][0])) py = int(math.floor(coords[n][p][1])) if px > 1 and px < res[0] and py > 1 and py < res[1]: diff = torch.Tensor([hm[py - 1][px] - hm[py - 1][px - 2], hm[py][px - 1]-hm[py - 2][px - 1]]) coords[n][p] += diff.sign() * .25 coords += 0.5 preds = coords.clone() # Transform back for i in range(coords.size(0)): preds[i] = transform_preds(coords[i], center[i], scale[i], res) if preds.dim() < 3: preds = preds.view(1, preds.size()) return preds class AverageMeter(object): """Computes and stores the average and current value""" def __init__(self): self.reset() def reset(self): self.val = 0 self.avg = 0 self.sum = 0 self.count = 0 def update(self, val, n=1): self.val = val self.sum += val * n self.count += n self.avg = self.sum / self.count ================================================ FILE: scripts/utils/imutils.py ================================================ from __future__ import absolute_import import torch import torch.nn as nn import numpy as np import scipy.misc from .misc import * def im_to_numpy(img): img = to_numpy(img) img = np.transpose(img, (1, 2, 0)) # H*W*C return img def im_to_torch(img): img = np.transpose(img, (2, 0, 1)) # C*H*W img = to_torch(img).float() if img.max() > 1: img /= 255 return img def load_image(img_path): # H x W x C => C x H x W return im_to_torch(scipy.misc.imread(img_path, mode='RGB')) def imread_all(img_path): return scipy.misc.imread(img_path, mode='RGB') def load_image_gray(img_path): # H x W x C => C x H x W x = scipy.misc.imread(img_path, mode='L') x = x[:,:,np.newaxis] return im_to_torch(x) def resize(img, owidth, oheight): img = im_to_numpy(img) if img.shape[2] == 1: img = scipy.misc.imresize(img.squeeze(),(oheight,owidth)) img = img[:,:,np.newaxis] else: img = scipy.misc.imresize( img, (oheight, owidth) ) img = im_to_torch(img) # print('%f %f' % (img.min(), img.max())) return img # ============================================================================= # Helpful functions generating groundtruth labelmap # ============================================================================= def gaussian(shape=(7,7),sigma=1): """ 2D gaussian mask - should give the same result as MATLAB's fspecial('gaussian',[shape],[sigma]) """ m,n = [(ss-1.)/2. for ss in shape] y,x = np.ogrid[-m:m+1,-n:n+1] h = np.exp( -(x*x + y*y) / (2.*sigma*sigma) ) h[ h < np.finfo(h.dtype).eps*h.max() ] = 0 return to_torch(h).float() def draw_labelmap(img, pt, sigma, type='Gaussian'): # Draw a 2D gaussian # Adopted from https://github.com/anewell/pose-hg-train/blob/master/src/pypose/draw.py img = to_numpy(img) # Check that any part of the gaussian is in-bounds ul = [int(pt[0] - 3 * sigma), int(pt[1] - 3 * sigma)] br = [int(pt[0] + 3 * sigma + 1), int(pt[1] + 3 * sigma + 1)] if (ul[0] >= img.shape[1] or ul[1] >= img.shape[0] or br[0] < 0 or br[1] < 0): # If not, just return the image as is return to_torch(img) # Generate gaussian size = 6 * sigma + 1 x = np.arange(0, size, 1, float) y = x[:, np.newaxis] x0 = y0 = size // 2 # The gaussian is not normalized, we want the center value to equal 1 if type == 'Gaussian': g = np.exp(- ((x - x0) ** 2 + (y - y0) ** 2) / (2 * sigma ** 2)) elif type == 'Cauchy': g = sigma / (((x - x0) ** 2 + (y - y0) ** 2 + sigma ** 2) ** 1.5) # Usable gaussian range g_x = max(0, -ul[0]), min(br[0], img.shape[1]) - ul[0] g_y = max(0, -ul[1]), min(br[1], img.shape[0]) - ul[1] # Image range img_x = max(0, ul[0]), min(br[0], img.shape[1]) img_y = max(0, ul[1]), min(br[1], img.shape[0]) img[img_y[0]:img_y[1], img_x[0]:img_x[1]] = g[g_y[0]:g_y[1], g_x[0]:g_x[1]] return to_torch(img) # ============================================================================= # Helpful display functions # ============================================================================= def gauss(x, a, b, c, d=0): return a * np.exp(-(x - b)**2 / (2 * c**2)) + d def color_heatmap(x): x = to_numpy(x) color = np.zeros((x.shape[0],x.shape[1],3)) color[:,:,0] = gauss(x, .5, .6, .2) + gauss(x, 1, .8, .3) color[:,:,1] = gauss(x, 1, .5, .3) color[:,:,2] = gauss(x, 1, .2, .3) color[color > 1] = 1 color = (color * 255).astype(np.uint8) return color def imshow(img): npimg = im_to_numpy(img*255).astype(np.uint8) plt.imshow(npimg) plt.axis('off') def show_joints(img, pts): imshow(img) for i in range(pts.size(0)): if pts[i, 2] > 0: plt.plot(pts[i, 0], pts[i, 1], 'yo') plt.axis('off') def show_sample(inputs, target): num_sample = inputs.size(0) num_joints = target.size(1) height = target.size(2) width = target.size(3) for n in range(num_sample): inp = resize(inputs[n], width, height) out = inp for p in range(num_joints): tgt = inp*0.5 + color_heatmap(target[n,p,:,:])*0.5 out = torch.cat((out, tgt), 2) imshow(out) plt.show() def sample_with_heatmap(inp, out, num_rows=2, parts_to_show=None): inp = to_numpy(inp * 255) out = to_numpy(out) img = np.zeros((inp.shape[1], inp.shape[2], inp.shape[0])) for i in range(3): img[:, :, i] = inp[i, :, :] if parts_to_show is None: parts_to_show = np.arange(out.shape[0]) # Generate a single image to display input/output pair num_cols = int(np.ceil(float(len(parts_to_show)) / num_rows)) size = img.shape[0] // num_rows full_img = np.zeros((img.shape[0], size * (num_cols + num_rows), 3), np.uint8) full_img[:img.shape[0], :img.shape[1]] = img inp_small = scipy.misc.imresize(img, [size, size]) # Set up heatmap display for each part for i, part in enumerate(parts_to_show): part_idx = part out_resized = scipy.misc.imresize(out[part_idx], [size, size]) out_resized = out_resized.astype(float)/255 out_img = inp_small.copy() * .3 color_hm = color_heatmap(out_resized) out_img += color_hm * .7 col_offset = (i % num_cols + num_rows) * size row_offset = (i // num_cols) * size full_img[row_offset:row_offset + size, col_offset:col_offset + size] = out_img return full_img def batch_with_heatmap(inputs, outputs, mean=torch.Tensor([0.5, 0.5, 0.5]), num_rows=2, parts_to_show=None): batch_img = [] for n in range(min(inputs.size(0), 4)): inp = inputs[n] + mean.view(3, 1, 1).expand_as(inputs[n]) batch_img.append( sample_with_heatmap(inp.clamp(0, 1), outputs[n], num_rows=num_rows, parts_to_show=parts_to_show) ) return np.concatenate(batch_img) def normalize_batch(batch): # normalize using imagenet mean and std mean = batch.new_tensor([0.485, 0.456, 0.406]).view(-1, 1, 1) std = batch.new_tensor([0.229, 0.224, 0.225]).view(-1, 1, 1) batch = batch/255.0 return (batch - mean) / std def show_image_tensor(tensor): re = [] for i in range(tensor.size(0)): inp = tensor[i].data.cpu() #w,h,c inp = inp.numpy().transpose((1, 2, 0)) mean = np.array([0.485, 0.456, 0.406]) std = np.array([0.229, 0.224, 0.225]) inp = std * inp + mean inp = np.clip(inp, 0, 1).transpose((2,0,1)) re.append(torch.from_numpy(inp).unsqueeze(0)) return torch.cat(re,0) def get_jet(): colormap_int = np.zeros((256, 3), np.uint8) for i in range(0, 256, 1): colormap_int[i, 0] = np.int_(np.round(cm.jet(i)[0] * 255.0)) colormap_int[i, 1] = np.int_(np.round(cm.jet(i)[1] * 255.0)) colormap_int[i, 2] = np.int_(np.round(cm.jet(i)[2] * 255.0)) return colormap_int def clamp(num, min_value, max_value): return max(min(num, max_value), min_value) def gray2color(gray_array, color_map): rows, cols = gray_array.shape color_array = np.zeros((rows, cols, 3), np.uint8) for i in range(0, rows): for j in range(0, cols): # log(256,2) = 8 , log(1,2) = 0 * 8 color_array[i, j] = color_map[clamp(int(abs(gray_array[i, j])*10),0,255)] return color_array class objectview(object): def __init__(self, *args, **kwargs): d = dict(*args, **kwargs) self.__dict__ = d ================================================ FILE: scripts/utils/logger.py ================================================ # A simple torch style logger # (C) Wei YANG 2017 from __future__ import absolute_import import os import sys import numpy as np import matplotlib.pyplot as plt __all__ = ['Logger', 'LoggerMonitor', 'savefig'] def savefig(fname, dpi=None): dpi = 150 if dpi == None else dpi plt.savefig(fname, dpi=dpi) def plot_overlap(logger, names=None): names = logger.names if names == None else names numbers = logger.numbers for _, name in enumerate(names): x = np.arange(len(numbers[name])) plt.plot(x, np.asarray(numbers[name])) return [logger.title + '(' + name + ')' for name in names] class Logger(object): '''Save training process to log file with simple plot function.''' def __init__(self, fpath, title=None, resume=False): self.file = None self.resume = resume self.title = '' if title == None else title if fpath is not None: if resume: self.file = open(fpath, 'r') name = self.file.readline() self.names = name.rstrip().split('\t') self.numbers = {} for _, name in enumerate(self.names): self.numbers[name] = [] for numbers in self.file: numbers = numbers.rstrip().split('\t') for i in range(0, len(numbers)): self.numbers[self.names[i]].append(numbers[i]) self.file.close() self.file = open(fpath, 'a') else: self.file = open(fpath, 'w') def set_names(self, names): if self.resume: pass # initialize numbers as empty list self.numbers = {} self.names = names for _, name in enumerate(self.names): self.file.write(name) self.file.write('\t') self.numbers[name] = [] self.file.write('\n') self.file.flush() def append(self, numbers): assert len(self.names) == len(numbers), 'Numbers do not match names' for index, num in enumerate(numbers): self.file.write("{0:.6f}".format(num)) self.file.write('\t') self.numbers[self.names[index]].append(num) self.file.write('\n') self.file.flush() def plot(self, names=None): names = self.names if names == None else names numbers = self.numbers for _, name in enumerate(names): x = np.arange(len(numbers[name])) plt.plot(x, np.asarray(numbers[name])) plt.legend([self.title + '(' + name + ')' for name in names]) plt.grid(True) def close(self): if self.file is not None: self.file.close() class LoggerMonitor(object): '''Load and visualize multiple logs.''' def __init__ (self, paths): '''paths is a distionary with {name:filepath} pair''' self.loggers = [] for title, path in paths.items(): logger = Logger(path, title=title, resume=True) self.loggers.append(logger) def plot(self, names=None): plt.figure() plt.subplot(121) legend_text = [] for logger in self.loggers: legend_text += plot_overlap(logger, names) plt.legend(legend_text, bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.) plt.grid(True) if __name__ == '__main__': # # Example # logger = Logger('test.txt') # logger.set_names(['Train loss', 'Valid loss','Test loss']) # length = 100 # t = np.arange(length) # train_loss = np.exp(-t / 10.0) + np.random.rand(length) * 0.1 # valid_loss = np.exp(-t / 10.0) + np.random.rand(length) * 0.1 # test_loss = np.exp(-t / 10.0) + np.random.rand(length) * 0.1 # for i in range(0, length): # logger.append([train_loss[i], valid_loss[i], test_loss[i]]) # logger.plot() # Example: logger monitor paths = { 'resadvnet20':'/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet20/log.txt', 'resadvnet32':'/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet32/log.txt', 'resadvnet44':'/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet44/log.txt', } field = ['Valid Acc.'] monitor = LoggerMonitor(paths) monitor.plot(names=field) savefig('test.eps') ================================================ FILE: scripts/utils/losses.py ================================================ import torch import torch.nn as nn import torch.nn.functional as F from scripts.models.vgg import Vgg19 from torchvision import models from scripts.utils.misc import resize_to_match # from pytorch_msssim import SSIM, MS_SSIM import pytorch_ssim class WeightedBCE(nn.Module): def __init__(self): super(WeightedBCE, self).__init__() def forward(self, pred, gt): eposion = 1e-10 sigmoid_pred = torch.sigmoid(pred) count_pos = torch.sum(gt)*1.0+eposion count_neg = torch.sum(1.-gt)*1.0 beta = count_neg/count_pos beta_back = count_pos / (count_pos + count_neg) bce1 = nn.BCEWithLogitsLoss(pos_weight=beta) loss = beta_back*bce1(pred, gt) return loss def l1_relative(reconstructed, real, mask): batch = real.size(0) area = torch.sum(mask.view(batch,-1),dim=1) reconstructed = reconstructed * mask real = real * mask loss_l1 = torch.abs(reconstructed - real).view(batch, -1) loss_l1 = torch.sum(loss_l1, dim=1) / area loss_l1 = torch.sum(loss_l1) / batch return loss_l1 def is_dic(x): return type(x) == type([]) class Losses(nn.Module): def __init__(self, argx, device): super(Losses, self).__init__() self.args = argx if self.args.loss_type == 'l1bl2': self.outputLoss, self.attLoss, self.wrloss = nn.L1Loss(), nn.BCELoss(), nn.MSELoss() elif self.args.loss_type == 'l1wbl2': self.outputLoss, self.attLoss, self.wrloss = nn.L1Loss(), WeightedBCE(), nn.MSELoss() elif self.args.loss_type == 'l2wbl2': self.outputLoss, self.attLoss, self.wrloss = nn.MSELoss(), WeightedBCE(), nn.MSELoss() elif self.args.loss_type == 'l2xbl2': self.outputLoss, self.attLoss, self.wrloss = nn.MSELoss(), nn.BCEWithLogitsLoss(), nn.MSELoss() else: # l2bl2 self.outputLoss, self.attLoss, self.wrloss = nn.MSELoss(), nn.BCELoss(), nn.MSELoss() if self.args.style_loss > 0: self.vggloss = VGGLoss(self.args.sltype).to(device) if self.args.ssim_loss > 0: self.ssimloss = pytorch_ssim.SSIM().to(device) self.outputLoss = self.outputLoss.to(device) self.attLoss = self.attLoss.to(device) self.wrloss = self.wrloss.to(device) def forward(self,imgx,target,attx,mask,wmx,wm): pixel_loss,att_loss,wm_loss,vgg_loss,ssim_loss = 0,0,0,0,0 if is_dic(imgx): if self.args.masked: # calculate the overall loss and side output pixel_loss = self.outputLoss(imgx[0],target) + sum([self.outputLoss(im,resize_to_match(mask,im)*resize_to_match(target,im)) for im in imgx[1:]]) else: pixel_loss = sum([self.outputLoss(im,resize_to_match(target,im)) for im in imgx]) if self.args.style_loss > 0: vgg_loss = sum([self.vggloss(im,resize_to_match(target,im),resize_to_match(mask,im)) for im in imgx]) if self.args.ssim_loss > 0: ssim_loss = sum([ 1 - self.ssimloss(im,resize_to_match(target,im)) for im in imgx]) else: if self.args.masked: pixel_loss = self.outputLoss(imgx,mask*target) else: pixel_loss = self.outputLoss(imgx,target) if self.args.style_loss > 0: vgg_loss = self.vggloss(imgx,target,mask) if self.args.ssim_loss > 0: ssim_loss = 1 - self.ssimloss(imgx,target) if is_dic(attx): att_loss = sum([self.attLoss(at,resize_to_match(mask,at)) for at in attx]) else: att_loss = self.attLoss(attx, mask) if is_dic(wmx): wm_loss = sum([self.wrloss(w,resize_to_match(wm,w)) for w in wmx]) else: if self.args.masked: wm_loss = self.wrloss(wmx,mask*wm) else: wm_loss = self.wrloss(wmx, wm) return pixel_loss,att_loss,wm_loss,vgg_loss,ssim_loss def gram_matrix(feat): # https://github.com/pytorch/examples/blob/master/fast_neural_style/neural_style/utils.py (b, ch, h, w) = feat.size() feat = feat.view(b, ch, h * w) feat_t = feat.transpose(1, 2) gram = torch.bmm(feat, feat_t) / (ch * h * w) return gram class MeanShift(nn.Conv2d): def __init__(self, data_mean, data_std, data_range=1, norm=True): """norm (bool): normalize/denormalize the stats""" c = len(data_mean) super(MeanShift, self).__init__(c, c, kernel_size=1) std = torch.Tensor(data_std) self.weight.data = torch.eye(c).view(c, c, 1, 1) if norm: self.weight.data.div_(std.view(c, 1, 1, 1)) self.bias.data = -1 * data_range * torch.Tensor(data_mean) self.bias.data.div_(std) else: self.weight.data.mul_(std.view(c, 1, 1, 1)) self.bias.data = data_range * torch.Tensor(data_mean) self.requires_grad = False def VGGLoss(losstype): if losstype == 'vgg': return VGGLossA() elif losstype == 'vggx': return VGGLossX(mask=False) elif losstype == 'mvggx': return VGGLossX(mask=True) elif losstype == 'rvggx': return VGGLossX(mask=True,relative=True) else: raise Exception("error in %s"%losstype) class VGGLossA(nn.Module): def __init__(self, vgg=None, weights=None, indices=None, normalize=True): super(VGGLossA, self).__init__() if vgg is None: self.vgg = Vgg19().cuda() else: self.vgg = vgg self.criterion = nn.L1Loss() self.weights = weights or [1.0/2.6, 1.0/4.8, 1.0/3.7, 1.0/5.6, 10/1.5] self.indices = indices or [2, 7, 12, 21, 30] if normalize: self.normalize = MeanShift([0.485, 0.456, 0.406], [0.229, 0.224, 0.225], norm=True).cuda() else: self.normalize = None def forward(self, x, y): if self.normalize is not None: x = self.normalize(x) y = self.normalize(y) x_vgg, y_vgg = self.vgg(x, self.indices), self.vgg(y, self.indices) loss = 0 for i in range(len(x_vgg)): loss += self.weights[i] * self.criterion(x_vgg[i], y_vgg[i].detach()) return loss class VGG16FeatureExtractor(nn.Module): def __init__(self): super().__init__() vgg16 = models.vgg16(pretrained=True) self.enc_1 = nn.Sequential(*vgg16.features[:5]) self.enc_2 = nn.Sequential(*vgg16.features[5:10]) self.enc_3 = nn.Sequential(*vgg16.features[10:17]) # fix the encoder for i in range(3): for param in getattr(self, 'enc_{:d}'.format(i + 1)).parameters(): param.requires_grad = False def forward(self, image): results = [image] for i in range(3): func = getattr(self, 'enc_{:d}'.format(i + 1)) results.append(func(results[-1])) return results[1:] class VGGLossX(nn.Module): def __init__(self, normalize=True, mask=False, relative=False): super(VGGLossX, self).__init__() self.vgg = VGG16FeatureExtractor().cuda() self.criterion = nn.L1Loss().cuda() if not relative else l1_relative self.use_mask= mask self.relative = relative if normalize: self.normalize = MeanShift([0.485, 0.456, 0.406], [0.229, 0.224, 0.225], norm=True).cuda() else: self.normalize = None def forward(self, x, y, Xmask=None): if not self.use_mask: mask = torch.ones_like(x)[:,0:1,:,:] else: mask = Xmask if self.normalize is not None: x = self.normalize(x) y = self.normalize(y) x_vgg = self.vgg(x) y_vgg = self.vgg(y) loss = 0 for i in range(3): if self.relative: loss += self.criterion(x_vgg[i],y_vgg[i].detach(),resize_to_match(mask,x_vgg[i])) else: loss += self.criterion(resize_to_match(mask,x_vgg[i])*x_vgg[i],resize_to_match(mask,y_vgg[i])*y_vgg[i].detach()) return loss class GANLosses(object): """docstring for Loss""" def __init__(self, gantype): super(GANLosses, self).__init__() self.generator_loss = gen_gan(gantype) self.discriminator_loss = dis_gan(gantype) self.gantype = gantype def g_loss(self,dis_fake): if 'hinge' in self.gantype: return gen_hinge(dis_fake) else: return self.generator_loss(dis_fake) def d_loss(self,dis_fake,dis_real): if 'hinge' in self.gantype: return dis_hinge(dis_fake,dis_real) else: return self.discriminator_loss(dis_fake,dis_real) class gen_gan(nn.Module): def __init__(self,gantype): super(gen_gan,self).__init__() if gantype == 'lsgan': self.criterion = nn.MSELoss() elif gantype == 'naive': self.criterion = nn.BCEWithLogitsLoss() else: raise Exception("error gan type") def forward(self,dis_fake): return self.criterion(dis_fake, torch.ones_like(dis_fake)) class dis_gan(nn.Module): def __init__(self,gantype): super(dis_gan,self).__init__() if gantype == 'lsgan': self.criterion = nn.MSELoss() elif gantype == 'naive': self.criterion = nn.BCEWithLogitsLoss() else: raise Exception("error gan type") def forward(self,dis_fake,dis_real): loss_fake = self.criterion(dis_fake, torch.zeros_like(dis_fake)) loss_real = self.criterion(dis_real, torch.ones_like(dis_real)) return loss_fake, loss_real # def gen_gan(dis_fake): # # fake -> 1 # return F.binary_cross_entropy_with_logits(dis_fake,torch.ones_like(dis_fake)) # def dis_gan(dis_fake,dis_real): # # fake -> 0 , real ->1 # loss_fake = F.binary_cross_entropy_with_logits(dis_fake, torch.zeros_like(dis_real)) # loss_real = F.binary_cross_entropy_with_logits(dis_real, torch.ones_like(dis_fake)) # return loss_fake,loss_real # def gen_lsgan(dis_fake): # loss = F.mse_loss(dis_fake,torch.ones_like(dis_fake)) # # return loss # def dis_lsgan(dis_fake, dis_real): # loss_fake = F.mse_loss(dis_fake, torch.zeros_like(dis_real)) # loss_real = F.mse_loss(dis_real, torch.ones_like(dis_real)) # return loss_fake,loss_real def gen_hinge(dis_fake, dis_real=None): return -torch.mean(dis_fake) def dis_hinge(dis_fake, dis_real): loss_fake = torch.mean(torch.relu(1. + dis_fake)) loss_real = torch.mean(torch.relu(1. - dis_real)) return loss_fake,loss_real ================================================ FILE: scripts/utils/misc.py ================================================ from __future__ import absolute_import import os import shutil import torch import math import numpy as np import scipy.io import matplotlib.pyplot as plt import torch.nn.functional as F def to_numpy(tensor): if torch.is_tensor(tensor): return tensor.cpu().numpy() elif type(tensor).__module__ != 'numpy': raise ValueError("Cannot convert {} to numpy array" .format(type(tensor))) return tensor def resize_to_match(fm,to): # just use interpolate # [1,3] = (h,w) return F.interpolate(fm,to.size()[-2:],mode='bilinear',align_corners=False) def to_torch(ndarray): if type(ndarray).__module__ == 'numpy': return torch.from_numpy(ndarray) elif not torch.is_tensor(ndarray): raise ValueError("Cannot convert {} to torch tensor" .format(type(ndarray))) return ndarray def save_checkpoint(machine,filename='checkpoint.pth.tar', snapshot=None): is_best = True if machine.best_acc < machine.metric else False if is_best: machine.best_acc = machine.metric state = { 'epoch': machine.current_epoch + 1, 'arch': machine.args.arch, 'state_dict': machine.model.state_dict(), 'best_acc': machine.best_acc, 'optimizer' : machine.optimizer.state_dict(), } filepath = os.path.join(machine.args.checkpoint, filename) torch.save(state, filepath) if snapshot and state['epoch'] % snapshot == 0: shutil.copyfile(filepath, os.path.join(machine.args.checkpoint, 'checkpoint_{}.pth.tar'.format(state.epoch))) if is_best: machine.best_acc = machine.metric print('Saving Best Metric with PSNR:%s'%machine.best_acc) shutil.copyfile(filepath, os.path.join(machine.args.checkpoint, 'model_best.pth.tar')) def save_pred(preds, checkpoint='checkpoint', filename='preds_valid.mat'): preds = to_numpy(preds) filepath = os.path.join(checkpoint, filename) scipy.io.savemat(filepath, mdict={'preds' : preds}) def adjust_learning_rate(datasets,optimizer, epoch, lr,args): """Sets the learning rate to the initial LR decayed by schedule""" if epoch in args.schedule: lr *= args.gamma for param_group in optimizer.param_groups: param_group['lr'] = lr # decay sigma for dset in datasets: if args.sigma_decay > 0: dset.dataset.sigma *= args.sigma_decay dset.dataset.sigma *= args.sigma_decay return lr ================================================ FILE: scripts/utils/model_init.py ================================================ from torch.nn import init def weights_init_normal(m): classname = m.__class__.__name__ # print(classname) if classname.find('Conv') != -1: init.normal_(m.weight.data, 0.0, 0.02) elif classname.find('Linear') != -1: init.normal_(m.weight.data, 0.0, 0.02) elif classname.find('BatchNorm2d') != -1: init.normal_(m.weight.data, 1.0, 0.02) init.constant_(m.bias.data, 0.0) def weights_init_xavier(m): classname = m.__class__.__name__ # print(classname) if classname.find('Conv') != -1: init.xavier_normal(m.weight.data, gain=0.02) elif classname.find('Linear') != -1: init.xavier_normal(m.weight.data, gain=0.02) # elif classname.find('BatchNorm2d') != -1: # init.normal(m.weight.data, 1.0, 0.02) # init.constant(m.bias.data, 0.0) def weights_init_kaiming(m): classname = m.__class__.__name__ # print(classname) if classname.find('Conv') != -1 and m.weight.requires_grad == True: init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') elif classname.find('Linear') != -1 and m.weight.requires_grad == True: init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') elif classname.find('BatchNorm2d') != -1 and m.weight.requires_grad == True: init.normal_(m.weight.data, 1.0, 0.02) init.constant_(m.bias.data, 0.0) def weights_init_orthogonal(m): classname = m.__class__.__name__ if classname.find('Conv') != -1: init.orthogonal(m.weight.data, gain=1) elif classname.find('Linear') != -1: init.orthogonal(m.weight.data, gain=1) # elif classname.find('BatchNorm2d') != -1: # init.normal(m.weight.data, 1.0, 0.02) # init.constant(m.bias.data, 0.0) ================================================ FILE: scripts/utils/osutils.py ================================================ from __future__ import absolute_import import os import errno def mkdir_p(dir_path): try: os.makedirs(dir_path) except OSError as e: if e.errno != errno.EEXIST: raise def isfile(fname): return os.path.isfile(fname) def isdir(dirname): return os.path.isdir(dirname) def join(path, *paths): return os.path.join(path, *paths) ================================================ FILE: scripts/utils/parallel.py ================================================ ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ ## Created by: Hang Zhang, Rutgers University, Email: zhang.hang@rutgers.edu ## Modified by Thomas Wolf, HuggingFace Inc., Email: thomas@huggingface.co ## Copyright (c) 2017-2018 ## ## This source code is licensed under the MIT-style license found in the ## LICENSE file in the root directory of this source tree ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ """Encoding Data Parallel""" import threading import functools import torch from torch.autograd import Variable, Function import torch.cuda.comm as comm from torch.nn.parallel import DistributedDataParallel from torch.nn.parallel.data_parallel import DataParallel from torch.nn.parallel.parallel_apply import get_a_var from torch.nn.parallel.scatter_gather import gather from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast torch_ver = torch.__version__[:3] __all__ = ['allreduce', 'DataParallelModel', 'DataParallelCriterion', 'patch_replication_callback'] def allreduce(*inputs): """Cross GPU all reduce autograd operation for calculate mean and variance in SyncBN. """ return AllReduce.apply(*inputs) class AllReduce(Function): @staticmethod def forward(ctx, num_inputs, *inputs): ctx.num_inputs = num_inputs ctx.target_gpus = [inputs[i].get_device() for i in range(0, len(inputs), num_inputs)] inputs = [inputs[i:i + num_inputs] for i in range(0, len(inputs), num_inputs)] # sort before reduce sum inputs = sorted(inputs, key=lambda i: i[0].get_device()) results = comm.reduce_add_coalesced(inputs, ctx.target_gpus[0]) outputs = comm.broadcast_coalesced(results, ctx.target_gpus) return tuple([t for tensors in outputs for t in tensors]) @staticmethod def backward(ctx, *inputs): inputs = [i.data for i in inputs] inputs = [inputs[i:i + ctx.num_inputs] for i in range(0, len(inputs), ctx.num_inputs)] results = comm.reduce_add_coalesced(inputs, ctx.target_gpus[0]) outputs = comm.broadcast_coalesced(results, ctx.target_gpus) return (None,) + tuple([Variable(t) for tensors in outputs for t in tensors]) class Reduce(Function): @staticmethod def forward(ctx, *inputs): ctx.target_gpus = [inputs[i].get_device() for i in range(len(inputs))] inputs = sorted(inputs, key=lambda i: i.get_device()) return comm.reduce_add(inputs) @staticmethod def backward(ctx, gradOutput): return Broadcast.apply(ctx.target_gpus, gradOutput) class DistributedDataParallelModel(DistributedDataParallel): """Implements data parallelism at the module level for the DistributedDataParallel module. This container parallelizes the application of the given module by splitting the input across the specified devices by chunking in the batch dimension. In the forward pass, the module is replicated on each device, and each replica handles a portion of the input. During the backwards pass, gradients from each replica are summed into the original module. Note that the outputs are not gathered, please use compatible :class:`encoding.parallel.DataParallelCriterion`. The batch size should be larger than the number of GPUs used. It should also be an integer multiple of the number of GPUs so that each chunk is the same size (so that each GPU processes the same number of samples). Args: module: module to be parallelized device_ids: CUDA devices (default: all devices) Reference: Hang Zhang, Kristin Dana, Jianping Shi, Zhongyue Zhang, Xiaogang Wang, Ambrish Tyagi, Amit Agrawal. “Context Encoding for Semantic Segmentation. *The IEEE Conference on Computer Vision and Pattern Recognition (CVPR) 2018* Example:: >>> net = encoding.nn.DistributedDataParallelModel(model, device_ids=[0, 1, 2]) >>> y = net(x) """ def gather(self, outputs, output_device): return outputs class DataParallelModel(DataParallel): """Implements data parallelism at the module level. This container parallelizes the application of the given module by splitting the input across the specified devices by chunking in the batch dimension. In the forward pass, the module is replicated on each device, and each replica handles a portion of the input. During the backwards pass, gradients from each replica are summed into the original module. Note that the outputs are not gathered, please use compatible :class:`encoding.parallel.DataParallelCriterion`. The batch size should be larger than the number of GPUs used. It should also be an integer multiple of the number of GPUs so that each chunk is the same size (so that each GPU processes the same number of samples). Args: module: module to be parallelized device_ids: CUDA devices (default: all devices) Reference: Hang Zhang, Kristin Dana, Jianping Shi, Zhongyue Zhang, Xiaogang Wang, Ambrish Tyagi, Amit Agrawal. “Context Encoding for Semantic Segmentation. *The IEEE Conference on Computer Vision and Pattern Recognition (CVPR) 2018* Example:: >>> net = encoding.nn.DataParallelModel(model, device_ids=[0, 1, 2]) >>> y = net(x) """ def gather(self, outputs, output_device): return outputs def replicate(self, module, device_ids): modules = super(DataParallelModel, self).replicate(module, device_ids) execute_replication_callbacks(modules) return modules class DataParallelCriterion(DataParallel): """ Calculate loss in multiple-GPUs, which balance the memory usage. The targets are splitted across the specified devices by chunking in the batch dimension. Please use together with :class:`encoding.parallel.DataParallelModel`. Reference: Hang Zhang, Kristin Dana, Jianping Shi, Zhongyue Zhang, Xiaogang Wang, Ambrish Tyagi, Amit Agrawal. “Context Encoding for Semantic Segmentation. *The IEEE Conference on Computer Vision and Pattern Recognition (CVPR) 2018* Example:: >>> net = encoding.nn.DataParallelModel(model, device_ids=[0, 1, 2]) >>> criterion = encoding.nn.DataParallelCriterion(criterion, device_ids=[0, 1, 2]) >>> y = net(x) >>> loss = criterion(y, target) """ def forward(self, inputs, *targets, **kwargs): # input should be already scatterd # scattering the targets instead if not self.device_ids: return self.module(inputs, *targets, **kwargs) targets, kwargs = self.scatter(targets, kwargs, self.device_ids) if len(self.device_ids) == 1: return self.module(inputs, *targets[0], **kwargs[0]) replicas = self.replicate(self.module, self.device_ids[:len(inputs)]) outputs = _criterion_parallel_apply(replicas, inputs, targets, kwargs) #return Reduce.apply(*outputs) / len(outputs) #return self.gather(outputs, self.output_device).mean() return self.gather(outputs, self.output_device) def _criterion_parallel_apply(modules, inputs, targets, kwargs_tup=None, devices=None): assert len(modules) == len(inputs) assert len(targets) == len(inputs) if kwargs_tup: assert len(modules) == len(kwargs_tup) else: kwargs_tup = ({},) * len(modules) if devices is not None: assert len(modules) == len(devices) else: devices = [None] * len(modules) lock = threading.Lock() results = {} if torch_ver != "0.3": grad_enabled = torch.is_grad_enabled() def _worker(i, module, input, target, kwargs, device=None): if torch_ver != "0.3": torch.set_grad_enabled(grad_enabled) if device is None: device = get_a_var(input).get_device() try: with torch.cuda.device(device): # this also avoids accidental slicing of `input` if it is a Tensor if not isinstance(input, (list, tuple)): input = (input,) if not isinstance(target, (list, tuple)): target = (target,) output = module(*(input + target), **kwargs) with lock: results[i] = output except Exception as e: with lock: results[i] = e if len(modules) > 1: threads = [threading.Thread(target=_worker, args=(i, module, input, target, kwargs, device),) for i, (module, input, target, kwargs, device) in enumerate(zip(modules, inputs, targets, kwargs_tup, devices))] for thread in threads: thread.start() for thread in threads: thread.join() else: _worker(0, modules[0], inputs[0], kwargs_tup[0], devices[0]) outputs = [] for i in range(len(inputs)): output = results[i] if isinstance(output, Exception): raise output outputs.append(output) return outputs ########################################################################### # Adapted from Synchronized-BatchNorm-PyTorch. # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch # class CallbackContext(object): pass def execute_replication_callbacks(modules): """ Execute an replication callback `__data_parallel_replicate__` on each module created by original replication. The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` Note that, as all modules are isomorphism, we assign each sub-module with a context (shared among multiple copies of this module on different devices). Through this context, different copies can share some information. We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback of any slave copies. """ master_copy = modules[0] nr_modules = len(list(master_copy.modules())) ctxs = [CallbackContext() for _ in range(nr_modules)] for i, module in enumerate(modules): for j, m in enumerate(module.modules()): if hasattr(m, '__data_parallel_replicate__'): m.__data_parallel_replicate__(ctxs[j], i) def patch_replication_callback(data_parallel): """ Monkey-patch an existing `DataParallel` object. Add the replication callback. Useful when you have customized `DataParallel` implementation. Examples: > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) > sync_bn = DataParallel(sync_bn, device_ids=[0, 1]) > patch_replication_callback(sync_bn) # this is equivalent to > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) """ assert isinstance(data_parallel, DataParallel) old_replicate = data_parallel.replicate @functools.wraps(old_replicate) def new_replicate(module, device_ids): modules = old_replicate(module, device_ids) execute_replication_callbacks(modules) return modules data_parallel.replicate = new_replicate ================================================ FILE: scripts/utils/transforms.py ================================================ from __future__ import absolute_import import os import numpy as np import scipy.misc import matplotlib.pyplot as plt import torch import torchvision from .misc import * from .imutils import * def color_normalize(x, mean, std): if x.size(0) == 1: x = x.repeat(3, x.size(1), x.size(2)) for t, m, s in zip(x, mean, std): t.sub_(m) return x def flip_back(flip_output, dataset='mpii'): """ flip output map """ if dataset == 'mpii': matchedParts = ( [0,5], [1,4], [2,3], [10,15], [11,14], [12,13] ) else: print('Not supported dataset: ' + dataset) # flip output horizontally flip_output = fliplr(flip_output.numpy()) # Change left-right parts for pair in matchedParts: tmp = np.copy(flip_output[:, pair[0], :, :]) flip_output[:, pair[0], :, :] = flip_output[:, pair[1], :, :] flip_output[:, pair[1], :, :] = tmp return torch.from_numpy(flip_output).float() def shufflelr(x, width, dataset='mpii'): """ flip coords """ if dataset == 'mpii': matchedParts = ( [0,5], [1,4], [2,3], [10,15], [11,14], [12,13] ) else: print('Not supported dataset: ' + dataset) # Flip horizontal x[:, 0] = width - x[:, 0] # Change left-right parts for pair in matchedParts: tmp = x[pair[0], :].clone() x[pair[0], :] = x[pair[1], :] x[pair[1], :] = tmp return x def fliplr(x): if x.ndim == 3: x = np.transpose(np.fliplr(np.transpose(x, (0, 2, 1))), (0, 2, 1)) elif x.ndim == 4: for i in range(x.shape[0]): x[i] = np.transpose(np.fliplr(np.transpose(x[i], (0, 2, 1))), (0, 2, 1)) return x.astype(float) def get_transform(center, scale, res, rot=0): """ General image processing functions """ # Generate transformation matrix h = 200 * scale t = np.zeros((3, 3)) t[0, 0] = float(res[1]) / h t[1, 1] = float(res[0]) / h t[0, 2] = res[1] * (-float(center[0]) / h + .5) t[1, 2] = res[0] * (-float(center[1]) / h + .5) t[2, 2] = 1 if not rot == 0: rot = -rot # To match direction of rotation from cropping rot_mat = np.zeros((3,3)) rot_rad = rot * np.pi / 180 sn,cs = np.sin(rot_rad), np.cos(rot_rad) rot_mat[0,:2] = [cs, -sn] rot_mat[1,:2] = [sn, cs] rot_mat[2,2] = 1 # Need to rotate around center t_mat = np.eye(3) t_mat[0,2] = -res[1]/2 t_mat[1,2] = -res[0]/2 t_inv = t_mat.copy() t_inv[:2,2] *= -1 t = np.dot(t_inv,np.dot(rot_mat,np.dot(t_mat,t))) return t def transform(pt, center, scale, res, invert=0, rot=0): # Transform pixel location to different reference t = get_transform(center, scale, res, rot=rot) if invert: t = np.linalg.inv(t) new_pt = np.array([pt[0] - 1, pt[1] - 1, 1.]).T new_pt = np.dot(t, new_pt) return new_pt[:2].astype(int) + 1 def transform_preds(coords, center, scale, res): # size = coords.size() # coords = coords.view(-1, coords.size(-1)) # print(coords.size()) for p in range(coords.size(0)): coords[p, 0:2] = to_torch(transform(coords[p, 0:2], center, scale, res, 1, 0)) return coords def crop(img, center, scale, res, rot=0): img = im_to_numpy(img) # Upper left point ul = np.array(transform([0, 0], center, scale, res, invert=1)) # Bottom right point br = np.array(transform(res, center, scale, res, invert=1)) # Padding so that when rotated proper amount of context is included pad = int(np.linalg.norm(br - ul) / 2 - float(br[1] - ul[1]) / 2) if not rot == 0: ul -= pad br += pad new_shape = [br[1] - ul[1], br[0] - ul[0]] if len(img.shape) > 2: new_shape += [img.shape[2]] new_img = np.zeros(new_shape) # Range to fill new array new_x = max(0, -ul[0]), min(br[0], len(img[0])) - ul[0] new_y = max(0, -ul[1]), min(br[1], len(img)) - ul[1] # Range to sample from original image old_x = max(0, ul[0]), min(len(img[0]), br[0]) old_y = max(0, ul[1]), min(len(img), br[1]) new_img[new_y[0]:new_y[1], new_x[0]:new_x[1]] = img[old_y[0]:old_y[1], old_x[0]:old_x[1]] if not rot == 0: # Remove padding new_img = scipy.misc.imrotate(new_img, rot) new_img = new_img[pad:-pad, pad:-pad] new_img = im_to_torch(scipy.misc.imresize(new_img, res)) return new_img def get_right(img,gray=False): img = im_to_numpy(img) #H*W*C new_img = img[:,0:256,:] new_img = im_to_torch(new_img) if gray == True: new_img = new_img[1,:,:]; return new_img class NormalizeInverse(torchvision.transforms.Normalize): """ Undoes the normalization and returns the reconstructed images in the input domain. """ def __init__(self, mean, std): mean = torch.as_tensor(mean) std = torch.as_tensor(std) std_inv = 1 / (std + 1e-7) mean_inv = -mean * std_inv super().__init__(mean=mean_inv, std=std_inv) def __call__(self, tensor): return super().__call__(tensor.clone()) ================================================ FILE: test.py ================================================ from __future__ import print_function, absolute_import import argparse import torch torch.backends.cudnn.benchmark = True from scripts.utils.misc import save_checkpoint, adjust_learning_rate import scripts.datasets as datasets import scripts.machines as machines from options import Options def main(args): val_loader = torch.utils.data.DataLoader(datasets.COCO('val',args),batch_size=args.test_batch, shuffle=False, num_workers=args.workers, pin_memory=True) data_loaders = (None,val_loader) Machine = machines.__dict__[args.machine](datasets=data_loaders, args=args) Machine.test() if __name__ == '__main__': parser=Options().init(argparse.ArgumentParser(description='WaterMark Removal')) main(parser.parse_args()) ================================================ FILE: watermark_synthesis.ipynb ================================================ { "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "SAVE ALL THE SETTING\n" ] } ], "source": [ "# watermark synthesis\n", "import os \n", "import random\n", "import shutil\n", "from PIL import Image\n", "import numpy as np\n", "\n", "def trans_paste(bg_img,fg_img,mask,box=(0,0)):\n", " fg_img_trans = Image.new(\"RGBA\",bg_img.size)\n", " fg_img_trans.paste(fg_img,box,mask=mask)\n", " new_img = Image.alpha_composite(bg_img,fg_img_trans)\n", " return new_img,fg_img_trans\n", "\n", "if os.path.isdir('dataset'):\n", " shutil.rmtree('dataset')\n", "\n", "os.mkdir('dataset')\n", "BASE_IMG_DIR = '/Users/oishii/Downloads/val2014/'\n", "WATERMARK_DIR = 'logos' #1080 \n", "images = sorted([os.path.join(BASE_IMG_DIR,x) for x in os.listdir(BASE_IMG_DIR) if '.jpg' in x])\n", "watermarks = sorted([os.path.join(WATERMARK_DIR,x).replace(' ','_') for x in os.listdir(WATERMARK_DIR) if '.png' in x])\n", "# rename all the watermark from replace ' ' to '_'\n", "\n", "random.shuffle(images)\n", "random.shuffle(watermarks)\n", "\n", "train_images = images[:int(len(images)*0.7)]\n", "val_images = images[int(len(images)*0.7):int(len(images)*0.8)]\n", "test_images = images[int(len(images)*0.8):]\n", "\n", "train_wms = watermarks[:int(len(watermarks)*0.7)]\n", "val_wms = watermarks[int(len(watermarks)*0.7):int(len(watermarks)*0.8)]\n", "test_wms = watermarks[int(len(watermarks)*0.8):]\n", "\n", "# save all the settings to file\n", "names = ['train_images','val_images','test_images','train_wms','val_wms','test_wms']\n", "lists = [train_images,val_images,test_images,train_wms,val_wms,test_wms]\n", "dataset = dict(zip(names, lists))\n", "\n", "for name,content in dataset.items():\n", " with open('dataset/%s.txt'%name,'w') as f:\n", " f.write(\"\\n\".join(content))\n", "\n", "print('SAVE ALL THE SETTING')\n", "\n", "for name, images in dataset.items():\n", " if 'images' not in name:\n", " continue\n", " # for each setting, synthesis the watermark\n", " # for each image, add X(X=6) watermark in differnet position, alpha,\n", " # save the synthesized image, watermark mask, reshaped mask,\n", " save_path = 'dataset/%s/'%name\n", " os.makedirs('%s/image'%(save_path))\n", " os.makedirs('%s/mask'%(save_path))\n", " os.makedirs('%s/wm'%(save_path))\n", " \n", " for img in images:\n", " im = Image.open(img).convert('RGBA')\n", " imw,imh = im.size\n", " \n", " for wmg in random.choices(dataset[name.replace('images','wms')],k=6):\n", " wm = Image.open(wmg.replace('_',' ')).convert(\"RGBA\") # RGBA\n", " # get the mask of wm\n", " # data agumentation of wm\n", " wm = wm.rotate(angle=random.randint(0,360),expand=True) # rotate\n", " \n", " # make sure the \n", " imrw = random.randrange(int(0.4*imw),int(0.8*imw))\n", " imrh = random.randrange(int(0.4*imh),int(0.8*imh))\n", " wmsize = imrh if imrw > imrh else imrw\n", " wm = wm.resize((wmsize,wmsize),Image.BILINEAR)\n", " w,h = wm.size # new size \n", " \n", " box_left = random.randint(0,imw-w)\n", " box_upper = random.randint(0,imh-h)\n", " wmm = wm.copy()\n", " wm.putalpha(random.randint(int(255*0.4),int(255*0.8))) # alpha\n", " \n", " ims,wmc = trans_paste(im,wm,wmm,(box_left,box_upper))\n", " \n", " wmnp = np.array(wmc) # h,w,3\n", " mask = np.sum(wmnp,axis=2)>0\n", " mm = Image.fromarray(np.uint8(mask*255),mode='L')\n", " \n", " identifier = os.path.basename(img).split('.')[0] +'-'+os.path.basename(wmg).split('.')[0] + '.png'\n", " # save \n", " wmc.save('%s/wm/%s'%(save_path,identifier))\n", " ims.save('%s/image/%s'%(save_path,identifier))\n", " mm.save('%s/mask/%s'%(save_path,identifier))\n", " \n", " " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.4" } }, "nbformat": 4, "nbformat_minor": 2 }