Repository: nie-lang/UDIS2
Branch: main
Commit: b4158950ab14
Files: 27
Total size: 123.9 KB
Directory structure:
gitextract_hn9kepfw/
├── Composition/
│ ├── Codes/
│ │ ├── dataset.py
│ │ ├── loss.py
│ │ ├── network.py
│ │ ├── test.py
│ │ ├── test_other.py
│ │ └── train.py
│ ├── model/
│ │ └── .txt
│ ├── readme.md
│ └── summary/
│ └── .txt
├── LICENSE
├── README.md
├── Warp/
│ ├── Codes/
│ │ ├── dataset.py
│ │ ├── grid_res.py
│ │ ├── loss.py
│ │ ├── network.py
│ │ ├── test.py
│ │ ├── test_other.py
│ │ ├── test_output.py
│ │ ├── train.py
│ │ └── utils/
│ │ ├── torch_DLT.py
│ │ ├── torch_homo_transform.py
│ │ ├── torch_tps_transform.py
│ │ └── torch_tps_transform2.py
│ ├── model/
│ │ └── .txt
│ ├── readme.md
│ └── summary/
│ └── .txt
└── environment.yml
================================================
FILE CONTENTS
================================================
================================================
FILE: Composition/Codes/dataset.py
================================================
from torch.utils.data import Dataset
import numpy as np
import cv2, torch
import os
import glob
from collections import OrderedDict
import random
class TrainDataset(Dataset):
def __init__(self, data_path):
self.train_path = data_path
self.datas = OrderedDict()
datas = glob.glob(os.path.join(self.train_path, '*'))
for data in sorted(datas):
data_name = data.split('/')[-1]
if data_name == 'warp1' or data_name == 'warp2' or data_name == 'mask1' or data_name == 'mask2':
self.datas[data_name] = {}
self.datas[data_name]['path'] = data
self.datas[data_name]['image'] = glob.glob(os.path.join(data, '*.jpg'))
self.datas[data_name]['image'].sort()
print(self.datas.keys())
def __getitem__(self, index):
# load image1
warp1 = cv2.imread(self.datas['warp1']['image'][index])
warp1 = warp1.astype(dtype=np.float32)
warp1 = (warp1 / 127.5) - 1.0
warp1 = np.transpose(warp1, [2, 0, 1])
# load image2
warp2 = cv2.imread(self.datas['warp2']['image'][index])
warp2 = warp2.astype(dtype=np.float32)
warp2 = (warp2 / 127.5) - 1.0
warp2 = np.transpose(warp2, [2, 0, 1])
# load mask1
mask1 = cv2.imread(self.datas['mask1']['image'][index])
mask1 = mask1.astype(dtype=np.float32)
mask1 = np.expand_dims(mask1[:,:,0], 2) / 255
mask1 = np.transpose(mask1, [2, 0, 1])
# load mask2
mask2 = cv2.imread(self.datas['mask2']['image'][index])
mask2 = mask2.astype(dtype=np.float32)
mask2 = np.expand_dims(mask2[:,:,0], 2) / 255
mask2 = np.transpose(mask2, [2, 0, 1])
# convert to tensor
warp1_tensor = torch.tensor(warp1)
warp2_tensor = torch.tensor(warp2)
mask1_tensor = torch.tensor(mask1)
mask2_tensor = torch.tensor(mask2)
#return (input1_tensor, input2_tensor, mask1_tensor, mask2_tensor)
if_exchange = random.randint(0,1)
if if_exchange == 0:
#print(if_exchange)
return (warp1_tensor, warp2_tensor, mask1_tensor, mask2_tensor)
else:
#print(if_exchange)
return (warp2_tensor, warp1_tensor, mask2_tensor, mask1_tensor)
def __len__(self):
return len(self.datas['warp1']['image'])
class TestDataset(Dataset):
def __init__(self, data_path):
self.test_path = data_path
self.datas = OrderedDict()
datas = glob.glob(os.path.join(self.test_path, '*'))
for data in sorted(datas):
data_name = data.split('/')[-1]
if data_name == 'warp1' or data_name == 'warp2' or data_name == 'mask1' or data_name == 'mask2':
self.datas[data_name] = {}
self.datas[data_name]['path'] = data
self.datas[data_name]['image'] = glob.glob(os.path.join(data, '*.jpg'))
self.datas[data_name]['image'].sort()
print(self.datas.keys())
def __getitem__(self, index):
# load image1
warp1 = cv2.imread(self.datas['warp1']['image'][index])
warp1 = warp1.astype(dtype=np.float32)
warp1 = (warp1 / 127.5) - 1.0
warp1 = np.transpose(warp1, [2, 0, 1])
# load image2
warp2 = cv2.imread(self.datas['warp2']['image'][index])
warp2 = warp2.astype(dtype=np.float32)
warp2 = (warp2 / 127.5) - 1.0
warp2 = np.transpose(warp2, [2, 0, 1])
# load mask1
mask1 = cv2.imread(self.datas['mask1']['image'][index])
mask1 = mask1.astype(dtype=np.float32)
mask1 = np.expand_dims(mask1[:,:,0], 2) / 255
mask1 = np.transpose(mask1, [2, 0, 1])
# load mask2
mask2 = cv2.imread(self.datas['mask2']['image'][index])
mask2 = mask2.astype(dtype=np.float32)
mask2 = np.expand_dims(mask2[:,:,0], 2) / 255
mask2 = np.transpose(mask2, [2, 0, 1])
# convert to tensor
warp1_tensor = torch.tensor(warp1)
warp2_tensor = torch.tensor(warp2)
mask1_tensor = torch.tensor(mask1)
mask2_tensor = torch.tensor(mask2)
return (warp1_tensor, warp2_tensor, mask1_tensor, mask2_tensor)
def __len__(self):
return len(self.datas['warp1']['image'])
================================================
FILE: Composition/Codes/loss.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
# def get_vgg19_FeatureMap(vgg_model, input_255, layer_index):
# vgg_mean = torch.tensor([123.6800, 116.7790, 103.9390]).reshape((1,3,1,1))
# if torch.cuda.is_available():
# vgg_mean = vgg_mean.cuda()
# vgg_input = input_255-vgg_mean
# #x = vgg_model.features[0](vgg_input)
# #FeatureMap_list.append(x)
# for i in range(0,layer_index+1):
# if i == 0:
# x = vgg_model.features[0](vgg_input)
# else:
# x = vgg_model.features[i](x)
# return x
def l_num_loss(img1, img2, l_num=1):
return torch.mean(torch.abs((img1 - img2)**l_num))
def boundary_extraction(mask):
ones = torch.ones_like(mask)
zeros = torch.zeros_like(mask)
#define kernel
in_channel = 1
out_channel = 1
kernel = [[1, 1, 1],
[1, 1, 1],
[1, 1, 1]]
kernel = torch.FloatTensor(kernel).expand(out_channel,in_channel,3,3)
if torch.cuda.is_available():
kernel = kernel.cuda()
ones = ones.cuda()
zeros = zeros.cuda()
weight = nn.Parameter(data=kernel, requires_grad=False)
#dilation
x = F.conv2d(1-mask,weight,stride=1,padding=1)
x = torch.where(x < 1, zeros, ones)
x = F.conv2d(x,weight,stride=1,padding=1)
x = torch.where(x < 1, zeros, ones)
x = F.conv2d(x,weight,stride=1,padding=1)
x = torch.where(x < 1, zeros, ones)
x = F.conv2d(x,weight,stride=1,padding=1)
x = torch.where(x < 1, zeros, ones)
x = F.conv2d(x,weight,stride=1,padding=1)
x = torch.where(x < 1, zeros, ones)
x = F.conv2d(x,weight,stride=1,padding=1)
x = torch.where(x < 1, zeros, ones)
x = F.conv2d(x,weight,stride=1,padding=1)
x = torch.where(x < 1, zeros, ones)
return x*mask
def cal_boundary_term(inpu1_tesnor, inpu2_tesnor, mask1_tesnor, mask2_tesnor, stitched_image):
boundary_mask1 = mask1_tesnor * boundary_extraction(mask2_tesnor)
boundary_mask2 = mask2_tesnor * boundary_extraction(mask1_tesnor)
loss1 = l_num_loss(inpu1_tesnor*boundary_mask1, stitched_image*boundary_mask1, 1)
loss2 = l_num_loss(inpu2_tesnor*boundary_mask2, stitched_image*boundary_mask2, 1)
return loss1+loss2, boundary_mask1
def cal_smooth_term_stitch(stitched_image, learned_mask1):
delta = 1
dh_mask = torch.abs(learned_mask1[:,:,0:-1*delta,:] - learned_mask1[:,:,delta:,:])
dw_mask = torch.abs(learned_mask1[:,:,:,0:-1*delta] - learned_mask1[:,:,:,delta:])
dh_diff_img = torch.abs(stitched_image[:,:,0:-1*delta,:] - stitched_image[:,:,delta:,:])
dw_diff_img = torch.abs(stitched_image[:,:,:,0:-1*delta] - stitched_image[:,:,:,delta:])
dh_pixel = dh_mask * dh_diff_img
dw_pixel = dw_mask * dw_diff_img
loss = torch.mean(dh_pixel) + torch.mean(dw_pixel)
return loss
def cal_smooth_term_diff(img1, img2, learned_mask1, overlap):
diff_feature = torch.abs(img1-img2)**2 * overlap
delta = 1
dh_mask = torch.abs(learned_mask1[:,:,0:-1*delta,:] - learned_mask1[:,:,delta:,:])
dw_mask = torch.abs(learned_mask1[:,:,:,0:-1*delta] - learned_mask1[:,:,:,delta:])
dh_diff_img = torch.abs(diff_feature[:,:,0:-1*delta,:] + diff_feature[:,:,delta:,:])
dw_diff_img = torch.abs(diff_feature[:,:,:,0:-1*delta] + diff_feature[:,:,:,delta:])
dh_pixel = dh_mask * dh_diff_img
dw_pixel = dw_mask * dw_diff_img
loss = torch.mean(dh_pixel) + torch.mean(dw_pixel)
return loss
# dh_zeros = torch.zeros_like(dh_pixel)
# dw_zeros = torch.zeros_like(dw_pixel)
# if torch.cuda.is_available():
# dh_zeros = dh_zeros.cuda()
# dw_zeros = dw_zeros.cuda()
# loss = l_num_loss(dh_pixel, dh_zeros, 1) + l_num_loss(dw_pixel, dw_zeros, 1)
# return loss, dh_pixel
================================================
FILE: Composition/Codes/network.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
def build_model(net, warp1_tensor, warp2_tensor, mask1_tensor, mask2_tensor):
out = net(warp1_tensor, warp2_tensor, mask1_tensor, mask2_tensor)
learned_mask1 = (mask1_tensor - mask1_tensor*mask2_tensor) + mask1_tensor*mask2_tensor*out
learned_mask2 = (mask2_tensor - mask1_tensor*mask2_tensor) + mask1_tensor*mask2_tensor*(1-out)
stitched_image = (warp1_tensor+1.) * learned_mask1 + (warp2_tensor+1.)*learned_mask2 - 1.
out_dict = {}
out_dict.update(learned_mask1=learned_mask1, learned_mask2=learned_mask2, stitched_image = stitched_image)
return out_dict
class DownBlock(nn.Module):
def __init__(self, inchannels, outchannels, dilation, pool=True):
super(DownBlock, self).__init__()
blk = []
if pool:
blk.append(nn.MaxPool2d(kernel_size=2, stride=2))
blk.append(nn.Conv2d(inchannels, outchannels, kernel_size=3, padding=1, dilation = dilation))
blk.append(nn.ReLU(inplace=True))
blk.append(nn.Conv2d(outchannels, outchannels, kernel_size=3, padding=1, dilation = dilation))
blk.append(nn.ReLU(inplace=True))
self.layer = nn.Sequential(*blk)
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight)
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
def forward(self, x):
return self.layer(x)
class UpBlock(nn.Module):
def __init__(self, inchannels, outchannels, dilation):
super(UpBlock, self).__init__()
#self.convt = nn.ConvTranspose2d(inchannels, outchannels, kernel_size=2, stride=2)
self.halfChanelConv = nn.Sequential(
nn.Conv2d(inchannels, outchannels, kernel_size=3, padding=1),
nn.ReLU(inplace=True)
)
self.conv = nn.Sequential(
nn.Conv2d(inchannels, outchannels, kernel_size=3, padding=1, dilation = dilation),
nn.ReLU(inplace=True),
nn.Conv2d(outchannels, outchannels, kernel_size=3, padding=1, dilation = dilation),
nn.ReLU(inplace=True)
)
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight)
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
def forward(self, x1, x2):
x1 = F.interpolate(x1, size = (x2.size()[2], x2.size()[3]), mode='nearest')
x1 = self.halfChanelConv(x1)
x = torch.cat([x2, x1], dim=1)
x = self.conv(x)
return x
# predict the composition mask of img1
class Network(nn.Module):
def __init__(self, nclasses=1):
super(Network, self).__init__()
self.down1 = DownBlock(3, 32, 1, pool=False)
self.down2 = DownBlock(32, 64, 2)
self.down3 = DownBlock(64, 128,3)
self.down4 = DownBlock(128, 256, 4)
self.down5 = DownBlock(256, 512, 5)
self.up1 = UpBlock(512, 256, 4)
self.up2 = UpBlock(256, 128, 3)
self.up3 = UpBlock(128, 64, 2)
self.up4 = UpBlock(64, 32, 1)
self.out = nn.Sequential(
nn.Conv2d(32, nclasses, kernel_size=1),
nn.Sigmoid()
)
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight)
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
def forward(self, x, y, m1, m2):
x1 = self.down1(x)
x2 = self.down2(x1)
x3 = self.down3(x2)
x4 = self.down4(x3)
x5 = self.down5(x4)
y1 = self.down1(y)
y2 = self.down2(y1)
y3 = self.down3(y2)
y4 = self.down4(y3)
y5 = self.down5(y4)
res = self.up1(x5-y5, x4-y4)
res = self.up2(res, x3-y3)
res = self.up3(res, x2-y2)
res = self.up4(res, x1-y1)
res = self.out(res)
return res
================================================
FILE: Composition/Codes/test.py
================================================
# coding: utf-8
import argparse
import torch
from torch.utils.data import DataLoader
from network import build_model, Network
from dataset import *
import os
import numpy as np
import cv2
last_path = os.path.abspath(os.path.join(os.path.dirname("__file__"), os.path.pardir))
MODEL_DIR = os.path.join(last_path, 'model')
def test(args):
os.environ['CUDA_DEVICES_ORDER'] = "PCI_BUS_ID"
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
# dataset
test_data = TestDataset(data_path=args.test_path)
test_loader = DataLoader(dataset=test_data, batch_size=args.batch_size, num_workers=1, shuffle=False, drop_last=False)
# define the network
net = Network()
if torch.cuda.is_available():
net = net.cuda()
#load the existing models if it exists
ckpt_list = glob.glob(MODEL_DIR + "/*.pth")
ckpt_list.sort()
if len(ckpt_list) != 0:
model_path = ckpt_list[-1]
checkpoint = torch.load(model_path)
net.load_state_dict(checkpoint['model'])
print('load model from {}!'.format(model_path))
else:
print('No checkpoint found!')
return
path_learn_mask1 = '../learn_mask1/'
if not os.path.exists(path_learn_mask1):
os.makedirs(path_learn_mask1)
path_learn_mask2 = '../learn_mask2/'
if not os.path.exists(path_learn_mask2):
os.makedirs(path_learn_mask2)
path_final_composition = '../composition/'
if not os.path.exists(path_final_composition):
os.makedirs(path_final_composition)
print("##################start testing#######################")
net.eval()
for i, batch_value in enumerate(test_loader):
warp1_tensor = batch_value[0].float()
warp2_tensor = batch_value[1].float()
mask1_tensor = batch_value[2].float()
mask2_tensor = batch_value[3].float()
if torch.cuda.is_available():
warp1_tensor = warp1_tensor.cuda()
warp2_tensor = warp2_tensor.cuda()
mask1_tensor = mask1_tensor.cuda()
mask2_tensor = mask2_tensor.cuda()
# if inpu1_tesnor.size()[2]*inpu1_tesnor.size()[3] > 1200000:
# print("oversize")
# continue
with torch.no_grad():
batch_out = build_model(net, warp1_tensor, warp2_tensor, mask1_tensor, mask2_tensor)
stitched_image = batch_out['stitched_image']
learned_mask1 = batch_out['learned_mask1']
learned_mask2 = batch_out['learned_mask2']
stitched_image = ((stitched_image[0]+1)*127.5).cpu().detach().numpy().transpose(1,2,0)
learned_mask1 = (learned_mask1[0]*255).cpu().detach().numpy().transpose(1,2,0)
learned_mask2 = (learned_mask2[0]*255).cpu().detach().numpy().transpose(1,2,0)
path = path_learn_mask1 + str(i+1).zfill(6) + ".jpg"
cv2.imwrite(path, learned_mask1)
path = path_learn_mask2 + str(i+1).zfill(6) + ".jpg"
cv2.imwrite(path, learned_mask2)
path = path_final_composition + str(i+1).zfill(6) + ".jpg"
cv2.imwrite(path, stitched_image)
print('i = {}'.format( i+1))
if __name__=="__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--gpu', type=str, default='0')
parser.add_argument('--batch_size', type=int, default=1)
parser.add_argument('--test_path', type=str, default='/opt/data/private/nl/Data/UDIS-D/testing/')
print('<==================== Loading data ===================>\n')
args = parser.parse_args()
print(args)
test(args)
================================================
FILE: Composition/Codes/test_other.py
================================================
# coding: utf-8
import argparse
import torch
from network import build_model, Network
import os
import numpy as np
import cv2
import glob
last_path = os.path.abspath(os.path.join(os.path.dirname("__file__"), os.path.pardir))
MODEL_DIR = os.path.join(last_path, 'model')
def loadSingleData(data_path):
# load image1
warp1 = cv2.imread(data_path+"warp1.jpg")
warp1 = warp1.astype(dtype=np.float32)
warp1 = (warp1 / 127.5) - 1.0
warp1 = np.transpose(warp1, [2, 0, 1])
# load image2
warp2 = cv2.imread(data_path+"warp2.jpg")
warp2 = warp2.astype(dtype=np.float32)
warp2 = (warp2 / 127.5) - 1.0
warp2 = np.transpose(warp2, [2, 0, 1])
# load mask1
mask1 = cv2.imread(data_path+"mask1.jpg")
mask1 = mask1.astype(dtype=np.float32)
mask1 = mask1 / 255
mask1 = np.transpose(mask1, [2, 0, 1])
# load mask2
mask2 = cv2.imread(data_path+"mask2.jpg")
mask2 = mask2.astype(dtype=np.float32)
mask2 = mask2 / 255
mask2 = np.transpose(mask2, [2, 0, 1])
# convert to tensor
warp1_tensor = torch.tensor(warp1).unsqueeze(0)
warp2_tensor = torch.tensor(warp2).unsqueeze(0)
mask1_tensor = torch.tensor(mask1).unsqueeze(0)
mask2_tensor = torch.tensor(mask2).unsqueeze(0)
return warp1_tensor, warp2_tensor, mask1_tensor, mask2_tensor
def test_other(args):
os.environ['CUDA_DEVICES_ORDER'] = "PCI_BUS_ID"
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
# define the network
net = Network()
if torch.cuda.is_available():
net = net.cuda()
#load the existing models if it exists
ckpt_list = glob.glob(MODEL_DIR + "/*.pth")
ckpt_list.sort()
if len(ckpt_list) != 0:
model_path = ckpt_list[-1]
checkpoint = torch.load(model_path)
net.load_state_dict(checkpoint['model'])
print('load model from {}!'.format(model_path))
else:
print('No checkpoint found!')
return
# load dataset(only one pair of images)
warp1_tensor, warp2_tensor, mask1_tensor, mask2_tensor = loadSingleData(data_path=args.path)
if torch.cuda.is_available():
warp1_tensor = warp1_tensor.cuda()
warp2_tensor = warp2_tensor.cuda()
mask1_tensor = mask1_tensor.cuda()
mask2_tensor = mask2_tensor.cuda()
net.eval()
with torch.no_grad():
batch_out = build_model(net, warp1_tensor, warp2_tensor, mask1_tensor, mask2_tensor)
stitched_image = batch_out['stitched_image']
learned_mask1 = batch_out['learned_mask1']
learned_mask2 = batch_out['learned_mask2']
# (optional) draw composition images with different colors like our paper
s1 = ((warp1_tensor[0]+1)*127.5 * learned_mask1[0]).cpu().detach().numpy().transpose(1,2,0)
s2 = ((warp2_tensor[0]+1)*127.5 * learned_mask2[0]).cpu().detach().numpy().transpose(1,2,0)
fusion = np.zeros((warp1_tensor.shape[2],warp1_tensor.shape[3],3), np.uint8)
fusion[...,0] = s2[...,0]
fusion[...,1] = s1[...,1]*0.5 + s2[...,1]*0.5
fusion[...,2] = s1[...,2]
path = args.path + "composition_color.jpg"
cv2.imwrite(path, fusion)
# save learned masks and final composition
stitched_image = ((stitched_image[0]+1)*127.5).cpu().detach().numpy().transpose(1,2,0)
learned_mask1 = (learned_mask1[0]*255).cpu().detach().numpy().transpose(1,2,0)
learned_mask2 = (learned_mask2[0]*255).cpu().detach().numpy().transpose(1,2,0)
path = args.path + "learn_mask1.jpg"
cv2.imwrite(path, learned_mask1)
path = args.path + "learn_mask2.jpg"
cv2.imwrite(path, learned_mask2)
path = args.path + "composition.jpg"
cv2.imwrite(path, stitched_image)
if __name__=="__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--gpu', type=str, default='0')
parser.add_argument('--path', type=str, default='../../Carpark-DHW/')
print('<==================== Loading data ===================>\n')
args = parser.parse_args()
print(args)
test_other(args)
================================================
FILE: Composition/Codes/train.py
================================================
import argparse
import torch
from torch.utils.data import DataLoader
import os
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter
from network import build_model, Network
from dataset import TrainDataset
import glob
from loss import cal_boundary_term, cal_smooth_term_stitch, cal_smooth_term_diff
# path of project
last_path = os.path.abspath(os.path.join(os.path.dirname("__file__"), os.path.pardir))
# path to save the summary files
SUMMARY_DIR = os.path.join(last_path, 'summary')
writer = SummaryWriter(log_dir=SUMMARY_DIR)
# path to save the model files
MODEL_DIR = os.path.join(last_path, 'model')
# create folders if it dose not exist
if not os.path.exists(MODEL_DIR):
os.makedirs(MODEL_DIR)
if not os.path.exists(SUMMARY_DIR):
os.makedirs(SUMMARY_DIR)
def train(args):
os.environ['CUDA_DEVICES_ORDER'] = "PCI_BUS_ID"
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
# dataset
train_data = TrainDataset(data_path=args.train_path)
train_loader = DataLoader(dataset=train_data, batch_size=args.batch_size, num_workers=4, shuffle=True, drop_last=True)
# define the network
net = Network()
if torch.cuda.is_available():
net = net.cuda()
# define the optimizer and learning rate
optimizer = optim.Adam(net.parameters(), lr=1e-4, betas=(0.9, 0.999), eps=1e-08) # default as 0.0001
scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.97)
#load the existing models if it exists
ckpt_list = glob.glob(MODEL_DIR + "/*.pth")
ckpt_list.sort()
if len(ckpt_list) != 0:
model_path = ckpt_list[-1]
checkpoint = torch.load(model_path)
net.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])
start_epoch = checkpoint['epoch']
glob_iter = checkpoint['glob_iter']
scheduler.last_epoch = start_epoch
print('load model from {}!'.format(model_path))
else:
start_epoch = 0
glob_iter = 0
print('training from stratch!')
print("##################start training#######################")
score_print_fre = 300
for epoch in range(start_epoch, args.max_epoch):
print("start epoch {}".format(epoch))
net.train()
sigma_total_loss = 0.
sigma_boundary_loss = 0.
sigma_smooth1_loss = 0.
sigma_smooth2_loss = 0.
print(epoch, 'lr={:.6f}'.format(optimizer.state_dict()['param_groups'][0]['lr']))
for i, batch_value in enumerate(train_loader):
warp1_tensor = batch_value[0].float()
warp2_tensor = batch_value[1].float()
mask1_tensor = batch_value[2].float()
mask2_tensor = batch_value[3].float()
if torch.cuda.is_available():
warp1_tensor = warp1_tensor.cuda()
warp2_tensor = warp2_tensor.cuda()
mask1_tensor = mask1_tensor.cuda()
mask2_tensor = mask2_tensor.cuda()
# forward, backward, update weights
optimizer.zero_grad()
batch_out = build_model(net, warp1_tensor, warp2_tensor, mask1_tensor, mask2_tensor)
learned_mask1 = batch_out['learned_mask1']
learned_mask2 = batch_out['learned_mask2']
stitched_image = batch_out['stitched_image']
# boundary term
boundary_loss, boundary_mask1 = cal_boundary_term( warp1_tensor, warp2_tensor, mask1_tensor, mask2_tensor, stitched_image)
boundary_loss = 10000 * boundary_loss
# smooth term
# on stitched image
smooth1_loss = cal_smooth_term_stitch(stitched_image, learned_mask1)
smooth1_loss = 1000* smooth1_loss
# on different image
smooth2_loss = cal_smooth_term_diff( warp1_tensor, warp2_tensor, learned_mask1, mask1_tensor*mask2_tensor)
smooth2_loss = 1000 * smooth2_loss
total_loss = boundary_loss + smooth1_loss + smooth2_loss
total_loss.backward()
# clip the gradient
torch.nn.utils.clip_grad_norm_(net.parameters(), max_norm=3, norm_type=2)
optimizer.step()
sigma_boundary_loss += boundary_loss.item()
sigma_smooth1_loss += smooth1_loss.item()
sigma_smooth2_loss += smooth2_loss.item()
sigma_total_loss += total_loss.item()
print(glob_iter)
# print loss etc.
if i % score_print_fre == 0 and i != 0:
average_total_loss = sigma_total_loss / score_print_fre
average_boundary_loss = sigma_boundary_loss/ score_print_fre
average_smooth1_loss = sigma_smooth1_loss/ score_print_fre
average_smooth2_loss = sigma_smooth2_loss/ score_print_fre
sigma_total_loss = 0.
sigma_boundary_loss = 0.
sigma_smooth1_loss = 0.
sigma_smooth2_loss = 0.
print("Training: Epoch[{:0>3}/{:0>3}] Iteration[{:0>3}]/[{:0>3}] Total Loss: {:.4f} boundary loss: {:.4f} smooth loss: {:.4f} diff loss: {:.4f} lr={:.8f}".format(epoch + 1, args.max_epoch, i + 1, len(train_loader), average_total_loss, average_boundary_loss, average_smooth1_loss, average_smooth2_loss, optimizer.state_dict()['param_groups'][0]['lr']))
# visualization
writer.add_image("inpu1", (warp1_tensor[0]+1.)/2., glob_iter)
writer.add_image("inpu2", (warp2_tensor[0]+1.)/2., glob_iter)
writer.add_image("stitched_image", (stitched_image[0]+1.)/2., glob_iter)
writer.add_image("learned_mask1", learned_mask1[0], glob_iter)
writer.add_image("boundary_mask1", boundary_mask1[0], glob_iter)
writer.add_scalar('lr', optimizer.state_dict()['param_groups'][0]['lr'], glob_iter)
writer.add_scalar('total loss', average_total_loss, glob_iter)
writer.add_scalar('average_boundary_loss', average_boundary_loss, glob_iter)
writer.add_scalar('average_smooth1_loss', average_smooth1_loss, glob_iter)
writer.add_scalar('average_smooth2_loss', average_smooth2_loss, glob_iter)
glob_iter += 1
scheduler.step()
# save model
if ((epoch+1) % 10 == 0 or (epoch+1)==args.max_epoch):
filename ='epoch' + str(epoch+1).zfill(3) + '_model.pth'
model_save_path = os.path.join(MODEL_DIR, filename)
state = {'model': net.state_dict(), 'optimizer': optimizer.state_dict(), 'epoch': epoch+1, "glob_iter": glob_iter}
torch.save(state, model_save_path)
print("##################end training#######################")
if __name__=="__main__":
print('<==================== setting arguments ===================>\n')
#nl: create the argument parser
parser = argparse.ArgumentParser()
#nl: add arguments
parser.add_argument('--gpu', type=str, default='0')
parser.add_argument('--batch_size', type=int, default=1)
parser.add_argument('--max_epoch', type=int, default=50)
parser.add_argument('--train_path', type=str, default='/opt/data/private/nl/Data/UDIS-D/training')
#nl: parse the arguments
args = parser.parse_args()
print(args)
print('<==================== jump into training function ===================>\n')
#nl: rain
train(args)
================================================
FILE: Composition/model/.txt
================================================
================================================
FILE: Composition/readme.md
================================================
## Train on UDIS-D
Before training, the warped images and corresponding masks should be generated in the warp stage.
Then, set the training dataset path in Composition/Codes/train.py.
```
python train_H.py
```
## Test on UDIS-D
The pre-trained model of warp is available at [Google Drive](https://drive.google.com/file/d/1OaG0ayEwRPhKVV_OwQwvwHDFHC26iv30/view?usp=sharing) or [Baidu Cloud](https://pan.baidu.com/s/1qCGegzvxtzri6GiG7mNw6g)(Extraction code: 1234).
Set the testing dataset path in Composition/Codes/test.py.
```
python test.py
```
The composition masks and final fusion results on UDIS-D will be generated and saved at the current path.
## Test on other datasets
Set the 'path/' in Composition/Codes/test_other.py.
```
python test_other.py
```
The results will be generated and saved at 'path'.
================================================
FILE: Composition/summary/.txt
================================================
================================================
FILE: LICENSE
================================================
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
================================================
FILE: README.md
================================================
# <p align="center">Parallax-Tolerant Unsupervised Deep Image Stitching (UDIS++ [paper](https://arxiv.org/abs/2302.08207))</p>
<p align="center">Lang Nie*, Chunyu Lin*, Kang Liao*, Shuaicheng Liu`, Yao Zhao*</p>
<p align="center">* Institute of Information Science, Beijing Jiaotong University</p>
<p align="center">` School of Information and Communication Engineering, University of Electronic Science and Technology of China</p>

## Dataset (UDIS-D)
We use the UDIS-D dataset to train and evaluate our method. Please refer to [UDIS](https://github.com/nie-lang/UnsupervisedDeepImageStitching) for more details about this dataset.
## Code
#### Requirement
* numpy 1.19.5
* pytorch 1.7.1
* scikit-image 0.15.0
* tensorboard 2.9.0
We implement this work with Ubuntu, 3090Ti, and CUDA11. Refer to [environment.yml](https://github.com/nie-lang/UDIS2/blob/main/environment.yml) for more details.
#### How to run it
Similar to UDIS, we also implement this solution in two stages:
* Stage 1 (unsupervised warp): please refer to [Warp/readme.md](https://github.com/nie-lang/UDIS2/blob/main/Warp/readme.md).
* Stage 2 (unsupervised composition): please refer to [Composition/readme.md](https://github.com/nie-lang/UDIS2/blob/main/Composition/readme.md).
## Meta
If you have any questions about this project, please feel free to drop me an email.
NIE Lang -- nielang@bjtu.edu.cn
```
@inproceedings{nie2023parallax,
title={Parallax-Tolerant Unsupervised Deep Image Stitching},
author={Nie, Lang and Lin, Chunyu and Liao, Kang and Liu, Shuaicheng and Zhao, Yao},
booktitle={Proceedings of the IEEE/CVF International Conference on Computer Vision},
pages={7399--7408},
year={2023}
}
```
## References
[1] L. Nie, C. Lin, K. Liao, M. Liu, and Y. Zhao, “A view-free image stitching network based on global homography,” Journal of Visual Communication and Image Representation, p. 102950, 2020.
[2] L. Nie, C. Lin, K. Liao, and Y. Zhao. Learning edge-preserved image stitching from multi-scale deep homography[J]. Neurocomputing, 2022, 491: 533-543.
[3] L. Nie, C. Lin, K. Liao, S. Liu, and Y. Zhao. Unsupervised deep image stitching: Reconstructing stitched features to images[J]. IEEE Transactions on Image Processing, 2021, 30: 6184-6197.
[4] L. Nie, C. Lin, K. Liao, S. Liu, and Y. Zhao. Deep rectangling for image stitching: a learning baseline[C]//Proceedings of the IEEE/CVF conference on computer vision and pattern recognition. 2022: 5740-5748.
================================================
FILE: Warp/Codes/dataset.py
================================================
from torch.utils.data import Dataset
import numpy as np
import cv2, torch
import os
import glob
from collections import OrderedDict
import random
class TrainDataset(Dataset):
def __init__(self, data_path):
self.width = 512
self.height = 512
self.train_path = data_path
self.datas = OrderedDict()
datas = glob.glob(os.path.join(self.train_path, '*'))
for data in sorted(datas):
data_name = data.split('/')[-1]
if data_name == 'input1' or data_name == 'input2' :
self.datas[data_name] = {}
self.datas[data_name]['path'] = data
self.datas[data_name]['image'] = glob.glob(os.path.join(data, '*.jpg'))
self.datas[data_name]['image'].sort()
print(self.datas.keys())
def __getitem__(self, index):
# load image1
input1 = cv2.imread(self.datas['input1']['image'][index])
input1 = cv2.resize(input1, (self.width, self.height))
input1 = input1.astype(dtype=np.float32)
input1 = (input1 / 127.5) - 1.0
input1 = np.transpose(input1, [2, 0, 1])
# load image2
input2 = cv2.imread(self.datas['input2']['image'][index])
input2 = cv2.resize(input2, (self.width, self.height))
input2 = input2.astype(dtype=np.float32)
input2 = (input2 / 127.5) - 1.0
input2 = np.transpose(input2, [2, 0, 1])
# convert to tensor
input1_tensor = torch.tensor(input1)
input2_tensor = torch.tensor(input2)
#print("fasdf")
if_exchange = random.randint(0,1)
if if_exchange == 0:
#print(if_exchange)
return (input1_tensor, input2_tensor)
else:
#print(if_exchange)
return (input2_tensor, input1_tensor)
def __len__(self):
return len(self.datas['input1']['image'])
class TestDataset(Dataset):
def __init__(self, data_path):
self.width = 512
self.height = 512
self.test_path = data_path
self.datas = OrderedDict()
datas = glob.glob(os.path.join(self.test_path, '*'))
for data in sorted(datas):
data_name = data.split('/')[-1]
if data_name == 'input1' or data_name == 'input2' :
self.datas[data_name] = {}
self.datas[data_name]['path'] = data
self.datas[data_name]['image'] = glob.glob(os.path.join(data, '*.jpg'))
self.datas[data_name]['image'].sort()
print(self.datas.keys())
def __getitem__(self, index):
# load image1
input1 = cv2.imread(self.datas['input1']['image'][index])
#input1 = cv2.resize(input1, (self.width, self.height))
input1 = input1.astype(dtype=np.float32)
input1 = (input1 / 127.5) - 1.0
input1 = np.transpose(input1, [2, 0, 1])
# load image2
input2 = cv2.imread(self.datas['input2']['image'][index])
#input2 = cv2.resize(input2, (self.width, self.height))
input2 = input2.astype(dtype=np.float32)
input2 = (input2 / 127.5) - 1.0
input2 = np.transpose(input2, [2, 0, 1])
# convert to tensor
input1_tensor = torch.tensor(input1)
input2_tensor = torch.tensor(input2)
return (input1_tensor, input2_tensor)
def __len__(self):
return len(self.datas['input1']['image'])
================================================
FILE: Warp/Codes/grid_res.py
================================================
#define control point resolution (GRID_H+1) * (GRID_W+1)
GRID_H = 12
GRID_W = 12
================================================
FILE: Warp/Codes/loss.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
import grid_res
grid_h = grid_res.GRID_H
grid_w = grid_res.GRID_W
def l_num_loss(img1, img2, l_num=1):
return torch.mean(torch.abs((img1 - img2)**l_num))
def cal_lp_loss(input1, input2, output_H, output_H_inv, warp_mesh, warp_mesh_mask):
batch_size, _, img_h, img_w = input1.size()
# part one: sym homo loss with color balance
delta1 = ( torch.sum(output_H[:,0:3,:,:], [2,3]) - torch.sum(input1*output_H[:,3:6,:,:], [2,3]) ) / torch.sum(output_H[:,3:6,:,:], [2,3])
input1_balance = input1 + delta1.unsqueeze(2).unsqueeze(3).expand(-1, -1, img_h, img_w)
delta2 = ( torch.sum(output_H_inv[:,0:3,:,:], [2,3]) - torch.sum(input2*output_H_inv[:,3:6,:,:], [2,3]) ) / torch.sum(output_H_inv[:,3:6,:,:], [2,3])
input2_balance = input2 + delta2.unsqueeze(2).unsqueeze(3).expand(-1, -1, img_h, img_w)
lp_loss_1 = l_num_loss(input1_balance*output_H[:,3:6,:,:], output_H[:,0:3,:,:], 1) + l_num_loss(input2_balance*output_H_inv[:,3:6,:,:], output_H_inv[:,0:3,:,:], 1)
# part two: tps loss with color balance
delta3 = ( torch.sum(warp_mesh, [2,3]) - torch.sum(input1*warp_mesh_mask, [2,3]) ) / torch.sum(warp_mesh_mask, [2,3])
input1_newbalance = input1 + delta3.unsqueeze(2).unsqueeze(3).expand(-1, -1, img_h, img_w)
lp_loss_2 = l_num_loss(input1_newbalance*warp_mesh_mask, warp_mesh, 1)
lp_loss = 3. * lp_loss_1 + 1. * lp_loss_2
return lp_loss
def cal_lp_loss2(input1, warp_mesh, warp_mesh_mask):
batch_size, _, img_h, img_w = input1.size()
delta3 = ( torch.sum(warp_mesh, [2,3]) - torch.sum(input1*warp_mesh_mask, [2,3]) ) / torch.sum(warp_mesh_mask, [2,3])
input1_newbalance = input1 + delta3.unsqueeze(2).unsqueeze(3).expand(-1, -1, img_h, img_w)
lp_loss_2 = l_num_loss(input1_newbalance*warp_mesh_mask, warp_mesh, 1)
lp_loss = 1. * lp_loss_2
return lp_loss
def inter_grid_loss(overlap, mesh):
##############################
# compute horizontal edges
w_edges = mesh[:,:,0:grid_w,:] - mesh[:,:,1:grid_w+1,:]
# compute angles of two successive horizontal edges
cos_w = torch.sum(w_edges[:,:,0:grid_w-1,:] * w_edges[:,:,1:grid_w,:],3) / (torch.sqrt(torch.sum(w_edges[:,:,0:grid_w-1,:]*w_edges[:,:,0:grid_w-1,:],3))*torch.sqrt(torch.sum(w_edges[:,:,1:grid_w,:]*w_edges[:,:,1:grid_w,:],3)))
# horizontal angle-preserving error for two successive horizontal edges
delta_w_angle = 1 - cos_w
# horizontal angle-preserving error for two successive horizontal grids
delta_w_angle = delta_w_angle[:,0:grid_h,:] + delta_w_angle[:,1:grid_h+1,:]
##############################
##############################
# compute vertical edges
h_edges = mesh[:,0:grid_h,:,:] - mesh[:,1:grid_h+1,:,:]
# compute angles of two successive vertical edges
cos_h = torch.sum(h_edges[:,0:grid_h-1,:,:] * h_edges[:,1:grid_h,:,:],3) / (torch.sqrt(torch.sum(h_edges[:,0:grid_h-1,:,:]*h_edges[:,0:grid_h-1,:,:],3))*torch.sqrt(torch.sum(h_edges[:,1:grid_h,:,:]*h_edges[:,1:grid_h,:,:],3)))
# vertical angle-preserving error for two successive vertical edges
delta_h_angle = 1 - cos_h
# vertical angle-preserving error for two successive vertical grids
delta_h_angle = delta_h_angle[:,:,0:grid_w] + delta_h_angle[:,:,1:grid_w+1]
##############################
# on overlapping regions
depth_diff_w = (1-torch.abs(overlap[:,:,0:grid_w-1] - overlap[:,:,1:grid_w])) * overlap[:,:,0:grid_w-1]
error_w = depth_diff_w * delta_w_angle
# on overlapping regions
depth_diff_h = (1-torch.abs(overlap[:,0:grid_h-1,:] - overlap[:,1:grid_h,:])) * overlap[:,0:grid_h-1,:]
error_h = depth_diff_h * delta_h_angle
return torch.mean(error_w) + torch.mean(error_h)
# intra-grid constraint
def intra_grid_loss(pts):
max_w = 512/grid_w * 2
max_h = 512/grid_h * 2
delta_x = pts[:,:,1:grid_w+1,0] - pts[:,:,0:grid_w,0]
delta_y = pts[:,1:grid_h+1,:,1] - pts[:,0:grid_h,:,1]
loss_x = F.relu(delta_x - max_w)
loss_y = F.relu(delta_y - max_h)
loss = torch.mean(loss_x) + torch.mean(loss_y)
return loss
================================================
FILE: Warp/Codes/network.py
================================================
import torch
import torch.nn as nn
import utils.torch_DLT as torch_DLT
import utils.torch_homo_transform as torch_homo_transform
import utils.torch_tps_transform as torch_tps_transform
import ssl
import torch.nn.functional as F
import cv2
import numpy as np
import torchvision.models as models
import torchvision.transforms as T
resize_512 = T.Resize((512,512))
import grid_res
grid_h = grid_res.GRID_H
grid_w = grid_res.GRID_W
# draw mesh on image
# warp: h*w*3
# f_local: grid_h*grid_w*2
def draw_mesh_on_warp(warp, f_local):
warp = np.ascontiguousarray(warp)
point_color = (0, 255, 0) # BGR
thickness = 2
lineType = 8
num = 1
for i in range(grid_h+1):
for j in range(grid_w+1):
num = num + 1
if j == grid_w and i == grid_h:
continue
elif j == grid_w:
cv2.line(warp, (int(f_local[i,j,0]), int(f_local[i,j,1])), (int(f_local[i+1,j,0]), int(f_local[i+1,j,1])), point_color, thickness, lineType)
elif i == grid_h:
cv2.line(warp, (int(f_local[i,j,0]), int(f_local[i,j,1])), (int(f_local[i,j+1,0]), int(f_local[i,j+1,1])), point_color, thickness, lineType)
else :
cv2.line(warp, (int(f_local[i,j,0]), int(f_local[i,j,1])), (int(f_local[i+1,j,0]), int(f_local[i+1,j,1])), point_color, thickness, lineType)
cv2.line(warp, (int(f_local[i,j,0]), int(f_local[i,j,1])), (int(f_local[i,j+1,0]), int(f_local[i,j+1,1])), point_color, thickness, lineType)
return warp
#Covert global homo into mesh
def H2Mesh(H, rigid_mesh):
H_inv = torch.inverse(H)
ori_pt = rigid_mesh.reshape(rigid_mesh.size()[0], -1, 2)
ones = torch.ones(rigid_mesh.size()[0], (grid_h+1)*(grid_w+1),1)
if torch.cuda.is_available():
ori_pt = ori_pt.cuda()
ones = ones.cuda()
ori_pt = torch.cat((ori_pt, ones), 2) # bs*(grid_h+1)*(grid_w+1)*3
tar_pt = torch.matmul(H_inv, ori_pt.permute(0,2,1)) # bs*3*(grid_h+1)*(grid_w+1)
mesh_x = torch.unsqueeze(tar_pt[:,0,:]/tar_pt[:,2,:], 2)
mesh_y = torch.unsqueeze(tar_pt[:,1,:]/tar_pt[:,2,:], 2)
mesh = torch.cat((mesh_x, mesh_y), 2).reshape([rigid_mesh.size()[0], grid_h+1, grid_w+1, 2])
return mesh
# get rigid mesh
def get_rigid_mesh(batch_size, height, width):
ww = torch.matmul(torch.ones([grid_h+1, 1]), torch.unsqueeze(torch.linspace(0., float(width), grid_w+1), 0))
hh = torch.matmul(torch.unsqueeze(torch.linspace(0.0, float(height), grid_h+1), 1), torch.ones([1, grid_w+1]))
if torch.cuda.is_available():
ww = ww.cuda()
hh = hh.cuda()
ori_pt = torch.cat((ww.unsqueeze(2), hh.unsqueeze(2)),2) # (grid_h+1)*(grid_w+1)*2
ori_pt = ori_pt.unsqueeze(0).expand(batch_size, -1, -1, -1)
return ori_pt
# normalize mesh from -1 ~ 1
def get_norm_mesh(mesh, height, width):
batch_size = mesh.size()[0]
mesh_w = mesh[...,0]*2./float(width) - 1.
mesh_h = mesh[...,1]*2./float(height) - 1.
norm_mesh = torch.stack([mesh_w, mesh_h], 3) # bs*(grid_h+1)*(grid_w+1)*2
return norm_mesh.reshape([batch_size, -1, 2]) # bs*-1*2
# random augmentation
# it seems to do nothing to the performance
def data_aug(img1, img2):
# Randomly shift brightness
random_brightness = torch.randn(1).uniform_(0.7,1.3).cuda()
img1_aug = img1 * random_brightness
random_brightness = torch.randn(1).uniform_(0.7,1.3).cuda()
img2_aug = img2 * random_brightness
# Randomly shift color
white = torch.ones([img1.size()[0], img1.size()[2], img1.size()[3]]).cuda()
random_colors = torch.randn(3).uniform_(0.7,1.3).cuda()
color_image = torch.stack([white * random_colors[i] for i in range(3)], axis=1)
img1_aug *= color_image
random_colors = torch.randn(3).uniform_(0.7,1.3).cuda()
color_image = torch.stack([white * random_colors[i] for i in range(3)], axis=1)
img2_aug *= color_image
# clip
img1_aug = torch.clamp(img1_aug, -1, 1)
img2_aug = torch.clamp(img2_aug, -1, 1)
return img1_aug, img2_aug
# for train.py / test.py
def build_model(net, input1_tensor, input2_tensor, is_training = True):
batch_size, _, img_h, img_w = input1_tensor.size()
# network
if is_training == True:
aug_input1_tensor, aug_input2_tensor = data_aug(input1_tensor, input2_tensor)
H_motion, mesh_motion = net(aug_input1_tensor, aug_input2_tensor)
else:
H_motion, mesh_motion = net(input1_tensor, input2_tensor)
H_motion = H_motion.reshape(-1, 4, 2)
mesh_motion = mesh_motion.reshape(-1, grid_h+1, grid_w+1, 2)
# initialize the source points bs x 4 x 2
src_p = torch.tensor([[0., 0.], [img_w, 0.], [0., img_h], [img_w, img_h]])
if torch.cuda.is_available():
src_p = src_p.cuda()
src_p = src_p.unsqueeze(0).expand(batch_size, -1, -1)
# target points
dst_p = src_p + H_motion
# solve homo using DLT
H = torch_DLT.tensor_DLT(src_p, dst_p)
M_tensor = torch.tensor([[img_w / 2.0, 0., img_w / 2.0],
[0., img_h / 2.0, img_h / 2.0],
[0., 0., 1.]])
if torch.cuda.is_available():
M_tensor = M_tensor.cuda()
M_tile = M_tensor.unsqueeze(0).expand(batch_size, -1, -1)
M_tensor_inv = torch.inverse(M_tensor)
M_tile_inv = M_tensor_inv.unsqueeze(0).expand(batch_size, -1, -1)
H_mat = torch.matmul(torch.matmul(M_tile_inv, H), M_tile)
mask = torch.ones_like(input2_tensor)
if torch.cuda.is_available():
mask = mask.cuda()
output_H = torch_homo_transform.transformer(torch.cat((input2_tensor, mask), 1), H_mat, (img_h, img_w))
H_inv_mat = torch.matmul(torch.matmul(M_tile_inv, torch.inverse(H)), M_tile)
output_H_inv = torch_homo_transform.transformer(torch.cat((input1_tensor, mask), 1), H_inv_mat, (img_h, img_w))
rigid_mesh = get_rigid_mesh(batch_size, img_h, img_w)
ini_mesh = H2Mesh(H, rigid_mesh)
mesh = ini_mesh + mesh_motion
norm_rigid_mesh = get_norm_mesh(rigid_mesh, img_h, img_w)
norm_mesh = get_norm_mesh(mesh, img_h, img_w)
output_tps = torch_tps_transform.transformer(torch.cat((input2_tensor, mask), 1), norm_mesh, norm_rigid_mesh, (img_h, img_w))
warp_mesh = output_tps[:,0:3,...]
warp_mesh_mask = output_tps[:,3:6,...]
# calculate the overlapping regions to apply shape-preserving constraints
overlap = torch_tps_transform.transformer(warp_mesh_mask, norm_rigid_mesh, norm_mesh, (img_h, img_w))
overlap = overlap.permute(0, 2, 3, 1).unfold(1, int(img_h/grid_h), int(img_h/grid_h)).unfold(2, int(img_w/grid_w), int(img_w/grid_w))
overlap = torch.mean(overlap.reshape(batch_size, grid_h, grid_w, -1), 3)
overlap_one = torch.ones_like(overlap)
overlap_zero = torch.zeros_like(overlap)
overlap = torch.where(overlap<0.9, overlap_one, overlap_zero)
out_dict = {}
out_dict.update(output_H=output_H, output_H_inv = output_H_inv, warp_mesh = warp_mesh, warp_mesh_mask = warp_mesh_mask, mesh1 = rigid_mesh, mesh2 = mesh, overlap = overlap)
return out_dict
# for train_ft.py
def build_new_ft_model(net, input1_tensor, input2_tensor):
batch_size, _, img_h, img_w = input1_tensor.size()
H_motion, mesh_motion = net(input1_tensor, input2_tensor)
H_motion = H_motion.reshape(-1, 4, 2)
#H_motion = torch.stack([H_motion[...,0]*img_w/512, H_motion[...,1]*img_h/512], 2)
mesh_motion = mesh_motion.reshape(-1, grid_h+1, grid_w+1, 2)
#mesh_motion = torch.stack([mesh_motion[...,0]*img_w/512, mesh_motion[...,1]*img_h/512], 3)
# initialize the source points bs x 4 x 2
src_p = torch.tensor([[0., 0.], [img_w, 0.], [0., img_h], [img_w, img_h]])
if torch.cuda.is_available():
src_p = src_p.cuda()
src_p = src_p.unsqueeze(0).expand(batch_size, -1, -1)
# target points
dst_p = src_p + H_motion
# solve homo using DLT
H = torch_DLT.tensor_DLT(src_p, dst_p)
rigid_mesh = get_rigid_mesh(batch_size, img_h, img_w)
ini_mesh = H2Mesh(H, rigid_mesh)
mesh = ini_mesh + mesh_motion
norm_rigid_mesh = get_norm_mesh(rigid_mesh, img_h, img_w)
norm_mesh = get_norm_mesh(mesh, img_h, img_w)
mask = torch.ones_like(input2_tensor)
if torch.cuda.is_available():
mask = mask.cuda()
output_tps = torch_tps_transform.transformer(torch.cat((input2_tensor, mask), 1), norm_mesh, norm_rigid_mesh, (img_h, img_w))
warp_mesh = output_tps[:,0:3,...]
warp_mesh_mask = output_tps[:,3:6,...]
out_dict = {}
out_dict.update(warp_mesh = warp_mesh, warp_mesh_mask = warp_mesh_mask, rigid_mesh = rigid_mesh, mesh = mesh)
return out_dict
# for train_ft.py
def get_stitched_result(input1_tensor, input2_tensor, rigid_mesh, mesh):
batch_size, _, img_h, img_w = input1_tensor.size()
rigid_mesh = torch.stack([rigid_mesh[...,0]*img_w/512, rigid_mesh[...,1]*img_h/512], 3)
mesh = torch.stack([mesh[...,0]*img_w/512, mesh[...,1]*img_h/512], 3)
######################################
width_max = torch.max(mesh[...,0])
width_max = torch.maximum(torch.tensor(img_w).cuda(), width_max)
width_min = torch.min(mesh[...,0])
width_min = torch.minimum(torch.tensor(0).cuda(), width_min)
height_max = torch.max(mesh[...,1])
height_max = torch.maximum(torch.tensor(img_h).cuda(), height_max)
height_min = torch.min(mesh[...,1])
height_min = torch.minimum(torch.tensor(0).cuda(), height_min)
out_width = width_max - width_min
out_height = height_max - height_min
print(out_width)
print(out_height)
warp1 = torch.zeros([batch_size, 3, out_height.int(), out_width.int()]).cuda()
warp1[:,:, int(torch.abs(height_min)):int(torch.abs(height_min))+img_h, int(torch.abs(width_min)):int(torch.abs(width_min))+img_w] = (input1_tensor+1)*127.5
mask1 = torch.zeros([batch_size, 3, out_height.int(), out_width.int()]).cuda()
mask1[:,:, int(torch.abs(height_min)):int(torch.abs(height_min))+img_h, int(torch.abs(width_min)):int(torch.abs(width_min))+img_w] = 255
mask = torch.ones_like(input2_tensor)
if torch.cuda.is_available():
mask = mask.cuda()
# get warped img2
mesh_trans = torch.stack([mesh[...,0]-width_min, mesh[...,1]-height_min], 3)
norm_rigid_mesh = get_norm_mesh(rigid_mesh, img_h, img_w)
norm_mesh = get_norm_mesh(mesh_trans, out_height, out_width)
stitch_tps_out = torch_tps_transform.transformer(torch.cat([input2_tensor+1, mask], 1), norm_mesh, norm_rigid_mesh, (out_height.int(), out_width.int()))
warp2 = stitch_tps_out[:,0:3,:,:]*127.5
mask2 = stitch_tps_out[:,3:6,:,:]*255
stitched = warp1*(warp1/(warp1+warp2+1e-6)) + warp2*(warp2/(warp1+warp2+1e-6))
stitched_mesh = draw_mesh_on_warp(stitched[0].cpu().detach().numpy().transpose(1,2,0), mesh_trans[0].cpu().detach().numpy())
out_dict = {}
out_dict.update(warp1 = warp1, mask1 = mask1, warp2 = warp2, mask2 = mask2, stitched = stitched, stitched_mesh = stitched_mesh)
return out_dict
# for test_output.py
def build_output_model(net, input1_tensor, input2_tensor):
batch_size, _, img_h, img_w = input1_tensor.size()
resized_input1 = resize_512(input1_tensor)
resized_input2 = resize_512(input2_tensor)
H_motion, mesh_motion = net(resized_input1, resized_input2)
H_motion = H_motion.reshape(-1, 4, 2)
H_motion = torch.stack([H_motion[...,0]*img_w/512, H_motion[...,1]*img_h/512], 2)
mesh_motion = mesh_motion.reshape(-1, grid_h+1, grid_w+1, 2)
mesh_motion = torch.stack([mesh_motion[...,0]*img_w/512, mesh_motion[...,1]*img_h/512], 3)
# initialize the source points bs x 4 x 2
src_p = torch.tensor([[0., 0.], [img_w, 0.], [0., img_h], [img_w, img_h]])
if torch.cuda.is_available():
src_p = src_p.cuda()
src_p = src_p.unsqueeze(0).expand(batch_size, -1, -1)
# target points
dst_p = src_p + H_motion
# solve homo using DLT
H = torch_DLT.tensor_DLT(src_p, dst_p)
rigid_mesh = get_rigid_mesh(batch_size, img_h, img_w)
ini_mesh = H2Mesh(H, rigid_mesh)
mesh = ini_mesh + mesh_motion
width_max = torch.max(mesh[...,0])
width_max = torch.maximum(torch.tensor(img_w).cuda(), width_max)
width_min = torch.min(mesh[...,0])
width_min = torch.minimum(torch.tensor(0).cuda(), width_min)
height_max = torch.max(mesh[...,1])
height_max = torch.maximum(torch.tensor(img_h).cuda(), height_max)
height_min = torch.min(mesh[...,1])
height_min = torch.minimum(torch.tensor(0).cuda(), height_min)
out_width = width_max - width_min
out_height = height_max - height_min
#print(out_width)
#print(out_height)
# get warped img1
M_tensor = torch.tensor([[out_width / 2.0, 0., out_width / 2.0],
[0., out_height / 2.0, out_height / 2.0],
[0., 0., 1.]])
N_tensor = torch.tensor([[img_w / 2.0, 0., img_w / 2.0],
[0., img_h / 2.0, img_h / 2.0],
[0., 0., 1.]])
if torch.cuda.is_available():
M_tensor = M_tensor.cuda()
N_tensor = N_tensor.cuda()
N_tensor_inv = torch.inverse(N_tensor)
I_ = torch.tensor([[1., 0., width_min],
[0., 1., height_min],
[0., 0., 1.]])#.unsqueeze(0)
mask = torch.ones_like(input2_tensor)
if torch.cuda.is_available():
I_ = I_.cuda()
mask = mask.cuda()
I_mat = torch.matmul(torch.matmul(N_tensor_inv, I_), M_tensor).unsqueeze(0)
homo_output = torch_homo_transform.transformer(torch.cat((input1_tensor+1, mask), 1), I_mat, (out_height.int(), out_width.int()))
torch.cuda.empty_cache()
# get warped img2
mesh_trans = torch.stack([mesh[...,0]-width_min, mesh[...,1]-height_min], 3)
norm_rigid_mesh = get_norm_mesh(rigid_mesh, img_h, img_w)
norm_mesh = get_norm_mesh(mesh_trans, out_height, out_width)
tps_output = torch_tps_transform.transformer(torch.cat([input2_tensor+1, mask],1), norm_mesh, norm_rigid_mesh, (out_height.int(), out_width.int()))
out_dict = {}
out_dict.update(final_warp1=homo_output[:, 0:3, ...]-1, final_warp1_mask = homo_output[:, 3:6, ...], final_warp2=tps_output[:, 0:3, ...]-1, final_warp2_mask = tps_output[:, 3:6, ...], mesh1=rigid_mesh, mesh2=mesh_trans)
return out_dict
# define and forward
class Network(nn.Module):
def __init__(self):
super(Network, self).__init__()
self.regressNet1_part1 = nn.Sequential(
nn.Conv2d(2, 64, kernel_size=3, padding=1, bias=False),
nn.ReLU(inplace=True),
nn.Conv2d(64, 64, kernel_size=3, padding=1, bias=False),
nn.ReLU(inplace=True),
nn.MaxPool2d(2, 2),
nn.Conv2d(64, 128, kernel_size=3, padding=1, bias=False),
nn.ReLU(inplace=True),
nn.Conv2d(128, 128, kernel_size=3, padding=1, bias=False),
nn.ReLU(inplace=True),
nn.MaxPool2d(2, 2),
nn.Conv2d(128, 256, kernel_size=3, padding=1, bias=False),
nn.ReLU(inplace=True),
nn.Conv2d(256, 256, kernel_size=3, padding=1, bias=False),
nn.ReLU(inplace=True),
nn.MaxPool2d(2, 2)
)
self.regressNet1_part2 = nn.Sequential(
nn.Linear(in_features=4096, out_features=4096, bias=True),
nn.ReLU(inplace=True),
nn.Linear(in_features=4096, out_features=1024, bias=True),
nn.ReLU(inplace=True),
nn.Linear(in_features=1024, out_features=8, bias=True)
)
self.regressNet2_part1 = nn.Sequential(
nn.Conv2d(2, 64, kernel_size=3, padding=1, bias=False),
nn.ReLU(inplace=True),
nn.Conv2d(64, 64, kernel_size=3, padding=1, bias=False),
nn.ReLU(inplace=True),
nn.MaxPool2d(2, 2),
nn.Conv2d(64, 128, kernel_size=3, padding=1, bias=False),
nn.ReLU(inplace=True),
nn.Conv2d(128, 128, kernel_size=3, padding=1, bias=False),
nn.ReLU(inplace=True),
nn.MaxPool2d(2, 2),
nn.Conv2d(128, 256, kernel_size=3, padding=1, bias=False),
nn.ReLU(inplace=True),
nn.Conv2d(256, 256, kernel_size=3, padding=1, bias=False),
nn.ReLU(inplace=True),
nn.MaxPool2d(2, 2),
nn.Conv2d(256, 512, kernel_size=3, padding=1, bias=False),
nn.ReLU(inplace=True),
nn.Conv2d(512, 512, kernel_size=3, padding=1, bias=False),
nn.ReLU(inplace=True),
nn.MaxPool2d(2, 2)
)
self.regressNet2_part2 = nn.Sequential(
nn.Linear(in_features=8192, out_features=4096, bias=True),
nn.ReLU(inplace=True),
nn.Linear(in_features=4096, out_features=2048, bias=True),
nn.ReLU(inplace=True),
nn.Linear(in_features=2048, out_features=(grid_w+1)*(grid_h+1)*2, bias=True)
)
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight)
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
ssl._create_default_https_context = ssl._create_unverified_context
resnet50_model = models.resnet.resnet50(pretrained=True)
if torch.cuda.is_available():
resnet50_model = resnet50_model.cuda()
self.feature_extractor_stage1, self.feature_extractor_stage2 = self.get_res50_FeatureMap(resnet50_model)
#-----------------------------------------
def get_res50_FeatureMap(self, resnet50_model):
layers_list = []
layers_list.append(resnet50_model.conv1)
layers_list.append(resnet50_model.bn1)
layers_list.append(resnet50_model.relu)
layers_list.append(resnet50_model.maxpool)
layers_list.append(resnet50_model.layer1)
layers_list.append(resnet50_model.layer2)
feature_extractor_stage1 = nn.Sequential(*layers_list)
feature_extractor_stage2 = nn.Sequential(resnet50_model.layer3)
#layers_list.append(resnet50_model.layer3)
return feature_extractor_stage1, feature_extractor_stage2
# forward
def forward(self, input1_tesnor, input2_tesnor):
batch_size, _, img_h, img_w = input1_tesnor.size()
feature_1_64 = self.feature_extractor_stage1(input1_tesnor)
feature_1_32 = self.feature_extractor_stage2(feature_1_64)
feature_2_64 = self.feature_extractor_stage1(input2_tesnor)
feature_2_32 = self.feature_extractor_stage2(feature_2_64)
######### stage 1
correlation_32 = self.CCL(feature_1_32, feature_2_32)
temp_1 = self.regressNet1_part1(correlation_32)
temp_1 = temp_1.view(temp_1.size()[0], -1)
offset_1 = self.regressNet1_part2(temp_1)
H_motion_1 = offset_1.reshape(-1, 4, 2)
src_p = torch.tensor([[0., 0.], [img_w, 0.], [0., img_h], [img_w, img_h]])
if torch.cuda.is_available():
src_p = src_p.cuda()
src_p = src_p.unsqueeze(0).expand(batch_size, -1, -1)
dst_p = src_p + H_motion_1
H = torch_DLT.tensor_DLT(src_p/8, dst_p/8)
M_tensor = torch.tensor([[img_w/8 / 2.0, 0., img_w/8 / 2.0],
[0., img_h/8 / 2.0, img_h/8 / 2.0],
[0., 0., 1.]])
if torch.cuda.is_available():
M_tensor = M_tensor.cuda()
M_tile = M_tensor.unsqueeze(0).expand(batch_size, -1, -1)
M_tensor_inv = torch.inverse(M_tensor)
M_tile_inv = M_tensor_inv.unsqueeze(0).expand(batch_size, -1, -1)
H_mat = torch.matmul(torch.matmul(M_tile_inv, H), M_tile)
warp_feature_2_64 = torch_homo_transform.transformer(feature_2_64, H_mat, (int(img_h/8), int(img_w/8)))
######### stage 2
correlation_64 = self.CCL(feature_1_64, warp_feature_2_64)
temp_2 = self.regressNet2_part1(correlation_64)
temp_2 = temp_2.view(temp_2.size()[0], -1)
offset_2 = self.regressNet2_part2(temp_2)
return offset_1, offset_2
def extract_patches(self, x, kernel=3, stride=1):
if kernel != 1:
x = nn.ZeroPad2d(1)(x)
x = x.permute(0, 2, 3, 1)
all_patches = x.unfold(1, kernel, stride).unfold(2, kernel, stride)
return all_patches
def CCL(self, feature_1, feature_2):
bs, c, h, w = feature_1.size()
norm_feature_1 = F.normalize(feature_1, p=2, dim=1)
norm_feature_2 = F.normalize(feature_2, p=2, dim=1)
#print(norm_feature_2.size())
patches = self.extract_patches(norm_feature_2)
if torch.cuda.is_available():
patches = patches.cuda()
matching_filters = patches.reshape((patches.size()[0], -1, patches.size()[3], patches.size()[4], patches.size()[5]))
match_vol = []
for i in range(bs):
single_match = F.conv2d(norm_feature_1[i].unsqueeze(0), matching_filters[i], padding=1)
match_vol.append(single_match)
match_vol = torch.cat(match_vol, 0)
#print(match_vol .size())
# scale softmax
softmax_scale = 10
match_vol = F.softmax(match_vol*softmax_scale,1)
channel = match_vol.size()[1]
h_one = torch.linspace(0, h-1, h)
one1w = torch.ones(1, w)
if torch.cuda.is_available():
h_one = h_one.cuda()
one1w = one1w.cuda()
h_one = torch.matmul(h_one.unsqueeze(1), one1w)
h_one = h_one.unsqueeze(0).unsqueeze(0).expand(bs, channel, -1, -1)
w_one = torch.linspace(0, w-1, w)
oneh1 = torch.ones(h, 1)
if torch.cuda.is_available():
w_one = w_one.cuda()
oneh1 = oneh1.cuda()
w_one = torch.matmul(oneh1, w_one.unsqueeze(0))
w_one = w_one.unsqueeze(0).unsqueeze(0).expand(bs, channel, -1, -1)
c_one = torch.linspace(0, channel-1, channel)
if torch.cuda.is_available():
c_one = c_one.cuda()
c_one = c_one.unsqueeze(0).unsqueeze(2).unsqueeze(3).expand(bs, -1, h, w)
flow_h = match_vol*(c_one//w - h_one)
flow_h = torch.sum(flow_h, dim=1, keepdim=True)
flow_w = match_vol*(c_one%w - w_one)
flow_w = torch.sum(flow_w, dim=1, keepdim=True)
feature_flow = torch.cat([flow_w, flow_h], 1)
#print(flow.size())
return feature_flow
================================================
FILE: Warp/Codes/test.py
================================================
# coding: utf-8
import argparse
import torch
from torch.utils.data import DataLoader
import torch.nn as nn
import imageio
from network import build_model, Network
from dataset import *
import os
import numpy as np
import skimage
import cv2
last_path = os.path.abspath(os.path.join(os.path.dirname("__file__"), os.path.pardir))
MODEL_DIR = os.path.join(last_path, 'model')
def create_gif(image_list, gif_name, duration=0.35):
frames = []
for image_name in image_list:
frames.append(image_name)
imageio.mimsave(gif_name, frames, 'GIF', duration=0.5)
return
def test(args):
os.environ['CUDA_DEVICES_ORDER'] = "PCI_BUS_ID"
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
# dataset
test_data = TestDataset(data_path=args.test_path)
test_loader = DataLoader(dataset=test_data, batch_size=args.batch_size, num_workers=1, shuffle=False, drop_last=False)
# define the network
net = Network()#build_model(args.model_name)
if torch.cuda.is_available():
net = net.cuda()
#load the existing models if it exists
ckpt_list = glob.glob(MODEL_DIR + "/*.pth")
ckpt_list.sort()
if len(ckpt_list) != 0:
model_path = ckpt_list[-1]
checkpoint = torch.load(model_path)
net.load_state_dict(checkpoint['model'])
print('load model from {}!'.format(model_path))
else:
print('No checkpoint found!')
print("##################start testing#######################")
psnr_list = []
ssim_list = []
net.eval()
for i, batch_value in enumerate(test_loader):
inpu1_tesnor = batch_value[0].float()
inpu2_tesnor = batch_value[1].float()
if torch.cuda.is_available():
inpu1_tesnor = inpu1_tesnor.cuda()
inpu2_tesnor = inpu2_tesnor.cuda()
with torch.no_grad():
batch_out = build_model(net, inpu1_tesnor, inpu2_tesnor, is_training=False)
warp_mesh_mask = batch_out['warp_mesh_mask']
warp_mesh = batch_out['warp_mesh']
warp_mesh_np = ((warp_mesh[0]+1)*127.5).cpu().detach().numpy().transpose(1,2,0)
warp_mesh_mask_np = warp_mesh_mask[0].cpu().detach().numpy().transpose(1,2,0)
inpu1_np = ((inpu1_tesnor[0]+1)*127.5).cpu().detach().numpy().transpose(1,2,0)
# calculate psnr/ssim
psnr = skimage.measure.compare_psnr(inpu1_np*warp_mesh_mask_np, warp_mesh_np*warp_mesh_mask_np, 255)
ssim = skimage.measure.compare_ssim(inpu1_np*warp_mesh_mask_np, warp_mesh_np*warp_mesh_mask_np, data_range=255, multichannel=True)
print('i = {}, psnr = {:.6f}'.format( i+1, psnr))
psnr_list.append(psnr)
ssim_list.append(ssim)
torch.cuda.empty_cache()
print("=================== Analysis ==================")
print("psnr")
psnr_list.sort(reverse = True)
psnr_list_30 = psnr_list[0 : 331]
psnr_list_60 = psnr_list[331: 663]
psnr_list_100 = psnr_list[663: -1]
print("top 30%", np.mean(psnr_list_30))
print("top 30~60%", np.mean(psnr_list_60))
print("top 60~100%", np.mean(psnr_list_100))
print('average psnr:', np.mean(psnr_list))
ssim_list.sort(reverse = True)
ssim_list_30 = ssim_list[0 : 331]
ssim_list_60 = ssim_list[331: 663]
ssim_list_100 = ssim_list[663: -1]
print("top 30%", np.mean(ssim_list_30))
print("top 30~60%", np.mean(ssim_list_60))
print("top 60~100%", np.mean(ssim_list_100))
print('average ssim:', np.mean(ssim_list))
print("##################end testing#######################")
if __name__=="__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--gpu', type=str, default='0')
parser.add_argument('--batch_size', type=int, default=1)
parser.add_argument('--test_path', type=str, default='/opt/data/private/nl/Data/UDIS-D/testing/')
print('<==================== Loading data ===================>\n')
args = parser.parse_args()
print(args)
test(args)
================================================
FILE: Warp/Codes/test_other.py
================================================
import argparse
import torch
import numpy as np
import os
import torch.nn as nn
import torch.optim as optim
import cv2
#from torch_homography_model import build_model
from network import get_stitched_result, Network, build_new_ft_model
import glob
from loss import cal_lp_loss2
import torchvision.transforms as T
#import PIL
resize_512 = T.Resize((512,512))
def loadSingleData(data_path, img1_name, img2_name):
# load image1
input1 = cv2.imread(data_path+img1_name)
input1 = input1.astype(dtype=np.float32)
input1 = (input1 / 127.5) - 1.0
input1 = np.transpose(input1, [2, 0, 1])
# load image2
input2 = cv2.imread(data_path+img2_name)
input2 = input2.astype(dtype=np.float32)
input2 = (input2 / 127.5) - 1.0
input2 = np.transpose(input2, [2, 0, 1])
# convert to tensor
input1_tensor = torch.tensor(input1).unsqueeze(0)
input2_tensor = torch.tensor(input2).unsqueeze(0)
return (input1_tensor, input2_tensor)
# path of project
#nl: os.path.dirname("__file__") ----- the current absolute path
#nl: os.path.pardir ---- the last path
last_path = os.path.abspath(os.path.join(os.path.dirname("__file__"), os.path.pardir))
#nl: path to save the model files
MODEL_DIR = os.path.join(last_path, 'model')
#nl: create folders if it dose not exist
if not os.path.exists(MODEL_DIR):
os.makedirs(MODEL_DIR)
def train(args):
os.environ['CUDA_DEVICES_ORDER'] = "PCI_BUS_ID"
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
# define the network
net = Network()
if torch.cuda.is_available():
net = net.cuda()
# define the optimizer and learning rate
optimizer = optim.Adam(net.parameters(), lr=1e-4, betas=(0.9, 0.999), eps=1e-08) # default as 0.0001
scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.97)
#load the existing models if it exists
ckpt_list = glob.glob(MODEL_DIR + "/*.pth")
ckpt_list.sort()
if len(ckpt_list) != 0:
model_path = ckpt_list[-1]
checkpoint = torch.load(model_path)
net.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])
start_epoch = checkpoint['epoch']
scheduler.last_epoch = start_epoch
print('load model from {}!'.format(model_path))
else:
start_epoch = 0
print('training from stratch!')
# load dataset(only one pair of images)
input1_tensor, input2_tensor = loadSingleData(data_path=args.path, img1_name = args.img1_name, img2_name = args.img2_name)
if torch.cuda.is_available():
input1_tensor = input1_tensor.cuda()
input2_tensor = input2_tensor.cuda()
input1_tensor_512 = resize_512(input1_tensor)
input2_tensor_512 = resize_512(input2_tensor)
loss_list = []
print("##################start iteration#######################")
for epoch in range(start_epoch, start_epoch + args.max_iter):
net.train()
optimizer.zero_grad()
batch_out = build_new_ft_model(net, input1_tensor_512, input2_tensor_512)
warp_mesh = batch_out['warp_mesh']
warp_mesh_mask = batch_out['warp_mesh_mask']
rigid_mesh = batch_out['rigid_mesh']
mesh = batch_out['mesh']
total_loss = cal_lp_loss2(input1_tensor_512, warp_mesh, warp_mesh_mask)
total_loss.backward()
# clip the gradient
torch.nn.utils.clip_grad_norm_(net.parameters(), max_norm=3, norm_type=2)
optimizer.step()
current_iter = epoch-start_epoch+1
print("Training: Iteration[{:0>3}/{:0>3}] Total Loss: {:.4f} lr={:.8f}".format(current_iter, args.max_iter, total_loss, optimizer.state_dict()['param_groups'][0]['lr']))
loss_list.append(total_loss)
if current_iter == 1:
with torch.no_grad():
output = get_stitched_result(input1_tensor, input2_tensor, rigid_mesh, mesh)
cv2.imwrite( args.path+ 'before_optimization.jpg', output['stitched'][0].cpu().detach().numpy().transpose(1,2,0))
cv2.imwrite( args.path+ 'before_optimization_mesh.jpg', output['stitched_mesh'])
if current_iter >= 4:
if torch.abs(loss_list[current_iter-4]-loss_list[current_iter-3]) <= 1e-4 and torch.abs(loss_list[current_iter-3]-loss_list[current_iter-2]) <= 1e-4 \
and torch.abs(loss_list[current_iter-2]-loss_list[current_iter-1]) <= 1e-4:
with torch.no_grad():
output = get_stitched_result(input1_tensor, input2_tensor, rigid_mesh, mesh)
path = args.path + "iter-" + str(epoch-start_epoch+1).zfill(3) + ".jpg"
cv2.imwrite(path, output['stitched'][0].cpu().detach().numpy().transpose(1,2,0))
cv2.imwrite(args.path + "iter-" + str(epoch-start_epoch+1).zfill(3) + "_mesh.jpg", output['stitched_mesh'])
cv2.imwrite( args.path+'warp1.jpg', output['warp1'][0].cpu().detach().numpy().transpose(1,2,0))
cv2.imwrite( args.path+'warp2.jpg', output['warp2'][0].cpu().detach().numpy().transpose(1,2,0))
cv2.imwrite( args.path+'mask1.jpg', output['mask1'][0].cpu().detach().numpy().transpose(1,2,0))
cv2.imwrite( args.path+'mask2.jpg', output['mask2'][0].cpu().detach().numpy().transpose(1,2,0))
break
if current_iter == args.max_iter:
with torch.no_grad():
output = get_stitched_result(input1_tensor, input2_tensor, rigid_mesh, mesh)
path = args.path + "iter-" + str(epoch-start_epoch+1).zfill(3) + ".jpg"
cv2.imwrite(path, output['stitched'][0].cpu().detach().numpy().transpose(1,2,0))
cv2.imwrite(args.path + "iter-" + str(epoch-start_epoch+1).zfill(3) + "_mesh.jpg", output['stitched_mesh'])
cv2.imwrite( args.path+'warp1.jpg', output['warp1'][0].cpu().detach().numpy().transpose(1,2,0))
cv2.imwrite( args.path+'warp2.jpg', output['warp2'][0].cpu().detach().numpy().transpose(1,2,0))
cv2.imwrite( args.path+'mask1.jpg', output['mask1'][0].cpu().detach().numpy().transpose(1,2,0))
cv2.imwrite( args.path+'mask2.jpg', output['mask2'][0].cpu().detach().numpy().transpose(1,2,0))
scheduler.step()
print("##################end iteration#######################")
if __name__=="__main__":
print('<==================== setting arguments ===================>\n')
#nl: create the argument parser
parser = argparse.ArgumentParser()
#nl: add arguments
parser.add_argument('--gpu', type=str, default='0')
parser.add_argument('--max_iter', type=int, default=50)
parser.add_argument('--path', type=str, default='../../Carpark-DHW/')
parser.add_argument('--img1_name', type=str, default='input1.jpg')
parser.add_argument('--img2_name', type=str, default='input2.jpg')
#nl: parse the arguments
args = parser.parse_args()
print(args)
#nl: rain
train(args)
================================================
FILE: Warp/Codes/test_output.py
================================================
# coding: utf-8
import argparse
import torch
from torch.utils.data import DataLoader
import torch.nn as nn
import imageio
from network import build_output_model, Network
from dataset import *
import os
import cv2
import grid_res
grid_h = grid_res.GRID_H
grid_w = grid_res.GRID_W
last_path = os.path.abspath(os.path.join(os.path.dirname("__file__"), os.path.pardir))
MODEL_DIR = os.path.join(last_path, 'model')
def draw_mesh_on_warp(warp, f_local):
point_color = (0, 255, 0) # BGR
thickness = 2
lineType = 8
num = 1
for i in range(grid_h+1):
for j in range(grid_w+1):
num = num + 1
if j == grid_w and i == grid_h:
continue
elif j == grid_w:
cv2.line(warp, (int(f_local[i,j,0]), int(f_local[i,j,1])), (int(f_local[i+1,j,0]), int(f_local[i+1,j,1])), point_color, thickness, lineType)
elif i == grid_h:
cv2.line(warp, (int(f_local[i,j,0]), int(f_local[i,j,1])), (int(f_local[i,j+1,0]), int(f_local[i,j+1,1])), point_color, thickness, lineType)
else :
cv2.line(warp, (int(f_local[i,j,0]), int(f_local[i,j,1])), (int(f_local[i+1,j,0]), int(f_local[i+1,j,1])), point_color, thickness, lineType)
cv2.line(warp, (int(f_local[i,j,0]), int(f_local[i,j,1])), (int(f_local[i,j+1,0]), int(f_local[i,j+1,1])), point_color, thickness, lineType)
return warp
def create_gif(image_list, gif_name, duration=0.35):
frames = []
for image_name in image_list:
frames.append(image_name)
imageio.mimsave(gif_name, frames, 'GIF', duration=0.5)
return
def test(args):
os.environ['CUDA_DEVICES_ORDER'] = "PCI_BUS_ID"
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
# dataset
test_data = TestDataset(data_path=args.test_path)
#nl: set num_workers = the number of cpus
test_loader = DataLoader(dataset=test_data, batch_size=args.batch_size, num_workers=1, shuffle=False, drop_last=False)
# define the network
net = Network()#build_model(args.model_name)
if torch.cuda.is_available():
net = net.cuda()
#load the existing models if it exists
ckpt_list = glob.glob(MODEL_DIR + "/*.pth")
ckpt_list.sort()
if len(ckpt_list) != 0:
model_path = ckpt_list[-1]
#model_path = '/opt/data/private/nl/Repository/Unsupervised_Mesh_Stitching/UDISv2-88/UDISv2-Homo_TPS88-10grid_NO-res50-new3/model/epoch150_model.pth'
checkpoint = torch.load(model_path)
net.load_state_dict(checkpoint['model'])
print('load model from {}!'.format(model_path))
else:
print('No checkpoint found!')
print("##################start testing#######################")
# create folders if it dose not exist
path_ave_fusion = '../ave_fusion/'
if not os.path.exists(path_ave_fusion):
os.makedirs(path_ave_fusion)
path_warp1 = args.test_path + 'warp1/'
if not os.path.exists(path_warp1):
os.makedirs(path_warp1)
path_warp2 = args.test_path + 'warp2/'
if not os.path.exists(path_warp2):
os.makedirs(path_warp2)
path_mask1 = args.test_path + 'mask1/'
if not os.path.exists(path_mask1):
os.makedirs(path_mask1)
path_mask2 = args.test_path + 'mask2/'
if not os.path.exists(path_mask2):
os.makedirs(path_mask2)
net.eval()
for i, batch_value in enumerate(test_loader):
#if i != 975:
# continue
inpu1_tesnor = batch_value[0].float()
inpu2_tesnor = batch_value[1].float()
if torch.cuda.is_available():
inpu1_tesnor = inpu1_tesnor.cuda()
inpu2_tesnor = inpu2_tesnor.cuda()
with torch.no_grad():
batch_out = build_output_model(net, inpu1_tesnor, inpu2_tesnor)
final_warp1 = batch_out['final_warp1']
final_warp1_mask = batch_out['final_warp1_mask']
final_warp2 = batch_out['final_warp2']
final_warp2_mask = batch_out['final_warp2_mask']
final_mesh1 = batch_out['mesh1']
final_mesh2 = batch_out['mesh2']
final_warp1 = ((final_warp1[0]+1)*127.5).cpu().detach().numpy().transpose(1,2,0)
final_warp2 = ((final_warp2[0]+1)*127.5).cpu().detach().numpy().transpose(1,2,0)
final_warp1_mask = final_warp1_mask[0].cpu().detach().numpy().transpose(1,2,0)
final_warp2_mask = final_warp2_mask[0].cpu().detach().numpy().transpose(1,2,0)
final_mesh1 = final_mesh1[0].cpu().detach().numpy()
final_mesh2 = final_mesh2[0].cpu().detach().numpy()
path = path_warp1 + str(i+1).zfill(6) + ".jpg"
cv2.imwrite(path, final_warp1)
path = path_warp2 + str(i+1).zfill(6) + ".jpg"
cv2.imwrite(path, final_warp2)
path = path_mask1 + str(i+1).zfill(6) + ".jpg"
cv2.imwrite(path, final_warp1_mask*255)
path = path_mask2 + str(i+1).zfill(6) + ".jpg"
cv2.imwrite(path, final_warp2_mask*255)
ave_fusion = final_warp1 * (final_warp1/ (final_warp1+final_warp2+1e-6)) + final_warp2 * (final_warp2/ (final_warp1+final_warp2+1e-6))
path = path_ave_fusion + str(i+1).zfill(6) + ".jpg"
cv2.imwrite(path, ave_fusion)
print('i = {}'.format( i+1))
torch.cuda.empty_cache()
print("##################end testing#######################")
if __name__=="__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--gpu', type=str, default='0')
parser.add_argument('--batch_size', type=int, default=1)
# /opt/data/private/nl/Data/UDIS-D/testing/ or /opt/data/private/nl/Data/UDIS-D/training/
parser.add_argument('--test_path', type=str, default='/opt/data/private/nl/Data/UDIS-D/testing/')
print('<==================== Loading data ===================>\n')
args = parser.parse_args()
print(args)
test(args)
================================================
FILE: Warp/Codes/train.py
================================================
import argparse
import torch
from torch.utils.data import DataLoader
import os
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter
from network import build_model, Network
from dataset import TrainDataset
import glob
from loss import cal_lp_loss, inter_grid_loss, intra_grid_loss
last_path = os.path.abspath(os.path.join(os.path.dirname("__file__"), os.path.pardir))
# path to save the summary files
SUMMARY_DIR = os.path.join(last_path, 'summary')
writer = SummaryWriter(log_dir=SUMMARY_DIR)
# path to save the model files
MODEL_DIR = os.path.join(last_path, 'model')
# create folders if it dose not exist
if not os.path.exists(MODEL_DIR):
os.makedirs(MODEL_DIR)
if not os.path.exists(SUMMARY_DIR):
os.makedirs(SUMMARY_DIR)
def train(args):
os.environ['CUDA_DEVICES_ORDER'] = "PCI_BUS_ID"
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
# define dataset
train_data = TrainDataset(data_path=args.train_path)
train_loader = DataLoader(dataset=train_data, batch_size=args.batch_size, num_workers=4, shuffle=True, drop_last=True)
# define the network
net = Network()
if torch.cuda.is_available():
net = net.cuda()
# define the optimizer and learning rate
optimizer = optim.Adam(net.parameters(), lr=1e-4, betas=(0.9, 0.999), eps=1e-08) # default as 0.0001
scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.97)
#load the existing models if it exists
ckpt_list = glob.glob(MODEL_DIR + "/*.pth")
ckpt_list.sort()
if len(ckpt_list) != 0:
model_path = ckpt_list[-1]
checkpoint = torch.load(model_path)
net.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])
start_epoch = checkpoint['epoch']
glob_iter = checkpoint['glob_iter']
scheduler.last_epoch = start_epoch
print('load model from {}!'.format(model_path))
else:
start_epoch = 0
glob_iter = 0
print('training from stratch!')
print("##################start training#######################")
score_print_fre = 300
for epoch in range(start_epoch, args.max_epoch):
print("start epoch {}".format(epoch))
net.train()
loss_sigma = 0.0
overlap_loss_sigma = 0.
nonoverlap_loss_sigma = 0.
print(epoch, 'lr={:.6f}'.format(optimizer.state_dict()['param_groups'][0]['lr']))
for i, batch_value in enumerate(train_loader):
inpu1_tesnor = batch_value[0].float()
inpu2_tesnor = batch_value[1].float()
if torch.cuda.is_available():
inpu1_tesnor = inpu1_tesnor.cuda()
inpu2_tesnor = inpu2_tesnor.cuda()
# forward, backward, update weights
optimizer.zero_grad()
batch_out = build_model(net, inpu1_tesnor, inpu2_tesnor)
# result
output_H = batch_out['output_H']
output_H_inv = batch_out['output_H_inv']
warp_mesh = batch_out['warp_mesh']
warp_mesh_mask = batch_out['warp_mesh_mask']
mesh1 = batch_out['mesh1']
mesh2 = batch_out['mesh2']
overlap = batch_out['overlap']
# calculate loss for overlapping regions
overlap_loss = cal_lp_loss(inpu1_tesnor, inpu2_tesnor, output_H, output_H_inv, warp_mesh, warp_mesh_mask)
# calculate loss for non-overlapping regions
nonoverlap_loss = 10*inter_grid_loss(overlap, mesh2) + 10*intra_grid_loss(mesh2)
total_loss = overlap_loss + nonoverlap_loss
total_loss.backward()
# clip the gradient
torch.nn.utils.clip_grad_norm_(net.parameters(), max_norm=3, norm_type=2)
optimizer.step()
overlap_loss_sigma += overlap_loss.item()
nonoverlap_loss_sigma += nonoverlap_loss.item()
loss_sigma += total_loss.item()
print(glob_iter)
# record loss and images in tensorboard
if i % score_print_fre == 0 and i != 0:
average_loss = loss_sigma / score_print_fre
average_overlap_loss = overlap_loss_sigma/ score_print_fre
average_nonoverlap_loss = nonoverlap_loss_sigma/ score_print_fre
loss_sigma = 0.0
overlap_loss_sigma = 0.
nonoverlap_loss_sigma = 0.
print("Training: Epoch[{:0>3}/{:0>3}] Iteration[{:0>3}]/[{:0>3}] Total Loss: {:.4f} Overlap Loss: {:.4f} Non-overlap Loss: {:.4f} lr={:.8f}".format(epoch + 1, args.max_epoch, i + 1, len(train_loader),
average_loss, average_overlap_loss, average_nonoverlap_loss, optimizer.state_dict()['param_groups'][0]['lr']))
# visualization
writer.add_image("inpu1", (inpu1_tesnor[0]+1.)/2., glob_iter)
writer.add_image("inpu2", (inpu2_tesnor[0]+1.)/2., glob_iter)
writer.add_image("warp_H", (output_H[0,0:3,:,:]+1.)/2., glob_iter)
writer.add_image("warp_mesh", (warp_mesh[0]+1.)/2., glob_iter)
writer.add_scalar('lr', optimizer.state_dict()['param_groups'][0]['lr'], glob_iter)
writer.add_scalar('total loss', average_loss, glob_iter)
writer.add_scalar('overlap loss', average_overlap_loss, glob_iter)
writer.add_scalar('nonoverlap loss', average_nonoverlap_loss, glob_iter)
glob_iter += 1
scheduler.step()
# save model
if ((epoch+1) % 10 == 0 or (epoch+1)==args.max_epoch):
filename ='epoch' + str(epoch+1).zfill(3) + '_model.pth'
model_save_path = os.path.join(MODEL_DIR, filename)
state = {'model': net.state_dict(), 'optimizer': optimizer.state_dict(), 'epoch': epoch+1, "glob_iter": glob_iter}
torch.save(state, model_save_path)
print("##################end training#######################")
if __name__=="__main__":
print('<==================== setting arguments ===================>\n')
# create the argument parser
parser = argparse.ArgumentParser()
# add arguments
parser.add_argument('--gpu', type=str, default='0')
parser.add_argument('--batch_size', type=int, default=4)
parser.add_argument('--max_epoch', type=int, default=100)
parser.add_argument('--train_path', type=str, default='/opt/data/private/nl/Data/UDIS-D/training/')
# parse the arguments
args = parser.parse_args()
print(args)
# train
train(args)
================================================
FILE: Warp/Codes/utils/torch_DLT.py
================================================
import torch
import numpy as np
import cv2
# src_p: shape=(bs, 4, 2)
# det_p: shape=(bs, 4, 2)
#
# | h1 |
# | h2 |
# | h3 |
# | x1 y1 1 0 0 0 -x1x2 -y1x2 | | h4 | = | x2 |
# | 0 0 0 x1 y1 1 -x1y2 -y1y2 | | h5 | | y2 |
# | h6 |
# | h7 |
# | h8 |
def tensor_DLT(src_p, dst_p):
bs, _, _ = src_p.shape
ones = torch.ones(bs, 4, 1)
if torch.cuda.is_available():
ones = ones.cuda()
xy1 = torch.cat((src_p, ones), 2)
zeros = torch.zeros_like(xy1)
if torch.cuda.is_available():
zeros = zeros.cuda()
xyu, xyd = torch.cat((xy1, zeros), 2), torch.cat((zeros, xy1), 2)
M1 = torch.cat((xyu, xyd), 2).reshape(bs, -1, 6)
M2 = torch.matmul(
dst_p.reshape(-1, 2, 1),
src_p.reshape(-1, 1, 2),
).reshape(bs, -1, 2)
# Ah = b
A = torch.cat((M1, -M2), 2)
b = dst_p.reshape(bs, -1, 1)
#h = A^{-1}b
Ainv = torch.inverse(A)
h8 = torch.matmul(Ainv, b).reshape(bs, 8)
H = torch.cat((h8, ones[:,0,:]), 1).reshape(bs, 3, 3)
return H
================================================
FILE: Warp/Codes/utils/torch_homo_transform.py
================================================
import torch
import numpy as np
def transformer(U, theta, out_size, **kwargs):
def _repeat(x, n_repeats):
rep = torch.ones([n_repeats, ]).unsqueeze(0)
rep = rep.int()
x = x.int()
x = torch.matmul(x.reshape([-1,1]), rep)
return x.reshape([-1])
def _interpolate(im, x, y, out_size):
num_batch, num_channels , height, width = im.size()
height_f = height
width_f = width
out_height, out_width = out_size[0], out_size[1]
zero = 0
max_y = height - 1
max_x = width - 1
x = (x + 1.0)*(width_f) / 2.0
y = (y + 1.0) * (height_f) / 2.0
# do sampling
x0 = torch.floor(x).int()
x1 = x0 + 1
y0 = torch.floor(y).int()
y1 = y0 + 1
x0 = torch.clamp(x0, zero, max_x)
x1 = torch.clamp(x1, zero, max_x)
y0 = torch.clamp(y0, zero, max_y)
y1 = torch.clamp(y1, zero, max_y)
dim2 = torch.from_numpy( np.array(width) )
dim1 = torch.from_numpy( np.array(width * height) )
base = _repeat(torch.arange(0,num_batch) * dim1, out_height * out_width)
if torch.cuda.is_available():
dim2 = dim2.cuda()
dim1 = dim1.cuda()
y0 = y0.cuda()
y1 = y1.cuda()
x0 = x0.cuda()
x1 = x1.cuda()
base = base.cuda()
base_y0 = base + y0 * dim2
base_y1 = base + y1 * dim2
idx_a = base_y0 + x0
idx_b = base_y1 + x0
idx_c = base_y0 + x1
idx_d = base_y1 + x1
# channels dim
im = im.permute(0,2,3,1)
im_flat = im.reshape([-1, num_channels]).float()
idx_a = idx_a.unsqueeze(-1).long()
idx_a = idx_a.expand(out_height * out_width * num_batch,num_channels)
Ia = torch.gather(im_flat, 0, idx_a)
idx_b = idx_b.unsqueeze(-1).long()
idx_b = idx_b.expand(out_height * out_width * num_batch, num_channels)
Ib = torch.gather(im_flat, 0, idx_b)
idx_c = idx_c.unsqueeze(-1).long()
idx_c = idx_c.expand(out_height * out_width * num_batch, num_channels)
Ic = torch.gather(im_flat, 0, idx_c)
idx_d = idx_d.unsqueeze(-1).long()
idx_d = idx_d.expand(out_height * out_width * num_batch, num_channels)
Id = torch.gather(im_flat, 0, idx_d)
x0_f = x0.float()
x1_f = x1.float()
y0_f = y0.float()
y1_f = y1.float()
wa = torch.unsqueeze(((x1_f - x) * (y1_f - y)), 1)
wb = torch.unsqueeze(((x1_f - x) * (y - y0_f)), 1)
wc = torch.unsqueeze(((x - x0_f) * (y1_f - y)), 1)
wd = torch.unsqueeze(((x - x0_f) * (y - y0_f)), 1)
output = wa*Ia+wb*Ib+wc*Ic+wd*Id
return output
def _meshgrid(height, width):
x_t = torch.matmul(torch.ones([height, 1]),
torch.transpose(torch.unsqueeze(torch.linspace(-1.0, 1.0, width), 1), 1, 0))
y_t = torch.matmul(torch.unsqueeze(torch.linspace(-1.0, 1.0, height), 1),
torch.ones([1, width]))
#x_t = torch.matmul(torch.ones([height, 1]),
# torch.transpose(torch.unsqueeze(torch.linspace(0.0, width.float(), width), 1), 1, 0))
#y_t = torch.matmul(torch.unsqueeze(torch.linspace(0.0, height.float(), height), 1),
# torch.ones([1, width]))
x_t_flat = x_t.reshape((1, -1)).float()
y_t_flat = y_t.reshape((1, -1)).float()
ones = torch.ones_like(x_t_flat)
grid = torch.cat([x_t_flat, y_t_flat, ones], 0)
if torch.cuda.is_available():
grid = grid.cuda()
return grid
def _transform(theta, input_dim, out_size):
num_batch, num_channels , height, width = input_dim.size()
# Changed
theta = theta.reshape([-1, 3, 3]).float()
out_height, out_width = out_size[0], out_size[1]
grid = _meshgrid(out_height, out_width)
grid = grid.unsqueeze(0).reshape([1,-1])
shape = grid.size()
grid = grid.expand(num_batch,shape[1])
grid = grid.reshape([num_batch, 3, -1])
T_g = torch.matmul(theta, grid)
x_s = T_g[:,0,:]
y_s = T_g[:,1,:]
t_s = T_g[:,2,:]
t_s_flat = t_s.reshape([-1])
# smaller
small = 1e-7
smallers = 1e-6*(1.0 - torch.ge(torch.abs(t_s_flat), small).float())
t_s_flat = t_s_flat + smallers
#condition = torch.sum(torch.gt(torch.abs(t_s_flat), small).float())
# Ty changed
x_s_flat = x_s.reshape([-1]) / t_s_flat
y_s_flat = y_s.reshape([-1]) / t_s_flat
input_transformed = _interpolate( input_dim, x_s_flat, y_s_flat,out_size)
output = input_transformed.reshape([num_batch, out_height, out_width, num_channels])
output = output.permute(0,3,1,2)
return output#, condition
output = _transform(theta, U, out_size)
return output#, condition
================================================
FILE: Warp/Codes/utils/torch_tps_transform.py
================================================
import torch
import numpy as np
# transforming an image (U) from target (control points) to source (control points)
# all the points should be normalized from -1 ~1
def transformer(U, source, target, out_size):
def _repeat(x, n_repeats):
rep = torch.ones([n_repeats, ]).unsqueeze(0)
rep = rep.int()
x = x.int()
x = torch.matmul(x.reshape([-1,1]), rep)
return x.reshape([-1])
def _interpolate(im, x, y, out_size):
num_batch, num_channels , height, width = im.size()
height_f = height
width_f = width
out_height, out_width = out_size[0], out_size[1]
zero = 0
max_y = height - 1
max_x = width - 1
x = (x + 1.0)*(width_f) / 2.0
y = (y + 1.0) * (height_f) / 2.0
# do sampling
x0 = torch.floor(x).int()
x1 = x0 + 1
y0 = torch.floor(y).int()
y1 = y0 + 1
x0 = torch.clamp(x0, zero, max_x)
x1 = torch.clamp(x1, zero, max_x)
y0 = torch.clamp(y0, zero, max_y)
y1 = torch.clamp(y1, zero, max_y)
dim2 = torch.from_numpy( np.array(width) )
dim1 = torch.from_numpy( np.array(width * height) )
base = _repeat(torch.arange(0,num_batch) * dim1, out_height * out_width)
if torch.cuda.is_available():
dim2 = dim2.cuda()
dim1 = dim1.cuda()
y0 = y0.cuda()
y1 = y1.cuda()
x0 = x0.cuda()
x1 = x1.cuda()
base = base.cuda()
base_y0 = base + y0 * dim2
base_y1 = base + y1 * dim2
idx_a = base_y0 + x0
idx_b = base_y1 + x0
idx_c = base_y0 + x1
idx_d = base_y1 + x1
# channels dim
im = im.permute(0,2,3,1)
im_flat = im.reshape([-1, num_channels]).float()
idx_a = idx_a.unsqueeze(-1).long()
idx_a = idx_a.expand(out_height * out_width * num_batch,num_channels)
Ia = torch.gather(im_flat, 0, idx_a)
idx_b = idx_b.unsqueeze(-1).long()
idx_b = idx_b.expand(out_height * out_width * num_batch, num_channels)
Ib = torch.gather(im_flat, 0, idx_b)
idx_c = idx_c.unsqueeze(-1).long()
idx_c = idx_c.expand(out_height * out_width * num_batch, num_channels)
Ic = torch.gather(im_flat, 0, idx_c)
idx_d = idx_d.unsqueeze(-1).long()
idx_d = idx_d.expand(out_height * out_width * num_batch, num_channels)
Id = torch.gather(im_flat, 0, idx_d)
x0_f = x0.float()
x1_f = x1.float()
y0_f = y0.float()
y1_f = y1.float()
wa = torch.unsqueeze(((x1_f - x) * (y1_f - y)), 1)
wb = torch.unsqueeze(((x1_f - x) * (y - y0_f)), 1)
wc = torch.unsqueeze(((x - x0_f) * (y1_f - y)), 1)
wd = torch.unsqueeze(((x - x0_f) * (y - y0_f)), 1)
output = wa*Ia+wb*Ib+wc*Ic+wd*Id
return output
def _meshgrid(height, width, source):
x_t = torch.matmul(torch.ones([height, 1]), torch.unsqueeze(torch.linspace(-1.0, 1.0, width), 0))
y_t = torch.matmul(torch.unsqueeze(torch.linspace(-1.0, 1.0, height), 1), torch.ones([1, width]))
if torch.cuda.is_available():
x_t = x_t.cuda()
y_t = y_t.cuda()
x_t_flat = x_t.reshape([1, 1, -1])
y_t_flat = y_t.reshape([1, 1, -1])
num_batch = source.size()[0]
px = torch.unsqueeze(source[:,:,0], 2) # [bn, pn, 1]
py = torch.unsqueeze(source[:,:,1], 2) # [bn, pn, 1]
if torch.cuda.is_available():
px = px.cuda()
py = py.cuda()
d2 = torch.square(x_t_flat - px) + torch.square(y_t_flat - py)
r = d2 * torch.log(d2 + 1e-6) # [bn, pn, h*w]
x_t_flat_g = x_t_flat.expand(num_batch, -1, -1) # [bn, 1, h*w]
y_t_flat_g = y_t_flat.expand(num_batch, -1, -1) # [bn, 1, h*w]
ones = torch.ones_like(x_t_flat_g) # [bn, 1, h*w]
if torch.cuda.is_available():
ones = ones.cuda()
grid = torch.cat((ones, x_t_flat_g, y_t_flat_g, r), 1) # [bn, 3+pn, h*w]
#if torch.cuda.is_available():
# grid = grid.cuda()
return grid
def _transform(T, source, input_dim, out_size):
num_batch, num_channels, height, width = input_dim.size()
out_height, out_width = out_size[0], out_size[1]
grid = _meshgrid(out_height, out_width, source) # [bn, 3+pn, h*w]
# transform A x (1, x_t, y_t, r1, r2, ..., rn) -> (x_s, y_s)
# [bn, 2, pn+3] x [bn, pn+3, h*w] -> [bn, 2, h*w]
T_g = torch.matmul(T, grid)
x_s = T_g[:,0,:]
y_s = T_g[:,1,:]
x_s_flat = x_s.reshape([-1])
y_s_flat = y_s.reshape([-1])
input_transformed = _interpolate(input_dim, x_s_flat, y_s_flat,out_size)
output = input_transformed.reshape([num_batch, out_height, out_width, num_channels])
output = output.permute(0,3,1,2)
return output#, condition
def _solve_system(source, target):
num_batch = source.size()[0]
num_point = source.size()[1]
np.set_printoptions(precision=8)
ones = torch.ones(num_batch, num_point, 1).float()
if torch.cuda.is_available():
ones = ones.cuda()
p = torch.cat([ones, source], 2) # [bn, pn, 3]
p_1 = p.reshape([num_batch, -1, 1, 3]) # [bn, pn, 1, 3]
p_2 = p.reshape([num_batch, 1, -1, 3]) # [bn, 1, pn, 3]
d2 = torch.sum(torch.square(p_1-p_2), 3) # p1 - p2: [bn, pn, pn, 3] final output: [bn, pn, pn]
r = d2 * torch.log(d2 + 1e-6) # [bn, pn, pn]
zeros = torch.zeros(num_batch, 3, 3).float()
if torch.cuda.is_available():
zeros = zeros.cuda()
W_0 = torch.cat((p, r), 2) # [bn, pn, 3+pn]
W_1 = torch.cat((zeros, p.permute(0,2,1)), 2) # [bn, 3, pn+3]
W = torch.cat((W_0, W_1), 1) # [bn, pn+3, pn+3]
W_inv = torch.inverse(W.type(torch.float64))
zeros2 = torch.zeros(num_batch, 3, 2)
if torch.cuda.is_available():
zeros2 = zeros2.cuda()
tp = torch.cat((target, zeros2), 1) # [bn, pn+3, 2]
T = torch.matmul(W_inv, tp.type(torch.float64)) # [bn, pn+3, 2]
T = T.permute(0, 2, 1) # [bn, 2, pn+3]
return T.type(torch.float32)
T = _solve_system(source, target)
output = _transform(T, source, U, out_size)
return output
================================================
FILE: Warp/Codes/utils/torch_tps_transform2.py
================================================
import torch
import numpy as np
# transforming an image (U) from target (control points) to source (control points)
# all the points should be normalized from -1 ~1
# compared with torch_tps_transform.py, this version move some operations from GPU to CPU to save GPU memory
def transformer(U, source, target, out_size):
def _repeat(x, n_repeats):
rep = torch.ones([n_repeats, ]).unsqueeze(0)
rep = rep.int()
x = x.int()
x = torch.matmul(x.reshape([-1,1]), rep)
return x.reshape([-1])
def _interpolate(im, x, y, out_size):
num_batch, num_channels , height, width = im.size()
height_f = height
width_f = width
out_height, out_width = out_size[0], out_size[1]
zero = 0
max_y = height - 1
max_x = width - 1
x = (x + 1.0)*(width_f) / 2.0
y = (y + 1.0) * (height_f) / 2.0
# do sampling
x0 = torch.floor(x).int()
x1 = x0 + 1
y0 = torch.floor(y).int()
y1 = y0 + 1
x0 = torch.clamp(x0, zero, max_x)
x1 = torch.clamp(x1, zero, max_x)
y0 = torch.clamp(y0, zero, max_y)
y1 = torch.clamp(y1, zero, max_y)
dim2 = torch.from_numpy( np.array(width) )
dim1 = torch.from_numpy( np.array(width * height) )
base = _repeat(torch.arange(0,num_batch) * dim1, out_height * out_width)
if torch.cuda.is_available():
dim2 = dim2.cuda()
dim1 = dim1.cuda()
y0 = y0.cuda()
y1 = y1.cuda()
x0 = x0.cuda()
x1 = x1.cuda()
base = base.cuda()
base_y0 = base + y0 * dim2
base_y1 = base + y1 * dim2
idx_a = base_y0 + x0
idx_b = base_y1 + x0
idx_c = base_y0 + x1
idx_d = base_y1 + x1
# channels dim
im = im.permute(0,2,3,1)
im_flat = im.reshape([-1, num_channels]).float()
idx_a = idx_a.unsqueeze(-1).long()
idx_a = idx_a.expand(out_height * out_width * num_batch,num_channels)
Ia = torch.gather(im_flat, 0, idx_a)
idx_b = idx_b.unsqueeze(-1).long()
idx_b = idx_b.expand(out_height * out_width * num_batch, num_channels)
Ib = torch.gather(im_flat, 0, idx_b)
idx_c = idx_c.unsqueeze(-1).long()
idx_c = idx_c.expand(out_height * out_width * num_batch, num_channels)
Ic = torch.gather(im_flat, 0, idx_c)
idx_d = idx_d.unsqueeze(-1).long()
idx_d = idx_d.expand(out_height * out_width * num_batch, num_channels)
Id = torch.gather(im_flat, 0, idx_d)
x0_f = x0.float()
x1_f = x1.float()
y0_f = y0.float()
y1_f = y1.float()
wa = torch.unsqueeze(((x1_f - x) * (y1_f - y)), 1)
wb = torch.unsqueeze(((x1_f - x) * (y - y0_f)), 1)
wc = torch.unsqueeze(((x - x0_f) * (y1_f - y)), 1)
wd = torch.unsqueeze(((x - x0_f) * (y - y0_f)), 1)
output = wa*Ia+wb*Ib+wc*Ic+wd*Id
return output
def _meshgrid(height, width, source):
source = source.cpu()
x_t = torch.matmul(torch.ones([height, 1]), torch.unsqueeze(torch.linspace(-1.0, 1.0, width), 0))
y_t = torch.matmul(torch.unsqueeze(torch.linspace(-1.0, 1.0, height), 1), torch.ones([1, width]))
x_t_flat = x_t.reshape([1, 1, -1])
y_t_flat = y_t.reshape([1, 1, -1])
num_batch = source.size()[0]
px = torch.unsqueeze(source[:,:,0], 2) # [bn, pn, 1]
py = torch.unsqueeze(source[:,:,1], 2) # [bn, pn, 1]
d2 = torch.square(x_t_flat - px) + torch.square(y_t_flat - py)
r = d2 * torch.log(d2 + 1e-6) # [bn, pn, h*w]
x_t_flat_g = x_t_flat.expand(num_batch, -1, -1) # [bn, 1, h*w]
y_t_flat_g = y_t_flat.expand(num_batch, -1, -1) # [bn, 1, h*w]
ones = torch.ones_like(x_t_flat_g) # [bn, 1, h*w]
grid = torch.cat((ones, x_t_flat_g, y_t_flat_g, r), 1) # [bn, 3+pn, h*w]
#if torch.cuda.is_available():
grid = grid.cuda()
return grid
def _transform(T, source, input_dim, out_size):
num_batch, num_channels, height, width = input_dim.size()
out_height, out_width = out_size[0], out_size[1]
grid = _meshgrid(out_height, out_width, source) # [bn, 3+pn, h*w]
#print(grid.device)
# transform A x (1, x_t, y_t, r1, r2, ..., rn) -> (x_s, y_s)
# [bn, 2, pn+3] x [bn, pn+3, h*w] -> [bn, 2, h*w]
T_g = torch.matmul(T, grid)
x_s = T_g[:,0,:]
y_s = T_g[:,1,:]
x_s_flat = x_s.reshape([-1])
y_s_flat = y_s.reshape([-1])
input_transformed = _interpolate(input_dim, x_s_flat, y_s_flat,out_size)
output = input_transformed.reshape([num_batch, out_height, out_width, num_channels])
output = output.permute(0,3,1,2)
#print(output.device)
return output#, condition
def _solve_system(source, target):
num_batch = source.size()[0]
num_point = source.size()[1]
np.set_printoptions(precision=8)
ones = torch.ones(num_batch, num_point, 1).float()
if torch.cuda.is_available():
ones = ones.cuda()
p = torch.cat([ones, source], 2) # [bn, pn, 3]
p_1 = p.reshape([num_batch, -1, 1, 3]) # [bn, pn, 1, 3]
p_2 = p.reshape([num_batch, 1, -1, 3]) # [bn, 1, pn, 3]
d2 = torch.sum(torch.square(p_1-p_2), 3) # p1 - p2: [bn, pn, pn, 3] final output: [bn, pn, pn]
r = d2 * torch.log(d2 + 1e-6) # [bn, pn, pn]
zeros = torch.zeros(num_batch, 3, 3).float()
if torch.cuda.is_available():
zeros = zeros.cuda()
W_0 = torch.cat((p, r), 2) # [bn, pn, 3+pn]
W_1 = torch.cat((zeros, p.permute(0,2,1)), 2) # [bn, 3, pn+3]
W = torch.cat((W_0, W_1), 1) # [bn, pn+3, pn+3]
W_inv = torch.inverse(W.type(torch.float64))
zeros2 = torch.zeros(num_batch, 3, 2)
if torch.cuda.is_available():
zeros2 = zeros2.cuda()
tp = torch.cat((target, zeros2), 1) # [bn, pn+3, 2]
T = torch.matmul(W_inv, tp.type(torch.float64)) # [bn, pn+3, 2]
T = T.permute(0, 2, 1) # [bn, 2, pn+3]
return T.type(torch.float32)
T = _solve_system(source, target)
output = _transform(T, source, U, out_size)
return output#, condition
================================================
FILE: Warp/model/.txt
================================================
================================================
FILE: Warp/readme.md
================================================
## Train on UDIS-D
Set the training dataset path in Warp/Codes/train.py.
```
python train.py
```
## Test on UDIS-D
The pre-trained model of warp is available at [Google Drive](https://drive.google.com/file/d/1GBwB0y3tUUsOYHErSqxDxoC_Om3BJUEt/view?usp=sharing) or [Baidu Cloud](https://pan.baidu.com/s/1Fx6YnQi9B2wvP_TOVAaBEA)(Extraction code: 1234).
#### Calculate PSNR/SSIM
Set the testing dataset path in Warp/Codes/test.py.
```
python test.py
```
#### Generate the warped images and corresponding masks
Set the training/testing dataset path in Warp/Codes/test_output.py.
```
python test_output.py
```
The warped images and masks will be generated and saved at the original training/testing dataset path. The results of average fusion will be saved at the current path.
## Test on other datasets
When testing on other datasets with different scenes and resolutions, we apply the iterative warp adaption to get better alignment performance.
Set the 'path/img1_name/img2_name' in Warp/Codes/test_other.py. (By default, both img1 and img2 are placed under 'path')
```
python test_other.py
```
The results before/after adaption will be generated and saved at 'path'.
================================================
FILE: Warp/summary/.txt
================================================
================================================
FILE: environment.yml
================================================
name: nl
channels:
- https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/pytorch
- https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main
- https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main/
- https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/pytorch/
- https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/free/
- defaults
dependencies:
- _anaconda_depends=2020.07=py38_0
- _ipyw_jlab_nb_ext_conf=0.1.0=py38_0
- _libgcc_mutex=0.1=main
- alabaster=0.7.12=py_0
- anaconda=custom=py38_1
- anaconda-client=1.7.2=py38_0
- anaconda-navigator=1.10.0=py38_0
- anaconda-project=0.8.4=py_0
- argh=0.26.2=py38_0
- argon2-cffi=20.1.0=py38h7b6447c_1
- asn1crypto=1.4.0=py_0
- astroid=2.4.2=py38_0
- astropy=4.0.2=py38h7b6447c_0
- async_generator=1.10=py_0
- atomicwrites=1.4.0=py_0
- attrs=20.3.0=pyhd3eb1b0_0
- autopep8=1.5.4=py_0
- babel=2.8.1=pyhd3eb1b0_0
- backcall=0.2.0=py_0
- backports=1.0=py_2
- backports.functools_lru_cache=1.6.1=py_0
- backports.shutil_get_terminal_size=1.0.0=py38_2
- backports.tempfile=1.0=py_1
- backports.weakref=1.0.post1=py_1
- beautifulsoup4=4.9.3=pyhb0f4dca_0
- bitarray=1.6.1=py38h27cfd23_0
- bkcharts=0.2=py38_0
- blas=1.0=mkl
- bleach=3.2.1=py_0
- blosc=1.20.1=hd408876_0
- bokeh=2.2.3=py38_0
- boto=2.49.0=py38_0
- bottleneck=1.3.2=py38heb32a55_1
- brotlipy=0.7.0=py38h7b6447c_1000
- bzip2=1.0.8=h7b6447c_0
- ca-certificates=2021.4.13=h06a4308_1
- cairo=1.14.12=h8948797_3
- certifi=2020.12.5=py38h06a4308_0
- cffi=1.14.3=py38he30daa8_0
- chardet=3.0.4=py38_1003
- click=7.1.2=py_0
- cloudpickle=1.6.0=py_0
- clyent=1.2.2=py38_1
- colorama=0.4.4=py_0
- conda-package-handling=1.7.2=py38h03888b9_0
- conda-verify=3.4.2=py_1
- contextlib2=0.6.0.post1=py_0
- cryptography=3.1.1=py38h1ba5d50_0
- cudatoolkit=11.0.221=h6bb024c_0
- curl=7.71.1=hbc83047_1
- cycler=0.10.0=py38_0
- cython=0.29.21=py38he6710b0_0
- cytoolz=0.11.0=py38h7b6447c_0
- dask=2.30.0=py_0
- dask-core=2.30.0=py_0
- dbus=1.13.18=hb2f20db_0
- decorator=4.4.2=py_0
- defusedxml=0.6.0=py_0
- diff-match-patch=20200713=py_0
- distributed=2.30.1=py38h06a4308_0
- docutils=0.16=py38_1
- entrypoints=0.3=py38_0
- et_xmlfile=1.0.1=py_1001
- expat=2.2.10=he6710b0_2
- fastcache=1.1.0=py38h7b6447c_0
- filelock=3.0.12=py_0
- flake8=3.8.4=py_0
- flask=1.1.2=py_0
- fontconfig=2.13.0=h9420a91_0
- freetype=2.10.4=h5ab3b9f_0
- fribidi=1.0.10=h7b6447c_0
- fsspec=0.8.3=py_0
- future=0.18.2=py38_1
- get_terminal_size=1.0.0=haa9412d_0
- gevent=20.9.0=py38h7b6447c_0
- glib=2.66.1=h92f7085_0
- glob2=0.7=py_0
- gmp=6.1.2=h6c8ec71_1
- gmpy2=2.0.8=py38hd5f6e3b_3
- graphite2=1.3.14=h23475e2_0
- greenlet=0.4.17=py38h7b6447c_0
- gst-plugins-base=1.14.0=hbbd80ab_1
- gstreamer=1.14.0=hb31296c_0
- h5py=2.10.0=py38h7918eee_0
- harfbuzz=2.4.0=hca77d97_1
- hdf5=1.10.4=hb1b8bf9_0
- heapdict=1.0.1=py_0
- html5lib=1.1=py_0
- icu=58.2=he6710b0_3
- idna=2.10=py_0
- imageio=2.9.0=py_0
- imagesize=1.2.0=py_0
- importlib_metadata=2.0.0=1
- iniconfig=1.1.1=py_0
- intel-openmp=2020.2=254
- intervaltree=3.1.0=py_0
- ipykernel=5.3.4=py38h5ca1d4c_0
- ipython=7.19.0=py38hb070fc8_0
- ipython_genutils=0.2.0=py38_0
- ipywidgets=7.5.1=py_1
- isort=5.6.4=py_0
- itsdangerous=1.1.0=py_0
- jbig=2.1=hdba287a_0
- jdcal=1.4.1=py_0
- jedi=0.17.1=py38_0
- jeepney=0.5.0=pyhd3eb1b0_0
- jinja2=2.11.2=py_0
- joblib=0.17.0=py_0
- jpeg=9b=h024ee3a_2
- json5=0.9.5=py_0
- jsonschema=3.2.0=py_2
- jupyter=1.0.0=py38_7
- jupyter_client=6.1.7=py_0
- jupyter_console=6.2.0=py_0
- jupyter_core=4.6.3=py38_0
- jupyterlab=2.2.6=py_0
- jupyterlab_pygments=0.1.2=py_0
- jupyterlab_server=1.2.0=py_0
- keyring=21.4.0=py38_1
- kiwisolver=1.3.0=py38h2531618_0
- krb5=1.18.2=h173b8e3_0
- lazy-object-proxy=1.4.3=py38h7b6447c_0
- lcms2=2.11=h396b838_0
- ld_impl_linux-64=2.33.1=h53a641e_7
- libarchive=3.4.2=h62408e4_0
- libcurl=7.71.1=h20c2e04_1
- libedit=3.1.20191231=h14c3975_1
- libffi=3.3=he6710b0_2
- libgcc-ng=9.1.0=hdf63c60_0
- libgfortran-ng=7.3.0=hdf63c60_0
- liblief=0.10.1=he6710b0_0
- libllvm10=10.0.1=hbcb73fb_5
- libllvm9=9.0.1=h4a3c616_1
- libpng=1.6.37=hbc83047_0
- libsodium=1.0.18=h7b6447c_0
- libspatialindex=1.9.3=he6710b0_0
- libssh2=1.9.0=h1ba5d50_1
- libstdcxx-ng=9.1.0=hdf63c60_0
- libtiff=4.1.0=h2733197_1
- libtool=2.4.6=h7b6447c_1005
- libuuid=1.0.3=h1bed415_2
- libuv=1.40.0=h7b6447c_0
- libxcb=1.14=h7b6447c_0
- libxml2=2.9.10=hb55368b_3
- libxslt=1.1.34=hc22bd24_0
- llvmlite=0.34.0=py38h269e1b5_4
- locket=0.2.0=py38_1
- lxml=4.6.1=py38hefd8a0e_0
- lz4-c=1.9.2=heb0550a_3
- lzo=2.10=h7b6447c_2
- markupsafe=1.1.1=py38h7b6447c_0
- matplotlib=3.3.2=0
- matplotlib-base=3.3.2=py38h817c723_0
- mccabe=0.6.1=py38_1
- mistune=0.8.4=py38h7b6447c_1000
- mkl=2020.2=256
- mkl-service=2.3.0=py38he904b0f_0
- mkl_fft=1.2.0=py38h23d657b_0
- mkl_random=1.1.1=py38h0573a6f_0
- mock=4.0.2=py_0
- more-itertools=8.6.0=pyhd3eb1b0_0
- mpc=1.1.0=h10f8cd9_1
- mpfr=4.0.2=hb69a4c5_1
- mpmath=1.1.0=py38_0
- msgpack-python=1.0.0=py38hfd86e86_1
- multipledispatch=0.6.0=py38_0
- navigator-updater=0.2.1=py38_0
- nbclient=0.5.1=py_0
- nbconvert=6.0.7=py38_0
- nbformat=5.0.8=py_0
- ncurses=6.2=he6710b0_1
- nest-asyncio=1.4.2=pyhd3eb1b0_0
- networkx=2.5=py_0
- ninja=1.10.2=hff7bd54_1
- nltk=3.5=py_0
- nose=1.3.7=py38_2
- notebook=6.1.4=py38_0
- numba=0.51.2=py38h0573a6f_1
- numexpr=2.7.1=py38h423224d_0
- numpy-base=1.19.2=py38hfa32c7d_0
- numpydoc=1.1.0=pyhd3eb1b0_1
- olefile=0.46=py_0
- openpyxl=3.0.5=py_0
- openssl=1.1.1k=h27cfd23_0
- packaging=20.4=py_0
- pandas=1.1.3=py38he6710b0_0
- pandoc=2.11=hb0f4dca_0
- pandocfilters=1.4.3=py38h06a4308_1
- pango=1.45.3=hd140c19_0
- parso=0.7.0=py_0
- partd=1.1.0=py_0
- patchelf=0.12=he6710b0_0
- path=15.0.0=py38_0
- path.py=12.5.0=0
- pathlib2=2.3.5=py38_0
- pathtools=0.1.2=py_1
- patsy=0.5.1=py38_0
- pcre=8.44=he6710b0_0
- pep8=1.7.1=py38_0
- pexpect=4.8.0=py38_0
- pickleshare=0.7.5=py38_1000
- pip=20.2.4=py38h06a4308_0
- pixman=0.40.0=h7b6447c_0
- pkginfo=1.6.1=py38h06a4308_0
- pluggy=0.13.1=py38_0
- ply=3.11=py38_0
- prometheus_client=0.8.0=py_0
- prompt-toolkit=3.0.8=py_0
- prompt_toolkit=3.0.8=0
- psutil=5.7.2=py38h7b6447c_0
- ptyprocess=0.6.0=py38_0
- py=1.9.0=py_0
- py-lief=0.10.1=py38h403a769_0
- pycodestyle=2.6.0=py_0
- pycosat=0.6.3=py38h7b6447c_1
- pycparser=2.20=py_2
- pycurl=7.43.0.6=py38h1ba5d50_0
- pydocstyle=5.1.1=py_0
- pyflakes=2.2.0=py_0
- pygments=2.7.2=pyhd3eb1b0_0
- pylint=2.6.0=py38_0
- pyodbc=4.0.30=py38he6710b0_0
- pyopenssl=19.1.0=py_1
- pyparsing=2.4.7=py_0
- pyqt=5.9.2=py38h05f1152_4
- pyrsistent=0.17.3=py38h7b6447c_0
- pysocks=1.7.1=py38_0
- pytables=3.6.1=py38h9fd0a39_0
- pytest=6.1.1=py38_0
- python=3.8.5=h7579374_1
- python-dateutil=2.8.1=py_0
- python-jsonrpc-server=0.4.0=py_0
- python-language-server=0.35.1=py_0
- python-libarchive-c=2.9=py_0
- pytorch=1.7.1=py3.8_cuda11.0.221_cudnn8.0.5_0
- pytz=2020.1=py_0
- pywavelets=1.1.1=py38h7b6447c_2
- pyxdg=0.27=pyhd3eb1b0_0
- pyyaml=5.3.1=py38h7b6447c_1
- pyzmq=19.0.2=py38he6710b0_1
- qdarkstyle=2.8.1=py_0
- qt=5.9.7=h5867ecd_1
- qtawesome=1.0.1=py_0
- qtconsole=4.7.7=py_0
- qtpy=1.9.0=py_0
- readline=8.0=h7b6447c_0
- regex=2020.10.15=py38h7b6447c_0
- requests=2.24.0=py_0
- ripgrep=12.1.1=0
- rope=0.18.0=py_0
- rtree=0.9.4=py38_1
- ruamel_yaml=0.15.87=py38h7b6447c_1
- scikit-learn=0.23.2=py38h0573a6f_0
- scipy=1.5.2=py38h0b6359f_0
- seaborn=0.11.0=py_0
- secretstorage=3.1.2=py38_0
- send2trash=1.5.0=py38_0
- setuptools=50.3.1=py38h06a4308_1
- simplegeneric=0.8.1=py38_2
- singledispatch=3.4.0.3=py_1001
- sip=4.19.13=py38he6710b0_0
- six=1.15.0=py38h06a4308_0
- snappy=1.1.8=he6710b0_0
- snowballstemmer=2.0.0=py_0
- sortedcollections=1.2.1=py_0
- sortedcontainers=2.2.2=py_0
- soupsieve=2.0.1=py_0
- sphinx=3.2.1=py_0
- sphinxcontrib=1.0=py38_1
- sphinxcontrib-applehelp=1.0.2=py_0
- sphinxcontrib-devhelp=1.0.2=py_0
- sphinxcontrib-htmlhelp=1.0.3=py_0
- sphinxcontrib-jsmath=1.0.1=py_0
- sphinxcontrib-qthelp=1.0.3=py_0
- sphinxcontrib-serializinghtml=1.1.4=py_0
- sphinxcontrib-websupport=1.2.4=py_0
- spyder=4.1.5=py38_0
- spyder-kernels=1.9.4=py38_0
- sqlalchemy=1.3.20=py38h7b6447c_0
- sqlite=3.33.0=h62c20be_0
- statsmodels=0.12.0=py38h7b6447c_0
- sympy=1.6.2=py38h06a4308_1
- tbb=2020.3=hfd86e86_0
- tblib=1.7.0=py_0
- terminado=0.9.1=py38_0
- testpath=0.4.4=py_0
- threadpoolctl=2.1.0=pyh5ca1d4c_0
- tifffile=2020.10.1=py38hdd07704_2
- tk=8.6.10=hbc83047_0
- toml=0.10.1=py_0
- toolz=0.11.1=py_0
- torchvision=0.8.2=py38_cu110
- tornado=6.0.4=py38h7b6447c_1
- tqdm=4.50.2=py_0
- traitlets=5.0.5=py_0
- typing_extensions=3.7.4.3=py_0
- ujson=4.0.1=py38he6710b0_0
- unicodecsv=0.14.1=py38_0
- unixodbc=2.3.9=h7b6447c_0
- urllib3=1.25.11=py_0
- watchdog=0.10.3=py38_0
- wcwidth=0.2.5=py_0
- webencodings=0.5.1=py38_1
- werkzeug=1.0.1=py_0
- wheel=0.35.1=py_0
- widgetsnbextension=3.5.1=py38_0
- wrapt=1.11.2=py38h7b6447c_0
- wurlitzer=2.0.1=py38_0
- xlrd=1.2.0=py_0
- xlsxwriter=1.3.7=py_0
- xlwt=1.3.0=py38_0
- xmltodict=0.12.0=py_0
- xz=5.2.5=h7b6447c_0
- yaml=0.2.5=h7b6447c_0
- yapf=0.30.0=py_0
- zeromq=4.3.3=he6710b0_3
- zict=2.0.0=py_0
- zipp=3.4.0=pyhd3eb1b0_0
- zlib=1.2.11=h7b6447c_3
- zope=1.0=py38_1
- zope.event=4.5.0=py38_0
- zope.interface=5.1.2=py38h7b6447c_0
- zstd=1.4.5=h9ceee32_0
- pip:
- absl-py==0.12.0
- cachetools==5.2.0
- einops==0.3.0
- google-auth==2.8.0
- google-auth-oauthlib==0.4.6
- grpcio==1.46.3
- importlib-metadata==4.11.4
- markdown==3.3.7
- medpy==0.4.0
- ml-collections==0.1.0
- numpy==1.19.5
- oauthlib==3.2.0
- opencv-python-headless==4.5.1.48
- pillow==9.1.1
- protobuf==3.17.0
- pyasn1==0.4.8
- pyasn1-modules==0.2.8
- requests-oauthlib==1.3.1
- rsa==4.8
- scikit-image==0.15.0
- simpleitk==2.0.2
- tensorboard==2.9.0
- tensorboard-data-server==0.6.1
- tensorboard-plugin-wit==1.8.1
- timm==0.4.9
- yacs==0.1.6
prefix: /root/anaconda3/envs/nl
gitextract_hn9kepfw/ ├── Composition/ │ ├── Codes/ │ │ ├── dataset.py │ │ ├── loss.py │ │ ├── network.py │ │ ├── test.py │ │ ├── test_other.py │ │ └── train.py │ ├── model/ │ │ └── .txt │ ├── readme.md │ └── summary/ │ └── .txt ├── LICENSE ├── README.md ├── Warp/ │ ├── Codes/ │ │ ├── dataset.py │ │ ├── grid_res.py │ │ ├── loss.py │ │ ├── network.py │ │ ├── test.py │ │ ├── test_other.py │ │ ├── test_output.py │ │ ├── train.py │ │ └── utils/ │ │ ├── torch_DLT.py │ │ ├── torch_homo_transform.py │ │ ├── torch_tps_transform.py │ │ └── torch_tps_transform2.py │ ├── model/ │ │ └── .txt │ ├── readme.md │ └── summary/ │ └── .txt └── environment.yml
SYMBOL INDEX (67 symbols across 17 files)
FILE: Composition/Codes/dataset.py
class TrainDataset (line 10) | class TrainDataset(Dataset):
method __init__ (line 11) | def __init__(self, data_path):
method __getitem__ (line 26) | def __getitem__(self, index):
method __len__ (line 69) | def __len__(self):
class TestDataset (line 73) | class TestDataset(Dataset):
method __init__ (line 74) | def __init__(self, data_path):
method __getitem__ (line 90) | def __getitem__(self, index):
method __len__ (line 125) | def __len__(self):
FILE: Composition/Codes/loss.py
function l_num_loss (line 27) | def l_num_loss(img1, img2, l_num=1):
function boundary_extraction (line 31) | def boundary_extraction(mask):
function cal_boundary_term (line 66) | def cal_boundary_term(inpu1_tesnor, inpu2_tesnor, mask1_tesnor, mask2_te...
function cal_smooth_term_stitch (line 76) | def cal_smooth_term_stitch(stitched_image, learned_mask1):
function cal_smooth_term_diff (line 94) | def cal_smooth_term_diff(img1, img2, learned_mask1, overlap):
FILE: Composition/Codes/network.py
function build_model (line 7) | def build_model(net, warp1_tensor, warp2_tensor, mask1_tensor, mask2_ten...
class DownBlock (line 22) | class DownBlock(nn.Module):
method __init__ (line 23) | def __init__(self, inchannels, outchannels, dilation, pool=True):
method forward (line 41) | def forward(self, x):
class UpBlock (line 44) | class UpBlock(nn.Module):
method __init__ (line 45) | def __init__(self, inchannels, outchannels, dilation):
method forward (line 67) | def forward(self, x1, x2):
class Network (line 76) | class Network(nn.Module):
method __init__ (line 77) | def __init__(self, nclasses=1):
method forward (line 105) | def forward(self, x, y, m1, m2):
FILE: Composition/Codes/test.py
function test (line 17) | def test(args):
FILE: Composition/Codes/test_other.py
function loadSingleData (line 15) | def loadSingleData(data_path):
function test_other (line 50) | def test_other(args):
FILE: Composition/Codes/train.py
function train (line 32) | def train(args):
FILE: Warp/Codes/dataset.py
class TrainDataset (line 10) | class TrainDataset(Dataset):
method __init__ (line 11) | def __init__(self, data_path):
method __getitem__ (line 28) | def __getitem__(self, index):
method __len__ (line 57) | def __len__(self):
class TestDataset (line 61) | class TestDataset(Dataset):
method __init__ (line 62) | def __init__(self, data_path):
method __getitem__ (line 79) | def __getitem__(self, index):
method __len__ (line 101) | def __len__(self):
FILE: Warp/Codes/loss.py
function l_num_loss (line 10) | def l_num_loss(img1, img2, l_num=1):
function cal_lp_loss (line 14) | def cal_lp_loss(input1, input2, output_H, output_H_inv, warp_mesh, warp_...
function cal_lp_loss2 (line 37) | def cal_lp_loss2(input1, warp_mesh, warp_mesh_mask):
function inter_grid_loss (line 48) | def inter_grid_loss(overlap, mesh):
function intra_grid_loss (line 84) | def intra_grid_loss(pts):
FILE: Warp/Codes/network.py
function draw_mesh_on_warp (line 23) | def draw_mesh_on_warp(warp, f_local):
function H2Mesh (line 50) | def H2Mesh(H, rigid_mesh):
function get_rigid_mesh (line 69) | def get_rigid_mesh(batch_size, height, width):
function get_norm_mesh (line 83) | def get_norm_mesh(mesh, height, width):
function data_aug (line 95) | def data_aug(img1, img2):
function build_model (line 120) | def build_model(net, input1_tensor, input2_tensor, is_training = True):
function build_new_ft_model (line 191) | def build_new_ft_model(net, input1_tensor, input2_tensor):
function get_stitched_result (line 235) | def get_stitched_result(input1_tensor, input2_tensor, rigid_mesh, mesh):
function build_output_model (line 286) | def build_output_model(net, input1_tensor, input2_tensor):
class Network (line 366) | class Network(nn.Module):
method __init__ (line 368) | def __init__(self):
method get_res50_FeatureMap (line 455) | def get_res50_FeatureMap(self, resnet50_model):
method forward (line 475) | def forward(self, input1_tesnor, input2_tesnor):
method extract_patches (line 522) | def extract_patches(self, x, kernel=3, stride=1):
method CCL (line 530) | def CCL(self, feature_1, feature_2):
FILE: Warp/Codes/test.py
function create_gif (line 18) | def create_gif(image_list, gif_name, duration=0.35):
function test (line 26) | def test(args):
FILE: Warp/Codes/test_other.py
function loadSingleData (line 21) | def loadSingleData(data_path, img1_name, img2_name):
function train (line 56) | def train(args):
FILE: Warp/Codes/test_output.py
function draw_mesh_on_warp (line 20) | def draw_mesh_on_warp(warp, f_local):
function create_gif (line 44) | def create_gif(image_list, gif_name, duration=0.35):
function test (line 54) | def test(args):
FILE: Warp/Codes/train.py
function train (line 28) | def train(args):
FILE: Warp/Codes/utils/torch_DLT.py
function tensor_DLT (line 17) | def tensor_DLT(src_p, dst_p):
FILE: Warp/Codes/utils/torch_homo_transform.py
function transformer (line 5) | def transformer(U, theta, out_size, **kwargs):
FILE: Warp/Codes/utils/torch_tps_transform.py
function transformer (line 7) | def transformer(U, source, target, out_size):
FILE: Warp/Codes/utils/torch_tps_transform2.py
function transformer (line 10) | def transformer(U, source, target, out_size):
Condensed preview — 27 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (133K chars).
[
{
"path": "Composition/Codes/dataset.py",
"chars": 4488,
"preview": "from torch.utils.data import Dataset\r\nimport numpy as np\r\nimport cv2, torch\r\nimport os\r\nimport glob\r\nfrom collections i"
},
{
"path": "Composition/Codes/loss.py",
"chars": 3800,
"preview": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\n\n\n# def get_vgg19_FeatureMap(vgg_model, input_255, l"
},
{
"path": "Composition/Codes/network.py",
"chars": 4116,
"preview": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\n\n\ndef build_model(net, warp1_tensor, warp2_tensor, m"
},
{
"path": "Composition/Codes/test.py",
"chars": 3524,
"preview": "# coding: utf-8\nimport argparse\nimport torch\nfrom torch.utils.data import DataLoader\nfrom network import build_model, Ne"
},
{
"path": "Composition/Codes/test_other.py",
"chars": 3992,
"preview": "# coding: utf-8\nimport argparse\nimport torch\nfrom network import build_model, Network\nimport os\nimport numpy as np\nimpor"
},
{
"path": "Composition/Codes/train.py",
"chars": 7660,
"preview": "import argparse\r\nimport torch\r\nfrom torch.utils.data import DataLoader\r\nimport os\r\nimport torch.optim as optim\r\nfrom tor"
},
{
"path": "Composition/model/.txt",
"chars": 1,
"preview": "\n"
},
{
"path": "Composition/readme.md",
"chars": 817,
"preview": "## Train on UDIS-D\nBefore training, the warped images and corresponding masks should be generated in the warp stage.\n\nTh"
},
{
"path": "Composition/summary/.txt",
"chars": 1,
"preview": "\n"
},
{
"path": "LICENSE",
"chars": 11357,
"preview": " Apache License\n Version 2.0, January 2004\n "
},
{
"path": "README.md",
"chars": 2536,
"preview": "# <p align=\"center\">Parallax-Tolerant Unsupervised Deep Image Stitching (UDIS++ [paper](https://arxiv.org/abs/2302.08207"
},
{
"path": "Warp/Codes/dataset.py",
"chars": 3597,
"preview": "from torch.utils.data import Dataset\r\nimport numpy as np\r\nimport cv2, torch\r\nimport os\r\nimport glob\r\nfrom collections i"
},
{
"path": "Warp/Codes/grid_res.py",
"chars": 81,
"preview": "\n#define control point resolution (GRID_H+1) * (GRID_W+1)\nGRID_H = 12\nGRID_W = 12"
},
{
"path": "Warp/Codes/loss.py",
"chars": 4164,
"preview": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nimport grid_res\ngrid_h = grid_res.GRID_H\ngrid_w = gr"
},
{
"path": "Warp/Codes/network.py",
"chars": 22382,
"preview": "import torch\nimport torch.nn as nn\nimport utils.torch_DLT as torch_DLT\nimport utils.torch_homo_transform as torch_homo_t"
},
{
"path": "Warp/Codes/test.py",
"chars": 4141,
"preview": "# coding: utf-8\r\nimport argparse\r\nimport torch\r\nfrom torch.utils.data import DataLoader\r\nimport torch.nn as nn\r\nimport i"
},
{
"path": "Warp/Codes/test_other.py",
"chars": 6954,
"preview": "import argparse\nimport torch\n\nimport numpy as np\nimport os\nimport torch.nn as nn\nimport torch.optim as optim\n\nimport cv2"
},
{
"path": "Warp/Codes/test_output.py",
"chars": 5868,
"preview": "# coding: utf-8\nimport argparse\nimport torch\nfrom torch.utils.data import DataLoader\nimport torch.nn as nn\nimport imagei"
},
{
"path": "Warp/Codes/train.py",
"chars": 6760,
"preview": "import argparse\r\nimport torch\r\nfrom torch.utils.data import DataLoader\r\nimport os\r\nimport torch.optim as optim\r\nfrom tor"
},
{
"path": "Warp/Codes/utils/torch_DLT.py",
"chars": 1280,
"preview": "import torch\nimport numpy as np\nimport cv2\n\n# src_p: shape=(bs, 4, 2)\n# det_p: shape=(bs, 4, 2)\n#\n# "
},
{
"path": "Warp/Codes/utils/torch_homo_transform.py",
"chars": 4997,
"preview": "import torch\nimport numpy as np\n\n\ndef transformer(U, theta, out_size, **kwargs):\n\n\n def _repeat(x, n_repeats):\n\n "
},
{
"path": "Warp/Codes/utils/torch_tps_transform.py",
"chars": 6386,
"preview": "import torch\nimport numpy as np\n\n# transforming an image (U) from target (control points) to source (control points)\n# a"
},
{
"path": "Warp/Codes/utils/torch_tps_transform2.py",
"chars": 6339,
"preview": "import torch\nimport numpy as np\n\n\n# transforming an image (U) from target (control points) to source (control points)\n# "
},
{
"path": "Warp/model/.txt",
"chars": 1,
"preview": "\n"
},
{
"path": "Warp/readme.md",
"chars": 1173,
"preview": "## Train on UDIS-D\nSet the training dataset path in Warp/Codes/train.py.\n\n```\npython train.py\n```\n\n## Test on UDIS-D\nThe"
},
{
"path": "Warp/summary/.txt",
"chars": 1,
"preview": "\n"
},
{
"path": "environment.yml",
"chars": 10467,
"preview": "name: nl\nchannels:\n - https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/pytorch\n - https://mirrors.tuna.tsinghua.edu"
}
]
About this extraction
This page contains the full source code of the nie-lang/UDIS2 GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 27 files (123.9 KB), approximately 38.7k tokens, and a symbol index with 67 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.
Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.