Repository: vinthony/deep-blind-watermark-removal
Branch: main
Commit: 72f0e61b9f06
Files: 35
Total size: 174.8 KB
Directory structure:
gitextract_p5y2m_4m/
├── README.md
├── examples/
│ ├── evaluate.sh
│ └── test.sh
├── main.py
├── options.py
├── requirements.txt
├── scripts/
│ ├── __init__.py
│ ├── datasets/
│ │ ├── BIH.py
│ │ ├── COCO.py
│ │ └── __init__.py
│ ├── machines/
│ │ ├── BasicMachine.py
│ │ ├── S2AM.py
│ │ ├── VX.py
│ │ └── __init__.py
│ ├── models/
│ │ ├── __init__.py
│ │ ├── backbone_unet.py
│ │ ├── blocks.py
│ │ ├── discriminator.py
│ │ ├── rasc.py
│ │ ├── sa_resunet.py
│ │ ├── unet.py
│ │ ├── vgg.py
│ │ └── vmu.py
│ └── utils/
│ ├── __init__.py
│ ├── evaluation.py
│ ├── imutils.py
│ ├── logger.py
│ ├── losses.py
│ ├── misc.py
│ ├── model_init.py
│ ├── osutils.py
│ ├── parallel.py
│ └── transforms.py
├── test.py
└── watermark_synthesis.ipynb
================================================
FILE CONTENTS
================================================
================================================
FILE: README.md
================================================
This repo contains the code and results of the AAAI 2021 paper:
[Split then Refine: Stacked Attention-guided ResUNets for Blind Single Image Visible Watermark Removal](https://arxiv.org/abs/2012.07007)
[Xiaodong Cun](http://vinthony.github.io), [Chi-Man Pun*](http://www.cis.umac.mo/~cmpun/)
[University of Macau](http://um.edu.mo/)
[Datasets](#Resources) | [Models](#Resources) | [Paper](https://arxiv.org/abs/2012.07007) | [🔥Online Demo!](https://colab.research.google.com/drive/1pYY7byBjM-7aFIWk8HcF9nK_s6pqGwww?usp=sharing)(Google CoLab)
The overview of the proposed two-stage framework. Firstly, we propose a multi-task network, SplitNet, for watermark detection, removal, and recovery. Then, we propose the RefineNet to smooth the learned region with the predicted mask and the recovered background from the previous stage. As a consequence, our network can be trained in an end-to-end fashion without any manual intervention. Note that, for clarity, we do not show any skip-connections between all the encoders and decoders.
> The whole project will be released in the January of 2021 (almost).
### Datasets
We synthesized four different datasets for training and testing, you can download the dataset via [huggingface](https://huggingface.co/datasets/vinthony/watermark-removal-logo/tree/main).

### Pre-trained Models
* [27kpng_model_best.pth.tar (google drive)](https://drive.google.com/file/d/1KpSJ6385CHN6WlAINqB3CYrJdleQTJBc/view?usp=sharing)
> Other Pre-trained Models are still reorganizing and uploading, it will be released soon.
### Demos
An easy-to-use online demo can be founded in [google colab](https://colab.research.google.com/drive/1pYY7byBjM-7aFIWk8HcF9nK_s6pqGwww?usp=sharing).
The local demo will be released soon.
### Pre-requirements
```
pip install -r requirements.txt
```
### Train
Besides training our methods, here, we also give an example of how to train the [s2am](https://github.com/vinthony/s2am) under our framework. More details can be found in the shell scripts.
```
bash examples/evaluation.sh
```
### Test
```
bash examples/test.sh
```
## **Acknowledgements**
The author would like to thanks Nan Chen for her helpful discussion.
Part of the code is based upon our previous work on image harmonization [s2am](https://github.com/vinthony/s2am)
## **Citation**
If you find our work useful in your research, please consider citing:
```
@misc{cun2020split,
title={Split then Refine: Stacked Attention-guided ResUNets for Blind Single Image Visible Watermark Removal},
author={Xiaodong Cun and Chi-Man Pun},
year={2020},
eprint={2012.07007},
archivePrefix={arXiv},
primaryClass={cs.CV}
}
```
## **Contact**
Please contact me if there is any question (Xiaodong Cun yb87432@um.edu.mo)
================================================
FILE: examples/evaluate.sh
================================================
set -ex
# example training scripts for AAAI-21
# Split then Refine: Stacked Attention-guided ResUNets for Blind Single Image Visible Watermark Removal
CUDA_VISIBLE_DEVICES=0 python /data/home/yb87432/s2am/main.py --epochs 100\
--schedule 100\
--lr 1e-3\
-c eval/10kgray/1e3_bs4_256_hybrid_ssim_vgg\
--arch vvv4n\
--sltype vggx\
--style-loss 0.025\
--ssim-loss 0.15\
--masked True\
--loss-type hybrid\
--limited-dataset 1\
--machine vx\
--input-size 256\
--train-batch 4\
--test-batch 1\
--base-dir $HOME/watermark/10kgray/\
--data _images
# example training scripts for TIP-20
# Improving the Harmony of the Composite Image by Spatial-Separated Attention Module
# * in the original version, the res = False
# suitable for the iHarmony4 dataset.
python /data/home/yb87432/mypaper/s2am/main.py --epochs 200\
--schedule 150\
--lr 1e-3\
-c checkpoint/normal_rasc_HAdobe5k_res \
--arch rascv2\
--style-loss 0\
--ssim-loss 0\
--limited-dataset 0\
--res True\
--machine s2am\
--input-size 256\
--train-batch 16\
--test-batch 1\
--base-dir $HOME/Datasets/\
--data HAdobe5k
================================================
FILE: examples/test.sh
================================================
set -ex
CUDA_VISIBLE_DEVICES=0 python /data/home/yb87432/s2am/test.py \
-c test/10kgray_ssim\
--resume /data/home/yb87432/s2am/eval/10kgray/1e3_bs6_256_hybrid_ssim_vgg_vx__images_vvv4n/model_best.pth.tar\
--arch vvv4n\
--machine vx\
--input-size 256\
--test-batch 1\
--evaluate\
--base-dir $HOME/watermark/10kgray/\
--data _images
================================================
FILE: main.py
================================================
from __future__ import print_function, absolute_import
import argparse
import torch,time,os
torch.backends.cudnn.benchmark = True
from scripts.utils.misc import save_checkpoint, adjust_learning_rate
import scripts.datasets as datasets
import scripts.machines as machines
from options import Options
def main(args):
if 'HFlickr' or 'HCOCO' or 'Hday2night' or 'HAdobe5k' in args.base_dir:
dataset_func = datasets.BIH
else:
dataset_func = datasets.COCO
train_loader = torch.utils.data.DataLoader(dataset_func('train',args),batch_size=args.train_batch, shuffle=True,
num_workers=args.workers, pin_memory=True)
val_loader = torch.utils.data.DataLoader(dataset_func('val',args),batch_size=args.test_batch, shuffle=False,
num_workers=args.workers, pin_memory=True)
lr = args.lr
data_loaders = (train_loader,val_loader)
Machine = machines.__dict__[args.machine](datasets=data_loaders, args=args)
print('============================ Initization Finish && Training Start =============================================')
for epoch in range(Machine.args.start_epoch, Machine.args.epochs):
print('\nEpoch: %d | LR: %.8f' % (epoch + 1, lr))
lr = adjust_learning_rate(data_loaders, Machine.optimizer, epoch, lr, args)
Machine.record('lr',lr, epoch)
Machine.train(epoch)
if args.freq < 0:
Machine.validate(epoch)
Machine.flush()
Machine.save_checkpoint()
if __name__ == '__main__':
parser=Options().init(argparse.ArgumentParser(description='WaterMark Removal'))
args = parser.parse_args()
print('==================================== WaterMark Removal =============================================')
print('==> {:50}: {:<}'.format("Start Time",time.ctime(time.time())))
print('==> {:50}: {:<}'.format("USE GPU",os.environ['CUDA_VISIBLE_DEVICES']))
print('==================================== Stable Parameters =============================================')
for arg in vars(args):
if type(getattr(args, arg)) == type([]):
if ','.join([ str(i) for i in getattr(args, arg)]) == ','.join([ str(i) for i in parser.get_default(arg)]):
print('==> {:50}: {:<}({:<})'.format(arg,','.join([ str(i) for i in getattr(args, arg)]),','.join([ str(i) for i in parser.get_default(arg)])))
else:
if getattr(args, arg) == parser.get_default(arg):
print('==> {:50}: {:<}({:<})'.format(arg,getattr(args, arg),parser.get_default(arg)))
print('==================================== Changed Parameters =============================================')
for arg in vars(args):
if type(getattr(args, arg)) == type([]):
if ','.join([ str(i) for i in getattr(args, arg)]) != ','.join([ str(i) for i in parser.get_default(arg)]):
print('==> {:50}: {:<}({:<})'.format(arg,','.join([ str(i) for i in getattr(args, arg)]),','.join([ str(i) for i in parser.get_default(arg)])))
else:
if getattr(args, arg) != parser.get_default(arg):
print('==> {:50}: {:<}({:<})'.format(arg,getattr(args, arg),parser.get_default(arg)))
print('==================================== Start Init Model ===============================================')
main(args)
print('==================================== FINISH WITHOUT ERROR =============================================')
================================================
FILE: options.py
================================================
import scripts.models as models
model_names = sorted(name for name in models.__dict__
if name.islower() and not name.startswith("__")
and callable(models.__dict__[name]))
class Options():
"""docstring for Options"""
def __init__(self):
pass
def init(self, parser):
# Model structure
parser.add_argument('--arch', '-a', metavar='ARCH', default='dhn',
choices=model_names,
help='model architecture: ' +
' | '.join(model_names) +
' (default: resnet18)')
parser.add_argument('--darch', metavar='ARCH', default='dhn',
choices=model_names,
help='model architecture: ' +
' | '.join(model_names) +
' (default: resnet18)')
parser.add_argument('--machine', '-m', metavar='NACHINE', default='basic')
# Training strategy
parser.add_argument('-j', '--workers', default=2, type=int, metavar='N',
help='number of data loading workers (default: 4)')
parser.add_argument('--epochs', default=30, type=int, metavar='N',
help='number of total epochs to run')
parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
help='manual epoch number (useful on restarts)')
parser.add_argument('--train-batch', default=64, type=int, metavar='N',
help='train batchsize')
parser.add_argument('--test-batch', default=6, type=int, metavar='N',
help='test batchsize')
parser.add_argument('--lr', '--learning-rate', default=1e-3, type=float,metavar='LR', help='initial learning rate')
parser.add_argument('--dlr', '--dlearning-rate', default=1e-3, type=float, help='initial learning rate')
parser.add_argument('--beta1', default=0.9, type=float, help='initial learning rate')
parser.add_argument('--beta2', default=0.999, type=float, help='initial learning rate')
parser.add_argument('--momentum', default=0, type=float, metavar='M',
help='momentum')
parser.add_argument('--weight-decay', '--wd', default=0, type=float,
metavar='W', help='weight decay (default: 0)')
parser.add_argument('--schedule', type=int, nargs='+', default=[5, 10],
help='Decrease learning rate at these epochs.')
parser.add_argument('--gamma', type=float, default=0.1,
help='LR is multiplied by gamma on schedule.')
# Data processing
parser.add_argument('-f', '--flip', dest='flip', action='store_true',
help='flip the input during validation')
parser.add_argument('--lambdaL1', type=float, default=1, help='the weight of L1.')
parser.add_argument('--alpha', type=float, default=0.5,
help='Groundtruth Gaussian sigma.')
parser.add_argument('--sigma-decay', type=float, default=0,
help='Sigma decay rate for each epoch.')
# Miscs
parser.add_argument('--base-dir', default='/PATH_TO_DATA_FOLDER/', type=str, metavar='PATH')
parser.add_argument('--data', default='', type=str, metavar='PATH',
help='path to save checkpoint (default: checkpoint)')
parser.add_argument('-c', '--checkpoint', default='checkpoint', type=str, metavar='PATH',
help='path to save checkpoint (default: checkpoint)')
parser.add_argument('--resume', default='', type=str, metavar='PATH',
help='path to latest checkpoint (default: none)')
parser.add_argument('--finetune', default='', type=str, metavar='PATH',
help='path to latest checkpoint (default: none)')
parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
help='evaluate model on validation set')
parser.add_argument('--style-loss', default=0, type=float,
help='preception loss')
parser.add_argument('--ssim-loss', default=0, type=float,help='msssim loss')
parser.add_argument('--att-loss', default=1, type=float,help='msssim loss')
parser.add_argument('--default-loss',default=False,type=bool)
parser.add_argument('--sltype', default='vggx', type=str)
parser.add_argument('-da', '--data-augumentation', default=False, type=bool,
help='preception loss')
parser.add_argument('-d', '--debug', dest='debug', action='store_true',
help='show intermediate results')
parser.add_argument('--input-size', default=256, type=int, metavar='N',
help='train batchsize')
parser.add_argument('--freq', default=-1, type=int, metavar='N',
help='evaluation frequence')
parser.add_argument('--normalized-input', default=False, type=bool,
help='train batchsize')
parser.add_argument('--res', default=False, type=bool,help='residual learning for s2am')
parser.add_argument('--requires-grad', default=False, type=bool,
help='train batchsize')
parser.add_argument('--limited-dataset', default=0, type=int, metavar='N')
parser.add_argument('--gpu',default=True,type=bool)
parser.add_argument('--masked',default=False,type=bool)
parser.add_argument('--gan-norm', default=False,type=bool, help='train batchsize')
parser.add_argument('--hl', default=False,type=bool, help='homogenious leanring')
parser.add_argument('--loss-type', default='l2',type=str, help='train batchsize')
return parser
================================================
FILE: requirements.txt
================================================
numpy==1.19.1
opencv-python==3.4.8.29
Pillow
scikit-image==0.14.5
scikit-learn==0.23.1
scipy==1.2.1
sklearn==0.0
tensorboardX
torch>=1.0.0
torchvision
================================================
FILE: scripts/__init__.py
================================================
from __future__ import absolute_import
from . import datasets
from . import models
from . import utils
# import os, sys
# sys.path.append(os.path.join(os.path.dirname(__file__), "progress"))
# from progress.bar import Bar as Bar
# __version__ = '0.1.0'
================================================
FILE: scripts/datasets/BIH.py
================================================
from __future__ import print_function, absolute_import
import os
import csv
import numpy as np
import json
import random
import math
import matplotlib.pyplot as plt
from collections import namedtuple
from os import listdir
from os.path import isfile, join
import torch
import torch.utils.data as data
from scripts.utils.osutils import *
from scripts.utils.imutils import *
from scripts.utils.transforms import *
import torchvision.transforms as transforms
from PIL import Image
from PIL import ImageEnhance
from PIL import ImageFilter
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True
class BIH(data.Dataset):
def __init__(self,train,config=None, sample=[],gan_norm=False):
self.train = []
self.anno = []
self.mask = []
self.wm = []
self.input_size = config.input_size
self.normalized_input = config.normalized_input
self.base_folder = config.base_dir +'/' + config.data
self.dataset = config.data
if config == None:
self.data_augumentation = False
else:
self.data_augumentation = config.data_augumentation
self.istrain = False if train.find('train') == -1 else True
self.sample = sample
self.gan_norm = gan_norm
mypath = join(self.base_folder,self.dataset+'_'+train+'.txt')
with open(mypath) as f:
# here we get the filenames
file_names = [ im.strip() for im in f.readlines() ]
if config.limited_dataset > 0:
xtrain = sorted(list(set([ file_name.split('-')[0] for file_name in file_names ])))
tmp = []
for x in xtrain:
tmp.append([y for y in file_names if x in y][0])
file_names = tmp
else:
file_names = file_names
for file_name in file_names:
self.train.append(os.path.join(self.base_folder,'images',file_name))
self.mask.append(os.path.join(self.base_folder,'masks','_'.join(file_name.split('_')[0:2])+'.png'))
self.anno.append(os.path.join(self.base_folder,'reals',file_name.split('_')[0]+'.jpg'))
if len(self.sample) > 0 :
self.train = [ self.train[i] for i in self.sample ]
self.mask = [ self.mask[i] for i in self.sample ]
self.anno = [ self.anno[i] for i in self.sample ]
self.trans = transforms.Compose([
transforms.Resize((self.input_size,self.input_size)),
transforms.ToTensor()
])
print('total Dataset of '+self.dataset+' is : ', len(self.train))
def __getitem__(self, index):
img = Image.open(self.train[index]).convert('RGB')
mask = Image.open(self.mask[index]).convert('L')
anno = Image.open(self.anno[index]).convert('RGB')
# for shadow removal and blind image harmonization, here is no ground truth wm
# wm = Image.open(self.wm[index]).convert('RGB')
return {"image": self.trans(img),
"target": self.trans(anno),
"mask": self.trans(mask),
"name": self.train[index].split('/')[-1],
"imgurl":self.train[index],
"maskurl":self.mask[index],
"targeturl":self.anno[index],
}
def __len__(self):
return len(self.train)
================================================
FILE: scripts/datasets/COCO.py
================================================
from __future__ import print_function, absolute_import
import os
import csv
import numpy as np
import json
import random
import math
import matplotlib.pyplot as plt
from collections import namedtuple
from os import listdir
from os.path import isfile, join
import torch
import torch.utils.data as data
from scripts.utils.osutils import *
from scripts.utils.imutils import *
from scripts.utils.transforms import *
import torchvision.transforms as transforms
from PIL import Image
from PIL import ImageEnhance
from PIL import ImageFilter
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True
class COCO(data.Dataset):
def __init__(self,train,config=None, sample=[],gan_norm=False):
self.train = []
self.anno = []
self.mask = []
self.wm = []
self.input_size = config.input_size
self.normalized_input = config.normalized_input
self.base_folder = config.base_dir
self.dataset = train+config.data
if config == None:
self.data_augumentation = False
else:
self.data_augumentation = config.data_augumentation
self.istrain = False if self.dataset.find('train') == -1 else True
self.sample = sample
self.gan_norm = gan_norm
mypath = join(self.base_folder,self.dataset)
file_names = sorted([f for f in listdir(join(mypath,'image')) if isfile(join(mypath,'image', f)) ])
if config.limited_dataset > 0:
xtrain = sorted(list(set([ file_name.split('-')[0] for file_name in file_names ])))
tmp = []
for x in xtrain:
# get the file_name by identifier
tmp.append([y for y in file_names if x in y][0])
file_names = tmp
else:
file_names = file_names
for file_name in file_names:
self.train.append(os.path.join(mypath,'image',file_name))
self.mask.append(os.path.join(mypath,'mask',file_name))
self.wm.append(os.path.join(mypath,'wm',file_name))
self.anno.append(os.path.join(self.base_folder,'natural',file_name.split('-')[0]+'.jpg'))
if len(self.sample) > 0 :
self.train = [ self.train[i] for i in self.sample ]
self.mask = [ self.mask[i] for i in self.sample ]
self.anno = [ self.anno[i] for i in self.sample ]
self.trans = transforms.Compose([
transforms.Resize((self.input_size,self.input_size)),
transforms.ToTensor()
])
print('total Dataset of '+self.dataset+' is : ', len(self.train))
def __getitem__(self, index):
img = Image.open(self.train[index]).convert('RGB')
mask = Image.open(self.mask[index]).convert('L')
anno = Image.open(self.anno[index]).convert('RGB')
wm = Image.open(self.wm[index]).convert('RGB')
return {"image": self.trans(img),
"target": self.trans(anno),
"mask": self.trans(mask),
"wm": self.trans(wm),
"name": self.train[index].split('/')[-1],
"imgurl":self.train[index],
"maskurl":self.mask[index],
"targeturl":self.anno[index],
"wmurl":self.wm[index]
}
def __len__(self):
return len(self.train)
================================================
FILE: scripts/datasets/__init__.py
================================================
from .COCO import COCO
from .BIH import BIH
__all__ = ('COCO','BIH')
================================================
FILE: scripts/machines/BasicMachine.py
================================================
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
from progress.bar import Bar
import json
import numpy as np
from tensorboardX import SummaryWriter
from scripts.utils.evaluation import accuracy, AverageMeter, final_preds
from scripts.utils.osutils import mkdir_p, isfile, isdir, join
from scripts.utils.parallel import DataParallelModel, DataParallelCriterion
import pytorch_ssim as pytorch_ssim
import torch.optim
import sys,shutil,os
import time
import scripts.models as archs
from math import log10
from torch.autograd import Variable
from scripts.utils.losses import VGGLoss
from scripts.utils.imutils import im_to_numpy
import skimage.io
from skimage.measure import compare_psnr,compare_ssim
class BasicMachine(object):
def __init__(self, datasets =(None,None), models = None, args = None, **kwargs):
super(BasicMachine, self).__init__()
self.args = args
# create model
print("==> creating model ")
self.model = archs.__dict__[self.args.arch]()
print("==> creating model [Finish]")
self.train_loader, self.val_loader = datasets
self.loss = torch.nn.MSELoss()
self.title = '_'+args.machine + '_' + args.data + '_' + args.arch
self.args.checkpoint = args.checkpoint + self.title
self.device = torch.device('cuda')
# create checkpoint dir
if not isdir(self.args.checkpoint):
mkdir_p(self.args.checkpoint)
self.optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, self.model.parameters()),
lr=args.lr,
betas=(args.beta1,args.beta2),
weight_decay=args.weight_decay)
if not self.args.evaluate:
self.writer = SummaryWriter(self.args.checkpoint+'/'+'ckpt')
self.best_acc = 0
self.is_best = False
self.current_epoch = 0
self.metric = -100000
self.hl = 6 if self.args.hl else 1
self.count_gpu = len(range(torch.cuda.device_count()))
if self.args.style_loss > 0:
# init perception loss
self.vggloss = VGGLoss(self.args.sltype).to(self.device)
if self.count_gpu > 1 : # multiple
# self.model = DataParallelModel(self.model, device_ids=range(torch.cuda.device_count()))
# self.loss = DataParallelCriterion(self.loss, device_ids=range(torch.cuda.device_count()))
self.model = nn.DataParallel(self.model, device_ids=range(torch.cuda.device_count()))
self.model.to(self.device)
self.loss.to(self.device)
print('==> Total params: %.2fM' % (sum(p.numel() for p in self.model.parameters())/1000000.0))
print('==> Total devices: %d' % (torch.cuda.device_count()))
print('==> Current Checkpoint: %s' % (self.args.checkpoint))
if self.args.resume != '':
self.resume(self.args.resume)
def train(self,epoch):
batch_time = AverageMeter()
data_time = AverageMeter()
losses = AverageMeter()
lossvgg = AverageMeter()
# switch to train mode
self.model.train()
end = time.time()
bar = Bar('Processing', max=len(self.train_loader)*self.hl)
for _ in range(self.hl):
for i, batches in enumerate(self.train_loader):
# measure data loading time
inputs = batches['image']
target = batches['target'].to(self.device)
mask =batches['mask'].to(self.device)
current_index = len(self.train_loader) * epoch + i
if self.args.hl:
feeded = torch.cat([inputs,mask],dim=1)
else:
feeded = inputs
feeded = feeded.to(self.device)
output = self.model(feeded)
L2_loss = self.loss(output,target)
if self.args.style_loss > 0:
vgg_loss = self.vggloss(output,target,mask)
else:
vgg_loss = 0
total_loss = L2_loss + self.args.style_loss * vgg_loss
# compute gradient and do SGD step
self.optimizer.zero_grad()
total_loss.backward()
self.optimizer.step()
# measure accuracy and record loss
losses.update(L2_loss.item(), inputs.size(0))
if self.args.style_loss > 0 :
lossvgg.update(vgg_loss.item(), inputs.size(0))
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
# plot progress
suffix = '({batch}/{size}) Data: {data:.2f}s | Batch: {bt:.3f}s | Total: {total:} | ETA: {eta:} | Loss L2: {loss_label:.4f} | Loss VGG: {loss_vgg:.4f}'.format(
batch=i + 1,
size=len(self.train_loader),
data=data_time.val,
bt=batch_time.val,
total=bar.elapsed_td,
eta=bar.eta_td,
loss_label=losses.avg,
loss_vgg=lossvgg.avg
)
if current_index % 1000 == 0:
print(suffix)
if self.args.freq > 0 and current_index % self.args.freq == 0:
self.validate(current_index)
self.flush()
self.save_checkpoint()
self.record('train/loss_L2', losses.avg, current_index)
def test(self, ):
# switch to evaluate mode
self.model.eval()
ssimes = AverageMeter()
psnres = AverageMeter()
with torch.no_grad():
for i, batches in enumerate(self.val_loader):
inputs = batches['image'].to(self.device)
target = batches['target'].to(self.device)
mask =batches['mask'].to(self.device)
outputs = self.model(inputs)
# select the outputs by the giving arch
if type(outputs) == type(inputs):
output = outputs
elif type(outputs[0]) == type([]):
output = outputs[0][0]
else:
output = outputs[0]
# recover the image to 255
output = im_to_numpy(torch.clamp(output[0]*255,min=0.0,max=255.0)).astype(np.uint8)
target = im_to_numpy(torch.clamp(target[0]*255,min=0.0,max=255.0)).astype(np.uint8)
skimage.io.imsave('%s/%s'%(self.args.checkpoint,batches['name'][0]), output)
psnr = compare_psnr(target,output)
ssim = compare_ssim(target,output,multichannel=True)
psnres.update(psnr, inputs.size(0))
ssimes.update(ssim, inputs.size(0))
print("%s:PSNR:%s,SSIM:%s"%(self.args.checkpoint,psnres.avg,ssimes.avg))
print("DONE.\n")
def validate(self, epoch):
batch_time = AverageMeter()
data_time = AverageMeter()
losses = AverageMeter()
ssimes = AverageMeter()
psnres = AverageMeter()
# switch to evaluate mode
self.model.eval()
end = time.time()
with torch.no_grad():
for i, batches in enumerate(self.val_loader):
inputs = batches['image'].to(self.device)
target = batches['target'].to(self.device)
mask =batches['mask'].to(self.device)
if self.args.hl:
feeded = torch.cat([inputs,torch.zeros((1,4,self.args.input_size,self.args.input_size)).to(self.device)],dim=1)
else:
feeded = inputs
output = self.model(feeded)
L2_loss = self.loss(output, target)
psnr = 10 * log10(1 / L2_loss.item())
ssim = pytorch_ssim.ssim(output, target)
losses.update(L2_loss.item(), inputs.size(0))
psnres.update(psnr, inputs.size(0))
ssimes.update(ssim.item(), inputs.size(0))
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
print("Epoches:%s,Losses:%.3f,PSNR:%.3f,SSIM:%.3f"%(epoch+1, losses.avg,psnres.avg,ssimes.avg))
self.record('val/loss_L2', losses.avg, epoch)
self.record('val/PSNR', psnres.avg, epoch)
self.record('val/SSIM', ssimes.avg, epoch)
self.metric = psnres.avg
def resume(self,resume_path):
if isfile(resume_path):
print("=> loading checkpoint '{}'".format(resume_path))
current_checkpoint = torch.load(resume_path)
if isinstance(current_checkpoint['state_dict'], torch.nn.DataParallel):
current_checkpoint['state_dict'] = current_checkpoint['state_dict'].module
if isinstance(current_checkpoint['optimizer'], torch.nn.DataParallel):
current_checkpoint['optimizer'] = current_checkpoint['optimizer'].module
self.args.start_epoch = current_checkpoint['epoch']
self.metric = current_checkpoint['best_acc']
self.model.load_state_dict(current_checkpoint['state_dict'])
# self.optimizer.load_state_dict(current_checkpoint['optimizer'])
print("=> loaded checkpoint '{}' (epoch {})"
.format(resume_path, current_checkpoint['epoch']))
else:
raise Exception("=> no checkpoint found at '{}'".format(resume_path))
def save_checkpoint(self,filename='checkpoint.pth.tar', snapshot=None):
is_best = True if self.best_acc < self.metric else False
if is_best:
self.best_acc = self.metric
state = {
'epoch': self.current_epoch + 1,
'arch': self.args.arch,
'state_dict': self.model.state_dict(),
'best_acc': self.best_acc,
'optimizer' : self.optimizer.state_dict() if self.optimizer else None,
}
filepath = os.path.join(self.args.checkpoint, filename)
torch.save(state, filepath)
if snapshot and state['epoch'] % snapshot == 0:
shutil.copyfile(filepath, os.path.join(self.args.checkpoint, 'checkpoint_{}.pth.tar'.format(state.epoch)))
if is_best:
self.best_acc = self.metric
print('Saving Best Metric with PSNR:%s'%self.best_acc)
shutil.copyfile(filepath, os.path.join(self.args.checkpoint, 'model_best.pth.tar'))
def clean(self):
self.writer.close()
def record(self,k,v,epoch):
self.writer.add_scalar(k, v, epoch)
def flush(self):
self.writer.flush()
sys.stdout.flush()
def norm(self,x):
if self.args.gan_norm:
return x*2.0 - 1.0
else:
return x
def denorm(self,x):
if self.args.gan_norm:
return (x+1.0)/2.0
else:
return x
================================================
FILE: scripts/machines/S2AM.py
================================================
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
from progress.bar import Bar
import json
import numpy as np
from tensorboardX import SummaryWriter
from scripts.utils.evaluation import accuracy, AverageMeter, final_preds
from scripts.utils.osutils import mkdir_p, isfile, isdir, join
from scripts.utils.parallel import DataParallelModel, DataParallelCriterion
import pytorch_ssim as pytorch_ssim
import torch.optim
import sys,shutil,os
import time
import scripts.models as archs
from math import log10
from torch.autograd import Variable
from scripts.utils.losses import VGGLoss
from scripts.utils.imutils import im_to_numpy
import skimage.io
from skimage.measure import compare_psnr,compare_ssim
class S2AM(object):
def __init__(self, datasets =(None,None), models = None, args = None, **kwargs):
super(S2AM, self).__init__()
self.args = args
# create model
print("==> creating model ")
self.model = archs.__dict__[self.args.arch]()
print("==> creating model [Finish]")
self.train_loader, self.val_loader = datasets
self.loss = torch.nn.MSELoss()
self.title = '_'+args.machine + '_' + args.data + '_' + args.arch
self.args.checkpoint = args.checkpoint + self.title
self.device = torch.device('cuda')
# create checkpoint dir
if not isdir(self.args.checkpoint):
mkdir_p(self.args.checkpoint)
self.optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, self.model.parameters()),
lr=args.lr,
betas=(args.beta1,args.beta2),
weight_decay=args.weight_decay)
if not self.args.evaluate:
self.writer = SummaryWriter(self.args.checkpoint+'/'+'ckpt')
self.best_acc = 0
self.is_best = False
self.current_epoch = 0
self.hl = 1
self.metric = -100000
self.count_gpu = len(range(torch.cuda.device_count()))
if self.args.style_loss > 0:
# init perception loss
self.vggloss = VGGLoss(self.args.sltype).to(self.device)
if self.count_gpu > 1 : # multiple
# self.model = DataParallelModel(self.model, device_ids=range(torch.cuda.device_count()))
# self.loss = DataParallelCriterion(self.loss, device_ids=range(torch.cuda.device_count()))
self.model = nn.DataParallel(self.model, device_ids=range(torch.cuda.device_count()))
self.model.to(self.device)
self.loss.to(self.device)
print('==> Total params: %.2fM' % (sum(p.numel() for p in self.model.parameters())/1000000.0))
print('==> Total devices: %d' % (torch.cuda.device_count()))
print('==> Current Checkpoint: %s' % (self.args.checkpoint))
if self.args.resume != '':
self.resume(self.args.resume)
def train(self,epoch):
batch_time = AverageMeter()
data_time = AverageMeter()
losses = AverageMeter()
lossvgg = AverageMeter()
# switch to train mode
self.model.train()
end = time.time()
bar = Bar('Processing', max=len(self.train_loader)*self.hl)
for _ in range(self.hl):
for i, batches in enumerate(self.train_loader):
# measure data loading time
inputs = batches['image'].to(self.device)
target = batches['target'].to(self.device)
mask =batches['mask'].to(self.device)
current_index = len(self.train_loader) * epoch + i
feeded = torch.cat([inputs,mask],dim=1)
feeded = feeded.to(self.device)
output = self.model(feeded)
if self.args.res:
output = output + inputs
L2_loss = self.loss(output,target)
if self.args.style_loss > 0:
vgg_loss = self.vggloss(output,target,mask)
else:
vgg_loss = 0
total_loss = L2_loss + self.args.style_loss * vgg_loss
# compute gradient and do SGD step
self.optimizer.zero_grad()
total_loss.backward()
self.optimizer.step()
# measure accuracy and record loss
losses.update(L2_loss.item(), inputs.size(0))
if self.args.style_loss > 0 :
lossvgg.update(vgg_loss.item(), inputs.size(0))
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
# plot progress
suffix = '({batch}/{size}) Data: {data:.2f}s | Batch: {bt:.3f}s | Total: {total:} | ETA: {eta:} | Loss L2: {loss_label:.4f} | Loss VGG: {loss_vgg:.4f}'.format(
batch=i + 1,
size=len(self.train_loader),
data=data_time.val,
bt=batch_time.val,
total=bar.elapsed_td,
eta=bar.eta_td,
loss_label=losses.avg,
loss_vgg=lossvgg.avg
)
if current_index % 1000 == 0:
print(suffix)
if self.args.freq > 0 and current_index % self.args.freq == 0:
self.validate(current_index)
self.flush()
self.save_checkpoint()
self.record('train/loss_L2', losses.avg, current_index)
def test(self, ):
# switch to evaluate mode
self.model.eval()
ssimes = AverageMeter()
psnres = AverageMeter()
with torch.no_grad():
for i, batches in enumerate(self.val_loader):
inputs = batches['image'].to(self.device)
target = batches['target'].to(self.device)
mask =batches['mask'].to(self.device)
feeded = torch.cat([inputs,mask],dim=1)
feeded = feeded.to(self.device)
output = self.model(feeded)
if self.args.res:
output = output + inputs
# recover the image to 255
output = im_to_numpy(torch.clamp(output[0]*255,min=0.0,max=255.0)).astype(np.uint8)
target = im_to_numpy(torch.clamp(target[0]*255,min=0.0,max=255.0)).astype(np.uint8)
skimage.io.imsave('%s/%s'%(self.args.checkpoint,batches['name'][0]), output)
psnr = compare_psnr(target,output)
ssim = compare_ssim(target,output,multichannel=True)
psnres.update(psnr, inputs.size(0))
ssimes.update(ssim, inputs.size(0))
print("%s:PSNR:%s,SSIM:%s"%(self.args.checkpoint,psnres.avg,ssimes.avg))
print("DONE.\n")
def validate(self, epoch):
batch_time = AverageMeter()
data_time = AverageMeter()
losses = AverageMeter()
ssimes = AverageMeter()
psnres = AverageMeter()
# switch to evaluate mode
self.model.eval()
end = time.time()
with torch.no_grad():
for i, batches in enumerate(self.val_loader):
inputs = batches['image'].to(self.device)
target = batches['target'].to(self.device)
mask =batches['mask'].to(self.device)
feeded = torch.cat([inputs,mask],dim=1)
feeded = feeded.to(self.device)
output = self.model(feeded)
if self.args.res:
output = output + inputs
L2_loss = self.loss(output, target)
psnr = 10 * log10(1 / L2_loss.item())
ssim = pytorch_ssim.ssim(output, target)
losses.update(L2_loss.item(), inputs.size(0))
psnres.update(psnr, inputs.size(0))
ssimes.update(ssim.item(), inputs.size(0))
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
print("Epoches:%s,Losses:%.3f,PSNR:%.3f,SSIM:%.3f"%(epoch+1, losses.avg,psnres.avg,ssimes.avg))
self.record('val/loss_L2', losses.avg, epoch)
self.record('val/PSNR', psnres.avg, epoch)
self.record('val/SSIM', ssimes.avg, epoch)
self.metric = psnres.avg
def resume(self,resume_path):
if isfile(resume_path):
print("=> loading checkpoint '{}'".format(resume_path))
current_checkpoint = torch.load(resume_path)
if isinstance(current_checkpoint['state_dict'], torch.nn.DataParallel):
current_checkpoint['state_dict'] = current_checkpoint['state_dict'].module
if isinstance(current_checkpoint['optimizer'], torch.nn.DataParallel):
current_checkpoint['optimizer'] = current_checkpoint['optimizer'].module
self.args.start_epoch = current_checkpoint['epoch']
self.metric = current_checkpoint['best_acc']
self.model.load_state_dict(current_checkpoint['state_dict'])
# self.optimizer.load_state_dict(current_checkpoint['optimizer'])
print("=> loaded checkpoint '{}' (epoch {})"
.format(resume_path, current_checkpoint['epoch']))
else:
raise Exception("=> no checkpoint found at '{}'".format(resume_path))
def save_checkpoint(self,filename='checkpoint.pth.tar', snapshot=None):
is_best = True if self.best_acc < self.metric else False
if is_best:
self.best_acc = self.metric
state = {
'epoch': self.current_epoch + 1,
'arch': self.args.arch,
'state_dict': self.model.state_dict(),
'best_acc': self.best_acc,
'optimizer' : self.optimizer.state_dict() if self.optimizer else None,
}
filepath = os.path.join(self.args.checkpoint, filename)
torch.save(state, filepath)
if snapshot and state['epoch'] % snapshot == 0:
shutil.copyfile(filepath, os.path.join(self.args.checkpoint, 'checkpoint_{}.pth.tar'.format(state.epoch)))
if is_best:
self.best_acc = self.metric
print('Saving Best Metric with PSNR:%s'%self.best_acc)
shutil.copyfile(filepath, os.path.join(self.args.checkpoint, 'model_best.pth.tar'))
def clean(self):
self.writer.close()
def record(self,k,v,epoch):
self.writer.add_scalar(k, v, epoch)
def flush(self):
self.writer.flush()
sys.stdout.flush()
def norm(self,x):
if self.args.gan_norm:
return x*2.0 - 1.0
else:
return x
def denorm(self,x):
if self.args.gan_norm:
return (x+1.0)/2.0
else:
return x
================================================
FILE: scripts/machines/VX.py
================================================
import torch
import torch.nn as nn
from progress.bar import Bar
from tqdm import tqdm
import pytorch_ssim
import json
import sys,time,os
import torchvision
from math import log10
import numpy as np
from .BasicMachine import BasicMachine
from scripts.utils.evaluation import accuracy, AverageMeter, final_preds
from scripts.utils.misc import resize_to_match
from torch.autograd import Variable
import torch.nn.functional as F
from scripts.utils.parallel import DataParallelModel, DataParallelCriterion
from scripts.utils.losses import VGGLoss, l1_relative,is_dic
from scripts.utils.imutils import im_to_numpy
import skimage.io
from skimage.measure import compare_psnr,compare_ssim
class Losses(nn.Module):
def __init__(self, argx, device, norm_func=None, denorm_func=None):
super(Losses, self).__init__()
self.args = argx
if self.args.loss_type == 'l1bl2':
self.outputLoss, self.attLoss, self.wrloss = nn.L1Loss(), nn.BCELoss(), nn.MSELoss()
elif self.args.loss_type == 'l2xbl2':
self.outputLoss, self.attLoss, self.wrloss = nn.MSELoss(), nn.BCEWithLogitsLoss(), nn.MSELoss()
elif self.args.loss_type == 'relative' or self.args.loss_type == 'hybrid':
self.outputLoss, self.attLoss, self.wrloss = l1_relative, nn.BCELoss(), l1_relative
else: # l2bl2
self.outputLoss, self.attLoss, self.wrloss = nn.MSELoss(), nn.BCELoss(), nn.MSELoss()
self.default = nn.L1Loss()
if self.args.style_loss > 0:
self.vggloss = VGGLoss(self.args.sltype).to(device)
if self.args.ssim_loss > 0:
self.ssimloss = pytorch_ssim.SSIM().to(device)
self.norm = norm_func
self.denorm = denorm_func
def forward(self,pred_ims,target,pred_ms,mask,pred_wms,wm):
pixel_loss,att_loss,wm_loss,vgg_loss,ssim_loss = [0]*5
pred_ims = pred_ims if is_dic(pred_ims) else [pred_ims]
# try the loss in the masked region
if self.args.masked and 'hybrid' in self.args.loss_type: # masked loss
pixel_loss += sum([self.outputLoss(pred_im, target, mask) for pred_im in pred_ims])
pixel_loss += sum([self.default(pred_im*pred_ms,target*mask) for pred_im in pred_ims])
recov_imgs = [ self.denorm(pred_im*mask + (1-mask)*self.norm(target)) for pred_im in pred_ims ]
wm_loss += self.wrloss(pred_wms, wm, mask)
wm_loss += self.default(pred_wms*pred_ms, wm*mask)
elif self.args.masked and 'relative' in self.args.loss_type: # masked loss
pixel_loss += sum([self.outputLoss(pred_im, target, mask) for pred_im in pred_ims])
recov_imgs = [ self.denorm(pred_im*mask + (1-mask)*self.norm(target)) for pred_im in pred_ims ]
wm_loss = self.wrloss(pred_wms, wm, mask)
elif self.args.masked:
pixel_loss += sum([self.outputLoss(pred_im*mask, target*mask) for pred_im in pred_ims])
recov_imgs = [ self.denorm(pred_im*pred_ms + (1-pred_ms)*self.norm(target)) for pred_im in pred_ims ]
wm_loss = self.wrloss(pred_wms*mask, wm*mask)
else:
pixel_loss += sum([self.outputLoss(pred_im*pred_ms, target*mask) for pred_im in pred_ims])
recov_imgs = [ self.denorm(pred_im*pred_ms + (1-pred_ms)*self.norm(target)) for pred_im in pred_ims ]
wm_loss = self.wrloss(pred_wms*pred_ms,wm*mask)
pixel_loss += sum([self.default(im,target) for im in recov_imgs])
if self.args.style_loss > 0:
vgg_loss = sum([self.vggloss(im,target,mask) for im in recov_imgs])
if self.args.ssim_loss > 0:
ssim_loss = sum([ 1 - self.ssimloss(im,target) for im in recov_imgs])
att_loss = self.attLoss(pred_ms, mask)
return pixel_loss,att_loss,wm_loss,vgg_loss,ssim_loss
class VX(BasicMachine):
def __init__(self,**kwargs):
BasicMachine.__init__(self,**kwargs)
self.loss = Losses(self.args, self.device, self.norm, self.denorm)
self.model.set_optimizers()
self.optimizer = None
def train(self,epoch):
self.current_epoch = epoch
batch_time = AverageMeter()
data_time = AverageMeter()
losses = AverageMeter()
lossMask = AverageMeter()
lossWM = AverageMeter()
lossMX = AverageMeter()
lossvgg = AverageMeter()
lossssim = AverageMeter()
# switch to train mode
self.model.train()
end = time.time()
bar = Bar('Processing {} '.format(self.args.arch), max=len(self.train_loader))
for i, batches in enumerate(self.train_loader):
current_index = len(self.train_loader) * epoch + i
inputs = batches['image'].to(self.device)
target = batches['target'].to(self.device)
mask = batches['mask'].to(self.device)
wm = batches['wm'].to(self.device)
outputs = self.model(self.norm(inputs))
self.model.zero_grad_all()
l2_loss,att_loss,wm_loss,style_loss,ssim_loss = self.loss(outputs[0],self.norm(target),outputs[1],mask,outputs[2],self.norm(wm))
total_loss = 2*l2_loss + self.args.att_loss * att_loss + wm_loss + self.args.style_loss * style_loss + self.args.ssim_loss * ssim_loss
# compute gradient and do SGD step
total_loss.backward()
self.model.step_all()
# measure accuracy and record loss
losses.update(l2_loss.item(), inputs.size(0))
lossMask.update(att_loss.item(), inputs.size(0))
lossWM.update(wm_loss.item(), inputs.size(0))
if self.args.style_loss > 0 :
lossvgg.update(style_loss.item(), inputs.size(0))
if self.args.ssim_loss > 0 :
lossssim.update(ssim_loss.item(), inputs.size(0))
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
# plot progress
suffix = "({batch}/{size}) Data: {data:.2f}s | Batch: {bt:.3f}s | Total: {total:} | ETA: {eta:} | Loss L2: {loss_label:.4f} | Loss Mask: {loss_mask:.4f} | loss WM: {loss_wm:.4f} | loss VGG: {loss_vgg:.4f} | loss SSIM: {loss_ssim:.4f}| loss MX: {loss_mx:.4f}".format(
batch=i + 1,
size=len(self.train_loader),
data=data_time.val,
bt=batch_time.val,
total=bar.elapsed_td,
eta=bar.eta_td,
loss_label=losses.avg,
loss_mask=lossMask.avg,
loss_wm=lossWM.avg,
loss_vgg=lossvgg.avg,
loss_ssim=lossssim.avg,
loss_mx=lossMX.avg
)
if current_index % 1000 == 0:
print(suffix)
if self.args.freq > 0 and current_index % self.args.freq == 0:
self.validate(current_index)
self.flush()
self.save_checkpoint()
self.record('train/loss_L2', losses.avg, epoch)
self.record('train/loss_Mask', lossMask.avg, epoch)
self.record('train/loss_WM', lossWM.avg, epoch)
self.record('train/loss_VGG', lossvgg.avg, epoch)
self.record('train/loss_SSIM', lossssim.avg, epoch)
self.record('train/loss_MX', lossMX.avg, epoch)
def validate(self, epoch):
self.current_epoch = epoch
batch_time = AverageMeter()
data_time = AverageMeter()
losses = AverageMeter()
lossMask = AverageMeter()
psnres = AverageMeter()
ssimes = AverageMeter()
# switch to evaluate mode
self.model.eval()
end = time.time()
bar = Bar('Processing {} '.format(self.args.arch), max=len(self.val_loader))
with torch.no_grad():
for i, batches in enumerate(self.val_loader):
current_index = len(self.val_loader) * epoch + i
inputs = batches['image'].to(self.device)
target = batches['target'].to(self.device)
outputs = self.model(self.norm(inputs))
imoutput,immask,imwatermark = outputs
imoutput = imoutput[0] if is_dic(imoutput) else imoutput
imfinal = self.denorm(imoutput*immask + self.norm(inputs)*(1-immask))
if i % 300 == 0:
# save the sample images
ims = torch.cat([inputs,target,imfinal,immask.repeat(1,3,1,1)],dim=3)
torchvision.utils.save_image(ims,os.path.join(self.args.checkpoint,'%s_%s.jpg'%(i,epoch)))
# here two choice: mseLoss or NLLLoss
psnr = 10 * log10(1 / F.mse_loss(imfinal,target).item())
ssim = pytorch_ssim.ssim(imfinal,target)
psnres.update(psnr, inputs.size(0))
ssimes.update(ssim, inputs.size(0))
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
# plot progress
bar.suffix = '({batch}/{size}) Data: {data:.2f}s | Batch: {bt:.3f}s | Total: {total:} | ETA: {eta:} | Loss_L2: {loss_label:.4f} | Loss_Mask: {loss_mask:.4f} | PSNR: {psnr:.4f} | SSIM: {ssim:.4f}'.format(
batch=i + 1,
size=len(self.val_loader),
data=data_time.val,
bt=batch_time.val,
total=bar.elapsed_td,
eta=bar.eta_td,
loss_label=losses.avg,
loss_mask=lossMask.avg,
psnr=psnres.avg,
ssim=ssimes.avg
)
bar.next()
bar.finish()
print("Iter:%s,Losses:%s,PSNR:%.4f,SSIM:%.4f"%(epoch, losses.avg,psnres.avg,ssimes.avg))
self.record('val/loss_L2', losses.avg, epoch)
self.record('val/lossMask', lossMask.avg, epoch)
self.record('val/PSNR', psnres.avg, epoch)
self.record('val/SSIM', ssimes.avg, epoch)
self.metric = psnres.avg
self.model.train()
def test(self, ):
# switch to evaluate mode
self.model.eval()
print("==> testing VM model ")
ssimes = AverageMeter()
psnres = AverageMeter()
ssimesx = AverageMeter()
psnresx = AverageMeter()
with torch.no_grad():
for i, batches in enumerate(tqdm(self.val_loader)):
inputs = batches['image'].to(self.device)
target = batches['target'].to(self.device)
mask =batches['mask'].to(self.device)
# select the outputs by the giving arch
outputs = self.model(self.norm(inputs))
imoutput,immask,imwatermark = outputs
imoutput = imoutput[0] if is_dic(imoutput) else imoutput
imfinal = self.denorm(imoutput*immask + self.norm(inputs)*(1-immask))
psnrx = 10 * log10(1 / F.mse_loss(imfinal,target).item())
ssimx = pytorch_ssim.ssim(imfinal,target)
# recover the image to 255
imfinal = im_to_numpy(torch.clamp(imfinal[0]*255,min=0.0,max=255.0)).astype(np.uint8)
target = im_to_numpy(torch.clamp(target[0]*255,min=0.0,max=255.0)).astype(np.uint8)
skimage.io.imsave('%s/%s'%(self.args.checkpoint,batches['name'][0]), imfinal)
psnr = compare_psnr(target,imfinal)
ssim = compare_ssim(target,imfinal,multichannel=True)
psnres.update(psnr, inputs.size(0))
ssimes.update(ssim, inputs.size(0))
psnresx.update(psnrx, inputs.size(0))
ssimesx.update(ssimx, inputs.size(0))
print("%s:PSNR:%.5f(%.5f),SSIM:%.5f(%.5f)"%(self.args.checkpoint,psnres.avg,psnresx.avg,ssimes.avg,ssimesx.avg))
print("DONE.\n")
================================================
FILE: scripts/machines/__init__.py
================================================
from .BasicMachine import BasicMachine
from .VX import VX
from .S2AM import S2AM
def basic(**kwargs):
return BasicMachine(**kwargs)
def s2am(**kwargs):
return S2AM(**kwargs)
def vx(**kwargs):
return VX(**kwargs)
================================================
FILE: scripts/models/__init__.py
================================================
from .vgg import *
from .backbone_unet import *
from .discriminator import *
================================================
FILE: scripts/models/backbone_unet.py
================================================
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import functools
import math
from scripts.utils.model_init import *
from scripts.models.rasc import *
from scripts.models.unet import UnetGenerator,MinimalUnetV2
from scripts.models.vmu import UnetVM
from scripts.models.sa_resunet import UnetVMS2AMv4
# our method
def vvv4n(**kwargs):
return UnetVMS2AMv4(shared_depth=2, blocks=3, long_skip=True, use_vm_decoder=True,s2am='vms2am')
# BVMR
def vm3(**kwargs):
return UnetVM(shared_depth=2, blocks=3, use_vm_decoder=True)
# Blind version of S2AM
def urasc(**kwargs):
model = UnetGenerator(3,3,is_attention_layer=True,attention_model=URASC,basicblock=MinimalUnetV2)
model.apply(weights_init_kaiming)
return model
# Improving the Harmony of the Composite Image by Spatial-Separated Attention Module
# Xiaodong Cun and Chi-Man Pun
# University of Macau
# Trans. on Image Processing, vol. 29, pp. 4759-4771, 2020.
def rascv2(**kwargs):
model = UnetGenerator(4,3,is_attention_layer=True,attention_model=RASC,basicblock=MinimalUnetV2)
model.apply(weights_init_kaiming)
return model
# just original unet
def unet(**kwargs):
model = UnetGenerator(3,3)
model.apply(weights_init_kaiming)
return model
================================================
FILE: scripts/models/blocks.py
================================================
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import functools
import math
import numbers
from scripts.utils.model_init import *
from scripts.models.vgg import Vgg16
from torch import nn, cuda
from torch.autograd import Variable
class BasicLearningBlock(nn.Module):
"""docstring for BasicLearningBlock"""
def __init__(self,channel):
super(BasicLearningBlock, self).__init__()
self.rconv1 = nn.Conv2d(channel,channel*2,3,padding=1,bias=False)
self.rbn1 = nn.BatchNorm2d(channel*2)
self.rconv2 = nn.Conv2d(channel*2,channel,3,padding=1,bias=False)
self.rbn2 = nn.BatchNorm2d(channel)
def forward(self,feature):
return F.elu(self.rbn2(self.rconv2(F.elu(self.rbn1(self.rconv1(feature))))))
# From https://discuss.pytorch.org/t/is-there-anyway-to-do-gaussian-filtering-for-an-image-2d-3d-in-pytorch/12351/3
class GaussianSmoothing(nn.Module):
"""
Apply gaussian smoothing on a
1d, 2d or 3d tensor. Filtering is performed seperately for each channel
in the input using a depthwise convolution.
Arguments:
channels (int, sequence): Number of channels of the input tensors. Output will
have this number of channels as well.
kernel_size (int, sequence): Size of the gaussian kernel.
sigma (float, sequence): Standard deviation of the gaussian kernel.
dim (int, optional): The number of dimensions of the data.
Default value is 2 (spatial).
"""
def __init__(self, channels, kernel_size, sigma, dim=2):
super(GaussianSmoothing, self).__init__()
if isinstance(kernel_size, numbers.Number):
kernel_size = [kernel_size] * dim
if isinstance(sigma, numbers.Number):
sigma = [sigma] * dim
# The gaussian kernel is the product of the
# gaussian function of each dimension.
kernel = 1
meshgrids = torch.meshgrid(
[
torch.arange(size, dtype=torch.float32)
for size in kernel_size
]
)
for size, std, mgrid in zip(kernel_size, sigma, meshgrids):
mean = (size - 1) / 2
kernel *= 1 / (std * math.sqrt(2 * math.pi)) * \
torch.exp(-((mgrid - mean) / (2 * std)) ** 2)
# Make sure sum of values in gaussian kernel equals 1.
kernel = kernel / torch.sum(kernel)
# Reshape to depthwise convolutional weight
kernel = kernel.view(1, 1, *kernel.size())
kernel = kernel.repeat(channels, *[1] * (kernel.dim() - 1))
self.register_buffer('weight', kernel)
self.groups = channels
if dim == 1:
self.conv = F.conv1d
elif dim == 2:
self.conv = F.conv2d
elif dim == 3:
self.conv = F.conv3d
else:
raise RuntimeError(
'Only 1, 2 and 3 dimensions are supported. Received {}.'.format(dim)
)
def forward(self, input):
"""
Apply gaussian filter to input.
Arguments:
input (torch.Tensor): Input to apply gaussian filter on.
Returns:
filtered (torch.Tensor): Filtered output.
"""
return self.conv(input, weight=self.weight, groups=self.groups)
class ChannelPool(nn.Module):
def __init__(self,types):
super(ChannelPool, self).__init__()
if types == 'avg':
self.poolingx = nn.AdaptiveAvgPool1d(1)
elif types == 'max':
self.poolingx = nn.AdaptiveMaxPool1d(1)
else:
raise 'inner error'
def forward(self, input):
n, c, w, h = input.size()
input = input.view(n,c,w*h).permute(0,2,1)
pooled = self.poolingx(input)# b,w*h,c -> b,w*h,1
_, _, c = pooled.size()
return pooled.view(n,c,w,h)
class SEBlock(nn.Module):
"""docstring for SEBlock"""
def __init__(self, channel,reducation=16):
super(SEBlock, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.fc = nn.Sequential(
nn.Linear(channel,channel//reducation),
nn.ReLU(inplace=True),
nn.Linear(channel//reducation,channel),
nn.Sigmoid())
def forward(self,x):
b,c,w,h = x.size()
y1 = self.avg_pool(x).view(b,c)
y = self.fc(y1).view(b,c,1,1)
return x*y
class GlobalAttentionModule(nn.Module):
"""docstring for GlobalAttentionModule"""
def __init__(self, channel,reducation=16):
super(GlobalAttentionModule, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.max_pool = nn.AdaptiveMaxPool2d(1)
self.fc = nn.Sequential(
nn.Linear(channel*2,channel//reducation),
nn.ReLU(inplace=True),
nn.Linear(channel//reducation,channel),
nn.Sigmoid())
def forward(self,x):
b,c,w,h = x.size()
y1 = self.avg_pool(x).view(b,c)
y2 = self.max_pool(x).view(b,c)
y = self.fc(torch.cat([y1,y2],1)).view(b,c,1,1)
return x*y
class SpatialAttentionModule(nn.Module):
"""docstring for SpatialAttentionModule"""
def __init__(self, channel,reducation=16):
super(SpatialAttentionModule, self).__init__()
self.avg_pool = ChannelPool('avg')
self.max_pool = ChannelPool('max')
self.fc = nn.Sequential(
nn.Conv2d(2,reducation,7,stride=1,padding=3),
nn.ReLU(inplace=True),
nn.Conv2d(reducation,1,7,stride=1,padding=3),
nn.Sigmoid())
def forward(self,x):
b,c,w,h = x.size()
y1 = self.avg_pool(x)
y2 = self.max_pool(x)
y = self.fc(torch.cat([y1,y2],1))
yr = 1-y
return y,yr
class GlobalAttentionModuleJustSigmoid(nn.Module):
"""docstring for GlobalAttentionModule"""
def __init__(self, channel,reducation=16):
super(GlobalAttentionModuleJustSigmoid, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.max_pool = nn.AdaptiveMaxPool2d(1)
self.fc = nn.Sequential(
nn.Linear(channel*2,channel//reducation),
nn.ReLU(inplace=True),
nn.Linear(channel//reducation,channel),
nn.Sigmoid())
def forward(self,x):
b,c,w,h = x.size()
y1 = self.avg_pool(x).view(b,c)
y2 = self.max_pool(x).view(b,c)
y = self.fc(torch.cat([y1,y2],1)).view(b,c,1,1)
return y
class BasicBlock(nn.Module):
def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1, groups=1, relu=True, bn=True, bias=False):
super(BasicBlock, self).__init__()
self.out_channels = out_planes
self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias)
self.bn = nn.BatchNorm2d(out_planes,eps=1e-5, momentum=0.01, affine=True) if bn else None
self.relu = nn.ReLU() if relu else None
def forward(self, x):
x = self.conv(x)
if self.bn is not None:
x = self.bn(x)
if self.relu is not None:
x = self.relu(x)
return x
class Flatten(nn.Module):
def forward(self, x):
return x.view(x.size(0), -1)
class ChannelGate(nn.Module):
def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg', 'max']):
super(ChannelGate, self).__init__()
self.gate_channels = gate_channels
self.mlp = nn.Sequential(
Flatten(),
nn.Linear(gate_channels, gate_channels // reduction_ratio),
nn.ReLU(),
nn.Linear(gate_channels // reduction_ratio, gate_channels)
)
self.pool_types = pool_types
def forward(self, x):
channel_att_sum = None
for pool_type in self.pool_types:
if pool_type=='avg':
avg_pool = F.avg_pool2d( x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))
channel_att_raw = self.mlp( avg_pool )
elif pool_type=='max':
max_pool = F.max_pool2d( x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))
channel_att_raw = self.mlp( max_pool )
elif pool_type=='lp':
lp_pool = F.lp_pool2d( x, 2, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))
channel_att_raw = self.mlp( lp_pool )
elif pool_type=='lse':
# LSE pool only
lse_pool = logsumexp_2d(x)
channel_att_raw = self.mlp( lse_pool )
if channel_att_sum is None:
channel_att_sum = channel_att_raw
else:
channel_att_sum = channel_att_sum + channel_att_raw
scale = F.sigmoid( channel_att_sum ).unsqueeze(2).unsqueeze(3).expand_as(x)
return x * scale
def logsumexp_2d(tensor):
tensor_flatten = tensor.view(tensor.size(0), tensor.size(1), -1)
s, _ = torch.max(tensor_flatten, dim=2, keepdim=True)
outputs = s + (tensor_flatten - s).exp().sum(dim=2, keepdim=True).log()
return outputs
class ChannelPoolX(nn.Module):
def forward(self, x):
return torch.cat( (torch.max(x,1)[0].unsqueeze(1), torch.mean(x,1).unsqueeze(1)), dim=1 )
class SpatialGate(nn.Module):
def __init__(self):
super(SpatialGate, self).__init__()
kernel_size = 7
self.compress = ChannelPoolX()
self.spatial = BasicBlock(2, 1, kernel_size, stride=1, padding=(kernel_size-1) // 2, relu=False)
def forward(self, x):
x_compress = self.compress(x)
x_out = self.spatial(x_compress)
scale = F.sigmoid(x_out) # broadcasting
return x * scale
class CBAM(nn.Module):
def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg', 'max'], no_spatial=False):
super(CBAM, self).__init__()
self.ChannelGate = ChannelGate(gate_channels, reduction_ratio, pool_types)
self.no_spatial=no_spatial
if not no_spatial:
self.SpatialGate = SpatialGate()
def forward(self, x):
x_out = self.ChannelGate(x)
if not self.no_spatial:
x_out = self.SpatialGate(x_out)
return x_out
================================================
FILE: scripts/models/discriminator.py
================================================
import numpy as np
import functools
import math
import torch
from torch.autograd import Variable
import torch.nn.functional as F
from torch import nn
from torch import Tensor
from torch.nn import Parameter
from scripts.utils.model_init import *
from torch.optim.optimizer import Optimizer, required
__all__ = ['patchgan','sngan','maskedsngan']
class SNCoXvWithActivation(torch.nn.Module):
"""
SN convolution for spetral normalization conv
"""
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, activation=torch.nn.LeakyReLU(0.2, inplace=True)):
super(SNCoXvWithActivation, self).__init__()
self.conv2d = torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias)
self.conv2d = torch.nn.utils.spectral_norm(self.conv2d)
self.activation = activation
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight)
def forward(self, input):
x = self.conv2d(input)
if self.activation is not None:
return self.activation(x)
else:
return x
def l2normalize(v, eps=1e-12):
return v / (v.norm() + eps)
class SpectralNorm(nn.Module):
def __init__(self, module, name='weight', power_iterations=1):
super(SpectralNorm, self).__init__()
self.module = module
self.name = name
self.power_iterations = power_iterations
if not self._made_params():
self._make_params()
def _update_u_v(self):
u = getattr(self.module, self.name + "_u")
v = getattr(self.module, self.name + "_v")
w = getattr(self.module, self.name + "_bar")
height = w.data.shape[0]
for _ in range(self.power_iterations):
v.data = l2normalize(torch.mv(torch.t(w.view(height,-1).data), u.data))
u.data = l2normalize(torch.mv(w.view(height,-1).data, v.data))
# sigma = torch.dot(u.data, torch.mv(w.view(height,-1).data, v.data))
sigma = u.dot(w.view(height, -1).mv(v))
setattr(self.module, self.name, w / sigma.expand_as(w))
def _made_params(self):
try:
u = getattr(self.module, self.name + "_u")
v = getattr(self.module, self.name + "_v")
w = getattr(self.module, self.name + "_bar")
return True
except AttributeError:
return False
def _make_params(self):
w = getattr(self.module, self.name)
height = w.data.shape[0]
width = w.view(height, -1).data.shape[1]
u = Parameter(w.data.new(height).normal_(0, 1), requires_grad=False)
v = Parameter(w.data.new(width).normal_(0, 1), requires_grad=False)
u.data = l2normalize(u.data)
v.data = l2normalize(v.data)
w_bar = Parameter(w.data)
del self.module._parameters[self.name]
self.module.register_parameter(self.name + "_u", u)
self.module.register_parameter(self.name + "_v", v)
self.module.register_parameter(self.name + "_bar", w_bar)
def forward(self, *args):
self._update_u_v()
return self.module.forward(*args)
def get_pad(in_, ksize, stride, atrous=1):
out_ = np.ceil(float(in_)/stride)
return int(((out_ - 1) * stride + atrous*(ksize-1) + 1 - in_)/2)
class SNDiscriminator(nn.Module):
def __init__(self,channel=6):
super(SNDiscriminator, self).__init__()
cnum = 32
self.discriminator_net = nn.Sequential(
SNCoXvWithActivation(channel, 2*cnum, 4, 2, padding=get_pad(256, 5, 2)),
SNCoXvWithActivation(2*cnum, 4*cnum, 4, 2, padding=get_pad(128, 5, 2)),
SNCoXvWithActivation(4*cnum, 8*cnum, 4, 2, padding=get_pad(64, 5, 2)),
SNCoXvWithActivation(8*cnum, 8*cnum, 4, 2, padding=get_pad(32, 5, 2)),
SNCoXvWithActivation(8*cnum, 8*cnum, 4, 2, padding=get_pad(16, 5, 2)), # 8*8*256
# SNConvWithActivation(8*cnum, 8*cnum, 4, 2, padding=get_pad(8, 5, 2)), # 4*4*256
# SNConvWithActivation(8*cnum, 8*cnum, 4, 2, padding=get_pad(4, 5, 2)), # 2*2*256
)
# self.linear = nn.Linear(2*2*256,1)
def forward(self, img_A, img_B):
# Concatenate image and condition image by channels to produce input
img_input = torch.cat((img_A, img_B), 1)
x = self.discriminator_net(img_input)
# x = x.view((x.size(0),-1))
# x = self.linear(x)
return x
class Discriminator(nn.Module):
def __init__(self, in_channels=3):
super(Discriminator, self).__init__()
def discriminator_block(in_filters, out_filters, normalization=True):
"""Returns downsampling layers of each discriminator block"""
layers = [nn.Conv2d(in_filters, out_filters, 4, stride=2, padding=1)]
if normalization:
layers.append(nn.InstanceNorm2d(out_filters))
layers.append(nn.LeakyReLU(0.2, inplace=True))
return layers
self.model = nn.Sequential(
*discriminator_block(in_channels*2, 64, normalization=False),
*discriminator_block(64, 128),
*discriminator_block(128, 256),
*discriminator_block(256, 512),
nn.ZeroPad2d((1, 0, 1, 0)),
nn.Conv2d(512, 1, 4, padding=1, bias=False)
)
def forward(self, img_A, img_B):
# Concatenate image and condition image by channels to produce input
img_input = torch.cat((img_A, img_B), 1)
return self.model(img_input)
def patchgan():
model = Discriminator()
model.apply(weights_init_kaiming)
return model
def sngan():
model = SNDiscriminator()
model.apply(weights_init_kaiming)
return model
def maskedsngan():
model = SNDiscriminator(channel=7)
model.apply(weights_init_kaiming)
return model
================================================
FILE: scripts/models/rasc.py
================================================
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import math
from scripts.utils.model_init import *
from scripts.models.vgg import Vgg16
from scripts.models.blocks import *
class CAWapper(nn.Module):
"""docstring for SENet"""
def __init__(self, channel, type_of_connection=BasicLearningBlock):
super(CAWapper, self).__init__()
self.attention = ContextualAttention(ksize=3, stride=1, rate=2, fuse_k=3, softmax_scale=10, fuse=True, use_cuda=True)
def forward(self, feature, mask):
_, _, w, _ = feature.size()
_, _, mw, _ = mask.size()
# binaryfiy
# selected the feature from the background as the additional feature to masked splicing feature.
mask = torch.round(F.avg_pool2d(mask, 2, stride=mw//w))
result = self.attention(feature,mask)
return result
class NLWapper(nn.Module):
"""docstring for SENet"""
def __init__(self, channel, type_of_connection=BasicLearningBlock):
super(NLWapper, self).__init__()
self.attention = NONLocalBlock2D(channel)
def forward(self, feature, mask):
_, _, w, _ = feature.size()
_, _, mw, _ = mask.size()
# binaryfiy
# selected the feature from the background as the additional feature to masked splicing feature.
# mask = torch.round(F.avg_pool2d(mask, 2, stride=mw//w))
result = self.attention(feature)
return result
class SENet(nn.Module):
"""docstring for SENet"""
def __init__(self,channel,type_of_connection=BasicLearningBlock):
super(SENet, self).__init__()
self.attention = SEBlock(channel,16)
def forward(self,feature,mask):
_,_,w,_ = feature.size()
_,_,mw,_ = mask.size()
# binaryfiy
# selected the feature from the background as the additional feature to masked splicing feature.
mask = torch.round(F.avg_pool2d(mask,2,stride=mw//w))
result = self.attention(feature)
return result
class CBAMConnect(nn.Module):
def __init__(self,channel):
super(CBAMConnect, self).__init__()
self.attention = CBAM(channel)
def forward(self,feature,mask):
results = self.attention(feature)
return results
class RASC(nn.Module):
def __init__(self,channel,type_of_connection=BasicLearningBlock):
super(RASC, self).__init__()
self.connection = type_of_connection(channel)
self.background_attention = GlobalAttentionModule(channel,16)
self.mixed_attention = GlobalAttentionModule(channel,16)
self.spliced_attention = GlobalAttentionModule(channel,16)
self.gaussianMask = GaussianSmoothing(1,5,1)
def forward(self,feature,mask):
_,_,w,_ = feature.size()
_,_,mw,_ = mask.size()
# binaryfiy
# selected the feature from the background as the additional feature to masked splicing feature.
if w != mw:
mask = torch.round(F.avg_pool2d(mask,2,stride=mw//w))
reverse_mask = -1*(mask-1)
# here we add gaussin filter to mask and reverse_mask for better harimoization of edges.
mask = self.gaussianMask(F.pad(mask,(2,2,2,2),mode='reflect'))
reverse_mask = self.gaussianMask(F.pad(reverse_mask,(2,2,2,2),mode='reflect'))
background = self.background_attention(feature) * reverse_mask
selected_feature = self.mixed_attention(feature)
spliced_feature = self.spliced_attention(feature)
spliced = ( self.connection(spliced_feature) + selected_feature ) * mask
return background + spliced
class UNO(nn.Module):
def __init__(self,channel):
super(UNO, self).__init__()
def forward(self,feature,_m):
return feature
class URASC(nn.Module):
def __init__(self,channel,type_of_connection=BasicLearningBlock):
super(URASC, self).__init__()
self.connection = type_of_connection(channel)
self.background_attention = GlobalAttentionModule(channel,16)
self.mixed_attention = GlobalAttentionModule(channel,16)
self.spliced_attention = GlobalAttentionModule(channel,16)
self.mask_attention = SpatialAttentionModule(channel,16)
def forward(self,feature, m=None):
_,_,w,_ = feature.size()
mask, reverse_mask = self.mask_attention(feature)
background = self.background_attention(feature) * reverse_mask
selected_feature = self.mixed_attention(feature)
spliced_feature = self.spliced_attention(feature)
spliced = ( self.connection(spliced_feature) + selected_feature ) * mask
return background + spliced
class MaskedURASC(nn.Module):
def __init__(self,channel,type_of_connection=BasicLearningBlock):
super(MaskedURASC, self).__init__()
self.connection = type_of_connection(channel)
self.background_attention = GlobalAttentionModule(channel,16)
self.mixed_attention = GlobalAttentionModule(channel,16)
self.spliced_attention = GlobalAttentionModule(channel,16)
self.mask_attention = SpatialAttentionModule(channel,16)
def forward(self,feature):
_,_,w,_ = feature.size()
mask, reverse_mask = self.mask_attention(feature)
background = self.background_attention(feature) * reverse_mask
selected_feature = self.mixed_attention(feature)
spliced_feature = self.spliced_attention(feature)
spliced = ( self.connection(spliced_feature) + selected_feature ) * mask
return background + spliced, mask
================================================
FILE: scripts/models/sa_resunet.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
from scripts.models.blocks import SEBlock
from scripts.models.rasc import *
from scripts.models.unet import UnetGenerator,MinimalUnetV2
def weight_init(m):
if isinstance(m, nn.Conv2d):
nn.init.xavier_normal_(m.weight)
nn.init.constant_(m.bias, 0)
def reset_params(model):
for i, m in enumerate(model.modules()):
weight_init(m)
def conv3x3(in_channels, out_channels, stride=1,
padding=1, bias=True, groups=1):
return nn.Conv2d(
in_channels,
out_channels,
kernel_size=3,
stride=stride,
padding=padding,
bias=bias,
groups=groups)
def up_conv2x2(in_channels, out_channels, transpose=True):
if transpose:
return nn.ConvTranspose2d(
in_channels,
out_channels,
kernel_size=2,
stride=2)
else:
return nn.Sequential(
nn.Upsample(mode='bilinear', scale_factor=2),
conv1x1(in_channels, out_channels))
def conv1x1(in_channels, out_channels, groups=1):
return nn.Conv2d(
in_channels,
out_channels,
kernel_size=1,
groups=groups,
stride=1)
class UpCoXvD(nn.Module):
def __init__(self, in_channels, out_channels, blocks, residual=True,norm=nn.BatchNorm2d, act=F.relu,batch_norm=True, transpose=True,concat=True,use_att=False):
super(UpCoXvD, self).__init__()
self.concat = concat
self.residual = residual
self.batch_norm = batch_norm
self.bn = None
self.conv2 = []
self.use_att = use_att
self.up_conv = up_conv2x2(in_channels, out_channels, transpose=transpose)
self.norm0 = norm(out_channels)
if self.use_att:
self.s2am = RASC(2 * out_channels)
else:
self.s2am = None
if self.concat:
self.conv1 = conv3x3(2 * out_channels, out_channels)
self.norm1 = norm(out_channels , out_channels)
else:
self.conv1 = conv3x3(out_channels, out_channels)
self.norm1 = norm(out_channels , out_channels)
for _ in range(blocks):
self.conv2.append(conv3x3(out_channels, out_channels))
if self.batch_norm:
self.bn = []
for _ in range(blocks):
self.bn.append(norm(out_channels))
self.bn = nn.ModuleList(self.bn)
self.conv2 = nn.ModuleList(self.conv2)
self.act = act
def forward(self, from_up, from_down, mask=None,se=None):
from_up = self.act(self.norm0(self.up_conv(from_up)))
if self.concat:
x1 = torch.cat((from_up, from_down), 1)
else:
if from_down is not None:
x1 = from_up + from_down
else:
x1 = from_up
if self.use_att:
x1 = self.s2am(x1,mask)
x1 = self.act(self.norm1(self.conv1(x1)))
x2 = None
for idx, conv in enumerate(self.conv2):
x2 = conv(x1)
if self.batch_norm:
x2 = self.bn[idx](x2)
if (se is not None) and (idx == len(self.conv2) - 1): # last
x2 = se(x2)
if self.residual:
x2 = x2 + x1
x2 = self.act(x2)
x1 = x2
return x2
class DownCoXvD(nn.Module):
def __init__(self, in_channels, out_channels, blocks, pooling=True, norm=nn.BatchNorm2d,act=F.relu,residual=True, batch_norm=True):
super(DownCoXvD, self).__init__()
self.pooling = pooling
self.residual = residual
self.batch_norm = batch_norm
self.bn = None
self.pool = None
self.conv1 = conv3x3(in_channels, out_channels)
self.norm1 = norm(out_channels)
self.conv2 = []
for _ in range(blocks):
self.conv2.append(conv3x3(out_channels, out_channels))
if self.batch_norm:
self.bn = []
for _ in range(blocks):
self.bn.append(norm(out_channels))
self.bn = nn.ModuleList(self.bn)
if self.pooling:
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
self.conv2 = nn.ModuleList(self.conv2)
self.act = act
def __call__(self, x):
return self.forward(x)
def forward(self, x):
x1 = self.act(self.norm1(self.conv1(x)))
x2 = None
for idx, conv in enumerate(self.conv2):
x2 = conv(x1)
if self.batch_norm:
x2 = self.bn[idx](x2)
if self.residual:
x2 = x2 + x1
x2 = self.act(x2)
x1 = x2
before_pool = x2
if self.pooling:
x2 = self.pool(x2)
return x2, before_pool
class UnetDecoderD(nn.Module):
def __init__(self, in_channels=512, out_channels=3, norm=nn.BatchNorm2d,act=F.relu, depth=5, blocks=1, residual=True, batch_norm=True,
transpose=True, concat=True, is_final=True, use_att=False):
super(UnetDecoderD, self).__init__()
self.conv_final = None
self.up_convs = []
self.atts = []
self.use_att = use_att
outs = in_channels
for i in range(depth-1): # depth = 1
ins = outs
outs = ins // 2
# 512,256
# 256,128
# 128,64
# 64,32
up_conv = UpCoXvD(ins, outs, blocks, residual=residual, batch_norm=batch_norm, transpose=transpose,
concat=concat, norm=norm, act=act)
if self.use_att:
self.atts.append(SEBlock(outs))
self.up_convs.append(up_conv)
if is_final:
self.conv_final = conv1x1(outs, out_channels)
else:
up_conv = UpCoXvD(outs, out_channels, blocks, residual=residual, batch_norm=batch_norm, transpose=transpose,
concat=concat,norm=norm, act=act)
if self.use_att:
self.atts.append(SEBlock(out_channels))
self.up_convs.append(up_conv)
self.up_convs = nn.ModuleList(self.up_convs)
self.atts = nn.ModuleList(self.atts)
reset_params(self)
def __call__(self, x, encoder_outs=None):
return self.forward(x, encoder_outs)
def forward(self, x, encoder_outs=None):
for i, up_conv in enumerate(self.up_convs):
before_pool = None
if encoder_outs is not None:
before_pool = encoder_outs[-(i+2)]
x = up_conv(x, before_pool)
if self.use_att:
x = self.atts[i](x)
if self.conv_final is not None:
x = self.conv_final(x)
return x
class UnetDecoderDatt(nn.Module):
def __init__(self, in_channels=512, out_channels=3, depth=5, blocks=1, residual=True, batch_norm=True,
transpose=True, concat=True, is_final=True, norm=nn.BatchNorm2d,act=F.relu):
super(UnetDecoderDatt, self).__init__()
self.conv_final = None
self.up_convs = []
self.im_atts = []
self.vm_atts = []
self.mask_atts = []
outs = in_channels
for i in range(depth-1): # depth = 5 [0,1,2,3]
ins = outs
outs = ins // 2
# 512,256
# 256,128
# 128,64
# 64,32
up_conv = UpCoXvD(ins, outs, blocks, residual=residual, batch_norm=batch_norm, transpose=transpose,
concat=concat, norm=nn.BatchNorm2d,act=F.relu)
self.up_convs.append(up_conv)
self.im_atts.append(SEBlock(outs))
self.vm_atts.append(SEBlock(outs))
self.mask_atts.append(SEBlock(outs))
if is_final:
self.conv_final = conv1x1(outs, out_channels)
else:
up_conv = UpCoXvD(outs, out_channels, blocks, residual=residual, batch_norm=batch_norm, transpose=transpose,
concat=concat, norm=nn.BatchNorm2d,act=F.relu)
self.up_convs.append(up_conv)
self.im_atts.append(SEBlock(out_channels))
self.vm_atts.append(SEBlock(out_channels))
self.mask_atts.append(SEBlock(out_channels))
self.up_convs = nn.ModuleList(self.up_convs)
self.im_atts = nn.ModuleList(self.im_atts)
self.vm_atts = nn.ModuleList(self.vm_atts)
self.mask_atts = nn.ModuleList(self.mask_atts)
reset_params(self)
def forward(self, input, encoder_outs=None):
# im branch
x = input
for i, up_conv in enumerate(self.up_convs):
before_pool = None
if encoder_outs is not None:
before_pool = encoder_outs[-(i+2)]
x = up_conv(x, before_pool,se=self.im_atts[i])
x_im = x
x = input
for i, up_conv in enumerate(self.up_convs):
before_pool = None
if encoder_outs is not None:
before_pool = encoder_outs[-(i+2)]
x = up_conv(x, before_pool, se = self.mask_atts[i])
x_mask = x
x = input
for i, up_conv in enumerate(self.up_convs):
before_pool = None
if encoder_outs is not None:
before_pool = encoder_outs[-(i+2)]
x = up_conv(x, before_pool, se=self.vm_atts[i])
x_vm = x
return x_im,x_mask,x_vm
class UnetEncoderD(nn.Module):
def __init__(self, in_channels=3, depth=5, blocks=1, start_filters=32, residual=True, batch_norm=True, norm=nn.BatchNorm2d, act=F.relu):
super(UnetEncoderD, self).__init__()
self.down_convs = []
outs = None
if type(blocks) is tuple:
blocks = blocks[0]
for i in range(depth):
ins = in_channels if i == 0 else outs
outs = start_filters*(2**i)
pooling = True if i < depth-1 else False
down_conv = DownCoXvD(ins, outs, blocks, pooling=pooling, residual=residual, batch_norm=batch_norm, norm=nn.BatchNorm2d, act=F.relu)
self.down_convs.append(down_conv)
self.down_convs = nn.ModuleList(self.down_convs)
reset_params(self)
def __call__(self, x):
return self.forward(x)
def forward(self, x):
encoder_outs = []
for d_conv in self.down_convs:
x, before_pool = d_conv(x)
encoder_outs.append(before_pool)
return x, encoder_outs
class ResDown(nn.Module):
def __init__(self, in_size, out_size, pooling=True, use_att=False):
super(ResDown, self).__init__()
self.model = DownCoXvD(in_size, out_size, 3, pooling=pooling)
def forward(self, x):
return self.model(x)
class ResUp(nn.Module):
def __init__(self, in_size, out_size, use_att=False):
super(ResUp, self).__init__()
self.model = UpCoXvD(in_size, out_size, 3, use_att=use_att)
def forward(self, x, skip_input, mask=None):
return self.model(x,skip_input,mask)
class ResDownNew(nn.Module):
def __init__(self, in_size, out_size, pooling=True, use_att=False):
super(ResDownNew, self).__init__()
self.model = DownCoXvD(in_size, out_size, 3, pooling=pooling, norm=nn.InstanceNorm2d, act=F.leaky_relu)
def forward(self, x):
return self.model(x)
class ResUpNew(nn.Module):
def __init__(self, in_size, out_size, use_att=False):
super(ResUpNew, self).__init__()
self.model = UpCoXvD(in_size, out_size, 3, use_att=use_att, norm=nn.InstanceNorm2d)
def forward(self, x, skip_input, mask=None):
return self.model(x,skip_input,mask)
class VMSingle(nn.Module):
def __init__(self, in_channels=3, out_channels=3, down=ResDown, up=ResUp, ngf=32, res=True,use_att=False):
super(VMSingle, self).__init__()
self.down1 = down(in_channels, ngf)
self.down2 = down(ngf, ngf*2)
self.down3 = down(ngf*2, ngf*4)
self.down4 = down(ngf*4, ngf*8)
self.down5 = down(ngf*8, ngf*16, pooling=False)
self.up1 = up(ngf*16, ngf*8)
self.up2 = up(ngf*8, ngf*4, use_att=use_att)
self.up3 = up(ngf*4, ngf*2, use_att=use_att)
self.up4 = up(ngf*2, ngf*1, use_att=use_att)
self.im = nn.Conv2d(ngf, 3, 1)
self.res = res
def forward(self, input):
img, mask = input[:,0:3,:,:],input[:,3:4,:,:]
# U-Net generator with skip connections from encoder to decoder
x,d1 = self.down1(input) # 128,256
x,d2 = self.down2(x) # 64,128
x,d3 = self.down3(x) # 32,64
x,d4 = self.down4(x) # 16,32
x,_ = self.down5(x) # 8,16
x = self.up1(x, d4) # 16
x = self.up2(x, d3, mask) # 32
x = self.up3(x, d2, mask) # 64
x = self.up4(x, d1, mask) # 128
im = self.im(x)
return im
class VMSingleS2AM(nn.Module):
def __init__(self, in_channels=3, out_channels=3, down=ResDown, up=ResUp, ngf=32):
super(VMSingleS2AM, self).__init__()
self.down1 = down(in_channels, ngf)
self.down2 = down(ngf, ngf*2)
self.down3 = down(ngf*2, ngf*4)
self.down4 = down(ngf*4, ngf*8)
self.down5 = down(ngf*8, ngf*16, pooling=False)
self.up1 = up(ngf*16, ngf*8)
self.up2 = up(ngf*8, ngf*4)
self.s2am2 = RASC(ngf*4)
self.up3 = up(ngf*4, ngf*2)
self.s2am3 = RASC(ngf*2)
self.up4 = up(ngf*2, ngf*1)
self.s2am4 = RASC(ngf)
self.im = nn.Conv2d(ngf, 3, 1)
def forward(self, input):
img, mask = input[:,0:3,:,:],input[:,3:4,:,:]
# U-Net generator with skip connections from encoder to decoder
x,d1 = self.down1(input) # 128,256
x,d2 = self.down2(x) # 64,128
x,d3 = self.down3(x) # 32,64
x,d4 = self.down4(x) # 16,32
x,_ = self.down5(x) # 8,16
x = self.up1(x, d4) # 16
x = self.up2(x, d3) # 32
x = self.s2am2(x, mask)
x = self.up3(x, d2) # 64
x = self.s2am3(x, mask)
x = self.up4(x, d1) # 128
x = self.s2am4(x, mask)
im = self.im(x)
return im
class UnetVMS2AMv4(nn.Module):
def __init__(self, in_channels=3, depth=5, shared_depth=0, use_vm_decoder=False, blocks=1,
out_channels_image=3, out_channels_mask=1, start_filters=32, residual=True, batch_norm=True,
transpose=True, concat=True, transfer_data=True, long_skip=False, s2am='unet', use_coarser=True,no_stage2=False):
super(UnetVMS2AMv4, self).__init__()
self.transfer_data = transfer_data
self.shared = shared_depth
self.optimizer_encoder, self.optimizer_image, self.optimizer_vm = None, None, None
self.optimizer_mask, self.optimizer_shared = None, None
if type(blocks) is not tuple:
blocks = (blocks, blocks, blocks, blocks, blocks)
if not transfer_data:
concat = False
self.encoder = UnetEncoderD(in_channels=in_channels, depth=depth, blocks=blocks[0],
start_filters=start_filters, residual=residual, batch_norm=batch_norm,norm=nn.InstanceNorm2d,act=F.leaky_relu)
self.image_decoder = UnetDecoderD(in_channels=start_filters * 2 ** (depth - shared_depth - 1),
out_channels=out_channels_image, depth=depth - shared_depth,
blocks=blocks[1], residual=residual, batch_norm=batch_norm,
transpose=transpose, concat=concat,norm=nn.InstanceNorm2d)
self.mask_decoder = UnetDecoderD(in_channels=start_filters * 2 ** (depth - shared_depth - 1),
out_channels=out_channels_mask, depth=depth - shared_depth,
blocks=blocks[2], residual=residual, batch_norm=batch_norm,
transpose=transpose, concat=concat,norm=nn.InstanceNorm2d)
self.vm_decoder = None
if use_vm_decoder:
self.vm_decoder = UnetDecoderD(in_channels=start_filters * 2 ** (depth - shared_depth - 1),
out_channels=out_channels_image, depth=depth - shared_depth,
blocks=blocks[3], residual=residual, batch_norm=batch_norm,
transpose=transpose, concat=concat,norm=nn.InstanceNorm2d)
self.shared_decoder = None
self.use_coarser = use_coarser
self.long_skip = long_skip
self.no_stage2 = no_stage2
self._forward = self.unshared_forward
if self.shared != 0:
self._forward = self.shared_forward
self.shared_decoder = UnetDecoderDatt(in_channels=start_filters * 2 ** (depth - 1),
out_channels=start_filters * 2 ** (depth - shared_depth - 1),
depth=shared_depth, blocks=blocks[4], residual=residual,
batch_norm=batch_norm, transpose=transpose, concat=concat,
is_final=False,norm=nn.InstanceNorm2d)
if s2am == 'unet':
self.s2am = UnetGenerator(4,3,is_attention_layer=True,attention_model=RASC,basicblock=MinimalUnetV2)
elif s2am == 'vm':
self.s2am = VMSingle(4)
elif s2am == 'vms2am':
self.s2am = VMSingleS2AM(4,down=ResDownNew,up=ResUpNew)
def set_optimizers(self):
self.optimizer_encoder = torch.optim.Adam(self.encoder.parameters(), lr=0.001)
self.optimizer_image = torch.optim.Adam(self.image_decoder.parameters(), lr=0.001)
self.optimizer_mask = torch.optim.Adam(self.mask_decoder.parameters(), lr=0.001)
self.optimizer_s2am = torch.optim.Adam(self.s2am.parameters(), lr=0.001)
if self.vm_decoder is not None:
self.optimizer_vm = torch.optim.Adam(self.vm_decoder.parameters(), lr=0.001)
if self.shared != 0:
self.optimizer_shared = torch.optim.Adam(self.shared_decoder.parameters(), lr=0.001)
def zero_grad_all(self):
self.optimizer_encoder.zero_grad()
self.optimizer_image.zero_grad()
self.optimizer_mask.zero_grad()
self.optimizer_s2am.zero_grad()
if self.vm_decoder is not None:
self.optimizer_vm.zero_grad()
if self.shared != 0:
self.optimizer_shared.zero_grad()
def step_all(self):
self.optimizer_encoder.step()
self.optimizer_image.step()
self.optimizer_mask.step()
self.optimizer_s2am.step()
if self.vm_decoder is not None:
self.optimizer_vm.step()
if self.shared != 0:
self.optimizer_shared.step()
def step_optimizer_image(self):
self.optimizer_image.step()
def __call__(self, synthesized):
return self._forward(synthesized)
def forward(self, synthesized):
return self._forward(synthesized)
def unshared_forward(self, synthesized):
image_code, before_pool = self.encoder(synthesized)
if not self.transfer_data:
before_pool = None
reconstructed_image = torch.tanh(self.image_decoder(image_code, before_pool))
reconstructed_mask = torch.sigmoid(self.mask_decoder(image_code, before_pool))
if self.vm_decoder is not None:
reconstructed_vm = torch.tanh(self.vm_decoder(image_code, before_pool))
return reconstructed_image, reconstructed_mask, reconstructed_vm
return reconstructed_image, reconstructed_mask
def shared_forward(self, synthesized):
image_code, before_pool = self.encoder(synthesized)
if self.transfer_data:
shared_before_pool = before_pool[- self.shared - 1:]
unshared_before_pool = before_pool[: - self.shared]
else:
before_pool = None
shared_before_pool = None
unshared_before_pool = None
im,mask,vm = self.shared_decoder(image_code, shared_before_pool)
reconstructed_image = torch.tanh(self.image_decoder(im, unshared_before_pool))
if self.long_skip:
reconstructed_image = reconstructed_image + synthesized
reconstructed_mask = torch.sigmoid(self.mask_decoder(mask, unshared_before_pool))
if self.vm_decoder is not None:
reconstructed_vm = torch.tanh(self.vm_decoder(vm, unshared_before_pool))
if self.long_skip:
reconstructed_vm = reconstructed_vm + synthesized
coarser = reconstructed_image * reconstructed_mask + (1-reconstructed_mask)* synthesized
if self.use_coarser:
refine = torch.tanh(self.s2am(torch.cat([coarser,reconstructed_mask],dim=1))) + coarser
elif self.no_stage2:
refine = torch.tanh(self.s2am(torch.cat([coarser,reconstructed_mask],dim=1)))
else:
refine = torch.tanh(self.s2am(torch.cat([coarser,reconstructed_mask],dim=1))) + synthesized
# final = refine * reconstructed_mask + (1-reconstructed_mask)* synthesized
if self.vm_decoder is not None:
return [refine, reconstructed_image], reconstructed_mask, reconstructed_vm
else:
return [refine, reconstructed_image], reconstructed_mask
================================================
FILE: scripts/models/unet.py
================================================
import torch
import torch.nn as nn
from torch.nn import init
import functools
from scripts.models.blocks import *
from scripts.models.rasc import *
class MinimalUnetV2(nn.Module):
"""docstring for MinimalUnet"""
def __init__(self, down=None,up=None,submodule=None,attention=None,withoutskip=False,**kwags):
super(MinimalUnetV2, self).__init__()
self.down = nn.Sequential(*down)
self.up = nn.Sequential(*up)
self.sub = submodule
self.attention = attention
self.withoutskip = withoutskip
self.is_attention = not self.attention == None
self.is_sub = not submodule == None
def forward(self,x,mask=None):
if self.is_sub:
x_up,_ = self.sub(self.down(x),mask)
else:
x_up = self.down(x)
if self.withoutskip: #outer or inner.
x_out = self.up(x_up)
else:
if self.is_attention:
x_out = (self.attention(torch.cat([x,self.up(x_up)],1),mask),mask)
else:
x_out = (torch.cat([x,self.up(x_up)],1),mask)
return x_out
class MinimalUnet(nn.Module):
"""docstring for MinimalUnet"""
def __init__(self, down=None,up=None,submodule=None,attention=None,withoutskip=False,**kwags):
super(MinimalUnet, self).__init__()
self.down = nn.Sequential(*down)
self.up = nn.Sequential(*up)
self.sub = submodule
self.attention = attention
self.withoutskip = withoutskip
self.is_attention = not self.attention == None
self.is_sub = not submodule == None
def forward(self,x,mask=None):
if self.is_sub:
x_up,_ = self.sub(self.down(x),mask)
else:
x_up = self.down(x)
if self.is_attention:
x = self.attention(x,mask)
if self.withoutskip: #outer or inner.
x_out = self.up(x_up)
else:
x_out = (torch.cat([x,self.up(x_up)],1),mask)
return x_out
# Defines the submodule with skip connection.
# X -------------------identity---------------------- X
# |-- downsampling -- |submodule| -- upsampling --|
class UnetSkipConnectionBlock(nn.Module):
def __init__(self, outer_nc, inner_nc, input_nc=None,
submodule=None, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d, use_dropout=False,is_attention_layer=False,
attention_model=RASC,basicblock=MinimalUnet,outermostattention=False):
super(UnetSkipConnectionBlock, self).__init__()
self.outermost = outermost
if type(norm_layer) == functools.partial:
use_bias = norm_layer.func == nn.InstanceNorm2d
else:
use_bias = norm_layer == nn.InstanceNorm2d
if input_nc is None:
input_nc = outer_nc
downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4,
stride=2, padding=1, bias=use_bias)
downrelu = nn.LeakyReLU(0.2, True)
downnorm = norm_layer(inner_nc)
uprelu = nn.ReLU(True)
upnorm = norm_layer(outer_nc)
if outermost:
upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
kernel_size=4, stride=2,
padding=1)
down = [downconv]
up = [uprelu, upconv]
model = basicblock(down,up,submodule,withoutskip=outermost)
elif innermost:
upconv = nn.ConvTranspose2d(inner_nc, outer_nc,
kernel_size=4, stride=2,
padding=1, bias=use_bias)
down = [downrelu, downconv]
up = [uprelu, upconv, upnorm]
model = basicblock(down,up)
else:
upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
kernel_size=4, stride=2,
padding=1, bias=use_bias)
down = [downrelu, downconv, downnorm]
up = [uprelu, upconv, upnorm]
if is_attention_layer:
if MinimalUnetV2.__qualname__ in basicblock.__qualname__ :
attention_model = attention_model(input_nc*2)
else:
attention_model = attention_model(input_nc)
else:
attention_model = None
if use_dropout:
model = basicblock(down,up.append(nn.Dropout(0.5)),submodule,attention_model,outermostattention=outermostattention)
else:
model = basicblock(down,up,submodule,attention_model,outermostattention=outermostattention)
self.model = model
def forward(self, x,mask=None):
# build the mask for attention use
return self.model(x,mask)
class UnetGenerator(nn.Module):
def __init__(self, input_nc, output_nc, num_downs=8, ngf=64,norm_layer=nn.BatchNorm2d, use_dropout=False,
is_attention_layer=False,attention_model=RASC,use_inner_attention=False,basicblock=MinimalUnet):
super(UnetGenerator, self).__init__()
# 8 for 256x256
# 9 for 512x512
# construct unet structure
self.need_mask = not input_nc == output_nc
unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True,basicblock=basicblock) # 1
for i in range(num_downs - 5): #3 times
unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer, use_dropout=use_dropout,is_attention_layer=use_inner_attention,attention_model=attention_model,basicblock=basicblock) # 8,4,2
unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer,is_attention_layer=is_attention_layer,attention_model=attention_model,basicblock=basicblock) #16
unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, norm_layer=norm_layer,is_attention_layer=is_attention_layer,attention_model=attention_model,basicblock=basicblock) #32
unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer,is_attention_layer=is_attention_layer,attention_model=attention_model,basicblock=basicblock, outermostattention=True) #64
unet_block = UnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, basicblock=basicblock, norm_layer=norm_layer) # 128
self.model = unet_block
def forward(self, input):
if self.need_mask:
return self.model(input,input[:,3:4,:,:])
else:
return self.model(input[:,0:3,:,:],input[:,3:4,:,:])
================================================
FILE: scripts/models/vgg.py
================================================
from collections import namedtuple
import torch
from torchvision import models
class Vgg16(torch.nn.Module):
def __init__(self, requires_grad=False):
super(Vgg16, self).__init__()
vgg_pretrained_features = models.vgg16(pretrained=True).features
self.slice1 = torch.nn.Sequential()
self.slice2 = torch.nn.Sequential()
self.slice3 = torch.nn.Sequential()
self.slice4 = torch.nn.Sequential()
self.slice5 = torch.nn.Sequential()
for x in range(4):
self.slice1.add_module(str(x), vgg_pretrained_features[x])
for x in range(4, 9):
self.slice2.add_module(str(x), vgg_pretrained_features[x])
for x in range(9, 16):
self.slice3.add_module(str(x), vgg_pretrained_features[x])
for x in range(16, 23):
self.slice4.add_module(str(x), vgg_pretrained_features[x])
for x in range(23,30):
self.slice5.add_module(str(x), vgg_pretrained_features[x])
if not requires_grad:
for param in self.parameters():
param.requires_grad = False
def forward(self, X):
h = self.slice1(X)
h_relu1_2 = h
h = self.slice2(h)
h_relu2_2 = h
h = self.slice3(h)
h_relu3_3 = h
h = self.slice4(h)
h_relu4_3 = h
h = self.slice5(h)
h_relu5_3 = h
# vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3','relu5_3'])
# out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3)
return (h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3)
class Vgg19(torch.nn.Module):
def __init__(self, requires_grad=False):
super(Vgg19, self).__init__()
# vgg_pretrained_features = models.vgg19(pretrained=True).features
self.vgg_pretrained_features = models.vgg19(pretrained=True).features
# self.slice1 = torch.nn.Sequential()
# self.slice2 = torch.nn.Sequential()
# self.slice3 = torch.nn.Sequential()
# self.slice4 = torch.nn.Sequential()
# self.slice5 = torch.nn.Sequential()
# for x in range(2):
# self.slice1.add_module(str(x), vgg_pretrained_features[x])
# for x in range(2, 7):
# self.slice2.add_module(str(x), vgg_pretrained_features[x])
# for x in range(7, 12):
# self.slice3.add_module(str(x), vgg_pretrained_features[x])
# for x in range(12, 21):
# self.slice4.add_module(str(x), vgg_pretrained_features[x])
# for x in range(21, 30):
# self.slice5.add_module(str(x), vgg_pretrained_features[x])
if not requires_grad:
for param in self.parameters():
param.requires_grad = False
def forward(self, X, indices=None):
if indices is None:
indices = [2, 7, 12, 21, 30]
out = []
#indices = sorted(indices)
for i in range(indices[-1]):
X = self.vgg_pretrained_features[i](X)
if (i+1) in indices:
out.append(X)
return out
================================================
FILE: scripts/models/vmu.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
from scripts.models.blocks import SEBlock
from scripts.models.rasc import *
from scripts.models.unet import UnetGenerator,MinimalUnetV2
def weight_init(m):
if isinstance(m, nn.Conv2d):
nn.init.xavier_normal_(m.weight)
nn.init.constant_(m.bias, 0)
def reset_params(model):
for i, m in enumerate(model.modules()):
weight_init(m)
def conv3x3(in_channels, out_channels, stride=1,
padding=1, bias=True, groups=1):
return nn.Conv2d(
in_channels,
out_channels,
kernel_size=3,
stride=stride,
padding=padding,
bias=bias,
groups=groups)
def up_conv2x2(in_channels, out_channels, transpose=True):
if transpose:
return nn.ConvTranspose2d(
in_channels,
out_channels,
kernel_size=2,
stride=2)
else:
return nn.Sequential(
nn.Upsample(mode='bilinear', scale_factor=2),
conv1x1(in_channels, out_channels))
def conv1x1(in_channels, out_channels, groups=1):
return nn.Conv2d(
in_channels,
out_channels,
kernel_size=1,
groups=groups,
stride=1)
class UpCoXvD(nn.Module):
def __init__(self, in_channels, out_channels, blocks, residual=True, batch_norm=True, transpose=True,concat=True,use_att=False):
super(UpCoXvD, self).__init__()
self.concat = concat
self.residual = residual
self.batch_norm = batch_norm
self.bn = None
self.conv2 = []
self.use_att = use_att
self.up_conv = up_conv2x2(in_channels, out_channels, transpose=transpose)
if self.use_att:
self.s2am = RASC(2 * out_channels)
else:
self.s2am = None
if self.concat:
self.conv1 = conv3x3(2 * out_channels, out_channels)
else:
self.conv1 = conv3x3(out_channels, out_channels)
for _ in range(blocks):
self.conv2.append(conv3x3(out_channels, out_channels))
if self.batch_norm:
self.bn = []
for _ in range(blocks):
self.bn.append(nn.BatchNorm2d(out_channels))
self.bn = nn.ModuleList(self.bn)
self.conv2 = nn.ModuleList(self.conv2)
def forward(self, from_up, from_down, mask=None):
from_up = self.up_conv(from_up)
if self.concat:
x1 = torch.cat((from_up, from_down), 1)
else:
if from_down is not None:
x1 = from_up + from_down
else:
x1 = from_up
if self.use_att:
x1 = self.s2am(x1,mask)
x1 = F.relu(self.conv1(x1))
x2 = None
for idx, conv in enumerate(self.conv2):
x2 = conv(x1)
if self.batch_norm:
x2 = self.bn[idx](x2)
if self.residual:
x2 = x2 + x1
x2 = F.relu(x2)
x1 = x2
return x2
class DownCoXvD(nn.Module):
def __init__(self, in_channels, out_channels, blocks, pooling=True, residual=True, batch_norm=True):
super(DownCoXvD, self).__init__()
self.pooling = pooling
self.residual = residual
self.batch_norm = batch_norm
self.bn = None
self.pool = None
self.conv1 = conv3x3(in_channels, out_channels)
self.conv2 = []
for _ in range(blocks):
self.conv2.append(conv3x3(out_channels, out_channels))
if self.batch_norm:
self.bn = []
for _ in range(blocks):
self.bn.append(nn.BatchNorm2d(out_channels))
self.bn = nn.ModuleList(self.bn)
if self.pooling:
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
self.conv2 = nn.ModuleList(self.conv2)
def __call__(self, x):
return self.forward(x)
def forward(self, x):
x1 = F.relu(self.conv1(x))
x2 = None
for idx, conv in enumerate(self.conv2):
x2 = conv(x1)
if self.batch_norm:
x2 = self.bn[idx](x2)
if self.residual:
x2 = x2 + x1
x2 = F.relu(x2)
x1 = x2
before_pool = x2
if self.pooling:
x2 = self.pool(x2)
return x2, before_pool
class UnetDecoderD(nn.Module):
def __init__(self, in_channels=512, out_channels=3, depth=5, blocks=1, residual=True, batch_norm=True,
transpose=True, concat=True, is_final=True):
super(UnetDecoderD, self).__init__()
self.conv_final = None
self.up_convs = []
outs = in_channels
for i in range(depth-1):
ins = outs
outs = ins // 2
# 512,256
# 256,128
# 128,64
# 64,32
up_conv = UpCoXvD(ins, outs, blocks, residual=residual, batch_norm=batch_norm, transpose=transpose,
concat=concat)
self.up_convs.append(up_conv)
if is_final:
self.conv_final = conv1x1(outs, out_channels)
else:
up_conv = UpCoXvD(outs, out_channels, blocks, residual=residual, batch_norm=batch_norm, transpose=transpose,
concat=concat)
self.up_convs.append(up_conv)
self.up_convs = nn.ModuleList(self.up_convs)
reset_params(self)
def __call__(self, x, encoder_outs=None):
return self.forward(x, encoder_outs)
def forward(self, x, encoder_outs=None):
for i, up_conv in enumerate(self.up_convs):
before_pool = None
if encoder_outs is not None:
before_pool = encoder_outs[-(i+2)]
x = up_conv(x, before_pool)
if self.conv_final is not None:
x = self.conv_final(x)
return x
class UnetEncoderD(nn.Module):
def __init__(self, in_channels=3, depth=5, blocks=1, start_filters=32, residual=True, batch_norm=True):
super(UnetEncoderD, self).__init__()
self.down_convs = []
outs = None
if type(blocks) is tuple:
blocks = blocks[0]
for i in range(depth):
ins = in_channels if i == 0 else outs
outs = start_filters*(2**i)
pooling = True if i < depth-1 else False
down_conv = DownCoXvD(ins, outs, blocks, pooling=pooling, residual=residual, batch_norm=batch_norm)
self.down_convs.append(down_conv)
self.down_convs = nn.ModuleList(self.down_convs)
reset_params(self)
def __call__(self, x):
return self.forward(x)
def forward(self, x):
encoder_outs = []
for d_conv in self.down_convs:
x, before_pool = d_conv(x)
encoder_outs.append(before_pool)
return x, encoder_outs
class UnetVM(nn.Module):
def __init__(self, in_channels=3, depth=5, shared_depth=0, use_vm_decoder=False, blocks=1,
out_channels_image=3, out_channels_mask=1, start_filters=32, residual=True, batch_norm=True,
transpose=True, concat=True, transfer_data=True, long_skip=False):
super(UnetVM, self).__init__()
self.transfer_data = transfer_data
self.shared = shared_depth
self.optimizer_encoder, self.optimizer_image, self.optimizer_vm = None, None, None
self.optimizer_mask, self.optimizer_shared = None, None
if type(blocks) is not tuple:
blocks = (blocks, blocks, blocks, blocks, blocks)
if not transfer_data:
concat = False
self.encoder = UnetEncoderD(in_channels=in_channels, depth=depth, blocks=blocks[0],
start_filters=start_filters, residual=residual, batch_norm=batch_norm)
self.image_decoder = UnetDecoderD(in_channels=start_filters * 2 ** (depth - shared_depth - 1),
out_channels=out_channels_image, depth=depth - shared_depth,
blocks=blocks[1], residual=residual, batch_norm=batch_norm,
transpose=transpose, concat=concat)
self.mask_decoder = UnetDecoderD(in_channels=start_filters * 2 ** (depth - 1),
out_channels=out_channels_mask, depth=depth,
blocks=blocks[2], residual=residual, batch_norm=batch_norm,
transpose=transpose, concat=concat)
self.vm_decoder = None
if use_vm_decoder:
self.vm_decoder = UnetDecoderD(in_channels=start_filters * 2 ** (depth - shared_depth - 1),
out_channels=out_channels_image, depth=depth - shared_depth,
blocks=blocks[3], residual=residual, batch_norm=batch_norm,
transpose=transpose, concat=concat)
self.shared_decoder = None
self.long_skip = long_skip
self._forward = self.unshared_forward
if self.shared != 0:
self._forward = self.shared_forward
self.shared_decoder = UnetDecoderD(in_channels=start_filters * 2 ** (depth - 1),
out_channels=start_filters * 2 ** (depth - shared_depth - 1),
depth=shared_depth, blocks=blocks[4], residual=residual,
batch_norm=batch_norm, transpose=transpose, concat=concat,
is_final=False)
def set_optimizers(self):
self.optimizer_encoder = torch.optim.Adam(self.encoder.parameters(), lr=0.001)
self.optimizer_image = torch.optim.Adam(self.image_decoder.parameters(), lr=0.001)
self.optimizer_mask = torch.optim.Adam(self.mask_decoder.parameters(), lr=0.001)
if self.vm_decoder is not None:
self.optimizer_vm = torch.optim.Adam(self.vm_decoder.parameters(), lr=0.001)
if self.shared != 0:
self.optimizer_shared = torch.optim.Adam(self.shared_decoder.parameters(), lr=0.001)
def zero_grad_all(self):
self.optimizer_encoder.zero_grad()
self.optimizer_image.zero_grad()
self.optimizer_mask.zero_grad()
if self.vm_decoder is not None:
self.optimizer_vm.zero_grad()
if self.shared != 0:
self.optimizer_shared.zero_grad()
def step_all(self):
self.optimizer_encoder.step()
self.optimizer_image.step()
self.optimizer_mask.step()
if self.vm_decoder is not None:
self.optimizer_vm.step()
if self.shared != 0:
self.optimizer_shared.step()
def step_optimizer_image(self):
self.optimizer_image.step()
def __call__(self, synthesized):
return self._forward(synthesized)
def forward(self, synthesized):
return self._forward(synthesized)
def unshared_forward(self, synthesized):
image_code, before_pool = self.encoder(synthesized)
if not self.transfer_data:
before_pool = None
reconstructed_image = torch.tanh(self.image_decoder(image_code, before_pool))
reconstructed_mask = torch.sigmoid(self.mask_decoder(image_code, before_pool))
if self.vm_decoder is not None:
reconstructed_vm = torch.tanh(self.vm_decoder(image_code, before_pool))
return reconstructed_image, reconstructed_mask, reconstructed_vm
return reconstructed_image, reconstructed_mask
def shared_forward(self, synthesized):
image_code, before_pool = self.encoder(synthesized)
if self.transfer_data:
shared_before_pool = before_pool[- self.shared - 1:]
unshared_before_pool = before_pool[: - self.shared]
else:
before_pool = None
shared_before_pool = None
unshared_before_pool = None
x = self.shared_decoder(image_code, shared_before_pool)
reconstructed_image = torch.tanh(self.image_decoder(x, unshared_before_pool))
if self.long_skip:
reconstructed_image = reconstructed_image + synthesized
reconstructed_mask = torch.sigmoid(self.mask_decoder(image_code, before_pool))
if self.vm_decoder is not None:
reconstructed_vm = torch.tanh(self.vm_decoder(x, unshared_before_pool))
if self.long_skip:
reconstructed_vm = reconstructed_vm + synthesized
return reconstructed_image, reconstructed_mask, reconstructed_vm
return reconstructed_image, reconstructed_mask
================================================
FILE: scripts/utils/__init__.py
================================================
from __future__ import absolute_import
from .evaluation import *
from .imutils import *
from .logger import *
from .misc import *
from .osutils import *
from .transforms import *
================================================
FILE: scripts/utils/evaluation.py
================================================
from __future__ import absolute_import
import math
import numpy as np
import matplotlib.pyplot as plt
from random import randint
from .misc import *
from .transforms import transform, transform_preds
__all__ = ['accuracy', 'AverageMeter']
def get_preds(scores):
''' get predictions from score maps in torch Tensor
return type: torch.LongTensor
'''
assert scores.dim() == 4, 'Score maps should be 4-dim'
maxval, idx = torch.max(scores.view(scores.size(0), scores.size(1), -1), 2)
maxval = maxval.view(scores.size(0), scores.size(1), 1)
idx = idx.view(scores.size(0), scores.size(1), 1) + 1
preds = idx.repeat(1, 1, 2).float()
preds[:,:,0] = (preds[:,:,0] - 1) % scores.size(3) + 1
preds[:,:,1] = torch.floor((preds[:,:,1] - 1) / scores.size(2)) + 1
pred_mask = maxval.gt(0).repeat(1, 1, 2).float()
preds *= pred_mask
return preds
def calc_dists(preds, target, normalize):
preds = preds.float()
target = target.float()
dists = torch.zeros(preds.size(1), preds.size(0))
for n in range(preds.size(0)):
for c in range(preds.size(1)):
if target[n,c,0] > 1 and target[n, c, 1] > 1:
dists[c, n] = torch.dist(preds[n,c,:], target[n,c,:])/normalize[n]
else:
dists[c, n] = -1
return dists
def dist_acc(dists, thr=0.5):
''' Return percentage below threshold while ignoring values with a -1 '''
if dists.ne(-1).sum() > 0:
return dists.le(thr).eq(dists.ne(-1)).sum()*1.0 / dists.ne(-1).sum()
else:
return -1
def accuracy(output, target, thr=0.5):
''' Calculate accuracy according to PCK, but uses ground truth heatmap rather than x,y locations
First value to be returned is average accuracy across 'idxs', followed by individual accuracies
'''
# output_mask = torch.gt(output,thr);
# target_mask = torch.gt(target,thr);
# equal_mask = torch.eq(output_mask,target_mask);
# fp_equal_mask = torch.lt(output_mask,target_mask);
# fn_equal_mask = torch.gt(output_mask,target_mask);
# tp = torch.sum(equal_mask);
# fn = torch.sum(fn_equal_mask);
# fp = torch.sum(fp_equal_mask);
# return 2*tp / (2*tp+fn+fp)
if output.dim() > 2:
v,i = torch.max(output,1);
else:
v,i = torch.max(output,1);
return torch.sum(target.long() == i).float()/target.numel()
def final_preds(output, center, scale, res):
coords = get_preds(output) # float type
# pose-processing
for n in range(coords.size(0)):
for p in range(coords.size(1)):
hm = output[n][p]
px = int(math.floor(coords[n][p][0]))
py = int(math.floor(coords[n][p][1]))
if px > 1 and px < res[0] and py > 1 and py < res[1]:
diff = torch.Tensor([hm[py - 1][px] - hm[py - 1][px - 2], hm[py][px - 1]-hm[py - 2][px - 1]])
coords[n][p] += diff.sign() * .25
coords += 0.5
preds = coords.clone()
# Transform back
for i in range(coords.size(0)):
preds[i] = transform_preds(coords[i], center[i], scale[i], res)
if preds.dim() < 3:
preds = preds.view(1, preds.size())
return preds
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
================================================
FILE: scripts/utils/imutils.py
================================================
from __future__ import absolute_import
import torch
import torch.nn as nn
import numpy as np
import scipy.misc
from .misc import *
def im_to_numpy(img):
img = to_numpy(img)
img = np.transpose(img, (1, 2, 0)) # H*W*C
return img
def im_to_torch(img):
img = np.transpose(img, (2, 0, 1)) # C*H*W
img = to_torch(img).float()
if img.max() > 1:
img /= 255
return img
def load_image(img_path):
# H x W x C => C x H x W
return im_to_torch(scipy.misc.imread(img_path, mode='RGB'))
def imread_all(img_path):
return scipy.misc.imread(img_path, mode='RGB')
def load_image_gray(img_path):
# H x W x C => C x H x W
x = scipy.misc.imread(img_path, mode='L')
x = x[:,:,np.newaxis]
return im_to_torch(x)
def resize(img, owidth, oheight):
img = im_to_numpy(img)
if img.shape[2] == 1:
img = scipy.misc.imresize(img.squeeze(),(oheight,owidth))
img = img[:,:,np.newaxis]
else:
img = scipy.misc.imresize(
img,
(oheight, owidth)
)
img = im_to_torch(img)
# print('%f %f' % (img.min(), img.max()))
return img
# =============================================================================
# Helpful functions generating groundtruth labelmap
# =============================================================================
def gaussian(shape=(7,7),sigma=1):
"""
2D gaussian mask - should give the same result as MATLAB's
fspecial('gaussian',[shape],[sigma])
"""
m,n = [(ss-1.)/2. for ss in shape]
y,x = np.ogrid[-m:m+1,-n:n+1]
h = np.exp( -(x*x + y*y) / (2.*sigma*sigma) )
h[ h < np.finfo(h.dtype).eps*h.max() ] = 0
return to_torch(h).float()
def draw_labelmap(img, pt, sigma, type='Gaussian'):
# Draw a 2D gaussian
# Adopted from https://github.com/anewell/pose-hg-train/blob/master/src/pypose/draw.py
img = to_numpy(img)
# Check that any part of the gaussian is in-bounds
ul = [int(pt[0] - 3 * sigma), int(pt[1] - 3 * sigma)]
br = [int(pt[0] + 3 * sigma + 1), int(pt[1] + 3 * sigma + 1)]
if (ul[0] >= img.shape[1] or ul[1] >= img.shape[0] or
br[0] < 0 or br[1] < 0):
# If not, just return the image as is
return to_torch(img)
# Generate gaussian
size = 6 * sigma + 1
x = np.arange(0, size, 1, float)
y = x[:, np.newaxis]
x0 = y0 = size // 2
# The gaussian is not normalized, we want the center value to equal 1
if type == 'Gaussian':
g = np.exp(- ((x - x0) ** 2 + (y - y0) ** 2) / (2 * sigma ** 2))
elif type == 'Cauchy':
g = sigma / (((x - x0) ** 2 + (y - y0) ** 2 + sigma ** 2) ** 1.5)
# Usable gaussian range
g_x = max(0, -ul[0]), min(br[0], img.shape[1]) - ul[0]
g_y = max(0, -ul[1]), min(br[1], img.shape[0]) - ul[1]
# Image range
img_x = max(0, ul[0]), min(br[0], img.shape[1])
img_y = max(0, ul[1]), min(br[1], img.shape[0])
img[img_y[0]:img_y[1], img_x[0]:img_x[1]] = g[g_y[0]:g_y[1], g_x[0]:g_x[1]]
return to_torch(img)
# =============================================================================
# Helpful display functions
# =============================================================================
def gauss(x, a, b, c, d=0):
return a * np.exp(-(x - b)**2 / (2 * c**2)) + d
def color_heatmap(x):
x = to_numpy(x)
color = np.zeros((x.shape[0],x.shape[1],3))
color[:,:,0] = gauss(x, .5, .6, .2) + gauss(x, 1, .8, .3)
color[:,:,1] = gauss(x, 1, .5, .3)
color[:,:,2] = gauss(x, 1, .2, .3)
color[color > 1] = 1
color = (color * 255).astype(np.uint8)
return color
def imshow(img):
npimg = im_to_numpy(img*255).astype(np.uint8)
plt.imshow(npimg)
plt.axis('off')
def show_joints(img, pts):
imshow(img)
for i in range(pts.size(0)):
if pts[i, 2] > 0:
plt.plot(pts[i, 0], pts[i, 1], 'yo')
plt.axis('off')
def show_sample(inputs, target):
num_sample = inputs.size(0)
num_joints = target.size(1)
height = target.size(2)
width = target.size(3)
for n in range(num_sample):
inp = resize(inputs[n], width, height)
out = inp
for p in range(num_joints):
tgt = inp*0.5 + color_heatmap(target[n,p,:,:])*0.5
out = torch.cat((out, tgt), 2)
imshow(out)
plt.show()
def sample_with_heatmap(inp, out, num_rows=2, parts_to_show=None):
inp = to_numpy(inp * 255)
out = to_numpy(out)
img = np.zeros((inp.shape[1], inp.shape[2], inp.shape[0]))
for i in range(3):
img[:, :, i] = inp[i, :, :]
if parts_to_show is None:
parts_to_show = np.arange(out.shape[0])
# Generate a single image to display input/output pair
num_cols = int(np.ceil(float(len(parts_to_show)) / num_rows))
size = img.shape[0] // num_rows
full_img = np.zeros((img.shape[0], size * (num_cols + num_rows), 3), np.uint8)
full_img[:img.shape[0], :img.shape[1]] = img
inp_small = scipy.misc.imresize(img, [size, size])
# Set up heatmap display for each part
for i, part in enumerate(parts_to_show):
part_idx = part
out_resized = scipy.misc.imresize(out[part_idx], [size, size])
out_resized = out_resized.astype(float)/255
out_img = inp_small.copy() * .3
color_hm = color_heatmap(out_resized)
out_img += color_hm * .7
col_offset = (i % num_cols + num_rows) * size
row_offset = (i // num_cols) * size
full_img[row_offset:row_offset + size, col_offset:col_offset + size] = out_img
return full_img
def batch_with_heatmap(inputs, outputs, mean=torch.Tensor([0.5, 0.5, 0.5]), num_rows=2, parts_to_show=None):
batch_img = []
for n in range(min(inputs.size(0), 4)):
inp = inputs[n] + mean.view(3, 1, 1).expand_as(inputs[n])
batch_img.append(
sample_with_heatmap(inp.clamp(0, 1), outputs[n], num_rows=num_rows, parts_to_show=parts_to_show)
)
return np.concatenate(batch_img)
def normalize_batch(batch):
# normalize using imagenet mean and std
mean = batch.new_tensor([0.485, 0.456, 0.406]).view(-1, 1, 1)
std = batch.new_tensor([0.229, 0.224, 0.225]).view(-1, 1, 1)
batch = batch/255.0
return (batch - mean) / std
def show_image_tensor(tensor):
re = []
for i in range(tensor.size(0)):
inp = tensor[i].data.cpu() #w,h,c
inp = inp.numpy().transpose((1, 2, 0))
mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])
inp = std * inp + mean
inp = np.clip(inp, 0, 1).transpose((2,0,1))
re.append(torch.from_numpy(inp).unsqueeze(0))
return torch.cat(re,0)
def get_jet():
colormap_int = np.zeros((256, 3), np.uint8)
for i in range(0, 256, 1):
colormap_int[i, 0] = np.int_(np.round(cm.jet(i)[0] * 255.0))
colormap_int[i, 1] = np.int_(np.round(cm.jet(i)[1] * 255.0))
colormap_int[i, 2] = np.int_(np.round(cm.jet(i)[2] * 255.0))
return colormap_int
def clamp(num, min_value, max_value):
return max(min(num, max_value), min_value)
def gray2color(gray_array, color_map):
rows, cols = gray_array.shape
color_array = np.zeros((rows, cols, 3), np.uint8)
for i in range(0, rows):
for j in range(0, cols):
# log(256,2) = 8 , log(1,2) = 0 * 8
color_array[i, j] = color_map[clamp(int(abs(gray_array[i, j])*10),0,255)]
return color_array
class objectview(object):
def __init__(self, *args, **kwargs):
d = dict(*args, **kwargs)
self.__dict__ = d
================================================
FILE: scripts/utils/logger.py
================================================
# A simple torch style logger
# (C) Wei YANG 2017
from __future__ import absolute_import
import os
import sys
import numpy as np
import matplotlib.pyplot as plt
__all__ = ['Logger', 'LoggerMonitor', 'savefig']
def savefig(fname, dpi=None):
dpi = 150 if dpi == None else dpi
plt.savefig(fname, dpi=dpi)
def plot_overlap(logger, names=None):
names = logger.names if names == None else names
numbers = logger.numbers
for _, name in enumerate(names):
x = np.arange(len(numbers[name]))
plt.plot(x, np.asarray(numbers[name]))
return [logger.title + '(' + name + ')' for name in names]
class Logger(object):
'''Save training process to log file with simple plot function.'''
def __init__(self, fpath, title=None, resume=False):
self.file = None
self.resume = resume
self.title = '' if title == None else title
if fpath is not None:
if resume:
self.file = open(fpath, 'r')
name = self.file.readline()
self.names = name.rstrip().split('\t')
self.numbers = {}
for _, name in enumerate(self.names):
self.numbers[name] = []
for numbers in self.file:
numbers = numbers.rstrip().split('\t')
for i in range(0, len(numbers)):
self.numbers[self.names[i]].append(numbers[i])
self.file.close()
self.file = open(fpath, 'a')
else:
self.file = open(fpath, 'w')
def set_names(self, names):
if self.resume:
pass
# initialize numbers as empty list
self.numbers = {}
self.names = names
for _, name in enumerate(self.names):
self.file.write(name)
self.file.write('\t')
self.numbers[name] = []
self.file.write('\n')
self.file.flush()
def append(self, numbers):
assert len(self.names) == len(numbers), 'Numbers do not match names'
for index, num in enumerate(numbers):
self.file.write("{0:.6f}".format(num))
self.file.write('\t')
self.numbers[self.names[index]].append(num)
self.file.write('\n')
self.file.flush()
def plot(self, names=None):
names = self.names if names == None else names
numbers = self.numbers
for _, name in enumerate(names):
x = np.arange(len(numbers[name]))
plt.plot(x, np.asarray(numbers[name]))
plt.legend([self.title + '(' + name + ')' for name in names])
plt.grid(True)
def close(self):
if self.file is not None:
self.file.close()
class LoggerMonitor(object):
'''Load and visualize multiple logs.'''
def __init__ (self, paths):
'''paths is a distionary with {name:filepath} pair'''
self.loggers = []
for title, path in paths.items():
logger = Logger(path, title=title, resume=True)
self.loggers.append(logger)
def plot(self, names=None):
plt.figure()
plt.subplot(121)
legend_text = []
for logger in self.loggers:
legend_text += plot_overlap(logger, names)
plt.legend(legend_text, bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.)
plt.grid(True)
if __name__ == '__main__':
# # Example
# logger = Logger('test.txt')
# logger.set_names(['Train loss', 'Valid loss','Test loss'])
# length = 100
# t = np.arange(length)
# train_loss = np.exp(-t / 10.0) + np.random.rand(length) * 0.1
# valid_loss = np.exp(-t / 10.0) + np.random.rand(length) * 0.1
# test_loss = np.exp(-t / 10.0) + np.random.rand(length) * 0.1
# for i in range(0, length):
# logger.append([train_loss[i], valid_loss[i], test_loss[i]])
# logger.plot()
# Example: logger monitor
paths = {
'resadvnet20':'/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet20/log.txt',
'resadvnet32':'/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet32/log.txt',
'resadvnet44':'/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet44/log.txt',
}
field = ['Valid Acc.']
monitor = LoggerMonitor(paths)
monitor.plot(names=field)
savefig('test.eps')
================================================
FILE: scripts/utils/losses.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
from scripts.models.vgg import Vgg19
from torchvision import models
from scripts.utils.misc import resize_to_match
# from pytorch_msssim import SSIM, MS_SSIM
import pytorch_ssim
class WeightedBCE(nn.Module):
def __init__(self):
super(WeightedBCE, self).__init__()
def forward(self, pred, gt):
eposion = 1e-10
sigmoid_pred = torch.sigmoid(pred)
count_pos = torch.sum(gt)*1.0+eposion
count_neg = torch.sum(1.-gt)*1.0
beta = count_neg/count_pos
beta_back = count_pos / (count_pos + count_neg)
bce1 = nn.BCEWithLogitsLoss(pos_weight=beta)
loss = beta_back*bce1(pred, gt)
return loss
def l1_relative(reconstructed, real, mask):
batch = real.size(0)
area = torch.sum(mask.view(batch,-1),dim=1)
reconstructed = reconstructed * mask
real = real * mask
loss_l1 = torch.abs(reconstructed - real).view(batch, -1)
loss_l1 = torch.sum(loss_l1, dim=1) / area
loss_l1 = torch.sum(loss_l1) / batch
return loss_l1
def is_dic(x):
return type(x) == type([])
class Losses(nn.Module):
def __init__(self, argx, device):
super(Losses, self).__init__()
self.args = argx
if self.args.loss_type == 'l1bl2':
self.outputLoss, self.attLoss, self.wrloss = nn.L1Loss(), nn.BCELoss(), nn.MSELoss()
elif self.args.loss_type == 'l1wbl2':
self.outputLoss, self.attLoss, self.wrloss = nn.L1Loss(), WeightedBCE(), nn.MSELoss()
elif self.args.loss_type == 'l2wbl2':
self.outputLoss, self.attLoss, self.wrloss = nn.MSELoss(), WeightedBCE(), nn.MSELoss()
elif self.args.loss_type == 'l2xbl2':
self.outputLoss, self.attLoss, self.wrloss = nn.MSELoss(), nn.BCEWithLogitsLoss(), nn.MSELoss()
else: # l2bl2
self.outputLoss, self.attLoss, self.wrloss = nn.MSELoss(), nn.BCELoss(), nn.MSELoss()
if self.args.style_loss > 0:
self.vggloss = VGGLoss(self.args.sltype).to(device)
if self.args.ssim_loss > 0:
self.ssimloss = pytorch_ssim.SSIM().to(device)
self.outputLoss = self.outputLoss.to(device)
self.attLoss = self.attLoss.to(device)
self.wrloss = self.wrloss.to(device)
def forward(self,imgx,target,attx,mask,wmx,wm):
pixel_loss,att_loss,wm_loss,vgg_loss,ssim_loss = 0,0,0,0,0
if is_dic(imgx):
if self.args.masked:
# calculate the overall loss and side output
pixel_loss = self.outputLoss(imgx[0],target) + sum([self.outputLoss(im,resize_to_match(mask,im)*resize_to_match(target,im)) for im in imgx[1:]])
else:
pixel_loss = sum([self.outputLoss(im,resize_to_match(target,im)) for im in imgx])
if self.args.style_loss > 0:
vgg_loss = sum([self.vggloss(im,resize_to_match(target,im),resize_to_match(mask,im)) for im in imgx])
if self.args.ssim_loss > 0:
ssim_loss = sum([ 1 - self.ssimloss(im,resize_to_match(target,im)) for im in imgx])
else:
if self.args.masked:
pixel_loss = self.outputLoss(imgx,mask*target)
else:
pixel_loss = self.outputLoss(imgx,target)
if self.args.style_loss > 0:
vgg_loss = self.vggloss(imgx,target,mask)
if self.args.ssim_loss > 0:
ssim_loss = 1 - self.ssimloss(imgx,target)
if is_dic(attx):
att_loss = sum([self.attLoss(at,resize_to_match(mask,at)) for at in attx])
else:
att_loss = self.attLoss(attx, mask)
if is_dic(wmx):
wm_loss = sum([self.wrloss(w,resize_to_match(wm,w)) for w in wmx])
else:
if self.args.masked:
wm_loss = self.wrloss(wmx,mask*wm)
else:
wm_loss = self.wrloss(wmx, wm)
return pixel_loss,att_loss,wm_loss,vgg_loss,ssim_loss
def gram_matrix(feat):
# https://github.com/pytorch/examples/blob/master/fast_neural_style/neural_style/utils.py
(b, ch, h, w) = feat.size()
feat = feat.view(b, ch, h * w)
feat_t = feat.transpose(1, 2)
gram = torch.bmm(feat, feat_t) / (ch * h * w)
return gram
class MeanShift(nn.Conv2d):
def __init__(self, data_mean, data_std, data_range=1, norm=True):
"""norm (bool): normalize/denormalize the stats"""
c = len(data_mean)
super(MeanShift, self).__init__(c, c, kernel_size=1)
std = torch.Tensor(data_std)
self.weight.data = torch.eye(c).view(c, c, 1, 1)
if norm:
self.weight.data.div_(std.view(c, 1, 1, 1))
self.bias.data = -1 * data_range * torch.Tensor(data_mean)
self.bias.data.div_(std)
else:
self.weight.data.mul_(std.view(c, 1, 1, 1))
self.bias.data = data_range * torch.Tensor(data_mean)
self.requires_grad = False
def VGGLoss(losstype):
if losstype == 'vgg':
return VGGLossA()
elif losstype == 'vggx':
return VGGLossX(mask=False)
elif losstype == 'mvggx':
return VGGLossX(mask=True)
elif losstype == 'rvggx':
return VGGLossX(mask=True,relative=True)
else:
raise Exception("error in %s"%losstype)
class VGGLossA(nn.Module):
def __init__(self, vgg=None, weights=None, indices=None, normalize=True):
super(VGGLossA, self).__init__()
if vgg is None:
self.vgg = Vgg19().cuda()
else:
self.vgg = vgg
self.criterion = nn.L1Loss()
self.weights = weights or [1.0/2.6, 1.0/4.8, 1.0/3.7, 1.0/5.6, 10/1.5]
self.indices = indices or [2, 7, 12, 21, 30]
if normalize:
self.normalize = MeanShift([0.485, 0.456, 0.406], [0.229, 0.224, 0.225], norm=True).cuda()
else:
self.normalize = None
def forward(self, x, y):
if self.normalize is not None:
x = self.normalize(x)
y = self.normalize(y)
x_vgg, y_vgg = self.vgg(x, self.indices), self.vgg(y, self.indices)
loss = 0
for i in range(len(x_vgg)):
loss += self.weights[i] * self.criterion(x_vgg[i], y_vgg[i].detach())
return loss
class VGG16FeatureExtractor(nn.Module):
def __init__(self):
super().__init__()
vgg16 = models.vgg16(pretrained=True)
self.enc_1 = nn.Sequential(*vgg16.features[:5])
self.enc_2 = nn.Sequential(*vgg16.features[5:10])
self.enc_3 = nn.Sequential(*vgg16.features[10:17])
# fix the encoder
for i in range(3):
for param in getattr(self, 'enc_{:d}'.format(i + 1)).parameters():
param.requires_grad = False
def forward(self, image):
results = [image]
for i in range(3):
func = getattr(self, 'enc_{:d}'.format(i + 1))
results.append(func(results[-1]))
return results[1:]
class VGGLossX(nn.Module):
def __init__(self, normalize=True, mask=False, relative=False):
super(VGGLossX, self).__init__()
self.vgg = VGG16FeatureExtractor().cuda()
self.criterion = nn.L1Loss().cuda() if not relative else l1_relative
self.use_mask= mask
self.relative = relative
if normalize:
self.normalize = MeanShift([0.485, 0.456, 0.406], [0.229, 0.224, 0.225], norm=True).cuda()
else:
self.normalize = None
def forward(self, x, y, Xmask=None):
if not self.use_mask:
mask = torch.ones_like(x)[:,0:1,:,:]
else:
mask = Xmask
if self.normalize is not None:
x = self.normalize(x)
y = self.normalize(y)
x_vgg = self.vgg(x)
y_vgg = self.vgg(y)
loss = 0
for i in range(3):
if self.relative:
loss += self.criterion(x_vgg[i],y_vgg[i].detach(),resize_to_match(mask,x_vgg[i]))
else:
loss += self.criterion(resize_to_match(mask,x_vgg[i])*x_vgg[i],resize_to_match(mask,y_vgg[i])*y_vgg[i].detach())
return loss
class GANLosses(object):
"""docstring for Loss"""
def __init__(self, gantype):
super(GANLosses, self).__init__()
self.generator_loss = gen_gan(gantype)
self.discriminator_loss = dis_gan(gantype)
self.gantype = gantype
def g_loss(self,dis_fake):
if 'hinge' in self.gantype:
return gen_hinge(dis_fake)
else:
return self.generator_loss(dis_fake)
def d_loss(self,dis_fake,dis_real):
if 'hinge' in self.gantype:
return dis_hinge(dis_fake,dis_real)
else:
return self.discriminator_loss(dis_fake,dis_real)
class gen_gan(nn.Module):
def __init__(self,gantype):
super(gen_gan,self).__init__()
if gantype == 'lsgan':
self.criterion = nn.MSELoss()
elif gantype == 'naive':
self.criterion = nn.BCEWithLogitsLoss()
else:
raise Exception("error gan type")
def forward(self,dis_fake):
return self.criterion(dis_fake, torch.ones_like(dis_fake))
class dis_gan(nn.Module):
def __init__(self,gantype):
super(dis_gan,self).__init__()
if gantype == 'lsgan':
self.criterion = nn.MSELoss()
elif gantype == 'naive':
self.criterion = nn.BCEWithLogitsLoss()
else:
raise Exception("error gan type")
def forward(self,dis_fake,dis_real):
loss_fake = self.criterion(dis_fake, torch.zeros_like(dis_fake))
loss_real = self.criterion(dis_real, torch.ones_like(dis_real))
return loss_fake, loss_real
# def gen_gan(dis_fake):
# # fake -> 1
# return F.binary_cross_entropy_with_logits(dis_fake,torch.ones_like(dis_fake))
# def dis_gan(dis_fake,dis_real):
# # fake -> 0 , real ->1
# loss_fake = F.binary_cross_entropy_with_logits(dis_fake, torch.zeros_like(dis_real))
# loss_real = F.binary_cross_entropy_with_logits(dis_real, torch.ones_like(dis_fake))
# return loss_fake,loss_real
# def gen_lsgan(dis_fake):
# loss = F.mse_loss(dis_fake,torch.ones_like(dis_fake)) #
# return loss
# def dis_lsgan(dis_fake, dis_real):
# loss_fake = F.mse_loss(dis_fake, torch.zeros_like(dis_real))
# loss_real = F.mse_loss(dis_real, torch.ones_like(dis_real))
# return loss_fake,loss_real
def gen_hinge(dis_fake, dis_real=None):
return -torch.mean(dis_fake)
def dis_hinge(dis_fake, dis_real):
loss_fake = torch.mean(torch.relu(1. + dis_fake))
loss_real = torch.mean(torch.relu(1. - dis_real))
return loss_fake,loss_real
================================================
FILE: scripts/utils/misc.py
================================================
from __future__ import absolute_import
import os
import shutil
import torch
import math
import numpy as np
import scipy.io
import matplotlib.pyplot as plt
import torch.nn.functional as F
def to_numpy(tensor):
if torch.is_tensor(tensor):
return tensor.cpu().numpy()
elif type(tensor).__module__ != 'numpy':
raise ValueError("Cannot convert {} to numpy array"
.format(type(tensor)))
return tensor
def resize_to_match(fm,to):
# just use interpolate
# [1,3] = (h,w)
return F.interpolate(fm,to.size()[-2:],mode='bilinear',align_corners=False)
def to_torch(ndarray):
if type(ndarray).__module__ == 'numpy':
return torch.from_numpy(ndarray)
elif not torch.is_tensor(ndarray):
raise ValueError("Cannot convert {} to torch tensor"
.format(type(ndarray)))
return ndarray
def save_checkpoint(machine,filename='checkpoint.pth.tar', snapshot=None):
is_best = True if machine.best_acc < machine.metric else False
if is_best:
machine.best_acc = machine.metric
state = {
'epoch': machine.current_epoch + 1,
'arch': machine.args.arch,
'state_dict': machine.model.state_dict(),
'best_acc': machine.best_acc,
'optimizer' : machine.optimizer.state_dict(),
}
filepath = os.path.join(machine.args.checkpoint, filename)
torch.save(state, filepath)
if snapshot and state['epoch'] % snapshot == 0:
shutil.copyfile(filepath, os.path.join(machine.args.checkpoint, 'checkpoint_{}.pth.tar'.format(state.epoch)))
if is_best:
machine.best_acc = machine.metric
print('Saving Best Metric with PSNR:%s'%machine.best_acc)
shutil.copyfile(filepath, os.path.join(machine.args.checkpoint, 'model_best.pth.tar'))
def save_pred(preds, checkpoint='checkpoint', filename='preds_valid.mat'):
preds = to_numpy(preds)
filepath = os.path.join(checkpoint, filename)
scipy.io.savemat(filepath, mdict={'preds' : preds})
def adjust_learning_rate(datasets,optimizer, epoch, lr,args):
"""Sets the learning rate to the initial LR decayed by schedule"""
if epoch in args.schedule:
lr *= args.gamma
for param_group in optimizer.param_groups:
param_group['lr'] = lr
# decay sigma
for dset in datasets:
if args.sigma_decay > 0:
dset.dataset.sigma *= args.sigma_decay
dset.dataset.sigma *= args.sigma_decay
return lr
================================================
FILE: scripts/utils/model_init.py
================================================
from torch.nn import init
def weights_init_normal(m):
classname = m.__class__.__name__
# print(classname)
if classname.find('Conv') != -1:
init.normal_(m.weight.data, 0.0, 0.02)
elif classname.find('Linear') != -1:
init.normal_(m.weight.data, 0.0, 0.02)
elif classname.find('BatchNorm2d') != -1:
init.normal_(m.weight.data, 1.0, 0.02)
init.constant_(m.bias.data, 0.0)
def weights_init_xavier(m):
classname = m.__class__.__name__
# print(classname)
if classname.find('Conv') != -1:
init.xavier_normal(m.weight.data, gain=0.02)
elif classname.find('Linear') != -1:
init.xavier_normal(m.weight.data, gain=0.02)
# elif classname.find('BatchNorm2d') != -1:
# init.normal(m.weight.data, 1.0, 0.02)
# init.constant(m.bias.data, 0.0)
def weights_init_kaiming(m):
classname = m.__class__.__name__
# print(classname)
if classname.find('Conv') != -1 and m.weight.requires_grad == True:
init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
elif classname.find('Linear') != -1 and m.weight.requires_grad == True:
init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
elif classname.find('BatchNorm2d') != -1 and m.weight.requires_grad == True:
init.normal_(m.weight.data, 1.0, 0.02)
init.constant_(m.bias.data, 0.0)
def weights_init_orthogonal(m):
classname = m.__class__.__name__
if classname.find('Conv') != -1:
init.orthogonal(m.weight.data, gain=1)
elif classname.find('Linear') != -1:
init.orthogonal(m.weight.data, gain=1)
# elif classname.find('BatchNorm2d') != -1:
# init.normal(m.weight.data, 1.0, 0.02)
# init.constant(m.bias.data, 0.0)
================================================
FILE: scripts/utils/osutils.py
================================================
from __future__ import absolute_import
import os
import errno
def mkdir_p(dir_path):
try:
os.makedirs(dir_path)
except OSError as e:
if e.errno != errno.EEXIST:
raise
def isfile(fname):
return os.path.isfile(fname)
def isdir(dirname):
return os.path.isdir(dirname)
def join(path, *paths):
return os.path.join(path, *paths)
================================================
FILE: scripts/utils/parallel.py
================================================
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
## Created by: Hang Zhang, Rutgers University, Email: zhang.hang@rutgers.edu
## Modified by Thomas Wolf, HuggingFace Inc., Email: thomas@huggingface.co
## Copyright (c) 2017-2018
##
## This source code is licensed under the MIT-style license found in the
## LICENSE file in the root directory of this source tree
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
"""Encoding Data Parallel"""
import threading
import functools
import torch
from torch.autograd import Variable, Function
import torch.cuda.comm as comm
from torch.nn.parallel import DistributedDataParallel
from torch.nn.parallel.data_parallel import DataParallel
from torch.nn.parallel.parallel_apply import get_a_var
from torch.nn.parallel.scatter_gather import gather
from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast
torch_ver = torch.__version__[:3]
__all__ = ['allreduce', 'DataParallelModel', 'DataParallelCriterion',
'patch_replication_callback']
def allreduce(*inputs):
"""Cross GPU all reduce autograd operation for calculate mean and
variance in SyncBN.
"""
return AllReduce.apply(*inputs)
class AllReduce(Function):
@staticmethod
def forward(ctx, num_inputs, *inputs):
ctx.num_inputs = num_inputs
ctx.target_gpus = [inputs[i].get_device() for i in range(0, len(inputs), num_inputs)]
inputs = [inputs[i:i + num_inputs]
for i in range(0, len(inputs), num_inputs)]
# sort before reduce sum
inputs = sorted(inputs, key=lambda i: i[0].get_device())
results = comm.reduce_add_coalesced(inputs, ctx.target_gpus[0])
outputs = comm.broadcast_coalesced(results, ctx.target_gpus)
return tuple([t for tensors in outputs for t in tensors])
@staticmethod
def backward(ctx, *inputs):
inputs = [i.data for i in inputs]
inputs = [inputs[i:i + ctx.num_inputs]
for i in range(0, len(inputs), ctx.num_inputs)]
results = comm.reduce_add_coalesced(inputs, ctx.target_gpus[0])
outputs = comm.broadcast_coalesced(results, ctx.target_gpus)
return (None,) + tuple([Variable(t) for tensors in outputs for t in tensors])
class Reduce(Function):
@staticmethod
def forward(ctx, *inputs):
ctx.target_gpus = [inputs[i].get_device() for i in range(len(inputs))]
inputs = sorted(inputs, key=lambda i: i.get_device())
return comm.reduce_add(inputs)
@staticmethod
def backward(ctx, gradOutput):
return Broadcast.apply(ctx.target_gpus, gradOutput)
class DistributedDataParallelModel(DistributedDataParallel):
"""Implements data parallelism at the module level for the DistributedDataParallel module.
This container parallelizes the application of the given module by
splitting the input across the specified devices by chunking in the
batch dimension.
In the forward pass, the module is replicated on each device,
and each replica handles a portion of the input. During the backwards pass,
gradients from each replica are summed into the original module.
Note that the outputs are not gathered, please use compatible
:class:`encoding.parallel.DataParallelCriterion`.
The batch size should be larger than the number of GPUs used. It should
also be an integer multiple of the number of GPUs so that each chunk is
the same size (so that each GPU processes the same number of samples).
Args:
module: module to be parallelized
device_ids: CUDA devices (default: all devices)
Reference:
Hang Zhang, Kristin Dana, Jianping Shi, Zhongyue Zhang, Xiaogang Wang, Ambrish Tyagi,
Amit Agrawal. “Context Encoding for Semantic Segmentation.
*The IEEE Conference on Computer Vision and Pattern Recognition (CVPR) 2018*
Example::
>>> net = encoding.nn.DistributedDataParallelModel(model, device_ids=[0, 1, 2])
>>> y = net(x)
"""
def gather(self, outputs, output_device):
return outputs
class DataParallelModel(DataParallel):
"""Implements data parallelism at the module level.
This container parallelizes the application of the given module by
splitting the input across the specified devices by chunking in the
batch dimension.
In the forward pass, the module is replicated on each device,
and each replica handles a portion of the input. During the backwards pass,
gradients from each replica are summed into the original module.
Note that the outputs are not gathered, please use compatible
:class:`encoding.parallel.DataParallelCriterion`.
The batch size should be larger than the number of GPUs used. It should
also be an integer multiple of the number of GPUs so that each chunk is
the same size (so that each GPU processes the same number of samples).
Args:
module: module to be parallelized
device_ids: CUDA devices (default: all devices)
Reference:
Hang Zhang, Kristin Dana, Jianping Shi, Zhongyue Zhang, Xiaogang Wang, Ambrish Tyagi,
Amit Agrawal. “Context Encoding for Semantic Segmentation.
*The IEEE Conference on Computer Vision and Pattern Recognition (CVPR) 2018*
Example::
>>> net = encoding.nn.DataParallelModel(model, device_ids=[0, 1, 2])
>>> y = net(x)
"""
def gather(self, outputs, output_device):
return outputs
def replicate(self, module, device_ids):
modules = super(DataParallelModel, self).replicate(module, device_ids)
execute_replication_callbacks(modules)
return modules
class DataParallelCriterion(DataParallel):
"""
Calculate loss in multiple-GPUs, which balance the memory usage.
The targets are splitted across the specified devices by chunking in
the batch dimension. Please use together with :class:`encoding.parallel.DataParallelModel`.
Reference:
Hang Zhang, Kristin Dana, Jianping Shi, Zhongyue Zhang, Xiaogang Wang, Ambrish Tyagi,
Amit Agrawal. “Context Encoding for Semantic Segmentation.
*The IEEE Conference on Computer Vision and Pattern Recognition (CVPR) 2018*
Example::
>>> net = encoding.nn.DataParallelModel(model, device_ids=[0, 1, 2])
>>> criterion = encoding.nn.DataParallelCriterion(criterion, device_ids=[0, 1, 2])
>>> y = net(x)
>>> loss = criterion(y, target)
"""
def forward(self, inputs, *targets, **kwargs):
# input should be already scatterd
# scattering the targets instead
if not self.device_ids:
return self.module(inputs, *targets, **kwargs)
targets, kwargs = self.scatter(targets, kwargs, self.device_ids)
if len(self.device_ids) == 1:
return self.module(inputs, *targets[0], **kwargs[0])
replicas = self.replicate(self.module, self.device_ids[:len(inputs)])
outputs = _criterion_parallel_apply(replicas, inputs, targets, kwargs)
#return Reduce.apply(*outputs) / len(outputs)
#return self.gather(outputs, self.output_device).mean()
return self.gather(outputs, self.output_device)
def _criterion_parallel_apply(modules, inputs, targets, kwargs_tup=None, devices=None):
assert len(modules) == len(inputs)
assert len(targets) == len(inputs)
if kwargs_tup:
assert len(modules) == len(kwargs_tup)
else:
kwargs_tup = ({},) * len(modules)
if devices is not None:
assert len(modules) == len(devices)
else:
devices = [None] * len(modules)
lock = threading.Lock()
results = {}
if torch_ver != "0.3":
grad_enabled = torch.is_grad_enabled()
def _worker(i, module, input, target, kwargs, device=None):
if torch_ver != "0.3":
torch.set_grad_enabled(grad_enabled)
if device is None:
device = get_a_var(input).get_device()
try:
with torch.cuda.device(device):
# this also avoids accidental slicing of `input` if it is a Tensor
if not isinstance(input, (list, tuple)):
input = (input,)
if not isinstance(target, (list, tuple)):
target = (target,)
output = module(*(input + target), **kwargs)
with lock:
results[i] = output
except Exception as e:
with lock:
results[i] = e
if len(modules) > 1:
threads = [threading.Thread(target=_worker,
args=(i, module, input, target,
kwargs, device),)
for i, (module, input, target, kwargs, device) in
enumerate(zip(modules, inputs, targets, kwargs_tup, devices))]
for thread in threads:
thread.start()
for thread in threads:
thread.join()
else:
_worker(0, modules[0], inputs[0], kwargs_tup[0], devices[0])
outputs = []
for i in range(len(inputs)):
output = results[i]
if isinstance(output, Exception):
raise output
outputs.append(output)
return outputs
###########################################################################
# Adapted from Synchronized-BatchNorm-PyTorch.
# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
#
class CallbackContext(object):
pass
def execute_replication_callbacks(modules):
"""
Execute an replication callback `__data_parallel_replicate__` on each module created
by original replication.
The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)`
Note that, as all modules are isomorphism, we assign each sub-module with a context
(shared among multiple copies of this module on different devices).
Through this context, different copies can share some information.
We guarantee that the callback on the master copy (the first copy) will be called ahead
of calling the callback of any slave copies.
"""
master_copy = modules[0]
nr_modules = len(list(master_copy.modules()))
ctxs = [CallbackContext() for _ in range(nr_modules)]
for i, module in enumerate(modules):
for j, m in enumerate(module.modules()):
if hasattr(m, '__data_parallel_replicate__'):
m.__data_parallel_replicate__(ctxs[j], i)
def patch_replication_callback(data_parallel):
"""
Monkey-patch an existing `DataParallel` object. Add the replication callback.
Useful when you have customized `DataParallel` implementation.
Examples:
> sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
> sync_bn = DataParallel(sync_bn, device_ids=[0, 1])
> patch_replication_callback(sync_bn)
# this is equivalent to
> sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
> sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
"""
assert isinstance(data_parallel, DataParallel)
old_replicate = data_parallel.replicate
@functools.wraps(old_replicate)
def new_replicate(module, device_ids):
modules = old_replicate(module, device_ids)
execute_replication_callbacks(modules)
return modules
data_parallel.replicate = new_replicate
================================================
FILE: scripts/utils/transforms.py
================================================
from __future__ import absolute_import
import os
import numpy as np
import scipy.misc
import matplotlib.pyplot as plt
import torch
import torchvision
from .misc import *
from .imutils import *
def color_normalize(x, mean, std):
if x.size(0) == 1:
x = x.repeat(3, x.size(1), x.size(2))
for t, m, s in zip(x, mean, std):
t.sub_(m)
return x
def flip_back(flip_output, dataset='mpii'):
"""
flip output map
"""
if dataset == 'mpii':
matchedParts = (
[0,5], [1,4], [2,3],
[10,15], [11,14], [12,13]
)
else:
print('Not supported dataset: ' + dataset)
# flip output horizontally
flip_output = fliplr(flip_output.numpy())
# Change left-right parts
for pair in matchedParts:
tmp = np.copy(flip_output[:, pair[0], :, :])
flip_output[:, pair[0], :, :] = flip_output[:, pair[1], :, :]
flip_output[:, pair[1], :, :] = tmp
return torch.from_numpy(flip_output).float()
def shufflelr(x, width, dataset='mpii'):
"""
flip coords
"""
if dataset == 'mpii':
matchedParts = (
[0,5], [1,4], [2,3],
[10,15], [11,14], [12,13]
)
else:
print('Not supported dataset: ' + dataset)
# Flip horizontal
x[:, 0] = width - x[:, 0]
# Change left-right parts
for pair in matchedParts:
tmp = x[pair[0], :].clone()
x[pair[0], :] = x[pair[1], :]
x[pair[1], :] = tmp
return x
def fliplr(x):
if x.ndim == 3:
x = np.transpose(np.fliplr(np.transpose(x, (0, 2, 1))), (0, 2, 1))
elif x.ndim == 4:
for i in range(x.shape[0]):
x[i] = np.transpose(np.fliplr(np.transpose(x[i], (0, 2, 1))), (0, 2, 1))
return x.astype(float)
def get_transform(center, scale, res, rot=0):
"""
General image processing functions
"""
# Generate transformation matrix
h = 200 * scale
t = np.zeros((3, 3))
t[0, 0] = float(res[1]) / h
t[1, 1] = float(res[0]) / h
t[0, 2] = res[1] * (-float(center[0]) / h + .5)
t[1, 2] = res[0] * (-float(center[1]) / h + .5)
t[2, 2] = 1
if not rot == 0:
rot = -rot # To match direction of rotation from cropping
rot_mat = np.zeros((3,3))
rot_rad = rot * np.pi / 180
sn,cs = np.sin(rot_rad), np.cos(rot_rad)
rot_mat[0,:2] = [cs, -sn]
rot_mat[1,:2] = [sn, cs]
rot_mat[2,2] = 1
# Need to rotate around center
t_mat = np.eye(3)
t_mat[0,2] = -res[1]/2
t_mat[1,2] = -res[0]/2
t_inv = t_mat.copy()
t_inv[:2,2] *= -1
t = np.dot(t_inv,np.dot(rot_mat,np.dot(t_mat,t)))
return t
def transform(pt, center, scale, res, invert=0, rot=0):
# Transform pixel location to different reference
t = get_transform(center, scale, res, rot=rot)
if invert:
t = np.linalg.inv(t)
new_pt = np.array([pt[0] - 1, pt[1] - 1, 1.]).T
new_pt = np.dot(t, new_pt)
return new_pt[:2].astype(int) + 1
def transform_preds(coords, center, scale, res):
# size = coords.size()
# coords = coords.view(-1, coords.size(-1))
# print(coords.size())
for p in range(coords.size(0)):
coords[p, 0:2] = to_torch(transform(coords[p, 0:2], center, scale, res, 1, 0))
return coords
def crop(img, center, scale, res, rot=0):
img = im_to_numpy(img)
# Upper left point
ul = np.array(transform([0, 0], center, scale, res, invert=1))
# Bottom right point
br = np.array(transform(res, center, scale, res, invert=1))
# Padding so that when rotated proper amount of context is included
pad = int(np.linalg.norm(br - ul) / 2 - float(br[1] - ul[1]) / 2)
if not rot == 0:
ul -= pad
br += pad
new_shape = [br[1] - ul[1], br[0] - ul[0]]
if len(img.shape) > 2:
new_shape += [img.shape[2]]
new_img = np.zeros(new_shape)
# Range to fill new array
new_x = max(0, -ul[0]), min(br[0], len(img[0])) - ul[0]
new_y = max(0, -ul[1]), min(br[1], len(img)) - ul[1]
# Range to sample from original image
old_x = max(0, ul[0]), min(len(img[0]), br[0])
old_y = max(0, ul[1]), min(len(img), br[1])
new_img[new_y[0]:new_y[1], new_x[0]:new_x[1]] = img[old_y[0]:old_y[1], old_x[0]:old_x[1]]
if not rot == 0:
# Remove padding
new_img = scipy.misc.imrotate(new_img, rot)
new_img = new_img[pad:-pad, pad:-pad]
new_img = im_to_torch(scipy.misc.imresize(new_img, res))
return new_img
def get_right(img,gray=False):
img = im_to_numpy(img) #H*W*C
new_img = img[:,0:256,:]
new_img = im_to_torch(new_img)
if gray == True:
new_img = new_img[1,:,:];
return new_img
class NormalizeInverse(torchvision.transforms.Normalize):
"""
Undoes the normalization and returns the reconstructed images in the input domain.
"""
def __init__(self, mean, std):
mean = torch.as_tensor(mean)
std = torch.as_tensor(std)
std_inv = 1 / (std + 1e-7)
mean_inv = -mean * std_inv
super().__init__(mean=mean_inv, std=std_inv)
def __call__(self, tensor):
return super().__call__(tensor.clone())
================================================
FILE: test.py
================================================
from __future__ import print_function, absolute_import
import argparse
import torch
torch.backends.cudnn.benchmark = True
from scripts.utils.misc import save_checkpoint, adjust_learning_rate
import scripts.datasets as datasets
import scripts.machines as machines
from options import Options
def main(args):
val_loader = torch.utils.data.DataLoader(datasets.COCO('val',args),batch_size=args.test_batch, shuffle=False,
num_workers=args.workers, pin_memory=True)
data_loaders = (None,val_loader)
Machine = machines.__dict__[args.machine](datasets=data_loaders, args=args)
Machine.test()
if __name__ == '__main__':
parser=Options().init(argparse.ArgumentParser(description='WaterMark Removal'))
main(parser.parse_args())
================================================
FILE: watermark_synthesis.ipynb
================================================
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"SAVE ALL THE SETTING\n"
]
}
],
"source": [
"# watermark synthesis\n",
"import os \n",
"import random\n",
"import shutil\n",
"from PIL import Image\n",
"import numpy as np\n",
"\n",
"def trans_paste(bg_img,fg_img,mask,box=(0,0)):\n",
" fg_img_trans = Image.new(\"RGBA\",bg_img.size)\n",
" fg_img_trans.paste(fg_img,box,mask=mask)\n",
" new_img = Image.alpha_composite(bg_img,fg_img_trans)\n",
" return new_img,fg_img_trans\n",
"\n",
"if os.path.isdir('dataset'):\n",
" shutil.rmtree('dataset')\n",
"\n",
"os.mkdir('dataset')\n",
"BASE_IMG_DIR = '/Users/oishii/Downloads/val2014/'\n",
"WATERMARK_DIR = 'logos' #1080 \n",
"images = sorted([os.path.join(BASE_IMG_DIR,x) for x in os.listdir(BASE_IMG_DIR) if '.jpg' in x])\n",
"watermarks = sorted([os.path.join(WATERMARK_DIR,x).replace(' ','_') for x in os.listdir(WATERMARK_DIR) if '.png' in x])\n",
"# rename all the watermark from replace ' ' to '_'\n",
"\n",
"random.shuffle(images)\n",
"random.shuffle(watermarks)\n",
"\n",
"train_images = images[:int(len(images)*0.7)]\n",
"val_images = images[int(len(images)*0.7):int(len(images)*0.8)]\n",
"test_images = images[int(len(images)*0.8):]\n",
"\n",
"train_wms = watermarks[:int(len(watermarks)*0.7)]\n",
"val_wms = watermarks[int(len(watermarks)*0.7):int(len(watermarks)*0.8)]\n",
"test_wms = watermarks[int(len(watermarks)*0.8):]\n",
"\n",
"# save all the settings to file\n",
"names = ['train_images','val_images','test_images','train_wms','val_wms','test_wms']\n",
"lists = [train_images,val_images,test_images,train_wms,val_wms,test_wms]\n",
"dataset = dict(zip(names, lists))\n",
"\n",
"for name,content in dataset.items():\n",
" with open('dataset/%s.txt'%name,'w') as f:\n",
" f.write(\"\\n\".join(content))\n",
"\n",
"print('SAVE ALL THE SETTING')\n",
"\n",
"for name, images in dataset.items():\n",
" if 'images' not in name:\n",
" continue\n",
" # for each setting, synthesis the watermark\n",
" # for each image, add X(X=6) watermark in differnet position, alpha,\n",
" # save the synthesized image, watermark mask, reshaped mask,\n",
" save_path = 'dataset/%s/'%name\n",
" os.makedirs('%s/image'%(save_path))\n",
" os.makedirs('%s/mask'%(save_path))\n",
" os.makedirs('%s/wm'%(save_path))\n",
" \n",
" for img in images:\n",
" im = Image.open(img).convert('RGBA')\n",
" imw,imh = im.size\n",
" \n",
" for wmg in random.choices(dataset[name.replace('images','wms')],k=6):\n",
" wm = Image.open(wmg.replace('_',' ')).convert(\"RGBA\") # RGBA\n",
" # get the mask of wm\n",
" # data agumentation of wm\n",
" wm = wm.rotate(angle=random.randint(0,360),expand=True) # rotate\n",
" \n",
" # make sure the \n",
" imrw = random.randrange(int(0.4*imw),int(0.8*imw))\n",
" imrh = random.randrange(int(0.4*imh),int(0.8*imh))\n",
" wmsize = imrh if imrw > imrh else imrw\n",
" wm = wm.resize((wmsize,wmsize),Image.BILINEAR)\n",
" w,h = wm.size # new size \n",
" \n",
" box_left = random.randint(0,imw-w)\n",
" box_upper = random.randint(0,imh-h)\n",
" wmm = wm.copy()\n",
" wm.putalpha(random.randint(int(255*0.4),int(255*0.8))) # alpha\n",
" \n",
" ims,wmc = trans_paste(im,wm,wmm,(box_left,box_upper))\n",
" \n",
" wmnp = np.array(wmc) # h,w,3\n",
" mask = np.sum(wmnp,axis=2)>0\n",
" mm = Image.fromarray(np.uint8(mask*255),mode='L')\n",
" \n",
" identifier = os.path.basename(img).split('.')[0] +'-'+os.path.basename(wmg).split('.')[0] + '.png'\n",
" # save \n",
" wmc.save('%s/wm/%s'%(save_path,identifier))\n",
" ims.save('%s/image/%s'%(save_path,identifier))\n",
" mm.save('%s/mask/%s'%(save_path,identifier))\n",
" \n",
" "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.4"
}
},
"nbformat": 4,
"nbformat_minor": 2
}