[
  {
    "path": "Composition/Codes/dataset.py",
    "content": "from torch.utils.data import Dataset\r\nimport  numpy as np\r\nimport cv2, torch\r\nimport os\r\nimport glob\r\nfrom collections import OrderedDict\r\nimport random\r\n\r\n\r\nclass TrainDataset(Dataset):\r\n    def __init__(self, data_path):\r\n\r\n        self.train_path = data_path\r\n        self.datas = OrderedDict()\r\n\r\n        datas = glob.glob(os.path.join(self.train_path, '*'))\r\n        for data in sorted(datas):\r\n            data_name = data.split('/')[-1]\r\n            if data_name == 'warp1' or data_name == 'warp2' or data_name == 'mask1' or data_name == 'mask2':\r\n                self.datas[data_name] = {}\r\n                self.datas[data_name]['path'] = data\r\n                self.datas[data_name]['image'] = glob.glob(os.path.join(data, '*.jpg'))\r\n                self.datas[data_name]['image'].sort()\r\n        print(self.datas.keys())\r\n\r\n    def __getitem__(self, index):\r\n\r\n        # load image1\r\n        warp1 = cv2.imread(self.datas['warp1']['image'][index])\r\n        warp1 = warp1.astype(dtype=np.float32)\r\n        warp1 = (warp1 / 127.5) - 1.0\r\n        warp1 = np.transpose(warp1, [2, 0, 1])\r\n\r\n        # load image2\r\n        warp2 = cv2.imread(self.datas['warp2']['image'][index])\r\n        warp2 = warp2.astype(dtype=np.float32)\r\n        warp2 = (warp2 / 127.5) - 1.0\r\n        warp2 = np.transpose(warp2, [2, 0, 1])\r\n\r\n        # load mask1\r\n        mask1 = cv2.imread(self.datas['mask1']['image'][index])\r\n        mask1 = mask1.astype(dtype=np.float32)\r\n        mask1 = np.expand_dims(mask1[:,:,0], 2) / 255\r\n        mask1 = np.transpose(mask1, [2, 0, 1])\r\n\r\n        # load mask2\r\n        mask2 = cv2.imread(self.datas['mask2']['image'][index])\r\n        mask2 = mask2.astype(dtype=np.float32)\r\n        mask2 = np.expand_dims(mask2[:,:,0], 2) / 255\r\n        mask2 = np.transpose(mask2, [2, 0, 1])\r\n\r\n        # convert to tensor\r\n        warp1_tensor = torch.tensor(warp1)\r\n        warp2_tensor = torch.tensor(warp2)\r\n        mask1_tensor = torch.tensor(mask1)\r\n        mask2_tensor = torch.tensor(mask2)\r\n\r\n        #return (input1_tensor, input2_tensor, mask1_tensor, mask2_tensor)\r\n\r\n        if_exchange = random.randint(0,1)\r\n        if if_exchange == 0:\r\n            #print(if_exchange)\r\n            return (warp1_tensor, warp2_tensor, mask1_tensor, mask2_tensor)\r\n        else:\r\n            #print(if_exchange)\r\n            return (warp2_tensor, warp1_tensor, mask2_tensor, mask1_tensor)\r\n\r\n\r\n    def __len__(self):\r\n\r\n        return len(self.datas['warp1']['image'])\r\n\r\nclass TestDataset(Dataset):\r\n    def __init__(self, data_path):\r\n\r\n        self.test_path = data_path\r\n        self.datas = OrderedDict()\r\n\r\n        datas = glob.glob(os.path.join(self.test_path, '*'))\r\n        for data in sorted(datas):\r\n            data_name = data.split('/')[-1]\r\n            if data_name == 'warp1' or data_name == 'warp2' or data_name == 'mask1' or data_name == 'mask2':\r\n                self.datas[data_name] = {}\r\n                self.datas[data_name]['path'] = data\r\n                self.datas[data_name]['image'] = glob.glob(os.path.join(data, '*.jpg'))\r\n                self.datas[data_name]['image'].sort()\r\n\r\n        print(self.datas.keys())\r\n\r\n    def __getitem__(self, index):\r\n\r\n\r\n                # load image1\r\n        warp1 = cv2.imread(self.datas['warp1']['image'][index])\r\n        warp1 = warp1.astype(dtype=np.float32)\r\n        warp1 = (warp1 / 127.5) - 1.0\r\n        warp1 = np.transpose(warp1, [2, 0, 1])\r\n\r\n        # load image2\r\n        warp2 = cv2.imread(self.datas['warp2']['image'][index])\r\n        warp2 = warp2.astype(dtype=np.float32)\r\n        warp2 = (warp2 / 127.5) - 1.0\r\n        warp2 = np.transpose(warp2, [2, 0, 1])\r\n\r\n        # load mask1\r\n        mask1 = cv2.imread(self.datas['mask1']['image'][index])\r\n        mask1 = mask1.astype(dtype=np.float32)\r\n        mask1 = np.expand_dims(mask1[:,:,0], 2) / 255\r\n        mask1 = np.transpose(mask1, [2, 0, 1])\r\n\r\n        # load mask2\r\n        mask2 = cv2.imread(self.datas['mask2']['image'][index])\r\n        mask2 = mask2.astype(dtype=np.float32)\r\n        mask2 = np.expand_dims(mask2[:,:,0], 2) / 255\r\n        mask2 = np.transpose(mask2, [2, 0, 1])\r\n\r\n        # convert to tensor\r\n        warp1_tensor = torch.tensor(warp1)\r\n        warp2_tensor = torch.tensor(warp2)\r\n        mask1_tensor = torch.tensor(mask1)\r\n        mask2_tensor = torch.tensor(mask2)\r\n\r\n        return (warp1_tensor, warp2_tensor, mask1_tensor, mask2_tensor)\r\n\r\n    def __len__(self):\r\n\r\n        return len(self.datas['warp1']['image'])\r\n\r\n\r\n"
  },
  {
    "path": "Composition/Codes/loss.py",
    "content": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\n\n\n# def get_vgg19_FeatureMap(vgg_model, input_255, layer_index):\n\n#     vgg_mean = torch.tensor([123.6800, 116.7790, 103.9390]).reshape((1,3,1,1))\n#     if torch.cuda.is_available():\n#         vgg_mean = vgg_mean.cuda()\n#     vgg_input = input_255-vgg_mean\n#     #x = vgg_model.features[0](vgg_input)\n#     #FeatureMap_list.append(x)\n\n\n#     for i in range(0,layer_index+1):\n#         if i == 0:\n#             x = vgg_model.features[0](vgg_input)\n#         else:\n#             x = vgg_model.features[i](x)\n\n#     return x\n\n\n\ndef l_num_loss(img1, img2, l_num=1):\n    return torch.mean(torch.abs((img1 - img2)**l_num))\n\n\ndef boundary_extraction(mask):\n\n    ones = torch.ones_like(mask)\n    zeros = torch.zeros_like(mask)\n    #define kernel\n    in_channel = 1\n    out_channel = 1\n    kernel = [[1, 1, 1],\n               [1, 1, 1],\n               [1, 1, 1]]\n    kernel = torch.FloatTensor(kernel).expand(out_channel,in_channel,3,3)\n    if torch.cuda.is_available():\n        kernel = kernel.cuda()\n        ones = ones.cuda()\n        zeros = zeros.cuda()\n    weight = nn.Parameter(data=kernel, requires_grad=False)\n\n    #dilation\n    x = F.conv2d(1-mask,weight,stride=1,padding=1)\n    x = torch.where(x < 1, zeros, ones)\n    x = F.conv2d(x,weight,stride=1,padding=1)\n    x = torch.where(x < 1, zeros, ones)\n    x = F.conv2d(x,weight,stride=1,padding=1)\n    x = torch.where(x < 1, zeros, ones)\n    x = F.conv2d(x,weight,stride=1,padding=1)\n    x = torch.where(x < 1, zeros, ones)\n    x = F.conv2d(x,weight,stride=1,padding=1)\n    x = torch.where(x < 1, zeros, ones)\n    x = F.conv2d(x,weight,stride=1,padding=1)\n    x = torch.where(x < 1, zeros, ones)\n    x = F.conv2d(x,weight,stride=1,padding=1)\n    x = torch.where(x < 1, zeros, ones)\n\n    return x*mask\n\ndef cal_boundary_term(inpu1_tesnor, inpu2_tesnor, mask1_tesnor, mask2_tesnor, stitched_image):\n    boundary_mask1 = mask1_tesnor * boundary_extraction(mask2_tesnor)\n    boundary_mask2 = mask2_tesnor * boundary_extraction(mask1_tesnor)\n\n    loss1 = l_num_loss(inpu1_tesnor*boundary_mask1, stitched_image*boundary_mask1, 1)\n    loss2 = l_num_loss(inpu2_tesnor*boundary_mask2, stitched_image*boundary_mask2, 1)\n\n    return loss1+loss2, boundary_mask1\n\n\ndef cal_smooth_term_stitch(stitched_image, learned_mask1):\n\n\n    delta = 1\n    dh_mask = torch.abs(learned_mask1[:,:,0:-1*delta,:] - learned_mask1[:,:,delta:,:])\n    dw_mask = torch.abs(learned_mask1[:,:,:,0:-1*delta] - learned_mask1[:,:,:,delta:])\n    dh_diff_img = torch.abs(stitched_image[:,:,0:-1*delta,:] - stitched_image[:,:,delta:,:])\n    dw_diff_img = torch.abs(stitched_image[:,:,:,0:-1*delta] - stitched_image[:,:,:,delta:])\n\n    dh_pixel = dh_mask * dh_diff_img\n    dw_pixel = dw_mask * dw_diff_img\n\n    loss = torch.mean(dh_pixel) + torch.mean(dw_pixel)\n\n    return loss\n\n\n\ndef cal_smooth_term_diff(img1, img2, learned_mask1, overlap):\n\n    diff_feature = torch.abs(img1-img2)**2 * overlap\n\n    delta = 1\n    dh_mask = torch.abs(learned_mask1[:,:,0:-1*delta,:] - learned_mask1[:,:,delta:,:])\n    dw_mask = torch.abs(learned_mask1[:,:,:,0:-1*delta] - learned_mask1[:,:,:,delta:])\n    dh_diff_img = torch.abs(diff_feature[:,:,0:-1*delta,:] + diff_feature[:,:,delta:,:])\n    dw_diff_img = torch.abs(diff_feature[:,:,:,0:-1*delta] + diff_feature[:,:,:,delta:])\n\n    dh_pixel = dh_mask * dh_diff_img\n    dw_pixel = dw_mask * dw_diff_img\n\n    loss = torch.mean(dh_pixel) + torch.mean(dw_pixel)\n\n    return loss\n\n    # dh_zeros = torch.zeros_like(dh_pixel)\n    # dw_zeros = torch.zeros_like(dw_pixel)\n    # if torch.cuda.is_available():\n    #     dh_zeros = dh_zeros.cuda()\n    #     dw_zeros = dw_zeros.cuda()\n\n\n    # loss = l_num_loss(dh_pixel, dh_zeros, 1) + l_num_loss(dw_pixel, dw_zeros, 1)\n\n\n    # return  loss, dh_pixel"
  },
  {
    "path": "Composition/Codes/network.py",
    "content": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\n\n\ndef build_model(net, warp1_tensor, warp2_tensor, mask1_tensor, mask2_tensor):\n\n    out  = net(warp1_tensor, warp2_tensor, mask1_tensor, mask2_tensor)\n\n    learned_mask1 = (mask1_tensor - mask1_tensor*mask2_tensor) + mask1_tensor*mask2_tensor*out\n    learned_mask2 = (mask2_tensor - mask1_tensor*mask2_tensor) + mask1_tensor*mask2_tensor*(1-out)\n    stitched_image = (warp1_tensor+1.) * learned_mask1 + (warp2_tensor+1.)*learned_mask2 - 1.\n\n    out_dict = {}\n    out_dict.update(learned_mask1=learned_mask1, learned_mask2=learned_mask2, stitched_image = stitched_image)\n\n\n    return out_dict\n\n\nclass DownBlock(nn.Module):\n    def __init__(self, inchannels, outchannels, dilation, pool=True):\n        super(DownBlock, self).__init__()\n        blk = []\n        if pool:\n            blk.append(nn.MaxPool2d(kernel_size=2, stride=2))\n        blk.append(nn.Conv2d(inchannels, outchannels, kernel_size=3, padding=1, dilation = dilation))\n        blk.append(nn.ReLU(inplace=True))\n        blk.append(nn.Conv2d(outchannels, outchannels, kernel_size=3, padding=1, dilation = dilation))\n        blk.append(nn.ReLU(inplace=True))\n        self.layer = nn.Sequential(*blk)\n\n        for m in self.modules():\n            if isinstance(m, nn.Conv2d):\n                nn.init.kaiming_normal_(m.weight)\n            elif isinstance(m, nn.BatchNorm2d):\n                m.weight.data.fill_(1)\n                m.bias.data.zero_()\n\n    def forward(self, x):\n        return self.layer(x)\n\nclass UpBlock(nn.Module):\n    def __init__(self, inchannels, outchannels, dilation):\n        super(UpBlock, self).__init__()\n        #self.convt = nn.ConvTranspose2d(inchannels, outchannels, kernel_size=2, stride=2)\n        self.halfChanelConv = nn.Sequential(\n            nn.Conv2d(inchannels, outchannels, kernel_size=3, padding=1),\n            nn.ReLU(inplace=True)\n            )\n\n        self.conv = nn.Sequential(\n            nn.Conv2d(inchannels, outchannels, kernel_size=3, padding=1, dilation = dilation),\n            nn.ReLU(inplace=True),\n            nn.Conv2d(outchannels, outchannels, kernel_size=3, padding=1, dilation = dilation),\n            nn.ReLU(inplace=True)\n        )\n\n        for m in self.modules():\n            if isinstance(m, nn.Conv2d):\n                nn.init.kaiming_normal_(m.weight)\n            elif isinstance(m, nn.BatchNorm2d):\n                m.weight.data.fill_(1)\n                m.bias.data.zero_()\n\n    def forward(self, x1, x2):\n\n        x1 = F.interpolate(x1, size = (x2.size()[2], x2.size()[3]), mode='nearest')\n        x1 = self.halfChanelConv(x1)\n        x = torch.cat([x2, x1], dim=1)\n        x = self.conv(x)\n        return x\n\n# predict the composition mask of img1\nclass Network(nn.Module):\n    def __init__(self, nclasses=1):\n        super(Network, self).__init__()\n\n\n        self.down1 = DownBlock(3, 32, 1, pool=False)\n        self.down2 = DownBlock(32, 64, 2)\n        self.down3 = DownBlock(64, 128,3)\n        self.down4 = DownBlock(128, 256, 4)\n        self.down5 = DownBlock(256, 512, 5)\n        self.up1 = UpBlock(512, 256, 4)\n        self.up2 = UpBlock(256, 128, 3)\n        self.up3 = UpBlock(128, 64, 2)\n        self.up4 = UpBlock(64, 32, 1)\n\n\n        self.out = nn.Sequential(\n            nn.Conv2d(32, nclasses, kernel_size=1),\n            nn.Sigmoid()\n        )\n\n        for m in self.modules():\n            if isinstance(m, nn.Conv2d):\n                nn.init.kaiming_normal_(m.weight)\n            elif isinstance(m, nn.BatchNorm2d):\n                m.weight.data.fill_(1)\n                m.bias.data.zero_()\n\n\n    def forward(self, x, y, m1, m2):\n\n\n        x1 = self.down1(x)\n        x2 = self.down2(x1)\n        x3 = self.down3(x2)\n        x4 = self.down4(x3)\n        x5 = self.down5(x4)\n\n        y1 = self.down1(y)\n        y2 = self.down2(y1)\n        y3 = self.down3(y2)\n        y4 = self.down4(y3)\n        y5 = self.down5(y4)\n\n        res = self.up1(x5-y5, x4-y4)\n        res = self.up2(res, x3-y3)\n        res = self.up3(res, x2-y2)\n        res = self.up4(res, x1-y1)\n        res = self.out(res)\n\n        return res\n\n\n"
  },
  {
    "path": "Composition/Codes/test.py",
    "content": "# coding: utf-8\nimport argparse\nimport torch\nfrom torch.utils.data import DataLoader\nfrom network import build_model, Network\nfrom dataset import *\nimport os\nimport numpy as np\nimport cv2\n\n\nlast_path = os.path.abspath(os.path.join(os.path.dirname(\"__file__\"), os.path.pardir))\nMODEL_DIR = os.path.join(last_path, 'model')\n\n\n\ndef test(args):\n\n    os.environ['CUDA_DEVICES_ORDER'] = \"PCI_BUS_ID\"\n    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu\n\n    # dataset\n    test_data = TestDataset(data_path=args.test_path)\n    test_loader = DataLoader(dataset=test_data, batch_size=args.batch_size, num_workers=1, shuffle=False, drop_last=False)\n\n    # define the network\n    net = Network()\n    if torch.cuda.is_available():\n        net = net.cuda()\n\n    #load the existing models if it exists\n    ckpt_list = glob.glob(MODEL_DIR + \"/*.pth\")\n    ckpt_list.sort()\n    if len(ckpt_list) != 0:\n        model_path = ckpt_list[-1]\n        checkpoint = torch.load(model_path)\n        net.load_state_dict(checkpoint['model'])\n        print('load model from {}!'.format(model_path))\n    else:\n        print('No checkpoint found!')\n        return\n\n\n    path_learn_mask1 = '../learn_mask1/'\n    if not os.path.exists(path_learn_mask1):\n        os.makedirs(path_learn_mask1)\n    path_learn_mask2 = '../learn_mask2/'\n    if not os.path.exists(path_learn_mask2):\n        os.makedirs(path_learn_mask2)\n    path_final_composition = '../composition/'\n    if not os.path.exists(path_final_composition):\n        os.makedirs(path_final_composition)\n\n\n    print(\"##################start testing#######################\")\n    net.eval()\n    for i, batch_value in enumerate(test_loader):\n\n        warp1_tensor = batch_value[0].float()\n        warp2_tensor = batch_value[1].float()\n        mask1_tensor = batch_value[2].float()\n        mask2_tensor = batch_value[3].float()\n\n        if torch.cuda.is_available():\n            warp1_tensor = warp1_tensor.cuda()\n            warp2_tensor = warp2_tensor.cuda()\n            mask1_tensor = mask1_tensor.cuda()\n            mask2_tensor = mask2_tensor.cuda()\n\n        # if inpu1_tesnor.size()[2]*inpu1_tesnor.size()[3] > 1200000:\n        #     print(\"oversize\")\n        #     continue\n\n        with torch.no_grad():\n            batch_out = build_model(net, warp1_tensor, warp2_tensor, mask1_tensor, mask2_tensor)\n\n        stitched_image = batch_out['stitched_image']\n        learned_mask1 = batch_out['learned_mask1']\n        learned_mask2 = batch_out['learned_mask2']\n\n        stitched_image = ((stitched_image[0]+1)*127.5).cpu().detach().numpy().transpose(1,2,0)\n        learned_mask1 = (learned_mask1[0]*255).cpu().detach().numpy().transpose(1,2,0)\n        learned_mask2 = (learned_mask2[0]*255).cpu().detach().numpy().transpose(1,2,0)\n\n        path = path_learn_mask1 + str(i+1).zfill(6) + \".jpg\"\n        cv2.imwrite(path, learned_mask1)\n        path = path_learn_mask2 + str(i+1).zfill(6) + \".jpg\"\n        cv2.imwrite(path, learned_mask2)\n        path = path_final_composition + str(i+1).zfill(6) + \".jpg\"\n        cv2.imwrite(path, stitched_image)\n\n\n        print('i = {}'.format( i+1))\n\n\n\nif __name__==\"__main__\":\n\n    parser = argparse.ArgumentParser()\n\n    parser.add_argument('--gpu', type=str, default='0')\n    parser.add_argument('--batch_size', type=int, default=1)\n    parser.add_argument('--test_path', type=str, default='/opt/data/private/nl/Data/UDIS-D/testing/')\n\n    print('<==================== Loading data ===================>\\n')\n\n    args = parser.parse_args()\n    print(args)\n\n    test(args)"
  },
  {
    "path": "Composition/Codes/test_other.py",
    "content": "# coding: utf-8\nimport argparse\nimport torch\nfrom network import build_model, Network\nimport os\nimport numpy as np\nimport cv2\nimport glob\n\n\n\nlast_path = os.path.abspath(os.path.join(os.path.dirname(\"__file__\"), os.path.pardir))\nMODEL_DIR = os.path.join(last_path, 'model')\n\ndef loadSingleData(data_path):\n\n    # load image1\n    warp1 = cv2.imread(data_path+\"warp1.jpg\")\n    warp1 = warp1.astype(dtype=np.float32)\n    warp1 = (warp1 / 127.5) - 1.0\n    warp1 = np.transpose(warp1, [2, 0, 1])\n\n    # load image2\n    warp2 = cv2.imread(data_path+\"warp2.jpg\")\n    warp2 = warp2.astype(dtype=np.float32)\n    warp2 = (warp2 / 127.5) - 1.0\n    warp2 = np.transpose(warp2, [2, 0, 1])\n\n    # load mask1\n    mask1 = cv2.imread(data_path+\"mask1.jpg\")\n    mask1 = mask1.astype(dtype=np.float32)\n    mask1 = mask1 / 255\n    mask1 = np.transpose(mask1, [2, 0, 1])\n\n    # load mask2\n    mask2 = cv2.imread(data_path+\"mask2.jpg\")\n    mask2 = mask2.astype(dtype=np.float32)\n    mask2 = mask2 / 255\n    mask2 = np.transpose(mask2, [2, 0, 1])\n\n    # convert to tensor\n    warp1_tensor = torch.tensor(warp1).unsqueeze(0)\n    warp2_tensor = torch.tensor(warp2).unsqueeze(0)\n    mask1_tensor = torch.tensor(mask1).unsqueeze(0)\n    mask2_tensor = torch.tensor(mask2).unsqueeze(0)\n\n    return warp1_tensor, warp2_tensor, mask1_tensor, mask2_tensor\n\n\ndef test_other(args):\n\n    os.environ['CUDA_DEVICES_ORDER'] = \"PCI_BUS_ID\"\n    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu\n\n    # define the network\n    net = Network()\n    if torch.cuda.is_available():\n        net = net.cuda()\n\n    #load the existing models if it exists\n    ckpt_list = glob.glob(MODEL_DIR + \"/*.pth\")\n    ckpt_list.sort()\n    if len(ckpt_list) != 0:\n        model_path = ckpt_list[-1]\n        checkpoint = torch.load(model_path)\n        net.load_state_dict(checkpoint['model'])\n        print('load model from {}!'.format(model_path))\n    else:\n        print('No checkpoint found!')\n        return\n\n\n    # load dataset(only one pair of images)\n    warp1_tensor, warp2_tensor, mask1_tensor, mask2_tensor = loadSingleData(data_path=args.path)\n    if torch.cuda.is_available():\n        warp1_tensor = warp1_tensor.cuda()\n        warp2_tensor = warp2_tensor.cuda()\n        mask1_tensor = mask1_tensor.cuda()\n        mask2_tensor = mask2_tensor.cuda()\n\n    net.eval()\n    with torch.no_grad():\n        batch_out = build_model(net, warp1_tensor, warp2_tensor, mask1_tensor, mask2_tensor)\n    stitched_image = batch_out['stitched_image']\n    learned_mask1 = batch_out['learned_mask1']\n    learned_mask2 = batch_out['learned_mask2']\n\n    # (optional) draw composition images with different colors like our paper\n    s1 = ((warp1_tensor[0]+1)*127.5 * learned_mask1[0]).cpu().detach().numpy().transpose(1,2,0)\n    s2 = ((warp2_tensor[0]+1)*127.5 * learned_mask2[0]).cpu().detach().numpy().transpose(1,2,0)\n    fusion = np.zeros((warp1_tensor.shape[2],warp1_tensor.shape[3],3), np.uint8)\n    fusion[...,0] = s2[...,0]\n    fusion[...,1] = s1[...,1]*0.5 +  s2[...,1]*0.5\n    fusion[...,2] = s1[...,2]\n    path = args.path + \"composition_color.jpg\"\n    cv2.imwrite(path, fusion)\n\n\n    # save learned masks and final composition\n    stitched_image = ((stitched_image[0]+1)*127.5).cpu().detach().numpy().transpose(1,2,0)\n    learned_mask1 = (learned_mask1[0]*255).cpu().detach().numpy().transpose(1,2,0)\n    learned_mask2 = (learned_mask2[0]*255).cpu().detach().numpy().transpose(1,2,0)\n\n    path = args.path + \"learn_mask1.jpg\"\n    cv2.imwrite(path, learned_mask1)\n    path = args.path + \"learn_mask2.jpg\"\n    cv2.imwrite(path, learned_mask2)\n    path = args.path + \"composition.jpg\"\n    cv2.imwrite(path, stitched_image)\n\n\n\nif __name__==\"__main__\":\n\n    parser = argparse.ArgumentParser()\n\n    parser.add_argument('--gpu', type=str, default='0')\n    parser.add_argument('--path', type=str, default='../../Carpark-DHW/')\n\n    print('<==================== Loading data ===================>\\n')\n\n    args = parser.parse_args()\n    print(args)\n\n    test_other(args)"
  },
  {
    "path": "Composition/Codes/train.py",
    "content": "import argparse\r\nimport torch\r\nfrom torch.utils.data import DataLoader\r\nimport os\r\nimport torch.optim as optim\r\nfrom torch.utils.tensorboard import SummaryWriter\r\nfrom network import build_model, Network\r\nfrom dataset import TrainDataset\r\nimport glob\r\nfrom loss import cal_boundary_term, cal_smooth_term_stitch, cal_smooth_term_diff\r\n\r\n\r\n\r\n# path of project\r\nlast_path = os.path.abspath(os.path.join(os.path.dirname(\"__file__\"), os.path.pardir))\r\n\r\n# path to save the summary files\r\nSUMMARY_DIR = os.path.join(last_path, 'summary')\r\nwriter = SummaryWriter(log_dir=SUMMARY_DIR)\r\n\r\n# path to save the model files\r\nMODEL_DIR = os.path.join(last_path, 'model')\r\n\r\n# create folders if it dose not exist\r\nif not os.path.exists(MODEL_DIR):\r\n    os.makedirs(MODEL_DIR)\r\nif not os.path.exists(SUMMARY_DIR):\r\n    os.makedirs(SUMMARY_DIR)\r\n\r\n\r\n\r\ndef train(args):\r\n\r\n    os.environ['CUDA_DEVICES_ORDER'] = \"PCI_BUS_ID\"\r\n    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu\r\n\r\n    # dataset\r\n    train_data = TrainDataset(data_path=args.train_path)\r\n    train_loader = DataLoader(dataset=train_data, batch_size=args.batch_size, num_workers=4, shuffle=True, drop_last=True)\r\n\r\n    # define the network\r\n    net = Network()\r\n\r\n    if torch.cuda.is_available():\r\n        net = net.cuda()\r\n\r\n\r\n    # define the optimizer and learning rate\r\n    optimizer = optim.Adam(net.parameters(), lr=1e-4, betas=(0.9, 0.999), eps=1e-08)  # default as 0.0001\r\n    scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.97)\r\n\r\n    #load the existing models if it exists\r\n    ckpt_list = glob.glob(MODEL_DIR + \"/*.pth\")\r\n    ckpt_list.sort()\r\n    if len(ckpt_list) != 0:\r\n        model_path = ckpt_list[-1]\r\n        checkpoint = torch.load(model_path)\r\n\r\n        net.load_state_dict(checkpoint['model'])\r\n        optimizer.load_state_dict(checkpoint['optimizer'])\r\n        start_epoch = checkpoint['epoch']\r\n        glob_iter = checkpoint['glob_iter']\r\n        scheduler.last_epoch = start_epoch\r\n        print('load model from {}!'.format(model_path))\r\n    else:\r\n        start_epoch = 0\r\n        glob_iter = 0\r\n        print('training from stratch!')\r\n\r\n\r\n\r\n    print(\"##################start training#######################\")\r\n    score_print_fre = 300\r\n\r\n    for epoch in range(start_epoch, args.max_epoch):\r\n\r\n        print(\"start epoch {}\".format(epoch))\r\n        net.train()\r\n        sigma_total_loss = 0.\r\n        sigma_boundary_loss = 0.\r\n        sigma_smooth1_loss = 0.\r\n        sigma_smooth2_loss = 0.\r\n\r\n        print(epoch, 'lr={:.6f}'.format(optimizer.state_dict()['param_groups'][0]['lr']))\r\n\r\n        for i, batch_value in enumerate(train_loader):\r\n\r\n            warp1_tensor = batch_value[0].float()\r\n            warp2_tensor = batch_value[1].float()\r\n            mask1_tensor = batch_value[2].float()\r\n            mask2_tensor = batch_value[3].float()\r\n\r\n            if torch.cuda.is_available():\r\n                warp1_tensor = warp1_tensor.cuda()\r\n                warp2_tensor = warp2_tensor.cuda()\r\n                mask1_tensor = mask1_tensor.cuda()\r\n                mask2_tensor = mask2_tensor.cuda()\r\n\r\n\r\n            # forward, backward, update weights\r\n            optimizer.zero_grad()\r\n\r\n            batch_out = build_model(net,  warp1_tensor,  warp2_tensor, mask1_tensor, mask2_tensor)\r\n\r\n            learned_mask1 = batch_out['learned_mask1']\r\n            learned_mask2 = batch_out['learned_mask2']\r\n            stitched_image = batch_out['stitched_image']\r\n\r\n            # boundary term\r\n            boundary_loss, boundary_mask1 = cal_boundary_term( warp1_tensor,  warp2_tensor, mask1_tensor, mask2_tensor, stitched_image)\r\n            boundary_loss = 10000 * boundary_loss\r\n\r\n            #  smooth term\r\n            # on stitched image\r\n            smooth1_loss = cal_smooth_term_stitch(stitched_image, learned_mask1)\r\n            smooth1_loss = 1000* smooth1_loss\r\n            # on different image\r\n            smooth2_loss = cal_smooth_term_diff( warp1_tensor,  warp2_tensor, learned_mask1, mask1_tensor*mask2_tensor)\r\n            smooth2_loss = 1000 * smooth2_loss\r\n\r\n\r\n            total_loss = boundary_loss + smooth1_loss + smooth2_loss\r\n            total_loss.backward()\r\n            # clip the gradient\r\n            torch.nn.utils.clip_grad_norm_(net.parameters(), max_norm=3, norm_type=2)\r\n            optimizer.step()\r\n\r\n\r\n            sigma_boundary_loss += boundary_loss.item()\r\n            sigma_smooth1_loss += smooth1_loss.item()\r\n            sigma_smooth2_loss += smooth2_loss.item()\r\n            sigma_total_loss += total_loss.item()\r\n\r\n            print(glob_iter)\r\n            # print loss etc.\r\n            if i % score_print_fre == 0 and i != 0:\r\n                average_total_loss = sigma_total_loss / score_print_fre\r\n                average_boundary_loss = sigma_boundary_loss/ score_print_fre\r\n                average_smooth1_loss = sigma_smooth1_loss/ score_print_fre\r\n                average_smooth2_loss = sigma_smooth2_loss/ score_print_fre\r\n\r\n                sigma_total_loss = 0.\r\n                sigma_boundary_loss = 0.\r\n                sigma_smooth1_loss = 0.\r\n                sigma_smooth2_loss = 0.\r\n\r\n                print(\"Training: Epoch[{:0>3}/{:0>3}] Iteration[{:0>3}]/[{:0>3}] Total Loss: {:.4f}   boundary loss: {:.4f}  smooth loss: {:.4f}  diff loss: {:.4f}   lr={:.8f}\".format(epoch + 1, args.max_epoch, i + 1, len(train_loader), average_total_loss, average_boundary_loss, average_smooth1_loss, average_smooth2_loss, optimizer.state_dict()['param_groups'][0]['lr']))\r\n\r\n                # visualization\r\n                writer.add_image(\"inpu1\", (warp1_tensor[0]+1.)/2., glob_iter)\r\n                writer.add_image(\"inpu2\", (warp2_tensor[0]+1.)/2., glob_iter)\r\n\r\n\r\n                writer.add_image(\"stitched_image\", (stitched_image[0]+1.)/2., glob_iter)\r\n                writer.add_image(\"learned_mask1\", learned_mask1[0], glob_iter)\r\n                writer.add_image(\"boundary_mask1\", boundary_mask1[0], glob_iter)\r\n\r\n\r\n                writer.add_scalar('lr', optimizer.state_dict()['param_groups'][0]['lr'], glob_iter)\r\n                writer.add_scalar('total loss', average_total_loss, glob_iter)\r\n                writer.add_scalar('average_boundary_loss', average_boundary_loss, glob_iter)\r\n                writer.add_scalar('average_smooth1_loss', average_smooth1_loss, glob_iter)\r\n                writer.add_scalar('average_smooth2_loss', average_smooth2_loss, glob_iter)\r\n\r\n            glob_iter += 1\r\n\r\n\r\n        scheduler.step()\r\n        # save model\r\n        if ((epoch+1) % 10 == 0 or (epoch+1)==args.max_epoch):\r\n            filename ='epoch' + str(epoch+1).zfill(3) + '_model.pth'\r\n            model_save_path = os.path.join(MODEL_DIR, filename)\r\n            state = {'model': net.state_dict(), 'optimizer': optimizer.state_dict(), 'epoch': epoch+1, \"glob_iter\": glob_iter}\r\n            torch.save(state, model_save_path)\r\n    print(\"##################end training#######################\")\r\n\r\n\r\nif __name__==\"__main__\":\r\n\r\n\r\n    print('<==================== setting arguments ===================>\\n')\r\n\r\n    #nl: create the argument parser\r\n    parser = argparse.ArgumentParser()\r\n\r\n    #nl: add arguments\r\n    parser.add_argument('--gpu', type=str, default='0')\r\n    parser.add_argument('--batch_size', type=int, default=1)\r\n    parser.add_argument('--max_epoch', type=int, default=50)\r\n    parser.add_argument('--train_path', type=str, default='/opt/data/private/nl/Data/UDIS-D/training')\r\n\r\n    #nl: parse the arguments\r\n    args = parser.parse_args()\r\n    print(args)\r\n\r\n    print('<==================== jump into training function ===================>\\n')\r\n    #nl: rain\r\n    train(args)\r\n\r\n\r\n"
  },
  {
    "path": "Composition/model/.txt",
    "content": "\n"
  },
  {
    "path": "Composition/readme.md",
    "content": "## Train on UDIS-D\nBefore training, the warped images and corresponding masks should be generated in the warp stage.\n\nThen, set the training dataset path in Composition/Codes/train.py.\n\n```\npython train_H.py\n```\n\n## Test on UDIS-D\nThe pre-trained model of warp is available at [Google Drive](https://drive.google.com/file/d/1OaG0ayEwRPhKVV_OwQwvwHDFHC26iv30/view?usp=sharing) or [Baidu Cloud](https://pan.baidu.com/s/1qCGegzvxtzri6GiG7mNw6g)(Extraction code: 1234).\n\nSet the testing dataset path in Composition/Codes/test.py.\n\n```\npython test.py\n```\nThe composition masks and final fusion results on UDIS-D will be generated and saved at the current path.\n\n\n## Test on other datasets\nSet the 'path/' in Composition/Codes/test_other.py. \n```\npython test_other.py\n```\nThe results will be generated and saved at 'path'.\n"
  },
  {
    "path": "Composition/summary/.txt",
    "content": "\n"
  },
  {
    "path": "LICENSE",
    "content": "                                 Apache License\n                           Version 2.0, January 2004\n                        http://www.apache.org/licenses/\n\n   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION\n\n   1. Definitions.\n\n      \"License\" shall mean the terms and conditions for use, reproduction,\n      and distribution as defined by Sections 1 through 9 of this document.\n\n      \"Licensor\" shall mean the copyright owner or entity authorized by\n      the copyright owner that is granting the License.\n\n      \"Legal Entity\" shall mean the union of the acting entity and all\n      other entities that control, are controlled by, or are under common\n      control with that entity. For the purposes of this definition,\n      \"control\" means (i) the power, direct or indirect, to cause the\n      direction or management of such entity, whether by contract or\n      otherwise, or (ii) ownership of fifty percent (50%) or more of the\n      outstanding shares, or (iii) beneficial ownership of such entity.\n\n      \"You\" (or \"Your\") shall mean an individual or Legal Entity\n      exercising permissions granted by this License.\n\n      \"Source\" form shall mean the preferred form for making modifications,\n      including but not limited to software source code, documentation\n      source, and configuration files.\n\n      \"Object\" form shall mean any form resulting from mechanical\n      transformation or translation of a Source form, including but\n      not limited to compiled object code, generated documentation,\n      and conversions to other media types.\n\n      \"Work\" shall mean the work of authorship, whether in Source or\n      Object form, made available under the License, as indicated by a\n      copyright notice that is included in or attached to the work\n      (an example is provided in the Appendix below).\n\n      \"Derivative Works\" shall mean any work, whether in Source or Object\n      form, that is based on (or derived from) the Work and for which the\n      editorial revisions, annotations, elaborations, or other modifications\n      represent, as a whole, an original work of authorship. For the purposes\n      of this License, Derivative Works shall not include works that remain\n      separable from, or merely link (or bind by name) to the interfaces of,\n      the Work and Derivative Works thereof.\n\n      \"Contribution\" shall mean any work of authorship, including\n      the original version of the Work and any modifications or additions\n      to that Work or Derivative Works thereof, that is intentionally\n      submitted to Licensor for inclusion in the Work by the copyright owner\n      or by an individual or Legal Entity authorized to submit on behalf of\n      the copyright owner. For the purposes of this definition, \"submitted\"\n      means any form of electronic, verbal, or written communication sent\n      to the Licensor or its representatives, including but not limited to\n      communication on electronic mailing lists, source code control systems,\n      and issue tracking systems that are managed by, or on behalf of, the\n      Licensor for the purpose of discussing and improving the Work, but\n      excluding communication that is conspicuously marked or otherwise\n      designated in writing by the copyright owner as \"Not a Contribution.\"\n\n      \"Contributor\" shall mean Licensor and any individual or Legal Entity\n      on behalf of whom a Contribution has been received by Licensor and\n      subsequently incorporated within the Work.\n\n   2. Grant of Copyright License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      copyright license to reproduce, prepare Derivative Works of,\n      publicly display, publicly perform, sublicense, and distribute the\n      Work and such Derivative Works in Source or Object form.\n\n   3. Grant of Patent License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      (except as stated in this section) patent license to make, have made,\n      use, offer to sell, sell, import, and otherwise transfer the Work,\n      where such license applies only to those patent claims licensable\n      by such Contributor that are necessarily infringed by their\n      Contribution(s) alone or by combination of their Contribution(s)\n      with the Work to which such Contribution(s) was submitted. If You\n      institute patent litigation against any entity (including a\n      cross-claim or counterclaim in a lawsuit) alleging that the Work\n      or a Contribution incorporated within the Work constitutes direct\n      or contributory patent infringement, then any patent licenses\n      granted to You under this License for that Work shall terminate\n      as of the date such litigation is filed.\n\n   4. Redistribution. You may reproduce and distribute copies of the\n      Work or Derivative Works thereof in any medium, with or without\n      modifications, and in Source or Object form, provided that You\n      meet the following conditions:\n\n      (a) You must give any other recipients of the Work or\n          Derivative Works a copy of this License; and\n\n      (b) You must cause any modified files to carry prominent notices\n          stating that You changed the files; and\n\n      (c) You must retain, in the Source form of any Derivative Works\n          that You distribute, all copyright, patent, trademark, and\n          attribution notices from the Source form of the Work,\n          excluding those notices that do not pertain to any part of\n          the Derivative Works; and\n\n      (d) If the Work includes a \"NOTICE\" text file as part of its\n          distribution, then any Derivative Works that You distribute must\n          include a readable copy of the attribution notices contained\n          within such NOTICE file, excluding those notices that do not\n          pertain to any part of the Derivative Works, in at least one\n          of the following places: within a NOTICE text file distributed\n          as part of the Derivative Works; within the Source form or\n          documentation, if provided along with the Derivative Works; or,\n          within a display generated by the Derivative Works, if and\n          wherever such third-party notices normally appear. The contents\n          of the NOTICE file are for informational purposes only and\n          do not modify the License. You may add Your own attribution\n          notices within Derivative Works that You distribute, alongside\n          or as an addendum to the NOTICE text from the Work, provided\n          that such additional attribution notices cannot be construed\n          as modifying the License.\n\n      You may add Your own copyright statement to Your modifications and\n      may provide additional or different license terms and conditions\n      for use, reproduction, or distribution of Your modifications, or\n      for any such Derivative Works as a whole, provided Your use,\n      reproduction, and distribution of the Work otherwise complies with\n      the conditions stated in this License.\n\n   5. Submission of Contributions. Unless You explicitly state otherwise,\n      any Contribution intentionally submitted for inclusion in the Work\n      by You to the Licensor shall be under the terms and conditions of\n      this License, without any additional terms or conditions.\n      Notwithstanding the above, nothing herein shall supersede or modify\n      the terms of any separate license agreement you may have executed\n      with Licensor regarding such Contributions.\n\n   6. Trademarks. This License does not grant permission to use the trade\n      names, trademarks, service marks, or product names of the Licensor,\n      except as required for reasonable and customary use in describing the\n      origin of the Work and reproducing the content of the NOTICE file.\n\n   7. Disclaimer of Warranty. Unless required by applicable law or\n      agreed to in writing, Licensor provides the Work (and each\n      Contributor provides its Contributions) on an \"AS IS\" BASIS,\n      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n      implied, including, without limitation, any warranties or conditions\n      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A\n      PARTICULAR PURPOSE. You are solely responsible for determining the\n      appropriateness of using or redistributing the Work and assume any\n      risks associated with Your exercise of permissions under this License.\n\n   8. Limitation of Liability. In no event and under no legal theory,\n      whether in tort (including negligence), contract, or otherwise,\n      unless required by applicable law (such as deliberate and grossly\n      negligent acts) or agreed to in writing, shall any Contributor be\n      liable to You for damages, including any direct, indirect, special,\n      incidental, or consequential damages of any character arising as a\n      result of this License or out of the use or inability to use the\n      Work (including but not limited to damages for loss of goodwill,\n      work stoppage, computer failure or malfunction, or any and all\n      other commercial damages or losses), even if such Contributor\n      has been advised of the possibility of such damages.\n\n   9. Accepting Warranty or Additional Liability. While redistributing\n      the Work or Derivative Works thereof, You may choose to offer,\n      and charge a fee for, acceptance of support, warranty, indemnity,\n      or other liability obligations and/or rights consistent with this\n      License. However, in accepting such obligations, You may act only\n      on Your own behalf and on Your sole responsibility, not on behalf\n      of any other Contributor, and only if You agree to indemnify,\n      defend, and hold each Contributor harmless for any liability\n      incurred by, or claims asserted against, such Contributor by reason\n      of your accepting any such warranty or additional liability.\n\n   END OF TERMS AND CONDITIONS\n\n   APPENDIX: How to apply the Apache License to your work.\n\n      To apply the Apache License to your work, attach the following\n      boilerplate notice, with the fields enclosed by brackets \"[]\"\n      replaced with your own identifying information. (Don't include\n      the brackets!)  The text should be enclosed in the appropriate\n      comment syntax for the file format. We also recommend that a\n      file or class name and description of purpose be included on the\n      same \"printed page\" as the copyright notice for easier\n      identification within third-party archives.\n\n   Copyright [yyyy] [name of copyright owner]\n\n   Licensed under the Apache License, Version 2.0 (the \"License\");\n   you may not use this file except in compliance with the License.\n   You may obtain a copy of the License at\n\n       http://www.apache.org/licenses/LICENSE-2.0\n\n   Unless required by applicable law or agreed to in writing, software\n   distributed under the License is distributed on an \"AS IS\" BASIS,\n   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n   See the License for the specific language governing permissions and\n   limitations under the License.\n"
  },
  {
    "path": "README.md",
    "content": "# <p align=\"center\">Parallax-Tolerant Unsupervised Deep Image Stitching (UDIS++ [paper](https://arxiv.org/abs/2302.08207))</p>\n<p align=\"center\">Lang Nie*, Chunyu Lin*, Kang Liao*, Shuaicheng Liu`, Yao Zhao*</p>\n<p align=\"center\">* Institute of Information Science, Beijing Jiaotong University</p>\n<p align=\"center\">` School of Information and Communication Engineering, University of Electronic Science and Technology of China</p>\n\n![image](https://github.com/nie-lang/UDIS2/blob/main/fig1.png)\n\n## Dataset (UDIS-D)\nWe use the UDIS-D dataset to train and evaluate our method. Please refer to [UDIS](https://github.com/nie-lang/UnsupervisedDeepImageStitching) for more details about this dataset.\n\n\n## Code\n#### Requirement\n* numpy 1.19.5\n* pytorch 1.7.1\n* scikit-image 0.15.0\n* tensorboard 2.9.0\n\nWe implement this work with Ubuntu, 3090Ti, and CUDA11. Refer to [environment.yml](https://github.com/nie-lang/UDIS2/blob/main/environment.yml) for more details.\n\n#### How to run it\nSimilar to UDIS, we also implement this solution in two stages:\n* Stage 1 (unsupervised warp): please refer to  [Warp/readme.md](https://github.com/nie-lang/UDIS2/blob/main/Warp/readme.md).\n* Stage 2 (unsupervised composition): please refer to [Composition/readme.md](https://github.com/nie-lang/UDIS2/blob/main/Composition/readme.md).\n\n\n\n## Meta\nIf you have any questions about this project, please feel free to drop me an email.\n\nNIE Lang -- nielang@bjtu.edu.cn\n```\n@inproceedings{nie2023parallax,\n  title={Parallax-Tolerant Unsupervised Deep Image Stitching},\n  author={Nie, Lang and Lin, Chunyu and Liao, Kang and Liu, Shuaicheng and Zhao, Yao},\n  booktitle={Proceedings of the IEEE/CVF International Conference on Computer Vision},\n  pages={7399--7408},\n  year={2023}\n}\n```\n\n\n## References\n[1] L. Nie, C. Lin, K. Liao, M. Liu, and Y. Zhao, “A view-free image stitching network based on global homography,” Journal of Visual Communication and Image Representation, p. 102950, 2020.  \n[2] L. Nie, C. Lin, K. Liao, and Y. Zhao. Learning edge-preserved image stitching from multi-scale deep homography[J]. Neurocomputing, 2022, 491: 533-543.   \n[3] L. Nie, C. Lin, K. Liao, S. Liu, and Y. Zhao. Unsupervised deep image stitching: Reconstructing stitched features to images[J]. IEEE Transactions on Image Processing, 2021, 30: 6184-6197.   \n[4] L. Nie, C. Lin, K. Liao, S. Liu, and Y. Zhao. Deep rectangling for image stitching: a learning baseline[C]//Proceedings of the IEEE/CVF conference on computer vision and pattern recognition. 2022: 5740-5748.   \n"
  },
  {
    "path": "Warp/Codes/dataset.py",
    "content": "from torch.utils.data import Dataset\r\nimport  numpy as np\r\nimport cv2, torch\r\nimport os\r\nimport glob\r\nfrom collections import OrderedDict\r\nimport random\r\n\r\n\r\nclass TrainDataset(Dataset):\r\n    def __init__(self, data_path):\r\n\r\n        self.width = 512\r\n        self.height = 512\r\n        self.train_path = data_path\r\n        self.datas = OrderedDict()\r\n        \r\n        datas = glob.glob(os.path.join(self.train_path, '*'))\r\n        for data in sorted(datas):\r\n            data_name = data.split('/')[-1]\r\n            if data_name == 'input1' or data_name == 'input2' :\r\n                self.datas[data_name] = {}\r\n                self.datas[data_name]['path'] = data\r\n                self.datas[data_name]['image'] = glob.glob(os.path.join(data, '*.jpg'))\r\n                self.datas[data_name]['image'].sort()\r\n        print(self.datas.keys())\r\n\r\n    def __getitem__(self, index):\r\n        \r\n        # load image1\r\n        input1 = cv2.imread(self.datas['input1']['image'][index])\r\n        input1 = cv2.resize(input1, (self.width, self.height))\r\n        input1 = input1.astype(dtype=np.float32)\r\n        input1 = (input1 / 127.5) - 1.0\r\n        input1 = np.transpose(input1, [2, 0, 1])\r\n        \r\n        # load image2\r\n        input2 = cv2.imread(self.datas['input2']['image'][index])\r\n        input2 = cv2.resize(input2, (self.width, self.height))\r\n        input2 = input2.astype(dtype=np.float32)\r\n        input2 = (input2 / 127.5) - 1.0\r\n        input2 = np.transpose(input2, [2, 0, 1])\r\n        \r\n        # convert to tensor\r\n        input1_tensor = torch.tensor(input1)\r\n        input2_tensor = torch.tensor(input2)\r\n        \r\n        #print(\"fasdf\")\r\n        if_exchange = random.randint(0,1)\r\n        if if_exchange == 0:\r\n            #print(if_exchange)\r\n            return (input1_tensor, input2_tensor)\r\n        else:\r\n            #print(if_exchange)\r\n            return (input2_tensor, input1_tensor)\r\n\r\n    def __len__(self):\r\n\r\n        return len(self.datas['input1']['image'])\r\n\r\nclass TestDataset(Dataset):\r\n    def __init__(self, data_path):\r\n\r\n        self.width = 512\r\n        self.height = 512\r\n        self.test_path = data_path\r\n        self.datas = OrderedDict()\r\n        \r\n        datas = glob.glob(os.path.join(self.test_path, '*'))\r\n        for data in sorted(datas):\r\n            data_name = data.split('/')[-1]\r\n            if data_name == 'input1' or data_name == 'input2' :\r\n                self.datas[data_name] = {}\r\n                self.datas[data_name]['path'] = data\r\n                self.datas[data_name]['image'] = glob.glob(os.path.join(data, '*.jpg'))\r\n                self.datas[data_name]['image'].sort()\r\n        print(self.datas.keys())\r\n\r\n    def __getitem__(self, index):\r\n        \r\n        # load image1\r\n        input1 = cv2.imread(self.datas['input1']['image'][index])\r\n        #input1 = cv2.resize(input1, (self.width, self.height))\r\n        input1 = input1.astype(dtype=np.float32)\r\n        input1 = (input1 / 127.5) - 1.0\r\n        input1 = np.transpose(input1, [2, 0, 1])\r\n        \r\n        # load image2\r\n        input2 = cv2.imread(self.datas['input2']['image'][index])\r\n        #input2 = cv2.resize(input2, (self.width, self.height))\r\n        input2 = input2.astype(dtype=np.float32)\r\n        input2 = (input2 / 127.5) - 1.0\r\n        input2 = np.transpose(input2, [2, 0, 1])\r\n        \r\n        # convert to tensor\r\n        input1_tensor = torch.tensor(input1)\r\n        input2_tensor = torch.tensor(input2)\r\n\r\n        return (input1_tensor, input2_tensor)\r\n\r\n    def __len__(self):\r\n\r\n        return len(self.datas['input1']['image'])\r\n\r\n\r\n\r\n"
  },
  {
    "path": "Warp/Codes/grid_res.py",
    "content": "\n#define control point resolution (GRID_H+1) * (GRID_W+1)\nGRID_H = 12\nGRID_W = 12"
  },
  {
    "path": "Warp/Codes/loss.py",
    "content": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nimport grid_res\ngrid_h = grid_res.GRID_H\ngrid_w = grid_res.GRID_W\n\n\ndef l_num_loss(img1, img2, l_num=1):\n    return torch.mean(torch.abs((img1 - img2)**l_num))\n\n\ndef cal_lp_loss(input1, input2, output_H, output_H_inv, warp_mesh, warp_mesh_mask):\n    batch_size, _, img_h, img_w = input1.size()\n\n    # part one: sym homo loss with color balance\n    delta1 = ( torch.sum(output_H[:,0:3,:,:], [2,3])  -   torch.sum(input1*output_H[:,3:6,:,:], [2,3]) ) /  torch.sum(output_H[:,3:6,:,:], [2,3])\n    input1_balance = input1 + delta1.unsqueeze(2).unsqueeze(3).expand(-1, -1, img_h, img_w)\n\n    delta2 = ( torch.sum(output_H_inv[:,0:3,:,:], [2,3])  -   torch.sum(input2*output_H_inv[:,3:6,:,:], [2,3]) ) /  torch.sum(output_H_inv[:,3:6,:,:], [2,3])\n    input2_balance = input2 + delta2.unsqueeze(2).unsqueeze(3).expand(-1, -1, img_h, img_w)\n\n    lp_loss_1 = l_num_loss(input1_balance*output_H[:,3:6,:,:], output_H[:,0:3,:,:], 1) + l_num_loss(input2_balance*output_H_inv[:,3:6,:,:], output_H_inv[:,0:3,:,:], 1)\n\n    # part two: tps loss with color balance\n    delta3 = ( torch.sum(warp_mesh, [2,3])  -   torch.sum(input1*warp_mesh_mask, [2,3]) ) /  torch.sum(warp_mesh_mask, [2,3])\n    input1_newbalance = input1 + delta3.unsqueeze(2).unsqueeze(3).expand(-1, -1, img_h, img_w)\n\n    lp_loss_2 = l_num_loss(input1_newbalance*warp_mesh_mask, warp_mesh, 1)\n\n\n    lp_loss = 3. * lp_loss_1 + 1. * lp_loss_2\n\n    return lp_loss\n\ndef cal_lp_loss2(input1, warp_mesh, warp_mesh_mask):\n    batch_size, _, img_h, img_w = input1.size()\n\n    delta3 = ( torch.sum(warp_mesh, [2,3])  -   torch.sum(input1*warp_mesh_mask, [2,3]) ) /  torch.sum(warp_mesh_mask, [2,3])\n    input1_newbalance = input1 + delta3.unsqueeze(2).unsqueeze(3).expand(-1, -1, img_h, img_w)\n\n    lp_loss_2 = l_num_loss(input1_newbalance*warp_mesh_mask, warp_mesh, 1)\n    lp_loss =  1. * lp_loss_2\n\n    return lp_loss\n\ndef inter_grid_loss(overlap, mesh):\n\n    ##############################\n    # compute horizontal edges\n    w_edges = mesh[:,:,0:grid_w,:] - mesh[:,:,1:grid_w+1,:]\n    # compute angles of two successive horizontal edges\n    cos_w = torch.sum(w_edges[:,:,0:grid_w-1,:] * w_edges[:,:,1:grid_w,:],3) / (torch.sqrt(torch.sum(w_edges[:,:,0:grid_w-1,:]*w_edges[:,:,0:grid_w-1,:],3))*torch.sqrt(torch.sum(w_edges[:,:,1:grid_w,:]*w_edges[:,:,1:grid_w,:],3)))\n    # horizontal angle-preserving error for two successive horizontal edges\n    delta_w_angle = 1 - cos_w\n    # horizontal angle-preserving error for two successive horizontal grids\n    delta_w_angle = delta_w_angle[:,0:grid_h,:] + delta_w_angle[:,1:grid_h+1,:]\n    ##############################\n\n    ##############################\n    # compute vertical edges\n    h_edges = mesh[:,0:grid_h,:,:] - mesh[:,1:grid_h+1,:,:]\n    # compute angles of two successive vertical edges\n    cos_h = torch.sum(h_edges[:,0:grid_h-1,:,:] * h_edges[:,1:grid_h,:,:],3) / (torch.sqrt(torch.sum(h_edges[:,0:grid_h-1,:,:]*h_edges[:,0:grid_h-1,:,:],3))*torch.sqrt(torch.sum(h_edges[:,1:grid_h,:,:]*h_edges[:,1:grid_h,:,:],3)))\n    # vertical angle-preserving error for two successive vertical edges\n    delta_h_angle = 1 - cos_h\n    # vertical angle-preserving error for two successive vertical grids\n    delta_h_angle = delta_h_angle[:,:,0:grid_w] + delta_h_angle[:,:,1:grid_w+1]\n    ##############################\n\n    # on overlapping regions\n    depth_diff_w = (1-torch.abs(overlap[:,:,0:grid_w-1] - overlap[:,:,1:grid_w])) * overlap[:,:,0:grid_w-1]\n    error_w = depth_diff_w * delta_w_angle\n    # on overlapping regions\n    depth_diff_h = (1-torch.abs(overlap[:,0:grid_h-1,:] - overlap[:,1:grid_h,:])) * overlap[:,0:grid_h-1,:]\n    error_h = depth_diff_h * delta_h_angle\n\n    return torch.mean(error_w) + torch.mean(error_h)\n\n\n\n# intra-grid constraint\ndef intra_grid_loss(pts):\n\n    max_w = 512/grid_w * 2\n    max_h = 512/grid_h * 2\n\n    delta_x = pts[:,:,1:grid_w+1,0] - pts[:,:,0:grid_w,0]\n    delta_y = pts[:,1:grid_h+1,:,1] - pts[:,0:grid_h,:,1]\n\n    loss_x = F.relu(delta_x - max_w)\n    loss_y = F.relu(delta_y - max_h)\n    loss = torch.mean(loss_x) + torch.mean(loss_y)\n\n\n    return loss\n\n\n\n"
  },
  {
    "path": "Warp/Codes/network.py",
    "content": "import torch\nimport torch.nn as nn\nimport utils.torch_DLT as torch_DLT\nimport utils.torch_homo_transform as torch_homo_transform\nimport utils.torch_tps_transform as torch_tps_transform\nimport ssl\nimport torch.nn.functional as F\nimport cv2\nimport numpy as np\nimport torchvision.models as models\n\nimport torchvision.transforms as T\nresize_512 = T.Resize((512,512))\n\nimport grid_res\ngrid_h = grid_res.GRID_H\ngrid_w = grid_res.GRID_W\n\n\n# draw mesh on image\n# warp: h*w*3\n# f_local: grid_h*grid_w*2\ndef draw_mesh_on_warp(warp, f_local):\n\n    warp = np.ascontiguousarray(warp)\n\n    point_color = (0, 255, 0) # BGR\n    thickness = 2\n    lineType = 8\n\n    num = 1\n    for i in range(grid_h+1):\n        for j in range(grid_w+1):\n\n            num = num + 1\n            if j == grid_w and i == grid_h:\n                continue\n            elif j == grid_w:\n                cv2.line(warp, (int(f_local[i,j,0]), int(f_local[i,j,1])), (int(f_local[i+1,j,0]), int(f_local[i+1,j,1])), point_color, thickness, lineType)\n            elif i == grid_h:\n                cv2.line(warp, (int(f_local[i,j,0]), int(f_local[i,j,1])), (int(f_local[i,j+1,0]), int(f_local[i,j+1,1])), point_color, thickness, lineType)\n            else :\n                cv2.line(warp, (int(f_local[i,j,0]), int(f_local[i,j,1])), (int(f_local[i+1,j,0]), int(f_local[i+1,j,1])), point_color, thickness, lineType)\n                cv2.line(warp, (int(f_local[i,j,0]), int(f_local[i,j,1])), (int(f_local[i,j+1,0]), int(f_local[i,j+1,1])), point_color, thickness, lineType)\n\n    return warp\n\n\n#Covert global homo into mesh\ndef H2Mesh(H, rigid_mesh):\n\n    H_inv = torch.inverse(H)\n    ori_pt = rigid_mesh.reshape(rigid_mesh.size()[0], -1, 2)\n    ones = torch.ones(rigid_mesh.size()[0], (grid_h+1)*(grid_w+1),1)\n    if torch.cuda.is_available():\n        ori_pt = ori_pt.cuda()\n        ones = ones.cuda()\n\n    ori_pt = torch.cat((ori_pt, ones), 2) # bs*(grid_h+1)*(grid_w+1)*3\n    tar_pt = torch.matmul(H_inv, ori_pt.permute(0,2,1)) # bs*3*(grid_h+1)*(grid_w+1)\n\n    mesh_x = torch.unsqueeze(tar_pt[:,0,:]/tar_pt[:,2,:], 2)\n    mesh_y = torch.unsqueeze(tar_pt[:,1,:]/tar_pt[:,2,:], 2)\n    mesh = torch.cat((mesh_x, mesh_y), 2).reshape([rigid_mesh.size()[0], grid_h+1, grid_w+1, 2])\n\n    return mesh\n\n# get rigid mesh\ndef get_rigid_mesh(batch_size, height, width):\n\n    ww = torch.matmul(torch.ones([grid_h+1, 1]), torch.unsqueeze(torch.linspace(0., float(width), grid_w+1), 0))\n    hh = torch.matmul(torch.unsqueeze(torch.linspace(0.0, float(height), grid_h+1), 1), torch.ones([1, grid_w+1]))\n    if torch.cuda.is_available():\n        ww = ww.cuda()\n        hh = hh.cuda()\n\n    ori_pt = torch.cat((ww.unsqueeze(2), hh.unsqueeze(2)),2) # (grid_h+1)*(grid_w+1)*2\n    ori_pt = ori_pt.unsqueeze(0).expand(batch_size, -1, -1, -1)\n\n    return ori_pt\n\n# normalize mesh from -1 ~ 1\ndef get_norm_mesh(mesh, height, width):\n    batch_size = mesh.size()[0]\n    mesh_w = mesh[...,0]*2./float(width) - 1.\n    mesh_h = mesh[...,1]*2./float(height) - 1.\n    norm_mesh = torch.stack([mesh_w, mesh_h], 3) # bs*(grid_h+1)*(grid_w+1)*2\n\n    return norm_mesh.reshape([batch_size, -1, 2]) # bs*-1*2\n\n\n\n# random augmentation\n# it seems to do nothing to the performance\ndef data_aug(img1, img2):\n    # Randomly shift brightness\n    random_brightness = torch.randn(1).uniform_(0.7,1.3).cuda()\n    img1_aug = img1 * random_brightness\n    random_brightness = torch.randn(1).uniform_(0.7,1.3).cuda()\n    img2_aug = img2 * random_brightness\n\n    # Randomly shift color\n    white = torch.ones([img1.size()[0], img1.size()[2], img1.size()[3]]).cuda()\n    random_colors = torch.randn(3).uniform_(0.7,1.3).cuda()\n    color_image = torch.stack([white * random_colors[i] for i in range(3)], axis=1)\n    img1_aug  *= color_image\n\n    random_colors = torch.randn(3).uniform_(0.7,1.3).cuda()\n    color_image = torch.stack([white * random_colors[i] for i in range(3)], axis=1)\n    img2_aug  *= color_image\n\n    # clip\n    img1_aug = torch.clamp(img1_aug, -1, 1)\n    img2_aug = torch.clamp(img2_aug, -1, 1)\n\n    return img1_aug, img2_aug\n\n\n# for train.py / test.py\ndef build_model(net, input1_tensor, input2_tensor, is_training = True):\n    batch_size, _, img_h, img_w = input1_tensor.size()\n\n    # network\n    if is_training == True:\n        aug_input1_tensor, aug_input2_tensor = data_aug(input1_tensor, input2_tensor)\n        H_motion, mesh_motion = net(aug_input1_tensor, aug_input2_tensor)\n    else:\n        H_motion, mesh_motion = net(input1_tensor, input2_tensor)\n\n    H_motion = H_motion.reshape(-1, 4, 2)\n    mesh_motion = mesh_motion.reshape(-1, grid_h+1, grid_w+1, 2)\n\n    # initialize the source points bs x 4 x 2\n    src_p = torch.tensor([[0., 0.], [img_w, 0.], [0., img_h], [img_w, img_h]])\n    if torch.cuda.is_available():\n        src_p = src_p.cuda()\n    src_p = src_p.unsqueeze(0).expand(batch_size, -1, -1)\n    # target points\n    dst_p = src_p + H_motion\n    # solve homo using DLT\n    H = torch_DLT.tensor_DLT(src_p, dst_p)\n\n    M_tensor = torch.tensor([[img_w / 2.0, 0., img_w / 2.0],\n                      [0., img_h / 2.0, img_h / 2.0],\n                      [0., 0., 1.]])\n\n    if torch.cuda.is_available():\n        M_tensor = M_tensor.cuda()\n\n    M_tile = M_tensor.unsqueeze(0).expand(batch_size, -1, -1)\n    M_tensor_inv = torch.inverse(M_tensor)\n    M_tile_inv = M_tensor_inv.unsqueeze(0).expand(batch_size, -1, -1)\n    H_mat = torch.matmul(torch.matmul(M_tile_inv, H), M_tile)\n\n    mask = torch.ones_like(input2_tensor)\n    if torch.cuda.is_available():\n        mask = mask.cuda()\n    output_H = torch_homo_transform.transformer(torch.cat((input2_tensor, mask), 1), H_mat, (img_h, img_w))\n\n    H_inv_mat = torch.matmul(torch.matmul(M_tile_inv, torch.inverse(H)), M_tile)\n    output_H_inv = torch_homo_transform.transformer(torch.cat((input1_tensor, mask), 1), H_inv_mat, (img_h, img_w))\n\n    rigid_mesh = get_rigid_mesh(batch_size, img_h, img_w)\n    ini_mesh = H2Mesh(H, rigid_mesh)\n    mesh = ini_mesh + mesh_motion\n\n\n    norm_rigid_mesh = get_norm_mesh(rigid_mesh, img_h, img_w)\n    norm_mesh = get_norm_mesh(mesh, img_h, img_w)\n\n    output_tps = torch_tps_transform.transformer(torch.cat((input2_tensor, mask), 1), norm_mesh, norm_rigid_mesh, (img_h, img_w))\n    warp_mesh = output_tps[:,0:3,...]\n    warp_mesh_mask = output_tps[:,3:6,...]\n\n    # calculate the overlapping regions to apply shape-preserving constraints\n    overlap = torch_tps_transform.transformer(warp_mesh_mask, norm_rigid_mesh, norm_mesh, (img_h, img_w))\n    overlap = overlap.permute(0, 2, 3, 1).unfold(1, int(img_h/grid_h), int(img_h/grid_h)).unfold(2, int(img_w/grid_w), int(img_w/grid_w))\n    overlap = torch.mean(overlap.reshape(batch_size, grid_h, grid_w, -1), 3)\n    overlap_one = torch.ones_like(overlap)\n    overlap_zero = torch.zeros_like(overlap)\n    overlap = torch.where(overlap<0.9, overlap_one, overlap_zero)\n\n\n    out_dict = {}\n    out_dict.update(output_H=output_H, output_H_inv = output_H_inv, warp_mesh = warp_mesh, warp_mesh_mask = warp_mesh_mask, mesh1 = rigid_mesh, mesh2 = mesh, overlap = overlap)\n\n\n    return out_dict\n\n# for train_ft.py\ndef build_new_ft_model(net, input1_tensor, input2_tensor):\n    batch_size, _, img_h, img_w = input1_tensor.size()\n\n    H_motion, mesh_motion = net(input1_tensor, input2_tensor)\n\n    H_motion = H_motion.reshape(-1, 4, 2)\n    #H_motion = torch.stack([H_motion[...,0]*img_w/512, H_motion[...,1]*img_h/512], 2)\n\n    mesh_motion = mesh_motion.reshape(-1, grid_h+1, grid_w+1, 2)\n    #mesh_motion = torch.stack([mesh_motion[...,0]*img_w/512, mesh_motion[...,1]*img_h/512], 3)\n\n    # initialize the source points bs x 4 x 2\n    src_p = torch.tensor([[0., 0.], [img_w, 0.], [0., img_h], [img_w, img_h]])\n    if torch.cuda.is_available():\n        src_p = src_p.cuda()\n    src_p = src_p.unsqueeze(0).expand(batch_size, -1, -1)\n    # target points\n    dst_p = src_p + H_motion\n    # solve homo using DLT\n    H = torch_DLT.tensor_DLT(src_p, dst_p)\n\n\n    rigid_mesh = get_rigid_mesh(batch_size, img_h, img_w)\n    ini_mesh = H2Mesh(H, rigid_mesh)\n    mesh = ini_mesh + mesh_motion\n\n    norm_rigid_mesh = get_norm_mesh(rigid_mesh, img_h, img_w)\n    norm_mesh = get_norm_mesh(mesh, img_h, img_w)\n\n    mask = torch.ones_like(input2_tensor)\n    if torch.cuda.is_available():\n        mask = mask.cuda()\n    output_tps = torch_tps_transform.transformer(torch.cat((input2_tensor, mask), 1), norm_mesh, norm_rigid_mesh, (img_h, img_w))\n    warp_mesh = output_tps[:,0:3,...]\n    warp_mesh_mask = output_tps[:,3:6,...]\n\n\n    out_dict = {}\n    out_dict.update(warp_mesh = warp_mesh, warp_mesh_mask = warp_mesh_mask, rigid_mesh = rigid_mesh, mesh = mesh)\n\n\n    return out_dict\n\n# for train_ft.py\ndef get_stitched_result(input1_tensor, input2_tensor, rigid_mesh, mesh):\n    batch_size, _, img_h, img_w = input1_tensor.size()\n\n    rigid_mesh = torch.stack([rigid_mesh[...,0]*img_w/512, rigid_mesh[...,1]*img_h/512], 3)\n    mesh = torch.stack([mesh[...,0]*img_w/512, mesh[...,1]*img_h/512], 3)\n\n    ######################################\n    width_max = torch.max(mesh[...,0])\n    width_max = torch.maximum(torch.tensor(img_w).cuda(), width_max)\n    width_min = torch.min(mesh[...,0])\n    width_min = torch.minimum(torch.tensor(0).cuda(), width_min)\n    height_max = torch.max(mesh[...,1])\n    height_max = torch.maximum(torch.tensor(img_h).cuda(), height_max)\n    height_min = torch.min(mesh[...,1])\n    height_min = torch.minimum(torch.tensor(0).cuda(), height_min)\n\n    out_width = width_max - width_min\n    out_height = height_max - height_min\n    print(out_width)\n    print(out_height)\n\n    warp1 = torch.zeros([batch_size, 3, out_height.int(), out_width.int()]).cuda()\n    warp1[:,:, int(torch.abs(height_min)):int(torch.abs(height_min))+img_h,  int(torch.abs(width_min)):int(torch.abs(width_min))+img_w] = (input1_tensor+1)*127.5\n\n    mask1 = torch.zeros([batch_size, 3, out_height.int(), out_width.int()]).cuda()\n    mask1[:,:, int(torch.abs(height_min)):int(torch.abs(height_min))+img_h,  int(torch.abs(width_min)):int(torch.abs(width_min))+img_w] = 255\n\n    mask = torch.ones_like(input2_tensor)\n    if torch.cuda.is_available():\n        mask = mask.cuda()\n\n    # get warped img2\n    mesh_trans = torch.stack([mesh[...,0]-width_min, mesh[...,1]-height_min], 3)\n    norm_rigid_mesh = get_norm_mesh(rigid_mesh, img_h, img_w)\n    norm_mesh = get_norm_mesh(mesh_trans, out_height, out_width)\n\n    stitch_tps_out = torch_tps_transform.transformer(torch.cat([input2_tensor+1, mask], 1), norm_mesh, norm_rigid_mesh, (out_height.int(), out_width.int()))\n    warp2 = stitch_tps_out[:,0:3,:,:]*127.5\n    mask2 = stitch_tps_out[:,3:6,:,:]*255\n\n    stitched = warp1*(warp1/(warp1+warp2+1e-6)) + warp2*(warp2/(warp1+warp2+1e-6))\n\n    stitched_mesh = draw_mesh_on_warp(stitched[0].cpu().detach().numpy().transpose(1,2,0), mesh_trans[0].cpu().detach().numpy())\n\n    out_dict = {}\n    out_dict.update(warp1 = warp1, mask1 = mask1, warp2 = warp2, mask2 = mask2, stitched = stitched, stitched_mesh = stitched_mesh)\n\n    return out_dict\n\n\n# for test_output.py\ndef build_output_model(net, input1_tensor, input2_tensor):\n    batch_size, _, img_h, img_w = input1_tensor.size()\n\n    resized_input1 = resize_512(input1_tensor)\n    resized_input2 = resize_512(input2_tensor)\n    H_motion, mesh_motion = net(resized_input1, resized_input2)\n\n    H_motion = H_motion.reshape(-1, 4, 2)\n    H_motion = torch.stack([H_motion[...,0]*img_w/512, H_motion[...,1]*img_h/512], 2)\n    mesh_motion = mesh_motion.reshape(-1, grid_h+1, grid_w+1, 2)\n    mesh_motion = torch.stack([mesh_motion[...,0]*img_w/512, mesh_motion[...,1]*img_h/512], 3)\n\n    # initialize the source points bs x 4 x 2\n    src_p = torch.tensor([[0., 0.], [img_w, 0.], [0., img_h], [img_w, img_h]])\n    if torch.cuda.is_available():\n        src_p = src_p.cuda()\n    src_p = src_p.unsqueeze(0).expand(batch_size, -1, -1)\n    # target points\n    dst_p = src_p + H_motion\n    # solve homo using DLT\n    H = torch_DLT.tensor_DLT(src_p, dst_p)\n\n\n    rigid_mesh = get_rigid_mesh(batch_size, img_h, img_w)\n    ini_mesh = H2Mesh(H, rigid_mesh)\n    mesh = ini_mesh + mesh_motion\n\n    width_max = torch.max(mesh[...,0])\n    width_max = torch.maximum(torch.tensor(img_w).cuda(), width_max)\n    width_min = torch.min(mesh[...,0])\n    width_min = torch.minimum(torch.tensor(0).cuda(), width_min)\n    height_max = torch.max(mesh[...,1])\n    height_max = torch.maximum(torch.tensor(img_h).cuda(), height_max)\n    height_min = torch.min(mesh[...,1])\n    height_min = torch.minimum(torch.tensor(0).cuda(), height_min)\n\n    out_width = width_max - width_min\n    out_height = height_max - height_min\n    #print(out_width)\n    #print(out_height)\n\n    # get warped img1\n    M_tensor = torch.tensor([[out_width / 2.0, 0., out_width / 2.0],\n                      [0., out_height / 2.0, out_height / 2.0],\n                      [0., 0., 1.]])\n    N_tensor = torch.tensor([[img_w / 2.0, 0., img_w / 2.0],\n                      [0., img_h / 2.0, img_h / 2.0],\n                      [0., 0., 1.]])\n    if torch.cuda.is_available():\n        M_tensor = M_tensor.cuda()\n        N_tensor = N_tensor.cuda()\n    N_tensor_inv = torch.inverse(N_tensor)\n\n    I_ = torch.tensor([[1., 0., width_min],\n                      [0., 1., height_min],\n                      [0., 0., 1.]])#.unsqueeze(0)\n    mask = torch.ones_like(input2_tensor)\n    if torch.cuda.is_available():\n        I_ = I_.cuda()\n        mask = mask.cuda()\n    I_mat = torch.matmul(torch.matmul(N_tensor_inv, I_), M_tensor).unsqueeze(0)\n\n    homo_output = torch_homo_transform.transformer(torch.cat((input1_tensor+1, mask), 1), I_mat, (out_height.int(), out_width.int()))\n\n    torch.cuda.empty_cache()\n    # get warped img2\n    mesh_trans = torch.stack([mesh[...,0]-width_min, mesh[...,1]-height_min], 3)\n    norm_rigid_mesh = get_norm_mesh(rigid_mesh, img_h, img_w)\n    norm_mesh = get_norm_mesh(mesh_trans, out_height, out_width)\n    tps_output = torch_tps_transform.transformer(torch.cat([input2_tensor+1, mask],1), norm_mesh, norm_rigid_mesh, (out_height.int(), out_width.int()))\n\n\n    out_dict = {}\n    out_dict.update(final_warp1=homo_output[:, 0:3, ...]-1, final_warp1_mask = homo_output[:, 3:6, ...], final_warp2=tps_output[:, 0:3, ...]-1, final_warp2_mask = tps_output[:, 3:6, ...], mesh1=rigid_mesh, mesh2=mesh_trans)\n\n    return out_dict\n\n\n\n# define and forward\nclass Network(nn.Module):\n\n    def __init__(self):\n        super(Network, self).__init__()\n\n        self.regressNet1_part1 = nn.Sequential(\n            nn.Conv2d(2, 64, kernel_size=3, padding=1, bias=False),\n            nn.ReLU(inplace=True),\n            nn.Conv2d(64, 64, kernel_size=3, padding=1, bias=False),\n            nn.ReLU(inplace=True),\n            nn.MaxPool2d(2, 2),\n\n            nn.Conv2d(64, 128, kernel_size=3, padding=1, bias=False),\n            nn.ReLU(inplace=True),\n            nn.Conv2d(128, 128, kernel_size=3, padding=1, bias=False),\n            nn.ReLU(inplace=True),\n            nn.MaxPool2d(2, 2),\n\n            nn.Conv2d(128, 256, kernel_size=3, padding=1, bias=False),\n            nn.ReLU(inplace=True),\n            nn.Conv2d(256, 256, kernel_size=3, padding=1, bias=False),\n            nn.ReLU(inplace=True),\n            nn.MaxPool2d(2, 2)\n        )\n\n        self.regressNet1_part2 = nn.Sequential(\n            nn.Linear(in_features=4096, out_features=4096, bias=True),\n            nn.ReLU(inplace=True),\n\n            nn.Linear(in_features=4096, out_features=1024, bias=True),\n            nn.ReLU(inplace=True),\n\n            nn.Linear(in_features=1024, out_features=8, bias=True)\n        )\n\n\n        self.regressNet2_part1 = nn.Sequential(\n            nn.Conv2d(2, 64, kernel_size=3, padding=1, bias=False),\n            nn.ReLU(inplace=True),\n            nn.Conv2d(64, 64, kernel_size=3, padding=1, bias=False),\n            nn.ReLU(inplace=True),\n            nn.MaxPool2d(2, 2),\n\n            nn.Conv2d(64, 128, kernel_size=3, padding=1, bias=False),\n            nn.ReLU(inplace=True),\n            nn.Conv2d(128, 128, kernel_size=3, padding=1, bias=False),\n            nn.ReLU(inplace=True),\n            nn.MaxPool2d(2, 2),\n\n            nn.Conv2d(128, 256, kernel_size=3, padding=1, bias=False),\n            nn.ReLU(inplace=True),\n            nn.Conv2d(256, 256, kernel_size=3, padding=1, bias=False),\n            nn.ReLU(inplace=True),\n            nn.MaxPool2d(2, 2),\n\n            nn.Conv2d(256, 512, kernel_size=3, padding=1, bias=False),\n            nn.ReLU(inplace=True),\n            nn.Conv2d(512, 512, kernel_size=3, padding=1, bias=False),\n            nn.ReLU(inplace=True),\n            nn.MaxPool2d(2, 2)\n        )\n\n        self.regressNet2_part2 = nn.Sequential(\n            nn.Linear(in_features=8192, out_features=4096, bias=True),\n            nn.ReLU(inplace=True),\n\n            nn.Linear(in_features=4096, out_features=2048, bias=True),\n            nn.ReLU(inplace=True),\n\n            nn.Linear(in_features=2048, out_features=(grid_w+1)*(grid_h+1)*2, bias=True)\n\n        )\n\n\n        for m in self.modules():\n            if isinstance(m, nn.Conv2d):\n                nn.init.kaiming_normal_(m.weight)\n            elif isinstance(m, nn.BatchNorm2d):\n                m.weight.data.fill_(1)\n                m.bias.data.zero_()\n\n        ssl._create_default_https_context = ssl._create_unverified_context\n        resnet50_model = models.resnet.resnet50(pretrained=True)\n\n        if torch.cuda.is_available():\n            resnet50_model = resnet50_model.cuda()\n        self.feature_extractor_stage1, self.feature_extractor_stage2 = self.get_res50_FeatureMap(resnet50_model)\n        #-----------------------------------------\n\n    def get_res50_FeatureMap(self, resnet50_model):\n\n        layers_list = []\n\n        layers_list.append(resnet50_model.conv1)\n        layers_list.append(resnet50_model.bn1)\n        layers_list.append(resnet50_model.relu)\n        layers_list.append(resnet50_model.maxpool)\n        layers_list.append(resnet50_model.layer1)\n        layers_list.append(resnet50_model.layer2)\n\n        feature_extractor_stage1 = nn.Sequential(*layers_list)\n\n        feature_extractor_stage2 = nn.Sequential(resnet50_model.layer3)\n\n        #layers_list.append(resnet50_model.layer3)\n\n        return feature_extractor_stage1, feature_extractor_stage2\n\n    # forward\n    def forward(self, input1_tesnor, input2_tesnor):\n        batch_size, _, img_h, img_w = input1_tesnor.size()\n\n        feature_1_64 = self.feature_extractor_stage1(input1_tesnor)\n        feature_1_32 = self.feature_extractor_stage2(feature_1_64)\n        feature_2_64 = self.feature_extractor_stage1(input2_tesnor)\n        feature_2_32 = self.feature_extractor_stage2(feature_2_64)\n\n        ######### stage 1\n        correlation_32 = self.CCL(feature_1_32, feature_2_32)\n        temp_1 = self.regressNet1_part1(correlation_32)\n        temp_1 = temp_1.view(temp_1.size()[0], -1)\n        offset_1 = self.regressNet1_part2(temp_1)\n        H_motion_1 = offset_1.reshape(-1, 4, 2)\n\n\n        src_p = torch.tensor([[0., 0.], [img_w, 0.], [0., img_h], [img_w, img_h]])\n        if torch.cuda.is_available():\n            src_p = src_p.cuda()\n        src_p = src_p.unsqueeze(0).expand(batch_size, -1, -1)\n        dst_p = src_p + H_motion_1\n        H = torch_DLT.tensor_DLT(src_p/8, dst_p/8)\n\n        M_tensor = torch.tensor([[img_w/8 / 2.0, 0., img_w/8 / 2.0],\n                      [0., img_h/8 / 2.0, img_h/8 / 2.0],\n                      [0., 0., 1.]])\n\n        if torch.cuda.is_available():\n            M_tensor = M_tensor.cuda()\n\n        M_tile = M_tensor.unsqueeze(0).expand(batch_size, -1, -1)\n        M_tensor_inv = torch.inverse(M_tensor)\n        M_tile_inv = M_tensor_inv.unsqueeze(0).expand(batch_size, -1, -1)\n        H_mat = torch.matmul(torch.matmul(M_tile_inv, H), M_tile)\n\n        warp_feature_2_64 = torch_homo_transform.transformer(feature_2_64, H_mat, (int(img_h/8), int(img_w/8)))\n\n        ######### stage 2\n        correlation_64 = self.CCL(feature_1_64, warp_feature_2_64)\n        temp_2 = self.regressNet2_part1(correlation_64)\n        temp_2 = temp_2.view(temp_2.size()[0], -1)\n        offset_2 = self.regressNet2_part2(temp_2)\n\n\n        return offset_1, offset_2\n\n\n    def extract_patches(self, x, kernel=3, stride=1):\n        if kernel != 1:\n            x = nn.ZeroPad2d(1)(x)\n        x = x.permute(0, 2, 3, 1)\n        all_patches = x.unfold(1, kernel, stride).unfold(2, kernel, stride)\n        return all_patches\n\n\n    def CCL(self, feature_1, feature_2):\n        bs, c, h, w = feature_1.size()\n\n        norm_feature_1 = F.normalize(feature_1, p=2, dim=1)\n        norm_feature_2 = F.normalize(feature_2, p=2, dim=1)\n        #print(norm_feature_2.size())\n\n        patches = self.extract_patches(norm_feature_2)\n        if torch.cuda.is_available():\n            patches = patches.cuda()\n\n        matching_filters  = patches.reshape((patches.size()[0], -1, patches.size()[3], patches.size()[4], patches.size()[5]))\n\n        match_vol = []\n        for i in range(bs):\n            single_match = F.conv2d(norm_feature_1[i].unsqueeze(0), matching_filters[i], padding=1)\n            match_vol.append(single_match)\n\n        match_vol = torch.cat(match_vol, 0)\n        #print(match_vol .size())\n\n        # scale softmax\n        softmax_scale = 10\n        match_vol = F.softmax(match_vol*softmax_scale,1)\n\n        channel = match_vol.size()[1]\n\n        h_one = torch.linspace(0, h-1, h)\n        one1w = torch.ones(1, w)\n        if torch.cuda.is_available():\n            h_one = h_one.cuda()\n            one1w = one1w.cuda()\n        h_one = torch.matmul(h_one.unsqueeze(1), one1w)\n        h_one = h_one.unsqueeze(0).unsqueeze(0).expand(bs, channel, -1, -1)\n\n        w_one = torch.linspace(0, w-1, w)\n        oneh1 = torch.ones(h, 1)\n        if torch.cuda.is_available():\n            w_one = w_one.cuda()\n            oneh1 = oneh1.cuda()\n        w_one = torch.matmul(oneh1, w_one.unsqueeze(0))\n        w_one = w_one.unsqueeze(0).unsqueeze(0).expand(bs, channel, -1, -1)\n\n        c_one = torch.linspace(0, channel-1, channel)\n        if torch.cuda.is_available():\n            c_one = c_one.cuda()\n        c_one = c_one.unsqueeze(0).unsqueeze(2).unsqueeze(3).expand(bs, -1, h, w)\n\n        flow_h = match_vol*(c_one//w - h_one)\n        flow_h = torch.sum(flow_h, dim=1, keepdim=True)\n        flow_w = match_vol*(c_one%w - w_one)\n        flow_w = torch.sum(flow_w, dim=1, keepdim=True)\n\n        feature_flow = torch.cat([flow_w, flow_h], 1)\n        #print(flow.size())\n\n        return feature_flow\n"
  },
  {
    "path": "Warp/Codes/test.py",
    "content": "# coding: utf-8\r\nimport argparse\r\nimport torch\r\nfrom torch.utils.data import DataLoader\r\nimport torch.nn as nn\r\nimport imageio\r\nfrom network import build_model, Network\r\nfrom dataset import *\r\nimport os\r\nimport numpy as np\r\nimport skimage\r\nimport cv2\r\n\r\n\r\nlast_path = os.path.abspath(os.path.join(os.path.dirname(\"__file__\"), os.path.pardir))\r\nMODEL_DIR = os.path.join(last_path, 'model')\r\n\r\ndef create_gif(image_list, gif_name, duration=0.35):\r\n    frames = []\r\n    for image_name in image_list:\r\n        frames.append(image_name)\r\n    imageio.mimsave(gif_name, frames, 'GIF', duration=0.5)\r\n    return\r\n\r\n\r\ndef test(args):\r\n\r\n    os.environ['CUDA_DEVICES_ORDER'] = \"PCI_BUS_ID\"\r\n    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu\r\n\r\n    # dataset\r\n    test_data = TestDataset(data_path=args.test_path)\r\n    test_loader = DataLoader(dataset=test_data, batch_size=args.batch_size, num_workers=1, shuffle=False, drop_last=False)\r\n\r\n    # define the network\r\n    net = Network()#build_model(args.model_name)\r\n    if torch.cuda.is_available():\r\n        net = net.cuda()\r\n\r\n    #load the existing models if it exists\r\n    ckpt_list = glob.glob(MODEL_DIR + \"/*.pth\")\r\n    ckpt_list.sort()\r\n    if len(ckpt_list) != 0:\r\n        model_path = ckpt_list[-1]\r\n        checkpoint = torch.load(model_path)\r\n        net.load_state_dict(checkpoint['model'])\r\n        print('load model from {}!'.format(model_path))\r\n    else:\r\n        print('No checkpoint found!')\r\n\r\n\r\n\r\n    print(\"##################start testing#######################\")\r\n    psnr_list = []\r\n    ssim_list = []\r\n    net.eval()\r\n    for i, batch_value in enumerate(test_loader):\r\n\r\n        inpu1_tesnor = batch_value[0].float()\r\n        inpu2_tesnor = batch_value[1].float()\r\n\r\n        if torch.cuda.is_available():\r\n            inpu1_tesnor = inpu1_tesnor.cuda()\r\n            inpu2_tesnor = inpu2_tesnor.cuda()\r\n\r\n            with torch.no_grad():\r\n                batch_out = build_model(net, inpu1_tesnor, inpu2_tesnor, is_training=False)\r\n\r\n            warp_mesh_mask = batch_out['warp_mesh_mask']\r\n            warp_mesh = batch_out['warp_mesh']\r\n\r\n\r\n            warp_mesh_np = ((warp_mesh[0]+1)*127.5).cpu().detach().numpy().transpose(1,2,0)\r\n            warp_mesh_mask_np = warp_mesh_mask[0].cpu().detach().numpy().transpose(1,2,0)\r\n            inpu1_np = ((inpu1_tesnor[0]+1)*127.5).cpu().detach().numpy().transpose(1,2,0)\r\n\r\n\r\n            # calculate psnr/ssim\r\n            psnr = skimage.measure.compare_psnr(inpu1_np*warp_mesh_mask_np, warp_mesh_np*warp_mesh_mask_np, 255)\r\n            ssim = skimage.measure.compare_ssim(inpu1_np*warp_mesh_mask_np, warp_mesh_np*warp_mesh_mask_np, data_range=255, multichannel=True)\r\n\r\n\r\n            print('i = {}, psnr = {:.6f}'.format( i+1, psnr))\r\n\r\n            psnr_list.append(psnr)\r\n            ssim_list.append(ssim)\r\n            torch.cuda.empty_cache()\r\n\r\n    print(\"=================== Analysis ==================\")\r\n    print(\"psnr\")\r\n    psnr_list.sort(reverse = True)\r\n    psnr_list_30 = psnr_list[0 : 331]\r\n    psnr_list_60 = psnr_list[331: 663]\r\n    psnr_list_100 = psnr_list[663: -1]\r\n    print(\"top 30%\", np.mean(psnr_list_30))\r\n    print(\"top 30~60%\", np.mean(psnr_list_60))\r\n    print(\"top 60~100%\", np.mean(psnr_list_100))\r\n    print('average psnr:', np.mean(psnr_list))\r\n\r\n    ssim_list.sort(reverse = True)\r\n    ssim_list_30 = ssim_list[0 : 331]\r\n    ssim_list_60 = ssim_list[331: 663]\r\n    ssim_list_100 = ssim_list[663: -1]\r\n    print(\"top 30%\", np.mean(ssim_list_30))\r\n    print(\"top 30~60%\", np.mean(ssim_list_60))\r\n    print(\"top 60~100%\", np.mean(ssim_list_100))\r\n    print('average ssim:', np.mean(ssim_list))\r\n    print(\"##################end testing#######################\")\r\n\r\n\r\nif __name__==\"__main__\":\r\n\r\n    parser = argparse.ArgumentParser()\r\n\r\n    parser.add_argument('--gpu', type=str, default='0')\r\n    parser.add_argument('--batch_size', type=int, default=1)\r\n    parser.add_argument('--test_path', type=str, default='/opt/data/private/nl/Data/UDIS-D/testing/')\r\n\r\n    print('<==================== Loading data ===================>\\n')\r\n\r\n    args = parser.parse_args()\r\n    print(args)\r\n    test(args)\r\n"
  },
  {
    "path": "Warp/Codes/test_other.py",
    "content": "import argparse\nimport torch\n\nimport numpy as np\nimport os\nimport torch.nn as nn\nimport torch.optim as optim\n\nimport cv2\n#from torch_homography_model import build_model\nfrom network import get_stitched_result, Network, build_new_ft_model\n\nimport glob\nfrom loss import cal_lp_loss2\nimport torchvision.transforms as T\n\n#import PIL\nresize_512 = T.Resize((512,512))\n\n\ndef loadSingleData(data_path, img1_name, img2_name):\n\n    # load image1\n    input1 = cv2.imread(data_path+img1_name)\n    input1 = input1.astype(dtype=np.float32)\n    input1 = (input1 / 127.5) - 1.0\n    input1 = np.transpose(input1, [2, 0, 1])\n\n    # load image2\n    input2 = cv2.imread(data_path+img2_name)\n    input2 = input2.astype(dtype=np.float32)\n    input2 = (input2 / 127.5) - 1.0\n    input2 = np.transpose(input2, [2, 0, 1])\n\n    # convert to tensor\n    input1_tensor = torch.tensor(input1).unsqueeze(0)\n    input2_tensor = torch.tensor(input2).unsqueeze(0)\n    return (input1_tensor, input2_tensor)\n\n\n\n# path of project\n#nl: os.path.dirname(\"__file__\") ----- the current absolute path\n#nl: os.path.pardir ---- the last path\nlast_path = os.path.abspath(os.path.join(os.path.dirname(\"__file__\"), os.path.pardir))\n\n\n#nl: path to save the model files\nMODEL_DIR = os.path.join(last_path, 'model')\n\n#nl: create folders if it dose not exist\nif not os.path.exists(MODEL_DIR):\n    os.makedirs(MODEL_DIR)\n\n\ndef train(args):\n\n    os.environ['CUDA_DEVICES_ORDER'] = \"PCI_BUS_ID\"\n    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu\n    \n    # define the network\n    net = Network()\n    if torch.cuda.is_available():\n        net = net.cuda()\n\n    # define the optimizer and learning rate\n    optimizer = optim.Adam(net.parameters(), lr=1e-4, betas=(0.9, 0.999), eps=1e-08)  # default as 0.0001\n    scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.97)\n\n    #load the existing models if it exists\n    ckpt_list = glob.glob(MODEL_DIR + \"/*.pth\")\n    ckpt_list.sort()\n    if len(ckpt_list) != 0:\n        model_path = ckpt_list[-1]\n        checkpoint = torch.load(model_path)\n\n        net.load_state_dict(checkpoint['model'])\n        optimizer.load_state_dict(checkpoint['optimizer'])\n        start_epoch = checkpoint['epoch']\n        scheduler.last_epoch = start_epoch\n        print('load model from {}!'.format(model_path))\n    else:\n        start_epoch = 0\n        print('training from stratch!')\n\n    # load dataset(only one pair of images)\n    input1_tensor, input2_tensor = loadSingleData(data_path=args.path, img1_name = args.img1_name, img2_name = args.img2_name)\n    if torch.cuda.is_available():\n        input1_tensor = input1_tensor.cuda()\n        input2_tensor = input2_tensor.cuda()\n\n    input1_tensor_512 = resize_512(input1_tensor)\n    input2_tensor_512 = resize_512(input2_tensor)\n\n    loss_list = []\n\n    print(\"##################start iteration#######################\")\n    for epoch in range(start_epoch, start_epoch + args.max_iter):\n        net.train()\n\n        optimizer.zero_grad()\n\n        batch_out = build_new_ft_model(net, input1_tensor_512, input2_tensor_512)\n        warp_mesh = batch_out['warp_mesh']\n        warp_mesh_mask = batch_out['warp_mesh_mask']\n        rigid_mesh = batch_out['rigid_mesh']\n        mesh = batch_out['mesh']\n\n        total_loss = cal_lp_loss2(input1_tensor_512, warp_mesh, warp_mesh_mask)\n        total_loss.backward()\n        # clip the gradient\n        torch.nn.utils.clip_grad_norm_(net.parameters(), max_norm=3, norm_type=2)\n        optimizer.step()\n\n        current_iter = epoch-start_epoch+1\n        print(\"Training: Iteration[{:0>3}/{:0>3}] Total Loss: {:.4f} lr={:.8f}\".format(current_iter, args.max_iter, total_loss, optimizer.state_dict()['param_groups'][0]['lr']))\n\n\n        loss_list.append(total_loss)\n\n\n        if current_iter == 1:\n            with torch.no_grad():\n                output = get_stitched_result(input1_tensor, input2_tensor, rigid_mesh, mesh)\n            cv2.imwrite( args.path+ 'before_optimization.jpg', output['stitched'][0].cpu().detach().numpy().transpose(1,2,0))\n            cv2.imwrite( args.path+ 'before_optimization_mesh.jpg', output['stitched_mesh'])\n\n\n        if current_iter >= 4:\n            if torch.abs(loss_list[current_iter-4]-loss_list[current_iter-3]) <= 1e-4 and torch.abs(loss_list[current_iter-3]-loss_list[current_iter-2]) <= 1e-4 \\\n            and torch.abs(loss_list[current_iter-2]-loss_list[current_iter-1]) <= 1e-4:\n                with torch.no_grad():\n                    output = get_stitched_result(input1_tensor, input2_tensor, rigid_mesh, mesh)\n\n                path = args.path + \"iter-\" + str(epoch-start_epoch+1).zfill(3) + \".jpg\"\n                cv2.imwrite(path, output['stitched'][0].cpu().detach().numpy().transpose(1,2,0))\n                cv2.imwrite(args.path + \"iter-\" + str(epoch-start_epoch+1).zfill(3) + \"_mesh.jpg\", output['stitched_mesh'])\n                cv2.imwrite( args.path+'warp1.jpg', output['warp1'][0].cpu().detach().numpy().transpose(1,2,0))\n                cv2.imwrite( args.path+'warp2.jpg', output['warp2'][0].cpu().detach().numpy().transpose(1,2,0))\n                cv2.imwrite( args.path+'mask1.jpg', output['mask1'][0].cpu().detach().numpy().transpose(1,2,0))\n                cv2.imwrite( args.path+'mask2.jpg', output['mask2'][0].cpu().detach().numpy().transpose(1,2,0))\n                break\n\n        if current_iter == args.max_iter:\n            with torch.no_grad():\n                output = get_stitched_result(input1_tensor, input2_tensor, rigid_mesh, mesh)\n\n            path = args.path + \"iter-\" + str(epoch-start_epoch+1).zfill(3) + \".jpg\"\n            cv2.imwrite(path, output['stitched'][0].cpu().detach().numpy().transpose(1,2,0))\n            cv2.imwrite(args.path + \"iter-\" + str(epoch-start_epoch+1).zfill(3) + \"_mesh.jpg\", output['stitched_mesh'])\n            cv2.imwrite( args.path+'warp1.jpg', output['warp1'][0].cpu().detach().numpy().transpose(1,2,0))\n            cv2.imwrite( args.path+'warp2.jpg', output['warp2'][0].cpu().detach().numpy().transpose(1,2,0))\n            cv2.imwrite( args.path+'mask1.jpg', output['mask1'][0].cpu().detach().numpy().transpose(1,2,0))\n            cv2.imwrite( args.path+'mask2.jpg', output['mask2'][0].cpu().detach().numpy().transpose(1,2,0))\n\n        scheduler.step()\n\n    print(\"##################end iteration#######################\")\n\n\nif __name__==\"__main__\":\n\n\n    print('<==================== setting arguments ===================>\\n')\n\n    #nl: create the argument parser\n    parser = argparse.ArgumentParser()\n\n    #nl: add arguments\n    parser.add_argument('--gpu', type=str, default='0')\n    parser.add_argument('--max_iter', type=int, default=50)\n    parser.add_argument('--path', type=str, default='../../Carpark-DHW/')\n    parser.add_argument('--img1_name', type=str, default='input1.jpg')\n    parser.add_argument('--img2_name', type=str, default='input2.jpg')\n\n    #nl: parse the arguments\n    args = parser.parse_args()\n    print(args)\n\n    #nl: rain\n    train(args)\n\n\n"
  },
  {
    "path": "Warp/Codes/test_output.py",
    "content": "# coding: utf-8\nimport argparse\nimport torch\nfrom torch.utils.data import DataLoader\nimport torch.nn as nn\nimport imageio\nfrom network import build_output_model, Network\nfrom dataset import *\nimport os\nimport cv2\n\nimport grid_res\ngrid_h = grid_res.GRID_H\ngrid_w = grid_res.GRID_W\n\nlast_path = os.path.abspath(os.path.join(os.path.dirname(\"__file__\"), os.path.pardir))\nMODEL_DIR = os.path.join(last_path, 'model')\n\n\ndef draw_mesh_on_warp(warp, f_local):\n\n\n    point_color = (0, 255, 0) # BGR\n    thickness = 2\n    lineType = 8\n\n    num = 1\n    for i in range(grid_h+1):\n        for j in range(grid_w+1):\n\n            num = num + 1\n            if j == grid_w and i == grid_h:\n                continue\n            elif j == grid_w:\n                cv2.line(warp, (int(f_local[i,j,0]), int(f_local[i,j,1])), (int(f_local[i+1,j,0]), int(f_local[i+1,j,1])), point_color, thickness, lineType)\n            elif i == grid_h:\n                cv2.line(warp, (int(f_local[i,j,0]), int(f_local[i,j,1])), (int(f_local[i,j+1,0]), int(f_local[i,j+1,1])), point_color, thickness, lineType)\n            else :\n                cv2.line(warp, (int(f_local[i,j,0]), int(f_local[i,j,1])), (int(f_local[i+1,j,0]), int(f_local[i+1,j,1])), point_color, thickness, lineType)\n                cv2.line(warp, (int(f_local[i,j,0]), int(f_local[i,j,1])), (int(f_local[i,j+1,0]), int(f_local[i,j+1,1])), point_color, thickness, lineType)\n\n    return warp\n\ndef create_gif(image_list, gif_name, duration=0.35):\n    frames = []\n    for image_name in image_list:\n        frames.append(image_name)\n    imageio.mimsave(gif_name, frames, 'GIF', duration=0.5)\n    return\n\n\n\n\ndef test(args):\n\n    os.environ['CUDA_DEVICES_ORDER'] = \"PCI_BUS_ID\"\n    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu\n    \n    # dataset\n    test_data = TestDataset(data_path=args.test_path)\n    #nl: set num_workers = the number of cpus\n    test_loader = DataLoader(dataset=test_data, batch_size=args.batch_size, num_workers=1, shuffle=False, drop_last=False)\n\n    # define the network\n    net = Network()#build_model(args.model_name)\n    if torch.cuda.is_available():\n        net = net.cuda()\n\n    #load the existing models if it exists\n    ckpt_list = glob.glob(MODEL_DIR + \"/*.pth\")\n    ckpt_list.sort()\n    if len(ckpt_list) != 0:\n        model_path = ckpt_list[-1]\n        #model_path = '/opt/data/private/nl/Repository/Unsupervised_Mesh_Stitching/UDISv2-88/UDISv2-Homo_TPS88-10grid_NO-res50-new3/model/epoch150_model.pth'\n        checkpoint = torch.load(model_path)\n\n        net.load_state_dict(checkpoint['model'])\n        print('load model from {}!'.format(model_path))\n    else:\n        print('No checkpoint found!')\n\n\n\n    print(\"##################start testing#######################\")\n    # create folders if it dose not exist\n\n    path_ave_fusion = '../ave_fusion/'\n    if not os.path.exists(path_ave_fusion):\n        os.makedirs(path_ave_fusion)\n    path_warp1 = args.test_path + 'warp1/'\n    if not os.path.exists(path_warp1):\n        os.makedirs(path_warp1)\n    path_warp2 = args.test_path + 'warp2/'\n    if not os.path.exists(path_warp2):\n        os.makedirs(path_warp2)\n    path_mask1 = args.test_path + 'mask1/'\n    if not os.path.exists(path_mask1):\n        os.makedirs(path_mask1)\n    path_mask2 = args.test_path + 'mask2/'\n    if not os.path.exists(path_mask2):\n        os.makedirs(path_mask2)\n\n\n\n    net.eval()\n    for i, batch_value in enumerate(test_loader):\n\n        #if i != 975:\n        #    continue\n\n        inpu1_tesnor = batch_value[0].float()\n        inpu2_tesnor = batch_value[1].float()\n\n        if torch.cuda.is_available():\n            inpu1_tesnor = inpu1_tesnor.cuda()\n            inpu2_tesnor = inpu2_tesnor.cuda()\n\n        with torch.no_grad():\n            batch_out = build_output_model(net, inpu1_tesnor, inpu2_tesnor)\n\n        final_warp1 = batch_out['final_warp1']\n        final_warp1_mask = batch_out['final_warp1_mask']\n        final_warp2 = batch_out['final_warp2']\n        final_warp2_mask = batch_out['final_warp2_mask']\n        final_mesh1 = batch_out['mesh1']\n        final_mesh2 = batch_out['mesh2']\n\n\n        final_warp1 = ((final_warp1[0]+1)*127.5).cpu().detach().numpy().transpose(1,2,0)\n        final_warp2 = ((final_warp2[0]+1)*127.5).cpu().detach().numpy().transpose(1,2,0)\n        final_warp1_mask = final_warp1_mask[0].cpu().detach().numpy().transpose(1,2,0)\n        final_warp2_mask = final_warp2_mask[0].cpu().detach().numpy().transpose(1,2,0)\n        final_mesh1 = final_mesh1[0].cpu().detach().numpy()\n        final_mesh2 = final_mesh2[0].cpu().detach().numpy()\n\n\n\n        path = path_warp1 + str(i+1).zfill(6) + \".jpg\"\n        cv2.imwrite(path, final_warp1)\n        path = path_warp2 + str(i+1).zfill(6) + \".jpg\"\n        cv2.imwrite(path, final_warp2)\n        path = path_mask1 + str(i+1).zfill(6) + \".jpg\"\n        cv2.imwrite(path, final_warp1_mask*255)\n        path = path_mask2 + str(i+1).zfill(6) + \".jpg\"\n        cv2.imwrite(path, final_warp2_mask*255)\n\n        ave_fusion = final_warp1 * (final_warp1/ (final_warp1+final_warp2+1e-6)) + final_warp2 * (final_warp2/ (final_warp1+final_warp2+1e-6))\n        path = path_ave_fusion + str(i+1).zfill(6) + \".jpg\"\n        cv2.imwrite(path, ave_fusion)\n\n        print('i = {}'.format( i+1))\n\n        torch.cuda.empty_cache()\n\n\n\n\n    print(\"##################end testing#######################\")\n\n\nif __name__==\"__main__\":\n\n    parser = argparse.ArgumentParser()\n\n    parser.add_argument('--gpu', type=str, default='0')\n    parser.add_argument('--batch_size', type=int, default=1)\n    \n    # /opt/data/private/nl/Data/UDIS-D/testing/  or  /opt/data/private/nl/Data/UDIS-D/training/\n    parser.add_argument('--test_path', type=str, default='/opt/data/private/nl/Data/UDIS-D/testing/')\n\n\n    print('<==================== Loading data ===================>\\n')\n\n    args = parser.parse_args()\n    print(args)\n    test(args)\n"
  },
  {
    "path": "Warp/Codes/train.py",
    "content": "import argparse\r\nimport torch\r\nfrom torch.utils.data import DataLoader\r\nimport os\r\nimport torch.optim as optim\r\nfrom torch.utils.tensorboard import SummaryWriter\r\nfrom network import build_model, Network\r\nfrom dataset import TrainDataset\r\nimport glob\r\nfrom loss import cal_lp_loss, inter_grid_loss, intra_grid_loss\r\n\r\n\r\n\r\nlast_path = os.path.abspath(os.path.join(os.path.dirname(\"__file__\"), os.path.pardir))\r\n# path to save the summary files\r\nSUMMARY_DIR = os.path.join(last_path, 'summary')\r\nwriter = SummaryWriter(log_dir=SUMMARY_DIR)\r\n# path to save the model files\r\nMODEL_DIR = os.path.join(last_path, 'model')\r\n# create folders if it dose not exist\r\nif not os.path.exists(MODEL_DIR):\r\n    os.makedirs(MODEL_DIR)\r\nif not os.path.exists(SUMMARY_DIR):\r\n    os.makedirs(SUMMARY_DIR)\r\n\r\n\r\n\r\ndef train(args):\r\n\r\n    os.environ['CUDA_DEVICES_ORDER'] = \"PCI_BUS_ID\"\r\n    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu\r\n    \r\n    # define dataset\r\n    train_data = TrainDataset(data_path=args.train_path)\r\n    train_loader = DataLoader(dataset=train_data, batch_size=args.batch_size, num_workers=4, shuffle=True, drop_last=True)\r\n\r\n    # define the network\r\n    net = Network()\r\n    if torch.cuda.is_available():\r\n        net = net.cuda()\r\n\r\n    # define the optimizer and learning rate\r\n    optimizer = optim.Adam(net.parameters(), lr=1e-4, betas=(0.9, 0.999), eps=1e-08)  # default as 0.0001\r\n    scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.97)\r\n\r\n    #load the existing models if it exists\r\n    ckpt_list = glob.glob(MODEL_DIR + \"/*.pth\")\r\n    ckpt_list.sort()\r\n    if len(ckpt_list) != 0:\r\n        model_path = ckpt_list[-1]\r\n        checkpoint = torch.load(model_path)\r\n\r\n        net.load_state_dict(checkpoint['model'])\r\n        optimizer.load_state_dict(checkpoint['optimizer'])\r\n        start_epoch = checkpoint['epoch']\r\n        glob_iter = checkpoint['glob_iter']\r\n        scheduler.last_epoch = start_epoch\r\n        print('load model from {}!'.format(model_path))\r\n    else:\r\n        start_epoch = 0\r\n        glob_iter = 0\r\n        print('training from stratch!')\r\n\r\n\r\n\r\n    print(\"##################start training#######################\")\r\n    score_print_fre = 300\r\n\r\n    for epoch in range(start_epoch, args.max_epoch):\r\n\r\n        print(\"start epoch {}\".format(epoch))\r\n        net.train()\r\n        loss_sigma = 0.0\r\n        overlap_loss_sigma = 0.\r\n        nonoverlap_loss_sigma = 0.\r\n\r\n        print(epoch, 'lr={:.6f}'.format(optimizer.state_dict()['param_groups'][0]['lr']))\r\n\r\n        for i, batch_value in enumerate(train_loader):\r\n\r\n            inpu1_tesnor = batch_value[0].float()\r\n            inpu2_tesnor = batch_value[1].float()\r\n\r\n            if torch.cuda.is_available():\r\n                inpu1_tesnor = inpu1_tesnor.cuda()\r\n                inpu2_tesnor = inpu2_tesnor.cuda()\r\n\r\n            # forward, backward, update weights\r\n            optimizer.zero_grad()\r\n\r\n            batch_out = build_model(net, inpu1_tesnor, inpu2_tesnor)\r\n            # result\r\n            output_H = batch_out['output_H']\r\n            output_H_inv = batch_out['output_H_inv']\r\n            warp_mesh = batch_out['warp_mesh']\r\n            warp_mesh_mask = batch_out['warp_mesh_mask']\r\n            mesh1 = batch_out['mesh1']\r\n            mesh2 = batch_out['mesh2']\r\n            overlap = batch_out['overlap']\r\n\r\n            # calculate loss for overlapping regions\r\n            overlap_loss = cal_lp_loss(inpu1_tesnor, inpu2_tesnor, output_H, output_H_inv, warp_mesh, warp_mesh_mask)\r\n            # calculate loss for non-overlapping regions\r\n            nonoverlap_loss = 10*inter_grid_loss(overlap, mesh2) + 10*intra_grid_loss(mesh2)\r\n\r\n            total_loss = overlap_loss + nonoverlap_loss\r\n            total_loss.backward()\r\n\r\n            # clip the gradient\r\n            torch.nn.utils.clip_grad_norm_(net.parameters(), max_norm=3, norm_type=2)\r\n            optimizer.step()\r\n\r\n            overlap_loss_sigma += overlap_loss.item()\r\n            nonoverlap_loss_sigma += nonoverlap_loss.item()\r\n            loss_sigma += total_loss.item()\r\n\r\n            print(glob_iter)\r\n\r\n            # record loss and images in tensorboard\r\n            if i % score_print_fre == 0 and i != 0:\r\n                average_loss = loss_sigma / score_print_fre\r\n                average_overlap_loss = overlap_loss_sigma/ score_print_fre\r\n                average_nonoverlap_loss = nonoverlap_loss_sigma/ score_print_fre\r\n                loss_sigma = 0.0\r\n                overlap_loss_sigma = 0.\r\n                nonoverlap_loss_sigma = 0.\r\n\r\n                print(\"Training: Epoch[{:0>3}/{:0>3}] Iteration[{:0>3}]/[{:0>3}] Total Loss: {:.4f}  Overlap Loss: {:.4f}  Non-overlap Loss: {:.4f} lr={:.8f}\".format(epoch + 1, args.max_epoch, i + 1, len(train_loader),\r\n                                          average_loss, average_overlap_loss, average_nonoverlap_loss, optimizer.state_dict()['param_groups'][0]['lr']))\r\n                # visualization\r\n                writer.add_image(\"inpu1\", (inpu1_tesnor[0]+1.)/2., glob_iter)\r\n                writer.add_image(\"inpu2\", (inpu2_tesnor[0]+1.)/2., glob_iter)\r\n                writer.add_image(\"warp_H\", (output_H[0,0:3,:,:]+1.)/2., glob_iter)\r\n                writer.add_image(\"warp_mesh\", (warp_mesh[0]+1.)/2., glob_iter)\r\n                writer.add_scalar('lr', optimizer.state_dict()['param_groups'][0]['lr'], glob_iter)\r\n                writer.add_scalar('total loss', average_loss, glob_iter)\r\n                writer.add_scalar('overlap loss', average_overlap_loss, glob_iter)\r\n                writer.add_scalar('nonoverlap loss', average_nonoverlap_loss, glob_iter)\r\n\r\n            glob_iter += 1\r\n\r\n\r\n        scheduler.step()\r\n        # save model\r\n        if ((epoch+1) % 10 == 0 or (epoch+1)==args.max_epoch):\r\n            filename ='epoch' + str(epoch+1).zfill(3) + '_model.pth'\r\n            model_save_path = os.path.join(MODEL_DIR, filename)\r\n            state = {'model': net.state_dict(), 'optimizer': optimizer.state_dict(), 'epoch': epoch+1, \"glob_iter\": glob_iter}\r\n            torch.save(state, model_save_path)\r\n    print(\"##################end training#######################\")\r\n\r\n\r\nif __name__==\"__main__\":\r\n\r\n\r\n    print('<==================== setting arguments ===================>\\n')\r\n\r\n    # create the argument parser\r\n    parser = argparse.ArgumentParser()\r\n\r\n    # add arguments\r\n    parser.add_argument('--gpu', type=str, default='0')\r\n    parser.add_argument('--batch_size', type=int, default=4)\r\n    parser.add_argument('--max_epoch', type=int, default=100)\r\n    parser.add_argument('--train_path', type=str, default='/opt/data/private/nl/Data/UDIS-D/training/')\r\n\r\n    # parse the arguments\r\n    args = parser.parse_args()\r\n    print(args)\r\n\r\n    # train\r\n    train(args)\r\n\r\n\r\n"
  },
  {
    "path": "Warp/Codes/utils/torch_DLT.py",
    "content": "import torch\nimport numpy as np\nimport cv2\n\n# src_p: shape=(bs, 4, 2)\n# det_p: shape=(bs, 4, 2)\n#\n#                                     | h1 |\n#                                     | h2 |                   \n#                                     | h3 |\n# | x1 y1 1  0  0  0  -x1x2  -y1x2 |  | h4 |  =  | x2 |\n# | 0  0  0  x1 y1 1  -x1y2  -y1y2 |  | h5 |     | y2 |\n#                                     | h6 |\n#                                     | h7 |\n#                                     | h8 |\n\ndef tensor_DLT(src_p, dst_p):\n   \n    bs, _, _ = src_p.shape\n\n    ones = torch.ones(bs, 4, 1)\n    if torch.cuda.is_available():\n        ones = ones.cuda()\n    xy1 = torch.cat((src_p, ones), 2)\n    zeros = torch.zeros_like(xy1)\n    if torch.cuda.is_available():\n        zeros = zeros.cuda()\n\n    xyu, xyd = torch.cat((xy1, zeros), 2), torch.cat((zeros, xy1), 2)\n    M1 = torch.cat((xyu, xyd), 2).reshape(bs, -1, 6)\n    M2 = torch.matmul(\n        dst_p.reshape(-1, 2, 1), \n        src_p.reshape(-1, 1, 2),\n    ).reshape(bs, -1, 2)\n    \n    # Ah = b\n    A = torch.cat((M1, -M2), 2)\n    b = dst_p.reshape(bs, -1, 1)\n    \n    #h = A^{-1}b\n    Ainv = torch.inverse(A)\n    h8 = torch.matmul(Ainv, b).reshape(bs, 8)\n \n    H = torch.cat((h8, ones[:,0,:]), 1).reshape(bs, 3, 3)\n    return H"
  },
  {
    "path": "Warp/Codes/utils/torch_homo_transform.py",
    "content": "import torch\nimport numpy as np\n\n\ndef transformer(U, theta, out_size, **kwargs):\n\n\n    def _repeat(x, n_repeats):\n\n        rep = torch.ones([n_repeats, ]).unsqueeze(0)\n        rep = rep.int()\n        x = x.int()\n\n        x = torch.matmul(x.reshape([-1,1]), rep)\n        return x.reshape([-1])\n\n    def _interpolate(im, x, y, out_size):\n\n        num_batch, num_channels , height, width = im.size()\n\n        height_f = height\n        width_f = width\n        out_height, out_width = out_size[0], out_size[1]\n\n        zero = 0\n        max_y = height - 1\n        max_x = width - 1\n\n        x = (x + 1.0)*(width_f) / 2.0\n        y = (y + 1.0) * (height_f) / 2.0\n\n        # do sampling\n        x0 = torch.floor(x).int()\n        x1 = x0 + 1\n        y0 = torch.floor(y).int()\n        y1 = y0 + 1\n\n        x0 = torch.clamp(x0, zero, max_x)\n        x1 = torch.clamp(x1, zero, max_x)\n        y0 = torch.clamp(y0, zero, max_y)\n        y1 = torch.clamp(y1, zero, max_y)\n        dim2 = torch.from_numpy( np.array(width) )\n        dim1 = torch.from_numpy( np.array(width * height) )\n\n        base = _repeat(torch.arange(0,num_batch) * dim1, out_height * out_width)\n        if torch.cuda.is_available():\n            dim2 = dim2.cuda()\n            dim1 = dim1.cuda()\n            y0 = y0.cuda()\n            y1 = y1.cuda()\n            x0 = x0.cuda()\n            x1 = x1.cuda()\n            base = base.cuda()\n        base_y0 = base + y0 * dim2\n        base_y1 = base + y1 * dim2\n        idx_a = base_y0 + x0\n        idx_b = base_y1 + x0\n        idx_c = base_y0 + x1\n        idx_d = base_y1 + x1\n\n        # channels dim\n        im = im.permute(0,2,3,1)\n        im_flat = im.reshape([-1, num_channels]).float()\n\n        idx_a = idx_a.unsqueeze(-1).long()\n        idx_a = idx_a.expand(out_height * out_width * num_batch,num_channels)\n        Ia = torch.gather(im_flat, 0, idx_a)\n\n        idx_b = idx_b.unsqueeze(-1).long()\n        idx_b = idx_b.expand(out_height * out_width * num_batch, num_channels)\n        Ib = torch.gather(im_flat, 0, idx_b)\n\n        idx_c = idx_c.unsqueeze(-1).long()\n        idx_c = idx_c.expand(out_height * out_width * num_batch, num_channels)\n        Ic = torch.gather(im_flat, 0, idx_c)\n\n        idx_d = idx_d.unsqueeze(-1).long()\n        idx_d = idx_d.expand(out_height * out_width * num_batch, num_channels)\n        Id = torch.gather(im_flat, 0, idx_d)\n\n        x0_f = x0.float()\n        x1_f = x1.float()\n        y0_f = y0.float()\n        y1_f = y1.float()\n\n        wa = torch.unsqueeze(((x1_f - x) * (y1_f - y)), 1)\n        wb = torch.unsqueeze(((x1_f - x) * (y - y0_f)), 1)\n        wc = torch.unsqueeze(((x - x0_f) * (y1_f - y)), 1)\n        wd = torch.unsqueeze(((x - x0_f) * (y - y0_f)), 1)\n        output = wa*Ia+wb*Ib+wc*Ic+wd*Id\n\n        return output\n\n    def _meshgrid(height, width):\n\n        x_t = torch.matmul(torch.ones([height, 1]),\n                               torch.transpose(torch.unsqueeze(torch.linspace(-1.0, 1.0, width), 1), 1, 0))\n        y_t = torch.matmul(torch.unsqueeze(torch.linspace(-1.0, 1.0, height), 1),\n                               torch.ones([1, width]))\n        #x_t = torch.matmul(torch.ones([height, 1]),\n        #                       torch.transpose(torch.unsqueeze(torch.linspace(0.0, width.float(), width), 1), 1, 0))\n        #y_t = torch.matmul(torch.unsqueeze(torch.linspace(0.0, height.float(), height), 1),\n        #                       torch.ones([1, width]))\n\n        x_t_flat = x_t.reshape((1, -1)).float()\n        y_t_flat = y_t.reshape((1, -1)).float()\n\n        ones = torch.ones_like(x_t_flat)\n        grid = torch.cat([x_t_flat, y_t_flat, ones], 0)\n        if torch.cuda.is_available():\n            grid = grid.cuda()\n        return grid\n\n    def _transform(theta, input_dim, out_size):\n        num_batch, num_channels , height, width = input_dim.size()\n        #  Changed\n        theta = theta.reshape([-1, 3, 3]).float()\n\n        out_height, out_width = out_size[0], out_size[1]\n        grid = _meshgrid(out_height, out_width)\n        grid = grid.unsqueeze(0).reshape([1,-1])\n        shape = grid.size()\n        grid = grid.expand(num_batch,shape[1])\n        grid = grid.reshape([num_batch, 3, -1])\n\n        T_g = torch.matmul(theta, grid)\n        x_s = T_g[:,0,:]\n        y_s = T_g[:,1,:]\n        t_s = T_g[:,2,:]\n\n        t_s_flat = t_s.reshape([-1])\n\n        # smaller\n        small = 1e-7\n        smallers = 1e-6*(1.0 - torch.ge(torch.abs(t_s_flat), small).float())\n\n        t_s_flat = t_s_flat + smallers\n        #condition = torch.sum(torch.gt(torch.abs(t_s_flat), small).float())\n        # Ty changed\n        x_s_flat = x_s.reshape([-1]) / t_s_flat\n        y_s_flat = y_s.reshape([-1]) / t_s_flat\n\n        input_transformed = _interpolate( input_dim, x_s_flat, y_s_flat,out_size)\n\n        output = input_transformed.reshape([num_batch, out_height, out_width, num_channels])\n        output = output.permute(0,3,1,2)\n        return output#, condition\n\n\n    output = _transform(theta, U, out_size)\n    return output#, condition"
  },
  {
    "path": "Warp/Codes/utils/torch_tps_transform.py",
    "content": "import torch\nimport numpy as np\n\n# transforming an image (U) from target (control points) to source (control points)\n# all the points should be normalized from -1 ~1\n\ndef transformer(U, source, target, out_size):\n\n    def _repeat(x, n_repeats):\n\n        rep = torch.ones([n_repeats, ]).unsqueeze(0)\n        rep = rep.int()\n        x = x.int()\n\n        x = torch.matmul(x.reshape([-1,1]), rep)\n        return x.reshape([-1])\n\n    def _interpolate(im, x, y, out_size):\n\n        num_batch, num_channels , height, width = im.size()\n\n        height_f = height\n        width_f = width\n        out_height, out_width = out_size[0], out_size[1]\n\n        zero = 0\n        max_y = height - 1\n        max_x = width - 1\n\n        x = (x + 1.0)*(width_f) / 2.0\n        y = (y + 1.0) * (height_f) / 2.0\n\n        # do sampling\n        x0 = torch.floor(x).int()\n        x1 = x0 + 1\n        y0 = torch.floor(y).int()\n        y1 = y0 + 1\n\n        x0 = torch.clamp(x0, zero, max_x)\n        x1 = torch.clamp(x1, zero, max_x)\n        y0 = torch.clamp(y0, zero, max_y)\n        y1 = torch.clamp(y1, zero, max_y)\n        dim2 = torch.from_numpy( np.array(width) )\n        dim1 = torch.from_numpy( np.array(width * height) )\n\n        base = _repeat(torch.arange(0,num_batch) * dim1, out_height * out_width)\n        if torch.cuda.is_available():\n            dim2 = dim2.cuda()\n            dim1 = dim1.cuda()\n            y0 = y0.cuda()\n            y1 = y1.cuda()\n            x0 = x0.cuda()\n            x1 = x1.cuda()\n            base = base.cuda()\n        base_y0 = base + y0 * dim2\n        base_y1 = base + y1 * dim2\n        idx_a = base_y0 + x0\n        idx_b = base_y1 + x0\n        idx_c = base_y0 + x1\n        idx_d = base_y1 + x1\n\n        # channels dim\n        im = im.permute(0,2,3,1)\n        im_flat = im.reshape([-1, num_channels]).float()\n\n\n        idx_a = idx_a.unsqueeze(-1).long()\n        idx_a = idx_a.expand(out_height * out_width * num_batch,num_channels)\n        Ia = torch.gather(im_flat, 0, idx_a)\n\n        idx_b = idx_b.unsqueeze(-1).long()\n        idx_b = idx_b.expand(out_height * out_width * num_batch, num_channels)\n        Ib = torch.gather(im_flat, 0, idx_b)\n\n        idx_c = idx_c.unsqueeze(-1).long()\n        idx_c = idx_c.expand(out_height * out_width * num_batch, num_channels)\n        Ic = torch.gather(im_flat, 0, idx_c)\n\n        idx_d = idx_d.unsqueeze(-1).long()\n        idx_d = idx_d.expand(out_height * out_width * num_batch, num_channels)\n        Id = torch.gather(im_flat, 0, idx_d)\n\n        x0_f = x0.float()\n        x1_f = x1.float()\n        y0_f = y0.float()\n        y1_f = y1.float()\n\n        wa = torch.unsqueeze(((x1_f - x) * (y1_f - y)), 1)\n        wb = torch.unsqueeze(((x1_f - x) * (y - y0_f)), 1)\n        wc = torch.unsqueeze(((x - x0_f) * (y1_f - y)), 1)\n        wd = torch.unsqueeze(((x - x0_f) * (y - y0_f)), 1)\n        output = wa*Ia+wb*Ib+wc*Ic+wd*Id\n\n        return output\n\n    def _meshgrid(height, width, source):\n\n        x_t = torch.matmul(torch.ones([height, 1]), torch.unsqueeze(torch.linspace(-1.0, 1.0, width), 0))\n        y_t = torch.matmul(torch.unsqueeze(torch.linspace(-1.0, 1.0, height), 1), torch.ones([1, width]))\n        if torch.cuda.is_available():\n            x_t = x_t.cuda()\n            y_t = y_t.cuda()\n\n        x_t_flat = x_t.reshape([1, 1, -1])\n        y_t_flat = y_t.reshape([1, 1, -1])\n\n        num_batch = source.size()[0]\n        px = torch.unsqueeze(source[:,:,0], 2)  # [bn, pn, 1]\n        py = torch.unsqueeze(source[:,:,1], 2)  # [bn, pn, 1]\n        if torch.cuda.is_available():\n            px = px.cuda()\n            py = py.cuda()\n        d2 = torch.square(x_t_flat - px) + torch.square(y_t_flat - py)\n        r = d2 * torch.log(d2 + 1e-6) # [bn, pn, h*w]\n        x_t_flat_g = x_t_flat.expand(num_batch, -1, -1)  # [bn, 1, h*w]\n        y_t_flat_g = y_t_flat.expand(num_batch, -1, -1)  # [bn, 1, h*w]\n        ones = torch.ones_like(x_t_flat_g) # [bn, 1, h*w]\n        if torch.cuda.is_available():\n            ones = ones.cuda()\n\n        grid = torch.cat((ones, x_t_flat_g, y_t_flat_g, r), 1) # [bn, 3+pn, h*w]\n\n        #if torch.cuda.is_available():\n        #    grid = grid.cuda()\n        return grid\n\n    def _transform(T, source, input_dim, out_size):\n        num_batch, num_channels, height, width = input_dim.size()\n\n        out_height, out_width = out_size[0], out_size[1]\n        grid = _meshgrid(out_height, out_width, source) # [bn, 3+pn, h*w]\n\n        # transform A x (1, x_t, y_t, r1, r2, ..., rn) -> (x_s, y_s)\n        # [bn, 2, pn+3] x [bn, pn+3, h*w] -> [bn, 2, h*w]\n        T_g = torch.matmul(T, grid)\n        x_s = T_g[:,0,:]\n        y_s = T_g[:,1,:]\n        x_s_flat = x_s.reshape([-1])\n        y_s_flat = y_s.reshape([-1])\n\n        input_transformed = _interpolate(input_dim, x_s_flat, y_s_flat,out_size)\n\n        output = input_transformed.reshape([num_batch, out_height, out_width, num_channels])\n\n        output = output.permute(0,3,1,2)\n        return output#, condition\n\n\n    def _solve_system(source, target):\n        num_batch  = source.size()[0]\n        num_point  = source.size()[1]\n\n        np.set_printoptions(precision=8)\n\n        ones = torch.ones(num_batch, num_point, 1).float()\n        if torch.cuda.is_available():\n            ones = ones.cuda()\n        p = torch.cat([ones, source], 2) # [bn, pn, 3]\n\n        p_1 = p.reshape([num_batch, -1, 1, 3]) # [bn, pn, 1, 3]\n        p_2 = p.reshape([num_batch, 1, -1, 3])  # [bn, 1, pn, 3]\n        d2 = torch.sum(torch.square(p_1-p_2), 3) # p1 - p2: [bn, pn, pn, 3]   final output: [bn, pn, pn]\n\n        r = d2 * torch.log(d2 + 1e-6) # [bn, pn, pn]\n\n        zeros = torch.zeros(num_batch, 3, 3).float()\n        if torch.cuda.is_available():\n            zeros = zeros.cuda()\n        W_0 = torch.cat((p, r), 2) # [bn, pn, 3+pn]\n        W_1 = torch.cat((zeros, p.permute(0,2,1)), 2) # [bn, 3, pn+3]\n        W = torch.cat((W_0, W_1), 1) # [bn, pn+3, pn+3]\n\n        W_inv = torch.inverse(W.type(torch.float64))\n\n\n        zeros2 = torch.zeros(num_batch, 3, 2)\n        if torch.cuda.is_available():\n            zeros2 = zeros2.cuda()\n        tp = torch.cat((target, zeros2), 1) # [bn, pn+3, 2]\n\n        T = torch.matmul(W_inv, tp.type(torch.float64)) # [bn, pn+3, 2]\n        T = T.permute(0, 2, 1) # [bn, 2, pn+3]\n\n\n        return T.type(torch.float32)\n\n    T = _solve_system(source, target)\n\n    output = _transform(T, source, U, out_size)\n\n    return output"
  },
  {
    "path": "Warp/Codes/utils/torch_tps_transform2.py",
    "content": "import torch\nimport numpy as np\n\n\n# transforming an image (U) from target (control points) to source (control points)\n# all the points should be normalized from -1 ~1\n\n# compared with torch_tps_transform.py, this version move some operations from GPU to CPU to save GPU memory\n\ndef transformer(U, source, target, out_size):\n\n    def _repeat(x, n_repeats):\n\n        rep = torch.ones([n_repeats, ]).unsqueeze(0)\n        rep = rep.int()\n        x = x.int()\n\n        x = torch.matmul(x.reshape([-1,1]), rep)\n        return x.reshape([-1])\n\n    def _interpolate(im, x, y, out_size):\n\n        num_batch, num_channels , height, width = im.size()\n\n        height_f = height\n        width_f = width\n        out_height, out_width = out_size[0], out_size[1]\n\n        zero = 0\n        max_y = height - 1\n        max_x = width - 1\n\n        x = (x + 1.0)*(width_f) / 2.0\n        y = (y + 1.0) * (height_f) / 2.0\n\n        # do sampling\n        x0 = torch.floor(x).int()\n        x1 = x0 + 1\n        y0 = torch.floor(y).int()\n        y1 = y0 + 1\n\n        x0 = torch.clamp(x0, zero, max_x)\n        x1 = torch.clamp(x1, zero, max_x)\n        y0 = torch.clamp(y0, zero, max_y)\n        y1 = torch.clamp(y1, zero, max_y)\n        dim2 = torch.from_numpy( np.array(width) )\n        dim1 = torch.from_numpy( np.array(width * height) )\n\n        base = _repeat(torch.arange(0,num_batch) * dim1, out_height * out_width)\n        if torch.cuda.is_available():\n            dim2 = dim2.cuda()\n            dim1 = dim1.cuda()\n            y0 = y0.cuda()\n            y1 = y1.cuda()\n            x0 = x0.cuda()\n            x1 = x1.cuda()\n            base = base.cuda()\n        base_y0 = base + y0 * dim2\n        base_y1 = base + y1 * dim2\n        idx_a = base_y0 + x0\n        idx_b = base_y1 + x0\n        idx_c = base_y0 + x1\n        idx_d = base_y1 + x1\n\n        # channels dim\n        im = im.permute(0,2,3,1)\n        im_flat = im.reshape([-1, num_channels]).float()\n\n\n        idx_a = idx_a.unsqueeze(-1).long()\n        idx_a = idx_a.expand(out_height * out_width * num_batch,num_channels)\n        Ia = torch.gather(im_flat, 0, idx_a)\n\n        idx_b = idx_b.unsqueeze(-1).long()\n        idx_b = idx_b.expand(out_height * out_width * num_batch, num_channels)\n        Ib = torch.gather(im_flat, 0, idx_b)\n\n        idx_c = idx_c.unsqueeze(-1).long()\n        idx_c = idx_c.expand(out_height * out_width * num_batch, num_channels)\n        Ic = torch.gather(im_flat, 0, idx_c)\n\n        idx_d = idx_d.unsqueeze(-1).long()\n        idx_d = idx_d.expand(out_height * out_width * num_batch, num_channels)\n        Id = torch.gather(im_flat, 0, idx_d)\n\n        x0_f = x0.float()\n        x1_f = x1.float()\n        y0_f = y0.float()\n        y1_f = y1.float()\n\n        wa = torch.unsqueeze(((x1_f - x) * (y1_f - y)), 1)\n        wb = torch.unsqueeze(((x1_f - x) * (y - y0_f)), 1)\n        wc = torch.unsqueeze(((x - x0_f) * (y1_f - y)), 1)\n        wd = torch.unsqueeze(((x - x0_f) * (y - y0_f)), 1)\n        output = wa*Ia+wb*Ib+wc*Ic+wd*Id\n\n        return output\n\n    def _meshgrid(height, width, source):\n\n        source = source.cpu()\n\n        x_t = torch.matmul(torch.ones([height, 1]), torch.unsqueeze(torch.linspace(-1.0, 1.0, width), 0))\n        y_t = torch.matmul(torch.unsqueeze(torch.linspace(-1.0, 1.0, height), 1), torch.ones([1, width]))\n\n        x_t_flat = x_t.reshape([1, 1, -1])\n        y_t_flat = y_t.reshape([1, 1, -1])\n\n        num_batch = source.size()[0]\n        px = torch.unsqueeze(source[:,:,0], 2)  # [bn, pn, 1]\n        py = torch.unsqueeze(source[:,:,1], 2)  # [bn, pn, 1]\n\n        d2 = torch.square(x_t_flat - px) + torch.square(y_t_flat - py)\n        r = d2 * torch.log(d2 + 1e-6) # [bn, pn, h*w]\n        x_t_flat_g = x_t_flat.expand(num_batch, -1, -1)  # [bn, 1, h*w]\n        y_t_flat_g = y_t_flat.expand(num_batch, -1, -1)  # [bn, 1, h*w]\n        ones = torch.ones_like(x_t_flat_g) # [bn, 1, h*w]\n\n        grid = torch.cat((ones, x_t_flat_g, y_t_flat_g, r), 1) # [bn, 3+pn, h*w]\n\n        #if torch.cuda.is_available():\n        grid = grid.cuda()\n        return grid\n\n    def _transform(T, source, input_dim, out_size):\n        num_batch, num_channels, height, width = input_dim.size()\n\n        out_height, out_width = out_size[0], out_size[1]\n        grid = _meshgrid(out_height, out_width, source) # [bn, 3+pn, h*w]\n        #print(grid.device)\n\n\n\n        # transform A x (1, x_t, y_t, r1, r2, ..., rn) -> (x_s, y_s)\n        # [bn, 2, pn+3] x [bn, pn+3, h*w] -> [bn, 2, h*w]\n        T_g = torch.matmul(T, grid)\n        x_s = T_g[:,0,:]\n        y_s = T_g[:,1,:]\n        x_s_flat = x_s.reshape([-1])\n        y_s_flat = y_s.reshape([-1])\n\n        input_transformed = _interpolate(input_dim, x_s_flat, y_s_flat,out_size)\n\n        output = input_transformed.reshape([num_batch, out_height, out_width, num_channels])\n\n        output = output.permute(0,3,1,2)\n        #print(output.device)\n        return output#, condition\n\n\n    def _solve_system(source, target):\n        num_batch  = source.size()[0]\n        num_point  = source.size()[1]\n\n        np.set_printoptions(precision=8)\n\n        ones = torch.ones(num_batch, num_point, 1).float()\n        if torch.cuda.is_available():\n            ones = ones.cuda()\n        p = torch.cat([ones, source], 2) # [bn, pn, 3]\n\n        p_1 = p.reshape([num_batch, -1, 1, 3]) # [bn, pn, 1, 3]\n        p_2 = p.reshape([num_batch, 1, -1, 3])  # [bn, 1, pn, 3]\n        d2 = torch.sum(torch.square(p_1-p_2), 3) # p1 - p2: [bn, pn, pn, 3]   final output: [bn, pn, pn]\n\n        r = d2 * torch.log(d2 + 1e-6) # [bn, pn, pn]\n\n\n        zeros = torch.zeros(num_batch, 3, 3).float()\n        if torch.cuda.is_available():\n            zeros = zeros.cuda()\n        W_0 = torch.cat((p, r), 2) # [bn, pn, 3+pn]\n        W_1 = torch.cat((zeros, p.permute(0,2,1)), 2) # [bn, 3, pn+3]\n        W = torch.cat((W_0, W_1), 1) # [bn, pn+3, pn+3]\n\n        W_inv = torch.inverse(W.type(torch.float64))\n\n        zeros2 = torch.zeros(num_batch, 3, 2)\n        if torch.cuda.is_available():\n            zeros2 = zeros2.cuda()\n        tp = torch.cat((target, zeros2), 1) # [bn, pn+3, 2]\n\n        T = torch.matmul(W_inv, tp.type(torch.float64)) # [bn, pn+3, 2]\n        T = T.permute(0, 2, 1) # [bn, 2, pn+3]\n\n\n        return T.type(torch.float32)\n\n    T = _solve_system(source, target)\n\n    output = _transform(T, source, U, out_size)\n\n    return output#, condition"
  },
  {
    "path": "Warp/model/.txt",
    "content": "\n"
  },
  {
    "path": "Warp/readme.md",
    "content": "## Train on UDIS-D\nSet the training dataset path in Warp/Codes/train.py.\n\n```\npython train.py\n```\n\n## Test on UDIS-D\nThe pre-trained model of warp is available at [Google Drive](https://drive.google.com/file/d/1GBwB0y3tUUsOYHErSqxDxoC_Om3BJUEt/view?usp=sharing) or [Baidu Cloud](https://pan.baidu.com/s/1Fx6YnQi9B2wvP_TOVAaBEA)(Extraction code: 1234).\n#### Calculate PSNR/SSIM\nSet the testing dataset path in Warp/Codes/test.py.\n\n```\npython test.py\n```\n\n#### Generate the warped images and corresponding masks\nSet the training/testing dataset path in Warp/Codes/test_output.py.\n\n```\npython test_output.py\n```\nThe warped images and masks will be generated and saved at the original training/testing dataset path. The results of average fusion will be saved at the current path.\n\n## Test on other datasets\nWhen testing on other datasets with different scenes and resolutions, we apply the iterative warp adaption to get better alignment performance.\n\nSet the 'path/img1_name/img2_name' in Warp/Codes/test_other.py. (By default, both img1 and img2 are placed under 'path')\n```\npython test_other.py\n```\nThe results before/after adaption will be generated and saved at 'path'.\n\n"
  },
  {
    "path": "Warp/summary/.txt",
    "content": "\n"
  },
  {
    "path": "environment.yml",
    "content": "name: nl\nchannels:\n  - https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/pytorch\n  - https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main\n  - https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main/\n  - https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/pytorch/\n  - https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/free/\n  - defaults\ndependencies:\n  - _anaconda_depends=2020.07=py38_0\n  - _ipyw_jlab_nb_ext_conf=0.1.0=py38_0\n  - _libgcc_mutex=0.1=main\n  - alabaster=0.7.12=py_0\n  - anaconda=custom=py38_1\n  - anaconda-client=1.7.2=py38_0\n  - anaconda-navigator=1.10.0=py38_0\n  - anaconda-project=0.8.4=py_0\n  - argh=0.26.2=py38_0\n  - argon2-cffi=20.1.0=py38h7b6447c_1\n  - asn1crypto=1.4.0=py_0\n  - astroid=2.4.2=py38_0\n  - astropy=4.0.2=py38h7b6447c_0\n  - async_generator=1.10=py_0\n  - atomicwrites=1.4.0=py_0\n  - attrs=20.3.0=pyhd3eb1b0_0\n  - autopep8=1.5.4=py_0\n  - babel=2.8.1=pyhd3eb1b0_0\n  - backcall=0.2.0=py_0\n  - backports=1.0=py_2\n  - backports.functools_lru_cache=1.6.1=py_0\n  - backports.shutil_get_terminal_size=1.0.0=py38_2\n  - backports.tempfile=1.0=py_1\n  - backports.weakref=1.0.post1=py_1\n  - beautifulsoup4=4.9.3=pyhb0f4dca_0\n  - bitarray=1.6.1=py38h27cfd23_0\n  - bkcharts=0.2=py38_0\n  - blas=1.0=mkl\n  - bleach=3.2.1=py_0\n  - blosc=1.20.1=hd408876_0\n  - bokeh=2.2.3=py38_0\n  - boto=2.49.0=py38_0\n  - bottleneck=1.3.2=py38heb32a55_1\n  - brotlipy=0.7.0=py38h7b6447c_1000\n  - bzip2=1.0.8=h7b6447c_0\n  - ca-certificates=2021.4.13=h06a4308_1\n  - cairo=1.14.12=h8948797_3\n  - certifi=2020.12.5=py38h06a4308_0\n  - cffi=1.14.3=py38he30daa8_0\n  - chardet=3.0.4=py38_1003\n  - click=7.1.2=py_0\n  - cloudpickle=1.6.0=py_0\n  - clyent=1.2.2=py38_1\n  - colorama=0.4.4=py_0\n  - conda-package-handling=1.7.2=py38h03888b9_0\n  - conda-verify=3.4.2=py_1\n  - contextlib2=0.6.0.post1=py_0\n  - cryptography=3.1.1=py38h1ba5d50_0\n  - cudatoolkit=11.0.221=h6bb024c_0\n  - curl=7.71.1=hbc83047_1\n  - cycler=0.10.0=py38_0\n  - cython=0.29.21=py38he6710b0_0\n  - cytoolz=0.11.0=py38h7b6447c_0\n  - dask=2.30.0=py_0\n  - dask-core=2.30.0=py_0\n  - dbus=1.13.18=hb2f20db_0\n  - decorator=4.4.2=py_0\n  - defusedxml=0.6.0=py_0\n  - diff-match-patch=20200713=py_0\n  - distributed=2.30.1=py38h06a4308_0\n  - docutils=0.16=py38_1\n  - entrypoints=0.3=py38_0\n  - et_xmlfile=1.0.1=py_1001\n  - expat=2.2.10=he6710b0_2\n  - fastcache=1.1.0=py38h7b6447c_0\n  - filelock=3.0.12=py_0\n  - flake8=3.8.4=py_0\n  - flask=1.1.2=py_0\n  - fontconfig=2.13.0=h9420a91_0\n  - freetype=2.10.4=h5ab3b9f_0\n  - fribidi=1.0.10=h7b6447c_0\n  - fsspec=0.8.3=py_0\n  - future=0.18.2=py38_1\n  - get_terminal_size=1.0.0=haa9412d_0\n  - gevent=20.9.0=py38h7b6447c_0\n  - glib=2.66.1=h92f7085_0\n  - glob2=0.7=py_0\n  - gmp=6.1.2=h6c8ec71_1\n  - gmpy2=2.0.8=py38hd5f6e3b_3\n  - graphite2=1.3.14=h23475e2_0\n  - greenlet=0.4.17=py38h7b6447c_0\n  - gst-plugins-base=1.14.0=hbbd80ab_1\n  - gstreamer=1.14.0=hb31296c_0\n  - h5py=2.10.0=py38h7918eee_0\n  - harfbuzz=2.4.0=hca77d97_1\n  - hdf5=1.10.4=hb1b8bf9_0\n  - heapdict=1.0.1=py_0\n  - html5lib=1.1=py_0\n  - icu=58.2=he6710b0_3\n  - idna=2.10=py_0\n  - imageio=2.9.0=py_0\n  - imagesize=1.2.0=py_0\n  - importlib_metadata=2.0.0=1\n  - iniconfig=1.1.1=py_0\n  - intel-openmp=2020.2=254\n  - intervaltree=3.1.0=py_0\n  - ipykernel=5.3.4=py38h5ca1d4c_0\n  - ipython=7.19.0=py38hb070fc8_0\n  - ipython_genutils=0.2.0=py38_0\n  - ipywidgets=7.5.1=py_1\n  - isort=5.6.4=py_0\n  - itsdangerous=1.1.0=py_0\n  - jbig=2.1=hdba287a_0\n  - jdcal=1.4.1=py_0\n  - jedi=0.17.1=py38_0\n  - jeepney=0.5.0=pyhd3eb1b0_0\n  - jinja2=2.11.2=py_0\n  - joblib=0.17.0=py_0\n  - jpeg=9b=h024ee3a_2\n  - json5=0.9.5=py_0\n  - jsonschema=3.2.0=py_2\n  - jupyter=1.0.0=py38_7\n  - jupyter_client=6.1.7=py_0\n  - jupyter_console=6.2.0=py_0\n  - jupyter_core=4.6.3=py38_0\n  - jupyterlab=2.2.6=py_0\n  - jupyterlab_pygments=0.1.2=py_0\n  - jupyterlab_server=1.2.0=py_0\n  - keyring=21.4.0=py38_1\n  - kiwisolver=1.3.0=py38h2531618_0\n  - krb5=1.18.2=h173b8e3_0\n  - lazy-object-proxy=1.4.3=py38h7b6447c_0\n  - lcms2=2.11=h396b838_0\n  - ld_impl_linux-64=2.33.1=h53a641e_7\n  - libarchive=3.4.2=h62408e4_0\n  - libcurl=7.71.1=h20c2e04_1\n  - libedit=3.1.20191231=h14c3975_1\n  - libffi=3.3=he6710b0_2\n  - libgcc-ng=9.1.0=hdf63c60_0\n  - libgfortran-ng=7.3.0=hdf63c60_0\n  - liblief=0.10.1=he6710b0_0\n  - libllvm10=10.0.1=hbcb73fb_5\n  - libllvm9=9.0.1=h4a3c616_1\n  - libpng=1.6.37=hbc83047_0\n  - libsodium=1.0.18=h7b6447c_0\n  - libspatialindex=1.9.3=he6710b0_0\n  - libssh2=1.9.0=h1ba5d50_1\n  - libstdcxx-ng=9.1.0=hdf63c60_0\n  - libtiff=4.1.0=h2733197_1\n  - libtool=2.4.6=h7b6447c_1005\n  - libuuid=1.0.3=h1bed415_2\n  - libuv=1.40.0=h7b6447c_0\n  - libxcb=1.14=h7b6447c_0\n  - libxml2=2.9.10=hb55368b_3\n  - libxslt=1.1.34=hc22bd24_0\n  - llvmlite=0.34.0=py38h269e1b5_4\n  - locket=0.2.0=py38_1\n  - lxml=4.6.1=py38hefd8a0e_0\n  - lz4-c=1.9.2=heb0550a_3\n  - lzo=2.10=h7b6447c_2\n  - markupsafe=1.1.1=py38h7b6447c_0\n  - matplotlib=3.3.2=0\n  - matplotlib-base=3.3.2=py38h817c723_0\n  - mccabe=0.6.1=py38_1\n  - mistune=0.8.4=py38h7b6447c_1000\n  - mkl=2020.2=256\n  - mkl-service=2.3.0=py38he904b0f_0\n  - mkl_fft=1.2.0=py38h23d657b_0\n  - mkl_random=1.1.1=py38h0573a6f_0\n  - mock=4.0.2=py_0\n  - more-itertools=8.6.0=pyhd3eb1b0_0\n  - mpc=1.1.0=h10f8cd9_1\n  - mpfr=4.0.2=hb69a4c5_1\n  - mpmath=1.1.0=py38_0\n  - msgpack-python=1.0.0=py38hfd86e86_1\n  - multipledispatch=0.6.0=py38_0\n  - navigator-updater=0.2.1=py38_0\n  - nbclient=0.5.1=py_0\n  - nbconvert=6.0.7=py38_0\n  - nbformat=5.0.8=py_0\n  - ncurses=6.2=he6710b0_1\n  - nest-asyncio=1.4.2=pyhd3eb1b0_0\n  - networkx=2.5=py_0\n  - ninja=1.10.2=hff7bd54_1\n  - nltk=3.5=py_0\n  - nose=1.3.7=py38_2\n  - notebook=6.1.4=py38_0\n  - numba=0.51.2=py38h0573a6f_1\n  - numexpr=2.7.1=py38h423224d_0\n  - numpy-base=1.19.2=py38hfa32c7d_0\n  - numpydoc=1.1.0=pyhd3eb1b0_1\n  - olefile=0.46=py_0\n  - openpyxl=3.0.5=py_0\n  - openssl=1.1.1k=h27cfd23_0\n  - packaging=20.4=py_0\n  - pandas=1.1.3=py38he6710b0_0\n  - pandoc=2.11=hb0f4dca_0\n  - pandocfilters=1.4.3=py38h06a4308_1\n  - pango=1.45.3=hd140c19_0\n  - parso=0.7.0=py_0\n  - partd=1.1.0=py_0\n  - patchelf=0.12=he6710b0_0\n  - path=15.0.0=py38_0\n  - path.py=12.5.0=0\n  - pathlib2=2.3.5=py38_0\n  - pathtools=0.1.2=py_1\n  - patsy=0.5.1=py38_0\n  - pcre=8.44=he6710b0_0\n  - pep8=1.7.1=py38_0\n  - pexpect=4.8.0=py38_0\n  - pickleshare=0.7.5=py38_1000\n  - pip=20.2.4=py38h06a4308_0\n  - pixman=0.40.0=h7b6447c_0\n  - pkginfo=1.6.1=py38h06a4308_0\n  - pluggy=0.13.1=py38_0\n  - ply=3.11=py38_0\n  - prometheus_client=0.8.0=py_0\n  - prompt-toolkit=3.0.8=py_0\n  - prompt_toolkit=3.0.8=0\n  - psutil=5.7.2=py38h7b6447c_0\n  - ptyprocess=0.6.0=py38_0\n  - py=1.9.0=py_0\n  - py-lief=0.10.1=py38h403a769_0\n  - pycodestyle=2.6.0=py_0\n  - pycosat=0.6.3=py38h7b6447c_1\n  - pycparser=2.20=py_2\n  - pycurl=7.43.0.6=py38h1ba5d50_0\n  - pydocstyle=5.1.1=py_0\n  - pyflakes=2.2.0=py_0\n  - pygments=2.7.2=pyhd3eb1b0_0\n  - pylint=2.6.0=py38_0\n  - pyodbc=4.0.30=py38he6710b0_0\n  - pyopenssl=19.1.0=py_1\n  - pyparsing=2.4.7=py_0\n  - pyqt=5.9.2=py38h05f1152_4\n  - pyrsistent=0.17.3=py38h7b6447c_0\n  - pysocks=1.7.1=py38_0\n  - pytables=3.6.1=py38h9fd0a39_0\n  - pytest=6.1.1=py38_0\n  - python=3.8.5=h7579374_1\n  - python-dateutil=2.8.1=py_0\n  - python-jsonrpc-server=0.4.0=py_0\n  - python-language-server=0.35.1=py_0\n  - python-libarchive-c=2.9=py_0\n  - pytorch=1.7.1=py3.8_cuda11.0.221_cudnn8.0.5_0\n  - pytz=2020.1=py_0\n  - pywavelets=1.1.1=py38h7b6447c_2\n  - pyxdg=0.27=pyhd3eb1b0_0\n  - pyyaml=5.3.1=py38h7b6447c_1\n  - pyzmq=19.0.2=py38he6710b0_1\n  - qdarkstyle=2.8.1=py_0\n  - qt=5.9.7=h5867ecd_1\n  - qtawesome=1.0.1=py_0\n  - qtconsole=4.7.7=py_0\n  - qtpy=1.9.0=py_0\n  - readline=8.0=h7b6447c_0\n  - regex=2020.10.15=py38h7b6447c_0\n  - requests=2.24.0=py_0\n  - ripgrep=12.1.1=0\n  - rope=0.18.0=py_0\n  - rtree=0.9.4=py38_1\n  - ruamel_yaml=0.15.87=py38h7b6447c_1\n  - scikit-learn=0.23.2=py38h0573a6f_0\n  - scipy=1.5.2=py38h0b6359f_0\n  - seaborn=0.11.0=py_0\n  - secretstorage=3.1.2=py38_0\n  - send2trash=1.5.0=py38_0\n  - setuptools=50.3.1=py38h06a4308_1\n  - simplegeneric=0.8.1=py38_2\n  - singledispatch=3.4.0.3=py_1001\n  - sip=4.19.13=py38he6710b0_0\n  - six=1.15.0=py38h06a4308_0\n  - snappy=1.1.8=he6710b0_0\n  - snowballstemmer=2.0.0=py_0\n  - sortedcollections=1.2.1=py_0\n  - sortedcontainers=2.2.2=py_0\n  - soupsieve=2.0.1=py_0\n  - sphinx=3.2.1=py_0\n  - sphinxcontrib=1.0=py38_1\n  - sphinxcontrib-applehelp=1.0.2=py_0\n  - sphinxcontrib-devhelp=1.0.2=py_0\n  - sphinxcontrib-htmlhelp=1.0.3=py_0\n  - sphinxcontrib-jsmath=1.0.1=py_0\n  - sphinxcontrib-qthelp=1.0.3=py_0\n  - sphinxcontrib-serializinghtml=1.1.4=py_0\n  - sphinxcontrib-websupport=1.2.4=py_0\n  - spyder=4.1.5=py38_0\n  - spyder-kernels=1.9.4=py38_0\n  - sqlalchemy=1.3.20=py38h7b6447c_0\n  - sqlite=3.33.0=h62c20be_0\n  - statsmodels=0.12.0=py38h7b6447c_0\n  - sympy=1.6.2=py38h06a4308_1\n  - tbb=2020.3=hfd86e86_0\n  - tblib=1.7.0=py_0\n  - terminado=0.9.1=py38_0\n  - testpath=0.4.4=py_0\n  - threadpoolctl=2.1.0=pyh5ca1d4c_0\n  - tifffile=2020.10.1=py38hdd07704_2\n  - tk=8.6.10=hbc83047_0\n  - toml=0.10.1=py_0\n  - toolz=0.11.1=py_0\n  - torchvision=0.8.2=py38_cu110\n  - tornado=6.0.4=py38h7b6447c_1\n  - tqdm=4.50.2=py_0\n  - traitlets=5.0.5=py_0\n  - typing_extensions=3.7.4.3=py_0\n  - ujson=4.0.1=py38he6710b0_0\n  - unicodecsv=0.14.1=py38_0\n  - unixodbc=2.3.9=h7b6447c_0\n  - urllib3=1.25.11=py_0\n  - watchdog=0.10.3=py38_0\n  - wcwidth=0.2.5=py_0\n  - webencodings=0.5.1=py38_1\n  - werkzeug=1.0.1=py_0\n  - wheel=0.35.1=py_0\n  - widgetsnbextension=3.5.1=py38_0\n  - wrapt=1.11.2=py38h7b6447c_0\n  - wurlitzer=2.0.1=py38_0\n  - xlrd=1.2.0=py_0\n  - xlsxwriter=1.3.7=py_0\n  - xlwt=1.3.0=py38_0\n  - xmltodict=0.12.0=py_0\n  - xz=5.2.5=h7b6447c_0\n  - yaml=0.2.5=h7b6447c_0\n  - yapf=0.30.0=py_0\n  - zeromq=4.3.3=he6710b0_3\n  - zict=2.0.0=py_0\n  - zipp=3.4.0=pyhd3eb1b0_0\n  - zlib=1.2.11=h7b6447c_3\n  - zope=1.0=py38_1\n  - zope.event=4.5.0=py38_0\n  - zope.interface=5.1.2=py38h7b6447c_0\n  - zstd=1.4.5=h9ceee32_0\n  - pip:\n    - absl-py==0.12.0\n    - cachetools==5.2.0\n    - einops==0.3.0\n    - google-auth==2.8.0\n    - google-auth-oauthlib==0.4.6\n    - grpcio==1.46.3\n    - importlib-metadata==4.11.4\n    - markdown==3.3.7\n    - medpy==0.4.0\n    - ml-collections==0.1.0\n    - numpy==1.19.5\n    - oauthlib==3.2.0\n    - opencv-python-headless==4.5.1.48\n    - pillow==9.1.1\n    - protobuf==3.17.0\n    - pyasn1==0.4.8\n    - pyasn1-modules==0.2.8\n    - requests-oauthlib==1.3.1\n    - rsa==4.8\n    - scikit-image==0.15.0\n    - simpleitk==2.0.2\n    - tensorboard==2.9.0\n    - tensorboard-data-server==0.6.1\n    - tensorboard-plugin-wit==1.8.1\n    - timm==0.4.9\n    - yacs==0.1.6\nprefix: /root/anaconda3/envs/nl\n"
  }
]