Repository: lisiyao21/AnimeInbet Branch: main Commit: cc5554feb9d8 Files: 35 Total size: 217.5 KB Directory structure: gitextract_oa9f83a9/ ├── .gitignore ├── README.md ├── compute_cd.py ├── configs/ │ └── cr_inbetweener_full.yaml ├── corr/ │ ├── configs/ │ │ └── vtx_corr.yaml │ ├── datasets/ │ │ ├── __init__.py │ │ └── ml_dataset.py │ ├── experiments/ │ │ └── vtx_corr/ │ │ └── ckpt/ │ │ └── .gitkeep │ ├── main.py │ ├── models/ │ │ ├── __init__.py │ │ └── supergluet.py │ ├── srun.sh │ ├── utils/ │ │ ├── log.py │ │ └── visualize_vtx_corr.py │ └── vtx_matching.py ├── data/ │ └── README.md ├── datasets/ │ ├── __init__.py │ ├── ml_seq.py │ └── vd_seq.py ├── download.sh ├── experiments/ │ └── inbetweener_full/ │ └── ckpt/ │ └── .gitkeep ├── inbetween.py ├── inbetween_results/ │ └── .gitkeep ├── main.py ├── models/ │ ├── __init__.py │ ├── inbetweener_with_mask2.py │ └── inbetweener_with_mask_with_spec.py ├── requirement.txt ├── srun.sh └── utils/ ├── chamfer_distance.py ├── log.py ├── visualize_inbetween.py ├── visualize_inbetween2.py ├── visualize_inbetween3.py └── visualize_video.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: .gitignore ================================================ */__pycache__/* *.pt *.jpg *.pyc data/ml100_norm/ data/ml144* data/*.zip ================================================ FILE: README.md ================================================ # AnimeInbet Code for ICCV 2023 paper "Deep Geometrized Cartoon Line Inbetweening" [[Paper]](https://openaccess.thecvf.com/content/ICCV2023/papers/Siyao_Deep_Geometrized_Cartoon_Line_Inbetweening_ICCV_2023_paper.pdf) | [[Video Demo]](https://youtu.be/iUF-LsqFKpI?si=9FViAZUyFdSfZzS5) | [[Data (Google Drive)]](https://drive.google.com/file/d/1SNRGajIECxNwRp6ZJ0IlY7AEl2mRm2DR/view?usp=sharing) ✨ Do not hesitate to give a star! Thank you! ✨ ![image](https://github.com/lisiyao21/AnimeInbet/blob/main/figures/inbet_gif.gif) > We aim to address a significant but understudied problem in the anime industry, namely the inbetweening of cartoon line drawings. Inbetweening involves generating intermediate frames between two black-and-white line drawings and is a time-consuming and expensive process that can benefit from automation. However, existing frame interpolation methods that rely on matching and warping whole raster images are unsuitable for line inbetweening and often produce blurring artifacts that damage the intricate line structures. To preserve the precision and detail of the line drawings, we propose a new approach, AnimeInbet, which geometrizes raster line drawings into graphs of endpoints and reframes the inbetweening task as a graph fusion problem with vertex repositioning. Our method can effectively capture the sparsity and unique structure of line drawings while preserving the details during inbetweening. This is made possible via our novel modules, i.e., vertex geometric embedding, a vertex correspondence Transformer, an effective mechanism for vertex repositioning and a visibility predictor. To train our method, we introduce MixamoLine240, a new dataset of line drawings with ground truth vectorization and matching labels. Our experiments demonstrate that AnimeInbet synthesizes high-quality, clean, and complete intermediate line drawings, outperforming existing methods quantitatively and qualitatively, especially in cases with large motions. # ML240 Data The implementation of AnimeInbet depends on the matching of line vertices in the two adjancent two frames. To supervise the learning of vertex correspondence, we make a large-scale cartoon line sequential data, **MixiamoLine240** (ML240). ML240 contains a training set (100 sequences), a validation set (44 sequences) and a test set (100 sequences). Each sequence i To use the data, please first download it from [link](https://drive.google.com/file/d/1SNRGajIECxNwRp6ZJ0IlY7AEl2mRm2DR/view?usp=sharing) and uncompress it into **data** folder under this project directory. After decompression, the data will be like data |_ml100_norm | |_ all | |_frames | | |_chip_abe | | | |_Image0001.png | | | |_Image0001.png | | | | | | | ... | | ... | | | |_labels | |_chip_abe | | |_Line0001.json | | |_Line0001.json | | | | | ... | ... | |_ml144_norm_100_44_split |_ test | |_frames | | |_breakdance_1990_police | | | |_Image0001.png | | | |_Image0001.png | | | | | | | ... | | ... | | | |_labels | |_breakdance_1990_police | | |_Line0001.json | | |_Line0001.json | | | | | ... | ... |_ train |_frames | |_breakdance_1990_ganfaul | | |_Image0001.png | | |_Image0001.png | | | | | ... | ... | |_labels |_breakdance_1990_ganfaul | |_Line0001.json | |_Line0001.json | | | ... ... The json file in the "labels" folder (for example, ml100_norm/all/labels/chip_abe/Line0001.json) is the verctorization/geometrization labels of the corresponding image in the "frames" folder (ml100_norm/ all/frames/chip_abe/_Image0001.png). Each json file contains there components. (1) **vertex location**: line art vertices 2D positions, (2) **connection**: adjancent table of the vector graph and (3) **original index**: the index number of each vertex in the original 3D mesh. # Code ## Environment conda create -n inbetween python=3.8 conda activate inbetween conda install pytorch==1.7.0 torchvision==0.8.0 torchaudio==0.7.0 cudatoolkit=10.1 -c pytorch pip install -r requirement.txt ![image](https://github.com/lisiyao21/AnimeInbet/blob/main/figures/pipeline.png) In this code, the whole pipeline is separated into two parts: (1) vertex correspondence and (2) inbetweening/synthesis. In the first part, it is trained to match the vertices of two input vector graphs, including the "vertex embedding" and "vertex corr. Transformer". Then, "repositioning propagation" and "graph fusion" are done in the second part. The first part is inner ./corr, and the second is all others. We provide a pretrained correspondence network weight ([link](https://drive.google.com/file/d/1Edc-XGyMXqXDdfBYoglDMkBf7_AYZU0p/view?usp=sharing)) and a pretrained whole pipeline weight ([link](https://drive.google.com/file/d/1cemJCBNdcTvJ9LWCA_5LmDDorwEb-u7M/view?usp=sharing)). For correspondence, please decompress the weight (epoch_50.pt) to ./corr/experiments/vtx_corr/ckpt. For the whole pipeline, please decompress the weight (epoch_20.pt) to ./experiments/inbetweener_full/ckpt/. ## Train & test corr. For training, first, please cd into the ./corr folder and then run sh srun.sh configs/vtx_corr.yaml train [your node name] 1 If you don't use slurm in your computer/cluster, you can run python -u main.py --config vtx_corr.yaml --train For testing correspondence network, please run sh srun.sh configs/vtx_corr.yaml train [your node name] 1 or python -u main.py --config vtx_corr.yaml --test You may directly run the test code after downloading the weights without training. ## Train & test the whole inbetweening pipeline For training the whole pipeline, please firstly cd out from ./corr to the root project folder and run sh srun.sh configs/cr_inbetweener_full.yaml train [your node name] 1 or python -u main.py --config cr_inbetweener_full.yaml --train For testing, please run sh srun.sh configs/cr_inbetweener_full.yaml train [your node name] 1 or python -u main.py --config cr_inbetweener_full.yaml --test Inbetweened results will be stored into ./inbetween_results folder. ### Compute CD values The CD code is under utils/chamfer_distance.py. Please run python compute_cd.py --gt ./data/ml100_norm/all/frames --generated ./inbetween_results/test_gap=5 If everything goes right the score will be the same as that reported in the paper. # Citation If you use our code or data, or find our work inspiring, please kindly cite our paper: @inproceedings{siyao2023inbetween, title={Deep Geometrized Cartoon Line Inbetweening, author={Siyao, Li and Gu, Tianpei and Xiao, Weiye and Ding, Henghui and Liu, Ziwei and Loy, Chen Change}, booktitle={ICCV}, year={2023} } # License ML240 is released with CC BY-NC-SA 4.0. Code is released for non-commercial uses only. ================================================ FILE: compute_cd.py ================================================ import argparse import cv2 import os from utils.chamfer_distance import cd_score import numpy as np if __name__ == "__main__": cds = [] parser = argparse.ArgumentParser() parser.add_argument('--generated', type=str) parser.add_argument('--gt', type=str) args = parser.parse_args() gen_dir = args.generated gt_dir = args.gt if True: print('computing CD...', flush=True) for subfolder in os.listdir(gt_dir): # print(subfolder, len(cds), flush=True) for img in os.listdir(os.path.join(gt_dir, subfolder)): if not img.endswith('.png'): continue img_gt = cv2.imread(os.path.join(gt_dir, subfolder, img)) pred_name = subfolder + '_' + img.replace('Image', 'Line') if not os.path.exists(os.path.join(gen_dir, pred_name)): continue img_pred = cv2.imread(os.path.join(gen_dir, pred_name)) this_cd = cd_score(img_gt, img_pred) cds.append(this_cd) # print(this_cd, flush=True) print('GT: ', gt_dir) print('>>> Gen: ', gen_dir) print('>>> CD: ', np.mean(cds)/1e-5, print(len(cds))) ================================================ FILE: configs/cr_inbetweener_full.yaml ================================================ model: name: InbetweenerTM corr_model: descriptor_dim: 128 keypoint_encoder: [32, 64, 128] GNN_layer_num: 12 sinkhorn_iterations: 20 match_threshold: 0.2 descriptor_dim: 128 pos_weight: 0.2 optimizer: type: Adam kwargs: lr: 0.0001 betas: [0.9, 0.999] weight_decay: 0 schedular_kwargs: milestones: [80] gamma: 0.1 data: train: root: 'data/ml144_norm_100_44_split/' batch_size: 1 gap: 5 type: 'train' model: None action: None mode: 'train' test: root: 'data/ml100_norm/' batch_size: 1 gap: 5 type: 'all' model: None action: None mode: 'eval' use_vs: False testing: ckpt_epoch: 20 batch_size: 8 corr_weights: './corr/experiments/vtx_corr/ckpt/epoch_50.pt' imwrite_dir: ./inbetween_results/test_gap=5 expname: inbetweener_full epoch: 20 save_per_epochs: 1 log_per_updates: 1 test_freq: 10 seed: 42 ================================================ FILE: corr/configs/vtx_corr.yaml ================================================ model: name: SuperGlueT descriptor_dim: 128 keypoint_encoder: [32, 64, 128] GNN_layer_num: 12 sinkhorn_iterations: 20 match_threshold: 0.2 descriptor_dim: 128 optimizer: type: Adam kwargs: lr: 0.00001 betas: [0.9, 0.999] weight_decay: 0 schedular_kwargs: milestones: [50, 150] gamma: 0.1 data: train: batch_size: 1 gap: 5 model: None action: None type: 'train' mode: 'train' test: batch_size: 1 gap: 5 type: 'test' model: None action: None mode: 'eval' testing: ckpt_epoch: 50 batch_size: 8 expname: vtx_corr epoch: 50 save_per_epochs: 1 log_per_updates: 1 test_freq: 1 seed: 42 ================================================ FILE: corr/datasets/__init__.py ================================================ from .ml_dataset import MixamoLineArt from .ml_dataset import fetch_dataloader __all__ = ['MixamoLineArt', 'fetch_dataloader'] ================================================ FILE: corr/datasets/ml_dataset.py ================================================ import numpy as np import torch import torch.utils.data as data import torch.nn.functional as F # import networkx as nx import os import math import random from glob import glob import os.path as osp import sys import argparse import cv2 from collections import Counter import json import sknetwork from sknetwork.embedding import Spectral def read_json(file_path): """ input: json file path output: 2d vertex, connections, and index numbers in original 3D space """ with open(file_path) as file: data = json.load(file) vertex2d = np.array(data['vertex location']) topology = data['connection'] index = np.array(data['original index']) return vertex2d, topology, index def ids_to_mat(id1, id2): """ inputs are two list of vertex index in original 3D mesh """ corr1 = np.zeros(len(id1)) - 1.0 corr2 = np.zeros(len(id2)) - 1.0 id1 = np.array(id1).astype(int)[:, None] id2 = np.array(id2).astype(int) mat = (id1 == id2) pos12 = np.arange(len(id2))[None].repeat(len(id1), 0) pos21 = np.arange(len(id1))[None].repeat(len(id2), 0) corr1[mat.astype(int).sum(1).astype(bool)] = pos12[mat] corr2[mat.transpose().astype(int).sum(1).astype(bool)] = pos21[mat.transpose()] return mat, corr1, corr2 def adj_matrix(topology): """ topology is the adj table; returns adj matrix """ gsize = len(topology) adj = np.zeros((gsize, gsize)).astype(float) for v in range(gsize): adj[v][v] = 1.0 for nb in topology[v]: adj[v][nb] = 1.0 adj[nb][v] = 1.0 return adj class MixamoLineArt(data.Dataset): def __init__(self, root, gap=0, split='train', model=None, action=None, mode='train', max_len=3050, use_vs=False): """ input: root: the root folder of the line art data gap: how many frames between two frames split: train or test model: indicate a specific character (default None) action: indicate a specific action (default None) """ super(MixamoLineArt, self).__init__() if model == 'None': model = None if action == 'None': action = None self.is_train = True if mode == 'train' else False self.is_eval = True if mode == 'eval' else False # self.is_train = False self.max_len = max_len self.image_list = [] self.label_list = [] if use_vs: label_root = osp.join(root, split, 'labels_vs') else: label_root = osp.join(root, split, 'labels') image_root = osp.join(root, split, 'frames') self.spectral = Spectral(64, normalized=False) for clip in os.listdir(image_root): skip = False if model != None: for mm in model: if mm in clip: skip = True if action != None: for aa in action: if aa in clip: skip = True if skip: continue image_list = sorted(glob(osp.join(image_root, clip, '*.png'))) label_list = sorted(glob(osp.join(label_root, clip, '*.json'))) if len(image_list) != len(label_list): print(image_root, flush=True) continue for i in range(len(image_list) - (gap+1)): self.image_list += [ [image_list[i], image_list[i+gap+1]] ] for i in range(len(label_list) - (gap+1)): self.label_list += [ [label_list[i], label_list[i+gap+1]] ] # print(clip) print('Len of Frame is ', len(self.image_list)) print('Len of Label is ', len(self.label_list)) def __getitem__(self, index): # load image/label files # image crop to a square, 2d label same operation # index to index matching # spectral embedding # test does not need index matching index = index % len(self.image_list) file_name = self.label_list[index][0][:-4] img1 = cv2.imread(self.image_list[index][0]) img2 = cv2.imread(self.image_list[index][1]) v2d1, topo1, id1 = read_json(self.label_list[index][0]) v2d2, topo2, id2 = read_json(self.label_list[index][1]) for ii in range(len(topo1)): # if not len(topo1[ii]): topo1[ii].append(ii) for ii in range(len(topo2)): topo2[ii].append(ii) m, n = len(v2d1), len(v2d2) # img1, v2d1 = crop_img(img1, np.array(v2d1)) # img2, v2d2 = crop_img(img2, np.array(v2d2)) if len(img1.shape) == 2: img1 = np.tile(img1[...,None], (1, 1, 3)) img2 = np.tile(img2[...,None], (1, 1, 3)) else: img1 = img1[..., :3] img2 = img2[..., :3] img1 = torch.from_numpy(img1).permute(2, 0, 1).float() * 2 / 255.0 - 1.0 img2 = torch.from_numpy(img2).permute(2, 0, 1).float() * 2 / 255.0 - 1.0 v2d1 = torch.from_numpy(v2d1) v2d2 = torch.from_numpy(v2d2) v2d1[v2d1 > 719] = 719 v2d1[v2d1 < 0] = 0 v2d2[v2d2 > 719] = 719 v2d2[v2d2 < 0] = 0 adj1 = sknetwork.data.from_adjacency_list(topo1, matrix_only=True, reindex=False).toarray() adj2 = sknetwork.data.from_adjacency_list(topo2, matrix_only=True, reindex=False).toarray() # note here we compute the spectral embedding of adj matrix in data loading period # since it needs cpu computation and is not friendy to our cluster's computation # put them here to use multi-cpu pre-computing before network forward flow spec0, spec1 = np.abs(self.spectral.fit_transform(adj1)), np.abs(self.spectral.fit_transform(adj2)) mat_index, corr1, corr2 = ids_to_mat(id1, id2) mat_index = torch.from_numpy(mat_index).float() corr1 = torch.from_numpy(corr1).float() corr2 = torch.from_numpy(corr2).float() if self.is_train: # if False: v2d1 = torch.nn.functional.pad(v2d1, (0, 0, 0, self.max_len - m), mode='constant', value=0) v2d2 = torch.nn.functional.pad(v2d2, (0, 0, 0, self.max_len - n), mode='constant', value=0) corr1 = torch.nn.functional.pad(corr1, (0, self.max_len - m), mode='constant', value=0) corr2 = torch.nn.functional.pad(corr2, (0, self.max_len - n), mode='constant', value=0) mask0, mask1 = torch.zeros(self.max_len).float(), torch.zeros(self.max_len).float() mask0[:m] = 1 mask1[:n] = 1 else: mask0, mask1 = torch.ones(m).float(), torch.ones(n).float() # not return id anymore. too slow if self.is_eval: return{ 'keypoints0': v2d1, 'keypoints1': v2d2, 'topo0': [topo1], 'topo1': [topo2], # 'id0': id1, # 'id1': id2, 'adj_mat0': spec0, 'adj_mat1': spec1, 'image0': img1, 'image1': img2, 'all_matches': corr1, 'm01': corr1, 'm10': corr2, 'ms': m, 'ns': n, 'mask0': mask0, 'mask1': mask1, 'file_name': file_name, # 'with_match': True } if not self.is_train: return{ 'keypoints0': v2d1, 'keypoints1': v2d2, # 'topo0': topo1, # 'topo1': topo2, # 'id0': id1, # 'id1': id2, 'adj_mat0': spec0, 'adj_mat1': spec1, 'image0': img1, 'image1': img2, 'all_matches': corr1, 'm01': corr1, 'm10': corr2, 'ms': m, 'ns': n, 'mask0': mask0, 'mask1': mask1, 'file_name': file_name, # 'with_match': True } else: return{ 'keypoints0': v2d1, 'keypoints1': v2d2, # 'topo0': topo1, # 'topo1': topo2, # 'id0': id1, # 'id1': id2, 'adj_mat0': spec0, 'adj_mat1': spec1, 'image0': img1, 'image1': img2, 'all_matches': corr1, 'm01': corr1, 'm10': corr2, 'ms': m, 'ns': n, 'mask0': mask0, 'mask1': mask1, 'file_name': file_name, # 'with_match': True } def __rmul__(self, v): self.index_list = v * self.index_list self.seg_list = v * self.seg_list self.image_list = v * self.image_list return self def __len__(self): return len(self.image_list) def worker_init_fn(worker_id): np.random.seed(np.random.get_state()[1][0] + worker_id) def fetch_dataloader(args, type='train',): lineart = MixamoLineArt(root=args.root if hasattr(args, 'root') else '../data/ml144_norm_100_44_split/', gap=args.gap, split=args.type, model=args.model, action=args.action, mode=args.mode if hasattr(args, 'mode') else 'train', use_vs=args.use_vs if hasattr(args, 'use_vs') else False) train_loader = data.DataLoader(lineart, batch_size=args.batch_size, pin_memory=True, shuffle=True, num_workers=8, drop_last=True, worker_init_fn=worker_init_fn) if args.mode != 'train': loader = data.DataLoader(lineart, batch_size=args.batch_size, pin_memory=True, shuffle=False, num_workers=8) return train_loader if __name__ == '__main__': torch.multiprocessing.set_sharing_strategy('file_system') args = argparse.Namespace() # args.subset = 'agent' args.batch_size = 1 args.gap = 5 args.type = 'test' args.model = ['ganfaul', 'firlscout', 'jolleen', 'kachujin', 'knight', 'maria', 'michelle', 'peasant', 'timmy', 'uriel'] args.action = ['hip_hop', 'slash'] # args.model = None # args.action = None args.use_vs = False # args.model = ['warrok', 'police'] args.action = ['breakdance', 'capoeira', 'chapa-', 'fist_fight', 'flying', 'climb', 'running', 'reaction', 'magic', 'tripping'] args.mode = 'eval' args.root='/mnt/lustre/syli/inbetween/data/12by12/ml144_norm_100_44_split/' # args.stage = 'anime' # args.image_size = (368, 368) # lineart = MixamoLineArt(root='/mnt/lustre/syli/inbetween/data/12by12/ml144/', gap=0, split='train') lineart = fetch_dataloader(args) # lineart = MixamoLineArt(root='/mnt/cache/syli/inbetween/data/ml100_norm/', gap=args.gap, split=args.type, model=args.model, action=args.action, mode=args.mode if hasattr(args, 'mode') else 'train') # train_loader = data.DataLoader(lineart, batch_size=args.batch_size, percentage = 0.0 vertex_num = 0.0 vertex_shift = 0.0 vertex_max_shift = 0.0 edges = 0.0 # for data in loader: # print(data) # break unmatched_all = [] max_node = 0 for dict in lineart: # print(dict['file_name']) # print(dict['file_name'][0], flush=True) v2d1 = dict['keypoints0'].numpy().astype(int)[0] v2d2 = dict['keypoints1'].numpy().astype(int)[0] ms = dict['ms'][0] ns = dict['ns'][0] # this_edges topo = dict['topo0'][0] for ii in range(len(topo)): edges += len(topo[ii]) # print(len(topo), flush=True) # print(ms, ns, flush=True) # print(dict['keypoints0'], flush=True) # print(dict['image0'].size(), flush=True) v2d1 = v2d1[:ms] v2d2 = v2d2[:ns] m01 = dict['m01'][0][:ms] # print(m01.shape) # print(np.arange(len(m01))[m01 != -1], m01[m01 != -1]) # print(v2d2.shape, v2d1.shape) shift = np.sqrt(((v2d2[m01[m01 != -1].int(), :] * 1.0 - v2d1[np.arange(len(m01))[m01 != -1],:]) ** 2).sum(-1)) vertex_num += len(v2d1) vertex_shift += shift.mean() vertex_max_shift += shift.max() percentage += ((m01!=-1).float().sum() * 1.0 / len(m01) * 100.0) print('>>>> gap=', args.gap, ' percentage=', percentage / len(lineart), ' vertex num=', vertex_num*1.0/len(lineart), 'edges num=', edges*1.0/len(lineart)/2, 'vertex shift=', vertex_shift/len(lineart), ' vertex max shift=', vertex_max_shift/len(lineart), flush=True) # if len(v2d1) > max_node: # max_node = len(v2d1) # if len(v2d2) > max_node: # max_node = len(v2d2) # print(max_node) # print(v2d1.shape) # img1 = ((dict['image0'][0].permute(1, 2, 0).float().numpy() + 1.0) * 255 / 2).astype(np.uint8).copy() # img2 = ((dict['image1'][0].permute(1, 2, 0).float().numpy() + 1.0) * 255 / 2).astype(np.uint8).copy() # # print(v2d1.shape, img1.shape, flush=True) # for node, nbs in enumerate(dict['topo0']): # for nb in nbs: # cv2.line(img1, [v2d1[node][0], v2d1[node][1]], [v2d1[nb][0], v2d1[nb][1]], [255, 180, 180], 2) # colors1, colors2 = {}, {} # id1 = dict['id0'][0].numpy() # id2 = dict['id1'][0].numpy() # for index in id1: # # print(index) # color = [np.random.randint(0, 256), np.random.randint(0, 256), np.random.randint(0, 256)] # # for ii in index: # colors1[index] = color # colors1, colors2 = {}, {} # for index in id1: # # print(index) # color = [np.random.randint(0, 256), np.random.randint(0, 256), np.random.randint(0, 256)] # colors1[index] = color # for i, p in enumerate(v2d1): # ii = id1[i] # # print(ii) # cv2.circle(img1, [int(p[0]), int(p[1])], 1, colors1[ii], 2) # unmatched = 0 # for ii in id2: # color = [0, 0, 0] # this_is_umatched = 1 # colors2[ii] = colors1[ii] if ii in colors1 else color # if ii in colors1: # this_is_umatched = 0 # # if ii not in colors1: # unmatched += this_is_umatched # for i, p in enumerate(v2d2): # ii = id2[i] # # print(p) # cv2.circle(img2, [int(p[0]), int( p[1])], 1, colors2[ii], 2) # # print('Unmatched in Img 2: ', , '%') # unmatched_all.append(100 - unmatched * 100.0/len(v2d2)) # im_h = cv2.hconcat([img1, img2]) # print('/mnt/lustre/syli/inbetween/AnimeInbetween/corr/datasets/data_check_norm/' + dict['file_name'][0].replace('/', '_') + '.jpg', flush=True) # cv2.imwrite('/mnt/lustre/syli/inbetween/AnimeInbetween/corr/datasets/data_check_norm/' + dict['file_name'][0].replace('/', '_') + '.jpg', im_h) # print(np.mean(unmatched_all)) ================================================ FILE: corr/experiments/vtx_corr/ckpt/.gitkeep ================================================ ================================================ FILE: corr/main.py ================================================ from vtx_matching import VtxMat import argparse import os import yaml from pprint import pprint from easydict import EasyDict def parse_args(): parser = argparse.ArgumentParser( description='Anime segment matching') parser.add_argument('--config', default='') # exclusive arguments group = parser.add_mutually_exclusive_group(required=True) group.add_argument('--train', action='store_true') group.add_argument('--eval', action='store_true') return parser.parse_args() def main(): # parse arguments and load config args = parse_args() with open(args.config) as f: config = yaml.load(f) for k, v in vars(args).items(): config[k] = v pprint(config) config = EasyDict(config) agent = VtxMat(config) print(config) if args.train: agent.train() elif args.eval: agent.eval() if __name__ == '__main__': main() ================================================ FILE: corr/models/__init__.py ================================================ from .supergluet import SuperGlueT # from .supergluet_wo_OT import SuperGlueTwoOT # from .supergluenp import SuperGlue as SuperGlueNP # from .supergluei import SuperGlue as SuperGlueI # from .supergluet2 import SuperGlueT2 __all__ = ['SuperGlueT'] ================================================ FILE: corr/models/supergluet.py ================================================ import numpy as np from copy import deepcopy from pathlib import Path import torch from torch import nn import argparse from sknetwork.embedding import Spectral def MLP(channels: list, do_bn=True): """ Multi-layer perceptron """ n = len(channels) layers = [] for i in range(1, n): layers.append( nn.Conv1d(channels[i - 1], channels[i], kernel_size=1, bias=True)) if i < (n-1): if do_bn: layers.append(nn.InstanceNorm1d(channels[i])) layers.append(nn.ReLU()) return nn.Sequential(*layers) def normalize_keypoints(kpts, image_shape): """ Normalize keypoints locations based on image image_shape""" _, _, height, width = image_shape one = kpts.new_tensor(1) size = torch.stack([one*width, one*height])[None] center = size / 2 scaling = size.max(1, keepdim=True).values * 0.7 return (kpts - center[:, None, :]) / scaling[:, None, :] class ThreeLayerEncoder(nn.Module): """ Joint encoding of visual appearance and location using MLPs""" def __init__(self, enc_dim): super().__init__() # input must be 3 channel (r, g, b) self.layer1 = nn.Conv2d(3, enc_dim//4, 7, padding=3) self.non_linear1 = nn.ReLU() self.layer2 = nn.Conv2d(enc_dim//4, enc_dim//2, 3, padding=1) self.non_linear2 = nn.ReLU() self.layer3 = nn.Conv2d(enc_dim//2, enc_dim, 3, padding=1) self.norm1 = nn.InstanceNorm2d(enc_dim//4) self.norm2 = nn.InstanceNorm2d(enc_dim//2) self.norm3 = nn.InstanceNorm2d(enc_dim) for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') nn.init.constant_(m.bias, 0.0) def forward(self, img): x = self.non_linear1(self.norm1(self.layer1(img))) x = self.non_linear2(self.norm2(self.layer2(x))) x = self.norm3(self.layer3(x)) return x class VertexDescriptor(nn.Module): """ Joint encoding of visual appearance and location using MLPs""" def __init__(self, enc_dim): super().__init__() self.encoder = ThreeLayerEncoder(enc_dim) def forward(self, img, vtx): x = self.encoder(img) n, c, h, w = x.size() assert((h, w) == img.size()[2:4]) return x[:, :, torch.round(vtx[0, :, 1]).long(), torch.round(vtx[0, :, 0]).long()] class KeypointEncoder(nn.Module): """ Joint encoding of visual appearance and location using MLPs""" def __init__(self, feature_dim, layers): super().__init__() self.encoder = MLP([2] + layers + [feature_dim]) # for m in self.encoder.modules(): # if isinstance(m, nn.Conv2d): # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') # nn.init.constant_(m.bias, 0.0) nn.init.constant_(self.encoder[-1].bias, 0.0) def forward(self, kpts): inputs = kpts.transpose(1, 2) x = self.encoder(inputs) return x class TopoEncoder(nn.Module): """ Joint encoding of visual appearance and location using MLPs""" def __init__(self, feature_dim, layers): super().__init__() self.encoder = MLP([64] + layers + [feature_dim]) # for m in self.encoder.modules(): # if isinstance(m, nn.Conv2d): # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') # nn.init.constant_(m.bias, 0.0) nn.init.constant_(self.encoder[-1].bias, 0.0) def forward(self, kpts): inputs = kpts.transpose(1, 2) x = self.encoder(inputs) return x def attention(query, key, value, mask=None): dim = query.shape[1] scores = torch.einsum('bdhn,bdhm->bhnm', query, key) / dim**.5 if mask is not None: scores = scores.masked_fill(mask==0, float('-inf')) prob = torch.nn.functional.softmax(scores, dim=-1) return torch.einsum('bhnm,bdhm->bdhn', prob, value), prob class MultiHeadedAttention(nn.Module): """ Multi-head attention to increase model expressivitiy """ def __init__(self, num_heads: int, d_model: int): super().__init__() assert d_model % num_heads == 0 self.dim = d_model // num_heads self.num_heads = num_heads self.merge = nn.Conv1d(d_model, d_model, kernel_size=1) self.proj = nn.ModuleList([deepcopy(self.merge) for _ in range(3)]) def forward(self, query, key, value, mask=None): batch_dim = query.size(0) query, key, value = [l(x).view(batch_dim, self.dim, self.num_heads, -1) for l, x in zip(self.proj, (query, key, value))] x, prob = attention(query, key, value, mask) return self.merge(x.contiguous().view(batch_dim, self.dim*self.num_heads, -1)) class AttentionalPropagation(nn.Module): def __init__(self, feature_dim: int, num_heads: int): super().__init__() self.attn = MultiHeadedAttention(num_heads, feature_dim) self.mlp = MLP([feature_dim*2, feature_dim*2, feature_dim]) nn.init.constant_(self.mlp[-1].bias, 0.0) def forward(self, x, source, mask=None): message = self.attn(x, source, source, mask) return self.mlp(torch.cat([x, message], dim=1)) class AttentionalGNN(nn.Module): def __init__(self, feature_dim: int, layer_names: list): super().__init__() self.layers = nn.ModuleList([ AttentionalPropagation(feature_dim, 4) for _ in range(len(layer_names))]) self.names = layer_names def forward(self, desc0, desc1, mask00=None, mask11=None, mask01=None, mask10=None): for layer, name in zip(self.layers, self.names): layer.attn.prob = [] if name == 'cross': src0, src1 = desc1, desc0 mask0, mask1 = mask01[:, None], mask10[:, None] else: # if name == 'self': src0, src1 = desc0, desc1 mask0, mask1 = mask00[:, None], mask11[:, None] delta0, delta1 = layer(desc0, src0, mask0), layer(desc1, src1, mask1) desc0, desc1 = (desc0 + delta0), (desc1 + delta1) return desc0, desc1 def log_sinkhorn_iterations(Z, log_mu, log_nu, iters: int): """ Perform Sinkhorn Normalization in Log-space for stability""" u, v = torch.zeros_like(log_mu), torch.zeros_like(log_nu) for _ in range(iters): u = log_mu - torch.logsumexp(Z + v.unsqueeze(1), dim=2) v = log_nu - torch.logsumexp(Z + u.unsqueeze(2), dim=1) return Z + u.unsqueeze(2) + v.unsqueeze(1) def log_optimal_transport(scores, alpha, iters: int, ms=None, ns=None): """ Perform Differentiable Optimal Transport in Log-space for stability""" b, m, n = scores.shape one = scores.new_tensor(1) if ms is None or ns is None: ms, ns = (m*one).to(scores), (n*one).to(scores) # else: # ms, ns = ms.to(scores)[:, None], ns.to(scores)[:, None] # here m,n should be parameters not shape # ms, ns: (b, ) bins0 = alpha.expand(b, m, 1) bins1 = alpha.expand(b, 1, n) alpha = alpha.expand(b, 1, 1) # pad additional scores for unmatcheed (to -1) # alpha is the learned threshold couplings = torch.cat([torch.cat([scores, bins0], -1), torch.cat([bins1, alpha], -1)], 1) norm = - (ms + ns).log() # (b, ) # print(scores.min(), flush=True) if ms.size()[0] > 0: norm = norm[:, None] log_mu = torch.cat([norm.expand(b, m), ns.log()[:, None] + norm], dim=-1) # (m + 1) log_nu = torch.cat([norm.expand(b, n), ms.log()[:, None] + norm], dim=-1) # print(log_nu.min(), log_mu.min(), flush=True) else: log_mu = torch.cat([norm.expand(m), ns.log()[None] + norm]) # (m + 1) log_nu = torch.cat([norm.expand(n), ms.log()[None] + norm]) log_mu, log_nu = log_mu[None].expand(b, -1), log_nu[None].expand(b, -1) Z = log_sinkhorn_iterations(couplings, log_mu, log_nu, iters) if ms.size()[0] > 1: norm = norm[:, :, None] Z = Z - norm # multiply probabilities by M+N return Z def arange_like(x, dim: int): return x.new_ones(x.shape[dim]).cumsum(0) - 1 # traceable in 1.1 class SuperGlueT(nn.Module): def __init__(self, config=None): super().__init__() default_config = argparse.Namespace() default_config.descriptor_dim = 128 # default_config.weights = default_config.keypoint_encoder = [32, 64, 128] default_config.GNN_layers = ['self', 'cross'] * 9 default_config.sinkhorn_iterations = 100 default_config.match_threshold = 0.2 # self.config = {**self.default_config, **config} if config is None: self.config = default_config else: self.config = config self.config.GNN_layers = ['self', 'cross'] * self.config.GNN_layer_num # print('WULA!', self.config.GNN_layer_num) self.kenc = KeypointEncoder( self.config.descriptor_dim, self.config.keypoint_encoder) self.tenc = TopoEncoder( self.config.descriptor_dim, [96]) self.gnn = AttentionalGNN( self.config.descriptor_dim, self.config.GNN_layers) self.final_proj = nn.Conv1d( self.config.descriptor_dim, self.config.descriptor_dim, kernel_size=1, bias=True) bin_score = torch.nn.Parameter(torch.tensor(1.)) self.register_parameter('bin_score', bin_score) self.vertex_desc = VertexDescriptor(self.config.descriptor_dim) def forward(self, data): """Run SuperGlue on a pair of keypoints and descriptors""" kpts0, kpts1 = data['keypoints0'].float(), data['keypoints1'].float() ori_mask0, ori_mask1 = data['mask0'].float(), data['mask1'].float() dim_m, dim_n = data['ms'].float(), data['ns'].float() spec0, spec1 = data['adj_mat0'], data['adj_mat1'] mmax = dim_m.int().max() nmax = dim_n.int().max() mask0 = ori_mask0[:, :mmax] mask1 = ori_mask1[:, :nmax] kpts0 = kpts0[:, :mmax] kpts1 = kpts1[:, :nmax] desc0, desc1 = self.vertex_desc(data['image0'], kpts0.float()), self.vertex_desc(data['image1'], kpts1.float()) # spec0, spec1 = np.abs(self.spectral.fit_transform(topo0[0].cpu().numpy())), np.abs(self.spectral.fit_transform(topo1[0].cpu().numpy())) desc0 = desc0 + self.tenc(desc0.new_tensor(spec0)) desc1 = desc1 + self.tenc(desc1.new_tensor(spec1)) mask00 = torch.ones_like(mask0)[:, :, None] * mask0[:, None, :] mask11 = torch.ones_like(mask1)[:, :, None] * mask1[:, None, :] mask01 = torch.ones_like(mask0)[:, :, None] * mask1[:, None, :] mask10 = torch.ones_like(mask1)[:, :, None] * mask0[:, None, :] if kpts0.shape[1] < 2 or kpts1.shape[1] < 2: # no keypoints shape0, shape1 = kpts0.shape[:-1], kpts1.shape[:-1] # print(data['file_name']) return { 'matches0': kpts0.new_full(shape0, -1, dtype=torch.int)[0], # 'matches1': kpts1.new_full(shape1, -1, dtype=torch.int)[0], 'matching_scores0': kpts0.new_zeros(shape0)[0], # 'matching_scores1': kpts1.new_zeros(shape1)[0], 'skip_train': True } file_name = data['file_name'] all_matches = data['all_matches'] if 'all_matches' in data else None# shape = (1, K1) # positional embedding # Keypoint normalization. kpts0 = normalize_keypoints(kpts0, data['image0'].shape) kpts1 = normalize_keypoints(kpts1, data['image1'].shape) # Keypoint MLP encoder. pos0 = self.kenc(kpts0) pos1 = self.kenc(kpts1) desc0 = desc0 + pos0 desc1 = desc1 + pos1 # Multi-layer Transformer network. desc0, desc1 = self.gnn(desc0, desc1, mask00, mask11, mask01, mask10) # Final MLP projection. mdesc0, mdesc1 = self.final_proj(desc0), self.final_proj(desc1) # Compute matching descriptor distance. scores = torch.einsum('bdn,bdm->bnm', mdesc0, mdesc1) # b k1 k2 scores = scores / self.config.descriptor_dim**.5 mask01 = mask0[:, :, None] * mask1[:, None, :] scores = scores.masked_fill(mask01 == 0, float('-inf')) # Run the optimal transport. scores = log_optimal_transport( scores, self.bin_score, iters=self.config.sinkhorn_iterations, ms=dim_m, ns=dim_n) # Get the matches with score above "match_threshold". max0, max1 = scores[:, :-1, :-1].max(2), scores[:, :-1, :-1].max(1) indices0, indices1 = max0.indices, max1.indices mutual0 = arange_like(indices0, 1)[None] == indices1.gather(1, indices0) mutual1 = arange_like(indices1, 1)[None] == indices0.gather(1, indices1) zero = scores.new_tensor(0) mscores0 = torch.where(mutual0, max0.values.exp(), zero) mscores1 = torch.where(mutual1, mscores0.gather(1, indices1), zero) valid0 = mutual0 & (mscores0 > self.config.match_threshold) valid1 = mutual1 & valid0.gather(1, indices1) valid0 = mscores0 > self.config.match_threshold valid1 = valid0.gather(1, indices1) indices0 = torch.where(valid0, indices0, indices0.new_tensor(-1)) indices1 = torch.where(valid1, indices1, indices1.new_tensor(-1)) # check if indexed correctly loss = [] if all_matches is not None: for b in range(len(dim_m)): for i in range(int(dim_m[b])): x = i y = all_matches[b][i].long() loss.append(-scores[b][x][y] ) # check batch size == 1 ? loss_mean = torch.mean(torch.stack(loss)) loss_mean = torch.reshape(loss_mean, (1, -1)) return { 'matches0': indices0, # use -1 for invalid match 'matches1': indices1, # use -1 for invalid match 'matching_scores0': mscores0, # 'matching_scores1': mscores1[0], 'loss': loss_mean, 'skip_train': False, 'accuracy': (((all_matches[:, :mmax] == indices0) & mask0.bool()).sum() / mask0.sum()).item(), 'valid_accuracy': (((all_matches[:, :mmax] == indices0) & (all_matches[:, :mmax] != -1) & mask0.bool()).float().sum() / ((all_matches[:, :mmax] != -1) & mask0.bool()).float().sum()).item(), } else: return { 'matches0': indices0[0], # use -1 for invalid match 'matching_scores0': mscores0[0], 'loss': -1, 'skip_train': True, 'accuracy': -1, 'area_accuracy': -1, 'valid_accuracy': -1, } if __name__ == '__main__': args = argparse.Namespace() args.batch_size = 1 args.gap = 0 args.type = 'train' args.model = 'jolleen' args.action = 'slash' ss = SuperGlue() loader = fetch_dataloader(args) # #print(len(loader)) for data in loader: # p1, p2, s1, s2, mi = data dict1 = data kp1 = dict1['keypoints0'] kp2 = dict1['keypoints1'] p1 = dict1['image0'] p2 = dict1['image1'] # #print(s1) # #print(s1.type) mi = dict1['all_matches'] fname = dict1['file_name'] print(kp1.shape, p1.shape, mi.shape) # #print(mi.size()) # #print(mi) # break a = ss(data) print(dict1['file_name']) print(a['loss']) a['loss'].backward() # print(a['matches0'].size()) # print(a['accuracy'], a['valid_accuracy']) ================================================ FILE: corr/srun.sh ================================================ #!/bin/sh currenttime=`date "+%Y%m%d%H%M%S"` if [ ! -d log ]; then mkdir log fi echo "[Usage] ./srun.sh config_path [train|eval] partition gpunum" # check config exists if [ ! -e $1 ] then echo "[ERROR] configuration file: $1 does not exists!" exit fi if [ ! -d ${expname} ]; then mkdir ${expname} fi echo "[INFO] saving results to, or loading files from: "$expname if [ "$3" == "" ]; then echo "[ERROR] enter partition name" exit fi partition_name=$3 echo "[INFO] partition name: $partition_name" if [ "$4" == "" ]; then echo "[ERROR] enter gpu num" exit fi gpunum=$4 gpunum=$(($gpunum<8?$gpunum:8)) echo "[INFO] GPU num: $gpunum" ((ntask=$gpunum*3)) TOOLS="srun --partition=$partition_name --cpus-per-task=8 --gres=gpu:$gpunum -N 1 --job-name=${config_suffix}" PYTHONCMD="python -u main.py --config $1" if [ $2 == "train" ]; then $TOOLS $PYTHONCMD \ --train elif [ $2 == "eval" ]; then $TOOLS $PYTHONCMD \ --eval fi ================================================ FILE: corr/utils/log.py ================================================ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this open-source project. """ Define the Logger class to print log""" import os import sys import logging from datetime import datetime class Logger: def __init__(self, args, output_dir): log = logging.getLogger(output_dir) if not log.handlers: log.setLevel(logging.DEBUG) # if not os.path.exists(output_dir): # os.mkdir(args.data.output_dir) fh = logging.FileHandler(os.path.join(output_dir,'log.txt')) fh.setLevel(logging.INFO) ch = ProgressHandler() ch.setLevel(logging.DEBUG) formatter = logging.Formatter(fmt='%(asctime)s %(message)s', datefmt='%m/%d/%Y %I:%M:%S') fh.setFormatter(formatter) ch.setFormatter(formatter) log.addHandler(fh) log.addHandler(ch) self.log = log # setup TensorBoard # if args.tensorboard: # from tensorboardX import SummaryWriter # self.writer = SummaryWriter(log_dir=args.output_dir) # else: self.writer = None self.log_per_updates = args.log_per_updates def set_progress(self, epoch, total): self.log.info(f'Epoch: {epoch}') self.epoch = epoch self.i = 0 self.total = total self.start = datetime.now() def update(self, stats): self.i += 1 if self.i % self.log_per_updates == 0: remaining = str((datetime.now() - self.start) / self.i * (self.total - self.i)) remaining = remaining.split('.')[0] updates = stats.pop('updates') stats_str = ' '.join(f'{key}[{val:.8f}]' for key, val in stats.items()) self.log.info(f'> epoch [{self.epoch}] updates[{updates}] {stats_str} eta[{remaining}]') if self.writer: for key, val in stats.items(): self.writer.add_scalar(f'train/{key}', val, updates) if self.i == self.total: self.log.debug('\n') self.log.debug(f'elapsed time: {str(datetime.now() - self.start).split(".")[0]}') def log_eval(self, stats, metrics_group=None): stats_str = ' '.join(f'{key}: {val:.8f}' for key, val in stats.items()) self.log.info(f'valid {stats_str}') if self.writer: for key, val in stats.items(): self.writer.add_scalar(f'valid/{key}', val, self.epoch) # for mode, metrics in metrics_group.items(): # self.log.info(f'evaluation scores ({mode}):') # for key, (val, _) in metrics.items(): # self.log.info(f'\t{key} {val:.4f}') # if self.writer and metrics_group is not None: # for key, val in stats.items(): # self.writer.add_scalar(f'valid/{key}', val, self.epoch) # for key in list(metrics_group.values())[0]: # group = {} # for mode, metrics in metrics_group.items(): # group[mode] = metrics[key][0] # self.writer.add_scalars(f'valid/{key}', group, self.epoch) def __call__(self, msg): self.log.info(msg) class ProgressHandler(logging.Handler): def __init__(self, level=logging.NOTSET): super().__init__(level) def emit(self, record): log_entry = self.format(record) if record.message.startswith('> '): sys.stdout.write('{}\r'.format(log_entry.rstrip())) sys.stdout.flush() else: sys.stdout.write('{}\n'.format(log_entry)) ================================================ FILE: corr/utils/visualize_vtx_corr.py ================================================ import numpy as np import torch import cv2 def make_inter_graph(v2d1, v2d2, topo1, topo2, match12): valid = (match12 != -1) marked2 = np.zeros(len(v2d2)).astype(bool) # print(match12[valid]) marked2[match12[valid]] = True id1toh, id2toh = np.zeros(len(v2d1)), np.zeros(len(v2d2)) id1toh[valid] = np.arange(np.sum(valid)) id2toh[match12[valid]] = np.arange(np.sum(valid)) id1toh[np.invert(valid)] = np.arange(np.sum(1 - valid)) + np.sum(valid) # print(marked2) id2toh[np.invert(marked2)] = len(v2d1) + np.arange(np.sum(np.invert(marked2))) id1toh = id1toh.astype(int) id2toh = id2toh.astype(int) tot_len = len(v2d1) + np.sum(np.invert(marked2)) vin1 = v2d1[valid][:] vin2 = v2d2[match12[valid]][:] vh = 0.5 * (vin1 + vin2) vh = np.concatenate((vh, v2d1[np.invert(valid)], v2d2[np.invert(marked2)]), axis=0) topoh = [[] for ii in range(tot_len)] for node in range(len(topo1)): for nb in topo1[node]: if int(id1toh[nb]) not in topoh[id1toh[node]]: topoh[id1toh[node]].append(int(id1toh[nb])) for node in range(len(topo2)): for nb in topo2[node]: if int(id2toh[nb]) not in topoh[id2toh[node]]: topoh[id2toh[node]].append(int(id2toh[nb])) return vh, topoh def make_inter_graph_valid(v2d1, v2d2, topo1, topo2, match12): valid = (match12 != -1) marked2 = np.zeros(len(v2d2)).astype(bool) # print(match12[valid]) marked2[match12[valid]] = True id1toh, id2toh = np.zeros(len(v2d1)), np.zeros(len(v2d2)) id1toh[valid] = np.arange(np.sum(valid)) id2toh[match12[valid]] = np.arange(np.sum(valid)) id1toh[np.invert(valid)] = np.arange(np.sum(1 - valid)) + np.sum(valid) # print(marked2) id2toh[np.invert(marked2)] = len(v2d1) + np.arange(np.sum(np.invert(marked2))) id1toh = id1toh.astype(int) id2toh = id2toh.astype(int) tot_len = len(v2d1) + np.sum(np.invert(marked2)) vin1 = v2d1[valid][:] vin2 = v2d2[match12[valid]][:] vh = 0.5 * (vin1 + vin2) # vh = np.concatenate((vh, v2d1[np.invert(valid)], v2d2[np.invert(marked2)]), axis=0) # topoh = [[] for ii in range(tot_len)] topoh = [[] for ii in range(np.sum(valid))] for node in range(len(topo1)): if not valid[node]: continue for nb in topo1[node]: if int(id1toh[nb]) not in topoh[id1toh[node]]: if valid[nb]: topoh[id1toh[node]].append(int(id1toh[nb])) for node in range(len(topo2)): if not marked2[node]: continue for nb in topo2[node]: if int(id2toh[nb]) not in topoh[id2toh[node]]: if marked2[nb]: topoh[id2toh[node]].append(int(id2toh[nb])) return vh, topoh def visualize(dict): # print(dict['keypoints0'].size(), flush=True) img1 = ((dict['image0'].permute(1, 2, 0).float().numpy() + 1.0) * 255 / 2).astype(int).copy() img2 = ((dict['image1'].permute(1, 2, 0).float().numpy() + 1.0) * 255 / 2).astype(int).copy() img1p = ((dict['image0'].permute(1, 2, 0).float().numpy() + 1.0) * 255 / 2).astype(int).copy() img2p = ((dict['image1'].permute(1, 2, 0).float().numpy() + 1.0) * 255 / 2).astype(int).copy() img1[:, :, 0] += 255 img1[:, :, 1] += 180 img1[:, :, 2] += 180 img1[img1 > 255] = 255 img2[:, :, 0] += 255 img2[:, :, 1] += 180 img2[:, :, 2] += 180 img2[img2 > 255] = 255 img1p[:, :, 0] += 255 img1p[:, :, 1] += 180 img1p[:, :, 2] += 180 img1p[img1p > 255] = 255 img2p[:, :, 0] += 255 img2p[:, :, 1] += 180 img2p[:, :, 2] += 180 img2p[img2p > 255] = 255 img1, img2, img1p, img2p = img1.astype(np.uint8), img2.astype(np.uint8), img1p.astype(np.uint8), img2p.astype(np.uint8) # print(v2d1.shape, img1.shape, flush=True) v2d1 = dict['keypoints0'].numpy().astype(int) v2d2 = dict['keypoints1'].numpy().astype(int) topo1 = dict['topo0'] topo2 = dict['topo1'] # print(topo1, flush=True) # for node, nbs in enumerate(dict['topo0']): # for nb in nbs: # cv2.line(img1, [v2d1[node][0], v2d1[node][1]], [v2d1[nb][0], v2d1[nb][1]], [255, 180, 180], 2) id1 = np.arange(len(v2d1)) id2 = np.arange(len(v2d2)) all_matches = dict['all_matches'].cpu().int().data.numpy() predicted = dict['matches0'].cpu().data.numpy()[0] predicted1 = dict['matches1'].cpu().data.numpy()[0] colors1_gt, colors2_gt = {}, {} colors1_pred, colors2_pred = {}, {} cross1_pred, cross2_pred = {}, {} for index in id1: color = [np.random.randint(0, 256), np.random.randint(0, 256), np.random.randint(0, 256)] # print(predicted.shape, flush=True) if all_matches[index] != -1: colors2_gt[all_matches[index]] = color if predicted[index] != -1: colors2_pred[predicted[index]] = color colors1_gt[index] = color if all_matches[index] != -1 else [0, 0, 0] colors1_pred[index] = color if predicted[index] != -1 else [0, 0, 0] # if predicted[index] == -1 and colors1_pred[index] != [0, 0, 0]: # color = [np.random.randint(0, 256), np.random.randint(0, 256), np.random.randint(0, 256)] # colors1_pred[index] = [0, 0, 0] # colors2_pred.pop(all_matches[index]) # whether predicted correctly if predicted[index] != all_matches[index]: cross1_pred[index] = True if predicted[index] != -1: cross2_pred[predicted[index]] = True for i, p in enumerate(v2d1): ii = id1[i] # print(ii) cv2.circle(img1, [int(p[0]), int(p[1])], 1, colors1_gt[i], 2) if ii in cross1_pred and cross1_pred[ii]: cv2.rectangle(img1p, [int(p[0]-1), int(p[1]-1)], [int(p[0]+1), int(p[1]+1)], colors1_pred[i],-1) else: cv2.circle(img1p, [int(p[0]), int(p[1])], 1, colors1_pred[i], 2) for ii in id2: # print(ii) color = [0, 0, 0] this_is_umatched = 1 if ii not in colors2_gt: colors2_gt[ii] = color if ii not in colors2_pred: colors2_pred[ii] = color for i, p in enumerate(v2d2): ii = id2[i] # print(p) cv2.circle(img2, [int(p[0]), int( p[1])], 1, colors2_gt[ii], 2) if ii in cross2_pred and cross2_pred[ii]: cv2.rectangle(img2p, [int(p[0]-1), int(p[1]-1)], [int(p[0]+1), int(p[1]+1)], colors2_pred[i], -1) else: cv2.circle(img2p, [int(p[0]), int(p[1])], 1, colors2_pred[i], 2) # print('Unmatched in Img 2: ', , '%') # unmatched_all.append(100 - unmatched * 100.0/len(v2d2)) cv2.putText(img2p, str(round(np.sum(all_matches == predicted) * 100.0 / len(predicted), 2)).format('.2f') + '%', \ (500, 100), cv2.FONT_HERSHEY_PLAIN, 3, (0, 0, 255), 2) vh_gt, topoh_gt = make_inter_graph(v2d1, v2d2, topo1, topo2, all_matches) vh_pred, topoh_pred = make_inter_graph(v2d1, v2d2, topo1, topo2, predicted) vh_gt_valid, topoh_gt_valid = make_inter_graph_valid(v2d1, v2d2, topo1, topo2, all_matches) vh_pred_valid, topoh_pred_valid = make_inter_graph_valid(v2d1, v2d2, topo1, topo2, predicted) v2d1t = ((v2d2[predicted] + v2d1) * 0.5).astype(int) v2d2t = ((v2d1[predicted1] + v2d2) * 0.5).astype(int) vh_gt = vh_gt.astype(int) vh_gt_valid = vh_gt_valid.astype(int) vh_pred = vh_pred.astype(int) vh_pred_valid = vh_pred_valid.astype(int) imgh = np.zeros_like(img1) + 255 imghp = np.zeros_like(img1) + 255 imgh_valid = np.zeros_like(img1) + 255 imghp_valid = np.zeros_like(img1) + 255 for node, nbs in enumerate(topoh_gt): for nb in nbs: cv2.line(imgh, [vh_gt[node][0], vh_gt[node][1]], [vh_gt[nb][0], vh_gt[nb][1]], [0, 0, 0], 2) for node, nbs in enumerate(topoh_pred): for nb in nbs: cv2.line(imghp, [vh_pred[node][0], vh_pred[node][1]], [vh_pred[nb][0], vh_pred[nb][1]], [0, 0, 0], 2) for node, nbs in enumerate(topoh_gt_valid): for nb in nbs: cv2.line(imgh_valid, [vh_gt_valid[node][0], vh_gt_valid[node][1]], [vh_gt_valid[nb][0], vh_gt_valid[nb][1]], [0, 0, 0], 2) for node, nbs in enumerate(topoh_pred_valid): for nb in nbs: cv2.line(imghp_valid, [vh_pred_valid[node][0], vh_pred_valid[node][1]], [vh_pred_valid[nb][0], vh_pred_valid[nb][1]], [0, 0, 0], 2) # for node, nbs in enumerate(topo1): # for nb in nbs: # cv2.line(imghp_valid, [v2d1t[node][0], v2d1t[node][1]], [v2d1t[nb][0], v2d1t[nb][1]], [0, 0, 0], 2) # for node, nbs in enumerate(topo2): # for nb in nbs: # cv2.line(imghp_valid, [v2d2t[node][0], v2d2t[node][1]], [v2d2t[nb][0], v2d2t[nb][1]], [0, 0, 0], 2) im_h = cv2.hconcat([img1, img2]) im_hp = cv2.hconcat([img1p, img2p]) img_inter = cv2.hconcat([imgh, imghp]) img_inter_valid = cv2.hconcat([imgh_valid, imghp_valid]) im_hv = cv2.vconcat([im_h, im_hp, img_inter, img_inter_valid]) return im_hv ================================================ FILE: corr/vtx_matching.py ================================================ """ This script handling the training process. """ import os import time import random import argparse import torch import torch.nn as nn import torch.optim as optim import torch.utils.data from datasets import fetch_dataloader import random from utils.log import Logger from torch.optim import * import warnings from tqdm import tqdm import itertools import pdb import numpy as np import models import datetime import sys import json import cv2 from utils.visualize_vtx_corr import visualize import matplotlib.cm as cm # from models.utils import make_matching_seg_plot warnings.filterwarnings('ignore') import matplotlib.pyplot as plt import pdb class VtxMat(): def __init__(self, args): self.config = args torch.backends.cudnn.benchmark = True torch.multiprocessing.set_sharing_strategy('file_system') self._build() def train(self): opt = self.config print(opt) model = self.model if hasattr(self.config, 'init_weight'): checkpoint = torch.load(self.config.init_weight) model.load_state_dict(checkpoint['model']) optimizer = self.optimizer schedular = self.schedular mean_loss = [] log = Logger(self.config, self.expdir) updates = 0 # set seed random.seed(opt.seed) torch.manual_seed(opt.seed) torch.cuda.manual_seed(opt.seed) np.random.seed(opt.seed) # start training for epoch in range(1, opt.epoch+1): np.random.seed(opt.seed + epoch) train_loader = self.train_loader log.set_progress(epoch, len(train_loader)) batch_loss = 0 batch_acc = 0 batch_valid_acc = 0 batch_iter = 0 model.train() avg_time = 0 avg_num = 0 # torch.cuda.synchronize() for i, pred in enumerate(train_loader): # tstart = time.time() # print(pred['file_name']) data = model(pred) if not data['skip_train']: loss = data['loss'] / opt.batch_size batch_loss += loss.item() batch_acc += data['accuracy'] batch_valid_acc += data['valid_accuracy'] loss.backward() batch_iter += 1 else: print('Skip!') ## Accumulate gradient for batch training if ((i + 1) % opt.batch_size == 0) or (i + 1 == len(train_loader)): optimizer.step() optimizer.zero_grad() batch_iter = 1 if batch_iter == 0 else batch_iter stats = { 'updates': updates, 'loss': batch_loss, 'accuracy': batch_acc / batch_iter, 'valid_accuracy': batch_valid_acc / batch_iter } log.update(stats) updates += 1 batch_loss = 0 batch_acc = 0 batch_valid_acc = 0 batch_iter = 0 # torch.cuda.synchronize() # avg_num += 1 # for name, params in model.named_parameters(): # print('-->name:, ', name, '-->grad mean', params.grad.mean()) # print("All time is ", avg_time, "AVG time is ", avg_time * 1.0 /avg_num, "number is ", avg_num, flush=True) # save checkpoint if epoch % opt.save_per_epochs == 0 or epoch == 1: checkpoint = { 'model': model.state_dict(), 'config': opt, 'epoch': epoch } filename = os.path.join(self.ckptdir, f'epoch_{epoch}.pt') torch.save(checkpoint, filename) # validate if epoch % opt.test_freq == 0: if not os.path.exists(os.path.join(self.visdir, 'epoch' + str(epoch))): os.mkdir(os.path.join(self.visdir, 'epoch' + str(epoch))) eval_output_dir = os.path.join(self.visdir, 'epoch' + str(epoch)) test_loader = self.test_loader with torch.no_grad(): # Visualize the matches. mean_acc = [] mean_valid_acc = [] model.eval() for i_eval, data in enumerate(tqdm(test_loader, desc='Predicting Vtx Corr...')): pred = model(data) # for k, v in data.items(): # pred[k] = v[0] # pred = {**pred, **data} mean_acc.append(pred['accuracy']) mean_valid_acc.append(pred['valid_accuracy']) log.log_eval({ 'updates': opt.epoch, 'Accuracy': np.mean(mean_acc), 'Valid Accuracy': np.mean(mean_valid_acc), }) print('Epoch [{}/{}]], Acc.: {:.4f}, Valid Acc.{:.4f}' .format(epoch, opt.epoch, np.mean(mean_acc), np.mean(mean_valid_acc)) ) sys.stdout.flush() # make_matching_plot( # image0, image1, kpts0, kpts1, mkpts0, mkpts1, color, # text, viz_path, stem, stem, True, # True, False, 'Matches') self.schedular.step() def eval(self): train_action = ['breakdance_1990', 'capoeira', 'chapa-giratoria', 'fist_fight', 'flying_knee', 'freehang_climb', 'running', 'shove', 'magic', 'tripping'] test_action = ['great_sword_slash', 'hip_hop_dancing'] train_model = ['ganfaul', 'girlscout', 'jolleen', 'kachujin', 'knight', 'maria_w_jj', 'michelle', 'peasant_girl', 'timmy', 'uriel_a_plotexia'] test_model = ['police', 'warrok'] log = Logger(self.config, self.expdir) with torch.no_grad(): model = self.model.eval() config = self.config epoch_tested = self.config.testing.ckpt_epoch ckpt_path = os.path.join(self.ckptdir, f"epoch_{epoch_tested}.pt") # self.device = torch.device('cuda' if config.cuda else 'cpu') print("Evaluation...") checkpoint = torch.load(ckpt_path) model.load_state_dict(checkpoint['model']) model.eval() if not os.path.exists(os.path.join(self.evaldir, 'epoch' + str(epoch_tested))): os.mkdir(os.path.join(self.evaldir, 'epoch' + str(epoch_tested))) if not os.path.exists(os.path.join(self.evaldir, 'epoch' + str(epoch_tested), 'jsons')): os.mkdir(os.path.join(self.evaldir, 'epoch' + str(epoch_tested), 'jsons')) eval_output_dir = os.path.join(self.evaldir, 'epoch' + str(epoch_tested)) test_loader = self.test_loader print(len(test_loader)) mean_acc = [] mean_valid_acc = [] mean_invalid_acc = [] # 144 data # 10x10 is for training , 2x10 (unseen model) + 10x2 (unseen action) + 2x2 (unseen model unseen action) is for test # record the accuracy for each mean_model_acc = [] mean_model_valid_acc = [] mean_action_acc = [] mean_action_valid_acc = [] mean_none_acc = [] mean_none_valid_acc = [] mean_matched = [] for i_eval, pred in enumerate(tqdm(test_loader, desc='Predicting Vtx Corr...')): data = model(pred) for k, v in pred.items(): pred[k] = v[0] pred = {**pred, **data} mean_acc.append(pred['accuracy']) mean_valid_acc.append(pred['valid_accuracy']) this_pred = (pred['matches0'] != -1).float().cpu().data.numpy().astype(np.float32) mean_matched.append(np.mean( this_pred)) unmarked = True for model_name in train_model: if model_name in pred['file_name']: mean_model_acc.append(pred['accuracy']) mean_model_valid_acc.append(pred['valid_accuracy']) unmarked = False break for action_name in train_action: if action_name in pred['file_name']: mean_action_acc.append(pred['accuracy']) mean_action_valid_acc.append(pred['valid_accuracy']) unmarked = False break if unmarked: mean_none_acc.append(pred['accuracy']) mean_action_valid_acc.append(pred['valid_accuracy']) if 'invalid_accuracy' in pred and pred['invalid_accuracy'] is not None: mean_invalid_acc.append(pred['invalid_accuracy']) img_vis = visualize(pred) cv2.imwrite(os.path.join(eval_output_dir, pred['file_name'].replace('/', '_') + '.jpg'), img_vis) log.log_eval({ 'updates': self.config.testing.ckpt_epoch, 'Accuracy': np.mean(mean_acc), 'Accuracy (Matched)': np.mean(mean_valid_acc), 'Unseen Action Accuracy': np.mean(mean_model_acc), 'Unseen Action Accuracy (Matched)': np.mean(mean_model_valid_acc), 'Unseen Model Accuracy': np.mean(mean_action_acc), 'Unseen Model Accuracy (Matched)': np.mean(mean_action_valid_acc), 'Unseen Both Accuracy': np.mean(mean_none_acc), 'Unseen Both Valid Accuracy': np.mean(mean_none_valid_acc), 'Matching Rate': np.mean(mean_matched) }) # print ('Epoch [{}/{}]], Acc.: {:.4f}, Valid Acc.{:.4f}' # .format(epoch, opt.epoch, np.mean(mean_acc), np.mean(mean_valid_acc)) ) sys.stdout.flush() def _build(self): config = self.config self.start_epoch = 0 self._dir_setting() self._build_model() if not(hasattr(config, 'need_not_train_data') and config.need_not_train_data): self._build_train_loader() if not(hasattr(config, 'need_not_test_data') and config.need_not_train_data): self._build_test_loader() self._build_optimizer() def _build_model(self): """ Define Model """ config = self.config if hasattr(config.model, 'name'): print(f'Experiment Using {config.model.name}') model_class = getattr(models, config.model.name) model = model_class(config.model) else: raise NotImplementedError("Wrong Model Selection") model = nn.DataParallel(model) self.model = model.cuda() def _build_train_loader(self): config = self.config self.train_loader = fetch_dataloader(config.data.train, type='train') def _build_test_loader(self): config = self.config self.test_loader = fetch_dataloader(config.data.test, type='test') def _build_optimizer(self): #model = nn.DataParallel(model).to(device) config = self.config.optimizer try: optim = getattr(torch.optim, config.type) except Exception: raise NotImplementedError('not implemented optim method ' + config.type) self.optimizer = optim(itertools.chain(self.model.module.parameters(), ), **config.kwargs) self.schedular = torch.optim.lr_scheduler.MultiStepLR(self.optimizer, **config.schedular_kwargs) def _dir_setting(self): data = self.config.data self.expname = self.config.expname self.experiment_dir = os.path.join(".", "experiments") self.expdir = os.path.join(self.experiment_dir, self.expname) if not os.path.exists(self.expdir): os.mkdir(self.expdir) self.visdir = os.path.join(self.expdir, "vis") # -- imgs, videos, jsons if not os.path.exists(self.visdir): os.mkdir(self.visdir) self.ckptdir = os.path.join(self.expdir, "ckpt") if not os.path.exists(self.ckptdir): os.mkdir(self.ckptdir) self.evaldir = os.path.join(self.expdir, "eval") if not os.path.exists(self.evaldir): os.mkdir(self.evaldir) # self.ckptdir = os.path.join(self.expdir, "ckpt") # if not os.path.exists(self.ckptdir): # os.mkdir(self.ckptdir) ================================================ FILE: data/README.md ================================================ ================================================ FILE: datasets/__init__.py ================================================ from .ml_seq import fetch_dataloader from .vd_seq import fetch_videoloader __all__ = ['fetch_dataloader', 'fetch_videoloader'] ================================================ FILE: datasets/ml_seq.py ================================================ import numpy as np import torch import torch.utils.data as data import torch.nn.functional as F import os import math import random from glob import glob import os.path as osp import sys import argparse import cv2 from collections import Counter import time import json import sknetwork from sknetwork.embedding import Spectral import scipy def read_json(file_path): """ input: json file path output: 2d vertex, connections and vertex index in original 3D domain """ with open(file_path) as file: data = json.load(file) vertex2d = np.array(data['vertex location']) topology = data['connection'] index = np.array(data['original index']) return vertex2d, topology, index def matched_motion(v2d1, v2d2, match12, motion_pre=None): motion = np.zeros_like(v2d1) motion[match12 != -1] = v2d2[match12[match12 != -1]] - v2d1[match12 != -1] if motion_pre is not None: motion[match12 != -1] = motion[match12 != -1] + motion_pre[match12[match12 != -1]] return motion def unmatched_motion(topo1, v2d1, motion12, match12): pos = np.arange(len(topo1)) masked = (match12 == -1) round = 0 former_len = 0 while(len(pos[masked]) > 0): this_len = len(pos[masked]) if former_len == this_len: break former_len = this_len round += 1 for v in pos[masked]: unmatched = masked[topo1[v]] if unmatched.sum() != len(topo1[v]): motion12[v] = np.average(motion12[topo1[v]][np.invert(unmatched)], axis=0) masked[v] = False if len(pos[masked] > 0): # find the neast point for each unlabeled point index = ((v2d1[pos[masked]][:, None, :] - v2d1[pos[np.invert(masked)]]) ** 2).sum(2).argmin(1) motion12[pos[masked]] = motion12[pos[np.invert(masked)]][index] masked[pos[masked]] = False return motion12 def ids_to_mat(id1, id2): """ inputs are two list of vertex index in original 3D mesh """ corr1 = np.zeros(len(id1)) - 1.0 corr2 = np.zeros(len(id2)) - 1.0 id1 = np.array(id1).astype(int)[:, None] id2 = np.array(id2).astype(int) mat = (id1 == id2) pos12 = np.arange(len(id2))[None].repeat(len(id1), 0) pos21 = np.arange(len(id1))[None].repeat(len(id2), 0) corr1[mat.astype(int).sum(1).astype(bool)] = pos12[mat] corr2[mat.transpose().astype(int).sum(1).astype(bool)] = pos21[mat.transpose()] return mat, corr1, corr2 def adj_matrix(topology): """ topology is the adj table; returns adj matrix """ gsize = len(topology) adj = np.zeros((gsize, gsize)).astype(float) for v in range(gsize): adj[v][v] = 1.0 for nb in topology[v]: adj[v][nb] = 1.0 adj[nb][v] = 1.0 return adj class MixamoLineArtMotionSequence(data.Dataset): def __init__(self, root, gap=0, split='train', model=None, action=None, mode='train', use_vs=False, max_len=3050): """ input: root: the root folder of the line art data gap: how many frames between two frames. gap should be an odd numbe. split: train or test model: indicate a specific character (default None) action: indicate a specific action (default None) output: image of sources (0, 1) and output (0.5) topo0, topo1 v2d0, v2d1 corr12, corr21 motion0-->0.5, motion1-->0.5 visibility0-->0.5, visibility 1-->0.5 """ super(MixamoLineArtMotionSequence, self).__init__() self.gap = gap if model == 'None': model = None if action == 'None': action = None assert(gap%2 != 0) self.is_train = True if mode == 'train' else False self.is_eval = True if mode == 'eval' else False # self.is_train = False self.max_len = max_len self.image_list = [] self.label_list = [] label_root = osp.join(root, split, 'labels') self.use_vs = False if use_vs: print('>>>>>>>> Using VS labels') self.use_vs = True label_root = osp.join(root, split, 'labels_vs') image_root = osp.join(root, split, 'frames') self.spectral = Spectral(64, normalized=False) for clip in os.listdir(image_root): skip = False if model != None: for mm in model: if mm in clip: skip = True if action != None: for aa in action: if aa in clip: skip = True if skip: continue image_list = sorted(glob(osp.join(image_root, clip, '*.png'))) label_list = sorted(glob(osp.join(label_root, clip, '*.json'))) if len(image_list) != len(label_list): print(clip, flush=True) continue for i in range(len(image_list) - (gap+1)): self.image_list += [ [image_list[jj] for jj in range(i, i + gap + 2)] ] for i in range(len(label_list) - (gap+1)): self.label_list += [ [label_list[jj] for jj in range(i, i + gap + 2)] ] # print(clip) print('Len of Frame is ', len(self.image_list), flush=True) print('Len of Label is ', len(self.label_list), flush=True) def __getitem__(self, index): # load image/label files # load labels: # (a) read json (b) load image (c) make pseudo labels # image crop to a square (720x720) before input, 2d label same operation # index to index matching # test does not need index matching index = index % len(self.image_list) file_name = self.label_list[index][len(self.label_list[index])//2][:-4] imgt = [cv2.imread(self.image_list[index][ii]) for ii in range(0, len(self.image_list[index]))] labelt = [] for ii in range(0, len(self.label_list[index])): v, t, id = read_json(self.label_list[index][ii]) v[v > imgt[0].shape[0] - 1] = imgt[0].shape[0] - 1 v[v < 0] = 0 labelt.append({'keypoints': v.astype(int), 'topo': t, 'id': id}) # make motion pseudo label motion = None motion01 = None start_frame = 0 gap = self.gap // 2 + 1 ######### forward direction for ii in reversed(range(start_frame + 1, start_frame + 2*gap + 1)): img1 = imgt[ii - 1] img2 = imgt[ii] v2d1 = labelt[ii - 1]['keypoints'].astype(int) v2d2 = labelt[ii]['keypoints'].astype(int) topo1 = labelt[ii - 1]['topo'] topo2 = labelt[ii ]['topo'] id1 = labelt[ii - 1]['id'] id2 = labelt[ii]['id'] if self.use_vs: id1 = np.arange(len(id1)) id2 = np.arange(len(id2)) _, match12, matc21 = ids_to_mat(id1, id2) if ii <= start_frame + gap: motion01 = matched_motion(v2d1, v2d2, match12.astype(int), motion01) motion01 = unmatched_motion(topo1, v2d1, motion01, match12.astype(int)) motion = matched_motion(v2d1, v2d2, match12.astype(int), motion) motion = unmatched_motion(topo1, v2d1, motion, match12.astype(int)) motion0 = motion.copy() img2 = imgt[start_frame + gap] v2d1 = labelt[start_frame]['keypoints'].astype(int) source0_topo = labelt[start_frame]['topo'] target = cv2.erode(img2, np.ones((3, 3), np.uint8), iterations=1) shift_plabel = v2d1 + motion01 visible = np.ones(len(v2d1)).astype(float) visible[shift_plabel[:, 0] < 0] = 0 visible[shift_plabel[:, 0] >= imgt[0].shape[0]] = 0 visible[shift_plabel[:, 1] < 0] = 0 visible[shift_plabel[:, 1] >= imgt[0].shape[0]] = 0 # vertex visibility visible[visible == 1] = (target[:, :, 0][shift_plabel[visible == 1][:, 1], shift_plabel[visible == 1][:, 0]] < 255 ).astype(float) visible01 = visible.copy() v2d1s = shift_plabel # edge visibility for node, nbs in enumerate(source0_topo): for nb in nbs: if visible01[nb] and visible01[node] and ((v2d1s[node] - v2d1s[nb]) ** 2).sum() / (((v2d1[node] - v2d1[nb]) ** 2).sum() + 1e-7) > 25: visible01[nb] = False visible01[node] = False ######## backward direction motion = None motion21 = None for ii in range(start_frame + 1, start_frame + gap + gap + 1): img2 = imgt[ii - 1] img1 = imgt[ii] v2d2 = labelt[ii - 1]['keypoints'].astype(int) v2d1 = labelt[ii]['keypoints'].astype(int) topo2 = labelt[ii - 1]['topo'] topo1 = labelt[ii ]['topo'] id1 = labelt[ii]['id'] id2 = labelt[ii - 1]['id'] if self.use_vs: id1 = np.arange(len(id1)) id2 = np.arange(len(id2)) _, match12, _ = ids_to_mat(id1, id2) if ii >= start_frame + gap + 1: motion21 = matched_motion(v2d1, v2d2, match12.astype(int), motion21) motion21 = unmatched_motion(topo1, v2d1, motion21, match12.astype(int)) motion = matched_motion(v2d1, v2d2, match12.astype(int), motion) motion = unmatched_motion(topo1, v2d1, motion, match12.astype(int)) motion2 = motion.copy() img1 = imgt[start_frame + 2*gap] img2 = imgt[start_frame + gap] v2d1 = labelt[start_frame + 2*gap]['keypoints'].astype(int) source2_topo = labelt[start_frame + 2*gap]['topo'] shift_plabel = v2d1 + motion21 visible = np.ones(len(v2d1)).astype(float) visible[shift_plabel[:, 0] < 0] = 0 visible[shift_plabel[:, 0] >= imgt[0].shape[0]] = 0 visible[shift_plabel[:, 1] < 0] = 0 visible[shift_plabel[:, 1] >= imgt[0].shape[0]] = 0 visible[visible == 1] = (target[:, :, 0][shift_plabel[visible == 1][:, 1], shift_plabel[visible == 1][:, 0]] < 255 ).astype(float) visible21 = visible.copy() v2d1s = shift_plabel for node, nbs in enumerate(source2_topo): for nb in nbs: if visible21[nb] and visible21[node] and ((v2d1s[node] - v2d1s[nb]) ** 2).sum() / (((v2d1[node] - v2d1[nb]) ** 2).sum() + 1e-7) > 25: visible21[nb] = False visible21[node] = False ###### prepare other data img2 = imgt[-1] img1 = imgt[0] v2d2 = labelt[-1]['keypoints'].astype(int) v2d1 = labelt[0]['keypoints'].astype(int) topo2 = labelt[-1]['topo'] topo1 = labelt[0]['topo'] m, n = len(v2d1), len(v2d2) if len(img1.shape) == 2: img1 = np.tile(img1[...,None], (1, 1, 3)) img2 = np.tile(img2[...,None], (1, 1, 3)) else: img1 = img1[..., :3] img2 = img2[..., :3] img1 = torch.from_numpy(img1).permute(2, 0, 1).float() * 2 / 255.0 - 1.0 img2 = torch.from_numpy(img2).permute(2, 0, 1).float() * 2 / 255.0 - 1.0 imgt = torch.from_numpy(imgt[start_frame + gap]).permute(2, 0, 1).float() * 2 / 255.0 - 1.0 v2d1 = torch.from_numpy(v2d1) v2d2 = torch.from_numpy(v2d2) visible01 = torch.from_numpy(visible01) visible21 = torch.from_numpy(visible21) motion0 = torch.from_numpy(motion0) motion2 = torch.from_numpy(motion2) v2d1[v2d1 > imgt[0].shape[0] - 1 ] = imgt[0].shape[0] - 1 v2d1[v2d1 < 0] = 0 v2d2[v2d2 > imgt[0].shape[1] - 1] = imgt[0].shape[1] - 1 v2d2[v2d2 < 0] = 0 id1 = labelt[start_frame]['id'] id2 = labelt[-1]['id'] if self.use_vs: id1 = np.arange(len(id1)) id2 = np.arange(len(id2)) mat_index, corr1, corr2 = ids_to_mat(id1, id2) mat_index = torch.from_numpy(mat_index).float() corr1 = torch.from_numpy(corr1).float() corr2 = torch.from_numpy(corr2).float() if self.is_train: v2d1 = torch.nn.functional.pad(v2d1, (0, 0, 0, self.max_len - m), mode='constant', value=0) v2d2 = torch.nn.functional.pad(v2d2, (0, 0, 0, self.max_len - n), mode='constant', value=0) corr1 = torch.nn.functional.pad(corr1, (0, self.max_len - m), mode='constant', value=0) corr2 = torch.nn.functional.pad(corr2, (0, self.max_len - n), mode='constant', value=0) motion0 = torch.nn.functional.pad(motion0, (0, 0, 0, self.max_len - m), mode='constant', value=0) motion2 = torch.nn.functional.pad(motion2, (0, 0, 0, self.max_len - n), mode='constant', value=0) visible01 = torch.nn.functional.pad(visible01, (0, self.max_len - m), mode='constant', value=0) visible21 = torch.nn.functional.pad(visible21, (0, self.max_len - n), mode='constant', value=0) mask0, mask1 = torch.zeros(self.max_len).float(), torch.zeros(self.max_len).float() mask0[:m] = 1 mask1[:n] = 1 else: mask0, mask1 = torch.ones(m).float(), torch.ones(n).float() for ii in range(len(topo1)): # if not len(topo1[ii]): topo1[ii].append(ii) for ii in range(len(topo2)): topo2[ii].append(ii) adj1 = sknetwork.data.from_adjacency_list(topo1, matrix_only=True, reindex=False).toarray() adj2 = sknetwork.data.from_adjacency_list(topo2, matrix_only=True, reindex=False).toarray() try: spec0, spec1 = np.abs(self.spectral.fit_transform(adj1)), np.abs(self.spectral.fit_transform(adj2)) except: print('>>>>' + file_name, flush=True) spec0, spec1 = np.zeros((len(adj1), 64)), np.zeros((len(adj2), 64)) # else: # print('<<<<' + file_name, flush=True) # adj2 = adj2 + np.eye(len(adj2)) if self.is_eval: return{ 'keypoints0': v2d1, 'keypoints1': v2d2, 'topo0': [topo1], 'topo1': [topo2], # 'id0': id1, # 'id1': id2, 'adj_mat0': adj1, 'adj_mat1': adj2, 'spec0': spec0, 'spec1': spec1, 'imaget': imgt, 'image0': img1, 'image1': img2, 'motion0': motion0, 'motion1': motion2, 'visibility0': visible01, 'visibility1': visible21, 'all_matches': corr1, 'm01': corr1, 'm10': corr2, 'ms': m, 'ns': n, 'mask0': mask0, 'mask1': mask1, 'file_name': file_name, # 'with_match': True } elif not self.is_train: return{ 'keypoints0': v2d1, 'keypoints1': v2d2, # 'topo0': [topo1], # 'topo1': [topo2], # 'id0': id1, # 'id1': id2, 'adj_mat0': adj1, 'adj_mat1': adj2, 'spec0': spec0, 'spec1': spec1, 'imaget': imgt, 'image0': img1, 'image1': img2, 'motion0': motion0, 'motion1': motion2, 'visibility0': visible01, 'visibility1': visible21, 'all_matches': corr1, 'm01': corr1, 'm10': corr2, 'ms': m, 'ns': n, 'mask0': mask0, 'mask1': mask1, 'file_name': file_name, # 'with_match': True } else: return{ 'keypoints0': v2d1, 'keypoints1': v2d2, # 'topo0': topo1, # 'topo1': topo2, # 'id0': id1, # 'id1': id2, 'adj_mat0': adj1, 'adj_mat1': adj2, 'spec0': spec0, 'spec1': spec1, 'imaget': imgt, 'motion0': motion0, 'motion1': motion2, 'visibility0': visible01, 'visibility1': visible21, 'image0': img1, 'image1': img2, 'all_matches': corr1, 'm01': corr1, 'm10': corr2, 'ms': m, 'ns': n, 'mask0': mask0, 'mask1': mask1, 'file_name': file_name, # 'with_match': True } def __rmul__(self, v): self.label_list = v * self.label_list self.image_list = v * self.image_list return self def __len__(self): return len(self.image_list) def worker_init_fn(worker_id): np.random.seed(np.random.get_state()[1][0] + worker_id) def fetch_dataloader(args, type='train',): lineart = MixamoLineArtMotionSequence(root=args.root, gap=args.gap, split=args.type, model=args.model, action=args.action, mode=args.mode if hasattr(args, 'mode') else 'train', use_vs=args.use_vs if hasattr(args, 'use_vs') else False) if args.mode == 'train': lineart = MixamoLineArtMotionSequence(root=args.root, gap=args.gap, split=args.type, model=args.model, action=args.action, mode=args.mode if hasattr(args, 'mode') else 'train') if args.mode == 'train': loader = data.DataLoader(lineart, batch_size=args.batch_size, pin_memory=True, shuffle=True, num_workers=16, drop_last=True, worker_init_fn=worker_init_fn) else: loader = data.DataLoader(lineart, batch_size=args.batch_size, pin_memory=True, shuffle=False, num_workers=8) return loader ================================================ FILE: datasets/vd_seq.py ================================================ import numpy as np import torch import torch.utils.data as data import torch.nn.functional as F # import networkx as nx import os import math import random from glob import glob import os.path as osp import sys import argparse import cv2 from collections import Counter import time import json import sknetwork from sknetwork.embedding import Spectral import scipy def read_json(file_path): """ input: json file path output: 2d vertex """ with open(file_path) as file: data = json.load(file) vertex2d = np.array(data['vertex location']) topology = data['connection'] index = np.array(data['original index']) # index, vertex2d, topology = union_pixel(vertex2d, index, topology) # index, vertex2d, topology = union_pixel2d(vertex2d, index, topology) return vertex2d, topology, index class VideoLinSeq(data.Dataset): def __init__(self, root, split='train'): """ input: root: the root folder of the line art data split: split folder output: image of sources (0, 1) and output (0.5) topo0, topo1 v2d0, v2d1 """ super(VideoLinSeq, self).__init__() self.image_list = [] self.label_list = [] label_root = osp.join(root, split, 'labels') image_root = osp.join(root, split, 'frames') self.spectral = Spectral(64, normalized=False) for clip in os.listdir(image_root): label_list = sorted(glob(osp.join(label_root, clip, '*.json'))) for i in range(len(label_list) - 1): self.label_list += [ [label_list[jj] for jj in range(i, i + 2)] ] self.image_list += [ [label_list[jj].replace('labels', 'frames').replace('.json', '.png') for jj in range(i, i + 2)] ] # print(clip) print('Len of Frame is ', len(self.image_list), flush=True) print('Len of Label is ', len(self.label_list), flush=True) def __getitem__(self, index): # prepare images index = index % len(self.image_list) file_name0 = self.label_list[index][0][:-5].split('/')[-1] file_name1 = self.label_list[index][-1][:-5].split('/')[-1] folder0 = self.label_list[index][0][:-4].split('/')[-2] folder1 = self.label_list[index][-1][:-4].split('/')[-2] imgt = [cv2.imread(self.image_list[index][ii]) for ii in range(0, len(self.image_list[index]))] labelt = [] for ii in range(0, len(self.label_list[index])): v, t, id = read_json(self.label_list[index][ii]) v[v > imgt[0].shape[0] - 1] = imgt[0].shape[0] - 1 v[v < 0] = 0 labelt.append({'keypoints': v.astype(int), 'topo': t, 'id': id}) # make motion pseudo label ###### prepare other data img2 = imgt[-1] img1 = imgt[0] v2d2 = labelt[-1]['keypoints'].astype(int) v2d1 = labelt[0]['keypoints'].astype(int) topo2 = labelt[-1]['topo'] topo1 = labelt[0]['topo'] m, n = len(v2d1), len(v2d2) if len(img1.shape) == 2: img1 = np.tile(img1[...,None], (1, 1, 3)) img2 = np.tile(img2[...,None], (1, 1, 3)) else: img1 = img1[..., :3] img2 = img2[..., :3] img1 = torch.from_numpy(img1).permute(2, 0, 1).float() * 2 / 255.0 - 1.0 img2 = torch.from_numpy(img2).permute(2, 0, 1).float() * 2 / 255.0 - 1.0 v2d1 = torch.from_numpy(v2d1) v2d2 = torch.from_numpy(v2d2) mask0, mask1 = torch.ones(m).float(), torch.ones(n).float() v2d1[v2d1 > imgt[0].shape[0] - 1 ] = imgt[0].shape[0] - 1 v2d1[v2d1 < 0] = 0 v2d2[v2d2 > imgt[0].shape[1] - 1] = imgt[0].shape[1] - 1 v2d2[v2d2 < 0] = 0 id1 = np.arange(len(v2d1)) id2 = np.arange(len(v2d2)) for ii in range(len(topo1)): topo1[ii].append(ii) for ii in range(len(topo2)): topo2[ii].append(ii) adj1 = sknetwork.data.from_adjacency_list(topo1, matrix_only=True, reindex=False).toarray() adj2 = sknetwork.data.from_adjacency_list(topo2, matrix_only=True, reindex=False).toarray() try: spec0, spec1 = np.abs(self.spectral.fit_transform(adj1)), np.abs(self.spectral.fit_transform(adj2)) except: print('>>>>' + file_name, flush=True) spec0, spec1 = np.zeros((len(adj1), 64)), np.zeros((len(adj2), 64)) return{ 'keypoints0': v2d1, 'keypoints1': v2d2, 'topo0': [topo1], 'topo1': [topo2], 'adj_mat0': adj1, 'adj_mat1': adj2, 'spec0': spec0, 'spec1': spec1, 'image0': img1, 'image1': img2, 'ms': m, 'ns': n, 'mask0': mask0, 'mask1': mask1, 'gen_vid': True, 'file_name0': file_name0, 'file_name1': file_name1, 'folder_name0': folder0, 'folder_name1': folder1 } def __rmul__(self, v): self.label_list = v * self.label_list self.image_list = v * self.image_list return self def __len__(self): return len(self.image_list) def worker_init_fn(worker_id): np.random.seed(np.random.get_state()[1][0] + worker_id) def fetch_videoloader(args, type='train',): lineart = VideoLinSeq(root=args.root, split=args.type, ) loader = data.DataLoader(lineart, batch_size=args.batch_size, pin_memory=True, shuffle=False, num_workers=8) return loader ================================================ FILE: download.sh ================================================ cd data gdown 1SNRGajIECxNwRp6ZJ0IlY7AEl2mRm2DR unzip ml240data.zip ================================================ FILE: experiments/inbetweener_full/ckpt/.gitkeep ================================================ ================================================ FILE: inbetween.py ================================================ """ This script handling the training process. """ import os import time import random import argparse import torch import torch.nn as nn import torch.optim as optim import torch.utils.data from datasets import fetch_dataloader from datasets import fetch_videoloader import random from utils.log import Logger from torch.optim import * import warnings from tqdm import tqdm import itertools import pdb import numpy as np import models import datetime import sys import json import cv2 from utils.visualize_inbetween3 import visualize # from utils.visualize_inbetween import visualize from utils.visualize_video import visvid as visgen import matplotlib.cm as cm # from models.utils import make_matching_seg_plot warnings.filterwarnings('ignore') # a, b, c, d = check_data_distribution('/mnt/lustre/lisiyao1/dance/dance2/DanceRevolution/data/aistpp_train') import matplotlib.pyplot as plt import pdb class DraftRefine(): def __init__(self, args): self.config = args torch.backends.cudnn.benchmark = True torch.multiprocessing.set_sharing_strategy('file_system') self._build() def train(self): opt = self.config print(opt) # store viz results # eval_output_dir = Path(self.expdir) # eval_output_dir.mkdir(exist_ok=True, parents=True) # print('Will write visualization images to', # 'directory \"{}\"'.format(eval_output_dir)) # load training data model = self.model checkpoint = torch.load(self.config.corr_weights) dict = {k.replace('module.', ''): checkpoint['model'][k] for k in checkpoint['model']} model.module.corr.load_state_dict(dict) if hasattr(self.config, 'init_weight'): checkpoint = torch.load(self.config.init_weight) model.load_state_dict(checkpoint['model']) # if torch.cuda.is_available(): # model.cuda() # make sure it trains on GPU # else: # print("### CUDA not available ###") # return optimizer = self.optimizer schedular = self.schedular mean_loss = [] log = Logger(self.config, self.expdir) updates = 0 # set seed random.seed(opt.seed) torch.manual_seed(opt.seed) torch.cuda.manual_seed(opt.seed) np.random.seed(opt.seed) # print(opt.seed) # start training for epoch in range(1, opt.epoch+1): np.random.seed(opt.seed + epoch) train_loader = self.train_loader log.set_progress(epoch, len(train_loader)) batch_loss = 0 batch_epe = 0 batch_acc = 0 batch_iter = 0 model.train() avg_time = 0 avg_num = 0 # torch.cuda.synchronize() for i, data in enumerate(train_loader): pred = model(data) if True: loss = pred['loss'].mean() # print(loss.item(), opt.batch_size) batch_loss += loss.item() / opt.batch_size batch_acc += pred['Visibility Acc'].mean().item() / opt.batch_size batch_epe += pred['EPE'].mean().item() / opt.batch_size loss.backward() batch_iter += 1 else: print('Skip!') if ((i + 1) % opt.batch_size == 0) or (i + 1 == len(train_loader)): optimizer.step() optimizer.zero_grad() batch_iter = 1 if batch_iter == 0 else batch_iter stats = { 'updates': updates, 'loss': batch_loss, 'accuracy': batch_acc, 'EPE': batch_epe } log.update(stats) updates += 1 batch_loss = 0 batch_acc = 0 batch_epe = 0 batch_iter = 0 # tend = time.time() # avg_time = (tend - tstart) # print('Time is ', avg_time) # torch.cuda.synchronize() # avg_num += 1 # for name, params in model.named_parameters(): # print('-->name:, ', name, '-->grad mean', params.grad.mean()) # print("All time is ", avg_time, "AVG time is ", avg_time * 1.0 /avg_num, "number is ", avg_num, flush=True) # save checkpoint if epoch % opt.save_per_epochs == 0 or epoch == 1: checkpoint = { 'model': model.state_dict(), 'config': opt, 'epoch': epoch } filename = os.path.join(self.ckptdir, f'epoch_{epoch}.pt') torch.save(checkpoint, filename) # validate if epoch % opt.test_freq == 0: if not os.path.exists(os.path.join(self.visdir, 'epoch' + str(epoch))): os.mkdir(os.path.join(self.visdir, 'epoch' + str(epoch))) eval_output_dir = os.path.join(self.visdir, 'epoch' + str(epoch)) test_loader = self.test_loader with torch.no_grad(): # Visualize the matches. mean_acc = [] mean_epe = [] model.eval() for i_eval, data in enumerate(tqdm(test_loader, desc='Refining motion and visibility...')): pred = model(data) # for k, v in data.items(): # pred[k] = v[0] # pred = {**pred, **data} mean_acc.append(pred['Visibility Acc'].mean().item()) mean_epe.append(pred['EPE'].mean().item()) log.log_eval({ 'updates': opt.epoch, 'Visibility Accuracy': np.mean(mean_acc), 'EPE': np.mean(mean_epe), }) print('Epoch [{}/{}]], Vis Acc.: {:.4f}, EPE: {:.4f}' .format(epoch, opt.epoch, np.mean(mean_acc), np.mean(mean_epe)) ) sys.stdout.flush() # make_matching_plot( # image0, image1, kpts0, kpts1, mkpts0, mkpts1, color, # text, viz_path, stem, stem, True, # True, False, 'Matches') self.schedular.step() def eval(self): train_action = ['breakdance_1990', 'capoeira', 'chapa-giratoria', 'fist_fight', 'flying_knee', 'freehang_climb', 'running', 'shove', 'magic', 'tripping'] test_action = ['great_sword_slash', 'hip_hop_dancing'] train_model = ['ganfaul', 'girlscout', 'jolleen', 'kachujin', 'knight', 'maria_w_jj', 'michelle', 'peasant_girl', 'timmy', 'uriel_a_plotexia'] test_model = ['police', 'warrok'] config = self.config if not os.path.exists(config.imwrite_dir): os.mkdir(config.imwrite_dir) log = Logger(self.config, self.expdir) with torch.no_grad(): model = self.model.eval() config = self.config epoch_tested = self.config.testing.ckpt_epoch if epoch_tested == 0 or epoch_tested == '0': checkpoint = torch.load(self.config.corr_weights) dict = {k.replace('module.', ''): checkpoint['model'][k] for k in checkpoint['model']} model.module.corr.load_state_dict(dict) else: ckpt_path = os.path.join(self.ckptdir, f"epoch_{epoch_tested}.pt") # self.device = torch.device('cuda' if config.cuda else 'cpu') print("Evaluation...") checkpoint = torch.load(ckpt_path) model.load_state_dict(checkpoint['model']) model.eval() if not os.path.exists(os.path.join(self.evaldir, 'epoch' + str(epoch_tested))): os.mkdir(os.path.join(self.evaldir, 'epoch' + str(epoch_tested))) if not os.path.exists(os.path.join(self.evaldir, 'epoch' + str(epoch_tested), 'jsons')): os.mkdir(os.path.join(self.evaldir, 'epoch' + str(epoch_tested), 'jsons')) eval_output_dir = os.path.join(self.evaldir, 'epoch' + str(epoch_tested)) test_loader = self.test_loader print(len(test_loader)) mean_acc = [] mean_valid_acc = [] mean_invalid_acc = [] # 144 data 10x10 is for training , 2x10 (unseen model) + 10x2 (unseen action) + 2x2 (unseen model unseen action) is for test # record the accuracy for mean_model_acc = [] mean_model_epe = [] mean_action_acc = [] mean_action_epe = [] mean_none_acc = [] mean_none_epe = [] mean_acc = [] mean_epe = [] mean_cd = [] model.eval() # for i_eval, data in enumerate(tqdm(test_loader, desc='Refining motion and visibility...')): # pred = model(data) # # for k, v in data.items(): # # pred[k] = v[0] # # pred = {**pred, **data} # mean_acc.append(pred['Visibility Acc'].mean().item()) # mean_epe.append(pred['EPE'].mean().item()) # log.log_eval({ # 'updates': opt.epoch, # 'Visibility Accuracy': np.mean(mean_acc), # 'EPE': np.mean(mean_epe), # }) for i_eval, data in enumerate(tqdm(test_loader, desc='Predicting Vtx Corr...')): # if i_eval == 34: # continue pred = model(data) for k, v in pred.items(): # print(k, flush=True) pred[k] = v pred = {**pred, **data} mean_acc.append(pred['Visibility Acc'].mean().item()) mean_epe.append(pred['EPE'].mean().item()) unmarked = True for model_name in train_model: if model_name in pred['file_name']: mean_model_acc.append(pred['Visibility Acc']) mean_model_epe.append(pred['EPE']) unmarked = False break for action_name in train_action: if action_name in pred['file_name']: mean_action_acc.append(pred['Visibility Acc']) mean_action_epe.append(pred['EPE']) unmarked = False break if unmarked: mean_none_acc.append(pred['Visibility Acc']) mean_action_epe.append(pred['EPE']) # if 'invalid_accuracy' in pred and pred['invalid_accuracy'] is not None: # mean_invalid_acc.append(pred['invalid_accuracy']) img_vis = visualize(pred) # mean_cd.append(cd.item()) file_name = pred['file_name'][0].split('/') cv2.imwrite(os.path.join(config.imwrite_dir, (file_name[-2] + '_' + file_name[-1]) + 'png'), img_vis) # cv2.imwrite(os.path.join(eval_output_dir, pred['file_name'][0].replace('/', '_') + '.jpg'), img_vis) log.log_eval({ 'updates': self.config.testing.ckpt_epoch, # 'mean CD': np.mean(mean_cd), # 'Visibility Accuracy': np.mean(mean_acc), # 'EPE': np.mean(mean_epe), # 'Unseen Action Accuracy': np.mean(mean_model_acc), # 'Unseen Action EPE': np.mean(mean_model_epe), # 'Unseen Model Accuracy': np.mean(mean_action_acc), # 'Unseen Model EPE': np.mean(mean_action_epe), # 'Unseen Both Accuracy': np.mean(mean_none_acc), # 'Unseen Both Valid Accuracy': np.mean(mean_none_epe) }) # print ('Epoch [{}/{}]], Acc.: {:.4f}, Valid Acc.{:.4f}' # .format(epoch, opt.epoch, np.mean(mean_acc), np.mean(mean_valid_acc)) ) sys.stdout.flush() def gen(self): log = Logger(self.config, self.viddir) with torch.no_grad(): model = self.model.eval() config = self.config epoch_tested = self.config.testing.ckpt_epoch if epoch_tested == 0 or epoch_tested == '0': checkpoint = torch.load(self.config.corr_weights) dict = {k.replace('module.', ''): checkpoint['model'][k] for k in checkpoint['model']} model.module.corr.load_state_dict(dict) else: ckpt_path = os.path.join(self.ckptdir, f"epoch_{epoch_tested}.pt") # self.device = torch.device('cuda' if config.cuda else 'cpu') print("Evaluation...") checkpoint = torch.load(ckpt_path) model.load_state_dict(checkpoint['model']) model.eval() if not os.path.exists(os.path.join(self.viddir, 'epoch' + str(epoch_tested))): os.mkdir(os.path.join(self.viddir, 'epoch' + str(epoch_tested))) if not os.path.exists(os.path.join(self.viddir, 'epoch' + str(epoch_tested), 'frames')): os.mkdir(os.path.join(self.viddir, 'epoch' + str(epoch_tested), 'frames')) if not os.path.exists(os.path.join(self.viddir, 'epoch' + str(epoch_tested), 'videos')): os.mkdir(os.path.join(self.viddir, 'epoch' + str(epoch_tested), 'videos')) gen_frame_dir = os.path.join(self.viddir, 'epoch' + str(epoch_tested), 'frames') gen_video_dir = os.path.join(self.viddir, 'epoch' + str(epoch_tested), 'videos') vid_loader = self.vid_loader print(len(vid_loader)) mean_acc = [] mean_valid_acc = [] mean_invalid_acc = [] model.eval() for i_eval, data in enumerate(tqdm(vid_loader, desc='Gen Video...')): pred = model(data) for k, v in pred.items(): pred[k] = v pred = {**pred, **data} img_vis = visgen(pred, config.inter_frames) if not os.path.exists(os.path.join(gen_frame_dir, pred['folder_name0'][0])): os.mkdir(os.path.join(gen_frame_dir, pred['folder_name0'][0])) cv2.imwrite(os.path.join(gen_frame_dir, pred['folder_name0'][0], pred['file_name0'][0] + '_000.jpg'),img_vis[0]) for tt in range(config.inter_frames): cv2.imwrite(os.path.join(gen_frame_dir, pred['folder_name0'][0], pred['file_name0'][0] + '_' + '{:03d}'.format(tt + 1) + '.jpg'), img_vis[tt + 1]) cv2.imwrite(os.path.join(gen_frame_dir, pred['folder_name0'][0], pred['file_name1'][0] + '_000.jpg'),img_vis[-1]) for ff in os.listdir(gen_frame_dir): frame_dir = os.path.join(gen_frame_dir, ff) video_file = os.path.join(gen_video_dir, f"{ff}.mp4") cmd = f"ffmpeg -r {config.fps} -pattern_type glob -i '{frame_dir}/*.jpg' -vb 20M -vcodec mpeg4 -y '{video_file}'" print(cmd, flush=True) os.system(cmd) log.log_eval({ 'updates': self.config.testing.ckpt_epoch, }) sys.stdout.flush() def _build(self): config = self.config self.start_epoch = 0 self._dir_setting() self._build_model() if not(hasattr(config, 'need_not_train_data') and config.need_not_train_data): self._build_train_loader() if not(hasattr(config, 'need_not_test_data') and config.need_not_train_data): self._build_test_loader() if hasattr(config, 'gen_video') and config.gen_video: self._build_video_loader() self._build_optimizer() def _build_model(self): """ Define Model """ config = self.config if hasattr(config.model, 'name'): print(f'Experiment Using {config.model.name}') model_class = getattr(models, config.model.name) model = model_class(config.model) else: raise NotImplementedError("Wrong Model Selection") model = nn.DataParallel(model) self.model = model.cuda() def _build_train_loader(self): config = self.config self.train_loader = fetch_dataloader(config.data.train, type='train') def _build_test_loader(self): config = self.config self.test_loader = fetch_dataloader(config.data.test, type='test') def _build_video_loader(self): config = self.config self.vid_loader = fetch_videoloader(config.video) def _build_optimizer(self): #model = nn.DataParallel(model).to(device) config = self.config.optimizer try: optim = getattr(torch.optim, config.type) except Exception: raise NotImplementedError('not implemented optim method ' + config.type) self.optimizer = optim(itertools.chain(self.model.module.parameters(), ), **config.kwargs) self.schedular = torch.optim.lr_scheduler.MultiStepLR(self.optimizer, **config.schedular_kwargs) def _dir_setting(self): self.expname = self.config.expname # self.experiment_dir = os.path.join("/mnt/cache/syli/inbetween", "experiments") self.experiment_dir = 'experiments' self.expdir = os.path.join(self.experiment_dir, self.expname) if not os.path.exists(self.expdir): os.mkdir(self.expdir) self.visdir = os.path.join(self.expdir, "vis") # -- imgs, videos, jsons if not os.path.exists(self.visdir): os.mkdir(self.visdir) self.ckptdir = os.path.join(self.expdir, "ckpt") if not os.path.exists(self.ckptdir): os.mkdir(self.ckptdir) self.evaldir = os.path.join(self.expdir, "eval") if not os.path.exists(self.evaldir): os.mkdir(self.evaldir) self.viddir = os.path.join(self.expdir, "video") if not os.path.exists(self.viddir): os.mkdir(self.viddir) # self.ckptdir = os.path.join(self.expdir, "ckpt") # if not os.path.exists(self.ckptdir): # os.mkdir(self.ckptdir) ================================================ FILE: inbetween_results/.gitkeep ================================================ ================================================ FILE: main.py ================================================ from inbetween import DraftRefine import argparse import os import yaml from pprint import pprint from easydict import EasyDict def parse_args(): parser = argparse.ArgumentParser( description='Anime segment matching') parser.add_argument('--config', default='') # exclusive arguments group = parser.add_mutually_exclusive_group(required=True) group.add_argument('--train', action='store_true') group.add_argument('--eval', action='store_true') group.add_argument('--gen', action='store_true') return parser.parse_args() def main(): # parse arguments and load config args = parse_args() with open(args.config) as f: config = yaml.load(f) for k, v in vars(args).items(): config[k] = v pprint(config) config = EasyDict(config) agent = DraftRefine(config) print(config) if args.train: agent.train() elif args.eval: agent.eval() elif args.gen: agent.gen() if __name__ == '__main__': main() ================================================ FILE: models/__init__.py ================================================ # from .transformer_refiner import Refiner # from .inbetweener import Inbetweener # from .inbetweener_with_mask import InbetweenerM # from .inbetweener_wo_rp import InbetweenerM as InbetweenerNRP from .inbetweener_with_mask_with_spec import InbetweenerTM # from .inbetweener_with_mask_with_spec_wo_OT import InbetweenerTMwoOT from .inbetweener_with_mask2 import InbetweenerM as InbetweenerM2 # from .inbetweener_with_mask_wo_pos import InbetweenerNP # from .inbetweener_with_mask_wo_pos_wo_spec import InbetweenerNPS # from .transformer_refiner2 import Refiner as Refiner2 # from .transformer_refiner3 import Refiner as Refiner3 # from .transformer_refiner4 import Refiner as Refiner4 # from .transformer_refiner5 import Refiner as Refiner5 # from .transformer_refiner_norm import Refiner as RefinerN __all__ = [ 'InbetweenerTM', 'InbetweenerM2'] ================================================ FILE: models/inbetweener_with_mask2.py ================================================ from copy import deepcopy from pathlib import Path import torch from torch import nn # from seg_desc import seg_descriptor import argparse import torch.nn.functional as F def MLP(channels: list, do_bn=True): """ Multi-layer perceptron """ n = len(channels) layers = [] for i in range(1, n): layers.append( nn.Conv1d(channels[i - 1], channels[i], kernel_size=1, bias=True)) if i < (n-1): if do_bn: # layers.append(nn.BatchNorm1d(channels[i])) layers.append(nn.InstanceNorm1d(channels[i])) layers.append(nn.ReLU()) return nn.Sequential(*layers) def normalize_keypoints(kpts, image_shape): """ Normalize keypoints locations based on image image_shape""" _, _, height, width = image_shape one = kpts.new_tensor(1) size = torch.stack([one*width, one*height])[None] center = size / 2 scaling = size.max(1, keepdim=True).values * 0.7 return (kpts - center[:, None, :]) / scaling[:, None, :] class ThreeLayerEncoder(nn.Module): """ Joint encoding of visual appearance and location using MLPs""" def __init__(self, enc_dim): super().__init__() # input must be 3 channel (r, g, b) self.layer1 = nn.Conv2d(3, enc_dim//4, 7, padding=3) self.non_linear1 = nn.ReLU() self.layer2 = nn.Conv2d(enc_dim//4, enc_dim//2, 3, padding=1) self.non_linear2 = nn.ReLU() self.layer3 = nn.Conv2d(enc_dim//2, enc_dim, 3, padding=1) self.norm1 = nn.InstanceNorm2d(enc_dim//4) self.norm2 = nn.InstanceNorm2d(enc_dim//2) self.norm3 = nn.InstanceNorm2d(enc_dim) for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') nn.init.constant_(m.bias, 0.0) def forward(self, img): x = self.non_linear1(self.norm1(self.layer1(img))) x = self.non_linear2(self.norm2(self.layer2(x))) x = self.norm3(self.layer3(x)) # x = self.non_linear1(self.layer1(img)) # x = self.non_linear2(self.layer2(x)) # x = self.layer3(x) return x class VertexDescriptor(nn.Module): """ Joint encoding of visual appearance and location using MLPs""" def __init__(self, enc_dim): super().__init__() self.encoder = ThreeLayerEncoder(enc_dim) # self.super_pixel_pooling = # use scatter # nn.init.constant_(self.encoder[-1].bias, 0.0) def forward(self, img, vtx): x = self.encoder(img) n, c, h, w = x.size() assert((h, w) == img.size()[2:4]) return x[:, :, torch.round(vtx[0, :, 1]).long(), torch.round(vtx[0, :, 0]).long()] # return super_pixel_pooling(x.view(n, c, -1), seg.view(-1).long(), reduce='mean') # here return size is [1]xCx|Seg| class KeypointEncoder(nn.Module): """ Joint encoding of visual appearance and location using MLPs""" def __init__(self, feature_dim, layers): super().__init__() self.encoder = MLP([2] + layers + [feature_dim]) # for m in self.encoder.modules(): # if isinstance(m, nn.Conv2d): # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') # nn.init.constant_(m.bias, 0.0) nn.init.constant_(self.encoder[-1].bias, 0.0) def forward(self, kpts): inputs = kpts.transpose(1, 2) # print(inputs.size(), 'wula!') x = self.encoder(inputs) # print(x.size()) return x def attention(query, key, value, mask=None): dim = query.shape[1] scores = torch.einsum('bdhn,bdhm->bhnm', query, key) / dim**.5 if mask is not None: # print(mask, flush=True) scores = scores.masked_fill(mask==0, float('-inf')) # att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) # att = att.masked_fill(self.mask[:,:,:T,:T] == 0, float('-inf')) # att = F.softmax(att, dim=-1) prob = torch.nn.functional.softmax(scores, dim=-1) # print(scores[1][1], prob[1][1], flush=True) # while True: # pass # prob = torch.exp(scores) /((torch.sum(torch.exp(scores), dim=-1)[:, :, :, None]) + 1e-7) return torch.einsum('bhnm,bdhm->bdhn', prob, value), prob class MultiHeadedAttention(nn.Module): """ Multi-head attention to increase model expressivitiy """ def __init__(self, num_heads: int, d_model: int): super().__init__() assert d_model % num_heads == 0 self.dim = d_model // num_heads self.num_heads = num_heads self.merge = nn.Conv1d(d_model, d_model, kernel_size=1) self.proj = nn.ModuleList([deepcopy(self.merge) for _ in range(3)]) def forward(self, query, key, value, mask=None): batch_dim = query.size(0) query, key, value = [l(x).view(batch_dim, self.dim, self.num_heads, -1) for l, x in zip(self.proj, (query, key, value))] x, prob = attention(query, key, value, mask) # self.prob.append(prob) return self.merge(x.contiguous().view(batch_dim, self.dim*self.num_heads, -1)) class AttentionalPropagation(nn.Module): def __init__(self, feature_dim: int, num_heads: int): super().__init__() self.attn = MultiHeadedAttention(num_heads, feature_dim) self.mlp = MLP([feature_dim*2, feature_dim*2, feature_dim]) nn.init.constant_(self.mlp[-1].bias, 0.0) def forward(self, x, source, mask=None): message = self.attn(x, source, source, mask) return self.mlp(torch.cat([x, message], dim=1)) class AttentionalGNN(nn.Module): def __init__(self, feature_dim: int, layer_names: list): super().__init__() self.layers = nn.ModuleList([ AttentionalPropagation(feature_dim, 4) for _ in range(len(layer_names))]) self.names = layer_names def forward(self, desc0, desc1, mask00=None, mask11=None, mask01=None, mask10=None): for layer, name in zip(self.layers, self.names): layer.attn.prob = [] if name == 'cross': src0, src1 = desc1, desc0 mask0, mask1 = mask01[:, None], mask10[:, None] else: # if name == 'self': src0, src1 = desc0, desc1 mask0, mask1 = mask00[:, None], mask11[:, None] delta0, delta1 = layer(desc0, src0, mask0), layer(desc1, src1, mask1) desc0, desc1 = (desc0 + delta0), (desc1 + delta1) return desc0, desc1 def log_sinkhorn_iterations(Z, log_mu, log_nu, iters: int): """ Perform Sinkhorn Normalization in Log-space for stability""" u, v = torch.zeros_like(log_mu), torch.zeros_like(log_nu) for _ in range(iters): u = log_mu - torch.logsumexp(Z + v.unsqueeze(1), dim=2) v = log_nu - torch.logsumexp(Z + u.unsqueeze(2), dim=1) return Z + u.unsqueeze(2) + v.unsqueeze(1) def log_optimal_transport(scores, alpha, iters: int, ms=None, ns=None): """ Perform Differentiable Optimal Transport in Log-space for stability""" b, m, n = scores.shape one = scores.new_tensor(1) if ms is None or ns is None: ms, ns = (m*one).to(scores), (n*one).to(scores) # else: # ms, ns = ms.to(scores)[:, None], ns.to(scores)[:, None] # here m,n should be parameters not shape # ms, ns: (b, ) bins0 = alpha.expand(b, m, 1) bins1 = alpha.expand(b, 1, n) alpha = alpha.expand(b, 1, 1) # pad additional scores for unmatcheed (to -1) # alpha is the learned threshold couplings = torch.cat([torch.cat([scores, bins0], -1), torch.cat([bins1, alpha], -1)], 1) norm = - (ms + ns).log() # (b, ) # print(scores.min(), flush=True) if ms.size()[0] > 0: norm = norm[:, None] log_mu = torch.cat([norm.expand(b, m), ns.log()[:, None] + norm], dim=-1) # (m + 1) log_nu = torch.cat([norm.expand(b, n), ms.log()[:, None] + norm], dim=-1) # print(log_nu.min(), log_mu.min(), flush=True) else: log_mu = torch.cat([norm.expand(m), ns.log()[None] + norm]) # (m + 1) log_nu = torch.cat([norm.expand(n), ms.log()[None] + norm]) log_mu, log_nu = log_mu[None].expand(b, -1), log_nu[None].expand(b, -1) Z = log_sinkhorn_iterations(couplings, log_mu, log_nu, iters) if ms.size()[0] > 1: norm = norm[:, :, None] Z = Z - norm # multiply probabilities by M+N return Z def arange_like(x, dim: int): return x.new_ones(x.shape[dim]).cumsum(0) - 1 # traceable in 1.1 class SuperGlueM(nn.Module): """SuperGlue feature matching middle-end Given two sets of keypoints and locations, we determine the correspondences by: 1. Keypoint Encoding (normalization + visual feature and location fusion) 2. Graph Neural Network with multiple self and cross-attention layers 3. Final projection layer 4. Optimal Transport Layer (a differentiable Hungarian matching algorithm) 5. Thresholding matrix based on mutual exclusivity and a match_threshold The correspondence ids use -1 to indicate non-matching points. Paul-Edouard Sarlin, Daniel DeTone, Tomasz Malisiewicz, and Andrew Rabinovich. SuperGlue: Learning Feature Matching with Graph Neural Networks. In CVPR, 2020. https://arxiv.org/abs/1911.11763 """ # default_config = { # 'descriptor_dim': 128, # 'weights': 'indoor', # 'keypoint_encoder': [32, 64, 128], # 'GNN_layers': ['self', 'cross'] * 9, # 'sinkhorn_iterations': 100, # 'match_threshold': 0.2, # } def __init__(self, config=None): super().__init__() default_config = argparse.Namespace() default_config.descriptor_dim = 128 # default_config.weights = default_config.keypoint_encoder = [32, 64, 128] default_config.GNN_layers = ['self', 'cross'] * 9 default_config.sinkhorn_iterations = 100 default_config.match_threshold = 0.2 # self.config = {**self.default_config, **config} if config is None: self.config = default_config else: self.config = config self.config.GNN_layers = ['self', 'cross'] * self.config.GNN_layer_num # print('WULA!', self.config.GNN_layer_num) self.kenc = KeypointEncoder( self.config.descriptor_dim, self.config.keypoint_encoder) self.gnn = AttentionalGNN( self.config.descriptor_dim, self.config.GNN_layers) self.final_proj = nn.Conv1d( self.config.descriptor_dim, self.config.descriptor_dim, kernel_size=1, bias=True) bin_score = torch.nn.Parameter(torch.tensor(1.)) self.register_parameter('bin_score', bin_score) self.vertex_desc = VertexDescriptor(self.config.descriptor_dim) # assert self.config.weights in ['indoor', 'outdoor'] # path = Path(__file__).parent # path = path / 'weights/superglue_{}.pth'.format(self.config.weights) # self.load_state_dict(torch.load(path)) # print('Loaded SuperGlue model (\"{}\" weights)'.format( # self.config.weights)) def forward(self, data): """Run SuperGlue on a pair of keypoints and descriptors""" # print(data['segment0'].size()) # desc0, desc1 = data['descriptors0'].float()(), data['descriptors1'].float()() # print(desc0.size()) kpts0, kpts1 = data['keypoints0'].float(), data['keypoints1'].float() ori_mask0, ori_mask1 = data['mask0'].float(), data['mask1'].float() dim_m, dim_n = data['ms'].float(), data['ns'].float() mmax = dim_m.int().max() nmax = dim_n.int().max() mask0 = ori_mask0[:, :mmax] mask1 = ori_mask1[:, :nmax] kpts0 = kpts0[:, :mmax] kpts1 = kpts1[:, :nmax] desc0, desc1 = self.vertex_desc(data['image0'], kpts0.float()), self.vertex_desc(data['image1'], kpts1.float()) # print(desc0.size(), flush=True) mask00 = torch.ones_like(mask0)[:, :, None] * mask0[:, None, :] # print(mask00[1], flush=True) mask11 = torch.ones_like(mask1)[:, :, None] * mask1[:, None, :] mask01 = torch.ones_like(mask0)[:, :, None] * mask1[:, None, :] mask10 = torch.ones_like(mask1)[:, :, None] * mask0[:, None, :] # desc0 = desc0.transpose(0,1) # desc1 = desc1.transpose(0,1) # kpts0 = torch.reshape(kpts0, (1, -1, 2)) # kpts1 = torch.reshape(kpts1, (1, -1, 2)) if kpts0.shape[1] < 2 or kpts1.shape[1] < 2: # no keypoints shape0, shape1 = kpts0.shape[:-1], kpts1.shape[:-1] # print(data['file_name']) return { 'matches0': kpts0.new_full(shape0, -1, dtype=torch.int)[0], # 'matches1': kpts1.new_full(shape1, -1, dtype=torch.int)[0], 'matching_scores0': kpts0.new_zeros(shape0)[0], # 'matching_scores1': kpts1.new_zeros(shape1)[0], 'skip_train': True } # file_name = data['file_name'] all_matches = data['all_matches'] if 'all_matches' in data else None# shape = (1, K1) # .permute(1,2,0) # shape=torch.Size([1, 87,]) # positional embedding # Keypoint normalization. kpts0 = normalize_keypoints(kpts0, data['image0'].shape) kpts1 = normalize_keypoints(kpts1, data['image1'].shape) # Keypoint MLP encoder. # print(data['file_name']) # print(kpts0.size()) pos0 = self.kenc(kpts0) pos1 = self.kenc(kpts1) # print(desc0.size(), pos0.size()) # print(desc0.size(), pos0.size()) desc0 = desc0 + pos0 desc1 = desc1 + pos1 # self.register_buffer("mask", torch.tril(torch.ones(config.block_size, config.block_size)) # .view(1, 1, config.block_size, config.block_size)) # mask0 = ... # mask1 = ... # Multi-layer Transformer network. desc0, desc1 = self.gnn(desc0, desc1, mask00, mask11, mask01, mask10) # Final MLP projection. mdesc0, mdesc1 = self.final_proj(desc0), self.final_proj(desc1) # Compute matching descriptor distance. # print(mdesc0.size(), mdesc1.size()) scores = torch.einsum('bdn,bdm->bnm', mdesc0, mdesc1) scores0 = torch.einsum('bdn,bdm->bnm', mdesc0, mdesc0) scores1 = torch.einsum('bdn,bdm->bnm', mdesc1, mdesc1) # #print('here1!!', scores.size()) # b k1 k2 scores = scores / self.config.descriptor_dim**.5 # print(scores.size(), mask01.size()) # mask01 = mask0[:, :, None] * mask1[:, None, :] # scores = scores.masked_fill(mask01 == 0, float('-inf')) # print(scores.size()) # Run the optimal transport. # print(dim_m.size(), dim_m, flush=True) scores = log_optimal_transport( scores, self.bin_score, iters=self.config.sinkhorn_iterations, ms=dim_m, ns=dim_n) # print(scores) # print(scores.sum()) # print(scores.sum(1)) # print(scores.sum(0)) # Get the matches with score above "match_threshold". return scores[:, :-1, :-1], scores0, scores1, mdesc0, mdesc1 def tensor_erode(bin_img, ksize=5): # 首先为原图加入 padding,防止腐蚀后图像尺寸缩小 B, C, H, W = bin_img.shape pad = (ksize - 1) // 2 bin_img = F.pad(bin_img, [pad, pad, pad, pad], mode='constant', value=0) # 将原图 unfold 成 patch patches = bin_img.unfold(dimension=2, size=ksize, step=1) patches = patches.unfold(dimension=3, size=ksize, step=1) # B x C x H x W x k x k # 取每个 patch 中最小的值,i.e., 0 eroded, _ = patches.reshape(B, C, H, W, -1).min(dim=-1) return eroded class InbetweenerM(nn.Module): """SuperGlue feature matching middle-end Given two sets of keypoints and locations, we determine the correspondences by: 1. Keypoint Encoding (normalization + visual feature and location fusion) 2. Graph Neural Network with multiple self and cross-attention layers 3. Final projection layer 4. Optimal Transport Layer (a differentiable Hungarian matching algorithm) 5. Thresholding matrix based on mutual exclusivity and a match_threshold The correspondence ids use -1 to indicate non-matching points. Paul-Edouard Sarlin, Daniel DeTone, Tomasz Malisiewicz, and Andrew Rabinovich. SuperGlue: Learning Feature Matching with Graph Neural Networks. In CVPR, 2020. https://arxiv.org/abs/1911.11763 """ # default_config = { # 'descriptor_dim': 128, # 'weights': 'indoor', # 'keypoint_encoder': [32, 64, 128], # 'GNN_layers': ['self', 'cross'] * 9, # 'sinkhorn_iterations': 100, # 'match_threshold': 0.2, # } def __init__(self, config=None): super().__init__() self.corr = SuperGlueM(config.corr_model) self.mask_map = MLP([config.corr_model.descriptor_dim, 32, 1]) self.pos_weight = config.pos_weight # self.motion_propagation = # assert self.config.weights in ['indoor', 'outdoor'] # path = Path(__file__).parent # path = path / 'weights/superglue_{}.pth'.format(self.config.weights) # self.load_state_dict(torch.load(path)) # print('Loaded SuperGlue model (\"{}\" weights)'.format( # self.config.weights)) def forward(self, data): if 'gen_vid' in data: dim_m, dim_n = data['ms'].float(), data['ns'].float() mmax = dim_m.int().max() nmax = dim_n.int().max() # with torch.no_grad(): # self.corr.eval() score01, score0, score1, dec0, dec1 = self.corr(data) kpts0, kpts1 = data['keypoints0'][:,:mmax].float(), data['keypoints1'][:,:nmax].float() # BM2, BN2 ## print(kpts0.mean(), kpts1.mean(), flush=True) motion_pred0 = torch.softmax(score01, dim=-1) @ kpts1 - kpts0 motion_pred1 = torch.softmax(score01.transpose(1, 2), dim=-1) @ kpts0 - kpts1 motion_pred0 = torch.softmax(score0, dim=-1) @ motion_pred0 motion_pred1 = torch.softmax(score1, dim=-1) @ motion_pred1 max0, max1 = score01.max(2), score01.max(1) indices0, indices1 = max0.indices, max1.indices mutual0 = arange_like(indices0, 1)[None] == indices1.gather(1, indices0) mutual1 = arange_like(indices1, 1)[None] == indices0.gather(1, indices1) zero = score01.new_tensor(0) mscores0 = torch.where(mutual0, max0.values.exp(), zero) mscores1 = torch.where(mutual1, mscores0.gather(1, indices1), zero) # valid0 = mutual0 & (mscores0 > self.config.match_threshold) # valid1 = mutual1 & valid0.gather(1, indices1) valid0 = mscores0 > 0.2 valid1 = valid0.gather(1, indices1) indices0 = torch.where(valid0, indices0, indices0.new_tensor(-1)) indices1 = torch.where(valid1, indices1, indices1.new_tensor(-1)) adj0, adj1 = data['adj_mat0'].float(), data['adj_mat1'].float() motion_pred0 = torch.softmax(score01, dim=-1) @ kpts1 - kpts0 motion_pred1 = torch.softmax(score01.transpose(1, 2), dim=-1) @ kpts0 - kpts1 # score0.mask_off() motion_pred0 = torch.softmax(score0.masked_fill(adj0==0, float('-inf')), dim=-1) @ motion_pred0 motion_pred1 = torch.softmax(score1.masked_fill(adj1==0, float('-inf')), dim=-1) @ motion_pred1 vb0 = self.mask_map(dec0)[:, 0] vb1 = self.mask_map(dec1)[:, 0] vb0[:] = 1 vb1[:] = 1 im0_erode = data['image0'] im1_erode = data['image1'] im0_erode[im0_erode > 0] = 1 im0_erode[im0_erode <= 0] = 0 im1_erode[im1_erode > 0] = 1 im1_erode[im1_erode <= 0] = 0 im0_erode = tensor_erode(im0_erode, 3) im1_erode = tensor_erode(im1_erode, 3) motion_output0, motion_output1 = motion_pred0.clone(), motion_pred1.clone() ## print('>>>>> here', motion_pred0.mean(), motion_pred1.mean(), flush=True) kpt0t = kpts0 + motion_output0 * 1 kpt1t = kpts1 + motion_output1 * 1 if 'topo0' in data and 'topo1' in data: ## print(len(data['topo0'][0]), len(data['topo1']), flush=True) for node, nbs in enumerate(data['topo0'][0]): for nb in nbs: # print(nb, flush=True) # print(kpt0t.size(), 'fDsafdsafds', flush=True) # if vb0[0, nb] and vb0[0, node] and ((kpt0t[0, node] - kpt0t[0, nb]) ** 2).sum() / (((kpts0[0, node] - kpts0[0, nb]) ** 2).sum() + 1e-7) > 3: # vb0[0, nb] = -1 # vb0[0, node] = -1 # print(node.size()) center = ((kpt0t[0, node] + kpt0t[0, nb]) * 0.5).int()[0] # print(center.size(), flush=True) if vb0[0, nb] and vb0[0, node] and im1_erode[0,:, center[1], center[0]].mean() > 0.8: vb0[0, nb] = -1 vb0[0, node] = -1 # center = ((kpt0t[0, node] + kpt0t[0, nb]) * 0.25).int()[0] # # print(center.size(), flush=True) # if vb0[0, nb] and vb0[0, node] and center[1] < 720 and center[0] < 720 and im1_erode[0,:, center[1], center[0]].mean() > 0.8: # vb0[0, nb] = -1 # vb0[0, node] = -1 # center = ((kpt0t[0, node] + kpt0t[0, nb]) * 0.75).int()[0] # # print(center.size(), flush=True) # if vb0[0, nb] and vb0[0, node] and center[1] < 720 and center[0] < 720 and im1_erode[0,:, center[1], center[0]].mean() > 0.8: # vb0[0, nb] = -1 # vb0[0, node] = -1 for node, nbs in enumerate(data['topo1'][0]): for nb in nbs: # if vb1[0, nb] and vb1[0, node] and ((kpt1t[0, node] - kpt1t[0, nb]) ** 2).sum() / (((kpts1[0, node] - kpts1[0, nb]) ** 2).sum() + 1e-7) >3: # vb1[0, nb] = -1 # vb1[0, node] = -1 center = ((kpt1t[0, node] + kpt1t[0, nb]) * 0.5).int()[0] if vb1[0, nb] and vb1[0, node] and im0_erode[0,:, center[1], center[0]].mean() > 0.95: vb1[0, nb] = -1 vb1[0, node] = -1 # center = ((kpt1t[0, node] + kpt1t[0, nb]) * 0.25).int()[0] # if vb1[0, nb] and vb1[0, node] and center[1] < 720 and center[0] < 720 and im0_erode[0,:, center[1], center[0]].mean() > 0.95: # vb1[0, nb] = -1 # vb1[0, node] = -1 # center = ((kpt1t[0, node] + kpt1t[0, nb]) * 0.75).int()[0] # if vb1[0, nb] and vb1[0, node] and center[1] < 720 and center[0] < 720 and im0_erode[0,:, center[1], center[0]].mean() > 0.95: # vb1[0, nb] = -1 # vb1[0, node] = -1 # print(vb0.mean(), vb1.mean(), flush=True) return {'r0': motion_output0, 'r1': motion_output1, 'vb0':(vb0 > 0).float(), 'vb1':(vb1 > 0).float(),} dim_m, dim_n = data['ms'].float(), data['ns'].float() mmax = dim_m.int().max() nmax = dim_n.int().max() # with torch.no_grad(): # self.corr.eval() score01, score0, score1, dec0, dec1 = self.corr(data) kpts0, kpts1 = data['keypoints0'][:,:mmax].float(), data['keypoints1'][:,:nmax].float() # BM2, BN2 adj0, adj1 = data['adj_mat0'].float(), data['adj_mat1'].float() motion_pred0 = torch.softmax(score01, dim=-1) @ kpts1 - kpts0 motion_pred1 = torch.softmax(score01.transpose(1, 2), dim=-1) @ kpts0 - kpts1 # score0.mask_off() motion_pred0 = torch.softmax(score0.masked_fill(adj0==0, float('-inf')), dim=-1) @ motion_pred0 motion_pred1 = torch.softmax(score1.masked_fill(adj1==0, float('-inf')), dim=-1) @ motion_pred1 vb0 = self.mask_map(dec0)[:, 0] vb1 = self.mask_map(dec1)[:, 0] # motion0_pred, vb0 = pred0[:, :2].permute(0, 2, 1), pred0[:, 2:][:, 0] # motion1_pred, vb1 = pred1[:, :2].permute(0, 2, 1), pred1[:, 2:][:, 0] # delta0, delta1 = motion_delta[:, :, :mmax].permute(0, 2, 1), motion_delta[:, :, mmax:].permute(0, 2, 1) # motion_output0, motion_output1 = motion0 + delta0, motion1 + delta1 motion_output0, motion_output1 = motion_pred0.clone(), motion_pred1.clone() # print(delta0.max(), delta1.max()) # vb0 = kpts0.new_ones(motion_pred0[:, :, 0].size()) + 1.0 # vb1 = kpts1.new_ones(motion_pred1[:, :, 0].size()) + 1.0 # vb0, vb1 = visibility[:, 0, :mmax], visibility[:, 0, mmax:] # mask0, mask1 = mask[:, :mmax].bool(), mask[:, mmax:].bool() # vb0_output = vb0.clone() # vb1_output = vb1.clone() # vb1_output[batch, corr01[corr01 != -1]] = 1.0 # motion_output0[valid0.bool()] = motion0[valid0.bool()] # motion_output1[valid1.bool()] = motion1[valid1.bool()] # vb0_output[vb0_output >= 0] = 1.0 # vb0_output[vb0_output < 0] = 0.0 # vb1_output[vb1_output >= 0] = 1.0 # vb1_output[vb1_output < 0 ] = 0.0 kpt0t = kpts0 + motion_output0 / 2 kpt1t = kpts1 + motion_output1 / 2 # kpt1t[batch, corr01[corr01 != -1]] = kpt0t[corr01 != -1] ################################################## ## Note Here the mini batch size is 1!!!!!!!! ## ################################################## if 'topo0' in data and 'topo1' in data: # print(len(data['topo0'][0]), len(data['topo1']), flush=True) for node, nbs in enumerate(data['topo0'][0]): for nb in nbs: if vb0[0, nb] and vb0[0, node] and ((kpt0t[0, node] - kpt0t[0, nb]) ** 2).sum() / (((kpts0[0, node] - kpts0[0, nb]) ** 2).sum() + 1e-7) > 5: vb0[0, nb] = -1 vb0[0, node] = -1 for node, nbs in enumerate(data['topo1'][0]): for nb in nbs: if vb1[0, nb] and vb1[0, node] and ((kpt1t[0, node] - kpt1t[0, nb]) ** 2).sum() / (((kpts1[0, node] - kpts1[0, nb]) ** 2).sum() + 1e-7) > 5: vb1[0, nb] = -1 vb1[0, node] = -1 if 'motion0' in data and 'motion1' in data: # valid_motion0 = motion_output0[mask0[:, :, None].repeat(1, 1, 2)] # gt_valid_motion0 = data['motion0'][:, :mmax][mask0[:, :, None].repeat(1, 1, 2)].float() # valid_motion1 = motion_output1[mask1[:, :, None].repeat(1, 1, 2)] # gt_valid_motion1 = data['motion1'][:, :nmax][mask1[:, :, None].repeat(1, 1, 2)].float() loss_motion = torch.nn.functional.l1_loss(motion_pred0, data['motion0'][:, :mmax]) +\ torch.nn.functional.l1_loss(motion_pred1, data['motion1'][:, :nmax]) # loss_valid0 = ((corr01 == -1) & (mask0 == 1)) # loss_valid1 = ((corr10 == -1) & (mask1 == 1)) EPE0 = ((motion_pred0 - data['motion0'][:, :mmax]) ** 2).sum(dim=-1).sqrt() EPE1 = ((motion_pred1 - data['motion1'][:, :nmax]) ** 2).sum(dim=-1).sqrt() # print(EPE0.size(), 'fdsafdsa') EPE = (EPE0.mean() + EPE1.mean()) * 0.5 # print(len(EPE0[mask0]), len(EPE1[mask1])) # print(vb0[:, :mmax][mask0], vb0[:, :mmax][mask0].shape, data['visibility0'][:, :mmax][mask0], data['visibility0'][:, :mmax][mask0].shape) # print(.size()) # print((vb0[:, :mmax] > 0).float().sum(), data['visibility0'][:, :mmax].float().sum()) # pos_weight=vb0.new_tensor([0.5]) if 'visibility0' in data and 'visibility1' in data: loss_visibility = torch.nn.functional.binary_cross_entropy_with_logits(vb0[:, :mmax].view(-1, 1), data['visibility0'][:, :mmax].view(-1, 1), pos_weight=vb0.new_tensor([self.pos_weight])) + \ torch.nn.functional.binary_cross_entropy_with_logits(vb1[:, :nmax].view(-1, 1), data['visibility1'][:, :nmax].view(-1, 1), pos_weight=vb0.new_tensor([self.pos_weight])) VB_Acc = ((((vb0 > 0).float() == data['visibility0'][:, :mmax]).float().sum() + ((vb1 > 0).float() == data['visibility1'][:, :nmax]).float().sum()) * 1.0 / (mmax + nmax)) else: loss_visibility = 0 VB_Acc = EPE.new_zeros([1]) loss = loss_motion + 10 * loss_visibility loss_mean = torch.mean(loss) # loss_mean = torch.reshape(loss_mean, (1, -1)) # print(loss_mean, flush=True) # print(all_matches[:, :mmax].size(), indices0.size(), mask0.size(), flush=True) #print((all_matches[0] == indices0[0]).sum()) # print(vb1.size(),corr01.size()) # kpt0t = torch.nn.functional.pad(kpts0 + motion_output0, (0, 0, 0, self.max_len - mmax, 0, 0), mode='constant', value=0) # kpt1t = torch.nn.functional.pad(kpts1 + motion_output1, (0, 0, 0, self.max_len - nmax, 0, 0), mode='constant', value=0), # kpt1t[:, :nmax][batch, corr01[corr01 != -1]] = kpt0t[:, :mmax][corr01 != -1] b, _, _ = motion_pred0.size() # batch = torch.arange(b)[:, None].repeat(1, mmax)[corr01 != -1].long() # # print(kpts0[corr01 != -1].size(), corr01[corr01 != -1].size()) # matched_intermediate = (kpts0[(corr01 != -1)] + kpts1[batch, corr01[corr01 != -1].long(), :]) * 0.5 # motion0[corr01 != -1] = matched_intermediate - kpts0[corr01 != -1] # motion1[batch, corr01[corr01 != -1].long(), :] = matched_intermediate - kpts1[batch, corr01[corr01 != -1].long(), :] # vb0 = torch.nn.functional.pad(vb0, (0, self.max_len - mmax, 0, 0), mode='constant', value=0), # vb1 = torch.nn.functional.pad(vb1, (0, self.max_len - nmax, 0, 0), mode='constant', value=0), # self.max_len = 3050 # VB_Acc = ((((vb0 > 0.5).float() == data['visibility0'][:, :mmax]).float().sum() + ((vb1 > 0.5).float() == data['visibility1'][:, :nmax]).float().sum()) * 1.0 / (mmax + nmax)) return { # 'matches0': indices0, # use -1 for invalid match # 'matches1': indices1[0], # use -1 for invalid match # 'matching_scores0': mscores0, # 'matching_scores1': mscores1[0], # 'keypointst0': torch.nn.functional.pad(kpts0 + motion_output0, (0, 0, 0, self.max_len - mmax, 0, 0), mode='constant', value=0), # 'keypointst1': torch.nn.functional.pad(kpts1 + motion_output1, (0, 0, 0, self.max_len - nmax, 0, 0), mode='constant', value=0), # 'vb0': torch.nn.functional.pad(vb0, (0, self.max_len - mmax, 0, 0), mode='constant', value=0), # 'vb1': torch.nn.functional.pad(vb1, (0, self.max_len - nmax, 0, 0), mode='constant', value=0), 'keypoints0t': kpt0t, 'keypoints1t': kpt1t, 'vb0': (vb0 > 0).float(), 'vb1': (vb1 > 0).float(), 'loss': loss_mean, 'EPE': EPE, 'Visibility Acc': VB_Acc # ((((vb0[mask0] > 0).float() == data['visibility0'][:, :mmax][mask0]).float().sum() + ((vb1[mask1] > 0).float() == data['visibility1'][:, :nmax][mask1]).float().sum()) * 1.0 / (mask0.float().sum() + mask1.float().sum())), # 'skip_train': [False], # 'accuracy': (((all_matches[:, :mmax] == indices0) & mask0.bool()).sum() / mask0.sum()).item(), # 'valid_accuracy': (((all_matches[:, :mmax] == indices0) & (all_matches[:, :mmax] != -1) & mask0.bool()).float().sum() / ((all_matches[:, :mmax] != -1) & mask0.bool()).float().sum()).item(), } else: return { 'loss': -1, 'skip_train': True, 'keypointst0': kpts0 + motion_output0, 'keypointst1': kpts1 + motion_output1, 'vb0': vb0, 'vb1': vb1, # 'accuracy': -1, # 'area_accuracy': -1, # 'valid_accuracy': -1, } if __name__ == '__main__': args = argparse.Namespace() args.batch_size = 2 args.gap = 5 args.type = 'train' args.model = None args.action = None ss = Refiner() loader = fetch_dataloader(args) # #print(len(loader)) for data in loader: # p1, p2, s1, s2, mi = data dict1 = data kp1 = dict1['keypoints0'] kp2 = dict1['keypoints1'] p1 = dict1['image0'] p2 = dict1['image1'] # #print(s1) # #print(s1.type) mi = dict1['m01'] fname = dict1['file_name'] print(dict1['keypoints0'].size(), dict1['keypoints1'].size(), dict1['m01'].size(), dict1['motion0'].size(), dict1['mask0'].size()) # print(kp1.shape, p1.shape, mi.shape) # #print(mi.size()) # #print(mi) # break a = ss(data) print(dict1['file_name']) print(a['loss']) print(a['EPE'], a['Visibility Acc'],flush=True) a['loss'].backward() ================================================ FILE: models/inbetweener_with_mask_with_spec.py ================================================ from copy import deepcopy from pathlib import Path import torch from torch import nn # from seg_desc import seg_descriptor import argparse import numpy as np import torch.nn.functional as F from sknetwork.embedding import Spectral def MLP(channels: list, do_bn=True): """ Multi-layer perceptron """ n = len(channels) layers = [] for i in range(1, n): layers.append( nn.Conv1d(channels[i - 1], channels[i], kernel_size=1, bias=True)) if i < (n-1): if do_bn: # layers.append(nn.BatchNorm1d(channels[i])) layers.append(nn.InstanceNorm1d(channels[i])) layers.append(nn.ReLU()) return nn.Sequential(*layers) def normalize_keypoints(kpts, image_shape): """ Normalize keypoints locations based on image image_shape""" _, _, height, width = image_shape one = kpts.new_tensor(1) size = torch.stack([one*width, one*height])[None] center = size / 2 scaling = size.max(1, keepdim=True).values * 0.7 return (kpts - center[:, None, :]) / scaling[:, None, :] class ThreeLayerEncoder(nn.Module): """ Joint encoding of visual appearance and location using MLPs""" def __init__(self, enc_dim): super().__init__() # input must be 3 channel (r, g, b) self.layer1 = nn.Conv2d(3, enc_dim//4, 7, padding=3) self.non_linear1 = nn.ReLU() self.layer2 = nn.Conv2d(enc_dim//4, enc_dim//2, 3, padding=1) self.non_linear2 = nn.ReLU() self.layer3 = nn.Conv2d(enc_dim//2, enc_dim, 3, padding=1) self.norm1 = nn.InstanceNorm2d(enc_dim//4) self.norm2 = nn.InstanceNorm2d(enc_dim//2) self.norm3 = nn.InstanceNorm2d(enc_dim) for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') nn.init.constant_(m.bias, 0.0) def forward(self, img): x = self.non_linear1(self.norm1(self.layer1(img))) x = self.non_linear2(self.norm2(self.layer2(x))) x = self.norm3(self.layer3(x)) # x = self.non_linear1(self.layer1(img)) # x = self.non_linear2(self.layer2(x)) # x = self.layer3(x) return x class VertexDescriptor(nn.Module): """ Joint encoding of visual appearance and location using MLPs""" def __init__(self, enc_dim): super().__init__() self.encoder = ThreeLayerEncoder(enc_dim) # self.super_pixel_pooling = # use scatter # nn.init.constant_(self.encoder[-1].bias, 0.0) def forward(self, img, vtx): x = self.encoder(img) n, c, h, w = x.size() assert((h, w) == img.size()[2:4]) return x[:, :, torch.round(vtx[0, :, 1]).long(), torch.round(vtx[0, :, 0]).long()] # return super_pixel_pooling(x.view(n, c, -1), seg.view(-1).long(), reduce='mean') # here return size is [1]xCx|Seg| class KeypointEncoder(nn.Module): """ Joint encoding of visual appearance and location using MLPs""" def __init__(self, feature_dim, layers): super().__init__() self.encoder = MLP([2] + layers + [feature_dim]) # for m in self.encoder.modules(): # if isinstance(m, nn.Conv2d): # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') # nn.init.constant_(m.bias, 0.0) nn.init.constant_(self.encoder[-1].bias, 0.0) def forward(self, kpts): inputs = kpts.transpose(1, 2) x = self.encoder(inputs) return x class TopoEncoder(nn.Module): """ Joint encoding of visual appearance and location using MLPs""" def __init__(self, feature_dim, layers): super().__init__() self.encoder = MLP([64] + layers + [feature_dim]) # for m in self.encoder.modules(): # if isinstance(m, nn.Conv2d): # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') # nn.init.constant_(m.bias, 0.0) nn.init.constant_(self.encoder[-1].bias, 0.0) def forward(self, kpts): inputs = kpts.transpose(1, 2) x = self.encoder(inputs) return x def attention(query, key, value, mask=None): dim = query.shape[1] scores = torch.einsum('bdhn,bdhm->bhnm', query, key) / dim**.5 if mask is not None: scores = scores.masked_fill(mask==0, float('-inf')) prob = torch.nn.functional.softmax(scores, dim=-1) return torch.einsum('bhnm,bdhm->bdhn', prob, value), prob class MultiHeadedAttention(nn.Module): """ Multi-head attention to increase model expressivitiy """ def __init__(self, num_heads: int, d_model: int): super().__init__() assert d_model % num_heads == 0 self.dim = d_model // num_heads self.num_heads = num_heads self.merge = nn.Conv1d(d_model, d_model, kernel_size=1) self.proj = nn.ModuleList([deepcopy(self.merge) for _ in range(3)]) def forward(self, query, key, value, mask=None): batch_dim = query.size(0) query, key, value = [l(x).view(batch_dim, self.dim, self.num_heads, -1) for l, x in zip(self.proj, (query, key, value))] x, prob = attention(query, key, value, mask) # self.prob.append(prob) return self.merge(x.contiguous().view(batch_dim, self.dim*self.num_heads, -1)) class AttentionalPropagation(nn.Module): def __init__(self, feature_dim: int, num_heads: int): super().__init__() self.attn = MultiHeadedAttention(num_heads, feature_dim) self.mlp = MLP([feature_dim*2, feature_dim*2, feature_dim]) nn.init.constant_(self.mlp[-1].bias, 0.0) def forward(self, x, source, mask=None): message = self.attn(x, source, source, mask) return self.mlp(torch.cat([x, message], dim=1)) class AttentionalGNN(nn.Module): def __init__(self, feature_dim: int, layer_names: list): super().__init__() self.layers = nn.ModuleList([ AttentionalPropagation(feature_dim, 4) for _ in range(len(layer_names))]) self.names = layer_names def forward(self, desc0, desc1, mask00=None, mask11=None, mask01=None, mask10=None): for layer, name in zip(self.layers, self.names): layer.attn.prob = [] if name == 'cross': src0, src1 = desc1, desc0 mask0, mask1 = mask01[:, None], mask10[:, None] else: # if name == 'self': src0, src1 = desc0, desc1 mask0, mask1 = mask00[:, None], mask11[:, None] delta0, delta1 = layer(desc0, src0, mask0), layer(desc1, src1, mask1) desc0, desc1 = (desc0 + delta0), (desc1 + delta1) return desc0, desc1 def log_sinkhorn_iterations(Z, log_mu, log_nu, iters: int): """ Perform Sinkhorn Normalization in Log-space for stability""" u, v = torch.zeros_like(log_mu), torch.zeros_like(log_nu) for _ in range(iters): u = log_mu - torch.logsumexp(Z + v.unsqueeze(1), dim=2) v = log_nu - torch.logsumexp(Z + u.unsqueeze(2), dim=1) return Z + u.unsqueeze(2) + v.unsqueeze(1) def log_optimal_transport(scores, alpha, iters: int, ms=None, ns=None): """ Perform Differentiable Optimal Transport in Log-space for stability""" b, m, n = scores.shape one = scores.new_tensor(1) if ms is None or ns is None: ms, ns = (m*one).to(scores), (n*one).to(scores) # else: # ms, ns = ms.to(scores)[:, None], ns.to(scores)[:, None] # here m,n should be parameters not shape # ms, ns: (b, ) bins0 = alpha.expand(b, m, 1) bins1 = alpha.expand(b, 1, n) alpha = alpha.expand(b, 1, 1) # pad additional scores for unmatcheed (to -1) # alpha is the learned threshold couplings = torch.cat([torch.cat([scores, bins0], -1), torch.cat([bins1, alpha], -1)], 1) norm = - (ms + ns).log() # (b, ) if ms.size()[0] > 0: norm = norm[:, None] log_mu = torch.cat([norm.expand(b, m), ns.log()[:, None] + norm], dim=-1) # (m + 1) log_nu = torch.cat([norm.expand(b, n), ms.log()[:, None] + norm], dim=-1) else: log_mu = torch.cat([norm.expand(m), ns.log()[None] + norm]) # (m + 1) log_nu = torch.cat([norm.expand(n), ms.log()[None] + norm]) log_mu, log_nu = log_mu[None].expand(b, -1), log_nu[None].expand(b, -1) Z = log_sinkhorn_iterations(couplings, log_mu, log_nu, iters) if ms.size()[0] > 1: norm = norm[:, :, None] Z = Z - norm # multiply probabilities by M+N return Z def arange_like(x, dim: int): return x.new_ones(x.shape[dim]).cumsum(0) - 1 # traceable in 1.1 class SuperGlueT(nn.Module): """SuperGlue feature matching middle-end Given two sets of keypoints and locations, we determine the correspondences by: 1. Keypoint Encoding (normalization + visual feature and location fusion) 2. Graph Neural Network with multiple self and cross-attention layers 3. Final projection layer 4. Optimal Transport Layer (a differentiable Hungarian matching algorithm) 5. Thresholding matrix based on mutual exclusivity and a match_threshold The correspondence ids use -1 to indicate non-matching points. Paul-Edouard Sarlin, Daniel DeTone, Tomasz Malisiewicz, and Andrew Rabinovich. SuperGlue: Learning Feature Matching with Graph Neural Networks. In CVPR, 2020. https://arxiv.org/abs/1911.11763 """ # default_config = { # 'descriptor_dim': 128, # 'weights': 'indoor', # 'keypoint_encoder': [32, 64, 128], # 'GNN_layers': ['self', 'cross'] * 9, # 'sinkhorn_iterations': 100, # 'match_threshold': 0.2, # } def __init__(self, config=None): super().__init__() default_config = argparse.Namespace() default_config.descriptor_dim = 128 default_config.keypoint_encoder = [32, 64, 128] default_config.GNN_layers = ['self', 'cross'] * 9 default_config.sinkhorn_iterations = 100 default_config.match_threshold = 0.2 self.spectral = Spectral(64, normalized=False) if config is None: self.config = default_config else: self.config = config self.config.GNN_layers = ['self', 'cross'] * self.config.GNN_layer_num self.kenc = KeypointEncoder( self.config.descriptor_dim, self.config.keypoint_encoder) self.tenc = TopoEncoder( self.config.descriptor_dim, [96]) self.gnn = AttentionalGNN( self.config.descriptor_dim, self.config.GNN_layers) self.final_proj = nn.Conv1d( self.config.descriptor_dim, self.config.descriptor_dim, kernel_size=1, bias=True) bin_score = torch.nn.Parameter(torch.tensor(1.)) self.register_parameter('bin_score', bin_score) self.vertex_desc = VertexDescriptor(self.config.descriptor_dim) def forward(self, data): kpts0, kpts1 = data['keypoints0'].float(), data['keypoints1'].float() ori_mask0, ori_mask1 = data['mask0'].float(), data['mask1'].float() dim_m, dim_n = data['ms'].float(), data['ns'].float() # spectual embedding of adj matrices # here I find that online computation of spectrals are too slow during training # so the spectrual embedding is moved to dataset pipeline # such that it can be computed in data preparation by multi-processing cpus spec0, spec1 = data['spec0'], data['spec1'] # spec0, spec1 = np.abs(self.spectral.fit_transform(adj_mat0[0].cpu().numpy())), np.abs(self.spectral.fit_transform(adj_mat1[0].cpu().numpy())) mmax = dim_m.int().max() nmax = dim_n.int().max() mask0 = ori_mask0[:, :mmax] mask1 = ori_mask1[:, :nmax] kpts0 = kpts0[:, :mmax] kpts1 = kpts1[:, :nmax] # image context embedding desc0, desc1 = self.vertex_desc(data['image0'], kpts0.float()), self.vertex_desc(data['image1'], kpts1.float()) # add topological embedding desc0 = desc0 + self.tenc(desc0.new_tensor(spec0)) desc1 = desc1 + self.tenc(desc1.new_tensor(spec1)) # masks here were prepared for synchronized training with bach size > 1, but seems not to work well # so the current framework still uses grad accumulation mask00 = torch.ones_like(mask0)[:, :, None] * mask0[:, None, :] mask11 = torch.ones_like(mask1)[:, :, None] * mask1[:, None, :] mask01 = torch.ones_like(mask0)[:, :, None] * mask1[:, None, :] mask10 = torch.ones_like(mask1)[:, :, None] * mask0[:, None, :] if kpts0.shape[1] < 2 or kpts1.shape[1] < 2: # no keypoints shape0, shape1 = kpts0.shape[:-1], kpts1.shape[:-1] # print(data['file_name']) return { 'matches0': kpts0.new_full(shape0, -1, dtype=torch.int)[0], # 'matches1': kpts1.new_full(shape1, -1, dtype=torch.int)[0], 'matching_scores0': kpts0.new_zeros(shape0)[0], # 'matching_scores1': kpts1.new_zeros(shape1)[0], 'skip_train': True } all_matches = data['all_matches'] if 'all_matches' in data else None# shape = (1, K1) # positional embedding # Keypoint normalization. kpts0 = normalize_keypoints(kpts0, data['image0'].shape) kpts1 = normalize_keypoints(kpts1, data['image1'].shape) # Keypoint MLP encoder. pos0 = self.kenc(kpts0) pos1 = self.kenc(kpts1) desc0 = desc0 + pos0 desc1 = desc1 + pos1 # Multi-layer Transformer network. desc0, desc1 = self.gnn(desc0, desc1, mask00, mask11, mask01, mask10) # Final MLP projection. mdesc0, mdesc1 = self.final_proj(desc0), self.final_proj(desc1) # Compute matching descriptor distance. scores = torch.einsum('bdn,bdm->bnm', mdesc0, mdesc1) scores0 = torch.einsum('bdn,bdm->bnm', mdesc0, mdesc0) scores1 = torch.einsum('bdn,bdm->bnm', mdesc1, mdesc1) # b k1 k2 scores = scores / self.config.descriptor_dim**.5 # Run the optimal transport. scores = log_optimal_transport( scores, self.bin_score, iters=self.config.sinkhorn_iterations, ms=dim_m, ns=dim_n) # Get the matches with score above "match_threshold". return scores[:, :-1, :-1], scores0, scores1, mdesc0, mdesc1 def tensor_erode(bin_img, ksize=5): B, C, H, W = bin_img.shape pad = (ksize - 1) // 2 bin_img = F.pad(bin_img, [pad, pad, pad, pad], mode='constant', value=0) patches = bin_img.unfold(dimension=2, size=ksize, step=1) patches = patches.unfold(dimension=3, size=ksize, step=1) # B x C x H x W x k x k eroded, _ = patches.reshape(B, C, H, W, -1).min(dim=-1) return eroded class InbetweenerTM(nn.Module): """AnimeInbet The whole pipeline includes 1. vertex correspondence (vertex embedding + correspondence transformer) 2. repositioning propagation 3. vis mask vertex corr code is modified from SUPER GLUE Paul-Edouard Sarlin, Daniel DeTone, Tomasz Malisiewicz, and Andrew Rabinovich. SuperGlue: Learning Feature Matching with Graph Neural Networks. In CVPR, 2020. https://arxiv.org/abs/1911.11763 """ def __init__(self, config=None): super().__init__() # vertex correspondence self.corr = SuperGlueT(config.corr_model) self.mask_map = MLP([config.corr_model.descriptor_dim, 32, 1]) self.pos_weight = config.pos_weight def forward(self, data): # if in the mode of video generating if 'gen_vid' in data: dim_m, dim_n = data['ms'].float(), data['ns'].float() mmax = dim_m.int().max() nmax = dim_n.int().max() with torch.no_grad(): self.corr.eval() score01, score0, score1, dec0, dec1 = self.corr(data) kpts0, kpts1 = data['keypoints0'][:,:mmax].float(), data['keypoints1'][:,:nmax].float() # BM2, BN2 motion_pred0 = torch.softmax(score01, dim=-1) @ kpts1 - kpts0 motion_pred1 = torch.softmax(score01.transpose(1, 2), dim=-1) @ kpts0 - kpts1 motion_pred0 = torch.softmax(score0, dim=-1) @ motion_pred0 motion_pred1 = torch.softmax(score1, dim=-1) @ motion_pred1 self.mask_map.eval() vb0 = self.mask_map(dec0)[:, 0] vb1 = self.mask_map(dec1)[:, 0] motion_output0, motion_output1 = motion_pred0.clone(), motion_pred1.clone() kpt0t = kpts0 + motion_output0 kpt1t = kpts1 + motion_output1 if 'topo0' in data and 'topo1' in data: # print(len(data['topo0'][0]), len(data['topo1']), flush=True) for node, nbs in enumerate(data['topo0'][0]): for nb in nbs: if vb0[0, nb] and vb0[0, node] and ((kpt0t[0, node] - kpt0t[0, nb]) ** 2).sum() / (((kpts0[0, node] - kpts0[0, nb]) ** 2).sum() + 1e-7) > 3: vb0[0, nb] = 0 vb0[0, node] = 0 for node, nbs in enumerate(data['topo1'][0]): for nb in nbs: if vb1[0, nb] and vb1[0, node] and ((kpt1t[0, node] - kpt1t[0, nb]) ** 2).sum() / (((kpts1[0, node] - kpts1[0, nb]) ** 2).sum() + 1e-7) > 3: vb1[0, nb] = 0 vb1[0, node] = 0 return {'r0': motion_output0, 'r1': motion_output1, 'vb0':vb0, 'vb1':vb1,} # in the normal train/test mode dim_m, dim_n = data['ms'].float(), data['ns'].float() mmax = dim_m.int().max() nmax = dim_n.int().max() # with torch.no_grad(): # self.corr.eval() score01, score0, score1, dec0, dec1 = self.corr(data) kpts0, kpts1 = data['keypoints0'][:,:mmax].float(), data['keypoints1'][:,:nmax].float() # BM2, BN2 motion_pred0 = torch.softmax(score01, dim=-1) @ kpts1 - kpts0 motion_pred1 = torch.softmax(score01.transpose(1, 2), dim=-1) @ kpts0 - kpts1 motion_pred0 = torch.softmax(score0, dim=-1) @ motion_pred0 motion_pred1 = torch.softmax(score1, dim=-1) @ motion_pred1 max0, max1 = score01.max(2), score01.max(1) indices0, indices1 = max0.indices, max1.indices mutual0 = arange_like(indices0, 1)[None] == indices1.gather(1, indices0) mutual1 = arange_like(indices1, 1)[None] == indices0.gather(1, indices1) zero = score01.new_tensor(0) mscores0 = torch.where(mutual0, max0.values.exp(), zero) mscores1 = torch.where(mutual1, mscores0.gather(1, indices1), zero) # valid0 = mutual0 & (mscores0 > self.config.match_threshold) # valid1 = mutual1 & valid0.gather(1, indices1) valid0 = mscores0 > 0.2 valid1 = valid0.gather(1, indices1) indices0 = torch.where(valid0, indices0, indices0.new_tensor(-1)) indices1 = torch.where(valid1, indices1, indices1.new_tensor(-1)) # motion_pred1[0][indices1[0]==-1] = 0 vb0 = self.mask_map(dec0)[:, 0] vb1 = self.mask_map(dec1)[:, 0] motion_output0, motion_output1 = motion_pred0.clone(), motion_pred1.clone() if not self.training: motion_pred0[0][indices0[0]!=-1] = kpts1[0][indices0[0][indices0[0]!=-1]] - kpts0[0][indices0[0]!=-1] # # motion_pred0[0][indices0[0]==-1] = 0 motion_pred1[0][indices1[0]!=-1] = kpts0[0][indices1[0][indices1[0]!=-1]] - kpts1[0][indices1[0]!=-1] vb0[:] = vb0[:] + 0.7 vb1[:] = vb1[:] + 0.7 # motion0_pred, vb0 = pred0[:, :2].permute(0, 2, 1), pred0[:, 2:][:, 0] # motion1_pred, vb1 = pred1[:, :2].permute(0, 2, 1), pred1[:, 2:][:, 0] # delta0, delta1 = motion_delta[:, :, :mmax].permute(0, 2, 1), motion_delta[:, :, mmax:].permute(0, 2, 1) # motion_output0, motion_output1 = motion0 + delta0, motion1 + delta1 motion_output0, motion_output1 = motion_pred0.clone(), motion_pred1.clone() im0_erode = data['image0'] im1_erode = data['image1'] im0_erode[im0_erode > 0] = 1 im0_erode[im0_erode <= 0] = 0 im1_erode[im1_erode > 0] = 1 im1_erode[im1_erode <= 0] = 0 im0_erode = tensor_erode(im0_erode, 7) im1_erode = tensor_erode(im1_erode, 7) kpt0t = kpts0 + motion_output0 / 2 kpt1t = kpts1 + motion_output1 / 2 ################################################## ## Note Here the mini batch size is 1!!!!!!!! ## ################################################## if 'topo0' in data and 'topo1' in data: # print(len(data['topo0'][0]), len(data['topo1']), flush=True) for node, nbs in enumerate(data['topo0'][0]): for nb in nbs: if vb0[0, nb] > 0 and vb0[0, node] > 0 and ((kpt0t[0, node] - kpt0t[0, nb]) ** 2).sum() / (((kpts0[0, node] - kpts0[0, nb]) ** 2).sum() + 1e-7) > 5: vb0[0, nb] = -1 vb0[0, node] = -1 for node, nbs in enumerate(data['topo1'][0]): for nb in nbs: if vb1[0, nb] > 0 and vb1[0, node] > 0 and ((kpt1t[0, node] - kpt1t[0, nb]) ** 2).sum() / (((kpts1[0, node] - kpts1[0, nb]) ** 2).sum() + 1e-7) > 5: vb1[0, nb] = -1 vb1[0, node] = -1 kpt0t = kpts0 + motion_output0 * 1 kpt1t = kpts1 + motion_output1 * 1 if 'topo0' in data and 'topo1' in data: ## print(len(data['topo0'][0]), len(data['topo1']), flush=True) for node, nbs in enumerate(data['topo0'][0]): for nb in nbs: center = ((kpt0t[0, node] + kpt0t[0, nb]) * 0.5).int()[0] if center[0] >= 720 or center[1] >= 720: continue if vb0[0, nb] > 0 and vb0[0, node] > 0 and im1_erode[0,:, center[1], center[0]].mean() > 0.8: vb0[0, nb] = -1 vb0[0, node] = -1 for node, nbs in enumerate(data['topo1'][0]): for nb in nbs: center = ((kpt1t[0, node] + kpt1t[0, nb]) * 0.5).int()[0] if vb1[0, nb] > 0 and vb1[0, node] > 0 and im0_erode[0,:, center[1], center[0]].mean() > 0.8: vb1[0, nb] = -1 vb1[0, node] = -1 kpt0t = kpts0 + motion_output0 / 2 kpt1t = kpts1 + motion_output1 / 2 if 'motion0' in data and 'motion1' in data: loss_motion = torch.nn.functional.l1_loss(motion_pred0, data['motion0'][:, :mmax]) +\ torch.nn.functional.l1_loss(motion_pred1, data['motion1'][:, :nmax]) EPE0 = ((motion_pred0 - data['motion0'][:, :mmax]) ** 2).sum(dim=-1).sqrt() EPE1 = ((motion_pred1 - data['motion1'][:, :nmax]) ** 2).sum(dim=-1).sqrt() # print(EPE0.size(), 'fdsafdsa') EPE = (EPE0.mean() + EPE1.mean()) * 0.5 if 'visibility0' in data and 'visibility1' in data: loss_visibility = torch.nn.functional.binary_cross_entropy_with_logits(vb0[:, :mmax].view(-1, 1), data['visibility0'][:, :mmax].view(-1, 1), pos_weight=vb0.new_tensor([self.pos_weight])) + \ torch.nn.functional.binary_cross_entropy_with_logits(vb1[:, :nmax].view(-1, 1), data['visibility1'][:, :nmax].view(-1, 1), pos_weight=vb0.new_tensor([self.pos_weight])) VB_Acc = ((((vb0 > 0).float() == data['visibility0'][:, :mmax]).float().sum() + ((vb1 > 0).float() == data['visibility1'][:, :nmax]).float().sum()) * 1.0 / (mmax + nmax)) else: loss_visibility = 0 VB_Acc = EPE.new_zeros([1]) loss = loss_motion + 10 * loss_visibility loss_mean = torch.mean(loss) b, _, _ = motion_pred0.size() return { 'keypoints0t': kpt0t, 'keypoints1t': kpt1t, 'vb0': (vb0 > 0).float(), 'vb1': (vb1 > 0).float(), 'r0': motion_output0, 'r1': motion_output1, 'loss': loss_mean, 'EPE': EPE, 'Visibility Acc': VB_Acc } else: return { 'loss': -1, 'skip_train': True, 'keypointst0': kpts0 + motion_output0, 'keypointst1': kpts1 + motion_output1, 'vb0': vb0, 'vb1': vb1, } if __name__ == '__main__': args = argparse.Namespace() args.batch_size = 2 args.gap = 5 args.type = 'train' args.model = None args.action = None ss = Refiner() loader = fetch_dataloader(args) # #print(len(loader)) for data in loader: # p1, p2, s1, s2, mi = data dict1 = data kp1 = dict1['keypoints0'] kp2 = dict1['keypoints1'] p1 = dict1['image0'] p2 = dict1['image1'] # #print(s1) # #print(s1.type) mi = dict1['m01'] fname = dict1['file_name'] print(dict1['keypoints0'].size(), dict1['keypoints1'].size(), dict1['m01'].size(), dict1['motion0'].size(), dict1['mask0'].size()) # print(kp1.shape, p1.shape, mi.shape) # #print(mi.size()) # #print(mi) # break a = ss(data) print(dict1['file_name']) print(a['loss']) print(a['EPE'], a['Visibility Acc'],flush=True) a['loss'].backward() ================================================ FILE: requirement.txt ================================================ opencv-python pyyaml==5.4.1 scikit-network tqdm matplotlib easydict gdown ================================================ FILE: srun.sh ================================================ #!/bin/sh currenttime=`date "+%Y%m%d%H%M%S"` if [ ! -d log ]; then mkdir log fi echo "[Usage] ./srun.sh config_path [train|eval] partition gpunum" # check config exists if [ ! -e $1 ] then echo "[ERROR] configuration file: $1 does not exists!" exit fi if [ ! -d ${expname} ]; then mkdir ${expname} fi echo "[INFO] saving results to, or loading files from: "$expname if [ "$3" == "" ]; then echo "[ERROR] enter partition name" exit fi partition_name=$3 echo "[INFO] partition name: $partition_name" if [ "$4" == "" ]; then echo "[ERROR] enter gpu num" exit fi gpunum=$4 gpunum=$(($gpunum<8?$gpunum:8)) echo "[INFO] GPU num: $gpunum" ((ntask=$gpunum*3)) TOOLS="srun --partition=$partition_name -x SG-IDC2-10-51-5-44 --cpus-per-task=16 --gres=gpu:$gpunum -N 1 --mem-per-gpu=32G --job-name=${config_suffix}" PYTHONCMD="python -u main.py --config $1" if [ $2 == "train" ]; then $TOOLS $PYTHONCMD \ --train elif [ $2 == "eval" ]; then $TOOLS $PYTHONCMD \ --eval elif [ $2 == "gen" ]; then $TOOLS $PYTHONCMD \ --gen fi # elif [ $2 == "visgt" ]; # then # $TOOLS $PYTHONCMD \ # --visgt # elif [ $2 == "anl" ]; # then # $TOOLS $PYTHONCMD \ # --anl # elif [ $2 == "sample" ]; # then # $TOOLS $PYTHONCMD \ # --sample # fi ================================================ FILE: utils/chamfer_distance.py ================================================ import os import numpy as np from time import time import cv2 import pdb import scipy import scipy.ndimage import torch import torchmetrics black_threshold = 255.0 * 0.99 def batch_edt(img, block=1024): expand = False bs,h,w = img.shape diam2 = h**2 + w**2 odtype = img.dtype grid = (img.nelement()+block-1) // block # cupy implementation # default to scipy cpu implementation sums = img.sum(dim=(1,2)) ans = torch.tensor(np.stack([ scipy.ndimage.morphology.distance_transform_edt(i) if s!=0 else # change scipy behavior for empty image np.ones_like(i) * np.sqrt(diam2) for i,s in zip(1-img, sums) ]), dtype=odtype) if expand: ans = ans.unsqueeze(1) return ans ############### DERIVED DISTANCES ############### # input: (bs,h,w) or (bs,1,h,w) # returns: (bs,) # normalized s.t. metric is same across proportional image scales # average of two asymmetric distances # normalized by diameter and area def batch_chamfer_distance(gt, pred, block=1024, return_more=False): t = batch_chamfer_distance_t(gt, pred, block=block) p = batch_chamfer_distance_p(gt, pred, block=block) cd = (t + p) / 2 return cd def batch_chamfer_distance_t(gt, pred, block=1024, return_more=False): #pdb.set_trace() assert gt.device==pred.device and gt.shape==pred.shape bs,h,w = gt.shape[0], gt.shape[-2], gt.shape[-1] dpred = batch_edt(pred, block=block) cd = (gt*dpred).float().mean((-2,-1)) / np.sqrt(h**2+w**2) if len(cd.shape)==2: assert cd.shape[1]==1 cd = cd.squeeze(1) return cd def batch_chamfer_distance_p(gt, pred, block=1024, return_more=False): assert gt.device==pred.device and gt.shape==pred.shape bs,h,w = gt.shape[0], gt.shape[-2], gt.shape[-1] dgt = batch_edt(gt, block=block) cd = (pred*dgt).float().mean((-2,-1)) / np.sqrt(h**2+w**2) if len(cd.shape)==2: assert cd.shape[1]==1 cd = cd.squeeze(1) return cd # normalized by diameter # always between [0,1] def batch_hausdorff_distance(gt, pred, block=1024, return_more=False): assert gt.device==pred.device and gt.shape==pred.shape bs,h,w = gt.shape[0], gt.shape[-2], gt.shape[-1] dgt = batch_edt(gt, block=block) dpred = batch_edt(pred, block=block) hd = torch.stack([ (dgt*pred).amax(dim=(-2,-1)), (dpred*gt).amax(dim=(-2,-1)), ]).amax(dim=0).float() / np.sqrt(h**2+w**2) if len(hd.shape)==2: assert hd.shape[1]==1 hd = hd.squeeze(1) return hd ############### TORCHMETRICS ############### class ChamferDistance2dMetric(torchmetrics.Metric): full_state_update=False def __init__( self, block=1024, convert_dog=True, k=1.6, epsilon=0.01, kernel_factor=4, clip=False, **kwargs, ): super().__init__(**kwargs) self.block = block self.convert_dog = convert_dog self.add_state('running_sum', default=torch.tensor(0.0), dist_reduce_fx='sum') self.add_state('running_count', default=torch.tensor(0.0), dist_reduce_fx='sum') return def update(self, preds: torch.Tensor, target: torch.Tensor): dist = batch_chamfer_distance(target, preds, block=self.block) self.running_sum += dist.sum() self.running_count += len(dist) return def compute(self): return self.running_sum.float() / self.running_count class ChamferDistance2dTMetric(ChamferDistance2dMetric): def update(self, preds: torch.Tensor, target: torch.Tensor): if self.convert_dog: preds = (usketchers.batch_dog(preds, **self.dog_params)>0.5).float() target = (usketchers.batch_dog(target, **self.dog_params)>0.5).float() dist = batch_chamfer_distance_t(target, preds, block=self.block) self.running_sum += dist.sum() self.running_count += len(dist) return class ChamferDistance2dPMetric(ChamferDistance2dMetric): def update(self, preds: torch.Tensor, target: torch.Tensor): if self.convert_dog: preds = (usketchers.batch_dog(preds, **self.dog_params)>0.5).float() target = (usketchers.batch_dog(target, **self.dog_params)>0.5).float() dist = batch_chamfer_distance_p(target, preds, block=self.block) self.running_sum += dist.sum() self.running_count += len(dist) return class HausdorffDistance2dMetric(torchmetrics.Metric): def __init__( self, block=1024, convert_dog=True, t=2.0, sigma=1.0, k=1.6, epsilon=0.01, kernel_factor=4, clip=False, **kwargs, ): super().__init__(**kwargs) self.block = block self.convert_dog = convert_dog self.dog_params = { 't': t, 'sigma': sigma, 'k': k, 'epsilon': epsilon, 'kernel_factor': kernel_factor, 'clip': clip, } self.add_state('running_sum', default=torch.tensor(0.0), dist_reduce_fx='sum') self.add_state('running_count', default=torch.tensor(0.0), dist_reduce_fx='sum') return def update(self, preds: torch.Tensor, target: torch.Tensor): if self.convert_dog: preds = (usketchers.batch_dog(preds, **self.dog_params)>0.5).float() target = (usketchers.batch_dog(target, **self.dog_params)>0.5).float() dist = batch_hausdorff_distance(target, preds, block=self.block) self.running_sum += dist.sum() self.running_count += len(dist) return def compute(self): return self.running_sum.float() / self.running_count def rgb2sketch(img, black_threshold): #pdb.set_trace() img[img < black_threshold] = 1 img[img >= black_threshold] = 0 #cv2.imwrite("grey.png",img*255) return torch.tensor(img) def rgb2gray(rgb): r, g, b = rgb[:,:,0], rgb[:,:,1], rgb[:,:,2] gray = 0.2989 * r + 0.5870 * g + 0.1140 * b return gray def cd_score(img1, img2): img1 = rgb2gray(img1.astype(float)) img2 = rgb2gray(img2.astype(float)) img1_sketch = rgb2sketch(img1, black_threshold) img2_sketch = rgb2sketch(img2, black_threshold) img1_sketch = img1_sketch.unsqueeze(0) img2_sketch = img2_sketch.unsqueeze(0) CD = ChamferDistance2dMetric() cd = CD(img1_sketch,img2_sketch) return cd ================================================ FILE: utils/log.py ================================================ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this open-source project. """ Define the Logger class to print log""" import os import sys import logging from datetime import datetime class Logger: def __init__(self, args, output_dir): log = logging.getLogger(output_dir) if not log.handlers: log.setLevel(logging.DEBUG) # if not os.path.exists(output_dir): # os.mkdir(args.data.output_dir) fh = logging.FileHandler(os.path.join(output_dir,'log.txt')) fh.setLevel(logging.INFO) ch = ProgressHandler() ch.setLevel(logging.DEBUG) formatter = logging.Formatter(fmt='%(asctime)s %(message)s', datefmt='%m/%d/%Y %I:%M:%S') fh.setFormatter(formatter) ch.setFormatter(formatter) log.addHandler(fh) log.addHandler(ch) self.log = log # setup TensorBoard # if args.tensorboard: # from tensorboardX import SummaryWriter # self.writer = SummaryWriter(log_dir=args.output_dir) # else: self.writer = None self.log_per_updates = args.log_per_updates def set_progress(self, epoch, total): self.log.info(f'Epoch: {epoch}') self.epoch = epoch self.i = 0 self.total = total self.start = datetime.now() def update(self, stats): self.i += 1 if self.i % self.log_per_updates == 0: remaining = str((datetime.now() - self.start) / self.i * (self.total - self.i)) remaining = remaining.split('.')[0] updates = stats.pop('updates') stats_str = ' '.join(f'{key}[{val:.8f}]' for key, val in stats.items()) self.log.info(f'> epoch [{self.epoch}] updates[{updates}] {stats_str} eta[{remaining}]') if self.writer: for key, val in stats.items(): self.writer.add_scalar(f'train/{key}', val, updates) if self.i == self.total: self.log.debug('\n') self.log.debug(f'elapsed time: {str(datetime.now() - self.start).split(".")[0]}') def log_eval(self, stats, metrics_group=None): stats_str = ' '.join(f'{key}: {val:.8f}' for key, val in stats.items()) self.log.info(f'valid {stats_str}') if self.writer: for key, val in stats.items(): self.writer.add_scalar(f'valid/{key}', val, self.epoch) # for mode, metrics in metrics_group.items(): # self.log.info(f'evaluation scores ({mode}):') # for key, (val, _) in metrics.items(): # self.log.info(f'\t{key} {val:.4f}') # if self.writer and metrics_group is not None: # for key, val in stats.items(): # self.writer.add_scalar(f'valid/{key}', val, self.epoch) # for key in list(metrics_group.values())[0]: # group = {} # for mode, metrics in metrics_group.items(): # group[mode] = metrics[key][0] # self.writer.add_scalars(f'valid/{key}', group, self.epoch) def __call__(self, msg): self.log.info(msg) class ProgressHandler(logging.Handler): def __init__(self, level=logging.NOTSET): super().__init__(level) def emit(self, record): log_entry = self.format(record) if record.message.startswith('> '): sys.stdout.write('{}\r'.format(log_entry.rstrip())) sys.stdout.flush() else: sys.stdout.write('{}\n'.format(log_entry)) ================================================ FILE: utils/visualize_inbetween.py ================================================ import numpy as np import torch import cv2 from .chamfer_distance import cd_score # def make_inter_graph(v2d1, v2d2, topo1, topo2, match12): # valid = (match12 != -1) # marked2 = np.zeros(len(v2d2)).astype(bool) # # print(match12[valid]) # marked2[match12[valid]] = True # id1toh, id2toh = np.zeros(len(v2d1)), np.zeros(len(v2d2)) # id1toh[valid] = np.arange(np.sum(valid)) # id2toh[match12[valid]] = np.arange(np.sum(valid)) # id1toh[np.invert(valid)] = np.arange(np.sum(1 - valid)) + np.sum(valid) # # print(marked2) # id2toh[np.invert(marked2)] = len(v2d1) + np.arange(np.sum(np.invert(marked2))) # id1toh = id1toh.astype(int) # id2toh = id2toh.astype(int) # tot_len = len(v2d1) + np.sum(np.invert(marked2)) # vin1 = v2d1[valid][:] # vin2 = v2d2[match12[valid]][:] # vh = 0.5 * (vin1 + vin2) # vh = np.concatenate((vh, v2d1[np.invert(valid)], v2d2[np.invert(marked2)]), axis=0) # topoh = [[] for ii in range(tot_len)] # for node in range(len(topo1)): # for nb in topo1[node]: # if int(id1toh[nb]) not in topoh[id1toh[node]]: # topoh[id1toh[node]].append(int(id1toh[nb])) # for node in range(len(topo2)): # for nb in topo2[node]: # if int(id2toh[nb]) not in topoh[id2toh[node]]: # topoh[id2toh[node]].append(int(id2toh[nb])) # return vh, topoh # def make_inter_graph_valid(v2d1, v2d2, topo1, topo2, match12): # valid = (match12 != -1) # marked2 = np.zeros(len(v2d2)).astype(bool) # # print(match12[valid]) # marked2[match12[valid]] = True # id1toh, id2toh = np.zeros(len(v2d1)), np.zeros(len(v2d2)) # id1toh[valid] = np.arange(np.sum(valid)) # id2toh[match12[valid]] = np.arange(np.sum(valid)) # id1toh[np.invert(valid)] = np.arange(np.sum(1 - valid)) + np.sum(valid) # # print(marked2) # id2toh[np.invert(marked2)] = len(v2d1) + np.arange(np.sum(np.invert(marked2))) # id1toh = id1toh.astype(int) # id2toh = id2toh.astype(int) # tot_len = len(v2d1) + np.sum(np.invert(marked2)) # vin1 = v2d1[valid][:] # vin2 = v2d2[match12[valid]][:] # vh = 0.5 * (vin1 + vin2) # # vh = np.concatenate((vh, v2d1[np.invert(valid)], v2d2[np.invert(marked2)]), axis=0) # # topoh = [[] for ii in range(tot_len)] # topoh = [[] for ii in range(np.sum(valid))] # for node in range(len(topo1)): # if not valid[node]: # continue # for nb in topo1[node]: # if int(id1toh[nb]) not in topoh[id1toh[node]]: # if valid[nb]: # topoh[id1toh[node]].append(int(id1toh[nb])) # for node in range(len(topo2)): # if not marked2[node]: # continue # for nb in topo2[node]: # if int(id2toh[nb]) not in topoh[id2toh[node]]: # if marked2[nb]: # topoh[id2toh[node]].append(int(id2toh[nb])) # return vh, topoh def visualize(dict): # print(dict['keypoints0'].size(), flush=True) img1 = ((dict['image0'][0].permute(1, 2, 0).float().numpy() + 1.0) * 255 / 2).astype(np.uint8).copy() original_target = ((dict['imaget'][0].permute(1, 2, 0).float().numpy() + 1.0) * 255 / 2).astype(np.uint8).copy() # img1p = ((dict['image0'].permute(1, 2, 0).float().numpy() + 1.0) * 255 / 2).astype(int).copy() # img2p = ((dict['image1'].permute(1, 2, 0).float().numpy() + 1.0) * 255 / 2).astype(int).copy() # img1[:, :, 0] += 255 # img1[:, :, 1] += 180 # img1[:, :, 2] += 180 # img1[img1 > 255] = 255 # img2[:, :, 0] += 255 # img2[:, :, 1] += 180 # img2[:, :, 2] += 180 # img2[img2 > 255] = 255 # img1p[:, :, 0] += 255 # img1p[:, :, 1] += 180 # img1p[:, :, 2] += 180 # img1p[img1p > 255] = 255 # img2p[:, :, 0] += 255 # img2p[:, :, 1] += 180 # img2p[:, :, 2] += 180 # img2p[img2p > 255] = 255 # img1, img2, img1p, img2p = img1.astype(np.uint8), img2.astype(np.uint8), img1p.astype(np.uint8), img2p.astype(np.uint8) motion01 = dict['motion0'][0].cpu().numpy().astype(int) motion21 = dict['motion1'][0].cpu().numpy().astype(int) source0_warp = dict['keypoints0t'][0].cpu().numpy().astype(int) source2_warp = dict['keypoints1t'][0].cpu().numpy().astype(int) source0 = dict['keypoints0'][0].cpu().numpy().astype(int) source2 = dict['keypoints1'][0].cpu().numpy().astype(int) source0_topo = dict['topo0'][0] # print(len(dict['topo0'])) source2_topo = dict['topo1'][0] visible01 = dict['vb0'][0].cpu().numpy().astype(int) visible21 = dict['vb1'][0].cpu().numpy().astype(int) # corr01 = dict['m01'][0].cpu().numpy().astype(int) # corr10 = dict['m10'][0].cpu().numpy().astype(int) # canvas = np.zeros_like(img1) + 255 # source0_warp2 = source0 + motion01 // 2 # source2_warp2 = source2 + motion21 // 2 # for node, nbs in enumerate(source0_topo): # for nb in nbs: # # print([source0_warp[nb][0], source0_warp[nb][1]]) # cv2.line(canvas, [source0_warp[node][0], source0_warp[node][1]], [source0_warp[nb][0], source0_warp[nb][1]], [0, 0, 0], 2) # for node, nbs in enumerate(source2_topo): # for nb in nbs: # cv2.line(canvas, [source2_warp[node][0], source2_warp[node][1]], [source2_warp[nb][0], source2_warp[nb][1]], [0, 0, 0], 2) # canvas6 = np.zeros_like(img1) + 255 # for node, nbs in enumerate(source0_topo): # for nb in nbs: # # print([source0_warp[nb][0], source0_warp[nb][1]]) # cv2.line(canvas6, [source0_warp2[node][0], source0_warp2[node][1]], [source0_warp2[nb][0], source0_warp2[nb][1]], [0, 0, 0], 2) # for node, nbs in enumerate(source2_topo): # for nb in nbs: # cv2.line(canvas6, [source2_warp2[node][0], source2_warp2[node][1]], [source2_warp2[nb][0], source2_warp2[nb][1]], [0, 0, 0], 2) canvas2 = np.zeros_like(img1) + 255 ## print('huala<<<', source0_warp.mean(), source2_warp.mean(), flush=True) # source0_warp = source0 + motion01 # source2_warp = source2 + motion21 for node, nbs in enumerate(source0_topo): for nb in nbs: # if visible01[node] and visible01[nb]: cv2.line(canvas2, [source0_warp[node][0], source0_warp[node][1]], [source0_warp[nb][0], source0_warp[nb][1]], [0, 0, 0], 2) for node, nbs in enumerate(source2_topo): for nb in nbs: # if visible21[node] and visible21[nb]: cv2.line(canvas2, [source2_warp[node][0], source2_warp[node][1]], [source2_warp[nb][0], source2_warp[nb][1]], [0, 0, 0], 2) # canvas2 # black_threshold = 255 // 2 # img1_sketch = rgb2sketch(img1, black_threshold) # img2_sketch = rgb2sketch(img2, black_threshold) # img1_sketch = img1_sketch.unsqueeze(0) # img2_sketch = img2_sketch.unsqueeze(0) # CD = ChamferDistance2dMetric() # cd = CD(img1_sketch,img2_sketch) canvas5 = np.zeros_like(img1) + 255 # source0_warp = source0 + motion01 # source2_warp = source2 + motion21 ## print('gulaa>>>', visible01.mean(), visible21.mean(), flush=True) for node, nbs in enumerate(source0_topo): for nb in nbs: if visible01[node] and visible01[nb]: cv2.line(canvas5, [source0_warp[node][0], source0_warp[node][1]], [source0_warp[nb][0], source0_warp[nb][1]], [0, 0, 0], 2) for node, nbs in enumerate(source2_topo): for nb in nbs: if visible21[node] and visible21[nb]: cv2.line(canvas5, [source2_warp[node][0], source2_warp[node][1]], [source2_warp[nb][0], source2_warp[nb][1]], [0, 0, 0], 2) canvas3 = np.zeros_like(img1) + 255 for node, nbs in enumerate(source0_topo): for nb in nbs: cv2.line(canvas3, [source0[node][0], source0[node][1]], [source0[nb][0], source0[nb][1]], [255, 180, 180], 2) for node, nbs in enumerate(source2_topo): for nb in nbs: cv2.line(canvas3, [source2[node][0], source2[node][1]], [source2[nb][0], source2[nb][1]], [180, 180, 255], 2) # canvas_corr1 = np.zeros_like(img1) + 255 # canvas_corr2 = np.zeros_like(img1) + 255 canvas_corr1 = ((dict['image0'][0].permute(1, 2, 0).float().numpy() + 1.0) * 255 / 2).astype(int).copy() canvas_corr2 = ((dict['image1'][0].permute(1, 2, 0).float().numpy() + 1.0) * 255 / 2).astype(int).copy() canvas_corr1[:, :, 0] += 255 canvas_corr1[:, :, 1] += 180 canvas_corr1[:, :, 2] += 180 canvas_corr1[canvas_corr1 > 255] = 255 canvas_corr2[:, :, 0] += 255 canvas_corr2[:, :, 1] += 180 canvas_corr2[:, :, 2] += 180 canvas_corr2[canvas_corr2 > 255] = 255 canvas_corr1 = canvas_corr1.astype(np.uint8) canvas_corr2 = canvas_corr2.astype(np.uint8) # colors1_gt, colors2_gt = {}, {} colors1_pred, colors2_pred = {}, {} # cross1_pred, cross2_pred = {}, {} id1 = np.arange(len(source0)) id2 = np.arange(len(source2)) # predicted = dict['matches0'].cpu().data.numpy()[0] # predicted1 = dict['matches1'].cpu().data.numpy()[0] # for index in id1: # color = [np.random.randint(0, 256), np.random.randint(0, 256), np.random.randint(0, 256)] # # print(predicted.shape, flush=True) # # if all_matches[index] != -1: # # colors2_gt[all_matches[index]] = color # if predicted[index] != -1: # colors2_pred[predicted[index]] = color # # else: # # colors2_pred[predicted[index]] = [0, 0, 0] # # colors1_gt[index] = color if all_matches[index] != -1 else [0, 0, 0] # colors1_pred[index] = color if predicted[index] != -1 else [0, 0, 0] # # if predicted[index] == -1 and colors1_pred[index] != [0, 0, 0]: # # color = [np.random.randint(0, 256), np.random.randint(0, 256), np.random.randint(0, 256)] # # colors1_pred[index] = [0, 0, 0] # # colors2_pred.pop(all_matches[index]) # # whether predicted correctly # # if predicted[index] != all_matches[index]: # # cross1_pred[index] = True # # if predicted[index] != -1: # # cross2_pred[predicted[index]] = True # for i, p in enumerate(source0): # ii = id1[i] # # print(ii) # cv2.circle(canvas_corr1, [int(p[0]), int(p[1])], 1, colors1_pred[i], 2) # # if ii in cross1_pred and cross1_pred[ii]: # # cv2.rectangle(img1p, [int(p[0]-1), int(p[1]-1)], [int(p[0]+1), int(p[1]+1)], colors1_pred[i],-1) # # else: # # cv2.circle(img1p, [int(p[0]), int(p[1])], 1, colors1_pred[i], 2) # for ii in id2: # # print(ii) # color = [0, 0, 0] # this_is_umatched = 1 # if ii not in colors2_pred: # colors2_pred[ii] = color # for i, p in enumerate(source2): # ii = id2[i] # # print(p) # # cv2.circle(img2, [int(p[0]), int( p[1])], 1, colors2_gt[ii], 2) # # if ii in cross2_pred and cross2_pred[ii]: # # cv2.rectangle(img2p, [int(p[0]-1), int(p[1]-1)], [int(p[0]+1), int(p[1]+1)], colors2_pred[i], -1) # # else: # cv2.circle(canvas_corr2, [int(p[0]), int(p[1])], 1, colors2_pred[i], 2) #canvas6, canvas5, canvas, im_h = cv2.hconcat([canvas3, original_target, canvas2, canvas5]) ## print('<<<< mean cavans5: ', canvas5.mean()) cd = cd_score(canvas5.copy(), original_target.copy()) * 1e5 cv2.putText(im_h, str(cd), \ (720, 100), cv2.FONT_HERSHEY_PLAIN, 3, (0, 0, 255), 2) return im_h, cd ================================================ FILE: utils/visualize_inbetween2.py ================================================ import numpy as np import torch import cv2 from .chamfer_distance import cd_score # def make_inter_graph(v2d1, v2d2, topo1, topo2, match12): # valid = (match12 != -1) # marked2 = np.zeros(len(v2d2)).astype(bool) # # print(match12[valid]) # marked2[match12[valid]] = True # id1toh, id2toh = np.zeros(len(v2d1)), np.zeros(len(v2d2)) # id1toh[valid] = np.arange(np.sum(valid)) # id2toh[match12[valid]] = np.arange(np.sum(valid)) # id1toh[np.invert(valid)] = np.arange(np.sum(1 - valid)) + np.sum(valid) # # print(marked2) # id2toh[np.invert(marked2)] = len(v2d1) + np.arange(np.sum(np.invert(marked2))) # id1toh = id1toh.astype(int) # id2toh = id2toh.astype(int) # tot_len = len(v2d1) + np.sum(np.invert(marked2)) # vin1 = v2d1[valid][:] # vin2 = v2d2[match12[valid]][:] # vh = 0.5 * (vin1 + vin2) # vh = np.concatenate((vh, v2d1[np.invert(valid)], v2d2[np.invert(marked2)]), axis=0) # topoh = [[] for ii in range(tot_len)] # for node in range(len(topo1)): # for nb in topo1[node]: # if int(id1toh[nb]) not in topoh[id1toh[node]]: # topoh[id1toh[node]].append(int(id1toh[nb])) # for node in range(len(topo2)): # for nb in topo2[node]: # if int(id2toh[nb]) not in topoh[id2toh[node]]: # topoh[id2toh[node]].append(int(id2toh[nb])) # return vh, topoh # def make_inter_graph_valid(v2d1, v2d2, topo1, topo2, match12): # valid = (match12 != -1) # marked2 = np.zeros(len(v2d2)).astype(bool) # # print(match12[valid]) # marked2[match12[valid]] = True # id1toh, id2toh = np.zeros(len(v2d1)), np.zeros(len(v2d2)) # id1toh[valid] = np.arange(np.sum(valid)) # id2toh[match12[valid]] = np.arange(np.sum(valid)) # id1toh[np.invert(valid)] = np.arange(np.sum(1 - valid)) + np.sum(valid) # # print(marked2) # id2toh[np.invert(marked2)] = len(v2d1) + np.arange(np.sum(np.invert(marked2))) # id1toh = id1toh.astype(int) # id2toh = id2toh.astype(int) # tot_len = len(v2d1) + np.sum(np.invert(marked2)) # vin1 = v2d1[valid][:] # vin2 = v2d2[match12[valid]][:] # vh = 0.5 * (vin1 + vin2) # # vh = np.concatenate((vh, v2d1[np.invert(valid)], v2d2[np.invert(marked2)]), axis=0) # # topoh = [[] for ii in range(tot_len)] # topoh = [[] for ii in range(np.sum(valid))] # for node in range(len(topo1)): # if not valid[node]: # continue # for nb in topo1[node]: # if int(id1toh[nb]) not in topoh[id1toh[node]]: # if valid[nb]: # topoh[id1toh[node]].append(int(id1toh[nb])) # for node in range(len(topo2)): # if not marked2[node]: # continue # for nb in topo2[node]: # if int(id2toh[nb]) not in topoh[id2toh[node]]: # if marked2[nb]: # topoh[id2toh[node]].append(int(id2toh[nb])) # return vh, topoh def visualize(dict): # print(dict['keypoints0'].size(), flush=True) img1 = ((dict['image0'][0].permute(1, 2, 0).float().numpy() + 1.0) * 255 / 2).astype(np.uint8).copy() img2 = ((dict['image1'][0].permute(1, 2, 0).float().numpy() + 1.0) * 255 / 2).astype(np.uint8).copy() original_target = ((dict['imaget'][0].permute(1, 2, 0).float().numpy() + 1.0) * 255 / 2).astype(np.uint8).copy() # img1p = ((dict['image0'].permute(1, 2, 0).float().numpy() + 1.0) * 255 / 2).astype(int).copy() # img2p = ((dict['image1'].permute(1, 2, 0).float().numpy() + 1.0) * 255 / 2).astype(int).copy() # img1[:, :, 0] += 255 # img1[:, :, 1] += 180 # img1[:, :, 2] += 180 # img1[img1 > 255] = 255 # img2[:, :, 0] += 255 # img2[:, :, 1] += 180 # img2[:, :, 2] += 180 # img2[img2 > 255] = 255 # img1p[:, :, 0] += 255 # img1p[:, :, 1] += 180 # img1p[:, :, 2] += 180 # img1p[img1p > 255] = 255 # img2p[:, :, 0] += 255 # img2p[:, :, 1] += 180 # img2p[:, :, 2] += 180 # img2p[img2p > 255] = 255 # img1, img2, img1p, img2p = img1.astype(np.uint8), img2.astype(np.uint8), img1p.astype(np.uint8), img2p.astype(np.uint8) r0 = dict['r0'][0].cpu().numpy().astype(int) r1 = dict['r1'][0].cpu().numpy().astype(int) source0_warp = dict['keypoints0t'][0].cpu().numpy().astype(int) source2_warp = dict['keypoints1t'][0].cpu().numpy().astype(int) source0 = dict['keypoints0'][0].cpu().numpy().astype(int) source2 = dict['keypoints1'][0].cpu().numpy().astype(int) source0_topo = dict['topo0'][0] # print(len(dict['topo0'])) source2_topo = dict['topo1'][0] visible01 = dict['vb0'][0].cpu().numpy().astype(int) visible21 = dict['vb1'][0].cpu().numpy().astype(int) # corr01 = dict['m01'][0].cpu().numpy().astype(int) # corr10 = dict['m10'][0].cpu().numpy().astype(int) # canvas = np.zeros_like(img1) + 255 # source0_warp2 = source0 + motion01 // 2 # source2_warp2 = source2 + motion21 // 2 # for node, nbs in enumerate(source0_topo): # for nb in nbs: # # print([source0_warp[nb][0], source0_warp[nb][1]]) # cv2.line(canvas, [source0_warp[node][0], source0_warp[node][1]], [source0_warp[nb][0], source0_warp[nb][1]], [0, 0, 0], 2) # for node, nbs in enumerate(source2_topo): # for nb in nbs: # cv2.line(canvas, [source2_warp[node][0], source2_warp[node][1]], [source2_warp[nb][0], source2_warp[nb][1]], [0, 0, 0], 2) # canvas6 = np.zeros_like(img1) + 255 # for node, nbs in enumerate(source0_topo): # for nb in nbs: # # print([source0_warp[nb][0], source0_warp[nb][1]]) # cv2.line(canvas6, [source0_warp2[node][0], source0_warp2[node][1]], [source0_warp2[nb][0], source0_warp2[nb][1]], [0, 0, 0], 2) # for node, nbs in enumerate(source2_topo): # for nb in nbs: # cv2.line(canvas6, [source2_warp2[node][0], source2_warp2[node][1]], [source2_warp2[nb][0], source2_warp2[nb][1]], [0, 0, 0], 2) canvas2 = np.zeros_like(img1) + 255 # source0_warp = source0 + motion01 # source2_warp = source2 + motion21 for node, nbs in enumerate(source0_topo): for nb in nbs: # if visible01[node] and visible01[nb]: cv2.line(canvas2, [source0_warp[node][0], source0_warp[node][1]], [source0_warp[nb][0], source0_warp[nb][1]], [0, 0, 0], 2) for node, nbs in enumerate(source2_topo): for nb in nbs: # if visible21[node] and visible21[nb]: cv2.line(canvas2, [source2_warp[node][0], source2_warp[node][1]], [source2_warp[nb][0], source2_warp[nb][1]], [0, 0, 0], 2) # canvas2 # black_threshold = 255 // 2 # img1_sketch = rgb2sketch(img1, black_threshold) # img2_sketch = rgb2sketch(img2, black_threshold) # img1_sketch = img1_sketch.unsqueeze(0) # img2_sketch = img2_sketch.unsqueeze(0) # CD = ChamferDistance2dMetric() # cd = CD(img1_sketch,img2_sketch) canvases = [np.zeros_like(img1) + 255, np.zeros_like(img1) + 255, np.zeros_like(img1) + 255, np.zeros_like(img1) + 255] canvas5 = np.zeros_like(img1) + 255 # canvas7 = np.zeros_like(img1) + 255 # canvas8 = np.zeros_like(img1) + 255 # source0_warp = source0 + motion01 # source2_warp = source2 + motion21 for ii in range(len(canvases)): source0_warp = (source0 + (ii + 1.0) / (len(canvases) + 1.0) * r0).astype(int) source2_warp = (source2 + (1 - (ii + 1.0) / (len(canvases) + 1.0)) * r1).astype(int) for node, nbs in enumerate(source0_topo): for nb in nbs: if visible01[node] and visible01[nb]: cv2.line(canvases[ii], [source0_warp[node][0], source0_warp[node][1]], [source0_warp[nb][0], source0_warp[nb][1]], [0, 0, 0], 2) for node, nbs in enumerate(source2_topo): for nb in nbs: if visible21[node] and visible21[nb]: cv2.line(canvases[ii], [source2_warp[node][0], source2_warp[node][1]], [source2_warp[nb][0], source2_warp[nb][1]], [0, 0, 0], 2) canvas3 = np.zeros_like(img1) + 255 for node, nbs in enumerate(source0_topo): for nb in nbs: cv2.line(canvas3, [source0[node][0], source0[node][1]], [source0[nb][0], source0[nb][1]], [255, 180, 180], 2) for node, nbs in enumerate(source2_topo): for nb in nbs: cv2.line(canvas3, [source2[node][0], source2[node][1]], [source2[nb][0], source2[nb][1]], [180, 180, 255], 2) #canvas6, canvas5, canvas, # im_h = cv2.hconcat([canvas3, original_target, canvas2, canvas5]) im_h = cv2.hconcat([img1] + canvases + [img2]) cd = cd_score(canvas5.copy(), original_target.copy()) * 1e5 # cv2.putText(im_h, str(cd), \ # (720, 100), cv2.FONT_HERSHEY_PLAIN, 3, (0, 0, 255), 2) return im_h, cd ================================================ FILE: utils/visualize_inbetween3.py ================================================ import numpy as np import torch import cv2 from .chamfer_distance import cd_score # def make_inter_graph(v2d1, v2d2, topo1, topo2, match12): # valid = (match12 != -1) # marked2 = np.zeros(len(v2d2)).astype(bool) # # print(match12[valid]) # marked2[match12[valid]] = True # id1toh, id2toh = np.zeros(len(v2d1)), np.zeros(len(v2d2)) # id1toh[valid] = np.arange(np.sum(valid)) # id2toh[match12[valid]] = np.arange(np.sum(valid)) # id1toh[np.invert(valid)] = np.arange(np.sum(1 - valid)) + np.sum(valid) # # print(marked2) # id2toh[np.invert(marked2)] = len(v2d1) + np.arange(np.sum(np.invert(marked2))) # id1toh = id1toh.astype(int) # id2toh = id2toh.astype(int) # tot_len = len(v2d1) + np.sum(np.invert(marked2)) # vin1 = v2d1[valid][:] # vin2 = v2d2[match12[valid]][:] # vh = 0.5 * (vin1 + vin2) # vh = np.concatenate((vh, v2d1[np.invert(valid)], v2d2[np.invert(marked2)]), axis=0) # topoh = [[] for ii in range(tot_len)] # for node in range(len(topo1)): # for nb in topo1[node]: # if int(id1toh[nb]) not in topoh[id1toh[node]]: # topoh[id1toh[node]].append(int(id1toh[nb])) # for node in range(len(topo2)): # for nb in topo2[node]: # if int(id2toh[nb]) not in topoh[id2toh[node]]: # topoh[id2toh[node]].append(int(id2toh[nb])) # return vh, topoh # def make_inter_graph_valid(v2d1, v2d2, topo1, topo2, match12): # valid = (match12 != -1) # marked2 = np.zeros(len(v2d2)).astype(bool) # # print(match12[valid]) # marked2[match12[valid]] = True # id1toh, id2toh = np.zeros(len(v2d1)), np.zeros(len(v2d2)) # id1toh[valid] = np.arange(np.sum(valid)) # id2toh[match12[valid]] = np.arange(np.sum(valid)) # id1toh[np.invert(valid)] = np.arange(np.sum(1 - valid)) + np.sum(valid) # # print(marked2) # id2toh[np.invert(marked2)] = len(v2d1) + np.arange(np.sum(np.invert(marked2))) # id1toh = id1toh.astype(int) # id2toh = id2toh.astype(int) # tot_len = len(v2d1) + np.sum(np.invert(marked2)) # vin1 = v2d1[valid][:] # vin2 = v2d2[match12[valid]][:] # vh = 0.5 * (vin1 + vin2) # # vh = np.concatenate((vh, v2d1[np.invert(valid)], v2d2[np.invert(marked2)]), axis=0) # # topoh = [[] for ii in range(tot_len)] # topoh = [[] for ii in range(np.sum(valid))] # for node in range(len(topo1)): # if not valid[node]: # continue # for nb in topo1[node]: # if int(id1toh[nb]) not in topoh[id1toh[node]]: # if valid[nb]: # topoh[id1toh[node]].append(int(id1toh[nb])) # for node in range(len(topo2)): # if not marked2[node]: # continue # for nb in topo2[node]: # if int(id2toh[nb]) not in topoh[id2toh[node]]: # if marked2[nb]: # topoh[id2toh[node]].append(int(id2toh[nb])) # return vh, topoh def visualize(dict): # print(dict['keypoints0'].size(), flush=True) img1 = ((dict['image0'][0].permute(1, 2, 0).float().numpy() + 1.0) * 255 / 2).astype(np.uint8).copy() original_target = ((dict['imaget'][0].permute(1, 2, 0).float().numpy() + 1.0) * 255 / 2).astype(np.uint8).copy() # img1p = ((dict['image0'].permute(1, 2, 0).float().numpy() + 1.0) * 255 / 2).astype(int).copy() # img2p = ((dict['image1'].permute(1, 2, 0).float().numpy() + 1.0) * 255 / 2).astype(int).copy() # img1[:, :, 0] += 255 # img1[:, :, 1] += 180 # img1[:, :, 2] += 180 # img1[img1 > 255] = 255 # img2[:, :, 0] += 255 # img2[:, :, 1] += 180 # img2[:, :, 2] += 180 # img2[img2 > 255] = 255 # img1p[:, :, 0] += 255 # img1p[:, :, 1] += 180 # img1p[:, :, 2] += 180 # img1p[img1p > 255] = 255 # img2p[:, :, 0] += 255 # img2p[:, :, 1] += 180 # img2p[:, :, 2] += 180 # img2p[img2p > 255] = 255 # img1, img2, img1p, img2p = img1.astype(np.uint8), img2.astype(np.uint8), img1p.astype(np.uint8), img2p.astype(np.uint8) motion01 = dict['motion0'][0].cpu().numpy().astype(int) motion21 = dict['motion1'][0].cpu().numpy().astype(int) source0_warp = dict['keypoints0t'][0].cpu().numpy().astype(int) source2_warp = dict['keypoints1t'][0].cpu().numpy().astype(int) source0 = dict['keypoints0'][0].cpu().numpy().astype(int) source2 = dict['keypoints1'][0].cpu().numpy().astype(int) source0_topo = dict['topo0'][0] # print(len(dict['topo0'])) source2_topo = dict['topo1'][0] visible01 = dict['vb0'][0].cpu().numpy().astype(int) visible21 = dict['vb1'][0].cpu().numpy().astype(int) # corr01 = dict['m01'][0].cpu().numpy().astype(int) # corr10 = dict['m10'][0].cpu().numpy().astype(int) # canvas = np.zeros_like(img1) + 255 # source0_warp2 = source0 + motion01 // 2 # source2_warp2 = source2 + motion21 // 2 # for node, nbs in enumerate(source0_topo): # for nb in nbs: # # print([source0_warp[nb][0], source0_warp[nb][1]]) # cv2.line(canvas, [source0_warp[node][0], source0_warp[node][1]], [source0_warp[nb][0], source0_warp[nb][1]], [0, 0, 0], 2) # for node, nbs in enumerate(source2_topo): # for nb in nbs: # cv2.line(canvas, [source2_warp[node][0], source2_warp[node][1]], [source2_warp[nb][0], source2_warp[nb][1]], [0, 0, 0], 2) # canvas6 = np.zeros_like(img1) + 255 # for node, nbs in enumerate(source0_topo): # for nb in nbs: # # print([source0_warp[nb][0], source0_warp[nb][1]]) # cv2.line(canvas6, [source0_warp2[node][0], source0_warp2[node][1]], [source0_warp2[nb][0], source0_warp2[nb][1]], [0, 0, 0], 2) # for node, nbs in enumerate(source2_topo): # for nb in nbs: # cv2.line(canvas6, [source2_warp2[node][0], source2_warp2[node][1]], [source2_warp2[nb][0], source2_warp2[nb][1]], [0, 0, 0], 2) # canvas2 = np.zeros_like(img1) + 255 # ## print('huala<<<', source0_warp.mean(), source2_warp.mean(), flush=True) # # source0_warp = source0 + motion01 # # source2_warp = source2 + motion21 # for node, nbs in enumerate(source0_topo): # for nb in nbs: # # if visible01[node] and visible01[nb]: # cv2.line(canvas2, [source0_warp[node][0], source0_warp[node][1]], [source0_warp[nb][0], source0_warp[nb][1]], [0, 0, 0], 2) # for node, nbs in enumerate(source2_topo): # for nb in nbs: # # if visible21[node] and visible21[nb]: # cv2.line(canvas2, [source2_warp[node][0], source2_warp[node][1]], [source2_warp[nb][0], source2_warp[nb][1]], [0, 0, 0], 2) # canvas2 # black_threshold = 255 // 2 # img1_sketch = rgb2sketch(img1, black_threshold) # img2_sketch = rgb2sketch(img2, black_threshold) # img1_sketch = img1_sketch.unsqueeze(0) # img2_sketch = img2_sketch.unsqueeze(0) # CD = ChamferDistance2dMetric() # cd = CD(img1_sketch,img2_sketch) canvas5 = np.zeros_like(img1) + 255 # source0_warp = source0 + motion01 # source2_warp = source2 + motion21 ## print('gulaa>>>', visible01.mean(), visible21.mean(), flush=True) for node, nbs in enumerate(source0_topo): for nb in nbs: if visible01[node] and visible01[nb]: cv2.line(canvas5, [source0_warp[node][0], source0_warp[node][1]], [source0_warp[nb][0], source0_warp[nb][1]], [0, 0, 0], 2) for node, nbs in enumerate(source2_topo): for nb in nbs: if visible21[node] and visible21[nb]: cv2.line(canvas5, [source2_warp[node][0], source2_warp[node][1]], [source2_warp[nb][0], source2_warp[nb][1]], [0, 0, 0], 2) # canvas3 = np.zeros_like(img1) + 255 # for node, nbs in enumerate(source0_topo): # for nb in nbs: # cv2.line(canvas3, [source0[node][0], source0[node][1]], [source0[nb][0], source0[nb][1]], [255, 180, 180], 2) # for node, nbs in enumerate(source2_topo): # for nb in nbs: # cv2.line(canvas3, [source2[node][0], source2[node][1]], [source2[nb][0], source2[nb][1]], [180, 180, 255], 2) # canvas_corr1 = np.zeros_like(img1) + 255 # canvas_corr2 = np.zeros_like(img1) + 255 # canvas_corr1 = ((dict['image0'][0].permute(1, 2, 0).float().numpy() + 1.0) * 255 / 2).astype(int).copy() # canvas_corr2 = ((dict['image1'][0].permute(1, 2, 0).float().numpy() + 1.0) * 255 / 2).astype(int).copy() # canvas_corr1[:, :, 0] += 255 # canvas_corr1[:, :, 1] += 180 # canvas_corr1[:, :, 2] += 180 # canvas_corr1[canvas_corr1 > 255] = 255 # canvas_corr2[:, :, 0] += 255 # canvas_corr2[:, :, 1] += 180 # canvas_corr2[:, :, 2] += 180 # canvas_corr2[canvas_corr2 > 255] = 255 # canvas_corr1 = canvas_corr1.astype(np.uint8) # canvas_corr2 = canvas_corr2.astype(np.uint8) # # colors1_gt, colors2_gt = {}, {} # colors1_pred, colors2_pred = {}, {} # # cross1_pred, cross2_pred = {}, {} # id1 = np.arange(len(source0)) # id2 = np.arange(len(source2)) # predicted = dict['matches0'].cpu().data.numpy()[0] # predicted1 = dict['matches1'].cpu().data.numpy()[0] # for index in id1: # color = [np.random.randint(0, 256), np.random.randint(0, 256), np.random.randint(0, 256)] # # print(predicted.shape, flush=True) # # if all_matches[index] != -1: # # colors2_gt[all_matches[index]] = color # if predicted[index] != -1: # colors2_pred[predicted[index]] = color # # else: # # colors2_pred[predicted[index]] = [0, 0, 0] # # colors1_gt[index] = color if all_matches[index] != -1 else [0, 0, 0] # colors1_pred[index] = color if predicted[index] != -1 else [0, 0, 0] # # if predicted[index] == -1 and colors1_pred[index] != [0, 0, 0]: # # color = [np.random.randint(0, 256), np.random.randint(0, 256), np.random.randint(0, 256)] # # colors1_pred[index] = [0, 0, 0] # # colors2_pred.pop(all_matches[index]) # # whether predicted correctly # # if predicted[index] != all_matches[index]: # # cross1_pred[index] = True # # if predicted[index] != -1: # # cross2_pred[predicted[index]] = True # for i, p in enumerate(source0): # ii = id1[i] # # print(ii) # cv2.circle(canvas_corr1, [int(p[0]), int(p[1])], 1, colors1_pred[i], 2) # # if ii in cross1_pred and cross1_pred[ii]: # # cv2.rectangle(img1p, [int(p[0]-1), int(p[1]-1)], [int(p[0]+1), int(p[1]+1)], colors1_pred[i],-1) # # else: # # cv2.circle(img1p, [int(p[0]), int(p[1])], 1, colors1_pred[i], 2) # for ii in id2: # # print(ii) # color = [0, 0, 0] # this_is_umatched = 1 # if ii not in colors2_pred: # colors2_pred[ii] = color # for i, p in enumerate(source2): # ii = id2[i] # # print(p) # # cv2.circle(img2, [int(p[0]), int( p[1])], 1, colors2_gt[ii], 2) # # if ii in cross2_pred and cross2_pred[ii]: # # cv2.rectangle(img2p, [int(p[0]-1), int(p[1]-1)], [int(p[0]+1), int(p[1]+1)], colors2_pred[i], -1) # # else: # cv2.circle(canvas_corr2, [int(p[0]), int(p[1])], 1, colors2_pred[i], 2) #canvas6, canvas5, canvas, im_h = cv2.hconcat([canvas5]) # im_h = canvas5 ## print('<<<< mean cavans5: ', canvas5.mean()) # cd = cd_score(canvas5.copy(), original_target.copy()) * 1e5 # cv2.putText(im_h, str(cd), \ # (720, 100), cv2.FONT_HERSHEY_PLAIN, 3, (0, 0, 255), 2) return im_h ================================================ FILE: utils/visualize_video.py ================================================ import numpy as np import torch import cv2 def visvid(dict, inter_frames=1): img1 = ((dict['image0'][0].permute(1, 2, 0).float().numpy() + 1.0) * 255 / 2).astype(np.uint8).copy() img2 = ((dict['image1'][0].permute(1, 2, 0).float().numpy() + 1.0) * 255 / 2).astype(np.uint8).copy() r0 = dict['r0'][0].cpu().numpy() r1 = dict['r1'][0].cpu().numpy() source0 = dict['keypoints0'][0].cpu().numpy() source2 = dict['keypoints1'][0].cpu().numpy() source0_topo = dict['ntopo0'][0] source2_topo = dict['ntopo1'][0] ori_source0_topo = dict['topo0'][0] ori_source2_topo = dict['topo1'][0] visible01 = dict['vb0'][0].cpu().numpy().astype(int) visible21 = dict['vb1'][0].cpu().numpy().astype(int) canvas1 = np.zeros_like(img1) + 255 canvas2 = np.zeros_like(img1) + 255 for node, nbs in enumerate(ori_source0_topo): for nb in nbs: cv2.line(canvas1, [source0[node][0], source0[node][1]], [source0[nb][0], source0[nb][1]], [0, 0, 0], 2) for node, nbs in enumerate(ori_source2_topo): for nb in nbs: cv2.line(canvas2, [source2[node][0], source2[node][1]], [source2[nb][0], source2[nb][1]], [0, 0, 0], 2) canvases = [ np.zeros_like(img1).copy() + 255 for jj in range(inter_frames) ] for ii in range(len(canvases)): source0_warp = (source0 + (ii + 1.0) / (len(canvases) + 1.0) * r0).astype(int) source2_warp = (source2 + (1 - (ii + 1.0) / (len(canvases) + 1.0)) * r1).astype(int) for node, nbs in enumerate(source0_topo): for nb in nbs: if visible01[node] and visible01[nb]: cv2.line(canvases[ii], [source0_warp[node][0], source0_warp[node][1]], [source0_warp[nb][0], source0_warp[nb][1]], [0, 0, 0], 2) for node, nbs in enumerate(source2_topo): for nb in nbs: if visible21[node] and visible21[nb]: cv2.line(canvases[ii], [source2_warp[node][0], source2_warp[node][1]], [source2_warp[nb][0], source2_warp[nb][1]], [0, 0, 0], 2) # if ii == 15: ## print('hulala>>>>', source0_warp.mean(), source2_warp.mean(), (ii + 1.0) / (len(canvases) + 1.0), (1 - (ii + 1.0) / (len(canvases) + 1.0)), flush=True) ## print(canvases[ii].mean()) for ii in range(len(canvases)): canvases[ii] = cv2.hconcat([canvas1, canvases[ii]]) images = [cv2.hconcat([canvas1, canvas1])] + canvases + [cv2.hconcat([canvas2, canvas2])] return images