Repository: vinthony/deep-blind-watermark-removal
Branch: main
Commit: 72f0e61b9f06
Files: 35
Total size: 174.8 KB
Directory structure:
gitextract_p5y2m_4m/
├── README.md
├── examples/
│ ├── evaluate.sh
│ └── test.sh
├── main.py
├── options.py
├── requirements.txt
├── scripts/
│ ├── __init__.py
│ ├── datasets/
│ │ ├── BIH.py
│ │ ├── COCO.py
│ │ └── __init__.py
│ ├── machines/
│ │ ├── BasicMachine.py
│ │ ├── S2AM.py
│ │ ├── VX.py
│ │ └── __init__.py
│ ├── models/
│ │ ├── __init__.py
│ │ ├── backbone_unet.py
│ │ ├── blocks.py
│ │ ├── discriminator.py
│ │ ├── rasc.py
│ │ ├── sa_resunet.py
│ │ ├── unet.py
│ │ ├── vgg.py
│ │ └── vmu.py
│ └── utils/
│ ├── __init__.py
│ ├── evaluation.py
│ ├── imutils.py
│ ├── logger.py
│ ├── losses.py
│ ├── misc.py
│ ├── model_init.py
│ ├── osutils.py
│ ├── parallel.py
│ └── transforms.py
├── test.py
└── watermark_synthesis.ipynb
================================================
FILE CONTENTS
================================================
================================================
FILE: README.md
================================================
This repo contains the code and results of the AAAI 2021 paper:
<i><b> [Split then Refine: Stacked Attention-guided ResUNets for Blind Single Image Visible Watermark Removal](https://arxiv.org/abs/2012.07007)</b></i><br>
[Xiaodong Cun](http://vinthony.github.io), [Chi-Man Pun<sup>*</sup>](http://www.cis.umac.mo/~cmpun/) <br>
[University of Macau](http://um.edu.mo/)
[Datasets](#Resources) | [Models](#Resources) | [Paper](https://arxiv.org/abs/2012.07007) | [🔥Online Demo!](https://colab.research.google.com/drive/1pYY7byBjM-7aFIWk8HcF9nK_s6pqGwww?usp=sharing)(Google CoLab)
<hr>
<img width="726" alt="nn" src="https://user-images.githubusercontent.com/4397546/101241905-37915d80-3735-11eb-9fb9-2e1e46d63f15.png">
<i>The overview of the proposed two-stage framework. Firstly, we propose a multi-task network, SplitNet, for watermark detection, removal, and recovery. Then, we propose the RefineNet to smooth the learned region with the predicted mask and the recovered background from the previous stage. As a consequence, our network can be trained in an end-to-end fashion without any manual intervention. Note that, for clarity, we do not show any skip-connections between all the encoders and decoders.</i>
<hr>
> The whole project will be released in the January of 2021 (almost).
### Datasets
We synthesized four different datasets for training and testing, you can download the dataset via [huggingface](https://huggingface.co/datasets/vinthony/watermark-removal-logo/tree/main).

### Pre-trained Models
* [27kpng_model_best.pth.tar (google drive)](https://drive.google.com/file/d/1KpSJ6385CHN6WlAINqB3CYrJdleQTJBc/view?usp=sharing)
> Other Pre-trained Models are still reorganizing and uploading, it will be released soon.
### Demos
An easy-to-use online demo can be founded in [google colab](https://colab.research.google.com/drive/1pYY7byBjM-7aFIWk8HcF9nK_s6pqGwww?usp=sharing).
The local demo will be released soon.
### Pre-requirements
```
pip install -r requirements.txt
```
### Train
Besides training our methods, here, we also give an example of how to train the [s2am](https://github.com/vinthony/s2am) under our framework. More details can be found in the shell scripts.
```
bash examples/evaluation.sh
```
### Test
```
bash examples/test.sh
```
## **Acknowledgements**
The author would like to thanks Nan Chen for her helpful discussion.
Part of the code is based upon our previous work on image harmonization [s2am](https://github.com/vinthony/s2am)
## **Citation**
If you find our work useful in your research, please consider citing:
```
@misc{cun2020split,
title={Split then Refine: Stacked Attention-guided ResUNets for Blind Single Image Visible Watermark Removal},
author={Xiaodong Cun and Chi-Man Pun},
year={2020},
eprint={2012.07007},
archivePrefix={arXiv},
primaryClass={cs.CV}
}
```
## **Contact**
Please contact me if there is any question (Xiaodong Cun yb87432@um.edu.mo)
================================================
FILE: examples/evaluate.sh
================================================
set -ex
# example training scripts for AAAI-21
# Split then Refine: Stacked Attention-guided ResUNets for Blind Single Image Visible Watermark Removal
CUDA_VISIBLE_DEVICES=0 python /data/home/yb87432/s2am/main.py --epochs 100\
--schedule 100\
--lr 1e-3\
-c eval/10kgray/1e3_bs4_256_hybrid_ssim_vgg\
--arch vvv4n\
--sltype vggx\
--style-loss 0.025\
--ssim-loss 0.15\
--masked True\
--loss-type hybrid\
--limited-dataset 1\
--machine vx\
--input-size 256\
--train-batch 4\
--test-batch 1\
--base-dir $HOME/watermark/10kgray/\
--data _images
# example training scripts for TIP-20
# Improving the Harmony of the Composite Image by Spatial-Separated Attention Module
# * in the original version, the res = False
# suitable for the iHarmony4 dataset.
python /data/home/yb87432/mypaper/s2am/main.py --epochs 200\
--schedule 150\
--lr 1e-3\
-c checkpoint/normal_rasc_HAdobe5k_res \
--arch rascv2\
--style-loss 0\
--ssim-loss 0\
--limited-dataset 0\
--res True\
--machine s2am\
--input-size 256\
--train-batch 16\
--test-batch 1\
--base-dir $HOME/Datasets/\
--data HAdobe5k
================================================
FILE: examples/test.sh
================================================
set -ex
CUDA_VISIBLE_DEVICES=0 python /data/home/yb87432/s2am/test.py \
-c test/10kgray_ssim\
--resume /data/home/yb87432/s2am/eval/10kgray/1e3_bs6_256_hybrid_ssim_vgg_vx__images_vvv4n/model_best.pth.tar\
--arch vvv4n\
--machine vx\
--input-size 256\
--test-batch 1\
--evaluate\
--base-dir $HOME/watermark/10kgray/\
--data _images
================================================
FILE: main.py
================================================
from __future__ import print_function, absolute_import
import argparse
import torch,time,os
torch.backends.cudnn.benchmark = True
from scripts.utils.misc import save_checkpoint, adjust_learning_rate
import scripts.datasets as datasets
import scripts.machines as machines
from options import Options
def main(args):
if 'HFlickr' or 'HCOCO' or 'Hday2night' or 'HAdobe5k' in args.base_dir:
dataset_func = datasets.BIH
else:
dataset_func = datasets.COCO
train_loader = torch.utils.data.DataLoader(dataset_func('train',args),batch_size=args.train_batch, shuffle=True,
num_workers=args.workers, pin_memory=True)
val_loader = torch.utils.data.DataLoader(dataset_func('val',args),batch_size=args.test_batch, shuffle=False,
num_workers=args.workers, pin_memory=True)
lr = args.lr
data_loaders = (train_loader,val_loader)
Machine = machines.__dict__[args.machine](datasets=data_loaders, args=args)
print('============================ Initization Finish && Training Start =============================================')
for epoch in range(Machine.args.start_epoch, Machine.args.epochs):
print('\nEpoch: %d | LR: %.8f' % (epoch + 1, lr))
lr = adjust_learning_rate(data_loaders, Machine.optimizer, epoch, lr, args)
Machine.record('lr',lr, epoch)
Machine.train(epoch)
if args.freq < 0:
Machine.validate(epoch)
Machine.flush()
Machine.save_checkpoint()
if __name__ == '__main__':
parser=Options().init(argparse.ArgumentParser(description='WaterMark Removal'))
args = parser.parse_args()
print('==================================== WaterMark Removal =============================================')
print('==> {:50}: {:<}'.format("Start Time",time.ctime(time.time())))
print('==> {:50}: {:<}'.format("USE GPU",os.environ['CUDA_VISIBLE_DEVICES']))
print('==================================== Stable Parameters =============================================')
for arg in vars(args):
if type(getattr(args, arg)) == type([]):
if ','.join([ str(i) for i in getattr(args, arg)]) == ','.join([ str(i) for i in parser.get_default(arg)]):
print('==> {:50}: {:<}({:<})'.format(arg,','.join([ str(i) for i in getattr(args, arg)]),','.join([ str(i) for i in parser.get_default(arg)])))
else:
if getattr(args, arg) == parser.get_default(arg):
print('==> {:50}: {:<}({:<})'.format(arg,getattr(args, arg),parser.get_default(arg)))
print('==================================== Changed Parameters =============================================')
for arg in vars(args):
if type(getattr(args, arg)) == type([]):
if ','.join([ str(i) for i in getattr(args, arg)]) != ','.join([ str(i) for i in parser.get_default(arg)]):
print('==> {:50}: {:<}({:<})'.format(arg,','.join([ str(i) for i in getattr(args, arg)]),','.join([ str(i) for i in parser.get_default(arg)])))
else:
if getattr(args, arg) != parser.get_default(arg):
print('==> {:50}: {:<}({:<})'.format(arg,getattr(args, arg),parser.get_default(arg)))
print('==================================== Start Init Model ===============================================')
main(args)
print('==================================== FINISH WITHOUT ERROR =============================================')
================================================
FILE: options.py
================================================
import scripts.models as models
model_names = sorted(name for name in models.__dict__
if name.islower() and not name.startswith("__")
and callable(models.__dict__[name]))
class Options():
"""docstring for Options"""
def __init__(self):
pass
def init(self, parser):
# Model structure
parser.add_argument('--arch', '-a', metavar='ARCH', default='dhn',
choices=model_names,
help='model architecture: ' +
' | '.join(model_names) +
' (default: resnet18)')
parser.add_argument('--darch', metavar='ARCH', default='dhn',
choices=model_names,
help='model architecture: ' +
' | '.join(model_names) +
' (default: resnet18)')
parser.add_argument('--machine', '-m', metavar='NACHINE', default='basic')
# Training strategy
parser.add_argument('-j', '--workers', default=2, type=int, metavar='N',
help='number of data loading workers (default: 4)')
parser.add_argument('--epochs', default=30, type=int, metavar='N',
help='number of total epochs to run')
parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
help='manual epoch number (useful on restarts)')
parser.add_argument('--train-batch', default=64, type=int, metavar='N',
help='train batchsize')
parser.add_argument('--test-batch', default=6, type=int, metavar='N',
help='test batchsize')
parser.add_argument('--lr', '--learning-rate', default=1e-3, type=float,metavar='LR', help='initial learning rate')
parser.add_argument('--dlr', '--dlearning-rate', default=1e-3, type=float, help='initial learning rate')
parser.add_argument('--beta1', default=0.9, type=float, help='initial learning rate')
parser.add_argument('--beta2', default=0.999, type=float, help='initial learning rate')
parser.add_argument('--momentum', default=0, type=float, metavar='M',
help='momentum')
parser.add_argument('--weight-decay', '--wd', default=0, type=float,
metavar='W', help='weight decay (default: 0)')
parser.add_argument('--schedule', type=int, nargs='+', default=[5, 10],
help='Decrease learning rate at these epochs.')
parser.add_argument('--gamma', type=float, default=0.1,
help='LR is multiplied by gamma on schedule.')
# Data processing
parser.add_argument('-f', '--flip', dest='flip', action='store_true',
help='flip the input during validation')
parser.add_argument('--lambdaL1', type=float, default=1, help='the weight of L1.')
parser.add_argument('--alpha', type=float, default=0.5,
help='Groundtruth Gaussian sigma.')
parser.add_argument('--sigma-decay', type=float, default=0,
help='Sigma decay rate for each epoch.')
# Miscs
parser.add_argument('--base-dir', default='/PATH_TO_DATA_FOLDER/', type=str, metavar='PATH')
parser.add_argument('--data', default='', type=str, metavar='PATH',
help='path to save checkpoint (default: checkpoint)')
parser.add_argument('-c', '--checkpoint', default='checkpoint', type=str, metavar='PATH',
help='path to save checkpoint (default: checkpoint)')
parser.add_argument('--resume', default='', type=str, metavar='PATH',
help='path to latest checkpoint (default: none)')
parser.add_argument('--finetune', default='', type=str, metavar='PATH',
help='path to latest checkpoint (default: none)')
parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
help='evaluate model on validation set')
parser.add_argument('--style-loss', default=0, type=float,
help='preception loss')
parser.add_argument('--ssim-loss', default=0, type=float,help='msssim loss')
parser.add_argument('--att-loss', default=1, type=float,help='msssim loss')
parser.add_argument('--default-loss',default=False,type=bool)
parser.add_argument('--sltype', default='vggx', type=str)
parser.add_argument('-da', '--data-augumentation', default=False, type=bool,
help='preception loss')
parser.add_argument('-d', '--debug', dest='debug', action='store_true',
help='show intermediate results')
parser.add_argument('--input-size', default=256, type=int, metavar='N',
help='train batchsize')
parser.add_argument('--freq', default=-1, type=int, metavar='N',
help='evaluation frequence')
parser.add_argument('--normalized-input', default=False, type=bool,
help='train batchsize')
parser.add_argument('--res', default=False, type=bool,help='residual learning for s2am')
parser.add_argument('--requires-grad', default=False, type=bool,
help='train batchsize')
parser.add_argument('--limited-dataset', default=0, type=int, metavar='N')
parser.add_argument('--gpu',default=True,type=bool)
parser.add_argument('--masked',default=False,type=bool)
parser.add_argument('--gan-norm', default=False,type=bool, help='train batchsize')
parser.add_argument('--hl', default=False,type=bool, help='homogenious leanring')
parser.add_argument('--loss-type', default='l2',type=str, help='train batchsize')
return parser
================================================
FILE: requirements.txt
================================================
numpy==1.19.1
opencv-python==3.4.8.29
Pillow
scikit-image==0.14.5
scikit-learn==0.23.1
scipy==1.2.1
sklearn==0.0
tensorboardX
torch>=1.0.0
torchvision
================================================
FILE: scripts/__init__.py
================================================
from __future__ import absolute_import
from . import datasets
from . import models
from . import utils
# import os, sys
# sys.path.append(os.path.join(os.path.dirname(__file__), "progress"))
# from progress.bar import Bar as Bar
# __version__ = '0.1.0'
================================================
FILE: scripts/datasets/BIH.py
================================================
from __future__ import print_function, absolute_import
import os
import csv
import numpy as np
import json
import random
import math
import matplotlib.pyplot as plt
from collections import namedtuple
from os import listdir
from os.path import isfile, join
import torch
import torch.utils.data as data
from scripts.utils.osutils import *
from scripts.utils.imutils import *
from scripts.utils.transforms import *
import torchvision.transforms as transforms
from PIL import Image
from PIL import ImageEnhance
from PIL import ImageFilter
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True
class BIH(data.Dataset):
def __init__(self,train,config=None, sample=[],gan_norm=False):
self.train = []
self.anno = []
self.mask = []
self.wm = []
self.input_size = config.input_size
self.normalized_input = config.normalized_input
self.base_folder = config.base_dir +'/' + config.data
self.dataset = config.data
if config == None:
self.data_augumentation = False
else:
self.data_augumentation = config.data_augumentation
self.istrain = False if train.find('train') == -1 else True
self.sample = sample
self.gan_norm = gan_norm
mypath = join(self.base_folder,self.dataset+'_'+train+'.txt')
with open(mypath) as f:
# here we get the filenames
file_names = [ im.strip() for im in f.readlines() ]
if config.limited_dataset > 0:
xtrain = sorted(list(set([ file_name.split('-')[0] for file_name in file_names ])))
tmp = []
for x in xtrain:
tmp.append([y for y in file_names if x in y][0])
file_names = tmp
else:
file_names = file_names
for file_name in file_names:
self.train.append(os.path.join(self.base_folder,'images',file_name))
self.mask.append(os.path.join(self.base_folder,'masks','_'.join(file_name.split('_')[0:2])+'.png'))
self.anno.append(os.path.join(self.base_folder,'reals',file_name.split('_')[0]+'.jpg'))
if len(self.sample) > 0 :
self.train = [ self.train[i] for i in self.sample ]
self.mask = [ self.mask[i] for i in self.sample ]
self.anno = [ self.anno[i] for i in self.sample ]
self.trans = transforms.Compose([
transforms.Resize((self.input_size,self.input_size)),
transforms.ToTensor()
])
print('total Dataset of '+self.dataset+' is : ', len(self.train))
def __getitem__(self, index):
img = Image.open(self.train[index]).convert('RGB')
mask = Image.open(self.mask[index]).convert('L')
anno = Image.open(self.anno[index]).convert('RGB')
# for shadow removal and blind image harmonization, here is no ground truth wm
# wm = Image.open(self.wm[index]).convert('RGB')
return {"image": self.trans(img),
"target": self.trans(anno),
"mask": self.trans(mask),
"name": self.train[index].split('/')[-1],
"imgurl":self.train[index],
"maskurl":self.mask[index],
"targeturl":self.anno[index],
}
def __len__(self):
return len(self.train)
================================================
FILE: scripts/datasets/COCO.py
================================================
from __future__ import print_function, absolute_import
import os
import csv
import numpy as np
import json
import random
import math
import matplotlib.pyplot as plt
from collections import namedtuple
from os import listdir
from os.path import isfile, join
import torch
import torch.utils.data as data
from scripts.utils.osutils import *
from scripts.utils.imutils import *
from scripts.utils.transforms import *
import torchvision.transforms as transforms
from PIL import Image
from PIL import ImageEnhance
from PIL import ImageFilter
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True
class COCO(data.Dataset):
def __init__(self,train,config=None, sample=[],gan_norm=False):
self.train = []
self.anno = []
self.mask = []
self.wm = []
self.input_size = config.input_size
self.normalized_input = config.normalized_input
self.base_folder = config.base_dir
self.dataset = train+config.data
if config == None:
self.data_augumentation = False
else:
self.data_augumentation = config.data_augumentation
self.istrain = False if self.dataset.find('train') == -1 else True
self.sample = sample
self.gan_norm = gan_norm
mypath = join(self.base_folder,self.dataset)
file_names = sorted([f for f in listdir(join(mypath,'image')) if isfile(join(mypath,'image', f)) ])
if config.limited_dataset > 0:
xtrain = sorted(list(set([ file_name.split('-')[0] for file_name in file_names ])))
tmp = []
for x in xtrain:
# get the file_name by identifier
tmp.append([y for y in file_names if x in y][0])
file_names = tmp
else:
file_names = file_names
for file_name in file_names:
self.train.append(os.path.join(mypath,'image',file_name))
self.mask.append(os.path.join(mypath,'mask',file_name))
self.wm.append(os.path.join(mypath,'wm',file_name))
self.anno.append(os.path.join(self.base_folder,'natural',file_name.split('-')[0]+'.jpg'))
if len(self.sample) > 0 :
self.train = [ self.train[i] for i in self.sample ]
self.mask = [ self.mask[i] for i in self.sample ]
self.anno = [ self.anno[i] for i in self.sample ]
self.trans = transforms.Compose([
transforms.Resize((self.input_size,self.input_size)),
transforms.ToTensor()
])
print('total Dataset of '+self.dataset+' is : ', len(self.train))
def __getitem__(self, index):
img = Image.open(self.train[index]).convert('RGB')
mask = Image.open(self.mask[index]).convert('L')
anno = Image.open(self.anno[index]).convert('RGB')
wm = Image.open(self.wm[index]).convert('RGB')
return {"image": self.trans(img),
"target": self.trans(anno),
"mask": self.trans(mask),
"wm": self.trans(wm),
"name": self.train[index].split('/')[-1],
"imgurl":self.train[index],
"maskurl":self.mask[index],
"targeturl":self.anno[index],
"wmurl":self.wm[index]
}
def __len__(self):
return len(self.train)
================================================
FILE: scripts/datasets/__init__.py
================================================
from .COCO import COCO
from .BIH import BIH
__all__ = ('COCO','BIH')
================================================
FILE: scripts/machines/BasicMachine.py
================================================
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
from progress.bar import Bar
import json
import numpy as np
from tensorboardX import SummaryWriter
from scripts.utils.evaluation import accuracy, AverageMeter, final_preds
from scripts.utils.osutils import mkdir_p, isfile, isdir, join
from scripts.utils.parallel import DataParallelModel, DataParallelCriterion
import pytorch_ssim as pytorch_ssim
import torch.optim
import sys,shutil,os
import time
import scripts.models as archs
from math import log10
from torch.autograd import Variable
from scripts.utils.losses import VGGLoss
from scripts.utils.imutils import im_to_numpy
import skimage.io
from skimage.measure import compare_psnr,compare_ssim
class BasicMachine(object):
def __init__(self, datasets =(None,None), models = None, args = None, **kwargs):
super(BasicMachine, self).__init__()
self.args = args
# create model
print("==> creating model ")
self.model = archs.__dict__[self.args.arch]()
print("==> creating model [Finish]")
self.train_loader, self.val_loader = datasets
self.loss = torch.nn.MSELoss()
self.title = '_'+args.machine + '_' + args.data + '_' + args.arch
self.args.checkpoint = args.checkpoint + self.title
self.device = torch.device('cuda')
# create checkpoint dir
if not isdir(self.args.checkpoint):
mkdir_p(self.args.checkpoint)
self.optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, self.model.parameters()),
lr=args.lr,
betas=(args.beta1,args.beta2),
weight_decay=args.weight_decay)
if not self.args.evaluate:
self.writer = SummaryWriter(self.args.checkpoint+'/'+'ckpt')
self.best_acc = 0
self.is_best = False
self.current_epoch = 0
self.metric = -100000
self.hl = 6 if self.args.hl else 1
self.count_gpu = len(range(torch.cuda.device_count()))
if self.args.style_loss > 0:
# init perception loss
self.vggloss = VGGLoss(self.args.sltype).to(self.device)
if self.count_gpu > 1 : # multiple
# self.model = DataParallelModel(self.model, device_ids=range(torch.cuda.device_count()))
# self.loss = DataParallelCriterion(self.loss, device_ids=range(torch.cuda.device_count()))
self.model = nn.DataParallel(self.model, device_ids=range(torch.cuda.device_count()))
self.model.to(self.device)
self.loss.to(self.device)
print('==> Total params: %.2fM' % (sum(p.numel() for p in self.model.parameters())/1000000.0))
print('==> Total devices: %d' % (torch.cuda.device_count()))
print('==> Current Checkpoint: %s' % (self.args.checkpoint))
if self.args.resume != '':
self.resume(self.args.resume)
def train(self,epoch):
batch_time = AverageMeter()
data_time = AverageMeter()
losses = AverageMeter()
lossvgg = AverageMeter()
# switch to train mode
self.model.train()
end = time.time()
bar = Bar('Processing', max=len(self.train_loader)*self.hl)
for _ in range(self.hl):
for i, batches in enumerate(self.train_loader):
# measure data loading time
inputs = batches['image']
target = batches['target'].to(self.device)
mask =batches['mask'].to(self.device)
current_index = len(self.train_loader) * epoch + i
if self.args.hl:
feeded = torch.cat([inputs,mask],dim=1)
else:
feeded = inputs
feeded = feeded.to(self.device)
output = self.model(feeded)
L2_loss = self.loss(output,target)
if self.args.style_loss > 0:
vgg_loss = self.vggloss(output,target,mask)
else:
vgg_loss = 0
total_loss = L2_loss + self.args.style_loss * vgg_loss
# compute gradient and do SGD step
self.optimizer.zero_grad()
total_loss.backward()
self.optimizer.step()
# measure accuracy and record loss
losses.update(L2_loss.item(), inputs.size(0))
if self.args.style_loss > 0 :
lossvgg.update(vgg_loss.item(), inputs.size(0))
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
# plot progress
suffix = '({batch}/{size}) Data: {data:.2f}s | Batch: {bt:.3f}s | Total: {total:} | ETA: {eta:} | Loss L2: {loss_label:.4f} | Loss VGG: {loss_vgg:.4f}'.format(
batch=i + 1,
size=len(self.train_loader),
data=data_time.val,
bt=batch_time.val,
total=bar.elapsed_td,
eta=bar.eta_td,
loss_label=losses.avg,
loss_vgg=lossvgg.avg
)
if current_index % 1000 == 0:
print(suffix)
if self.args.freq > 0 and current_index % self.args.freq == 0:
self.validate(current_index)
self.flush()
self.save_checkpoint()
self.record('train/loss_L2', losses.avg, current_index)
def test(self, ):
# switch to evaluate mode
self.model.eval()
ssimes = AverageMeter()
psnres = AverageMeter()
with torch.no_grad():
for i, batches in enumerate(self.val_loader):
inputs = batches['image'].to(self.device)
target = batches['target'].to(self.device)
mask =batches['mask'].to(self.device)
outputs = self.model(inputs)
# select the outputs by the giving arch
if type(outputs) == type(inputs):
output = outputs
elif type(outputs[0]) == type([]):
output = outputs[0][0]
else:
output = outputs[0]
# recover the image to 255
output = im_to_numpy(torch.clamp(output[0]*255,min=0.0,max=255.0)).astype(np.uint8)
target = im_to_numpy(torch.clamp(target[0]*255,min=0.0,max=255.0)).astype(np.uint8)
skimage.io.imsave('%s/%s'%(self.args.checkpoint,batches['name'][0]), output)
psnr = compare_psnr(target,output)
ssim = compare_ssim(target,output,multichannel=True)
psnres.update(psnr, inputs.size(0))
ssimes.update(ssim, inputs.size(0))
print("%s:PSNR:%s,SSIM:%s"%(self.args.checkpoint,psnres.avg,ssimes.avg))
print("DONE.\n")
def validate(self, epoch):
batch_time = AverageMeter()
data_time = AverageMeter()
losses = AverageMeter()
ssimes = AverageMeter()
psnres = AverageMeter()
# switch to evaluate mode
self.model.eval()
end = time.time()
with torch.no_grad():
for i, batches in enumerate(self.val_loader):
inputs = batches['image'].to(self.device)
target = batches['target'].to(self.device)
mask =batches['mask'].to(self.device)
if self.args.hl:
feeded = torch.cat([inputs,torch.zeros((1,4,self.args.input_size,self.args.input_size)).to(self.device)],dim=1)
else:
feeded = inputs
output = self.model(feeded)
L2_loss = self.loss(output, target)
psnr = 10 * log10(1 / L2_loss.item())
ssim = pytorch_ssim.ssim(output, target)
losses.update(L2_loss.item(), inputs.size(0))
psnres.update(psnr, inputs.size(0))
ssimes.update(ssim.item(), inputs.size(0))
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
print("Epoches:%s,Losses:%.3f,PSNR:%.3f,SSIM:%.3f"%(epoch+1, losses.avg,psnres.avg,ssimes.avg))
self.record('val/loss_L2', losses.avg, epoch)
self.record('val/PSNR', psnres.avg, epoch)
self.record('val/SSIM', ssimes.avg, epoch)
self.metric = psnres.avg
def resume(self,resume_path):
if isfile(resume_path):
print("=> loading checkpoint '{}'".format(resume_path))
current_checkpoint = torch.load(resume_path)
if isinstance(current_checkpoint['state_dict'], torch.nn.DataParallel):
current_checkpoint['state_dict'] = current_checkpoint['state_dict'].module
if isinstance(current_checkpoint['optimizer'], torch.nn.DataParallel):
current_checkpoint['optimizer'] = current_checkpoint['optimizer'].module
self.args.start_epoch = current_checkpoint['epoch']
self.metric = current_checkpoint['best_acc']
self.model.load_state_dict(current_checkpoint['state_dict'])
# self.optimizer.load_state_dict(current_checkpoint['optimizer'])
print("=> loaded checkpoint '{}' (epoch {})"
.format(resume_path, current_checkpoint['epoch']))
else:
raise Exception("=> no checkpoint found at '{}'".format(resume_path))
def save_checkpoint(self,filename='checkpoint.pth.tar', snapshot=None):
is_best = True if self.best_acc < self.metric else False
if is_best:
self.best_acc = self.metric
state = {
'epoch': self.current_epoch + 1,
'arch': self.args.arch,
'state_dict': self.model.state_dict(),
'best_acc': self.best_acc,
'optimizer' : self.optimizer.state_dict() if self.optimizer else None,
}
filepath = os.path.join(self.args.checkpoint, filename)
torch.save(state, filepath)
if snapshot and state['epoch'] % snapshot == 0:
shutil.copyfile(filepath, os.path.join(self.args.checkpoint, 'checkpoint_{}.pth.tar'.format(state.epoch)))
if is_best:
self.best_acc = self.metric
print('Saving Best Metric with PSNR:%s'%self.best_acc)
shutil.copyfile(filepath, os.path.join(self.args.checkpoint, 'model_best.pth.tar'))
def clean(self):
self.writer.close()
def record(self,k,v,epoch):
self.writer.add_scalar(k, v, epoch)
def flush(self):
self.writer.flush()
sys.stdout.flush()
def norm(self,x):
if self.args.gan_norm:
return x*2.0 - 1.0
else:
return x
def denorm(self,x):
if self.args.gan_norm:
return (x+1.0)/2.0
else:
return x
================================================
FILE: scripts/machines/S2AM.py
================================================
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
from progress.bar import Bar
import json
import numpy as np
from tensorboardX import SummaryWriter
from scripts.utils.evaluation import accuracy, AverageMeter, final_preds
from scripts.utils.osutils import mkdir_p, isfile, isdir, join
from scripts.utils.parallel import DataParallelModel, DataParallelCriterion
import pytorch_ssim as pytorch_ssim
import torch.optim
import sys,shutil,os
import time
import scripts.models as archs
from math import log10
from torch.autograd import Variable
from scripts.utils.losses import VGGLoss
from scripts.utils.imutils import im_to_numpy
import skimage.io
from skimage.measure import compare_psnr,compare_ssim
class S2AM(object):
def __init__(self, datasets =(None,None), models = None, args = None, **kwargs):
super(S2AM, self).__init__()
self.args = args
# create model
print("==> creating model ")
self.model = archs.__dict__[self.args.arch]()
print("==> creating model [Finish]")
self.train_loader, self.val_loader = datasets
self.loss = torch.nn.MSELoss()
self.title = '_'+args.machine + '_' + args.data + '_' + args.arch
self.args.checkpoint = args.checkpoint + self.title
self.device = torch.device('cuda')
# create checkpoint dir
if not isdir(self.args.checkpoint):
mkdir_p(self.args.checkpoint)
self.optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, self.model.parameters()),
lr=args.lr,
betas=(args.beta1,args.beta2),
weight_decay=args.weight_decay)
if not self.args.evaluate:
self.writer = SummaryWriter(self.args.checkpoint+'/'+'ckpt')
self.best_acc = 0
self.is_best = False
self.current_epoch = 0
self.hl = 1
self.metric = -100000
self.count_gpu = len(range(torch.cuda.device_count()))
if self.args.style_loss > 0:
# init perception loss
self.vggloss = VGGLoss(self.args.sltype).to(self.device)
if self.count_gpu > 1 : # multiple
# self.model = DataParallelModel(self.model, device_ids=range(torch.cuda.device_count()))
# self.loss = DataParallelCriterion(self.loss, device_ids=range(torch.cuda.device_count()))
self.model = nn.DataParallel(self.model, device_ids=range(torch.cuda.device_count()))
self.model.to(self.device)
self.loss.to(self.device)
print('==> Total params: %.2fM' % (sum(p.numel() for p in self.model.parameters())/1000000.0))
print('==> Total devices: %d' % (torch.cuda.device_count()))
print('==> Current Checkpoint: %s' % (self.args.checkpoint))
if self.args.resume != '':
self.resume(self.args.resume)
def train(self,epoch):
batch_time = AverageMeter()
data_time = AverageMeter()
losses = AverageMeter()
lossvgg = AverageMeter()
# switch to train mode
self.model.train()
end = time.time()
bar = Bar('Processing', max=len(self.train_loader)*self.hl)
for _ in range(self.hl):
for i, batches in enumerate(self.train_loader):
# measure data loading time
inputs = batches['image'].to(self.device)
target = batches['target'].to(self.device)
mask =batches['mask'].to(self.device)
current_index = len(self.train_loader) * epoch + i
feeded = torch.cat([inputs,mask],dim=1)
feeded = feeded.to(self.device)
output = self.model(feeded)
if self.args.res:
output = output + inputs
L2_loss = self.loss(output,target)
if self.args.style_loss > 0:
vgg_loss = self.vggloss(output,target,mask)
else:
vgg_loss = 0
total_loss = L2_loss + self.args.style_loss * vgg_loss
# compute gradient and do SGD step
self.optimizer.zero_grad()
total_loss.backward()
self.optimizer.step()
# measure accuracy and record loss
losses.update(L2_loss.item(), inputs.size(0))
if self.args.style_loss > 0 :
lossvgg.update(vgg_loss.item(), inputs.size(0))
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
# plot progress
suffix = '({batch}/{size}) Data: {data:.2f}s | Batch: {bt:.3f}s | Total: {total:} | ETA: {eta:} | Loss L2: {loss_label:.4f} | Loss VGG: {loss_vgg:.4f}'.format(
batch=i + 1,
size=len(self.train_loader),
data=data_time.val,
bt=batch_time.val,
total=bar.elapsed_td,
eta=bar.eta_td,
loss_label=losses.avg,
loss_vgg=lossvgg.avg
)
if current_index % 1000 == 0:
print(suffix)
if self.args.freq > 0 and current_index % self.args.freq == 0:
self.validate(current_index)
self.flush()
self.save_checkpoint()
self.record('train/loss_L2', losses.avg, current_index)
def test(self, ):
# switch to evaluate mode
self.model.eval()
ssimes = AverageMeter()
psnres = AverageMeter()
with torch.no_grad():
for i, batches in enumerate(self.val_loader):
inputs = batches['image'].to(self.device)
target = batches['target'].to(self.device)
mask =batches['mask'].to(self.device)
feeded = torch.cat([inputs,mask],dim=1)
feeded = feeded.to(self.device)
output = self.model(feeded)
if self.args.res:
output = output + inputs
# recover the image to 255
output = im_to_numpy(torch.clamp(output[0]*255,min=0.0,max=255.0)).astype(np.uint8)
target = im_to_numpy(torch.clamp(target[0]*255,min=0.0,max=255.0)).astype(np.uint8)
skimage.io.imsave('%s/%s'%(self.args.checkpoint,batches['name'][0]), output)
psnr = compare_psnr(target,output)
ssim = compare_ssim(target,output,multichannel=True)
psnres.update(psnr, inputs.size(0))
ssimes.update(ssim, inputs.size(0))
print("%s:PSNR:%s,SSIM:%s"%(self.args.checkpoint,psnres.avg,ssimes.avg))
print("DONE.\n")
def validate(self, epoch):
batch_time = AverageMeter()
data_time = AverageMeter()
losses = AverageMeter()
ssimes = AverageMeter()
psnres = AverageMeter()
# switch to evaluate mode
self.model.eval()
end = time.time()
with torch.no_grad():
for i, batches in enumerate(self.val_loader):
inputs = batches['image'].to(self.device)
target = batches['target'].to(self.device)
mask =batches['mask'].to(self.device)
feeded = torch.cat([inputs,mask],dim=1)
feeded = feeded.to(self.device)
output = self.model(feeded)
if self.args.res:
output = output + inputs
L2_loss = self.loss(output, target)
psnr = 10 * log10(1 / L2_loss.item())
ssim = pytorch_ssim.ssim(output, target)
losses.update(L2_loss.item(), inputs.size(0))
psnres.update(psnr, inputs.size(0))
ssimes.update(ssim.item(), inputs.size(0))
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
print("Epoches:%s,Losses:%.3f,PSNR:%.3f,SSIM:%.3f"%(epoch+1, losses.avg,psnres.avg,ssimes.avg))
self.record('val/loss_L2', losses.avg, epoch)
self.record('val/PSNR', psnres.avg, epoch)
self.record('val/SSIM', ssimes.avg, epoch)
self.metric = psnres.avg
def resume(self,resume_path):
if isfile(resume_path):
print("=> loading checkpoint '{}'".format(resume_path))
current_checkpoint = torch.load(resume_path)
if isinstance(current_checkpoint['state_dict'], torch.nn.DataParallel):
current_checkpoint['state_dict'] = current_checkpoint['state_dict'].module
if isinstance(current_checkpoint['optimizer'], torch.nn.DataParallel):
current_checkpoint['optimizer'] = current_checkpoint['optimizer'].module
self.args.start_epoch = current_checkpoint['epoch']
self.metric = current_checkpoint['best_acc']
self.model.load_state_dict(current_checkpoint['state_dict'])
# self.optimizer.load_state_dict(current_checkpoint['optimizer'])
print("=> loaded checkpoint '{}' (epoch {})"
.format(resume_path, current_checkpoint['epoch']))
else:
raise Exception("=> no checkpoint found at '{}'".format(resume_path))
def save_checkpoint(self,filename='checkpoint.pth.tar', snapshot=None):
is_best = True if self.best_acc < self.metric else False
if is_best:
self.best_acc = self.metric
state = {
'epoch': self.current_epoch + 1,
'arch': self.args.arch,
'state_dict': self.model.state_dict(),
'best_acc': self.best_acc,
'optimizer' : self.optimizer.state_dict() if self.optimizer else None,
}
filepath = os.path.join(self.args.checkpoint, filename)
torch.save(state, filepath)
if snapshot and state['epoch'] % snapshot == 0:
shutil.copyfile(filepath, os.path.join(self.args.checkpoint, 'checkpoint_{}.pth.tar'.format(state.epoch)))
if is_best:
self.best_acc = self.metric
print('Saving Best Metric with PSNR:%s'%self.best_acc)
shutil.copyfile(filepath, os.path.join(self.args.checkpoint, 'model_best.pth.tar'))
def clean(self):
self.writer.close()
def record(self,k,v,epoch):
self.writer.add_scalar(k, v, epoch)
def flush(self):
self.writer.flush()
sys.stdout.flush()
def norm(self,x):
if self.args.gan_norm:
return x*2.0 - 1.0
else:
return x
def denorm(self,x):
if self.args.gan_norm:
return (x+1.0)/2.0
else:
return x
================================================
FILE: scripts/machines/VX.py
================================================
import torch
import torch.nn as nn
from progress.bar import Bar
from tqdm import tqdm
import pytorch_ssim
import json
import sys,time,os
import torchvision
from math import log10
import numpy as np
from .BasicMachine import BasicMachine
from scripts.utils.evaluation import accuracy, AverageMeter, final_preds
from scripts.utils.misc import resize_to_match
from torch.autograd import Variable
import torch.nn.functional as F
from scripts.utils.parallel import DataParallelModel, DataParallelCriterion
from scripts.utils.losses import VGGLoss, l1_relative,is_dic
from scripts.utils.imutils import im_to_numpy
import skimage.io
from skimage.measure import compare_psnr,compare_ssim
class Losses(nn.Module):
def __init__(self, argx, device, norm_func=None, denorm_func=None):
super(Losses, self).__init__()
self.args = argx
if self.args.loss_type == 'l1bl2':
self.outputLoss, self.attLoss, self.wrloss = nn.L1Loss(), nn.BCELoss(), nn.MSELoss()
elif self.args.loss_type == 'l2xbl2':
self.outputLoss, self.attLoss, self.wrloss = nn.MSELoss(), nn.BCEWithLogitsLoss(), nn.MSELoss()
elif self.args.loss_type == 'relative' or self.args.loss_type == 'hybrid':
self.outputLoss, self.attLoss, self.wrloss = l1_relative, nn.BCELoss(), l1_relative
else: # l2bl2
self.outputLoss, self.attLoss, self.wrloss = nn.MSELoss(), nn.BCELoss(), nn.MSELoss()
self.default = nn.L1Loss()
if self.args.style_loss > 0:
self.vggloss = VGGLoss(self.args.sltype).to(device)
if self.args.ssim_loss > 0:
self.ssimloss = pytorch_ssim.SSIM().to(device)
self.norm = norm_func
self.denorm = denorm_func
def forward(self,pred_ims,target,pred_ms,mask,pred_wms,wm):
pixel_loss,att_loss,wm_loss,vgg_loss,ssim_loss = [0]*5
pred_ims = pred_ims if is_dic(pred_ims) else [pred_ims]
# try the loss in the masked region
if self.args.masked and 'hybrid' in self.args.loss_type: # masked loss
pixel_loss += sum([self.outputLoss(pred_im, target, mask) for pred_im in pred_ims])
pixel_loss += sum([self.default(pred_im*pred_ms,target*mask) for pred_im in pred_ims])
recov_imgs = [ self.denorm(pred_im*mask + (1-mask)*self.norm(target)) for pred_im in pred_ims ]
wm_loss += self.wrloss(pred_wms, wm, mask)
wm_loss += self.default(pred_wms*pred_ms, wm*mask)
elif self.args.masked and 'relative' in self.args.loss_type: # masked loss
pixel_loss += sum([self.outputLoss(pred_im, target, mask) for pred_im in pred_ims])
recov_imgs = [ self.denorm(pred_im*mask + (1-mask)*self.norm(target)) for pred_im in pred_ims ]
wm_loss = self.wrloss(pred_wms, wm, mask)
elif self.args.masked:
pixel_loss += sum([self.outputLoss(pred_im*mask, target*mask) for pred_im in pred_ims])
recov_imgs = [ self.denorm(pred_im*pred_ms + (1-pred_ms)*self.norm(target)) for pred_im in pred_ims ]
wm_loss = self.wrloss(pred_wms*mask, wm*mask)
else:
pixel_loss += sum([self.outputLoss(pred_im*pred_ms, target*mask) for pred_im in pred_ims])
recov_imgs = [ self.denorm(pred_im*pred_ms + (1-pred_ms)*self.norm(target)) for pred_im in pred_ims ]
wm_loss = self.wrloss(pred_wms*pred_ms,wm*mask)
pixel_loss += sum([self.default(im,target) for im in recov_imgs])
if self.args.style_loss > 0:
vgg_loss = sum([self.vggloss(im,target,mask) for im in recov_imgs])
if self.args.ssim_loss > 0:
ssim_loss = sum([ 1 - self.ssimloss(im,target) for im in recov_imgs])
att_loss = self.attLoss(pred_ms, mask)
return pixel_loss,att_loss,wm_loss,vgg_loss,ssim_loss
class VX(BasicMachine):
def __init__(self,**kwargs):
BasicMachine.__init__(self,**kwargs)
self.loss = Losses(self.args, self.device, self.norm, self.denorm)
self.model.set_optimizers()
self.optimizer = None
def train(self,epoch):
self.current_epoch = epoch
batch_time = AverageMeter()
data_time = AverageMeter()
losses = AverageMeter()
lossMask = AverageMeter()
lossWM = AverageMeter()
lossMX = AverageMeter()
lossvgg = AverageMeter()
lossssim = AverageMeter()
# switch to train mode
self.model.train()
end = time.time()
bar = Bar('Processing {} '.format(self.args.arch), max=len(self.train_loader))
for i, batches in enumerate(self.train_loader):
current_index = len(self.train_loader) * epoch + i
inputs = batches['image'].to(self.device)
target = batches['target'].to(self.device)
mask = batches['mask'].to(self.device)
wm = batches['wm'].to(self.device)
outputs = self.model(self.norm(inputs))
self.model.zero_grad_all()
l2_loss,att_loss,wm_loss,style_loss,ssim_loss = self.loss(outputs[0],self.norm(target),outputs[1],mask,outputs[2],self.norm(wm))
total_loss = 2*l2_loss + self.args.att_loss * att_loss + wm_loss + self.args.style_loss * style_loss + self.args.ssim_loss * ssim_loss
# compute gradient and do SGD step
total_loss.backward()
self.model.step_all()
# measure accuracy and record loss
losses.update(l2_loss.item(), inputs.size(0))
lossMask.update(att_loss.item(), inputs.size(0))
lossWM.update(wm_loss.item(), inputs.size(0))
if self.args.style_loss > 0 :
lossvgg.update(style_loss.item(), inputs.size(0))
if self.args.ssim_loss > 0 :
lossssim.update(ssim_loss.item(), inputs.size(0))
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
# plot progress
suffix = "({batch}/{size}) Data: {data:.2f}s | Batch: {bt:.3f}s | Total: {total:} | ETA: {eta:} | Loss L2: {loss_label:.4f} | Loss Mask: {loss_mask:.4f} | loss WM: {loss_wm:.4f} | loss VGG: {loss_vgg:.4f} | loss SSIM: {loss_ssim:.4f}| loss MX: {loss_mx:.4f}".format(
batch=i + 1,
size=len(self.train_loader),
data=data_time.val,
bt=batch_time.val,
total=bar.elapsed_td,
eta=bar.eta_td,
loss_label=losses.avg,
loss_mask=lossMask.avg,
loss_wm=lossWM.avg,
loss_vgg=lossvgg.avg,
loss_ssim=lossssim.avg,
loss_mx=lossMX.avg
)
if current_index % 1000 == 0:
print(suffix)
if self.args.freq > 0 and current_index % self.args.freq == 0:
self.validate(current_index)
self.flush()
self.save_checkpoint()
self.record('train/loss_L2', losses.avg, epoch)
self.record('train/loss_Mask', lossMask.avg, epoch)
self.record('train/loss_WM', lossWM.avg, epoch)
self.record('train/loss_VGG', lossvgg.avg, epoch)
self.record('train/loss_SSIM', lossssim.avg, epoch)
self.record('train/loss_MX', lossMX.avg, epoch)
def validate(self, epoch):
self.current_epoch = epoch
batch_time = AverageMeter()
data_time = AverageMeter()
losses = AverageMeter()
lossMask = AverageMeter()
psnres = AverageMeter()
ssimes = AverageMeter()
# switch to evaluate mode
self.model.eval()
end = time.time()
bar = Bar('Processing {} '.format(self.args.arch), max=len(self.val_loader))
with torch.no_grad():
for i, batches in enumerate(self.val_loader):
current_index = len(self.val_loader) * epoch + i
inputs = batches['image'].to(self.device)
target = batches['target'].to(self.device)
outputs = self.model(self.norm(inputs))
imoutput,immask,imwatermark = outputs
imoutput = imoutput[0] if is_dic(imoutput) else imoutput
imfinal = self.denorm(imoutput*immask + self.norm(inputs)*(1-immask))
if i % 300 == 0:
# save the sample images
ims = torch.cat([inputs,target,imfinal,immask.repeat(1,3,1,1)],dim=3)
torchvision.utils.save_image(ims,os.path.join(self.args.checkpoint,'%s_%s.jpg'%(i,epoch)))
# here two choice: mseLoss or NLLLoss
psnr = 10 * log10(1 / F.mse_loss(imfinal,target).item())
ssim = pytorch_ssim.ssim(imfinal,target)
psnres.update(psnr, inputs.size(0))
ssimes.update(ssim, inputs.size(0))
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
# plot progress
bar.suffix = '({batch}/{size}) Data: {data:.2f}s | Batch: {bt:.3f}s | Total: {total:} | ETA: {eta:} | Loss_L2: {loss_label:.4f} | Loss_Mask: {loss_mask:.4f} | PSNR: {psnr:.4f} | SSIM: {ssim:.4f}'.format(
batch=i + 1,
size=len(self.val_loader),
data=data_time.val,
bt=batch_time.val,
total=bar.elapsed_td,
eta=bar.eta_td,
loss_label=losses.avg,
loss_mask=lossMask.avg,
psnr=psnres.avg,
ssim=ssimes.avg
)
bar.next()
bar.finish()
print("Iter:%s,Losses:%s,PSNR:%.4f,SSIM:%.4f"%(epoch, losses.avg,psnres.avg,ssimes.avg))
self.record('val/loss_L2', losses.avg, epoch)
self.record('val/lossMask', lossMask.avg, epoch)
self.record('val/PSNR', psnres.avg, epoch)
self.record('val/SSIM', ssimes.avg, epoch)
self.metric = psnres.avg
self.model.train()
def test(self, ):
# switch to evaluate mode
self.model.eval()
print("==> testing VM model ")
ssimes = AverageMeter()
psnres = AverageMeter()
ssimesx = AverageMeter()
psnresx = AverageMeter()
with torch.no_grad():
for i, batches in enumerate(tqdm(self.val_loader)):
inputs = batches['image'].to(self.device)
target = batches['target'].to(self.device)
mask =batches['mask'].to(self.device)
# select the outputs by the giving arch
outputs = self.model(self.norm(inputs))
imoutput,immask,imwatermark = outputs
imoutput = imoutput[0] if is_dic(imoutput) else imoutput
imfinal = self.denorm(imoutput*immask + self.norm(inputs)*(1-immask))
psnrx = 10 * log10(1 / F.mse_loss(imfinal,target).item())
ssimx = pytorch_ssim.ssim(imfinal,target)
# recover the image to 255
imfinal = im_to_numpy(torch.clamp(imfinal[0]*255,min=0.0,max=255.0)).astype(np.uint8)
target = im_to_numpy(torch.clamp(target[0]*255,min=0.0,max=255.0)).astype(np.uint8)
skimage.io.imsave('%s/%s'%(self.args.checkpoint,batches['name'][0]), imfinal)
psnr = compare_psnr(target,imfinal)
ssim = compare_ssim(target,imfinal,multichannel=True)
psnres.update(psnr, inputs.size(0))
ssimes.update(ssim, inputs.size(0))
psnresx.update(psnrx, inputs.size(0))
ssimesx.update(ssimx, inputs.size(0))
print("%s:PSNR:%.5f(%.5f),SSIM:%.5f(%.5f)"%(self.args.checkpoint,psnres.avg,psnresx.avg,ssimes.avg,ssimesx.avg))
print("DONE.\n")
================================================
FILE: scripts/machines/__init__.py
================================================
from .BasicMachine import BasicMachine
from .VX import VX
from .S2AM import S2AM
def basic(**kwargs):
return BasicMachine(**kwargs)
def s2am(**kwargs):
return S2AM(**kwargs)
def vx(**kwargs):
return VX(**kwargs)
================================================
FILE: scripts/models/__init__.py
================================================
from .vgg import *
from .backbone_unet import *
from .discriminator import *
================================================
FILE: scripts/models/backbone_unet.py
================================================
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import functools
import math
from scripts.utils.model_init import *
from scripts.models.rasc import *
from scripts.models.unet import UnetGenerator,MinimalUnetV2
from scripts.models.vmu import UnetVM
from scripts.models.sa_resunet import UnetVMS2AMv4
# our method
def vvv4n(**kwargs):
return UnetVMS2AMv4(shared_depth=2, blocks=3, long_skip=True, use_vm_decoder=True,s2am='vms2am')
# BVMR
def vm3(**kwargs):
return UnetVM(shared_depth=2, blocks=3, use_vm_decoder=True)
# Blind version of S2AM
def urasc(**kwargs):
model = UnetGenerator(3,3,is_attention_layer=True,attention_model=URASC,basicblock=MinimalUnetV2)
model.apply(weights_init_kaiming)
return model
# Improving the Harmony of the Composite Image by Spatial-Separated Attention Module
# Xiaodong Cun and Chi-Man Pun
# University of Macau
# Trans. on Image Processing, vol. 29, pp. 4759-4771, 2020.
def rascv2(**kwargs):
model = UnetGenerator(4,3,is_attention_layer=True,attention_model=RASC,basicblock=MinimalUnetV2)
model.apply(weights_init_kaiming)
return model
# just original unet
def unet(**kwargs):
model = UnetGenerator(3,3)
model.apply(weights_init_kaiming)
return model
================================================
FILE: scripts/models/blocks.py
================================================
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import functools
import math
import numbers
from scripts.utils.model_init import *
from scripts.models.vgg import Vgg16
from torch import nn, cuda
from torch.autograd import Variable
class BasicLearningBlock(nn.Module):
"""docstring for BasicLearningBlock"""
def __init__(self,channel):
super(BasicLearningBlock, self).__init__()
self.rconv1 = nn.Conv2d(channel,channel*2,3,padding=1,bias=False)
self.rbn1 = nn.BatchNorm2d(channel*2)
self.rconv2 = nn.Conv2d(channel*2,channel,3,padding=1,bias=False)
self.rbn2 = nn.BatchNorm2d(channel)
def forward(self,feature):
return F.elu(self.rbn2(self.rconv2(F.elu(self.rbn1(self.rconv1(feature))))))
# From https://discuss.pytorch.org/t/is-there-anyway-to-do-gaussian-filtering-for-an-image-2d-3d-in-pytorch/12351/3
class GaussianSmoothing(nn.Module):
"""
Apply gaussian smoothing on a
1d, 2d or 3d tensor. Filtering is performed seperately for each channel
in the input using a depthwise convolution.
Arguments:
channels (int, sequence): Number of channels of the input tensors. Output will
have this number of channels as well.
kernel_size (int, sequence): Size of the gaussian kernel.
sigma (float, sequence): Standard deviation of the gaussian kernel.
dim (int, optional): The number of dimensions of the data.
Default value is 2 (spatial).
"""
def __init__(self, channels, kernel_size, sigma, dim=2):
super(GaussianSmoothing, self).__init__()
if isinstance(kernel_size, numbers.Number):
kernel_size = [kernel_size] * dim
if isinstance(sigma, numbers.Number):
sigma = [sigma] * dim
# The gaussian kernel is the product of the
# gaussian function of each dimension.
kernel = 1
meshgrids = torch.meshgrid(
[
torch.arange(size, dtype=torch.float32)
for size in kernel_size
]
)
for size, std, mgrid in zip(kernel_size, sigma, meshgrids):
mean = (size - 1) / 2
kernel *= 1 / (std * math.sqrt(2 * math.pi)) * \
torch.exp(-((mgrid - mean) / (2 * std)) ** 2)
# Make sure sum of values in gaussian kernel equals 1.
kernel = kernel / torch.sum(kernel)
# Reshape to depthwise convolutional weight
kernel = kernel.view(1, 1, *kernel.size())
kernel = kernel.repeat(channels, *[1] * (kernel.dim() - 1))
self.register_buffer('weight', kernel)
self.groups = channels
if dim == 1:
self.conv = F.conv1d
elif dim == 2:
self.conv = F.conv2d
elif dim == 3:
self.conv = F.conv3d
else:
raise RuntimeError(
'Only 1, 2 and 3 dimensions are supported. Received {}.'.format(dim)
)
def forward(self, input):
"""
Apply gaussian filter to input.
Arguments:
input (torch.Tensor): Input to apply gaussian filter on.
Returns:
filtered (torch.Tensor): Filtered output.
"""
return self.conv(input, weight=self.weight, groups=self.groups)
class ChannelPool(nn.Module):
def __init__(self,types):
super(ChannelPool, self).__init__()
if types == 'avg':
self.poolingx = nn.AdaptiveAvgPool1d(1)
elif types == 'max':
self.poolingx = nn.AdaptiveMaxPool1d(1)
else:
raise 'inner error'
def forward(self, input):
n, c, w, h = input.size()
input = input.view(n,c,w*h).permute(0,2,1)
pooled = self.poolingx(input)# b,w*h,c -> b,w*h,1
_, _, c = pooled.size()
return pooled.view(n,c,w,h)
class SEBlock(nn.Module):
"""docstring for SEBlock"""
def __init__(self, channel,reducation=16):
super(SEBlock, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.fc = nn.Sequential(
nn.Linear(channel,channel//reducation),
nn.ReLU(inplace=True),
nn.Linear(channel//reducation,channel),
nn.Sigmoid())
def forward(self,x):
b,c,w,h = x.size()
y1 = self.avg_pool(x).view(b,c)
y = self.fc(y1).view(b,c,1,1)
return x*y
class GlobalAttentionModule(nn.Module):
"""docstring for GlobalAttentionModule"""
def __init__(self, channel,reducation=16):
super(GlobalAttentionModule, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.max_pool = nn.AdaptiveMaxPool2d(1)
self.fc = nn.Sequential(
nn.Linear(channel*2,channel//reducation),
nn.ReLU(inplace=True),
nn.Linear(channel//reducation,channel),
nn.Sigmoid())
def forward(self,x):
b,c,w,h = x.size()
y1 = self.avg_pool(x).view(b,c)
y2 = self.max_pool(x).view(b,c)
y = self.fc(torch.cat([y1,y2],1)).view(b,c,1,1)
return x*y
class SpatialAttentionModule(nn.Module):
"""docstring for SpatialAttentionModule"""
def __init__(self, channel,reducation=16):
super(SpatialAttentionModule, self).__init__()
self.avg_pool = ChannelPool('avg')
self.max_pool = ChannelPool('max')
self.fc = nn.Sequential(
nn.Conv2d(2,reducation,7,stride=1,padding=3),
nn.ReLU(inplace=True),
nn.Conv2d(reducation,1,7,stride=1,padding=3),
nn.Sigmoid())
def forward(self,x):
b,c,w,h = x.size()
y1 = self.avg_pool(x)
y2 = self.max_pool(x)
y = self.fc(torch.cat([y1,y2],1))
yr = 1-y
return y,yr
class GlobalAttentionModuleJustSigmoid(nn.Module):
"""docstring for GlobalAttentionModule"""
def __init__(self, channel,reducation=16):
super(GlobalAttentionModuleJustSigmoid, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.max_pool = nn.AdaptiveMaxPool2d(1)
self.fc = nn.Sequential(
nn.Linear(channel*2,channel//reducation),
nn.ReLU(inplace=True),
nn.Linear(channel//reducation,channel),
nn.Sigmoid())
def forward(self,x):
b,c,w,h = x.size()
y1 = self.avg_pool(x).view(b,c)
y2 = self.max_pool(x).view(b,c)
y = self.fc(torch.cat([y1,y2],1)).view(b,c,1,1)
return y
class BasicBlock(nn.Module):
def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1, groups=1, relu=True, bn=True, bias=False):
super(BasicBlock, self).__init__()
self.out_channels = out_planes
self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias)
self.bn = nn.BatchNorm2d(out_planes,eps=1e-5, momentum=0.01, affine=True) if bn else None
self.relu = nn.ReLU() if relu else None
def forward(self, x):
x = self.conv(x)
if self.bn is not None:
x = self.bn(x)
if self.relu is not None:
x = self.relu(x)
return x
class Flatten(nn.Module):
def forward(self, x):
return x.view(x.size(0), -1)
class ChannelGate(nn.Module):
def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg', 'max']):
super(ChannelGate, self).__init__()
self.gate_channels = gate_channels
self.mlp = nn.Sequential(
Flatten(),
nn.Linear(gate_channels, gate_channels // reduction_ratio),
nn.ReLU(),
nn.Linear(gate_channels // reduction_ratio, gate_channels)
)
self.pool_types = pool_types
def forward(self, x):
channel_att_sum = None
for pool_type in self.pool_types:
if pool_type=='avg':
avg_pool = F.avg_pool2d( x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))
channel_att_raw = self.mlp( avg_pool )
elif pool_type=='max':
max_pool = F.max_pool2d( x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))
channel_att_raw = self.mlp( max_pool )
elif pool_type=='lp':
lp_pool = F.lp_pool2d( x, 2, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))
channel_att_raw = self.mlp( lp_pool )
elif pool_type=='lse':
# LSE pool only
lse_pool = logsumexp_2d(x)
channel_att_raw = self.mlp( lse_pool )
if channel_att_sum is None:
channel_att_sum = channel_att_raw
else:
channel_att_sum = channel_att_sum + channel_att_raw
scale = F.sigmoid( channel_att_sum ).unsqueeze(2).unsqueeze(3).expand_as(x)
return x * scale
def logsumexp_2d(tensor):
tensor_flatten = tensor.view(tensor.size(0), tensor.size(1), -1)
s, _ = torch.max(tensor_flatten, dim=2, keepdim=True)
outputs = s + (tensor_flatten - s).exp().sum(dim=2, keepdim=True).log()
return outputs
class ChannelPoolX(nn.Module):
def forward(self, x):
return torch.cat( (torch.max(x,1)[0].unsqueeze(1), torch.mean(x,1).unsqueeze(1)), dim=1 )
class SpatialGate(nn.Module):
def __init__(self):
super(SpatialGate, self).__init__()
kernel_size = 7
self.compress = ChannelPoolX()
self.spatial = BasicBlock(2, 1, kernel_size, stride=1, padding=(kernel_size-1) // 2, relu=False)
def forward(self, x):
x_compress = self.compress(x)
x_out = self.spatial(x_compress)
scale = F.sigmoid(x_out) # broadcasting
return x * scale
class CBAM(nn.Module):
def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg', 'max'], no_spatial=False):
super(CBAM, self).__init__()
self.ChannelGate = ChannelGate(gate_channels, reduction_ratio, pool_types)
self.no_spatial=no_spatial
if not no_spatial:
self.SpatialGate = SpatialGate()
def forward(self, x):
x_out = self.ChannelGate(x)
if not self.no_spatial:
x_out = self.SpatialGate(x_out)
return x_out
================================================
FILE: scripts/models/discriminator.py
================================================
import numpy as np
import functools
import math
import torch
from torch.autograd import Variable
import torch.nn.functional as F
from torch import nn
from torch import Tensor
from torch.nn import Parameter
from scripts.utils.model_init import *
from torch.optim.optimizer import Optimizer, required
__all__ = ['patchgan','sngan','maskedsngan']
class SNCoXvWithActivation(torch.nn.Module):
"""
SN convolution for spetral normalization conv
"""
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, activation=torch.nn.LeakyReLU(0.2, inplace=True)):
super(SNCoXvWithActivation, self).__init__()
self.conv2d = torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias)
self.conv2d = torch.nn.utils.spectral_norm(self.conv2d)
self.activation = activation
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight)
def forward(self, input):
x = self.conv2d(input)
if self.activation is not None:
return self.activation(x)
else:
return x
def l2normalize(v, eps=1e-12):
return v / (v.norm() + eps)
class SpectralNorm(nn.Module):
def __init__(self, module, name='weight', power_iterations=1):
super(SpectralNorm, self).__init__()
self.module = module
self.name = name
self.power_iterations = power_iterations
if not self._made_params():
self._make_params()
def _update_u_v(self):
u = getattr(self.module, self.name + "_u")
v = getattr(self.module, self.name + "_v")
w = getattr(self.module, self.name + "_bar")
height = w.data.shape[0]
for _ in range(self.power_iterations):
v.data = l2normalize(torch.mv(torch.t(w.view(height,-1).data), u.data))
u.data = l2normalize(torch.mv(w.view(height,-1).data, v.data))
# sigma = torch.dot(u.data, torch.mv(w.view(height,-1).data, v.data))
sigma = u.dot(w.view(height, -1).mv(v))
setattr(self.module, self.name, w / sigma.expand_as(w))
def _made_params(self):
try:
u = getattr(self.module, self.name + "_u")
v = getattr(self.module, self.name + "_v")
w = getattr(self.module, self.name + "_bar")
return True
except AttributeError:
return False
def _make_params(self):
w = getattr(self.module, self.name)
height = w.data.shape[0]
width = w.view(height, -1).data.shape[1]
u = Parameter(w.data.new(height).normal_(0, 1), requires_grad=False)
v = Parameter(w.data.new(width).normal_(0, 1), requires_grad=False)
u.data = l2normalize(u.data)
v.data = l2normalize(v.data)
w_bar = Parameter(w.data)
del self.module._parameters[self.name]
self.module.register_parameter(self.name + "_u", u)
self.module.register_parameter(self.name + "_v", v)
self.module.register_parameter(self.name + "_bar", w_bar)
def forward(self, *args):
self._update_u_v()
return self.module.forward(*args)
def get_pad(in_, ksize, stride, atrous=1):
out_ = np.ceil(float(in_)/stride)
return int(((out_ - 1) * stride + atrous*(ksize-1) + 1 - in_)/2)
class SNDiscriminator(nn.Module):
def __init__(self,channel=6):
super(SNDiscriminator, self).__init__()
cnum = 32
self.discriminator_net = nn.Sequential(
SNCoXvWithActivation(channel, 2*cnum, 4, 2, padding=get_pad(256, 5, 2)),
SNCoXvWithActivation(2*cnum, 4*cnum, 4, 2, padding=get_pad(128, 5, 2)),
SNCoXvWithActivation(4*cnum, 8*cnum, 4, 2, padding=get_pad(64, 5, 2)),
SNCoXvWithActivation(8*cnum, 8*cnum, 4, 2, padding=get_pad(32, 5, 2)),
SNCoXvWithActivation(8*cnum, 8*cnum, 4, 2, padding=get_pad(16, 5, 2)), # 8*8*256
# SNConvWithActivation(8*cnum, 8*cnum, 4, 2, padding=get_pad(8, 5, 2)), # 4*4*256
# SNConvWithActivation(8*cnum, 8*cnum, 4, 2, padding=get_pad(4, 5, 2)), # 2*2*256
)
# self.linear = nn.Linear(2*2*256,1)
def forward(self, img_A, img_B):
# Concatenate image and condition image by channels to produce input
img_input = torch.cat((img_A, img_B), 1)
x = self.discriminator_net(img_input)
# x = x.view((x.size(0),-1))
# x = self.linear(x)
return x
class Discriminator(nn.Module):
def __init__(self, in_channels=3):
super(Discriminator, self).__init__()
def discriminator_block(in_filters, out_filters, normalization=True):
"""Returns downsampling layers of each discriminator block"""
layers = [nn.Conv2d(in_filters, out_filters, 4, stride=2, padding=1)]
if normalization:
layers.append(nn.InstanceNorm2d(out_filters))
layers.append(nn.LeakyReLU(0.2, inplace=True))
return layers
self.model = nn.Sequential(
*discriminator_block(in_channels*2, 64, normalization=False),
*discriminator_block(64, 128),
*discriminator_block(128, 256),
*discriminator_block(256, 512),
nn.ZeroPad2d((1, 0, 1, 0)),
nn.Conv2d(512, 1, 4, padding=1, bias=False)
)
def forward(self, img_A, img_B):
# Concatenate image and condition image by channels to produce input
img_input = torch.cat((img_A, img_B), 1)
return self.model(img_input)
def patchgan():
model = Discriminator()
model.apply(weights_init_kaiming)
return model
def sngan():
model = SNDiscriminator()
model.apply(weights_init_kaiming)
return model
def maskedsngan():
model = SNDiscriminator(channel=7)
model.apply(weights_init_kaiming)
return model
================================================
FILE: scripts/models/rasc.py
================================================
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import math
from scripts.utils.model_init import *
from scripts.models.vgg import Vgg16
from scripts.models.blocks import *
class CAWapper(nn.Module):
"""docstring for SENet"""
def __init__(self, channel, type_of_connection=BasicLearningBlock):
super(CAWapper, self).__init__()
self.attention = ContextualAttention(ksize=3, stride=1, rate=2, fuse_k=3, softmax_scale=10, fuse=True, use_cuda=True)
def forward(self, feature, mask):
_, _, w, _ = feature.size()
_, _, mw, _ = mask.size()
# binaryfiy
# selected the feature from the background as the additional feature to masked splicing feature.
mask = torch.round(F.avg_pool2d(mask, 2, stride=mw//w))
result = self.attention(feature,mask)
return result
class NLWapper(nn.Module):
"""docstring for SENet"""
def __init__(self, channel, type_of_connection=BasicLearningBlock):
super(NLWapper, self).__init__()
self.attention = NONLocalBlock2D(channel)
def forward(self, feature, mask):
_, _, w, _ = feature.size()
_, _, mw, _ = mask.size()
# binaryfiy
# selected the feature from the background as the additional feature to masked splicing feature.
# mask = torch.round(F.avg_pool2d(mask, 2, stride=mw//w))
result = self.attention(feature)
return result
class SENet(nn.Module):
"""docstring for SENet"""
def __init__(self,channel,type_of_connection=BasicLearningBlock):
super(SENet, self).__init__()
self.attention = SEBlock(channel,16)
def forward(self,feature,mask):
_,_,w,_ = feature.size()
_,_,mw,_ = mask.size()
# binaryfiy
# selected the feature from the background as the additional feature to masked splicing feature.
mask = torch.round(F.avg_pool2d(mask,2,stride=mw//w))
result = self.attention(feature)
return result
class CBAMConnect(nn.Module):
def __init__(self,channel):
super(CBAMConnect, self).__init__()
self.attention = CBAM(channel)
def forward(self,feature,mask):
results = self.attention(feature)
return results
class RASC(nn.Module):
def __init__(self,channel,type_of_connection=BasicLearningBlock):
super(RASC, self).__init__()
self.connection = type_of_connection(channel)
self.background_attention = GlobalAttentionModule(channel,16)
self.mixed_attention = GlobalAttentionModule(channel,16)
self.spliced_attention = GlobalAttentionModule(channel,16)
self.gaussianMask = GaussianSmoothing(1,5,1)
def forward(self,feature,mask):
_,_,w,_ = feature.size()
_,_,mw,_ = mask.size()
# binaryfiy
# selected the feature from the background as the additional feature to masked splicing feature.
if w != mw:
mask = torch.round(F.avg_pool2d(mask,2,stride=mw//w))
reverse_mask = -1*(mask-1)
# here we add gaussin filter to mask and reverse_mask for better harimoization of edges.
mask = self.gaussianMask(F.pad(mask,(2,2,2,2),mode='reflect'))
reverse_mask = self.gaussianMask(F.pad(reverse_mask,(2,2,2,2),mode='reflect'))
background = self.background_attention(feature) * reverse_mask
selected_feature = self.mixed_attention(feature)
spliced_feature = self.spliced_attention(feature)
spliced = ( self.connection(spliced_feature) + selected_feature ) * mask
return background + spliced
class UNO(nn.Module):
def __init__(self,channel):
super(UNO, self).__init__()
def forward(self,feature,_m):
return feature
class URASC(nn.Module):
def __init__(self,channel,type_of_connection=BasicLearningBlock):
super(URASC, self).__init__()
self.connection = type_of_connection(channel)
self.background_attention = GlobalAttentionModule(channel,16)
self.mixed_attention = GlobalAttentionModule(channel,16)
self.spliced_attention = GlobalAttentionModule(channel,16)
self.mask_attention = SpatialAttentionModule(channel,16)
def forward(self,feature, m=None):
_,_,w,_ = feature.size()
mask, reverse_mask = self.mask_attention(feature)
background = self.background_attention(feature) * reverse_mask
selected_feature = self.mixed_attention(feature)
spliced_feature = self.spliced_attention(feature)
spliced = ( self.connection(spliced_feature) + selected_feature ) * mask
return background + spliced
class MaskedURASC(nn.Module):
def __init__(self,channel,type_of_connection=BasicLearningBlock):
super(MaskedURASC, self).__init__()
self.connection = type_of_connection(channel)
self.background_attention = GlobalAttentionModule(channel,16)
self.mixed_attention = GlobalAttentionModule(channel,16)
self.spliced_attention = GlobalAttentionModule(channel,16)
self.mask_attention = SpatialAttentionModule(channel,16)
def forward(self,feature):
_,_,w,_ = feature.size()
mask, reverse_mask = self.mask_attention(feature)
background = self.background_attention(feature) * reverse_mask
selected_feature = self.mixed_attention(feature)
spliced_feature = self.spliced_attention(feature)
spliced = ( self.connection(spliced_feature) + selected_feature ) * mask
return background + spliced, mask
================================================
FILE: scripts/models/sa_resunet.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
from scripts.models.blocks import SEBlock
from scripts.models.rasc import *
from scripts.models.unet import UnetGenerator,MinimalUnetV2
def weight_init(m):
if isinstance(m, nn.Conv2d):
nn.init.xavier_normal_(m.weight)
nn.init.constant_(m.bias, 0)
def reset_params(model):
for i, m in enumerate(model.modules()):
weight_init(m)
def conv3x3(in_channels, out_channels, stride=1,
padding=1, bias=True, groups=1):
return nn.Conv2d(
in_channels,
out_channels,
kernel_size=3,
stride=stride,
padding=padding,
bias=bias,
groups=groups)
def up_conv2x2(in_channels, out_channels, transpose=True):
if transpose:
return nn.ConvTranspose2d(
in_channels,
out_channels,
kernel_size=2,
stride=2)
else:
return nn.Sequential(
nn.Upsample(mode='bilinear', scale_factor=2),
conv1x1(in_channels, out_channels))
def conv1x1(in_channels, out_channels, groups=1):
return nn.Conv2d(
in_channels,
out_channels,
kernel_size=1,
groups=groups,
stride=1)
class UpCoXvD(nn.Module):
def __init__(self, in_channels, out_channels, blocks, residual=True,norm=nn.BatchNorm2d, act=F.relu,batch_norm=True, transpose=True,concat=True,use_att=False):
super(UpCoXvD, self).__init__()
self.concat = concat
self.residual = residual
self.batch_norm = batch_norm
self.bn = None
self.conv2 = []
self.use_att = use_att
self.up_conv = up_conv2x2(in_channels, out_channels, transpose=transpose)
self.norm0 = norm(out_channels)
if self.use_att:
self.s2am = RASC(2 * out_channels)
else:
self.s2am = None
if self.concat:
self.conv1 = conv3x3(2 * out_channels, out_channels)
self.norm1 = norm(out_channels , out_channels)
else:
self.conv1 = conv3x3(out_channels, out_channels)
self.norm1 = norm(out_channels , out_channels)
for _ in range(blocks):
self.conv2.append(conv3x3(out_channels, out_channels))
if self.batch_norm:
self.bn = []
for _ in range(blocks):
self.bn.append(norm(out_channels))
self.bn = nn.ModuleList(self.bn)
self.conv2 = nn.ModuleList(self.conv2)
self.act = act
def forward(self, from_up, from_down, mask=None,se=None):
from_up = self.act(self.norm0(self.up_conv(from_up)))
if self.concat:
x1 = torch.cat((from_up, from_down), 1)
else:
if from_down is not None:
x1 = from_up + from_down
else:
x1 = from_up
if self.use_att:
x1 = self.s2am(x1,mask)
x1 = self.act(self.norm1(self.conv1(x1)))
x2 = None
for idx, conv in enumerate(self.conv2):
x2 = conv(x1)
if self.batch_norm:
x2 = self.bn[idx](x2)
if (se is not None) and (idx == len(self.conv2) - 1): # last
x2 = se(x2)
if self.residual:
x2 = x2 + x1
x2 = self.act(x2)
x1 = x2
return x2
class DownCoXvD(nn.Module):
def __init__(self, in_channels, out_channels, blocks, pooling=True, norm=nn.BatchNorm2d,act=F.relu,residual=True, batch_norm=True):
super(DownCoXvD, self).__init__()
self.pooling = pooling
self.residual = residual
self.batch_norm = batch_norm
self.bn = None
self.pool = None
self.conv1 = conv3x3(in_channels, out_channels)
self.norm1 = norm(out_channels)
self.conv2 = []
for _ in range(blocks):
self.conv2.append(conv3x3(out_channels, out_channels))
if self.batch_norm:
self.bn = []
for _ in range(blocks):
self.bn.append(norm(out_channels))
self.bn = nn.ModuleList(self.bn)
if self.pooling:
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
self.conv2 = nn.ModuleList(self.conv2)
self.act = act
def __call__(self, x):
return self.forward(x)
def forward(self, x):
x1 = self.act(self.norm1(self.conv1(x)))
x2 = None
for idx, conv in enumerate(self.conv2):
x2 = conv(x1)
if self.batch_norm:
x2 = self.bn[idx](x2)
if self.residual:
x2 = x2 + x1
x2 = self.act(x2)
x1 = x2
before_pool = x2
if self.pooling:
x2 = self.pool(x2)
return x2, before_pool
class UnetDecoderD(nn.Module):
def __init__(self, in_channels=512, out_channels=3, norm=nn.BatchNorm2d,act=F.relu, depth=5, blocks=1, residual=True, batch_norm=True,
transpose=True, concat=True, is_final=True, use_att=False):
super(UnetDecoderD, self).__init__()
self.conv_final = None
self.up_convs = []
self.atts = []
self.use_att = use_att
outs = in_channels
for i in range(depth-1): # depth = 1
ins = outs
outs = ins // 2
# 512,256
# 256,128
# 128,64
# 64,32
up_conv = UpCoXvD(ins, outs, blocks, residual=residual, batch_norm=batch_norm, transpose=transpose,
concat=concat, norm=norm, act=act)
if self.use_att:
self.atts.append(SEBlock(outs))
self.up_convs.append(up_conv)
if is_final:
self.conv_final = conv1x1(outs, out_channels)
else:
up_conv = UpCoXvD(outs, out_channels, blocks, residual=residual, batch_norm=batch_norm, transpose=transpose,
concat=concat,norm=norm, act=act)
if self.use_att:
self.atts.append(SEBlock(out_channels))
self.up_convs.append(up_conv)
self.up_convs = nn.ModuleList(self.up_convs)
self.atts = nn.ModuleList(self.atts)
reset_params(self)
def __call__(self, x, encoder_outs=None):
return self.forward(x, encoder_outs)
def forward(self, x, encoder_outs=None):
for i, up_conv in enumerate(self.up_convs):
before_pool = None
if encoder_outs is not None:
before_pool = encoder_outs[-(i+2)]
x = up_conv(x, before_pool)
if self.use_att:
x = self.atts[i](x)
if self.conv_final is not None:
x = self.conv_final(x)
return x
class UnetDecoderDatt(nn.Module):
def __init__(self, in_channels=512, out_channels=3, depth=5, blocks=1, residual=True, batch_norm=True,
transpose=True, concat=True, is_final=True, norm=nn.BatchNorm2d,act=F.relu):
super(UnetDecoderDatt, self).__init__()
self.conv_final = None
self.up_convs = []
self.im_atts = []
self.vm_atts = []
self.mask_atts = []
outs = in_channels
for i in range(depth-1): # depth = 5 [0,1,2,3]
ins = outs
outs = ins // 2
# 512,256
# 256,128
# 128,64
# 64,32
up_conv = UpCoXvD(ins, outs, blocks, residual=residual, batch_norm=batch_norm, transpose=transpose,
concat=concat, norm=nn.BatchNorm2d,act=F.relu)
self.up_convs.append(up_conv)
self.im_atts.append(SEBlock(outs))
self.vm_atts.append(SEBlock(outs))
self.mask_atts.append(SEBlock(outs))
if is_final:
self.conv_final = conv1x1(outs, out_channels)
else:
up_conv = UpCoXvD(outs, out_channels, blocks, residual=residual, batch_norm=batch_norm, transpose=transpose,
concat=concat, norm=nn.BatchNorm2d,act=F.relu)
self.up_convs.append(up_conv)
self.im_atts.append(SEBlock(out_channels))
self.vm_atts.append(SEBlock(out_channels))
self.mask_atts.append(SEBlock(out_channels))
self.up_convs = nn.ModuleList(self.up_convs)
self.im_atts = nn.ModuleList(self.im_atts)
self.vm_atts = nn.ModuleList(self.vm_atts)
self.mask_atts = nn.ModuleList(self.mask_atts)
reset_params(self)
def forward(self, input, encoder_outs=None):
# im branch
x = input
for i, up_conv in enumerate(self.up_convs):
before_pool = None
if encoder_outs is not None:
before_pool = encoder_outs[-(i+2)]
x = up_conv(x, before_pool,se=self.im_atts[i])
x_im = x
x = input
for i, up_conv in enumerate(self.up_convs):
before_pool = None
if encoder_outs is not None:
before_pool = encoder_outs[-(i+2)]
x = up_conv(x, before_pool, se = self.mask_atts[i])
x_mask = x
x = input
for i, up_conv in enumerate(self.up_convs):
before_pool = None
if encoder_outs is not None:
before_pool = encoder_outs[-(i+2)]
x = up_conv(x, before_pool, se=self.vm_atts[i])
x_vm = x
return x_im,x_mask,x_vm
class UnetEncoderD(nn.Module):
def __init__(self, in_channels=3, depth=5, blocks=1, start_filters=32, residual=True, batch_norm=True, norm=nn.BatchNorm2d, act=F.relu):
super(UnetEncoderD, self).__init__()
self.down_convs = []
outs = None
if type(blocks) is tuple:
blocks = blocks[0]
for i in range(depth):
ins = in_channels if i == 0 else outs
outs = start_filters*(2**i)
pooling = True if i < depth-1 else False
down_conv = DownCoXvD(ins, outs, blocks, pooling=pooling, residual=residual, batch_norm=batch_norm, norm=nn.BatchNorm2d, act=F.relu)
self.down_convs.append(down_conv)
self.down_convs = nn.ModuleList(self.down_convs)
reset_params(self)
def __call__(self, x):
return self.forward(x)
def forward(self, x):
encoder_outs = []
for d_conv in self.down_convs:
x, before_pool = d_conv(x)
encoder_outs.append(before_pool)
return x, encoder_outs
class ResDown(nn.Module):
def __init__(self, in_size, out_size, pooling=True, use_att=False):
super(ResDown, self).__init__()
self.model = DownCoXvD(in_size, out_size, 3, pooling=pooling)
def forward(self, x):
return self.model(x)
class ResUp(nn.Module):
def __init__(self, in_size, out_size, use_att=False):
super(ResUp, self).__init__()
self.model = UpCoXvD(in_size, out_size, 3, use_att=use_att)
def forward(self, x, skip_input, mask=None):
return self.model(x,skip_input,mask)
class ResDownNew(nn.Module):
def __init__(self, in_size, out_size, pooling=True, use_att=False):
super(ResDownNew, self).__init__()
self.model = DownCoXvD(in_size, out_size, 3, pooling=pooling, norm=nn.InstanceNorm2d, act=F.leaky_relu)
def forward(self, x):
return self.model(x)
class ResUpNew(nn.Module):
def __init__(self, in_size, out_size, use_att=False):
super(ResUpNew, self).__init__()
self.model = UpCoXvD(in_size, out_size, 3, use_att=use_att, norm=nn.InstanceNorm2d)
def forward(self, x, skip_input, mask=None):
return self.model(x,skip_input,mask)
class VMSingle(nn.Module):
def __init__(self, in_channels=3, out_channels=3, down=ResDown, up=ResUp, ngf=32, res=True,use_att=False):
super(VMSingle, self).__init__()
self.down1 = down(in_channels, ngf)
self.down2 = down(ngf, ngf*2)
self.down3 = down(ngf*2, ngf*4)
self.down4 = down(ngf*4, ngf*8)
self.down5 = down(ngf*8, ngf*16, pooling=False)
self.up1 = up(ngf*16, ngf*8)
self.up2 = up(ngf*8, ngf*4, use_att=use_att)
self.up3 = up(ngf*4, ngf*2, use_att=use_att)
self.up4 = up(ngf*2, ngf*1, use_att=use_att)
self.im = nn.Conv2d(ngf, 3, 1)
self.res = res
def forward(self, input):
img, mask = input[:,0:3,:,:],input[:,3:4,:,:]
# U-Net generator with skip connections from encoder to decoder
x,d1 = self.down1(input) # 128,256
x,d2 = self.down2(x) # 64,128
x,d3 = self.down3(x) # 32,64
x,d4 = self.down4(x) # 16,32
x,_ = self.down5(x) # 8,16
x = self.up1(x, d4) # 16
x = self.up2(x, d3, mask) # 32
x = self.up3(x, d2, mask) # 64
x = self.up4(x, d1, mask) # 128
im = self.im(x)
return im
class VMSingleS2AM(nn.Module):
def __init__(self, in_channels=3, out_channels=3, down=ResDown, up=ResUp, ngf=32):
super(VMSingleS2AM, self).__init__()
self.down1 = down(in_channels, ngf)
self.down2 = down(ngf, ngf*2)
self.down3 = down(ngf*2, ngf*4)
self.down4 = down(ngf*4, ngf*8)
self.down5 = down(ngf*8, ngf*16, pooling=False)
self.up1 = up(ngf*16, ngf*8)
self.up2 = up(ngf*8, ngf*4)
self.s2am2 = RASC(ngf*4)
self.up3 = up(ngf*4, ngf*2)
self.s2am3 = RASC(ngf*2)
self.up4 = up(ngf*2, ngf*1)
self.s2am4 = RASC(ngf)
self.im = nn.Conv2d(ngf, 3, 1)
def forward(self, input):
img, mask = input[:,0:3,:,:],input[:,3:4,:,:]
# U-Net generator with skip connections from encoder to decoder
x,d1 = self.down1(input) # 128,256
x,d2 = self.down2(x) # 64,128
x,d3 = self.down3(x) # 32,64
x,d4 = self.down4(x) # 16,32
x,_ = self.down5(x) # 8,16
x = self.up1(x, d4) # 16
x = self.up2(x, d3) # 32
x = self.s2am2(x, mask)
x = self.up3(x, d2) # 64
x = self.s2am3(x, mask)
x = self.up4(x, d1) # 128
x = self.s2am4(x, mask)
im = self.im(x)
return im
class UnetVMS2AMv4(nn.Module):
def __init__(self, in_channels=3, depth=5, shared_depth=0, use_vm_decoder=False, blocks=1,
out_channels_image=3, out_channels_mask=1, start_filters=32, residual=True, batch_norm=True,
transpose=True, concat=True, transfer_data=True, long_skip=False, s2am='unet', use_coarser=True,no_stage2=False):
super(UnetVMS2AMv4, self).__init__()
self.transfer_data = transfer_data
self.shared = shared_depth
self.optimizer_encoder, self.optimizer_image, self.optimizer_vm = None, None, None
self.optimizer_mask, self.optimizer_shared = None, None
if type(blocks) is not tuple:
blocks = (blocks, blocks, blocks, blocks, blocks)
if not transfer_data:
concat = False
self.encoder = UnetEncoderD(in_channels=in_channels, depth=depth, blocks=blocks[0],
start_filters=start_filters, residual=residual, batch_norm=batch_norm,norm=nn.InstanceNorm2d,act=F.leaky_relu)
self.image_decoder = UnetDecoderD(in_channels=start_filters * 2 ** (depth - shared_depth - 1),
out_channels=out_channels_image, depth=depth - shared_depth,
blocks=blocks[1], residual=residual, batch_norm=batch_norm,
transpose=transpose, concat=concat,norm=nn.InstanceNorm2d)
self.mask_decoder = UnetDecoderD(in_channels=start_filters * 2 ** (depth - shared_depth - 1),
out_channels=out_channels_mask, depth=depth - shared_depth,
blocks=blocks[2], residual=residual, batch_norm=batch_norm,
transpose=transpose, concat=concat,norm=nn.InstanceNorm2d)
self.vm_decoder = None
if use_vm_decoder:
self.vm_decoder = UnetDecoderD(in_channels=start_filters * 2 ** (depth - shared_depth - 1),
out_channels=out_channels_image, depth=depth - shared_depth,
blocks=blocks[3], residual=residual, batch_norm=batch_norm,
transpose=transpose, concat=concat,norm=nn.InstanceNorm2d)
self.shared_decoder = None
self.use_coarser = use_coarser
self.long_skip = long_skip
self.no_stage2 = no_stage2
self._forward = self.unshared_forward
if self.shared != 0:
self._forward = self.shared_forward
self.shared_decoder = UnetDecoderDatt(in_channels=start_filters * 2 ** (depth - 1),
out_channels=start_filters * 2 ** (depth - shared_depth - 1),
depth=shared_depth, blocks=blocks[4], residual=residual,
batch_norm=batch_norm, transpose=transpose, concat=concat,
is_final=False,norm=nn.InstanceNorm2d)
if s2am == 'unet':
self.s2am = UnetGenerator(4,3,is_attention_layer=True,attention_model=RASC,basicblock=MinimalUnetV2)
elif s2am == 'vm':
self.s2am = VMSingle(4)
elif s2am == 'vms2am':
self.s2am = VMSingleS2AM(4,down=ResDownNew,up=ResUpNew)
def set_optimizers(self):
self.optimizer_encoder = torch.optim.Adam(self.encoder.parameters(), lr=0.001)
self.optimizer_image = torch.optim.Adam(self.image_decoder.parameters(), lr=0.001)
self.optimizer_mask = torch.optim.Adam(self.mask_decoder.parameters(), lr=0.001)
self.optimizer_s2am = torch.optim.Adam(self.s2am.parameters(), lr=0.001)
if self.vm_decoder is not None:
self.optimizer_vm = torch.optim.Adam(self.vm_decoder.parameters(), lr=0.001)
if self.shared != 0:
self.optimizer_shared = torch.optim.Adam(self.shared_decoder.parameters(), lr=0.001)
def zero_grad_all(self):
self.optimizer_encoder.zero_grad()
self.optimizer_image.zero_grad()
self.optimizer_mask.zero_grad()
self.optimizer_s2am.zero_grad()
if self.vm_decoder is not None:
self.optimizer_vm.zero_grad()
if self.shared != 0:
self.optimizer_shared.zero_grad()
def step_all(self):
self.optimizer_encoder.step()
self.optimizer_image.step()
self.optimizer_mask.step()
self.optimizer_s2am.step()
if self.vm_decoder is not None:
self.optimizer_vm.step()
if self.shared != 0:
self.optimizer_shared.step()
def step_optimizer_image(self):
self.optimizer_image.step()
def __call__(self, synthesized):
return self._forward(synthesized)
def forward(self, synthesized):
return self._forward(synthesized)
def unshared_forward(self, synthesized):
image_code, before_pool = self.encoder(synthesized)
if not self.transfer_data:
before_pool = None
reconstructed_image = torch.tanh(self.image_decoder(image_code, before_pool))
reconstructed_mask = torch.sigmoid(self.mask_decoder(image_code, before_pool))
if self.vm_decoder is not None:
reconstructed_vm = torch.tanh(self.vm_decoder(image_code, before_pool))
return reconstructed_image, reconstructed_mask, reconstructed_vm
return reconstructed_image, reconstructed_mask
def shared_forward(self, synthesized):
image_code, before_pool = self.encoder(synthesized)
if self.transfer_data:
shared_before_pool = before_pool[- self.shared - 1:]
unshared_before_pool = before_pool[: - self.shared]
else:
before_pool = None
shared_before_pool = None
unshared_before_pool = None
im,mask,vm = self.shared_decoder(image_code, shared_before_pool)
reconstructed_image = torch.tanh(self.image_decoder(im, unshared_before_pool))
if self.long_skip:
reconstructed_image = reconstructed_image + synthesized
reconstructed_mask = torch.sigmoid(self.mask_decoder(mask, unshared_before_pool))
if self.vm_decoder is not None:
reconstructed_vm = torch.tanh(self.vm_decoder(vm, unshared_before_pool))
if self.long_skip:
reconstructed_vm = reconstructed_vm + synthesized
coarser = reconstructed_image * reconstructed_mask + (1-reconstructed_mask)* synthesized
if self.use_coarser:
refine = torch.tanh(self.s2am(torch.cat([coarser,reconstructed_mask],dim=1))) + coarser
elif self.no_stage2:
refine = torch.tanh(self.s2am(torch.cat([coarser,reconstructed_mask],dim=1)))
else:
refine = torch.tanh(self.s2am(torch.cat([coarser,reconstructed_mask],dim=1))) + synthesized
# final = refine * reconstructed_mask + (1-reconstructed_mask)* synthesized
if self.vm_decoder is not None:
return [refine, reconstructed_image], reconstructed_mask, reconstructed_vm
else:
return [refine, reconstructed_image], reconstructed_mask
================================================
FILE: scripts/models/unet.py
================================================
import torch
import torch.nn as nn
from torch.nn import init
import functools
from scripts.models.blocks import *
from scripts.models.rasc import *
class MinimalUnetV2(nn.Module):
"""docstring for MinimalUnet"""
def __init__(self, down=None,up=None,submodule=None,attention=None,withoutskip=False,**kwags):
super(MinimalUnetV2, self).__init__()
self.down = nn.Sequential(*down)
self.up = nn.Sequential(*up)
self.sub = submodule
self.attention = attention
self.withoutskip = withoutskip
self.is_attention = not self.attention == None
self.is_sub = not submodule == None
def forward(self,x,mask=None):
if self.is_sub:
x_up,_ = self.sub(self.down(x),mask)
else:
x_up = self.down(x)
if self.withoutskip: #outer or inner.
x_out = self.up(x_up)
else:
if self.is_attention:
x_out = (self.attention(torch.cat([x,self.up(x_up)],1),mask),mask)
else:
x_out = (torch.cat([x,self.up(x_up)],1),mask)
return x_out
class MinimalUnet(nn.Module):
"""docstring for MinimalUnet"""
def __init__(self, down=None,up=None,submodule=None,attention=None,withoutskip=False,**kwags):
super(MinimalUnet, self).__init__()
self.down = nn.Sequential(*down)
self.up = nn.Sequential(*up)
self.sub = submodule
self.attention = attention
self.withoutskip = withoutskip
self.is_attention = not self.attention == None
self.is_sub = not submodule == None
def forward(self,x,mask=None):
if self.is_sub:
x_up,_ = self.sub(self.down(x),mask)
else:
x_up = self.down(x)
if self.is_attention:
x = self.attention(x,mask)
if self.withoutskip: #outer or inner.
x_out = self.up(x_up)
else:
x_out = (torch.cat([x,self.up(x_up)],1),mask)
return x_out
# Defines the submodule with skip connection.
# X -------------------identity---------------------- X
# |-- downsampling -- |submodule| -- upsampling --|
class UnetSkipConnectionBlock(nn.Module):
def __init__(self, outer_nc, inner_nc, input_nc=None,
submodule=None, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d, use_dropout=False,is_attention_layer=False,
attention_model=RASC,basicblock=MinimalUnet,outermostattention=False):
super(UnetSkipConnectionBlock, self).__init__()
self.outermost = outermost
if type(norm_layer) == functools.partial:
use_bias = norm_layer.func == nn.InstanceNorm2d
else:
use_bias = norm_layer == nn.InstanceNorm2d
if input_nc is None:
input_nc = outer_nc
downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4,
stride=2, padding=1, bias=use_bias)
downrelu = nn.LeakyReLU(0.2, True)
downnorm = norm_layer(inner_nc)
uprelu = nn.ReLU(True)
upnorm = norm_layer(outer_nc)
if outermost:
upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
kernel_size=4, stride=2,
padding=1)
down = [downconv]
up = [uprelu, upconv]
model = basicblock(down,up,submodule,withoutskip=outermost)
elif innermost:
upconv = nn.ConvTranspose2d(inner_nc, outer_nc,
kernel_size=4, stride=2,
padding=1, bias=use_bias)
down = [downrelu, downconv]
up = [uprelu, upconv, upnorm]
model = basicblock(down,up)
else:
upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
kernel_size=4, stride=2,
padding=1, bias=use_bias)
down = [downrelu, downconv, downnorm]
up = [uprelu, upconv, upnorm]
if is_attention_layer:
if MinimalUnetV2.__qualname__ in basicblock.__qualname__ :
attention_model = attention_model(input_nc*2)
else:
attention_model = attention_model(input_nc)
else:
attention_model = None
if use_dropout:
model = basicblock(down,up.append(nn.Dropout(0.5)),submodule,attention_model,outermostattention=outermostattention)
else:
model = basicblock(down,up,submodule,attention_model,outermostattention=outermostattention)
self.model = model
def forward(self, x,mask=None):
# build the mask for attention use
return self.model(x,mask)
class UnetGenerator(nn.Module):
def __init__(self, input_nc, output_nc, num_downs=8, ngf=64,norm_layer=nn.BatchNorm2d, use_dropout=False,
is_attention_layer=False,attention_model=RASC,use_inner_attention=False,basicblock=MinimalUnet):
super(UnetGenerator, self).__init__()
# 8 for 256x256
# 9 for 512x512
# construct unet structure
self.need_mask = not input_nc == output_nc
unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True,basicblock=basicblock) # 1
for i in range(num_downs - 5): #3 times
unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer, use_dropout=use_dropout,is_attention_layer=use_inner_attention,attention_model=attention_model,basicblock=basicblock) # 8,4,2
unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer,is_attention_layer=is_attention_layer,attention_model=attention_model,basicblock=basicblock) #16
unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, norm_layer=norm_layer,is_attention_layer=is_attention_layer,attention_model=attention_model,basicblock=basicblock) #32
unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer,is_attention_layer=is_attention_layer,attention_model=attention_model,basicblock=basicblock, outermostattention=True) #64
unet_block = UnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, basicblock=basicblock, norm_layer=norm_layer) # 128
self.model = unet_block
def forward(self, input):
if self.need_mask:
return self.model(input,input[:,3:4,:,:])
else:
return self.model(input[:,0:3,:,:],input[:,3:4,:,:])
================================================
FILE: scripts/models/vgg.py
================================================
from collections import namedtuple
import torch
from torchvision import models
class Vgg16(torch.nn.Module):
def __init__(self, requires_grad=False):
super(Vgg16, self).__init__()
vgg_pretrained_features = models.vgg16(pretrained=True).features
self.slice1 = torch.nn.Sequential()
self.slice2 = torch.nn.Sequential()
self.slice3 = torch.nn.Sequential()
self.slice4 = torch.nn.Sequential()
self.slice5 = torch.nn.Sequential()
for x in range(4):
self.slice1.add_module(str(x), vgg_pretrained_features[x])
for x in range(4, 9):
self.slice2.add_module(str(x), vgg_pretrained_features[x])
for x in range(9, 16):
self.slice3.add_module(str(x), vgg_pretrained_features[x])
for x in range(16, 23):
self.slice4.add_module(str(x), vgg_pretrained_features[x])
for x in range(23,30):
self.slice5.add_module(str(x), vgg_pretrained_features[x])
if not requires_grad:
for param in self.parameters():
param.requires_grad = False
def forward(self, X):
h = self.slice1(X)
h_relu1_2 = h
h = self.slice2(h)
h_relu2_2 = h
h = self.slice3(h)
h_relu3_3 = h
h = self.slice4(h)
h_relu4_3 = h
h = self.slice5(h)
h_relu5_3 = h
# vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3','relu5_3'])
# out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3)
return (h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3)
class Vgg19(torch.nn.Module):
def __init__(self, requires_grad=False):
super(Vgg19, self).__init__()
# vgg_pretrained_features = models.vgg19(pretrained=True).features
self.vgg_pretrained_features = models.vgg19(pretrained=True).features
# self.slice1 = torch.nn.Sequential()
# self.slice2 = torch.nn.Sequential()
# self.slice3 = torch.nn.Sequential()
# self.slice4 = torch.nn.Sequential()
# self.slice5 = torch.nn.Sequential()
# for x in range(2):
# self.slice1.add_module(str(x), vgg_pretrained_features[x])
# for x in range(2, 7):
# self.slice2.add_module(str(x), vgg_pretrained_features[x])
# for x in range(7, 12):
# self.slice3.add_module(str(x), vgg_pretrained_features[x])
# for x in range(12, 21):
# self.slice4.add_module(str(x), vgg_pretrained_features[x])
# for x in range(21, 30):
# self.slice5.add_module(str(x), vgg_pretrained_features[x])
if not requires_grad:
for param in self.parameters():
param.requires_grad = False
def forward(self, X, indices=None):
if indices is None:
indices = [2, 7, 12, 21, 30]
out = []
#indices = sorted(indices)
for i in range(indices[-1]):
X = self.vgg_pretrained_features[i](X)
if (i+1) in indices:
out.append(X)
return out
================================================
FILE: scripts/models/vmu.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
from scripts.models.blocks import SEBlock
from scripts.models.rasc import *
from scripts.models.unet import UnetGenerator,MinimalUnetV2
def weight_init(m):
if isinstance(m, nn.Conv2d):
nn.init.xavier_normal_(m.weight)
nn.init.constant_(m.bias, 0)
def reset_params(model):
for i, m in enumerate(model.modules()):
weight_init(m)
def conv3x3(in_channels, out_channels, stride=1,
padding=1, bias=True, groups=1):
return nn.Conv2d(
in_channels,
out_channels,
kernel_size=3,
stride=stride,
padding=padding,
bias=bias,
groups=groups)
def up_conv2x2(in_channels, out_channels, transpose=True):
if transpose:
return nn.ConvTranspose2d(
in_channels,
out_channels,
kernel_size=2,
stride=2)
else:
return nn.Sequential(
nn.Upsample(mode='bilinear', scale_factor=2),
conv1x1(in_channels, out_channels))
def conv1x1(in_channels, out_channels, groups=1):
return nn.Conv2d(
in_channels,
out_channels,
kernel_size=1,
groups=groups,
stride=1)
class UpCoXvD(nn.Module):
def __init__(self, in_channels, out_channels, blocks, residual=True, batch_norm=True, transpose=True,concat=True,use_att=False):
super(UpCoXvD, self).__init__()
self.concat = concat
self.residual = residual
self.batch_norm = batch_norm
self.bn = None
self.conv2 = []
self.use_att = use_att
self.up_conv = up_conv2x2(in_channels, out_channels, transpose=transpose)
if self.use_att:
self.s2am = RASC(2 * out_channels)
else:
self.s2am = None
if self.concat:
self.conv1 = conv3x3(2 * out_channels, out_channels)
else:
self.conv1 = conv3x3(out_channels, out_channels)
for _ in range(blocks):
self.conv2.append(conv3x3(out_channels, out_channels))
if self.batch_norm:
self.bn = []
for _ in range(blocks):
self.bn.append(nn.BatchNorm2d(out_channels))
self.bn = nn.ModuleList(self.bn)
self.conv2 = nn.ModuleList(self.conv2)
def forward(self, from_up, from_down, mask=None):
from_up = self.up_conv(from_up)
if self.concat:
x1 = torch.cat((from_up, from_down), 1)
else:
if from_down is not None:
x1 = from_up + from_down
else:
x1 = from_up
if self.use_att:
x1 = self.s2am(x1,mask)
x1 = F.relu(self.conv1(x1))
x2 = None
for idx, conv in enumerate(self.conv2):
x2 = conv(x1)
if self.batch_norm:
x2 = self.bn[idx](x2)
if self.residual:
x2 = x2 + x1
x2 = F.relu(x2)
x1 = x2
return x2
class DownCoXvD(nn.Module):
def __init__(self, in_channels, out_channels, blocks, pooling=True, residual=True, batch_norm=True):
super(DownCoXvD, self).__init__()
self.pooling = pooling
self.residual = residual
self.batch_norm = batch_norm
self.bn = None
self.pool = None
self.conv1 = conv3x3(in_channels, out_channels)
self.conv2 = []
for _ in range(blocks):
self.conv2.append(conv3x3(out_channels, out_channels))
if self.batch_norm:
self.bn = []
for _ in range(blocks):
self.bn.append(nn.BatchNorm2d(out_channels))
self.bn = nn.ModuleList(self.bn)
if self.pooling:
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
self.conv2 = nn.ModuleList(self.conv2)
def __call__(self, x):
return self.forward(x)
def forward(self, x):
x1 = F.relu(self.conv1(x))
x2 = None
for idx, conv in enumerate(self.conv2):
x2 = conv(x1)
if self.batch_norm:
x2 = self.bn[idx](x2)
if self.residual:
x2 = x2 + x1
x2 = F.relu(x2)
x1 = x2
before_pool = x2
if self.pooling:
x2 = self.pool(x2)
return x2, before_pool
class UnetDecoderD(nn.Module):
def __init__(self, in_channels=512, out_channels=3, depth=5, blocks=1, residual=True, batch_norm=True,
transpose=True, concat=True, is_final=True):
super(UnetDecoderD, self).__init__()
self.conv_final = None
self.up_convs = []
outs = in_channels
for i in range(depth-1):
ins = outs
outs = ins // 2
# 512,256
# 256,128
# 128,64
# 64,32
up_conv = UpCoXvD(ins, outs, blocks, residual=residual, batch_norm=batch_norm, transpose=transpose,
concat=concat)
self.up_convs.append(up_conv)
if is_final:
self.conv_final = conv1x1(outs, out_channels)
else:
up_conv = UpCoXvD(outs, out_channels, blocks, residual=residual, batch_norm=batch_norm, transpose=transpose,
concat=concat)
self.up_convs.append(up_conv)
self.up_convs = nn.ModuleList(self.up_convs)
reset_params(self)
def __call__(self, x, encoder_outs=None):
return self.forward(x, encoder_outs)
def forward(self, x, encoder_outs=None):
for i, up_conv in enumerate(self.up_convs):
before_pool = None
if encoder_outs is not None:
before_pool = encoder_outs[-(i+2)]
x = up_conv(x, before_pool)
if self.conv_final is not None:
x = self.conv_final(x)
return x
class UnetEncoderD(nn.Module):
def __init__(self, in_channels=3, depth=5, blocks=1, start_filters=32, residual=True, batch_norm=True):
super(UnetEncoderD, self).__init__()
self.down_convs = []
outs = None
if type(blocks) is tuple:
blocks = blocks[0]
for i in range(depth):
ins = in_channels if i == 0 else outs
outs = start_filters*(2**i)
pooling = True if i < depth-1 else False
down_conv = DownCoXvD(ins, outs, blocks, pooling=pooling, residual=residual, batch_norm=batch_norm)
self.down_convs.append(down_conv)
self.down_convs = nn.ModuleList(self.down_convs)
reset_params(self)
def __call__(self, x):
return self.forward(x)
def forward(self, x):
encoder_outs = []
for d_conv in self.down_convs:
x, before_pool = d_conv(x)
encoder_outs.append(before_pool)
return x, encoder_outs
class UnetVM(nn.Module):
def __init__(self, in_channels=3, depth=5, shared_depth=0, use_vm_decoder=False, blocks=1,
out_channels_image=3, out_channels_mask=1, start_filters=32, residual=True, batch_norm=True,
transpose=True, concat=True, transfer_data=True, long_skip=False):
super(UnetVM, self).__init__()
self.transfer_data = transfer_data
self.shared = shared_depth
self.optimizer_encoder, self.optimizer_image, self.optimizer_vm = None, None, None
self.optimizer_mask, self.optimizer_shared = None, None
if type(blocks) is not tuple:
blocks = (blocks, blocks, blocks, blocks, blocks)
if not transfer_data:
concat = False
self.encoder = UnetEncoderD(in_channels=in_channels, depth=depth, blocks=blocks[0],
start_filters=start_filters, residual=residual, batch_norm=batch_norm)
self.image_decoder = UnetDecoderD(in_channels=start_filters * 2 ** (depth - shared_depth - 1),
out_channels=out_channels_image, depth=depth - shared_depth,
blocks=blocks[1], residual=residual, batch_norm=batch_norm,
transpose=transpose, concat=concat)
self.mask_decoder = UnetDecoderD(in_channels=start_filters * 2 ** (depth - 1),
out_channels=out_channels_mask, depth=depth,
blocks=blocks[2], residual=residual, batch_norm=batch_norm,
transpose=transpose, concat=concat)
self.vm_decoder = None
if use_vm_decoder:
self.vm_decoder = UnetDecoderD(in_channels=start_filters * 2 ** (depth - shared_depth - 1),
out_channels=out_channels_image, depth=depth - shared_depth,
blocks=blocks[3], residual=residual, batch_norm=batch_norm,
transpose=transpose, concat=concat)
self.shared_decoder = None
self.long_skip = long_skip
self._forward = self.unshared_forward
if self.shared != 0:
self._forward = self.shared_forward
self.shared_decoder = UnetDecoderD(in_channels=start_filters * 2 ** (depth - 1),
out_channels=start_filters * 2 ** (depth - shared_depth - 1),
depth=shared_depth, blocks=blocks[4], residual=residual,
batch_norm=batch_norm, transpose=transpose, concat=concat,
is_final=False)
def set_optimizers(self):
self.optimizer_encoder = torch.optim.Adam(self.encoder.parameters(), lr=0.001)
self.optimizer_image = torch.optim.Adam(self.image_decoder.parameters(), lr=0.001)
self.optimizer_mask = torch.optim.Adam(self.mask_decoder.parameters(), lr=0.001)
if self.vm_decoder is not None:
self.optimizer_vm = torch.optim.Adam(self.vm_decoder.parameters(), lr=0.001)
if self.shared != 0:
self.optimizer_shared = torch.optim.Adam(self.shared_decoder.parameters(), lr=0.001)
def zero_grad_all(self):
self.optimizer_encoder.zero_grad()
self.optimizer_image.zero_grad()
self.optimizer_mask.zero_grad()
if self.vm_decoder is not None:
self.optimizer_vm.zero_grad()
if self.shared != 0:
self.optimizer_shared.zero_grad()
def step_all(self):
self.optimizer_encoder.step()
self.optimizer_image.step()
self.optimizer_mask.step()
if self.vm_decoder is not None:
self.optimizer_vm.step()
if self.shared != 0:
self.optimizer_shared.step()
def step_optimizer_image(self):
self.optimizer_image.step()
def __call__(self, synthesized):
return self._forward(synthesized)
def forward(self, synthesized):
return self._forward(synthesized)
def unshared_forward(self, synthesized):
image_code, before_pool = self.encoder(synthesized)
if not self.transfer_data:
before_pool = None
reconstructed_image = torch.tanh(self.image_decoder(image_code, before_pool))
reconstructed_mask = torch.sigmoid(self.mask_decoder(image_code, before_pool))
if self.vm_decoder is not None:
reconstructed_vm = torch.tanh(self.vm_decoder(image_code, before_pool))
return reconstructed_image, reconstructed_mask, reconstructed_vm
return reconstructed_image, reconstructed_mask
def shared_forward(self, synthesized):
image_code, before_pool = self.encoder(synthesized)
if self.transfer_data:
shared_before_pool = before_pool[- self.shared - 1:]
unshared_before_pool = before_pool[: - self.shared]
else:
before_pool = None
shared_before_pool = None
unshared_before_pool = None
x = self.shared_decoder(image_code, shared_before_pool)
reconstructed_image = torch.tanh(self.image_decoder(x, unshared_before_pool))
if self.long_skip:
reconstructed_image = reconstructed_image + synthesized
reconstructed_mask = torch.sigmoid(self.mask_decoder(image_code, before_pool))
if self.vm_decoder is not None:
reconstructed_vm = torch.tanh(self.vm_decoder(x, unshared_before_pool))
if self.long_skip:
reconstructed_vm = reconstructed_vm + synthesized
return reconstructed_image, reconstructed_mask, reconstructed_vm
return reconstructed_image, reconstructed_mask
================================================
FILE: scripts/utils/__init__.py
================================================
from __future__ import absolute_import
from .evaluation import *
from .imutils import *
from .logger import *
from .misc import *
from .osutils import *
from .transforms import *
================================================
FILE: scripts/utils/evaluation.py
================================================
from __future__ import absolute_import
import math
import numpy as np
import matplotlib.pyplot as plt
from random import randint
from .misc import *
from .transforms import transform, transform_preds
__all__ = ['accuracy', 'AverageMeter']
def get_preds(scores):
''' get predictions from score maps in torch Tensor
return type: torch.LongTensor
'''
assert scores.dim() == 4, 'Score maps should be 4-dim'
maxval, idx = torch.max(scores.view(scores.size(0), scores.size(1), -1), 2)
maxval = maxval.view(scores.size(0), scores.size(1), 1)
idx = idx.view(scores.size(0), scores.size(1), 1) + 1
preds = idx.repeat(1, 1, 2).float()
preds[:,:,0] = (preds[:,:,0] - 1) % scores.size(3) + 1
preds[:,:,1] = torch.floor((preds[:,:,1] - 1) / scores.size(2)) + 1
pred_mask = maxval.gt(0).repeat(1, 1, 2).float()
preds *= pred_mask
return preds
def calc_dists(preds, target, normalize):
preds = preds.float()
target = target.float()
dists = torch.zeros(preds.size(1), preds.size(0))
for n in range(preds.size(0)):
for c in range(preds.size(1)):
if target[n,c,0] > 1 and target[n, c, 1] > 1:
dists[c, n] = torch.dist(preds[n,c,:], target[n,c,:])/normalize[n]
else:
dists[c, n] = -1
return dists
def dist_acc(dists, thr=0.5):
''' Return percentage below threshold while ignoring values with a -1 '''
if dists.ne(-1).sum() > 0:
return dists.le(thr).eq(dists.ne(-1)).sum()*1.0 / dists.ne(-1).sum()
else:
return -1
def accuracy(output, target, thr=0.5):
''' Calculate accuracy according to PCK, but uses ground truth heatmap rather than x,y locations
First value to be returned is average accuracy across 'idxs', followed by individual accuracies
'''
# output_mask = torch.gt(output,thr);
# target_mask = torch.gt(target,thr);
# equal_mask = torch.eq(output_mask,target_mask);
# fp_equal_mask = torch.lt(output_mask,target_mask);
# fn_equal_mask = torch.gt(output_mask,target_mask);
# tp = torch.sum(equal_mask);
# fn = torch.sum(fn_equal_mask);
# fp = torch.sum(fp_equal_mask);
# return 2*tp / (2*tp+fn+fp)
if output.dim() > 2:
v,i = torch.max(output,1);
else:
v,i = torch.max(output,1);
return torch.sum(target.long() == i).float()/target.numel()
def final_preds(output, center, scale, res):
coords = get_preds(output) # float type
# pose-processing
for n in range(coords.size(0)):
for p in range(coords.size(1)):
hm = output[n][p]
px = int(math.floor(coords[n][p][0]))
py = int(math.floor(coords[n][p][1]))
if px > 1 and px < res[0] and py > 1 and py < res[1]:
diff = torch.Tensor([hm[py - 1][px] - hm[py - 1][px - 2], hm[py][px - 1]-hm[py - 2][px - 1]])
coords[n][p] += diff.sign() * .25
coords += 0.5
preds = coords.clone()
# Transform back
for i in range(coords.size(0)):
preds[i] = transform_preds(coords[i], center[i], scale[i], res)
if preds.dim() < 3:
preds = preds.view(1, preds.size())
return preds
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
================================================
FILE: scripts/utils/imutils.py
================================================
from __future__ import absolute_import
import torch
import torch.nn as nn
import numpy as np
import scipy.misc
from .misc import *
def im_to_numpy(img):
img = to_numpy(img)
img = np.transpose(img, (1, 2, 0)) # H*W*C
return img
def im_to_torch(img):
img = np.transpose(img, (2, 0, 1)) # C*H*W
img = to_torch(img).float()
if img.max() > 1:
img /= 255
return img
def load_image(img_path):
# H x W x C => C x H x W
return im_to_torch(scipy.misc.imread(img_path, mode='RGB'))
def imread_all(img_path):
return scipy.misc.imread(img_path, mode='RGB')
def load_image_gray(img_path):
# H x W x C => C x H x W
x = scipy.misc.imread(img_path, mode='L')
x = x[:,:,np.newaxis]
return im_to_torch(x)
def resize(img, owidth, oheight):
img = im_to_numpy(img)
if img.shape[2] == 1:
img = scipy.misc.imresize(img.squeeze(),(oheight,owidth))
img = img[:,:,np.newaxis]
else:
img = scipy.misc.imresize(
img,
(oheight, owidth)
)
img = im_to_torch(img)
# print('%f %f' % (img.min(), img.max()))
return img
# =============================================================================
# Helpful functions generating groundtruth labelmap
# =============================================================================
def gaussian(shape=(7,7),sigma=1):
"""
2D gaussian mask - should give the same result as MATLAB's
fspecial('gaussian',[shape],[sigma])
"""
m,n = [(ss-1.)/2. for ss in shape]
y,x = np.ogrid[-m:m+1,-n:n+1]
h = np.exp( -(x*x + y*y) / (2.*sigma*sigma) )
h[ h < np.finfo(h.dtype).eps*h.max() ] = 0
return to_torch(h).float()
def draw_labelmap(img, pt, sigma, type='Gaussian'):
# Draw a 2D gaussian
# Adopted from https://github.com/anewell/pose-hg-train/blob/master/src/pypose/draw.py
img = to_numpy(img)
# Check that any part of the gaussian is in-bounds
ul = [int(pt[0] - 3 * sigma), int(pt[1] - 3 * sigma)]
br = [int(pt[0] + 3 * sigma + 1), int(pt[1] + 3 * sigma + 1)]
if (ul[0] >= img.shape[1] or ul[1] >= img.shape[0] or
br[0] < 0 or br[1] < 0):
# If not, just return the image as is
return to_torch(img)
# Generate gaussian
size = 6 * sigma + 1
x = np.arange(0, size, 1, float)
y = x[:, np.newaxis]
x0 = y0 = size // 2
# The gaussian is not normalized, we want the center value to equal 1
if type == 'Gaussian':
g = np.exp(- ((x - x0) ** 2 + (y - y0) ** 2) / (2 * sigma ** 2))
elif type == 'Cauchy':
g = sigma / (((x - x0) ** 2 + (y - y0) ** 2 + sigma ** 2) ** 1.5)
# Usable gaussian range
g_x = max(0, -ul[0]), min(br[0], img.shape[1]) - ul[0]
g_y = max(0, -ul[1]), min(br[1], img.shape[0]) - ul[1]
# Image range
img_x = max(0, ul[0]), min(br[0], img.shape[1])
img_y = max(0, ul[1]), min(br[1], img.shape[0])
img[img_y[0]:img_y[1], img_x[0]:img_x[1]] = g[g_y[0]:g_y[1], g_x[0]:g_x[1]]
return to_torch(img)
# =============================================================================
# Helpful display functions
# =============================================================================
def gauss(x, a, b, c, d=0):
return a * np.exp(-(x - b)**2 / (2 * c**2)) + d
def color_heatmap(x):
x = to_numpy(x)
color = np.zeros((x.shape[0],x.shape[1],3))
color[:,:,0] = gauss(x, .5, .6, .2) + gauss(x, 1, .8, .3)
color[:,:,1] = gauss(x, 1, .5, .3)
color[:,:,2] = gauss(x, 1, .2, .3)
color[color > 1] = 1
color = (color * 255).astype(np.uint8)
return color
def imshow(img):
npimg = im_to_numpy(img*255).astype(np.uint8)
plt.imshow(npimg)
plt.axis('off')
def show_joints(img, pts):
imshow(img)
for i in range(pts.size(0)):
if pts[i, 2] > 0:
plt.plot(pts[i, 0], pts[i, 1], 'yo')
plt.axis('off')
def show_sample(inputs, target):
num_sample = inputs.size(0)
num_joints = target.size(1)
height = target.size(2)
width = target.size(3)
for n in range(num_sample):
inp = resize(inputs[n], width, height)
out = inp
for p in range(num_joints):
tgt = inp*0.5 + color_heatmap(target[n,p,:,:])*0.5
out = torch.cat((out, tgt), 2)
imshow(out)
plt.show()
def sample_with_heatmap(inp, out, num_rows=2, parts_to_show=None):
inp = to_numpy(inp * 255)
out = to_numpy(out)
img = np.zeros((inp.shape[1], inp.shape[2], inp.shape[0]))
for i in range(3):
img[:, :, i] = inp[i, :, :]
if parts_to_show is None:
parts_to_show = np.arange(out.shape[0])
# Generate a single image to display input/output pair
num_cols = int(np.ceil(float(len(parts_to_show)) / num_rows))
size = img.shape[0] // num_rows
full_img = np.zeros((img.shape[0], size * (num_cols + num_rows), 3), np.uint8)
full_img[:img.shape[0], :img.shape[1]] = img
inp_small = scipy.misc.imresize(img, [size, size])
# Set up heatmap display for each part
for i, part in enumerate(parts_to_show):
part_idx = part
out_resized = scipy.misc.imresize(out[part_idx], [size, size])
out_resized = out_resized.astype(float)/255
out_img = inp_small.copy() * .3
color_hm = color_heatmap(out_resized)
out_img += color_hm * .7
col_offset = (i % num_cols + num_rows) * size
row_offset = (i // num_cols) * size
full_img[row_offset:row_offset + size, col_offset:col_offset + size] = out_img
return full_img
def batch_with_heatmap(inputs, outputs, mean=torch.Tensor([0.5, 0.5, 0.5]), num_rows=2, parts_to_show=None):
batch_img = []
for n in range(min(inputs.size(0), 4)):
inp = inputs[n] + mean.view(3, 1, 1).expand_as(inputs[n])
batch_img.append(
sample_with_heatmap(inp.clamp(0, 1), outputs[n], num_rows=num_rows, parts_to_show=parts_to_show)
)
return np.concatenate(batch_img)
def normalize_batch(batch):
# normalize using imagenet mean and std
mean = batch.new_tensor([0.485, 0.456, 0.406]).view(-1, 1, 1)
std = batch.new_tensor([0.229, 0.224, 0.225]).view(-1, 1, 1)
batch = batch/255.0
return (batch - mean) / std
def show_image_tensor(tensor):
re = []
for i in range(tensor.size(0)):
inp = tensor[i].data.cpu() #w,h,c
inp = inp.numpy().transpose((1, 2, 0))
mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])
inp = std * inp + mean
inp = np.clip(inp, 0, 1).transpose((2,0,1))
re.append(torch.from_numpy(inp).unsqueeze(0))
return torch.cat(re,0)
def get_jet():
colormap_int = np.zeros((256, 3), np.uint8)
for i in range(0, 256, 1):
colormap_int[i, 0] = np.int_(np.round(cm.jet(i)[0] * 255.0))
colormap_int[i, 1] = np.int_(np.round(cm.jet(i)[1] * 255.0))
colormap_int[i, 2] = np.int_(np.round(cm.jet(i)[2] * 255.0))
return colormap_int
def clamp(num, min_value, max_value):
return max(min(num, max_value), min_value)
def gray2color(gray_array, color_map):
rows, cols = gray_array.shape
color_array = np.zeros((rows, cols, 3), np.uint8)
for i in range(0, rows):
for j in range(0, cols):
# log(256,2) = 8 , log(1,2) = 0 * 8
color_array[i, j] = color_map[clamp(int(abs(gray_array[i, j])*10),0,255)]
return color_array
class objectview(object):
def __init__(self, *args, **kwargs):
d = dict(*args, **kwargs)
self.__dict__ = d
================================================
FILE: scripts/utils/logger.py
================================================
# A simple torch style logger
# (C) Wei YANG 2017
from __future__ import absolute_import
import os
import sys
import numpy as np
import matplotlib.pyplot as plt
__all__ = ['Logger', 'LoggerMonitor', 'savefig']
def savefig(fname, dpi=None):
dpi = 150 if dpi == None else dpi
plt.savefig(fname, dpi=dpi)
def plot_overlap(logger, names=None):
names = logger.names if names == None else names
numbers = logger.numbers
for _, name in enumerate(names):
x = np.arange(len(numbers[name]))
plt.plot(x, np.asarray(numbers[name]))
return [logger.title + '(' + name + ')' for name in names]
class Logger(object):
'''Save training process to log file with simple plot function.'''
def __init__(self, fpath, title=None, resume=False):
self.file = None
self.resume = resume
self.title = '' if title == None else title
if fpath is not None:
if resume:
self.file = open(fpath, 'r')
name = self.file.readline()
self.names = name.rstrip().split('\t')
self.numbers = {}
for _, name in enumerate(self.names):
self.numbers[name] = []
for numbers in self.file:
numbers = numbers.rstrip().split('\t')
for i in range(0, len(numbers)):
self.numbers[self.names[i]].append(numbers[i])
self.file.close()
self.file = open(fpath, 'a')
else:
self.file = open(fpath, 'w')
def set_names(self, names):
if self.resume:
pass
# initialize numbers as empty list
self.numbers = {}
self.names = names
for _, name in enumerate(self.names):
self.file.write(name)
self.file.write('\t')
self.numbers[name] = []
self.file.write('\n')
self.file.flush()
def append(self, numbers):
assert len(self.names) == len(numbers), 'Numbers do not match names'
for index, num in enumerate(numbers):
self.file.write("{0:.6f}".format(num))
self.file.write('\t')
self.numbers[self.names[index]].append(num)
self.file.write('\n')
self.file.flush()
def plot(self, names=None):
names = self.names if names == None else names
numbers = self.numbers
for _, name in enumerate(names):
x = np.arange(len(numbers[name]))
plt.plot(x, np.asarray(numbers[name]))
plt.legend([self.title + '(' + name + ')' for name in names])
plt.grid(True)
def close(self):
if self.file is not None:
self.file.close()
class LoggerMonitor(object):
'''Load and visualize multiple logs.'''
def __init__ (self, paths):
'''paths is a distionary with {name:filepath} pair'''
self.loggers = []
for title, path in paths.items():
logger = Logger(path, title=title, resume=True)
self.loggers.append(logger)
def plot(self, names=None):
plt.figure()
plt.subplot(121)
legend_text = []
for logger in self.loggers:
legend_text += plot_overlap(logger, names)
plt.legend(legend_text, bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.)
plt.grid(True)
if __name__ == '__main__':
# # Example
# logger = Logger('test.txt')
# logger.set_names(['Train loss', 'Valid loss','Test loss'])
# length = 100
# t = np.arange(length)
# train_loss = np.exp(-t / 10.0) + np.random.rand(length) * 0.1
# valid_loss = np.exp(-t / 10.0) + np.random.rand(length) * 0.1
# test_loss = np.exp(-t / 10.0) + np.random.rand(length) * 0.1
# for i in range(0, length):
# logger.append([train_loss[i], valid_loss[i], test_loss[i]])
# logger.plot()
# Example: logger monitor
paths = {
'resadvnet20':'/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet20/log.txt',
'resadvnet32':'/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet32/log.txt',
'resadvnet44':'/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet44/log.txt',
}
field = ['Valid Acc.']
monitor = LoggerMonitor(paths)
monitor.plot(names=field)
savefig('test.eps')
================================================
FILE: scripts/utils/losses.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
from scripts.models.vgg import Vgg19
from torchvision import models
from scripts.utils.misc import resize_to_match
# from pytorch_msssim import SSIM, MS_SSIM
import pytorch_ssim
class WeightedBCE(nn.Module):
def __init__(self):
super(WeightedBCE, self).__init__()
def forward(self, pred, gt):
eposion = 1e-10
sigmoid_pred = torch.sigmoid(pred)
count_pos = torch.sum(gt)*1.0+eposion
count_neg = torch.sum(1.-gt)*1.0
beta = count_neg/count_pos
beta_back = count_pos / (count_pos + count_neg)
bce1 = nn.BCEWithLogitsLoss(pos_weight=beta)
loss = beta_back*bce1(pred, gt)
return loss
def l1_relative(reconstructed, real, mask):
batch = real.size(0)
area = torch.sum(mask.view(batch,-1),dim=1)
reconstructed = reconstructed * mask
real = real * mask
loss_l1 = torch.abs(reconstructed - real).view(batch, -1)
loss_l1 = torch.sum(loss_l1, dim=1) / area
loss_l1 = torch.sum(loss_l1) / batch
return loss_l1
def is_dic(x):
return type(x) == type([])
class Losses(nn.Module):
def __init__(self, argx, device):
super(Losses, self).__init__()
self.args = argx
if self.args.loss_type == 'l1bl2':
self.outputLoss, self.attLoss, self.wrloss = nn.L1Loss(), nn.BCELoss(), nn.MSELoss()
elif self.args.loss_type == 'l1wbl2':
self.outputLoss, self.attLoss, self.wrloss = nn.L1Loss(), WeightedBCE(), nn.MSELoss()
elif self.args.loss_type == 'l2wbl2':
self.outputLoss, self.attLoss, self.wrloss = nn.MSELoss(), WeightedBCE(), nn.MSELoss()
elif self.args.loss_type == 'l2xbl2':
self.outputLoss, self.attLoss, self.wrloss = nn.MSELoss(), nn.BCEWithLogitsLoss(), nn.MSELoss()
else: # l2bl2
self.outputLoss, self.attLoss, self.wrloss = nn.MSELoss(), nn.BCELoss(), nn.MSELoss()
if self.args.style_loss > 0:
self.vggloss = VGGLoss(self.args.sltype).to(device)
if self.args.ssim_loss > 0:
self.ssimloss = pytorch_ssim.SSIM().to(device)
self.outputLoss = self.outputLoss.to(device)
self.attLoss = self.attLoss.to(device)
self.wrloss = self.wrloss.to(device)
def forward(self,imgx,target,attx,mask,wmx,wm):
pixel_loss,att_loss,wm_loss,vgg_loss,ssim_loss = 0,0,0,0,0
if is_dic(imgx):
if self.args.masked:
# calculate the overall loss and side output
pixel_loss = self.outputLoss(imgx[0],target) + sum([self.outputLoss(im,resize_to_match(mask,im)*resize_to_match(target,im)) for im in imgx[1:]])
else:
pixel_loss = sum([self.outputLoss(im,resize_to_match(target,im)) for im in imgx])
if self.args.style_loss > 0:
vgg_loss = sum([self.vggloss(im,resize_to_match(target,im),resize_to_match(mask,im)) for im in imgx])
if self.args.ssim_loss > 0:
ssim_loss = sum([ 1 - self.ssimloss(im,resize_to_match(target,im)) for im in imgx])
else:
if self.args.masked:
pixel_loss = self.outputLoss(imgx,mask*target)
else:
pixel_loss = self.outputLoss(imgx,target)
if self.args.style_loss > 0:
vgg_loss = self.vggloss(imgx,target,mask)
if self.args.ssim_loss > 0:
ssim_loss = 1 - self.ssimloss(imgx,target)
if is_dic(attx):
att_loss = sum([self.attLoss(at,resize_to_match(mask,at)) for at in attx])
else:
att_loss = self.attLoss(attx, mask)
if is_dic(wmx):
wm_loss = sum([self.wrloss(w,resize_to_match(wm,w)) for w in wmx])
else:
if self.args.masked:
wm_loss = self.wrloss(wmx,mask*wm)
else:
wm_loss = self.wrloss(wmx, wm)
return pixel_loss,att_loss,wm_loss,vgg_loss,ssim_loss
def gram_matrix(feat):
# https://github.com/pytorch/examples/blob/master/fast_neural_style/neural_style/utils.py
(b, ch, h, w) = feat.size()
feat = feat.view(b, ch, h * w)
feat_t = feat.transpose(1, 2)
gram = torch.bmm(feat, feat_t) / (ch * h * w)
return gram
class MeanShift(nn.Conv2d):
def __init__(self, data_mean, data_std, data_range=1, norm=True):
"""norm (bool): normalize/denormalize the stats"""
c = len(data_mean)
super(MeanShift, self).__init__(c, c, kernel_size=1)
std = torch.Tensor(data_std)
self.weight.data = torch.eye(c).view(c, c, 1, 1)
if norm:
self.weight.data.div_(std.view(c, 1, 1, 1))
self.bias.data = -1 * data_range * torch.Tensor(data_mean)
self.bias.data.div_(std)
else:
self.weight.data.mul_(std.view(c, 1, 1, 1))
self.bias.data = data_range * torch.Tensor(data_mean)
self.requires_grad = False
def VGGLoss(losstype):
if losstype == 'vgg':
return VGGLossA()
elif losstype == 'vggx':
return VGGLossX(mask=False)
elif losstype == 'mvggx':
return VGGLossX(mask=True)
elif losstype == 'rvggx':
return VGGLossX(mask=True,relative=True)
else:
raise Exception("error in %s"%losstype)
class VGGLossA(nn.Module):
def __init__(self, vgg=None, weights=None, indices=None, normalize=True):
super(VGGLossA, self).__init__()
if vgg is None:
self.vgg = Vgg19().cuda()
else:
self.vgg = vgg
self.criterion = nn.L1Loss()
self.weights = weights or [1.0/2.6, 1.0/4.8, 1.0/3.7, 1.0/5.6, 10/1.5]
self.indices = indices or [2, 7, 12, 21, 30]
if normalize:
self.normalize = MeanShift([0.485, 0.456, 0.406], [0.229, 0.224, 0.225], norm=True).cuda()
else:
self.normalize = None
def forward(self, x, y):
if self.normalize is not None:
x = self.normalize(x)
y = self.normalize(y)
x_vgg, y_vgg = self.vgg(x, self.indices), self.vgg(y, self.indices)
loss = 0
for i in range(len(x_vgg)):
loss += self.weights[i] * self.criterion(x_vgg[i], y_vgg[i].detach())
return loss
class VGG16FeatureExtractor(nn.Module):
def __init__(self):
super().__init__()
vgg16 = models.vgg16(pretrained=True)
self.enc_1 = nn.Sequential(*vgg16.features[:5])
self.enc_2 = nn.Sequential(*vgg16.features[5:10])
self.enc_3 = nn.Sequential(*vgg16.features[10:17])
# fix the encoder
for i in range(3):
for param in getattr(self, 'enc_{:d}'.format(i + 1)).parameters():
param.requires_grad = False
def forward(self, image):
results = [image]
for i in range(3):
func = getattr(self, 'enc_{:d}'.format(i + 1))
results.append(func(results[-1]))
return results[1:]
class VGGLossX(nn.Module):
def __init__(self, normalize=True, mask=False, relative=False):
super(VGGLossX, self).__init__()
self.vgg = VGG16FeatureExtractor().cuda()
self.criterion = nn.L1Loss().cuda() if not relative else l1_relative
self.use_mask= mask
self.relative = relative
if normalize:
self.normalize = MeanShift([0.485, 0.456, 0.406], [0.229, 0.224, 0.225], norm=True).cuda()
else:
self.normalize = None
def forward(self, x, y, Xmask=None):
if not self.use_mask:
mask = torch.ones_like(x)[:,0:1,:,:]
else:
mask = Xmask
if self.normalize is not None:
x = self.normalize(x)
y = self.normalize(y)
x_vgg = self.vgg(x)
y_vgg = self.vgg(y)
loss = 0
for i in range(3):
if self.relative:
loss += self.criterion(x_vgg[i],y_vgg[i].detach(),resize_to_match(mask,x_vgg[i]))
else:
loss += self.criterion(resize_to_match(mask,x_vgg[i])*x_vgg[i],resize_to_match(mask,y_vgg[i])*y_vgg[i].detach())
return loss
class GANLosses(object):
"""docstring for Loss"""
def __init__(self, gantype):
super(GANLosses, self).__init__()
self.generator_loss = gen_gan(gantype)
self.discriminator_loss = dis_gan(gantype)
self.gantype = gantype
def g_loss(self,dis_fake):
if 'hinge' in self.gantype:
return gen_hinge(dis_fake)
else:
return self.generator_loss(dis_fake)
def d_loss(self,dis_fake,dis_real):
if 'hinge' in self.gantype:
return dis_hinge(dis_fake,dis_real)
else:
return self.discriminator_loss(dis_fake,dis_real)
class gen_gan(nn.Module):
def __init__(self,gantype):
super(gen_gan,self).__init__()
if gantype == 'lsgan':
self.criterion = nn.MSELoss()
elif gantype == 'naive':
self.criterion = nn.BCEWithLogitsLoss()
else:
raise Exception("error gan type")
def forward(self,dis_fake):
return self.criterion(dis_fake, torch.ones_like(dis_fake))
class dis_gan(nn.Module):
def __init__(self,gantype):
super(dis_gan,self).__init__()
if gantype == 'lsgan':
self.criterion = nn.MSELoss()
elif gantype == 'naive':
self.criterion = nn.BCEWithLogitsLoss()
else:
raise Exception("error gan type")
def forward(self,dis_fake,dis_real):
loss_fake = self.criterion(dis_fake, torch.zeros_like(dis_fake))
loss_real = self.criterion(dis_real, torch.ones_like(dis_real))
return loss_fake, loss_real
# def gen_gan(dis_fake):
# # fake -> 1
# return F.binary_cross_entropy_with_logits(dis_fake,torch.ones_like(dis_fake))
# def dis_gan(dis_fake,dis_real):
# # fake -> 0 , real ->1
# loss_fake = F.binary_cross_entropy_with_logits(dis_fake, torch.zeros_like(dis_real))
# loss_real = F.binary_cross_entropy_with_logits(dis_real, torch.ones_like(dis_fake))
# return loss_fake,loss_real
# def gen_lsgan(dis_fake):
# loss = F.mse_loss(dis_fake,torch.ones_like(dis_fake)) #
# return loss
# def dis_lsgan(dis_fake, dis_real):
# loss_fake = F.mse_loss(dis_fake, torch.zeros_like(dis_real))
# loss_real = F.mse_loss(dis_real, torch.ones_like(dis_real))
# return loss_fake,loss_real
def gen_hinge(dis_fake, dis_real=None):
return -torch.mean(dis_fake)
def dis_hinge(dis_fake, dis_real):
loss_fake = torch.mean(torch.relu(1. + dis_fake))
loss_real = torch.mean(torch.relu(1. - dis_real))
return loss_fake,loss_real
================================================
FILE: scripts/utils/misc.py
================================================
from __future__ import absolute_import
import os
import shutil
import torch
import math
import numpy as np
import scipy.io
import matplotlib.pyplot as plt
import torch.nn.functional as F
def to_numpy(tensor):
if torch.is_tensor(tensor):
return tensor.cpu().numpy()
elif type(tensor).__module__ != 'numpy':
raise ValueError("Cannot convert {} to numpy array"
.format(type(tensor)))
return tensor
def resize_to_match(fm,to):
# just use interpolate
# [1,3] = (h,w)
return F.interpolate(fm,to.size()[-2:],mode='bilinear',align_corners=False)
def to_torch(ndarray):
if type(ndarray).__module__ == 'numpy':
return torch.from_numpy(ndarray)
elif not torch.is_tensor(ndarray):
raise ValueError("Cannot convert {} to torch tensor"
.format(type(ndarray)))
return ndarray
def save_checkpoint(machine,filename='checkpoint.pth.tar', snapshot=None):
is_best = True if machine.best_acc < machine.metric else False
if is_best:
machine.best_acc = machine.metric
state = {
'epoch': machine.current_epoch + 1,
'arch': machine.args.arch,
'state_dict': machine.model.state_dict(),
'best_acc': machine.best_acc,
'optimizer' : machine.optimizer.state_dict(),
}
filepath = os.path.join(machine.args.checkpoint, filename)
torch.save(state, filepath)
if snapshot and state['epoch'] % snapshot == 0:
shutil.copyfile(filepath, os.path.join(machine.args.checkpoint, 'checkpoint_{}.pth.tar'.format(state.epoch)))
if is_best:
machine.best_acc = machine.metric
print('Saving Best Metric with PSNR:%s'%machine.best_acc)
shutil.copyfile(filepath, os.path.join(machine.args.checkpoint, 'model_best.pth.tar'))
def save_pred(preds, checkpoint='checkpoint', filename='preds_valid.mat'):
preds = to_numpy(preds)
filepath = os.path.join(checkpoint, filename)
scipy.io.savemat(filepath, mdict={'preds' : preds})
def adjust_learning_rate(datasets,optimizer, epoch, lr,args):
"""Sets the learning rate to the initial LR decayed by schedule"""
if epoch in args.schedule:
lr *= args.gamma
for param_group in optimizer.param_groups:
param_group['lr'] = lr
# decay sigma
for dset in datasets:
if args.sigma_decay > 0:
dset.dataset.sigma *= args.sigma_decay
dset.dataset.sigma *= args.sigma_decay
return lr
================================================
FILE: scripts/utils/model_init.py
================================================
from torch.nn import init
def weights_init_normal(m):
classname = m.__class__.__name__
# print(classname)
if classname.find('Conv') != -1:
init.normal_(m.weight.data, 0.0, 0.02)
elif classname.find('Linear') != -1:
init.normal_(m.weight.data, 0.0, 0.02)
elif classname.find('BatchNorm2d') != -1:
init.normal_(m.weight.data, 1.0, 0.02)
init.constant_(m.bias.data, 0.0)
def weights_init_xavier(m):
classname = m.__class__.__name__
# print(classname)
if classname.find('Conv') != -1:
init.xavier_normal(m.weight.data, gain=0.02)
elif classname.find('Linear') != -1:
init.xavier_normal(m.weight.data, gain=0.02)
# elif classname.find('BatchNorm2d') != -1:
# init.normal(m.weight.data, 1.0, 0.02)
# init.constant(m.bias.data, 0.0)
def weights_init_kaiming(m):
classname = m.__class__.__name__
# print(classname)
if classname.find('Conv') != -1 and m.weight.requires_grad == True:
init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
elif classname.find('Linear') != -1 and m.weight.requires_grad == True:
init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
elif classname.find('BatchNorm2d') != -1 and m.weight.requires_grad == True:
init.normal_(m.weight.data, 1.0, 0.02)
init.constant_(m.bias.data, 0.0)
def weights_init_orthogonal(m):
classname = m.__class__.__name__
if classname.find('Conv') != -1:
init.orthogonal(m.weight.data, gain=1)
elif classname.find('Linear') != -1:
init.orthogonal(m.weight.data, gain=1)
# elif classname.find('BatchNorm2d') != -1:
# init.normal(m.weight.data, 1.0, 0.02)
# init.constant(m.bias.data, 0.0)
================================================
FILE: scripts/utils/osutils.py
================================================
from __future__ import absolute_import
import os
import errno
def mkdir_p(dir_path):
try:
os.makedirs(dir_path)
except OSError as e:
if e.errno != errno.EEXIST:
raise
def isfile(fname):
return os.path.isfile(fname)
def isdir(dirname):
return os.path.isdir(dirname)
def join(path, *paths):
return os.path.join(path, *paths)
================================================
FILE: scripts/utils/parallel.py
================================================
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
## Created by: Hang Zhang, Rutgers University, Email: zhang.hang@rutgers.edu
## Modified by Thomas Wolf, HuggingFace Inc., Email: thomas@huggingface.co
## Copyright (c) 2017-2018
##
## This source code is licensed under the MIT-style license found in the
## LICENSE file in the root directory of this source tree
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
"""Encoding Data Parallel"""
import threading
import functools
import torch
from torch.autograd import Variable, Function
import torch.cuda.comm as comm
from torch.nn.parallel import DistributedDataParallel
from torch.nn.parallel.data_parallel import DataParallel
from torch.nn.parallel.parallel_apply import get_a_var
from torch.nn.parallel.scatter_gather import gather
from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast
torch_ver = torch.__version__[:3]
__all__ = ['allreduce', 'DataParallelModel', 'DataParallelCriterion',
'patch_replication_callback']
def allreduce(*inputs):
"""Cross GPU all reduce autograd operation for calculate mean and
variance in SyncBN.
"""
return AllReduce.apply(*inputs)
class AllReduce(Function):
@staticmethod
def forward(ctx, num_inputs, *inputs):
ctx.num_inputs = num_inputs
ctx.target_gpus = [inputs[i].get_device() for i in range(0, len(inputs), num_inputs)]
inputs = [inputs[i:i + num_inputs]
for i in range(0, len(inputs), num_inputs)]
# sort before reduce sum
inputs = sorted(inputs, key=lambda i: i[0].get_device())
results = comm.reduce_add_coalesced(inputs, ctx.target_gpus[0])
outputs = comm.broadcast_coalesced(results, ctx.target_gpus)
return tuple([t for tensors in outputs for t in tensors])
@staticmethod
def backward(ctx, *inputs):
inputs = [i.data for i in inputs]
inputs = [inputs[i:i + ctx.num_inputs]
for i in range(0, len(inputs), ctx.num_inputs)]
results = comm.reduce_add_coalesced(inputs, ctx.target_gpus[0])
outputs = comm.broadcast_coalesced(results, ctx.target_gpus)
return (None,) + tuple([Variable(t) for tensors in outputs for t in tensors])
class Reduce(Function):
@staticmethod
def forward(ctx, *inputs):
ctx.target_gpus = [inputs[i].get_device() for i in range(len(inputs))]
inputs = sorted(inputs, key=lambda i: i.get_device())
return comm.reduce_add(inputs)
@staticmethod
def backward(ctx, gradOutput):
return Broadcast.apply(ctx.target_gpus, gradOutput)
class DistributedDataParallelModel(DistributedDataParallel):
"""Implements data parallelism at the module level for the DistributedDataParallel module.
This container parallelizes the application of the given module by
splitting the input across the specified devices by chunking in the
batch dimension.
In the forward pass, the module is replicated on each device,
and each replica handles a portion of the input. During the backwards pass,
gradients from each replica are summed into the original module.
Note that the outputs are not gathered, please use compatible
:class:`encoding.parallel.DataParallelCriterion`.
The batch size should be larger than the number of GPUs used. It should
also be an integer multiple of the number of GPUs so that each chunk is
the same size (so that each GPU processes the same number of samples).
Args:
module: module to be parallelized
device_ids: CUDA devices (default: all devices)
Reference:
Hang Zhang, Kristin Dana, Jianping Shi, Zhongyue Zhang, Xiaogang Wang, Ambrish Tyagi,
Amit Agrawal. “Context Encoding for Semantic Segmentation.
*The IEEE Conference on Computer Vision and Pattern Recognition (CVPR) 2018*
Example::
>>> net = encoding.nn.DistributedDataParallelModel(model, device_ids=[0, 1, 2])
>>> y = net(x)
"""
def gather(self, outputs, output_device):
return outputs
class DataParallelModel(DataParallel):
"""Implements data parallelism at the module level.
This container parallelizes the application of the given module by
splitting the input across the specified devices by chunking in the
batch dimension.
In the forward pass, the module is replicated on each device,
and each replica handles a portion of the input. During the backwards pass,
gradients from each replica are summed into the original module.
Note that the outputs are not gathered, please use compatible
:class:`encoding.parallel.DataParallelCriterion`.
The batch size should be larger than the number of GPUs used. It should
also be an integer multiple of the number of GPUs so that each chunk is
the same size (so that each GPU processes the same number of samples).
Args:
module: module to be parallelized
device_ids: CUDA devices (default: all devices)
Reference:
Hang Zhang, Kristin Dana, Jianping Shi, Zhongyue Zhang, Xiaogang Wang, Ambrish Tyagi,
Amit Agrawal. “Context Encoding for Semantic Segmentation.
*The IEEE Conference on Computer Vision and Pattern Recognition (CVPR) 2018*
Example::
>>> net = encoding.nn.DataParallelModel(model, device_ids=[0, 1, 2])
>>> y = net(x)
"""
def gather(self, outputs, output_device):
return outputs
def replicate(self, module, device_ids):
modules = super(DataParallelModel, self).replicate(module, device_ids)
execute_replication_callbacks(modules)
return modules
class DataParallelCriterion(DataParallel):
"""
Calculate loss in multiple-GPUs, which balance the memory usage.
The targets are splitted across the specified devices by chunking in
the batch dimension. Please use together with :class:`encoding.parallel.DataParallelModel`.
Reference:
Hang Zhang, Kristin Dana, Jianping Shi, Zhongyue Zhang, Xiaogang Wang, Ambrish Tyagi,
Amit Agrawal. “Context Encoding for Semantic Segmentation.
*The IEEE Conference on Computer Vision and Pattern Recognition (CVPR) 2018*
Example::
>>> net = encoding.nn.DataParallelModel(model, device_ids=[0, 1, 2])
>>> criterion = encoding.nn.DataParallelCriterion(criterion, device_ids=[0, 1, 2])
>>> y = net(x)
>>> loss = criterion(y, target)
"""
def forward(self, inputs, *targets, **kwargs):
# input should be already scatterd
# scattering the targets instead
if not self.device_ids:
return self.module(inputs, *targets, **kwargs)
targets, kwargs = self.scatter(targets, kwargs, self.device_ids)
if len(self.device_ids) == 1:
return self.module(inputs, *targets[0], **kwargs[0])
replicas = self.replicate(self.module, self.device_ids[:len(inputs)])
outputs = _criterion_parallel_apply(replicas, inputs, targets, kwargs)
#return Reduce.apply(*outputs) / len(outputs)
#return self.gather(outputs, self.output_device).mean()
return self.gather(outputs, self.output_device)
def _criterion_parallel_apply(modules, inputs, targets, kwargs_tup=None, devices=None):
assert len(modules) == len(inputs)
assert len(targets) == len(inputs)
if kwargs_tup:
assert len(modules) == len(kwargs_tup)
else:
kwargs_tup = ({},) * len(modules)
if devices is not None:
assert len(modules) == len(devices)
else:
devices = [None] * len(modules)
lock = threading.Lock()
results = {}
if torch_ver != "0.3":
grad_enabled = torch.is_grad_enabled()
def _worker(i, module, input, target, kwargs, device=None):
if torch_ver != "0.3":
torch.set_grad_enabled(grad_enabled)
if device is None:
device = get_a_var(input).get_device()
try:
with torch.cuda.device(device):
# this also avoids accidental slicing of `input` if it is a Tensor
if not isinstance(input, (list, tuple)):
input = (input,)
if not isinstance(target, (list, tuple)):
target = (target,)
output = module(*(input + target), **kwargs)
with lock:
results[i] = output
except Exception as e:
with lock:
results[i] = e
if len(modules) > 1:
threads = [threading.Thread(target=_worker,
args=(i, module, input, target,
kwargs, device),)
for i, (module, input, target, kwargs, device) in
enumerate(zip(modules, inputs, targets, kwargs_tup, devices))]
for thread in threads:
thread.start()
for thread in threads:
thread.join()
else:
_worker(0, modules[0], inputs[0], kwargs_tup[0], devices[0])
outputs = []
for i in range(len(inputs)):
output = results[i]
if isinstance(output, Exception):
raise output
outputs.append(output)
return outputs
###########################################################################
# Adapted from Synchronized-BatchNorm-PyTorch.
# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
#
class CallbackContext(object):
pass
def execute_replication_callbacks(modules):
"""
Execute an replication callback `__data_parallel_replicate__` on each module created
by original replication.
The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)`
Note that, as all modules are isomorphism, we assign each sub-module with a context
(shared among multiple copies of this module on different devices).
Through this context, different copies can share some information.
We guarantee that the callback on the master copy (the first copy) will be called ahead
of calling the callback of any slave copies.
"""
master_copy = modules[0]
nr_modules = len(list(master_copy.modules()))
ctxs = [CallbackContext() for _ in range(nr_modules)]
for i, module in enumerate(modules):
for j, m in enumerate(module.modules()):
if hasattr(m, '__data_parallel_replicate__'):
m.__data_parallel_replicate__(ctxs[j], i)
def patch_replication_callback(data_parallel):
"""
Monkey-patch an existing `DataParallel` object. Add the replication callback.
Useful when you have customized `DataParallel` implementation.
Examples:
> sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
> sync_bn = DataParallel(sync_bn, device_ids=[0, 1])
> patch_replication_callback(sync_bn)
# this is equivalent to
> sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
> sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
"""
assert isinstance(data_parallel, DataParallel)
old_replicate = data_parallel.replicate
@functools.wraps(old_replicate)
def new_replicate(module, device_ids):
modules = old_replicate(module, device_ids)
execute_replication_callbacks(modules)
return modules
data_parallel.replicate = new_replicate
================================================
FILE: scripts/utils/transforms.py
================================================
from __future__ import absolute_import
import os
import numpy as np
import scipy.misc
import matplotlib.pyplot as plt
import torch
import torchvision
from .misc import *
from .imutils import *
def color_normalize(x, mean, std):
if x.size(0) == 1:
x = x.repeat(3, x.size(1), x.size(2))
for t, m, s in zip(x, mean, std):
t.sub_(m)
return x
def flip_back(flip_output, dataset='mpii'):
"""
flip output map
"""
if dataset == 'mpii':
matchedParts = (
[0,5], [1,4], [2,3],
[10,15], [11,14], [12,13]
)
else:
print('Not supported dataset: ' + dataset)
# flip output horizontally
flip_output = fliplr(flip_output.numpy())
# Change left-right parts
for pair in matchedParts:
tmp = np.copy(flip_output[:, pair[0], :, :])
flip_output[:, pair[0], :, :] = flip_output[:, pair[1], :, :]
flip_output[:, pair[1], :, :] = tmp
return torch.from_numpy(flip_output).float()
def shufflelr(x, width, dataset='mpii'):
"""
flip coords
"""
if dataset == 'mpii':
matchedParts = (
[0,5], [1,4], [2,3],
[10,15], [11,14], [12,13]
)
else:
print('Not supported dataset: ' + dataset)
# Flip horizontal
x[:, 0] = width - x[:, 0]
# Change left-right parts
for pair in matchedParts:
tmp = x[pair[0], :].clone()
x[pair[0], :] = x[pair[1], :]
x[pair[1], :] = tmp
return x
def fliplr(x):
if x.ndim == 3:
x = np.transpose(np.fliplr(np.transpose(x, (0, 2, 1))), (0, 2, 1))
elif x.ndim == 4:
for i in range(x.shape[0]):
x[i] = np.transpose(np.fliplr(np.transpose(x[i], (0, 2, 1))), (0, 2, 1))
return x.astype(float)
def get_transform(center, scale, res, rot=0):
"""
General image processing functions
"""
# Generate transformation matrix
h = 200 * scale
t = np.zeros((3, 3))
t[0, 0] = float(res[1]) / h
t[1, 1] = float(res[0]) / h
t[0, 2] = res[1] * (-float(center[0]) / h + .5)
t[1, 2] = res[0] * (-float(center[1]) / h + .5)
t[2, 2] = 1
if not rot == 0:
rot = -rot # To match direction of rotation from cropping
rot_mat = np.zeros((3,3))
rot_rad = rot * np.pi / 180
sn,cs = np.sin(rot_rad), np.cos(rot_rad)
rot_mat[0,:2] = [cs, -sn]
rot_mat[1,:2] = [sn, cs]
rot_mat[2,2] = 1
# Need to rotate around center
t_mat = np.eye(3)
t_mat[0,2] = -res[1]/2
t_mat[1,2] = -res[0]/2
t_inv = t_mat.copy()
t_inv[:2,2] *= -1
t = np.dot(t_inv,np.dot(rot_mat,np.dot(t_mat,t)))
return t
def transform(pt, center, scale, res, invert=0, rot=0):
# Transform pixel location to different reference
t = get_transform(center, scale, res, rot=rot)
if invert:
t = np.linalg.inv(t)
new_pt = np.array([pt[0] - 1, pt[1] - 1, 1.]).T
new_pt = np.dot(t, new_pt)
return new_pt[:2].astype(int) + 1
def transform_preds(coords, center, scale, res):
# size = coords.size()
# coords = coords.view(-1, coords.size(-1))
# print(coords.size())
for p in range(coords.size(0)):
coords[p, 0:2] = to_torch(transform(coords[p, 0:2], center, scale, res, 1, 0))
return coords
def crop(img, center, scale, res, rot=0):
img = im_to_numpy(img)
# Upper left point
ul = np.array(transform([0, 0], center, scale, res, invert=1))
# Bottom right point
br = np.array(transform(res, center, scale, res, invert=1))
# Padding so that when rotated proper amount of context is included
pad = int(np.linalg.norm(br - ul) / 2 - float(br[1] - ul[1]) / 2)
if not rot == 0:
ul -= pad
br += pad
new_shape = [br[1] - ul[1], br[0] - ul[0]]
if len(img.shape) > 2:
new_shape += [img.shape[2]]
new_img = np.zeros(new_shape)
# Range to fill new array
new_x = max(0, -ul[0]), min(br[0], len(img[0])) - ul[0]
new_y = max(0, -ul[1]), min(br[1], len(img)) - ul[1]
# Range to sample from original image
old_x = max(0, ul[0]), min(len(img[0]), br[0])
old_y = max(0, ul[1]), min(len(img), br[1])
new_img[new_y[0]:new_y[1], new_x[0]:new_x[1]] = img[old_y[0]:old_y[1], old_x[0]:old_x[1]]
if not rot == 0:
# Remove padding
new_img = scipy.misc.imrotate(new_img, rot)
new_img = new_img[pad:-pad, pad:-pad]
new_img = im_to_torch(scipy.misc.imresize(new_img, res))
return new_img
def get_right(img,gray=False):
img = im_to_numpy(img) #H*W*C
new_img = img[:,0:256,:]
new_img = im_to_torch(new_img)
if gray == True:
new_img = new_img[1,:,:];
return new_img
class NormalizeInverse(torchvision.transforms.Normalize):
"""
Undoes the normalization and returns the reconstructed images in the input domain.
"""
def __init__(self, mean, std):
mean = torch.as_tensor(mean)
std = torch.as_tensor(std)
std_inv = 1 / (std + 1e-7)
mean_inv = -mean * std_inv
super().__init__(mean=mean_inv, std=std_inv)
def __call__(self, tensor):
return super().__call__(tensor.clone())
================================================
FILE: test.py
================================================
from __future__ import print_function, absolute_import
import argparse
import torch
torch.backends.cudnn.benchmark = True
from scripts.utils.misc import save_checkpoint, adjust_learning_rate
import scripts.datasets as datasets
import scripts.machines as machines
from options import Options
def main(args):
val_loader = torch.utils.data.DataLoader(datasets.COCO('val',args),batch_size=args.test_batch, shuffle=False,
num_workers=args.workers, pin_memory=True)
data_loaders = (None,val_loader)
Machine = machines.__dict__[args.machine](datasets=data_loaders, args=args)
Machine.test()
if __name__ == '__main__':
parser=Options().init(argparse.ArgumentParser(description='WaterMark Removal'))
main(parser.parse_args())
================================================
FILE: watermark_synthesis.ipynb
================================================
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"SAVE ALL THE SETTING\n"
]
}
],
"source": [
"# watermark synthesis\n",
"import os \n",
"import random\n",
"import shutil\n",
"from PIL import Image\n",
"import numpy as np\n",
"\n",
"def trans_paste(bg_img,fg_img,mask,box=(0,0)):\n",
" fg_img_trans = Image.new(\"RGBA\",bg_img.size)\n",
" fg_img_trans.paste(fg_img,box,mask=mask)\n",
" new_img = Image.alpha_composite(bg_img,fg_img_trans)\n",
" return new_img,fg_img_trans\n",
"\n",
"if os.path.isdir('dataset'):\n",
" shutil.rmtree('dataset')\n",
"\n",
"os.mkdir('dataset')\n",
"BASE_IMG_DIR = '/Users/oishii/Downloads/val2014/'\n",
"WATERMARK_DIR = 'logos' #1080 \n",
"images = sorted([os.path.join(BASE_IMG_DIR,x) for x in os.listdir(BASE_IMG_DIR) if '.jpg' in x])\n",
"watermarks = sorted([os.path.join(WATERMARK_DIR,x).replace(' ','_') for x in os.listdir(WATERMARK_DIR) if '.png' in x])\n",
"# rename all the watermark from replace ' ' to '_'\n",
"\n",
"random.shuffle(images)\n",
"random.shuffle(watermarks)\n",
"\n",
"train_images = images[:int(len(images)*0.7)]\n",
"val_images = images[int(len(images)*0.7):int(len(images)*0.8)]\n",
"test_images = images[int(len(images)*0.8):]\n",
"\n",
"train_wms = watermarks[:int(len(watermarks)*0.7)]\n",
"val_wms = watermarks[int(len(watermarks)*0.7):int(len(watermarks)*0.8)]\n",
"test_wms = watermarks[int(len(watermarks)*0.8):]\n",
"\n",
"# save all the settings to file\n",
"names = ['train_images','val_images','test_images','train_wms','val_wms','test_wms']\n",
"lists = [train_images,val_images,test_images,train_wms,val_wms,test_wms]\n",
"dataset = dict(zip(names, lists))\n",
"\n",
"for name,content in dataset.items():\n",
" with open('dataset/%s.txt'%name,'w') as f:\n",
" f.write(\"\\n\".join(content))\n",
"\n",
"print('SAVE ALL THE SETTING')\n",
"\n",
"for name, images in dataset.items():\n",
" if 'images' not in name:\n",
" continue\n",
" # for each setting, synthesis the watermark\n",
" # for each image, add X(X=6) watermark in differnet position, alpha,\n",
" # save the synthesized image, watermark mask, reshaped mask,\n",
" save_path = 'dataset/%s/'%name\n",
" os.makedirs('%s/image'%(save_path))\n",
" os.makedirs('%s/mask'%(save_path))\n",
" os.makedirs('%s/wm'%(save_path))\n",
" \n",
" for img in images:\n",
" im = Image.open(img).convert('RGBA')\n",
" imw,imh = im.size\n",
" \n",
" for wmg in random.choices(dataset[name.replace('images','wms')],k=6):\n",
" wm = Image.open(wmg.replace('_',' ')).convert(\"RGBA\") # RGBA\n",
" # get the mask of wm\n",
" # data agumentation of wm\n",
" wm = wm.rotate(angle=random.randint(0,360),expand=True) # rotate\n",
" \n",
" # make sure the \n",
" imrw = random.randrange(int(0.4*imw),int(0.8*imw))\n",
" imrh = random.randrange(int(0.4*imh),int(0.8*imh))\n",
" wmsize = imrh if imrw > imrh else imrw\n",
" wm = wm.resize((wmsize,wmsize),Image.BILINEAR)\n",
" w,h = wm.size # new size \n",
" \n",
" box_left = random.randint(0,imw-w)\n",
" box_upper = random.randint(0,imh-h)\n",
" wmm = wm.copy()\n",
" wm.putalpha(random.randint(int(255*0.4),int(255*0.8))) # alpha\n",
" \n",
" ims,wmc = trans_paste(im,wm,wmm,(box_left,box_upper))\n",
" \n",
" wmnp = np.array(wmc) # h,w,3\n",
" mask = np.sum(wmnp,axis=2)>0\n",
" mm = Image.fromarray(np.uint8(mask*255),mode='L')\n",
" \n",
" identifier = os.path.basename(img).split('.')[0] +'-'+os.path.basename(wmg).split('.')[0] + '.png'\n",
" # save \n",
" wmc.save('%s/wm/%s'%(save_path,identifier))\n",
" ims.save('%s/image/%s'%(save_path,identifier))\n",
" mm.save('%s/mask/%s'%(save_path,identifier))\n",
" \n",
" "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.4"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
gitextract_p5y2m_4m/ ├── README.md ├── examples/ │ ├── evaluate.sh │ └── test.sh ├── main.py ├── options.py ├── requirements.txt ├── scripts/ │ ├── __init__.py │ ├── datasets/ │ │ ├── BIH.py │ │ ├── COCO.py │ │ └── __init__.py │ ├── machines/ │ │ ├── BasicMachine.py │ │ ├── S2AM.py │ │ ├── VX.py │ │ └── __init__.py │ ├── models/ │ │ ├── __init__.py │ │ ├── backbone_unet.py │ │ ├── blocks.py │ │ ├── discriminator.py │ │ ├── rasc.py │ │ ├── sa_resunet.py │ │ ├── unet.py │ │ ├── vgg.py │ │ └── vmu.py │ └── utils/ │ ├── __init__.py │ ├── evaluation.py │ ├── imutils.py │ ├── logger.py │ ├── losses.py │ ├── misc.py │ ├── model_init.py │ ├── osutils.py │ ├── parallel.py │ └── transforms.py ├── test.py └── watermark_synthesis.ipynb
SYMBOL INDEX (353 symbols across 26 files)
FILE: main.py
function main (line 14) | def main(args):
FILE: options.py
class Options (line 8) | class Options():
method __init__ (line 10) | def __init__(self):
method init (line 13) | def init(self, parser):
FILE: scripts/datasets/BIH.py
class BIH (line 27) | class BIH(data.Dataset):
method __init__ (line 28) | def __init__(self,train,config=None, sample=[],gan_norm=False):
method __getitem__ (line 81) | def __getitem__(self, index):
method __len__ (line 98) | def __len__(self):
FILE: scripts/datasets/COCO.py
class COCO (line 27) | class COCO(data.Dataset):
method __init__ (line 28) | def __init__(self,train,config=None, sample=[],gan_norm=False):
method __getitem__ (line 80) | def __getitem__(self, index):
method __len__ (line 97) | def __len__(self):
FILE: scripts/machines/BasicMachine.py
class BasicMachine (line 25) | class BasicMachine(object):
method __init__ (line 26) | def __init__(self, datasets =(None,None), models = None, args = None, ...
method train (line 82) | def train(self,epoch):
method test (line 155) | def test(self, ):
method validate (line 196) | def validate(self, epoch):
method resume (line 240) | def resume(self,resume_path):
method save_checkpoint (line 259) | def save_checkpoint(self,filename='checkpoint.pth.tar', snapshot=None):
method clean (line 284) | def clean(self):
method record (line 287) | def record(self,k,v,epoch):
method flush (line 290) | def flush(self):
method norm (line 294) | def norm(self,x):
method denorm (line 300) | def denorm(self,x):
FILE: scripts/machines/S2AM.py
class S2AM (line 25) | class S2AM(object):
method __init__ (line 26) | def __init__(self, datasets =(None,None), models = None, args = None, ...
method train (line 82) | def train(self,epoch):
method test (line 156) | def test(self, ):
method validate (line 195) | def validate(self, epoch):
method resume (line 240) | def resume(self,resume_path):
method save_checkpoint (line 259) | def save_checkpoint(self,filename='checkpoint.pth.tar', snapshot=None):
method clean (line 284) | def clean(self):
method record (line 287) | def record(self,k,v,epoch):
method flush (line 290) | def flush(self):
method norm (line 294) | def norm(self,x):
method denorm (line 300) | def denorm(self,x):
FILE: scripts/machines/VX.py
class Losses (line 23) | class Losses(nn.Module):
method __init__ (line 24) | def __init__(self, argx, device, norm_func=None, denorm_func=None):
method forward (line 49) | def forward(self,pred_ims,target,pred_ms,mask,pred_wms,wm):
class VX (line 87) | class VX(BasicMachine):
method __init__ (line 88) | def __init__(self,**kwargs):
method train (line 94) | def train(self,epoch):
method validate (line 182) | def validate(self, epoch):
method test (line 254) | def test(self, ):
FILE: scripts/machines/__init__.py
function basic (line 6) | def basic(**kwargs):
function s2am (line 9) | def s2am(**kwargs):
function vx (line 12) | def vx(**kwargs):
FILE: scripts/models/backbone_unet.py
function vvv4n (line 19) | def vvv4n(**kwargs):
function vm3 (line 24) | def vm3(**kwargs):
function urasc (line 29) | def urasc(**kwargs):
function rascv2 (line 39) | def rascv2(**kwargs):
function unet (line 45) | def unet(**kwargs):
FILE: scripts/models/blocks.py
class BasicLearningBlock (line 15) | class BasicLearningBlock(nn.Module):
method __init__ (line 17) | def __init__(self,channel):
method forward (line 24) | def forward(self,feature):
class GaussianSmoothing (line 30) | class GaussianSmoothing(nn.Module):
method __init__ (line 43) | def __init__(self, channels, kernel_size, sigma, dim=2):
method forward (line 85) | def forward(self, input):
class ChannelPool (line 95) | class ChannelPool(nn.Module):
method __init__ (line 96) | def __init__(self,types):
method forward (line 105) | def forward(self, input):
class SEBlock (line 114) | class SEBlock(nn.Module):
method __init__ (line 116) | def __init__(self, channel,reducation=16):
method forward (line 125) | def forward(self,x):
class GlobalAttentionModule (line 133) | class GlobalAttentionModule(nn.Module):
method __init__ (line 135) | def __init__(self, channel,reducation=16):
method forward (line 145) | def forward(self,x):
class SpatialAttentionModule (line 152) | class SpatialAttentionModule(nn.Module):
method __init__ (line 154) | def __init__(self, channel,reducation=16):
method forward (line 164) | def forward(self,x):
class GlobalAttentionModuleJustSigmoid (line 174) | class GlobalAttentionModuleJustSigmoid(nn.Module):
method __init__ (line 176) | def __init__(self, channel,reducation=16):
method forward (line 186) | def forward(self,x):
class BasicBlock (line 195) | class BasicBlock(nn.Module):
method __init__ (line 196) | def __init__(self, in_planes, out_planes, kernel_size, stride=1, paddi...
method forward (line 203) | def forward(self, x):
class Flatten (line 211) | class Flatten(nn.Module):
method forward (line 212) | def forward(self, x):
class ChannelGate (line 215) | class ChannelGate(nn.Module):
method __init__ (line 216) | def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg...
method forward (line 226) | def forward(self, x):
function logsumexp_2d (line 251) | def logsumexp_2d(tensor):
class ChannelPoolX (line 257) | class ChannelPoolX(nn.Module):
method forward (line 258) | def forward(self, x):
class SpatialGate (line 261) | class SpatialGate(nn.Module):
method __init__ (line 262) | def __init__(self):
method forward (line 267) | def forward(self, x):
class CBAM (line 273) | class CBAM(nn.Module):
method __init__ (line 274) | def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg...
method forward (line 280) | def forward(self, x):
FILE: scripts/models/discriminator.py
class SNCoXvWithActivation (line 17) | class SNCoXvWithActivation(torch.nn.Module):
method __init__ (line 21) | def __init__(self, in_channels, out_channels, kernel_size, stride=1, p...
method forward (line 29) | def forward(self, input):
function l2normalize (line 36) | def l2normalize(v, eps=1e-12):
class SpectralNorm (line 40) | class SpectralNorm(nn.Module):
method __init__ (line 41) | def __init__(self, module, name='weight', power_iterations=1):
method _update_u_v (line 49) | def _update_u_v(self):
method _made_params (line 63) | def _made_params(self):
method _make_params (line 73) | def _make_params(self):
method forward (line 92) | def forward(self, *args):
function get_pad (line 97) | def get_pad(in_, ksize, stride, atrous=1):
class SNDiscriminator (line 101) | class SNDiscriminator(nn.Module):
method __init__ (line 102) | def __init__(self,channel=6):
method forward (line 116) | def forward(self, img_A, img_B):
class Discriminator (line 124) | class Discriminator(nn.Module):
method __init__ (line 125) | def __init__(self, in_channels=3):
method forward (line 145) | def forward(self, img_A, img_B):
function patchgan (line 151) | def patchgan():
function sngan (line 156) | def sngan():
function maskedsngan (line 161) | def maskedsngan():
FILE: scripts/models/rasc.py
class CAWapper (line 15) | class CAWapper(nn.Module):
method __init__ (line 18) | def __init__(self, channel, type_of_connection=BasicLearningBlock):
method forward (line 22) | def forward(self, feature, mask):
class NLWapper (line 34) | class NLWapper(nn.Module):
method __init__ (line 37) | def __init__(self, channel, type_of_connection=BasicLearningBlock):
method forward (line 41) | def forward(self, feature, mask):
class SENet (line 52) | class SENet(nn.Module):
method __init__ (line 54) | def __init__(self,channel,type_of_connection=BasicLearningBlock):
method forward (line 58) | def forward(self,feature,mask):
class CBAMConnect (line 69) | class CBAMConnect(nn.Module):
method __init__ (line 70) | def __init__(self,channel):
method forward (line 74) | def forward(self,feature,mask):
class RASC (line 80) | class RASC(nn.Module):
method __init__ (line 81) | def __init__(self,channel,type_of_connection=BasicLearningBlock):
method forward (line 89) | def forward(self,feature,mask):
class UNO (line 110) | class UNO(nn.Module):
method __init__ (line 111) | def __init__(self,channel):
method forward (line 114) | def forward(self,feature,_m):
class URASC (line 118) | class URASC(nn.Module):
method __init__ (line 119) | def __init__(self,channel,type_of_connection=BasicLearningBlock):
method forward (line 127) | def forward(self,feature, m=None):
class MaskedURASC (line 139) | class MaskedURASC(nn.Module):
method __init__ (line 140) | def __init__(self,channel,type_of_connection=BasicLearningBlock):
method forward (line 148) | def forward(self,feature):
FILE: scripts/models/sa_resunet.py
function weight_init (line 9) | def weight_init(m):
function reset_params (line 14) | def reset_params(model):
function conv3x3 (line 19) | def conv3x3(in_channels, out_channels, stride=1,
function up_conv2x2 (line 31) | def up_conv2x2(in_channels, out_channels, transpose=True):
function conv1x1 (line 44) | def conv1x1(in_channels, out_channels, groups=1):
class UpCoXvD (line 53) | class UpCoXvD(nn.Module):
method __init__ (line 55) | def __init__(self, in_channels, out_channels, blocks, residual=True,no...
method forward (line 88) | def forward(self, from_up, from_down, mask=None,se=None):
class DownCoXvD (line 118) | class DownCoXvD(nn.Module):
method __init__ (line 120) | def __init__(self, in_channels, out_channels, blocks, pooling=True, no...
method __call__ (line 143) | def __call__(self, x):
method forward (line 146) | def forward(self, x):
class UnetDecoderD (line 162) | class UnetDecoderD(nn.Module):
method __init__ (line 163) | def __init__(self, in_channels=512, out_channels=3, norm=nn.BatchNorm2...
method __call__ (line 200) | def __call__(self, x, encoder_outs=None):
method forward (line 203) | def forward(self, x, encoder_outs=None):
class UnetDecoderDatt (line 217) | class UnetDecoderDatt(nn.Module):
method __init__ (line 218) | def __init__(self, in_channels=512, out_channels=3, depth=5, blocks=1,...
method forward (line 258) | def forward(self, input, encoder_outs=None):
class UnetEncoderD (line 286) | class UnetEncoderD(nn.Module):
method __init__ (line 288) | def __init__(self, in_channels=3, depth=5, blocks=1, start_filters=32,...
method __call__ (line 303) | def __call__(self, x):
method forward (line 306) | def forward(self, x):
class ResDown (line 313) | class ResDown(nn.Module):
method __init__ (line 314) | def __init__(self, in_size, out_size, pooling=True, use_att=False):
method forward (line 318) | def forward(self, x):
class ResUp (line 321) | class ResUp(nn.Module):
method __init__ (line 322) | def __init__(self, in_size, out_size, use_att=False):
method forward (line 326) | def forward(self, x, skip_input, mask=None):
class ResDownNew (line 329) | class ResDownNew(nn.Module):
method __init__ (line 330) | def __init__(self, in_size, out_size, pooling=True, use_att=False):
method forward (line 334) | def forward(self, x):
class ResUpNew (line 337) | class ResUpNew(nn.Module):
method __init__ (line 338) | def __init__(self, in_size, out_size, use_att=False):
method forward (line 342) | def forward(self, x, skip_input, mask=None):
class VMSingle (line 347) | class VMSingle(nn.Module):
method __init__ (line 348) | def __init__(self, in_channels=3, out_channels=3, down=ResDown, up=Res...
method forward (line 366) | def forward(self, input):
class VMSingleS2AM (line 385) | class VMSingleS2AM(nn.Module):
method __init__ (line 386) | def __init__(self, in_channels=3, out_channels=3, down=ResDown, up=Res...
method forward (line 408) | def forward(self, input):
class UnetVMS2AMv4 (line 430) | class UnetVMS2AMv4(nn.Module):
method __init__ (line 432) | def __init__(self, in_channels=3, depth=5, shared_depth=0, use_vm_deco...
method set_optimizers (line 480) | def set_optimizers(self):
method zero_grad_all (line 491) | def zero_grad_all(self):
method step_all (line 501) | def step_all(self):
method step_optimizer_image (line 511) | def step_optimizer_image(self):
method __call__ (line 514) | def __call__(self, synthesized):
method forward (line 517) | def forward(self, synthesized):
method unshared_forward (line 520) | def unshared_forward(self, synthesized):
method shared_forward (line 531) | def shared_forward(self, synthesized):
FILE: scripts/models/unet.py
class MinimalUnetV2 (line 9) | class MinimalUnetV2(nn.Module):
method __init__ (line 11) | def __init__(self, down=None,up=None,submodule=None,attention=None,wit...
method forward (line 22) | def forward(self,x,mask=None):
class MinimalUnet (line 39) | class MinimalUnet(nn.Module):
method __init__ (line 41) | def __init__(self, down=None,up=None,submodule=None,attention=None,wit...
method forward (line 52) | def forward(self,x,mask=None):
class UnetSkipConnectionBlock (line 72) | class UnetSkipConnectionBlock(nn.Module):
method __init__ (line 73) | def __init__(self, outer_nc, inner_nc, input_nc=None,
method forward (line 129) | def forward(self, x,mask=None):
class UnetGenerator (line 133) | class UnetGenerator(nn.Module):
method __init__ (line 134) | def __init__(self, input_nc, output_nc, num_downs=8, ngf=64,norm_layer...
method forward (line 153) | def forward(self, input):
FILE: scripts/models/vgg.py
class Vgg16 (line 7) | class Vgg16(torch.nn.Module):
method __init__ (line 8) | def __init__(self, requires_grad=False):
method forward (line 31) | def forward(self, X):
class Vgg19 (line 47) | class Vgg19(torch.nn.Module):
method __init__ (line 48) | def __init__(self, requires_grad=False):
method forward (line 71) | def forward(self, X, indices=None):
FILE: scripts/models/vmu.py
function weight_init (line 9) | def weight_init(m):
function reset_params (line 14) | def reset_params(model):
function conv3x3 (line 19) | def conv3x3(in_channels, out_channels, stride=1,
function up_conv2x2 (line 31) | def up_conv2x2(in_channels, out_channels, transpose=True):
function conv1x1 (line 44) | def conv1x1(in_channels, out_channels, groups=1):
class UpCoXvD (line 55) | class UpCoXvD(nn.Module):
method __init__ (line 57) | def __init__(self, in_channels, out_channels, blocks, residual=True, b...
method forward (line 85) | def forward(self, from_up, from_down, mask=None):
class DownCoXvD (line 111) | class DownCoXvD(nn.Module):
method __init__ (line 113) | def __init__(self, in_channels, out_channels, blocks, pooling=True, re...
method __call__ (line 133) | def __call__(self, x):
method forward (line 136) | def forward(self, x):
class UnetDecoderD (line 152) | class UnetDecoderD(nn.Module):
method __init__ (line 153) | def __init__(self, in_channels=512, out_channels=3, depth=5, blocks=1,...
method __call__ (line 178) | def __call__(self, x, encoder_outs=None):
method forward (line 181) | def forward(self, x, encoder_outs=None):
class UnetEncoderD (line 192) | class UnetEncoderD(nn.Module):
method __init__ (line 194) | def __init__(self, in_channels=3, depth=5, blocks=1, start_filters=32,...
method __call__ (line 209) | def __call__(self, x):
method forward (line 212) | def forward(self, x):
class UnetVM (line 221) | class UnetVM(nn.Module):
method __init__ (line 223) | def __init__(self, in_channels=3, depth=5, shared_depth=0, use_vm_deco...
method set_optimizers (line 262) | def set_optimizers(self):
method zero_grad_all (line 271) | def zero_grad_all(self):
method step_all (line 280) | def step_all(self):
method step_optimizer_image (line 289) | def step_optimizer_image(self):
method __call__ (line 292) | def __call__(self, synthesized):
method forward (line 295) | def forward(self, synthesized):
method unshared_forward (line 298) | def unshared_forward(self, synthesized):
method shared_forward (line 309) | def shared_forward(self, synthesized):
FILE: scripts/utils/evaluation.py
function get_preds (line 13) | def get_preds(scores):
function calc_dists (line 32) | def calc_dists(preds, target, normalize):
function dist_acc (line 44) | def dist_acc(dists, thr=0.5):
function accuracy (line 53) | def accuracy(output, target, thr=0.5):
function final_preds (line 77) | def final_preds(output, center, scale, res):
class AverageMeter (line 102) | class AverageMeter(object):
method __init__ (line 104) | def __init__(self):
method reset (line 107) | def reset(self):
method update (line 113) | def update(self, val, n=1):
FILE: scripts/utils/imutils.py
function im_to_numpy (line 10) | def im_to_numpy(img):
function im_to_torch (line 15) | def im_to_torch(img):
function load_image (line 22) | def load_image(img_path):
function imread_all (line 26) | def imread_all(img_path):
function load_image_gray (line 29) | def load_image_gray(img_path):
function resize (line 35) | def resize(img, owidth, oheight):
function gaussian (line 54) | def gaussian(shape=(7,7),sigma=1):
function draw_labelmap (line 65) | def draw_labelmap(img, pt, sigma, type='Gaussian'):
function gauss (line 104) | def gauss(x, a, b, c, d=0):
function color_heatmap (line 107) | def color_heatmap(x):
function imshow (line 117) | def imshow(img):
function show_joints (line 122) | def show_joints(img, pts):
function show_sample (line 130) | def show_sample(inputs, target):
function sample_with_heatmap (line 146) | def sample_with_heatmap(inp, out, num_rows=2, parts_to_show=None):
function batch_with_heatmap (line 181) | def batch_with_heatmap(inputs, outputs, mean=torch.Tensor([0.5, 0.5, 0.5...
function normalize_batch (line 191) | def normalize_batch(batch):
function show_image_tensor (line 198) | def show_image_tensor(tensor):
function get_jet (line 211) | def get_jet():
function clamp (line 221) | def clamp(num, min_value, max_value):
function gray2color (line 224) | def gray2color(gray_array, color_map):
class objectview (line 236) | class objectview(object):
method __init__ (line 237) | def __init__(self, *args, **kwargs):
FILE: scripts/utils/logger.py
function savefig (line 12) | def savefig(fname, dpi=None):
function plot_overlap (line 16) | def plot_overlap(logger, names=None):
class Logger (line 24) | class Logger(object):
method __init__ (line 26) | def __init__(self, fpath, title=None, resume=False):
method set_names (line 48) | def set_names(self, names):
method append (line 62) | def append(self, numbers):
method plot (line 71) | def plot(self, names=None):
method close (line 80) | def close(self):
class LoggerMonitor (line 84) | class LoggerMonitor(object):
method __init__ (line 86) | def __init__ (self, paths):
method plot (line 93) | def plot(self, names=None):
FILE: scripts/utils/losses.py
class WeightedBCE (line 10) | class WeightedBCE(nn.Module):
method __init__ (line 11) | def __init__(self):
method forward (line 14) | def forward(self, pred, gt):
function l1_relative (line 28) | def l1_relative(reconstructed, real, mask):
function is_dic (line 40) | def is_dic(x):
class Losses (line 43) | class Losses(nn.Module):
method __init__ (line 44) | def __init__(self, argx, device):
method forward (line 70) | def forward(self,imgx,target,attx,mask,wmx,wm):
function gram_matrix (line 116) | def gram_matrix(feat):
class MeanShift (line 124) | class MeanShift(nn.Conv2d):
method __init__ (line 125) | def __init__(self, data_mean, data_std, data_range=1, norm=True):
function VGGLoss (line 142) | def VGGLoss(losstype):
class VGGLossA (line 156) | class VGGLossA(nn.Module):
method __init__ (line 157) | def __init__(self, vgg=None, weights=None, indices=None, normalize=True):
method forward (line 171) | def forward(self, x, y):
class VGG16FeatureExtractor (line 182) | class VGG16FeatureExtractor(nn.Module):
method __init__ (line 183) | def __init__(self):
method forward (line 195) | def forward(self, image):
class VGGLossX (line 202) | class VGGLossX(nn.Module):
method __init__ (line 203) | def __init__(self, normalize=True, mask=False, relative=False):
method forward (line 216) | def forward(self, x, y, Xmask=None):
class GANLosses (line 239) | class GANLosses(object):
method __init__ (line 241) | def __init__(self, gantype):
method g_loss (line 247) | def g_loss(self,dis_fake):
method d_loss (line 253) | def d_loss(self,dis_fake,dis_real):
class gen_gan (line 260) | class gen_gan(nn.Module):
method __init__ (line 261) | def __init__(self,gantype):
method forward (line 270) | def forward(self,dis_fake):
class dis_gan (line 273) | class dis_gan(nn.Module):
method __init__ (line 274) | def __init__(self,gantype):
method forward (line 283) | def forward(self,dis_fake,dis_real):
function gen_hinge (line 307) | def gen_hinge(dis_fake, dis_real=None):
function dis_hinge (line 310) | def dis_hinge(dis_fake, dis_real):
FILE: scripts/utils/misc.py
function to_numpy (line 12) | def to_numpy(tensor):
function resize_to_match (line 20) | def resize_to_match(fm,to):
function to_torch (line 25) | def to_torch(ndarray):
function save_checkpoint (line 34) | def save_checkpoint(machine,filename='checkpoint.pth.tar', snapshot=None):
function save_pred (line 61) | def save_pred(preds, checkpoint='checkpoint', filename='preds_valid.mat'):
function adjust_learning_rate (line 67) | def adjust_learning_rate(datasets,optimizer, epoch, lr,args):
FILE: scripts/utils/model_init.py
function weights_init_normal (line 6) | def weights_init_normal(m):
function weights_init_xavier (line 18) | def weights_init_xavier(m):
function weights_init_kaiming (line 30) | def weights_init_kaiming(m):
function weights_init_orthogonal (line 42) | def weights_init_orthogonal(m):
FILE: scripts/utils/osutils.py
function mkdir_p (line 6) | def mkdir_p(dir_path):
function isfile (line 13) | def isfile(fname):
function isdir (line 16) | def isdir(dirname):
function join (line 19) | def join(path, *paths):
FILE: scripts/utils/parallel.py
function allreduce (line 27) | def allreduce(*inputs):
class AllReduce (line 33) | class AllReduce(Function):
method forward (line 35) | def forward(ctx, num_inputs, *inputs):
method backward (line 47) | def backward(ctx, *inputs):
class Reduce (line 56) | class Reduce(Function):
method forward (line 58) | def forward(ctx, *inputs):
method backward (line 64) | def backward(ctx, gradOutput):
class DistributedDataParallelModel (line 67) | class DistributedDataParallelModel(DistributedDataParallel):
method gather (line 91) | def gather(self, outputs, output_device):
class DataParallelModel (line 94) | class DataParallelModel(DataParallel):
method gather (line 124) | def gather(self, outputs, output_device):
method replicate (line 127) | def replicate(self, module, device_ids):
class DataParallelCriterion (line 133) | class DataParallelCriterion(DataParallel):
method forward (line 151) | def forward(self, inputs, *targets, **kwargs):
function _criterion_parallel_apply (line 166) | def _criterion_parallel_apply(modules, inputs, targets, kwargs_tup=None,...
class CallbackContext (line 229) | class CallbackContext(object):
function execute_replication_callbacks (line 233) | def execute_replication_callbacks(modules):
function patch_replication_callback (line 257) | def patch_replication_callback(data_parallel):
FILE: scripts/utils/transforms.py
function color_normalize (line 14) | def color_normalize(x, mean, std):
function flip_back (line 23) | def flip_back(flip_output, dataset='mpii'):
function shufflelr (line 47) | def shufflelr(x, width, dataset='mpii'):
function fliplr (line 71) | def fliplr(x):
function get_transform (line 80) | def get_transform(center, scale, res, rot=0):
function transform (line 110) | def transform(pt, center, scale, res, invert=0, rot=0):
function transform_preds (line 120) | def transform_preds(coords, center, scale, res):
function crop (line 129) | def crop(img, center, scale, res, rot=0):
function get_right (line 165) | def get_right(img,gray=False):
class NormalizeInverse (line 177) | class NormalizeInverse(torchvision.transforms.Normalize):
method __init__ (line 182) | def __init__(self, mean, std):
method __call__ (line 189) | def __call__(self, tensor):
FILE: test.py
function main (line 14) | def main(args):
Condensed preview — 35 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (188K chars).
[
{
"path": "README.md",
"chars": 3088,
"preview": "This repo contains the code and results of the AAAI 2021 paper:\n\n<i><b> [Split then Refine: Stacked Attention-guided Res"
},
{
"path": "examples/evaluate.sh",
"chars": 1107,
"preview": "set -ex\n\n\n\n# example training scripts for AAAI-21\n# Split then Refine: Stacked Attention-guided ResUNets for Blind Singl"
},
{
"path": "examples/test.sh",
"chars": 350,
"preview": "\nset -ex\n\nCUDA_VISIBLE_DEVICES=0 python /data/home/yb87432/s2am/test.py \\\n -c test/10kgray_ssim\\\n --resume /data/home/"
},
{
"path": "main.py",
"chars": 3470,
"preview": "from __future__ import print_function, absolute_import\n\nimport argparse\nimport torch,time,os\n\ntorch.backends.cudnn.bench"
},
{
"path": "options.py",
"chars": 6050,
"preview": "\nimport scripts.models as models\n\nmodel_names = sorted(name for name in models.__dict__\n if name.islower() and not na"
},
{
"path": "requirements.txt",
"chars": 150,
"preview": "numpy==1.19.1\nopencv-python==3.4.8.29\nPillow\nscikit-image==0.14.5\nscikit-learn==0.23.1\nscipy==1.2.1\nsklearn==0.0\ntensorb"
},
{
"path": "scripts/__init__.py",
"chars": 255,
"preview": "from __future__ import absolute_import\n\nfrom . import datasets\nfrom . import models\nfrom . import utils\n\n# import os, sy"
},
{
"path": "scripts/datasets/BIH.py",
"chars": 3357,
"preview": "from __future__ import print_function, absolute_import\n\nimport os\nimport csv\nimport numpy as np\nimport json\nimport rando"
},
{
"path": "scripts/datasets/COCO.py",
"chars": 3353,
"preview": "from __future__ import print_function, absolute_import\n\nimport os\nimport csv\nimport numpy as np\nimport json\nimport rando"
},
{
"path": "scripts/datasets/__init__.py",
"chars": 69,
"preview": "from .COCO import COCO\nfrom .BIH import BIH\n\n__all__ = ('COCO','BIH')"
},
{
"path": "scripts/machines/BasicMachine.py",
"chars": 11420,
"preview": "import torch\nimport torch.nn as nn\nimport torch.backends.cudnn as cudnn\nfrom progress.bar import Bar\nimport json\nimport "
},
{
"path": "scripts/machines/S2AM.py",
"chars": 11228,
"preview": "import torch\nimport torch.nn as nn\nimport torch.backends.cudnn as cudnn\nfrom progress.bar import Bar\nimport json\nimport "
},
{
"path": "scripts/machines/VX.py",
"chars": 12188,
"preview": "import torch\nimport torch.nn as nn\nfrom progress.bar import Bar\nfrom tqdm import tqdm\nimport pytorch_ssim\nimport json\nim"
},
{
"path": "scripts/machines/__init__.py",
"chars": 225,
"preview": "\nfrom .BasicMachine import BasicMachine\nfrom .VX import VX\nfrom .S2AM import S2AM\n\ndef basic(**kwargs):\n\treturn BasicMac"
},
{
"path": "scripts/models/__init__.py",
"chars": 78,
"preview": "from .vgg import *\nfrom .backbone_unet import *\nfrom .discriminator import *\n\n"
},
{
"path": "scripts/models/backbone_unet.py",
"chars": 1301,
"preview": "\n\nimport torch\nimport torchvision\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport numpy as np\nimport functo"
},
{
"path": "scripts/models/blocks.py",
"chars": 10365,
"preview": "import torch\nimport torchvision\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport numpy as np\nimport functool"
},
{
"path": "scripts/models/discriminator.py",
"chars": 5929,
"preview": "import numpy as np\nimport functools\nimport math\nimport torch\nfrom torch.autograd import Variable\nimport torch.nn.functio"
},
{
"path": "scripts/models/rasc.py",
"chars": 5778,
"preview": "\r\n\r\nimport torch\r\nimport torchvision\r\nimport torch.nn as nn\r\nimport torch.nn.functional as F\r\nimport numpy as np\r\nimport"
},
{
"path": "scripts/models/sa_resunet.py",
"chars": 22050,
"preview": "\r\nimport torch\r\nimport torch.nn as nn\r\nimport torch.nn.functional as F\r\nfrom scripts.models.blocks import SEBlock\r\nfrom "
},
{
"path": "scripts/models/unet.py",
"chars": 7038,
"preview": "import torch\r\nimport torch.nn as nn\r\nfrom torch.nn import init\r\nimport functools\r\nfrom scripts.models.blocks import *\r\nf"
},
{
"path": "scripts/models/vgg.py",
"chars": 3144,
"preview": "from collections import namedtuple\n\nimport torch\nfrom torchvision import models\n\n\nclass Vgg16(torch.nn.Module):\n def "
},
{
"path": "scripts/models/vmu.py",
"chars": 13059,
"preview": "\r\nimport torch\r\nimport torch.nn as nn\r\nimport torch.nn.functional as F\r\nfrom scripts.models.blocks import SEBlock\r\nfrom "
},
{
"path": "scripts/utils/__init__.py",
"chars": 180,
"preview": "from __future__ import absolute_import\n\nfrom .evaluation import *\nfrom .imutils import *\nfrom .logger import *\nfrom .mis"
},
{
"path": "scripts/utils/evaluation.py",
"chars": 3597,
"preview": "from __future__ import absolute_import\n\nimport math\nimport numpy as np\nimport matplotlib.pyplot as plt\nfrom random impor"
},
{
"path": "scripts/utils/imutils.py",
"chars": 7611,
"preview": "from __future__ import absolute_import\n\nimport torch\nimport torch.nn as nn\nimport numpy as np\nimport scipy.misc\n\nfrom .m"
},
{
"path": "scripts/utils/logger.py",
"chars": 4399,
"preview": "# A simple torch style logger\n# (C) Wei YANG 2017\nfrom __future__ import absolute_import\n\nimport os\nimport sys\nimport nu"
},
{
"path": "scripts/utils/losses.py",
"chars": 10807,
"preview": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom scripts.models.vgg import Vgg19\nfrom torchvision"
},
{
"path": "scripts/utils/misc.py",
"chars": 2573,
"preview": "from __future__ import absolute_import\n\nimport os\nimport shutil\nimport torch \nimport math\nimport numpy as np\nimport scip"
},
{
"path": "scripts/utils/model_init.py",
"chars": 1752,
"preview": "\n\nfrom torch.nn import init\n\n\ndef weights_init_normal(m):\n classname = m.__class__.__name__\n # print(classname)\n "
},
{
"path": "scripts/utils/osutils.py",
"chars": 377,
"preview": "from __future__ import absolute_import\n\nimport os\nimport errno\n\ndef mkdir_p(dir_path):\n try:\n os.makedirs(dir_"
},
{
"path": "scripts/utils/parallel.py",
"chars": 11406,
"preview": "##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n## Created by: Hang Zhang, Rutgers Universit"
},
{
"path": "scripts/utils/transforms.py",
"chars": 5225,
"preview": "from __future__ import absolute_import\n\nimport os\nimport numpy as np\nimport scipy.misc\nimport matplotlib.pyplot as plt\ni"
},
{
"path": "test.py",
"chars": 763,
"preview": "from __future__ import print_function, absolute_import\n\nimport argparse\nimport torch\n\ntorch.backends.cudnn.benchmark = T"
},
{
"path": "watermark_synthesis.ipynb",
"chars": 5294,
"preview": "{\n \"cells\": [\n {\n \"cell_type\": \"code\",\n \"execution_count\": null,\n \"metadata\": {},\n \"outputs\": [\n {\n \"nam"
}
]
About this extraction
This page contains the full source code of the vinthony/deep-blind-watermark-removal GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 35 files (174.8 KB), approximately 45.1k tokens, and a symbol index with 353 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.
Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.