Showing preview only (229K chars total). Download the full file or copy to clipboard to get everything.
Repository: lisiyao21/AnimeInbet
Branch: main
Commit: cc5554feb9d8
Files: 35
Total size: 217.5 KB
Directory structure:
gitextract_oa9f83a9/
├── .gitignore
├── README.md
├── compute_cd.py
├── configs/
│ └── cr_inbetweener_full.yaml
├── corr/
│ ├── configs/
│ │ └── vtx_corr.yaml
│ ├── datasets/
│ │ ├── __init__.py
│ │ └── ml_dataset.py
│ ├── experiments/
│ │ └── vtx_corr/
│ │ └── ckpt/
│ │ └── .gitkeep
│ ├── main.py
│ ├── models/
│ │ ├── __init__.py
│ │ └── supergluet.py
│ ├── srun.sh
│ ├── utils/
│ │ ├── log.py
│ │ └── visualize_vtx_corr.py
│ └── vtx_matching.py
├── data/
│ └── README.md
├── datasets/
│ ├── __init__.py
│ ├── ml_seq.py
│ └── vd_seq.py
├── download.sh
├── experiments/
│ └── inbetweener_full/
│ └── ckpt/
│ └── .gitkeep
├── inbetween.py
├── inbetween_results/
│ └── .gitkeep
├── main.py
├── models/
│ ├── __init__.py
│ ├── inbetweener_with_mask2.py
│ └── inbetweener_with_mask_with_spec.py
├── requirement.txt
├── srun.sh
└── utils/
├── chamfer_distance.py
├── log.py
├── visualize_inbetween.py
├── visualize_inbetween2.py
├── visualize_inbetween3.py
└── visualize_video.py
================================================
FILE CONTENTS
================================================
================================================
FILE: .gitignore
================================================
*/__pycache__/*
*.pt
*.jpg
*.pyc
data/ml100_norm/
data/ml144*
data/*.zip
================================================
FILE: README.md
================================================
# AnimeInbet
Code for ICCV 2023 paper "Deep Geometrized Cartoon Line Inbetweening"
[[Paper]](https://openaccess.thecvf.com/content/ICCV2023/papers/Siyao_Deep_Geometrized_Cartoon_Line_Inbetweening_ICCV_2023_paper.pdf) | [[Video Demo]](https://youtu.be/iUF-LsqFKpI?si=9FViAZUyFdSfZzS5) | [[Data (Google Drive)]](https://drive.google.com/file/d/1SNRGajIECxNwRp6ZJ0IlY7AEl2mRm2DR/view?usp=sharing)
✨ Do not hesitate to give a star! Thank you! ✨

> We aim to address a significant but understudied problem in the anime industry, namely the inbetweening of cartoon line drawings. Inbetweening involves generating intermediate frames between two black-and-white line drawings and is a time-consuming and expensive process that can benefit from automation. However, existing frame interpolation methods that rely on matching and warping whole raster images are unsuitable for line inbetweening and often produce blurring artifacts that damage the intricate line structures. To preserve the precision and detail of the line drawings, we propose a new approach, AnimeInbet, which geometrizes raster line drawings into graphs of endpoints and reframes the inbetweening task as a graph fusion problem with vertex repositioning. Our method can effectively capture the sparsity and unique structure of line drawings while preserving the details during inbetweening. This is made possible via our novel modules, i.e., vertex geometric embedding, a vertex correspondence Transformer, an effective mechanism for vertex repositioning and a visibility predictor. To train our method, we introduce MixamoLine240, a new dataset of line drawings with ground truth vectorization and matching labels. Our experiments demonstrate that AnimeInbet synthesizes high-quality, clean, and complete intermediate line drawings, outperforming existing methods quantitatively and qualitatively, especially in cases with large motions.
# ML240 Data
The implementation of AnimeInbet depends on the matching of line vertices in the two adjancent two frames. To supervise the learning of vertex correspondence, we make a large-scale cartoon line sequential data, **MixiamoLine240** (ML240). ML240 contains a training set (100 sequences), a validation set (44 sequences) and a test set (100 sequences). Each sequence i
To use the data, please first download it from [link](https://drive.google.com/file/d/1SNRGajIECxNwRp6ZJ0IlY7AEl2mRm2DR/view?usp=sharing) and uncompress it into **data** folder under this project directory. After decompression, the data will be like
data
|_ml100_norm
| |_ all
| |_frames
| | |_chip_abe
| | | |_Image0001.png
| | | |_Image0001.png
| | | |
| | | ...
| | ...
| |
| |_labels
| |_chip_abe
| | |_Line0001.json
| | |_Line0001.json
| | |
| | ...
| ...
|
|_ml144_norm_100_44_split
|_ test
| |_frames
| | |_breakdance_1990_police
| | | |_Image0001.png
| | | |_Image0001.png
| | | |
| | | ...
| | ...
| |
| |_labels
| |_breakdance_1990_police
| | |_Line0001.json
| | |_Line0001.json
| | |
| | ...
| ...
|_ train
|_frames
| |_breakdance_1990_ganfaul
| | |_Image0001.png
| | |_Image0001.png
| | |
| | ...
| ...
|
|_labels
|_breakdance_1990_ganfaul
| |_Line0001.json
| |_Line0001.json
| |
| ...
...
The json file in the "labels" folder (for example, ml100_norm/all/labels/chip_abe/Line0001.json) is the verctorization/geometrization labels of the corresponding image in the "frames" folder (ml100_norm/ all/frames/chip_abe/_Image0001.png). Each json file contains there components. (1) **vertex location**: line art vertices 2D positions, (2) **connection**: adjancent table of the vector graph and (3) **original index**: the index number of each vertex in the original 3D mesh.
# Code
## Environment
conda create -n inbetween python=3.8
conda activate inbetween
conda install pytorch==1.7.0 torchvision==0.8.0 torchaudio==0.7.0 cudatoolkit=10.1 -c pytorch
pip install -r requirement.txt

In this code, the whole pipeline is separated into two parts: (1) vertex correspondence and (2) inbetweening/synthesis. In the first part, it is trained to match the vertices of two input vector graphs, including the "vertex embedding" and "vertex corr. Transformer". Then, "repositioning propagation" and "graph fusion" are done in the second part.
The first part is inner ./corr, and the second is all others. We provide a pretrained correspondence network weight ([link](https://drive.google.com/file/d/1Edc-XGyMXqXDdfBYoglDMkBf7_AYZU0p/view?usp=sharing)) and a pretrained whole pipeline weight ([link](https://drive.google.com/file/d/1cemJCBNdcTvJ9LWCA_5LmDDorwEb-u7M/view?usp=sharing)). For correspondence, please decompress the weight (epoch_50.pt) to ./corr/experiments/vtx_corr/ckpt. For the whole pipeline, please decompress the weight (epoch_20.pt) to ./experiments/inbetweener_full/ckpt/.
## Train & test corr.
For training, first, please cd into the ./corr folder and then run
sh srun.sh configs/vtx_corr.yaml train [your node name] 1
If you don't use slurm in your computer/cluster, you can run
python -u main.py --config vtx_corr.yaml --train
For testing correspondence network, please run
sh srun.sh configs/vtx_corr.yaml train [your node name] 1
or
python -u main.py --config vtx_corr.yaml --test
You may directly run the test code after downloading the weights without training.
## Train & test the whole inbetweening pipeline
For training the whole pipeline, please firstly cd out from ./corr to the root project folder and run
sh srun.sh configs/cr_inbetweener_full.yaml train [your node name] 1
or
python -u main.py --config cr_inbetweener_full.yaml --train
For testing, please run
sh srun.sh configs/cr_inbetweener_full.yaml train [your node name] 1
or
python -u main.py --config cr_inbetweener_full.yaml --test
Inbetweened results will be stored into ./inbetween_results folder.
### Compute CD values
The CD code is under utils/chamfer_distance.py. Please run
python compute_cd.py --gt ./data/ml100_norm/all/frames --generated ./inbetween_results/test_gap=5
If everything goes right the score will be the same as that reported in the paper.
# Citation
If you use our code or data, or find our work inspiring, please kindly cite our paper:
@inproceedings{siyao2023inbetween,
title={Deep Geometrized Cartoon Line Inbetweening,
author={Siyao, Li and Gu, Tianpei and Xiao, Weiye and Ding, Henghui and Liu, Ziwei and Loy, Chen Change},
booktitle={ICCV},
year={2023}
}
# License
ML240 is released with CC BY-NC-SA 4.0. Code is released for non-commercial uses only.
================================================
FILE: compute_cd.py
================================================
import argparse
import cv2
import os
from utils.chamfer_distance import cd_score
import numpy as np
if __name__ == "__main__":
cds = []
parser = argparse.ArgumentParser()
parser.add_argument('--generated', type=str)
parser.add_argument('--gt', type=str)
args = parser.parse_args()
gen_dir = args.generated
gt_dir = args.gt
if True:
print('computing CD...', flush=True)
for subfolder in os.listdir(gt_dir):
# print(subfolder, len(cds), flush=True)
for img in os.listdir(os.path.join(gt_dir, subfolder)):
if not img.endswith('.png'):
continue
img_gt = cv2.imread(os.path.join(gt_dir, subfolder, img))
pred_name = subfolder + '_' + img.replace('Image', 'Line')
if not os.path.exists(os.path.join(gen_dir, pred_name)):
continue
img_pred = cv2.imread(os.path.join(gen_dir, pred_name))
this_cd = cd_score(img_gt, img_pred)
cds.append(this_cd)
# print(this_cd, flush=True)
print('GT: ', gt_dir)
print('>>> Gen: ', gen_dir)
print('>>> CD: ', np.mean(cds)/1e-5, print(len(cds)))
================================================
FILE: configs/cr_inbetweener_full.yaml
================================================
model:
name: InbetweenerTM
corr_model:
descriptor_dim: 128
keypoint_encoder: [32, 64, 128]
GNN_layer_num: 12
sinkhorn_iterations: 20
match_threshold: 0.2
descriptor_dim: 128
pos_weight: 0.2
optimizer:
type: Adam
kwargs:
lr: 0.0001
betas: [0.9, 0.999]
weight_decay: 0
schedular_kwargs:
milestones: [80]
gamma: 0.1
data:
train:
root: 'data/ml144_norm_100_44_split/'
batch_size: 1
gap: 5
type: 'train'
model: None
action: None
mode: 'train'
test:
root: 'data/ml100_norm/'
batch_size: 1
gap: 5
type: 'all'
model: None
action: None
mode: 'eval'
use_vs: False
testing:
ckpt_epoch: 20
batch_size: 8
corr_weights: './corr/experiments/vtx_corr/ckpt/epoch_50.pt'
imwrite_dir: ./inbetween_results/test_gap=5
expname: inbetweener_full
epoch: 20
save_per_epochs: 1
log_per_updates: 1
test_freq: 10
seed: 42
================================================
FILE: corr/configs/vtx_corr.yaml
================================================
model:
name: SuperGlueT
descriptor_dim: 128
keypoint_encoder: [32, 64, 128]
GNN_layer_num: 12
sinkhorn_iterations: 20
match_threshold: 0.2
descriptor_dim: 128
optimizer:
type: Adam
kwargs:
lr: 0.00001
betas: [0.9, 0.999]
weight_decay: 0
schedular_kwargs:
milestones: [50, 150]
gamma: 0.1
data:
train:
batch_size: 1
gap: 5
model: None
action: None
type: 'train'
mode: 'train'
test:
batch_size: 1
gap: 5
type: 'test'
model: None
action: None
mode: 'eval'
testing:
ckpt_epoch: 50
batch_size: 8
expname: vtx_corr
epoch: 50
save_per_epochs: 1
log_per_updates: 1
test_freq: 1
seed: 42
================================================
FILE: corr/datasets/__init__.py
================================================
from .ml_dataset import MixamoLineArt
from .ml_dataset import fetch_dataloader
__all__ = ['MixamoLineArt', 'fetch_dataloader']
================================================
FILE: corr/datasets/ml_dataset.py
================================================
import numpy as np
import torch
import torch.utils.data as data
import torch.nn.functional as F
# import networkx as nx
import os
import math
import random
from glob import glob
import os.path as osp
import sys
import argparse
import cv2
from collections import Counter
import json
import sknetwork
from sknetwork.embedding import Spectral
def read_json(file_path):
"""
input: json file path
output: 2d vertex, connections, and index numbers in original 3D space
"""
with open(file_path) as file:
data = json.load(file)
vertex2d = np.array(data['vertex location'])
topology = data['connection']
index = np.array(data['original index'])
return vertex2d, topology, index
def ids_to_mat(id1, id2):
"""
inputs are two list of vertex index in original 3D mesh
"""
corr1 = np.zeros(len(id1)) - 1.0
corr2 = np.zeros(len(id2)) - 1.0
id1 = np.array(id1).astype(int)[:, None]
id2 = np.array(id2).astype(int)
mat = (id1 == id2)
pos12 = np.arange(len(id2))[None].repeat(len(id1), 0)
pos21 = np.arange(len(id1))[None].repeat(len(id2), 0)
corr1[mat.astype(int).sum(1).astype(bool)] = pos12[mat]
corr2[mat.transpose().astype(int).sum(1).astype(bool)] = pos21[mat.transpose()]
return mat, corr1, corr2
def adj_matrix(topology):
"""
topology is the adj table; returns adj matrix
"""
gsize = len(topology)
adj = np.zeros((gsize, gsize)).astype(float)
for v in range(gsize):
adj[v][v] = 1.0
for nb in topology[v]:
adj[v][nb] = 1.0
adj[nb][v] = 1.0
return adj
class MixamoLineArt(data.Dataset):
def __init__(self, root, gap=0, split='train', model=None, action=None, mode='train', max_len=3050, use_vs=False):
"""
input:
root: the root folder of the line art data
gap: how many frames between two frames
split: train or test
model: indicate a specific character (default None)
action: indicate a specific action (default None)
"""
super(MixamoLineArt, self).__init__()
if model == 'None':
model = None
if action == 'None':
action = None
self.is_train = True if mode == 'train' else False
self.is_eval = True if mode == 'eval' else False
# self.is_train = False
self.max_len = max_len
self.image_list = []
self.label_list = []
if use_vs:
label_root = osp.join(root, split, 'labels_vs')
else:
label_root = osp.join(root, split, 'labels')
image_root = osp.join(root, split, 'frames')
self.spectral = Spectral(64, normalized=False)
for clip in os.listdir(image_root):
skip = False
if model != None:
for mm in model:
if mm in clip:
skip = True
if action != None:
for aa in action:
if aa in clip:
skip = True
if skip:
continue
image_list = sorted(glob(osp.join(image_root, clip, '*.png')))
label_list = sorted(glob(osp.join(label_root, clip, '*.json')))
if len(image_list) != len(label_list):
print(image_root, flush=True)
continue
for i in range(len(image_list) - (gap+1)):
self.image_list += [ [image_list[i], image_list[i+gap+1]] ]
for i in range(len(label_list) - (gap+1)):
self.label_list += [ [label_list[i], label_list[i+gap+1]] ]
# print(clip)
print('Len of Frame is ', len(self.image_list))
print('Len of Label is ', len(self.label_list))
def __getitem__(self, index):
# load image/label files
# image crop to a square, 2d label same operation
# index to index matching
# spectral embedding
# test does not need index matching
index = index % len(self.image_list)
file_name = self.label_list[index][0][:-4]
img1 = cv2.imread(self.image_list[index][0])
img2 = cv2.imread(self.image_list[index][1])
v2d1, topo1, id1 = read_json(self.label_list[index][0])
v2d2, topo2, id2 = read_json(self.label_list[index][1])
for ii in range(len(topo1)):
# if not len(topo1[ii]):
topo1[ii].append(ii)
for ii in range(len(topo2)):
topo2[ii].append(ii)
m, n = len(v2d1), len(v2d2)
# img1, v2d1 = crop_img(img1, np.array(v2d1))
# img2, v2d2 = crop_img(img2, np.array(v2d2))
if len(img1.shape) == 2:
img1 = np.tile(img1[...,None], (1, 1, 3))
img2 = np.tile(img2[...,None], (1, 1, 3))
else:
img1 = img1[..., :3]
img2 = img2[..., :3]
img1 = torch.from_numpy(img1).permute(2, 0, 1).float() * 2 / 255.0 - 1.0
img2 = torch.from_numpy(img2).permute(2, 0, 1).float() * 2 / 255.0 - 1.0
v2d1 = torch.from_numpy(v2d1)
v2d2 = torch.from_numpy(v2d2)
v2d1[v2d1 > 719] = 719
v2d1[v2d1 < 0] = 0
v2d2[v2d2 > 719] = 719
v2d2[v2d2 < 0] = 0
adj1 = sknetwork.data.from_adjacency_list(topo1, matrix_only=True, reindex=False).toarray()
adj2 = sknetwork.data.from_adjacency_list(topo2, matrix_only=True, reindex=False).toarray()
# note here we compute the spectral embedding of adj matrix in data loading period
# since it needs cpu computation and is not friendy to our cluster's computation
# put them here to use multi-cpu pre-computing before network forward flow
spec0, spec1 = np.abs(self.spectral.fit_transform(adj1)), np.abs(self.spectral.fit_transform(adj2))
mat_index, corr1, corr2 = ids_to_mat(id1, id2)
mat_index = torch.from_numpy(mat_index).float()
corr1 = torch.from_numpy(corr1).float()
corr2 = torch.from_numpy(corr2).float()
if self.is_train:
# if False:
v2d1 = torch.nn.functional.pad(v2d1, (0, 0, 0, self.max_len - m), mode='constant', value=0)
v2d2 = torch.nn.functional.pad(v2d2, (0, 0, 0, self.max_len - n), mode='constant', value=0)
corr1 = torch.nn.functional.pad(corr1, (0, self.max_len - m), mode='constant', value=0)
corr2 = torch.nn.functional.pad(corr2, (0, self.max_len - n), mode='constant', value=0)
mask0, mask1 = torch.zeros(self.max_len).float(), torch.zeros(self.max_len).float()
mask0[:m] = 1
mask1[:n] = 1
else:
mask0, mask1 = torch.ones(m).float(), torch.ones(n).float()
# not return id anymore. too slow
if self.is_eval:
return{
'keypoints0': v2d1,
'keypoints1': v2d2,
'topo0': [topo1],
'topo1': [topo2],
# 'id0': id1,
# 'id1': id2,
'adj_mat0': spec0,
'adj_mat1': spec1,
'image0': img1,
'image1': img2,
'all_matches': corr1,
'm01': corr1,
'm10': corr2,
'ms': m,
'ns': n,
'mask0': mask0,
'mask1': mask1,
'file_name': file_name,
# 'with_match': True
}
if not self.is_train:
return{
'keypoints0': v2d1,
'keypoints1': v2d2,
# 'topo0': topo1,
# 'topo1': topo2,
# 'id0': id1,
# 'id1': id2,
'adj_mat0': spec0,
'adj_mat1': spec1,
'image0': img1,
'image1': img2,
'all_matches': corr1,
'm01': corr1,
'm10': corr2,
'ms': m,
'ns': n,
'mask0': mask0,
'mask1': mask1,
'file_name': file_name,
# 'with_match': True
}
else:
return{
'keypoints0': v2d1,
'keypoints1': v2d2,
# 'topo0': topo1,
# 'topo1': topo2,
# 'id0': id1,
# 'id1': id2,
'adj_mat0': spec0,
'adj_mat1': spec1,
'image0': img1,
'image1': img2,
'all_matches': corr1,
'm01': corr1,
'm10': corr2,
'ms': m,
'ns': n,
'mask0': mask0,
'mask1': mask1,
'file_name': file_name,
# 'with_match': True
}
def __rmul__(self, v):
self.index_list = v * self.index_list
self.seg_list = v * self.seg_list
self.image_list = v * self.image_list
return self
def __len__(self):
return len(self.image_list)
def worker_init_fn(worker_id):
np.random.seed(np.random.get_state()[1][0] + worker_id)
def fetch_dataloader(args, type='train',):
lineart = MixamoLineArt(root=args.root if hasattr(args, 'root') else '../data/ml144_norm_100_44_split/', gap=args.gap, split=args.type, model=args.model, action=args.action, mode=args.mode if hasattr(args, 'mode') else 'train', use_vs=args.use_vs if hasattr(args, 'use_vs') else False)
train_loader = data.DataLoader(lineart, batch_size=args.batch_size,
pin_memory=True, shuffle=True, num_workers=8, drop_last=True, worker_init_fn=worker_init_fn)
if args.mode != 'train':
loader = data.DataLoader(lineart, batch_size=args.batch_size,
pin_memory=True, shuffle=False, num_workers=8)
return train_loader
if __name__ == '__main__':
torch.multiprocessing.set_sharing_strategy('file_system')
args = argparse.Namespace()
# args.subset = 'agent'
args.batch_size = 1
args.gap = 5
args.type = 'test'
args.model = ['ganfaul', 'firlscout', 'jolleen', 'kachujin', 'knight', 'maria', 'michelle', 'peasant', 'timmy', 'uriel']
args.action = ['hip_hop', 'slash']
# args.model = None
# args.action = None
args.use_vs = False
# args.model = ['warrok', 'police']
args.action = ['breakdance', 'capoeira', 'chapa-', 'fist_fight', 'flying', 'climb', 'running', 'reaction', 'magic', 'tripping']
args.mode = 'eval'
args.root='/mnt/lustre/syli/inbetween/data/12by12/ml144_norm_100_44_split/'
# args.stage = 'anime'
# args.image_size = (368, 368)
# lineart = MixamoLineArt(root='/mnt/lustre/syli/inbetween/data/12by12/ml144/', gap=0, split='train')
lineart = fetch_dataloader(args)
# lineart = MixamoLineArt(root='/mnt/cache/syli/inbetween/data/ml100_norm/', gap=args.gap, split=args.type, model=args.model, action=args.action, mode=args.mode if hasattr(args, 'mode') else 'train')
# train_loader = data.DataLoader(lineart, batch_size=args.batch_size,
percentage = 0.0
vertex_num = 0.0
vertex_shift = 0.0
vertex_max_shift = 0.0
edges = 0.0
# for data in loader:
# print(data)
# break
unmatched_all = []
max_node = 0
for dict in lineart:
# print(dict['file_name'])
# print(dict['file_name'][0], flush=True)
v2d1 = dict['keypoints0'].numpy().astype(int)[0]
v2d2 = dict['keypoints1'].numpy().astype(int)[0]
ms = dict['ms'][0]
ns = dict['ns'][0]
# this_edges
topo = dict['topo0'][0]
for ii in range(len(topo)):
edges += len(topo[ii])
# print(len(topo), flush=True)
# print(ms, ns, flush=True)
# print(dict['keypoints0'], flush=True)
# print(dict['image0'].size(), flush=True)
v2d1 = v2d1[:ms]
v2d2 = v2d2[:ns]
m01 = dict['m01'][0][:ms]
# print(m01.shape)
# print(np.arange(len(m01))[m01 != -1], m01[m01 != -1])
# print(v2d2.shape, v2d1.shape)
shift = np.sqrt(((v2d2[m01[m01 != -1].int(), :] * 1.0 - v2d1[np.arange(len(m01))[m01 != -1],:]) ** 2).sum(-1))
vertex_num += len(v2d1)
vertex_shift += shift.mean()
vertex_max_shift += shift.max()
percentage += ((m01!=-1).float().sum() * 1.0 / len(m01) * 100.0)
print('>>>> gap=', args.gap, ' percentage=', percentage / len(lineart), ' vertex num=', vertex_num*1.0/len(lineart), 'edges num=', edges*1.0/len(lineart)/2, 'vertex shift=', vertex_shift/len(lineart), ' vertex max shift=', vertex_max_shift/len(lineart), flush=True)
# if len(v2d1) > max_node:
# max_node = len(v2d1)
# if len(v2d2) > max_node:
# max_node = len(v2d2)
# print(max_node)
# print(v2d1.shape)
# img1 = ((dict['image0'][0].permute(1, 2, 0).float().numpy() + 1.0) * 255 / 2).astype(np.uint8).copy()
# img2 = ((dict['image1'][0].permute(1, 2, 0).float().numpy() + 1.0) * 255 / 2).astype(np.uint8).copy()
# # print(v2d1.shape, img1.shape, flush=True)
# for node, nbs in enumerate(dict['topo0']):
# for nb in nbs:
# cv2.line(img1, [v2d1[node][0], v2d1[node][1]], [v2d1[nb][0], v2d1[nb][1]], [255, 180, 180], 2)
# colors1, colors2 = {}, {}
# id1 = dict['id0'][0].numpy()
# id2 = dict['id1'][0].numpy()
# for index in id1:
# # print(index)
# color = [np.random.randint(0, 256), np.random.randint(0, 256), np.random.randint(0, 256)]
# # for ii in index:
# colors1[index] = color
# colors1, colors2 = {}, {}
# for index in id1:
# # print(index)
# color = [np.random.randint(0, 256), np.random.randint(0, 256), np.random.randint(0, 256)]
# colors1[index] = color
# for i, p in enumerate(v2d1):
# ii = id1[i]
# # print(ii)
# cv2.circle(img1, [int(p[0]), int(p[1])], 1, colors1[ii], 2)
# unmatched = 0
# for ii in id2:
# color = [0, 0, 0]
# this_is_umatched = 1
# colors2[ii] = colors1[ii] if ii in colors1 else color
# if ii in colors1:
# this_is_umatched = 0
# # if ii not in colors1:
# unmatched += this_is_umatched
# for i, p in enumerate(v2d2):
# ii = id2[i]
# # print(p)
# cv2.circle(img2, [int(p[0]), int( p[1])], 1, colors2[ii], 2)
# # print('Unmatched in Img 2: ', , '%')
# unmatched_all.append(100 - unmatched * 100.0/len(v2d2))
# im_h = cv2.hconcat([img1, img2])
# print('/mnt/lustre/syli/inbetween/AnimeInbetween/corr/datasets/data_check_norm/' + dict['file_name'][0].replace('/', '_') + '.jpg', flush=True)
# cv2.imwrite('/mnt/lustre/syli/inbetween/AnimeInbetween/corr/datasets/data_check_norm/' + dict['file_name'][0].replace('/', '_') + '.jpg', im_h)
# print(np.mean(unmatched_all))
================================================
FILE: corr/experiments/vtx_corr/ckpt/.gitkeep
================================================
================================================
FILE: corr/main.py
================================================
from vtx_matching import VtxMat
import argparse
import os
import yaml
from pprint import pprint
from easydict import EasyDict
def parse_args():
parser = argparse.ArgumentParser(
description='Anime segment matching')
parser.add_argument('--config', default='')
# exclusive arguments
group = parser.add_mutually_exclusive_group(required=True)
group.add_argument('--train', action='store_true')
group.add_argument('--eval', action='store_true')
return parser.parse_args()
def main():
# parse arguments and load config
args = parse_args()
with open(args.config) as f:
config = yaml.load(f)
for k, v in vars(args).items():
config[k] = v
pprint(config)
config = EasyDict(config)
agent = VtxMat(config)
print(config)
if args.train:
agent.train()
elif args.eval:
agent.eval()
if __name__ == '__main__':
main()
================================================
FILE: corr/models/__init__.py
================================================
from .supergluet import SuperGlueT
# from .supergluet_wo_OT import SuperGlueTwoOT
# from .supergluenp import SuperGlue as SuperGlueNP
# from .supergluei import SuperGlue as SuperGlueI
# from .supergluet2 import SuperGlueT2
__all__ = ['SuperGlueT']
================================================
FILE: corr/models/supergluet.py
================================================
import numpy as np
from copy import deepcopy
from pathlib import Path
import torch
from torch import nn
import argparse
from sknetwork.embedding import Spectral
def MLP(channels: list, do_bn=True):
""" Multi-layer perceptron """
n = len(channels)
layers = []
for i in range(1, n):
layers.append(
nn.Conv1d(channels[i - 1], channels[i], kernel_size=1, bias=True))
if i < (n-1):
if do_bn:
layers.append(nn.InstanceNorm1d(channels[i]))
layers.append(nn.ReLU())
return nn.Sequential(*layers)
def normalize_keypoints(kpts, image_shape):
""" Normalize keypoints locations based on image image_shape"""
_, _, height, width = image_shape
one = kpts.new_tensor(1)
size = torch.stack([one*width, one*height])[None]
center = size / 2
scaling = size.max(1, keepdim=True).values * 0.7
return (kpts - center[:, None, :]) / scaling[:, None, :]
class ThreeLayerEncoder(nn.Module):
""" Joint encoding of visual appearance and location using MLPs"""
def __init__(self, enc_dim):
super().__init__()
# input must be 3 channel (r, g, b)
self.layer1 = nn.Conv2d(3, enc_dim//4, 7, padding=3)
self.non_linear1 = nn.ReLU()
self.layer2 = nn.Conv2d(enc_dim//4, enc_dim//2, 3, padding=1)
self.non_linear2 = nn.ReLU()
self.layer3 = nn.Conv2d(enc_dim//2, enc_dim, 3, padding=1)
self.norm1 = nn.InstanceNorm2d(enc_dim//4)
self.norm2 = nn.InstanceNorm2d(enc_dim//2)
self.norm3 = nn.InstanceNorm2d(enc_dim)
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
nn.init.constant_(m.bias, 0.0)
def forward(self, img):
x = self.non_linear1(self.norm1(self.layer1(img)))
x = self.non_linear2(self.norm2(self.layer2(x)))
x = self.norm3(self.layer3(x))
return x
class VertexDescriptor(nn.Module):
""" Joint encoding of visual appearance and location using MLPs"""
def __init__(self, enc_dim):
super().__init__()
self.encoder = ThreeLayerEncoder(enc_dim)
def forward(self, img, vtx):
x = self.encoder(img)
n, c, h, w = x.size()
assert((h, w) == img.size()[2:4])
return x[:, :, torch.round(vtx[0, :, 1]).long(), torch.round(vtx[0, :, 0]).long()]
class KeypointEncoder(nn.Module):
""" Joint encoding of visual appearance and location using MLPs"""
def __init__(self, feature_dim, layers):
super().__init__()
self.encoder = MLP([2] + layers + [feature_dim])
# for m in self.encoder.modules():
# if isinstance(m, nn.Conv2d):
# nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
# nn.init.constant_(m.bias, 0.0)
nn.init.constant_(self.encoder[-1].bias, 0.0)
def forward(self, kpts):
inputs = kpts.transpose(1, 2)
x = self.encoder(inputs)
return x
class TopoEncoder(nn.Module):
""" Joint encoding of visual appearance and location using MLPs"""
def __init__(self, feature_dim, layers):
super().__init__()
self.encoder = MLP([64] + layers + [feature_dim])
# for m in self.encoder.modules():
# if isinstance(m, nn.Conv2d):
# nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
# nn.init.constant_(m.bias, 0.0)
nn.init.constant_(self.encoder[-1].bias, 0.0)
def forward(self, kpts):
inputs = kpts.transpose(1, 2)
x = self.encoder(inputs)
return x
def attention(query, key, value, mask=None):
dim = query.shape[1]
scores = torch.einsum('bdhn,bdhm->bhnm', query, key) / dim**.5
if mask is not None:
scores = scores.masked_fill(mask==0, float('-inf'))
prob = torch.nn.functional.softmax(scores, dim=-1)
return torch.einsum('bhnm,bdhm->bdhn', prob, value), prob
class MultiHeadedAttention(nn.Module):
""" Multi-head attention to increase model expressivitiy """
def __init__(self, num_heads: int, d_model: int):
super().__init__()
assert d_model % num_heads == 0
self.dim = d_model // num_heads
self.num_heads = num_heads
self.merge = nn.Conv1d(d_model, d_model, kernel_size=1)
self.proj = nn.ModuleList([deepcopy(self.merge) for _ in range(3)])
def forward(self, query, key, value, mask=None):
batch_dim = query.size(0)
query, key, value = [l(x).view(batch_dim, self.dim, self.num_heads, -1)
for l, x in zip(self.proj, (query, key, value))]
x, prob = attention(query, key, value, mask)
return self.merge(x.contiguous().view(batch_dim, self.dim*self.num_heads, -1))
class AttentionalPropagation(nn.Module):
def __init__(self, feature_dim: int, num_heads: int):
super().__init__()
self.attn = MultiHeadedAttention(num_heads, feature_dim)
self.mlp = MLP([feature_dim*2, feature_dim*2, feature_dim])
nn.init.constant_(self.mlp[-1].bias, 0.0)
def forward(self, x, source, mask=None):
message = self.attn(x, source, source, mask)
return self.mlp(torch.cat([x, message], dim=1))
class AttentionalGNN(nn.Module):
def __init__(self, feature_dim: int, layer_names: list):
super().__init__()
self.layers = nn.ModuleList([
AttentionalPropagation(feature_dim, 4)
for _ in range(len(layer_names))])
self.names = layer_names
def forward(self, desc0, desc1, mask00=None, mask11=None, mask01=None, mask10=None):
for layer, name in zip(self.layers, self.names):
layer.attn.prob = []
if name == 'cross':
src0, src1 = desc1, desc0
mask0, mask1 = mask01[:, None], mask10[:, None]
else: # if name == 'self':
src0, src1 = desc0, desc1
mask0, mask1 = mask00[:, None], mask11[:, None]
delta0, delta1 = layer(desc0, src0, mask0), layer(desc1, src1, mask1)
desc0, desc1 = (desc0 + delta0), (desc1 + delta1)
return desc0, desc1
def log_sinkhorn_iterations(Z, log_mu, log_nu, iters: int):
""" Perform Sinkhorn Normalization in Log-space for stability"""
u, v = torch.zeros_like(log_mu), torch.zeros_like(log_nu)
for _ in range(iters):
u = log_mu - torch.logsumexp(Z + v.unsqueeze(1), dim=2)
v = log_nu - torch.logsumexp(Z + u.unsqueeze(2), dim=1)
return Z + u.unsqueeze(2) + v.unsqueeze(1)
def log_optimal_transport(scores, alpha, iters: int, ms=None, ns=None):
""" Perform Differentiable Optimal Transport in Log-space for stability"""
b, m, n = scores.shape
one = scores.new_tensor(1)
if ms is None or ns is None:
ms, ns = (m*one).to(scores), (n*one).to(scores)
# else:
# ms, ns = ms.to(scores)[:, None], ns.to(scores)[:, None]
# here m,n should be parameters not shape
# ms, ns: (b, )
bins0 = alpha.expand(b, m, 1)
bins1 = alpha.expand(b, 1, n)
alpha = alpha.expand(b, 1, 1)
# pad additional scores for unmatcheed (to -1)
# alpha is the learned threshold
couplings = torch.cat([torch.cat([scores, bins0], -1),
torch.cat([bins1, alpha], -1)], 1)
norm = - (ms + ns).log() # (b, )
# print(scores.min(), flush=True)
if ms.size()[0] > 0:
norm = norm[:, None]
log_mu = torch.cat([norm.expand(b, m), ns.log()[:, None] + norm], dim=-1) # (m + 1)
log_nu = torch.cat([norm.expand(b, n), ms.log()[:, None] + norm], dim=-1)
# print(log_nu.min(), log_mu.min(), flush=True)
else:
log_mu = torch.cat([norm.expand(m), ns.log()[None] + norm]) # (m + 1)
log_nu = torch.cat([norm.expand(n), ms.log()[None] + norm])
log_mu, log_nu = log_mu[None].expand(b, -1), log_nu[None].expand(b, -1)
Z = log_sinkhorn_iterations(couplings, log_mu, log_nu, iters)
if ms.size()[0] > 1:
norm = norm[:, :, None]
Z = Z - norm # multiply probabilities by M+N
return Z
def arange_like(x, dim: int):
return x.new_ones(x.shape[dim]).cumsum(0) - 1 # traceable in 1.1
class SuperGlueT(nn.Module):
def __init__(self, config=None):
super().__init__()
default_config = argparse.Namespace()
default_config.descriptor_dim = 128
# default_config.weights =
default_config.keypoint_encoder = [32, 64, 128]
default_config.GNN_layers = ['self', 'cross'] * 9
default_config.sinkhorn_iterations = 100
default_config.match_threshold = 0.2
# self.config = {**self.default_config, **config}
if config is None:
self.config = default_config
else:
self.config = config
self.config.GNN_layers = ['self', 'cross'] * self.config.GNN_layer_num
# print('WULA!', self.config.GNN_layer_num)
self.kenc = KeypointEncoder(
self.config.descriptor_dim, self.config.keypoint_encoder)
self.tenc = TopoEncoder(
self.config.descriptor_dim, [96])
self.gnn = AttentionalGNN(
self.config.descriptor_dim, self.config.GNN_layers)
self.final_proj = nn.Conv1d(
self.config.descriptor_dim, self.config.descriptor_dim,
kernel_size=1, bias=True)
bin_score = torch.nn.Parameter(torch.tensor(1.))
self.register_parameter('bin_score', bin_score)
self.vertex_desc = VertexDescriptor(self.config.descriptor_dim)
def forward(self, data):
"""Run SuperGlue on a pair of keypoints and descriptors"""
kpts0, kpts1 = data['keypoints0'].float(), data['keypoints1'].float()
ori_mask0, ori_mask1 = data['mask0'].float(), data['mask1'].float()
dim_m, dim_n = data['ms'].float(), data['ns'].float()
spec0, spec1 = data['adj_mat0'], data['adj_mat1']
mmax = dim_m.int().max()
nmax = dim_n.int().max()
mask0 = ori_mask0[:, :mmax]
mask1 = ori_mask1[:, :nmax]
kpts0 = kpts0[:, :mmax]
kpts1 = kpts1[:, :nmax]
desc0, desc1 = self.vertex_desc(data['image0'], kpts0.float()), self.vertex_desc(data['image1'], kpts1.float())
# spec0, spec1 = np.abs(self.spectral.fit_transform(topo0[0].cpu().numpy())), np.abs(self.spectral.fit_transform(topo1[0].cpu().numpy()))
desc0 = desc0 + self.tenc(desc0.new_tensor(spec0))
desc1 = desc1 + self.tenc(desc1.new_tensor(spec1))
mask00 = torch.ones_like(mask0)[:, :, None] * mask0[:, None, :]
mask11 = torch.ones_like(mask1)[:, :, None] * mask1[:, None, :]
mask01 = torch.ones_like(mask0)[:, :, None] * mask1[:, None, :]
mask10 = torch.ones_like(mask1)[:, :, None] * mask0[:, None, :]
if kpts0.shape[1] < 2 or kpts1.shape[1] < 2: # no keypoints
shape0, shape1 = kpts0.shape[:-1], kpts1.shape[:-1]
# print(data['file_name'])
return {
'matches0': kpts0.new_full(shape0, -1, dtype=torch.int)[0],
# 'matches1': kpts1.new_full(shape1, -1, dtype=torch.int)[0],
'matching_scores0': kpts0.new_zeros(shape0)[0],
# 'matching_scores1': kpts1.new_zeros(shape1)[0],
'skip_train': True
}
file_name = data['file_name']
all_matches = data['all_matches'] if 'all_matches' in data else None# shape = (1, K1)
# positional embedding
# Keypoint normalization.
kpts0 = normalize_keypoints(kpts0, data['image0'].shape)
kpts1 = normalize_keypoints(kpts1, data['image1'].shape)
# Keypoint MLP encoder.
pos0 = self.kenc(kpts0)
pos1 = self.kenc(kpts1)
desc0 = desc0 + pos0
desc1 = desc1 + pos1
# Multi-layer Transformer network.
desc0, desc1 = self.gnn(desc0, desc1, mask00, mask11, mask01, mask10)
# Final MLP projection.
mdesc0, mdesc1 = self.final_proj(desc0), self.final_proj(desc1)
# Compute matching descriptor distance.
scores = torch.einsum('bdn,bdm->bnm', mdesc0, mdesc1)
# b k1 k2
scores = scores / self.config.descriptor_dim**.5
mask01 = mask0[:, :, None] * mask1[:, None, :]
scores = scores.masked_fill(mask01 == 0, float('-inf'))
# Run the optimal transport.
scores = log_optimal_transport(
scores, self.bin_score,
iters=self.config.sinkhorn_iterations,
ms=dim_m, ns=dim_n)
# Get the matches with score above "match_threshold".
max0, max1 = scores[:, :-1, :-1].max(2), scores[:, :-1, :-1].max(1)
indices0, indices1 = max0.indices, max1.indices
mutual0 = arange_like(indices0, 1)[None] == indices1.gather(1, indices0)
mutual1 = arange_like(indices1, 1)[None] == indices0.gather(1, indices1)
zero = scores.new_tensor(0)
mscores0 = torch.where(mutual0, max0.values.exp(), zero)
mscores1 = torch.where(mutual1, mscores0.gather(1, indices1), zero)
valid0 = mutual0 & (mscores0 > self.config.match_threshold)
valid1 = mutual1 & valid0.gather(1, indices1)
valid0 = mscores0 > self.config.match_threshold
valid1 = valid0.gather(1, indices1)
indices0 = torch.where(valid0, indices0, indices0.new_tensor(-1))
indices1 = torch.where(valid1, indices1, indices1.new_tensor(-1))
# check if indexed correctly
loss = []
if all_matches is not None:
for b in range(len(dim_m)):
for i in range(int(dim_m[b])):
x = i
y = all_matches[b][i].long()
loss.append(-scores[b][x][y] ) # check batch size == 1 ?
loss_mean = torch.mean(torch.stack(loss))
loss_mean = torch.reshape(loss_mean, (1, -1))
return {
'matches0': indices0, # use -1 for invalid match
'matches1': indices1, # use -1 for invalid match
'matching_scores0': mscores0,
# 'matching_scores1': mscores1[0],
'loss': loss_mean,
'skip_train': False,
'accuracy': (((all_matches[:, :mmax] == indices0) & mask0.bool()).sum() / mask0.sum()).item(),
'valid_accuracy': (((all_matches[:, :mmax] == indices0) & (all_matches[:, :mmax] != -1) & mask0.bool()).float().sum() / ((all_matches[:, :mmax] != -1) & mask0.bool()).float().sum()).item(),
}
else:
return {
'matches0': indices0[0], # use -1 for invalid match
'matching_scores0': mscores0[0],
'loss': -1,
'skip_train': True,
'accuracy': -1,
'area_accuracy': -1,
'valid_accuracy': -1,
}
if __name__ == '__main__':
args = argparse.Namespace()
args.batch_size = 1
args.gap = 0
args.type = 'train'
args.model = 'jolleen'
args.action = 'slash'
ss = SuperGlue()
loader = fetch_dataloader(args)
# #print(len(loader))
for data in loader:
# p1, p2, s1, s2, mi = data
dict1 = data
kp1 = dict1['keypoints0']
kp2 = dict1['keypoints1']
p1 = dict1['image0']
p2 = dict1['image1']
# #print(s1)
# #print(s1.type)
mi = dict1['all_matches']
fname = dict1['file_name']
print(kp1.shape, p1.shape, mi.shape)
# #print(mi.size())
# #print(mi)
# break
a = ss(data)
print(dict1['file_name'])
print(a['loss'])
a['loss'].backward()
# print(a['matches0'].size())
# print(a['accuracy'], a['valid_accuracy'])
================================================
FILE: corr/srun.sh
================================================
#!/bin/sh
currenttime=`date "+%Y%m%d%H%M%S"`
if [ ! -d log ]; then
mkdir log
fi
echo "[Usage] ./srun.sh config_path [train|eval] partition gpunum"
# check config exists
if [ ! -e $1 ]
then
echo "[ERROR] configuration file: $1 does not exists!"
exit
fi
if [ ! -d ${expname} ]; then
mkdir ${expname}
fi
echo "[INFO] saving results to, or loading files from: "$expname
if [ "$3" == "" ]; then
echo "[ERROR] enter partition name"
exit
fi
partition_name=$3
echo "[INFO] partition name: $partition_name"
if [ "$4" == "" ]; then
echo "[ERROR] enter gpu num"
exit
fi
gpunum=$4
gpunum=$(($gpunum<8?$gpunum:8))
echo "[INFO] GPU num: $gpunum"
((ntask=$gpunum*3))
TOOLS="srun --partition=$partition_name --cpus-per-task=8 --gres=gpu:$gpunum -N 1 --job-name=${config_suffix}"
PYTHONCMD="python -u main.py --config $1"
if [ $2 == "train" ];
then
$TOOLS $PYTHONCMD \
--train
elif [ $2 == "eval" ];
then
$TOOLS $PYTHONCMD \
--eval
fi
================================================
FILE: corr/utils/log.py
================================================
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this open-source project.
""" Define the Logger class to print log"""
import os
import sys
import logging
from datetime import datetime
class Logger:
def __init__(self, args, output_dir):
log = logging.getLogger(output_dir)
if not log.handlers:
log.setLevel(logging.DEBUG)
# if not os.path.exists(output_dir):
# os.mkdir(args.data.output_dir)
fh = logging.FileHandler(os.path.join(output_dir,'log.txt'))
fh.setLevel(logging.INFO)
ch = ProgressHandler()
ch.setLevel(logging.DEBUG)
formatter = logging.Formatter(fmt='%(asctime)s %(message)s', datefmt='%m/%d/%Y %I:%M:%S')
fh.setFormatter(formatter)
ch.setFormatter(formatter)
log.addHandler(fh)
log.addHandler(ch)
self.log = log
# setup TensorBoard
# if args.tensorboard:
# from tensorboardX import SummaryWriter
# self.writer = SummaryWriter(log_dir=args.output_dir)
# else:
self.writer = None
self.log_per_updates = args.log_per_updates
def set_progress(self, epoch, total):
self.log.info(f'Epoch: {epoch}')
self.epoch = epoch
self.i = 0
self.total = total
self.start = datetime.now()
def update(self, stats):
self.i += 1
if self.i % self.log_per_updates == 0:
remaining = str((datetime.now() - self.start) / self.i * (self.total - self.i))
remaining = remaining.split('.')[0]
updates = stats.pop('updates')
stats_str = ' '.join(f'{key}[{val:.8f}]' for key, val in stats.items())
self.log.info(f'> epoch [{self.epoch}] updates[{updates}] {stats_str} eta[{remaining}]')
if self.writer:
for key, val in stats.items():
self.writer.add_scalar(f'train/{key}', val, updates)
if self.i == self.total:
self.log.debug('\n')
self.log.debug(f'elapsed time: {str(datetime.now() - self.start).split(".")[0]}')
def log_eval(self, stats, metrics_group=None):
stats_str = ' '.join(f'{key}: {val:.8f}' for key, val in stats.items())
self.log.info(f'valid {stats_str}')
if self.writer:
for key, val in stats.items():
self.writer.add_scalar(f'valid/{key}', val, self.epoch)
# for mode, metrics in metrics_group.items():
# self.log.info(f'evaluation scores ({mode}):')
# for key, (val, _) in metrics.items():
# self.log.info(f'\t{key} {val:.4f}')
# if self.writer and metrics_group is not None:
# for key, val in stats.items():
# self.writer.add_scalar(f'valid/{key}', val, self.epoch)
# for key in list(metrics_group.values())[0]:
# group = {}
# for mode, metrics in metrics_group.items():
# group[mode] = metrics[key][0]
# self.writer.add_scalars(f'valid/{key}', group, self.epoch)
def __call__(self, msg):
self.log.info(msg)
class ProgressHandler(logging.Handler):
def __init__(self, level=logging.NOTSET):
super().__init__(level)
def emit(self, record):
log_entry = self.format(record)
if record.message.startswith('> '):
sys.stdout.write('{}\r'.format(log_entry.rstrip()))
sys.stdout.flush()
else:
sys.stdout.write('{}\n'.format(log_entry))
================================================
FILE: corr/utils/visualize_vtx_corr.py
================================================
import numpy as np
import torch
import cv2
def make_inter_graph(v2d1, v2d2, topo1, topo2, match12):
valid = (match12 != -1)
marked2 = np.zeros(len(v2d2)).astype(bool)
# print(match12[valid])
marked2[match12[valid]] = True
id1toh, id2toh = np.zeros(len(v2d1)), np.zeros(len(v2d2))
id1toh[valid] = np.arange(np.sum(valid))
id2toh[match12[valid]] = np.arange(np.sum(valid))
id1toh[np.invert(valid)] = np.arange(np.sum(1 - valid)) + np.sum(valid)
# print(marked2)
id2toh[np.invert(marked2)] = len(v2d1) + np.arange(np.sum(np.invert(marked2)))
id1toh = id1toh.astype(int)
id2toh = id2toh.astype(int)
tot_len = len(v2d1) + np.sum(np.invert(marked2))
vin1 = v2d1[valid][:]
vin2 = v2d2[match12[valid]][:]
vh = 0.5 * (vin1 + vin2)
vh = np.concatenate((vh, v2d1[np.invert(valid)], v2d2[np.invert(marked2)]), axis=0)
topoh = [[] for ii in range(tot_len)]
for node in range(len(topo1)):
for nb in topo1[node]:
if int(id1toh[nb]) not in topoh[id1toh[node]]:
topoh[id1toh[node]].append(int(id1toh[nb]))
for node in range(len(topo2)):
for nb in topo2[node]:
if int(id2toh[nb]) not in topoh[id2toh[node]]:
topoh[id2toh[node]].append(int(id2toh[nb]))
return vh, topoh
def make_inter_graph_valid(v2d1, v2d2, topo1, topo2, match12):
valid = (match12 != -1)
marked2 = np.zeros(len(v2d2)).astype(bool)
# print(match12[valid])
marked2[match12[valid]] = True
id1toh, id2toh = np.zeros(len(v2d1)), np.zeros(len(v2d2))
id1toh[valid] = np.arange(np.sum(valid))
id2toh[match12[valid]] = np.arange(np.sum(valid))
id1toh[np.invert(valid)] = np.arange(np.sum(1 - valid)) + np.sum(valid)
# print(marked2)
id2toh[np.invert(marked2)] = len(v2d1) + np.arange(np.sum(np.invert(marked2)))
id1toh = id1toh.astype(int)
id2toh = id2toh.astype(int)
tot_len = len(v2d1) + np.sum(np.invert(marked2))
vin1 = v2d1[valid][:]
vin2 = v2d2[match12[valid]][:]
vh = 0.5 * (vin1 + vin2)
# vh = np.concatenate((vh, v2d1[np.invert(valid)], v2d2[np.invert(marked2)]), axis=0)
# topoh = [[] for ii in range(tot_len)]
topoh = [[] for ii in range(np.sum(valid))]
for node in range(len(topo1)):
if not valid[node]:
continue
for nb in topo1[node]:
if int(id1toh[nb]) not in topoh[id1toh[node]]:
if valid[nb]:
topoh[id1toh[node]].append(int(id1toh[nb]))
for node in range(len(topo2)):
if not marked2[node]:
continue
for nb in topo2[node]:
if int(id2toh[nb]) not in topoh[id2toh[node]]:
if marked2[nb]:
topoh[id2toh[node]].append(int(id2toh[nb]))
return vh, topoh
def visualize(dict):
# print(dict['keypoints0'].size(), flush=True)
img1 = ((dict['image0'].permute(1, 2, 0).float().numpy() + 1.0) * 255 / 2).astype(int).copy()
img2 = ((dict['image1'].permute(1, 2, 0).float().numpy() + 1.0) * 255 / 2).astype(int).copy()
img1p = ((dict['image0'].permute(1, 2, 0).float().numpy() + 1.0) * 255 / 2).astype(int).copy()
img2p = ((dict['image1'].permute(1, 2, 0).float().numpy() + 1.0) * 255 / 2).astype(int).copy()
img1[:, :, 0] += 255
img1[:, :, 1] += 180
img1[:, :, 2] += 180
img1[img1 > 255] = 255
img2[:, :, 0] += 255
img2[:, :, 1] += 180
img2[:, :, 2] += 180
img2[img2 > 255] = 255
img1p[:, :, 0] += 255
img1p[:, :, 1] += 180
img1p[:, :, 2] += 180
img1p[img1p > 255] = 255
img2p[:, :, 0] += 255
img2p[:, :, 1] += 180
img2p[:, :, 2] += 180
img2p[img2p > 255] = 255
img1, img2, img1p, img2p = img1.astype(np.uint8), img2.astype(np.uint8), img1p.astype(np.uint8), img2p.astype(np.uint8)
# print(v2d1.shape, img1.shape, flush=True)
v2d1 = dict['keypoints0'].numpy().astype(int)
v2d2 = dict['keypoints1'].numpy().astype(int)
topo1 = dict['topo0']
topo2 = dict['topo1']
# print(topo1, flush=True)
# for node, nbs in enumerate(dict['topo0']):
# for nb in nbs:
# cv2.line(img1, [v2d1[node][0], v2d1[node][1]], [v2d1[nb][0], v2d1[nb][1]], [255, 180, 180], 2)
id1 = np.arange(len(v2d1))
id2 = np.arange(len(v2d2))
all_matches = dict['all_matches'].cpu().int().data.numpy()
predicted = dict['matches0'].cpu().data.numpy()[0]
predicted1 = dict['matches1'].cpu().data.numpy()[0]
colors1_gt, colors2_gt = {}, {}
colors1_pred, colors2_pred = {}, {}
cross1_pred, cross2_pred = {}, {}
for index in id1:
color = [np.random.randint(0, 256), np.random.randint(0, 256), np.random.randint(0, 256)]
# print(predicted.shape, flush=True)
if all_matches[index] != -1:
colors2_gt[all_matches[index]] = color
if predicted[index] != -1:
colors2_pred[predicted[index]] = color
colors1_gt[index] = color if all_matches[index] != -1 else [0, 0, 0]
colors1_pred[index] = color if predicted[index] != -1 else [0, 0, 0]
# if predicted[index] == -1 and colors1_pred[index] != [0, 0, 0]:
# color = [np.random.randint(0, 256), np.random.randint(0, 256), np.random.randint(0, 256)]
# colors1_pred[index] = [0, 0, 0]
# colors2_pred.pop(all_matches[index])
# whether predicted correctly
if predicted[index] != all_matches[index]:
cross1_pred[index] = True
if predicted[index] != -1:
cross2_pred[predicted[index]] = True
for i, p in enumerate(v2d1):
ii = id1[i]
# print(ii)
cv2.circle(img1, [int(p[0]), int(p[1])], 1, colors1_gt[i], 2)
if ii in cross1_pred and cross1_pred[ii]:
cv2.rectangle(img1p, [int(p[0]-1), int(p[1]-1)], [int(p[0]+1), int(p[1]+1)], colors1_pred[i],-1)
else:
cv2.circle(img1p, [int(p[0]), int(p[1])], 1, colors1_pred[i], 2)
for ii in id2:
# print(ii)
color = [0, 0, 0]
this_is_umatched = 1
if ii not in colors2_gt:
colors2_gt[ii] = color
if ii not in colors2_pred:
colors2_pred[ii] = color
for i, p in enumerate(v2d2):
ii = id2[i]
# print(p)
cv2.circle(img2, [int(p[0]), int( p[1])], 1, colors2_gt[ii], 2)
if ii in cross2_pred and cross2_pred[ii]:
cv2.rectangle(img2p, [int(p[0]-1), int(p[1]-1)], [int(p[0]+1), int(p[1]+1)], colors2_pred[i], -1)
else:
cv2.circle(img2p, [int(p[0]), int(p[1])], 1, colors2_pred[i], 2)
# print('Unmatched in Img 2: ', , '%')
# unmatched_all.append(100 - unmatched * 100.0/len(v2d2))
cv2.putText(img2p, str(round(np.sum(all_matches == predicted) * 100.0 / len(predicted), 2)).format('.2f') + '%', \
(500, 100), cv2.FONT_HERSHEY_PLAIN, 3, (0, 0, 255), 2)
vh_gt, topoh_gt = make_inter_graph(v2d1, v2d2, topo1, topo2, all_matches)
vh_pred, topoh_pred = make_inter_graph(v2d1, v2d2, topo1, topo2, predicted)
vh_gt_valid, topoh_gt_valid = make_inter_graph_valid(v2d1, v2d2, topo1, topo2, all_matches)
vh_pred_valid, topoh_pred_valid = make_inter_graph_valid(v2d1, v2d2, topo1, topo2, predicted)
v2d1t = ((v2d2[predicted] + v2d1) * 0.5).astype(int)
v2d2t = ((v2d1[predicted1] + v2d2) * 0.5).astype(int)
vh_gt = vh_gt.astype(int)
vh_gt_valid = vh_gt_valid.astype(int)
vh_pred = vh_pred.astype(int)
vh_pred_valid = vh_pred_valid.astype(int)
imgh = np.zeros_like(img1) + 255
imghp = np.zeros_like(img1) + 255
imgh_valid = np.zeros_like(img1) + 255
imghp_valid = np.zeros_like(img1) + 255
for node, nbs in enumerate(topoh_gt):
for nb in nbs:
cv2.line(imgh, [vh_gt[node][0], vh_gt[node][1]], [vh_gt[nb][0], vh_gt[nb][1]], [0, 0, 0], 2)
for node, nbs in enumerate(topoh_pred):
for nb in nbs:
cv2.line(imghp, [vh_pred[node][0], vh_pred[node][1]], [vh_pred[nb][0], vh_pred[nb][1]], [0, 0, 0], 2)
for node, nbs in enumerate(topoh_gt_valid):
for nb in nbs:
cv2.line(imgh_valid, [vh_gt_valid[node][0], vh_gt_valid[node][1]], [vh_gt_valid[nb][0], vh_gt_valid[nb][1]], [0, 0, 0], 2)
for node, nbs in enumerate(topoh_pred_valid):
for nb in nbs:
cv2.line(imghp_valid, [vh_pred_valid[node][0], vh_pred_valid[node][1]], [vh_pred_valid[nb][0], vh_pred_valid[nb][1]], [0, 0, 0], 2)
# for node, nbs in enumerate(topo1):
# for nb in nbs:
# cv2.line(imghp_valid, [v2d1t[node][0], v2d1t[node][1]], [v2d1t[nb][0], v2d1t[nb][1]], [0, 0, 0], 2)
# for node, nbs in enumerate(topo2):
# for nb in nbs:
# cv2.line(imghp_valid, [v2d2t[node][0], v2d2t[node][1]], [v2d2t[nb][0], v2d2t[nb][1]], [0, 0, 0], 2)
im_h = cv2.hconcat([img1, img2])
im_hp = cv2.hconcat([img1p, img2p])
img_inter = cv2.hconcat([imgh, imghp])
img_inter_valid = cv2.hconcat([imgh_valid, imghp_valid])
im_hv = cv2.vconcat([im_h, im_hp, img_inter, img_inter_valid])
return im_hv
================================================
FILE: corr/vtx_matching.py
================================================
""" This script handling the training process. """
import os
import time
import random
import argparse
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data
from datasets import fetch_dataloader
import random
from utils.log import Logger
from torch.optim import *
import warnings
from tqdm import tqdm
import itertools
import pdb
import numpy as np
import models
import datetime
import sys
import json
import cv2
from utils.visualize_vtx_corr import visualize
import matplotlib.cm as cm
# from models.utils import make_matching_seg_plot
warnings.filterwarnings('ignore')
import matplotlib.pyplot as plt
import pdb
class VtxMat():
def __init__(self, args):
self.config = args
torch.backends.cudnn.benchmark = True
torch.multiprocessing.set_sharing_strategy('file_system')
self._build()
def train(self):
opt = self.config
print(opt)
model = self.model
if hasattr(self.config, 'init_weight'):
checkpoint = torch.load(self.config.init_weight)
model.load_state_dict(checkpoint['model'])
optimizer = self.optimizer
schedular = self.schedular
mean_loss = []
log = Logger(self.config, self.expdir)
updates = 0
# set seed
random.seed(opt.seed)
torch.manual_seed(opt.seed)
torch.cuda.manual_seed(opt.seed)
np.random.seed(opt.seed)
# start training
for epoch in range(1, opt.epoch+1):
np.random.seed(opt.seed + epoch)
train_loader = self.train_loader
log.set_progress(epoch, len(train_loader))
batch_loss = 0
batch_acc = 0
batch_valid_acc = 0
batch_iter = 0
model.train()
avg_time = 0
avg_num = 0
# torch.cuda.synchronize()
for i, pred in enumerate(train_loader):
# tstart = time.time()
# print(pred['file_name'])
data = model(pred)
if not data['skip_train']:
loss = data['loss'] / opt.batch_size
batch_loss += loss.item()
batch_acc += data['accuracy']
batch_valid_acc += data['valid_accuracy']
loss.backward()
batch_iter += 1
else:
print('Skip!')
## Accumulate gradient for batch training
if ((i + 1) % opt.batch_size == 0) or (i + 1 == len(train_loader)):
optimizer.step()
optimizer.zero_grad()
batch_iter = 1 if batch_iter == 0 else batch_iter
stats = {
'updates': updates,
'loss': batch_loss,
'accuracy': batch_acc / batch_iter,
'valid_accuracy': batch_valid_acc / batch_iter
}
log.update(stats)
updates += 1
batch_loss = 0
batch_acc = 0
batch_valid_acc = 0
batch_iter = 0
# torch.cuda.synchronize()
# avg_num += 1
# for name, params in model.named_parameters():
# print('-->name:, ', name, '-->grad mean', params.grad.mean())
# print("All time is ", avg_time, "AVG time is ", avg_time * 1.0 /avg_num, "number is ", avg_num, flush=True)
# save checkpoint
if epoch % opt.save_per_epochs == 0 or epoch == 1:
checkpoint = {
'model': model.state_dict(),
'config': opt,
'epoch': epoch
}
filename = os.path.join(self.ckptdir, f'epoch_{epoch}.pt')
torch.save(checkpoint, filename)
# validate
if epoch % opt.test_freq == 0:
if not os.path.exists(os.path.join(self.visdir, 'epoch' + str(epoch))):
os.mkdir(os.path.join(self.visdir, 'epoch' + str(epoch)))
eval_output_dir = os.path.join(self.visdir, 'epoch' + str(epoch))
test_loader = self.test_loader
with torch.no_grad():
# Visualize the matches.
mean_acc = []
mean_valid_acc = []
model.eval()
for i_eval, data in enumerate(tqdm(test_loader, desc='Predicting Vtx Corr...')):
pred = model(data)
# for k, v in data.items():
# pred[k] = v[0]
# pred = {**pred, **data}
mean_acc.append(pred['accuracy'])
mean_valid_acc.append(pred['valid_accuracy'])
log.log_eval({
'updates': opt.epoch,
'Accuracy': np.mean(mean_acc),
'Valid Accuracy': np.mean(mean_valid_acc),
})
print('Epoch [{}/{}]], Acc.: {:.4f}, Valid Acc.{:.4f}'
.format(epoch, opt.epoch, np.mean(mean_acc), np.mean(mean_valid_acc)) )
sys.stdout.flush()
# make_matching_plot(
# image0, image1, kpts0, kpts1, mkpts0, mkpts1, color,
# text, viz_path, stem, stem, True,
# True, False, 'Matches')
self.schedular.step()
def eval(self):
train_action = ['breakdance_1990', 'capoeira', 'chapa-giratoria', 'fist_fight', 'flying_knee', 'freehang_climb', 'running', 'shove', 'magic', 'tripping']
test_action = ['great_sword_slash', 'hip_hop_dancing']
train_model = ['ganfaul', 'girlscout', 'jolleen', 'kachujin', 'knight', 'maria_w_jj', 'michelle', 'peasant_girl', 'timmy', 'uriel_a_plotexia']
test_model = ['police', 'warrok']
log = Logger(self.config, self.expdir)
with torch.no_grad():
model = self.model.eval()
config = self.config
epoch_tested = self.config.testing.ckpt_epoch
ckpt_path = os.path.join(self.ckptdir, f"epoch_{epoch_tested}.pt")
# self.device = torch.device('cuda' if config.cuda else 'cpu')
print("Evaluation...")
checkpoint = torch.load(ckpt_path)
model.load_state_dict(checkpoint['model'])
model.eval()
if not os.path.exists(os.path.join(self.evaldir, 'epoch' + str(epoch_tested))):
os.mkdir(os.path.join(self.evaldir, 'epoch' + str(epoch_tested)))
if not os.path.exists(os.path.join(self.evaldir, 'epoch' + str(epoch_tested), 'jsons')):
os.mkdir(os.path.join(self.evaldir, 'epoch' + str(epoch_tested), 'jsons'))
eval_output_dir = os.path.join(self.evaldir, 'epoch' + str(epoch_tested))
test_loader = self.test_loader
print(len(test_loader))
mean_acc = []
mean_valid_acc = []
mean_invalid_acc = []
# 144 data
# 10x10 is for training , 2x10 (unseen model) + 10x2 (unseen action) + 2x2 (unseen model unseen action) is for test
# record the accuracy for each
mean_model_acc = []
mean_model_valid_acc = []
mean_action_acc = []
mean_action_valid_acc = []
mean_none_acc = []
mean_none_valid_acc = []
mean_matched = []
for i_eval, pred in enumerate(tqdm(test_loader, desc='Predicting Vtx Corr...')):
data = model(pred)
for k, v in pred.items():
pred[k] = v[0]
pred = {**pred, **data}
mean_acc.append(pred['accuracy'])
mean_valid_acc.append(pred['valid_accuracy'])
this_pred = (pred['matches0'] != -1).float().cpu().data.numpy().astype(np.float32)
mean_matched.append(np.mean( this_pred))
unmarked = True
for model_name in train_model:
if model_name in pred['file_name']:
mean_model_acc.append(pred['accuracy'])
mean_model_valid_acc.append(pred['valid_accuracy'])
unmarked = False
break
for action_name in train_action:
if action_name in pred['file_name']:
mean_action_acc.append(pred['accuracy'])
mean_action_valid_acc.append(pred['valid_accuracy'])
unmarked = False
break
if unmarked:
mean_none_acc.append(pred['accuracy'])
mean_action_valid_acc.append(pred['valid_accuracy'])
if 'invalid_accuracy' in pred and pred['invalid_accuracy'] is not None:
mean_invalid_acc.append(pred['invalid_accuracy'])
img_vis = visualize(pred)
cv2.imwrite(os.path.join(eval_output_dir, pred['file_name'].replace('/', '_') + '.jpg'), img_vis)
log.log_eval({
'updates': self.config.testing.ckpt_epoch,
'Accuracy': np.mean(mean_acc),
'Accuracy (Matched)': np.mean(mean_valid_acc),
'Unseen Action Accuracy': np.mean(mean_model_acc),
'Unseen Action Accuracy (Matched)': np.mean(mean_model_valid_acc),
'Unseen Model Accuracy': np.mean(mean_action_acc),
'Unseen Model Accuracy (Matched)': np.mean(mean_action_valid_acc),
'Unseen Both Accuracy': np.mean(mean_none_acc),
'Unseen Both Valid Accuracy': np.mean(mean_none_valid_acc),
'Matching Rate': np.mean(mean_matched)
})
# print ('Epoch [{}/{}]], Acc.: {:.4f}, Valid Acc.{:.4f}'
# .format(epoch, opt.epoch, np.mean(mean_acc), np.mean(mean_valid_acc)) )
sys.stdout.flush()
def _build(self):
config = self.config
self.start_epoch = 0
self._dir_setting()
self._build_model()
if not(hasattr(config, 'need_not_train_data') and config.need_not_train_data):
self._build_train_loader()
if not(hasattr(config, 'need_not_test_data') and config.need_not_train_data):
self._build_test_loader()
self._build_optimizer()
def _build_model(self):
""" Define Model """
config = self.config
if hasattr(config.model, 'name'):
print(f'Experiment Using {config.model.name}')
model_class = getattr(models, config.model.name)
model = model_class(config.model)
else:
raise NotImplementedError("Wrong Model Selection")
model = nn.DataParallel(model)
self.model = model.cuda()
def _build_train_loader(self):
config = self.config
self.train_loader = fetch_dataloader(config.data.train, type='train')
def _build_test_loader(self):
config = self.config
self.test_loader = fetch_dataloader(config.data.test, type='test')
def _build_optimizer(self):
#model = nn.DataParallel(model).to(device)
config = self.config.optimizer
try:
optim = getattr(torch.optim, config.type)
except Exception:
raise NotImplementedError('not implemented optim method ' + config.type)
self.optimizer = optim(itertools.chain(self.model.module.parameters(),
),
**config.kwargs)
self.schedular = torch.optim.lr_scheduler.MultiStepLR(self.optimizer, **config.schedular_kwargs)
def _dir_setting(self):
data = self.config.data
self.expname = self.config.expname
self.experiment_dir = os.path.join(".", "experiments")
self.expdir = os.path.join(self.experiment_dir, self.expname)
if not os.path.exists(self.expdir):
os.mkdir(self.expdir)
self.visdir = os.path.join(self.expdir, "vis") # -- imgs, videos, jsons
if not os.path.exists(self.visdir):
os.mkdir(self.visdir)
self.ckptdir = os.path.join(self.expdir, "ckpt")
if not os.path.exists(self.ckptdir):
os.mkdir(self.ckptdir)
self.evaldir = os.path.join(self.expdir, "eval")
if not os.path.exists(self.evaldir):
os.mkdir(self.evaldir)
# self.ckptdir = os.path.join(self.expdir, "ckpt")
# if not os.path.exists(self.ckptdir):
# os.mkdir(self.ckptdir)
================================================
FILE: data/README.md
================================================
================================================
FILE: datasets/__init__.py
================================================
from .ml_seq import fetch_dataloader
from .vd_seq import fetch_videoloader
__all__ = ['fetch_dataloader', 'fetch_videoloader']
================================================
FILE: datasets/ml_seq.py
================================================
import numpy as np
import torch
import torch.utils.data as data
import torch.nn.functional as F
import os
import math
import random
from glob import glob
import os.path as osp
import sys
import argparse
import cv2
from collections import Counter
import time
import json
import sknetwork
from sknetwork.embedding import Spectral
import scipy
def read_json(file_path):
"""
input: json file path
output: 2d vertex, connections and vertex index in original 3D domain
"""
with open(file_path) as file:
data = json.load(file)
vertex2d = np.array(data['vertex location'])
topology = data['connection']
index = np.array(data['original index'])
return vertex2d, topology, index
def matched_motion(v2d1, v2d2, match12, motion_pre=None):
motion = np.zeros_like(v2d1)
motion[match12 != -1] = v2d2[match12[match12 != -1]] - v2d1[match12 != -1]
if motion_pre is not None:
motion[match12 != -1] = motion[match12 != -1] + motion_pre[match12[match12 != -1]]
return motion
def unmatched_motion(topo1, v2d1, motion12, match12):
pos = np.arange(len(topo1))
masked = (match12 == -1)
round = 0
former_len = 0
while(len(pos[masked]) > 0):
this_len = len(pos[masked])
if former_len == this_len:
break
former_len = this_len
round += 1
for v in pos[masked]:
unmatched = masked[topo1[v]]
if unmatched.sum() != len(topo1[v]):
motion12[v] = np.average(motion12[topo1[v]][np.invert(unmatched)], axis=0)
masked[v] = False
if len(pos[masked] > 0):
# find the neast point for each unlabeled point
index = ((v2d1[pos[masked]][:, None, :] - v2d1[pos[np.invert(masked)]]) ** 2).sum(2).argmin(1)
motion12[pos[masked]] = motion12[pos[np.invert(masked)]][index]
masked[pos[masked]] = False
return motion12
def ids_to_mat(id1, id2):
"""
inputs are two list of vertex index in original 3D mesh
"""
corr1 = np.zeros(len(id1)) - 1.0
corr2 = np.zeros(len(id2)) - 1.0
id1 = np.array(id1).astype(int)[:, None]
id2 = np.array(id2).astype(int)
mat = (id1 == id2)
pos12 = np.arange(len(id2))[None].repeat(len(id1), 0)
pos21 = np.arange(len(id1))[None].repeat(len(id2), 0)
corr1[mat.astype(int).sum(1).astype(bool)] = pos12[mat]
corr2[mat.transpose().astype(int).sum(1).astype(bool)] = pos21[mat.transpose()]
return mat, corr1, corr2
def adj_matrix(topology):
"""
topology is the adj table; returns adj matrix
"""
gsize = len(topology)
adj = np.zeros((gsize, gsize)).astype(float)
for v in range(gsize):
adj[v][v] = 1.0
for nb in topology[v]:
adj[v][nb] = 1.0
adj[nb][v] = 1.0
return adj
class MixamoLineArtMotionSequence(data.Dataset):
def __init__(self, root, gap=0, split='train', model=None, action=None, mode='train', use_vs=False, max_len=3050):
"""
input:
root: the root folder of the line art data
gap: how many frames between two frames. gap should be an odd numbe.
split: train or test
model: indicate a specific character (default None)
action: indicate a specific action (default None)
output:
image of sources (0, 1) and output (0.5)
topo0, topo1
v2d0, v2d1
corr12, corr21
motion0-->0.5, motion1-->0.5
visibility0-->0.5, visibility 1-->0.5
"""
super(MixamoLineArtMotionSequence, self).__init__()
self.gap = gap
if model == 'None':
model = None
if action == 'None':
action = None
assert(gap%2 != 0)
self.is_train = True if mode == 'train' else False
self.is_eval = True if mode == 'eval' else False
# self.is_train = False
self.max_len = max_len
self.image_list = []
self.label_list = []
label_root = osp.join(root, split, 'labels')
self.use_vs = False
if use_vs:
print('>>>>>>>> Using VS labels')
self.use_vs = True
label_root = osp.join(root, split, 'labels_vs')
image_root = osp.join(root, split, 'frames')
self.spectral = Spectral(64, normalized=False)
for clip in os.listdir(image_root):
skip = False
if model != None:
for mm in model:
if mm in clip:
skip = True
if action != None:
for aa in action:
if aa in clip:
skip = True
if skip:
continue
image_list = sorted(glob(osp.join(image_root, clip, '*.png')))
label_list = sorted(glob(osp.join(label_root, clip, '*.json')))
if len(image_list) != len(label_list):
print(clip, flush=True)
continue
for i in range(len(image_list) - (gap+1)):
self.image_list += [ [image_list[jj] for jj in range(i, i + gap + 2)] ]
for i in range(len(label_list) - (gap+1)):
self.label_list += [ [label_list[jj] for jj in range(i, i + gap + 2)] ]
# print(clip)
print('Len of Frame is ', len(self.image_list), flush=True)
print('Len of Label is ', len(self.label_list), flush=True)
def __getitem__(self, index):
# load image/label files
# load labels:
# (a) read json (b) load image (c) make pseudo labels
# image crop to a square (720x720) before input, 2d label same operation
# index to index matching
# test does not need index matching
index = index % len(self.image_list)
file_name = self.label_list[index][len(self.label_list[index])//2][:-4]
imgt = [cv2.imread(self.image_list[index][ii]) for ii in range(0, len(self.image_list[index]))]
labelt = []
for ii in range(0, len(self.label_list[index])):
v, t, id = read_json(self.label_list[index][ii])
v[v > imgt[0].shape[0] - 1] = imgt[0].shape[0] - 1
v[v < 0] = 0
labelt.append({'keypoints': v.astype(int), 'topo': t, 'id': id})
# make motion pseudo label
motion = None
motion01 = None
start_frame = 0
gap = self.gap // 2 + 1
######### forward direction
for ii in reversed(range(start_frame + 1, start_frame + 2*gap + 1)):
img1 = imgt[ii - 1]
img2 = imgt[ii]
v2d1 = labelt[ii - 1]['keypoints'].astype(int)
v2d2 = labelt[ii]['keypoints'].astype(int)
topo1 = labelt[ii - 1]['topo']
topo2 = labelt[ii ]['topo']
id1 = labelt[ii - 1]['id']
id2 = labelt[ii]['id']
if self.use_vs:
id1 = np.arange(len(id1))
id2 = np.arange(len(id2))
_, match12, matc21 = ids_to_mat(id1, id2)
if ii <= start_frame + gap:
motion01 = matched_motion(v2d1, v2d2, match12.astype(int), motion01)
motion01 = unmatched_motion(topo1, v2d1, motion01, match12.astype(int))
motion = matched_motion(v2d1, v2d2, match12.astype(int), motion)
motion = unmatched_motion(topo1, v2d1, motion, match12.astype(int))
motion0 = motion.copy()
img2 = imgt[start_frame + gap]
v2d1 = labelt[start_frame]['keypoints'].astype(int)
source0_topo = labelt[start_frame]['topo']
target = cv2.erode(img2, np.ones((3, 3), np.uint8), iterations=1)
shift_plabel = v2d1 + motion01
visible = np.ones(len(v2d1)).astype(float)
visible[shift_plabel[:, 0] < 0] = 0
visible[shift_plabel[:, 0] >= imgt[0].shape[0]] = 0
visible[shift_plabel[:, 1] < 0] = 0
visible[shift_plabel[:, 1] >= imgt[0].shape[0]] = 0
# vertex visibility
visible[visible == 1] = (target[:, :, 0][shift_plabel[visible == 1][:, 1], shift_plabel[visible == 1][:, 0]] < 255 ).astype(float)
visible01 = visible.copy()
v2d1s = shift_plabel
# edge visibility
for node, nbs in enumerate(source0_topo):
for nb in nbs:
if visible01[nb] and visible01[node] and ((v2d1s[node] - v2d1s[nb]) ** 2).sum() / (((v2d1[node] - v2d1[nb]) ** 2).sum() + 1e-7) > 25:
visible01[nb] = False
visible01[node] = False
######## backward direction
motion = None
motion21 = None
for ii in range(start_frame + 1, start_frame + gap + gap + 1):
img2 = imgt[ii - 1]
img1 = imgt[ii]
v2d2 = labelt[ii - 1]['keypoints'].astype(int)
v2d1 = labelt[ii]['keypoints'].astype(int)
topo2 = labelt[ii - 1]['topo']
topo1 = labelt[ii ]['topo']
id1 = labelt[ii]['id']
id2 = labelt[ii - 1]['id']
if self.use_vs:
id1 = np.arange(len(id1))
id2 = np.arange(len(id2))
_, match12, _ = ids_to_mat(id1, id2)
if ii >= start_frame + gap + 1:
motion21 = matched_motion(v2d1, v2d2, match12.astype(int), motion21)
motion21 = unmatched_motion(topo1, v2d1, motion21, match12.astype(int))
motion = matched_motion(v2d1, v2d2, match12.astype(int), motion)
motion = unmatched_motion(topo1, v2d1, motion, match12.astype(int))
motion2 = motion.copy()
img1 = imgt[start_frame + 2*gap]
img2 = imgt[start_frame + gap]
v2d1 = labelt[start_frame + 2*gap]['keypoints'].astype(int)
source2_topo = labelt[start_frame + 2*gap]['topo']
shift_plabel = v2d1 + motion21
visible = np.ones(len(v2d1)).astype(float)
visible[shift_plabel[:, 0] < 0] = 0
visible[shift_plabel[:, 0] >= imgt[0].shape[0]] = 0
visible[shift_plabel[:, 1] < 0] = 0
visible[shift_plabel[:, 1] >= imgt[0].shape[0]] = 0
visible[visible == 1] = (target[:, :, 0][shift_plabel[visible == 1][:, 1], shift_plabel[visible == 1][:, 0]] < 255 ).astype(float)
visible21 = visible.copy()
v2d1s = shift_plabel
for node, nbs in enumerate(source2_topo):
for nb in nbs:
if visible21[nb] and visible21[node] and ((v2d1s[node] - v2d1s[nb]) ** 2).sum() / (((v2d1[node] - v2d1[nb]) ** 2).sum() + 1e-7) > 25:
visible21[nb] = False
visible21[node] = False
###### prepare other data
img2 = imgt[-1]
img1 = imgt[0]
v2d2 = labelt[-1]['keypoints'].astype(int)
v2d1 = labelt[0]['keypoints'].astype(int)
topo2 = labelt[-1]['topo']
topo1 = labelt[0]['topo']
m, n = len(v2d1), len(v2d2)
if len(img1.shape) == 2:
img1 = np.tile(img1[...,None], (1, 1, 3))
img2 = np.tile(img2[...,None], (1, 1, 3))
else:
img1 = img1[..., :3]
img2 = img2[..., :3]
img1 = torch.from_numpy(img1).permute(2, 0, 1).float() * 2 / 255.0 - 1.0
img2 = torch.from_numpy(img2).permute(2, 0, 1).float() * 2 / 255.0 - 1.0
imgt = torch.from_numpy(imgt[start_frame + gap]).permute(2, 0, 1).float() * 2 / 255.0 - 1.0
v2d1 = torch.from_numpy(v2d1)
v2d2 = torch.from_numpy(v2d2)
visible01 = torch.from_numpy(visible01)
visible21 = torch.from_numpy(visible21)
motion0 = torch.from_numpy(motion0)
motion2 = torch.from_numpy(motion2)
v2d1[v2d1 > imgt[0].shape[0] - 1 ] = imgt[0].shape[0] - 1
v2d1[v2d1 < 0] = 0
v2d2[v2d2 > imgt[0].shape[1] - 1] = imgt[0].shape[1] - 1
v2d2[v2d2 < 0] = 0
id1 = labelt[start_frame]['id']
id2 = labelt[-1]['id']
if self.use_vs:
id1 = np.arange(len(id1))
id2 = np.arange(len(id2))
mat_index, corr1, corr2 = ids_to_mat(id1, id2)
mat_index = torch.from_numpy(mat_index).float()
corr1 = torch.from_numpy(corr1).float()
corr2 = torch.from_numpy(corr2).float()
if self.is_train:
v2d1 = torch.nn.functional.pad(v2d1, (0, 0, 0, self.max_len - m), mode='constant', value=0)
v2d2 = torch.nn.functional.pad(v2d2, (0, 0, 0, self.max_len - n), mode='constant', value=0)
corr1 = torch.nn.functional.pad(corr1, (0, self.max_len - m), mode='constant', value=0)
corr2 = torch.nn.functional.pad(corr2, (0, self.max_len - n), mode='constant', value=0)
motion0 = torch.nn.functional.pad(motion0, (0, 0, 0, self.max_len - m), mode='constant', value=0)
motion2 = torch.nn.functional.pad(motion2, (0, 0, 0, self.max_len - n), mode='constant', value=0)
visible01 = torch.nn.functional.pad(visible01, (0, self.max_len - m), mode='constant', value=0)
visible21 = torch.nn.functional.pad(visible21, (0, self.max_len - n), mode='constant', value=0)
mask0, mask1 = torch.zeros(self.max_len).float(), torch.zeros(self.max_len).float()
mask0[:m] = 1
mask1[:n] = 1
else:
mask0, mask1 = torch.ones(m).float(), torch.ones(n).float()
for ii in range(len(topo1)):
# if not len(topo1[ii]):
topo1[ii].append(ii)
for ii in range(len(topo2)):
topo2[ii].append(ii)
adj1 = sknetwork.data.from_adjacency_list(topo1, matrix_only=True, reindex=False).toarray()
adj2 = sknetwork.data.from_adjacency_list(topo2, matrix_only=True, reindex=False).toarray()
try:
spec0, spec1 = np.abs(self.spectral.fit_transform(adj1)), np.abs(self.spectral.fit_transform(adj2))
except:
print('>>>>' + file_name, flush=True)
spec0, spec1 = np.zeros((len(adj1), 64)), np.zeros((len(adj2), 64))
# else:
# print('<<<<' + file_name, flush=True)
# adj2 = adj2 + np.eye(len(adj2))
if self.is_eval:
return{
'keypoints0': v2d1,
'keypoints1': v2d2,
'topo0': [topo1],
'topo1': [topo2],
# 'id0': id1,
# 'id1': id2,
'adj_mat0': adj1,
'adj_mat1': adj2,
'spec0': spec0,
'spec1': spec1,
'imaget': imgt,
'image0': img1,
'image1': img2,
'motion0': motion0,
'motion1': motion2,
'visibility0': visible01,
'visibility1': visible21,
'all_matches': corr1,
'm01': corr1,
'm10': corr2,
'ms': m,
'ns': n,
'mask0': mask0,
'mask1': mask1,
'file_name': file_name,
# 'with_match': True
}
elif not self.is_train:
return{
'keypoints0': v2d1,
'keypoints1': v2d2,
# 'topo0': [topo1],
# 'topo1': [topo2],
# 'id0': id1,
# 'id1': id2,
'adj_mat0': adj1,
'adj_mat1': adj2,
'spec0': spec0,
'spec1': spec1,
'imaget': imgt,
'image0': img1,
'image1': img2,
'motion0': motion0,
'motion1': motion2,
'visibility0': visible01,
'visibility1': visible21,
'all_matches': corr1,
'm01': corr1,
'm10': corr2,
'ms': m,
'ns': n,
'mask0': mask0,
'mask1': mask1,
'file_name': file_name,
# 'with_match': True
}
else:
return{
'keypoints0': v2d1,
'keypoints1': v2d2,
# 'topo0': topo1,
# 'topo1': topo2,
# 'id0': id1,
# 'id1': id2,
'adj_mat0': adj1,
'adj_mat1': adj2,
'spec0': spec0,
'spec1': spec1,
'imaget': imgt,
'motion0': motion0,
'motion1': motion2,
'visibility0': visible01,
'visibility1': visible21,
'image0': img1,
'image1': img2,
'all_matches': corr1,
'm01': corr1,
'm10': corr2,
'ms': m,
'ns': n,
'mask0': mask0,
'mask1': mask1,
'file_name': file_name,
# 'with_match': True
}
def __rmul__(self, v):
self.label_list = v * self.label_list
self.image_list = v * self.image_list
return self
def __len__(self):
return len(self.image_list)
def worker_init_fn(worker_id):
np.random.seed(np.random.get_state()[1][0] + worker_id)
def fetch_dataloader(args, type='train',):
lineart = MixamoLineArtMotionSequence(root=args.root, gap=args.gap, split=args.type, model=args.model, action=args.action, mode=args.mode if hasattr(args, 'mode') else 'train', use_vs=args.use_vs if hasattr(args, 'use_vs') else False)
if args.mode == 'train':
lineart = MixamoLineArtMotionSequence(root=args.root, gap=args.gap, split=args.type, model=args.model, action=args.action, mode=args.mode if hasattr(args, 'mode') else 'train')
if args.mode == 'train':
loader = data.DataLoader(lineart, batch_size=args.batch_size,
pin_memory=True, shuffle=True, num_workers=16, drop_last=True, worker_init_fn=worker_init_fn)
else:
loader = data.DataLoader(lineart, batch_size=args.batch_size,
pin_memory=True, shuffle=False, num_workers=8)
return loader
================================================
FILE: datasets/vd_seq.py
================================================
import numpy as np
import torch
import torch.utils.data as data
import torch.nn.functional as F
# import networkx as nx
import os
import math
import random
from glob import glob
import os.path as osp
import sys
import argparse
import cv2
from collections import Counter
import time
import json
import sknetwork
from sknetwork.embedding import Spectral
import scipy
def read_json(file_path):
"""
input: json file path
output: 2d vertex
"""
with open(file_path) as file:
data = json.load(file)
vertex2d = np.array(data['vertex location'])
topology = data['connection']
index = np.array(data['original index'])
# index, vertex2d, topology = union_pixel(vertex2d, index, topology)
# index, vertex2d, topology = union_pixel2d(vertex2d, index, topology)
return vertex2d, topology, index
class VideoLinSeq(data.Dataset):
def __init__(self, root, split='train'):
"""
input:
root: the root folder of the line art data
split: split folder
output:
image of sources (0, 1) and output (0.5)
topo0, topo1
v2d0, v2d1
"""
super(VideoLinSeq, self).__init__()
self.image_list = []
self.label_list = []
label_root = osp.join(root, split, 'labels')
image_root = osp.join(root, split, 'frames')
self.spectral = Spectral(64, normalized=False)
for clip in os.listdir(image_root):
label_list = sorted(glob(osp.join(label_root, clip, '*.json')))
for i in range(len(label_list) - 1):
self.label_list += [ [label_list[jj] for jj in range(i, i + 2)] ]
self.image_list += [ [label_list[jj].replace('labels', 'frames').replace('.json', '.png') for jj in range(i, i + 2)] ]
# print(clip)
print('Len of Frame is ', len(self.image_list), flush=True)
print('Len of Label is ', len(self.label_list), flush=True)
def __getitem__(self, index):
# prepare images
index = index % len(self.image_list)
file_name0 = self.label_list[index][0][:-5].split('/')[-1]
file_name1 = self.label_list[index][-1][:-5].split('/')[-1]
folder0 = self.label_list[index][0][:-4].split('/')[-2]
folder1 = self.label_list[index][-1][:-4].split('/')[-2]
imgt = [cv2.imread(self.image_list[index][ii]) for ii in range(0, len(self.image_list[index]))]
labelt = []
for ii in range(0, len(self.label_list[index])):
v, t, id = read_json(self.label_list[index][ii])
v[v > imgt[0].shape[0] - 1] = imgt[0].shape[0] - 1
v[v < 0] = 0
labelt.append({'keypoints': v.astype(int), 'topo': t, 'id': id})
# make motion pseudo label
###### prepare other data
img2 = imgt[-1]
img1 = imgt[0]
v2d2 = labelt[-1]['keypoints'].astype(int)
v2d1 = labelt[0]['keypoints'].astype(int)
topo2 = labelt[-1]['topo']
topo1 = labelt[0]['topo']
m, n = len(v2d1), len(v2d2)
if len(img1.shape) == 2:
img1 = np.tile(img1[...,None], (1, 1, 3))
img2 = np.tile(img2[...,None], (1, 1, 3))
else:
img1 = img1[..., :3]
img2 = img2[..., :3]
img1 = torch.from_numpy(img1).permute(2, 0, 1).float() * 2 / 255.0 - 1.0
img2 = torch.from_numpy(img2).permute(2, 0, 1).float() * 2 / 255.0 - 1.0
v2d1 = torch.from_numpy(v2d1)
v2d2 = torch.from_numpy(v2d2)
mask0, mask1 = torch.ones(m).float(), torch.ones(n).float()
v2d1[v2d1 > imgt[0].shape[0] - 1 ] = imgt[0].shape[0] - 1
v2d1[v2d1 < 0] = 0
v2d2[v2d2 > imgt[0].shape[1] - 1] = imgt[0].shape[1] - 1
v2d2[v2d2 < 0] = 0
id1 = np.arange(len(v2d1))
id2 = np.arange(len(v2d2))
for ii in range(len(topo1)):
topo1[ii].append(ii)
for ii in range(len(topo2)):
topo2[ii].append(ii)
adj1 = sknetwork.data.from_adjacency_list(topo1, matrix_only=True, reindex=False).toarray()
adj2 = sknetwork.data.from_adjacency_list(topo2, matrix_only=True, reindex=False).toarray()
try:
spec0, spec1 = np.abs(self.spectral.fit_transform(adj1)), np.abs(self.spectral.fit_transform(adj2))
except:
print('>>>>' + file_name, flush=True)
spec0, spec1 = np.zeros((len(adj1), 64)), np.zeros((len(adj2), 64))
return{
'keypoints0': v2d1,
'keypoints1': v2d2,
'topo0': [topo1],
'topo1': [topo2],
'adj_mat0': adj1,
'adj_mat1': adj2,
'spec0': spec0,
'spec1': spec1,
'image0': img1,
'image1': img2,
'ms': m,
'ns': n,
'mask0': mask0,
'mask1': mask1,
'gen_vid': True,
'file_name0': file_name0,
'file_name1': file_name1,
'folder_name0': folder0,
'folder_name1': folder1
}
def __rmul__(self, v):
self.label_list = v * self.label_list
self.image_list = v * self.image_list
return self
def __len__(self):
return len(self.image_list)
def worker_init_fn(worker_id):
np.random.seed(np.random.get_state()[1][0] + worker_id)
def fetch_videoloader(args, type='train',):
lineart = VideoLinSeq(root=args.root, split=args.type, )
loader = data.DataLoader(lineart, batch_size=args.batch_size,
pin_memory=True, shuffle=False, num_workers=8)
return loader
================================================
FILE: download.sh
================================================
cd data
gdown 1SNRGajIECxNwRp6ZJ0IlY7AEl2mRm2DR
unzip ml240data.zip
================================================
FILE: experiments/inbetweener_full/ckpt/.gitkeep
================================================
================================================
FILE: inbetween.py
================================================
""" This script handling the training process. """
import os
import time
import random
import argparse
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data
from datasets import fetch_dataloader
from datasets import fetch_videoloader
import random
from utils.log import Logger
from torch.optim import *
import warnings
from tqdm import tqdm
import itertools
import pdb
import numpy as np
import models
import datetime
import sys
import json
import cv2
from utils.visualize_inbetween3 import visualize
# from utils.visualize_inbetween import visualize
from utils.visualize_video import visvid as visgen
import matplotlib.cm as cm
# from models.utils import make_matching_seg_plot
warnings.filterwarnings('ignore')
# a, b, c, d = check_data_distribution('/mnt/lustre/lisiyao1/dance/dance2/DanceRevolution/data/aistpp_train')
import matplotlib.pyplot as plt
import pdb
class DraftRefine():
def __init__(self, args):
self.config = args
torch.backends.cudnn.benchmark = True
torch.multiprocessing.set_sharing_strategy('file_system')
self._build()
def train(self):
opt = self.config
print(opt)
# store viz results
# eval_output_dir = Path(self.expdir)
# eval_output_dir.mkdir(exist_ok=True, parents=True)
# print('Will write visualization images to',
# 'directory \"{}\"'.format(eval_output_dir))
# load training data
model = self.model
checkpoint = torch.load(self.config.corr_weights)
dict = {k.replace('module.', ''): checkpoint['model'][k] for k in checkpoint['model']}
model.module.corr.load_state_dict(dict)
if hasattr(self.config, 'init_weight'):
checkpoint = torch.load(self.config.init_weight)
model.load_state_dict(checkpoint['model'])
# if torch.cuda.is_available():
# model.cuda() # make sure it trains on GPU
# else:
# print("### CUDA not available ###")
# return
optimizer = self.optimizer
schedular = self.schedular
mean_loss = []
log = Logger(self.config, self.expdir)
updates = 0
# set seed
random.seed(opt.seed)
torch.manual_seed(opt.seed)
torch.cuda.manual_seed(opt.seed)
np.random.seed(opt.seed)
# print(opt.seed)
# start training
for epoch in range(1, opt.epoch+1):
np.random.seed(opt.seed + epoch)
train_loader = self.train_loader
log.set_progress(epoch, len(train_loader))
batch_loss = 0
batch_epe = 0
batch_acc = 0
batch_iter = 0
model.train()
avg_time = 0
avg_num = 0
# torch.cuda.synchronize()
for i, data in enumerate(train_loader):
pred = model(data)
if True:
loss = pred['loss'].mean()
# print(loss.item(), opt.batch_size)
batch_loss += loss.item() / opt.batch_size
batch_acc += pred['Visibility Acc'].mean().item() / opt.batch_size
batch_epe += pred['EPE'].mean().item() / opt.batch_size
loss.backward()
batch_iter += 1
else:
print('Skip!')
if ((i + 1) % opt.batch_size == 0) or (i + 1 == len(train_loader)):
optimizer.step()
optimizer.zero_grad()
batch_iter = 1 if batch_iter == 0 else batch_iter
stats = {
'updates': updates,
'loss': batch_loss,
'accuracy': batch_acc,
'EPE': batch_epe
}
log.update(stats)
updates += 1
batch_loss = 0
batch_acc = 0
batch_epe = 0
batch_iter = 0
# tend = time.time()
# avg_time = (tend - tstart)
# print('Time is ', avg_time)
# torch.cuda.synchronize()
# avg_num += 1
# for name, params in model.named_parameters():
# print('-->name:, ', name, '-->grad mean', params.grad.mean())
# print("All time is ", avg_time, "AVG time is ", avg_time * 1.0 /avg_num, "number is ", avg_num, flush=True)
# save checkpoint
if epoch % opt.save_per_epochs == 0 or epoch == 1:
checkpoint = {
'model': model.state_dict(),
'config': opt,
'epoch': epoch
}
filename = os.path.join(self.ckptdir, f'epoch_{epoch}.pt')
torch.save(checkpoint, filename)
# validate
if epoch % opt.test_freq == 0:
if not os.path.exists(os.path.join(self.visdir, 'epoch' + str(epoch))):
os.mkdir(os.path.join(self.visdir, 'epoch' + str(epoch)))
eval_output_dir = os.path.join(self.visdir, 'epoch' + str(epoch))
test_loader = self.test_loader
with torch.no_grad():
# Visualize the matches.
mean_acc = []
mean_epe = []
model.eval()
for i_eval, data in enumerate(tqdm(test_loader, desc='Refining motion and visibility...')):
pred = model(data)
# for k, v in data.items():
# pred[k] = v[0]
# pred = {**pred, **data}
mean_acc.append(pred['Visibility Acc'].mean().item())
mean_epe.append(pred['EPE'].mean().item())
log.log_eval({
'updates': opt.epoch,
'Visibility Accuracy': np.mean(mean_acc),
'EPE': np.mean(mean_epe),
})
print('Epoch [{}/{}]], Vis Acc.: {:.4f}, EPE: {:.4f}'
.format(epoch, opt.epoch, np.mean(mean_acc), np.mean(mean_epe)) )
sys.stdout.flush()
# make_matching_plot(
# image0, image1, kpts0, kpts1, mkpts0, mkpts1, color,
# text, viz_path, stem, stem, True,
# True, False, 'Matches')
self.schedular.step()
def eval(self):
train_action = ['breakdance_1990', 'capoeira', 'chapa-giratoria', 'fist_fight', 'flying_knee', 'freehang_climb', 'running', 'shove', 'magic', 'tripping']
test_action = ['great_sword_slash', 'hip_hop_dancing']
train_model = ['ganfaul', 'girlscout', 'jolleen', 'kachujin', 'knight', 'maria_w_jj', 'michelle', 'peasant_girl', 'timmy', 'uriel_a_plotexia']
test_model = ['police', 'warrok']
config = self.config
if not os.path.exists(config.imwrite_dir):
os.mkdir(config.imwrite_dir)
log = Logger(self.config, self.expdir)
with torch.no_grad():
model = self.model.eval()
config = self.config
epoch_tested = self.config.testing.ckpt_epoch
if epoch_tested == 0 or epoch_tested == '0':
checkpoint = torch.load(self.config.corr_weights)
dict = {k.replace('module.', ''): checkpoint['model'][k] for k in checkpoint['model']}
model.module.corr.load_state_dict(dict)
else:
ckpt_path = os.path.join(self.ckptdir, f"epoch_{epoch_tested}.pt")
# self.device = torch.device('cuda' if config.cuda else 'cpu')
print("Evaluation...")
checkpoint = torch.load(ckpt_path)
model.load_state_dict(checkpoint['model'])
model.eval()
if not os.path.exists(os.path.join(self.evaldir, 'epoch' + str(epoch_tested))):
os.mkdir(os.path.join(self.evaldir, 'epoch' + str(epoch_tested)))
if not os.path.exists(os.path.join(self.evaldir, 'epoch' + str(epoch_tested), 'jsons')):
os.mkdir(os.path.join(self.evaldir, 'epoch' + str(epoch_tested), 'jsons'))
eval_output_dir = os.path.join(self.evaldir, 'epoch' + str(epoch_tested))
test_loader = self.test_loader
print(len(test_loader))
mean_acc = []
mean_valid_acc = []
mean_invalid_acc = []
# 144 data 10x10 is for training , 2x10 (unseen model) + 10x2 (unseen action) + 2x2 (unseen model unseen action) is for test
# record the accuracy for
mean_model_acc = []
mean_model_epe = []
mean_action_acc = []
mean_action_epe = []
mean_none_acc = []
mean_none_epe = []
mean_acc = []
mean_epe = []
mean_cd = []
model.eval()
# for i_eval, data in enumerate(tqdm(test_loader, desc='Refining motion and visibility...')):
# pred = model(data)
# # for k, v in data.items():
# # pred[k] = v[0]
# # pred = {**pred, **data}
# mean_acc.append(pred['Visibility Acc'].mean().item())
# mean_epe.append(pred['EPE'].mean().item())
# log.log_eval({
# 'updates': opt.epoch,
# 'Visibility Accuracy': np.mean(mean_acc),
# 'EPE': np.mean(mean_epe),
# })
for i_eval, data in enumerate(tqdm(test_loader, desc='Predicting Vtx Corr...')):
# if i_eval == 34:
# continue
pred = model(data)
for k, v in pred.items():
# print(k, flush=True)
pred[k] = v
pred = {**pred, **data}
mean_acc.append(pred['Visibility Acc'].mean().item())
mean_epe.append(pred['EPE'].mean().item())
unmarked = True
for model_name in train_model:
if model_name in pred['file_name']:
mean_model_acc.append(pred['Visibility Acc'])
mean_model_epe.append(pred['EPE'])
unmarked = False
break
for action_name in train_action:
if action_name in pred['file_name']:
mean_action_acc.append(pred['Visibility Acc'])
mean_action_epe.append(pred['EPE'])
unmarked = False
break
if unmarked:
mean_none_acc.append(pred['Visibility Acc'])
mean_action_epe.append(pred['EPE'])
# if 'invalid_accuracy' in pred and pred['invalid_accuracy'] is not None:
# mean_invalid_acc.append(pred['invalid_accuracy'])
img_vis = visualize(pred)
# mean_cd.append(cd.item())
file_name = pred['file_name'][0].split('/')
cv2.imwrite(os.path.join(config.imwrite_dir, (file_name[-2] + '_' + file_name[-1]) + 'png'), img_vis)
# cv2.imwrite(os.path.join(eval_output_dir, pred['file_name'][0].replace('/', '_') + '.jpg'), img_vis)
log.log_eval({
'updates': self.config.testing.ckpt_epoch,
# 'mean CD': np.mean(mean_cd),
# 'Visibility Accuracy': np.mean(mean_acc),
# 'EPE': np.mean(mean_epe),
# 'Unseen Action Accuracy': np.mean(mean_model_acc),
# 'Unseen Action EPE': np.mean(mean_model_epe),
# 'Unseen Model Accuracy': np.mean(mean_action_acc),
# 'Unseen Model EPE': np.mean(mean_action_epe),
# 'Unseen Both Accuracy': np.mean(mean_none_acc),
# 'Unseen Both Valid Accuracy': np.mean(mean_none_epe)
})
# print ('Epoch [{}/{}]], Acc.: {:.4f}, Valid Acc.{:.4f}'
# .format(epoch, opt.epoch, np.mean(mean_acc), np.mean(mean_valid_acc)) )
sys.stdout.flush()
def gen(self):
log = Logger(self.config, self.viddir)
with torch.no_grad():
model = self.model.eval()
config = self.config
epoch_tested = self.config.testing.ckpt_epoch
if epoch_tested == 0 or epoch_tested == '0':
checkpoint = torch.load(self.config.corr_weights)
dict = {k.replace('module.', ''): checkpoint['model'][k] for k in checkpoint['model']}
model.module.corr.load_state_dict(dict)
else:
ckpt_path = os.path.join(self.ckptdir, f"epoch_{epoch_tested}.pt")
# self.device = torch.device('cuda' if config.cuda else 'cpu')
print("Evaluation...")
checkpoint = torch.load(ckpt_path)
model.load_state_dict(checkpoint['model'])
model.eval()
if not os.path.exists(os.path.join(self.viddir, 'epoch' + str(epoch_tested))):
os.mkdir(os.path.join(self.viddir, 'epoch' + str(epoch_tested)))
if not os.path.exists(os.path.join(self.viddir, 'epoch' + str(epoch_tested), 'frames')):
os.mkdir(os.path.join(self.viddir, 'epoch' + str(epoch_tested), 'frames'))
if not os.path.exists(os.path.join(self.viddir, 'epoch' + str(epoch_tested), 'videos')):
os.mkdir(os.path.join(self.viddir, 'epoch' + str(epoch_tested), 'videos'))
gen_frame_dir = os.path.join(self.viddir, 'epoch' + str(epoch_tested), 'frames')
gen_video_dir = os.path.join(self.viddir, 'epoch' + str(epoch_tested), 'videos')
vid_loader = self.vid_loader
print(len(vid_loader))
mean_acc = []
mean_valid_acc = []
mean_invalid_acc = []
model.eval()
for i_eval, data in enumerate(tqdm(vid_loader, desc='Gen Video...')):
pred = model(data)
for k, v in pred.items():
pred[k] = v
pred = {**pred, **data}
img_vis = visgen(pred, config.inter_frames)
if not os.path.exists(os.path.join(gen_frame_dir, pred['folder_name0'][0])):
os.mkdir(os.path.join(gen_frame_dir, pred['folder_name0'][0]))
cv2.imwrite(os.path.join(gen_frame_dir, pred['folder_name0'][0], pred['file_name0'][0] + '_000.jpg'),img_vis[0])
for tt in range(config.inter_frames):
cv2.imwrite(os.path.join(gen_frame_dir, pred['folder_name0'][0], pred['file_name0'][0] + '_' + '{:03d}'.format(tt + 1) + '.jpg'), img_vis[tt + 1])
cv2.imwrite(os.path.join(gen_frame_dir, pred['folder_name0'][0], pred['file_name1'][0] + '_000.jpg'),img_vis[-1])
for ff in os.listdir(gen_frame_dir):
frame_dir = os.path.join(gen_frame_dir, ff)
video_file = os.path.join(gen_video_dir, f"{ff}.mp4")
cmd = f"ffmpeg -r {config.fps} -pattern_type glob -i '{frame_dir}/*.jpg' -vb 20M -vcodec mpeg4 -y '{video_file}'"
print(cmd, flush=True)
os.system(cmd)
log.log_eval({
'updates': self.config.testing.ckpt_epoch,
})
sys.stdout.flush()
def _build(self):
config = self.config
self.start_epoch = 0
self._dir_setting()
self._build_model()
if not(hasattr(config, 'need_not_train_data') and config.need_not_train_data):
self._build_train_loader()
if not(hasattr(config, 'need_not_test_data') and config.need_not_train_data):
self._build_test_loader()
if hasattr(config, 'gen_video') and config.gen_video:
self._build_video_loader()
self._build_optimizer()
def _build_model(self):
""" Define Model """
config = self.config
if hasattr(config.model, 'name'):
print(f'Experiment Using {config.model.name}')
model_class = getattr(models, config.model.name)
model = model_class(config.model)
else:
raise NotImplementedError("Wrong Model Selection")
model = nn.DataParallel(model)
self.model = model.cuda()
def _build_train_loader(self):
config = self.config
self.train_loader = fetch_dataloader(config.data.train, type='train')
def _build_test_loader(self):
config = self.config
self.test_loader = fetch_dataloader(config.data.test, type='test')
def _build_video_loader(self):
config = self.config
self.vid_loader = fetch_videoloader(config.video)
def _build_optimizer(self):
#model = nn.DataParallel(model).to(device)
config = self.config.optimizer
try:
optim = getattr(torch.optim, config.type)
except Exception:
raise NotImplementedError('not implemented optim method ' + config.type)
self.optimizer = optim(itertools.chain(self.model.module.parameters(),
),
**config.kwargs)
self.schedular = torch.optim.lr_scheduler.MultiStepLR(self.optimizer, **config.schedular_kwargs)
def _dir_setting(self):
self.expname = self.config.expname
# self.experiment_dir = os.path.join("/mnt/cache/syli/inbetween", "experiments")
self.experiment_dir = 'experiments'
self.expdir = os.path.join(self.experiment_dir, self.expname)
if not os.path.exists(self.expdir):
os.mkdir(self.expdir)
self.visdir = os.path.join(self.expdir, "vis") # -- imgs, videos, jsons
if not os.path.exists(self.visdir):
os.mkdir(self.visdir)
self.ckptdir = os.path.join(self.expdir, "ckpt")
if not os.path.exists(self.ckptdir):
os.mkdir(self.ckptdir)
self.evaldir = os.path.join(self.expdir, "eval")
if not os.path.exists(self.evaldir):
os.mkdir(self.evaldir)
self.viddir = os.path.join(self.expdir, "video")
if not os.path.exists(self.viddir):
os.mkdir(self.viddir)
# self.ckptdir = os.path.join(self.expdir, "ckpt")
# if not os.path.exists(self.ckptdir):
# os.mkdir(self.ckptdir)
================================================
FILE: inbetween_results/.gitkeep
================================================
================================================
FILE: main.py
================================================
from inbetween import DraftRefine
import argparse
import os
import yaml
from pprint import pprint
from easydict import EasyDict
def parse_args():
parser = argparse.ArgumentParser(
description='Anime segment matching')
parser.add_argument('--config', default='')
# exclusive arguments
group = parser.add_mutually_exclusive_group(required=True)
group.add_argument('--train', action='store_true')
group.add_argument('--eval', action='store_true')
group.add_argument('--gen', action='store_true')
return parser.parse_args()
def main():
# parse arguments and load config
args = parse_args()
with open(args.config) as f:
config = yaml.load(f)
for k, v in vars(args).items():
config[k] = v
pprint(config)
config = EasyDict(config)
agent = DraftRefine(config)
print(config)
if args.train:
agent.train()
elif args.eval:
agent.eval()
elif args.gen:
agent.gen()
if __name__ == '__main__':
main()
================================================
FILE: models/__init__.py
================================================
# from .transformer_refiner import Refiner
# from .inbetweener import Inbetweener
# from .inbetweener_with_mask import InbetweenerM
# from .inbetweener_wo_rp import InbetweenerM as InbetweenerNRP
from .inbetweener_with_mask_with_spec import InbetweenerTM
# from .inbetweener_with_mask_with_spec_wo_OT import InbetweenerTMwoOT
from .inbetweener_with_mask2 import InbetweenerM as InbetweenerM2
# from .inbetweener_with_mask_wo_pos import InbetweenerNP
# from .inbetweener_with_mask_wo_pos_wo_spec import InbetweenerNPS
# from .transformer_refiner2 import Refiner as Refiner2
# from .transformer_refiner3 import Refiner as Refiner3
# from .transformer_refiner4 import Refiner as Refiner4
# from .transformer_refiner5 import Refiner as Refiner5
# from .transformer_refiner_norm import Refiner as RefinerN
__all__ = [ 'InbetweenerTM', 'InbetweenerM2']
================================================
FILE: models/inbetweener_with_mask2.py
================================================
from copy import deepcopy
from pathlib import Path
import torch
from torch import nn
# from seg_desc import seg_descriptor
import argparse
import torch.nn.functional as F
def MLP(channels: list, do_bn=True):
""" Multi-layer perceptron """
n = len(channels)
layers = []
for i in range(1, n):
layers.append(
nn.Conv1d(channels[i - 1], channels[i], kernel_size=1, bias=True))
if i < (n-1):
if do_bn:
# layers.append(nn.BatchNorm1d(channels[i]))
layers.append(nn.InstanceNorm1d(channels[i]))
layers.append(nn.ReLU())
return nn.Sequential(*layers)
def normalize_keypoints(kpts, image_shape):
""" Normalize keypoints locations based on image image_shape"""
_, _, height, width = image_shape
one = kpts.new_tensor(1)
size = torch.stack([one*width, one*height])[None]
center = size / 2
scaling = size.max(1, keepdim=True).values * 0.7
return (kpts - center[:, None, :]) / scaling[:, None, :]
class ThreeLayerEncoder(nn.Module):
""" Joint encoding of visual appearance and location using MLPs"""
def __init__(self, enc_dim):
super().__init__()
# input must be 3 channel (r, g, b)
self.layer1 = nn.Conv2d(3, enc_dim//4, 7, padding=3)
self.non_linear1 = nn.ReLU()
self.layer2 = nn.Conv2d(enc_dim//4, enc_dim//2, 3, padding=1)
self.non_linear2 = nn.ReLU()
self.layer3 = nn.Conv2d(enc_dim//2, enc_dim, 3, padding=1)
self.norm1 = nn.InstanceNorm2d(enc_dim//4)
self.norm2 = nn.InstanceNorm2d(enc_dim//2)
self.norm3 = nn.InstanceNorm2d(enc_dim)
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
nn.init.constant_(m.bias, 0.0)
def forward(self, img):
x = self.non_linear1(self.norm1(self.layer1(img)))
x = self.non_linear2(self.norm2(self.layer2(x)))
x = self.norm3(self.layer3(x))
# x = self.non_linear1(self.layer1(img))
# x = self.non_linear2(self.layer2(x))
# x = self.layer3(x)
return x
class VertexDescriptor(nn.Module):
""" Joint encoding of visual appearance and location using MLPs"""
def __init__(self, enc_dim):
super().__init__()
self.encoder = ThreeLayerEncoder(enc_dim)
# self.super_pixel_pooling =
# use scatter
# nn.init.constant_(self.encoder[-1].bias, 0.0)
def forward(self, img, vtx):
x = self.encoder(img)
n, c, h, w = x.size()
assert((h, w) == img.size()[2:4])
return x[:, :, torch.round(vtx[0, :, 1]).long(), torch.round(vtx[0, :, 0]).long()]
# return super_pixel_pooling(x.view(n, c, -1), seg.view(-1).long(), reduce='mean')
# here return size is [1]xCx|Seg|
class KeypointEncoder(nn.Module):
""" Joint encoding of visual appearance and location using MLPs"""
def __init__(self, feature_dim, layers):
super().__init__()
self.encoder = MLP([2] + layers + [feature_dim])
# for m in self.encoder.modules():
# if isinstance(m, nn.Conv2d):
# nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
# nn.init.constant_(m.bias, 0.0)
nn.init.constant_(self.encoder[-1].bias, 0.0)
def forward(self, kpts):
inputs = kpts.transpose(1, 2)
# print(inputs.size(), 'wula!')
x = self.encoder(inputs)
# print(x.size())
return x
def attention(query, key, value, mask=None):
dim = query.shape[1]
scores = torch.einsum('bdhn,bdhm->bhnm', query, key) / dim**.5
if mask is not None:
# print(mask, flush=True)
scores = scores.masked_fill(mask==0, float('-inf'))
# att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
# att = att.masked_fill(self.mask[:,:,:T,:T] == 0, float('-inf'))
# att = F.softmax(att, dim=-1)
prob = torch.nn.functional.softmax(scores, dim=-1)
# print(scores[1][1], prob[1][1], flush=True)
# while True:
# pass
# prob = torch.exp(scores) /((torch.sum(torch.exp(scores), dim=-1)[:, :, :, None]) + 1e-7)
return torch.einsum('bhnm,bdhm->bdhn', prob, value), prob
class MultiHeadedAttention(nn.Module):
""" Multi-head attention to increase model expressivitiy """
def __init__(self, num_heads: int, d_model: int):
super().__init__()
assert d_model % num_heads == 0
self.dim = d_model // num_heads
self.num_heads = num_heads
self.merge = nn.Conv1d(d_model, d_model, kernel_size=1)
self.proj = nn.ModuleList([deepcopy(self.merge) for _ in range(3)])
def forward(self, query, key, value, mask=None):
batch_dim = query.size(0)
query, key, value = [l(x).view(batch_dim, self.dim, self.num_heads, -1)
for l, x in zip(self.proj, (query, key, value))]
x, prob = attention(query, key, value, mask)
# self.prob.append(prob)
return self.merge(x.contiguous().view(batch_dim, self.dim*self.num_heads, -1))
class AttentionalPropagation(nn.Module):
def __init__(self, feature_dim: int, num_heads: int):
super().__init__()
self.attn = MultiHeadedAttention(num_heads, feature_dim)
self.mlp = MLP([feature_dim*2, feature_dim*2, feature_dim])
nn.init.constant_(self.mlp[-1].bias, 0.0)
def forward(self, x, source, mask=None):
message = self.attn(x, source, source, mask)
return self.mlp(torch.cat([x, message], dim=1))
class AttentionalGNN(nn.Module):
def __init__(self, feature_dim: int, layer_names: list):
super().__init__()
self.layers = nn.ModuleList([
AttentionalPropagation(feature_dim, 4)
for _ in range(len(layer_names))])
self.names = layer_names
def forward(self, desc0, desc1, mask00=None, mask11=None, mask01=None, mask10=None):
for layer, name in zip(self.layers, self.names):
layer.attn.prob = []
if name == 'cross':
src0, src1 = desc1, desc0
mask0, mask1 = mask01[:, None], mask10[:, None]
else: # if name == 'self':
src0, src1 = desc0, desc1
mask0, mask1 = mask00[:, None], mask11[:, None]
delta0, delta1 = layer(desc0, src0, mask0), layer(desc1, src1, mask1)
desc0, desc1 = (desc0 + delta0), (desc1 + delta1)
return desc0, desc1
def log_sinkhorn_iterations(Z, log_mu, log_nu, iters: int):
""" Perform Sinkhorn Normalization in Log-space for stability"""
u, v = torch.zeros_like(log_mu), torch.zeros_like(log_nu)
for _ in range(iters):
u = log_mu - torch.logsumexp(Z + v.unsqueeze(1), dim=2)
v = log_nu - torch.logsumexp(Z + u.unsqueeze(2), dim=1)
return Z + u.unsqueeze(2) + v.unsqueeze(1)
def log_optimal_transport(scores, alpha, iters: int, ms=None, ns=None):
""" Perform Differentiable Optimal Transport in Log-space for stability"""
b, m, n = scores.shape
one = scores.new_tensor(1)
if ms is None or ns is None:
ms, ns = (m*one).to(scores), (n*one).to(scores)
# else:
# ms, ns = ms.to(scores)[:, None], ns.to(scores)[:, None]
# here m,n should be parameters not shape
# ms, ns: (b, )
bins0 = alpha.expand(b, m, 1)
bins1 = alpha.expand(b, 1, n)
alpha = alpha.expand(b, 1, 1)
# pad additional scores for unmatcheed (to -1)
# alpha is the learned threshold
couplings = torch.cat([torch.cat([scores, bins0], -1),
torch.cat([bins1, alpha], -1)], 1)
norm = - (ms + ns).log() # (b, )
# print(scores.min(), flush=True)
if ms.size()[0] > 0:
norm = norm[:, None]
log_mu = torch.cat([norm.expand(b, m), ns.log()[:, None] + norm], dim=-1) # (m + 1)
log_nu = torch.cat([norm.expand(b, n), ms.log()[:, None] + norm], dim=-1)
# print(log_nu.min(), log_mu.min(), flush=True)
else:
log_mu = torch.cat([norm.expand(m), ns.log()[None] + norm]) # (m + 1)
log_nu = torch.cat([norm.expand(n), ms.log()[None] + norm])
log_mu, log_nu = log_mu[None].expand(b, -1), log_nu[None].expand(b, -1)
Z = log_sinkhorn_iterations(couplings, log_mu, log_nu, iters)
if ms.size()[0] > 1:
norm = norm[:, :, None]
Z = Z - norm # multiply probabilities by M+N
return Z
def arange_like(x, dim: int):
return x.new_ones(x.shape[dim]).cumsum(0) - 1 # traceable in 1.1
class SuperGlueM(nn.Module):
"""SuperGlue feature matching middle-end
Given two sets of keypoints and locations, we determine the
correspondences by:
1. Keypoint Encoding (normalization + visual feature and location fusion)
2. Graph Neural Network with multiple self and cross-attention layers
3. Final projection layer
4. Optimal Transport Layer (a differentiable Hungarian matching algorithm)
5. Thresholding matrix based on mutual exclusivity and a match_threshold
The correspondence ids use -1 to indicate non-matching points.
Paul-Edouard Sarlin, Daniel DeTone, Tomasz Malisiewicz, and Andrew
Rabinovich. SuperGlue: Learning Feature Matching with Graph Neural
Networks. In CVPR, 2020. https://arxiv.org/abs/1911.11763
"""
# default_config = {
# 'descriptor_dim': 128,
# 'weights': 'indoor',
# 'keypoint_encoder': [32, 64, 128],
# 'GNN_layers': ['self', 'cross'] * 9,
# 'sinkhorn_iterations': 100,
# 'match_threshold': 0.2,
# }
def __init__(self, config=None):
super().__init__()
default_config = argparse.Namespace()
default_config.descriptor_dim = 128
# default_config.weights =
default_config.keypoint_encoder = [32, 64, 128]
default_config.GNN_layers = ['self', 'cross'] * 9
default_config.sinkhorn_iterations = 100
default_config.match_threshold = 0.2
# self.config = {**self.default_config, **config}
if config is None:
self.config = default_config
else:
self.config = config
self.config.GNN_layers = ['self', 'cross'] * self.config.GNN_layer_num
# print('WULA!', self.config.GNN_layer_num)
self.kenc = KeypointEncoder(
self.config.descriptor_dim, self.config.keypoint_encoder)
self.gnn = AttentionalGNN(
self.config.descriptor_dim, self.config.GNN_layers)
self.final_proj = nn.Conv1d(
self.config.descriptor_dim, self.config.descriptor_dim,
kernel_size=1, bias=True)
bin_score = torch.nn.Parameter(torch.tensor(1.))
self.register_parameter('bin_score', bin_score)
self.vertex_desc = VertexDescriptor(self.config.descriptor_dim)
# assert self.config.weights in ['indoor', 'outdoor']
# path = Path(__file__).parent
# path = path / 'weights/superglue_{}.pth'.format(self.config.weights)
# self.load_state_dict(torch.load(path))
# print('Loaded SuperGlue model (\"{}\" weights)'.format(
# self.config.weights))
def forward(self, data):
"""Run SuperGlue on a pair of keypoints and descriptors"""
# print(data['segment0'].size())
# desc0, desc1 = data['descriptors0'].float()(), data['descriptors1'].float()()
# print(desc0.size())
kpts0, kpts1 = data['keypoints0'].float(), data['keypoints1'].float()
ori_mask0, ori_mask1 = data['mask0'].float(), data['mask1'].float()
dim_m, dim_n = data['ms'].float(), data['ns'].float()
mmax = dim_m.int().max()
nmax = dim_n.int().max()
mask0 = ori_mask0[:, :mmax]
mask1 = ori_mask1[:, :nmax]
kpts0 = kpts0[:, :mmax]
kpts1 = kpts1[:, :nmax]
desc0, desc1 = self.vertex_desc(data['image0'], kpts0.float()), self.vertex_desc(data['image1'], kpts1.float())
# print(desc0.size(), flush=True)
mask00 = torch.ones_like(mask0)[:, :, None] * mask0[:, None, :]
# print(mask00[1], flush=True)
mask11 = torch.ones_like(mask1)[:, :, None] * mask1[:, None, :]
mask01 = torch.ones_like(mask0)[:, :, None] * mask1[:, None, :]
mask10 = torch.ones_like(mask1)[:, :, None] * mask0[:, None, :]
# desc0 = desc0.transpose(0,1)
# desc1 = desc1.transpose(0,1)
# kpts0 = torch.reshape(kpts0, (1, -1, 2))
# kpts1 = torch.reshape(kpts1, (1, -1, 2))
if kpts0.shape[1] < 2 or kpts1.shape[1] < 2: # no keypoints
shape0, shape1 = kpts0.shape[:-1], kpts1.shape[:-1]
# print(data['file_name'])
return {
'matches0': kpts0.new_full(shape0, -1, dtype=torch.int)[0],
# 'matches1': kpts1.new_full(shape1, -1, dtype=torch.int)[0],
'matching_scores0': kpts0.new_zeros(shape0)[0],
# 'matching_scores1': kpts1.new_zeros(shape1)[0],
'skip_train': True
}
# file_name = data['file_name']
all_matches = data['all_matches'] if 'all_matches' in data else None# shape = (1, K1)
# .permute(1,2,0) # shape=torch.Size([1, 87,])
# positional embedding
# Keypoint normalization.
kpts0 = normalize_keypoints(kpts0, data['image0'].shape)
kpts1 = normalize_keypoints(kpts1, data['image1'].shape)
# Keypoint MLP encoder.
# print(data['file_name'])
# print(kpts0.size())
pos0 = self.kenc(kpts0)
pos1 = self.kenc(kpts1)
# print(desc0.size(), pos0.size())
# print(desc0.size(), pos0.size())
desc0 = desc0 + pos0
desc1 = desc1 + pos1
# self.register_buffer("mask", torch.tril(torch.ones(config.block_size, config.block_size))
# .view(1, 1, config.block_size, config.block_size))
# mask0 = ...
# mask1 = ...
# Multi-layer Transformer network.
desc0, desc1 = self.gnn(desc0, desc1, mask00, mask11, mask01, mask10)
# Final MLP projection.
mdesc0, mdesc1 = self.final_proj(desc0), self.final_proj(desc1)
# Compute matching descriptor distance.
# print(mdesc0.size(), mdesc1.size())
scores = torch.einsum('bdn,bdm->bnm', mdesc0, mdesc1)
scores0 = torch.einsum('bdn,bdm->bnm', mdesc0, mdesc0)
scores1 = torch.einsum('bdn,bdm->bnm', mdesc1, mdesc1)
# #print('here1!!', scores.size())
# b k1 k2
scores = scores / self.config.descriptor_dim**.5
# print(scores.size(), mask01.size())
# mask01 = mask0[:, :, None] * mask1[:, None, :]
# scores = scores.masked_fill(mask01 == 0, float('-inf'))
# print(scores.size())
# Run the optimal transport.
# print(dim_m.size(), dim_m, flush=True)
scores = log_optimal_transport(
scores, self.bin_score,
iters=self.config.sinkhorn_iterations,
ms=dim_m, ns=dim_n)
# print(scores)
# print(scores.sum())
# print(scores.sum(1))
# print(scores.sum(0))
# Get the matches with score above "match_threshold".
return scores[:, :-1, :-1], scores0, scores1, mdesc0, mdesc1
def tensor_erode(bin_img, ksize=5):
# 首先为原图加入 padding,防止腐蚀后图像尺寸缩小
B, C, H, W = bin_img.shape
pad = (ksize - 1) // 2
bin_img = F.pad(bin_img, [pad, pad, pad, pad], mode='constant', value=0)
# 将原图 unfold 成 patch
patches = bin_img.unfold(dimension=2, size=ksize, step=1)
patches = patches.unfold(dimension=3, size=ksize, step=1)
# B x C x H x W x k x k
# 取每个 patch 中最小的值,i.e., 0
eroded, _ = patches.reshape(B, C, H, W, -1).min(dim=-1)
return eroded
class InbetweenerM(nn.Module):
"""SuperGlue feature matching middle-end
Given two sets of keypoints and locations, we determine the
correspondences by:
1. Keypoint Encoding (normalization + visual feature and location fusion)
2. Graph Neural Network with multiple self and cross-attention layers
3. Final projection layer
4. Optimal Transport Layer (a differentiable Hungarian matching algorithm)
5. Thresholding matrix based on mutual exclusivity and a match_threshold
The correspondence ids use -1 to indicate non-matching points.
Paul-Edouard Sarlin, Daniel DeTone, Tomasz Malisiewicz, and Andrew
Rabinovich. SuperGlue: Learning Feature Matching with Graph Neural
Networks. In CVPR, 2020. https://arxiv.org/abs/1911.11763
"""
# default_config = {
# 'descriptor_dim': 128,
# 'weights': 'indoor',
# 'keypoint_encoder': [32, 64, 128],
# 'GNN_layers': ['self', 'cross'] * 9,
# 'sinkhorn_iterations': 100,
# 'match_threshold': 0.2,
# }
def __init__(self, config=None):
super().__init__()
self.corr = SuperGlueM(config.corr_model)
self.mask_map = MLP([config.corr_model.descriptor_dim, 32, 1])
self.pos_weight = config.pos_weight
# self.motion_propagation =
# assert self.config.weights in ['indoor', 'outdoor']
# path = Path(__file__).parent
# path = path / 'weights/superglue_{}.pth'.format(self.config.weights)
# self.load_state_dict(torch.load(path))
# print('Loaded SuperGlue model (\"{}\" weights)'.format(
# self.config.weights))
def forward(self, data):
if 'gen_vid' in data:
dim_m, dim_n = data['ms'].float(), data['ns'].float()
mmax = dim_m.int().max()
nmax = dim_n.int().max()
# with torch.no_grad():
# self.corr.eval()
score01, score0, score1, dec0, dec1 = self.corr(data)
kpts0, kpts1 = data['keypoints0'][:,:mmax].float(), data['keypoints1'][:,:nmax].float() # BM2, BN2
## print(kpts0.mean(), kpts1.mean(), flush=True)
motion_pred0 = torch.softmax(score01, dim=-1) @ kpts1 - kpts0
motion_pred1 = torch.softmax(score01.transpose(1, 2), dim=-1) @ kpts0 - kpts1
motion_pred0 = torch.softmax(score0, dim=-1) @ motion_pred0
motion_pred1 = torch.softmax(score1, dim=-1) @ motion_pred1
max0, max1 = score01.max(2), score01.max(1)
indices0, indices1 = max0.indices, max1.indices
mutual0 = arange_like(indices0, 1)[None] == indices1.gather(1, indices0)
mutual1 = arange_like(indices1, 1)[None] == indices0.gather(1, indices1)
zero = score01.new_tensor(0)
mscores0 = torch.where(mutual0, max0.values.exp(), zero)
mscores1 = torch.where(mutual1, mscores0.gather(1, indices1), zero)
# valid0 = mutual0 & (mscores0 > self.config.match_threshold)
# valid1 = mutual1 & valid0.gather(1, indices1)
valid0 = mscores0 > 0.2
valid1 = valid0.gather(1, indices1)
indices0 = torch.where(valid0, indices0, indices0.new_tensor(-1))
indices1 = torch.where(valid1, indices1, indices1.new_tensor(-1))
adj0, adj1 = data['adj_mat0'].float(), data['adj_mat1'].float()
motion_pred0 = torch.softmax(score01, dim=-1) @ kpts1 - kpts0
motion_pred1 = torch.softmax(score01.transpose(1, 2), dim=-1) @ kpts0 - kpts1
# score0.mask_off()
motion_pred0 = torch.softmax(score0.masked_fill(adj0==0, float('-inf')), dim=-1) @ motion_pred0
motion_pred1 = torch.softmax(score1.masked_fill(adj1==0, float('-inf')), dim=-1) @ motion_pred1
vb0 = self.mask_map(dec0)[:, 0]
vb1 = self.mask_map(dec1)[:, 0]
vb0[:] = 1
vb1[:] = 1
im0_erode = data['image0']
im1_erode = data['image1']
im0_erode[im0_erode > 0] = 1
im0_erode[im0_erode <= 0] = 0
im1_erode[im1_erode > 0] = 1
im1_erode[im1_erode <= 0] = 0
im0_erode = tensor_erode(im0_erode, 3)
im1_erode = tensor_erode(im1_erode, 3)
motion_output0, motion_output1 = motion_pred0.clone(), motion_pred1.clone()
## print('>>>>> here', motion_pred0.mean(), motion_pred1.mean(), flush=True)
kpt0t = kpts0 + motion_output0 * 1
kpt1t = kpts1 + motion_output1 * 1
if 'topo0' in data and 'topo1' in data:
## print(len(data['topo0'][0]), len(data['topo1']), flush=True)
for node, nbs in enumerate(data['topo0'][0]):
for nb in nbs:
# print(nb, flush=True)
# print(kpt0t.size(), 'fDsafdsafds', flush=True)
# if vb0[0, nb] and vb0[0, node] and ((kpt0t[0, node] - kpt0t[0, nb]) ** 2).sum() / (((kpts0[0, node] - kpts0[0, nb]) ** 2).sum() + 1e-7) > 3:
# vb0[0, nb] = -1
# vb0[0, node] = -1
# print(node.size())
center = ((kpt0t[0, node] + kpt0t[0, nb]) * 0.5).int()[0]
# print(center.size(), flush=True)
if vb0[0, nb] and vb0[0, node] and im1_erode[0,:, center[1], center[0]].mean() > 0.8:
vb0[0, nb] = -1
vb0[0, node] = -1
# center = ((kpt0t[0, node] + kpt0t[0, nb]) * 0.25).int()[0]
# # print(center.size(), flush=True)
# if vb0[0, nb] and vb0[0, node] and center[1] < 720 and center[0] < 720 and im1_erode[0,:, center[1], center[0]].mean() > 0.8:
# vb0[0, nb] = -1
# vb0[0, node] = -1
# center = ((kpt0t[0, node] + kpt0t[0, nb]) * 0.75).int()[0]
# # print(center.size(), flush=True)
# if vb0[0, nb] and vb0[0, node] and center[1] < 720 and center[0] < 720 and im1_erode[0,:, center[1], center[0]].mean() > 0.8:
# vb0[0, nb] = -1
# vb0[0, node] = -1
for node, nbs in enumerate(data['topo1'][0]):
for nb in nbs:
# if vb1[0, nb] and vb1[0, node] and ((kpt1t[0, node] - kpt1t[0, nb]) ** 2).sum() / (((kpts1[0, node] - kpts1[0, nb]) ** 2).sum() + 1e-7) >3:
# vb1[0, nb] = -1
# vb1[0, node] = -1
center = ((kpt1t[0, node] + kpt1t[0, nb]) * 0.5).int()[0]
if vb1[0, nb] and vb1[0, node] and im0_erode[0,:, center[1], center[0]].mean() > 0.95:
vb1[0, nb] = -1
vb1[0, node] = -1
# center = ((kpt1t[0, node] + kpt1t[0, nb]) * 0.25).int()[0]
# if vb1[0, nb] and vb1[0, node] and center[1] < 720 and center[0] < 720 and im0_erode[0,:, center[1], center[0]].mean() > 0.95:
# vb1[0, nb] = -1
# vb1[0, node] = -1
# center = ((kpt1t[0, node] + kpt1t[0, nb]) * 0.75).int()[0]
# if vb1[0, nb] and vb1[0, node] and center[1] < 720 and center[0] < 720 and im0_erode[0,:, center[1], center[0]].mean() > 0.95:
# vb1[0, nb] = -1
# vb1[0, node] = -1
# print(vb0.mean(), vb1.mean(), flush=True)
return {'r0': motion_output0, 'r1': motion_output1, 'vb0':(vb0 > 0).float(), 'vb1':(vb1 > 0).float(),}
dim_m, dim_n = data['ms'].float(), data['ns'].float()
mmax = dim_m.int().max()
nmax = dim_n.int().max()
# with torch.no_grad():
# self.corr.eval()
score01, score0, score1, dec0, dec1 = self.corr(data)
kpts0, kpts1 = data['keypoints0'][:,:mmax].float(), data['keypoints1'][:,:nmax].float() # BM2, BN2
adj0, adj1 = data['adj_mat0'].float(), data['adj_mat1'].float()
motion_pred0 = torch.softmax(score01, dim=-1) @ kpts1 - kpts0
motion_pred1 = torch.softmax(score01.transpose(1, 2), dim=-1) @ kpts0 - kpts1
# score0.mask_off()
motion_pred0 = torch.softmax(score0.masked_fill(adj0==0, float('-inf')), dim=-1) @ motion_pred0
motion_pred1 = torch.softmax(score1.masked_fill(adj1==0, float('-inf')), dim=-1) @ motion_pred1
vb0 = self.mask_map(dec0)[:, 0]
vb1 = self.mask_map(dec1)[:, 0]
# motion0_pred, vb0 = pred0[:, :2].permute(0, 2, 1), pred0[:, 2:][:, 0]
# motion1_pred, vb1 = pred1[:, :2].permute(0, 2, 1), pred1[:, 2:][:, 0]
# delta0, delta1 = motion_delta[:, :, :mmax].permute(0, 2, 1), motion_delta[:, :, mmax:].permute(0, 2, 1)
# motion_output0, motion_output1 = motion0 + delta0, motion1 + delta1
motion_output0, motion_output1 = motion_pred0.clone(), motion_pred1.clone()
# print(delta0.max(), delta1.max())
# vb0 = kpts0.new_ones(motion_pred0[:, :, 0].size()) + 1.0
# vb1 = kpts1.new_ones(motion_pred1[:, :, 0].size()) + 1.0
# vb0, vb1 = visibility[:, 0, :mmax], visibility[:, 0, mmax:]
# mask0, mask1 = mask[:, :mmax].bool(), mask[:, mmax:].bool()
# vb0_output = vb0.clone()
# vb1_output = vb1.clone()
# vb1_output[batch, corr01[corr01 != -1]] = 1.0
# motion_output0[valid0.bool()] = motion0[valid0.bool()]
# motion_output1[valid1.bool()] = motion1[valid1.bool()]
# vb0_output[vb0_output >= 0] = 1.0
# vb0_output[vb0_output < 0] = 0.0
# vb1_output[vb1_output >= 0] = 1.0
# vb1_output[vb1_output < 0 ] = 0.0
kpt0t = kpts0 + motion_output0 / 2
kpt1t = kpts1 + motion_output1 / 2
# kpt1t[batch, corr01[corr01 != -1]] = kpt0t[corr01 != -1]
##################################################
## Note Here the mini batch size is 1!!!!!!!! ##
##################################################
if 'topo0' in data and 'topo1' in data:
# print(len(data['topo0'][0]), len(data['topo1']), flush=True)
for node, nbs in enumerate(data['topo0'][0]):
for nb in nbs:
if vb0[0, nb] and vb0[0, node] and ((kpt0t[0, node] - kpt0t[0, nb]) ** 2).sum() / (((kpts0[0, node] - kpts0[0, nb]) ** 2).sum() + 1e-7) > 5:
vb0[0, nb] = -1
vb0[0, node] = -1
for node, nbs in enumerate(data['topo1'][0]):
for nb in nbs:
if vb1[0, nb] and vb1[0, node] and ((kpt1t[0, node] - kpt1t[0, nb]) ** 2).sum() / (((kpts1[0, node] - kpts1[0, nb]) ** 2).sum() + 1e-7) > 5:
vb1[0, nb] = -1
vb1[0, node] = -1
if 'motion0' in data and 'motion1' in data:
# valid_motion0 = motion_output0[mask0[:, :, None].repeat(1, 1, 2)]
# gt_valid_motion0 = data['motion0'][:, :mmax][mask0[:, :, None].repeat(1, 1, 2)].float()
# valid_motion1 = motion_output1[mask1[:, :, None].repeat(1, 1, 2)]
# gt_valid_motion1 = data['motion1'][:, :nmax][mask1[:, :, None].repeat(1, 1, 2)].float()
loss_motion = torch.nn.functional.l1_loss(motion_pred0, data['motion0'][:, :mmax]) +\
torch.nn.functional.l1_loss(motion_pred1, data['motion1'][:, :nmax])
# loss_valid0 = ((corr01 == -1) & (mask0 == 1))
# loss_valid1 = ((corr10 == -1) & (mask1 == 1))
EPE0 = ((motion_pred0 - data['motion0'][:, :mmax]) ** 2).sum(dim=-1).sqrt()
EPE1 = ((motion_pred1 - data['motion1'][:, :nmax]) ** 2).sum(dim=-1).sqrt()
# print(EPE0.size(), 'fdsafdsa')
EPE = (EPE0.mean() + EPE1.mean()) * 0.5
# print(len(EPE0[mask0]), len(EPE1[mask1]))
# print(vb0[:, :mmax][mask0], vb0[:, :mmax][mask0].shape, data['visibility0'][:, :mmax][mask0], data['visibility0'][:, :mmax][mask0].shape)
# print(.size())
# print((vb0[:, :mmax] > 0).float().sum(), data['visibility0'][:, :mmax].float().sum())
# pos_weight=vb0.new_tensor([0.5])
if 'visibility0' in data and 'visibility1' in data:
loss_visibility = torch.nn.functional.binary_cross_entropy_with_logits(vb0[:, :mmax].view(-1, 1), data['visibility0'][:, :mmax].view(-1, 1), pos_weight=vb0.new_tensor([self.pos_weight])) + \
torch.nn.functional.binary_cross_entropy_with_logits(vb1[:, :nmax].view(-1, 1), data['visibility1'][:, :nmax].view(-1, 1), pos_weight=vb0.new_tensor([self.pos_weight]))
VB_Acc = ((((vb0 > 0).float() == data['visibility0'][:, :mmax]).float().sum() + ((vb1 > 0).float() == data['visibility1'][:, :nmax]).float().sum()) * 1.0 / (mmax + nmax))
else:
loss_visibility = 0
VB_Acc = EPE.new_zeros([1])
loss = loss_motion + 10 * loss_visibility
loss_mean = torch.mean(loss)
# loss_mean = torch.reshape(loss_mean, (1, -1))
# print(loss_mean, flush=True)
# print(all_matches[:, :mmax].size(), indices0.size(), mask0.size(), flush=True)
#print((all_matches[0] == indices0[0]).sum())
# print(vb1.size(),corr01.size())
# kpt0t = torch.nn.functional.pad(kpts0 + motion_output0, (0, 0, 0, self.max_len - mmax, 0, 0), mode='constant', value=0)
# kpt1t = torch.nn.functional.pad(kpts1 + motion_output1, (0, 0, 0, self.max_len - nmax, 0, 0), mode='constant', value=0),
# kpt1t[:, :nmax][batch, corr01[corr01 != -1]] = kpt0t[:, :mmax][corr01 != -1]
b, _, _ = motion_pred0.size()
# batch = torch.arange(b)[:, None].repeat(1, mmax)[corr01 != -1].long()
# # print(kpts0[corr01 != -1].size(), corr01[corr01 != -1].size())
# matched_intermediate = (kpts0[(corr01 != -1)] + kpts1[batch, corr01[corr01 != -1].long(), :]) * 0.5
# motion0[corr01 != -1] = matched_intermediate - kpts0[corr01 != -1]
# motion1[batch, corr01[corr01 != -1].long(), :] = matched_intermediate - kpts1[batch, corr01[corr01 != -1].long(), :]
# vb0 = torch.nn.functional.pad(vb0, (0, self.max_len - mmax, 0, 0), mode='constant', value=0),
# vb1 = torch.nn.functional.pad(vb1, (0, self.max_len - nmax, 0, 0), mode='constant', value=0),
# self.max_len = 3050
# VB_Acc = ((((vb0 > 0.5).float() == data['visibility0'][:, :mmax]).float().sum() + ((vb1 > 0.5).float() == data['visibility1'][:, :nmax]).float().sum()) * 1.0 / (mmax + nmax))
return {
# 'matches0': indices0, # use -1 for invalid match
# 'matches1': indices1[0], # use -1 for invalid match
# 'matching_scores0': mscores0,
# 'matching_scores1': mscores1[0],
# 'keypointst0': torch.nn.functional.pad(kpts0 + motion_output0, (0, 0, 0, self.max_len - mmax, 0, 0), mode='constant', value=0),
# 'keypointst1': torch.nn.functional.pad(kpts1 + motion_output1, (0, 0, 0, self.max_len - nmax, 0, 0), mode='constant', value=0),
# 'vb0': torch.nn.functional.pad(vb0, (0, self.max_len - mmax, 0, 0), mode='constant', value=0),
# 'vb1': torch.nn.functional.pad(vb1, (0, self.max_len - nmax, 0, 0), mode='constant', value=0),
'keypoints0t': kpt0t,
'keypoints1t': kpt1t,
'vb0': (vb0 > 0).float(),
'vb1': (vb1 > 0).float(),
'loss': loss_mean,
'EPE': EPE,
'Visibility Acc': VB_Acc
# ((((vb0[mask0] > 0).float() == data['visibility0'][:, :mmax][mask0]).float().sum() + ((vb1[mask1] > 0).float() == data['visibility1'][:, :nmax][mask1]).float().sum()) * 1.0 / (mask0.float().sum() + mask1.float().sum())),
# 'skip_train': [False],
# 'accuracy': (((all_matches[:, :mmax] == indices0) & mask0.bool()).sum() / mask0.sum()).item(),
# 'valid_accuracy': (((all_matches[:, :mmax] == indices0) & (all_matches[:, :mmax] != -1) & mask0.bool()).float().sum() / ((all_matches[:, :mmax] != -1) & mask0.bool()).float().sum()).item(),
}
else:
return {
'loss': -1,
'skip_train': True,
'keypointst0': kpts0 + motion_output0,
'keypointst1': kpts1 + motion_output1,
'vb0': vb0,
'vb1': vb1,
# 'accuracy': -1,
# 'area_accuracy': -1,
# 'valid_accuracy': -1,
}
if __name__ == '__main__':
args = argparse.Namespace()
args.batch_size = 2
args.gap = 5
args.type = 'train'
args.model = None
args.action = None
ss = Refiner()
loader = fetch_dataloader(args)
# #print(len(loader))
for data in loader:
# p1, p2, s1, s2, mi = data
dict1 = data
kp1 = dict1['keypoints0']
kp2 = dict1['keypoints1']
p1 = dict1['image0']
p2 = dict1['image1']
# #print(s1)
# #print(s1.type)
mi = dict1['m01']
fname = dict1['file_name']
print(dict1['keypoints0'].size(), dict1['keypoints1'].size(), dict1['m01'].size(), dict1['motion0'].size(), dict1['mask0'].size())
# print(kp1.shape, p1.shape, mi.shape)
# #print(mi.size())
# #print(mi)
# break
a = ss(data)
print(dict1['file_name'])
print(a['loss'])
print(a['EPE'], a['Visibility Acc'],flush=True)
a['loss'].backward()
================================================
FILE: models/inbetweener_with_mask_with_spec.py
================================================
from copy import deepcopy
from pathlib import Path
import torch
from torch import nn
# from seg_desc import seg_descriptor
import argparse
import numpy as np
import torch.nn.functional as F
from sknetwork.embedding import Spectral
def MLP(channels: list, do_bn=True):
""" Multi-layer perceptron """
n = len(channels)
layers = []
for i in range(1, n):
layers.append(
nn.Conv1d(channels[i - 1], channels[i], kernel_size=1, bias=True))
if i < (n-1):
if do_bn:
# layers.append(nn.BatchNorm1d(channels[i]))
layers.append(nn.InstanceNorm1d(channels[i]))
layers.append(nn.ReLU())
return nn.Sequential(*layers)
def normalize_keypoints(kpts, image_shape):
""" Normalize keypoints locations based on image image_shape"""
_, _, height, width = image_shape
one = kpts.new_tensor(1)
size = torch.stack([one*width, one*height])[None]
center = size / 2
scaling = size.max(1, keepdim=True).values * 0.7
return (kpts - center[:, None, :]) / scaling[:, None, :]
class ThreeLayerEncoder(nn.Module):
""" Joint encoding of visual appearance and location using MLPs"""
def __init__(self, enc_dim):
super().__init__()
# input must be 3 channel (r, g, b)
self.layer1 = nn.Conv2d(3, enc_dim//4, 7, padding=3)
self.non_linear1 = nn.ReLU()
self.layer2 = nn.Conv2d(enc_dim//4, enc_dim//2, 3, padding=1)
self.non_linear2 = nn.ReLU()
self.layer3 = nn.Conv2d(enc_dim//2, enc_dim, 3, padding=1)
self.norm1 = nn.InstanceNorm2d(enc_dim//4)
self.norm2 = nn.InstanceNorm2d(enc_dim//2)
self.norm3 = nn.InstanceNorm2d(enc_dim)
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
nn.init.constant_(m.bias, 0.0)
def forward(self, img):
x = self.non_linear1(self.norm1(self.layer1(img)))
x = self.non_linear2(self.norm2(self.layer2(x)))
x = self.norm3(self.layer3(x))
# x = self.non_linear1(self.layer1(img))
# x = self.non_linear2(self.layer2(x))
# x = self.layer3(x)
return x
class VertexDescriptor(nn.Module):
""" Joint encoding of visual appearance and location using MLPs"""
def __init__(self, enc_dim):
super().__init__()
self.encoder = ThreeLayerEncoder(enc_dim)
# self.super_pixel_pooling =
# use scatter
# nn.init.constant_(self.encoder[-1].bias, 0.0)
def forward(self, img, vtx):
x = self.encoder(img)
n, c, h, w = x.size()
assert((h, w) == img.size()[2:4])
return x[:, :, torch.round(vtx[0, :, 1]).long(), torch.round(vtx[0, :, 0]).long()]
# return super_pixel_pooling(x.view(n, c, -1), seg.view(-1).long(), reduce='mean')
# here return size is [1]xCx|Seg|
class KeypointEncoder(nn.Module):
""" Joint encoding of visual appearance and location using MLPs"""
def __init__(self, feature_dim, layers):
super().__init__()
self.encoder = MLP([2] + layers + [feature_dim])
# for m in self.encoder.modules():
# if isinstance(m, nn.Conv2d):
# nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
# nn.init.constant_(m.bias, 0.0)
nn.init.constant_(self.encoder[-1].bias, 0.0)
def forward(self, kpts):
inputs = kpts.transpose(1, 2)
x = self.encoder(inputs)
return x
class TopoEncoder(nn.Module):
""" Joint encoding of visual appearance and location using MLPs"""
def __init__(self, feature_dim, layers):
super().__init__()
self.encoder = MLP([64] + layers + [feature_dim])
# for m in self.encoder.modules():
# if isinstance(m, nn.Conv2d):
# nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
# nn.init.constant_(m.bias, 0.0)
nn.init.constant_(self.encoder[-1].bias, 0.0)
def forward(self, kpts):
inputs = kpts.transpose(1, 2)
x = self.encoder(inputs)
return x
def attention(query, key, value, mask=None):
dim = query.shape[1]
scores = torch.einsum('bdhn,bdhm->bhnm', query, key) / dim**.5
if mask is not None:
scores = scores.masked_fill(mask==0, float('-inf'))
prob = torch.nn.functional.softmax(scores, dim=-1)
return torch.einsum('bhnm,bdhm->bdhn', prob, value), prob
class MultiHeadedAttention(nn.Module):
""" Multi-head attention to increase model expressivitiy """
def __init__(self, num_heads: int, d_model: int):
super().__init__()
assert d_model % num_heads == 0
self.dim = d_model // num_heads
self.num_heads = num_heads
self.merge = nn.Conv1d(d_model, d_model, kernel_size=1)
self.proj = nn.ModuleList([deepcopy(self.merge) for _ in range(3)])
def forward(self, query, key, value, mask=None):
batch_dim = query.size(0)
query, key, value = [l(x).view(batch_dim, self.dim, self.num_heads, -1)
for l, x in zip(self.proj, (query, key, value))]
x, prob = attention(query, key, value, mask)
# self.prob.append(prob)
return self.merge(x.contiguous().view(batch_dim, self.dim*self.num_heads, -1))
class AttentionalPropagation(nn.Module):
def __init__(self, feature_dim: int, num_heads: int):
super().__init__()
self.attn = MultiHeadedAttention(num_heads, feature_dim)
self.mlp = MLP([feature_dim*2, feature_dim*2, feature_dim])
nn.init.constant_(self.mlp[-1].bias, 0.0)
def forward(self, x, source, mask=None):
message = self.attn(x, source, source, mask)
return self.mlp(torch.cat([x, message], dim=1))
class AttentionalGNN(nn.Module):
def __init__(self, feature_dim: int, layer_names: list):
super().__init__()
self.layers = nn.ModuleList([
AttentionalPropagation(feature_dim, 4)
for _ in range(len(layer_names))])
self.names = layer_names
def forward(self, desc0, desc1, mask00=None, mask11=None, mask01=None, mask10=None):
for layer, name in zip(self.layers, self.names):
layer.attn.prob = []
if name == 'cross':
src0, src1 = desc1, desc0
mask0, mask1 = mask01[:, None], mask10[:, None]
else: # if name == 'self':
src0, src1 = desc0, desc1
mask0, mask1 = mask00[:, None], mask11[:, None]
delta0, delta1 = layer(desc0, src0, mask0), layer(desc1, src1, mask1)
desc0, desc1 = (desc0 + delta0), (desc1 + delta1)
return desc0, desc1
def log_sinkhorn_iterations(Z, log_mu, log_nu, iters: int):
""" Perform Sinkhorn Normalization in Log-space for stability"""
u, v = torch.zeros_like(log_mu), torch.zeros_like(log_nu)
for _ in range(iters):
u = log_mu - torch.logsumexp(Z + v.unsqueeze(1), dim=2)
v = log_nu - torch.logsumexp(Z + u.unsqueeze(2), dim=1)
return Z + u.unsqueeze(2) + v.unsqueeze(1)
def log_optimal_transport(scores, alpha, iters: int, ms=None, ns=None):
""" Perform Differentiable Optimal Transport in Log-space for stability"""
b, m, n = scores.shape
one = scores.new_tensor(1)
if ms is None or ns is None:
ms, ns = (m*one).to(scores), (n*one).to(scores)
# else:
# ms, ns = ms.to(scores)[:, None], ns.to(scores)[:, None]
# here m,n should be parameters not shape
# ms, ns: (b, )
bins0 = alpha.expand(b, m, 1)
bins1 = alpha.expand(b, 1, n)
alpha = alpha.expand(b, 1, 1)
# pad additional scores for unmatcheed (to -1)
# alpha is the learned threshold
couplings = torch.cat([torch.cat([scores, bins0], -1),
torch.cat([bins1, alpha], -1)], 1)
norm = - (ms + ns).log() # (b, )
if ms.size()[0] > 0:
norm = norm[:, None]
log_mu = torch.cat([norm.expand(b, m), ns.log()[:, None] + norm], dim=-1) # (m + 1)
log_nu = torch.cat([norm.expand(b, n), ms.log()[:, None] + norm], dim=-1)
else:
log_mu = torch.cat([norm.expand(m), ns.log()[None] + norm]) # (m + 1)
log_nu = torch.cat([norm.expand(n), ms.log()[None] + norm])
log_mu, log_nu = log_mu[None].expand(b, -1), log_nu[None].expand(b, -1)
Z = log_sinkhorn_iterations(couplings, log_mu, log_nu, iters)
if ms.size()[0] > 1:
norm = norm[:, :, None]
Z = Z - norm # multiply probabilities by M+N
return Z
def arange_like(x, dim: int):
return x.new_ones(x.shape[dim]).cumsum(0) - 1 # traceable in 1.1
class SuperGlueT(nn.Module):
"""SuperGlue feature matching middle-end
Given two sets of keypoints and locations, we determine the
correspondences by:
1. Keypoint Encoding (normalization + visual feature and location fusion)
2. Graph Neural Network with multiple self and cross-attention layers
3. Final projection layer
4. Optimal Transport Layer (a differentiable Hungarian matching algorithm)
5. Thresholding matrix based on mutual exclusivity and a match_threshold
The correspondence ids use -1 to indicate non-matching points.
Paul-Edouard Sarlin, Daniel DeTone, Tomasz Malisiewicz, and Andrew
Rabinovich. SuperGlue: Learning Feature Matching with Graph Neural
Networks. In CVPR, 2020. https://arxiv.org/abs/1911.11763
"""
# default_config = {
# 'descriptor_dim': 128,
# 'weights': 'indoor',
# 'keypoint_encoder': [32, 64, 128],
# 'GNN_layers': ['self', 'cross'] * 9,
# 'sinkhorn_iterations': 100,
# 'match_threshold': 0.2,
# }
def __init__(self, config=None):
super().__init__()
default_config = argparse.Namespace()
default_config.descriptor_dim = 128
default_config.keypoint_encoder = [32, 64, 128]
default_config.GNN_layers = ['self', 'cross'] * 9
default_config.sinkhorn_iterations = 100
default_config.match_threshold = 0.2
self.spectral = Spectral(64, normalized=False)
if config is None:
self.config = default_config
else:
self.config = config
self.config.GNN_layers = ['self', 'cross'] * self.config.GNN_layer_num
self.kenc = KeypointEncoder(
self.config.descriptor_dim, self.config.keypoint_encoder)
self.tenc = TopoEncoder(
self.config.descriptor_dim, [96])
self.gnn = AttentionalGNN(
self.config.descriptor_dim, self.config.GNN_layers)
self.final_proj = nn.Conv1d(
self.config.descriptor_dim, self.config.descriptor_dim,
kernel_size=1, bias=True)
bin_score = torch.nn.Parameter(torch.tensor(1.))
self.register_parameter('bin_score', bin_score)
self.vertex_desc = VertexDescriptor(self.config.descriptor_dim)
def forward(self, data):
kpts0, kpts1 = data['keypoints0'].float(), data['keypoints1'].float()
ori_mask0, ori_mask1 = data['mask0'].float(), data['mask1'].float()
dim_m, dim_n = data['ms'].float(), data['ns'].float()
# spectual embedding of adj matrices
# here I find that online computation of spectrals are too slow during training
# so the spectrual embedding is moved to dataset pipeline
# such that it can be computed in data preparation by multi-processing cpus
spec0, spec1 = data['spec0'], data['spec1']
# spec0, spec1 = np.abs(self.spectral.fit_transform(adj_mat0[0].cpu().numpy())), np.abs(self.spectral.fit_transform(adj_mat1[0].cpu().numpy()))
mmax = dim_m.int().max()
nmax = dim_n.int().max()
mask0 = ori_mask0[:, :mmax]
mask1 = ori_mask1[:, :nmax]
kpts0 = kpts0[:, :mmax]
kpts1 = kpts1[:, :nmax]
# image context embedding
desc0, desc1 = self.vertex_desc(data['image0'], kpts0.float()), self.vertex_desc(data['image1'], kpts1.float())
# add topological embedding
desc0 = desc0 + self.tenc(desc0.new_tensor(spec0))
desc1 = desc1 + self.tenc(desc1.new_tensor(spec1))
# masks here were prepared for synchronized training with bach size > 1, but seems not to work well
# so the current framework still uses grad accumulation
mask00 = torch.ones_like(mask0)[:, :, None] * mask0[:, None, :]
mask11 = torch.ones_like(mask1)[:, :, None] * mask1[:, None, :]
mask01 = torch.ones_like(mask0)[:, :, None] * mask1[:, None, :]
mask10 = torch.ones_like(mask1)[:, :, None] * mask0[:, None, :]
if kpts0.shape[1] < 2 or kpts1.shape[1] < 2: # no keypoints
shape0, shape1 = kpts0.shape[:-1], kpts1.shape[:-1]
# print(data['file_name'])
return {
'matches0': kpts0.new_full(shape0, -1, dtype=torch.int)[0],
# 'matches1': kpts1.new_full(shape1, -1, dtype=torch.int)[0],
'matching_scores0': kpts0.new_zeros(shape0)[0],
# 'matching_scores1': kpts1.new_zeros(shape1)[0],
'skip_train': True
}
all_matches = data['all_matches'] if 'all_matches' in data else None# shape = (1, K1)
# positional embedding
# Keypoint normalization.
kpts0 = normalize_keypoints(kpts0, data['image0'].shape)
kpts1 = normalize_keypoints(kpts1, data['image1'].shape)
# Keypoint MLP encoder.
pos0 = self.kenc(kpts0)
pos1 = self.kenc(kpts1)
desc0 = desc0 + pos0
desc1 = desc1 + pos1
# Multi-layer Transformer network.
desc0, desc1 = self.gnn(desc0, desc1, mask00, mask11, mask01, mask10)
# Final MLP projection.
mdesc0, mdesc1 = self.final_proj(desc0), self.final_proj(desc1)
# Compute matching descriptor distance.
scores = torch.einsum('bdn,bdm->bnm', mdesc0, mdesc1)
scores0 = torch.einsum('bdn,bdm->bnm', mdesc0, mdesc0)
scores1 = torch.einsum('bdn,bdm->bnm', mdesc1, mdesc1)
# b k1 k2
scores = scores / self.config.descriptor_dim**.5
# Run the optimal transport.
scores = log_optimal_transport(
scores, self.bin_score,
iters=self.config.sinkhorn_iterations,
ms=dim_m, ns=dim_n)
# Get the matches with score above "match_threshold".
return scores[:, :-1, :-1], scores0, scores1, mdesc0, mdesc1
def tensor_erode(bin_img, ksize=5):
B, C, H, W = bin_img.shape
pad = (ksize - 1) // 2
bin_img = F.pad(bin_img, [pad, pad, pad, pad], mode='constant', value=0)
patches = bin_img.unfold(dimension=2, size=ksize, step=1)
patches = patches.unfold(dimension=3, size=ksize, step=1)
# B x C x H x W x k x k
eroded, _ = patches.reshape(B, C, H, W, -1).min(dim=-1)
return eroded
class InbetweenerTM(nn.Module):
"""AnimeInbet
The whole pipeline includes
1. vertex correspondence (vertex embedding + correspondence transformer)
2. repositioning propagation
3. vis mask
vertex corr code is modified from SUPER GLUE
Paul-Edouard Sarlin, Daniel DeTone, Tomasz Malisiewicz, and Andrew
Rabinovich. SuperGlue: Learning Feature Matching with Graph Neural
Networks. In CVPR, 2020. https://arxiv.org/abs/1911.11763
"""
def __init__(self, config=None):
super().__init__()
# vertex correspondence
self.corr = SuperGlueT(config.corr_model)
self.mask_map = MLP([config.corr_model.descriptor_dim, 32, 1])
self.pos_weight = config.pos_weight
def forward(self, data):
# if in the mode of video generating
if 'gen_vid' in data:
dim_m, dim_n = data['ms'].float(), data['ns'].float()
mmax = dim_m.int().max()
nmax = dim_n.int().max()
with torch.no_grad():
self.corr.eval()
score01, score0, score1, dec0, dec1 = self.corr(data)
kpts0, kpts1 = data['keypoints0'][:,:mmax].float(), data['keypoints1'][:,:nmax].float() # BM2, BN2
motion_pred0 = torch.softmax(score01, dim=-1) @ kpts1 - kpts0
motion_pred1 = torch.softmax(score01.transpose(1, 2), dim=-1) @ kpts0 - kpts1
motion_pred0 = torch.softmax(score0, dim=-1) @ motion_pred0
motion_pred1 = torch.softmax(score1, dim=-1) @ motion_pred1
self.mask_map.eval()
vb0 = self.mask_map(dec0)[:, 0]
vb1 = self.mask_map(dec1)[:, 0]
motion_output0, motion_output1 = motion_pred0.clone(), motion_pred1.clone()
kpt0t = kpts0 + motion_output0
kpt1t = kpts1 + motion_output1
if 'topo0' in data and 'topo1' in data:
# print(len(data['topo0'][0]), len(data['topo1']), flush=True)
for node, nbs in enumerate(data['topo0'][0]):
for nb in nbs:
if vb0[0, nb] and vb0[0, node] and ((kpt0t[0, node] - kpt0t[0, nb]) ** 2).sum() / (((kpts0[0, node] - kpts0[0, nb]) ** 2).sum() + 1e-7) > 3:
vb0[0, nb] = 0
vb0[0, node] = 0
for node, nbs in enumerate(data['topo1'][0]):
for nb in nbs:
if vb1[0, nb] and vb1[0, node] and ((kpt1t[0, node] - kpt1t[0, nb]) ** 2).sum() / (((kpts1[0, node] - kpts1[0, nb]) ** 2).sum() + 1e-7) > 3:
vb1[0, nb] = 0
vb1[0, node] = 0
return {'r0': motion_output0, 'r1': motion_output1, 'vb0':vb0, 'vb1':vb1,}
# in the normal train/test mode
dim_m, dim_n = data['ms'].float(), data['ns'].float()
mmax = dim_m.int().max()
nmax = dim_n.int().max()
# with torch.no_grad():
# self.corr.eval()
score01, score0, score1, dec0, dec1 = self.corr(data)
kpts0, kpts1 = data['keypoints0'][:,:mmax].float(), data['keypoints1'][:,:nmax].float() # BM2, BN2
motion_pred0 = torch.softmax(score01, dim=-1) @ kpts1 - kpts0
motion_pred1 = torch.softmax(score01.transpose(1, 2), dim=-1) @ kpts0 - kpts1
motion_pred0 = torch.softmax(score0, dim=-1) @ motion_pred0
motion_pred1 = torch.softmax(score1, dim=-1) @ motion_pred1
max0, max1 = score01.max(2), score01.max(1)
indices0, indices1 = max0.indices, max1.indices
mutual0 = arange_like(indices0, 1)[None] == indices1.gather(1, indices0)
mutual1 = arange_like(indices1, 1)[None] == indices0.gather(1, indices1)
zero = score01.new_tensor(0)
mscores0 = torch.where(mutual0, max0.values.exp(), zero)
mscores1 = torch.where(mutual1, mscores0.gather(1, indices1), zero)
# valid0 = mutual0 & (mscores0 > self.config.match_threshold)
# valid1 = mutual1 & valid0.gather(1, indices1)
valid0 = mscores0 > 0.2
valid1 = valid0.gather(1, indices1)
indices0 = torch.where(valid0, indices0, indices0.new_tensor(-1))
indices1 = torch.where(valid1, indices1, indices1.new_tensor(-1))
# motion_pred1[0][indices1[0]==-1] = 0
vb0 = self.mask_map(dec0)[:, 0]
vb1 = self.mask_map(dec1)[:, 0]
motion_output0, motion_output1 = motion_pred0.clone(), motion_pred1.clone()
if not self.training:
motion_pred0[0][indices0[0]!=-1] = kpts1[0][indices0[0][indices0[0]!=-1]] - kpts0[0][indices0[0]!=-1]
# # motion_pred0[0][indices0[0]==-1] = 0
motion_pred1[0][indices1[0]!=-1] = kpts0[0][indices1[0][indices1[0]!=-1]] - kpts1[0][indices1[0]!=-1]
vb0[:] = vb0[:] + 0.7
vb1[:] = vb1[:] + 0.7
# motion0_pred, vb0 = pred0[:, :2].permute(0, 2, 1), pred0[:, 2:][:, 0]
# motion1_pred, vb1 = pred1[:, :2].permute(0, 2, 1), pred1[:, 2:][:, 0]
# delta0, delta1 = motion_delta[:, :, :mmax].permute(0, 2, 1), motion_delta[:, :, mmax:].permute(0, 2, 1)
# motion_output0, motion_output1 = motion0 + delta0, motion1 + delta1
motion_output0, motion_output1 = motion_pred0.clone(), motion_pred1.clone()
im0_erode = data['image0']
im1_erode = data['image1']
im0_erode[im0_erode > 0] = 1
im0_erode[im0_erode <= 0] = 0
im1_erode[im1_erode > 0] = 1
im1_erode[im1_erode <= 0] = 0
im0_erode = tensor_erode(im0_erode, 7)
im1_erode = tensor_erode(im1_erode, 7)
kpt0t = kpts0 + motion_output0 / 2
kpt1t = kpts1 + motion_output1 / 2
##################################################
## Note Here the mini batch size is 1!!!!!!!! ##
##################################################
if 'topo0' in data and 'topo1' in data:
# print(len(data['topo0'][0]), len(data['topo1']), flush=True)
for node, nbs in enumerate(data['topo0'][0]):
for nb in nbs:
if vb0[0, nb] > 0 and vb0[0, node] > 0 and ((kpt0t[0, node] - kpt0t[0, nb]) ** 2).sum() / (((kpts0[0, node] - kpts0[0, nb]) ** 2).sum() + 1e-7) > 5:
vb0[0, nb] = -1
vb0[0, node] = -1
for node, nbs in enumerate(data['topo1'][0]):
for nb in nbs:
if vb1[0, nb] > 0 and vb1[0, node] > 0 and ((kpt1t[0, node] - kpt1t[0, nb]) ** 2).sum() / (((kpts1[0, node] - kpts1[0, nb]) ** 2).sum() + 1e-7) > 5:
vb1[0, nb] = -1
vb1[0, node] = -1
kpt0t = kpts0 + motion_output0 * 1
kpt1t = kpts1 + motion_output1 * 1
if 'topo0' in data and 'topo1' in data:
## print(len(data['topo0'][0]), len(data['topo1']), flush=True)
for node, nbs in enumerate(data['topo0'][0]):
for nb in nbs:
center = ((kpt0t[0, node] + kpt0t[0, nb]) * 0.5).int()[0]
if center[0] >= 720 or center[1] >= 720:
continue
if vb0[0, nb] > 0 and vb0[0, node] > 0 and im1_erode[0,:, center[1], center[0]].mean() > 0.8:
vb0[0, nb] = -1
vb0[0, node] = -1
for node, nbs in enumerate(data['topo1'][0]):
for nb in nbs:
center = ((kpt1t[0, node] + kpt1t[0, nb]) * 0.5).int()[0]
if vb1[0, nb] > 0 and vb1[0, node] > 0 and im0_erode[0,:, center[1], center[0]].mean() > 0.8:
vb1[0, nb] = -1
vb1[0, node] = -1
kpt0t = kpts0 + motion_output0 / 2
kpt1t = kpts1 + motion_output1 / 2
if 'motion0' in data and 'motion1' in data:
loss_motion = torch.nn.functional.l1_loss(motion_pred0, data['motion0'][:, :mmax]) +\
torch.nn.functional.l1_loss(motion_pred1, data['motion1'][:, :nmax])
EPE0 = ((motion_pred0 - data['motion0'][:, :mmax]) ** 2).sum(dim=-1).sqrt()
EPE1 = ((motion_pred1 - data['motion1'][:, :nmax]) ** 2).sum(dim=-1).sqrt()
# print(EPE0.size(), 'fdsafdsa')
EPE = (EPE0.mean() + EPE1.mean()) * 0.5
if 'visibility0' in data and 'visibility1' in data:
loss_visibility = torch.nn.functional.binary_cross_entropy_with_logits(vb0[:, :mmax].view(-1, 1), data['visibility0'][:, :mmax].view(-1, 1), pos_weight=vb0.new_tensor([self.pos_weight])) + \
torch.nn.functional.binary_cross_entropy_with_logits(vb1[:, :nmax].view(-1, 1), data['visibility1'][:, :nmax].view(-1, 1), pos_weight=vb0.new_tensor([self.pos_weight]))
VB_Acc = ((((vb0 > 0).float() == data['visibility0'][:, :mmax]).float().sum() + ((vb1 > 0).float() == data['visibility1'][:, :nmax]).float().sum()) * 1.0 / (mmax + nmax))
else:
loss_visibility = 0
VB_Acc = EPE.new_zeros([1])
loss = loss_motion + 10 * loss_visibility
loss_mean = torch.mean(loss)
b, _, _ = motion_pred0.size()
return {
'keypoints0t': kpt0t,
'keypoints1t': kpt1t,
'vb0': (vb0 > 0).float(),
'vb1': (vb1 > 0).float(),
'r0': motion_output0,
'r1': motion_output1,
'loss': loss_mean,
'EPE': EPE,
'Visibility Acc': VB_Acc
}
else:
return {
'loss': -1,
'skip_train': True,
'keypointst0': kpts0 + motion_output0,
'keypointst1': kpts1 + motion_output1,
'vb0': vb0,
'vb1': vb1,
}
if __name__ == '__main__':
args = argparse.Namespace()
args.batch_size = 2
args.gap = 5
args.type = 'train'
args.model = None
args.action = None
ss = Refiner()
loader = fetch_dataloader(args)
# #print(len(loader))
for data in loader:
# p1, p2, s1, s2, mi = data
dict1 = data
kp1 = dict1['keypoints0']
kp2 = dict1['keypoints1']
p1 = dict1['image0']
p2 = dict1['image1']
# #print(s1)
# #print(s1.type)
mi = dict1['m01']
fname = dict1['file_name']
print(dict1['keypoints0'].size(), dict1['keypoints1'].size(), dict1['m01'].size(), dict1['motion0'].size(), dict1['mask0'].size())
# print(kp1.shape, p1.shape, mi.shape)
# #print(mi.size())
# #print(mi)
# break
a = ss(data)
print(dict1['file_name'])
print(a['loss'])
print(a['EPE'], a['Visibility Acc'],flush=True)
a['loss'].backward()
================================================
FILE: requirement.txt
================================================
opencv-python
pyyaml==5.4.1
scikit-network
tqdm
matplotlib
easydict
gdown
================================================
FILE: srun.sh
================================================
#!/bin/sh
currenttime=`date "+%Y%m%d%H%M%S"`
if [ ! -d log ]; then
mkdir log
fi
echo "[Usage] ./srun.sh config_path [train|eval] partition gpunum"
# check config exists
if [ ! -e $1 ]
then
echo "[ERROR] configuration file: $1 does not exists!"
exit
fi
if [ ! -d ${expname} ]; then
mkdir ${expname}
fi
echo "[INFO] saving results to, or loading files from: "$expname
if [ "$3" == "" ]; then
echo "[ERROR] enter partition name"
exit
fi
partition_name=$3
echo "[INFO] partition name: $partition_name"
if [ "$4" == "" ]; then
echo "[ERROR] enter gpu num"
exit
fi
gpunum=$4
gpunum=$(($gpunum<8?$gpunum:8))
echo "[INFO] GPU num: $gpunum"
((ntask=$gpunum*3))
TOOLS="srun --partition=$partition_name -x SG-IDC2-10-51-5-44 --cpus-per-task=16 --gres=gpu:$gpunum -N 1 --mem-per-gpu=32G --job-name=${config_suffix}"
PYTHONCMD="python -u main.py --config $1"
if [ $2 == "train" ];
then
$TOOLS $PYTHONCMD \
--train
elif [ $2 == "eval" ];
then
$TOOLS $PYTHONCMD \
--eval
elif [ $2 == "gen" ];
then
$TOOLS $PYTHONCMD \
--gen
fi
# elif [ $2 == "visgt" ];
# then
# $TOOLS $PYTHONCMD \
# --visgt
# elif [ $2 == "anl" ];
# then
# $TOOLS $PYTHONCMD \
# --anl
# elif [ $2 == "sample" ];
# then
# $TOOLS $PYTHONCMD \
# --sample
# fi
================================================
FILE: utils/chamfer_distance.py
================================================
import os
import numpy as np
from time import time
import cv2
import pdb
import scipy
import scipy.ndimage
import torch
import torchmetrics
black_threshold = 255.0 * 0.99
def batch_edt(img, block=1024):
expand = False
bs,h,w = img.shape
diam2 = h**2 + w**2
odtype = img.dtype
grid = (img.nelement()+block-1) // block
# cupy implementation
# default to scipy cpu implementation
sums = img.sum(dim=(1,2))
ans = torch.tensor(np.stack([
scipy.ndimage.morphology.distance_transform_edt(i)
if s!=0 else # change scipy behavior for empty image
np.ones_like(i) * np.sqrt(diam2)
for i,s in zip(1-img, sums)
]), dtype=odtype)
if expand:
ans = ans.unsqueeze(1)
return ans
############### DERIVED DISTANCES ###############
# input: (bs,h,w) or (bs,1,h,w)
# returns: (bs,)
# normalized s.t. metric is same across proportional image scales
# average of two asymmetric distances
# normalized by diameter and area
def batch_chamfer_distance(gt, pred, block=1024, return_more=False):
t = batch_chamfer_distance_t(gt, pred, block=block)
p = batch_chamfer_distance_p(gt, pred, block=block)
cd = (t + p) / 2
return cd
def batch_chamfer_distance_t(gt, pred, block=1024, return_more=False):
#pdb.set_trace()
assert gt.device==pred.device and gt.shape==pred.shape
bs,h,w = gt.shape[0], gt.shape[-2], gt.shape[-1]
dpred = batch_edt(pred, block=block)
cd = (gt*dpred).float().mean((-2,-1)) / np.sqrt(h**2+w**2)
if len(cd.shape)==2:
assert cd.shape[1]==1
cd = cd.squeeze(1)
return cd
def batch_chamfer_distance_p(gt, pred, block=1024, return_more=False):
assert gt.device==pred.device and gt.shape==pred.shape
bs,h,w = gt.shape[0], gt.shape[-2], gt.shape[-1]
dgt = batch_edt(gt, block=block)
cd = (pred*dgt).float().mean((-2,-1)) / np.sqrt(h**2+w**2)
if len(cd.shape)==2:
assert cd.shape[1]==1
cd = cd.squeeze(1)
return cd
# normalized by diameter
# always between [0,1]
def batch_hausdorff_distance(gt, pred, block=1024, return_more=False):
assert gt.device==pred.device and gt.shape==pred.shape
bs,h,w = gt.shape[0], gt.shape[-2], gt.shape[-1]
dgt = batch_edt(gt, block=block)
dpred = batch_edt(pred, block=block)
hd = torch.stack([
(dgt*pred).amax(dim=(-2,-1)),
(dpred*gt).amax(dim=(-2,-1)),
]).amax(dim=0).float() / np.sqrt(h**2+w**2)
if len(hd.shape)==2:
assert hd.shape[1]==1
hd = hd.squeeze(1)
return hd
############### TORCHMETRICS ###############
class ChamferDistance2dMetric(torchmetrics.Metric):
full_state_update=False
def __init__(
self, block=1024, convert_dog=True, k=1.6, epsilon=0.01, kernel_factor=4, clip=False,
**kwargs,
):
super().__init__(**kwargs)
self.block = block
self.convert_dog = convert_dog
self.add_state('running_sum', default=torch.tensor(0.0), dist_reduce_fx='sum')
self.add_state('running_count', default=torch.tensor(0.0), dist_reduce_fx='sum')
return
def update(self, preds: torch.Tensor, target: torch.Tensor):
dist = batch_chamfer_distance(target, preds, block=self.block)
self.running_sum += dist.sum()
self.running_count += len(dist)
return
def compute(self):
return self.running_sum.float() / self.running_count
class ChamferDistance2dTMetric(ChamferDistance2dMetric):
def update(self, preds: torch.Tensor, target: torch.Tensor):
if self.convert_dog:
preds = (usketchers.batch_dog(preds, **self.dog_params)>0.5).float()
target = (usketchers.batch_dog(target, **self.dog_params)>0.5).float()
dist = batch_chamfer_distance_t(target, preds, block=self.block)
self.running_sum += dist.sum()
self.running_count += len(dist)
return
class ChamferDistance2dPMetric(ChamferDistance2dMetric):
def update(self, preds: torch.Tensor, target: torch.Tensor):
if self.convert_dog:
preds = (usketchers.batch_dog(preds, **self.dog_params)>0.5).float()
target = (usketchers.batch_dog(target, **self.dog_params)>0.5).float()
dist = batch_chamfer_distance_p(target, preds, block=self.block)
self.running_sum += dist.sum()
self.running_count += len(dist)
return
class HausdorffDistance2dMetric(torchmetrics.Metric):
def __init__(
self, block=1024, convert_dog=True,
t=2.0, sigma=1.0, k=1.6, epsilon=0.01, kernel_factor=4, clip=False,
**kwargs,
):
super().__init__(**kwargs)
self.block = block
self.convert_dog = convert_dog
self.dog_params = {
't': t, 'sigma': sigma, 'k': k, 'epsilon': epsilon,
'kernel_factor': kernel_factor, 'clip': clip,
}
self.add_state('running_sum', default=torch.tensor(0.0), dist_reduce_fx='sum')
self.add_state('running_count', default=torch.tensor(0.0), dist_reduce_fx='sum')
return
def update(self, preds: torch.Tensor, target: torch.Tensor):
if self.convert_dog:
preds = (usketchers.batch_dog(preds, **self.dog_params)>0.5).float()
target = (usketchers.batch_dog(target, **self.dog_params)>0.5).float()
dist = batch_hausdorff_distance(target, preds, block=self.block)
self.running_sum += dist.sum()
self.running_count += len(dist)
return
def compute(self):
return self.running_sum.float() / self.running_count
def rgb2sketch(img, black_threshold):
#pdb.set_trace()
img[img < black_threshold] = 1
img[img >= black_threshold] = 0
#cv2.imwrite("grey.png",img*255)
return torch.tensor(img)
def rgb2gray(rgb):
r, g, b = rgb[:,:,0], rgb[:,:,1], rgb[:,:,2]
gray = 0.2989 * r + 0.5870 * g + 0.1140 * b
return gray
def cd_score(img1, img2):
img1 = rgb2gray(img1.astype(float))
img2 = rgb2gray(img2.astype(float))
img1_sketch = rgb2sketch(img1, black_threshold)
img2_sketch = rgb2sketch(img2, black_threshold)
img1_sketch = img1_sketch.unsqueeze(0)
img2_sketch = img2_sketch.unsqueeze(0)
CD = ChamferDistance2dMetric()
cd = CD(img1_sketch,img2_sketch)
return cd
================================================
FILE: utils/log.py
================================================
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this open-source project.
""" Define the Logger class to print log"""
import os
import sys
import logging
from datetime import datetime
class Logger:
def __init__(self, args, output_dir):
log = logging.getLogger(output_dir)
if not log.handlers:
log.setLevel(logging.DEBUG)
# if not os.path.exists(output_dir):
# os.mkdir(args.data.output_dir)
fh = logging.FileHandler(os.path.join(output_dir,'log.txt'))
fh.setLevel(logging.INFO)
ch = ProgressHandler()
ch.setLevel(logging.DEBUG)
formatter = logging.Formatter(fmt='%(asctime)s %(message)s', datefmt='%m/%d/%Y %I:%M:%S')
fh.setFormatter(formatter)
ch.setFormatter(formatter)
log.addHandler(fh)
log.addHandler(ch)
self.log = log
# setup TensorBoard
# if args.tensorboard:
# from tensorboardX import SummaryWriter
# self.writer = SummaryWriter(log_dir=args.output_dir)
# else:
self.writer = None
self.log_per_updates = args.log_per_updates
def set_progress(self, epoch, total):
self.log.info(f'Epoch: {epoch}')
self.epoch = epoch
self.i = 0
self.total = total
self.start = datetime.now()
def update(self, stats):
self.i += 1
if self.i % self.log_per_updates == 0:
remaining = str((datetime.now() - self.start) / self.i * (self.total - self.i))
remaining = remaining.split('.')[0]
updates = stats.pop('updates')
stats_str = ' '.join(f'{key}[{val:.8f}]' for key, val in stats.items())
self.log.info(f'> epoch [{self.epoch}] updates[{updates}] {stats_str} eta[{remaining}]')
if self.writer:
for key, val in stats.items():
self.writer.add_scalar(f'train/{key}', val, updates)
if self.i == self.total:
self.log.debug('\n')
self.log.debug(f'elapsed time: {str(datetime.now() - self.start).split(".")[0]}')
def log_eval(self, stats, metrics_group=None):
stats_str = ' '.join(f'{key}: {val:.8f}' for key, val in stats.items())
self.log.info(f'valid {stats_str}')
if self.writer:
for key, val in stats.items():
self.writer.add_scalar(f'valid/{key}', val, self.epoch)
# for mode, metrics in metrics_group.items():
# self.log.info(f'evaluation scores ({mode}):')
# for key, (val, _) in metrics.items():
# self.log.info(f'\t{key} {val:.4f}')
# if self.writer and metrics_group is not None:
# for key, val in stats.items():
# self.writer.add_scalar(f'valid/{key}', val, self.epoch)
# for key in list(metrics_group.values())[0]:
# group = {}
# for mode, metrics in metrics_group.items():
# group[mode] = metrics[key][0]
# self.writer.add_scalars(f'valid/{key}', group, self.epoch)
def __call__(self, msg):
self.log.info(msg)
class ProgressHandler(logging.Handler):
def __init__(self, level=logging.NOTSET):
super().__init__(level)
def emit(self, record):
log_entry = self.format(record)
if record.message.startswith('> '):
sys.stdout.write('{}\r'.format(log_entry.rstrip()))
sys.stdout.flush()
else:
sys.stdout.write('{}\n'.format(log_entry))
================================================
FILE: utils/visualize_inbetween.py
================================================
import numpy as np
import torch
import cv2
from .chamfer_distance import cd_score
# def make_inter_graph(v2d1, v2d2, topo1, topo2, match12):
# valid = (match12 != -1)
# marked2 = np.zeros(len(v2d2)).astype(bool)
# # print(match12[valid])
# marked2[match12[valid]] = True
# id1toh, id2toh = np.zeros(len(v2d1)), np.zeros(len(v2d2))
# id1toh[valid] = np.arange(np.sum(valid))
# id2toh[match12[valid]] = np.arange(np.sum(valid))
# id1toh[np.invert(valid)] = np.arange(np.sum(1 - valid)) + np.sum(valid)
# # print(marked2)
# id2toh[np.invert(marked2)] = len(v2d1) + np.arange(np.sum(np.invert(marked2)))
# id1toh = id1toh.astype(int)
# id2toh = id2toh.astype(int)
# tot_len = len(v2d1) + np.sum(np.invert(marked2))
# vin1 = v2d1[valid][:]
# vin2 = v2d2[match12[valid]][:]
# vh = 0.5 * (vin1 + vin2)
# vh = np.concatenate((vh, v2d1[np.invert(valid)], v2d2[np.invert(marked2)]), axis=0)
# topoh = [[] for ii in range(tot_len)]
# for node in range(len(topo1)):
# for nb in topo1[node]:
# if int(id1toh[nb]) not in topoh[id1toh[node]]:
# topoh[id1toh[node]].append(int(id1toh[nb]))
# for node in range(len(topo2)):
# for nb in topo2[node]:
# if int(id2toh[nb]) not in topoh[id2toh[node]]:
# topoh[id2toh[node]].append(int(id2toh[nb]))
# return vh, topoh
# def make_inter_graph_valid(v2d1, v2d2, topo1, topo2, match12):
# valid = (match12 != -1)
# marked2 = np.zeros(len(v2d2)).astype(bool)
# # print(match12[valid])
# marked2[match12[valid]] = True
# id1toh, id2toh = np.zeros(len(v2d1)), np.zeros(len(v2d2))
# id1toh[valid] = np.arange(np.sum(valid))
# id2toh[match12[valid]] = np.arange(np.sum(valid))
# id1toh[np.invert(valid)] = np.arange(np.sum(1 - valid)) + np.sum(valid)
# # print(marked2)
# id2toh[np.invert(marked2)] = len(v2d1) + np.arange(np.sum(np.invert(marked2)))
# id1toh = id1toh.astype(int)
# id2toh = id2toh.astype(int)
# tot_len = len(v2d1) + np.sum(np.invert(marked2))
# vin1 = v2d1[valid][:]
# vin2 = v2d2[match12[valid]][:]
# vh = 0.5 * (vin1 + vin2)
# # vh = np.concatenate((vh, v2d1[np.invert(valid)], v2d2[np.invert(marked2)]), axis=0)
# # topoh = [[] for ii in range(tot_len)]
# topoh = [[] for ii in range(np.sum(valid))]
# for node in range(len(topo1)):
# if not valid[node]:
# continue
# for nb in topo1[node]:
# if int(id1toh[nb]) not in topoh[id1toh[node]]:
# if valid[nb]:
# topoh[id1toh[node]].append(int(id1toh[nb]))
# for node in range(len(topo2)):
# if not marked2[node]:
# continue
# for nb in topo2[node]:
# if int(id2toh[nb]) not in topoh[id2toh[node]]:
# if marked2[nb]:
# topoh[id2toh[node]].append(int(id2toh[nb]))
# return vh, topoh
def visualize(dict):
# print(dict['keypoints0'].size(), flush=True)
img1 = ((dict['image0'][0].permute(1, 2, 0).float().numpy() + 1.0) * 255 / 2).astype(np.uint8).copy()
original_target = ((dict['imaget'][0].permute(1, 2, 0).float().numpy() + 1.0) * 255 / 2).astype(np.uint8).copy()
# img1p = ((dict['image0'].permute(1, 2, 0).float().numpy() + 1.0) * 255 / 2).astype(int).copy()
# img2p = ((dict['image1'].permute(1, 2, 0).float().numpy() + 1.0) * 255 / 2).astype(int).copy()
# img1[:, :, 0] += 255
# img1[:, :, 1] += 180
# img1[:, :, 2] += 180
# img1[img1 > 255] = 255
# img2[:, :, 0] += 255
# img2[:, :, 1] += 180
# img2[:, :, 2] += 180
# img2[img2 > 255] = 255
# img1p[:, :, 0] += 255
# img1p[:, :, 1] += 180
# img1p[:, :, 2] += 180
# img1p[img1p > 255] = 255
# img2p[:, :, 0] += 255
# img2p[:, :, 1] += 180
# img2p[:, :, 2] += 180
# img2p[img2p > 255] = 255
# img1, img2, img1p, img2p = img1.astype(np.uint8), img2.astype(np.uint8), img1p.astype(np.uint8), img2p.astype(np.uint8)
motion01 = dict['motion0'][0].cpu().numpy().astype(int)
motion21 = dict['motion1'][0].cpu().numpy().astype(int)
source0_warp = dict['keypoints0t'][0].cpu().numpy().astype(int)
source2_warp = dict['keypoints1t'][0].cpu().numpy().astype(int)
source0 = dict['keypoints0'][0].cpu().numpy().astype(int)
source2 = dict['keypoints1'][0].cpu().numpy().astype(int)
source0_topo = dict['topo0'][0]
# print(len(dict['topo0']))
source2_topo = dict['topo1'][0]
visible01 = dict['vb0'][0].cpu().numpy().astype(int)
visible21 = dict['vb1'][0].cpu().numpy().astype(int)
# corr01 = dict['m01'][0].cpu().numpy().astype(int)
# corr10 = dict['m10'][0].cpu().numpy().astype(int)
# canvas = np.zeros_like(img1) + 255
# source0_warp2 = source0 + motion01 // 2
# source2_warp2 = source2 + motion21 // 2
# for node, nbs in enumerate(source0_topo):
# for nb in nbs:
# # print([source0_warp[nb][0], source0_warp[nb][1]])
# cv2.line(canvas, [source0_warp[node][0], source0_warp[node][1]], [source0_warp[nb][0], source0_warp[nb][1]], [0, 0, 0], 2)
# for node, nbs in enumerate(source2_topo):
# for nb in nbs:
# cv2.line(canvas, [source2_warp[node][0], source2_warp[node][1]], [source2_warp[nb][0], source2_warp[nb][1]], [0, 0, 0], 2)
# canvas6 = np.zeros_like(img1) + 255
# for node, nbs in enumerate(source0_topo):
# for nb in nbs:
# # print([source0_warp[nb][0], source0_warp[nb][1]])
# cv2.line(canvas6, [source0_warp2[node][0], source0_warp2[node][1]], [source0_warp2[nb][0], source0_warp2[nb][1]], [0, 0, 0], 2)
# for node, nbs in enumerate(source2_topo):
# for nb in nbs:
# cv2.line(canvas6, [source2_warp2[node][0], source2_warp2[node][1]], [source2_warp2[nb][0], source2_warp2[nb][1]], [0, 0, 0], 2)
canvas2 = np.zeros_like(img1) + 255
## print('huala<<<', source0_warp.mean(), source2_warp.mean(), flush=True)
# source0_warp = source0 + motion01
# source2_warp = source2 + motion21
for node, nbs in enumerate(source0_topo):
for nb in nbs:
# if visible01[node] and visible01[nb]:
cv2.line(canvas2, [source0_warp[node][0], source0_warp[node][1]], [source0_warp[nb][0], source0_warp[nb][1]], [0, 0, 0], 2)
for node, nbs in enumerate(source2_topo):
for nb in nbs:
# if visible21[node] and visible21[nb]:
cv2.line(canvas2, [source2_warp[node][0], source2_warp[node][1]], [source2_warp[nb][0], source2_warp[nb][1]], [0, 0, 0], 2)
# canvas2
# black_threshold = 255 // 2
# img1_sketch = rgb2sketch(img1, black_threshold)
# img2_sketch = rgb2sketch(img2, black_threshold)
# img1_sketch = img1_sketch.unsqueeze(0)
# img2_sketch = img2_sketch.unsqueeze(0)
# CD = ChamferDistance2dMetric()
# cd = CD(img1_sketch,img2_sketch)
canvas5 = np.zeros_like(img1) + 255
# source0_warp = source0 + motion01
# source2_warp = source2 + motion21
## print('gulaa>>>', visible01.mean(), visible21.mean(), flush=True)
for node, nbs in enumerate(source0_topo):
for nb in nbs:
if visible01[node] and visible01[nb]:
cv2.line(canvas5, [source0_warp[node][0], source0_warp[node][1]], [source0_warp[nb][0], source0_warp[nb][1]], [0, 0, 0], 2)
for node, nbs in enumerate(source2_topo):
for nb in nbs:
if visible21[node] and visible21[nb]:
cv2.line(canvas5, [source2_warp[node][0], source2_warp[node][1]], [source2_warp[nb][0], source2_warp[nb][1]], [0, 0, 0], 2)
canvas3 = np.zeros_like(img1) + 255
for node, nbs in enumerate(source0_topo):
for nb in nbs:
cv2.line(canvas3, [source0[node][0], source0[node][1]], [source0[nb][0], source0[nb][1]], [255, 180, 180], 2)
for node, nbs i
gitextract_oa9f83a9/
├── .gitignore
├── README.md
├── compute_cd.py
├── configs/
│ └── cr_inbetweener_full.yaml
├── corr/
│ ├── configs/
│ │ └── vtx_corr.yaml
│ ├── datasets/
│ │ ├── __init__.py
│ │ └── ml_dataset.py
│ ├── experiments/
│ │ └── vtx_corr/
│ │ └── ckpt/
│ │ └── .gitkeep
│ ├── main.py
│ ├── models/
│ │ ├── __init__.py
│ │ └── supergluet.py
│ ├── srun.sh
│ ├── utils/
│ │ ├── log.py
│ │ └── visualize_vtx_corr.py
│ └── vtx_matching.py
├── data/
│ └── README.md
├── datasets/
│ ├── __init__.py
│ ├── ml_seq.py
│ └── vd_seq.py
├── download.sh
├── experiments/
│ └── inbetweener_full/
│ └── ckpt/
│ └── .gitkeep
├── inbetween.py
├── inbetween_results/
│ └── .gitkeep
├── main.py
├── models/
│ ├── __init__.py
│ ├── inbetweener_with_mask2.py
│ └── inbetweener_with_mask_with_spec.py
├── requirement.txt
├── srun.sh
└── utils/
├── chamfer_distance.py
├── log.py
├── visualize_inbetween.py
├── visualize_inbetween2.py
├── visualize_inbetween3.py
└── visualize_video.py
SYMBOL INDEX (196 symbols across 18 files)
FILE: corr/datasets/ml_dataset.py
function read_json (line 21) | def read_json(file_path):
function ids_to_mat (line 36) | def ids_to_mat(id1, id2):
function adj_matrix (line 55) | def adj_matrix(topology):
class MixamoLineArt (line 68) | class MixamoLineArt(data.Dataset):
method __init__ (line 69) | def __init__(self, root, gap=0, split='train', model=None, action=None...
method __getitem__ (line 127) | def __getitem__(self, index):
method __rmul__ (line 272) | def __rmul__(self, v):
method __len__ (line 278) | def __len__(self):
function worker_init_fn (line 282) | def worker_init_fn(worker_id):
function fetch_dataloader (line 285) | def fetch_dataloader(args, type='train',):
FILE: corr/main.py
function parse_args (line 10) | def parse_args():
function main (line 23) | def main():
FILE: corr/models/supergluet.py
function MLP (line 10) | def MLP(channels: list, do_bn=True):
function normalize_keypoints (line 24) | def normalize_keypoints(kpts, image_shape):
class ThreeLayerEncoder (line 33) | class ThreeLayerEncoder(nn.Module):
method __init__ (line 35) | def __init__(self, enc_dim):
method forward (line 53) | def forward(self, img):
class VertexDescriptor (line 61) | class VertexDescriptor(nn.Module):
method __init__ (line 63) | def __init__(self, enc_dim):
method forward (line 68) | def forward(self, img, vtx):
class KeypointEncoder (line 76) | class KeypointEncoder(nn.Module):
method __init__ (line 78) | def __init__(self, feature_dim, layers):
method forward (line 87) | def forward(self, kpts):
class TopoEncoder (line 94) | class TopoEncoder(nn.Module):
method __init__ (line 96) | def __init__(self, feature_dim, layers):
method forward (line 105) | def forward(self, kpts):
function attention (line 113) | def attention(query, key, value, mask=None):
class MultiHeadedAttention (line 125) | class MultiHeadedAttention(nn.Module):
method __init__ (line 127) | def __init__(self, num_heads: int, d_model: int):
method forward (line 135) | def forward(self, query, key, value, mask=None):
class AttentionalPropagation (line 144) | class AttentionalPropagation(nn.Module):
method __init__ (line 145) | def __init__(self, feature_dim: int, num_heads: int):
method forward (line 151) | def forward(self, x, source, mask=None):
class AttentionalGNN (line 156) | class AttentionalGNN(nn.Module):
method __init__ (line 157) | def __init__(self, feature_dim: int, layer_names: list):
method forward (line 164) | def forward(self, desc0, desc1, mask00=None, mask11=None, mask01=None,...
function log_sinkhorn_iterations (line 179) | def log_sinkhorn_iterations(Z, log_mu, log_nu, iters: int):
function log_optimal_transport (line 188) | def log_optimal_transport(scores, alpha, iters: int, ms=None, ns=None):
function arange_like (line 229) | def arange_like(x, dim: int):
class SuperGlueT (line 233) | class SuperGlueT(nn.Module):
method __init__ (line 235) | def __init__(self, config=None):
method forward (line 274) | def forward(self, data):
FILE: corr/utils/log.py
class Logger (line 12) | class Logger:
method __init__ (line 13) | def __init__(self, args, output_dir):
method set_progress (line 38) | def set_progress(self, epoch, total):
method update (line 45) | def update(self, stats):
method log_eval (line 62) | def log_eval(self, stats, metrics_group=None):
method __call__ (line 81) | def __call__(self, msg):
class ProgressHandler (line 85) | class ProgressHandler(logging.Handler):
method __init__ (line 86) | def __init__(self, level=logging.NOTSET):
method emit (line 89) | def emit(self, record):
FILE: corr/utils/visualize_vtx_corr.py
function make_inter_graph (line 6) | def make_inter_graph(v2d1, v2d2, topo1, topo2, match12):
function make_inter_graph_valid (line 47) | def make_inter_graph_valid(v2d1, v2d2, topo1, topo2, match12):
function visualize (line 94) | def visualize(dict):
FILE: corr/vtx_matching.py
class VtxMat (line 36) | class VtxMat():
method __init__ (line 37) | def __init__(self, args):
method train (line 43) | def train(self):
method eval (line 172) | def eval(self):
method _build (line 269) | def _build(self):
method _build_model (line 280) | def _build_model(self):
method _build_train_loader (line 293) | def _build_train_loader(self):
method _build_test_loader (line 297) | def _build_test_loader(self):
method _build_optimizer (line 301) | def _build_optimizer(self):
method _dir_setting (line 314) | def _dir_setting(self):
FILE: datasets/ml_seq.py
function read_json (line 23) | def read_json(file_path):
function matched_motion (line 39) | def matched_motion(v2d1, v2d2, match12, motion_pre=None):
function unmatched_motion (line 47) | def unmatched_motion(topo1, v2d1, motion12, match12):
function ids_to_mat (line 76) | def ids_to_mat(id1, id2):
function adj_matrix (line 97) | def adj_matrix(topology):
class MixamoLineArtMotionSequence (line 110) | class MixamoLineArtMotionSequence(data.Dataset):
method __init__ (line 111) | def __init__(self, root, gap=0, split='train', model=None, action=None...
method __getitem__ (line 184) | def __getitem__(self, index):
method __rmul__ (line 505) | def __rmul__(self, v):
method __len__ (line 510) | def __len__(self):
function worker_init_fn (line 514) | def worker_init_fn(worker_id):
function fetch_dataloader (line 517) | def fetch_dataloader(args, type='train',):
FILE: datasets/vd_seq.py
function read_json (line 23) | def read_json(file_path):
class VideoLinSeq (line 42) | class VideoLinSeq(data.Dataset):
method __init__ (line 43) | def __init__(self, root, split='train'):
method __getitem__ (line 78) | def __getitem__(self, index):
method __rmul__ (line 171) | def __rmul__(self, v):
method __len__ (line 176) | def __len__(self):
function worker_init_fn (line 180) | def worker_init_fn(worker_id):
function fetch_videoloader (line 183) | def fetch_videoloader(args, type='train',):
FILE: inbetween.py
class DraftRefine (line 40) | class DraftRefine():
method __init__ (line 41) | def __init__(self, args):
method train (line 47) | def train(self):
method eval (line 197) | def eval(self):
method gen (line 325) | def gen(self):
method _build (line 393) | def _build(self):
method _build_model (line 406) | def _build_model(self):
method _build_train_loader (line 419) | def _build_train_loader(self):
method _build_test_loader (line 423) | def _build_test_loader(self):
method _build_video_loader (line 426) | def _build_video_loader(self):
method _build_optimizer (line 430) | def _build_optimizer(self):
method _dir_setting (line 443) | def _dir_setting(self):
FILE: main.py
function parse_args (line 10) | def parse_args():
function main (line 24) | def main():
FILE: models/inbetweener_with_mask2.py
function MLP (line 9) | def MLP(channels: list, do_bn=True):
function normalize_keypoints (line 24) | def normalize_keypoints(kpts, image_shape):
class ThreeLayerEncoder (line 33) | class ThreeLayerEncoder(nn.Module):
method __init__ (line 35) | def __init__(self, enc_dim):
method forward (line 53) | def forward(self, img):
class VertexDescriptor (line 63) | class VertexDescriptor(nn.Module):
method __init__ (line 65) | def __init__(self, enc_dim):
method forward (line 72) | def forward(self, img, vtx):
class KeypointEncoder (line 81) | class KeypointEncoder(nn.Module):
method __init__ (line 83) | def __init__(self, feature_dim, layers):
method forward (line 92) | def forward(self, kpts):
function attention (line 100) | def attention(query, key, value, mask=None):
class MultiHeadedAttention (line 119) | class MultiHeadedAttention(nn.Module):
method __init__ (line 121) | def __init__(self, num_heads: int, d_model: int):
method forward (line 129) | def forward(self, query, key, value, mask=None):
class AttentionalPropagation (line 138) | class AttentionalPropagation(nn.Module):
method __init__ (line 139) | def __init__(self, feature_dim: int, num_heads: int):
method forward (line 145) | def forward(self, x, source, mask=None):
class AttentionalGNN (line 150) | class AttentionalGNN(nn.Module):
method __init__ (line 151) | def __init__(self, feature_dim: int, layer_names: list):
method forward (line 158) | def forward(self, desc0, desc1, mask00=None, mask11=None, mask01=None,...
function log_sinkhorn_iterations (line 173) | def log_sinkhorn_iterations(Z, log_mu, log_nu, iters: int):
function log_optimal_transport (line 182) | def log_optimal_transport(scores, alpha, iters: int, ms=None, ns=None):
function arange_like (line 223) | def arange_like(x, dim: int):
class SuperGlueM (line 227) | class SuperGlueM(nn.Module):
method __init__ (line 254) | def __init__(self, config=None):
method forward (line 294) | def forward(self, data):
function tensor_erode (line 402) | def tensor_erode(bin_img, ksize=5):
class InbetweenerM (line 417) | class InbetweenerM(nn.Module):
method __init__ (line 444) | def __init__(self, config=None):
method forward (line 458) | def forward(self, data):
FILE: models/inbetweener_with_mask_with_spec.py
function MLP (line 11) | def MLP(channels: list, do_bn=True):
function normalize_keypoints (line 26) | def normalize_keypoints(kpts, image_shape):
class ThreeLayerEncoder (line 35) | class ThreeLayerEncoder(nn.Module):
method __init__ (line 37) | def __init__(self, enc_dim):
method forward (line 55) | def forward(self, img):
class VertexDescriptor (line 65) | class VertexDescriptor(nn.Module):
method __init__ (line 67) | def __init__(self, enc_dim):
method forward (line 74) | def forward(self, img, vtx):
class KeypointEncoder (line 83) | class KeypointEncoder(nn.Module):
method __init__ (line 85) | def __init__(self, feature_dim, layers):
method forward (line 94) | def forward(self, kpts):
class TopoEncoder (line 100) | class TopoEncoder(nn.Module):
method __init__ (line 102) | def __init__(self, feature_dim, layers):
method forward (line 111) | def forward(self, kpts):
function attention (line 117) | def attention(query, key, value, mask=None):
class MultiHeadedAttention (line 128) | class MultiHeadedAttention(nn.Module):
method __init__ (line 130) | def __init__(self, num_heads: int, d_model: int):
method forward (line 138) | def forward(self, query, key, value, mask=None):
class AttentionalPropagation (line 147) | class AttentionalPropagation(nn.Module):
method __init__ (line 148) | def __init__(self, feature_dim: int, num_heads: int):
method forward (line 154) | def forward(self, x, source, mask=None):
class AttentionalGNN (line 159) | class AttentionalGNN(nn.Module):
method __init__ (line 160) | def __init__(self, feature_dim: int, layer_names: list):
method forward (line 167) | def forward(self, desc0, desc1, mask00=None, mask11=None, mask01=None,...
function log_sinkhorn_iterations (line 182) | def log_sinkhorn_iterations(Z, log_mu, log_nu, iters: int):
function log_optimal_transport (line 191) | def log_optimal_transport(scores, alpha, iters: int, ms=None, ns=None):
function arange_like (line 231) | def arange_like(x, dim: int):
class SuperGlueT (line 235) | class SuperGlueT(nn.Module):
method __init__ (line 262) | def __init__(self, config=None):
method forward (line 300) | def forward(self, data):
function tensor_erode (line 390) | def tensor_erode(bin_img, ksize=5):
class InbetweenerTM (line 402) | class InbetweenerTM(nn.Module):
method __init__ (line 418) | def __init__(self, config=None):
method forward (line 426) | def forward(self, data):
FILE: utils/chamfer_distance.py
function batch_edt (line 14) | def batch_edt(img, block=1024):
function batch_chamfer_distance (line 46) | def batch_chamfer_distance(gt, pred, block=1024, return_more=False):
function batch_chamfer_distance_t (line 51) | def batch_chamfer_distance_t(gt, pred, block=1024, return_more=False):
function batch_chamfer_distance_p (line 61) | def batch_chamfer_distance_p(gt, pred, block=1024, return_more=False):
function batch_hausdorff_distance (line 73) | def batch_hausdorff_distance(gt, pred, block=1024, return_more=False):
class ChamferDistance2dMetric (line 90) | class ChamferDistance2dMetric(torchmetrics.Metric):
method __init__ (line 92) | def __init__(
method update (line 104) | def update(self, preds: torch.Tensor, target: torch.Tensor):
method compute (line 110) | def compute(self):
class ChamferDistance2dTMetric (line 113) | class ChamferDistance2dTMetric(ChamferDistance2dMetric):
method update (line 114) | def update(self, preds: torch.Tensor, target: torch.Tensor):
class ChamferDistance2dPMetric (line 122) | class ChamferDistance2dPMetric(ChamferDistance2dMetric):
method update (line 123) | def update(self, preds: torch.Tensor, target: torch.Tensor):
class HausdorffDistance2dMetric (line 132) | class HausdorffDistance2dMetric(torchmetrics.Metric):
method __init__ (line 133) | def __init__(
method update (line 148) | def update(self, preds: torch.Tensor, target: torch.Tensor):
method compute (line 156) | def compute(self):
function rgb2sketch (line 162) | def rgb2sketch(img, black_threshold):
function rgb2gray (line 168) | def rgb2gray(rgb):
function cd_score (line 174) | def cd_score(img1, img2):
FILE: utils/log.py
class Logger (line 12) | class Logger:
method __init__ (line 13) | def __init__(self, args, output_dir):
method set_progress (line 38) | def set_progress(self, epoch, total):
method update (line 45) | def update(self, stats):
method log_eval (line 62) | def log_eval(self, stats, metrics_group=None):
method __call__ (line 81) | def __call__(self, msg):
class ProgressHandler (line 85) | class ProgressHandler(logging.Handler):
method __init__ (line 86) | def __init__(self, level=logging.NOTSET):
method emit (line 89) | def emit(self, record):
FILE: utils/visualize_inbetween.py
function visualize (line 95) | def visualize(dict):
FILE: utils/visualize_inbetween2.py
function visualize (line 95) | def visualize(dict):
FILE: utils/visualize_inbetween3.py
function visualize (line 95) | def visualize(dict):
FILE: utils/visualize_video.py
function visvid (line 7) | def visvid(dict, inter_frames=1):
Condensed preview — 35 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (231K chars).
[
{
"path": ".gitignore",
"chars": 73,
"preview": "*/__pycache__/*\n*.pt\n*.jpg\n*.pyc\ndata/ml100_norm/\ndata/ml144*\ndata/*.zip\n"
},
{
"path": "README.md",
"chars": 8079,
"preview": "# AnimeInbet\n\nCode for ICCV 2023 paper \"Deep Geometrized Cartoon Line Inbetweening\"\n\n[[Paper]](https://openaccess.thecvf"
},
{
"path": "compute_cd.py",
"chars": 1261,
"preview": "import argparse\nimport cv2\nimport os\nfrom utils.chamfer_distance import cd_score\nimport numpy as np\n\n\n\n\nif __name__ == \""
},
{
"path": "configs/cr_inbetweener_full.yaml",
"chars": 1049,
"preview": "model:\n name: InbetweenerTM\n corr_model:\n descriptor_dim: 128\n keypoint_encoder: [32, 64, 128]\n "
},
{
"path": "corr/configs/vtx_corr.yaml",
"chars": 771,
"preview": "model:\n name: SuperGlueT\n descriptor_dim: 128\n keypoint_encoder: [32, 64, 128]\n GNN_layer_num: 12\n sinkho"
},
{
"path": "corr/datasets/__init__.py",
"chars": 127,
"preview": "from .ml_dataset import MixamoLineArt\nfrom .ml_dataset import fetch_dataloader\n\n__all__ = ['MixamoLineArt', 'fetch_datal"
},
{
"path": "corr/datasets/ml_dataset.py",
"chars": 15288,
"preview": "import numpy as np\nimport torch\nimport torch.utils.data as data\nimport torch.nn.functional as F\n# import networkx as nx\n"
},
{
"path": "corr/experiments/vtx_corr/ckpt/.gitkeep",
"chars": 0,
"preview": ""
},
{
"path": "corr/main.py",
"chars": 927,
"preview": "from vtx_matching import VtxMat\nimport argparse\nimport os\nimport yaml\nfrom pprint import pprint\nfrom easydict import Eas"
},
{
"path": "corr/models/__init__.py",
"chars": 248,
"preview": "from .supergluet import SuperGlueT\n# from .supergluet_wo_OT import SuperGlueTwoOT\n# from .supergluenp import SuperGlue a"
},
{
"path": "corr/models/supergluet.py",
"chars": 16000,
"preview": "import numpy as np\nfrom copy import deepcopy\nfrom pathlib import Path\nimport torch\nfrom torch import nn\n\nimport argparse"
},
{
"path": "corr/srun.sh",
"chars": 979,
"preview": "#!/bin/sh\ncurrenttime=`date \"+%Y%m%d%H%M%S\"`\nif [ ! -d log ]; then\n mkdir log\nfi\n\necho \"[Usage] ./srun.sh config_path"
},
{
"path": "corr/utils/log.py",
"chars": 3673,
"preview": "# This source code is licensed under the MIT license found in the\n# LICENSE file in the root directory of this open-sour"
},
{
"path": "corr/utils/visualize_vtx_corr.py",
"chars": 9176,
"preview": "import numpy as np\nimport torch\nimport cv2\n\n\ndef make_inter_graph(v2d1, v2d2, topo1, topo2, match12):\n valid = (match"
},
{
"path": "corr/vtx_matching.py",
"chars": 13129,
"preview": "\"\"\" This script handling the training process. \"\"\"\nimport os\nimport time\nimport random\nimport argparse\nimport torch\nimpo"
},
{
"path": "data/README.md",
"chars": 0,
"preview": ""
},
{
"path": "datasets/__init__.py",
"chars": 128,
"preview": "\nfrom .ml_seq import fetch_dataloader\nfrom .vd_seq import fetch_videoloader\n\n__all__ = ['fetch_dataloader', 'fetch_video"
},
{
"path": "datasets/ml_seq.py",
"chars": 18411,
"preview": "import numpy as np\nimport torch\nimport torch.utils.data as data\nimport torch.nn.functional as F\n\nimport os\nimport math\ni"
},
{
"path": "datasets/vd_seq.py",
"chars": 5804,
"preview": "import numpy as np\nimport torch\nimport torch.utils.data as data\nimport torch.nn.functional as F\n# import networkx as nx\n"
},
{
"path": "download.sh",
"chars": 68,
"preview": "cd data\ngdown 1SNRGajIECxNwRp6ZJ0IlY7AEl2mRm2DR\nunzip ml240data.zip\n"
},
{
"path": "experiments/inbetweener_full/ckpt/.gitkeep",
"chars": 0,
"preview": ""
},
{
"path": "inbetween.py",
"chars": 19151,
"preview": "\"\"\" This script handling the training process. \"\"\"\nimport os\nimport time\nimport random\nimport argparse\nimport torch\nimpo"
},
{
"path": "inbetween_results/.gitkeep",
"chars": 0,
"preview": ""
},
{
"path": "main.py",
"chars": 1025,
"preview": "from inbetween import DraftRefine\nimport argparse\nimport os\nimport yaml\nfrom pprint import pprint\nfrom easydict import E"
},
{
"path": "models/__init__.py",
"chars": 848,
"preview": "# from .transformer_refiner import Refiner\n# from .inbetweener import Inbetweener\n# from .inbetweener_with_mask import I"
},
{
"path": "models/inbetweener_with_mask2.py",
"chars": 34042,
"preview": "from copy import deepcopy\nfrom pathlib import Path\nimport torch\nfrom torch import nn\n# from seg_desc import seg_descript"
},
{
"path": "models/inbetweener_with_mask_with_spec.py",
"chars": 26288,
"preview": "from copy import deepcopy\nfrom pathlib import Path\nimport torch\nfrom torch import nn\n# from seg_desc import seg_descript"
},
{
"path": "requirement.txt",
"chars": 73,
"preview": "opencv-python\npyyaml==5.4.1\nscikit-network\ntqdm\nmatplotlib\neasydict\ngdown"
},
{
"path": "srun.sh",
"chars": 1308,
"preview": "#!/bin/sh\ncurrenttime=`date \"+%Y%m%d%H%M%S\"`\nif [ ! -d log ]; then\n mkdir log\nfi\n\necho \"[Usage] ./srun.sh config_path"
},
{
"path": "utils/chamfer_distance.py",
"chars": 6330,
"preview": "import os\nimport numpy as np\nfrom time import time\nimport cv2\nimport pdb\nimport scipy\nimport scipy.ndimage\nimport torch\n"
},
{
"path": "utils/log.py",
"chars": 3673,
"preview": "# This source code is licensed under the MIT license found in the\n# LICENSE file in the root directory of this open-sour"
},
{
"path": "utils/visualize_inbetween.py",
"chars": 11672,
"preview": "import numpy as np\nimport torch\nimport cv2\nfrom .chamfer_distance import cd_score\n\n\n# def make_inter_graph(v2d1, v2d2, t"
},
{
"path": "utils/visualize_inbetween2.py",
"chars": 8893,
"preview": "import numpy as np\nimport torch\nimport cv2\nfrom .chamfer_distance import cd_score\n\n\n# def make_inter_graph(v2d1, v2d2, t"
},
{
"path": "utils/visualize_inbetween3.py",
"chars": 11731,
"preview": "import numpy as np\nimport torch\nimport cv2\nfrom .chamfer_distance import cd_score\n\n\n# def make_inter_graph(v2d1, v2d2, t"
},
{
"path": "utils/visualize_video.py",
"chars": 2509,
"preview": "import numpy as np\nimport torch\nimport cv2\n\n\n\ndef visvid(dict, inter_frames=1):\n img1 = ((dict['image0'][0].permute(1"
}
]
About this extraction
This page contains the full source code of the lisiyao21/AnimeInbet GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 35 files (217.5 KB), approximately 64.5k tokens, and a symbol index with 196 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.
Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.