Full Code of nie-lang/UDIS2 for AI

main b4158950ab14 cached
27 files
123.9 KB
38.7k tokens
67 symbols
1 requests
Download .txt
Repository: nie-lang/UDIS2
Branch: main
Commit: b4158950ab14
Files: 27
Total size: 123.9 KB

Directory structure:
gitextract_hn9kepfw/

├── Composition/
│   ├── Codes/
│   │   ├── dataset.py
│   │   ├── loss.py
│   │   ├── network.py
│   │   ├── test.py
│   │   ├── test_other.py
│   │   └── train.py
│   ├── model/
│   │   └── .txt
│   ├── readme.md
│   └── summary/
│       └── .txt
├── LICENSE
├── README.md
├── Warp/
│   ├── Codes/
│   │   ├── dataset.py
│   │   ├── grid_res.py
│   │   ├── loss.py
│   │   ├── network.py
│   │   ├── test.py
│   │   ├── test_other.py
│   │   ├── test_output.py
│   │   ├── train.py
│   │   └── utils/
│   │       ├── torch_DLT.py
│   │       ├── torch_homo_transform.py
│   │       ├── torch_tps_transform.py
│   │       └── torch_tps_transform2.py
│   ├── model/
│   │   └── .txt
│   ├── readme.md
│   └── summary/
│       └── .txt
└── environment.yml

================================================
FILE CONTENTS
================================================

================================================
FILE: Composition/Codes/dataset.py
================================================
from torch.utils.data import Dataset
import  numpy as np
import cv2, torch
import os
import glob
from collections import OrderedDict
import random


class TrainDataset(Dataset):
    def __init__(self, data_path):

        self.train_path = data_path
        self.datas = OrderedDict()

        datas = glob.glob(os.path.join(self.train_path, '*'))
        for data in sorted(datas):
            data_name = data.split('/')[-1]
            if data_name == 'warp1' or data_name == 'warp2' or data_name == 'mask1' or data_name == 'mask2':
                self.datas[data_name] = {}
                self.datas[data_name]['path'] = data
                self.datas[data_name]['image'] = glob.glob(os.path.join(data, '*.jpg'))
                self.datas[data_name]['image'].sort()
        print(self.datas.keys())

    def __getitem__(self, index):

        # load image1
        warp1 = cv2.imread(self.datas['warp1']['image'][index])
        warp1 = warp1.astype(dtype=np.float32)
        warp1 = (warp1 / 127.5) - 1.0
        warp1 = np.transpose(warp1, [2, 0, 1])

        # load image2
        warp2 = cv2.imread(self.datas['warp2']['image'][index])
        warp2 = warp2.astype(dtype=np.float32)
        warp2 = (warp2 / 127.5) - 1.0
        warp2 = np.transpose(warp2, [2, 0, 1])

        # load mask1
        mask1 = cv2.imread(self.datas['mask1']['image'][index])
        mask1 = mask1.astype(dtype=np.float32)
        mask1 = np.expand_dims(mask1[:,:,0], 2) / 255
        mask1 = np.transpose(mask1, [2, 0, 1])

        # load mask2
        mask2 = cv2.imread(self.datas['mask2']['image'][index])
        mask2 = mask2.astype(dtype=np.float32)
        mask2 = np.expand_dims(mask2[:,:,0], 2) / 255
        mask2 = np.transpose(mask2, [2, 0, 1])

        # convert to tensor
        warp1_tensor = torch.tensor(warp1)
        warp2_tensor = torch.tensor(warp2)
        mask1_tensor = torch.tensor(mask1)
        mask2_tensor = torch.tensor(mask2)

        #return (input1_tensor, input2_tensor, mask1_tensor, mask2_tensor)

        if_exchange = random.randint(0,1)
        if if_exchange == 0:
            #print(if_exchange)
            return (warp1_tensor, warp2_tensor, mask1_tensor, mask2_tensor)
        else:
            #print(if_exchange)
            return (warp2_tensor, warp1_tensor, mask2_tensor, mask1_tensor)


    def __len__(self):

        return len(self.datas['warp1']['image'])

class TestDataset(Dataset):
    def __init__(self, data_path):

        self.test_path = data_path
        self.datas = OrderedDict()

        datas = glob.glob(os.path.join(self.test_path, '*'))
        for data in sorted(datas):
            data_name = data.split('/')[-1]
            if data_name == 'warp1' or data_name == 'warp2' or data_name == 'mask1' or data_name == 'mask2':
                self.datas[data_name] = {}
                self.datas[data_name]['path'] = data
                self.datas[data_name]['image'] = glob.glob(os.path.join(data, '*.jpg'))
                self.datas[data_name]['image'].sort()

        print(self.datas.keys())

    def __getitem__(self, index):


                # load image1
        warp1 = cv2.imread(self.datas['warp1']['image'][index])
        warp1 = warp1.astype(dtype=np.float32)
        warp1 = (warp1 / 127.5) - 1.0
        warp1 = np.transpose(warp1, [2, 0, 1])

        # load image2
        warp2 = cv2.imread(self.datas['warp2']['image'][index])
        warp2 = warp2.astype(dtype=np.float32)
        warp2 = (warp2 / 127.5) - 1.0
        warp2 = np.transpose(warp2, [2, 0, 1])

        # load mask1
        mask1 = cv2.imread(self.datas['mask1']['image'][index])
        mask1 = mask1.astype(dtype=np.float32)
        mask1 = np.expand_dims(mask1[:,:,0], 2) / 255
        mask1 = np.transpose(mask1, [2, 0, 1])

        # load mask2
        mask2 = cv2.imread(self.datas['mask2']['image'][index])
        mask2 = mask2.astype(dtype=np.float32)
        mask2 = np.expand_dims(mask2[:,:,0], 2) / 255
        mask2 = np.transpose(mask2, [2, 0, 1])

        # convert to tensor
        warp1_tensor = torch.tensor(warp1)
        warp2_tensor = torch.tensor(warp2)
        mask1_tensor = torch.tensor(mask1)
        mask2_tensor = torch.tensor(mask2)

        return (warp1_tensor, warp2_tensor, mask1_tensor, mask2_tensor)

    def __len__(self):

        return len(self.datas['warp1']['image'])




================================================
FILE: Composition/Codes/loss.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F



# def get_vgg19_FeatureMap(vgg_model, input_255, layer_index):

#     vgg_mean = torch.tensor([123.6800, 116.7790, 103.9390]).reshape((1,3,1,1))
#     if torch.cuda.is_available():
#         vgg_mean = vgg_mean.cuda()
#     vgg_input = input_255-vgg_mean
#     #x = vgg_model.features[0](vgg_input)
#     #FeatureMap_list.append(x)


#     for i in range(0,layer_index+1):
#         if i == 0:
#             x = vgg_model.features[0](vgg_input)
#         else:
#             x = vgg_model.features[i](x)

#     return x



def l_num_loss(img1, img2, l_num=1):
    return torch.mean(torch.abs((img1 - img2)**l_num))


def boundary_extraction(mask):

    ones = torch.ones_like(mask)
    zeros = torch.zeros_like(mask)
    #define kernel
    in_channel = 1
    out_channel = 1
    kernel = [[1, 1, 1],
               [1, 1, 1],
               [1, 1, 1]]
    kernel = torch.FloatTensor(kernel).expand(out_channel,in_channel,3,3)
    if torch.cuda.is_available():
        kernel = kernel.cuda()
        ones = ones.cuda()
        zeros = zeros.cuda()
    weight = nn.Parameter(data=kernel, requires_grad=False)

    #dilation
    x = F.conv2d(1-mask,weight,stride=1,padding=1)
    x = torch.where(x < 1, zeros, ones)
    x = F.conv2d(x,weight,stride=1,padding=1)
    x = torch.where(x < 1, zeros, ones)
    x = F.conv2d(x,weight,stride=1,padding=1)
    x = torch.where(x < 1, zeros, ones)
    x = F.conv2d(x,weight,stride=1,padding=1)
    x = torch.where(x < 1, zeros, ones)
    x = F.conv2d(x,weight,stride=1,padding=1)
    x = torch.where(x < 1, zeros, ones)
    x = F.conv2d(x,weight,stride=1,padding=1)
    x = torch.where(x < 1, zeros, ones)
    x = F.conv2d(x,weight,stride=1,padding=1)
    x = torch.where(x < 1, zeros, ones)

    return x*mask

def cal_boundary_term(inpu1_tesnor, inpu2_tesnor, mask1_tesnor, mask2_tesnor, stitched_image):
    boundary_mask1 = mask1_tesnor * boundary_extraction(mask2_tesnor)
    boundary_mask2 = mask2_tesnor * boundary_extraction(mask1_tesnor)

    loss1 = l_num_loss(inpu1_tesnor*boundary_mask1, stitched_image*boundary_mask1, 1)
    loss2 = l_num_loss(inpu2_tesnor*boundary_mask2, stitched_image*boundary_mask2, 1)

    return loss1+loss2, boundary_mask1


def cal_smooth_term_stitch(stitched_image, learned_mask1):


    delta = 1
    dh_mask = torch.abs(learned_mask1[:,:,0:-1*delta,:] - learned_mask1[:,:,delta:,:])
    dw_mask = torch.abs(learned_mask1[:,:,:,0:-1*delta] - learned_mask1[:,:,:,delta:])
    dh_diff_img = torch.abs(stitched_image[:,:,0:-1*delta,:] - stitched_image[:,:,delta:,:])
    dw_diff_img = torch.abs(stitched_image[:,:,:,0:-1*delta] - stitched_image[:,:,:,delta:])

    dh_pixel = dh_mask * dh_diff_img
    dw_pixel = dw_mask * dw_diff_img

    loss = torch.mean(dh_pixel) + torch.mean(dw_pixel)

    return loss



def cal_smooth_term_diff(img1, img2, learned_mask1, overlap):

    diff_feature = torch.abs(img1-img2)**2 * overlap

    delta = 1
    dh_mask = torch.abs(learned_mask1[:,:,0:-1*delta,:] - learned_mask1[:,:,delta:,:])
    dw_mask = torch.abs(learned_mask1[:,:,:,0:-1*delta] - learned_mask1[:,:,:,delta:])
    dh_diff_img = torch.abs(diff_feature[:,:,0:-1*delta,:] + diff_feature[:,:,delta:,:])
    dw_diff_img = torch.abs(diff_feature[:,:,:,0:-1*delta] + diff_feature[:,:,:,delta:])

    dh_pixel = dh_mask * dh_diff_img
    dw_pixel = dw_mask * dw_diff_img

    loss = torch.mean(dh_pixel) + torch.mean(dw_pixel)

    return loss

    # dh_zeros = torch.zeros_like(dh_pixel)
    # dw_zeros = torch.zeros_like(dw_pixel)
    # if torch.cuda.is_available():
    #     dh_zeros = dh_zeros.cuda()
    #     dw_zeros = dw_zeros.cuda()


    # loss = l_num_loss(dh_pixel, dh_zeros, 1) + l_num_loss(dw_pixel, dw_zeros, 1)


    # return  loss, dh_pixel

================================================
FILE: Composition/Codes/network.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F



def build_model(net, warp1_tensor, warp2_tensor, mask1_tensor, mask2_tensor):

    out  = net(warp1_tensor, warp2_tensor, mask1_tensor, mask2_tensor)

    learned_mask1 = (mask1_tensor - mask1_tensor*mask2_tensor) + mask1_tensor*mask2_tensor*out
    learned_mask2 = (mask2_tensor - mask1_tensor*mask2_tensor) + mask1_tensor*mask2_tensor*(1-out)
    stitched_image = (warp1_tensor+1.) * learned_mask1 + (warp2_tensor+1.)*learned_mask2 - 1.

    out_dict = {}
    out_dict.update(learned_mask1=learned_mask1, learned_mask2=learned_mask2, stitched_image = stitched_image)


    return out_dict


class DownBlock(nn.Module):
    def __init__(self, inchannels, outchannels, dilation, pool=True):
        super(DownBlock, self).__init__()
        blk = []
        if pool:
            blk.append(nn.MaxPool2d(kernel_size=2, stride=2))
        blk.append(nn.Conv2d(inchannels, outchannels, kernel_size=3, padding=1, dilation = dilation))
        blk.append(nn.ReLU(inplace=True))
        blk.append(nn.Conv2d(outchannels, outchannels, kernel_size=3, padding=1, dilation = dilation))
        blk.append(nn.ReLU(inplace=True))
        self.layer = nn.Sequential(*blk)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight)
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

    def forward(self, x):
        return self.layer(x)

class UpBlock(nn.Module):
    def __init__(self, inchannels, outchannels, dilation):
        super(UpBlock, self).__init__()
        #self.convt = nn.ConvTranspose2d(inchannels, outchannels, kernel_size=2, stride=2)
        self.halfChanelConv = nn.Sequential(
            nn.Conv2d(inchannels, outchannels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True)
            )

        self.conv = nn.Sequential(
            nn.Conv2d(inchannels, outchannels, kernel_size=3, padding=1, dilation = dilation),
            nn.ReLU(inplace=True),
            nn.Conv2d(outchannels, outchannels, kernel_size=3, padding=1, dilation = dilation),
            nn.ReLU(inplace=True)
        )

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight)
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

    def forward(self, x1, x2):

        x1 = F.interpolate(x1, size = (x2.size()[2], x2.size()[3]), mode='nearest')
        x1 = self.halfChanelConv(x1)
        x = torch.cat([x2, x1], dim=1)
        x = self.conv(x)
        return x

# predict the composition mask of img1
class Network(nn.Module):
    def __init__(self, nclasses=1):
        super(Network, self).__init__()


        self.down1 = DownBlock(3, 32, 1, pool=False)
        self.down2 = DownBlock(32, 64, 2)
        self.down3 = DownBlock(64, 128,3)
        self.down4 = DownBlock(128, 256, 4)
        self.down5 = DownBlock(256, 512, 5)
        self.up1 = UpBlock(512, 256, 4)
        self.up2 = UpBlock(256, 128, 3)
        self.up3 = UpBlock(128, 64, 2)
        self.up4 = UpBlock(64, 32, 1)


        self.out = nn.Sequential(
            nn.Conv2d(32, nclasses, kernel_size=1),
            nn.Sigmoid()
        )

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight)
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()


    def forward(self, x, y, m1, m2):


        x1 = self.down1(x)
        x2 = self.down2(x1)
        x3 = self.down3(x2)
        x4 = self.down4(x3)
        x5 = self.down5(x4)

        y1 = self.down1(y)
        y2 = self.down2(y1)
        y3 = self.down3(y2)
        y4 = self.down4(y3)
        y5 = self.down5(y4)

        res = self.up1(x5-y5, x4-y4)
        res = self.up2(res, x3-y3)
        res = self.up3(res, x2-y2)
        res = self.up4(res, x1-y1)
        res = self.out(res)

        return res




================================================
FILE: Composition/Codes/test.py
================================================
# coding: utf-8
import argparse
import torch
from torch.utils.data import DataLoader
from network import build_model, Network
from dataset import *
import os
import numpy as np
import cv2


last_path = os.path.abspath(os.path.join(os.path.dirname("__file__"), os.path.pardir))
MODEL_DIR = os.path.join(last_path, 'model')



def test(args):

    os.environ['CUDA_DEVICES_ORDER'] = "PCI_BUS_ID"
    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu

    # dataset
    test_data = TestDataset(data_path=args.test_path)
    test_loader = DataLoader(dataset=test_data, batch_size=args.batch_size, num_workers=1, shuffle=False, drop_last=False)

    # define the network
    net = Network()
    if torch.cuda.is_available():
        net = net.cuda()

    #load the existing models if it exists
    ckpt_list = glob.glob(MODEL_DIR + "/*.pth")
    ckpt_list.sort()
    if len(ckpt_list) != 0:
        model_path = ckpt_list[-1]
        checkpoint = torch.load(model_path)
        net.load_state_dict(checkpoint['model'])
        print('load model from {}!'.format(model_path))
    else:
        print('No checkpoint found!')
        return


    path_learn_mask1 = '../learn_mask1/'
    if not os.path.exists(path_learn_mask1):
        os.makedirs(path_learn_mask1)
    path_learn_mask2 = '../learn_mask2/'
    if not os.path.exists(path_learn_mask2):
        os.makedirs(path_learn_mask2)
    path_final_composition = '../composition/'
    if not os.path.exists(path_final_composition):
        os.makedirs(path_final_composition)


    print("##################start testing#######################")
    net.eval()
    for i, batch_value in enumerate(test_loader):

        warp1_tensor = batch_value[0].float()
        warp2_tensor = batch_value[1].float()
        mask1_tensor = batch_value[2].float()
        mask2_tensor = batch_value[3].float()

        if torch.cuda.is_available():
            warp1_tensor = warp1_tensor.cuda()
            warp2_tensor = warp2_tensor.cuda()
            mask1_tensor = mask1_tensor.cuda()
            mask2_tensor = mask2_tensor.cuda()

        # if inpu1_tesnor.size()[2]*inpu1_tesnor.size()[3] > 1200000:
        #     print("oversize")
        #     continue

        with torch.no_grad():
            batch_out = build_model(net, warp1_tensor, warp2_tensor, mask1_tensor, mask2_tensor)

        stitched_image = batch_out['stitched_image']
        learned_mask1 = batch_out['learned_mask1']
        learned_mask2 = batch_out['learned_mask2']

        stitched_image = ((stitched_image[0]+1)*127.5).cpu().detach().numpy().transpose(1,2,0)
        learned_mask1 = (learned_mask1[0]*255).cpu().detach().numpy().transpose(1,2,0)
        learned_mask2 = (learned_mask2[0]*255).cpu().detach().numpy().transpose(1,2,0)

        path = path_learn_mask1 + str(i+1).zfill(6) + ".jpg"
        cv2.imwrite(path, learned_mask1)
        path = path_learn_mask2 + str(i+1).zfill(6) + ".jpg"
        cv2.imwrite(path, learned_mask2)
        path = path_final_composition + str(i+1).zfill(6) + ".jpg"
        cv2.imwrite(path, stitched_image)


        print('i = {}'.format( i+1))



if __name__=="__main__":

    parser = argparse.ArgumentParser()

    parser.add_argument('--gpu', type=str, default='0')
    parser.add_argument('--batch_size', type=int, default=1)
    parser.add_argument('--test_path', type=str, default='/opt/data/private/nl/Data/UDIS-D/testing/')

    print('<==================== Loading data ===================>\n')

    args = parser.parse_args()
    print(args)

    test(args)

================================================
FILE: Composition/Codes/test_other.py
================================================
# coding: utf-8
import argparse
import torch
from network import build_model, Network
import os
import numpy as np
import cv2
import glob



last_path = os.path.abspath(os.path.join(os.path.dirname("__file__"), os.path.pardir))
MODEL_DIR = os.path.join(last_path, 'model')

def loadSingleData(data_path):

    # load image1
    warp1 = cv2.imread(data_path+"warp1.jpg")
    warp1 = warp1.astype(dtype=np.float32)
    warp1 = (warp1 / 127.5) - 1.0
    warp1 = np.transpose(warp1, [2, 0, 1])

    # load image2
    warp2 = cv2.imread(data_path+"warp2.jpg")
    warp2 = warp2.astype(dtype=np.float32)
    warp2 = (warp2 / 127.5) - 1.0
    warp2 = np.transpose(warp2, [2, 0, 1])

    # load mask1
    mask1 = cv2.imread(data_path+"mask1.jpg")
    mask1 = mask1.astype(dtype=np.float32)
    mask1 = mask1 / 255
    mask1 = np.transpose(mask1, [2, 0, 1])

    # load mask2
    mask2 = cv2.imread(data_path+"mask2.jpg")
    mask2 = mask2.astype(dtype=np.float32)
    mask2 = mask2 / 255
    mask2 = np.transpose(mask2, [2, 0, 1])

    # convert to tensor
    warp1_tensor = torch.tensor(warp1).unsqueeze(0)
    warp2_tensor = torch.tensor(warp2).unsqueeze(0)
    mask1_tensor = torch.tensor(mask1).unsqueeze(0)
    mask2_tensor = torch.tensor(mask2).unsqueeze(0)

    return warp1_tensor, warp2_tensor, mask1_tensor, mask2_tensor


def test_other(args):

    os.environ['CUDA_DEVICES_ORDER'] = "PCI_BUS_ID"
    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu

    # define the network
    net = Network()
    if torch.cuda.is_available():
        net = net.cuda()

    #load the existing models if it exists
    ckpt_list = glob.glob(MODEL_DIR + "/*.pth")
    ckpt_list.sort()
    if len(ckpt_list) != 0:
        model_path = ckpt_list[-1]
        checkpoint = torch.load(model_path)
        net.load_state_dict(checkpoint['model'])
        print('load model from {}!'.format(model_path))
    else:
        print('No checkpoint found!')
        return


    # load dataset(only one pair of images)
    warp1_tensor, warp2_tensor, mask1_tensor, mask2_tensor = loadSingleData(data_path=args.path)
    if torch.cuda.is_available():
        warp1_tensor = warp1_tensor.cuda()
        warp2_tensor = warp2_tensor.cuda()
        mask1_tensor = mask1_tensor.cuda()
        mask2_tensor = mask2_tensor.cuda()

    net.eval()
    with torch.no_grad():
        batch_out = build_model(net, warp1_tensor, warp2_tensor, mask1_tensor, mask2_tensor)
    stitched_image = batch_out['stitched_image']
    learned_mask1 = batch_out['learned_mask1']
    learned_mask2 = batch_out['learned_mask2']

    # (optional) draw composition images with different colors like our paper
    s1 = ((warp1_tensor[0]+1)*127.5 * learned_mask1[0]).cpu().detach().numpy().transpose(1,2,0)
    s2 = ((warp2_tensor[0]+1)*127.5 * learned_mask2[0]).cpu().detach().numpy().transpose(1,2,0)
    fusion = np.zeros((warp1_tensor.shape[2],warp1_tensor.shape[3],3), np.uint8)
    fusion[...,0] = s2[...,0]
    fusion[...,1] = s1[...,1]*0.5 +  s2[...,1]*0.5
    fusion[...,2] = s1[...,2]
    path = args.path + "composition_color.jpg"
    cv2.imwrite(path, fusion)


    # save learned masks and final composition
    stitched_image = ((stitched_image[0]+1)*127.5).cpu().detach().numpy().transpose(1,2,0)
    learned_mask1 = (learned_mask1[0]*255).cpu().detach().numpy().transpose(1,2,0)
    learned_mask2 = (learned_mask2[0]*255).cpu().detach().numpy().transpose(1,2,0)

    path = args.path + "learn_mask1.jpg"
    cv2.imwrite(path, learned_mask1)
    path = args.path + "learn_mask2.jpg"
    cv2.imwrite(path, learned_mask2)
    path = args.path + "composition.jpg"
    cv2.imwrite(path, stitched_image)



if __name__=="__main__":

    parser = argparse.ArgumentParser()

    parser.add_argument('--gpu', type=str, default='0')
    parser.add_argument('--path', type=str, default='../../Carpark-DHW/')

    print('<==================== Loading data ===================>\n')

    args = parser.parse_args()
    print(args)

    test_other(args)

================================================
FILE: Composition/Codes/train.py
================================================
import argparse
import torch
from torch.utils.data import DataLoader
import os
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter
from network import build_model, Network
from dataset import TrainDataset
import glob
from loss import cal_boundary_term, cal_smooth_term_stitch, cal_smooth_term_diff



# path of project
last_path = os.path.abspath(os.path.join(os.path.dirname("__file__"), os.path.pardir))

# path to save the summary files
SUMMARY_DIR = os.path.join(last_path, 'summary')
writer = SummaryWriter(log_dir=SUMMARY_DIR)

# path to save the model files
MODEL_DIR = os.path.join(last_path, 'model')

# create folders if it dose not exist
if not os.path.exists(MODEL_DIR):
    os.makedirs(MODEL_DIR)
if not os.path.exists(SUMMARY_DIR):
    os.makedirs(SUMMARY_DIR)



def train(args):

    os.environ['CUDA_DEVICES_ORDER'] = "PCI_BUS_ID"
    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu

    # dataset
    train_data = TrainDataset(data_path=args.train_path)
    train_loader = DataLoader(dataset=train_data, batch_size=args.batch_size, num_workers=4, shuffle=True, drop_last=True)

    # define the network
    net = Network()

    if torch.cuda.is_available():
        net = net.cuda()


    # define the optimizer and learning rate
    optimizer = optim.Adam(net.parameters(), lr=1e-4, betas=(0.9, 0.999), eps=1e-08)  # default as 0.0001
    scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.97)

    #load the existing models if it exists
    ckpt_list = glob.glob(MODEL_DIR + "/*.pth")
    ckpt_list.sort()
    if len(ckpt_list) != 0:
        model_path = ckpt_list[-1]
        checkpoint = torch.load(model_path)

        net.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        start_epoch = checkpoint['epoch']
        glob_iter = checkpoint['glob_iter']
        scheduler.last_epoch = start_epoch
        print('load model from {}!'.format(model_path))
    else:
        start_epoch = 0
        glob_iter = 0
        print('training from stratch!')



    print("##################start training#######################")
    score_print_fre = 300

    for epoch in range(start_epoch, args.max_epoch):

        print("start epoch {}".format(epoch))
        net.train()
        sigma_total_loss = 0.
        sigma_boundary_loss = 0.
        sigma_smooth1_loss = 0.
        sigma_smooth2_loss = 0.

        print(epoch, 'lr={:.6f}'.format(optimizer.state_dict()['param_groups'][0]['lr']))

        for i, batch_value in enumerate(train_loader):

            warp1_tensor = batch_value[0].float()
            warp2_tensor = batch_value[1].float()
            mask1_tensor = batch_value[2].float()
            mask2_tensor = batch_value[3].float()

            if torch.cuda.is_available():
                warp1_tensor = warp1_tensor.cuda()
                warp2_tensor = warp2_tensor.cuda()
                mask1_tensor = mask1_tensor.cuda()
                mask2_tensor = mask2_tensor.cuda()


            # forward, backward, update weights
            optimizer.zero_grad()

            batch_out = build_model(net,  warp1_tensor,  warp2_tensor, mask1_tensor, mask2_tensor)

            learned_mask1 = batch_out['learned_mask1']
            learned_mask2 = batch_out['learned_mask2']
            stitched_image = batch_out['stitched_image']

            # boundary term
            boundary_loss, boundary_mask1 = cal_boundary_term( warp1_tensor,  warp2_tensor, mask1_tensor, mask2_tensor, stitched_image)
            boundary_loss = 10000 * boundary_loss

            #  smooth term
            # on stitched image
            smooth1_loss = cal_smooth_term_stitch(stitched_image, learned_mask1)
            smooth1_loss = 1000* smooth1_loss
            # on different image
            smooth2_loss = cal_smooth_term_diff( warp1_tensor,  warp2_tensor, learned_mask1, mask1_tensor*mask2_tensor)
            smooth2_loss = 1000 * smooth2_loss


            total_loss = boundary_loss + smooth1_loss + smooth2_loss
            total_loss.backward()
            # clip the gradient
            torch.nn.utils.clip_grad_norm_(net.parameters(), max_norm=3, norm_type=2)
            optimizer.step()


            sigma_boundary_loss += boundary_loss.item()
            sigma_smooth1_loss += smooth1_loss.item()
            sigma_smooth2_loss += smooth2_loss.item()
            sigma_total_loss += total_loss.item()

            print(glob_iter)
            # print loss etc.
            if i % score_print_fre == 0 and i != 0:
                average_total_loss = sigma_total_loss / score_print_fre
                average_boundary_loss = sigma_boundary_loss/ score_print_fre
                average_smooth1_loss = sigma_smooth1_loss/ score_print_fre
                average_smooth2_loss = sigma_smooth2_loss/ score_print_fre

                sigma_total_loss = 0.
                sigma_boundary_loss = 0.
                sigma_smooth1_loss = 0.
                sigma_smooth2_loss = 0.

                print("Training: Epoch[{:0>3}/{:0>3}] Iteration[{:0>3}]/[{:0>3}] Total Loss: {:.4f}   boundary loss: {:.4f}  smooth loss: {:.4f}  diff loss: {:.4f}   lr={:.8f}".format(epoch + 1, args.max_epoch, i + 1, len(train_loader), average_total_loss, average_boundary_loss, average_smooth1_loss, average_smooth2_loss, optimizer.state_dict()['param_groups'][0]['lr']))

                # visualization
                writer.add_image("inpu1", (warp1_tensor[0]+1.)/2., glob_iter)
                writer.add_image("inpu2", (warp2_tensor[0]+1.)/2., glob_iter)


                writer.add_image("stitched_image", (stitched_image[0]+1.)/2., glob_iter)
                writer.add_image("learned_mask1", learned_mask1[0], glob_iter)
                writer.add_image("boundary_mask1", boundary_mask1[0], glob_iter)


                writer.add_scalar('lr', optimizer.state_dict()['param_groups'][0]['lr'], glob_iter)
                writer.add_scalar('total loss', average_total_loss, glob_iter)
                writer.add_scalar('average_boundary_loss', average_boundary_loss, glob_iter)
                writer.add_scalar('average_smooth1_loss', average_smooth1_loss, glob_iter)
                writer.add_scalar('average_smooth2_loss', average_smooth2_loss, glob_iter)

            glob_iter += 1


        scheduler.step()
        # save model
        if ((epoch+1) % 10 == 0 or (epoch+1)==args.max_epoch):
            filename ='epoch' + str(epoch+1).zfill(3) + '_model.pth'
            model_save_path = os.path.join(MODEL_DIR, filename)
            state = {'model': net.state_dict(), 'optimizer': optimizer.state_dict(), 'epoch': epoch+1, "glob_iter": glob_iter}
            torch.save(state, model_save_path)
    print("##################end training#######################")


if __name__=="__main__":


    print('<==================== setting arguments ===================>\n')

    #nl: create the argument parser
    parser = argparse.ArgumentParser()

    #nl: add arguments
    parser.add_argument('--gpu', type=str, default='0')
    parser.add_argument('--batch_size', type=int, default=1)
    parser.add_argument('--max_epoch', type=int, default=50)
    parser.add_argument('--train_path', type=str, default='/opt/data/private/nl/Data/UDIS-D/training')

    #nl: parse the arguments
    args = parser.parse_args()
    print(args)

    print('<==================== jump into training function ===================>\n')
    #nl: rain
    train(args)




================================================
FILE: Composition/model/.txt
================================================



================================================
FILE: Composition/readme.md
================================================
## Train on UDIS-D
Before training, the warped images and corresponding masks should be generated in the warp stage.

Then, set the training dataset path in Composition/Codes/train.py.

```
python train_H.py
```

## Test on UDIS-D
The pre-trained model of warp is available at [Google Drive](https://drive.google.com/file/d/1OaG0ayEwRPhKVV_OwQwvwHDFHC26iv30/view?usp=sharing) or [Baidu Cloud](https://pan.baidu.com/s/1qCGegzvxtzri6GiG7mNw6g)(Extraction code: 1234).

Set the testing dataset path in Composition/Codes/test.py.

```
python test.py
```
The composition masks and final fusion results on UDIS-D will be generated and saved at the current path.


## Test on other datasets
Set the 'path/' in Composition/Codes/test_other.py. 
```
python test_other.py
```
The results will be generated and saved at 'path'.


================================================
FILE: Composition/summary/.txt
================================================



================================================
FILE: LICENSE
================================================
                                 Apache License
                           Version 2.0, January 2004
                        http://www.apache.org/licenses/

   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION

   1. Definitions.

      "License" shall mean the terms and conditions for use, reproduction,
      and distribution as defined by Sections 1 through 9 of this document.

      "Licensor" shall mean the copyright owner or entity authorized by
      the copyright owner that is granting the License.

      "Legal Entity" shall mean the union of the acting entity and all
      other entities that control, are controlled by, or are under common
      control with that entity. For the purposes of this definition,
      "control" means (i) the power, direct or indirect, to cause the
      direction or management of such entity, whether by contract or
      otherwise, or (ii) ownership of fifty percent (50%) or more of the
      outstanding shares, or (iii) beneficial ownership of such entity.

      "You" (or "Your") shall mean an individual or Legal Entity
      exercising permissions granted by this License.

      "Source" form shall mean the preferred form for making modifications,
      including but not limited to software source code, documentation
      source, and configuration files.

      "Object" form shall mean any form resulting from mechanical
      transformation or translation of a Source form, including but
      not limited to compiled object code, generated documentation,
      and conversions to other media types.

      "Work" shall mean the work of authorship, whether in Source or
      Object form, made available under the License, as indicated by a
      copyright notice that is included in or attached to the work
      (an example is provided in the Appendix below).

      "Derivative Works" shall mean any work, whether in Source or Object
      form, that is based on (or derived from) the Work and for which the
      editorial revisions, annotations, elaborations, or other modifications
      represent, as a whole, an original work of authorship. For the purposes
      of this License, Derivative Works shall not include works that remain
      separable from, or merely link (or bind by name) to the interfaces of,
      the Work and Derivative Works thereof.

      "Contribution" shall mean any work of authorship, including
      the original version of the Work and any modifications or additions
      to that Work or Derivative Works thereof, that is intentionally
      submitted to Licensor for inclusion in the Work by the copyright owner
      or by an individual or Legal Entity authorized to submit on behalf of
      the copyright owner. For the purposes of this definition, "submitted"
      means any form of electronic, verbal, or written communication sent
      to the Licensor or its representatives, including but not limited to
      communication on electronic mailing lists, source code control systems,
      and issue tracking systems that are managed by, or on behalf of, the
      Licensor for the purpose of discussing and improving the Work, but
      excluding communication that is conspicuously marked or otherwise
      designated in writing by the copyright owner as "Not a Contribution."

      "Contributor" shall mean Licensor and any individual or Legal Entity
      on behalf of whom a Contribution has been received by Licensor and
      subsequently incorporated within the Work.

   2. Grant of Copyright License. Subject to the terms and conditions of
      this License, each Contributor hereby grants to You a perpetual,
      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
      copyright license to reproduce, prepare Derivative Works of,
      publicly display, publicly perform, sublicense, and distribute the
      Work and such Derivative Works in Source or Object form.

   3. Grant of Patent License. Subject to the terms and conditions of
      this License, each Contributor hereby grants to You a perpetual,
      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
      (except as stated in this section) patent license to make, have made,
      use, offer to sell, sell, import, and otherwise transfer the Work,
      where such license applies only to those patent claims licensable
      by such Contributor that are necessarily infringed by their
      Contribution(s) alone or by combination of their Contribution(s)
      with the Work to which such Contribution(s) was submitted. If You
      institute patent litigation against any entity (including a
      cross-claim or counterclaim in a lawsuit) alleging that the Work
      or a Contribution incorporated within the Work constitutes direct
      or contributory patent infringement, then any patent licenses
      granted to You under this License for that Work shall terminate
      as of the date such litigation is filed.

   4. Redistribution. You may reproduce and distribute copies of the
      Work or Derivative Works thereof in any medium, with or without
      modifications, and in Source or Object form, provided that You
      meet the following conditions:

      (a) You must give any other recipients of the Work or
          Derivative Works a copy of this License; and

      (b) You must cause any modified files to carry prominent notices
          stating that You changed the files; and

      (c) You must retain, in the Source form of any Derivative Works
          that You distribute, all copyright, patent, trademark, and
          attribution notices from the Source form of the Work,
          excluding those notices that do not pertain to any part of
          the Derivative Works; and

      (d) If the Work includes a "NOTICE" text file as part of its
          distribution, then any Derivative Works that You distribute must
          include a readable copy of the attribution notices contained
          within such NOTICE file, excluding those notices that do not
          pertain to any part of the Derivative Works, in at least one
          of the following places: within a NOTICE text file distributed
          as part of the Derivative Works; within the Source form or
          documentation, if provided along with the Derivative Works; or,
          within a display generated by the Derivative Works, if and
          wherever such third-party notices normally appear. The contents
          of the NOTICE file are for informational purposes only and
          do not modify the License. You may add Your own attribution
          notices within Derivative Works that You distribute, alongside
          or as an addendum to the NOTICE text from the Work, provided
          that such additional attribution notices cannot be construed
          as modifying the License.

      You may add Your own copyright statement to Your modifications and
      may provide additional or different license terms and conditions
      for use, reproduction, or distribution of Your modifications, or
      for any such Derivative Works as a whole, provided Your use,
      reproduction, and distribution of the Work otherwise complies with
      the conditions stated in this License.

   5. Submission of Contributions. Unless You explicitly state otherwise,
      any Contribution intentionally submitted for inclusion in the Work
      by You to the Licensor shall be under the terms and conditions of
      this License, without any additional terms or conditions.
      Notwithstanding the above, nothing herein shall supersede or modify
      the terms of any separate license agreement you may have executed
      with Licensor regarding such Contributions.

   6. Trademarks. This License does not grant permission to use the trade
      names, trademarks, service marks, or product names of the Licensor,
      except as required for reasonable and customary use in describing the
      origin of the Work and reproducing the content of the NOTICE file.

   7. Disclaimer of Warranty. Unless required by applicable law or
      agreed to in writing, Licensor provides the Work (and each
      Contributor provides its Contributions) on an "AS IS" BASIS,
      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
      implied, including, without limitation, any warranties or conditions
      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
      PARTICULAR PURPOSE. You are solely responsible for determining the
      appropriateness of using or redistributing the Work and assume any
      risks associated with Your exercise of permissions under this License.

   8. Limitation of Liability. In no event and under no legal theory,
      whether in tort (including negligence), contract, or otherwise,
      unless required by applicable law (such as deliberate and grossly
      negligent acts) or agreed to in writing, shall any Contributor be
      liable to You for damages, including any direct, indirect, special,
      incidental, or consequential damages of any character arising as a
      result of this License or out of the use or inability to use the
      Work (including but not limited to damages for loss of goodwill,
      work stoppage, computer failure or malfunction, or any and all
      other commercial damages or losses), even if such Contributor
      has been advised of the possibility of such damages.

   9. Accepting Warranty or Additional Liability. While redistributing
      the Work or Derivative Works thereof, You may choose to offer,
      and charge a fee for, acceptance of support, warranty, indemnity,
      or other liability obligations and/or rights consistent with this
      License. However, in accepting such obligations, You may act only
      on Your own behalf and on Your sole responsibility, not on behalf
      of any other Contributor, and only if You agree to indemnify,
      defend, and hold each Contributor harmless for any liability
      incurred by, or claims asserted against, such Contributor by reason
      of your accepting any such warranty or additional liability.

   END OF TERMS AND CONDITIONS

   APPENDIX: How to apply the Apache License to your work.

      To apply the Apache License to your work, attach the following
      boilerplate notice, with the fields enclosed by brackets "[]"
      replaced with your own identifying information. (Don't include
      the brackets!)  The text should be enclosed in the appropriate
      comment syntax for the file format. We also recommend that a
      file or class name and description of purpose be included on the
      same "printed page" as the copyright notice for easier
      identification within third-party archives.

   Copyright [yyyy] [name of copyright owner]

   Licensed under the Apache License, Version 2.0 (the "License");
   you may not use this file except in compliance with the License.
   You may obtain a copy of the License at

       http://www.apache.org/licenses/LICENSE-2.0

   Unless required by applicable law or agreed to in writing, software
   distributed under the License is distributed on an "AS IS" BASIS,
   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
   See the License for the specific language governing permissions and
   limitations under the License.


================================================
FILE: README.md
================================================
# <p align="center">Parallax-Tolerant Unsupervised Deep Image Stitching (UDIS++ [paper](https://arxiv.org/abs/2302.08207))</p>
<p align="center">Lang Nie*, Chunyu Lin*, Kang Liao*, Shuaicheng Liu`, Yao Zhao*</p>
<p align="center">* Institute of Information Science, Beijing Jiaotong University</p>
<p align="center">` School of Information and Communication Engineering, University of Electronic Science and Technology of China</p>

![image](https://github.com/nie-lang/UDIS2/blob/main/fig1.png)

## Dataset (UDIS-D)
We use the UDIS-D dataset to train and evaluate our method. Please refer to [UDIS](https://github.com/nie-lang/UnsupervisedDeepImageStitching) for more details about this dataset.


## Code
#### Requirement
* numpy 1.19.5
* pytorch 1.7.1
* scikit-image 0.15.0
* tensorboard 2.9.0

We implement this work with Ubuntu, 3090Ti, and CUDA11. Refer to [environment.yml](https://github.com/nie-lang/UDIS2/blob/main/environment.yml) for more details.

#### How to run it
Similar to UDIS, we also implement this solution in two stages:
* Stage 1 (unsupervised warp): please refer to  [Warp/readme.md](https://github.com/nie-lang/UDIS2/blob/main/Warp/readme.md).
* Stage 2 (unsupervised composition): please refer to [Composition/readme.md](https://github.com/nie-lang/UDIS2/blob/main/Composition/readme.md).



## Meta
If you have any questions about this project, please feel free to drop me an email.

NIE Lang -- nielang@bjtu.edu.cn
```
@inproceedings{nie2023parallax,
  title={Parallax-Tolerant Unsupervised Deep Image Stitching},
  author={Nie, Lang and Lin, Chunyu and Liao, Kang and Liu, Shuaicheng and Zhao, Yao},
  booktitle={Proceedings of the IEEE/CVF International Conference on Computer Vision},
  pages={7399--7408},
  year={2023}
}
```


## References
[1] L. Nie, C. Lin, K. Liao, M. Liu, and Y. Zhao, “A view-free image stitching network based on global homography,” Journal of Visual Communication and Image Representation, p. 102950, 2020.  
[2] L. Nie, C. Lin, K. Liao, and Y. Zhao. Learning edge-preserved image stitching from multi-scale deep homography[J]. Neurocomputing, 2022, 491: 533-543.   
[3] L. Nie, C. Lin, K. Liao, S. Liu, and Y. Zhao. Unsupervised deep image stitching: Reconstructing stitched features to images[J]. IEEE Transactions on Image Processing, 2021, 30: 6184-6197.   
[4] L. Nie, C. Lin, K. Liao, S. Liu, and Y. Zhao. Deep rectangling for image stitching: a learning baseline[C]//Proceedings of the IEEE/CVF conference on computer vision and pattern recognition. 2022: 5740-5748.   


================================================
FILE: Warp/Codes/dataset.py
================================================
from torch.utils.data import Dataset
import  numpy as np
import cv2, torch
import os
import glob
from collections import OrderedDict
import random


class TrainDataset(Dataset):
    def __init__(self, data_path):

        self.width = 512
        self.height = 512
        self.train_path = data_path
        self.datas = OrderedDict()
        
        datas = glob.glob(os.path.join(self.train_path, '*'))
        for data in sorted(datas):
            data_name = data.split('/')[-1]
            if data_name == 'input1' or data_name == 'input2' :
                self.datas[data_name] = {}
                self.datas[data_name]['path'] = data
                self.datas[data_name]['image'] = glob.glob(os.path.join(data, '*.jpg'))
                self.datas[data_name]['image'].sort()
        print(self.datas.keys())

    def __getitem__(self, index):
        
        # load image1
        input1 = cv2.imread(self.datas['input1']['image'][index])
        input1 = cv2.resize(input1, (self.width, self.height))
        input1 = input1.astype(dtype=np.float32)
        input1 = (input1 / 127.5) - 1.0
        input1 = np.transpose(input1, [2, 0, 1])
        
        # load image2
        input2 = cv2.imread(self.datas['input2']['image'][index])
        input2 = cv2.resize(input2, (self.width, self.height))
        input2 = input2.astype(dtype=np.float32)
        input2 = (input2 / 127.5) - 1.0
        input2 = np.transpose(input2, [2, 0, 1])
        
        # convert to tensor
        input1_tensor = torch.tensor(input1)
        input2_tensor = torch.tensor(input2)
        
        #print("fasdf")
        if_exchange = random.randint(0,1)
        if if_exchange == 0:
            #print(if_exchange)
            return (input1_tensor, input2_tensor)
        else:
            #print(if_exchange)
            return (input2_tensor, input1_tensor)

    def __len__(self):

        return len(self.datas['input1']['image'])

class TestDataset(Dataset):
    def __init__(self, data_path):

        self.width = 512
        self.height = 512
        self.test_path = data_path
        self.datas = OrderedDict()
        
        datas = glob.glob(os.path.join(self.test_path, '*'))
        for data in sorted(datas):
            data_name = data.split('/')[-1]
            if data_name == 'input1' or data_name == 'input2' :
                self.datas[data_name] = {}
                self.datas[data_name]['path'] = data
                self.datas[data_name]['image'] = glob.glob(os.path.join(data, '*.jpg'))
                self.datas[data_name]['image'].sort()
        print(self.datas.keys())

    def __getitem__(self, index):
        
        # load image1
        input1 = cv2.imread(self.datas['input1']['image'][index])
        #input1 = cv2.resize(input1, (self.width, self.height))
        input1 = input1.astype(dtype=np.float32)
        input1 = (input1 / 127.5) - 1.0
        input1 = np.transpose(input1, [2, 0, 1])
        
        # load image2
        input2 = cv2.imread(self.datas['input2']['image'][index])
        #input2 = cv2.resize(input2, (self.width, self.height))
        input2 = input2.astype(dtype=np.float32)
        input2 = (input2 / 127.5) - 1.0
        input2 = np.transpose(input2, [2, 0, 1])
        
        # convert to tensor
        input1_tensor = torch.tensor(input1)
        input2_tensor = torch.tensor(input2)

        return (input1_tensor, input2_tensor)

    def __len__(self):

        return len(self.datas['input1']['image'])





================================================
FILE: Warp/Codes/grid_res.py
================================================

#define control point resolution (GRID_H+1) * (GRID_W+1)
GRID_H = 12
GRID_W = 12

================================================
FILE: Warp/Codes/loss.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F

import grid_res
grid_h = grid_res.GRID_H
grid_w = grid_res.GRID_W


def l_num_loss(img1, img2, l_num=1):
    return torch.mean(torch.abs((img1 - img2)**l_num))


def cal_lp_loss(input1, input2, output_H, output_H_inv, warp_mesh, warp_mesh_mask):
    batch_size, _, img_h, img_w = input1.size()

    # part one: sym homo loss with color balance
    delta1 = ( torch.sum(output_H[:,0:3,:,:], [2,3])  -   torch.sum(input1*output_H[:,3:6,:,:], [2,3]) ) /  torch.sum(output_H[:,3:6,:,:], [2,3])
    input1_balance = input1 + delta1.unsqueeze(2).unsqueeze(3).expand(-1, -1, img_h, img_w)

    delta2 = ( torch.sum(output_H_inv[:,0:3,:,:], [2,3])  -   torch.sum(input2*output_H_inv[:,3:6,:,:], [2,3]) ) /  torch.sum(output_H_inv[:,3:6,:,:], [2,3])
    input2_balance = input2 + delta2.unsqueeze(2).unsqueeze(3).expand(-1, -1, img_h, img_w)

    lp_loss_1 = l_num_loss(input1_balance*output_H[:,3:6,:,:], output_H[:,0:3,:,:], 1) + l_num_loss(input2_balance*output_H_inv[:,3:6,:,:], output_H_inv[:,0:3,:,:], 1)

    # part two: tps loss with color balance
    delta3 = ( torch.sum(warp_mesh, [2,3])  -   torch.sum(input1*warp_mesh_mask, [2,3]) ) /  torch.sum(warp_mesh_mask, [2,3])
    input1_newbalance = input1 + delta3.unsqueeze(2).unsqueeze(3).expand(-1, -1, img_h, img_w)

    lp_loss_2 = l_num_loss(input1_newbalance*warp_mesh_mask, warp_mesh, 1)


    lp_loss = 3. * lp_loss_1 + 1. * lp_loss_2

    return lp_loss

def cal_lp_loss2(input1, warp_mesh, warp_mesh_mask):
    batch_size, _, img_h, img_w = input1.size()

    delta3 = ( torch.sum(warp_mesh, [2,3])  -   torch.sum(input1*warp_mesh_mask, [2,3]) ) /  torch.sum(warp_mesh_mask, [2,3])
    input1_newbalance = input1 + delta3.unsqueeze(2).unsqueeze(3).expand(-1, -1, img_h, img_w)

    lp_loss_2 = l_num_loss(input1_newbalance*warp_mesh_mask, warp_mesh, 1)
    lp_loss =  1. * lp_loss_2

    return lp_loss

def inter_grid_loss(overlap, mesh):

    ##############################
    # compute horizontal edges
    w_edges = mesh[:,:,0:grid_w,:] - mesh[:,:,1:grid_w+1,:]
    # compute angles of two successive horizontal edges
    cos_w = torch.sum(w_edges[:,:,0:grid_w-1,:] * w_edges[:,:,1:grid_w,:],3) / (torch.sqrt(torch.sum(w_edges[:,:,0:grid_w-1,:]*w_edges[:,:,0:grid_w-1,:],3))*torch.sqrt(torch.sum(w_edges[:,:,1:grid_w,:]*w_edges[:,:,1:grid_w,:],3)))
    # horizontal angle-preserving error for two successive horizontal edges
    delta_w_angle = 1 - cos_w
    # horizontal angle-preserving error for two successive horizontal grids
    delta_w_angle = delta_w_angle[:,0:grid_h,:] + delta_w_angle[:,1:grid_h+1,:]
    ##############################

    ##############################
    # compute vertical edges
    h_edges = mesh[:,0:grid_h,:,:] - mesh[:,1:grid_h+1,:,:]
    # compute angles of two successive vertical edges
    cos_h = torch.sum(h_edges[:,0:grid_h-1,:,:] * h_edges[:,1:grid_h,:,:],3) / (torch.sqrt(torch.sum(h_edges[:,0:grid_h-1,:,:]*h_edges[:,0:grid_h-1,:,:],3))*torch.sqrt(torch.sum(h_edges[:,1:grid_h,:,:]*h_edges[:,1:grid_h,:,:],3)))
    # vertical angle-preserving error for two successive vertical edges
    delta_h_angle = 1 - cos_h
    # vertical angle-preserving error for two successive vertical grids
    delta_h_angle = delta_h_angle[:,:,0:grid_w] + delta_h_angle[:,:,1:grid_w+1]
    ##############################

    # on overlapping regions
    depth_diff_w = (1-torch.abs(overlap[:,:,0:grid_w-1] - overlap[:,:,1:grid_w])) * overlap[:,:,0:grid_w-1]
    error_w = depth_diff_w * delta_w_angle
    # on overlapping regions
    depth_diff_h = (1-torch.abs(overlap[:,0:grid_h-1,:] - overlap[:,1:grid_h,:])) * overlap[:,0:grid_h-1,:]
    error_h = depth_diff_h * delta_h_angle

    return torch.mean(error_w) + torch.mean(error_h)



# intra-grid constraint
def intra_grid_loss(pts):

    max_w = 512/grid_w * 2
    max_h = 512/grid_h * 2

    delta_x = pts[:,:,1:grid_w+1,0] - pts[:,:,0:grid_w,0]
    delta_y = pts[:,1:grid_h+1,:,1] - pts[:,0:grid_h,:,1]

    loss_x = F.relu(delta_x - max_w)
    loss_y = F.relu(delta_y - max_h)
    loss = torch.mean(loss_x) + torch.mean(loss_y)


    return loss





================================================
FILE: Warp/Codes/network.py
================================================
import torch
import torch.nn as nn
import utils.torch_DLT as torch_DLT
import utils.torch_homo_transform as torch_homo_transform
import utils.torch_tps_transform as torch_tps_transform
import ssl
import torch.nn.functional as F
import cv2
import numpy as np
import torchvision.models as models

import torchvision.transforms as T
resize_512 = T.Resize((512,512))

import grid_res
grid_h = grid_res.GRID_H
grid_w = grid_res.GRID_W


# draw mesh on image
# warp: h*w*3
# f_local: grid_h*grid_w*2
def draw_mesh_on_warp(warp, f_local):

    warp = np.ascontiguousarray(warp)

    point_color = (0, 255, 0) # BGR
    thickness = 2
    lineType = 8

    num = 1
    for i in range(grid_h+1):
        for j in range(grid_w+1):

            num = num + 1
            if j == grid_w and i == grid_h:
                continue
            elif j == grid_w:
                cv2.line(warp, (int(f_local[i,j,0]), int(f_local[i,j,1])), (int(f_local[i+1,j,0]), int(f_local[i+1,j,1])), point_color, thickness, lineType)
            elif i == grid_h:
                cv2.line(warp, (int(f_local[i,j,0]), int(f_local[i,j,1])), (int(f_local[i,j+1,0]), int(f_local[i,j+1,1])), point_color, thickness, lineType)
            else :
                cv2.line(warp, (int(f_local[i,j,0]), int(f_local[i,j,1])), (int(f_local[i+1,j,0]), int(f_local[i+1,j,1])), point_color, thickness, lineType)
                cv2.line(warp, (int(f_local[i,j,0]), int(f_local[i,j,1])), (int(f_local[i,j+1,0]), int(f_local[i,j+1,1])), point_color, thickness, lineType)

    return warp


#Covert global homo into mesh
def H2Mesh(H, rigid_mesh):

    H_inv = torch.inverse(H)
    ori_pt = rigid_mesh.reshape(rigid_mesh.size()[0], -1, 2)
    ones = torch.ones(rigid_mesh.size()[0], (grid_h+1)*(grid_w+1),1)
    if torch.cuda.is_available():
        ori_pt = ori_pt.cuda()
        ones = ones.cuda()

    ori_pt = torch.cat((ori_pt, ones), 2) # bs*(grid_h+1)*(grid_w+1)*3
    tar_pt = torch.matmul(H_inv, ori_pt.permute(0,2,1)) # bs*3*(grid_h+1)*(grid_w+1)

    mesh_x = torch.unsqueeze(tar_pt[:,0,:]/tar_pt[:,2,:], 2)
    mesh_y = torch.unsqueeze(tar_pt[:,1,:]/tar_pt[:,2,:], 2)
    mesh = torch.cat((mesh_x, mesh_y), 2).reshape([rigid_mesh.size()[0], grid_h+1, grid_w+1, 2])

    return mesh

# get rigid mesh
def get_rigid_mesh(batch_size, height, width):

    ww = torch.matmul(torch.ones([grid_h+1, 1]), torch.unsqueeze(torch.linspace(0., float(width), grid_w+1), 0))
    hh = torch.matmul(torch.unsqueeze(torch.linspace(0.0, float(height), grid_h+1), 1), torch.ones([1, grid_w+1]))
    if torch.cuda.is_available():
        ww = ww.cuda()
        hh = hh.cuda()

    ori_pt = torch.cat((ww.unsqueeze(2), hh.unsqueeze(2)),2) # (grid_h+1)*(grid_w+1)*2
    ori_pt = ori_pt.unsqueeze(0).expand(batch_size, -1, -1, -1)

    return ori_pt

# normalize mesh from -1 ~ 1
def get_norm_mesh(mesh, height, width):
    batch_size = mesh.size()[0]
    mesh_w = mesh[...,0]*2./float(width) - 1.
    mesh_h = mesh[...,1]*2./float(height) - 1.
    norm_mesh = torch.stack([mesh_w, mesh_h], 3) # bs*(grid_h+1)*(grid_w+1)*2

    return norm_mesh.reshape([batch_size, -1, 2]) # bs*-1*2



# random augmentation
# it seems to do nothing to the performance
def data_aug(img1, img2):
    # Randomly shift brightness
    random_brightness = torch.randn(1).uniform_(0.7,1.3).cuda()
    img1_aug = img1 * random_brightness
    random_brightness = torch.randn(1).uniform_(0.7,1.3).cuda()
    img2_aug = img2 * random_brightness

    # Randomly shift color
    white = torch.ones([img1.size()[0], img1.size()[2], img1.size()[3]]).cuda()
    random_colors = torch.randn(3).uniform_(0.7,1.3).cuda()
    color_image = torch.stack([white * random_colors[i] for i in range(3)], axis=1)
    img1_aug  *= color_image

    random_colors = torch.randn(3).uniform_(0.7,1.3).cuda()
    color_image = torch.stack([white * random_colors[i] for i in range(3)], axis=1)
    img2_aug  *= color_image

    # clip
    img1_aug = torch.clamp(img1_aug, -1, 1)
    img2_aug = torch.clamp(img2_aug, -1, 1)

    return img1_aug, img2_aug


# for train.py / test.py
def build_model(net, input1_tensor, input2_tensor, is_training = True):
    batch_size, _, img_h, img_w = input1_tensor.size()

    # network
    if is_training == True:
        aug_input1_tensor, aug_input2_tensor = data_aug(input1_tensor, input2_tensor)
        H_motion, mesh_motion = net(aug_input1_tensor, aug_input2_tensor)
    else:
        H_motion, mesh_motion = net(input1_tensor, input2_tensor)

    H_motion = H_motion.reshape(-1, 4, 2)
    mesh_motion = mesh_motion.reshape(-1, grid_h+1, grid_w+1, 2)

    # initialize the source points bs x 4 x 2
    src_p = torch.tensor([[0., 0.], [img_w, 0.], [0., img_h], [img_w, img_h]])
    if torch.cuda.is_available():
        src_p = src_p.cuda()
    src_p = src_p.unsqueeze(0).expand(batch_size, -1, -1)
    # target points
    dst_p = src_p + H_motion
    # solve homo using DLT
    H = torch_DLT.tensor_DLT(src_p, dst_p)

    M_tensor = torch.tensor([[img_w / 2.0, 0., img_w / 2.0],
                      [0., img_h / 2.0, img_h / 2.0],
                      [0., 0., 1.]])

    if torch.cuda.is_available():
        M_tensor = M_tensor.cuda()

    M_tile = M_tensor.unsqueeze(0).expand(batch_size, -1, -1)
    M_tensor_inv = torch.inverse(M_tensor)
    M_tile_inv = M_tensor_inv.unsqueeze(0).expand(batch_size, -1, -1)
    H_mat = torch.matmul(torch.matmul(M_tile_inv, H), M_tile)

    mask = torch.ones_like(input2_tensor)
    if torch.cuda.is_available():
        mask = mask.cuda()
    output_H = torch_homo_transform.transformer(torch.cat((input2_tensor, mask), 1), H_mat, (img_h, img_w))

    H_inv_mat = torch.matmul(torch.matmul(M_tile_inv, torch.inverse(H)), M_tile)
    output_H_inv = torch_homo_transform.transformer(torch.cat((input1_tensor, mask), 1), H_inv_mat, (img_h, img_w))

    rigid_mesh = get_rigid_mesh(batch_size, img_h, img_w)
    ini_mesh = H2Mesh(H, rigid_mesh)
    mesh = ini_mesh + mesh_motion


    norm_rigid_mesh = get_norm_mesh(rigid_mesh, img_h, img_w)
    norm_mesh = get_norm_mesh(mesh, img_h, img_w)

    output_tps = torch_tps_transform.transformer(torch.cat((input2_tensor, mask), 1), norm_mesh, norm_rigid_mesh, (img_h, img_w))
    warp_mesh = output_tps[:,0:3,...]
    warp_mesh_mask = output_tps[:,3:6,...]

    # calculate the overlapping regions to apply shape-preserving constraints
    overlap = torch_tps_transform.transformer(warp_mesh_mask, norm_rigid_mesh, norm_mesh, (img_h, img_w))
    overlap = overlap.permute(0, 2, 3, 1).unfold(1, int(img_h/grid_h), int(img_h/grid_h)).unfold(2, int(img_w/grid_w), int(img_w/grid_w))
    overlap = torch.mean(overlap.reshape(batch_size, grid_h, grid_w, -1), 3)
    overlap_one = torch.ones_like(overlap)
    overlap_zero = torch.zeros_like(overlap)
    overlap = torch.where(overlap<0.9, overlap_one, overlap_zero)


    out_dict = {}
    out_dict.update(output_H=output_H, output_H_inv = output_H_inv, warp_mesh = warp_mesh, warp_mesh_mask = warp_mesh_mask, mesh1 = rigid_mesh, mesh2 = mesh, overlap = overlap)


    return out_dict

# for train_ft.py
def build_new_ft_model(net, input1_tensor, input2_tensor):
    batch_size, _, img_h, img_w = input1_tensor.size()

    H_motion, mesh_motion = net(input1_tensor, input2_tensor)

    H_motion = H_motion.reshape(-1, 4, 2)
    #H_motion = torch.stack([H_motion[...,0]*img_w/512, H_motion[...,1]*img_h/512], 2)

    mesh_motion = mesh_motion.reshape(-1, grid_h+1, grid_w+1, 2)
    #mesh_motion = torch.stack([mesh_motion[...,0]*img_w/512, mesh_motion[...,1]*img_h/512], 3)

    # initialize the source points bs x 4 x 2
    src_p = torch.tensor([[0., 0.], [img_w, 0.], [0., img_h], [img_w, img_h]])
    if torch.cuda.is_available():
        src_p = src_p.cuda()
    src_p = src_p.unsqueeze(0).expand(batch_size, -1, -1)
    # target points
    dst_p = src_p + H_motion
    # solve homo using DLT
    H = torch_DLT.tensor_DLT(src_p, dst_p)


    rigid_mesh = get_rigid_mesh(batch_size, img_h, img_w)
    ini_mesh = H2Mesh(H, rigid_mesh)
    mesh = ini_mesh + mesh_motion

    norm_rigid_mesh = get_norm_mesh(rigid_mesh, img_h, img_w)
    norm_mesh = get_norm_mesh(mesh, img_h, img_w)

    mask = torch.ones_like(input2_tensor)
    if torch.cuda.is_available():
        mask = mask.cuda()
    output_tps = torch_tps_transform.transformer(torch.cat((input2_tensor, mask), 1), norm_mesh, norm_rigid_mesh, (img_h, img_w))
    warp_mesh = output_tps[:,0:3,...]
    warp_mesh_mask = output_tps[:,3:6,...]


    out_dict = {}
    out_dict.update(warp_mesh = warp_mesh, warp_mesh_mask = warp_mesh_mask, rigid_mesh = rigid_mesh, mesh = mesh)


    return out_dict

# for train_ft.py
def get_stitched_result(input1_tensor, input2_tensor, rigid_mesh, mesh):
    batch_size, _, img_h, img_w = input1_tensor.size()

    rigid_mesh = torch.stack([rigid_mesh[...,0]*img_w/512, rigid_mesh[...,1]*img_h/512], 3)
    mesh = torch.stack([mesh[...,0]*img_w/512, mesh[...,1]*img_h/512], 3)

    ######################################
    width_max = torch.max(mesh[...,0])
    width_max = torch.maximum(torch.tensor(img_w).cuda(), width_max)
    width_min = torch.min(mesh[...,0])
    width_min = torch.minimum(torch.tensor(0).cuda(), width_min)
    height_max = torch.max(mesh[...,1])
    height_max = torch.maximum(torch.tensor(img_h).cuda(), height_max)
    height_min = torch.min(mesh[...,1])
    height_min = torch.minimum(torch.tensor(0).cuda(), height_min)

    out_width = width_max - width_min
    out_height = height_max - height_min
    print(out_width)
    print(out_height)

    warp1 = torch.zeros([batch_size, 3, out_height.int(), out_width.int()]).cuda()
    warp1[:,:, int(torch.abs(height_min)):int(torch.abs(height_min))+img_h,  int(torch.abs(width_min)):int(torch.abs(width_min))+img_w] = (input1_tensor+1)*127.5

    mask1 = torch.zeros([batch_size, 3, out_height.int(), out_width.int()]).cuda()
    mask1[:,:, int(torch.abs(height_min)):int(torch.abs(height_min))+img_h,  int(torch.abs(width_min)):int(torch.abs(width_min))+img_w] = 255

    mask = torch.ones_like(input2_tensor)
    if torch.cuda.is_available():
        mask = mask.cuda()

    # get warped img2
    mesh_trans = torch.stack([mesh[...,0]-width_min, mesh[...,1]-height_min], 3)
    norm_rigid_mesh = get_norm_mesh(rigid_mesh, img_h, img_w)
    norm_mesh = get_norm_mesh(mesh_trans, out_height, out_width)

    stitch_tps_out = torch_tps_transform.transformer(torch.cat([input2_tensor+1, mask], 1), norm_mesh, norm_rigid_mesh, (out_height.int(), out_width.int()))
    warp2 = stitch_tps_out[:,0:3,:,:]*127.5
    mask2 = stitch_tps_out[:,3:6,:,:]*255

    stitched = warp1*(warp1/(warp1+warp2+1e-6)) + warp2*(warp2/(warp1+warp2+1e-6))

    stitched_mesh = draw_mesh_on_warp(stitched[0].cpu().detach().numpy().transpose(1,2,0), mesh_trans[0].cpu().detach().numpy())

    out_dict = {}
    out_dict.update(warp1 = warp1, mask1 = mask1, warp2 = warp2, mask2 = mask2, stitched = stitched, stitched_mesh = stitched_mesh)

    return out_dict


# for test_output.py
def build_output_model(net, input1_tensor, input2_tensor):
    batch_size, _, img_h, img_w = input1_tensor.size()

    resized_input1 = resize_512(input1_tensor)
    resized_input2 = resize_512(input2_tensor)
    H_motion, mesh_motion = net(resized_input1, resized_input2)

    H_motion = H_motion.reshape(-1, 4, 2)
    H_motion = torch.stack([H_motion[...,0]*img_w/512, H_motion[...,1]*img_h/512], 2)
    mesh_motion = mesh_motion.reshape(-1, grid_h+1, grid_w+1, 2)
    mesh_motion = torch.stack([mesh_motion[...,0]*img_w/512, mesh_motion[...,1]*img_h/512], 3)

    # initialize the source points bs x 4 x 2
    src_p = torch.tensor([[0., 0.], [img_w, 0.], [0., img_h], [img_w, img_h]])
    if torch.cuda.is_available():
        src_p = src_p.cuda()
    src_p = src_p.unsqueeze(0).expand(batch_size, -1, -1)
    # target points
    dst_p = src_p + H_motion
    # solve homo using DLT
    H = torch_DLT.tensor_DLT(src_p, dst_p)


    rigid_mesh = get_rigid_mesh(batch_size, img_h, img_w)
    ini_mesh = H2Mesh(H, rigid_mesh)
    mesh = ini_mesh + mesh_motion

    width_max = torch.max(mesh[...,0])
    width_max = torch.maximum(torch.tensor(img_w).cuda(), width_max)
    width_min = torch.min(mesh[...,0])
    width_min = torch.minimum(torch.tensor(0).cuda(), width_min)
    height_max = torch.max(mesh[...,1])
    height_max = torch.maximum(torch.tensor(img_h).cuda(), height_max)
    height_min = torch.min(mesh[...,1])
    height_min = torch.minimum(torch.tensor(0).cuda(), height_min)

    out_width = width_max - width_min
    out_height = height_max - height_min
    #print(out_width)
    #print(out_height)

    # get warped img1
    M_tensor = torch.tensor([[out_width / 2.0, 0., out_width / 2.0],
                      [0., out_height / 2.0, out_height / 2.0],
                      [0., 0., 1.]])
    N_tensor = torch.tensor([[img_w / 2.0, 0., img_w / 2.0],
                      [0., img_h / 2.0, img_h / 2.0],
                      [0., 0., 1.]])
    if torch.cuda.is_available():
        M_tensor = M_tensor.cuda()
        N_tensor = N_tensor.cuda()
    N_tensor_inv = torch.inverse(N_tensor)

    I_ = torch.tensor([[1., 0., width_min],
                      [0., 1., height_min],
                      [0., 0., 1.]])#.unsqueeze(0)
    mask = torch.ones_like(input2_tensor)
    if torch.cuda.is_available():
        I_ = I_.cuda()
        mask = mask.cuda()
    I_mat = torch.matmul(torch.matmul(N_tensor_inv, I_), M_tensor).unsqueeze(0)

    homo_output = torch_homo_transform.transformer(torch.cat((input1_tensor+1, mask), 1), I_mat, (out_height.int(), out_width.int()))

    torch.cuda.empty_cache()
    # get warped img2
    mesh_trans = torch.stack([mesh[...,0]-width_min, mesh[...,1]-height_min], 3)
    norm_rigid_mesh = get_norm_mesh(rigid_mesh, img_h, img_w)
    norm_mesh = get_norm_mesh(mesh_trans, out_height, out_width)
    tps_output = torch_tps_transform.transformer(torch.cat([input2_tensor+1, mask],1), norm_mesh, norm_rigid_mesh, (out_height.int(), out_width.int()))


    out_dict = {}
    out_dict.update(final_warp1=homo_output[:, 0:3, ...]-1, final_warp1_mask = homo_output[:, 3:6, ...], final_warp2=tps_output[:, 0:3, ...]-1, final_warp2_mask = tps_output[:, 3:6, ...], mesh1=rigid_mesh, mesh2=mesh_trans)

    return out_dict



# define and forward
class Network(nn.Module):

    def __init__(self):
        super(Network, self).__init__()

        self.regressNet1_part1 = nn.Sequential(
            nn.Conv2d(2, 64, kernel_size=3, padding=1, bias=False),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, kernel_size=3, padding=1, bias=False),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),

            nn.Conv2d(64, 128, kernel_size=3, padding=1, bias=False),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 128, kernel_size=3, padding=1, bias=False),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),

            nn.Conv2d(128, 256, kernel_size=3, padding=1, bias=False),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1, bias=False),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2)
        )

        self.regressNet1_part2 = nn.Sequential(
            nn.Linear(in_features=4096, out_features=4096, bias=True),
            nn.ReLU(inplace=True),

            nn.Linear(in_features=4096, out_features=1024, bias=True),
            nn.ReLU(inplace=True),

            nn.Linear(in_features=1024, out_features=8, bias=True)
        )


        self.regressNet2_part1 = nn.Sequential(
            nn.Conv2d(2, 64, kernel_size=3, padding=1, bias=False),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, kernel_size=3, padding=1, bias=False),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),

            nn.Conv2d(64, 128, kernel_size=3, padding=1, bias=False),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 128, kernel_size=3, padding=1, bias=False),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),

            nn.Conv2d(128, 256, kernel_size=3, padding=1, bias=False),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1, bias=False),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),

            nn.Conv2d(256, 512, kernel_size=3, padding=1, bias=False),
            nn.ReLU(inplace=True),
            nn.Conv2d(512, 512, kernel_size=3, padding=1, bias=False),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2)
        )

        self.regressNet2_part2 = nn.Sequential(
            nn.Linear(in_features=8192, out_features=4096, bias=True),
            nn.ReLU(inplace=True),

            nn.Linear(in_features=4096, out_features=2048, bias=True),
            nn.ReLU(inplace=True),

            nn.Linear(in_features=2048, out_features=(grid_w+1)*(grid_h+1)*2, bias=True)

        )


        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight)
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

        ssl._create_default_https_context = ssl._create_unverified_context
        resnet50_model = models.resnet.resnet50(pretrained=True)

        if torch.cuda.is_available():
            resnet50_model = resnet50_model.cuda()
        self.feature_extractor_stage1, self.feature_extractor_stage2 = self.get_res50_FeatureMap(resnet50_model)
        #-----------------------------------------

    def get_res50_FeatureMap(self, resnet50_model):

        layers_list = []

        layers_list.append(resnet50_model.conv1)
        layers_list.append(resnet50_model.bn1)
        layers_list.append(resnet50_model.relu)
        layers_list.append(resnet50_model.maxpool)
        layers_list.append(resnet50_model.layer1)
        layers_list.append(resnet50_model.layer2)

        feature_extractor_stage1 = nn.Sequential(*layers_list)

        feature_extractor_stage2 = nn.Sequential(resnet50_model.layer3)

        #layers_list.append(resnet50_model.layer3)

        return feature_extractor_stage1, feature_extractor_stage2

    # forward
    def forward(self, input1_tesnor, input2_tesnor):
        batch_size, _, img_h, img_w = input1_tesnor.size()

        feature_1_64 = self.feature_extractor_stage1(input1_tesnor)
        feature_1_32 = self.feature_extractor_stage2(feature_1_64)
        feature_2_64 = self.feature_extractor_stage1(input2_tesnor)
        feature_2_32 = self.feature_extractor_stage2(feature_2_64)

        ######### stage 1
        correlation_32 = self.CCL(feature_1_32, feature_2_32)
        temp_1 = self.regressNet1_part1(correlation_32)
        temp_1 = temp_1.view(temp_1.size()[0], -1)
        offset_1 = self.regressNet1_part2(temp_1)
        H_motion_1 = offset_1.reshape(-1, 4, 2)


        src_p = torch.tensor([[0., 0.], [img_w, 0.], [0., img_h], [img_w, img_h]])
        if torch.cuda.is_available():
            src_p = src_p.cuda()
        src_p = src_p.unsqueeze(0).expand(batch_size, -1, -1)
        dst_p = src_p + H_motion_1
        H = torch_DLT.tensor_DLT(src_p/8, dst_p/8)

        M_tensor = torch.tensor([[img_w/8 / 2.0, 0., img_w/8 / 2.0],
                      [0., img_h/8 / 2.0, img_h/8 / 2.0],
                      [0., 0., 1.]])

        if torch.cuda.is_available():
            M_tensor = M_tensor.cuda()

        M_tile = M_tensor.unsqueeze(0).expand(batch_size, -1, -1)
        M_tensor_inv = torch.inverse(M_tensor)
        M_tile_inv = M_tensor_inv.unsqueeze(0).expand(batch_size, -1, -1)
        H_mat = torch.matmul(torch.matmul(M_tile_inv, H), M_tile)

        warp_feature_2_64 = torch_homo_transform.transformer(feature_2_64, H_mat, (int(img_h/8), int(img_w/8)))

        ######### stage 2
        correlation_64 = self.CCL(feature_1_64, warp_feature_2_64)
        temp_2 = self.regressNet2_part1(correlation_64)
        temp_2 = temp_2.view(temp_2.size()[0], -1)
        offset_2 = self.regressNet2_part2(temp_2)


        return offset_1, offset_2


    def extract_patches(self, x, kernel=3, stride=1):
        if kernel != 1:
            x = nn.ZeroPad2d(1)(x)
        x = x.permute(0, 2, 3, 1)
        all_patches = x.unfold(1, kernel, stride).unfold(2, kernel, stride)
        return all_patches


    def CCL(self, feature_1, feature_2):
        bs, c, h, w = feature_1.size()

        norm_feature_1 = F.normalize(feature_1, p=2, dim=1)
        norm_feature_2 = F.normalize(feature_2, p=2, dim=1)
        #print(norm_feature_2.size())

        patches = self.extract_patches(norm_feature_2)
        if torch.cuda.is_available():
            patches = patches.cuda()

        matching_filters  = patches.reshape((patches.size()[0], -1, patches.size()[3], patches.size()[4], patches.size()[5]))

        match_vol = []
        for i in range(bs):
            single_match = F.conv2d(norm_feature_1[i].unsqueeze(0), matching_filters[i], padding=1)
            match_vol.append(single_match)

        match_vol = torch.cat(match_vol, 0)
        #print(match_vol .size())

        # scale softmax
        softmax_scale = 10
        match_vol = F.softmax(match_vol*softmax_scale,1)

        channel = match_vol.size()[1]

        h_one = torch.linspace(0, h-1, h)
        one1w = torch.ones(1, w)
        if torch.cuda.is_available():
            h_one = h_one.cuda()
            one1w = one1w.cuda()
        h_one = torch.matmul(h_one.unsqueeze(1), one1w)
        h_one = h_one.unsqueeze(0).unsqueeze(0).expand(bs, channel, -1, -1)

        w_one = torch.linspace(0, w-1, w)
        oneh1 = torch.ones(h, 1)
        if torch.cuda.is_available():
            w_one = w_one.cuda()
            oneh1 = oneh1.cuda()
        w_one = torch.matmul(oneh1, w_one.unsqueeze(0))
        w_one = w_one.unsqueeze(0).unsqueeze(0).expand(bs, channel, -1, -1)

        c_one = torch.linspace(0, channel-1, channel)
        if torch.cuda.is_available():
            c_one = c_one.cuda()
        c_one = c_one.unsqueeze(0).unsqueeze(2).unsqueeze(3).expand(bs, -1, h, w)

        flow_h = match_vol*(c_one//w - h_one)
        flow_h = torch.sum(flow_h, dim=1, keepdim=True)
        flow_w = match_vol*(c_one%w - w_one)
        flow_w = torch.sum(flow_w, dim=1, keepdim=True)

        feature_flow = torch.cat([flow_w, flow_h], 1)
        #print(flow.size())

        return feature_flow


================================================
FILE: Warp/Codes/test.py
================================================
# coding: utf-8
import argparse
import torch
from torch.utils.data import DataLoader
import torch.nn as nn
import imageio
from network import build_model, Network
from dataset import *
import os
import numpy as np
import skimage
import cv2


last_path = os.path.abspath(os.path.join(os.path.dirname("__file__"), os.path.pardir))
MODEL_DIR = os.path.join(last_path, 'model')

def create_gif(image_list, gif_name, duration=0.35):
    frames = []
    for image_name in image_list:
        frames.append(image_name)
    imageio.mimsave(gif_name, frames, 'GIF', duration=0.5)
    return


def test(args):

    os.environ['CUDA_DEVICES_ORDER'] = "PCI_BUS_ID"
    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu

    # dataset
    test_data = TestDataset(data_path=args.test_path)
    test_loader = DataLoader(dataset=test_data, batch_size=args.batch_size, num_workers=1, shuffle=False, drop_last=False)

    # define the network
    net = Network()#build_model(args.model_name)
    if torch.cuda.is_available():
        net = net.cuda()

    #load the existing models if it exists
    ckpt_list = glob.glob(MODEL_DIR + "/*.pth")
    ckpt_list.sort()
    if len(ckpt_list) != 0:
        model_path = ckpt_list[-1]
        checkpoint = torch.load(model_path)
        net.load_state_dict(checkpoint['model'])
        print('load model from {}!'.format(model_path))
    else:
        print('No checkpoint found!')



    print("##################start testing#######################")
    psnr_list = []
    ssim_list = []
    net.eval()
    for i, batch_value in enumerate(test_loader):

        inpu1_tesnor = batch_value[0].float()
        inpu2_tesnor = batch_value[1].float()

        if torch.cuda.is_available():
            inpu1_tesnor = inpu1_tesnor.cuda()
            inpu2_tesnor = inpu2_tesnor.cuda()

            with torch.no_grad():
                batch_out = build_model(net, inpu1_tesnor, inpu2_tesnor, is_training=False)

            warp_mesh_mask = batch_out['warp_mesh_mask']
            warp_mesh = batch_out['warp_mesh']


            warp_mesh_np = ((warp_mesh[0]+1)*127.5).cpu().detach().numpy().transpose(1,2,0)
            warp_mesh_mask_np = warp_mesh_mask[0].cpu().detach().numpy().transpose(1,2,0)
            inpu1_np = ((inpu1_tesnor[0]+1)*127.5).cpu().detach().numpy().transpose(1,2,0)


            # calculate psnr/ssim
            psnr = skimage.measure.compare_psnr(inpu1_np*warp_mesh_mask_np, warp_mesh_np*warp_mesh_mask_np, 255)
            ssim = skimage.measure.compare_ssim(inpu1_np*warp_mesh_mask_np, warp_mesh_np*warp_mesh_mask_np, data_range=255, multichannel=True)


            print('i = {}, psnr = {:.6f}'.format( i+1, psnr))

            psnr_list.append(psnr)
            ssim_list.append(ssim)
            torch.cuda.empty_cache()

    print("=================== Analysis ==================")
    print("psnr")
    psnr_list.sort(reverse = True)
    psnr_list_30 = psnr_list[0 : 331]
    psnr_list_60 = psnr_list[331: 663]
    psnr_list_100 = psnr_list[663: -1]
    print("top 30%", np.mean(psnr_list_30))
    print("top 30~60%", np.mean(psnr_list_60))
    print("top 60~100%", np.mean(psnr_list_100))
    print('average psnr:', np.mean(psnr_list))

    ssim_list.sort(reverse = True)
    ssim_list_30 = ssim_list[0 : 331]
    ssim_list_60 = ssim_list[331: 663]
    ssim_list_100 = ssim_list[663: -1]
    print("top 30%", np.mean(ssim_list_30))
    print("top 30~60%", np.mean(ssim_list_60))
    print("top 60~100%", np.mean(ssim_list_100))
    print('average ssim:', np.mean(ssim_list))
    print("##################end testing#######################")


if __name__=="__main__":

    parser = argparse.ArgumentParser()

    parser.add_argument('--gpu', type=str, default='0')
    parser.add_argument('--batch_size', type=int, default=1)
    parser.add_argument('--test_path', type=str, default='/opt/data/private/nl/Data/UDIS-D/testing/')

    print('<==================== Loading data ===================>\n')

    args = parser.parse_args()
    print(args)
    test(args)


================================================
FILE: Warp/Codes/test_other.py
================================================
import argparse
import torch

import numpy as np
import os
import torch.nn as nn
import torch.optim as optim

import cv2
#from torch_homography_model import build_model
from network import get_stitched_result, Network, build_new_ft_model

import glob
from loss import cal_lp_loss2
import torchvision.transforms as T

#import PIL
resize_512 = T.Resize((512,512))


def loadSingleData(data_path, img1_name, img2_name):

    # load image1
    input1 = cv2.imread(data_path+img1_name)
    input1 = input1.astype(dtype=np.float32)
    input1 = (input1 / 127.5) - 1.0
    input1 = np.transpose(input1, [2, 0, 1])

    # load image2
    input2 = cv2.imread(data_path+img2_name)
    input2 = input2.astype(dtype=np.float32)
    input2 = (input2 / 127.5) - 1.0
    input2 = np.transpose(input2, [2, 0, 1])

    # convert to tensor
    input1_tensor = torch.tensor(input1).unsqueeze(0)
    input2_tensor = torch.tensor(input2).unsqueeze(0)
    return (input1_tensor, input2_tensor)



# path of project
#nl: os.path.dirname("__file__") ----- the current absolute path
#nl: os.path.pardir ---- the last path
last_path = os.path.abspath(os.path.join(os.path.dirname("__file__"), os.path.pardir))


#nl: path to save the model files
MODEL_DIR = os.path.join(last_path, 'model')

#nl: create folders if it dose not exist
if not os.path.exists(MODEL_DIR):
    os.makedirs(MODEL_DIR)


def train(args):

    os.environ['CUDA_DEVICES_ORDER'] = "PCI_BUS_ID"
    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
    
    # define the network
    net = Network()
    if torch.cuda.is_available():
        net = net.cuda()

    # define the optimizer and learning rate
    optimizer = optim.Adam(net.parameters(), lr=1e-4, betas=(0.9, 0.999), eps=1e-08)  # default as 0.0001
    scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.97)

    #load the existing models if it exists
    ckpt_list = glob.glob(MODEL_DIR + "/*.pth")
    ckpt_list.sort()
    if len(ckpt_list) != 0:
        model_path = ckpt_list[-1]
        checkpoint = torch.load(model_path)

        net.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        start_epoch = checkpoint['epoch']
        scheduler.last_epoch = start_epoch
        print('load model from {}!'.format(model_path))
    else:
        start_epoch = 0
        print('training from stratch!')

    # load dataset(only one pair of images)
    input1_tensor, input2_tensor = loadSingleData(data_path=args.path, img1_name = args.img1_name, img2_name = args.img2_name)
    if torch.cuda.is_available():
        input1_tensor = input1_tensor.cuda()
        input2_tensor = input2_tensor.cuda()

    input1_tensor_512 = resize_512(input1_tensor)
    input2_tensor_512 = resize_512(input2_tensor)

    loss_list = []

    print("##################start iteration#######################")
    for epoch in range(start_epoch, start_epoch + args.max_iter):
        net.train()

        optimizer.zero_grad()

        batch_out = build_new_ft_model(net, input1_tensor_512, input2_tensor_512)
        warp_mesh = batch_out['warp_mesh']
        warp_mesh_mask = batch_out['warp_mesh_mask']
        rigid_mesh = batch_out['rigid_mesh']
        mesh = batch_out['mesh']

        total_loss = cal_lp_loss2(input1_tensor_512, warp_mesh, warp_mesh_mask)
        total_loss.backward()
        # clip the gradient
        torch.nn.utils.clip_grad_norm_(net.parameters(), max_norm=3, norm_type=2)
        optimizer.step()

        current_iter = epoch-start_epoch+1
        print("Training: Iteration[{:0>3}/{:0>3}] Total Loss: {:.4f} lr={:.8f}".format(current_iter, args.max_iter, total_loss, optimizer.state_dict()['param_groups'][0]['lr']))


        loss_list.append(total_loss)


        if current_iter == 1:
            with torch.no_grad():
                output = get_stitched_result(input1_tensor, input2_tensor, rigid_mesh, mesh)
            cv2.imwrite( args.path+ 'before_optimization.jpg', output['stitched'][0].cpu().detach().numpy().transpose(1,2,0))
            cv2.imwrite( args.path+ 'before_optimization_mesh.jpg', output['stitched_mesh'])


        if current_iter >= 4:
            if torch.abs(loss_list[current_iter-4]-loss_list[current_iter-3]) <= 1e-4 and torch.abs(loss_list[current_iter-3]-loss_list[current_iter-2]) <= 1e-4 \
            and torch.abs(loss_list[current_iter-2]-loss_list[current_iter-1]) <= 1e-4:
                with torch.no_grad():
                    output = get_stitched_result(input1_tensor, input2_tensor, rigid_mesh, mesh)

                path = args.path + "iter-" + str(epoch-start_epoch+1).zfill(3) + ".jpg"
                cv2.imwrite(path, output['stitched'][0].cpu().detach().numpy().transpose(1,2,0))
                cv2.imwrite(args.path + "iter-" + str(epoch-start_epoch+1).zfill(3) + "_mesh.jpg", output['stitched_mesh'])
                cv2.imwrite( args.path+'warp1.jpg', output['warp1'][0].cpu().detach().numpy().transpose(1,2,0))
                cv2.imwrite( args.path+'warp2.jpg', output['warp2'][0].cpu().detach().numpy().transpose(1,2,0))
                cv2.imwrite( args.path+'mask1.jpg', output['mask1'][0].cpu().detach().numpy().transpose(1,2,0))
                cv2.imwrite( args.path+'mask2.jpg', output['mask2'][0].cpu().detach().numpy().transpose(1,2,0))
                break

        if current_iter == args.max_iter:
            with torch.no_grad():
                output = get_stitched_result(input1_tensor, input2_tensor, rigid_mesh, mesh)

            path = args.path + "iter-" + str(epoch-start_epoch+1).zfill(3) + ".jpg"
            cv2.imwrite(path, output['stitched'][0].cpu().detach().numpy().transpose(1,2,0))
            cv2.imwrite(args.path + "iter-" + str(epoch-start_epoch+1).zfill(3) + "_mesh.jpg", output['stitched_mesh'])
            cv2.imwrite( args.path+'warp1.jpg', output['warp1'][0].cpu().detach().numpy().transpose(1,2,0))
            cv2.imwrite( args.path+'warp2.jpg', output['warp2'][0].cpu().detach().numpy().transpose(1,2,0))
            cv2.imwrite( args.path+'mask1.jpg', output['mask1'][0].cpu().detach().numpy().transpose(1,2,0))
            cv2.imwrite( args.path+'mask2.jpg', output['mask2'][0].cpu().detach().numpy().transpose(1,2,0))

        scheduler.step()

    print("##################end iteration#######################")


if __name__=="__main__":


    print('<==================== setting arguments ===================>\n')

    #nl: create the argument parser
    parser = argparse.ArgumentParser()

    #nl: add arguments
    parser.add_argument('--gpu', type=str, default='0')
    parser.add_argument('--max_iter', type=int, default=50)
    parser.add_argument('--path', type=str, default='../../Carpark-DHW/')
    parser.add_argument('--img1_name', type=str, default='input1.jpg')
    parser.add_argument('--img2_name', type=str, default='input2.jpg')

    #nl: parse the arguments
    args = parser.parse_args()
    print(args)

    #nl: rain
    train(args)




================================================
FILE: Warp/Codes/test_output.py
================================================
# coding: utf-8
import argparse
import torch
from torch.utils.data import DataLoader
import torch.nn as nn
import imageio
from network import build_output_model, Network
from dataset import *
import os
import cv2

import grid_res
grid_h = grid_res.GRID_H
grid_w = grid_res.GRID_W

last_path = os.path.abspath(os.path.join(os.path.dirname("__file__"), os.path.pardir))
MODEL_DIR = os.path.join(last_path, 'model')


def draw_mesh_on_warp(warp, f_local):


    point_color = (0, 255, 0) # BGR
    thickness = 2
    lineType = 8

    num = 1
    for i in range(grid_h+1):
        for j in range(grid_w+1):

            num = num + 1
            if j == grid_w and i == grid_h:
                continue
            elif j == grid_w:
                cv2.line(warp, (int(f_local[i,j,0]), int(f_local[i,j,1])), (int(f_local[i+1,j,0]), int(f_local[i+1,j,1])), point_color, thickness, lineType)
            elif i == grid_h:
                cv2.line(warp, (int(f_local[i,j,0]), int(f_local[i,j,1])), (int(f_local[i,j+1,0]), int(f_local[i,j+1,1])), point_color, thickness, lineType)
            else :
                cv2.line(warp, (int(f_local[i,j,0]), int(f_local[i,j,1])), (int(f_local[i+1,j,0]), int(f_local[i+1,j,1])), point_color, thickness, lineType)
                cv2.line(warp, (int(f_local[i,j,0]), int(f_local[i,j,1])), (int(f_local[i,j+1,0]), int(f_local[i,j+1,1])), point_color, thickness, lineType)

    return warp

def create_gif(image_list, gif_name, duration=0.35):
    frames = []
    for image_name in image_list:
        frames.append(image_name)
    imageio.mimsave(gif_name, frames, 'GIF', duration=0.5)
    return




def test(args):

    os.environ['CUDA_DEVICES_ORDER'] = "PCI_BUS_ID"
    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
    
    # dataset
    test_data = TestDataset(data_path=args.test_path)
    #nl: set num_workers = the number of cpus
    test_loader = DataLoader(dataset=test_data, batch_size=args.batch_size, num_workers=1, shuffle=False, drop_last=False)

    # define the network
    net = Network()#build_model(args.model_name)
    if torch.cuda.is_available():
        net = net.cuda()

    #load the existing models if it exists
    ckpt_list = glob.glob(MODEL_DIR + "/*.pth")
    ckpt_list.sort()
    if len(ckpt_list) != 0:
        model_path = ckpt_list[-1]
        #model_path = '/opt/data/private/nl/Repository/Unsupervised_Mesh_Stitching/UDISv2-88/UDISv2-Homo_TPS88-10grid_NO-res50-new3/model/epoch150_model.pth'
        checkpoint = torch.load(model_path)

        net.load_state_dict(checkpoint['model'])
        print('load model from {}!'.format(model_path))
    else:
        print('No checkpoint found!')



    print("##################start testing#######################")
    # create folders if it dose not exist

    path_ave_fusion = '../ave_fusion/'
    if not os.path.exists(path_ave_fusion):
        os.makedirs(path_ave_fusion)
    path_warp1 = args.test_path + 'warp1/'
    if not os.path.exists(path_warp1):
        os.makedirs(path_warp1)
    path_warp2 = args.test_path + 'warp2/'
    if not os.path.exists(path_warp2):
        os.makedirs(path_warp2)
    path_mask1 = args.test_path + 'mask1/'
    if not os.path.exists(path_mask1):
        os.makedirs(path_mask1)
    path_mask2 = args.test_path + 'mask2/'
    if not os.path.exists(path_mask2):
        os.makedirs(path_mask2)



    net.eval()
    for i, batch_value in enumerate(test_loader):

        #if i != 975:
        #    continue

        inpu1_tesnor = batch_value[0].float()
        inpu2_tesnor = batch_value[1].float()

        if torch.cuda.is_available():
            inpu1_tesnor = inpu1_tesnor.cuda()
            inpu2_tesnor = inpu2_tesnor.cuda()

        with torch.no_grad():
            batch_out = build_output_model(net, inpu1_tesnor, inpu2_tesnor)

        final_warp1 = batch_out['final_warp1']
        final_warp1_mask = batch_out['final_warp1_mask']
        final_warp2 = batch_out['final_warp2']
        final_warp2_mask = batch_out['final_warp2_mask']
        final_mesh1 = batch_out['mesh1']
        final_mesh2 = batch_out['mesh2']


        final_warp1 = ((final_warp1[0]+1)*127.5).cpu().detach().numpy().transpose(1,2,0)
        final_warp2 = ((final_warp2[0]+1)*127.5).cpu().detach().numpy().transpose(1,2,0)
        final_warp1_mask = final_warp1_mask[0].cpu().detach().numpy().transpose(1,2,0)
        final_warp2_mask = final_warp2_mask[0].cpu().detach().numpy().transpose(1,2,0)
        final_mesh1 = final_mesh1[0].cpu().detach().numpy()
        final_mesh2 = final_mesh2[0].cpu().detach().numpy()



        path = path_warp1 + str(i+1).zfill(6) + ".jpg"
        cv2.imwrite(path, final_warp1)
        path = path_warp2 + str(i+1).zfill(6) + ".jpg"
        cv2.imwrite(path, final_warp2)
        path = path_mask1 + str(i+1).zfill(6) + ".jpg"
        cv2.imwrite(path, final_warp1_mask*255)
        path = path_mask2 + str(i+1).zfill(6) + ".jpg"
        cv2.imwrite(path, final_warp2_mask*255)

        ave_fusion = final_warp1 * (final_warp1/ (final_warp1+final_warp2+1e-6)) + final_warp2 * (final_warp2/ (final_warp1+final_warp2+1e-6))
        path = path_ave_fusion + str(i+1).zfill(6) + ".jpg"
        cv2.imwrite(path, ave_fusion)

        print('i = {}'.format( i+1))

        torch.cuda.empty_cache()




    print("##################end testing#######################")


if __name__=="__main__":

    parser = argparse.ArgumentParser()

    parser.add_argument('--gpu', type=str, default='0')
    parser.add_argument('--batch_size', type=int, default=1)
    
    # /opt/data/private/nl/Data/UDIS-D/testing/  or  /opt/data/private/nl/Data/UDIS-D/training/
    parser.add_argument('--test_path', type=str, default='/opt/data/private/nl/Data/UDIS-D/testing/')


    print('<==================== Loading data ===================>\n')

    args = parser.parse_args()
    print(args)
    test(args)


================================================
FILE: Warp/Codes/train.py
================================================
import argparse
import torch
from torch.utils.data import DataLoader
import os
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter
from network import build_model, Network
from dataset import TrainDataset
import glob
from loss import cal_lp_loss, inter_grid_loss, intra_grid_loss



last_path = os.path.abspath(os.path.join(os.path.dirname("__file__"), os.path.pardir))
# path to save the summary files
SUMMARY_DIR = os.path.join(last_path, 'summary')
writer = SummaryWriter(log_dir=SUMMARY_DIR)
# path to save the model files
MODEL_DIR = os.path.join(last_path, 'model')
# create folders if it dose not exist
if not os.path.exists(MODEL_DIR):
    os.makedirs(MODEL_DIR)
if not os.path.exists(SUMMARY_DIR):
    os.makedirs(SUMMARY_DIR)



def train(args):

    os.environ['CUDA_DEVICES_ORDER'] = "PCI_BUS_ID"
    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
    
    # define dataset
    train_data = TrainDataset(data_path=args.train_path)
    train_loader = DataLoader(dataset=train_data, batch_size=args.batch_size, num_workers=4, shuffle=True, drop_last=True)

    # define the network
    net = Network()
    if torch.cuda.is_available():
        net = net.cuda()

    # define the optimizer and learning rate
    optimizer = optim.Adam(net.parameters(), lr=1e-4, betas=(0.9, 0.999), eps=1e-08)  # default as 0.0001
    scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.97)

    #load the existing models if it exists
    ckpt_list = glob.glob(MODEL_DIR + "/*.pth")
    ckpt_list.sort()
    if len(ckpt_list) != 0:
        model_path = ckpt_list[-1]
        checkpoint = torch.load(model_path)

        net.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        start_epoch = checkpoint['epoch']
        glob_iter = checkpoint['glob_iter']
        scheduler.last_epoch = start_epoch
        print('load model from {}!'.format(model_path))
    else:
        start_epoch = 0
        glob_iter = 0
        print('training from stratch!')



    print("##################start training#######################")
    score_print_fre = 300

    for epoch in range(start_epoch, args.max_epoch):

        print("start epoch {}".format(epoch))
        net.train()
        loss_sigma = 0.0
        overlap_loss_sigma = 0.
        nonoverlap_loss_sigma = 0.

        print(epoch, 'lr={:.6f}'.format(optimizer.state_dict()['param_groups'][0]['lr']))

        for i, batch_value in enumerate(train_loader):

            inpu1_tesnor = batch_value[0].float()
            inpu2_tesnor = batch_value[1].float()

            if torch.cuda.is_available():
                inpu1_tesnor = inpu1_tesnor.cuda()
                inpu2_tesnor = inpu2_tesnor.cuda()

            # forward, backward, update weights
            optimizer.zero_grad()

            batch_out = build_model(net, inpu1_tesnor, inpu2_tesnor)
            # result
            output_H = batch_out['output_H']
            output_H_inv = batch_out['output_H_inv']
            warp_mesh = batch_out['warp_mesh']
            warp_mesh_mask = batch_out['warp_mesh_mask']
            mesh1 = batch_out['mesh1']
            mesh2 = batch_out['mesh2']
            overlap = batch_out['overlap']

            # calculate loss for overlapping regions
            overlap_loss = cal_lp_loss(inpu1_tesnor, inpu2_tesnor, output_H, output_H_inv, warp_mesh, warp_mesh_mask)
            # calculate loss for non-overlapping regions
            nonoverlap_loss = 10*inter_grid_loss(overlap, mesh2) + 10*intra_grid_loss(mesh2)

            total_loss = overlap_loss + nonoverlap_loss
            total_loss.backward()

            # clip the gradient
            torch.nn.utils.clip_grad_norm_(net.parameters(), max_norm=3, norm_type=2)
            optimizer.step()

            overlap_loss_sigma += overlap_loss.item()
            nonoverlap_loss_sigma += nonoverlap_loss.item()
            loss_sigma += total_loss.item()

            print(glob_iter)

            # record loss and images in tensorboard
            if i % score_print_fre == 0 and i != 0:
                average_loss = loss_sigma / score_print_fre
                average_overlap_loss = overlap_loss_sigma/ score_print_fre
                average_nonoverlap_loss = nonoverlap_loss_sigma/ score_print_fre
                loss_sigma = 0.0
                overlap_loss_sigma = 0.
                nonoverlap_loss_sigma = 0.

                print("Training: Epoch[{:0>3}/{:0>3}] Iteration[{:0>3}]/[{:0>3}] Total Loss: {:.4f}  Overlap Loss: {:.4f}  Non-overlap Loss: {:.4f} lr={:.8f}".format(epoch + 1, args.max_epoch, i + 1, len(train_loader),
                                          average_loss, average_overlap_loss, average_nonoverlap_loss, optimizer.state_dict()['param_groups'][0]['lr']))
                # visualization
                writer.add_image("inpu1", (inpu1_tesnor[0]+1.)/2., glob_iter)
                writer.add_image("inpu2", (inpu2_tesnor[0]+1.)/2., glob_iter)
                writer.add_image("warp_H", (output_H[0,0:3,:,:]+1.)/2., glob_iter)
                writer.add_image("warp_mesh", (warp_mesh[0]+1.)/2., glob_iter)
                writer.add_scalar('lr', optimizer.state_dict()['param_groups'][0]['lr'], glob_iter)
                writer.add_scalar('total loss', average_loss, glob_iter)
                writer.add_scalar('overlap loss', average_overlap_loss, glob_iter)
                writer.add_scalar('nonoverlap loss', average_nonoverlap_loss, glob_iter)

            glob_iter += 1


        scheduler.step()
        # save model
        if ((epoch+1) % 10 == 0 or (epoch+1)==args.max_epoch):
            filename ='epoch' + str(epoch+1).zfill(3) + '_model.pth'
            model_save_path = os.path.join(MODEL_DIR, filename)
            state = {'model': net.state_dict(), 'optimizer': optimizer.state_dict(), 'epoch': epoch+1, "glob_iter": glob_iter}
            torch.save(state, model_save_path)
    print("##################end training#######################")


if __name__=="__main__":


    print('<==================== setting arguments ===================>\n')

    # create the argument parser
    parser = argparse.ArgumentParser()

    # add arguments
    parser.add_argument('--gpu', type=str, default='0')
    parser.add_argument('--batch_size', type=int, default=4)
    parser.add_argument('--max_epoch', type=int, default=100)
    parser.add_argument('--train_path', type=str, default='/opt/data/private/nl/Data/UDIS-D/training/')

    # parse the arguments
    args = parser.parse_args()
    print(args)

    # train
    train(args)




================================================
FILE: Warp/Codes/utils/torch_DLT.py
================================================
import torch
import numpy as np
import cv2

# src_p: shape=(bs, 4, 2)
# det_p: shape=(bs, 4, 2)
#
#                                     | h1 |
#                                     | h2 |                   
#                                     | h3 |
# | x1 y1 1  0  0  0  -x1x2  -y1x2 |  | h4 |  =  | x2 |
# | 0  0  0  x1 y1 1  -x1y2  -y1y2 |  | h5 |     | y2 |
#                                     | h6 |
#                                     | h7 |
#                                     | h8 |

def tensor_DLT(src_p, dst_p):
   
    bs, _, _ = src_p.shape

    ones = torch.ones(bs, 4, 1)
    if torch.cuda.is_available():
        ones = ones.cuda()
    xy1 = torch.cat((src_p, ones), 2)
    zeros = torch.zeros_like(xy1)
    if torch.cuda.is_available():
        zeros = zeros.cuda()

    xyu, xyd = torch.cat((xy1, zeros), 2), torch.cat((zeros, xy1), 2)
    M1 = torch.cat((xyu, xyd), 2).reshape(bs, -1, 6)
    M2 = torch.matmul(
        dst_p.reshape(-1, 2, 1), 
        src_p.reshape(-1, 1, 2),
    ).reshape(bs, -1, 2)
    
    # Ah = b
    A = torch.cat((M1, -M2), 2)
    b = dst_p.reshape(bs, -1, 1)
    
    #h = A^{-1}b
    Ainv = torch.inverse(A)
    h8 = torch.matmul(Ainv, b).reshape(bs, 8)
 
    H = torch.cat((h8, ones[:,0,:]), 1).reshape(bs, 3, 3)
    return H

================================================
FILE: Warp/Codes/utils/torch_homo_transform.py
================================================
import torch
import numpy as np


def transformer(U, theta, out_size, **kwargs):


    def _repeat(x, n_repeats):

        rep = torch.ones([n_repeats, ]).unsqueeze(0)
        rep = rep.int()
        x = x.int()

        x = torch.matmul(x.reshape([-1,1]), rep)
        return x.reshape([-1])

    def _interpolate(im, x, y, out_size):

        num_batch, num_channels , height, width = im.size()

        height_f = height
        width_f = width
        out_height, out_width = out_size[0], out_size[1]

        zero = 0
        max_y = height - 1
        max_x = width - 1

        x = (x + 1.0)*(width_f) / 2.0
        y = (y + 1.0) * (height_f) / 2.0

        # do sampling
        x0 = torch.floor(x).int()
        x1 = x0 + 1
        y0 = torch.floor(y).int()
        y1 = y0 + 1

        x0 = torch.clamp(x0, zero, max_x)
        x1 = torch.clamp(x1, zero, max_x)
        y0 = torch.clamp(y0, zero, max_y)
        y1 = torch.clamp(y1, zero, max_y)
        dim2 = torch.from_numpy( np.array(width) )
        dim1 = torch.from_numpy( np.array(width * height) )

        base = _repeat(torch.arange(0,num_batch) * dim1, out_height * out_width)
        if torch.cuda.is_available():
            dim2 = dim2.cuda()
            dim1 = dim1.cuda()
            y0 = y0.cuda()
            y1 = y1.cuda()
            x0 = x0.cuda()
            x1 = x1.cuda()
            base = base.cuda()
        base_y0 = base + y0 * dim2
        base_y1 = base + y1 * dim2
        idx_a = base_y0 + x0
        idx_b = base_y1 + x0
        idx_c = base_y0 + x1
        idx_d = base_y1 + x1

        # channels dim
        im = im.permute(0,2,3,1)
        im_flat = im.reshape([-1, num_channels]).float()

        idx_a = idx_a.unsqueeze(-1).long()
        idx_a = idx_a.expand(out_height * out_width * num_batch,num_channels)
        Ia = torch.gather(im_flat, 0, idx_a)

        idx_b = idx_b.unsqueeze(-1).long()
        idx_b = idx_b.expand(out_height * out_width * num_batch, num_channels)
        Ib = torch.gather(im_flat, 0, idx_b)

        idx_c = idx_c.unsqueeze(-1).long()
        idx_c = idx_c.expand(out_height * out_width * num_batch, num_channels)
        Ic = torch.gather(im_flat, 0, idx_c)

        idx_d = idx_d.unsqueeze(-1).long()
        idx_d = idx_d.expand(out_height * out_width * num_batch, num_channels)
        Id = torch.gather(im_flat, 0, idx_d)

        x0_f = x0.float()
        x1_f = x1.float()
        y0_f = y0.float()
        y1_f = y1.float()

        wa = torch.unsqueeze(((x1_f - x) * (y1_f - y)), 1)
        wb = torch.unsqueeze(((x1_f - x) * (y - y0_f)), 1)
        wc = torch.unsqueeze(((x - x0_f) * (y1_f - y)), 1)
        wd = torch.unsqueeze(((x - x0_f) * (y - y0_f)), 1)
        output = wa*Ia+wb*Ib+wc*Ic+wd*Id

        return output

    def _meshgrid(height, width):

        x_t = torch.matmul(torch.ones([height, 1]),
                               torch.transpose(torch.unsqueeze(torch.linspace(-1.0, 1.0, width), 1), 1, 0))
        y_t = torch.matmul(torch.unsqueeze(torch.linspace(-1.0, 1.0, height), 1),
                               torch.ones([1, width]))
        #x_t = torch.matmul(torch.ones([height, 1]),
        #                       torch.transpose(torch.unsqueeze(torch.linspace(0.0, width.float(), width), 1), 1, 0))
        #y_t = torch.matmul(torch.unsqueeze(torch.linspace(0.0, height.float(), height), 1),
        #                       torch.ones([1, width]))

        x_t_flat = x_t.reshape((1, -1)).float()
        y_t_flat = y_t.reshape((1, -1)).float()

        ones = torch.ones_like(x_t_flat)
        grid = torch.cat([x_t_flat, y_t_flat, ones], 0)
        if torch.cuda.is_available():
            grid = grid.cuda()
        return grid

    def _transform(theta, input_dim, out_size):
        num_batch, num_channels , height, width = input_dim.size()
        #  Changed
        theta = theta.reshape([-1, 3, 3]).float()

        out_height, out_width = out_size[0], out_size[1]
        grid = _meshgrid(out_height, out_width)
        grid = grid.unsqueeze(0).reshape([1,-1])
        shape = grid.size()
        grid = grid.expand(num_batch,shape[1])
        grid = grid.reshape([num_batch, 3, -1])

        T_g = torch.matmul(theta, grid)
        x_s = T_g[:,0,:]
        y_s = T_g[:,1,:]
        t_s = T_g[:,2,:]

        t_s_flat = t_s.reshape([-1])

        # smaller
        small = 1e-7
        smallers = 1e-6*(1.0 - torch.ge(torch.abs(t_s_flat), small).float())

        t_s_flat = t_s_flat + smallers
        #condition = torch.sum(torch.gt(torch.abs(t_s_flat), small).float())
        # Ty changed
        x_s_flat = x_s.reshape([-1]) / t_s_flat
        y_s_flat = y_s.reshape([-1]) / t_s_flat

        input_transformed = _interpolate( input_dim, x_s_flat, y_s_flat,out_size)

        output = input_transformed.reshape([num_batch, out_height, out_width, num_channels])
        output = output.permute(0,3,1,2)
        return output#, condition


    output = _transform(theta, U, out_size)
    return output#, condition

================================================
FILE: Warp/Codes/utils/torch_tps_transform.py
================================================
import torch
import numpy as np

# transforming an image (U) from target (control points) to source (control points)
# all the points should be normalized from -1 ~1

def transformer(U, source, target, out_size):

    def _repeat(x, n_repeats):

        rep = torch.ones([n_repeats, ]).unsqueeze(0)
        rep = rep.int()
        x = x.int()

        x = torch.matmul(x.reshape([-1,1]), rep)
        return x.reshape([-1])

    def _interpolate(im, x, y, out_size):

        num_batch, num_channels , height, width = im.size()

        height_f = height
        width_f = width
        out_height, out_width = out_size[0], out_size[1]

        zero = 0
        max_y = height - 1
        max_x = width - 1

        x = (x + 1.0)*(width_f) / 2.0
        y = (y + 1.0) * (height_f) / 2.0

        # do sampling
        x0 = torch.floor(x).int()
        x1 = x0 + 1
        y0 = torch.floor(y).int()
        y1 = y0 + 1

        x0 = torch.clamp(x0, zero, max_x)
        x1 = torch.clamp(x1, zero, max_x)
        y0 = torch.clamp(y0, zero, max_y)
        y1 = torch.clamp(y1, zero, max_y)
        dim2 = torch.from_numpy( np.array(width) )
        dim1 = torch.from_numpy( np.array(width * height) )

        base = _repeat(torch.arange(0,num_batch) * dim1, out_height * out_width)
        if torch.cuda.is_available():
            dim2 = dim2.cuda()
            dim1 = dim1.cuda()
            y0 = y0.cuda()
            y1 = y1.cuda()
            x0 = x0.cuda()
            x1 = x1.cuda()
            base = base.cuda()
        base_y0 = base + y0 * dim2
        base_y1 = base + y1 * dim2
        idx_a = base_y0 + x0
        idx_b = base_y1 + x0
        idx_c = base_y0 + x1
        idx_d = base_y1 + x1

        # channels dim
        im = im.permute(0,2,3,1)
        im_flat = im.reshape([-1, num_channels]).float()


        idx_a = idx_a.unsqueeze(-1).long()
        idx_a = idx_a.expand(out_height * out_width * num_batch,num_channels)
        Ia = torch.gather(im_flat, 0, idx_a)

        idx_b = idx_b.unsqueeze(-1).long()
        idx_b = idx_b.expand(out_height * out_width * num_batch, num_channels)
        Ib = torch.gather(im_flat, 0, idx_b)

        idx_c = idx_c.unsqueeze(-1).long()
        idx_c = idx_c.expand(out_height * out_width * num_batch, num_channels)
        Ic = torch.gather(im_flat, 0, idx_c)

        idx_d = idx_d.unsqueeze(-1).long()
        idx_d = idx_d.expand(out_height * out_width * num_batch, num_channels)
        Id = torch.gather(im_flat, 0, idx_d)

        x0_f = x0.float()
        x1_f = x1.float()
        y0_f = y0.float()
        y1_f = y1.float()

        wa = torch.unsqueeze(((x1_f - x) * (y1_f - y)), 1)
        wb = torch.unsqueeze(((x1_f - x) * (y - y0_f)), 1)
        wc = torch.unsqueeze(((x - x0_f) * (y1_f - y)), 1)
        wd = torch.unsqueeze(((x - x0_f) * (y - y0_f)), 1)
        output = wa*Ia+wb*Ib+wc*Ic+wd*Id

        return output

    def _meshgrid(height, width, source):

        x_t = torch.matmul(torch.ones([height, 1]), torch.unsqueeze(torch.linspace(-1.0, 1.0, width), 0))
        y_t = torch.matmul(torch.unsqueeze(torch.linspace(-1.0, 1.0, height), 1), torch.ones([1, width]))
        if torch.cuda.is_available():
            x_t = x_t.cuda()
            y_t = y_t.cuda()

        x_t_flat = x_t.reshape([1, 1, -1])
        y_t_flat = y_t.reshape([1, 1, -1])

        num_batch = source.size()[0]
        px = torch.unsqueeze(source[:,:,0], 2)  # [bn, pn, 1]
        py = torch.unsqueeze(source[:,:,1], 2)  # [bn, pn, 1]
        if torch.cuda.is_available():
            px = px.cuda()
            py = py.cuda()
        d2 = torch.square(x_t_flat - px) + torch.square(y_t_flat - py)
        r = d2 * torch.log(d2 + 1e-6) # [bn, pn, h*w]
        x_t_flat_g = x_t_flat.expand(num_batch, -1, -1)  # [bn, 1, h*w]
        y_t_flat_g = y_t_flat.expand(num_batch, -1, -1)  # [bn, 1, h*w]
        ones = torch.ones_like(x_t_flat_g) # [bn, 1, h*w]
        if torch.cuda.is_available():
            ones = ones.cuda()

        grid = torch.cat((ones, x_t_flat_g, y_t_flat_g, r), 1) # [bn, 3+pn, h*w]

        #if torch.cuda.is_available():
        #    grid = grid.cuda()
        return grid

    def _transform(T, source, input_dim, out_size):
        num_batch, num_channels, height, width = input_dim.size()

        out_height, out_width = out_size[0], out_size[1]
        grid = _meshgrid(out_height, out_width, source) # [bn, 3+pn, h*w]

        # transform A x (1, x_t, y_t, r1, r2, ..., rn) -> (x_s, y_s)
        # [bn, 2, pn+3] x [bn, pn+3, h*w] -> [bn, 2, h*w]
        T_g = torch.matmul(T, grid)
        x_s = T_g[:,0,:]
        y_s = T_g[:,1,:]
        x_s_flat = x_s.reshape([-1])
        y_s_flat = y_s.reshape([-1])

        input_transformed = _interpolate(input_dim, x_s_flat, y_s_flat,out_size)

        output = input_transformed.reshape([num_batch, out_height, out_width, num_channels])

        output = output.permute(0,3,1,2)
        return output#, condition


    def _solve_system(source, target):
        num_batch  = source.size()[0]
        num_point  = source.size()[1]

        np.set_printoptions(precision=8)

        ones = torch.ones(num_batch, num_point, 1).float()
        if torch.cuda.is_available():
            ones = ones.cuda()
        p = torch.cat([ones, source], 2) # [bn, pn, 3]

        p_1 = p.reshape([num_batch, -1, 1, 3]) # [bn, pn, 1, 3]
        p_2 = p.reshape([num_batch, 1, -1, 3])  # [bn, 1, pn, 3]
        d2 = torch.sum(torch.square(p_1-p_2), 3) # p1 - p2: [bn, pn, pn, 3]   final output: [bn, pn, pn]

        r = d2 * torch.log(d2 + 1e-6) # [bn, pn, pn]

        zeros = torch.zeros(num_batch, 3, 3).float()
        if torch.cuda.is_available():
            zeros = zeros.cuda()
        W_0 = torch.cat((p, r), 2) # [bn, pn, 3+pn]
        W_1 = torch.cat((zeros, p.permute(0,2,1)), 2) # [bn, 3, pn+3]
        W = torch.cat((W_0, W_1), 1) # [bn, pn+3, pn+3]

        W_inv = torch.inverse(W.type(torch.float64))


        zeros2 = torch.zeros(num_batch, 3, 2)
        if torch.cuda.is_available():
            zeros2 = zeros2.cuda()
        tp = torch.cat((target, zeros2), 1) # [bn, pn+3, 2]

        T = torch.matmul(W_inv, tp.type(torch.float64)) # [bn, pn+3, 2]
        T = T.permute(0, 2, 1) # [bn, 2, pn+3]


        return T.type(torch.float32)

    T = _solve_system(source, target)

    output = _transform(T, source, U, out_size)

    return output

================================================
FILE: Warp/Codes/utils/torch_tps_transform2.py
================================================
import torch
import numpy as np


# transforming an image (U) from target (control points) to source (control points)
# all the points should be normalized from -1 ~1

# compared with torch_tps_transform.py, this version move some operations from GPU to CPU to save GPU memory

def transformer(U, source, target, out_size):

    def _repeat(x, n_repeats):

        rep = torch.ones([n_repeats, ]).unsqueeze(0)
        rep = rep.int()
        x = x.int()

        x = torch.matmul(x.reshape([-1,1]), rep)
        return x.reshape([-1])

    def _interpolate(im, x, y, out_size):

        num_batch, num_channels , height, width = im.size()

        height_f = height
        width_f = width
        out_height, out_width = out_size[0], out_size[1]

        zero = 0
        max_y = height - 1
        max_x = width - 1

        x = (x + 1.0)*(width_f) / 2.0
        y = (y + 1.0) * (height_f) / 2.0

        # do sampling
        x0 = torch.floor(x).int()
        x1 = x0 + 1
        y0 = torch.floor(y).int()
        y1 = y0 + 1

        x0 = torch.clamp(x0, zero, max_x)
        x1 = torch.clamp(x1, zero, max_x)
        y0 = torch.clamp(y0, zero, max_y)
        y1 = torch.clamp(y1, zero, max_y)
        dim2 = torch.from_numpy( np.array(width) )
        dim1 = torch.from_numpy( np.array(width * height) )

        base = _repeat(torch.arange(0,num_batch) * dim1, out_height * out_width)
        if torch.cuda.is_available():
            dim2 = dim2.cuda()
            dim1 = dim1.cuda()
            y0 = y0.cuda()
            y1 = y1.cuda()
            x0 = x0.cuda()
            x1 = x1.cuda()
            base = base.cuda()
        base_y0 = base + y0 * dim2
        base_y1 = base + y1 * dim2
        idx_a = base_y0 + x0
        idx_b = base_y1 + x0
        idx_c = base_y0 + x1
        idx_d = base_y1 + x1

        # channels dim
        im = im.permute(0,2,3,1)
        im_flat = im.reshape([-1, num_channels]).float()


        idx_a = idx_a.unsqueeze(-1).long()
        idx_a = idx_a.expand(out_height * out_width * num_batch,num_channels)
        Ia = torch.gather(im_flat, 0, idx_a)

        idx_b = idx_b.unsqueeze(-1).long()
        idx_b = idx_b.expand(out_height * out_width * num_batch, num_channels)
        Ib = torch.gather(im_flat, 0, idx_b)

        idx_c = idx_c.unsqueeze(-1).long()
        idx_c = idx_c.expand(out_height * out_width * num_batch, num_channels)
        Ic = torch.gather(im_flat, 0, idx_c)

        idx_d = idx_d.unsqueeze(-1).long()
        idx_d = idx_d.expand(out_height * out_width * num_batch, num_channels)
        Id = torch.gather(im_flat, 0, idx_d)

        x0_f = x0.float()
        x1_f = x1.float()
        y0_f = y0.float()
        y1_f = y1.float()

        wa = torch.unsqueeze(((x1_f - x) * (y1_f - y)), 1)
        wb = torch.unsqueeze(((x1_f - x) * (y - y0_f)), 1)
        wc = torch.unsqueeze(((x - x0_f) * (y1_f - y)), 1)
        wd = torch.unsqueeze(((x - x0_f) * (y - y0_f)), 1)
        output = wa*Ia+wb*Ib+wc*Ic+wd*Id

        return output

    def _meshgrid(height, width, source):

        source = source.cpu()

        x_t = torch.matmul(torch.ones([height, 1]), torch.unsqueeze(torch.linspace(-1.0, 1.0, width), 0))
        y_t = torch.matmul(torch.unsqueeze(torch.linspace(-1.0, 1.0, height), 1), torch.ones([1, width]))

        x_t_flat = x_t.reshape([1, 1, -1])
        y_t_flat = y_t.reshape([1, 1, -1])

        num_batch = source.size()[0]
        px = torch.unsqueeze(source[:,:,0], 2)  # [bn, pn, 1]
        py = torch.unsqueeze(source[:,:,1], 2)  # [bn, pn, 1]

        d2 = torch.square(x_t_flat - px) + torch.square(y_t_flat - py)
        r = d2 * torch.log(d2 + 1e-6) # [bn, pn, h*w]
        x_t_flat_g = x_t_flat.expand(num_batch, -1, -1)  # [bn, 1, h*w]
        y_t_flat_g = y_t_flat.expand(num_batch, -1, -1)  # [bn, 1, h*w]
        ones = torch.ones_like(x_t_flat_g) # [bn, 1, h*w]

        grid = torch.cat((ones, x_t_flat_g, y_t_flat_g, r), 1) # [bn, 3+pn, h*w]

        #if torch.cuda.is_available():
        grid = grid.cuda()
        return grid

    def _transform(T, source, input_dim, out_size):
        num_batch, num_channels, height, width = input_dim.size()

        out_height, out_width = out_size[0], out_size[1]
        grid = _meshgrid(out_height, out_width, source) # [bn, 3+pn, h*w]
        #print(grid.device)



        # transform A x (1, x_t, y_t, r1, r2, ..., rn) -> (x_s, y_s)
        # [bn, 2, pn+3] x [bn, pn+3, h*w] -> [bn, 2, h*w]
        T_g = torch.matmul(T, grid)
        x_s = T_g[:,0,:]
        y_s = T_g[:,1,:]
        x_s_flat = x_s.reshape([-1])
        y_s_flat = y_s.reshape([-1])

        input_transformed = _interpolate(input_dim, x_s_flat, y_s_flat,out_size)

        output = input_transformed.reshape([num_batch, out_height, out_width, num_channels])

        output = output.permute(0,3,1,2)
        #print(output.device)
        return output#, condition


    def _solve_system(source, target):
        num_batch  = source.size()[0]
        num_point  = source.size()[1]

        np.set_printoptions(precision=8)

        ones = torch.ones(num_batch, num_point, 1).float()
        if torch.cuda.is_available():
            ones = ones.cuda()
        p = torch.cat([ones, source], 2) # [bn, pn, 3]

        p_1 = p.reshape([num_batch, -1, 1, 3]) # [bn, pn, 1, 3]
        p_2 = p.reshape([num_batch, 1, -1, 3])  # [bn, 1, pn, 3]
        d2 = torch.sum(torch.square(p_1-p_2), 3) # p1 - p2: [bn, pn, pn, 3]   final output: [bn, pn, pn]

        r = d2 * torch.log(d2 + 1e-6) # [bn, pn, pn]


        zeros = torch.zeros(num_batch, 3, 3).float()
        if torch.cuda.is_available():
            zeros = zeros.cuda()
        W_0 = torch.cat((p, r), 2) # [bn, pn, 3+pn]
        W_1 = torch.cat((zeros, p.permute(0,2,1)), 2) # [bn, 3, pn+3]
        W = torch.cat((W_0, W_1), 1) # [bn, pn+3, pn+3]

        W_inv = torch.inverse(W.type(torch.float64))

        zeros2 = torch.zeros(num_batch, 3, 2)
        if torch.cuda.is_available():
            zeros2 = zeros2.cuda()
        tp = torch.cat((target, zeros2), 1) # [bn, pn+3, 2]

        T = torch.matmul(W_inv, tp.type(torch.float64)) # [bn, pn+3, 2]
        T = T.permute(0, 2, 1) # [bn, 2, pn+3]


        return T.type(torch.float32)

    T = _solve_system(source, target)

    output = _transform(T, source, U, out_size)

    return output#, condition

================================================
FILE: Warp/model/.txt
================================================



================================================
FILE: Warp/readme.md
================================================
## Train on UDIS-D
Set the training dataset path in Warp/Codes/train.py.

```
python train.py
```

## Test on UDIS-D
The pre-trained model of warp is available at [Google Drive](https://drive.google.com/file/d/1GBwB0y3tUUsOYHErSqxDxoC_Om3BJUEt/view?usp=sharing) or [Baidu Cloud](https://pan.baidu.com/s/1Fx6YnQi9B2wvP_TOVAaBEA)(Extraction code: 1234).
#### Calculate PSNR/SSIM
Set the testing dataset path in Warp/Codes/test.py.

```
python test.py
```

#### Generate the warped images and corresponding masks
Set the training/testing dataset path in Warp/Codes/test_output.py.

```
python test_output.py
```
The warped images and masks will be generated and saved at the original training/testing dataset path. The results of average fusion will be saved at the current path.

## Test on other datasets
When testing on other datasets with different scenes and resolutions, we apply the iterative warp adaption to get better alignment performance.

Set the 'path/img1_name/img2_name' in Warp/Codes/test_other.py. (By default, both img1 and img2 are placed under 'path')
```
python test_other.py
```
The results before/after adaption will be generated and saved at 'path'.



================================================
FILE: Warp/summary/.txt
================================================



================================================
FILE: environment.yml
================================================
name: nl
channels:
  - https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/pytorch
  - https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main
  - https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main/
  - https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/pytorch/
  - https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/free/
  - defaults
dependencies:
  - _anaconda_depends=2020.07=py38_0
  - _ipyw_jlab_nb_ext_conf=0.1.0=py38_0
  - _libgcc_mutex=0.1=main
  - alabaster=0.7.12=py_0
  - anaconda=custom=py38_1
  - anaconda-client=1.7.2=py38_0
  - anaconda-navigator=1.10.0=py38_0
  - anaconda-project=0.8.4=py_0
  - argh=0.26.2=py38_0
  - argon2-cffi=20.1.0=py38h7b6447c_1
  - asn1crypto=1.4.0=py_0
  - astroid=2.4.2=py38_0
  - astropy=4.0.2=py38h7b6447c_0
  - async_generator=1.10=py_0
  - atomicwrites=1.4.0=py_0
  - attrs=20.3.0=pyhd3eb1b0_0
  - autopep8=1.5.4=py_0
  - babel=2.8.1=pyhd3eb1b0_0
  - backcall=0.2.0=py_0
  - backports=1.0=py_2
  - backports.functools_lru_cache=1.6.1=py_0
  - backports.shutil_get_terminal_size=1.0.0=py38_2
  - backports.tempfile=1.0=py_1
  - backports.weakref=1.0.post1=py_1
  - beautifulsoup4=4.9.3=pyhb0f4dca_0
  - bitarray=1.6.1=py38h27cfd23_0
  - bkcharts=0.2=py38_0
  - blas=1.0=mkl
  - bleach=3.2.1=py_0
  - blosc=1.20.1=hd408876_0
  - bokeh=2.2.3=py38_0
  - boto=2.49.0=py38_0
  - bottleneck=1.3.2=py38heb32a55_1
  - brotlipy=0.7.0=py38h7b6447c_1000
  - bzip2=1.0.8=h7b6447c_0
  - ca-certificates=2021.4.13=h06a4308_1
  - cairo=1.14.12=h8948797_3
  - certifi=2020.12.5=py38h06a4308_0
  - cffi=1.14.3=py38he30daa8_0
  - chardet=3.0.4=py38_1003
  - click=7.1.2=py_0
  - cloudpickle=1.6.0=py_0
  - clyent=1.2.2=py38_1
  - colorama=0.4.4=py_0
  - conda-package-handling=1.7.2=py38h03888b9_0
  - conda-verify=3.4.2=py_1
  - contextlib2=0.6.0.post1=py_0
  - cryptography=3.1.1=py38h1ba5d50_0
  - cudatoolkit=11.0.221=h6bb024c_0
  - curl=7.71.1=hbc83047_1
  - cycler=0.10.0=py38_0
  - cython=0.29.21=py38he6710b0_0
  - cytoolz=0.11.0=py38h7b6447c_0
  - dask=2.30.0=py_0
  - dask-core=2.30.0=py_0
  - dbus=1.13.18=hb2f20db_0
  - decorator=4.4.2=py_0
  - defusedxml=0.6.0=py_0
  - diff-match-patch=20200713=py_0
  - distributed=2.30.1=py38h06a4308_0
  - docutils=0.16=py38_1
  - entrypoints=0.3=py38_0
  - et_xmlfile=1.0.1=py_1001
  - expat=2.2.10=he6710b0_2
  - fastcache=1.1.0=py38h7b6447c_0
  - filelock=3.0.12=py_0
  - flake8=3.8.4=py_0
  - flask=1.1.2=py_0
  - fontconfig=2.13.0=h9420a91_0
  - freetype=2.10.4=h5ab3b9f_0
  - fribidi=1.0.10=h7b6447c_0
  - fsspec=0.8.3=py_0
  - future=0.18.2=py38_1
  - get_terminal_size=1.0.0=haa9412d_0
  - gevent=20.9.0=py38h7b6447c_0
  - glib=2.66.1=h92f7085_0
  - glob2=0.7=py_0
  - gmp=6.1.2=h6c8ec71_1
  - gmpy2=2.0.8=py38hd5f6e3b_3
  - graphite2=1.3.14=h23475e2_0
  - greenlet=0.4.17=py38h7b6447c_0
  - gst-plugins-base=1.14.0=hbbd80ab_1
  - gstreamer=1.14.0=hb31296c_0
  - h5py=2.10.0=py38h7918eee_0
  - harfbuzz=2.4.0=hca77d97_1
  - hdf5=1.10.4=hb1b8bf9_0
  - heapdict=1.0.1=py_0
  - html5lib=1.1=py_0
  - icu=58.2=he6710b0_3
  - idna=2.10=py_0
  - imageio=2.9.0=py_0
  - imagesize=1.2.0=py_0
  - importlib_metadata=2.0.0=1
  - iniconfig=1.1.1=py_0
  - intel-openmp=2020.2=254
  - intervaltree=3.1.0=py_0
  - ipykernel=5.3.4=py38h5ca1d4c_0
  - ipython=7.19.0=py38hb070fc8_0
  - ipython_genutils=0.2.0=py38_0
  - ipywidgets=7.5.1=py_1
  - isort=5.6.4=py_0
  - itsdangerous=1.1.0=py_0
  - jbig=2.1=hdba287a_0
  - jdcal=1.4.1=py_0
  - jedi=0.17.1=py38_0
  - jeepney=0.5.0=pyhd3eb1b0_0
  - jinja2=2.11.2=py_0
  - joblib=0.17.0=py_0
  - jpeg=9b=h024ee3a_2
  - json5=0.9.5=py_0
  - jsonschema=3.2.0=py_2
  - jupyter=1.0.0=py38_7
  - jupyter_client=6.1.7=py_0
  - jupyter_console=6.2.0=py_0
  - jupyter_core=4.6.3=py38_0
  - jupyterlab=2.2.6=py_0
  - jupyterlab_pygments=0.1.2=py_0
  - jupyterlab_server=1.2.0=py_0
  - keyring=21.4.0=py38_1
  - kiwisolver=1.3.0=py38h2531618_0
  - krb5=1.18.2=h173b8e3_0
  - lazy-object-proxy=1.4.3=py38h7b6447c_0
  - lcms2=2.11=h396b838_0
  - ld_impl_linux-64=2.33.1=h53a641e_7
  - libarchive=3.4.2=h62408e4_0
  - libcurl=7.71.1=h20c2e04_1
  - libedit=3.1.20191231=h14c3975_1
  - libffi=3.3=he6710b0_2
  - libgcc-ng=9.1.0=hdf63c60_0
  - libgfortran-ng=7.3.0=hdf63c60_0
  - liblief=0.10.1=he6710b0_0
  - libllvm10=10.0.1=hbcb73fb_5
  - libllvm9=9.0.1=h4a3c616_1
  - libpng=1.6.37=hbc83047_0
  - libsodium=1.0.18=h7b6447c_0
  - libspatialindex=1.9.3=he6710b0_0
  - libssh2=1.9.0=h1ba5d50_1
  - libstdcxx-ng=9.1.0=hdf63c60_0
  - libtiff=4.1.0=h2733197_1
  - libtool=2.4.6=h7b6447c_1005
  - libuuid=1.0.3=h1bed415_2
  - libuv=1.40.0=h7b6447c_0
  - libxcb=1.14=h7b6447c_0
  - libxml2=2.9.10=hb55368b_3
  - libxslt=1.1.34=hc22bd24_0
  - llvmlite=0.34.0=py38h269e1b5_4
  - locket=0.2.0=py38_1
  - lxml=4.6.1=py38hefd8a0e_0
  - lz4-c=1.9.2=heb0550a_3
  - lzo=2.10=h7b6447c_2
  - markupsafe=1.1.1=py38h7b6447c_0
  - matplotlib=3.3.2=0
  - matplotlib-base=3.3.2=py38h817c723_0
  - mccabe=0.6.1=py38_1
  - mistune=0.8.4=py38h7b6447c_1000
  - mkl=2020.2=256
  - mkl-service=2.3.0=py38he904b0f_0
  - mkl_fft=1.2.0=py38h23d657b_0
  - mkl_random=1.1.1=py38h0573a6f_0
  - mock=4.0.2=py_0
  - more-itertools=8.6.0=pyhd3eb1b0_0
  - mpc=1.1.0=h10f8cd9_1
  - mpfr=4.0.2=hb69a4c5_1
  - mpmath=1.1.0=py38_0
  - msgpack-python=1.0.0=py38hfd86e86_1
  - multipledispatch=0.6.0=py38_0
  - navigator-updater=0.2.1=py38_0
  - nbclient=0.5.1=py_0
  - nbconvert=6.0.7=py38_0
  - nbformat=5.0.8=py_0
  - ncurses=6.2=he6710b0_1
  - nest-asyncio=1.4.2=pyhd3eb1b0_0
  - networkx=2.5=py_0
  - ninja=1.10.2=hff7bd54_1
  - nltk=3.5=py_0
  - nose=1.3.7=py38_2
  - notebook=6.1.4=py38_0
  - numba=0.51.2=py38h0573a6f_1
  - numexpr=2.7.1=py38h423224d_0
  - numpy-base=1.19.2=py38hfa32c7d_0
  - numpydoc=1.1.0=pyhd3eb1b0_1
  - olefile=0.46=py_0
  - openpyxl=3.0.5=py_0
  - openssl=1.1.1k=h27cfd23_0
  - packaging=20.4=py_0
  - pandas=1.1.3=py38he6710b0_0
  - pandoc=2.11=hb0f4dca_0
  - pandocfilters=1.4.3=py38h06a4308_1
  - pango=1.45.3=hd140c19_0
  - parso=0.7.0=py_0
  - partd=1.1.0=py_0
  - patchelf=0.12=he6710b0_0
  - path=15.0.0=py38_0
  - path.py=12.5.0=0
  - pathlib2=2.3.5=py38_0
  - pathtools=0.1.2=py_1
  - patsy=0.5.1=py38_0
  - pcre=8.44=he6710b0_0
  - pep8=1.7.1=py38_0
  - pexpect=4.8.0=py38_0
  - pickleshare=0.7.5=py38_1000
  - pip=20.2.4=py38h06a4308_0
  - pixman=0.40.0=h7b6447c_0
  - pkginfo=1.6.1=py38h06a4308_0
  - pluggy=0.13.1=py38_0
  - ply=3.11=py38_0
  - prometheus_client=0.8.0=py_0
  - prompt-toolkit=3.0.8=py_0
  - prompt_toolkit=3.0.8=0
  - psutil=5.7.2=py38h7b6447c_0
  - ptyprocess=0.6.0=py38_0
  - py=1.9.0=py_0
  - py-lief=0.10.1=py38h403a769_0
  - pycodestyle=2.6.0=py_0
  - pycosat=0.6.3=py38h7b6447c_1
  - pycparser=2.20=py_2
  - pycurl=7.43.0.6=py38h1ba5d50_0
  - pydocstyle=5.1.1=py_0
  - pyflakes=2.2.0=py_0
  - pygments=2.7.2=pyhd3eb1b0_0
  - pylint=2.6.0=py38_0
  - pyodbc=4.0.30=py38he6710b0_0
  - pyopenssl=19.1.0=py_1
  - pyparsing=2.4.7=py_0
  - pyqt=5.9.2=py38h05f1152_4
  - pyrsistent=0.17.3=py38h7b6447c_0
  - pysocks=1.7.1=py38_0
  - pytables=3.6.1=py38h9fd0a39_0
  - pytest=6.1.1=py38_0
  - python=3.8.5=h7579374_1
  - python-dateutil=2.8.1=py_0
  - python-jsonrpc-server=0.4.0=py_0
  - python-language-server=0.35.1=py_0
  - python-libarchive-c=2.9=py_0
  - pytorch=1.7.1=py3.8_cuda11.0.221_cudnn8.0.5_0
  - pytz=2020.1=py_0
  - pywavelets=1.1.1=py38h7b6447c_2
  - pyxdg=0.27=pyhd3eb1b0_0
  - pyyaml=5.3.1=py38h7b6447c_1
  - pyzmq=19.0.2=py38he6710b0_1
  - qdarkstyle=2.8.1=py_0
  - qt=5.9.7=h5867ecd_1
  - qtawesome=1.0.1=py_0
  - qtconsole=4.7.7=py_0
  - qtpy=1.9.0=py_0
  - readline=8.0=h7b6447c_0
  - regex=2020.10.15=py38h7b6447c_0
  - requests=2.24.0=py_0
  - ripgrep=12.1.1=0
  - rope=0.18.0=py_0
  - rtree=0.9.4=py38_1
  - ruamel_yaml=0.15.87=py38h7b6447c_1
  - scikit-learn=0.23.2=py38h0573a6f_0
  - scipy=1.5.2=py38h0b6359f_0
  - seaborn=0.11.0=py_0
  - secretstorage=3.1.2=py38_0
  - send2trash=1.5.0=py38_0
  - setuptools=50.3.1=py38h06a4308_1
  - simplegeneric=0.8.1=py38_2
  - singledispatch=3.4.0.3=py_1001
  - sip=4.19.13=py38he6710b0_0
  - six=1.15.0=py38h06a4308_0
  - snappy=1.1.8=he6710b0_0
  - snowballstemmer=2.0.0=py_0
  - sortedcollections=1.2.1=py_0
  - sortedcontainers=2.2.2=py_0
  - soupsieve=2.0.1=py_0
  - sphinx=3.2.1=py_0
  - sphinxcontrib=1.0=py38_1
  - sphinxcontrib-applehelp=1.0.2=py_0
  - sphinxcontrib-devhelp=1.0.2=py_0
  - sphinxcontrib-htmlhelp=1.0.3=py_0
  - sphinxcontrib-jsmath=1.0.1=py_0
  - sphinxcontrib-qthelp=1.0.3=py_0
  - sphinxcontrib-serializinghtml=1.1.4=py_0
  - sphinxcontrib-websupport=1.2.4=py_0
  - spyder=4.1.5=py38_0
  - spyder-kernels=1.9.4=py38_0
  - sqlalchemy=1.3.20=py38h7b6447c_0
  - sqlite=3.33.0=h62c20be_0
  - statsmodels=0.12.0=py38h7b6447c_0
  - sympy=1.6.2=py38h06a4308_1
  - tbb=2020.3=hfd86e86_0
  - tblib=1.7.0=py_0
  - terminado=0.9.1=py38_0
  - testpath=0.4.4=py_0
  - threadpoolctl=2.1.0=pyh5ca1d4c_0
  - tifffile=2020.10.1=py38hdd07704_2
  - tk=8.6.10=hbc83047_0
  - toml=0.10.1=py_0
  - toolz=0.11.1=py_0
  - torchvision=0.8.2=py38_cu110
  - tornado=6.0.4=py38h7b6447c_1
  - tqdm=4.50.2=py_0
  - traitlets=5.0.5=py_0
  - typing_extensions=3.7.4.3=py_0
  - ujson=4.0.1=py38he6710b0_0
  - unicodecsv=0.14.1=py38_0
  - unixodbc=2.3.9=h7b6447c_0
  - urllib3=1.25.11=py_0
  - watchdog=0.10.3=py38_0
  - wcwidth=0.2.5=py_0
  - webencodings=0.5.1=py38_1
  - werkzeug=1.0.1=py_0
  - wheel=0.35.1=py_0
  - widgetsnbextension=3.5.1=py38_0
  - wrapt=1.11.2=py38h7b6447c_0
  - wurlitzer=2.0.1=py38_0
  - xlrd=1.2.0=py_0
  - xlsxwriter=1.3.7=py_0
  - xlwt=1.3.0=py38_0
  - xmltodict=0.12.0=py_0
  - xz=5.2.5=h7b6447c_0
  - yaml=0.2.5=h7b6447c_0
  - yapf=0.30.0=py_0
  - zeromq=4.3.3=he6710b0_3
  - zict=2.0.0=py_0
  - zipp=3.4.0=pyhd3eb1b0_0
  - zlib=1.2.11=h7b6447c_3
  - zope=1.0=py38_1
  - zope.event=4.5.0=py38_0
  - zope.interface=5.1.2=py38h7b6447c_0
  - zstd=1.4.5=h9ceee32_0
  - pip:
    - absl-py==0.12.0
    - cachetools==5.2.0
    - einops==0.3.0
    - google-auth==2.8.0
    - google-auth-oauthlib==0.4.6
    - grpcio==1.46.3
    - importlib-metadata==4.11.4
    - markdown==3.3.7
    - medpy==0.4.0
    - ml-collections==0.1.0
    - numpy==1.19.5
    - oauthlib==3.2.0
    - opencv-python-headless==4.5.1.48
    - pillow==9.1.1
    - protobuf==3.17.0
    - pyasn1==0.4.8
    - pyasn1-modules==0.2.8
    - requests-oauthlib==1.3.1
    - rsa==4.8
    - scikit-image==0.15.0
    - simpleitk==2.0.2
    - tensorboard==2.9.0
    - tensorboard-data-server==0.6.1
    - tensorboard-plugin-wit==1.8.1
    - timm==0.4.9
    - yacs==0.1.6
prefix: /root/anaconda3/envs/nl
Download .txt
gitextract_hn9kepfw/

├── Composition/
│   ├── Codes/
│   │   ├── dataset.py
│   │   ├── loss.py
│   │   ├── network.py
│   │   ├── test.py
│   │   ├── test_other.py
│   │   └── train.py
│   ├── model/
│   │   └── .txt
│   ├── readme.md
│   └── summary/
│       └── .txt
├── LICENSE
├── README.md
├── Warp/
│   ├── Codes/
│   │   ├── dataset.py
│   │   ├── grid_res.py
│   │   ├── loss.py
│   │   ├── network.py
│   │   ├── test.py
│   │   ├── test_other.py
│   │   ├── test_output.py
│   │   ├── train.py
│   │   └── utils/
│   │       ├── torch_DLT.py
│   │       ├── torch_homo_transform.py
│   │       ├── torch_tps_transform.py
│   │       └── torch_tps_transform2.py
│   ├── model/
│   │   └── .txt
│   ├── readme.md
│   └── summary/
│       └── .txt
└── environment.yml
Download .txt
SYMBOL INDEX (67 symbols across 17 files)

FILE: Composition/Codes/dataset.py
  class TrainDataset (line 10) | class TrainDataset(Dataset):
    method __init__ (line 11) | def __init__(self, data_path):
    method __getitem__ (line 26) | def __getitem__(self, index):
    method __len__ (line 69) | def __len__(self):
  class TestDataset (line 73) | class TestDataset(Dataset):
    method __init__ (line 74) | def __init__(self, data_path):
    method __getitem__ (line 90) | def __getitem__(self, index):
    method __len__ (line 125) | def __len__(self):

FILE: Composition/Codes/loss.py
  function l_num_loss (line 27) | def l_num_loss(img1, img2, l_num=1):
  function boundary_extraction (line 31) | def boundary_extraction(mask):
  function cal_boundary_term (line 66) | def cal_boundary_term(inpu1_tesnor, inpu2_tesnor, mask1_tesnor, mask2_te...
  function cal_smooth_term_stitch (line 76) | def cal_smooth_term_stitch(stitched_image, learned_mask1):
  function cal_smooth_term_diff (line 94) | def cal_smooth_term_diff(img1, img2, learned_mask1, overlap):

FILE: Composition/Codes/network.py
  function build_model (line 7) | def build_model(net, warp1_tensor, warp2_tensor, mask1_tensor, mask2_ten...
  class DownBlock (line 22) | class DownBlock(nn.Module):
    method __init__ (line 23) | def __init__(self, inchannels, outchannels, dilation, pool=True):
    method forward (line 41) | def forward(self, x):
  class UpBlock (line 44) | class UpBlock(nn.Module):
    method __init__ (line 45) | def __init__(self, inchannels, outchannels, dilation):
    method forward (line 67) | def forward(self, x1, x2):
  class Network (line 76) | class Network(nn.Module):
    method __init__ (line 77) | def __init__(self, nclasses=1):
    method forward (line 105) | def forward(self, x, y, m1, m2):

FILE: Composition/Codes/test.py
  function test (line 17) | def test(args):

FILE: Composition/Codes/test_other.py
  function loadSingleData (line 15) | def loadSingleData(data_path):
  function test_other (line 50) | def test_other(args):

FILE: Composition/Codes/train.py
  function train (line 32) | def train(args):

FILE: Warp/Codes/dataset.py
  class TrainDataset (line 10) | class TrainDataset(Dataset):
    method __init__ (line 11) | def __init__(self, data_path):
    method __getitem__ (line 28) | def __getitem__(self, index):
    method __len__ (line 57) | def __len__(self):
  class TestDataset (line 61) | class TestDataset(Dataset):
    method __init__ (line 62) | def __init__(self, data_path):
    method __getitem__ (line 79) | def __getitem__(self, index):
    method __len__ (line 101) | def __len__(self):

FILE: Warp/Codes/loss.py
  function l_num_loss (line 10) | def l_num_loss(img1, img2, l_num=1):
  function cal_lp_loss (line 14) | def cal_lp_loss(input1, input2, output_H, output_H_inv, warp_mesh, warp_...
  function cal_lp_loss2 (line 37) | def cal_lp_loss2(input1, warp_mesh, warp_mesh_mask):
  function inter_grid_loss (line 48) | def inter_grid_loss(overlap, mesh):
  function intra_grid_loss (line 84) | def intra_grid_loss(pts):

FILE: Warp/Codes/network.py
  function draw_mesh_on_warp (line 23) | def draw_mesh_on_warp(warp, f_local):
  function H2Mesh (line 50) | def H2Mesh(H, rigid_mesh):
  function get_rigid_mesh (line 69) | def get_rigid_mesh(batch_size, height, width):
  function get_norm_mesh (line 83) | def get_norm_mesh(mesh, height, width):
  function data_aug (line 95) | def data_aug(img1, img2):
  function build_model (line 120) | def build_model(net, input1_tensor, input2_tensor, is_training = True):
  function build_new_ft_model (line 191) | def build_new_ft_model(net, input1_tensor, input2_tensor):
  function get_stitched_result (line 235) | def get_stitched_result(input1_tensor, input2_tensor, rigid_mesh, mesh):
  function build_output_model (line 286) | def build_output_model(net, input1_tensor, input2_tensor):
  class Network (line 366) | class Network(nn.Module):
    method __init__ (line 368) | def __init__(self):
    method get_res50_FeatureMap (line 455) | def get_res50_FeatureMap(self, resnet50_model):
    method forward (line 475) | def forward(self, input1_tesnor, input2_tesnor):
    method extract_patches (line 522) | def extract_patches(self, x, kernel=3, stride=1):
    method CCL (line 530) | def CCL(self, feature_1, feature_2):

FILE: Warp/Codes/test.py
  function create_gif (line 18) | def create_gif(image_list, gif_name, duration=0.35):
  function test (line 26) | def test(args):

FILE: Warp/Codes/test_other.py
  function loadSingleData (line 21) | def loadSingleData(data_path, img1_name, img2_name):
  function train (line 56) | def train(args):

FILE: Warp/Codes/test_output.py
  function draw_mesh_on_warp (line 20) | def draw_mesh_on_warp(warp, f_local):
  function create_gif (line 44) | def create_gif(image_list, gif_name, duration=0.35):
  function test (line 54) | def test(args):

FILE: Warp/Codes/train.py
  function train (line 28) | def train(args):

FILE: Warp/Codes/utils/torch_DLT.py
  function tensor_DLT (line 17) | def tensor_DLT(src_p, dst_p):

FILE: Warp/Codes/utils/torch_homo_transform.py
  function transformer (line 5) | def transformer(U, theta, out_size, **kwargs):

FILE: Warp/Codes/utils/torch_tps_transform.py
  function transformer (line 7) | def transformer(U, source, target, out_size):

FILE: Warp/Codes/utils/torch_tps_transform2.py
  function transformer (line 10) | def transformer(U, source, target, out_size):
Condensed preview — 27 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (133K chars).
[
  {
    "path": "Composition/Codes/dataset.py",
    "chars": 4488,
    "preview": "from torch.utils.data import Dataset\r\nimport  numpy as np\r\nimport cv2, torch\r\nimport os\r\nimport glob\r\nfrom collections i"
  },
  {
    "path": "Composition/Codes/loss.py",
    "chars": 3800,
    "preview": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\n\n\n# def get_vgg19_FeatureMap(vgg_model, input_255, l"
  },
  {
    "path": "Composition/Codes/network.py",
    "chars": 4116,
    "preview": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\n\n\ndef build_model(net, warp1_tensor, warp2_tensor, m"
  },
  {
    "path": "Composition/Codes/test.py",
    "chars": 3524,
    "preview": "# coding: utf-8\nimport argparse\nimport torch\nfrom torch.utils.data import DataLoader\nfrom network import build_model, Ne"
  },
  {
    "path": "Composition/Codes/test_other.py",
    "chars": 3992,
    "preview": "# coding: utf-8\nimport argparse\nimport torch\nfrom network import build_model, Network\nimport os\nimport numpy as np\nimpor"
  },
  {
    "path": "Composition/Codes/train.py",
    "chars": 7660,
    "preview": "import argparse\r\nimport torch\r\nfrom torch.utils.data import DataLoader\r\nimport os\r\nimport torch.optim as optim\r\nfrom tor"
  },
  {
    "path": "Composition/model/.txt",
    "chars": 1,
    "preview": "\n"
  },
  {
    "path": "Composition/readme.md",
    "chars": 817,
    "preview": "## Train on UDIS-D\nBefore training, the warped images and corresponding masks should be generated in the warp stage.\n\nTh"
  },
  {
    "path": "Composition/summary/.txt",
    "chars": 1,
    "preview": "\n"
  },
  {
    "path": "LICENSE",
    "chars": 11357,
    "preview": "                                 Apache License\n                           Version 2.0, January 2004\n                   "
  },
  {
    "path": "README.md",
    "chars": 2536,
    "preview": "# <p align=\"center\">Parallax-Tolerant Unsupervised Deep Image Stitching (UDIS++ [paper](https://arxiv.org/abs/2302.08207"
  },
  {
    "path": "Warp/Codes/dataset.py",
    "chars": 3597,
    "preview": "from torch.utils.data import Dataset\r\nimport  numpy as np\r\nimport cv2, torch\r\nimport os\r\nimport glob\r\nfrom collections i"
  },
  {
    "path": "Warp/Codes/grid_res.py",
    "chars": 81,
    "preview": "\n#define control point resolution (GRID_H+1) * (GRID_W+1)\nGRID_H = 12\nGRID_W = 12"
  },
  {
    "path": "Warp/Codes/loss.py",
    "chars": 4164,
    "preview": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nimport grid_res\ngrid_h = grid_res.GRID_H\ngrid_w = gr"
  },
  {
    "path": "Warp/Codes/network.py",
    "chars": 22382,
    "preview": "import torch\nimport torch.nn as nn\nimport utils.torch_DLT as torch_DLT\nimport utils.torch_homo_transform as torch_homo_t"
  },
  {
    "path": "Warp/Codes/test.py",
    "chars": 4141,
    "preview": "# coding: utf-8\r\nimport argparse\r\nimport torch\r\nfrom torch.utils.data import DataLoader\r\nimport torch.nn as nn\r\nimport i"
  },
  {
    "path": "Warp/Codes/test_other.py",
    "chars": 6954,
    "preview": "import argparse\nimport torch\n\nimport numpy as np\nimport os\nimport torch.nn as nn\nimport torch.optim as optim\n\nimport cv2"
  },
  {
    "path": "Warp/Codes/test_output.py",
    "chars": 5868,
    "preview": "# coding: utf-8\nimport argparse\nimport torch\nfrom torch.utils.data import DataLoader\nimport torch.nn as nn\nimport imagei"
  },
  {
    "path": "Warp/Codes/train.py",
    "chars": 6760,
    "preview": "import argparse\r\nimport torch\r\nfrom torch.utils.data import DataLoader\r\nimport os\r\nimport torch.optim as optim\r\nfrom tor"
  },
  {
    "path": "Warp/Codes/utils/torch_DLT.py",
    "chars": 1280,
    "preview": "import torch\nimport numpy as np\nimport cv2\n\n# src_p: shape=(bs, 4, 2)\n# det_p: shape=(bs, 4, 2)\n#\n#                     "
  },
  {
    "path": "Warp/Codes/utils/torch_homo_transform.py",
    "chars": 4997,
    "preview": "import torch\nimport numpy as np\n\n\ndef transformer(U, theta, out_size, **kwargs):\n\n\n    def _repeat(x, n_repeats):\n\n     "
  },
  {
    "path": "Warp/Codes/utils/torch_tps_transform.py",
    "chars": 6386,
    "preview": "import torch\nimport numpy as np\n\n# transforming an image (U) from target (control points) to source (control points)\n# a"
  },
  {
    "path": "Warp/Codes/utils/torch_tps_transform2.py",
    "chars": 6339,
    "preview": "import torch\nimport numpy as np\n\n\n# transforming an image (U) from target (control points) to source (control points)\n# "
  },
  {
    "path": "Warp/model/.txt",
    "chars": 1,
    "preview": "\n"
  },
  {
    "path": "Warp/readme.md",
    "chars": 1173,
    "preview": "## Train on UDIS-D\nSet the training dataset path in Warp/Codes/train.py.\n\n```\npython train.py\n```\n\n## Test on UDIS-D\nThe"
  },
  {
    "path": "Warp/summary/.txt",
    "chars": 1,
    "preview": "\n"
  },
  {
    "path": "environment.yml",
    "chars": 10467,
    "preview": "name: nl\nchannels:\n  - https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/pytorch\n  - https://mirrors.tuna.tsinghua.edu"
  }
]

About this extraction

This page contains the full source code of the nie-lang/UDIS2 GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 27 files (123.9 KB), approximately 38.7k tokens, and a symbol index with 67 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!