Full Code of lisiyao21/AnimeInbet for AI

main cc5554feb9d8 cached
35 files
217.5 KB
64.5k tokens
196 symbols
1 requests
Download .txt
Showing preview only (229K chars total). Download the full file or copy to clipboard to get everything.
Repository: lisiyao21/AnimeInbet
Branch: main
Commit: cc5554feb9d8
Files: 35
Total size: 217.5 KB

Directory structure:
gitextract_oa9f83a9/

├── .gitignore
├── README.md
├── compute_cd.py
├── configs/
│   └── cr_inbetweener_full.yaml
├── corr/
│   ├── configs/
│   │   └── vtx_corr.yaml
│   ├── datasets/
│   │   ├── __init__.py
│   │   └── ml_dataset.py
│   ├── experiments/
│   │   └── vtx_corr/
│   │       └── ckpt/
│   │           └── .gitkeep
│   ├── main.py
│   ├── models/
│   │   ├── __init__.py
│   │   └── supergluet.py
│   ├── srun.sh
│   ├── utils/
│   │   ├── log.py
│   │   └── visualize_vtx_corr.py
│   └── vtx_matching.py
├── data/
│   └── README.md
├── datasets/
│   ├── __init__.py
│   ├── ml_seq.py
│   └── vd_seq.py
├── download.sh
├── experiments/
│   └── inbetweener_full/
│       └── ckpt/
│           └── .gitkeep
├── inbetween.py
├── inbetween_results/
│   └── .gitkeep
├── main.py
├── models/
│   ├── __init__.py
│   ├── inbetweener_with_mask2.py
│   └── inbetweener_with_mask_with_spec.py
├── requirement.txt
├── srun.sh
└── utils/
    ├── chamfer_distance.py
    ├── log.py
    ├── visualize_inbetween.py
    ├── visualize_inbetween2.py
    ├── visualize_inbetween3.py
    └── visualize_video.py

================================================
FILE CONTENTS
================================================

================================================
FILE: .gitignore
================================================
*/__pycache__/*
*.pt
*.jpg
*.pyc
data/ml100_norm/
data/ml144*
data/*.zip


================================================
FILE: README.md
================================================
# AnimeInbet

Code for ICCV 2023 paper "Deep Geometrized Cartoon Line Inbetweening"

[[Paper]](https://openaccess.thecvf.com/content/ICCV2023/papers/Siyao_Deep_Geometrized_Cartoon_Line_Inbetweening_ICCV_2023_paper.pdf) | [[Video Demo]](https://youtu.be/iUF-LsqFKpI?si=9FViAZUyFdSfZzS5) | [[Data (Google Drive)]](https://drive.google.com/file/d/1SNRGajIECxNwRp6ZJ0IlY7AEl2mRm2DR/view?usp=sharing) 

✨ Do not hesitate to give a star! Thank you! ✨


![image](https://github.com/lisiyao21/AnimeInbet/blob/main/figures/inbet_gif.gif)

> We aim to address a significant but understudied problem in the anime industry, namely the inbetweening of cartoon line drawings. Inbetweening involves generating intermediate frames between two black-and-white line drawings and is a time-consuming and expensive process that can benefit from automation. However, existing frame interpolation methods that rely on matching and warping whole raster images are unsuitable for line inbetweening and often produce blurring artifacts that damage the intricate line structures. To preserve the precision and detail of the line drawings, we propose a new approach, AnimeInbet, which geometrizes raster line drawings into graphs of endpoints and reframes the inbetweening task as a graph fusion problem with vertex repositioning. Our method can effectively capture the sparsity and unique structure of line drawings while preserving the details during inbetweening. This is made possible via our novel modules, i.e., vertex geometric embedding, a vertex correspondence Transformer, an effective mechanism for vertex repositioning and a visibility predictor. To train our method, we introduce MixamoLine240, a new dataset of line drawings with ground truth vectorization and matching labels. Our experiments demonstrate that AnimeInbet synthesizes high-quality, clean, and complete intermediate line drawings, outperforming existing methods quantitatively and qualitatively, especially in cases with large motions.

# ML240 Data

The implementation of AnimeInbet depends on the matching of line vertices in the two adjancent two frames. To supervise the learning of vertex correspondence, we make a large-scale cartoon line sequential data, **MixiamoLine240** (ML240). ML240 contains a training set (100 sequences), a validation set (44 sequences) and a test set (100 sequences). Each sequence i

To use the data, please first download it from [link](https://drive.google.com/file/d/1SNRGajIECxNwRp6ZJ0IlY7AEl2mRm2DR/view?usp=sharing) and uncompress it into **data** folder under this project directory. After decompression, the data will be like 

        data
          |_ml100_norm
          |        |_ all
          |             |_frames  
          |             |    |_chip_abe
          |             |    |     |_Image0001.png
          |             |    |     |_Image0001.png
          |             |    |     |
          |             |    |     ...  
          |             |    ... 
          |             |
          |             |_labels
          |                  |_chip_abe
          |                  |     |_Line0001.json
          |                  |     |_Line0001.json
          |                  |     |
          |                  |     ...  
          |                  ...
          | 
          |_ml144_norm_100_44_split  
                  |_ test
                  |    |_frames  
                  |    |    |_breakdance_1990_police
                  |    |    |     |_Image0001.png
                  |    |    |     |_Image0001.png
                  |    |    |     |
                  |    |    |     ...  
                  |    |    ... 
                  |    |
                  |    |_labels
                  |         |_breakdance_1990_police
                  |         |     |_Line0001.json
                  |         |     |_Line0001.json
                  |         |     |
                  |         |     ...  
                  |         ...
                  |_ train
                      |_frames  
                      |    |_breakdance_1990_ganfaul
                      |    |     |_Image0001.png
                      |    |     |_Image0001.png
                      |    |     |
                      |    |     ...  
                      |    ... 
                      |
                      |_labels
                          |_breakdance_1990_ganfaul
                          |     |_Line0001.json
                          |     |_Line0001.json
                          |     |
                          |     ...  
                          ...


The json file in the "labels" folder (for example, ml100_norm/all/labels/chip_abe/Line0001.json) is the verctorization/geometrization labels of the corresponding image in the "frames" folder (ml100_norm/ all/frames/chip_abe/_Image0001.png). Each json file contains there components. (1) **vertex location**: line art vertices 2D positions, (2) **connection**: adjancent table of the vector graph and (3) **original index**: the index number of each vertex in the original 3D mesh.


# Code

## Environment 

    conda create -n inbetween python=3.8
    conda activate inbetween
    conda install pytorch==1.7.0 torchvision==0.8.0 torchaudio==0.7.0 cudatoolkit=10.1 -c pytorch
    pip install -r requirement.txt


![image](https://github.com/lisiyao21/AnimeInbet/blob/main/figures/pipeline.png)

In this code, the whole pipeline is separated into two parts: (1) vertex correspondence and (2) inbetweening/synthesis. In the first part, it is trained to match the vertices of two input vector graphs, including the "vertex embedding" and "vertex corr. Transformer". Then,  "repositioning propagation" and "graph fusion" are done in the second part.

The first part is inner ./corr, and the second is all others. We provide a pretrained correspondence network weight ([link](https://drive.google.com/file/d/1Edc-XGyMXqXDdfBYoglDMkBf7_AYZU0p/view?usp=sharing)) and a pretrained whole pipeline weight ([link](https://drive.google.com/file/d/1cemJCBNdcTvJ9LWCA_5LmDDorwEb-u7M/view?usp=sharing)). For correspondence, please decompress the weight (epoch_50.pt) to ./corr/experiments/vtx_corr/ckpt. For the whole pipeline, please decompress the weight (epoch_20.pt) to ./experiments/inbetweener_full/ckpt/.


## Train & test corr.

For training, first, please cd into the ./corr folder and then run

    sh srun.sh configs/vtx_corr.yaml train [your node name] 1

If you don't use slurm in your computer/cluster, you can run

    python -u main.py --config vtx_corr.yaml --train 

For testing correspondence network, please run

    sh srun.sh configs/vtx_corr.yaml train [your node name] 1

or 

    python -u main.py --config vtx_corr.yaml --test

You may directly run the test code after downloading the weights without training.

## Train & test the whole inbetweening pipeline

For training the whole pipeline, please firstly cd out from ./corr to the root project folder and run

    sh srun.sh configs/cr_inbetweener_full.yaml train [your node name] 1

or

    python -u main.py --config cr_inbetweener_full.yaml --train 

For testing, please run

    sh srun.sh configs/cr_inbetweener_full.yaml train [your node name] 1

or 

    python -u main.py --config cr_inbetweener_full.yaml --test

Inbetweened results will be stored into ./inbetween_results folder.

### Compute CD values

The CD code is under utils/chamfer_distance.py. Please run

    python compute_cd.py --gt ./data/ml100_norm/all/frames --generated ./inbetween_results/test_gap=5

If everything goes right the score will be the same as that reported in the paper.


# Citation

If you use our code or data, or find our work inspiring, please kindly cite our paper:

    @inproceedings{siyao2023inbetween,
	    title={Deep Geometrized Cartoon Line Inbetweening,
	    author={Siyao, Li and Gu, Tianpei and Xiao, Weiye and Ding, Henghui and Liu, Ziwei and Loy, Chen Change},
	    booktitle={ICCV},
	    year={2023}
    }

# License

ML240 is released with CC BY-NC-SA 4.0. Code is released for non-commercial uses only.



================================================
FILE: compute_cd.py
================================================
import argparse
import cv2
import os
from utils.chamfer_distance import cd_score
import numpy as np




if __name__ == "__main__":
    cds = []

    parser = argparse.ArgumentParser()
    parser.add_argument('--generated', type=str)
    parser.add_argument('--gt', type=str)
    args = parser.parse_args()

    gen_dir = args.generated
    gt_dir = args.gt

    if True:
    
        print('computing CD...', flush=True)

        for subfolder in os.listdir(gt_dir):
            # print(subfolder, len(cds), flush=True)
            for img in os.listdir(os.path.join(gt_dir, subfolder)):
                if not img.endswith('.png'):
                    continue
                img_gt = cv2.imread(os.path.join(gt_dir, subfolder, img))

                pred_name = subfolder + '_' + img.replace('Image', 'Line')
                if not os.path.exists(os.path.join(gen_dir, pred_name)):
                    continue
                img_pred = cv2.imread(os.path.join(gen_dir, pred_name))

                this_cd = cd_score(img_gt, img_pred)
                cds.append(this_cd)
                # print(this_cd, flush=True)
        
        print('GT: ', gt_dir)
        print('>>> Gen: ', gen_dir)
        print('>>> CD: ', np.mean(cds)/1e-5, print(len(cds)))





================================================
FILE: configs/cr_inbetweener_full.yaml
================================================
model:
    name: InbetweenerTM
    corr_model:
        descriptor_dim: 128
        keypoint_encoder: [32, 64, 128]
        GNN_layer_num: 12
        sinkhorn_iterations: 20
        match_threshold: 0.2
        descriptor_dim: 128
    pos_weight: 0.2

optimizer:
    type: Adam
    kwargs:
        lr: 0.0001
        betas: [0.9, 0.999]
        weight_decay: 0
    schedular_kwargs:
        milestones: [80]
        gamma: 0.1

data:
    train:
        root: 'data/ml144_norm_100_44_split/'
        batch_size: 1
        gap: 5
        type: 'train'
        model: None
        action: None
        mode: 'train'
    test:
        root: 'data/ml100_norm/'
        batch_size: 1
        gap: 5
        type: 'all'
        model: None
        action: None
        mode: 'eval'
        use_vs: False

testing:
    ckpt_epoch: 20
    
batch_size: 8

corr_weights: './corr/experiments/vtx_corr/ckpt/epoch_50.pt'

imwrite_dir: ./inbetween_results/test_gap=5

expname: inbetweener_full
epoch: 20
save_per_epochs: 1
log_per_updates: 1
test_freq: 10
seed: 42


================================================
FILE: corr/configs/vtx_corr.yaml
================================================
model:
    name: SuperGlueT
    descriptor_dim: 128
    keypoint_encoder: [32, 64, 128]
    GNN_layer_num: 12
    sinkhorn_iterations: 20
    match_threshold: 0.2
    descriptor_dim: 128

optimizer:
    type: Adam
    kwargs:
        lr: 0.00001
        betas: [0.9, 0.999]
        weight_decay: 0
    schedular_kwargs:
        milestones: [50, 150]
        gamma: 0.1

data:
    train:
        batch_size: 1
        gap: 5
        model: None
        action: None
        type: 'train'
        mode: 'train'
    test:
        batch_size: 1
        gap: 5
        type: 'test'
        model: None
        action: None
        mode: 'eval'

testing:
    ckpt_epoch: 50
batch_size: 8

expname: vtx_corr
epoch: 50
save_per_epochs: 1
log_per_updates: 1
test_freq: 1
seed: 42


================================================
FILE: corr/datasets/__init__.py
================================================
from .ml_dataset import MixamoLineArt
from .ml_dataset import fetch_dataloader

__all__ = ['MixamoLineArt', 'fetch_dataloader']

================================================
FILE: corr/datasets/ml_dataset.py
================================================
import numpy as np
import torch
import torch.utils.data as data
import torch.nn.functional as F
# import networkx as nx
import os
import math
import random
from glob import glob
import os.path as osp

import sys
import argparse
import cv2
from collections import Counter

import json
import sknetwork
from sknetwork.embedding import Spectral

def read_json(file_path):
    """
        input: json file path
        output: 2d vertex, connections, and index numbers in original 3D space
    """

    with open(file_path) as file:
        data = json.load(file)
        vertex2d = np.array(data['vertex location'])
        
        topology = data['connection']
        index = np.array(data['original index'])

    return vertex2d, topology, index

def ids_to_mat(id1, id2):
    """
        inputs are two list of vertex index in original 3D mesh
    """
    corr1 = np.zeros(len(id1)) - 1.0
    corr2 = np.zeros(len(id2)) - 1.0

    id1 = np.array(id1).astype(int)[:, None]
    id2 = np.array(id2).astype(int)
    
    mat = (id1 == id2)

    pos12 = np.arange(len(id2))[None].repeat(len(id1), 0)
    pos21 = np.arange(len(id1))[None].repeat(len(id2), 0)
    corr1[mat.astype(int).sum(1).astype(bool)] = pos12[mat]
    corr2[mat.transpose().astype(int).sum(1).astype(bool)] = pos21[mat.transpose()]

    return mat, corr1, corr2

def adj_matrix(topology):
    """
        topology is the adj table; returns adj matrix
    """
    gsize = len(topology)
    adj = np.zeros((gsize, gsize)).astype(float)
    for v in range(gsize):
        adj[v][v] = 1.0
        for nb in topology[v]:
            adj[v][nb] = 1.0
            adj[nb][v] = 1.0
    return adj

class MixamoLineArt(data.Dataset):
    def __init__(self, root, gap=0, split='train', model=None, action=None, mode='train', max_len=3050, use_vs=False):
        """
            input:
                root: the root folder of the line art data
                gap: how many frames between two frames
                split: train or test
                model: indicate a specific character (default None)
                action: indicate a specific action (default None)
        """
        super(MixamoLineArt, self).__init__()


        if model == 'None':
            model = None
        if action == 'None':
            action = None

        self.is_train = True if mode == 'train' else False
        self.is_eval = True if mode == 'eval' else False
        # self.is_train = False
        self.max_len = max_len

        self.image_list = []
        self.label_list = []
        
        if use_vs:
            label_root = osp.join(root, split, 'labels_vs')
        else:
            label_root = osp.join(root, split, 'labels')
        image_root = osp.join(root, split, 'frames')
        self.spectral = Spectral(64,  normalized=False)

        for clip in os.listdir(image_root):
            skip = False
            if model != None:
                for mm in model:
                    if mm in clip:
                        skip = True
                
            if action != None:
                for aa in action:
                    if aa in clip:
                        skip = True
            if skip:
                continue
            image_list = sorted(glob(osp.join(image_root, clip, '*.png')))
            label_list = sorted(glob(osp.join(label_root, clip, '*.json')))
            if len(image_list) != len(label_list):
                print(image_root, flush=True)
                continue
            for i in range(len(image_list) - (gap+1)):
                self.image_list += [ [image_list[i], image_list[i+gap+1]] ]
            for i in range(len(label_list) - (gap+1)):
                self.label_list += [ [label_list[i], label_list[i+gap+1]] ]
        # print(clip)
        print('Len of Frame is ', len(self.image_list))
        print('Len of Label is ', len(self.label_list))

    def __getitem__(self, index):

        # load image/label files
        # image crop to a square, 2d label same operation
        # index to index matching
        # spectral embedding

        # test does not need index matching
        
        index = index % len(self.image_list)
        file_name = self.label_list[index][0][:-4]
  
        img1 = cv2.imread(self.image_list[index][0])
        img2 = cv2.imread(self.image_list[index][1])
        v2d1, topo1, id1 = read_json(self.label_list[index][0])
        v2d2, topo2, id2 = read_json(self.label_list[index][1])
        for ii in range(len(topo1)):
            # if not len(topo1[ii]):
            topo1[ii].append(ii)
        for ii in range(len(topo2)):
            topo2[ii].append(ii)

    
        m, n = len(v2d1), len(v2d2)

        # img1, v2d1 = crop_img(img1, np.array(v2d1))
        # img2, v2d2 = crop_img(img2, np.array(v2d2))

        if len(img1.shape) == 2:
            img1 = np.tile(img1[...,None], (1, 1, 3))
            img2 = np.tile(img2[...,None], (1, 1, 3))
        else:
            img1 = img1[..., :3]
            img2 = img2[..., :3]
        
        img1 = torch.from_numpy(img1).permute(2, 0, 1).float() * 2 / 255.0 - 1.0 
        img2 = torch.from_numpy(img2).permute(2, 0, 1).float() * 2 / 255.0 - 1.0

        v2d1 = torch.from_numpy(v2d1)
        v2d2 = torch.from_numpy(v2d2)

        v2d1[v2d1 > 719] = 719
        v2d1[v2d1 < 0] = 0
        v2d2[v2d2 > 719] = 719
        v2d2[v2d2 < 0] = 0


        adj1 = sknetwork.data.from_adjacency_list(topo1, matrix_only=True, reindex=False).toarray()
        adj2 = sknetwork.data.from_adjacency_list(topo2, matrix_only=True, reindex=False).toarray()

        # note here we compute the spectral embedding of adj matrix in data loading period
        # since it needs cpu computation and is not friendy to our cluster's computation
        # put them here to use multi-cpu pre-computing before network forward flow
        spec0, spec1 = np.abs(self.spectral.fit_transform(adj1)), np.abs(self.spectral.fit_transform(adj2))

        mat_index, corr1, corr2 = ids_to_mat(id1, id2)
        mat_index = torch.from_numpy(mat_index).float()
        corr1 = torch.from_numpy(corr1).float()
        corr2 = torch.from_numpy(corr2).float()
        if self.is_train:
        # if False:
            v2d1 = torch.nn.functional.pad(v2d1, (0, 0, 0, self.max_len - m), mode='constant', value=0)
            v2d2 = torch.nn.functional.pad(v2d2, (0, 0, 0, self.max_len - n), mode='constant', value=0)
            corr1 = torch.nn.functional.pad(corr1, (0, self.max_len - m), mode='constant', value=0)
            corr2 = torch.nn.functional.pad(corr2, (0, self.max_len - n), mode='constant', value=0)

            mask0, mask1 = torch.zeros(self.max_len).float(), torch.zeros(self.max_len).float()
            mask0[:m] = 1
            mask1[:n] = 1
        else:
            mask0, mask1 = torch.ones(m).float(), torch.ones(n).float()

        # not return id anymore. too slow
        if self.is_eval:
            return{
                'keypoints0': v2d1,
                'keypoints1': v2d2,
                'topo0': [topo1],
                'topo1': [topo2],
                # 'id0': id1,
                # 'id1': id2,
                'adj_mat0': spec0,
                'adj_mat1': spec1,
                'image0': img1,
                'image1': img2,

                'all_matches': corr1,
                'm01': corr1,
                'm10': corr2,
                'ms': m,
                'ns': n,
                'mask0': mask0,
                'mask1': mask1,
                'file_name': file_name,
                # 'with_match': True
            } 
        if not self.is_train:
            return{
                'keypoints0': v2d1,
                'keypoints1': v2d2,
                # 'topo0': topo1,
                # 'topo1': topo2,
                # 'id0': id1,
                # 'id1': id2,
                'adj_mat0': spec0,
                'adj_mat1': spec1,
                'image0': img1,
                'image1': img2,

                'all_matches': corr1,
                'm01': corr1,
                'm10': corr2,
                'ms': m,
                'ns': n,
                'mask0': mask0,
                'mask1': mask1,
                'file_name': file_name,
                # 'with_match': True
            } 
        else:
            return{
                'keypoints0': v2d1,
                'keypoints1': v2d2,
                # 'topo0': topo1,
                # 'topo1': topo2,
                # 'id0': id1,
                # 'id1': id2,
                'adj_mat0': spec0,
                'adj_mat1': spec1,
                'image0': img1,
                'image1': img2,

                'all_matches': corr1,
                'm01': corr1,
                'm10': corr2,
                'ms': m,
                'ns': n,
                'mask0': mask0,
                'mask1': mask1,
                'file_name': file_name,
                # 'with_match': True
            } 

        

    def __rmul__(self, v):
        self.index_list = v * self.index_list
        self.seg_list = v * self.seg_list
        self.image_list = v * self.image_list
        return self
        
    def __len__(self):
        return len(self.image_list)
        

def worker_init_fn(worker_id):                                                          
    np.random.seed(np.random.get_state()[1][0] + worker_id)

def fetch_dataloader(args, type='train',):
    lineart = MixamoLineArt(root=args.root if hasattr(args, 'root') else '../data/ml144_norm_100_44_split/', gap=args.gap, split=args.type, model=args.model, action=args.action, mode=args.mode if hasattr(args, 'mode') else 'train', use_vs=args.use_vs if hasattr(args, 'use_vs') else False)
    train_loader = data.DataLoader(lineart, batch_size=args.batch_size, 
        pin_memory=True, shuffle=True, num_workers=8, drop_last=True, worker_init_fn=worker_init_fn)

    if args.mode != 'train':
        loader = data.DataLoader(lineart, batch_size=args.batch_size, 
            pin_memory=True, shuffle=False, num_workers=8)

    return train_loader


if __name__ == '__main__':
    torch.multiprocessing.set_sharing_strategy('file_system')
    args = argparse.Namespace()
    # args.subset = 'agent'
    args.batch_size = 1
    args.gap = 5
    args.type = 'test'
    args.model = ['ganfaul', 'firlscout', 'jolleen', 'kachujin', 'knight', 'maria', 'michelle', 'peasant', 'timmy', 'uriel']
    args.action = ['hip_hop', 'slash']
    # args.model = None
    # args.action = None
    args.use_vs = False
    # args.model = ['warrok', 'police']
    args.action = ['breakdance', 'capoeira', 'chapa-', 'fist_fight', 'flying', 'climb', 'running', 'reaction', 'magic', 'tripping']
        
    args.mode = 'eval'
    args.root='/mnt/lustre/syli/inbetween/data/12by12/ml144_norm_100_44_split/'
    # args.stage = 'anime'
    # args.image_size = (368, 368)
    # lineart = MixamoLineArt(root='/mnt/lustre/syli/inbetween/data/12by12/ml144/', gap=0, split='train')
    lineart = fetch_dataloader(args)
    # lineart = MixamoLineArt(root='/mnt/cache/syli/inbetween/data/ml100_norm/', gap=args.gap, split=args.type, model=args.model, action=args.action, mode=args.mode if hasattr(args, 'mode') else 'train')
    # train_loader = data.DataLoader(lineart, batch_size=args.batch_size, 

    percentage = 0.0
    vertex_num = 0.0
    vertex_shift = 0.0
    vertex_max_shift = 0.0
    edges = 0.0
    # for data in loader:
    #     print(data)
    #     break
    unmatched_all = []
    max_node = 0
    for dict in lineart:
        # print(dict['file_name'])
        # print(dict['file_name'][0], flush=True)
        v2d1 = dict['keypoints0'].numpy().astype(int)[0]
        v2d2 = dict['keypoints1'].numpy().astype(int)[0]

        ms = dict['ms'][0]
        ns = dict['ns'][0]
        # this_edges 
        topo = dict['topo0'][0]
        for ii in range(len(topo)):
            edges += len(topo[ii])
        # print(len(topo), flush=True)


        # print(ms, ns, flush=True)
        # print(dict['keypoints0'], flush=True)
        # print(dict['image0'].size(), flush=True)
        v2d1 = v2d1[:ms]
        v2d2 = v2d2[:ns]
        m01 = dict['m01'][0][:ms]
        # print(m01.shape)
        # print(np.arange(len(m01))[m01 != -1], m01[m01 != -1])
        # print(v2d2.shape, v2d1.shape)
        shift = np.sqrt(((v2d2[m01[m01 != -1].int(), :] * 1.0 - v2d1[np.arange(len(m01))[m01 != -1],:]) ** 2).sum(-1))
        vertex_num += len(v2d1)
        vertex_shift += shift.mean()
        vertex_max_shift += shift.max()
        percentage += ((m01!=-1).float().sum() * 1.0 / len(m01) * 100.0)
    
    print('>>>> gap=', args.gap, ' percentage=', percentage / len(lineart), ' vertex num=', vertex_num*1.0/len(lineart), 'edges num=', edges*1.0/len(lineart)/2, 'vertex shift=', vertex_shift/len(lineart), ' vertex max shift=', vertex_max_shift/len(lineart), flush=True)
        

        # if len(v2d1) > max_node:
        #     max_node = len(v2d1)
        # if len(v2d2) > max_node:
        #     max_node = len(v2d2)
    # print(max_node)
        # print(v2d1.shape)
        # img1 = ((dict['image0'][0].permute(1, 2, 0).float().numpy() + 1.0) * 255 / 2).astype(np.uint8).copy()
        # img2 = ((dict['image1'][0].permute(1, 2, 0).float().numpy() + 1.0) * 255 / 2).astype(np.uint8).copy()

        # # print(v2d1.shape, img1.shape, flush=True)

        # for node, nbs in enumerate(dict['topo0']):
        #     for nb in nbs:
        #         cv2.line(img1, [v2d1[node][0], v2d1[node][1]], [v2d1[nb][0], v2d1[nb][1]], [255, 180, 180], 2)
        # colors1, colors2 = {}, {}

        # id1 = dict['id0'][0].numpy()
        # id2 = dict['id1'][0].numpy()
        # for index in id1:
        #     # print(index)
        #     color = [np.random.randint(0, 256), np.random.randint(0, 256), np.random.randint(0, 256)]
        #     # for ii in index:
        #     colors1[index] = color
        
        # colors1, colors2 = {}, {}


        # for index in id1:
        #     # print(index)
        #     color = [np.random.randint(0, 256), np.random.randint(0, 256), np.random.randint(0, 256)]
        #     colors1[index] = color

        # for i, p in enumerate(v2d1):
        #     ii = id1[i]
        #     # print(ii)
        #     cv2.circle(img1, [int(p[0]), int(p[1])], 1, colors1[ii], 2)

        # unmatched = 0
        # for ii in id2:
        #     color = [0, 0, 0]
        #     this_is_umatched = 1
        #     colors2[ii] = colors1[ii] if ii in colors1 else color
        #     if ii in colors1:
        #         this_is_umatched = 0
        #     # if ii not in colors1:
        #     unmatched += this_is_umatched

        # for i, p in enumerate(v2d2):
        #     ii = id2[i]
        #     # print(p)
        #     cv2.circle(img2, [int(p[0]), int( p[1])], 1, colors2[ii], 2)

        # # print('Unmatched in Img 2: ', , '%')
        # unmatched_all.append(100 - unmatched * 100.0/len(v2d2))

        # im_h = cv2.hconcat([img1, img2])
        # print('/mnt/lustre/syli/inbetween/AnimeInbetween/corr/datasets/data_check_norm/' + dict['file_name'][0].replace('/', '_') + '.jpg', flush=True)
        # cv2.imwrite('/mnt/lustre/syli/inbetween/AnimeInbetween/corr/datasets/data_check_norm/' + dict['file_name'][0].replace('/', '_') + '.jpg', im_h)

    # print(np.mean(unmatched_all))
 



================================================
FILE: corr/experiments/vtx_corr/ckpt/.gitkeep
================================================


================================================
FILE: corr/main.py
================================================
from vtx_matching import VtxMat
import argparse
import os
import yaml
from pprint import pprint
from easydict import EasyDict



def parse_args():
    parser = argparse.ArgumentParser(
        description='Anime segment matching')
    parser.add_argument('--config', default='')
    # exclusive arguments
    group = parser.add_mutually_exclusive_group(required=True)
    group.add_argument('--train', action='store_true')
    group.add_argument('--eval', action='store_true')


    return parser.parse_args()


def main():
    # parse arguments and load config
    args = parse_args()
    with open(args.config) as f:
        config = yaml.load(f)

    for k, v in vars(args).items():
        config[k] = v
    pprint(config)

    config = EasyDict(config)
    agent = VtxMat(config)
    print(config)

    if args.train:
        agent.train()
    elif args.eval:
        agent.eval()



if __name__ == '__main__':
    main()


================================================
FILE: corr/models/__init__.py
================================================
from .supergluet import SuperGlueT
# from .supergluet_wo_OT import SuperGlueTwoOT
# from .supergluenp import SuperGlue as SuperGlueNP
# from .supergluei import SuperGlue as SuperGlueI
# from .supergluet2 import SuperGlueT2

__all__ = ['SuperGlueT']

================================================
FILE: corr/models/supergluet.py
================================================
import numpy as np
from copy import deepcopy
from pathlib import Path
import torch
from torch import nn

import argparse
from sknetwork.embedding import Spectral

def MLP(channels: list, do_bn=True):
    """ Multi-layer perceptron """
    n = len(channels)
    layers = []
    for i in range(1, n):
        layers.append(
            nn.Conv1d(channels[i - 1], channels[i], kernel_size=1, bias=True))
        if i < (n-1):
            if do_bn:
                layers.append(nn.InstanceNorm1d(channels[i]))
            layers.append(nn.ReLU())
    return nn.Sequential(*layers)


def normalize_keypoints(kpts, image_shape):
    """ Normalize keypoints locations based on image image_shape"""
    _, _, height, width = image_shape
    one = kpts.new_tensor(1)
    size = torch.stack([one*width, one*height])[None]
    center = size / 2
    scaling = size.max(1, keepdim=True).values * 0.7
    return (kpts - center[:, None, :]) / scaling[:, None, :]

class ThreeLayerEncoder(nn.Module):
    """ Joint encoding of visual appearance and location using MLPs"""
    def __init__(self, enc_dim):
        super().__init__()
        # input must be 3 channel (r, g, b)
        self.layer1 = nn.Conv2d(3, enc_dim//4, 7, padding=3)
        self.non_linear1 = nn.ReLU()
        self.layer2 = nn.Conv2d(enc_dim//4, enc_dim//2, 3, padding=1)
        self.non_linear2 = nn.ReLU()
        self.layer3 = nn.Conv2d(enc_dim//2, enc_dim, 3, padding=1)

        self.norm1 = nn.InstanceNorm2d(enc_dim//4)
        self.norm2 = nn.InstanceNorm2d(enc_dim//2)
        self.norm3 = nn.InstanceNorm2d(enc_dim)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                nn.init.constant_(m.bias, 0.0)

    def forward(self, img):
        x = self.non_linear1(self.norm1(self.layer1(img)))
        x = self.non_linear2(self.norm2(self.layer2(x)))
        x = self.norm3(self.layer3(x))

        return x


class VertexDescriptor(nn.Module):
    """ Joint encoding of visual appearance and location using MLPs"""
    def __init__(self, enc_dim):
        super().__init__()
        self.encoder = ThreeLayerEncoder(enc_dim)


    def forward(self, img, vtx):
        x = self.encoder(img)
        n, c, h, w = x.size()
        assert((h, w) == img.size()[2:4])
        return x[:, :, torch.round(vtx[0, :, 1]).long(), torch.round(vtx[0, :, 0]).long()]



class KeypointEncoder(nn.Module):
    """ Joint encoding of visual appearance and location using MLPs"""
    def __init__(self, feature_dim, layers):
        super().__init__()
        self.encoder = MLP([2] + layers + [feature_dim])
        # for m in self.encoder.modules():
        #     if isinstance(m, nn.Conv2d):
        #         nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
        #         nn.init.constant_(m.bias, 0.0)
        nn.init.constant_(self.encoder[-1].bias, 0.0)

    def forward(self, kpts):
        inputs = kpts.transpose(1, 2)

        x = self.encoder(inputs)

        return x

class TopoEncoder(nn.Module):
    """ Joint encoding of visual appearance and location using MLPs"""
    def __init__(self, feature_dim, layers):
        super().__init__()
        self.encoder = MLP([64] + layers + [feature_dim])
        # for m in self.encoder.modules():
        #     if isinstance(m, nn.Conv2d):
        #         nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
        #         nn.init.constant_(m.bias, 0.0)
        nn.init.constant_(self.encoder[-1].bias, 0.0)

    def forward(self, kpts):
        inputs = kpts.transpose(1, 2)

        x = self.encoder(inputs)

        return x


def attention(query, key, value, mask=None):
    dim = query.shape[1]
    scores = torch.einsum('bdhn,bdhm->bhnm', query, key) / dim**.5
    if mask is not None:
        scores = scores.masked_fill(mask==0, float('-inf'))

    prob = torch.nn.functional.softmax(scores, dim=-1)


    return torch.einsum('bhnm,bdhm->bdhn', prob, value), prob


class MultiHeadedAttention(nn.Module):
    """ Multi-head attention to increase model expressivitiy """
    def __init__(self, num_heads: int, d_model: int):
        super().__init__()
        assert d_model % num_heads == 0
        self.dim = d_model // num_heads
        self.num_heads = num_heads
        self.merge = nn.Conv1d(d_model, d_model, kernel_size=1)
        self.proj = nn.ModuleList([deepcopy(self.merge) for _ in range(3)])

    def forward(self, query, key, value, mask=None):
        batch_dim = query.size(0)
        query, key, value = [l(x).view(batch_dim, self.dim, self.num_heads, -1)
                             for l, x in zip(self.proj, (query, key, value))]
        x, prob = attention(query, key, value, mask)

        return self.merge(x.contiguous().view(batch_dim, self.dim*self.num_heads, -1))


class AttentionalPropagation(nn.Module):
    def __init__(self, feature_dim: int, num_heads: int):
        super().__init__()
        self.attn = MultiHeadedAttention(num_heads, feature_dim)
        self.mlp = MLP([feature_dim*2, feature_dim*2, feature_dim])
        nn.init.constant_(self.mlp[-1].bias, 0.0)

    def forward(self, x, source, mask=None):
        message = self.attn(x, source, source, mask)
        return self.mlp(torch.cat([x, message], dim=1))


class AttentionalGNN(nn.Module):
    def __init__(self, feature_dim: int, layer_names: list):
        super().__init__()
        self.layers = nn.ModuleList([
            AttentionalPropagation(feature_dim, 4)
            for _ in range(len(layer_names))])
        self.names = layer_names

    def forward(self, desc0, desc1, mask00=None, mask11=None, mask01=None, mask10=None):
        for layer, name in zip(self.layers, self.names):
            layer.attn.prob = []
            if name == 'cross':
                src0, src1 = desc1, desc0
                mask0, mask1 = mask01[:, None], mask10[:, None] 
            else:  # if name == 'self':
                src0, src1 = desc0, desc1
                mask0, mask1 = mask00[:, None], mask11[:, None]

            delta0, delta1 = layer(desc0, src0, mask0), layer(desc1, src1, mask1)
            desc0, desc1 = (desc0 + delta0), (desc1 + delta1)
        return desc0, desc1


def log_sinkhorn_iterations(Z, log_mu, log_nu, iters: int):
    """ Perform Sinkhorn Normalization in Log-space for stability"""
    u, v = torch.zeros_like(log_mu), torch.zeros_like(log_nu)
    for _ in range(iters):
        u = log_mu - torch.logsumexp(Z + v.unsqueeze(1), dim=2)
        v = log_nu - torch.logsumexp(Z + u.unsqueeze(2), dim=1)
    return Z + u.unsqueeze(2) + v.unsqueeze(1)


def log_optimal_transport(scores, alpha, iters: int, ms=None, ns=None):
    """ Perform Differentiable Optimal Transport in Log-space for stability"""
    b, m, n = scores.shape
    one = scores.new_tensor(1)
    if ms is  None or ns is  None:
        ms, ns = (m*one).to(scores), (n*one).to(scores)
    # else:
    #     ms, ns = ms.to(scores)[:, None], ns.to(scores)[:, None]
    # here m,n should be parameters not shape

    # ms, ns: (b, )
    bins0 = alpha.expand(b, m, 1)
    bins1 = alpha.expand(b, 1, n)
    alpha = alpha.expand(b, 1, 1)

    # pad additional scores for unmatcheed (to -1)
    # alpha is the learned threshold
    couplings = torch.cat([torch.cat([scores, bins0], -1),
                           torch.cat([bins1, alpha], -1)], 1)

    norm = - (ms + ns).log() # (b, )
    # print(scores.min(), flush=True)
    if ms.size()[0] > 0:
        norm = norm[:, None]
        log_mu = torch.cat([norm.expand(b, m), ns.log()[:, None] + norm], dim=-1) # (m + 1)
        log_nu = torch.cat([norm.expand(b, n), ms.log()[:, None] + norm], dim=-1)
        # print(log_nu.min(), log_mu.min(), flush=True)
    else:
        log_mu = torch.cat([norm.expand(m), ns.log()[None] + norm]) # (m + 1)
        log_nu = torch.cat([norm.expand(n), ms.log()[None] + norm])
        log_mu, log_nu = log_mu[None].expand(b, -1), log_nu[None].expand(b, -1)

    
    Z = log_sinkhorn_iterations(couplings, log_mu, log_nu, iters)

    if ms.size()[0] > 1:
        norm = norm[:, :, None]
    Z = Z - norm  # multiply probabilities by M+N
    return Z


def arange_like(x, dim: int):
    return x.new_ones(x.shape[dim]).cumsum(0) - 1  # traceable in 1.1


class SuperGlueT(nn.Module):

    def __init__(self, config=None):
        super().__init__()

        default_config = argparse.Namespace()
        default_config.descriptor_dim = 128
        # default_config.weights = 
        default_config.keypoint_encoder = [32, 64, 128]
        default_config.GNN_layers = ['self', 'cross'] * 9
        default_config.sinkhorn_iterations = 100
        default_config.match_threshold = 0.2
        # self.config = {**self.default_config, **config}

        if config is None:
            self.config = default_config
        else:
            self.config = config   
            self.config.GNN_layers = ['self', 'cross'] * self.config.GNN_layer_num
            # print('WULA!', self.config.GNN_layer_num)

        self.kenc = KeypointEncoder(
            self.config.descriptor_dim, self.config.keypoint_encoder)

        self.tenc = TopoEncoder(
            self.config.descriptor_dim, [96])


        self.gnn = AttentionalGNN(
            self.config.descriptor_dim, self.config.GNN_layers)

        self.final_proj = nn.Conv1d(
            self.config.descriptor_dim, self.config.descriptor_dim,
            kernel_size=1, bias=True)

        bin_score = torch.nn.Parameter(torch.tensor(1.))
        self.register_parameter('bin_score', bin_score)
        self.vertex_desc = VertexDescriptor(self.config.descriptor_dim)
       


    def forward(self, data):
        """Run SuperGlue on a pair of keypoints and descriptors"""

        kpts0, kpts1 = data['keypoints0'].float(), data['keypoints1'].float()

        ori_mask0, ori_mask1 = data['mask0'].float(), data['mask1'].float()
        dim_m, dim_n = data['ms'].float(), data['ns'].float()

        spec0, spec1 = data['adj_mat0'], data['adj_mat1']

        mmax = dim_m.int().max()
        nmax = dim_n.int().max()

        mask0 = ori_mask0[:, :mmax]
        mask1 = ori_mask1[:, :nmax]

        kpts0 = kpts0[:, :mmax]
        kpts1 = kpts1[:, :nmax]

        desc0, desc1 = self.vertex_desc(data['image0'], kpts0.float()), self.vertex_desc(data['image1'], kpts1.float())
        # spec0, spec1 = np.abs(self.spectral.fit_transform(topo0[0].cpu().numpy())), np.abs(self.spectral.fit_transform(topo1[0].cpu().numpy()))

        desc0 = desc0 + self.tenc(desc0.new_tensor(spec0))
        desc1 = desc1 + self.tenc(desc1.new_tensor(spec1))

        mask00 = torch.ones_like(mask0)[:, :, None] * mask0[:, None, :]
        
        mask11 = torch.ones_like(mask1)[:, :, None] * mask1[:, None, :]
        mask01 = torch.ones_like(mask0)[:, :, None] * mask1[:, None, :]
        mask10 = torch.ones_like(mask1)[:, :, None] * mask0[:, None, :]


        if kpts0.shape[1] < 2 or kpts1.shape[1] < 2:  # no keypoints
            shape0, shape1 = kpts0.shape[:-1], kpts1.shape[:-1]
            # print(data['file_name'])
            return {
                'matches0': kpts0.new_full(shape0, -1, dtype=torch.int)[0],
                # 'matches1': kpts1.new_full(shape1, -1, dtype=torch.int)[0],
                'matching_scores0': kpts0.new_zeros(shape0)[0],
                # 'matching_scores1': kpts1.new_zeros(shape1)[0],
                'skip_train': True
            }

        file_name = data['file_name']
        all_matches = data['all_matches'] if 'all_matches' in data else None# shape = (1, K1)

        
        # positional embedding
        # Keypoint normalization.
        kpts0 = normalize_keypoints(kpts0, data['image0'].shape)
        kpts1 = normalize_keypoints(kpts1, data['image1'].shape)

        # Keypoint MLP encoder.
    
        pos0 = self.kenc(kpts0)
        pos1 = self.kenc(kpts1)

        desc0 = desc0 + pos0
        desc1 = desc1 + pos1

       
        # Multi-layer Transformer network.
        desc0, desc1 = self.gnn(desc0, desc1, mask00, mask11, mask01, mask10)

        # Final MLP projection.
        mdesc0, mdesc1 = self.final_proj(desc0), self.final_proj(desc1)

        # Compute matching descriptor distance.
        scores = torch.einsum('bdn,bdm->bnm', mdesc0, mdesc1)

        # b k1 k2
        scores = scores / self.config.descriptor_dim**.5

        mask01 = mask0[:, :, None] * mask1[:, None, :]
        scores = scores.masked_fill(mask01 == 0, float('-inf'))


        # Run the optimal transport.
        scores = log_optimal_transport(
            scores, self.bin_score,
            iters=self.config.sinkhorn_iterations,
            ms=dim_m, ns=dim_n)


        # Get the matches with score above "match_threshold".
        max0, max1 = scores[:, :-1, :-1].max(2), scores[:, :-1, :-1].max(1)
        indices0, indices1 = max0.indices, max1.indices
        mutual0 = arange_like(indices0, 1)[None] == indices1.gather(1, indices0)
        mutual1 = arange_like(indices1, 1)[None] == indices0.gather(1, indices1)
        zero = scores.new_tensor(0)
        mscores0 = torch.where(mutual0, max0.values.exp(), zero)
        mscores1 = torch.where(mutual1, mscores0.gather(1, indices1), zero)
        valid0 = mutual0 & (mscores0 > self.config.match_threshold)
        valid1 = mutual1 & valid0.gather(1, indices1)
        
        valid0 = mscores0 > self.config.match_threshold
        valid1 = valid0.gather(1, indices1)
        indices0 = torch.where(valid0, indices0, indices0.new_tensor(-1))
        indices1 = torch.where(valid1, indices1, indices1.new_tensor(-1))

        # check if indexed correctly

        loss = []

        

        if all_matches is not None:
            for b in range(len(dim_m)):

                for i in range(int(dim_m[b])):
      
                    x = i
                    y = all_matches[b][i].long()

                    loss.append(-scores[b][x][y] ) # check batch size == 1 ?

            loss_mean = torch.mean(torch.stack(loss))
            loss_mean = torch.reshape(loss_mean, (1, -1))

            return {
                'matches0': indices0, # use -1 for invalid match
                'matches1': indices1, # use -1 for invalid match
                'matching_scores0': mscores0,
                # 'matching_scores1': mscores1[0],
                'loss': loss_mean,
                'skip_train': False,
                'accuracy': (((all_matches[:, :mmax] == indices0) & mask0.bool()).sum() / mask0.sum()).item(),
                'valid_accuracy': (((all_matches[:, :mmax] == indices0) & (all_matches[:, :mmax] != -1) & mask0.bool()).float().sum() / ((all_matches[:, :mmax] != -1) & mask0.bool()).float().sum()).item(),
            }
        else:
            return {
                'matches0': indices0[0], # use -1 for invalid match
                'matching_scores0': mscores0[0],
                'loss': -1,
                'skip_train': True,
                'accuracy': -1,
                'area_accuracy': -1,
                'valid_accuracy': -1,
            }


if __name__ == '__main__':

    args = argparse.Namespace()
    args.batch_size = 1
    args.gap = 0
    args.type = 'train'
    args.model = 'jolleen' 
    args.action = 'slash'
    ss = SuperGlue()


    loader = fetch_dataloader(args)
    # #print(len(loader))
    for data in loader:
        # p1, p2, s1, s2, mi = data
        dict1 = data

        kp1 = dict1['keypoints0']
        kp2 = dict1['keypoints1']
        p1 = dict1['image0']
        p2 = dict1['image1']  

        # #print(s1)
        # #print(s1.type)
        mi = dict1['all_matches']
        fname = dict1['file_name'] 
        print(kp1.shape, p1.shape, mi.shape)  
        # #print(mi.size())  
        # #print(mi)
        # break

        a = ss(data)
        print(dict1['file_name'])
        print(a['loss'])
        a['loss'].backward()
        # print(a['matches0'].size())
        # print(a['accuracy'], a['valid_accuracy'])

================================================
FILE: corr/srun.sh
================================================
#!/bin/sh
currenttime=`date "+%Y%m%d%H%M%S"`
if [ ! -d log ]; then
    mkdir log
fi

echo "[Usage] ./srun.sh config_path [train|eval] partition gpunum"
# check config exists
if [ ! -e $1 ]
then
    echo "[ERROR] configuration file: $1 does not exists!"
    exit
fi


if [ ! -d ${expname} ]; then
    mkdir ${expname}
fi

echo "[INFO] saving results to, or loading files from: "$expname

if [ "$3" == "" ]; then
    echo "[ERROR] enter partition name"
    exit
fi
partition_name=$3
echo "[INFO] partition name: $partition_name"

if [ "$4" == "" ]; then
    echo "[ERROR] enter gpu num"
    exit
fi
gpunum=$4
gpunum=$(($gpunum<8?$gpunum:8))
echo "[INFO] GPU num: $gpunum"
((ntask=$gpunum*3))


TOOLS="srun --partition=$partition_name --cpus-per-task=8 --gres=gpu:$gpunum   -N 1 --job-name=${config_suffix}"
PYTHONCMD="python -u main.py --config $1"

if [ $2 == "train" ];
then
    $TOOLS $PYTHONCMD \
    --train 
elif [ $2 == "eval" ];
then
    $TOOLS $PYTHONCMD \
    --eval 
fi


================================================
FILE: corr/utils/log.py
================================================
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this open-source project.


""" Define the Logger class to print log"""
import os
import sys
import logging
from datetime import datetime


class Logger:
    def __init__(self, args, output_dir):

        log = logging.getLogger(output_dir)
        if not log.handlers:
            log.setLevel(logging.DEBUG)
            # if not os.path.exists(output_dir):
            #     os.mkdir(args.data.output_dir)
            fh = logging.FileHandler(os.path.join(output_dir,'log.txt'))
            fh.setLevel(logging.INFO)
            ch = ProgressHandler()
            ch.setLevel(logging.DEBUG)
            formatter = logging.Formatter(fmt='%(asctime)s %(message)s', datefmt='%m/%d/%Y %I:%M:%S')
            fh.setFormatter(formatter)
            ch.setFormatter(formatter)
            log.addHandler(fh)
            log.addHandler(ch)
        self.log = log
        # setup TensorBoard
        # if args.tensorboard:
        #     from tensorboardX import SummaryWriter
        #     self.writer = SummaryWriter(log_dir=args.output_dir)
        # else:
        self.writer = None
        self.log_per_updates = args.log_per_updates

    def set_progress(self, epoch, total):
        self.log.info(f'Epoch: {epoch}')
        self.epoch = epoch
        self.i = 0
        self.total = total
        self.start = datetime.now()

    def update(self, stats):
        self.i += 1
        if self.i % self.log_per_updates == 0:
            remaining = str((datetime.now() - self.start) / self.i * (self.total - self.i))
            remaining = remaining.split('.')[0]
            updates = stats.pop('updates')
            stats_str = ' '.join(f'{key}[{val:.8f}]' for key, val in stats.items())
            
            self.log.info(f'> epoch [{self.epoch}] updates[{updates}] {stats_str} eta[{remaining}]')
            
            if self.writer:
                for key, val in stats.items():
                    self.writer.add_scalar(f'train/{key}', val, updates)
        if self.i == self.total:
            self.log.debug('\n')
            self.log.debug(f'elapsed time: {str(datetime.now() - self.start).split(".")[0]}')

    def log_eval(self, stats, metrics_group=None):
        stats_str = ' '.join(f'{key}: {val:.8f}' for key, val in stats.items())
        self.log.info(f'valid {stats_str}')
        if self.writer:
            for key, val in stats.items():
                self.writer.add_scalar(f'valid/{key}', val, self.epoch)
        # for mode, metrics in metrics_group.items():
        #     self.log.info(f'evaluation scores ({mode}):')
        #     for key, (val, _) in metrics.items():
        #         self.log.info(f'\t{key} {val:.4f}')
        # if self.writer and metrics_group is not None:
        #     for key, val in stats.items():
        #         self.writer.add_scalar(f'valid/{key}', val, self.epoch)
        #     for key in list(metrics_group.values())[0]:
        #         group = {}
        #         for mode, metrics in metrics_group.items():
        #             group[mode] = metrics[key][0]
        #         self.writer.add_scalars(f'valid/{key}', group, self.epoch)

    def __call__(self, msg):
        self.log.info(msg)


class ProgressHandler(logging.Handler):
    def __init__(self, level=logging.NOTSET):
        super().__init__(level)

    def emit(self, record):
        log_entry = self.format(record)
        if record.message.startswith('> '):
            sys.stdout.write('{}\r'.format(log_entry.rstrip()))
            sys.stdout.flush()
        else:
            sys.stdout.write('{}\n'.format(log_entry))



================================================
FILE: corr/utils/visualize_vtx_corr.py
================================================
import numpy as np
import torch
import cv2


def make_inter_graph(v2d1, v2d2, topo1, topo2, match12):
    valid = (match12 != -1)
    marked2 = np.zeros(len(v2d2)).astype(bool)
    # print(match12[valid])
    marked2[match12[valid]] = True

    id1toh, id2toh = np.zeros(len(v2d1)), np.zeros(len(v2d2))
    id1toh[valid] = np.arange(np.sum(valid))
    id2toh[match12[valid]] = np.arange(np.sum(valid))
    id1toh[np.invert(valid)] = np.arange(np.sum(1 - valid)) + np.sum(valid)
    # print(marked2)
    id2toh[np.invert(marked2)] = len(v2d1) + np.arange(np.sum(np.invert(marked2)))

    id1toh = id1toh.astype(int)
    id2toh = id2toh.astype(int)

    tot_len = len(v2d1) + np.sum(np.invert(marked2))

    vin1 = v2d1[valid][:]
    vin2 = v2d2[match12[valid]][:]
    vh = 0.5 * (vin1 + vin2)
    vh = np.concatenate((vh, v2d1[np.invert(valid)], v2d2[np.invert(marked2)]), axis=0)

    topoh = [[] for ii in range(tot_len)]


    for node in range(len(topo1)):
        
        for nb in topo1[node]:
            if int(id1toh[nb]) not in topoh[id1toh[node]]:
                topoh[id1toh[node]].append(int(id1toh[nb]))


    for node in range(len(topo2)):
        for nb in topo2[node]:
            if int(id2toh[nb]) not in topoh[id2toh[node]]:
                topoh[id2toh[node]].append(int(id2toh[nb]))

    return vh, topoh


def make_inter_graph_valid(v2d1, v2d2, topo1, topo2, match12):
    valid = (match12 != -1)
    marked2 = np.zeros(len(v2d2)).astype(bool)
    # print(match12[valid])
    marked2[match12[valid]] = True

    id1toh, id2toh = np.zeros(len(v2d1)), np.zeros(len(v2d2))
    id1toh[valid] = np.arange(np.sum(valid))
    id2toh[match12[valid]] = np.arange(np.sum(valid))
    id1toh[np.invert(valid)] = np.arange(np.sum(1 - valid)) + np.sum(valid)
    # print(marked2)
    id2toh[np.invert(marked2)] = len(v2d1) + np.arange(np.sum(np.invert(marked2)))

    id1toh = id1toh.astype(int)
    id2toh = id2toh.astype(int)

    tot_len = len(v2d1) + np.sum(np.invert(marked2))

    vin1 = v2d1[valid][:]
    vin2 = v2d2[match12[valid]][:]
    vh = 0.5 * (vin1 + vin2)
    # vh = np.concatenate((vh, v2d1[np.invert(valid)], v2d2[np.invert(marked2)]), axis=0)

    # topoh = [[] for ii in range(tot_len)]
    topoh = [[] for ii in range(np.sum(valid))]

    for node in range(len(topo1)):
        if not valid[node]:
            continue
        for nb in topo1[node]:
            if int(id1toh[nb]) not in topoh[id1toh[node]]:
                if valid[nb]:
                    topoh[id1toh[node]].append(int(id1toh[nb]))


    for node in range(len(topo2)):
        if not marked2[node]:
            continue
        for nb in topo2[node]:
            if int(id2toh[nb]) not in topoh[id2toh[node]]:
                if marked2[nb]:
                    topoh[id2toh[node]].append(int(id2toh[nb]))

    return vh, topoh



def visualize(dict):
    # print(dict['keypoints0'].size(), flush=True)
    img1 = ((dict['image0'].permute(1, 2, 0).float().numpy() + 1.0) * 255 / 2).astype(int).copy()
    img2 = ((dict['image1'].permute(1, 2, 0).float().numpy() + 1.0) * 255 / 2).astype(int).copy()
    img1p = ((dict['image0'].permute(1, 2, 0).float().numpy() + 1.0) * 255 / 2).astype(int).copy()
    img2p = ((dict['image1'].permute(1, 2, 0).float().numpy() + 1.0) * 255 / 2).astype(int).copy()

    img1[:, :, 0] += 255
    img1[:, :, 1] += 180
    img1[:, :, 2] += 180
    img1[img1 > 255] = 255

    img2[:, :, 0] += 255
    img2[:, :, 1] += 180
    img2[:, :, 2] += 180
    img2[img2 > 255] = 255
    
    img1p[:, :, 0] += 255
    img1p[:, :, 1] += 180
    img1p[:, :, 2] += 180
    img1p[img1p > 255] = 255
    
    img2p[:, :, 0] += 255
    img2p[:, :, 1] += 180
    img2p[:, :, 2] += 180
    img2p[img2p > 255] = 255

    img1, img2, img1p, img2p = img1.astype(np.uint8), img2.astype(np.uint8), img1p.astype(np.uint8), img2p.astype(np.uint8)
    

    # print(v2d1.shape, img1.shape, flush=True)
    v2d1 = dict['keypoints0'].numpy().astype(int)
    v2d2 = dict['keypoints1'].numpy().astype(int)
    topo1 = dict['topo0']
    topo2 = dict['topo1']
    # print(topo1, flush=True)
    # for node, nbs in enumerate(dict['topo0']):
    #     for nb in nbs:
    #         cv2.line(img1, [v2d1[node][0], v2d1[node][1]], [v2d1[nb][0], v2d1[nb][1]], [255, 180, 180], 2)


    id1 = np.arange(len(v2d1))
    id2 = np.arange(len(v2d2))
    all_matches = dict['all_matches'].cpu().int().data.numpy()
    predicted = dict['matches0'].cpu().data.numpy()[0]
    predicted1 = dict['matches1'].cpu().data.numpy()[0]
    
    colors1_gt, colors2_gt = {}, {}
    colors1_pred, colors2_pred = {}, {}
    cross1_pred, cross2_pred = {}, {}

    for index in id1:
        color = [np.random.randint(0, 256), np.random.randint(0, 256), np.random.randint(0, 256)]
            # print(predicted.shape, flush=True)
        if all_matches[index] != -1:
            colors2_gt[all_matches[index]] = color
        if predicted[index] != -1:
            colors2_pred[predicted[index]] = color

        colors1_gt[index] = color if all_matches[index] != -1 else [0, 0, 0]
        colors1_pred[index] = color if predicted[index] != -1 else [0, 0, 0]

        # if predicted[index] == -1 and colors1_pred[index] != [0, 0, 0]:
        #     color = [np.random.randint(0, 256), np.random.randint(0, 256), np.random.randint(0, 256)]
        #     colors1_pred[index] = [0, 0, 0]
        #     colors2_pred.pop(all_matches[index])
        # whether predicted correctly
        if predicted[index] != all_matches[index]:
            cross1_pred[index] = True
            if predicted[index] != -1:
                cross2_pred[predicted[index]] = True
        
    for i, p in enumerate(v2d1):
        ii = id1[i]
        # print(ii)
        cv2.circle(img1, [int(p[0]), int(p[1])], 1, colors1_gt[i], 2)
        if ii in cross1_pred and cross1_pred[ii]:
            cv2.rectangle(img1p, [int(p[0]-1), int(p[1]-1)], [int(p[0]+1), int(p[1]+1)], colors1_pred[i],-1)
        else:
            cv2.circle(img1p, [int(p[0]), int(p[1])], 1, colors1_pred[i], 2)
        
    for ii in id2:
        # print(ii)
        color = [0, 0, 0]
        this_is_umatched = 1
        if ii not in colors2_gt:
            colors2_gt[ii] = color  
        if ii not in colors2_pred:
            colors2_pred[ii] = color

    for i, p in enumerate(v2d2):
        ii = id2[i]
        # print(p)
        cv2.circle(img2, [int(p[0]), int( p[1])], 1, colors2_gt[ii], 2)
        if ii in cross2_pred and cross2_pred[ii]:
            cv2.rectangle(img2p, [int(p[0]-1), int(p[1]-1)], [int(p[0]+1), int(p[1]+1)], colors2_pred[i], -1)
        else:
            cv2.circle(img2p, [int(p[0]), int(p[1])], 1, colors2_pred[i], 2)

    # print('Unmatched in Img 2: ', , '%')
    # unmatched_all.append(100 - unmatched * 100.0/len(v2d2))
    cv2.putText(img2p, str(round(np.sum(all_matches == predicted) * 100.0 / len(predicted), 2)).format('.2f') + '%', \
        (500, 100), cv2.FONT_HERSHEY_PLAIN, 3, (0, 0, 255), 2)



    vh_gt, topoh_gt = make_inter_graph(v2d1, v2d2, topo1, topo2, all_matches)
    vh_pred, topoh_pred = make_inter_graph(v2d1, v2d2, topo1, topo2, predicted)
    vh_gt_valid, topoh_gt_valid = make_inter_graph_valid(v2d1, v2d2, topo1, topo2, all_matches)
    vh_pred_valid, topoh_pred_valid = make_inter_graph_valid(v2d1, v2d2, topo1, topo2, predicted)
    v2d1t = ((v2d2[predicted] + v2d1) * 0.5).astype(int)
    v2d2t = ((v2d1[predicted1] + v2d2) * 0.5).astype(int)

    vh_gt = vh_gt.astype(int)
    vh_gt_valid = vh_gt_valid.astype(int)
    vh_pred = vh_pred.astype(int)
    vh_pred_valid = vh_pred_valid.astype(int)

    imgh = np.zeros_like(img1) + 255
    imghp = np.zeros_like(img1) + 255
    imgh_valid = np.zeros_like(img1) + 255
    imghp_valid = np.zeros_like(img1) + 255

    for node, nbs in enumerate(topoh_gt):
        for nb in nbs:
            cv2.line(imgh, [vh_gt[node][0], vh_gt[node][1]], [vh_gt[nb][0], vh_gt[nb][1]], [0, 0, 0], 2)
    
    for node, nbs in enumerate(topoh_pred):
        for nb in nbs:
            cv2.line(imghp, [vh_pred[node][0], vh_pred[node][1]], [vh_pred[nb][0], vh_pred[nb][1]], [0, 0, 0], 2)
    
    for node, nbs in enumerate(topoh_gt_valid):
        for nb in nbs:
            cv2.line(imgh_valid, [vh_gt_valid[node][0], vh_gt_valid[node][1]], [vh_gt_valid[nb][0], vh_gt_valid[nb][1]], [0, 0, 0], 2)
    
    for node, nbs in enumerate(topoh_pred_valid):
        for nb in nbs:
            cv2.line(imghp_valid, [vh_pred_valid[node][0], vh_pred_valid[node][1]], [vh_pred_valid[nb][0], vh_pred_valid[nb][1]], [0, 0, 0], 2)
    
    # for node, nbs in enumerate(topo1):
    #     for nb in nbs:
    #         cv2.line(imghp_valid, [v2d1t[node][0], v2d1t[node][1]], [v2d1t[nb][0], v2d1t[nb][1]], [0, 0, 0], 2)
    
    # for node, nbs in enumerate(topo2):
    #     for nb in nbs:
    #         cv2.line(imghp_valid, [v2d2t[node][0], v2d2t[node][1]], [v2d2t[nb][0], v2d2t[nb][1]], [0, 0, 0], 2)
    


    im_h = cv2.hconcat([img1, img2])
    im_hp = cv2.hconcat([img1p, img2p])
    img_inter = cv2.hconcat([imgh, imghp])
    img_inter_valid = cv2.hconcat([imgh_valid, imghp_valid])
    im_hv = cv2.vconcat([im_h, im_hp, img_inter, img_inter_valid])

    return im_hv


================================================
FILE: corr/vtx_matching.py
================================================
""" This script handling the training process. """
import os
import time
import random
import argparse
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data
from datasets import fetch_dataloader
import random
from utils.log import Logger

from torch.optim import *
import warnings
from tqdm import tqdm
import itertools
import pdb
import numpy as np
import models
import datetime
import sys
import json
import cv2

from utils.visualize_vtx_corr import visualize
import matplotlib.cm as cm
# from models.utils import make_matching_seg_plot

warnings.filterwarnings('ignore')


import matplotlib.pyplot as plt
import pdb

class VtxMat():
    def __init__(self, args):
        self.config = args
        torch.backends.cudnn.benchmark = True
        torch.multiprocessing.set_sharing_strategy('file_system')
        self._build()

    def train(self):
        
        opt = self.config
        print(opt)

        model = self.model

        if hasattr(self.config, 'init_weight'):
            checkpoint = torch.load(self.config.init_weight)
            model.load_state_dict(checkpoint['model'])

        optimizer = self.optimizer
        schedular = self.schedular
        mean_loss = []
        log = Logger(self.config, self.expdir)
        updates = 0
        
        # set seed
        random.seed(opt.seed)
        torch.manual_seed(opt.seed)
        torch.cuda.manual_seed(opt.seed)
        np.random.seed(opt.seed)

        # start training
        for epoch in range(1, opt.epoch+1):
            np.random.seed(opt.seed + epoch)
            train_loader = self.train_loader
            log.set_progress(epoch, len(train_loader))
            batch_loss = 0
            batch_acc = 0 
            batch_valid_acc = 0
            batch_iter = 0
            model.train()
            avg_time = 0
            avg_num = 0
            # torch.cuda.synchronize()
            
            for i, pred in enumerate(train_loader):
                # tstart = time.time()
                # print(pred['file_name'])
                data = model(pred)


                if not data['skip_train']:
                    loss = data['loss'] / opt.batch_size
                    batch_loss += loss.item()
                    batch_acc += data['accuracy'] 
                    batch_valid_acc += data['valid_accuracy'] 
                    loss.backward()
                    batch_iter += 1
                else:
                    print('Skip!')

                ## Accumulate gradient for batch training
                if ((i + 1) % opt.batch_size == 0) or (i + 1 == len(train_loader)):
                    optimizer.step()
                    optimizer.zero_grad()
                    batch_iter = 1 if batch_iter == 0 else batch_iter               
                    stats = {
                        'updates': updates,
                        'loss': batch_loss,
                        'accuracy': batch_acc / batch_iter,
                        'valid_accuracy': batch_valid_acc / batch_iter
                    }
                    log.update(stats)
                    updates += 1
                    batch_loss = 0
                    batch_acc = 0 
                    batch_valid_acc = 0
                    batch_iter = 0

            # torch.cuda.synchronize()

            # avg_num += 1
                # for name, params in model.named_parameters():
                #     print('-->name:, ', name, '-->grad mean', params.grad.mean())
            # print("All time is ", avg_time, "AVG time is ", avg_time * 1.0 /avg_num,  "number is ", avg_num, flush=True)

            # save checkpoint 
            if epoch % opt.save_per_epochs == 0 or epoch == 1:
                checkpoint = {
                    'model': model.state_dict(),
                    'config': opt,
                    'epoch': epoch
                }

                filename = os.path.join(self.ckptdir, f'epoch_{epoch}.pt')
                torch.save(checkpoint, filename)
                
            # validate
            if epoch % opt.test_freq == 0:

                if not os.path.exists(os.path.join(self.visdir, 'epoch' + str(epoch))):
                    os.mkdir(os.path.join(self.visdir, 'epoch' + str(epoch)))
                eval_output_dir = os.path.join(self.visdir, 'epoch' + str(epoch))    
                
                test_loader = self.test_loader

                with torch.no_grad():
                    # Visualize the matches.
                    mean_acc = []
                    mean_valid_acc = []
                    model.eval()
                    for i_eval, data in enumerate(tqdm(test_loader, desc='Predicting Vtx Corr...')):
                        pred = model(data)
                        # for k, v in data.items():
                        #     pred[k] = v[0]
                        #     pred = {**pred, **data}

                        mean_acc.append(pred['accuracy'])
                        mean_valid_acc.append(pred['valid_accuracy'])
                    log.log_eval({
                        'updates': opt.epoch,
                        'Accuracy': np.mean(mean_acc),
                        'Valid Accuracy': np.mean(mean_valid_acc),
                        })
                    print('Epoch [{}/{}]], Acc.: {:.4f}, Valid Acc.{:.4f}' 
                        .format(epoch, opt.epoch, np.mean(mean_acc), np.mean(mean_valid_acc)) )
                    sys.stdout.flush()
                        # make_matching_plot(
                        #     image0, image1, kpts0, kpts1, mkpts0, mkpts1, color,
                        #     text, viz_path, stem, stem, True,
                        #     True, False, 'Matches')
        
            self.schedular.step()

            


    def eval(self):
        train_action = ['breakdance_1990', 'capoeira', 'chapa-giratoria', 'fist_fight', 'flying_knee', 'freehang_climb', 'running', 'shove', 'magic', 'tripping']
        test_action = ['great_sword_slash', 'hip_hop_dancing']

        train_model = ['ganfaul', 'girlscout', 'jolleen', 'kachujin', 'knight', 'maria_w_jj', 'michelle', 'peasant_girl', 'timmy', 'uriel_a_plotexia']
        test_model = ['police', 'warrok']

        log = Logger(self.config, self.expdir)
        with torch.no_grad():
            model = self.model.eval()
            config = self.config
            epoch_tested = self.config.testing.ckpt_epoch
            ckpt_path = os.path.join(self.ckptdir, f"epoch_{epoch_tested}.pt")
            # self.device = torch.device('cuda' if config.cuda else 'cpu')
            print("Evaluation...")
            checkpoint = torch.load(ckpt_path)
            model.load_state_dict(checkpoint['model'])

            model.eval()

            if not os.path.exists(os.path.join(self.evaldir, 'epoch' + str(epoch_tested))):
                os.mkdir(os.path.join(self.evaldir, 'epoch' + str(epoch_tested)))
            if not os.path.exists(os.path.join(self.evaldir, 'epoch' + str(epoch_tested), 'jsons')):
                os.mkdir(os.path.join(self.evaldir, 'epoch' + str(epoch_tested), 'jsons'))
            eval_output_dir = os.path.join(self.evaldir, 'epoch' + str(epoch_tested))    
                
            test_loader = self.test_loader
            print(len(test_loader))
            mean_acc = []
            mean_valid_acc = []
            mean_invalid_acc = []

            # 144 data 
            # 10x10 is for training , 2x10 (unseen model) + 10x2 (unseen action) + 2x2 (unseen model unseen action) is for test
            # record the accuracy for each
            mean_model_acc = []
            mean_model_valid_acc = []
            mean_action_acc = []
            mean_action_valid_acc = []
            
            mean_none_acc = []
            mean_none_valid_acc = []

            mean_matched = []

            for i_eval, pred in enumerate(tqdm(test_loader, desc='Predicting Vtx Corr...')):
                data = model(pred)
                for k, v in pred.items():
                    pred[k] = v[0]
                    pred = {**pred, **data}
            
                mean_acc.append(pred['accuracy'])
                mean_valid_acc.append(pred['valid_accuracy'])
                this_pred = (pred['matches0'] != -1).float().cpu().data.numpy().astype(np.float32)
                mean_matched.append(np.mean( this_pred))

                unmarked = True
                for model_name in train_model:
                    if model_name in pred['file_name']:
                        mean_model_acc.append(pred['accuracy'])
                        mean_model_valid_acc.append(pred['valid_accuracy'])
                        unmarked = False
                        break

                for action_name in train_action:
                    if action_name in pred['file_name']:
                        mean_action_acc.append(pred['accuracy'])
                        mean_action_valid_acc.append(pred['valid_accuracy'])
                        unmarked = False
                        break
                
                if unmarked:
                    mean_none_acc.append(pred['accuracy'])
                    mean_action_valid_acc.append(pred['valid_accuracy'])

                if 'invalid_accuracy' in pred and pred['invalid_accuracy'] is not None:
                    mean_invalid_acc.append(pred['invalid_accuracy'])
                
                img_vis = visualize(pred)
                cv2.imwrite(os.path.join(eval_output_dir, pred['file_name'].replace('/', '_') + '.jpg'), img_vis)
                
            log.log_eval({
                'updates': self.config.testing.ckpt_epoch,
                'Accuracy': np.mean(mean_acc),
                'Accuracy (Matched)': np.mean(mean_valid_acc),
                'Unseen Action Accuracy': np.mean(mean_model_acc),
                'Unseen Action Accuracy (Matched)': np.mean(mean_model_valid_acc),
                'Unseen Model Accuracy': np.mean(mean_action_acc),
                'Unseen Model Accuracy (Matched)': np.mean(mean_action_valid_acc),
                'Unseen Both Accuracy': np.mean(mean_none_acc),
                'Unseen Both Valid Accuracy': np.mean(mean_none_valid_acc),
                'Matching Rate': np.mean(mean_matched)
                })
                # print ('Epoch [{}/{}]], Acc.: {:.4f}, Valid Acc.{:.4f}' 
                #     .format(epoch, opt.epoch, np.mean(mean_acc), np.mean(mean_valid_acc)) )
            sys.stdout.flush()

    def _build(self):
        config = self.config
        self.start_epoch = 0
        self._dir_setting()
        self._build_model()
        if not(hasattr(config, 'need_not_train_data') and config.need_not_train_data):
            self._build_train_loader()
        if not(hasattr(config, 'need_not_test_data') and config.need_not_train_data):      
            self._build_test_loader()
        self._build_optimizer()

    def _build_model(self):
        """ Define Model """
        config = self.config 
        if hasattr(config.model, 'name'):
            print(f'Experiment Using {config.model.name}')
            model_class = getattr(models, config.model.name)
            model = model_class(config.model)
        else:
            raise NotImplementedError("Wrong Model Selection")
        
        model = nn.DataParallel(model)
        self.model = model.cuda()

    def _build_train_loader(self):
        config = self.config
        self.train_loader = fetch_dataloader(config.data.train, type='train')

    def _build_test_loader(self):
        config = self.config
        self.test_loader = fetch_dataloader(config.data.test, type='test')

    def _build_optimizer(self):
        #model = nn.DataParallel(model).to(device)
        config = self.config.optimizer
        try:
            optim = getattr(torch.optim, config.type)
        except Exception:
            raise NotImplementedError('not implemented optim method ' + config.type)

        self.optimizer = optim(itertools.chain(self.model.module.parameters(),
                                             ),
                                             **config.kwargs)
        self.schedular = torch.optim.lr_scheduler.MultiStepLR(self.optimizer, **config.schedular_kwargs)

    def _dir_setting(self):
        data = self.config.data
        self.expname = self.config.expname
        self.experiment_dir = os.path.join(".", "experiments")
        self.expdir = os.path.join(self.experiment_dir, self.expname)

        if not os.path.exists(self.expdir):
            os.mkdir(self.expdir)

        self.visdir = os.path.join(self.expdir, "vis")  # -- imgs, videos, jsons
        if not os.path.exists(self.visdir):
            os.mkdir(self.visdir)

        self.ckptdir = os.path.join(self.expdir, "ckpt")
        if not os.path.exists(self.ckptdir):
            os.mkdir(self.ckptdir)

        self.evaldir = os.path.join(self.expdir, "eval")
        if not os.path.exists(self.evaldir):
            os.mkdir(self.evaldir)

        

        # self.ckptdir = os.path.join(self.expdir, "ckpt")
        # if not os.path.exists(self.ckptdir):
        #     os.mkdir(self.ckptdir)



        






================================================
FILE: data/README.md
================================================


================================================
FILE: datasets/__init__.py
================================================

from .ml_seq import fetch_dataloader
from .vd_seq import fetch_videoloader

__all__ = ['fetch_dataloader', 'fetch_videoloader']

================================================
FILE: datasets/ml_seq.py
================================================
import numpy as np
import torch
import torch.utils.data as data
import torch.nn.functional as F

import os
import math
import random
from glob import glob
import os.path as osp

import sys
import argparse
import cv2
from collections import Counter
import time
import json
import sknetwork
from sknetwork.embedding import Spectral

import scipy

def read_json(file_path):
    """
        input: json file path
        output: 2d vertex, connections and vertex index in original 3D domain
    """

    with open(file_path) as file:
        data = json.load(file)
        vertex2d = np.array(data['vertex location'])
        
        topology = data['connection']
        index = np.array(data['original index'])

    return vertex2d, topology, index


def matched_motion(v2d1, v2d2, match12, motion_pre=None):
    motion = np.zeros_like(v2d1)

    motion[match12 != -1] = v2d2[match12[match12 != -1]] - v2d1[match12 != -1]
    if motion_pre is not None:
        motion[match12 != -1] = motion[match12 != -1] + motion_pre[match12[match12 != -1]]
    return motion

def unmatched_motion(topo1, v2d1, motion12, match12):
    pos = np.arange(len(topo1))
    masked = (match12 == -1)

    round = 0
    former_len = 0
    while(len(pos[masked]) > 0):
        this_len = len(pos[masked])
        if former_len == this_len:
            break
        former_len = this_len
        round += 1
        for v in pos[masked]:
            unmatched = masked[topo1[v]]

            if unmatched.sum() != len(topo1[v]):
                motion12[v] = np.average(motion12[topo1[v]][np.invert(unmatched)], axis=0)
                masked[v] = False

                
    if len(pos[masked] > 0):
        # find the neast point for each unlabeled point
        index = ((v2d1[pos[masked]][:, None, :] - v2d1[pos[np.invert(masked)]]) ** 2).sum(2).argmin(1)
        motion12[pos[masked]] = motion12[pos[np.invert(masked)]][index]
        masked[pos[masked]] = False

    return motion12


def ids_to_mat(id1, id2):
    """
        inputs are two list of vertex index in original 3D mesh
    """
    corr1 = np.zeros(len(id1)) - 1.0
    corr2 = np.zeros(len(id2)) - 1.0
    
    id1 = np.array(id1).astype(int)[:, None]
    id2 = np.array(id2).astype(int)
    
    mat = (id1 == id2)


    pos12 = np.arange(len(id2))[None].repeat(len(id1), 0)
    pos21 = np.arange(len(id1))[None].repeat(len(id2), 0)
    corr1[mat.astype(int).sum(1).astype(bool)] = pos12[mat]
    corr2[mat.transpose().astype(int).sum(1).astype(bool)] = pos21[mat.transpose()]


    return mat, corr1, corr2

def adj_matrix(topology):
    """
        topology is the adj table; returns adj matrix
    """
    gsize = len(topology)
    adj = np.zeros((gsize, gsize)).astype(float)
    for v in range(gsize):
        adj[v][v] = 1.0
        for nb in topology[v]:
            adj[v][nb] = 1.0
            adj[nb][v] = 1.0
    return adj

class MixamoLineArtMotionSequence(data.Dataset):
    def __init__(self, root, gap=0, split='train', model=None, action=None, mode='train', use_vs=False, max_len=3050):
        """
            input:
                root: the root folder of the line art data
                gap: how many frames between two frames. gap should be an odd numbe.
                split: train or test
                model: indicate a specific character (default None)
                action: indicate a specific action (default None)

            output:
                image of sources (0, 1) and output (0.5)
                topo0, topo1
                v2d0, v2d1
                
                corr12, corr21

                motion0-->0.5, motion1-->0.5
                visibility0-->0.5, visibility   1-->0.5

        """
        super(MixamoLineArtMotionSequence, self).__init__()

        self.gap = gap
        if model == 'None':
            model = None
        if action == 'None':
            action = None

        assert(gap%2 != 0)

        self.is_train = True if mode == 'train' else False
        self.is_eval = True if mode == 'eval' else False
        # self.is_train = False
        self.max_len = max_len

        self.image_list = []
        self.label_list = []

        label_root = osp.join(root, split, 'labels')
        self.use_vs = False
        if use_vs:
            print('>>>>>>>> Using VS labels')
            self.use_vs = True
            label_root = osp.join(root, split, 'labels_vs')
        image_root = osp.join(root, split, 'frames')
        self.spectral = Spectral(64,  normalized=False)

        for clip in os.listdir(image_root):
            skip = False
            if model != None:
                for mm in model:
                    if mm in clip:
                        skip = True
                
            if action != None:
                for aa in action:
                    if aa in clip:
                        skip = True
            if skip:
                continue
            image_list = sorted(glob(osp.join(image_root, clip, '*.png')))
            label_list = sorted(glob(osp.join(label_root, clip, '*.json')))
            if len(image_list) != len(label_list):
                print(clip, flush=True)
                continue
            for i in range(len(image_list) - (gap+1)):
                self.image_list += [ [image_list[jj] for jj in range(i, i + gap + 2)] ]
            for i in range(len(label_list) - (gap+1)):
                self.label_list += [ [label_list[jj] for jj in range(i, i + gap + 2)] ]
        # print(clip)
        print('Len of Frame is ', len(self.image_list), flush=True)
        print('Len of Label is ', len(self.label_list), flush=True)

    def __getitem__(self, index):


        # load image/label files
        # load labels: 
        #   (a) read json (b) load image (c) make pseudo labels

        # image crop to a square (720x720) before input, 2d label same operation
        # index to index matching

        # test does not need index matching
        
        index = index % len(self.image_list)
        file_name = self.label_list[index][len(self.label_list[index])//2][:-4]

        imgt = [cv2.imread(self.image_list[index][ii]) for ii in range(0, len(self.image_list[index]))]

        labelt = []
        for ii in range(0, len(self.label_list[index])):
            v, t, id = read_json(self.label_list[index][ii])
            v[v > imgt[0].shape[0] - 1] = imgt[0].shape[0] - 1
            v[v < 0] = 0
            labelt.append({'keypoints': v.astype(int), 'topo': t, 'id': id})

        # make motion pseudo label
        motion = None
        motion01 = None

        start_frame = 0
        gap = self.gap // 2 + 1


        ######### forward direction
        for ii in reversed(range(start_frame + 1, start_frame + 2*gap + 1)):
            img1 = imgt[ii - 1]
            img2 = imgt[ii] 

            v2d1 = labelt[ii - 1]['keypoints'].astype(int)
            v2d2 = labelt[ii]['keypoints'].astype(int)

            topo1 = labelt[ii - 1]['topo']
            topo2 = labelt[ii ]['topo']

            id1 = labelt[ii - 1]['id']
            id2 = labelt[ii]['id']

            if self.use_vs:
                id1 = np.arange(len(id1))
                id2 = np.arange(len(id2))

            _, match12, matc21 = ids_to_mat(id1, id2)

            if ii <= start_frame + gap:
                motion01 = matched_motion(v2d1, v2d2, match12.astype(int), motion01)
                motion01 = unmatched_motion(topo1, v2d1, motion01, match12.astype(int))

            motion = matched_motion(v2d1, v2d2, match12.astype(int), motion)
            motion = unmatched_motion(topo1, v2d1, motion, match12.astype(int))
        motion0 = motion.copy()
 
        img2 = imgt[start_frame + gap]
        
        v2d1 = labelt[start_frame]['keypoints'].astype(int)
        source0_topo = labelt[start_frame]['topo']

        target = cv2.erode(img2, np.ones((3, 3), np.uint8), iterations=1)

        shift_plabel = v2d1 + motion01
        visible = np.ones(len(v2d1)).astype(float)
        visible[shift_plabel[:, 0] < 0] = 0
        visible[shift_plabel[:, 0] >= imgt[0].shape[0]] = 0
        visible[shift_plabel[:, 1] < 0] = 0
        visible[shift_plabel[:, 1] >= imgt[0].shape[0]] = 0

        # vertex visibility
        visible[visible == 1] = (target[:, :, 0][shift_plabel[visible == 1][:, 1], shift_plabel[visible == 1][:, 0]] < 255 ).astype(float)

        visible01 = visible.copy()
        v2d1s = shift_plabel

        # edge visibility
        for node, nbs in enumerate(source0_topo):
            for nb in nbs:
                if visible01[nb] and visible01[node] and ((v2d1s[node] - v2d1s[nb]) ** 2).sum() / (((v2d1[node] - v2d1[nb]) ** 2).sum() + 1e-7) > 25:
                    visible01[nb] = False
                    visible01[node] = False

        ######## backward direction
        motion = None
        motion21 = None

        for ii in range(start_frame + 1, start_frame + gap + gap + 1):
            img2 = imgt[ii - 1]
            img1 = imgt[ii] 

            v2d2 = labelt[ii - 1]['keypoints'].astype(int)
            v2d1 = labelt[ii]['keypoints'].astype(int)

            topo2 = labelt[ii - 1]['topo']
            topo1 = labelt[ii ]['topo']

            
            id1 = labelt[ii]['id']
            id2 = labelt[ii - 1]['id']
            if self.use_vs:
                id1 = np.arange(len(id1))
                id2 = np.arange(len(id2))
            _, match12, _ = ids_to_mat(id1, id2)

            if ii >= start_frame + gap + 1:
                motion21 = matched_motion(v2d1, v2d2, match12.astype(int), motion21)
                motion21 = unmatched_motion(topo1, v2d1, motion21, match12.astype(int))

            motion = matched_motion(v2d1, v2d2, match12.astype(int), motion)
            motion = unmatched_motion(topo1, v2d1, motion, match12.astype(int))

        motion2 = motion.copy()
        
        img1 = imgt[start_frame + 2*gap]
        img2 = imgt[start_frame + gap]
        
        v2d1 = labelt[start_frame + 2*gap]['keypoints'].astype(int)
        source2_topo = labelt[start_frame + 2*gap]['topo']

        shift_plabel = v2d1 + motion21
        visible = np.ones(len(v2d1)).astype(float)
        visible[shift_plabel[:, 0] < 0] = 0
        visible[shift_plabel[:, 0] >= imgt[0].shape[0]] = 0
        visible[shift_plabel[:, 1] < 0] = 0
        visible[shift_plabel[:, 1] >= imgt[0].shape[0]] = 0

        visible[visible == 1] = (target[:, :, 0][shift_plabel[visible == 1][:, 1], shift_plabel[visible == 1][:, 0]] < 255 ).astype(float)

        visible21 = visible.copy()

        v2d1s = shift_plabel

        for node, nbs in enumerate(source2_topo):
            for nb in nbs:
                if visible21[nb] and visible21[node] and ((v2d1s[node] - v2d1s[nb]) ** 2).sum() / (((v2d1[node] - v2d1[nb]) ** 2).sum() + 1e-7) > 25:
                    visible21[nb] = False
                    visible21[node] = False


        ###### prepare other data
        img2 = imgt[-1]
        img1 = imgt[0] 

        v2d2 = labelt[-1]['keypoints'].astype(int)
        v2d1 = labelt[0]['keypoints'].astype(int)

        topo2 = labelt[-1]['topo']
        topo1 = labelt[0]['topo']

        m, n = len(v2d1), len(v2d2)

        if len(img1.shape) == 2:
            img1 = np.tile(img1[...,None], (1, 1, 3))
            img2 = np.tile(img2[...,None], (1, 1, 3))
        else:
            img1 = img1[..., :3]
            img2 = img2[..., :3]

        img1 = torch.from_numpy(img1).permute(2, 0, 1).float() * 2 / 255.0 - 1.0 
        img2 = torch.from_numpy(img2).permute(2, 0, 1).float() * 2 / 255.0 - 1.0
        imgt = torch.from_numpy(imgt[start_frame + gap]).permute(2, 0, 1).float() * 2 / 255.0 - 1.0 

        v2d1 = torch.from_numpy(v2d1)
        v2d2 = torch.from_numpy(v2d2)

        visible01 = torch.from_numpy(visible01)
        visible21 = torch.from_numpy(visible21)
        motion0 = torch.from_numpy(motion0)
        motion2 = torch.from_numpy(motion2)

        v2d1[v2d1 > imgt[0].shape[0] - 1 ] = imgt[0].shape[0] - 1
        v2d1[v2d1 < 0] = 0
        v2d2[v2d2 > imgt[0].shape[1] - 1] = imgt[0].shape[1] - 1
        v2d2[v2d2 < 0] = 0

        
        id1 = labelt[start_frame]['id']
        id2 = labelt[-1]['id']
        if self.use_vs:
            id1 = np.arange(len(id1))
            id2 = np.arange(len(id2))

        mat_index, corr1, corr2 = ids_to_mat(id1, id2)
        mat_index = torch.from_numpy(mat_index).float()
        corr1 = torch.from_numpy(corr1).float()
        corr2 = torch.from_numpy(corr2).float()

        if self.is_train:
            v2d1 = torch.nn.functional.pad(v2d1, (0, 0, 0, self.max_len - m), mode='constant', value=0)
            v2d2 = torch.nn.functional.pad(v2d2, (0, 0, 0, self.max_len - n), mode='constant', value=0)
            corr1 = torch.nn.functional.pad(corr1, (0, self.max_len - m), mode='constant', value=0)
            corr2 = torch.nn.functional.pad(corr2, (0, self.max_len - n), mode='constant', value=0)
            motion0 = torch.nn.functional.pad(motion0, (0, 0, 0, self.max_len - m), mode='constant', value=0)
            motion2 = torch.nn.functional.pad(motion2, (0, 0, 0, self.max_len - n), mode='constant', value=0)
            visible01 = torch.nn.functional.pad(visible01, (0, self.max_len - m), mode='constant', value=0)
            visible21 = torch.nn.functional.pad(visible21, (0, self.max_len - n), mode='constant', value=0)

            mask0, mask1 = torch.zeros(self.max_len).float(), torch.zeros(self.max_len).float()
            mask0[:m] = 1
            mask1[:n] = 1
        else:
            mask0, mask1 = torch.ones(m).float(), torch.ones(n).float()
        
        for ii in range(len(topo1)):
            # if not len(topo1[ii]):
            topo1[ii].append(ii)
        for ii in range(len(topo2)):
            topo2[ii].append(ii)
        adj1 = sknetwork.data.from_adjacency_list(topo1, matrix_only=True, reindex=False).toarray()
        adj2 = sknetwork.data.from_adjacency_list(topo2, matrix_only=True, reindex=False).toarray()

        try:
            spec0, spec1 = np.abs(self.spectral.fit_transform(adj1)), np.abs(self.spectral.fit_transform(adj2))
        except:
            print('>>>>' + file_name, flush=True)
            spec0, spec1 = np.zeros((len(adj1), 64)), np.zeros((len(adj2), 64))
        # else:
        #     print('<<<<' + file_name, flush=True)

        # adj2 = adj2 + np.eye(len(adj2))

        if self.is_eval:
            return{
                'keypoints0': v2d1,
                'keypoints1': v2d2,
                'topo0': [topo1],
                'topo1': [topo2],
                # 'id0': id1,
                # 'id1': id2,
                'adj_mat0': adj1,
                'adj_mat1': adj2,
                'spec0': spec0,
                'spec1': spec1,
                'imaget': imgt,
                'image0': img1,
                'image1': img2,
                'motion0': motion0,
                'motion1': motion2,
                'visibility0': visible01,
                'visibility1': visible21,

                'all_matches': corr1,
                'm01': corr1,
                'm10': corr2,
                'ms': m,
                'ns': n,
                'mask0': mask0,
                'mask1': mask1,
                'file_name': file_name,
                # 'with_match': True
            }
        elif not self.is_train:
            return{
                'keypoints0': v2d1,
                'keypoints1': v2d2,
                # 'topo0': [topo1],
                # 'topo1': [topo2],
                # 'id0': id1,
                # 'id1': id2,
                'adj_mat0': adj1,
                'adj_mat1': adj2,
                'spec0': spec0,
                'spec1': spec1,
                'imaget': imgt,
                'image0': img1,
                'image1': img2,
                'motion0': motion0,
                'motion1': motion2,
                'visibility0': visible01,
                'visibility1': visible21,

                'all_matches': corr1,
                'm01': corr1,
                'm10': corr2,
                'ms': m,
                'ns': n,
                'mask0': mask0,
                'mask1': mask1,
                'file_name': file_name,
                # 'with_match': True
            }
        
        else:
            return{
                'keypoints0': v2d1,
                'keypoints1': v2d2,
                # 'topo0': topo1,
                # 'topo1': topo2,
                # 'id0': id1,
                # 'id1': id2,
                'adj_mat0': adj1,
                'adj_mat1': adj2,
                'spec0': spec0,
                'spec1': spec1,
                'imaget': imgt,
                'motion0': motion0,
                'motion1': motion2,
                'visibility0': visible01,
                'visibility1': visible21,

                'image0': img1,
                'image1': img2,

                'all_matches': corr1,
                'm01': corr1,
                'm10': corr2,
                'ms': m,
                'ns': n,
                'mask0': mask0,
                'mask1': mask1,
                'file_name': file_name,
                # 'with_match': True
            } 

        

    def __rmul__(self, v):
        self.label_list = v * self.label_list
        self.image_list = v * self.image_list
        return self
        
    def __len__(self):
        return len(self.image_list)
        

def worker_init_fn(worker_id):                                                          
    np.random.seed(np.random.get_state()[1][0] + worker_id)

def fetch_dataloader(args, type='train',):
    lineart = MixamoLineArtMotionSequence(root=args.root, gap=args.gap, split=args.type, model=args.model, action=args.action, mode=args.mode if hasattr(args, 'mode') else 'train', use_vs=args.use_vs if hasattr(args, 'use_vs') else False)
    
    if args.mode == 'train':
        lineart = MixamoLineArtMotionSequence(root=args.root, gap=args.gap, split=args.type, model=args.model, action=args.action, mode=args.mode if hasattr(args, 'mode') else 'train')
    
    if args.mode == 'train':
        loader = data.DataLoader(lineart, batch_size=args.batch_size, 
            pin_memory=True, shuffle=True, num_workers=16, drop_last=True, worker_init_fn=worker_init_fn)
    else:
        loader = data.DataLoader(lineart, batch_size=args.batch_size, 
            pin_memory=True, shuffle=False, num_workers=8)
    return loader



================================================
FILE: datasets/vd_seq.py
================================================
import numpy as np
import torch
import torch.utils.data as data
import torch.nn.functional as F
# import networkx as nx
import os
import math
import random
from glob import glob
import os.path as osp

import sys
import argparse
import cv2
from collections import Counter
import time
import json
import sknetwork
from sknetwork.embedding import Spectral

import scipy

def read_json(file_path):
    """
        input: json file path
        output: 2d vertex 
    """

    with open(file_path) as file:
        data = json.load(file)
        vertex2d = np.array(data['vertex location'])
        
        topology = data['connection']
        index = np.array(data['original index'])

        # index, vertex2d, topology = union_pixel(vertex2d, index, topology)
        # index, vertex2d, topology = union_pixel2d(vertex2d, index, topology)

    return vertex2d, topology, index


class VideoLinSeq(data.Dataset):
    def __init__(self, root, split='train'):
        """
            input:
                root: the root folder of the line art data
                split: split folder

            output:
                image of sources (0, 1) and output (0.5)
                topo0, topo1
                v2d0, v2d1


        """
        super(VideoLinSeq, self).__init__()

        self.image_list = []
        self.label_list = []

        label_root = osp.join(root, split, 'labels')
        image_root = osp.join(root, split, 'frames')

        self.spectral = Spectral(64,  normalized=False)

        for clip in os.listdir(image_root):
            
            label_list = sorted(glob(osp.join(label_root, clip, '*.json')))

            for i in range(len(label_list) - 1):
                self.label_list += [ [label_list[jj] for jj in range(i, i + 2)] ]
                self.image_list += [ [label_list[jj].replace('labels', 'frames').replace('.json', '.png') for jj in range(i, i + 2)] ]

        # print(clip)
        print('Len of Frame is ', len(self.image_list), flush=True)
        print('Len of Label is ', len(self.label_list), flush=True)

    def __getitem__(self, index):
        # prepare images
        index = index % len(self.image_list)
        file_name0 = self.label_list[index][0][:-5].split('/')[-1]
        file_name1 = self.label_list[index][-1][:-5].split('/')[-1]
        folder0 = self.label_list[index][0][:-4].split('/')[-2]
        folder1 = self.label_list[index][-1][:-4].split('/')[-2]


        imgt = [cv2.imread(self.image_list[index][ii]) for ii in range(0, len(self.image_list[index]))]

        labelt = []
        for ii in range(0, len(self.label_list[index])):
            v, t, id = read_json(self.label_list[index][ii])
            v[v > imgt[0].shape[0] - 1] = imgt[0].shape[0] - 1
            v[v < 0] = 0
            labelt.append({'keypoints': v.astype(int), 'topo': t, 'id': id})

        # make motion pseudo label

        ###### prepare other data
        img2 = imgt[-1]
        img1 = imgt[0] 

        v2d2 = labelt[-1]['keypoints'].astype(int)
        v2d1 = labelt[0]['keypoints'].astype(int)

        topo2 = labelt[-1]['topo']
        topo1 = labelt[0]['topo']

        m, n = len(v2d1), len(v2d2)

        if len(img1.shape) == 2:
            img1 = np.tile(img1[...,None], (1, 1, 3))
            img2 = np.tile(img2[...,None], (1, 1, 3))
        else:
            img1 = img1[..., :3]
            img2 = img2[..., :3]

        img1 = torch.from_numpy(img1).permute(2, 0, 1).float() * 2 / 255.0 - 1.0 
        img2 = torch.from_numpy(img2).permute(2, 0, 1).float() * 2 / 255.0 - 1.0

        v2d1 = torch.from_numpy(v2d1)
        v2d2 = torch.from_numpy(v2d2)

        mask0, mask1 = torch.ones(m).float(), torch.ones(n).float()

        v2d1[v2d1 > imgt[0].shape[0] - 1 ] = imgt[0].shape[0] - 1
        v2d1[v2d1 < 0] = 0
        v2d2[v2d2 > imgt[0].shape[1] - 1] = imgt[0].shape[1] - 1
        v2d2[v2d2 < 0] = 0

     
        id1 = np.arange(len(v2d1))
        id2 = np.arange(len(v2d2))

       
        for ii in range(len(topo1)):
            topo1[ii].append(ii)
        for ii in range(len(topo2)):
            topo2[ii].append(ii)
        adj1 = sknetwork.data.from_adjacency_list(topo1, matrix_only=True, reindex=False).toarray()
        adj2 = sknetwork.data.from_adjacency_list(topo2, matrix_only=True, reindex=False).toarray()

        try:
            spec0, spec1 = np.abs(self.spectral.fit_transform(adj1)), np.abs(self.spectral.fit_transform(adj2))
        except:
            print('>>>>' + file_name, flush=True)
            spec0, spec1 = np.zeros((len(adj1), 64)), np.zeros((len(adj2), 64))

        return{
            'keypoints0': v2d1,
            'keypoints1': v2d2,
            'topo0': [topo1],
            'topo1': [topo2],
            'adj_mat0': adj1,
            'adj_mat1': adj2,
            'spec0': spec0,
            'spec1': spec1,
            'image0': img1,
            'image1': img2,
            'ms': m,
            'ns': n,
            'mask0': mask0,
            'mask1': mask1,
            'gen_vid': True,
            'file_name0': file_name0,
            'file_name1': file_name1,
            'folder_name0': folder0,
            'folder_name1': folder1
        }


    def __rmul__(self, v):
        self.label_list = v * self.label_list
        self.image_list = v * self.image_list
        return self
        
    def __len__(self):
        return len(self.image_list)
        

def worker_init_fn(worker_id):                                                          
    np.random.seed(np.random.get_state()[1][0] + worker_id)

def fetch_videoloader(args, type='train',):
    lineart = VideoLinSeq(root=args.root, split=args.type, )
    
    loader = data.DataLoader(lineart, batch_size=args.batch_size, 
            pin_memory=True, shuffle=False, num_workers=8)
    return loader



================================================
FILE: download.sh
================================================
cd data
gdown 1SNRGajIECxNwRp6ZJ0IlY7AEl2mRm2DR
unzip ml240data.zip


================================================
FILE: experiments/inbetweener_full/ckpt/.gitkeep
================================================


================================================
FILE: inbetween.py
================================================
""" This script handling the training process. """
import os
import time
import random
import argparse
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data
from datasets import fetch_dataloader
from datasets import fetch_videoloader
import random
from utils.log import Logger

from torch.optim import *
import warnings
from tqdm import tqdm
import itertools
import pdb
import numpy as np
import models
import datetime
import sys
import json
import cv2

from utils.visualize_inbetween3 import visualize
# from utils.visualize_inbetween import visualize
from utils.visualize_video import visvid as visgen
import matplotlib.cm as cm
# from models.utils import make_matching_seg_plot

warnings.filterwarnings('ignore')

# a, b, c, d = check_data_distribution('/mnt/lustre/lisiyao1/dance/dance2/DanceRevolution/data/aistpp_train')

import matplotlib.pyplot as plt
import pdb

class DraftRefine():
    def __init__(self, args):
        self.config = args
        torch.backends.cudnn.benchmark = True
        torch.multiprocessing.set_sharing_strategy('file_system')
        self._build()

    def train(self):
        
        opt = self.config
        print(opt)

        # store viz results
        # eval_output_dir = Path(self.expdir)
        # eval_output_dir.mkdir(exist_ok=True, parents=True)

        # print('Will write visualization images to',
        #     'directory \"{}\"'.format(eval_output_dir))

        # load training data
        
        model = self.model

        checkpoint = torch.load(self.config.corr_weights)
        dict = {k.replace('module.', ''): checkpoint['model'][k] for k in checkpoint['model']}
        model.module.corr.load_state_dict(dict)

        if hasattr(self.config, 'init_weight'):
            checkpoint = torch.load(self.config.init_weight)
            model.load_state_dict(checkpoint['model'])

        # if torch.cuda.is_available():
        #     model.cuda() # make sure it trains on GPU
        # else:
        #     print("### CUDA not available ###")
            # return
        optimizer = self.optimizer
        schedular = self.schedular
        mean_loss = []
        log = Logger(self.config, self.expdir)
        updates = 0
        
        # set seed
        random.seed(opt.seed)
        torch.manual_seed(opt.seed)
        torch.cuda.manual_seed(opt.seed)
        np.random.seed(opt.seed)
        # print(opt.seed)
        # start training

        for epoch in range(1, opt.epoch+1):
            np.random.seed(opt.seed + epoch)
            train_loader = self.train_loader
            log.set_progress(epoch, len(train_loader))
            batch_loss = 0
            batch_epe = 0 
            batch_acc = 0
            batch_iter = 0
            model.train()
            avg_time = 0
            avg_num = 0
            # torch.cuda.synchronize()
            
            for i, data in enumerate(train_loader):
                pred = model(data)
                if True:
                    loss = pred['loss'].mean() 
                    # print(loss.item(), opt.batch_size)
                    batch_loss += loss.item() / opt.batch_size
                    batch_acc += pred['Visibility Acc'].mean().item() / opt.batch_size
                    batch_epe += pred['EPE'].mean().item() / opt.batch_size 
                    loss.backward()
                    batch_iter += 1
                else:
                    print('Skip!')



                if ((i + 1) % opt.batch_size == 0) or (i + 1 == len(train_loader)):
                    optimizer.step()

                    optimizer.zero_grad()
                    batch_iter = 1 if batch_iter == 0 else batch_iter               
                    stats = {
                        'updates': updates,
                        'loss': batch_loss,
                        'accuracy': batch_acc,
                        'EPE': batch_epe
                    }
                    log.update(stats)
                    updates += 1
                    batch_loss = 0
                    batch_acc = 0 
                    batch_epe = 0
                    batch_iter = 0
                # tend = time.time()
                # avg_time = (tend - tstart)
                # print('Time is ', avg_time)

            # torch.cuda.synchronize()

            # avg_num += 1
                # for name, params in model.named_parameters():
                #     print('-->name:, ', name, '-->grad mean', params.grad.mean())
            # print("All time is ", avg_time, "AVG time is ", avg_time * 1.0 /avg_num,  "number is ", avg_num, flush=True)

            # save checkpoint 
            if epoch % opt.save_per_epochs == 0 or epoch == 1:
                checkpoint = {
                    'model': model.state_dict(),
                    'config': opt,
                    'epoch': epoch
                }

                filename = os.path.join(self.ckptdir, f'epoch_{epoch}.pt')
                torch.save(checkpoint, filename)
                
            # validate
            if epoch % opt.test_freq == 0:

                if not os.path.exists(os.path.join(self.visdir, 'epoch' + str(epoch))):
                    os.mkdir(os.path.join(self.visdir, 'epoch' + str(epoch)))
                eval_output_dir = os.path.join(self.visdir, 'epoch' + str(epoch))    
                
                test_loader = self.test_loader

                with torch.no_grad():
                    # Visualize the matches.
                    mean_acc = []
                    mean_epe = []
                    model.eval()
                    for i_eval, data in enumerate(tqdm(test_loader, desc='Refining motion and visibility...')):
                        pred = model(data)
                        # for k, v in data.items():
                        #     pred[k] = v[0]
                        #     pred = {**pred, **data}

                        mean_acc.append(pred['Visibility Acc'].mean().item())
                        mean_epe.append(pred['EPE'].mean().item())
                    log.log_eval({
                        'updates': opt.epoch,
                        'Visibility Accuracy': np.mean(mean_acc),
                        'EPE': np.mean(mean_epe),
                        })
                    print('Epoch [{}/{}]], Vis Acc.: {:.4f}, EPE: {:.4f}' 
                        .format(epoch, opt.epoch, np.mean(mean_acc), np.mean(mean_epe)) )
                    sys.stdout.flush()
                        # make_matching_plot(
                        #     image0, image1, kpts0, kpts1, mkpts0, mkpts1, color,
                        #     text, viz_path, stem, stem, True,
                        #     True, False, 'Matches')
                        

            self.schedular.step()
            


    def eval(self):
        train_action = ['breakdance_1990', 'capoeira', 'chapa-giratoria', 'fist_fight', 'flying_knee', 'freehang_climb', 'running', 'shove', 'magic', 'tripping']
        test_action = ['great_sword_slash', 'hip_hop_dancing']

        train_model = ['ganfaul', 'girlscout', 'jolleen', 'kachujin', 'knight', 'maria_w_jj', 'michelle', 'peasant_girl', 'timmy', 'uriel_a_plotexia']
        test_model = ['police', 'warrok']

        config = self.config
        if not os.path.exists(config.imwrite_dir):
            os.mkdir(config.imwrite_dir)
            
        log = Logger(self.config, self.expdir)
        with torch.no_grad():
            model = self.model.eval()
            config = self.config
            epoch_tested = self.config.testing.ckpt_epoch
            if epoch_tested == 0 or epoch_tested == '0':
                checkpoint = torch.load(self.config.corr_weights)
                dict = {k.replace('module.', ''): checkpoint['model'][k] for k in checkpoint['model']}
                model.module.corr.load_state_dict(dict)
            else:
                ckpt_path = os.path.join(self.ckptdir, f"epoch_{epoch_tested}.pt")
                # self.device = torch.device('cuda' if config.cuda else 'cpu')
                print("Evaluation...")
                checkpoint = torch.load(ckpt_path)
                model.load_state_dict(checkpoint['model'])
            model.eval()

            if not os.path.exists(os.path.join(self.evaldir, 'epoch' + str(epoch_tested))):
                os.mkdir(os.path.join(self.evaldir, 'epoch' + str(epoch_tested)))
            if not os.path.exists(os.path.join(self.evaldir, 'epoch' + str(epoch_tested), 'jsons')):
                os.mkdir(os.path.join(self.evaldir, 'epoch' + str(epoch_tested), 'jsons'))
            eval_output_dir = os.path.join(self.evaldir, 'epoch' + str(epoch_tested))    
                
            test_loader = self.test_loader
            print(len(test_loader))
            mean_acc = []
            mean_valid_acc = []
            mean_invalid_acc = []

            # 144 data 10x10 is for training , 2x10 (unseen model) + 10x2 (unseen action) + 2x2 (unseen model unseen action) is for test
            # record the accuracy for 
            mean_model_acc = []
            mean_model_epe = []
            mean_action_acc = []
            mean_action_epe = []
            
            mean_none_acc = []
            mean_none_epe = []

            mean_acc = []
            mean_epe = []

            mean_cd = []
            model.eval()
            # for i_eval, data in enumerate(tqdm(test_loader, desc='Refining motion and visibility...')):
            #     pred = model(data)
            #     # for k, v in data.items():
            #     #     pred[k] = v[0]
            #     #     pred = {**pred, **data}

            #     mean_acc.append(pred['Visibility Acc'].mean().item())
            #     mean_epe.append(pred['EPE'].mean().item())
            # log.log_eval({
            #     'updates': opt.epoch,
            #     'Visibility Accuracy': np.mean(mean_acc),
            #     'EPE': np.mean(mean_epe),
            #     })

            for i_eval, data in enumerate(tqdm(test_loader, desc='Predicting Vtx Corr...')):
                # if i_eval == 34:
                #     continue
                
                pred = model(data)
                for k, v in pred.items():
                    # print(k, flush=True)
                    pred[k] = v
                    pred = {**pred, **data}
            
                mean_acc.append(pred['Visibility Acc'].mean().item())
                mean_epe.append(pred['EPE'].mean().item())

                unmarked = True
                for model_name in train_model:
                    if model_name in pred['file_name']:
                        mean_model_acc.append(pred['Visibility Acc'])
                        mean_model_epe.append(pred['EPE'])
                        unmarked = False
                        break

                for action_name in train_action:
                    if action_name in pred['file_name']:
                        mean_action_acc.append(pred['Visibility Acc'])
                        mean_action_epe.append(pred['EPE'])
                        unmarked = False
                        break
                
                if unmarked:
                    mean_none_acc.append(pred['Visibility Acc'])
                    mean_action_epe.append(pred['EPE'])

                # if 'invalid_accuracy' in pred and pred['invalid_accuracy'] is not None:
                #     mean_invalid_acc.append(pred['invalid_accuracy'])
                
                img_vis = visualize(pred)
                # mean_cd.append(cd.item())
                file_name = pred['file_name'][0].split('/')
                cv2.imwrite(os.path.join(config.imwrite_dir, (file_name[-2] + '_' + file_name[-1]) + 'png'), img_vis)

                # cv2.imwrite(os.path.join(eval_output_dir, pred['file_name'][0].replace('/', '_') + '.jpg'), img_vis)
                
            log.log_eval({
                'updates': self.config.testing.ckpt_epoch,
                # 'mean CD': np.mean(mean_cd),
                # 'Visibility Accuracy': np.mean(mean_acc),
                # 'EPE': np.mean(mean_epe),
                # 'Unseen Action Accuracy': np.mean(mean_model_acc),
                # 'Unseen Action EPE': np.mean(mean_model_epe),
                # 'Unseen Model Accuracy': np.mean(mean_action_acc),
                # 'Unseen Model EPE': np.mean(mean_action_epe),
                # 'Unseen Both Accuracy': np.mean(mean_none_acc),
                # 'Unseen Both Valid Accuracy': np.mean(mean_none_epe)
                })
                # print ('Epoch [{}/{}]], Acc.: {:.4f}, Valid Acc.{:.4f}' 
                #     .format(epoch, opt.epoch, np.mean(mean_acc), np.mean(mean_valid_acc)) )
            sys.stdout.flush()


    def gen(self):
        log = Logger(self.config, self.viddir)
        with torch.no_grad():
            model = self.model.eval()
            config = self.config
            epoch_tested = self.config.testing.ckpt_epoch
            if epoch_tested == 0 or epoch_tested == '0':
                checkpoint = torch.load(self.config.corr_weights)
                dict = {k.replace('module.', ''): checkpoint['model'][k] for k in checkpoint['model']}
                model.module.corr.load_state_dict(dict)
            else:
                ckpt_path = os.path.join(self.ckptdir, f"epoch_{epoch_tested}.pt")
                # self.device = torch.device('cuda' if config.cuda else 'cpu')
                print("Evaluation...")
                checkpoint = torch.load(ckpt_path)
                model.load_state_dict(checkpoint['model'])
            model.eval()

            if not os.path.exists(os.path.join(self.viddir, 'epoch' + str(epoch_tested))):
                os.mkdir(os.path.join(self.viddir, 'epoch' + str(epoch_tested)))
            if not os.path.exists(os.path.join(self.viddir, 'epoch' + str(epoch_tested), 'frames')):
                os.mkdir(os.path.join(self.viddir, 'epoch' + str(epoch_tested), 'frames'))
            if not os.path.exists(os.path.join(self.viddir, 'epoch' + str(epoch_tested), 'videos')):
                os.mkdir(os.path.join(self.viddir, 'epoch' + str(epoch_tested), 'videos'))

            gen_frame_dir = os.path.join(self.viddir, 'epoch' + str(epoch_tested), 'frames')  
            gen_video_dir = os.path.join(self.viddir, 'epoch' + str(epoch_tested), 'videos')    
                
            vid_loader = self.vid_loader
            print(len(vid_loader))
            mean_acc = []
            mean_valid_acc = []
            mean_invalid_acc = []

            model.eval()

            for i_eval, data in enumerate(tqdm(vid_loader, desc='Gen Video...')):
                
                pred = model(data)
                for k, v in pred.items():
                    pred[k] = v
                    pred = {**pred, **data}
            

                img_vis = visgen(pred, config.inter_frames)

                if not os.path.exists(os.path.join(gen_frame_dir, pred['folder_name0'][0])):
                    os.mkdir(os.path.join(gen_frame_dir, pred['folder_name0'][0]))
                
                cv2.imwrite(os.path.join(gen_frame_dir, pred['folder_name0'][0], pred['file_name0'][0] + '_000.jpg'),img_vis[0])
                for tt in range(config.inter_frames):
                    cv2.imwrite(os.path.join(gen_frame_dir, pred['folder_name0'][0], pred['file_name0'][0] + '_' + '{:03d}'.format(tt + 1) + '.jpg'), img_vis[tt + 1])
                cv2.imwrite(os.path.join(gen_frame_dir, pred['folder_name0'][0], pred['file_name1'][0] + '_000.jpg'),img_vis[-1])
            
            for ff in os.listdir(gen_frame_dir):
                frame_dir = os.path.join(gen_frame_dir, ff)
                video_file = os.path.join(gen_video_dir, f"{ff}.mp4")
                cmd = f"ffmpeg -r {config.fps} -pattern_type glob -i '{frame_dir}/*.jpg' -vb 20M -vcodec mpeg4 -y '{video_file}'"
                
                print(cmd, flush=True)
                os.system(cmd)
                

            log.log_eval({
                'updates': self.config.testing.ckpt_epoch,
                })
            sys.stdout.flush()

    def _build(self):
        config = self.config
        self.start_epoch = 0
        self._dir_setting()
        self._build_model()
        if not(hasattr(config, 'need_not_train_data') and config.need_not_train_data):
            self._build_train_loader()
        if not(hasattr(config, 'need_not_test_data') and config.need_not_train_data):      
            self._build_test_loader()
        if hasattr(config, 'gen_video') and config.gen_video:
            self._build_video_loader()
        self._build_optimizer()

    def _build_model(self):
        """ Define Model """
        config = self.config 
        if hasattr(config.model, 'name'):
            print(f'Experiment Using {config.model.name}')
            model_class = getattr(models, config.model.name)
            model = model_class(config.model)
        else:
            raise NotImplementedError("Wrong Model Selection")
        
        model = nn.DataParallel(model)
        self.model = model.cuda()

    def _build_train_loader(self):
        config = self.config
        self.train_loader = fetch_dataloader(config.data.train, type='train')

    def _build_test_loader(self):
        config = self.config
        self.test_loader = fetch_dataloader(config.data.test, type='test')
    def _build_video_loader(self):
        config = self.config
        self.vid_loader = fetch_videoloader(config.video)

    def _build_optimizer(self):
        #model = nn.DataParallel(model).to(device)
        config = self.config.optimizer
        try:
            optim = getattr(torch.optim, config.type)
        except Exception:
            raise NotImplementedError('not implemented optim method ' + config.type)

        self.optimizer = optim(itertools.chain(self.model.module.parameters(),
                                             ),
                                             **config.kwargs)
        self.schedular = torch.optim.lr_scheduler.MultiStepLR(self.optimizer, **config.schedular_kwargs)

    def _dir_setting(self):
        self.expname = self.config.expname
        # self.experiment_dir = os.path.join("/mnt/cache/syli/inbetween", "experiments")

        self.experiment_dir = 'experiments'
        self.expdir = os.path.join(self.experiment_dir, self.expname)

        if not os.path.exists(self.expdir):
            os.mkdir(self.expdir)

        self.visdir = os.path.join(self.expdir, "vis")  # -- imgs, videos, jsons
        if not os.path.exists(self.visdir):
            os.mkdir(self.visdir)

        self.ckptdir = os.path.join(self.expdir, "ckpt")
        if not os.path.exists(self.ckptdir):
            os.mkdir(self.ckptdir)

        self.evaldir = os.path.join(self.expdir, "eval")
        if not os.path.exists(self.evaldir):
            os.mkdir(self.evaldir)

        self.viddir = os.path.join(self.expdir, "video")
        if not os.path.exists(self.viddir):
            os.mkdir(self.viddir)

        

        # self.ckptdir = os.path.join(self.expdir, "ckpt")
        # if not os.path.exists(self.ckptdir):
        #     os.mkdir(self.ckptdir)



        






================================================
FILE: inbetween_results/.gitkeep
================================================


================================================
FILE: main.py
================================================
from inbetween import DraftRefine
import argparse
import os
import yaml
from pprint import pprint
from easydict import EasyDict



def parse_args():
    parser = argparse.ArgumentParser(
        description='Anime segment matching')
    parser.add_argument('--config', default='')
    # exclusive arguments
    group = parser.add_mutually_exclusive_group(required=True)
    group.add_argument('--train', action='store_true')
    group.add_argument('--eval', action='store_true')
    group.add_argument('--gen', action='store_true')


    return parser.parse_args()


def main():
    # parse arguments and load config
    args = parse_args()
    with open(args.config) as f:
        config = yaml.load(f)

    for k, v in vars(args).items():
        config[k] = v
    pprint(config)

    config = EasyDict(config)
    agent = DraftRefine(config)
    print(config)

    if args.train:
        agent.train()
    elif args.eval:
        agent.eval()
    elif args.gen:
        agent.gen()


if __name__ == '__main__':
    main()


================================================
FILE: models/__init__.py
================================================
# from .transformer_refiner import Refiner
# from .inbetweener import Inbetweener
# from .inbetweener_with_mask import InbetweenerM
# from .inbetweener_wo_rp import InbetweenerM as InbetweenerNRP
from .inbetweener_with_mask_with_spec import InbetweenerTM
# from .inbetweener_with_mask_with_spec_wo_OT import InbetweenerTMwoOT
from .inbetweener_with_mask2 import InbetweenerM as InbetweenerM2
# from .inbetweener_with_mask_wo_pos import InbetweenerNP
# from .inbetweener_with_mask_wo_pos_wo_spec import InbetweenerNPS
# from .transformer_refiner2 import Refiner as Refiner2
# from .transformer_refiner3 import Refiner as Refiner3
# from .transformer_refiner4 import Refiner as Refiner4
# from .transformer_refiner5 import Refiner as Refiner5
# from .transformer_refiner_norm import Refiner as RefinerN

__all__ = [ 'InbetweenerTM', 'InbetweenerM2']


================================================
FILE: models/inbetweener_with_mask2.py
================================================
from copy import deepcopy
from pathlib import Path
import torch
from torch import nn
# from seg_desc import seg_descriptor
import argparse
import torch.nn.functional as F

def MLP(channels: list, do_bn=True):
    """ Multi-layer perceptron """
    n = len(channels)
    layers = []
    for i in range(1, n):
        layers.append(
            nn.Conv1d(channels[i - 1], channels[i], kernel_size=1, bias=True))
        if i < (n-1):
            if do_bn:
                # layers.append(nn.BatchNorm1d(channels[i]))
                layers.append(nn.InstanceNorm1d(channels[i]))
            layers.append(nn.ReLU())
    return nn.Sequential(*layers)


def normalize_keypoints(kpts, image_shape):
    """ Normalize keypoints locations based on image image_shape"""
    _, _, height, width = image_shape
    one = kpts.new_tensor(1)
    size = torch.stack([one*width, one*height])[None]
    center = size / 2
    scaling = size.max(1, keepdim=True).values * 0.7
    return (kpts - center[:, None, :]) / scaling[:, None, :]

class ThreeLayerEncoder(nn.Module):
    """ Joint encoding of visual appearance and location using MLPs"""
    def __init__(self, enc_dim):
        super().__init__()
        # input must be 3 channel (r, g, b)
        self.layer1 = nn.Conv2d(3, enc_dim//4, 7, padding=3)
        self.non_linear1 = nn.ReLU()
        self.layer2 = nn.Conv2d(enc_dim//4, enc_dim//2, 3, padding=1)
        self.non_linear2 = nn.ReLU()
        self.layer3 = nn.Conv2d(enc_dim//2, enc_dim, 3, padding=1)

        self.norm1 = nn.InstanceNorm2d(enc_dim//4)
        self.norm2 = nn.InstanceNorm2d(enc_dim//2)
        self.norm3 = nn.InstanceNorm2d(enc_dim)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                nn.init.constant_(m.bias, 0.0)

    def forward(self, img):
        x = self.non_linear1(self.norm1(self.layer1(img)))
        x = self.non_linear2(self.norm2(self.layer2(x)))
        x = self.norm3(self.layer3(x))
        # x = self.non_linear1(self.layer1(img))
        # x = self.non_linear2(self.layer2(x))
        # x = self.layer3(x)
        return x


class VertexDescriptor(nn.Module):
    """ Joint encoding of visual appearance and location using MLPs"""
    def __init__(self, enc_dim):
        super().__init__()
        self.encoder = ThreeLayerEncoder(enc_dim)
        # self.super_pixel_pooling = 
        # use scatter
        # nn.init.constant_(self.encoder[-1].bias, 0.0)

    def forward(self, img, vtx):
        x = self.encoder(img)
        n, c, h, w = x.size()
        assert((h, w) == img.size()[2:4])
        return x[:, :, torch.round(vtx[0, :, 1]).long(), torch.round(vtx[0, :, 0]).long()]
        # return super_pixel_pooling(x.view(n, c, -1), seg.view(-1).long(), reduce='mean')
        # here return size is [1]xCx|Seg|


class KeypointEncoder(nn.Module):
    """ Joint encoding of visual appearance and location using MLPs"""
    def __init__(self, feature_dim, layers):
        super().__init__()
        self.encoder = MLP([2] + layers + [feature_dim])
        # for m in self.encoder.modules():
        #     if isinstance(m, nn.Conv2d):
        #         nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
        #         nn.init.constant_(m.bias, 0.0)
        nn.init.constant_(self.encoder[-1].bias, 0.0)

    def forward(self, kpts):
        inputs = kpts.transpose(1, 2)
        # print(inputs.size(), 'wula!')
        x = self.encoder(inputs)
        # print(x.size())
        return x


def attention(query, key, value, mask=None):
    dim = query.shape[1]
    scores = torch.einsum('bdhn,bdhm->bhnm', query, key) / dim**.5
    if mask is not None:
        # print(mask, flush=True)
        scores = scores.masked_fill(mask==0, float('-inf'))

    # att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
    # att = att.masked_fill(self.mask[:,:,:T,:T] == 0, float('-inf'))
    # att = F.softmax(att, dim=-1)
    prob = torch.nn.functional.softmax(scores, dim=-1)

    # print(scores[1][1], prob[1][1], flush=True)
    # while True:
    #     pass 
    # prob = torch.exp(scores) /((torch.sum(torch.exp(scores), dim=-1)[:, :, :, None]) + 1e-7)
    return torch.einsum('bhnm,bdhm->bdhn', prob, value), prob


class MultiHeadedAttention(nn.Module):
    """ Multi-head attention to increase model expressivitiy """
    def __init__(self, num_heads: int, d_model: int):
        super().__init__()
        assert d_model % num_heads == 0
        self.dim = d_model // num_heads
        self.num_heads = num_heads
        self.merge = nn.Conv1d(d_model, d_model, kernel_size=1)
        self.proj = nn.ModuleList([deepcopy(self.merge) for _ in range(3)])

    def forward(self, query, key, value, mask=None):
        batch_dim = query.size(0)
        query, key, value = [l(x).view(batch_dim, self.dim, self.num_heads, -1)
                             for l, x in zip(self.proj, (query, key, value))]
        x, prob = attention(query, key, value, mask)
        # self.prob.append(prob)
        return self.merge(x.contiguous().view(batch_dim, self.dim*self.num_heads, -1))


class AttentionalPropagation(nn.Module):
    def __init__(self, feature_dim: int, num_heads: int):
        super().__init__()
        self.attn = MultiHeadedAttention(num_heads, feature_dim)
        self.mlp = MLP([feature_dim*2, feature_dim*2, feature_dim])
        nn.init.constant_(self.mlp[-1].bias, 0.0)

    def forward(self, x, source, mask=None):
        message = self.attn(x, source, source, mask)
        return self.mlp(torch.cat([x, message], dim=1))


class AttentionalGNN(nn.Module):
    def __init__(self, feature_dim: int, layer_names: list):
        super().__init__()
        self.layers = nn.ModuleList([
            AttentionalPropagation(feature_dim, 4)
            for _ in range(len(layer_names))])
        self.names = layer_names

    def forward(self, desc0, desc1, mask00=None, mask11=None, mask01=None, mask10=None):
        for layer, name in zip(self.layers, self.names):
            layer.attn.prob = []
            if name == 'cross':
                src0, src1 = desc1, desc0
                mask0, mask1 = mask01[:, None], mask10[:, None] 
            else:  # if name == 'self':
                src0, src1 = desc0, desc1
                mask0, mask1 = mask00[:, None], mask11[:, None]

            delta0, delta1 = layer(desc0, src0, mask0), layer(desc1, src1, mask1)
            desc0, desc1 = (desc0 + delta0), (desc1 + delta1)
        return desc0, desc1


def log_sinkhorn_iterations(Z, log_mu, log_nu, iters: int):
    """ Perform Sinkhorn Normalization in Log-space for stability"""
    u, v = torch.zeros_like(log_mu), torch.zeros_like(log_nu)
    for _ in range(iters):
        u = log_mu - torch.logsumexp(Z + v.unsqueeze(1), dim=2)
        v = log_nu - torch.logsumexp(Z + u.unsqueeze(2), dim=1)
    return Z + u.unsqueeze(2) + v.unsqueeze(1)


def log_optimal_transport(scores, alpha, iters: int, ms=None, ns=None):
    """ Perform Differentiable Optimal Transport in Log-space for stability"""
    b, m, n = scores.shape
    one = scores.new_tensor(1)
    if ms is  None or ns is  None:
        ms, ns = (m*one).to(scores), (n*one).to(scores)
    # else:
    #     ms, ns = ms.to(scores)[:, None], ns.to(scores)[:, None]
    # here m,n should be parameters not shape

    # ms, ns: (b, )
    bins0 = alpha.expand(b, m, 1)
    bins1 = alpha.expand(b, 1, n)
    alpha = alpha.expand(b, 1, 1)

    # pad additional scores for unmatcheed (to -1)
    # alpha is the learned threshold
    couplings = torch.cat([torch.cat([scores, bins0], -1),
                           torch.cat([bins1, alpha], -1)], 1)

    norm = - (ms + ns).log() # (b, )
    # print(scores.min(), flush=True)
    if ms.size()[0] > 0:
        norm = norm[:, None]
        log_mu = torch.cat([norm.expand(b, m), ns.log()[:, None] + norm], dim=-1) # (m + 1)
        log_nu = torch.cat([norm.expand(b, n), ms.log()[:, None] + norm], dim=-1)
        # print(log_nu.min(), log_mu.min(), flush=True)
    else:
        log_mu = torch.cat([norm.expand(m), ns.log()[None] + norm]) # (m + 1)
        log_nu = torch.cat([norm.expand(n), ms.log()[None] + norm])
        log_mu, log_nu = log_mu[None].expand(b, -1), log_nu[None].expand(b, -1)

    
    Z = log_sinkhorn_iterations(couplings, log_mu, log_nu, iters)

    if ms.size()[0] > 1:
        norm = norm[:, :, None]
    Z = Z - norm  # multiply probabilities by M+N
    return Z


def arange_like(x, dim: int):
    return x.new_ones(x.shape[dim]).cumsum(0) - 1  # traceable in 1.1


class SuperGlueM(nn.Module):
    """SuperGlue feature matching middle-end

    Given two sets of keypoints and locations, we determine the
    correspondences by:
      1. Keypoint Encoding (normalization + visual feature and location fusion)
      2. Graph Neural Network with multiple self and cross-attention layers
      3. Final projection layer
      4. Optimal Transport Layer (a differentiable Hungarian matching algorithm)
      5. Thresholding matrix based on mutual exclusivity and a match_threshold

    The correspondence ids use -1 to indicate non-matching points.

    Paul-Edouard Sarlin, Daniel DeTone, Tomasz Malisiewicz, and Andrew
    Rabinovich. SuperGlue: Learning Feature Matching with Graph Neural
    Networks. In CVPR, 2020. https://arxiv.org/abs/1911.11763

    """
    # default_config = {
    #     'descriptor_dim': 128,
    #     'weights': 'indoor',
    #     'keypoint_encoder': [32, 64, 128],
    #     'GNN_layers': ['self', 'cross'] * 9,
    #     'sinkhorn_iterations': 100,
    #     'match_threshold': 0.2,
    # }

    def __init__(self, config=None):
        super().__init__()

        default_config = argparse.Namespace()
        default_config.descriptor_dim = 128
        # default_config.weights = 
        default_config.keypoint_encoder = [32, 64, 128]
        default_config.GNN_layers = ['self', 'cross'] * 9
        default_config.sinkhorn_iterations = 100
        default_config.match_threshold = 0.2
        # self.config = {**self.default_config, **config}

        if config is None:
            self.config = default_config
        else:
            self.config = config   
            self.config.GNN_layers = ['self', 'cross'] * self.config.GNN_layer_num
            # print('WULA!', self.config.GNN_layer_num)

        self.kenc = KeypointEncoder(
            self.config.descriptor_dim, self.config.keypoint_encoder)

        self.gnn = AttentionalGNN(
            self.config.descriptor_dim, self.config.GNN_layers)

        self.final_proj = nn.Conv1d(
            self.config.descriptor_dim, self.config.descriptor_dim,
            kernel_size=1, bias=True)

        bin_score = torch.nn.Parameter(torch.tensor(1.))
        self.register_parameter('bin_score', bin_score)
        self.vertex_desc = VertexDescriptor(self.config.descriptor_dim)

        # assert self.config.weights in ['indoor', 'outdoor']
        # path = Path(__file__).parent
        # path = path / 'weights/superglue_{}.pth'.format(self.config.weights)
        # self.load_state_dict(torch.load(path))
        # print('Loaded SuperGlue model (\"{}\" weights)'.format(
        #     self.config.weights))

    def forward(self, data):
        """Run SuperGlue on a pair of keypoints and descriptors"""
        # print(data['segment0'].size())
        # desc0, desc1 = data['descriptors0'].float()(), data['descriptors1'].float()()
         # print(desc0.size())
        kpts0, kpts1 = data['keypoints0'].float(), data['keypoints1'].float()

        ori_mask0, ori_mask1 = data['mask0'].float(), data['mask1'].float()
        dim_m, dim_n = data['ms'].float(), data['ns'].float()

        mmax = dim_m.int().max()
        nmax = dim_n.int().max()

        mask0 = ori_mask0[:, :mmax]
        mask1 = ori_mask1[:, :nmax]

        kpts0 = kpts0[:, :mmax]
        kpts1 = kpts1[:, :nmax]

        desc0, desc1 = self.vertex_desc(data['image0'], kpts0.float()), self.vertex_desc(data['image1'], kpts1.float())
        
       
        # print(desc0.size(), flush=True)

        mask00 = torch.ones_like(mask0)[:, :, None] * mask0[:, None, :]
        # print(mask00[1], flush=True)
        
        mask11 = torch.ones_like(mask1)[:, :, None] * mask1[:, None, :]
        mask01 = torch.ones_like(mask0)[:, :, None] * mask1[:, None, :]
        mask10 = torch.ones_like(mask1)[:, :, None] * mask0[:, None, :]
        
        # desc0 = desc0.transpose(0,1)
        # desc1 = desc1.transpose(0,1)
        # kpts0 = torch.reshape(kpts0, (1, -1, 2))
        # kpts1 = torch.reshape(kpts1, (1, -1, 2))

        if kpts0.shape[1] < 2 or kpts1.shape[1] < 2:  # no keypoints
            shape0, shape1 = kpts0.shape[:-1], kpts1.shape[:-1]
            # print(data['file_name'])
            return {
                'matches0': kpts0.new_full(shape0, -1, dtype=torch.int)[0],
                # 'matches1': kpts1.new_full(shape1, -1, dtype=torch.int)[0],
                'matching_scores0': kpts0.new_zeros(shape0)[0],
                # 'matching_scores1': kpts1.new_zeros(shape1)[0],
                'skip_train': True
            }

        # file_name = data['file_name']
        all_matches = data['all_matches'] if 'all_matches' in data else None# shape = (1, K1)
        # .permute(1,2,0) # shape=torch.Size([1, 87,])
        
        # positional embedding
        # Keypoint normalization.
        kpts0 = normalize_keypoints(kpts0, data['image0'].shape)
        kpts1 = normalize_keypoints(kpts1, data['image1'].shape)

        # Keypoint MLP encoder.
        # print(data['file_name'])
        # print(kpts0.size())
    
        pos0 = self.kenc(kpts0)
        pos1 = self.kenc(kpts1)
        # print(desc0.size(), pos0.size())
        # print(desc0.size(), pos0.size())
        desc0 = desc0 + pos0
        desc1 = desc1 + pos1

        # self.register_buffer("mask", torch.tril(torch.ones(config.block_size, config.block_size))
                                    #  .view(1, 1, config.block_size, config.block_size))
        # mask0 = ...
        # mask1 = ...

        # Multi-layer Transformer network.
        desc0, desc1 = self.gnn(desc0, desc1, mask00, mask11, mask01, mask10)

        # Final MLP projection.
        mdesc0, mdesc1 = self.final_proj(desc0), self.final_proj(desc1)

        # Compute matching descriptor distance.
        # print(mdesc0.size(), mdesc1.size())
        scores = torch.einsum('bdn,bdm->bnm', mdesc0, mdesc1)
        scores0 = torch.einsum('bdn,bdm->bnm', mdesc0, mdesc0)
        scores1 = torch.einsum('bdn,bdm->bnm', mdesc1, mdesc1)
        # #print('here1!!', scores.size())

        # b k1 k2
        scores = scores / self.config.descriptor_dim**.5
        # print(scores.size(), mask01.size())
        # mask01 = mask0[:, :, None] * mask1[:, None, :]
        # scores = scores.masked_fill(mask01 == 0, float('-inf'))

        # print(scores.size())
        # Run the optimal transport.
        # print(dim_m.size(), dim_m, flush=True)
        scores = log_optimal_transport(
            scores, self.bin_score,
            iters=self.config.sinkhorn_iterations,
            ms=dim_m, ns=dim_n)

        # print(scores)
        # print(scores.sum())
        # print(scores.sum(1))
        # print(scores.sum(0))

        # Get the matches with score above "match_threshold".
        return scores[:, :-1, :-1], scores0, scores1, mdesc0, mdesc1
       

def tensor_erode(bin_img, ksize=5):
    # 首先为原图加入 padding,防止腐蚀后图像尺寸缩小
    B, C, H, W = bin_img.shape
    pad = (ksize - 1) // 2
    bin_img = F.pad(bin_img, [pad, pad, pad, pad], mode='constant', value=0)

    # 将原图 unfold 成 patch
    patches = bin_img.unfold(dimension=2, size=ksize, step=1)
    patches = patches.unfold(dimension=3, size=ksize, step=1)
    # B x C x H x W x k x k

    # 取每个 patch 中最小的值,i.e., 0
    eroded, _ = patches.reshape(B, C, H, W, -1).min(dim=-1)
    return eroded

class InbetweenerM(nn.Module):
    """SuperGlue feature matching middle-end

    Given two sets of keypoints and locations, we determine the
    correspondences by:
      1. Keypoint Encoding (normalization + visual feature and location fusion)
      2. Graph Neural Network with multiple self and cross-attention layers
      3. Final projection layer
      4. Optimal Transport Layer (a differentiable Hungarian matching algorithm)
      5. Thresholding matrix based on mutual exclusivity and a match_threshold

    The correspondence ids use -1 to indicate non-matching points.

    Paul-Edouard Sarlin, Daniel DeTone, Tomasz Malisiewicz, and Andrew
    Rabinovich. SuperGlue: Learning Feature Matching with Graph Neural
    Networks. In CVPR, 2020. https://arxiv.org/abs/1911.11763

    """
    # default_config = {
    #     'descriptor_dim': 128,
    #     'weights': 'indoor',
    #     'keypoint_encoder': [32, 64, 128],
    #     'GNN_layers': ['self', 'cross'] * 9,
    #     'sinkhorn_iterations': 100,
    #     'match_threshold': 0.2,
    # }

    def __init__(self, config=None):
        super().__init__()
        self.corr = SuperGlueM(config.corr_model)
        self.mask_map = MLP([config.corr_model.descriptor_dim, 32, 1])
        self.pos_weight = config.pos_weight
        # self.motion_propagation = 
        
        # assert self.config.weights in ['indoor', 'outdoor']
        # path = Path(__file__).parent
        # path = path / 'weights/superglue_{}.pth'.format(self.config.weights)
        # self.load_state_dict(torch.load(path))
        # print('Loaded SuperGlue model (\"{}\" weights)'.format(
        #     self.config.weights))

    def forward(self, data):
        if 'gen_vid' in data:
            dim_m, dim_n = data['ms'].float(), data['ns'].float()
            mmax = dim_m.int().max()
            nmax = dim_n.int().max()
            # with torch.no_grad():
            #     self.corr.eval()
            score01, score0, score1, dec0, dec1 = self.corr(data)
            kpts0, kpts1 = data['keypoints0'][:,:mmax].float(), data['keypoints1'][:,:nmax].float() # BM2, BN2 
          ##  print(kpts0.mean(), kpts1.mean(), flush=True)

            motion_pred0 = torch.softmax(score01, dim=-1) @ kpts1 - kpts0
            motion_pred1 = torch.softmax(score01.transpose(1, 2), dim=-1) @ kpts0 - kpts1

            motion_pred0 = torch.softmax(score0, dim=-1) @ motion_pred0
            motion_pred1 = torch.softmax(score1, dim=-1) @ motion_pred1

            max0, max1 = score01.max(2), score01.max(1)
            indices0, indices1 = max0.indices, max1.indices
            mutual0 = arange_like(indices0, 1)[None] == indices1.gather(1, indices0)
            mutual1 = arange_like(indices1, 1)[None] == indices0.gather(1, indices1)
            zero = score01.new_tensor(0)

            mscores0 = torch.where(mutual0, max0.values.exp(), zero)
            mscores1 = torch.where(mutual1, mscores0.gather(1, indices1), zero)
            # valid0 = mutual0 & (mscores0 > self.config.match_threshold)
            # valid1 = mutual1 & valid0.gather(1, indices1)
            
            valid0 = mscores0 > 0.2
            valid1 = valid0.gather(1, indices1)
            indices0 = torch.where(valid0, indices0, indices0.new_tensor(-1))
            indices1 = torch.where(valid1, indices1, indices1.new_tensor(-1))

            adj0, adj1 = data['adj_mat0'].float(), data['adj_mat1'].float()

            motion_pred0 = torch.softmax(score01, dim=-1) @ kpts1 - kpts0
            motion_pred1 = torch.softmax(score01.transpose(1, 2), dim=-1) @ kpts0 - kpts1

            # score0.mask_off()

            motion_pred0 = torch.softmax(score0.masked_fill(adj0==0, float('-inf')), dim=-1) @ motion_pred0
            motion_pred1 = torch.softmax(score1.masked_fill(adj1==0, float('-inf')), dim=-1) @ motion_pred1
            
            vb0 = self.mask_map(dec0)[:, 0]
            vb1 = self.mask_map(dec1)[:, 0]
            vb0[:] = 1
            vb1[:] = 1

            im0_erode =  data['image0']
            im1_erode =  data['image1']
            im0_erode[im0_erode > 0] = 1
            im0_erode[im0_erode <= 0] = 0
            im1_erode[im1_erode > 0] = 1
            im1_erode[im1_erode <= 0] = 0
            
            im0_erode = tensor_erode(im0_erode, 3)
            im1_erode = tensor_erode(im1_erode, 3)

            motion_output0, motion_output1 =  motion_pred0.clone(), motion_pred1.clone()
          ##  print('>>>>> here', motion_pred0.mean(), motion_pred1.mean(), flush=True)
            kpt0t = kpts0 + motion_output0 * 1
            kpt1t = kpts1 + motion_output1 * 1
            if 'topo0' in data and 'topo1' in data:
              ##  print(len(data['topo0'][0]), len(data['topo1']), flush=True)
                for node, nbs in enumerate(data['topo0'][0]):
                    for nb in nbs:
                        # print(nb, flush=True)
                        # print(kpt0t.size(), 'fDsafdsafds', flush=True)
                        # if vb0[0, nb] and vb0[0, node] and ((kpt0t[0, node] - kpt0t[0, nb]) ** 2).sum() / (((kpts0[0, node] - kpts0[0, nb]) ** 2).sum() + 1e-7) > 3:
                        #     vb0[0, nb] = -1
                        #     vb0[0, node] = -1
                        # print(node.size())
                        center = ((kpt0t[0, node] + kpt0t[0, nb]) * 0.5).int()[0]
                        # print(center.size(), flush=True)
                        if vb0[0, nb] and vb0[0, node] and im1_erode[0,:, center[1], center[0]].mean() > 0.8:
                            vb0[0, nb] = -1
                            vb0[0, node] = -1
                        # center = ((kpt0t[0, node] + kpt0t[0, nb]) * 0.25).int()[0]
                        # # print(center.size(), flush=True)
                        # if vb0[0, nb] and vb0[0, node] and center[1] < 720 and center[0] < 720 and im1_erode[0,:, center[1], center[0]].mean() > 0.8:
                        #     vb0[0, nb] = -1
                        #     vb0[0, node] = -1
                        # center = ((kpt0t[0, node] + kpt0t[0, nb]) * 0.75).int()[0]
                        # # print(center.size(), flush=True)
                        # if vb0[0, nb] and vb0[0, node] and center[1] < 720 and center[0] < 720 and im1_erode[0,:, center[1], center[0]].mean() > 0.8:
                        #     vb0[0, nb] = -1
                        #     vb0[0, node] = -1
                for node, nbs in enumerate(data['topo1'][0]):
                    for nb in nbs:
                        
                        # if vb1[0, nb] and vb1[0, node] and ((kpt1t[0, node] - kpt1t[0, nb]) ** 2).sum() / (((kpts1[0, node] - kpts1[0, nb]) ** 2).sum() + 1e-7) >3:
                        #     vb1[0, nb] = -1
                        #     vb1[0, node] = -1
                        center = ((kpt1t[0, node] + kpt1t[0, nb]) * 0.5).int()[0]
                        if vb1[0, nb] and vb1[0, node] and im0_erode[0,:, center[1], center[0]].mean() > 0.95:
                            vb1[0, nb] = -1
                            vb1[0, node] = -1
                        # center = ((kpt1t[0, node] + kpt1t[0, nb]) * 0.25).int()[0]
                        # if vb1[0, nb] and vb1[0, node] and center[1] < 720 and center[0] < 720 and im0_erode[0,:, center[1], center[0]].mean() > 0.95:
                        #     vb1[0, nb] = -1
                        #     vb1[0, node] = -1
                        # center = ((kpt1t[0, node] + kpt1t[0, nb]) * 0.75).int()[0]
                        # if vb1[0, nb] and vb1[0, node] and center[1] < 720 and center[0] < 720 and im0_erode[0,:, center[1], center[0]].mean() > 0.95:
                        #     vb1[0, nb] = -1
                        #     vb1[0, node] = -1
            # print(vb0.mean(), vb1.mean(), flush=True)
            return {'r0': motion_output0, 'r1': motion_output1, 'vb0':(vb0 > 0).float(), 'vb1':(vb1 > 0).float(),}

        dim_m, dim_n = data['ms'].float(), data['ns'].float()
        mmax = dim_m.int().max()
        nmax = dim_n.int().max()
        # with torch.no_grad():
        #     self.corr.eval()
        score01, score0, score1, dec0, dec1 = self.corr(data)


        kpts0, kpts1 = data['keypoints0'][:,:mmax].float(), data['keypoints1'][:,:nmax].float() # BM2, BN2 


        adj0, adj1 = data['adj_mat0'].float(), data['adj_mat1'].float()

        motion_pred0 = torch.softmax(score01, dim=-1) @ kpts1 - kpts0
        motion_pred1 = torch.softmax(score01.transpose(1, 2), dim=-1) @ kpts0 - kpts1

        # score0.mask_off()

        motion_pred0 = torch.softmax(score0.masked_fill(adj0==0, float('-inf')), dim=-1) @ motion_pred0
        motion_pred1 = torch.softmax(score1.masked_fill(adj1==0, float('-inf')), dim=-1) @ motion_pred1
        
        vb0 = self.mask_map(dec0)[:, 0]
        vb1 = self.mask_map(dec1)[:, 0]

        # motion0_pred, vb0 = pred0[:, :2].permute(0, 2, 1), pred0[:, 2:][:, 0]
        # motion1_pred, vb1 = pred1[:, :2].permute(0, 2, 1), pred1[:, 2:][:, 0]
        
        # delta0, delta1 = motion_delta[:, :, :mmax].permute(0, 2, 1), motion_delta[:, :, mmax:].permute(0, 2, 1)
        # motion_output0, motion_output1 =  motion0 + delta0, motion1 + delta1
        motion_output0, motion_output1 =  motion_pred0.clone(), motion_pred1.clone()

        # print(delta0.max(), delta1.max())
        # vb0 = kpts0.new_ones(motion_pred0[:, :, 0].size()) + 1.0
        # vb1 = kpts1.new_ones(motion_pred1[:, :, 0].size()) + 1.0

        # vb0, vb1 = visibility[:, 0, :mmax], visibility[:, 0, mmax:]
        # mask0, mask1 = mask[:, :mmax].bool(), mask[:, mmax:].bool()
        # vb0_output = vb0.clone()
        # vb1_output = vb1.clone()

        # vb1_output[batch, corr01[corr01 != -1]] = 1.0

        # motion_output0[valid0.bool()] = motion0[valid0.bool()]
        # motion_output1[valid1.bool()] = motion1[valid1.bool()]

        # vb0_output[vb0_output >= 0] = 1.0
        # vb0_output[vb0_output < 0] = 0.0
        # vb1_output[vb1_output >= 0] = 1.0
        # vb1_output[vb1_output < 0 ] = 0.0

        

        kpt0t = kpts0 + motion_output0 / 2
        kpt1t = kpts1 + motion_output1 / 2
        # kpt1t[batch, corr01[corr01 != -1]] = kpt0t[corr01 != -1]
        
        
        ##################################################
        ##  Note Here the mini batch size is 1!!!!!!!!  ##
        ##################################################

        if 'topo0' in data and 'topo1' in data:
            # print(len(data['topo0'][0]), len(data['topo1']), flush=True)
            for node, nbs in enumerate(data['topo0'][0]):
                for nb in nbs:
                    if vb0[0, nb] and vb0[0, node] and ((kpt0t[0, node] - kpt0t[0, nb]) ** 2).sum() / (((kpts0[0, node] - kpts0[0, nb]) ** 2).sum() + 1e-7) > 5:
                        vb0[0, nb] = -1
                        vb0[0, node] = -1
            for node, nbs in enumerate(data['topo1'][0]):
                for nb in nbs:
                    if vb1[0, nb] and vb1[0, node] and ((kpt1t[0, node] - kpt1t[0, nb]) ** 2).sum() / (((kpts1[0, node] - kpts1[0, nb]) ** 2).sum() + 1e-7) > 5:
                        vb1[0, nb] = -1
                        vb1[0, node] = -1

        if 'motion0' in data and 'motion1' in data:
            # valid_motion0 = motion_output0[mask0[:, :, None].repeat(1, 1, 2)]
            # gt_valid_motion0 = data['motion0'][:, :mmax][mask0[:, :, None].repeat(1, 1, 2)].float()
            # valid_motion1 = motion_output1[mask1[:, :, None].repeat(1, 1, 2)]
            # gt_valid_motion1 = data['motion1'][:, :nmax][mask1[:, :, None].repeat(1, 1, 2)].float()

            loss_motion = torch.nn.functional.l1_loss(motion_pred0, data['motion0'][:, :mmax]) +\
                torch.nn.functional.l1_loss(motion_pred1, data['motion1'][:, :nmax])
            
            # loss_valid0 = ((corr01 == -1) & (mask0 == 1))
            # loss_valid1 = ((corr10 == -1) & (mask1 == 1))
            EPE0 = ((motion_pred0 - data['motion0'][:, :mmax]) ** 2).sum(dim=-1).sqrt()
            EPE1 = ((motion_pred1 - data['motion1'][:, :nmax]) ** 2).sum(dim=-1).sqrt()
            # print(EPE0.size(), 'fdsafdsa')

            EPE = (EPE0.mean() + EPE1.mean()) * 0.5
            # print(len(EPE0[mask0]), len(EPE1[mask1]))
            # print(vb0[:, :mmax][mask0], vb0[:, :mmax][mask0].shape, data['visibility0'][:, :mmax][mask0], data['visibility0'][:, :mmax][mask0].shape)
            # print(.size())
            # print((vb0[:, :mmax] > 0).float().sum(), data['visibility0'][:, :mmax].float().sum())
            # pos_weight=vb0.new_tensor([0.5])
            if 'visibility0' in data and 'visibility1' in data:
                loss_visibility = torch.nn.functional.binary_cross_entropy_with_logits(vb0[:, :mmax].view(-1, 1), data['visibility0'][:, :mmax].view(-1, 1), pos_weight=vb0.new_tensor([self.pos_weight])) + \
                torch.nn.functional.binary_cross_entropy_with_logits(vb1[:, :nmax].view(-1, 1), data['visibility1'][:, :nmax].view(-1, 1), pos_weight=vb0.new_tensor([self.pos_weight]))
            
                VB_Acc = ((((vb0 > 0).float() == data['visibility0'][:, :mmax]).float().sum() + ((vb1 > 0).float() == data['visibility1'][:, :nmax]).float().sum()) * 1.0 / (mmax + nmax))
            else:
                loss_visibility = 0
                VB_Acc = EPE.new_zeros([1])
            loss = loss_motion + 10 * loss_visibility

            loss_mean = torch.mean(loss)
            # loss_mean = torch.reshape(loss_mean, (1, -1))
            # print(loss_mean, flush=True)

            # print(all_matches[:, :mmax].size(), indices0.size(), mask0.size(), flush=True)
            #print((all_matches[0] == indices0[0]).sum())

            # print(vb1.size(),corr01.size())

            # kpt0t = torch.nn.functional.pad(kpts0 + motion_output0, (0, 0, 0, self.max_len - mmax, 0, 0), mode='constant', value=0)
            # kpt1t = torch.nn.functional.pad(kpts1 + motion_output1, (0, 0, 0, self.max_len - nmax, 0, 0), mode='constant', value=0),

            # kpt1t[:, :nmax][batch, corr01[corr01 != -1]] = kpt0t[:, :mmax][corr01 != -1]

            b, _, _ = motion_pred0.size()
            # batch = torch.arange(b)[:, None].repeat(1, mmax)[corr01 != -1].long()
            # # print(kpts0[corr01 != -1].size(), corr01[corr01 != -1].size())
            # matched_intermediate = (kpts0[(corr01 != -1)] + kpts1[batch, corr01[corr01 != -1].long(), :]) * 0.5
            # motion0[corr01 != -1] = matched_intermediate - kpts0[corr01 != -1]
            # motion1[batch, corr01[corr01 != -1].long(), :] = matched_intermediate - kpts1[batch, corr01[corr01 != -1].long(), :]

            # vb0 = torch.nn.functional.pad(vb0, (0, self.max_len - mmax, 0, 0), mode='constant', value=0),
            # vb1 = torch.nn.functional.pad(vb1, (0, self.max_len - nmax, 0, 0), mode='constant', value=0),

            # self.max_len = 3050
            # VB_Acc = ((((vb0 > 0.5).float() == data['visibility0'][:, :mmax]).float().sum() + ((vb1 > 0.5).float() == data['visibility1'][:, :nmax]).float().sum()) * 1.0 / (mmax + nmax))
                
            return {
                # 'matches0': indices0, # use -1 for invalid match
                # 'matches1': indices1[0], # use -1 for invalid match
                # 'matching_scores0': mscores0,
                # 'matching_scores1': mscores1[0],
                # 'keypointst0': torch.nn.functional.pad(kpts0 + motion_output0, (0, 0, 0, self.max_len - mmax, 0, 0), mode='constant', value=0),
                # 'keypointst1': torch.nn.functional.pad(kpts1 + motion_output1, (0, 0, 0, self.max_len - nmax, 0, 0), mode='constant', value=0),
                # 'vb0': torch.nn.functional.pad(vb0, (0, self.max_len - mmax, 0, 0), mode='constant', value=0),
                # 'vb1': torch.nn.functional.pad(vb1, (0, self.max_len - nmax, 0, 0), mode='constant', value=0),
                'keypoints0t': kpt0t,
                'keypoints1t': kpt1t,
                'vb0': (vb0 > 0).float(),
                'vb1': (vb1 > 0).float(),
                'loss': loss_mean,
                'EPE': EPE,
                'Visibility Acc': VB_Acc
                # ((((vb0[mask0] > 0).float() == data['visibility0'][:, :mmax][mask0]).float().sum() + ((vb1[mask1] > 0).float() == data['visibility1'][:, :nmax][mask1]).float().sum()) * 1.0 / (mask0.float().sum() + mask1.float().sum())),
                # 'skip_train': [False],
                # 'accuracy': (((all_matches[:, :mmax] == indices0) & mask0.bool()).sum() / mask0.sum()).item(),
                # 'valid_accuracy': (((all_matches[:, :mmax] == indices0) & (all_matches[:, :mmax] != -1) & mask0.bool()).float().sum() / ((all_matches[:, :mmax] != -1) & mask0.bool()).float().sum()).item(),
            }
        else:
            return {
                'loss': -1,
                'skip_train': True,
                'keypointst0': kpts0 + motion_output0,
                'keypointst1': kpts1 + motion_output1,
                'vb0': vb0,
                'vb1': vb1,
                # 'accuracy': -1,
                # 'area_accuracy': -1,
                # 'valid_accuracy': -1,
            }


if __name__ == '__main__':

    args = argparse.Namespace()
    args.batch_size = 2
    args.gap = 5
    args.type = 'train'
    args.model = None
    args.action = None
    ss = Refiner()


    loader = fetch_dataloader(args)
    # #print(len(loader))
    for data in loader:
        # p1, p2, s1, s2, mi = data
        dict1 = data

        kp1 = dict1['keypoints0']
        kp2 = dict1['keypoints1']
        p1 = dict1['image0']
        p2 = dict1['image1']  

        # #print(s1)
        # #print(s1.type)
        mi = dict1['m01']
        fname = dict1['file_name'] 
        print(dict1['keypoints0'].size(), dict1['keypoints1'].size(), dict1['m01'].size(), dict1['motion0'].size(), dict1['mask0'].size())
        # print(kp1.shape, p1.shape, mi.shape)  
        # #print(mi.size())  
        # #print(mi)
        # break

        a = ss(data)
        print(dict1['file_name'])
        print(a['loss'])
        print(a['EPE'], a['Visibility Acc'],flush=True)
        a['loss'].backward()

================================================
FILE: models/inbetweener_with_mask_with_spec.py
================================================
from copy import deepcopy
from pathlib import Path
import torch
from torch import nn
# from seg_desc import seg_descriptor
import argparse
import numpy as np
import torch.nn.functional as F
from sknetwork.embedding import Spectral

def MLP(channels: list, do_bn=True):
    """ Multi-layer perceptron """
    n = len(channels)
    layers = []
    for i in range(1, n):
        layers.append(
            nn.Conv1d(channels[i - 1], channels[i], kernel_size=1, bias=True))
        if i < (n-1):
            if do_bn:
                # layers.append(nn.BatchNorm1d(channels[i]))
                layers.append(nn.InstanceNorm1d(channels[i]))
            layers.append(nn.ReLU())
    return nn.Sequential(*layers)


def normalize_keypoints(kpts, image_shape):
    """ Normalize keypoints locations based on image image_shape"""
    _, _, height, width = image_shape
    one = kpts.new_tensor(1)
    size = torch.stack([one*width, one*height])[None]
    center = size / 2
    scaling = size.max(1, keepdim=True).values * 0.7
    return (kpts - center[:, None, :]) / scaling[:, None, :]

class ThreeLayerEncoder(nn.Module):
    """ Joint encoding of visual appearance and location using MLPs"""
    def __init__(self, enc_dim):
        super().__init__()
        # input must be 3 channel (r, g, b)
        self.layer1 = nn.Conv2d(3, enc_dim//4, 7, padding=3)
        self.non_linear1 = nn.ReLU()
        self.layer2 = nn.Conv2d(enc_dim//4, enc_dim//2, 3, padding=1)
        self.non_linear2 = nn.ReLU()
        self.layer3 = nn.Conv2d(enc_dim//2, enc_dim, 3, padding=1)

        self.norm1 = nn.InstanceNorm2d(enc_dim//4)
        self.norm2 = nn.InstanceNorm2d(enc_dim//2)
        self.norm3 = nn.InstanceNorm2d(enc_dim)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                nn.init.constant_(m.bias, 0.0)

    def forward(self, img):
        x = self.non_linear1(self.norm1(self.layer1(img)))
        x = self.non_linear2(self.norm2(self.layer2(x)))
        x = self.norm3(self.layer3(x))
        # x = self.non_linear1(self.layer1(img))
        # x = self.non_linear2(self.layer2(x))
        # x = self.layer3(x)
        return x


class VertexDescriptor(nn.Module):
    """ Joint encoding of visual appearance and location using MLPs"""
    def __init__(self, enc_dim):
        super().__init__()
        self.encoder = ThreeLayerEncoder(enc_dim)
        # self.super_pixel_pooling = 
        # use scatter
        # nn.init.constant_(self.encoder[-1].bias, 0.0)

    def forward(self, img, vtx):
        x = self.encoder(img)
        n, c, h, w = x.size()
        assert((h, w) == img.size()[2:4])
        return x[:, :, torch.round(vtx[0, :, 1]).long(), torch.round(vtx[0, :, 0]).long()]
        # return super_pixel_pooling(x.view(n, c, -1), seg.view(-1).long(), reduce='mean')
        # here return size is [1]xCx|Seg|


class KeypointEncoder(nn.Module):
    """ Joint encoding of visual appearance and location using MLPs"""
    def __init__(self, feature_dim, layers):
        super().__init__()
        self.encoder = MLP([2] + layers + [feature_dim])
        # for m in self.encoder.modules():
        #     if isinstance(m, nn.Conv2d):
        #         nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
        #         nn.init.constant_(m.bias, 0.0)
        nn.init.constant_(self.encoder[-1].bias, 0.0)

    def forward(self, kpts):
        inputs = kpts.transpose(1, 2)

        x = self.encoder(inputs)
        return x

class TopoEncoder(nn.Module):
    """ Joint encoding of visual appearance and location using MLPs"""
    def __init__(self, feature_dim, layers):
        super().__init__()
        self.encoder = MLP([64] + layers + [feature_dim])
        # for m in self.encoder.modules():
        #     if isinstance(m, nn.Conv2d):
        #         nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
        #         nn.init.constant_(m.bias, 0.0)
        nn.init.constant_(self.encoder[-1].bias, 0.0)

    def forward(self, kpts):
        inputs = kpts.transpose(1, 2)
        x = self.encoder(inputs)
        return x


def attention(query, key, value, mask=None):
    dim = query.shape[1]
    scores = torch.einsum('bdhn,bdhm->bhnm', query, key) / dim**.5
    if mask is not None:
        scores = scores.masked_fill(mask==0, float('-inf'))

    prob = torch.nn.functional.softmax(scores, dim=-1)

    return torch.einsum('bhnm,bdhm->bdhn', prob, value), prob


class MultiHeadedAttention(nn.Module):
    """ Multi-head attention to increase model expressivitiy """
    def __init__(self, num_heads: int, d_model: int):
        super().__init__()
        assert d_model % num_heads == 0
        self.dim = d_model // num_heads
        self.num_heads = num_heads
        self.merge = nn.Conv1d(d_model, d_model, kernel_size=1)
        self.proj = nn.ModuleList([deepcopy(self.merge) for _ in range(3)])

    def forward(self, query, key, value, mask=None):
        batch_dim = query.size(0)
        query, key, value = [l(x).view(batch_dim, self.dim, self.num_heads, -1)
                             for l, x in zip(self.proj, (query, key, value))]
        x, prob = attention(query, key, value, mask)
        # self.prob.append(prob)
        return self.merge(x.contiguous().view(batch_dim, self.dim*self.num_heads, -1))


class AttentionalPropagation(nn.Module):
    def __init__(self, feature_dim: int, num_heads: int):
        super().__init__()
        self.attn = MultiHeadedAttention(num_heads, feature_dim)
        self.mlp = MLP([feature_dim*2, feature_dim*2, feature_dim])
        nn.init.constant_(self.mlp[-1].bias, 0.0)

    def forward(self, x, source, mask=None):
        message = self.attn(x, source, source, mask)
        return self.mlp(torch.cat([x, message], dim=1))


class AttentionalGNN(nn.Module):
    def __init__(self, feature_dim: int, layer_names: list):
        super().__init__()
        self.layers = nn.ModuleList([
            AttentionalPropagation(feature_dim, 4)
            for _ in range(len(layer_names))])
        self.names = layer_names

    def forward(self, desc0, desc1, mask00=None, mask11=None, mask01=None, mask10=None):
        for layer, name in zip(self.layers, self.names):
            layer.attn.prob = []
            if name == 'cross':
                src0, src1 = desc1, desc0
                mask0, mask1 = mask01[:, None], mask10[:, None] 
            else:  # if name == 'self':
                src0, src1 = desc0, desc1
                mask0, mask1 = mask00[:, None], mask11[:, None]

            delta0, delta1 = layer(desc0, src0, mask0), layer(desc1, src1, mask1)
            desc0, desc1 = (desc0 + delta0), (desc1 + delta1)
        return desc0, desc1


def log_sinkhorn_iterations(Z, log_mu, log_nu, iters: int):
    """ Perform Sinkhorn Normalization in Log-space for stability"""
    u, v = torch.zeros_like(log_mu), torch.zeros_like(log_nu)
    for _ in range(iters):
        u = log_mu - torch.logsumexp(Z + v.unsqueeze(1), dim=2)
        v = log_nu - torch.logsumexp(Z + u.unsqueeze(2), dim=1)
    return Z + u.unsqueeze(2) + v.unsqueeze(1)


def log_optimal_transport(scores, alpha, iters: int, ms=None, ns=None):
    """ Perform Differentiable Optimal Transport in Log-space for stability"""
    b, m, n = scores.shape
    one = scores.new_tensor(1)
    if ms is  None or ns is  None:
        ms, ns = (m*one).to(scores), (n*one).to(scores)
    # else:
    #     ms, ns = ms.to(scores)[:, None], ns.to(scores)[:, None]
    # here m,n should be parameters not shape

    # ms, ns: (b, )
    bins0 = alpha.expand(b, m, 1)
    bins1 = alpha.expand(b, 1, n)
    alpha = alpha.expand(b, 1, 1)

    # pad additional scores for unmatcheed (to -1)
    # alpha is the learned threshold
    couplings = torch.cat([torch.cat([scores, bins0], -1),
                           torch.cat([bins1, alpha], -1)], 1)

    norm = - (ms + ns).log() # (b, )

    if ms.size()[0] > 0:
        norm = norm[:, None]
        log_mu = torch.cat([norm.expand(b, m), ns.log()[:, None] + norm], dim=-1) # (m + 1)
        log_nu = torch.cat([norm.expand(b, n), ms.log()[:, None] + norm], dim=-1)
    else:
        log_mu = torch.cat([norm.expand(m), ns.log()[None] + norm]) # (m + 1)
        log_nu = torch.cat([norm.expand(n), ms.log()[None] + norm])
        log_mu, log_nu = log_mu[None].expand(b, -1), log_nu[None].expand(b, -1)

    
    Z = log_sinkhorn_iterations(couplings, log_mu, log_nu, iters)

    if ms.size()[0] > 1:
        norm = norm[:, :, None]
    Z = Z - norm  # multiply probabilities by M+N
    return Z


def arange_like(x, dim: int):
    return x.new_ones(x.shape[dim]).cumsum(0) - 1  # traceable in 1.1


class SuperGlueT(nn.Module):
    """SuperGlue feature matching middle-end

    Given two sets of keypoints and locations, we determine the
    correspondences by:
      1. Keypoint Encoding (normalization + visual feature and location fusion)
      2. Graph Neural Network with multiple self and cross-attention layers
      3. Final projection layer
      4. Optimal Transport Layer (a differentiable Hungarian matching algorithm)
      5. Thresholding matrix based on mutual exclusivity and a match_threshold

    The correspondence ids use -1 to indicate non-matching points.

    Paul-Edouard Sarlin, Daniel DeTone, Tomasz Malisiewicz, and Andrew
    Rabinovich. SuperGlue: Learning Feature Matching with Graph Neural
    Networks. In CVPR, 2020. https://arxiv.org/abs/1911.11763

    """
    # default_config = {
    #     'descriptor_dim': 128,
    #     'weights': 'indoor',
    #     'keypoint_encoder': [32, 64, 128],
    #     'GNN_layers': ['self', 'cross'] * 9,
    #     'sinkhorn_iterations': 100,
    #     'match_threshold': 0.2,
    # }

    def __init__(self, config=None):
        super().__init__()

        default_config = argparse.Namespace()
        default_config.descriptor_dim = 128

        default_config.keypoint_encoder = [32, 64, 128]
        default_config.GNN_layers = ['self', 'cross'] * 9
        default_config.sinkhorn_iterations = 100
        default_config.match_threshold = 0.2
        self.spectral = Spectral(64,  normalized=False)


        if config is None:
            self.config = default_config
        else:
            self.config = config   
            self.config.GNN_layers = ['self', 'cross'] * self.config.GNN_layer_num

        self.kenc = KeypointEncoder(
            self.config.descriptor_dim, self.config.keypoint_encoder)

        self.tenc = TopoEncoder(
            self.config.descriptor_dim, [96])


        self.gnn = AttentionalGNN(
            self.config.descriptor_dim, self.config.GNN_layers)

        self.final_proj = nn.Conv1d(
            self.config.descriptor_dim, self.config.descriptor_dim,
            kernel_size=1, bias=True)

        bin_score = torch.nn.Parameter(torch.tensor(1.))
        self.register_parameter('bin_score', bin_score)
        self.vertex_desc = VertexDescriptor(self.config.descriptor_dim)
       

    def forward(self, data):
        kpts0, kpts1 = data['keypoints0'].float(), data['keypoints1'].float()

        ori_mask0, ori_mask1 = data['mask0'].float(), data['mask1'].float()
        dim_m, dim_n = data['ms'].float(), data['ns'].float()

        # spectual embedding of adj matrices
        # here I find that online computation of spectrals are too slow during training
        # so the spectrual embedding is moved to dataset pipeline 
        # such that it can be computed in data preparation by multi-processing cpus
        spec0, spec1 = data['spec0'], data['spec1']
        # spec0, spec1 = np.abs(self.spectral.fit_transform(adj_mat0[0].cpu().numpy())), np.abs(self.spectral.fit_transform(adj_mat1[0].cpu().numpy()))

        mmax = dim_m.int().max()
        nmax = dim_n.int().max()

        mask0 = ori_mask0[:, :mmax]
        mask1 = ori_mask1[:, :nmax]

        kpts0 = kpts0[:, :mmax]
        kpts1 = kpts1[:, :nmax]

        # image context embedding
        desc0, desc1 = self.vertex_desc(data['image0'], kpts0.float()), self.vertex_desc(data['image1'], kpts1.float())

        # add topological embedding
        desc0 = desc0 + self.tenc(desc0.new_tensor(spec0))
        desc1 = desc1 + self.tenc(desc1.new_tensor(spec1))

        # masks here were prepared for synchronized training with bach size > 1, but seems not to work well
        # so the current framework still uses grad accumulation 
        mask00 = torch.ones_like(mask0)[:, :, None] * mask0[:, None, :]
        
        mask11 = torch.ones_like(mask1)[:, :, None] * mask1[:, None, :]
        mask01 = torch.ones_like(mask0)[:, :, None] * mask1[:, None, :]
        mask10 = torch.ones_like(mask1)[:, :, None] * mask0[:, None, :]
        

        if kpts0.shape[1] < 2 or kpts1.shape[1] < 2:  # no keypoints
            shape0, shape1 = kpts0.shape[:-1], kpts1.shape[:-1]
            # print(data['file_name'])
            return {
                'matches0': kpts0.new_full(shape0, -1, dtype=torch.int)[0],
                # 'matches1': kpts1.new_full(shape1, -1, dtype=torch.int)[0],
                'matching_scores0': kpts0.new_zeros(shape0)[0],
                # 'matching_scores1': kpts1.new_zeros(shape1)[0],
                'skip_train': True
            }

        all_matches = data['all_matches'] if 'all_matches' in data else None# shape = (1, K1)

        # positional embedding
        # Keypoint normalization.
        kpts0 = normalize_keypoints(kpts0, data['image0'].shape)
        kpts1 = normalize_keypoints(kpts1, data['image1'].shape)

        # Keypoint MLP encoder.
        pos0 = self.kenc(kpts0)
        pos1 = self.kenc(kpts1)

        desc0 = desc0 + pos0
        desc1 = desc1 + pos1


        # Multi-layer Transformer network.
        desc0, desc1 = self.gnn(desc0, desc1, mask00, mask11, mask01, mask10)

        # Final MLP projection.
        mdesc0, mdesc1 = self.final_proj(desc0), self.final_proj(desc1)

        # Compute matching descriptor distance.
        scores = torch.einsum('bdn,bdm->bnm', mdesc0, mdesc1)
        scores0 = torch.einsum('bdn,bdm->bnm', mdesc0, mdesc0)
        scores1 = torch.einsum('bdn,bdm->bnm', mdesc1, mdesc1)


        # b k1 k2
        scores = scores / self.config.descriptor_dim**.5


        # Run the optimal transport.
        scores = log_optimal_transport(
            scores, self.bin_score,
            iters=self.config.sinkhorn_iterations,
            ms=dim_m, ns=dim_n)


        # Get the matches with score above "match_threshold".
        return scores[:, :-1, :-1], scores0, scores1, mdesc0, mdesc1
       
def tensor_erode(bin_img, ksize=5):
    B, C, H, W = bin_img.shape
    pad = (ksize - 1) // 2
    bin_img = F.pad(bin_img, [pad, pad, pad, pad], mode='constant', value=0)

    patches = bin_img.unfold(dimension=2, size=ksize, step=1)
    patches = patches.unfold(dimension=3, size=ksize, step=1)
    # B x C x H x W x k x k

    eroded, _ = patches.reshape(B, C, H, W, -1).min(dim=-1)
    return eroded

class InbetweenerTM(nn.Module):
    """AnimeInbet
    The whole pipeline includes
    1. vertex correspondence (vertex embedding + correspondence transformer)
    2. repositioning propagation
    3. vis mask

    vertex corr code is modified from SUPER GLUE 

    Paul-Edouard Sarlin, Daniel DeTone, Tomasz Malisiewicz, and Andrew
    Rabinovich. SuperGlue: Learning Feature Matching with Graph Neural
    Networks. In CVPR, 2020. https://arxiv.org/abs/1911.11763

    """


    def __init__(self, config=None):
        super().__init__()
        # vertex correspondence
        self.corr = SuperGlueT(config.corr_model)
        self.mask_map = MLP([config.corr_model.descriptor_dim, 32, 1])
        self.pos_weight = config.pos_weight
        

    def forward(self, data):
        # if in the mode of video generating
        if 'gen_vid' in data:
            dim_m, dim_n = data['ms'].float(), data['ns'].float()
            mmax = dim_m.int().max()
            nmax = dim_n.int().max()
            with torch.no_grad():
                self.corr.eval()
                score01, score0, score1, dec0, dec1 = self.corr(data)
                kpts0, kpts1 = data['keypoints0'][:,:mmax].float(), data['keypoints1'][:,:nmax].float() # BM2, BN2 

                motion_pred0 = torch.softmax(score01, dim=-1) @ kpts1 - kpts0
                motion_pred1 = torch.softmax(score01.transpose(1, 2), dim=-1) @ kpts0 - kpts1

            motion_pred0 = torch.softmax(score0, dim=-1) @ motion_pred0
            motion_pred1 = torch.softmax(score1, dim=-1) @ motion_pred1
            
            self.mask_map.eval()
            vb0 = self.mask_map(dec0)[:, 0]
            vb1 = self.mask_map(dec1)[:, 0]

            motion_output0, motion_output1 =  motion_pred0.clone(), motion_pred1.clone()
            kpt0t = kpts0 + motion_output0 
            kpt1t = kpts1 + motion_output1 
            if 'topo0' in data and 'topo1' in data:
            # print(len(data['topo0'][0]), len(data['topo1']), flush=True)
                for node, nbs in enumerate(data['topo0'][0]):
                    for nb in nbs:
                        if vb0[0, nb] and vb0[0, node] and ((kpt0t[0, node] - kpt0t[0, nb]) ** 2).sum() / (((kpts0[0, node] - kpts0[0, nb]) ** 2).sum() + 1e-7) > 3:
                            vb0[0, nb] = 0
                            vb0[0, node] = 0
                for node, nbs in enumerate(data['topo1'][0]):
                    for nb in nbs:
                        if vb1[0, nb] and vb1[0, node] and ((kpt1t[0, node] - kpt1t[0, nb]) ** 2).sum() / (((kpts1[0, node] - kpts1[0, nb]) ** 2).sum() + 1e-7) > 3:
                            vb1[0, nb] = 0
                            vb1[0, node] = 0
            return {'r0': motion_output0, 'r1': motion_output1, 'vb0':vb0, 'vb1':vb1,}

        # in the normal train/test mode
        dim_m, dim_n = data['ms'].float(), data['ns'].float()
        mmax = dim_m.int().max()
        nmax = dim_n.int().max()
        # with torch.no_grad():
        #     self.corr.eval()
        score01, score0, score1, dec0, dec1 = self.corr(data)


        kpts0, kpts1 = data['keypoints0'][:,:mmax].float(), data['keypoints1'][:,:nmax].float() # BM2, BN2 

        motion_pred0 = torch.softmax(score01, dim=-1) @ kpts1 - kpts0
        motion_pred1 = torch.softmax(score01.transpose(1, 2), dim=-1) @ kpts0 - kpts1

        motion_pred0 = torch.softmax(score0, dim=-1) @ motion_pred0
        motion_pred1 = torch.softmax(score1, dim=-1) @ motion_pred1


        max0, max1 = score01.max(2), score01.max(1)
        indices0, indices1 = max0.indices, max1.indices
        mutual0 = arange_like(indices0, 1)[None] == indices1.gather(1, indices0)
        mutual1 = arange_like(indices1, 1)[None] == indices0.gather(1, indices1)
        zero = score01.new_tensor(0)

        mscores0 = torch.where(mutual0, max0.values.exp(), zero)
        mscores1 = torch.where(mutual1, mscores0.gather(1, indices1), zero)
        # valid0 = mutual0 & (mscores0 > self.config.match_threshold)
        # valid1 = mutual1 & valid0.gather(1, indices1)
        
        valid0 = mscores0 > 0.2
        valid1 = valid0.gather(1, indices1)
        indices0 = torch.where(valid0, indices0, indices0.new_tensor(-1))
        indices1 = torch.where(valid1, indices1, indices1.new_tensor(-1))




        # motion_pred1[0][indices1[0]==-1] = 0




        
        vb0 = self.mask_map(dec0)[:, 0]
        vb1 = self.mask_map(dec1)[:, 0]
        motion_output0, motion_output1 =  motion_pred0.clone(), motion_pred1.clone()

        if not self.training:
            motion_pred0[0][indices0[0]!=-1] = kpts1[0][indices0[0][indices0[0]!=-1]] - kpts0[0][indices0[0]!=-1]
        # # motion_pred0[0][indices0[0]==-1] = 0
            motion_pred1[0][indices1[0]!=-1] = kpts0[0][indices1[0][indices1[0]!=-1]] - kpts1[0][indices1[0]!=-1]
            vb0[:] = vb0[:] + 0.7
            vb1[:] = vb1[:] + 0.7

            # motion0_pred, vb0 = pred0[:, :2].permute(0, 2, 1), pred0[:, 2:][:, 0]
            # motion1_pred, vb1 = pred1[:, :2].permute(0, 2, 1), pred1[:, 2:][:, 0]
            
            # delta0, delta1 = motion_delta[:, :, :mmax].permute(0, 2, 1), motion_delta[:, :, mmax:].permute(0, 2, 1)
            # motion_output0, motion_output1 =  motion0 + delta0, motion1 + delta1
            motion_output0, motion_output1 =  motion_pred0.clone(), motion_pred1.clone()

            im0_erode =  data['image0']
            im1_erode =  data['image1']
            im0_erode[im0_erode > 0] = 1
            im0_erode[im0_erode <= 0] = 0
            im1_erode[im1_erode > 0] = 1
            im1_erode[im1_erode <= 0] = 0
            
            im0_erode = tensor_erode(im0_erode, 7)
            im1_erode = tensor_erode(im1_erode, 7)


            
            kpt0t = kpts0 + motion_output0 / 2
            kpt1t = kpts1 + motion_output1 / 2
        
        
            ##################################################
            ##  Note Here the mini batch size is 1!!!!!!!!  ##
            ##################################################

            if 'topo0' in data and 'topo1' in data:
                # print(len(data['topo0'][0]), len(data['topo1']), flush=True)
                for node, nbs in enumerate(data['topo0'][0]):
                    for nb in nbs:
                        if vb0[0, nb] > 0 and vb0[0, node] > 0 and ((kpt0t[0, node] - kpt0t[0, nb]) ** 2).sum() / (((kpts0[0, node] - kpts0[0, nb]) ** 2).sum() + 1e-7) > 5:
                            vb0[0, nb] = -1
                            vb0[0, node] = -1
                for node, nbs in enumerate(data['topo1'][0]):
                    for nb in nbs:
                        if vb1[0, nb] > 0 and vb1[0, node] > 0 and ((kpt1t[0, node] - kpt1t[0, nb]) ** 2).sum() / (((kpts1[0, node] - kpts1[0, nb]) ** 2).sum() + 1e-7) > 5:
                            vb1[0, nb] = -1
                            vb1[0, node] = -1
            
            
            kpt0t = kpts0 + motion_output0 * 1
            kpt1t = kpts1 + motion_output1 * 1
            if 'topo0' in data and 'topo1' in data:
                ##  print(len(data['topo0'][0]), len(data['topo1']), flush=True)
                for node, nbs in enumerate(data['topo0'][0]):
                    for nb in nbs:

                        center = ((kpt0t[0, node] + kpt0t[0, nb]) * 0.5).int()[0]
                        if center[0] >= 720 or center[1] >= 720:
                            continue

                        if vb0[0, nb] > 0 and vb0[0, node] > 0 and im1_erode[0,:, center[1], center[0]].mean() > 0.8:
                            vb0[0, nb] = -1
                            vb0[0, node] = -1

                for node, nbs in enumerate(data['topo1'][0]):
                    for nb in nbs:
                        
                        center = ((kpt1t[0, node] + kpt1t[0, nb]) * 0.5).int()[0]
                        if vb1[0, nb] > 0  and vb1[0, node] > 0 and im0_erode[0,:, center[1], center[0]].mean() > 0.8:
                            vb1[0, nb] = -1
                            vb1[0, node] = -1

        

        kpt0t = kpts0 + motion_output0 / 2
        kpt1t = kpts1 + motion_output1 / 2

        

        if 'motion0' in data and 'motion1' in data:
            loss_motion = torch.nn.functional.l1_loss(motion_pred0, data['motion0'][:, :mmax]) +\
                torch.nn.functional.l1_loss(motion_pred1, data['motion1'][:, :nmax])
            

            EPE0 = ((motion_pred0 - data['motion0'][:, :mmax]) ** 2).sum(dim=-1).sqrt()
            EPE1 = ((motion_pred1 - data['motion1'][:, :nmax]) ** 2).sum(dim=-1).sqrt()
            # print(EPE0.size(), 'fdsafdsa')

            EPE = (EPE0.mean() + EPE1.mean()) * 0.5

            if 'visibility0' in data and 'visibility1' in data:
                loss_visibility = torch.nn.functional.binary_cross_entropy_with_logits(vb0[:, :mmax].view(-1, 1), data['visibility0'][:, :mmax].view(-1, 1), pos_weight=vb0.new_tensor([self.pos_weight])) + \
                torch.nn.functional.binary_cross_entropy_with_logits(vb1[:, :nmax].view(-1, 1), data['visibility1'][:, :nmax].view(-1, 1), pos_weight=vb0.new_tensor([self.pos_weight]))
            
                VB_Acc = ((((vb0 > 0).float() == data['visibility0'][:, :mmax]).float().sum() + ((vb1 > 0).float() == data['visibility1'][:, :nmax]).float().sum()) * 1.0 / (mmax + nmax))
            else:
                loss_visibility = 0
                VB_Acc = EPE.new_zeros([1])
            loss = loss_motion + 10 * loss_visibility

            loss_mean = torch.mean(loss)

            b, _, _ = motion_pred0.size()

            return {
                'keypoints0t': kpt0t,
                'keypoints1t': kpt1t,
                'vb0': (vb0 > 0).float(),
                'vb1': (vb1 > 0).float(),
                'r0': motion_output0,
                'r1': motion_output1,
                'loss': loss_mean,
                'EPE': EPE,
                'Visibility Acc': VB_Acc
            }
        else:
            return {
                'loss': -1,
                'skip_train': True,
                'keypointst0': kpts0 + motion_output0,
                'keypointst1': kpts1 + motion_output1,
                'vb0': vb0,
                'vb1': vb1,
            }


if __name__ == '__main__':

    args = argparse.Namespace()
    args.batch_size = 2
    args.gap = 5
    args.type = 'train'
    args.model = None
    args.action = None
    ss = Refiner()


    loader = fetch_dataloader(args)
    # #print(len(loader))
    for data in loader:
        # p1, p2, s1, s2, mi = data
        dict1 = data

        kp1 = dict1['keypoints0']
        kp2 = dict1['keypoints1']
        p1 = dict1['image0']
        p2 = dict1['image1']  

        # #print(s1)
        # #print(s1.type)
        mi = dict1['m01']
        fname = dict1['file_name'] 
        print(dict1['keypoints0'].size(), dict1['keypoints1'].size(), dict1['m01'].size(), dict1['motion0'].size(), dict1['mask0'].size())
        # print(kp1.shape, p1.shape, mi.shape)  
        # #print(mi.size())  
        # #print(mi)
        # break

        a = ss(data)
        print(dict1['file_name'])
        print(a['loss'])
        print(a['EPE'], a['Visibility Acc'],flush=True)
        a['loss'].backward()

================================================
FILE: requirement.txt
================================================
opencv-python
pyyaml==5.4.1
scikit-network
tqdm
matplotlib
easydict
gdown

================================================
FILE: srun.sh
================================================
#!/bin/sh
currenttime=`date "+%Y%m%d%H%M%S"`
if [ ! -d log ]; then
    mkdir log
fi

echo "[Usage] ./srun.sh config_path [train|eval] partition gpunum"
# check config exists
if [ ! -e $1 ]
then
    echo "[ERROR] configuration file: $1 does not exists!"
    exit
fi


if [ ! -d ${expname} ]; then
    mkdir ${expname}
fi

echo "[INFO] saving results to, or loading files from: "$expname

if [ "$3" == "" ]; then
    echo "[ERROR] enter partition name"
    exit
fi
partition_name=$3
echo "[INFO] partition name: $partition_name"

if [ "$4" == "" ]; then
    echo "[ERROR] enter gpu num"
    exit
fi
gpunum=$4
gpunum=$(($gpunum<8?$gpunum:8))
echo "[INFO] GPU num: $gpunum"
((ntask=$gpunum*3))


TOOLS="srun  --partition=$partition_name -x SG-IDC2-10-51-5-44 --cpus-per-task=16 --gres=gpu:$gpunum -N 1 --mem-per-gpu=32G  --job-name=${config_suffix}"
PYTHONCMD="python -u main.py --config $1"

if [ $2 == "train" ];
then
    $TOOLS $PYTHONCMD \
    --train 
elif [ $2 == "eval" ];
then
    $TOOLS $PYTHONCMD \
    --eval 
elif [ $2 == "gen" ];
then
    $TOOLS $PYTHONCMD \
    --gen 
fi
# elif [ $2 == "visgt" ];
# then
#     $TOOLS $PYTHONCMD \
#     --visgt 
# elif [ $2 == "anl" ];
# then
#     $TOOLS $PYTHONCMD \
#     --anl 
# elif [ $2 == "sample" ];
# then
#     $TOOLS $PYTHONCMD \
#     --sample 
# fi



================================================
FILE: utils/chamfer_distance.py
================================================
import os
import numpy as np
from time import time
import cv2
import pdb
import scipy
import scipy.ndimage
import torch
import torchmetrics

black_threshold = 255.0 * 0.99


def batch_edt(img, block=1024):
    expand = False
    bs,h,w = img.shape
    diam2 = h**2 + w**2
    odtype = img.dtype
    grid = (img.nelement()+block-1) // block

    # cupy implementation

    # default to scipy cpu implementation

    sums = img.sum(dim=(1,2))
    ans = torch.tensor(np.stack([
        scipy.ndimage.morphology.distance_transform_edt(i)
        if s!=0 else  # change scipy behavior for empty image
        np.ones_like(i) * np.sqrt(diam2)
        for i,s in zip(1-img, sums)
    ]), dtype=odtype)

    if expand:
        ans = ans.unsqueeze(1)
    return ans


############### DERIVED DISTANCES ###############

# input: (bs,h,w) or (bs,1,h,w)
# returns: (bs,)
# normalized s.t. metric is same across proportional image scales

# average of two asymmetric distances
# normalized by diameter and area
def batch_chamfer_distance(gt, pred, block=1024, return_more=False):
    t = batch_chamfer_distance_t(gt, pred, block=block)
    p = batch_chamfer_distance_p(gt, pred, block=block)
    cd = (t + p) / 2
    return cd
def batch_chamfer_distance_t(gt, pred, block=1024, return_more=False):
    #pdb.set_trace()
    assert gt.device==pred.device and gt.shape==pred.shape
    bs,h,w = gt.shape[0], gt.shape[-2], gt.shape[-1]
    dpred = batch_edt(pred, block=block)
    cd = (gt*dpred).float().mean((-2,-1)) / np.sqrt(h**2+w**2)
    if len(cd.shape)==2:
        assert cd.shape[1]==1
        cd = cd.squeeze(1)
    return cd
def batch_chamfer_distance_p(gt, pred, block=1024, return_more=False):
    assert gt.device==pred.device and gt.shape==pred.shape
    bs,h,w = gt.shape[0], gt.shape[-2], gt.shape[-1]
    dgt = batch_edt(gt, block=block)
    cd = (pred*dgt).float().mean((-2,-1)) / np.sqrt(h**2+w**2)
    if len(cd.shape)==2:
        assert cd.shape[1]==1
        cd = cd.squeeze(1)
    return cd

# normalized by diameter
# always between [0,1]
def batch_hausdorff_distance(gt, pred, block=1024, return_more=False):
    assert gt.device==pred.device and gt.shape==pred.shape
    bs,h,w = gt.shape[0], gt.shape[-2], gt.shape[-1]
    dgt = batch_edt(gt, block=block)
    dpred = batch_edt(pred, block=block)
    hd = torch.stack([
        (dgt*pred).amax(dim=(-2,-1)),
        (dpred*gt).amax(dim=(-2,-1)),
    ]).amax(dim=0).float() / np.sqrt(h**2+w**2)
    if len(hd.shape)==2:
        assert hd.shape[1]==1
        hd = hd.squeeze(1)
    return hd


############### TORCHMETRICS ###############

class ChamferDistance2dMetric(torchmetrics.Metric):
    full_state_update=False
    def __init__(
            self, block=1024, convert_dog=True, k=1.6, epsilon=0.01, kernel_factor=4, clip=False,
            **kwargs,
        ):
        super().__init__(**kwargs)
        self.block = block
        self.convert_dog = convert_dog

        self.add_state('running_sum', default=torch.tensor(0.0), dist_reduce_fx='sum')
        self.add_state('running_count', default=torch.tensor(0.0), dist_reduce_fx='sum')
        return

    def update(self, preds: torch.Tensor, target: torch.Tensor):
        dist = batch_chamfer_distance(target, preds, block=self.block)
        self.running_sum += dist.sum()
        self.running_count += len(dist)
        return
        
    def compute(self):
        return self.running_sum.float() / self.running_count

class ChamferDistance2dTMetric(ChamferDistance2dMetric):
    def update(self, preds: torch.Tensor, target: torch.Tensor):
        if self.convert_dog:
            preds = (usketchers.batch_dog(preds, **self.dog_params)>0.5).float()
            target = (usketchers.batch_dog(target, **self.dog_params)>0.5).float()
        dist = batch_chamfer_distance_t(target, preds, block=self.block)
        self.running_sum += dist.sum()
        self.running_count += len(dist)
        return
class ChamferDistance2dPMetric(ChamferDistance2dMetric):
    def update(self, preds: torch.Tensor, target: torch.Tensor):
        if self.convert_dog:
            preds = (usketchers.batch_dog(preds, **self.dog_params)>0.5).float()
            target = (usketchers.batch_dog(target, **self.dog_params)>0.5).float()
        dist = batch_chamfer_distance_p(target, preds, block=self.block)
        self.running_sum += dist.sum()
        self.running_count += len(dist)
        return

class HausdorffDistance2dMetric(torchmetrics.Metric):
    def __init__(
            self, block=1024, convert_dog=True,
            t=2.0, sigma=1.0, k=1.6, epsilon=0.01, kernel_factor=4, clip=False,
            **kwargs,
        ):
        super().__init__(**kwargs)
        self.block = block
        self.convert_dog = convert_dog
        self.dog_params = {
            't': t, 'sigma': sigma, 'k': k, 'epsilon': epsilon,
            'kernel_factor': kernel_factor, 'clip': clip,
        }
        self.add_state('running_sum', default=torch.tensor(0.0), dist_reduce_fx='sum')
        self.add_state('running_count', default=torch.tensor(0.0), dist_reduce_fx='sum')
        return
    def update(self, preds: torch.Tensor, target: torch.Tensor):
        if self.convert_dog:
            preds = (usketchers.batch_dog(preds, **self.dog_params)>0.5).float()
            target = (usketchers.batch_dog(target, **self.dog_params)>0.5).float()
        dist = batch_hausdorff_distance(target, preds, block=self.block)
        self.running_sum += dist.sum()
        self.running_count += len(dist)
        return
    def compute(self):
        return self.running_sum.float() / self.running_count




def rgb2sketch(img, black_threshold):
    #pdb.set_trace()
    img[img < black_threshold] = 1
    img[img >= black_threshold] = 0
    #cv2.imwrite("grey.png",img*255)
    return torch.tensor(img)
def rgb2gray(rgb):
    r, g, b = rgb[:,:,0], rgb[:,:,1], rgb[:,:,2]
    gray = 0.2989 * r + 0.5870 * g + 0.1140 * b

    return gray

def cd_score(img1, img2):


    img1 = rgb2gray(img1.astype(float))
    img2 = rgb2gray(img2.astype(float))
    
    img1_sketch = rgb2sketch(img1, black_threshold)
    img2_sketch = rgb2sketch(img2, black_threshold)

    img1_sketch = img1_sketch.unsqueeze(0)
    img2_sketch = img2_sketch.unsqueeze(0)

    CD = ChamferDistance2dMetric()
    cd = CD(img1_sketch,img2_sketch)
    return cd



================================================
FILE: utils/log.py
================================================
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this open-source project.


""" Define the Logger class to print log"""
import os
import sys
import logging
from datetime import datetime


class Logger:
    def __init__(self, args, output_dir):

        log = logging.getLogger(output_dir)
        if not log.handlers:
            log.setLevel(logging.DEBUG)
            # if not os.path.exists(output_dir):
            #     os.mkdir(args.data.output_dir)
            fh = logging.FileHandler(os.path.join(output_dir,'log.txt'))
            fh.setLevel(logging.INFO)
            ch = ProgressHandler()
            ch.setLevel(logging.DEBUG)
            formatter = logging.Formatter(fmt='%(asctime)s %(message)s', datefmt='%m/%d/%Y %I:%M:%S')
            fh.setFormatter(formatter)
            ch.setFormatter(formatter)
            log.addHandler(fh)
            log.addHandler(ch)
        self.log = log
        # setup TensorBoard
        # if args.tensorboard:
        #     from tensorboardX import SummaryWriter
        #     self.writer = SummaryWriter(log_dir=args.output_dir)
        # else:
        self.writer = None
        self.log_per_updates = args.log_per_updates

    def set_progress(self, epoch, total):
        self.log.info(f'Epoch: {epoch}')
        self.epoch = epoch
        self.i = 0
        self.total = total
        self.start = datetime.now()

    def update(self, stats):
        self.i += 1
        if self.i % self.log_per_updates == 0:
            remaining = str((datetime.now() - self.start) / self.i * (self.total - self.i))
            remaining = remaining.split('.')[0]
            updates = stats.pop('updates')
            stats_str = ' '.join(f'{key}[{val:.8f}]' for key, val in stats.items())
            
            self.log.info(f'> epoch [{self.epoch}] updates[{updates}] {stats_str} eta[{remaining}]')
            
            if self.writer:
                for key, val in stats.items():
                    self.writer.add_scalar(f'train/{key}', val, updates)
        if self.i == self.total:
            self.log.debug('\n')
            self.log.debug(f'elapsed time: {str(datetime.now() - self.start).split(".")[0]}')

    def log_eval(self, stats, metrics_group=None):
        stats_str = ' '.join(f'{key}: {val:.8f}' for key, val in stats.items())
        self.log.info(f'valid {stats_str}')
        if self.writer:
            for key, val in stats.items():
                self.writer.add_scalar(f'valid/{key}', val, self.epoch)
        # for mode, metrics in metrics_group.items():
        #     self.log.info(f'evaluation scores ({mode}):')
        #     for key, (val, _) in metrics.items():
        #         self.log.info(f'\t{key} {val:.4f}')
        # if self.writer and metrics_group is not None:
        #     for key, val in stats.items():
        #         self.writer.add_scalar(f'valid/{key}', val, self.epoch)
        #     for key in list(metrics_group.values())[0]:
        #         group = {}
        #         for mode, metrics in metrics_group.items():
        #             group[mode] = metrics[key][0]
        #         self.writer.add_scalars(f'valid/{key}', group, self.epoch)

    def __call__(self, msg):
        self.log.info(msg)


class ProgressHandler(logging.Handler):
    def __init__(self, level=logging.NOTSET):
        super().__init__(level)

    def emit(self, record):
        log_entry = self.format(record)
        if record.message.startswith('> '):
            sys.stdout.write('{}\r'.format(log_entry.rstrip()))
            sys.stdout.flush()
        else:
            sys.stdout.write('{}\n'.format(log_entry))



================================================
FILE: utils/visualize_inbetween.py
================================================
import numpy as np
import torch
import cv2
from .chamfer_distance import cd_score


# def make_inter_graph(v2d1, v2d2, topo1, topo2, match12):
#     valid = (match12 != -1)
#     marked2 = np.zeros(len(v2d2)).astype(bool)
#     # print(match12[valid])
#     marked2[match12[valid]] = True

#     id1toh, id2toh = np.zeros(len(v2d1)), np.zeros(len(v2d2))
#     id1toh[valid] = np.arange(np.sum(valid))
#     id2toh[match12[valid]] = np.arange(np.sum(valid))
#     id1toh[np.invert(valid)] = np.arange(np.sum(1 - valid)) + np.sum(valid)
#     # print(marked2)
#     id2toh[np.invert(marked2)] = len(v2d1) + np.arange(np.sum(np.invert(marked2)))

#     id1toh = id1toh.astype(int)
#     id2toh = id2toh.astype(int)

#     tot_len = len(v2d1) + np.sum(np.invert(marked2))

#     vin1 = v2d1[valid][:]
#     vin2 = v2d2[match12[valid]][:]
#     vh = 0.5 * (vin1 + vin2)
#     vh = np.concatenate((vh, v2d1[np.invert(valid)], v2d2[np.invert(marked2)]), axis=0)

#     topoh = [[] for ii in range(tot_len)]


#     for node in range(len(topo1)):
        
#         for nb in topo1[node]:
#             if int(id1toh[nb]) not in topoh[id1toh[node]]:
#                 topoh[id1toh[node]].append(int(id1toh[nb]))


#     for node in range(len(topo2)):
#         for nb in topo2[node]:
#             if int(id2toh[nb]) not in topoh[id2toh[node]]:
#                 topoh[id2toh[node]].append(int(id2toh[nb]))

#     return vh, topoh


# def make_inter_graph_valid(v2d1, v2d2, topo1, topo2, match12):
#     valid = (match12 != -1)
#     marked2 = np.zeros(len(v2d2)).astype(bool)
#     # print(match12[valid])
#     marked2[match12[valid]] = True

#     id1toh, id2toh = np.zeros(len(v2d1)), np.zeros(len(v2d2))
#     id1toh[valid] = np.arange(np.sum(valid))
#     id2toh[match12[valid]] = np.arange(np.sum(valid))
#     id1toh[np.invert(valid)] = np.arange(np.sum(1 - valid)) + np.sum(valid)
#     # print(marked2)
#     id2toh[np.invert(marked2)] = len(v2d1) + np.arange(np.sum(np.invert(marked2)))

#     id1toh = id1toh.astype(int)
#     id2toh = id2toh.astype(int)

#     tot_len = len(v2d1) + np.sum(np.invert(marked2))

#     vin1 = v2d1[valid][:]
#     vin2 = v2d2[match12[valid]][:]
#     vh = 0.5 * (vin1 + vin2)
#     # vh = np.concatenate((vh, v2d1[np.invert(valid)], v2d2[np.invert(marked2)]), axis=0)

#     # topoh = [[] for ii in range(tot_len)]
#     topoh = [[] for ii in range(np.sum(valid))]

#     for node in range(len(topo1)):
#         if not valid[node]:
#             continue
#         for nb in topo1[node]:
#             if int(id1toh[nb]) not in topoh[id1toh[node]]:
#                 if valid[nb]:
#                     topoh[id1toh[node]].append(int(id1toh[nb]))


#     for node in range(len(topo2)):
#         if not marked2[node]:
#             continue
#         for nb in topo2[node]:
#             if int(id2toh[nb]) not in topoh[id2toh[node]]:
#                 if marked2[nb]:
#                     topoh[id2toh[node]].append(int(id2toh[nb]))

#     return vh, topoh



def visualize(dict):
    # print(dict['keypoints0'].size(), flush=True)
    img1 = ((dict['image0'][0].permute(1, 2, 0).float().numpy() + 1.0) * 255 / 2).astype(np.uint8).copy()
    original_target = ((dict['imaget'][0].permute(1, 2, 0).float().numpy() + 1.0) * 255 / 2).astype(np.uint8).copy()
    # img1p = ((dict['image0'].permute(1, 2, 0).float().numpy() + 1.0) * 255 / 2).astype(int).copy()
    # img2p = ((dict['image1'].permute(1, 2, 0).float().numpy() + 1.0) * 255 / 2).astype(int).copy()

    # img1[:, :, 0] += 255
    # img1[:, :, 1] += 180
    # img1[:, :, 2] += 180
    # img1[img1 > 255] = 255

    # img2[:, :, 0] += 255
    # img2[:, :, 1] += 180
    # img2[:, :, 2] += 180
    # img2[img2 > 255] = 255
    
    # img1p[:, :, 0] += 255
    # img1p[:, :, 1] += 180
    # img1p[:, :, 2] += 180
    # img1p[img1p > 255] = 255
    
    # img2p[:, :, 0] += 255
    # img2p[:, :, 1] += 180
    # img2p[:, :, 2] += 180
    # img2p[img2p > 255] = 255

    # img1, img2, img1p, img2p = img1.astype(np.uint8), img2.astype(np.uint8), img1p.astype(np.uint8), img2p.astype(np.uint8)
    motion01 = dict['motion0'][0].cpu().numpy().astype(int) 
    motion21 = dict['motion1'][0].cpu().numpy().astype(int) 

    source0_warp = dict['keypoints0t'][0].cpu().numpy().astype(int)
    source2_warp = dict['keypoints1t'][0].cpu().numpy().astype(int)
    source0 = dict['keypoints0'][0].cpu().numpy().astype(int)
    source2 = dict['keypoints1'][0].cpu().numpy().astype(int)
    source0_topo = dict['topo0'][0]
    # print(len(dict['topo0']))
    source2_topo = dict['topo1'][0]
    visible01 = dict['vb0'][0].cpu().numpy().astype(int)
    visible21 = dict['vb1'][0].cpu().numpy().astype(int)

    # corr01 = dict['m01'][0].cpu().numpy().astype(int)
    # corr10 = dict['m10'][0].cpu().numpy().astype(int)

    # canvas = np.zeros_like(img1) + 255

    # source0_warp2 = source0 + motion01 // 2
    # source2_warp2 = source2 + motion21 // 2

    # for node, nbs in enumerate(source0_topo):
    #     for nb in nbs:
    #         # print([source0_warp[nb][0], source0_warp[nb][1]])
    #         cv2.line(canvas, [source0_warp[node][0], source0_warp[node][1]], [source0_warp[nb][0], source0_warp[nb][1]], [0, 0, 0], 2)
    # for node, nbs in enumerate(source2_topo):
    #     for nb in nbs:
    #         cv2.line(canvas, [source2_warp[node][0], source2_warp[node][1]], [source2_warp[nb][0], source2_warp[nb][1]], [0, 0, 0], 2)


    # canvas6 = np.zeros_like(img1) + 255


    # for node, nbs in enumerate(source0_topo):
    #     for nb in nbs:
    #         # print([source0_warp[nb][0], source0_warp[nb][1]])
    #         cv2.line(canvas6, [source0_warp2[node][0], source0_warp2[node][1]], [source0_warp2[nb][0], source0_warp2[nb][1]], [0, 0, 0], 2)
    # for node, nbs in enumerate(source2_topo):
    #     for nb in nbs:
    #         cv2.line(canvas6, [source2_warp2[node][0], source2_warp2[node][1]], [source2_warp2[nb][0], source2_warp2[nb][1]], [0, 0, 0], 2)

    canvas2 = np.zeros_like(img1) + 255

  ##  print('huala<<<', source0_warp.mean(), source2_warp.mean(), flush=True)

    # source0_warp = source0 + motion01
    # source2_warp = source2 + motion21

    for node, nbs in enumerate(source0_topo):
        for nb in nbs:
            # if visible01[node] and visible01[nb]:
            cv2.line(canvas2, [source0_warp[node][0], source0_warp[node][1]], [source0_warp[nb][0], source0_warp[nb][1]], [0, 0, 0], 2)
    for node, nbs in enumerate(source2_topo):
        for nb in nbs:
            # if visible21[node] and visible21[nb]:
            cv2.line(canvas2, [source2_warp[node][0], source2_warp[node][1]], [source2_warp[nb][0], source2_warp[nb][1]], [0, 0, 0], 2)

    

    # canvas2
    # black_threshold = 255 // 2
    # img1_sketch = rgb2sketch(img1, black_threshold)
    # img2_sketch = rgb2sketch(img2, black_threshold)

    # img1_sketch = img1_sketch.unsqueeze(0)
    # img2_sketch = img2_sketch.unsqueeze(0)

    # CD = ChamferDistance2dMetric()
    # cd = CD(img1_sketch,img2_sketch)
    canvas5 = np.zeros_like(img1) + 255

    # source0_warp = source0 + motion01
    # source2_warp = source2 + motion21

  ##  print('gulaa>>>', visible01.mean(), visible21.mean(), flush=True)

    for node, nbs in enumerate(source0_topo):
        for nb in nbs:
            if visible01[node] and visible01[nb]:
                cv2.line(canvas5, [source0_warp[node][0], source0_warp[node][1]], [source0_warp[nb][0], source0_warp[nb][1]], [0, 0, 0], 2)
    for node, nbs in enumerate(source2_topo):
        for nb in nbs:
            if visible21[node] and visible21[nb]:
                cv2.line(canvas5, [source2_warp[node][0], source2_warp[node][1]], [source2_warp[nb][0], source2_warp[nb][1]], [0, 0, 0], 2)



    canvas3 = np.zeros_like(img1) + 255
    

    for node, nbs in enumerate(source0_topo):
        for nb in nbs:
            cv2.line(canvas3, [source0[node][0], source0[node][1]], [source0[nb][0], source0[nb][1]], [255, 180, 180], 2)
    for node, nbs i
Download .txt
gitextract_oa9f83a9/

├── .gitignore
├── README.md
├── compute_cd.py
├── configs/
│   └── cr_inbetweener_full.yaml
├── corr/
│   ├── configs/
│   │   └── vtx_corr.yaml
│   ├── datasets/
│   │   ├── __init__.py
│   │   └── ml_dataset.py
│   ├── experiments/
│   │   └── vtx_corr/
│   │       └── ckpt/
│   │           └── .gitkeep
│   ├── main.py
│   ├── models/
│   │   ├── __init__.py
│   │   └── supergluet.py
│   ├── srun.sh
│   ├── utils/
│   │   ├── log.py
│   │   └── visualize_vtx_corr.py
│   └── vtx_matching.py
├── data/
│   └── README.md
├── datasets/
│   ├── __init__.py
│   ├── ml_seq.py
│   └── vd_seq.py
├── download.sh
├── experiments/
│   └── inbetweener_full/
│       └── ckpt/
│           └── .gitkeep
├── inbetween.py
├── inbetween_results/
│   └── .gitkeep
├── main.py
├── models/
│   ├── __init__.py
│   ├── inbetweener_with_mask2.py
│   └── inbetweener_with_mask_with_spec.py
├── requirement.txt
├── srun.sh
└── utils/
    ├── chamfer_distance.py
    ├── log.py
    ├── visualize_inbetween.py
    ├── visualize_inbetween2.py
    ├── visualize_inbetween3.py
    └── visualize_video.py
Download .txt
SYMBOL INDEX (196 symbols across 18 files)

FILE: corr/datasets/ml_dataset.py
  function read_json (line 21) | def read_json(file_path):
  function ids_to_mat (line 36) | def ids_to_mat(id1, id2):
  function adj_matrix (line 55) | def adj_matrix(topology):
  class MixamoLineArt (line 68) | class MixamoLineArt(data.Dataset):
    method __init__ (line 69) | def __init__(self, root, gap=0, split='train', model=None, action=None...
    method __getitem__ (line 127) | def __getitem__(self, index):
    method __rmul__ (line 272) | def __rmul__(self, v):
    method __len__ (line 278) | def __len__(self):
  function worker_init_fn (line 282) | def worker_init_fn(worker_id):
  function fetch_dataloader (line 285) | def fetch_dataloader(args, type='train',):

FILE: corr/main.py
  function parse_args (line 10) | def parse_args():
  function main (line 23) | def main():

FILE: corr/models/supergluet.py
  function MLP (line 10) | def MLP(channels: list, do_bn=True):
  function normalize_keypoints (line 24) | def normalize_keypoints(kpts, image_shape):
  class ThreeLayerEncoder (line 33) | class ThreeLayerEncoder(nn.Module):
    method __init__ (line 35) | def __init__(self, enc_dim):
    method forward (line 53) | def forward(self, img):
  class VertexDescriptor (line 61) | class VertexDescriptor(nn.Module):
    method __init__ (line 63) | def __init__(self, enc_dim):
    method forward (line 68) | def forward(self, img, vtx):
  class KeypointEncoder (line 76) | class KeypointEncoder(nn.Module):
    method __init__ (line 78) | def __init__(self, feature_dim, layers):
    method forward (line 87) | def forward(self, kpts):
  class TopoEncoder (line 94) | class TopoEncoder(nn.Module):
    method __init__ (line 96) | def __init__(self, feature_dim, layers):
    method forward (line 105) | def forward(self, kpts):
  function attention (line 113) | def attention(query, key, value, mask=None):
  class MultiHeadedAttention (line 125) | class MultiHeadedAttention(nn.Module):
    method __init__ (line 127) | def __init__(self, num_heads: int, d_model: int):
    method forward (line 135) | def forward(self, query, key, value, mask=None):
  class AttentionalPropagation (line 144) | class AttentionalPropagation(nn.Module):
    method __init__ (line 145) | def __init__(self, feature_dim: int, num_heads: int):
    method forward (line 151) | def forward(self, x, source, mask=None):
  class AttentionalGNN (line 156) | class AttentionalGNN(nn.Module):
    method __init__ (line 157) | def __init__(self, feature_dim: int, layer_names: list):
    method forward (line 164) | def forward(self, desc0, desc1, mask00=None, mask11=None, mask01=None,...
  function log_sinkhorn_iterations (line 179) | def log_sinkhorn_iterations(Z, log_mu, log_nu, iters: int):
  function log_optimal_transport (line 188) | def log_optimal_transport(scores, alpha, iters: int, ms=None, ns=None):
  function arange_like (line 229) | def arange_like(x, dim: int):
  class SuperGlueT (line 233) | class SuperGlueT(nn.Module):
    method __init__ (line 235) | def __init__(self, config=None):
    method forward (line 274) | def forward(self, data):

FILE: corr/utils/log.py
  class Logger (line 12) | class Logger:
    method __init__ (line 13) | def __init__(self, args, output_dir):
    method set_progress (line 38) | def set_progress(self, epoch, total):
    method update (line 45) | def update(self, stats):
    method log_eval (line 62) | def log_eval(self, stats, metrics_group=None):
    method __call__ (line 81) | def __call__(self, msg):
  class ProgressHandler (line 85) | class ProgressHandler(logging.Handler):
    method __init__ (line 86) | def __init__(self, level=logging.NOTSET):
    method emit (line 89) | def emit(self, record):

FILE: corr/utils/visualize_vtx_corr.py
  function make_inter_graph (line 6) | def make_inter_graph(v2d1, v2d2, topo1, topo2, match12):
  function make_inter_graph_valid (line 47) | def make_inter_graph_valid(v2d1, v2d2, topo1, topo2, match12):
  function visualize (line 94) | def visualize(dict):

FILE: corr/vtx_matching.py
  class VtxMat (line 36) | class VtxMat():
    method __init__ (line 37) | def __init__(self, args):
    method train (line 43) | def train(self):
    method eval (line 172) | def eval(self):
    method _build (line 269) | def _build(self):
    method _build_model (line 280) | def _build_model(self):
    method _build_train_loader (line 293) | def _build_train_loader(self):
    method _build_test_loader (line 297) | def _build_test_loader(self):
    method _build_optimizer (line 301) | def _build_optimizer(self):
    method _dir_setting (line 314) | def _dir_setting(self):

FILE: datasets/ml_seq.py
  function read_json (line 23) | def read_json(file_path):
  function matched_motion (line 39) | def matched_motion(v2d1, v2d2, match12, motion_pre=None):
  function unmatched_motion (line 47) | def unmatched_motion(topo1, v2d1, motion12, match12):
  function ids_to_mat (line 76) | def ids_to_mat(id1, id2):
  function adj_matrix (line 97) | def adj_matrix(topology):
  class MixamoLineArtMotionSequence (line 110) | class MixamoLineArtMotionSequence(data.Dataset):
    method __init__ (line 111) | def __init__(self, root, gap=0, split='train', model=None, action=None...
    method __getitem__ (line 184) | def __getitem__(self, index):
    method __rmul__ (line 505) | def __rmul__(self, v):
    method __len__ (line 510) | def __len__(self):
  function worker_init_fn (line 514) | def worker_init_fn(worker_id):
  function fetch_dataloader (line 517) | def fetch_dataloader(args, type='train',):

FILE: datasets/vd_seq.py
  function read_json (line 23) | def read_json(file_path):
  class VideoLinSeq (line 42) | class VideoLinSeq(data.Dataset):
    method __init__ (line 43) | def __init__(self, root, split='train'):
    method __getitem__ (line 78) | def __getitem__(self, index):
    method __rmul__ (line 171) | def __rmul__(self, v):
    method __len__ (line 176) | def __len__(self):
  function worker_init_fn (line 180) | def worker_init_fn(worker_id):
  function fetch_videoloader (line 183) | def fetch_videoloader(args, type='train',):

FILE: inbetween.py
  class DraftRefine (line 40) | class DraftRefine():
    method __init__ (line 41) | def __init__(self, args):
    method train (line 47) | def train(self):
    method eval (line 197) | def eval(self):
    method gen (line 325) | def gen(self):
    method _build (line 393) | def _build(self):
    method _build_model (line 406) | def _build_model(self):
    method _build_train_loader (line 419) | def _build_train_loader(self):
    method _build_test_loader (line 423) | def _build_test_loader(self):
    method _build_video_loader (line 426) | def _build_video_loader(self):
    method _build_optimizer (line 430) | def _build_optimizer(self):
    method _dir_setting (line 443) | def _dir_setting(self):

FILE: main.py
  function parse_args (line 10) | def parse_args():
  function main (line 24) | def main():

FILE: models/inbetweener_with_mask2.py
  function MLP (line 9) | def MLP(channels: list, do_bn=True):
  function normalize_keypoints (line 24) | def normalize_keypoints(kpts, image_shape):
  class ThreeLayerEncoder (line 33) | class ThreeLayerEncoder(nn.Module):
    method __init__ (line 35) | def __init__(self, enc_dim):
    method forward (line 53) | def forward(self, img):
  class VertexDescriptor (line 63) | class VertexDescriptor(nn.Module):
    method __init__ (line 65) | def __init__(self, enc_dim):
    method forward (line 72) | def forward(self, img, vtx):
  class KeypointEncoder (line 81) | class KeypointEncoder(nn.Module):
    method __init__ (line 83) | def __init__(self, feature_dim, layers):
    method forward (line 92) | def forward(self, kpts):
  function attention (line 100) | def attention(query, key, value, mask=None):
  class MultiHeadedAttention (line 119) | class MultiHeadedAttention(nn.Module):
    method __init__ (line 121) | def __init__(self, num_heads: int, d_model: int):
    method forward (line 129) | def forward(self, query, key, value, mask=None):
  class AttentionalPropagation (line 138) | class AttentionalPropagation(nn.Module):
    method __init__ (line 139) | def __init__(self, feature_dim: int, num_heads: int):
    method forward (line 145) | def forward(self, x, source, mask=None):
  class AttentionalGNN (line 150) | class AttentionalGNN(nn.Module):
    method __init__ (line 151) | def __init__(self, feature_dim: int, layer_names: list):
    method forward (line 158) | def forward(self, desc0, desc1, mask00=None, mask11=None, mask01=None,...
  function log_sinkhorn_iterations (line 173) | def log_sinkhorn_iterations(Z, log_mu, log_nu, iters: int):
  function log_optimal_transport (line 182) | def log_optimal_transport(scores, alpha, iters: int, ms=None, ns=None):
  function arange_like (line 223) | def arange_like(x, dim: int):
  class SuperGlueM (line 227) | class SuperGlueM(nn.Module):
    method __init__ (line 254) | def __init__(self, config=None):
    method forward (line 294) | def forward(self, data):
  function tensor_erode (line 402) | def tensor_erode(bin_img, ksize=5):
  class InbetweenerM (line 417) | class InbetweenerM(nn.Module):
    method __init__ (line 444) | def __init__(self, config=None):
    method forward (line 458) | def forward(self, data):

FILE: models/inbetweener_with_mask_with_spec.py
  function MLP (line 11) | def MLP(channels: list, do_bn=True):
  function normalize_keypoints (line 26) | def normalize_keypoints(kpts, image_shape):
  class ThreeLayerEncoder (line 35) | class ThreeLayerEncoder(nn.Module):
    method __init__ (line 37) | def __init__(self, enc_dim):
    method forward (line 55) | def forward(self, img):
  class VertexDescriptor (line 65) | class VertexDescriptor(nn.Module):
    method __init__ (line 67) | def __init__(self, enc_dim):
    method forward (line 74) | def forward(self, img, vtx):
  class KeypointEncoder (line 83) | class KeypointEncoder(nn.Module):
    method __init__ (line 85) | def __init__(self, feature_dim, layers):
    method forward (line 94) | def forward(self, kpts):
  class TopoEncoder (line 100) | class TopoEncoder(nn.Module):
    method __init__ (line 102) | def __init__(self, feature_dim, layers):
    method forward (line 111) | def forward(self, kpts):
  function attention (line 117) | def attention(query, key, value, mask=None):
  class MultiHeadedAttention (line 128) | class MultiHeadedAttention(nn.Module):
    method __init__ (line 130) | def __init__(self, num_heads: int, d_model: int):
    method forward (line 138) | def forward(self, query, key, value, mask=None):
  class AttentionalPropagation (line 147) | class AttentionalPropagation(nn.Module):
    method __init__ (line 148) | def __init__(self, feature_dim: int, num_heads: int):
    method forward (line 154) | def forward(self, x, source, mask=None):
  class AttentionalGNN (line 159) | class AttentionalGNN(nn.Module):
    method __init__ (line 160) | def __init__(self, feature_dim: int, layer_names: list):
    method forward (line 167) | def forward(self, desc0, desc1, mask00=None, mask11=None, mask01=None,...
  function log_sinkhorn_iterations (line 182) | def log_sinkhorn_iterations(Z, log_mu, log_nu, iters: int):
  function log_optimal_transport (line 191) | def log_optimal_transport(scores, alpha, iters: int, ms=None, ns=None):
  function arange_like (line 231) | def arange_like(x, dim: int):
  class SuperGlueT (line 235) | class SuperGlueT(nn.Module):
    method __init__ (line 262) | def __init__(self, config=None):
    method forward (line 300) | def forward(self, data):
  function tensor_erode (line 390) | def tensor_erode(bin_img, ksize=5):
  class InbetweenerTM (line 402) | class InbetweenerTM(nn.Module):
    method __init__ (line 418) | def __init__(self, config=None):
    method forward (line 426) | def forward(self, data):

FILE: utils/chamfer_distance.py
  function batch_edt (line 14) | def batch_edt(img, block=1024):
  function batch_chamfer_distance (line 46) | def batch_chamfer_distance(gt, pred, block=1024, return_more=False):
  function batch_chamfer_distance_t (line 51) | def batch_chamfer_distance_t(gt, pred, block=1024, return_more=False):
  function batch_chamfer_distance_p (line 61) | def batch_chamfer_distance_p(gt, pred, block=1024, return_more=False):
  function batch_hausdorff_distance (line 73) | def batch_hausdorff_distance(gt, pred, block=1024, return_more=False):
  class ChamferDistance2dMetric (line 90) | class ChamferDistance2dMetric(torchmetrics.Metric):
    method __init__ (line 92) | def __init__(
    method update (line 104) | def update(self, preds: torch.Tensor, target: torch.Tensor):
    method compute (line 110) | def compute(self):
  class ChamferDistance2dTMetric (line 113) | class ChamferDistance2dTMetric(ChamferDistance2dMetric):
    method update (line 114) | def update(self, preds: torch.Tensor, target: torch.Tensor):
  class ChamferDistance2dPMetric (line 122) | class ChamferDistance2dPMetric(ChamferDistance2dMetric):
    method update (line 123) | def update(self, preds: torch.Tensor, target: torch.Tensor):
  class HausdorffDistance2dMetric (line 132) | class HausdorffDistance2dMetric(torchmetrics.Metric):
    method __init__ (line 133) | def __init__(
    method update (line 148) | def update(self, preds: torch.Tensor, target: torch.Tensor):
    method compute (line 156) | def compute(self):
  function rgb2sketch (line 162) | def rgb2sketch(img, black_threshold):
  function rgb2gray (line 168) | def rgb2gray(rgb):
  function cd_score (line 174) | def cd_score(img1, img2):

FILE: utils/log.py
  class Logger (line 12) | class Logger:
    method __init__ (line 13) | def __init__(self, args, output_dir):
    method set_progress (line 38) | def set_progress(self, epoch, total):
    method update (line 45) | def update(self, stats):
    method log_eval (line 62) | def log_eval(self, stats, metrics_group=None):
    method __call__ (line 81) | def __call__(self, msg):
  class ProgressHandler (line 85) | class ProgressHandler(logging.Handler):
    method __init__ (line 86) | def __init__(self, level=logging.NOTSET):
    method emit (line 89) | def emit(self, record):

FILE: utils/visualize_inbetween.py
  function visualize (line 95) | def visualize(dict):

FILE: utils/visualize_inbetween2.py
  function visualize (line 95) | def visualize(dict):

FILE: utils/visualize_inbetween3.py
  function visualize (line 95) | def visualize(dict):

FILE: utils/visualize_video.py
  function visvid (line 7) | def visvid(dict, inter_frames=1):
Condensed preview — 35 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (231K chars).
[
  {
    "path": ".gitignore",
    "chars": 73,
    "preview": "*/__pycache__/*\n*.pt\n*.jpg\n*.pyc\ndata/ml100_norm/\ndata/ml144*\ndata/*.zip\n"
  },
  {
    "path": "README.md",
    "chars": 8079,
    "preview": "# AnimeInbet\n\nCode for ICCV 2023 paper \"Deep Geometrized Cartoon Line Inbetweening\"\n\n[[Paper]](https://openaccess.thecvf"
  },
  {
    "path": "compute_cd.py",
    "chars": 1261,
    "preview": "import argparse\nimport cv2\nimport os\nfrom utils.chamfer_distance import cd_score\nimport numpy as np\n\n\n\n\nif __name__ == \""
  },
  {
    "path": "configs/cr_inbetweener_full.yaml",
    "chars": 1049,
    "preview": "model:\n    name: InbetweenerTM\n    corr_model:\n        descriptor_dim: 128\n        keypoint_encoder: [32, 64, 128]\n     "
  },
  {
    "path": "corr/configs/vtx_corr.yaml",
    "chars": 771,
    "preview": "model:\n    name: SuperGlueT\n    descriptor_dim: 128\n    keypoint_encoder: [32, 64, 128]\n    GNN_layer_num: 12\n    sinkho"
  },
  {
    "path": "corr/datasets/__init__.py",
    "chars": 127,
    "preview": "from .ml_dataset import MixamoLineArt\nfrom .ml_dataset import fetch_dataloader\n\n__all__ = ['MixamoLineArt', 'fetch_datal"
  },
  {
    "path": "corr/datasets/ml_dataset.py",
    "chars": 15288,
    "preview": "import numpy as np\nimport torch\nimport torch.utils.data as data\nimport torch.nn.functional as F\n# import networkx as nx\n"
  },
  {
    "path": "corr/experiments/vtx_corr/ckpt/.gitkeep",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "corr/main.py",
    "chars": 927,
    "preview": "from vtx_matching import VtxMat\nimport argparse\nimport os\nimport yaml\nfrom pprint import pprint\nfrom easydict import Eas"
  },
  {
    "path": "corr/models/__init__.py",
    "chars": 248,
    "preview": "from .supergluet import SuperGlueT\n# from .supergluet_wo_OT import SuperGlueTwoOT\n# from .supergluenp import SuperGlue a"
  },
  {
    "path": "corr/models/supergluet.py",
    "chars": 16000,
    "preview": "import numpy as np\nfrom copy import deepcopy\nfrom pathlib import Path\nimport torch\nfrom torch import nn\n\nimport argparse"
  },
  {
    "path": "corr/srun.sh",
    "chars": 979,
    "preview": "#!/bin/sh\ncurrenttime=`date \"+%Y%m%d%H%M%S\"`\nif [ ! -d log ]; then\n    mkdir log\nfi\n\necho \"[Usage] ./srun.sh config_path"
  },
  {
    "path": "corr/utils/log.py",
    "chars": 3673,
    "preview": "# This source code is licensed under the MIT license found in the\n# LICENSE file in the root directory of this open-sour"
  },
  {
    "path": "corr/utils/visualize_vtx_corr.py",
    "chars": 9176,
    "preview": "import numpy as np\nimport torch\nimport cv2\n\n\ndef make_inter_graph(v2d1, v2d2, topo1, topo2, match12):\n    valid = (match"
  },
  {
    "path": "corr/vtx_matching.py",
    "chars": 13129,
    "preview": "\"\"\" This script handling the training process. \"\"\"\nimport os\nimport time\nimport random\nimport argparse\nimport torch\nimpo"
  },
  {
    "path": "data/README.md",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "datasets/__init__.py",
    "chars": 128,
    "preview": "\nfrom .ml_seq import fetch_dataloader\nfrom .vd_seq import fetch_videoloader\n\n__all__ = ['fetch_dataloader', 'fetch_video"
  },
  {
    "path": "datasets/ml_seq.py",
    "chars": 18411,
    "preview": "import numpy as np\nimport torch\nimport torch.utils.data as data\nimport torch.nn.functional as F\n\nimport os\nimport math\ni"
  },
  {
    "path": "datasets/vd_seq.py",
    "chars": 5804,
    "preview": "import numpy as np\nimport torch\nimport torch.utils.data as data\nimport torch.nn.functional as F\n# import networkx as nx\n"
  },
  {
    "path": "download.sh",
    "chars": 68,
    "preview": "cd data\ngdown 1SNRGajIECxNwRp6ZJ0IlY7AEl2mRm2DR\nunzip ml240data.zip\n"
  },
  {
    "path": "experiments/inbetweener_full/ckpt/.gitkeep",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "inbetween.py",
    "chars": 19151,
    "preview": "\"\"\" This script handling the training process. \"\"\"\nimport os\nimport time\nimport random\nimport argparse\nimport torch\nimpo"
  },
  {
    "path": "inbetween_results/.gitkeep",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "main.py",
    "chars": 1025,
    "preview": "from inbetween import DraftRefine\nimport argparse\nimport os\nimport yaml\nfrom pprint import pprint\nfrom easydict import E"
  },
  {
    "path": "models/__init__.py",
    "chars": 848,
    "preview": "# from .transformer_refiner import Refiner\n# from .inbetweener import Inbetweener\n# from .inbetweener_with_mask import I"
  },
  {
    "path": "models/inbetweener_with_mask2.py",
    "chars": 34042,
    "preview": "from copy import deepcopy\nfrom pathlib import Path\nimport torch\nfrom torch import nn\n# from seg_desc import seg_descript"
  },
  {
    "path": "models/inbetweener_with_mask_with_spec.py",
    "chars": 26288,
    "preview": "from copy import deepcopy\nfrom pathlib import Path\nimport torch\nfrom torch import nn\n# from seg_desc import seg_descript"
  },
  {
    "path": "requirement.txt",
    "chars": 73,
    "preview": "opencv-python\npyyaml==5.4.1\nscikit-network\ntqdm\nmatplotlib\neasydict\ngdown"
  },
  {
    "path": "srun.sh",
    "chars": 1308,
    "preview": "#!/bin/sh\ncurrenttime=`date \"+%Y%m%d%H%M%S\"`\nif [ ! -d log ]; then\n    mkdir log\nfi\n\necho \"[Usage] ./srun.sh config_path"
  },
  {
    "path": "utils/chamfer_distance.py",
    "chars": 6330,
    "preview": "import os\nimport numpy as np\nfrom time import time\nimport cv2\nimport pdb\nimport scipy\nimport scipy.ndimage\nimport torch\n"
  },
  {
    "path": "utils/log.py",
    "chars": 3673,
    "preview": "# This source code is licensed under the MIT license found in the\n# LICENSE file in the root directory of this open-sour"
  },
  {
    "path": "utils/visualize_inbetween.py",
    "chars": 11672,
    "preview": "import numpy as np\nimport torch\nimport cv2\nfrom .chamfer_distance import cd_score\n\n\n# def make_inter_graph(v2d1, v2d2, t"
  },
  {
    "path": "utils/visualize_inbetween2.py",
    "chars": 8893,
    "preview": "import numpy as np\nimport torch\nimport cv2\nfrom .chamfer_distance import cd_score\n\n\n# def make_inter_graph(v2d1, v2d2, t"
  },
  {
    "path": "utils/visualize_inbetween3.py",
    "chars": 11731,
    "preview": "import numpy as np\nimport torch\nimport cv2\nfrom .chamfer_distance import cd_score\n\n\n# def make_inter_graph(v2d1, v2d2, t"
  },
  {
    "path": "utils/visualize_video.py",
    "chars": 2509,
    "preview": "import numpy as np\nimport torch\nimport cv2\n\n\n\ndef visvid(dict, inter_frames=1):\n    img1 = ((dict['image0'][0].permute(1"
  }
]

About this extraction

This page contains the full source code of the lisiyao21/AnimeInbet GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 35 files (217.5 KB), approximately 64.5k tokens, and a symbol index with 196 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!