[
  {
    "path": ".gitignore",
    "content": "*/__pycache__/*\n*.pt\n*.jpg\n*.pyc\ndata/ml100_norm/\ndata/ml144*\ndata/*.zip\n"
  },
  {
    "path": "README.md",
    "content": "# AnimeInbet\n\nCode for ICCV 2023 paper \"Deep Geometrized Cartoon Line Inbetweening\"\n\n[[Paper]](https://openaccess.thecvf.com/content/ICCV2023/papers/Siyao_Deep_Geometrized_Cartoon_Line_Inbetweening_ICCV_2023_paper.pdf) | [[Video Demo]](https://youtu.be/iUF-LsqFKpI?si=9FViAZUyFdSfZzS5) | [[Data (Google Drive)]](https://drive.google.com/file/d/1SNRGajIECxNwRp6ZJ0IlY7AEl2mRm2DR/view?usp=sharing) \n\n✨ Do not hesitate to give a star! Thank you! ✨\n\n\n![image](https://github.com/lisiyao21/AnimeInbet/blob/main/figures/inbet_gif.gif)\n\n> We aim to address a significant but understudied problem in the anime industry, namely the inbetweening of cartoon line drawings. Inbetweening involves generating intermediate frames between two black-and-white line drawings and is a time-consuming and expensive process that can benefit from automation. However, existing frame interpolation methods that rely on matching and warping whole raster images are unsuitable for line inbetweening and often produce blurring artifacts that damage the intricate line structures. To preserve the precision and detail of the line drawings, we propose a new approach, AnimeInbet, which geometrizes raster line drawings into graphs of endpoints and reframes the inbetweening task as a graph fusion problem with vertex repositioning. Our method can effectively capture the sparsity and unique structure of line drawings while preserving the details during inbetweening. This is made possible via our novel modules, i.e., vertex geometric embedding, a vertex correspondence Transformer, an effective mechanism for vertex repositioning and a visibility predictor. To train our method, we introduce MixamoLine240, a new dataset of line drawings with ground truth vectorization and matching labels. Our experiments demonstrate that AnimeInbet synthesizes high-quality, clean, and complete intermediate line drawings, outperforming existing methods quantitatively and qualitatively, especially in cases with large motions.\n\n# ML240 Data\n\nThe implementation of AnimeInbet depends on the matching of line vertices in the two adjancent two frames. To supervise the learning of vertex correspondence, we make a large-scale cartoon line sequential data, **MixiamoLine240** (ML240). ML240 contains a training set (100 sequences), a validation set (44 sequences) and a test set (100 sequences). Each sequence i\n\nTo use the data, please first download it from [link](https://drive.google.com/file/d/1SNRGajIECxNwRp6ZJ0IlY7AEl2mRm2DR/view?usp=sharing) and uncompress it into **data** folder under this project directory. After decompression, the data will be like \n\n        data\n          |_ml100_norm\n          |        |_ all\n          |             |_frames  \n          |             |    |_chip_abe\n          |             |    |     |_Image0001.png\n          |             |    |     |_Image0001.png\n          |             |    |     |\n          |             |    |     ...  \n          |             |    ... \n          |             |\n          |             |_labels\n          |                  |_chip_abe\n          |                  |     |_Line0001.json\n          |                  |     |_Line0001.json\n          |                  |     |\n          |                  |     ...  \n          |                  ...\n          | \n          |_ml144_norm_100_44_split  \n                  |_ test\n                  |    |_frames  \n                  |    |    |_breakdance_1990_police\n                  |    |    |     |_Image0001.png\n                  |    |    |     |_Image0001.png\n                  |    |    |     |\n                  |    |    |     ...  \n                  |    |    ... \n                  |    |\n                  |    |_labels\n                  |         |_breakdance_1990_police\n                  |         |     |_Line0001.json\n                  |         |     |_Line0001.json\n                  |         |     |\n                  |         |     ...  \n                  |         ...\n                  |_ train\n                      |_frames  \n                      |    |_breakdance_1990_ganfaul\n                      |    |     |_Image0001.png\n                      |    |     |_Image0001.png\n                      |    |     |\n                      |    |     ...  \n                      |    ... \n                      |\n                      |_labels\n                          |_breakdance_1990_ganfaul\n                          |     |_Line0001.json\n                          |     |_Line0001.json\n                          |     |\n                          |     ...  \n                          ...\n\n\nThe json file in the \"labels\" folder (for example, ml100_norm/all/labels/chip_abe/Line0001.json) is the verctorization/geometrization labels of the corresponding image in the \"frames\" folder (ml100_norm/ all/frames/chip_abe/_Image0001.png). Each json file contains there components. (1) **vertex location**: line art vertices 2D positions, (2) **connection**: adjancent table of the vector graph and (3) **original index**: the index number of each vertex in the original 3D mesh.\n\n\n# Code\n\n## Environment \n\n    conda create -n inbetween python=3.8\n    conda activate inbetween\n    conda install pytorch==1.7.0 torchvision==0.8.0 torchaudio==0.7.0 cudatoolkit=10.1 -c pytorch\n    pip install -r requirement.txt\n\n\n![image](https://github.com/lisiyao21/AnimeInbet/blob/main/figures/pipeline.png)\n\nIn this code, the whole pipeline is separated into two parts: (1) vertex correspondence and (2) inbetweening/synthesis. In the first part, it is trained to match the vertices of two input vector graphs, including the \"vertex embedding\" and \"vertex corr. Transformer\". Then,  \"repositioning propagation\" and \"graph fusion\" are done in the second part.\n\nThe first part is inner ./corr, and the second is all others. We provide a pretrained correspondence network weight ([link](https://drive.google.com/file/d/1Edc-XGyMXqXDdfBYoglDMkBf7_AYZU0p/view?usp=sharing)) and a pretrained whole pipeline weight ([link](https://drive.google.com/file/d/1cemJCBNdcTvJ9LWCA_5LmDDorwEb-u7M/view?usp=sharing)). For correspondence, please decompress the weight (epoch_50.pt) to ./corr/experiments/vtx_corr/ckpt. For the whole pipeline, please decompress the weight (epoch_20.pt) to ./experiments/inbetweener_full/ckpt/.\n\n\n## Train & test corr.\n\nFor training, first, please cd into the ./corr folder and then run\n\n    sh srun.sh configs/vtx_corr.yaml train [your node name] 1\n\nIf you don't use slurm in your computer/cluster, you can run\n\n    python -u main.py --config vtx_corr.yaml --train \n\nFor testing correspondence network, please run\n\n    sh srun.sh configs/vtx_corr.yaml train [your node name] 1\n\nor \n\n    python -u main.py --config vtx_corr.yaml --test\n\nYou may directly run the test code after downloading the weights without training.\n\n## Train & test the whole inbetweening pipeline\n\nFor training the whole pipeline, please firstly cd out from ./corr to the root project folder and run\n\n    sh srun.sh configs/cr_inbetweener_full.yaml train [your node name] 1\n\nor\n\n    python -u main.py --config cr_inbetweener_full.yaml --train \n\nFor testing, please run\n\n    sh srun.sh configs/cr_inbetweener_full.yaml train [your node name] 1\n\nor \n\n    python -u main.py --config cr_inbetweener_full.yaml --test\n\nInbetweened results will be stored into ./inbetween_results folder.\n\n### Compute CD values\n\nThe CD code is under utils/chamfer_distance.py. Please run\n\n    python compute_cd.py --gt ./data/ml100_norm/all/frames --generated ./inbetween_results/test_gap=5\n\nIf everything goes right the score will be the same as that reported in the paper.\n\n\n# Citation\n\nIf you use our code or data, or find our work inspiring, please kindly cite our paper:\n\n    @inproceedings{siyao2023inbetween,\n\t    title={Deep Geometrized Cartoon Line Inbetweening,\n\t    author={Siyao, Li and Gu, Tianpei and Xiao, Weiye and Ding, Henghui and Liu, Ziwei and Loy, Chen Change},\n\t    booktitle={ICCV},\n\t    year={2023}\n    }\n\n# License\n\nML240 is released with CC BY-NC-SA 4.0. Code is released for non-commercial uses only.\n\n"
  },
  {
    "path": "compute_cd.py",
    "content": "import argparse\nimport cv2\nimport os\nfrom utils.chamfer_distance import cd_score\nimport numpy as np\n\n\n\n\nif __name__ == \"__main__\":\n    cds = []\n\n    parser = argparse.ArgumentParser()\n    parser.add_argument('--generated', type=str)\n    parser.add_argument('--gt', type=str)\n    args = parser.parse_args()\n\n    gen_dir = args.generated\n    gt_dir = args.gt\n\n    if True:\n    \n        print('computing CD...', flush=True)\n\n        for subfolder in os.listdir(gt_dir):\n            # print(subfolder, len(cds), flush=True)\n            for img in os.listdir(os.path.join(gt_dir, subfolder)):\n                if not img.endswith('.png'):\n                    continue\n                img_gt = cv2.imread(os.path.join(gt_dir, subfolder, img))\n\n                pred_name = subfolder + '_' + img.replace('Image', 'Line')\n                if not os.path.exists(os.path.join(gen_dir, pred_name)):\n                    continue\n                img_pred = cv2.imread(os.path.join(gen_dir, pred_name))\n\n                this_cd = cd_score(img_gt, img_pred)\n                cds.append(this_cd)\n                # print(this_cd, flush=True)\n        \n        print('GT: ', gt_dir)\n        print('>>> Gen: ', gen_dir)\n        print('>>> CD: ', np.mean(cds)/1e-5, print(len(cds)))\n\n\n\n"
  },
  {
    "path": "configs/cr_inbetweener_full.yaml",
    "content": "model:\n    name: InbetweenerTM\n    corr_model:\n        descriptor_dim: 128\n        keypoint_encoder: [32, 64, 128]\n        GNN_layer_num: 12\n        sinkhorn_iterations: 20\n        match_threshold: 0.2\n        descriptor_dim: 128\n    pos_weight: 0.2\n\noptimizer:\n    type: Adam\n    kwargs:\n        lr: 0.0001\n        betas: [0.9, 0.999]\n        weight_decay: 0\n    schedular_kwargs:\n        milestones: [80]\n        gamma: 0.1\n\ndata:\n    train:\n        root: 'data/ml144_norm_100_44_split/'\n        batch_size: 1\n        gap: 5\n        type: 'train'\n        model: None\n        action: None\n        mode: 'train'\n    test:\n        root: 'data/ml100_norm/'\n        batch_size: 1\n        gap: 5\n        type: 'all'\n        model: None\n        action: None\n        mode: 'eval'\n        use_vs: False\n\ntesting:\n    ckpt_epoch: 20\n    \nbatch_size: 8\n\ncorr_weights: './corr/experiments/vtx_corr/ckpt/epoch_50.pt'\n\nimwrite_dir: ./inbetween_results/test_gap=5\n\nexpname: inbetweener_full\nepoch: 20\nsave_per_epochs: 1\nlog_per_updates: 1\ntest_freq: 10\nseed: 42\n"
  },
  {
    "path": "corr/configs/vtx_corr.yaml",
    "content": "model:\n    name: SuperGlueT\n    descriptor_dim: 128\n    keypoint_encoder: [32, 64, 128]\n    GNN_layer_num: 12\n    sinkhorn_iterations: 20\n    match_threshold: 0.2\n    descriptor_dim: 128\n\noptimizer:\n    type: Adam\n    kwargs:\n        lr: 0.00001\n        betas: [0.9, 0.999]\n        weight_decay: 0\n    schedular_kwargs:\n        milestones: [50, 150]\n        gamma: 0.1\n\ndata:\n    train:\n        batch_size: 1\n        gap: 5\n        model: None\n        action: None\n        type: 'train'\n        mode: 'train'\n    test:\n        batch_size: 1\n        gap: 5\n        type: 'test'\n        model: None\n        action: None\n        mode: 'eval'\n\ntesting:\n    ckpt_epoch: 50\nbatch_size: 8\n\nexpname: vtx_corr\nepoch: 50\nsave_per_epochs: 1\nlog_per_updates: 1\ntest_freq: 1\nseed: 42\n"
  },
  {
    "path": "corr/datasets/__init__.py",
    "content": "from .ml_dataset import MixamoLineArt\nfrom .ml_dataset import fetch_dataloader\n\n__all__ = ['MixamoLineArt', 'fetch_dataloader']"
  },
  {
    "path": "corr/datasets/ml_dataset.py",
    "content": "import numpy as np\nimport torch\nimport torch.utils.data as data\nimport torch.nn.functional as F\n# import networkx as nx\nimport os\nimport math\nimport random\nfrom glob import glob\nimport os.path as osp\n\nimport sys\nimport argparse\nimport cv2\nfrom collections import Counter\n\nimport json\nimport sknetwork\nfrom sknetwork.embedding import Spectral\n\ndef read_json(file_path):\n    \"\"\"\n        input: json file path\n        output: 2d vertex, connections, and index numbers in original 3D space\n    \"\"\"\n\n    with open(file_path) as file:\n        data = json.load(file)\n        vertex2d = np.array(data['vertex location'])\n        \n        topology = data['connection']\n        index = np.array(data['original index'])\n\n    return vertex2d, topology, index\n\ndef ids_to_mat(id1, id2):\n    \"\"\"\n        inputs are two list of vertex index in original 3D mesh\n    \"\"\"\n    corr1 = np.zeros(len(id1)) - 1.0\n    corr2 = np.zeros(len(id2)) - 1.0\n\n    id1 = np.array(id1).astype(int)[:, None]\n    id2 = np.array(id2).astype(int)\n    \n    mat = (id1 == id2)\n\n    pos12 = np.arange(len(id2))[None].repeat(len(id1), 0)\n    pos21 = np.arange(len(id1))[None].repeat(len(id2), 0)\n    corr1[mat.astype(int).sum(1).astype(bool)] = pos12[mat]\n    corr2[mat.transpose().astype(int).sum(1).astype(bool)] = pos21[mat.transpose()]\n\n    return mat, corr1, corr2\n\ndef adj_matrix(topology):\n    \"\"\"\n        topology is the adj table; returns adj matrix\n    \"\"\"\n    gsize = len(topology)\n    adj = np.zeros((gsize, gsize)).astype(float)\n    for v in range(gsize):\n        adj[v][v] = 1.0\n        for nb in topology[v]:\n            adj[v][nb] = 1.0\n            adj[nb][v] = 1.0\n    return adj\n\nclass MixamoLineArt(data.Dataset):\n    def __init__(self, root, gap=0, split='train', model=None, action=None, mode='train', max_len=3050, use_vs=False):\n        \"\"\"\n            input:\n                root: the root folder of the line art data\n                gap: how many frames between two frames\n                split: train or test\n                model: indicate a specific character (default None)\n                action: indicate a specific action (default None)\n        \"\"\"\n        super(MixamoLineArt, self).__init__()\n\n\n        if model == 'None':\n            model = None\n        if action == 'None':\n            action = None\n\n        self.is_train = True if mode == 'train' else False\n        self.is_eval = True if mode == 'eval' else False\n        # self.is_train = False\n        self.max_len = max_len\n\n        self.image_list = []\n        self.label_list = []\n        \n        if use_vs:\n            label_root = osp.join(root, split, 'labels_vs')\n        else:\n            label_root = osp.join(root, split, 'labels')\n        image_root = osp.join(root, split, 'frames')\n        self.spectral = Spectral(64,  normalized=False)\n\n        for clip in os.listdir(image_root):\n            skip = False\n            if model != None:\n                for mm in model:\n                    if mm in clip:\n                        skip = True\n                \n            if action != None:\n                for aa in action:\n                    if aa in clip:\n                        skip = True\n            if skip:\n                continue\n            image_list = sorted(glob(osp.join(image_root, clip, '*.png')))\n            label_list = sorted(glob(osp.join(label_root, clip, '*.json')))\n            if len(image_list) != len(label_list):\n                print(image_root, flush=True)\n                continue\n            for i in range(len(image_list) - (gap+1)):\n                self.image_list += [ [image_list[i], image_list[i+gap+1]] ]\n            for i in range(len(label_list) - (gap+1)):\n                self.label_list += [ [label_list[i], label_list[i+gap+1]] ]\n        # print(clip)\n        print('Len of Frame is ', len(self.image_list))\n        print('Len of Label is ', len(self.label_list))\n\n    def __getitem__(self, index):\n\n        # load image/label files\n        # image crop to a square, 2d label same operation\n        # index to index matching\n        # spectral embedding\n\n        # test does not need index matching\n        \n        index = index % len(self.image_list)\n        file_name = self.label_list[index][0][:-4]\n  \n        img1 = cv2.imread(self.image_list[index][0])\n        img2 = cv2.imread(self.image_list[index][1])\n        v2d1, topo1, id1 = read_json(self.label_list[index][0])\n        v2d2, topo2, id2 = read_json(self.label_list[index][1])\n        for ii in range(len(topo1)):\n            # if not len(topo1[ii]):\n            topo1[ii].append(ii)\n        for ii in range(len(topo2)):\n            topo2[ii].append(ii)\n\n    \n        m, n = len(v2d1), len(v2d2)\n\n        # img1, v2d1 = crop_img(img1, np.array(v2d1))\n        # img2, v2d2 = crop_img(img2, np.array(v2d2))\n\n        if len(img1.shape) == 2:\n            img1 = np.tile(img1[...,None], (1, 1, 3))\n            img2 = np.tile(img2[...,None], (1, 1, 3))\n        else:\n            img1 = img1[..., :3]\n            img2 = img2[..., :3]\n        \n        img1 = torch.from_numpy(img1).permute(2, 0, 1).float() * 2 / 255.0 - 1.0 \n        img2 = torch.from_numpy(img2).permute(2, 0, 1).float() * 2 / 255.0 - 1.0\n\n        v2d1 = torch.from_numpy(v2d1)\n        v2d2 = torch.from_numpy(v2d2)\n\n        v2d1[v2d1 > 719] = 719\n        v2d1[v2d1 < 0] = 0\n        v2d2[v2d2 > 719] = 719\n        v2d2[v2d2 < 0] = 0\n\n\n        adj1 = sknetwork.data.from_adjacency_list(topo1, matrix_only=True, reindex=False).toarray()\n        adj2 = sknetwork.data.from_adjacency_list(topo2, matrix_only=True, reindex=False).toarray()\n\n        # note here we compute the spectral embedding of adj matrix in data loading period\n        # since it needs cpu computation and is not friendy to our cluster's computation\n        # put them here to use multi-cpu pre-computing before network forward flow\n        spec0, spec1 = np.abs(self.spectral.fit_transform(adj1)), np.abs(self.spectral.fit_transform(adj2))\n\n        mat_index, corr1, corr2 = ids_to_mat(id1, id2)\n        mat_index = torch.from_numpy(mat_index).float()\n        corr1 = torch.from_numpy(corr1).float()\n        corr2 = torch.from_numpy(corr2).float()\n        if self.is_train:\n        # if False:\n            v2d1 = torch.nn.functional.pad(v2d1, (0, 0, 0, self.max_len - m), mode='constant', value=0)\n            v2d2 = torch.nn.functional.pad(v2d2, (0, 0, 0, self.max_len - n), mode='constant', value=0)\n            corr1 = torch.nn.functional.pad(corr1, (0, self.max_len - m), mode='constant', value=0)\n            corr2 = torch.nn.functional.pad(corr2, (0, self.max_len - n), mode='constant', value=0)\n\n            mask0, mask1 = torch.zeros(self.max_len).float(), torch.zeros(self.max_len).float()\n            mask0[:m] = 1\n            mask1[:n] = 1\n        else:\n            mask0, mask1 = torch.ones(m).float(), torch.ones(n).float()\n\n        # not return id anymore. too slow\n        if self.is_eval:\n            return{\n                'keypoints0': v2d1,\n                'keypoints1': v2d2,\n                'topo0': [topo1],\n                'topo1': [topo2],\n                # 'id0': id1,\n                # 'id1': id2,\n                'adj_mat0': spec0,\n                'adj_mat1': spec1,\n                'image0': img1,\n                'image1': img2,\n\n                'all_matches': corr1,\n                'm01': corr1,\n                'm10': corr2,\n                'ms': m,\n                'ns': n,\n                'mask0': mask0,\n                'mask1': mask1,\n                'file_name': file_name,\n                # 'with_match': True\n            } \n        if not self.is_train:\n            return{\n                'keypoints0': v2d1,\n                'keypoints1': v2d2,\n                # 'topo0': topo1,\n                # 'topo1': topo2,\n                # 'id0': id1,\n                # 'id1': id2,\n                'adj_mat0': spec0,\n                'adj_mat1': spec1,\n                'image0': img1,\n                'image1': img2,\n\n                'all_matches': corr1,\n                'm01': corr1,\n                'm10': corr2,\n                'ms': m,\n                'ns': n,\n                'mask0': mask0,\n                'mask1': mask1,\n                'file_name': file_name,\n                # 'with_match': True\n            } \n        else:\n            return{\n                'keypoints0': v2d1,\n                'keypoints1': v2d2,\n                # 'topo0': topo1,\n                # 'topo1': topo2,\n                # 'id0': id1,\n                # 'id1': id2,\n                'adj_mat0': spec0,\n                'adj_mat1': spec1,\n                'image0': img1,\n                'image1': img2,\n\n                'all_matches': corr1,\n                'm01': corr1,\n                'm10': corr2,\n                'ms': m,\n                'ns': n,\n                'mask0': mask0,\n                'mask1': mask1,\n                'file_name': file_name,\n                # 'with_match': True\n            } \n\n        \n\n    def __rmul__(self, v):\n        self.index_list = v * self.index_list\n        self.seg_list = v * self.seg_list\n        self.image_list = v * self.image_list\n        return self\n        \n    def __len__(self):\n        return len(self.image_list)\n        \n\ndef worker_init_fn(worker_id):                                                          \n    np.random.seed(np.random.get_state()[1][0] + worker_id)\n\ndef fetch_dataloader(args, type='train',):\n    lineart = MixamoLineArt(root=args.root if hasattr(args, 'root') else '../data/ml144_norm_100_44_split/', gap=args.gap, split=args.type, model=args.model, action=args.action, mode=args.mode if hasattr(args, 'mode') else 'train', use_vs=args.use_vs if hasattr(args, 'use_vs') else False)\n    train_loader = data.DataLoader(lineart, batch_size=args.batch_size, \n        pin_memory=True, shuffle=True, num_workers=8, drop_last=True, worker_init_fn=worker_init_fn)\n\n    if args.mode != 'train':\n        loader = data.DataLoader(lineart, batch_size=args.batch_size, \n            pin_memory=True, shuffle=False, num_workers=8)\n\n    return train_loader\n\n\nif __name__ == '__main__':\n    torch.multiprocessing.set_sharing_strategy('file_system')\n    args = argparse.Namespace()\n    # args.subset = 'agent'\n    args.batch_size = 1\n    args.gap = 5\n    args.type = 'test'\n    args.model = ['ganfaul', 'firlscout', 'jolleen', 'kachujin', 'knight', 'maria', 'michelle', 'peasant', 'timmy', 'uriel']\n    args.action = ['hip_hop', 'slash']\n    # args.model = None\n    # args.action = None\n    args.use_vs = False\n    # args.model = ['warrok', 'police']\n    args.action = ['breakdance', 'capoeira', 'chapa-', 'fist_fight', 'flying', 'climb', 'running', 'reaction', 'magic', 'tripping']\n        \n    args.mode = 'eval'\n    args.root='/mnt/lustre/syli/inbetween/data/12by12/ml144_norm_100_44_split/'\n    # args.stage = 'anime'\n    # args.image_size = (368, 368)\n    # lineart = MixamoLineArt(root='/mnt/lustre/syli/inbetween/data/12by12/ml144/', gap=0, split='train')\n    lineart = fetch_dataloader(args)\n    # lineart = MixamoLineArt(root='/mnt/cache/syli/inbetween/data/ml100_norm/', gap=args.gap, split=args.type, model=args.model, action=args.action, mode=args.mode if hasattr(args, 'mode') else 'train')\n    # train_loader = data.DataLoader(lineart, batch_size=args.batch_size, \n\n    percentage = 0.0\n    vertex_num = 0.0\n    vertex_shift = 0.0\n    vertex_max_shift = 0.0\n    edges = 0.0\n    # for data in loader:\n    #     print(data)\n    #     break\n    unmatched_all = []\n    max_node = 0\n    for dict in lineart:\n        # print(dict['file_name'])\n        # print(dict['file_name'][0], flush=True)\n        v2d1 = dict['keypoints0'].numpy().astype(int)[0]\n        v2d2 = dict['keypoints1'].numpy().astype(int)[0]\n\n        ms = dict['ms'][0]\n        ns = dict['ns'][0]\n        # this_edges \n        topo = dict['topo0'][0]\n        for ii in range(len(topo)):\n            edges += len(topo[ii])\n        # print(len(topo), flush=True)\n\n\n        # print(ms, ns, flush=True)\n        # print(dict['keypoints0'], flush=True)\n        # print(dict['image0'].size(), flush=True)\n        v2d1 = v2d1[:ms]\n        v2d2 = v2d2[:ns]\n        m01 = dict['m01'][0][:ms]\n        # print(m01.shape)\n        # print(np.arange(len(m01))[m01 != -1], m01[m01 != -1])\n        # print(v2d2.shape, v2d1.shape)\n        shift = np.sqrt(((v2d2[m01[m01 != -1].int(), :] * 1.0 - v2d1[np.arange(len(m01))[m01 != -1],:]) ** 2).sum(-1))\n        vertex_num += len(v2d1)\n        vertex_shift += shift.mean()\n        vertex_max_shift += shift.max()\n        percentage += ((m01!=-1).float().sum() * 1.0 / len(m01) * 100.0)\n    \n    print('>>>> gap=', args.gap, ' percentage=', percentage / len(lineart), ' vertex num=', vertex_num*1.0/len(lineart), 'edges num=', edges*1.0/len(lineart)/2, 'vertex shift=', vertex_shift/len(lineart), ' vertex max shift=', vertex_max_shift/len(lineart), flush=True)\n        \n\n        # if len(v2d1) > max_node:\n        #     max_node = len(v2d1)\n        # if len(v2d2) > max_node:\n        #     max_node = len(v2d2)\n    # print(max_node)\n        # print(v2d1.shape)\n        # img1 = ((dict['image0'][0].permute(1, 2, 0).float().numpy() + 1.0) * 255 / 2).astype(np.uint8).copy()\n        # img2 = ((dict['image1'][0].permute(1, 2, 0).float().numpy() + 1.0) * 255 / 2).astype(np.uint8).copy()\n\n        # # print(v2d1.shape, img1.shape, flush=True)\n\n        # for node, nbs in enumerate(dict['topo0']):\n        #     for nb in nbs:\n        #         cv2.line(img1, [v2d1[node][0], v2d1[node][1]], [v2d1[nb][0], v2d1[nb][1]], [255, 180, 180], 2)\n        # colors1, colors2 = {}, {}\n\n        # id1 = dict['id0'][0].numpy()\n        # id2 = dict['id1'][0].numpy()\n        # for index in id1:\n        #     # print(index)\n        #     color = [np.random.randint(0, 256), np.random.randint(0, 256), np.random.randint(0, 256)]\n        #     # for ii in index:\n        #     colors1[index] = color\n        \n        # colors1, colors2 = {}, {}\n\n\n        # for index in id1:\n        #     # print(index)\n        #     color = [np.random.randint(0, 256), np.random.randint(0, 256), np.random.randint(0, 256)]\n        #     colors1[index] = color\n\n        # for i, p in enumerate(v2d1):\n        #     ii = id1[i]\n        #     # print(ii)\n        #     cv2.circle(img1, [int(p[0]), int(p[1])], 1, colors1[ii], 2)\n\n        # unmatched = 0\n        # for ii in id2:\n        #     color = [0, 0, 0]\n        #     this_is_umatched = 1\n        #     colors2[ii] = colors1[ii] if ii in colors1 else color\n        #     if ii in colors1:\n        #         this_is_umatched = 0\n        #     # if ii not in colors1:\n        #     unmatched += this_is_umatched\n\n        # for i, p in enumerate(v2d2):\n        #     ii = id2[i]\n        #     # print(p)\n        #     cv2.circle(img2, [int(p[0]), int( p[1])], 1, colors2[ii], 2)\n\n        # # print('Unmatched in Img 2: ', , '%')\n        # unmatched_all.append(100 - unmatched * 100.0/len(v2d2))\n\n        # im_h = cv2.hconcat([img1, img2])\n        # print('/mnt/lustre/syli/inbetween/AnimeInbetween/corr/datasets/data_check_norm/' + dict['file_name'][0].replace('/', '_') + '.jpg', flush=True)\n        # cv2.imwrite('/mnt/lustre/syli/inbetween/AnimeInbetween/corr/datasets/data_check_norm/' + dict['file_name'][0].replace('/', '_') + '.jpg', im_h)\n\n    # print(np.mean(unmatched_all))\n \n\n"
  },
  {
    "path": "corr/experiments/vtx_corr/ckpt/.gitkeep",
    "content": ""
  },
  {
    "path": "corr/main.py",
    "content": "from vtx_matching import VtxMat\nimport argparse\nimport os\nimport yaml\nfrom pprint import pprint\nfrom easydict import EasyDict\n\n\n\ndef parse_args():\n    parser = argparse.ArgumentParser(\n        description='Anime segment matching')\n    parser.add_argument('--config', default='')\n    # exclusive arguments\n    group = parser.add_mutually_exclusive_group(required=True)\n    group.add_argument('--train', action='store_true')\n    group.add_argument('--eval', action='store_true')\n\n\n    return parser.parse_args()\n\n\ndef main():\n    # parse arguments and load config\n    args = parse_args()\n    with open(args.config) as f:\n        config = yaml.load(f)\n\n    for k, v in vars(args).items():\n        config[k] = v\n    pprint(config)\n\n    config = EasyDict(config)\n    agent = VtxMat(config)\n    print(config)\n\n    if args.train:\n        agent.train()\n    elif args.eval:\n        agent.eval()\n\n\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "corr/models/__init__.py",
    "content": "from .supergluet import SuperGlueT\n# from .supergluet_wo_OT import SuperGlueTwoOT\n# from .supergluenp import SuperGlue as SuperGlueNP\n# from .supergluei import SuperGlue as SuperGlueI\n# from .supergluet2 import SuperGlueT2\n\n__all__ = ['SuperGlueT']"
  },
  {
    "path": "corr/models/supergluet.py",
    "content": "import numpy as np\nfrom copy import deepcopy\nfrom pathlib import Path\nimport torch\nfrom torch import nn\n\nimport argparse\nfrom sknetwork.embedding import Spectral\n\ndef MLP(channels: list, do_bn=True):\n    \"\"\" Multi-layer perceptron \"\"\"\n    n = len(channels)\n    layers = []\n    for i in range(1, n):\n        layers.append(\n            nn.Conv1d(channels[i - 1], channels[i], kernel_size=1, bias=True))\n        if i < (n-1):\n            if do_bn:\n                layers.append(nn.InstanceNorm1d(channels[i]))\n            layers.append(nn.ReLU())\n    return nn.Sequential(*layers)\n\n\ndef normalize_keypoints(kpts, image_shape):\n    \"\"\" Normalize keypoints locations based on image image_shape\"\"\"\n    _, _, height, width = image_shape\n    one = kpts.new_tensor(1)\n    size = torch.stack([one*width, one*height])[None]\n    center = size / 2\n    scaling = size.max(1, keepdim=True).values * 0.7\n    return (kpts - center[:, None, :]) / scaling[:, None, :]\n\nclass ThreeLayerEncoder(nn.Module):\n    \"\"\" Joint encoding of visual appearance and location using MLPs\"\"\"\n    def __init__(self, enc_dim):\n        super().__init__()\n        # input must be 3 channel (r, g, b)\n        self.layer1 = nn.Conv2d(3, enc_dim//4, 7, padding=3)\n        self.non_linear1 = nn.ReLU()\n        self.layer2 = nn.Conv2d(enc_dim//4, enc_dim//2, 3, padding=1)\n        self.non_linear2 = nn.ReLU()\n        self.layer3 = nn.Conv2d(enc_dim//2, enc_dim, 3, padding=1)\n\n        self.norm1 = nn.InstanceNorm2d(enc_dim//4)\n        self.norm2 = nn.InstanceNorm2d(enc_dim//2)\n        self.norm3 = nn.InstanceNorm2d(enc_dim)\n\n        for m in self.modules():\n            if isinstance(m, nn.Conv2d):\n                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')\n                nn.init.constant_(m.bias, 0.0)\n\n    def forward(self, img):\n        x = self.non_linear1(self.norm1(self.layer1(img)))\n        x = self.non_linear2(self.norm2(self.layer2(x)))\n        x = self.norm3(self.layer3(x))\n\n        return x\n\n\nclass VertexDescriptor(nn.Module):\n    \"\"\" Joint encoding of visual appearance and location using MLPs\"\"\"\n    def __init__(self, enc_dim):\n        super().__init__()\n        self.encoder = ThreeLayerEncoder(enc_dim)\n\n\n    def forward(self, img, vtx):\n        x = self.encoder(img)\n        n, c, h, w = x.size()\n        assert((h, w) == img.size()[2:4])\n        return x[:, :, torch.round(vtx[0, :, 1]).long(), torch.round(vtx[0, :, 0]).long()]\n\n\n\nclass KeypointEncoder(nn.Module):\n    \"\"\" Joint encoding of visual appearance and location using MLPs\"\"\"\n    def __init__(self, feature_dim, layers):\n        super().__init__()\n        self.encoder = MLP([2] + layers + [feature_dim])\n        # for m in self.encoder.modules():\n        #     if isinstance(m, nn.Conv2d):\n        #         nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')\n        #         nn.init.constant_(m.bias, 0.0)\n        nn.init.constant_(self.encoder[-1].bias, 0.0)\n\n    def forward(self, kpts):\n        inputs = kpts.transpose(1, 2)\n\n        x = self.encoder(inputs)\n\n        return x\n\nclass TopoEncoder(nn.Module):\n    \"\"\" Joint encoding of visual appearance and location using MLPs\"\"\"\n    def __init__(self, feature_dim, layers):\n        super().__init__()\n        self.encoder = MLP([64] + layers + [feature_dim])\n        # for m in self.encoder.modules():\n        #     if isinstance(m, nn.Conv2d):\n        #         nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')\n        #         nn.init.constant_(m.bias, 0.0)\n        nn.init.constant_(self.encoder[-1].bias, 0.0)\n\n    def forward(self, kpts):\n        inputs = kpts.transpose(1, 2)\n\n        x = self.encoder(inputs)\n\n        return x\n\n\ndef attention(query, key, value, mask=None):\n    dim = query.shape[1]\n    scores = torch.einsum('bdhn,bdhm->bhnm', query, key) / dim**.5\n    if mask is not None:\n        scores = scores.masked_fill(mask==0, float('-inf'))\n\n    prob = torch.nn.functional.softmax(scores, dim=-1)\n\n\n    return torch.einsum('bhnm,bdhm->bdhn', prob, value), prob\n\n\nclass MultiHeadedAttention(nn.Module):\n    \"\"\" Multi-head attention to increase model expressivitiy \"\"\"\n    def __init__(self, num_heads: int, d_model: int):\n        super().__init__()\n        assert d_model % num_heads == 0\n        self.dim = d_model // num_heads\n        self.num_heads = num_heads\n        self.merge = nn.Conv1d(d_model, d_model, kernel_size=1)\n        self.proj = nn.ModuleList([deepcopy(self.merge) for _ in range(3)])\n\n    def forward(self, query, key, value, mask=None):\n        batch_dim = query.size(0)\n        query, key, value = [l(x).view(batch_dim, self.dim, self.num_heads, -1)\n                             for l, x in zip(self.proj, (query, key, value))]\n        x, prob = attention(query, key, value, mask)\n\n        return self.merge(x.contiguous().view(batch_dim, self.dim*self.num_heads, -1))\n\n\nclass AttentionalPropagation(nn.Module):\n    def __init__(self, feature_dim: int, num_heads: int):\n        super().__init__()\n        self.attn = MultiHeadedAttention(num_heads, feature_dim)\n        self.mlp = MLP([feature_dim*2, feature_dim*2, feature_dim])\n        nn.init.constant_(self.mlp[-1].bias, 0.0)\n\n    def forward(self, x, source, mask=None):\n        message = self.attn(x, source, source, mask)\n        return self.mlp(torch.cat([x, message], dim=1))\n\n\nclass AttentionalGNN(nn.Module):\n    def __init__(self, feature_dim: int, layer_names: list):\n        super().__init__()\n        self.layers = nn.ModuleList([\n            AttentionalPropagation(feature_dim, 4)\n            for _ in range(len(layer_names))])\n        self.names = layer_names\n\n    def forward(self, desc0, desc1, mask00=None, mask11=None, mask01=None, mask10=None):\n        for layer, name in zip(self.layers, self.names):\n            layer.attn.prob = []\n            if name == 'cross':\n                src0, src1 = desc1, desc0\n                mask0, mask1 = mask01[:, None], mask10[:, None] \n            else:  # if name == 'self':\n                src0, src1 = desc0, desc1\n                mask0, mask1 = mask00[:, None], mask11[:, None]\n\n            delta0, delta1 = layer(desc0, src0, mask0), layer(desc1, src1, mask1)\n            desc0, desc1 = (desc0 + delta0), (desc1 + delta1)\n        return desc0, desc1\n\n\ndef log_sinkhorn_iterations(Z, log_mu, log_nu, iters: int):\n    \"\"\" Perform Sinkhorn Normalization in Log-space for stability\"\"\"\n    u, v = torch.zeros_like(log_mu), torch.zeros_like(log_nu)\n    for _ in range(iters):\n        u = log_mu - torch.logsumexp(Z + v.unsqueeze(1), dim=2)\n        v = log_nu - torch.logsumexp(Z + u.unsqueeze(2), dim=1)\n    return Z + u.unsqueeze(2) + v.unsqueeze(1)\n\n\ndef log_optimal_transport(scores, alpha, iters: int, ms=None, ns=None):\n    \"\"\" Perform Differentiable Optimal Transport in Log-space for stability\"\"\"\n    b, m, n = scores.shape\n    one = scores.new_tensor(1)\n    if ms is  None or ns is  None:\n        ms, ns = (m*one).to(scores), (n*one).to(scores)\n    # else:\n    #     ms, ns = ms.to(scores)[:, None], ns.to(scores)[:, None]\n    # here m,n should be parameters not shape\n\n    # ms, ns: (b, )\n    bins0 = alpha.expand(b, m, 1)\n    bins1 = alpha.expand(b, 1, n)\n    alpha = alpha.expand(b, 1, 1)\n\n    # pad additional scores for unmatcheed (to -1)\n    # alpha is the learned threshold\n    couplings = torch.cat([torch.cat([scores, bins0], -1),\n                           torch.cat([bins1, alpha], -1)], 1)\n\n    norm = - (ms + ns).log() # (b, )\n    # print(scores.min(), flush=True)\n    if ms.size()[0] > 0:\n        norm = norm[:, None]\n        log_mu = torch.cat([norm.expand(b, m), ns.log()[:, None] + norm], dim=-1) # (m + 1)\n        log_nu = torch.cat([norm.expand(b, n), ms.log()[:, None] + norm], dim=-1)\n        # print(log_nu.min(), log_mu.min(), flush=True)\n    else:\n        log_mu = torch.cat([norm.expand(m), ns.log()[None] + norm]) # (m + 1)\n        log_nu = torch.cat([norm.expand(n), ms.log()[None] + norm])\n        log_mu, log_nu = log_mu[None].expand(b, -1), log_nu[None].expand(b, -1)\n\n    \n    Z = log_sinkhorn_iterations(couplings, log_mu, log_nu, iters)\n\n    if ms.size()[0] > 1:\n        norm = norm[:, :, None]\n    Z = Z - norm  # multiply probabilities by M+N\n    return Z\n\n\ndef arange_like(x, dim: int):\n    return x.new_ones(x.shape[dim]).cumsum(0) - 1  # traceable in 1.1\n\n\nclass SuperGlueT(nn.Module):\n\n    def __init__(self, config=None):\n        super().__init__()\n\n        default_config = argparse.Namespace()\n        default_config.descriptor_dim = 128\n        # default_config.weights = \n        default_config.keypoint_encoder = [32, 64, 128]\n        default_config.GNN_layers = ['self', 'cross'] * 9\n        default_config.sinkhorn_iterations = 100\n        default_config.match_threshold = 0.2\n        # self.config = {**self.default_config, **config}\n\n        if config is None:\n            self.config = default_config\n        else:\n            self.config = config   \n            self.config.GNN_layers = ['self', 'cross'] * self.config.GNN_layer_num\n            # print('WULA!', self.config.GNN_layer_num)\n\n        self.kenc = KeypointEncoder(\n            self.config.descriptor_dim, self.config.keypoint_encoder)\n\n        self.tenc = TopoEncoder(\n            self.config.descriptor_dim, [96])\n\n\n        self.gnn = AttentionalGNN(\n            self.config.descriptor_dim, self.config.GNN_layers)\n\n        self.final_proj = nn.Conv1d(\n            self.config.descriptor_dim, self.config.descriptor_dim,\n            kernel_size=1, bias=True)\n\n        bin_score = torch.nn.Parameter(torch.tensor(1.))\n        self.register_parameter('bin_score', bin_score)\n        self.vertex_desc = VertexDescriptor(self.config.descriptor_dim)\n       \n\n\n    def forward(self, data):\n        \"\"\"Run SuperGlue on a pair of keypoints and descriptors\"\"\"\n\n        kpts0, kpts1 = data['keypoints0'].float(), data['keypoints1'].float()\n\n        ori_mask0, ori_mask1 = data['mask0'].float(), data['mask1'].float()\n        dim_m, dim_n = data['ms'].float(), data['ns'].float()\n\n        spec0, spec1 = data['adj_mat0'], data['adj_mat1']\n\n        mmax = dim_m.int().max()\n        nmax = dim_n.int().max()\n\n        mask0 = ori_mask0[:, :mmax]\n        mask1 = ori_mask1[:, :nmax]\n\n        kpts0 = kpts0[:, :mmax]\n        kpts1 = kpts1[:, :nmax]\n\n        desc0, desc1 = self.vertex_desc(data['image0'], kpts0.float()), self.vertex_desc(data['image1'], kpts1.float())\n        # spec0, spec1 = np.abs(self.spectral.fit_transform(topo0[0].cpu().numpy())), np.abs(self.spectral.fit_transform(topo1[0].cpu().numpy()))\n\n        desc0 = desc0 + self.tenc(desc0.new_tensor(spec0))\n        desc1 = desc1 + self.tenc(desc1.new_tensor(spec1))\n\n        mask00 = torch.ones_like(mask0)[:, :, None] * mask0[:, None, :]\n        \n        mask11 = torch.ones_like(mask1)[:, :, None] * mask1[:, None, :]\n        mask01 = torch.ones_like(mask0)[:, :, None] * mask1[:, None, :]\n        mask10 = torch.ones_like(mask1)[:, :, None] * mask0[:, None, :]\n\n\n        if kpts0.shape[1] < 2 or kpts1.shape[1] < 2:  # no keypoints\n            shape0, shape1 = kpts0.shape[:-1], kpts1.shape[:-1]\n            # print(data['file_name'])\n            return {\n                'matches0': kpts0.new_full(shape0, -1, dtype=torch.int)[0],\n                # 'matches1': kpts1.new_full(shape1, -1, dtype=torch.int)[0],\n                'matching_scores0': kpts0.new_zeros(shape0)[0],\n                # 'matching_scores1': kpts1.new_zeros(shape1)[0],\n                'skip_train': True\n            }\n\n        file_name = data['file_name']\n        all_matches = data['all_matches'] if 'all_matches' in data else None# shape = (1, K1)\n\n        \n        # positional embedding\n        # Keypoint normalization.\n        kpts0 = normalize_keypoints(kpts0, data['image0'].shape)\n        kpts1 = normalize_keypoints(kpts1, data['image1'].shape)\n\n        # Keypoint MLP encoder.\n    \n        pos0 = self.kenc(kpts0)\n        pos1 = self.kenc(kpts1)\n\n        desc0 = desc0 + pos0\n        desc1 = desc1 + pos1\n\n       \n        # Multi-layer Transformer network.\n        desc0, desc1 = self.gnn(desc0, desc1, mask00, mask11, mask01, mask10)\n\n        # Final MLP projection.\n        mdesc0, mdesc1 = self.final_proj(desc0), self.final_proj(desc1)\n\n        # Compute matching descriptor distance.\n        scores = torch.einsum('bdn,bdm->bnm', mdesc0, mdesc1)\n\n        # b k1 k2\n        scores = scores / self.config.descriptor_dim**.5\n\n        mask01 = mask0[:, :, None] * mask1[:, None, :]\n        scores = scores.masked_fill(mask01 == 0, float('-inf'))\n\n\n        # Run the optimal transport.\n        scores = log_optimal_transport(\n            scores, self.bin_score,\n            iters=self.config.sinkhorn_iterations,\n            ms=dim_m, ns=dim_n)\n\n\n        # Get the matches with score above \"match_threshold\".\n        max0, max1 = scores[:, :-1, :-1].max(2), scores[:, :-1, :-1].max(1)\n        indices0, indices1 = max0.indices, max1.indices\n        mutual0 = arange_like(indices0, 1)[None] == indices1.gather(1, indices0)\n        mutual1 = arange_like(indices1, 1)[None] == indices0.gather(1, indices1)\n        zero = scores.new_tensor(0)\n        mscores0 = torch.where(mutual0, max0.values.exp(), zero)\n        mscores1 = torch.where(mutual1, mscores0.gather(1, indices1), zero)\n        valid0 = mutual0 & (mscores0 > self.config.match_threshold)\n        valid1 = mutual1 & valid0.gather(1, indices1)\n        \n        valid0 = mscores0 > self.config.match_threshold\n        valid1 = valid0.gather(1, indices1)\n        indices0 = torch.where(valid0, indices0, indices0.new_tensor(-1))\n        indices1 = torch.where(valid1, indices1, indices1.new_tensor(-1))\n\n        # check if indexed correctly\n\n        loss = []\n\n        \n\n        if all_matches is not None:\n            for b in range(len(dim_m)):\n\n                for i in range(int(dim_m[b])):\n      \n                    x = i\n                    y = all_matches[b][i].long()\n\n                    loss.append(-scores[b][x][y] ) # check batch size == 1 ?\n\n            loss_mean = torch.mean(torch.stack(loss))\n            loss_mean = torch.reshape(loss_mean, (1, -1))\n\n            return {\n                'matches0': indices0, # use -1 for invalid match\n                'matches1': indices1, # use -1 for invalid match\n                'matching_scores0': mscores0,\n                # 'matching_scores1': mscores1[0],\n                'loss': loss_mean,\n                'skip_train': False,\n                'accuracy': (((all_matches[:, :mmax] == indices0) & mask0.bool()).sum() / mask0.sum()).item(),\n                'valid_accuracy': (((all_matches[:, :mmax] == indices0) & (all_matches[:, :mmax] != -1) & mask0.bool()).float().sum() / ((all_matches[:, :mmax] != -1) & mask0.bool()).float().sum()).item(),\n            }\n        else:\n            return {\n                'matches0': indices0[0], # use -1 for invalid match\n                'matching_scores0': mscores0[0],\n                'loss': -1,\n                'skip_train': True,\n                'accuracy': -1,\n                'area_accuracy': -1,\n                'valid_accuracy': -1,\n            }\n\n\nif __name__ == '__main__':\n\n    args = argparse.Namespace()\n    args.batch_size = 1\n    args.gap = 0\n    args.type = 'train'\n    args.model = 'jolleen' \n    args.action = 'slash'\n    ss = SuperGlue()\n\n\n    loader = fetch_dataloader(args)\n    # #print(len(loader))\n    for data in loader:\n        # p1, p2, s1, s2, mi = data\n        dict1 = data\n\n        kp1 = dict1['keypoints0']\n        kp2 = dict1['keypoints1']\n        p1 = dict1['image0']\n        p2 = dict1['image1']  \n\n        # #print(s1)\n        # #print(s1.type)\n        mi = dict1['all_matches']\n        fname = dict1['file_name'] \n        print(kp1.shape, p1.shape, mi.shape)  \n        # #print(mi.size())  \n        # #print(mi)\n        # break\n\n        a = ss(data)\n        print(dict1['file_name'])\n        print(a['loss'])\n        a['loss'].backward()\n        # print(a['matches0'].size())\n        # print(a['accuracy'], a['valid_accuracy'])"
  },
  {
    "path": "corr/srun.sh",
    "content": "#!/bin/sh\ncurrenttime=`date \"+%Y%m%d%H%M%S\"`\nif [ ! -d log ]; then\n    mkdir log\nfi\n\necho \"[Usage] ./srun.sh config_path [train|eval] partition gpunum\"\n# check config exists\nif [ ! -e $1 ]\nthen\n    echo \"[ERROR] configuration file: $1 does not exists!\"\n    exit\nfi\n\n\nif [ ! -d ${expname} ]; then\n    mkdir ${expname}\nfi\n\necho \"[INFO] saving results to, or loading files from: \"$expname\n\nif [ \"$3\" == \"\" ]; then\n    echo \"[ERROR] enter partition name\"\n    exit\nfi\npartition_name=$3\necho \"[INFO] partition name: $partition_name\"\n\nif [ \"$4\" == \"\" ]; then\n    echo \"[ERROR] enter gpu num\"\n    exit\nfi\ngpunum=$4\ngpunum=$(($gpunum<8?$gpunum:8))\necho \"[INFO] GPU num: $gpunum\"\n((ntask=$gpunum*3))\n\n\nTOOLS=\"srun --partition=$partition_name --cpus-per-task=8 --gres=gpu:$gpunum   -N 1 --job-name=${config_suffix}\"\nPYTHONCMD=\"python -u main.py --config $1\"\n\nif [ $2 == \"train\" ];\nthen\n    $TOOLS $PYTHONCMD \\\n    --train \nelif [ $2 == \"eval\" ];\nthen\n    $TOOLS $PYTHONCMD \\\n    --eval \nfi\n"
  },
  {
    "path": "corr/utils/log.py",
    "content": "# This source code is licensed under the MIT license found in the\n# LICENSE file in the root directory of this open-source project.\n\n\n\"\"\" Define the Logger class to print log\"\"\"\nimport os\nimport sys\nimport logging\nfrom datetime import datetime\n\n\nclass Logger:\n    def __init__(self, args, output_dir):\n\n        log = logging.getLogger(output_dir)\n        if not log.handlers:\n            log.setLevel(logging.DEBUG)\n            # if not os.path.exists(output_dir):\n            #     os.mkdir(args.data.output_dir)\n            fh = logging.FileHandler(os.path.join(output_dir,'log.txt'))\n            fh.setLevel(logging.INFO)\n            ch = ProgressHandler()\n            ch.setLevel(logging.DEBUG)\n            formatter = logging.Formatter(fmt='%(asctime)s %(message)s', datefmt='%m/%d/%Y %I:%M:%S')\n            fh.setFormatter(formatter)\n            ch.setFormatter(formatter)\n            log.addHandler(fh)\n            log.addHandler(ch)\n        self.log = log\n        # setup TensorBoard\n        # if args.tensorboard:\n        #     from tensorboardX import SummaryWriter\n        #     self.writer = SummaryWriter(log_dir=args.output_dir)\n        # else:\n        self.writer = None\n        self.log_per_updates = args.log_per_updates\n\n    def set_progress(self, epoch, total):\n        self.log.info(f'Epoch: {epoch}')\n        self.epoch = epoch\n        self.i = 0\n        self.total = total\n        self.start = datetime.now()\n\n    def update(self, stats):\n        self.i += 1\n        if self.i % self.log_per_updates == 0:\n            remaining = str((datetime.now() - self.start) / self.i * (self.total - self.i))\n            remaining = remaining.split('.')[0]\n            updates = stats.pop('updates')\n            stats_str = ' '.join(f'{key}[{val:.8f}]' for key, val in stats.items())\n            \n            self.log.info(f'> epoch [{self.epoch}] updates[{updates}] {stats_str} eta[{remaining}]')\n            \n            if self.writer:\n                for key, val in stats.items():\n                    self.writer.add_scalar(f'train/{key}', val, updates)\n        if self.i == self.total:\n            self.log.debug('\\n')\n            self.log.debug(f'elapsed time: {str(datetime.now() - self.start).split(\".\")[0]}')\n\n    def log_eval(self, stats, metrics_group=None):\n        stats_str = ' '.join(f'{key}: {val:.8f}' for key, val in stats.items())\n        self.log.info(f'valid {stats_str}')\n        if self.writer:\n            for key, val in stats.items():\n                self.writer.add_scalar(f'valid/{key}', val, self.epoch)\n        # for mode, metrics in metrics_group.items():\n        #     self.log.info(f'evaluation scores ({mode}):')\n        #     for key, (val, _) in metrics.items():\n        #         self.log.info(f'\\t{key} {val:.4f}')\n        # if self.writer and metrics_group is not None:\n        #     for key, val in stats.items():\n        #         self.writer.add_scalar(f'valid/{key}', val, self.epoch)\n        #     for key in list(metrics_group.values())[0]:\n        #         group = {}\n        #         for mode, metrics in metrics_group.items():\n        #             group[mode] = metrics[key][0]\n        #         self.writer.add_scalars(f'valid/{key}', group, self.epoch)\n\n    def __call__(self, msg):\n        self.log.info(msg)\n\n\nclass ProgressHandler(logging.Handler):\n    def __init__(self, level=logging.NOTSET):\n        super().__init__(level)\n\n    def emit(self, record):\n        log_entry = self.format(record)\n        if record.message.startswith('> '):\n            sys.stdout.write('{}\\r'.format(log_entry.rstrip()))\n            sys.stdout.flush()\n        else:\n            sys.stdout.write('{}\\n'.format(log_entry))\n\n"
  },
  {
    "path": "corr/utils/visualize_vtx_corr.py",
    "content": "import numpy as np\nimport torch\nimport cv2\n\n\ndef make_inter_graph(v2d1, v2d2, topo1, topo2, match12):\n    valid = (match12 != -1)\n    marked2 = np.zeros(len(v2d2)).astype(bool)\n    # print(match12[valid])\n    marked2[match12[valid]] = True\n\n    id1toh, id2toh = np.zeros(len(v2d1)), np.zeros(len(v2d2))\n    id1toh[valid] = np.arange(np.sum(valid))\n    id2toh[match12[valid]] = np.arange(np.sum(valid))\n    id1toh[np.invert(valid)] = np.arange(np.sum(1 - valid)) + np.sum(valid)\n    # print(marked2)\n    id2toh[np.invert(marked2)] = len(v2d1) + np.arange(np.sum(np.invert(marked2)))\n\n    id1toh = id1toh.astype(int)\n    id2toh = id2toh.astype(int)\n\n    tot_len = len(v2d1) + np.sum(np.invert(marked2))\n\n    vin1 = v2d1[valid][:]\n    vin2 = v2d2[match12[valid]][:]\n    vh = 0.5 * (vin1 + vin2)\n    vh = np.concatenate((vh, v2d1[np.invert(valid)], v2d2[np.invert(marked2)]), axis=0)\n\n    topoh = [[] for ii in range(tot_len)]\n\n\n    for node in range(len(topo1)):\n        \n        for nb in topo1[node]:\n            if int(id1toh[nb]) not in topoh[id1toh[node]]:\n                topoh[id1toh[node]].append(int(id1toh[nb]))\n\n\n    for node in range(len(topo2)):\n        for nb in topo2[node]:\n            if int(id2toh[nb]) not in topoh[id2toh[node]]:\n                topoh[id2toh[node]].append(int(id2toh[nb]))\n\n    return vh, topoh\n\n\ndef make_inter_graph_valid(v2d1, v2d2, topo1, topo2, match12):\n    valid = (match12 != -1)\n    marked2 = np.zeros(len(v2d2)).astype(bool)\n    # print(match12[valid])\n    marked2[match12[valid]] = True\n\n    id1toh, id2toh = np.zeros(len(v2d1)), np.zeros(len(v2d2))\n    id1toh[valid] = np.arange(np.sum(valid))\n    id2toh[match12[valid]] = np.arange(np.sum(valid))\n    id1toh[np.invert(valid)] = np.arange(np.sum(1 - valid)) + np.sum(valid)\n    # print(marked2)\n    id2toh[np.invert(marked2)] = len(v2d1) + np.arange(np.sum(np.invert(marked2)))\n\n    id1toh = id1toh.astype(int)\n    id2toh = id2toh.astype(int)\n\n    tot_len = len(v2d1) + np.sum(np.invert(marked2))\n\n    vin1 = v2d1[valid][:]\n    vin2 = v2d2[match12[valid]][:]\n    vh = 0.5 * (vin1 + vin2)\n    # vh = np.concatenate((vh, v2d1[np.invert(valid)], v2d2[np.invert(marked2)]), axis=0)\n\n    # topoh = [[] for ii in range(tot_len)]\n    topoh = [[] for ii in range(np.sum(valid))]\n\n    for node in range(len(topo1)):\n        if not valid[node]:\n            continue\n        for nb in topo1[node]:\n            if int(id1toh[nb]) not in topoh[id1toh[node]]:\n                if valid[nb]:\n                    topoh[id1toh[node]].append(int(id1toh[nb]))\n\n\n    for node in range(len(topo2)):\n        if not marked2[node]:\n            continue\n        for nb in topo2[node]:\n            if int(id2toh[nb]) not in topoh[id2toh[node]]:\n                if marked2[nb]:\n                    topoh[id2toh[node]].append(int(id2toh[nb]))\n\n    return vh, topoh\n\n\n\ndef visualize(dict):\n    # print(dict['keypoints0'].size(), flush=True)\n    img1 = ((dict['image0'].permute(1, 2, 0).float().numpy() + 1.0) * 255 / 2).astype(int).copy()\n    img2 = ((dict['image1'].permute(1, 2, 0).float().numpy() + 1.0) * 255 / 2).astype(int).copy()\n    img1p = ((dict['image0'].permute(1, 2, 0).float().numpy() + 1.0) * 255 / 2).astype(int).copy()\n    img2p = ((dict['image1'].permute(1, 2, 0).float().numpy() + 1.0) * 255 / 2).astype(int).copy()\n\n    img1[:, :, 0] += 255\n    img1[:, :, 1] += 180\n    img1[:, :, 2] += 180\n    img1[img1 > 255] = 255\n\n    img2[:, :, 0] += 255\n    img2[:, :, 1] += 180\n    img2[:, :, 2] += 180\n    img2[img2 > 255] = 255\n    \n    img1p[:, :, 0] += 255\n    img1p[:, :, 1] += 180\n    img1p[:, :, 2] += 180\n    img1p[img1p > 255] = 255\n    \n    img2p[:, :, 0] += 255\n    img2p[:, :, 1] += 180\n    img2p[:, :, 2] += 180\n    img2p[img2p > 255] = 255\n\n    img1, img2, img1p, img2p = img1.astype(np.uint8), img2.astype(np.uint8), img1p.astype(np.uint8), img2p.astype(np.uint8)\n    \n\n    # print(v2d1.shape, img1.shape, flush=True)\n    v2d1 = dict['keypoints0'].numpy().astype(int)\n    v2d2 = dict['keypoints1'].numpy().astype(int)\n    topo1 = dict['topo0']\n    topo2 = dict['topo1']\n    # print(topo1, flush=True)\n    # for node, nbs in enumerate(dict['topo0']):\n    #     for nb in nbs:\n    #         cv2.line(img1, [v2d1[node][0], v2d1[node][1]], [v2d1[nb][0], v2d1[nb][1]], [255, 180, 180], 2)\n\n\n    id1 = np.arange(len(v2d1))\n    id2 = np.arange(len(v2d2))\n    all_matches = dict['all_matches'].cpu().int().data.numpy()\n    predicted = dict['matches0'].cpu().data.numpy()[0]\n    predicted1 = dict['matches1'].cpu().data.numpy()[0]\n    \n    colors1_gt, colors2_gt = {}, {}\n    colors1_pred, colors2_pred = {}, {}\n    cross1_pred, cross2_pred = {}, {}\n\n    for index in id1:\n        color = [np.random.randint(0, 256), np.random.randint(0, 256), np.random.randint(0, 256)]\n            # print(predicted.shape, flush=True)\n        if all_matches[index] != -1:\n            colors2_gt[all_matches[index]] = color\n        if predicted[index] != -1:\n            colors2_pred[predicted[index]] = color\n\n        colors1_gt[index] = color if all_matches[index] != -1 else [0, 0, 0]\n        colors1_pred[index] = color if predicted[index] != -1 else [0, 0, 0]\n\n        # if predicted[index] == -1 and colors1_pred[index] != [0, 0, 0]:\n        #     color = [np.random.randint(0, 256), np.random.randint(0, 256), np.random.randint(0, 256)]\n        #     colors1_pred[index] = [0, 0, 0]\n        #     colors2_pred.pop(all_matches[index])\n        # whether predicted correctly\n        if predicted[index] != all_matches[index]:\n            cross1_pred[index] = True\n            if predicted[index] != -1:\n                cross2_pred[predicted[index]] = True\n        \n    for i, p in enumerate(v2d1):\n        ii = id1[i]\n        # print(ii)\n        cv2.circle(img1, [int(p[0]), int(p[1])], 1, colors1_gt[i], 2)\n        if ii in cross1_pred and cross1_pred[ii]:\n            cv2.rectangle(img1p, [int(p[0]-1), int(p[1]-1)], [int(p[0]+1), int(p[1]+1)], colors1_pred[i],-1)\n        else:\n            cv2.circle(img1p, [int(p[0]), int(p[1])], 1, colors1_pred[i], 2)\n        \n    for ii in id2:\n        # print(ii)\n        color = [0, 0, 0]\n        this_is_umatched = 1\n        if ii not in colors2_gt:\n            colors2_gt[ii] = color  \n        if ii not in colors2_pred:\n            colors2_pred[ii] = color\n\n    for i, p in enumerate(v2d2):\n        ii = id2[i]\n        # print(p)\n        cv2.circle(img2, [int(p[0]), int( p[1])], 1, colors2_gt[ii], 2)\n        if ii in cross2_pred and cross2_pred[ii]:\n            cv2.rectangle(img2p, [int(p[0]-1), int(p[1]-1)], [int(p[0]+1), int(p[1]+1)], colors2_pred[i], -1)\n        else:\n            cv2.circle(img2p, [int(p[0]), int(p[1])], 1, colors2_pred[i], 2)\n\n    # print('Unmatched in Img 2: ', , '%')\n    # unmatched_all.append(100 - unmatched * 100.0/len(v2d2))\n    cv2.putText(img2p, str(round(np.sum(all_matches == predicted) * 100.0 / len(predicted), 2)).format('.2f') + '%', \\\n        (500, 100), cv2.FONT_HERSHEY_PLAIN, 3, (0, 0, 255), 2)\n\n\n\n    vh_gt, topoh_gt = make_inter_graph(v2d1, v2d2, topo1, topo2, all_matches)\n    vh_pred, topoh_pred = make_inter_graph(v2d1, v2d2, topo1, topo2, predicted)\n    vh_gt_valid, topoh_gt_valid = make_inter_graph_valid(v2d1, v2d2, topo1, topo2, all_matches)\n    vh_pred_valid, topoh_pred_valid = make_inter_graph_valid(v2d1, v2d2, topo1, topo2, predicted)\n    v2d1t = ((v2d2[predicted] + v2d1) * 0.5).astype(int)\n    v2d2t = ((v2d1[predicted1] + v2d2) * 0.5).astype(int)\n\n    vh_gt = vh_gt.astype(int)\n    vh_gt_valid = vh_gt_valid.astype(int)\n    vh_pred = vh_pred.astype(int)\n    vh_pred_valid = vh_pred_valid.astype(int)\n\n    imgh = np.zeros_like(img1) + 255\n    imghp = np.zeros_like(img1) + 255\n    imgh_valid = np.zeros_like(img1) + 255\n    imghp_valid = np.zeros_like(img1) + 255\n\n    for node, nbs in enumerate(topoh_gt):\n        for nb in nbs:\n            cv2.line(imgh, [vh_gt[node][0], vh_gt[node][1]], [vh_gt[nb][0], vh_gt[nb][1]], [0, 0, 0], 2)\n    \n    for node, nbs in enumerate(topoh_pred):\n        for nb in nbs:\n            cv2.line(imghp, [vh_pred[node][0], vh_pred[node][1]], [vh_pred[nb][0], vh_pred[nb][1]], [0, 0, 0], 2)\n    \n    for node, nbs in enumerate(topoh_gt_valid):\n        for nb in nbs:\n            cv2.line(imgh_valid, [vh_gt_valid[node][0], vh_gt_valid[node][1]], [vh_gt_valid[nb][0], vh_gt_valid[nb][1]], [0, 0, 0], 2)\n    \n    for node, nbs in enumerate(topoh_pred_valid):\n        for nb in nbs:\n            cv2.line(imghp_valid, [vh_pred_valid[node][0], vh_pred_valid[node][1]], [vh_pred_valid[nb][0], vh_pred_valid[nb][1]], [0, 0, 0], 2)\n    \n    # for node, nbs in enumerate(topo1):\n    #     for nb in nbs:\n    #         cv2.line(imghp_valid, [v2d1t[node][0], v2d1t[node][1]], [v2d1t[nb][0], v2d1t[nb][1]], [0, 0, 0], 2)\n    \n    # for node, nbs in enumerate(topo2):\n    #     for nb in nbs:\n    #         cv2.line(imghp_valid, [v2d2t[node][0], v2d2t[node][1]], [v2d2t[nb][0], v2d2t[nb][1]], [0, 0, 0], 2)\n    \n\n\n    im_h = cv2.hconcat([img1, img2])\n    im_hp = cv2.hconcat([img1p, img2p])\n    img_inter = cv2.hconcat([imgh, imghp])\n    img_inter_valid = cv2.hconcat([imgh_valid, imghp_valid])\n    im_hv = cv2.vconcat([im_h, im_hp, img_inter, img_inter_valid])\n\n    return im_hv\n"
  },
  {
    "path": "corr/vtx_matching.py",
    "content": "\"\"\" This script handling the training process. \"\"\"\nimport os\nimport time\nimport random\nimport argparse\nimport torch\nimport torch.nn as nn\nimport torch.optim as optim\nimport torch.utils.data\nfrom datasets import fetch_dataloader\nimport random\nfrom utils.log import Logger\n\nfrom torch.optim import *\nimport warnings\nfrom tqdm import tqdm\nimport itertools\nimport pdb\nimport numpy as np\nimport models\nimport datetime\nimport sys\nimport json\nimport cv2\n\nfrom utils.visualize_vtx_corr import visualize\nimport matplotlib.cm as cm\n# from models.utils import make_matching_seg_plot\n\nwarnings.filterwarnings('ignore')\n\n\nimport matplotlib.pyplot as plt\nimport pdb\n\nclass VtxMat():\n    def __init__(self, args):\n        self.config = args\n        torch.backends.cudnn.benchmark = True\n        torch.multiprocessing.set_sharing_strategy('file_system')\n        self._build()\n\n    def train(self):\n        \n        opt = self.config\n        print(opt)\n\n        model = self.model\n\n        if hasattr(self.config, 'init_weight'):\n            checkpoint = torch.load(self.config.init_weight)\n            model.load_state_dict(checkpoint['model'])\n\n        optimizer = self.optimizer\n        schedular = self.schedular\n        mean_loss = []\n        log = Logger(self.config, self.expdir)\n        updates = 0\n        \n        # set seed\n        random.seed(opt.seed)\n        torch.manual_seed(opt.seed)\n        torch.cuda.manual_seed(opt.seed)\n        np.random.seed(opt.seed)\n\n        # start training\n        for epoch in range(1, opt.epoch+1):\n            np.random.seed(opt.seed + epoch)\n            train_loader = self.train_loader\n            log.set_progress(epoch, len(train_loader))\n            batch_loss = 0\n            batch_acc = 0 \n            batch_valid_acc = 0\n            batch_iter = 0\n            model.train()\n            avg_time = 0\n            avg_num = 0\n            # torch.cuda.synchronize()\n            \n            for i, pred in enumerate(train_loader):\n                # tstart = time.time()\n                # print(pred['file_name'])\n                data = model(pred)\n\n\n                if not data['skip_train']:\n                    loss = data['loss'] / opt.batch_size\n                    batch_loss += loss.item()\n                    batch_acc += data['accuracy'] \n                    batch_valid_acc += data['valid_accuracy'] \n                    loss.backward()\n                    batch_iter += 1\n                else:\n                    print('Skip!')\n\n                ## Accumulate gradient for batch training\n                if ((i + 1) % opt.batch_size == 0) or (i + 1 == len(train_loader)):\n                    optimizer.step()\n                    optimizer.zero_grad()\n                    batch_iter = 1 if batch_iter == 0 else batch_iter               \n                    stats = {\n                        'updates': updates,\n                        'loss': batch_loss,\n                        'accuracy': batch_acc / batch_iter,\n                        'valid_accuracy': batch_valid_acc / batch_iter\n                    }\n                    log.update(stats)\n                    updates += 1\n                    batch_loss = 0\n                    batch_acc = 0 \n                    batch_valid_acc = 0\n                    batch_iter = 0\n\n            # torch.cuda.synchronize()\n\n            # avg_num += 1\n                # for name, params in model.named_parameters():\n                #     print('-->name:, ', name, '-->grad mean', params.grad.mean())\n            # print(\"All time is \", avg_time, \"AVG time is \", avg_time * 1.0 /avg_num,  \"number is \", avg_num, flush=True)\n\n            # save checkpoint \n            if epoch % opt.save_per_epochs == 0 or epoch == 1:\n                checkpoint = {\n                    'model': model.state_dict(),\n                    'config': opt,\n                    'epoch': epoch\n                }\n\n                filename = os.path.join(self.ckptdir, f'epoch_{epoch}.pt')\n                torch.save(checkpoint, filename)\n                \n            # validate\n            if epoch % opt.test_freq == 0:\n\n                if not os.path.exists(os.path.join(self.visdir, 'epoch' + str(epoch))):\n                    os.mkdir(os.path.join(self.visdir, 'epoch' + str(epoch)))\n                eval_output_dir = os.path.join(self.visdir, 'epoch' + str(epoch))    \n                \n                test_loader = self.test_loader\n\n                with torch.no_grad():\n                    # Visualize the matches.\n                    mean_acc = []\n                    mean_valid_acc = []\n                    model.eval()\n                    for i_eval, data in enumerate(tqdm(test_loader, desc='Predicting Vtx Corr...')):\n                        pred = model(data)\n                        # for k, v in data.items():\n                        #     pred[k] = v[0]\n                        #     pred = {**pred, **data}\n\n                        mean_acc.append(pred['accuracy'])\n                        mean_valid_acc.append(pred['valid_accuracy'])\n                    log.log_eval({\n                        'updates': opt.epoch,\n                        'Accuracy': np.mean(mean_acc),\n                        'Valid Accuracy': np.mean(mean_valid_acc),\n                        })\n                    print('Epoch [{}/{}]], Acc.: {:.4f}, Valid Acc.{:.4f}' \n                        .format(epoch, opt.epoch, np.mean(mean_acc), np.mean(mean_valid_acc)) )\n                    sys.stdout.flush()\n                        # make_matching_plot(\n                        #     image0, image1, kpts0, kpts1, mkpts0, mkpts1, color,\n                        #     text, viz_path, stem, stem, True,\n                        #     True, False, 'Matches')\n        \n            self.schedular.step()\n\n            \n\n\n    def eval(self):\n        train_action = ['breakdance_1990', 'capoeira', 'chapa-giratoria', 'fist_fight', 'flying_knee', 'freehang_climb', 'running', 'shove', 'magic', 'tripping']\n        test_action = ['great_sword_slash', 'hip_hop_dancing']\n\n        train_model = ['ganfaul', 'girlscout', 'jolleen', 'kachujin', 'knight', 'maria_w_jj', 'michelle', 'peasant_girl', 'timmy', 'uriel_a_plotexia']\n        test_model = ['police', 'warrok']\n\n        log = Logger(self.config, self.expdir)\n        with torch.no_grad():\n            model = self.model.eval()\n            config = self.config\n            epoch_tested = self.config.testing.ckpt_epoch\n            ckpt_path = os.path.join(self.ckptdir, f\"epoch_{epoch_tested}.pt\")\n            # self.device = torch.device('cuda' if config.cuda else 'cpu')\n            print(\"Evaluation...\")\n            checkpoint = torch.load(ckpt_path)\n            model.load_state_dict(checkpoint['model'])\n\n            model.eval()\n\n            if not os.path.exists(os.path.join(self.evaldir, 'epoch' + str(epoch_tested))):\n                os.mkdir(os.path.join(self.evaldir, 'epoch' + str(epoch_tested)))\n            if not os.path.exists(os.path.join(self.evaldir, 'epoch' + str(epoch_tested), 'jsons')):\n                os.mkdir(os.path.join(self.evaldir, 'epoch' + str(epoch_tested), 'jsons'))\n            eval_output_dir = os.path.join(self.evaldir, 'epoch' + str(epoch_tested))    \n                \n            test_loader = self.test_loader\n            print(len(test_loader))\n            mean_acc = []\n            mean_valid_acc = []\n            mean_invalid_acc = []\n\n            # 144 data \n            # 10x10 is for training , 2x10 (unseen model) + 10x2 (unseen action) + 2x2 (unseen model unseen action) is for test\n            # record the accuracy for each\n            mean_model_acc = []\n            mean_model_valid_acc = []\n            mean_action_acc = []\n            mean_action_valid_acc = []\n            \n            mean_none_acc = []\n            mean_none_valid_acc = []\n\n            mean_matched = []\n\n            for i_eval, pred in enumerate(tqdm(test_loader, desc='Predicting Vtx Corr...')):\n                data = model(pred)\n                for k, v in pred.items():\n                    pred[k] = v[0]\n                    pred = {**pred, **data}\n            \n                mean_acc.append(pred['accuracy'])\n                mean_valid_acc.append(pred['valid_accuracy'])\n                this_pred = (pred['matches0'] != -1).float().cpu().data.numpy().astype(np.float32)\n                mean_matched.append(np.mean( this_pred))\n\n                unmarked = True\n                for model_name in train_model:\n                    if model_name in pred['file_name']:\n                        mean_model_acc.append(pred['accuracy'])\n                        mean_model_valid_acc.append(pred['valid_accuracy'])\n                        unmarked = False\n                        break\n\n                for action_name in train_action:\n                    if action_name in pred['file_name']:\n                        mean_action_acc.append(pred['accuracy'])\n                        mean_action_valid_acc.append(pred['valid_accuracy'])\n                        unmarked = False\n                        break\n                \n                if unmarked:\n                    mean_none_acc.append(pred['accuracy'])\n                    mean_action_valid_acc.append(pred['valid_accuracy'])\n\n                if 'invalid_accuracy' in pred and pred['invalid_accuracy'] is not None:\n                    mean_invalid_acc.append(pred['invalid_accuracy'])\n                \n                img_vis = visualize(pred)\n                cv2.imwrite(os.path.join(eval_output_dir, pred['file_name'].replace('/', '_') + '.jpg'), img_vis)\n                \n            log.log_eval({\n                'updates': self.config.testing.ckpt_epoch,\n                'Accuracy': np.mean(mean_acc),\n                'Accuracy (Matched)': np.mean(mean_valid_acc),\n                'Unseen Action Accuracy': np.mean(mean_model_acc),\n                'Unseen Action Accuracy (Matched)': np.mean(mean_model_valid_acc),\n                'Unseen Model Accuracy': np.mean(mean_action_acc),\n                'Unseen Model Accuracy (Matched)': np.mean(mean_action_valid_acc),\n                'Unseen Both Accuracy': np.mean(mean_none_acc),\n                'Unseen Both Valid Accuracy': np.mean(mean_none_valid_acc),\n                'Matching Rate': np.mean(mean_matched)\n                })\n                # print ('Epoch [{}/{}]], Acc.: {:.4f}, Valid Acc.{:.4f}' \n                #     .format(epoch, opt.epoch, np.mean(mean_acc), np.mean(mean_valid_acc)) )\n            sys.stdout.flush()\n\n    def _build(self):\n        config = self.config\n        self.start_epoch = 0\n        self._dir_setting()\n        self._build_model()\n        if not(hasattr(config, 'need_not_train_data') and config.need_not_train_data):\n            self._build_train_loader()\n        if not(hasattr(config, 'need_not_test_data') and config.need_not_train_data):      \n            self._build_test_loader()\n        self._build_optimizer()\n\n    def _build_model(self):\n        \"\"\" Define Model \"\"\"\n        config = self.config \n        if hasattr(config.model, 'name'):\n            print(f'Experiment Using {config.model.name}')\n            model_class = getattr(models, config.model.name)\n            model = model_class(config.model)\n        else:\n            raise NotImplementedError(\"Wrong Model Selection\")\n        \n        model = nn.DataParallel(model)\n        self.model = model.cuda()\n\n    def _build_train_loader(self):\n        config = self.config\n        self.train_loader = fetch_dataloader(config.data.train, type='train')\n\n    def _build_test_loader(self):\n        config = self.config\n        self.test_loader = fetch_dataloader(config.data.test, type='test')\n\n    def _build_optimizer(self):\n        #model = nn.DataParallel(model).to(device)\n        config = self.config.optimizer\n        try:\n            optim = getattr(torch.optim, config.type)\n        except Exception:\n            raise NotImplementedError('not implemented optim method ' + config.type)\n\n        self.optimizer = optim(itertools.chain(self.model.module.parameters(),\n                                             ),\n                                             **config.kwargs)\n        self.schedular = torch.optim.lr_scheduler.MultiStepLR(self.optimizer, **config.schedular_kwargs)\n\n    def _dir_setting(self):\n        data = self.config.data\n        self.expname = self.config.expname\n        self.experiment_dir = os.path.join(\".\", \"experiments\")\n        self.expdir = os.path.join(self.experiment_dir, self.expname)\n\n        if not os.path.exists(self.expdir):\n            os.mkdir(self.expdir)\n\n        self.visdir = os.path.join(self.expdir, \"vis\")  # -- imgs, videos, jsons\n        if not os.path.exists(self.visdir):\n            os.mkdir(self.visdir)\n\n        self.ckptdir = os.path.join(self.expdir, \"ckpt\")\n        if not os.path.exists(self.ckptdir):\n            os.mkdir(self.ckptdir)\n\n        self.evaldir = os.path.join(self.expdir, \"eval\")\n        if not os.path.exists(self.evaldir):\n            os.mkdir(self.evaldir)\n\n        \n\n        # self.ckptdir = os.path.join(self.expdir, \"ckpt\")\n        # if not os.path.exists(self.ckptdir):\n        #     os.mkdir(self.ckptdir)\n\n\n\n        \n\n\n\n\n"
  },
  {
    "path": "data/README.md",
    "content": ""
  },
  {
    "path": "datasets/__init__.py",
    "content": "\nfrom .ml_seq import fetch_dataloader\nfrom .vd_seq import fetch_videoloader\n\n__all__ = ['fetch_dataloader', 'fetch_videoloader']"
  },
  {
    "path": "datasets/ml_seq.py",
    "content": "import numpy as np\nimport torch\nimport torch.utils.data as data\nimport torch.nn.functional as F\n\nimport os\nimport math\nimport random\nfrom glob import glob\nimport os.path as osp\n\nimport sys\nimport argparse\nimport cv2\nfrom collections import Counter\nimport time\nimport json\nimport sknetwork\nfrom sknetwork.embedding import Spectral\n\nimport scipy\n\ndef read_json(file_path):\n    \"\"\"\n        input: json file path\n        output: 2d vertex, connections and vertex index in original 3D domain\n    \"\"\"\n\n    with open(file_path) as file:\n        data = json.load(file)\n        vertex2d = np.array(data['vertex location'])\n        \n        topology = data['connection']\n        index = np.array(data['original index'])\n\n    return vertex2d, topology, index\n\n\ndef matched_motion(v2d1, v2d2, match12, motion_pre=None):\n    motion = np.zeros_like(v2d1)\n\n    motion[match12 != -1] = v2d2[match12[match12 != -1]] - v2d1[match12 != -1]\n    if motion_pre is not None:\n        motion[match12 != -1] = motion[match12 != -1] + motion_pre[match12[match12 != -1]]\n    return motion\n\ndef unmatched_motion(topo1, v2d1, motion12, match12):\n    pos = np.arange(len(topo1))\n    masked = (match12 == -1)\n\n    round = 0\n    former_len = 0\n    while(len(pos[masked]) > 0):\n        this_len = len(pos[masked])\n        if former_len == this_len:\n            break\n        former_len = this_len\n        round += 1\n        for v in pos[masked]:\n            unmatched = masked[topo1[v]]\n\n            if unmatched.sum() != len(topo1[v]):\n                motion12[v] = np.average(motion12[topo1[v]][np.invert(unmatched)], axis=0)\n                masked[v] = False\n\n                \n    if len(pos[masked] > 0):\n        # find the neast point for each unlabeled point\n        index = ((v2d1[pos[masked]][:, None, :] - v2d1[pos[np.invert(masked)]]) ** 2).sum(2).argmin(1)\n        motion12[pos[masked]] = motion12[pos[np.invert(masked)]][index]\n        masked[pos[masked]] = False\n\n    return motion12\n\n\ndef ids_to_mat(id1, id2):\n    \"\"\"\n        inputs are two list of vertex index in original 3D mesh\n    \"\"\"\n    corr1 = np.zeros(len(id1)) - 1.0\n    corr2 = np.zeros(len(id2)) - 1.0\n    \n    id1 = np.array(id1).astype(int)[:, None]\n    id2 = np.array(id2).astype(int)\n    \n    mat = (id1 == id2)\n\n\n    pos12 = np.arange(len(id2))[None].repeat(len(id1), 0)\n    pos21 = np.arange(len(id1))[None].repeat(len(id2), 0)\n    corr1[mat.astype(int).sum(1).astype(bool)] = pos12[mat]\n    corr2[mat.transpose().astype(int).sum(1).astype(bool)] = pos21[mat.transpose()]\n\n\n    return mat, corr1, corr2\n\ndef adj_matrix(topology):\n    \"\"\"\n        topology is the adj table; returns adj matrix\n    \"\"\"\n    gsize = len(topology)\n    adj = np.zeros((gsize, gsize)).astype(float)\n    for v in range(gsize):\n        adj[v][v] = 1.0\n        for nb in topology[v]:\n            adj[v][nb] = 1.0\n            adj[nb][v] = 1.0\n    return adj\n\nclass MixamoLineArtMotionSequence(data.Dataset):\n    def __init__(self, root, gap=0, split='train', model=None, action=None, mode='train', use_vs=False, max_len=3050):\n        \"\"\"\n            input:\n                root: the root folder of the line art data\n                gap: how many frames between two frames. gap should be an odd numbe.\n                split: train or test\n                model: indicate a specific character (default None)\n                action: indicate a specific action (default None)\n\n            output:\n                image of sources (0, 1) and output (0.5)\n                topo0, topo1\n                v2d0, v2d1\n                \n                corr12, corr21\n\n                motion0-->0.5, motion1-->0.5\n                visibility0-->0.5, visibility   1-->0.5\n\n        \"\"\"\n        super(MixamoLineArtMotionSequence, self).__init__()\n\n        self.gap = gap\n        if model == 'None':\n            model = None\n        if action == 'None':\n            action = None\n\n        assert(gap%2 != 0)\n\n        self.is_train = True if mode == 'train' else False\n        self.is_eval = True if mode == 'eval' else False\n        # self.is_train = False\n        self.max_len = max_len\n\n        self.image_list = []\n        self.label_list = []\n\n        label_root = osp.join(root, split, 'labels')\n        self.use_vs = False\n        if use_vs:\n            print('>>>>>>>> Using VS labels')\n            self.use_vs = True\n            label_root = osp.join(root, split, 'labels_vs')\n        image_root = osp.join(root, split, 'frames')\n        self.spectral = Spectral(64,  normalized=False)\n\n        for clip in os.listdir(image_root):\n            skip = False\n            if model != None:\n                for mm in model:\n                    if mm in clip:\n                        skip = True\n                \n            if action != None:\n                for aa in action:\n                    if aa in clip:\n                        skip = True\n            if skip:\n                continue\n            image_list = sorted(glob(osp.join(image_root, clip, '*.png')))\n            label_list = sorted(glob(osp.join(label_root, clip, '*.json')))\n            if len(image_list) != len(label_list):\n                print(clip, flush=True)\n                continue\n            for i in range(len(image_list) - (gap+1)):\n                self.image_list += [ [image_list[jj] for jj in range(i, i + gap + 2)] ]\n            for i in range(len(label_list) - (gap+1)):\n                self.label_list += [ [label_list[jj] for jj in range(i, i + gap + 2)] ]\n        # print(clip)\n        print('Len of Frame is ', len(self.image_list), flush=True)\n        print('Len of Label is ', len(self.label_list), flush=True)\n\n    def __getitem__(self, index):\n\n\n        # load image/label files\n        # load labels: \n        #   (a) read json (b) load image (c) make pseudo labels\n\n        # image crop to a square (720x720) before input, 2d label same operation\n        # index to index matching\n\n        # test does not need index matching\n        \n        index = index % len(self.image_list)\n        file_name = self.label_list[index][len(self.label_list[index])//2][:-4]\n\n        imgt = [cv2.imread(self.image_list[index][ii]) for ii in range(0, len(self.image_list[index]))]\n\n        labelt = []\n        for ii in range(0, len(self.label_list[index])):\n            v, t, id = read_json(self.label_list[index][ii])\n            v[v > imgt[0].shape[0] - 1] = imgt[0].shape[0] - 1\n            v[v < 0] = 0\n            labelt.append({'keypoints': v.astype(int), 'topo': t, 'id': id})\n\n        # make motion pseudo label\n        motion = None\n        motion01 = None\n\n        start_frame = 0\n        gap = self.gap // 2 + 1\n\n\n        ######### forward direction\n        for ii in reversed(range(start_frame + 1, start_frame + 2*gap + 1)):\n            img1 = imgt[ii - 1]\n            img2 = imgt[ii] \n\n            v2d1 = labelt[ii - 1]['keypoints'].astype(int)\n            v2d2 = labelt[ii]['keypoints'].astype(int)\n\n            topo1 = labelt[ii - 1]['topo']\n            topo2 = labelt[ii ]['topo']\n\n            id1 = labelt[ii - 1]['id']\n            id2 = labelt[ii]['id']\n\n            if self.use_vs:\n                id1 = np.arange(len(id1))\n                id2 = np.arange(len(id2))\n\n            _, match12, matc21 = ids_to_mat(id1, id2)\n\n            if ii <= start_frame + gap:\n                motion01 = matched_motion(v2d1, v2d2, match12.astype(int), motion01)\n                motion01 = unmatched_motion(topo1, v2d1, motion01, match12.astype(int))\n\n            motion = matched_motion(v2d1, v2d2, match12.astype(int), motion)\n            motion = unmatched_motion(topo1, v2d1, motion, match12.astype(int))\n        motion0 = motion.copy()\n \n        img2 = imgt[start_frame + gap]\n        \n        v2d1 = labelt[start_frame]['keypoints'].astype(int)\n        source0_topo = labelt[start_frame]['topo']\n\n        target = cv2.erode(img2, np.ones((3, 3), np.uint8), iterations=1)\n\n        shift_plabel = v2d1 + motion01\n        visible = np.ones(len(v2d1)).astype(float)\n        visible[shift_plabel[:, 0] < 0] = 0\n        visible[shift_plabel[:, 0] >= imgt[0].shape[0]] = 0\n        visible[shift_plabel[:, 1] < 0] = 0\n        visible[shift_plabel[:, 1] >= imgt[0].shape[0]] = 0\n\n        # vertex visibility\n        visible[visible == 1] = (target[:, :, 0][shift_plabel[visible == 1][:, 1], shift_plabel[visible == 1][:, 0]] < 255 ).astype(float)\n\n        visible01 = visible.copy()\n        v2d1s = shift_plabel\n\n        # edge visibility\n        for node, nbs in enumerate(source0_topo):\n            for nb in nbs:\n                if visible01[nb] and visible01[node] and ((v2d1s[node] - v2d1s[nb]) ** 2).sum() / (((v2d1[node] - v2d1[nb]) ** 2).sum() + 1e-7) > 25:\n                    visible01[nb] = False\n                    visible01[node] = False\n\n        ######## backward direction\n        motion = None\n        motion21 = None\n\n        for ii in range(start_frame + 1, start_frame + gap + gap + 1):\n            img2 = imgt[ii - 1]\n            img1 = imgt[ii] \n\n            v2d2 = labelt[ii - 1]['keypoints'].astype(int)\n            v2d1 = labelt[ii]['keypoints'].astype(int)\n\n            topo2 = labelt[ii - 1]['topo']\n            topo1 = labelt[ii ]['topo']\n\n            \n            id1 = labelt[ii]['id']\n            id2 = labelt[ii - 1]['id']\n            if self.use_vs:\n                id1 = np.arange(len(id1))\n                id2 = np.arange(len(id2))\n            _, match12, _ = ids_to_mat(id1, id2)\n\n            if ii >= start_frame + gap + 1:\n                motion21 = matched_motion(v2d1, v2d2, match12.astype(int), motion21)\n                motion21 = unmatched_motion(topo1, v2d1, motion21, match12.astype(int))\n\n            motion = matched_motion(v2d1, v2d2, match12.astype(int), motion)\n            motion = unmatched_motion(topo1, v2d1, motion, match12.astype(int))\n\n        motion2 = motion.copy()\n        \n        img1 = imgt[start_frame + 2*gap]\n        img2 = imgt[start_frame + gap]\n        \n        v2d1 = labelt[start_frame + 2*gap]['keypoints'].astype(int)\n        source2_topo = labelt[start_frame + 2*gap]['topo']\n\n        shift_plabel = v2d1 + motion21\n        visible = np.ones(len(v2d1)).astype(float)\n        visible[shift_plabel[:, 0] < 0] = 0\n        visible[shift_plabel[:, 0] >= imgt[0].shape[0]] = 0\n        visible[shift_plabel[:, 1] < 0] = 0\n        visible[shift_plabel[:, 1] >= imgt[0].shape[0]] = 0\n\n        visible[visible == 1] = (target[:, :, 0][shift_plabel[visible == 1][:, 1], shift_plabel[visible == 1][:, 0]] < 255 ).astype(float)\n\n        visible21 = visible.copy()\n\n        v2d1s = shift_plabel\n\n        for node, nbs in enumerate(source2_topo):\n            for nb in nbs:\n                if visible21[nb] and visible21[node] and ((v2d1s[node] - v2d1s[nb]) ** 2).sum() / (((v2d1[node] - v2d1[nb]) ** 2).sum() + 1e-7) > 25:\n                    visible21[nb] = False\n                    visible21[node] = False\n\n\n        ###### prepare other data\n        img2 = imgt[-1]\n        img1 = imgt[0] \n\n        v2d2 = labelt[-1]['keypoints'].astype(int)\n        v2d1 = labelt[0]['keypoints'].astype(int)\n\n        topo2 = labelt[-1]['topo']\n        topo1 = labelt[0]['topo']\n\n        m, n = len(v2d1), len(v2d2)\n\n        if len(img1.shape) == 2:\n            img1 = np.tile(img1[...,None], (1, 1, 3))\n            img2 = np.tile(img2[...,None], (1, 1, 3))\n        else:\n            img1 = img1[..., :3]\n            img2 = img2[..., :3]\n\n        img1 = torch.from_numpy(img1).permute(2, 0, 1).float() * 2 / 255.0 - 1.0 \n        img2 = torch.from_numpy(img2).permute(2, 0, 1).float() * 2 / 255.0 - 1.0\n        imgt = torch.from_numpy(imgt[start_frame + gap]).permute(2, 0, 1).float() * 2 / 255.0 - 1.0 \n\n        v2d1 = torch.from_numpy(v2d1)\n        v2d2 = torch.from_numpy(v2d2)\n\n        visible01 = torch.from_numpy(visible01)\n        visible21 = torch.from_numpy(visible21)\n        motion0 = torch.from_numpy(motion0)\n        motion2 = torch.from_numpy(motion2)\n\n        v2d1[v2d1 > imgt[0].shape[0] - 1 ] = imgt[0].shape[0] - 1\n        v2d1[v2d1 < 0] = 0\n        v2d2[v2d2 > imgt[0].shape[1] - 1] = imgt[0].shape[1] - 1\n        v2d2[v2d2 < 0] = 0\n\n        \n        id1 = labelt[start_frame]['id']\n        id2 = labelt[-1]['id']\n        if self.use_vs:\n            id1 = np.arange(len(id1))\n            id2 = np.arange(len(id2))\n\n        mat_index, corr1, corr2 = ids_to_mat(id1, id2)\n        mat_index = torch.from_numpy(mat_index).float()\n        corr1 = torch.from_numpy(corr1).float()\n        corr2 = torch.from_numpy(corr2).float()\n\n        if self.is_train:\n            v2d1 = torch.nn.functional.pad(v2d1, (0, 0, 0, self.max_len - m), mode='constant', value=0)\n            v2d2 = torch.nn.functional.pad(v2d2, (0, 0, 0, self.max_len - n), mode='constant', value=0)\n            corr1 = torch.nn.functional.pad(corr1, (0, self.max_len - m), mode='constant', value=0)\n            corr2 = torch.nn.functional.pad(corr2, (0, self.max_len - n), mode='constant', value=0)\n            motion0 = torch.nn.functional.pad(motion0, (0, 0, 0, self.max_len - m), mode='constant', value=0)\n            motion2 = torch.nn.functional.pad(motion2, (0, 0, 0, self.max_len - n), mode='constant', value=0)\n            visible01 = torch.nn.functional.pad(visible01, (0, self.max_len - m), mode='constant', value=0)\n            visible21 = torch.nn.functional.pad(visible21, (0, self.max_len - n), mode='constant', value=0)\n\n            mask0, mask1 = torch.zeros(self.max_len).float(), torch.zeros(self.max_len).float()\n            mask0[:m] = 1\n            mask1[:n] = 1\n        else:\n            mask0, mask1 = torch.ones(m).float(), torch.ones(n).float()\n        \n        for ii in range(len(topo1)):\n            # if not len(topo1[ii]):\n            topo1[ii].append(ii)\n        for ii in range(len(topo2)):\n            topo2[ii].append(ii)\n        adj1 = sknetwork.data.from_adjacency_list(topo1, matrix_only=True, reindex=False).toarray()\n        adj2 = sknetwork.data.from_adjacency_list(topo2, matrix_only=True, reindex=False).toarray()\n\n        try:\n            spec0, spec1 = np.abs(self.spectral.fit_transform(adj1)), np.abs(self.spectral.fit_transform(adj2))\n        except:\n            print('>>>>' + file_name, flush=True)\n            spec0, spec1 = np.zeros((len(adj1), 64)), np.zeros((len(adj2), 64))\n        # else:\n        #     print('<<<<' + file_name, flush=True)\n\n        # adj2 = adj2 + np.eye(len(adj2))\n\n        if self.is_eval:\n            return{\n                'keypoints0': v2d1,\n                'keypoints1': v2d2,\n                'topo0': [topo1],\n                'topo1': [topo2],\n                # 'id0': id1,\n                # 'id1': id2,\n                'adj_mat0': adj1,\n                'adj_mat1': adj2,\n                'spec0': spec0,\n                'spec1': spec1,\n                'imaget': imgt,\n                'image0': img1,\n                'image1': img2,\n                'motion0': motion0,\n                'motion1': motion2,\n                'visibility0': visible01,\n                'visibility1': visible21,\n\n                'all_matches': corr1,\n                'm01': corr1,\n                'm10': corr2,\n                'ms': m,\n                'ns': n,\n                'mask0': mask0,\n                'mask1': mask1,\n                'file_name': file_name,\n                # 'with_match': True\n            }\n        elif not self.is_train:\n            return{\n                'keypoints0': v2d1,\n                'keypoints1': v2d2,\n                # 'topo0': [topo1],\n                # 'topo1': [topo2],\n                # 'id0': id1,\n                # 'id1': id2,\n                'adj_mat0': adj1,\n                'adj_mat1': adj2,\n                'spec0': spec0,\n                'spec1': spec1,\n                'imaget': imgt,\n                'image0': img1,\n                'image1': img2,\n                'motion0': motion0,\n                'motion1': motion2,\n                'visibility0': visible01,\n                'visibility1': visible21,\n\n                'all_matches': corr1,\n                'm01': corr1,\n                'm10': corr2,\n                'ms': m,\n                'ns': n,\n                'mask0': mask0,\n                'mask1': mask1,\n                'file_name': file_name,\n                # 'with_match': True\n            }\n        \n        else:\n            return{\n                'keypoints0': v2d1,\n                'keypoints1': v2d2,\n                # 'topo0': topo1,\n                # 'topo1': topo2,\n                # 'id0': id1,\n                # 'id1': id2,\n                'adj_mat0': adj1,\n                'adj_mat1': adj2,\n                'spec0': spec0,\n                'spec1': spec1,\n                'imaget': imgt,\n                'motion0': motion0,\n                'motion1': motion2,\n                'visibility0': visible01,\n                'visibility1': visible21,\n\n                'image0': img1,\n                'image1': img2,\n\n                'all_matches': corr1,\n                'm01': corr1,\n                'm10': corr2,\n                'ms': m,\n                'ns': n,\n                'mask0': mask0,\n                'mask1': mask1,\n                'file_name': file_name,\n                # 'with_match': True\n            } \n\n        \n\n    def __rmul__(self, v):\n        self.label_list = v * self.label_list\n        self.image_list = v * self.image_list\n        return self\n        \n    def __len__(self):\n        return len(self.image_list)\n        \n\ndef worker_init_fn(worker_id):                                                          \n    np.random.seed(np.random.get_state()[1][0] + worker_id)\n\ndef fetch_dataloader(args, type='train',):\n    lineart = MixamoLineArtMotionSequence(root=args.root, gap=args.gap, split=args.type, model=args.model, action=args.action, mode=args.mode if hasattr(args, 'mode') else 'train', use_vs=args.use_vs if hasattr(args, 'use_vs') else False)\n    \n    if args.mode == 'train':\n        lineart = MixamoLineArtMotionSequence(root=args.root, gap=args.gap, split=args.type, model=args.model, action=args.action, mode=args.mode if hasattr(args, 'mode') else 'train')\n    \n    if args.mode == 'train':\n        loader = data.DataLoader(lineart, batch_size=args.batch_size, \n            pin_memory=True, shuffle=True, num_workers=16, drop_last=True, worker_init_fn=worker_init_fn)\n    else:\n        loader = data.DataLoader(lineart, batch_size=args.batch_size, \n            pin_memory=True, shuffle=False, num_workers=8)\n    return loader\n\n"
  },
  {
    "path": "datasets/vd_seq.py",
    "content": "import numpy as np\nimport torch\nimport torch.utils.data as data\nimport torch.nn.functional as F\n# import networkx as nx\nimport os\nimport math\nimport random\nfrom glob import glob\nimport os.path as osp\n\nimport sys\nimport argparse\nimport cv2\nfrom collections import Counter\nimport time\nimport json\nimport sknetwork\nfrom sknetwork.embedding import Spectral\n\nimport scipy\n\ndef read_json(file_path):\n    \"\"\"\n        input: json file path\n        output: 2d vertex \n    \"\"\"\n\n    with open(file_path) as file:\n        data = json.load(file)\n        vertex2d = np.array(data['vertex location'])\n        \n        topology = data['connection']\n        index = np.array(data['original index'])\n\n        # index, vertex2d, topology = union_pixel(vertex2d, index, topology)\n        # index, vertex2d, topology = union_pixel2d(vertex2d, index, topology)\n\n    return vertex2d, topology, index\n\n\nclass VideoLinSeq(data.Dataset):\n    def __init__(self, root, split='train'):\n        \"\"\"\n            input:\n                root: the root folder of the line art data\n                split: split folder\n\n            output:\n                image of sources (0, 1) and output (0.5)\n                topo0, topo1\n                v2d0, v2d1\n\n\n        \"\"\"\n        super(VideoLinSeq, self).__init__()\n\n        self.image_list = []\n        self.label_list = []\n\n        label_root = osp.join(root, split, 'labels')\n        image_root = osp.join(root, split, 'frames')\n\n        self.spectral = Spectral(64,  normalized=False)\n\n        for clip in os.listdir(image_root):\n            \n            label_list = sorted(glob(osp.join(label_root, clip, '*.json')))\n\n            for i in range(len(label_list) - 1):\n                self.label_list += [ [label_list[jj] for jj in range(i, i + 2)] ]\n                self.image_list += [ [label_list[jj].replace('labels', 'frames').replace('.json', '.png') for jj in range(i, i + 2)] ]\n\n        # print(clip)\n        print('Len of Frame is ', len(self.image_list), flush=True)\n        print('Len of Label is ', len(self.label_list), flush=True)\n\n    def __getitem__(self, index):\n        # prepare images\n        index = index % len(self.image_list)\n        file_name0 = self.label_list[index][0][:-5].split('/')[-1]\n        file_name1 = self.label_list[index][-1][:-5].split('/')[-1]\n        folder0 = self.label_list[index][0][:-4].split('/')[-2]\n        folder1 = self.label_list[index][-1][:-4].split('/')[-2]\n\n\n        imgt = [cv2.imread(self.image_list[index][ii]) for ii in range(0, len(self.image_list[index]))]\n\n        labelt = []\n        for ii in range(0, len(self.label_list[index])):\n            v, t, id = read_json(self.label_list[index][ii])\n            v[v > imgt[0].shape[0] - 1] = imgt[0].shape[0] - 1\n            v[v < 0] = 0\n            labelt.append({'keypoints': v.astype(int), 'topo': t, 'id': id})\n\n        # make motion pseudo label\n\n        ###### prepare other data\n        img2 = imgt[-1]\n        img1 = imgt[0] \n\n        v2d2 = labelt[-1]['keypoints'].astype(int)\n        v2d1 = labelt[0]['keypoints'].astype(int)\n\n        topo2 = labelt[-1]['topo']\n        topo1 = labelt[0]['topo']\n\n        m, n = len(v2d1), len(v2d2)\n\n        if len(img1.shape) == 2:\n            img1 = np.tile(img1[...,None], (1, 1, 3))\n            img2 = np.tile(img2[...,None], (1, 1, 3))\n        else:\n            img1 = img1[..., :3]\n            img2 = img2[..., :3]\n\n        img1 = torch.from_numpy(img1).permute(2, 0, 1).float() * 2 / 255.0 - 1.0 \n        img2 = torch.from_numpy(img2).permute(2, 0, 1).float() * 2 / 255.0 - 1.0\n\n        v2d1 = torch.from_numpy(v2d1)\n        v2d2 = torch.from_numpy(v2d2)\n\n        mask0, mask1 = torch.ones(m).float(), torch.ones(n).float()\n\n        v2d1[v2d1 > imgt[0].shape[0] - 1 ] = imgt[0].shape[0] - 1\n        v2d1[v2d1 < 0] = 0\n        v2d2[v2d2 > imgt[0].shape[1] - 1] = imgt[0].shape[1] - 1\n        v2d2[v2d2 < 0] = 0\n\n     \n        id1 = np.arange(len(v2d1))\n        id2 = np.arange(len(v2d2))\n\n       \n        for ii in range(len(topo1)):\n            topo1[ii].append(ii)\n        for ii in range(len(topo2)):\n            topo2[ii].append(ii)\n        adj1 = sknetwork.data.from_adjacency_list(topo1, matrix_only=True, reindex=False).toarray()\n        adj2 = sknetwork.data.from_adjacency_list(topo2, matrix_only=True, reindex=False).toarray()\n\n        try:\n            spec0, spec1 = np.abs(self.spectral.fit_transform(adj1)), np.abs(self.spectral.fit_transform(adj2))\n        except:\n            print('>>>>' + file_name, flush=True)\n            spec0, spec1 = np.zeros((len(adj1), 64)), np.zeros((len(adj2), 64))\n\n        return{\n            'keypoints0': v2d1,\n            'keypoints1': v2d2,\n            'topo0': [topo1],\n            'topo1': [topo2],\n            'adj_mat0': adj1,\n            'adj_mat1': adj2,\n            'spec0': spec0,\n            'spec1': spec1,\n            'image0': img1,\n            'image1': img2,\n            'ms': m,\n            'ns': n,\n            'mask0': mask0,\n            'mask1': mask1,\n            'gen_vid': True,\n            'file_name0': file_name0,\n            'file_name1': file_name1,\n            'folder_name0': folder0,\n            'folder_name1': folder1\n        }\n\n\n    def __rmul__(self, v):\n        self.label_list = v * self.label_list\n        self.image_list = v * self.image_list\n        return self\n        \n    def __len__(self):\n        return len(self.image_list)\n        \n\ndef worker_init_fn(worker_id):                                                          \n    np.random.seed(np.random.get_state()[1][0] + worker_id)\n\ndef fetch_videoloader(args, type='train',):\n    lineart = VideoLinSeq(root=args.root, split=args.type, )\n    \n    loader = data.DataLoader(lineart, batch_size=args.batch_size, \n            pin_memory=True, shuffle=False, num_workers=8)\n    return loader\n\n"
  },
  {
    "path": "download.sh",
    "content": "cd data\ngdown 1SNRGajIECxNwRp6ZJ0IlY7AEl2mRm2DR\nunzip ml240data.zip\n"
  },
  {
    "path": "experiments/inbetweener_full/ckpt/.gitkeep",
    "content": ""
  },
  {
    "path": "inbetween.py",
    "content": "\"\"\" This script handling the training process. \"\"\"\nimport os\nimport time\nimport random\nimport argparse\nimport torch\nimport torch.nn as nn\nimport torch.optim as optim\nimport torch.utils.data\nfrom datasets import fetch_dataloader\nfrom datasets import fetch_videoloader\nimport random\nfrom utils.log import Logger\n\nfrom torch.optim import *\nimport warnings\nfrom tqdm import tqdm\nimport itertools\nimport pdb\nimport numpy as np\nimport models\nimport datetime\nimport sys\nimport json\nimport cv2\n\nfrom utils.visualize_inbetween3 import visualize\n# from utils.visualize_inbetween import visualize\nfrom utils.visualize_video import visvid as visgen\nimport matplotlib.cm as cm\n# from models.utils import make_matching_seg_plot\n\nwarnings.filterwarnings('ignore')\n\n# a, b, c, d = check_data_distribution('/mnt/lustre/lisiyao1/dance/dance2/DanceRevolution/data/aistpp_train')\n\nimport matplotlib.pyplot as plt\nimport pdb\n\nclass DraftRefine():\n    def __init__(self, args):\n        self.config = args\n        torch.backends.cudnn.benchmark = True\n        torch.multiprocessing.set_sharing_strategy('file_system')\n        self._build()\n\n    def train(self):\n        \n        opt = self.config\n        print(opt)\n\n        # store viz results\n        # eval_output_dir = Path(self.expdir)\n        # eval_output_dir.mkdir(exist_ok=True, parents=True)\n\n        # print('Will write visualization images to',\n        #     'directory \\\"{}\\\"'.format(eval_output_dir))\n\n        # load training data\n        \n        model = self.model\n\n        checkpoint = torch.load(self.config.corr_weights)\n        dict = {k.replace('module.', ''): checkpoint['model'][k] for k in checkpoint['model']}\n        model.module.corr.load_state_dict(dict)\n\n        if hasattr(self.config, 'init_weight'):\n            checkpoint = torch.load(self.config.init_weight)\n            model.load_state_dict(checkpoint['model'])\n\n        # if torch.cuda.is_available():\n        #     model.cuda() # make sure it trains on GPU\n        # else:\n        #     print(\"### CUDA not available ###\")\n            # return\n        optimizer = self.optimizer\n        schedular = self.schedular\n        mean_loss = []\n        log = Logger(self.config, self.expdir)\n        updates = 0\n        \n        # set seed\n        random.seed(opt.seed)\n        torch.manual_seed(opt.seed)\n        torch.cuda.manual_seed(opt.seed)\n        np.random.seed(opt.seed)\n        # print(opt.seed)\n        # start training\n\n        for epoch in range(1, opt.epoch+1):\n            np.random.seed(opt.seed + epoch)\n            train_loader = self.train_loader\n            log.set_progress(epoch, len(train_loader))\n            batch_loss = 0\n            batch_epe = 0 \n            batch_acc = 0\n            batch_iter = 0\n            model.train()\n            avg_time = 0\n            avg_num = 0\n            # torch.cuda.synchronize()\n            \n            for i, data in enumerate(train_loader):\n                pred = model(data)\n                if True:\n                    loss = pred['loss'].mean() \n                    # print(loss.item(), opt.batch_size)\n                    batch_loss += loss.item() / opt.batch_size\n                    batch_acc += pred['Visibility Acc'].mean().item() / opt.batch_size\n                    batch_epe += pred['EPE'].mean().item() / opt.batch_size \n                    loss.backward()\n                    batch_iter += 1\n                else:\n                    print('Skip!')\n\n\n\n                if ((i + 1) % opt.batch_size == 0) or (i + 1 == len(train_loader)):\n                    optimizer.step()\n\n                    optimizer.zero_grad()\n                    batch_iter = 1 if batch_iter == 0 else batch_iter               \n                    stats = {\n                        'updates': updates,\n                        'loss': batch_loss,\n                        'accuracy': batch_acc,\n                        'EPE': batch_epe\n                    }\n                    log.update(stats)\n                    updates += 1\n                    batch_loss = 0\n                    batch_acc = 0 \n                    batch_epe = 0\n                    batch_iter = 0\n                # tend = time.time()\n                # avg_time = (tend - tstart)\n                # print('Time is ', avg_time)\n\n            # torch.cuda.synchronize()\n\n            # avg_num += 1\n                # for name, params in model.named_parameters():\n                #     print('-->name:, ', name, '-->grad mean', params.grad.mean())\n            # print(\"All time is \", avg_time, \"AVG time is \", avg_time * 1.0 /avg_num,  \"number is \", avg_num, flush=True)\n\n            # save checkpoint \n            if epoch % opt.save_per_epochs == 0 or epoch == 1:\n                checkpoint = {\n                    'model': model.state_dict(),\n                    'config': opt,\n                    'epoch': epoch\n                }\n\n                filename = os.path.join(self.ckptdir, f'epoch_{epoch}.pt')\n                torch.save(checkpoint, filename)\n                \n            # validate\n            if epoch % opt.test_freq == 0:\n\n                if not os.path.exists(os.path.join(self.visdir, 'epoch' + str(epoch))):\n                    os.mkdir(os.path.join(self.visdir, 'epoch' + str(epoch)))\n                eval_output_dir = os.path.join(self.visdir, 'epoch' + str(epoch))    \n                \n                test_loader = self.test_loader\n\n                with torch.no_grad():\n                    # Visualize the matches.\n                    mean_acc = []\n                    mean_epe = []\n                    model.eval()\n                    for i_eval, data in enumerate(tqdm(test_loader, desc='Refining motion and visibility...')):\n                        pred = model(data)\n                        # for k, v in data.items():\n                        #     pred[k] = v[0]\n                        #     pred = {**pred, **data}\n\n                        mean_acc.append(pred['Visibility Acc'].mean().item())\n                        mean_epe.append(pred['EPE'].mean().item())\n                    log.log_eval({\n                        'updates': opt.epoch,\n                        'Visibility Accuracy': np.mean(mean_acc),\n                        'EPE': np.mean(mean_epe),\n                        })\n                    print('Epoch [{}/{}]], Vis Acc.: {:.4f}, EPE: {:.4f}' \n                        .format(epoch, opt.epoch, np.mean(mean_acc), np.mean(mean_epe)) )\n                    sys.stdout.flush()\n                        # make_matching_plot(\n                        #     image0, image1, kpts0, kpts1, mkpts0, mkpts1, color,\n                        #     text, viz_path, stem, stem, True,\n                        #     True, False, 'Matches')\n                        \n\n            self.schedular.step()\n            \n\n\n    def eval(self):\n        train_action = ['breakdance_1990', 'capoeira', 'chapa-giratoria', 'fist_fight', 'flying_knee', 'freehang_climb', 'running', 'shove', 'magic', 'tripping']\n        test_action = ['great_sword_slash', 'hip_hop_dancing']\n\n        train_model = ['ganfaul', 'girlscout', 'jolleen', 'kachujin', 'knight', 'maria_w_jj', 'michelle', 'peasant_girl', 'timmy', 'uriel_a_plotexia']\n        test_model = ['police', 'warrok']\n\n        config = self.config\n        if not os.path.exists(config.imwrite_dir):\n            os.mkdir(config.imwrite_dir)\n            \n        log = Logger(self.config, self.expdir)\n        with torch.no_grad():\n            model = self.model.eval()\n            config = self.config\n            epoch_tested = self.config.testing.ckpt_epoch\n            if epoch_tested == 0 or epoch_tested == '0':\n                checkpoint = torch.load(self.config.corr_weights)\n                dict = {k.replace('module.', ''): checkpoint['model'][k] for k in checkpoint['model']}\n                model.module.corr.load_state_dict(dict)\n            else:\n                ckpt_path = os.path.join(self.ckptdir, f\"epoch_{epoch_tested}.pt\")\n                # self.device = torch.device('cuda' if config.cuda else 'cpu')\n                print(\"Evaluation...\")\n                checkpoint = torch.load(ckpt_path)\n                model.load_state_dict(checkpoint['model'])\n            model.eval()\n\n            if not os.path.exists(os.path.join(self.evaldir, 'epoch' + str(epoch_tested))):\n                os.mkdir(os.path.join(self.evaldir, 'epoch' + str(epoch_tested)))\n            if not os.path.exists(os.path.join(self.evaldir, 'epoch' + str(epoch_tested), 'jsons')):\n                os.mkdir(os.path.join(self.evaldir, 'epoch' + str(epoch_tested), 'jsons'))\n            eval_output_dir = os.path.join(self.evaldir, 'epoch' + str(epoch_tested))    \n                \n            test_loader = self.test_loader\n            print(len(test_loader))\n            mean_acc = []\n            mean_valid_acc = []\n            mean_invalid_acc = []\n\n            # 144 data 10x10 is for training , 2x10 (unseen model) + 10x2 (unseen action) + 2x2 (unseen model unseen action) is for test\n            # record the accuracy for \n            mean_model_acc = []\n            mean_model_epe = []\n            mean_action_acc = []\n            mean_action_epe = []\n            \n            mean_none_acc = []\n            mean_none_epe = []\n\n            mean_acc = []\n            mean_epe = []\n\n            mean_cd = []\n            model.eval()\n            # for i_eval, data in enumerate(tqdm(test_loader, desc='Refining motion and visibility...')):\n            #     pred = model(data)\n            #     # for k, v in data.items():\n            #     #     pred[k] = v[0]\n            #     #     pred = {**pred, **data}\n\n            #     mean_acc.append(pred['Visibility Acc'].mean().item())\n            #     mean_epe.append(pred['EPE'].mean().item())\n            # log.log_eval({\n            #     'updates': opt.epoch,\n            #     'Visibility Accuracy': np.mean(mean_acc),\n            #     'EPE': np.mean(mean_epe),\n            #     })\n\n            for i_eval, data in enumerate(tqdm(test_loader, desc='Predicting Vtx Corr...')):\n                # if i_eval == 34:\n                #     continue\n                \n                pred = model(data)\n                for k, v in pred.items():\n                    # print(k, flush=True)\n                    pred[k] = v\n                    pred = {**pred, **data}\n            \n                mean_acc.append(pred['Visibility Acc'].mean().item())\n                mean_epe.append(pred['EPE'].mean().item())\n\n                unmarked = True\n                for model_name in train_model:\n                    if model_name in pred['file_name']:\n                        mean_model_acc.append(pred['Visibility Acc'])\n                        mean_model_epe.append(pred['EPE'])\n                        unmarked = False\n                        break\n\n                for action_name in train_action:\n                    if action_name in pred['file_name']:\n                        mean_action_acc.append(pred['Visibility Acc'])\n                        mean_action_epe.append(pred['EPE'])\n                        unmarked = False\n                        break\n                \n                if unmarked:\n                    mean_none_acc.append(pred['Visibility Acc'])\n                    mean_action_epe.append(pred['EPE'])\n\n                # if 'invalid_accuracy' in pred and pred['invalid_accuracy'] is not None:\n                #     mean_invalid_acc.append(pred['invalid_accuracy'])\n                \n                img_vis = visualize(pred)\n                # mean_cd.append(cd.item())\n                file_name = pred['file_name'][0].split('/')\n                cv2.imwrite(os.path.join(config.imwrite_dir, (file_name[-2] + '_' + file_name[-1]) + 'png'), img_vis)\n\n                # cv2.imwrite(os.path.join(eval_output_dir, pred['file_name'][0].replace('/', '_') + '.jpg'), img_vis)\n                \n            log.log_eval({\n                'updates': self.config.testing.ckpt_epoch,\n                # 'mean CD': np.mean(mean_cd),\n                # 'Visibility Accuracy': np.mean(mean_acc),\n                # 'EPE': np.mean(mean_epe),\n                # 'Unseen Action Accuracy': np.mean(mean_model_acc),\n                # 'Unseen Action EPE': np.mean(mean_model_epe),\n                # 'Unseen Model Accuracy': np.mean(mean_action_acc),\n                # 'Unseen Model EPE': np.mean(mean_action_epe),\n                # 'Unseen Both Accuracy': np.mean(mean_none_acc),\n                # 'Unseen Both Valid Accuracy': np.mean(mean_none_epe)\n                })\n                # print ('Epoch [{}/{}]], Acc.: {:.4f}, Valid Acc.{:.4f}' \n                #     .format(epoch, opt.epoch, np.mean(mean_acc), np.mean(mean_valid_acc)) )\n            sys.stdout.flush()\n\n\n    def gen(self):\n        log = Logger(self.config, self.viddir)\n        with torch.no_grad():\n            model = self.model.eval()\n            config = self.config\n            epoch_tested = self.config.testing.ckpt_epoch\n            if epoch_tested == 0 or epoch_tested == '0':\n                checkpoint = torch.load(self.config.corr_weights)\n                dict = {k.replace('module.', ''): checkpoint['model'][k] for k in checkpoint['model']}\n                model.module.corr.load_state_dict(dict)\n            else:\n                ckpt_path = os.path.join(self.ckptdir, f\"epoch_{epoch_tested}.pt\")\n                # self.device = torch.device('cuda' if config.cuda else 'cpu')\n                print(\"Evaluation...\")\n                checkpoint = torch.load(ckpt_path)\n                model.load_state_dict(checkpoint['model'])\n            model.eval()\n\n            if not os.path.exists(os.path.join(self.viddir, 'epoch' + str(epoch_tested))):\n                os.mkdir(os.path.join(self.viddir, 'epoch' + str(epoch_tested)))\n            if not os.path.exists(os.path.join(self.viddir, 'epoch' + str(epoch_tested), 'frames')):\n                os.mkdir(os.path.join(self.viddir, 'epoch' + str(epoch_tested), 'frames'))\n            if not os.path.exists(os.path.join(self.viddir, 'epoch' + str(epoch_tested), 'videos')):\n                os.mkdir(os.path.join(self.viddir, 'epoch' + str(epoch_tested), 'videos'))\n\n            gen_frame_dir = os.path.join(self.viddir, 'epoch' + str(epoch_tested), 'frames')  \n            gen_video_dir = os.path.join(self.viddir, 'epoch' + str(epoch_tested), 'videos')    \n                \n            vid_loader = self.vid_loader\n            print(len(vid_loader))\n            mean_acc = []\n            mean_valid_acc = []\n            mean_invalid_acc = []\n\n            model.eval()\n\n            for i_eval, data in enumerate(tqdm(vid_loader, desc='Gen Video...')):\n                \n                pred = model(data)\n                for k, v in pred.items():\n                    pred[k] = v\n                    pred = {**pred, **data}\n            \n\n                img_vis = visgen(pred, config.inter_frames)\n\n                if not os.path.exists(os.path.join(gen_frame_dir, pred['folder_name0'][0])):\n                    os.mkdir(os.path.join(gen_frame_dir, pred['folder_name0'][0]))\n                \n                cv2.imwrite(os.path.join(gen_frame_dir, pred['folder_name0'][0], pred['file_name0'][0] + '_000.jpg'),img_vis[0])\n                for tt in range(config.inter_frames):\n                    cv2.imwrite(os.path.join(gen_frame_dir, pred['folder_name0'][0], pred['file_name0'][0] + '_' + '{:03d}'.format(tt + 1) + '.jpg'), img_vis[tt + 1])\n                cv2.imwrite(os.path.join(gen_frame_dir, pred['folder_name0'][0], pred['file_name1'][0] + '_000.jpg'),img_vis[-1])\n            \n            for ff in os.listdir(gen_frame_dir):\n                frame_dir = os.path.join(gen_frame_dir, ff)\n                video_file = os.path.join(gen_video_dir, f\"{ff}.mp4\")\n                cmd = f\"ffmpeg -r {config.fps} -pattern_type glob -i '{frame_dir}/*.jpg' -vb 20M -vcodec mpeg4 -y '{video_file}'\"\n                \n                print(cmd, flush=True)\n                os.system(cmd)\n                \n\n            log.log_eval({\n                'updates': self.config.testing.ckpt_epoch,\n                })\n            sys.stdout.flush()\n\n    def _build(self):\n        config = self.config\n        self.start_epoch = 0\n        self._dir_setting()\n        self._build_model()\n        if not(hasattr(config, 'need_not_train_data') and config.need_not_train_data):\n            self._build_train_loader()\n        if not(hasattr(config, 'need_not_test_data') and config.need_not_train_data):      \n            self._build_test_loader()\n        if hasattr(config, 'gen_video') and config.gen_video:\n            self._build_video_loader()\n        self._build_optimizer()\n\n    def _build_model(self):\n        \"\"\" Define Model \"\"\"\n        config = self.config \n        if hasattr(config.model, 'name'):\n            print(f'Experiment Using {config.model.name}')\n            model_class = getattr(models, config.model.name)\n            model = model_class(config.model)\n        else:\n            raise NotImplementedError(\"Wrong Model Selection\")\n        \n        model = nn.DataParallel(model)\n        self.model = model.cuda()\n\n    def _build_train_loader(self):\n        config = self.config\n        self.train_loader = fetch_dataloader(config.data.train, type='train')\n\n    def _build_test_loader(self):\n        config = self.config\n        self.test_loader = fetch_dataloader(config.data.test, type='test')\n    def _build_video_loader(self):\n        config = self.config\n        self.vid_loader = fetch_videoloader(config.video)\n\n    def _build_optimizer(self):\n        #model = nn.DataParallel(model).to(device)\n        config = self.config.optimizer\n        try:\n            optim = getattr(torch.optim, config.type)\n        except Exception:\n            raise NotImplementedError('not implemented optim method ' + config.type)\n\n        self.optimizer = optim(itertools.chain(self.model.module.parameters(),\n                                             ),\n                                             **config.kwargs)\n        self.schedular = torch.optim.lr_scheduler.MultiStepLR(self.optimizer, **config.schedular_kwargs)\n\n    def _dir_setting(self):\n        self.expname = self.config.expname\n        # self.experiment_dir = os.path.join(\"/mnt/cache/syli/inbetween\", \"experiments\")\n\n        self.experiment_dir = 'experiments'\n        self.expdir = os.path.join(self.experiment_dir, self.expname)\n\n        if not os.path.exists(self.expdir):\n            os.mkdir(self.expdir)\n\n        self.visdir = os.path.join(self.expdir, \"vis\")  # -- imgs, videos, jsons\n        if not os.path.exists(self.visdir):\n            os.mkdir(self.visdir)\n\n        self.ckptdir = os.path.join(self.expdir, \"ckpt\")\n        if not os.path.exists(self.ckptdir):\n            os.mkdir(self.ckptdir)\n\n        self.evaldir = os.path.join(self.expdir, \"eval\")\n        if not os.path.exists(self.evaldir):\n            os.mkdir(self.evaldir)\n\n        self.viddir = os.path.join(self.expdir, \"video\")\n        if not os.path.exists(self.viddir):\n            os.mkdir(self.viddir)\n\n        \n\n        # self.ckptdir = os.path.join(self.expdir, \"ckpt\")\n        # if not os.path.exists(self.ckptdir):\n        #     os.mkdir(self.ckptdir)\n\n\n\n        \n\n\n\n\n"
  },
  {
    "path": "inbetween_results/.gitkeep",
    "content": ""
  },
  {
    "path": "main.py",
    "content": "from inbetween import DraftRefine\nimport argparse\nimport os\nimport yaml\nfrom pprint import pprint\nfrom easydict import EasyDict\n\n\n\ndef parse_args():\n    parser = argparse.ArgumentParser(\n        description='Anime segment matching')\n    parser.add_argument('--config', default='')\n    # exclusive arguments\n    group = parser.add_mutually_exclusive_group(required=True)\n    group.add_argument('--train', action='store_true')\n    group.add_argument('--eval', action='store_true')\n    group.add_argument('--gen', action='store_true')\n\n\n    return parser.parse_args()\n\n\ndef main():\n    # parse arguments and load config\n    args = parse_args()\n    with open(args.config) as f:\n        config = yaml.load(f)\n\n    for k, v in vars(args).items():\n        config[k] = v\n    pprint(config)\n\n    config = EasyDict(config)\n    agent = DraftRefine(config)\n    print(config)\n\n    if args.train:\n        agent.train()\n    elif args.eval:\n        agent.eval()\n    elif args.gen:\n        agent.gen()\n\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "models/__init__.py",
    "content": "# from .transformer_refiner import Refiner\n# from .inbetweener import Inbetweener\n# from .inbetweener_with_mask import InbetweenerM\n# from .inbetweener_wo_rp import InbetweenerM as InbetweenerNRP\nfrom .inbetweener_with_mask_with_spec import InbetweenerTM\n# from .inbetweener_with_mask_with_spec_wo_OT import InbetweenerTMwoOT\nfrom .inbetweener_with_mask2 import InbetweenerM as InbetweenerM2\n# from .inbetweener_with_mask_wo_pos import InbetweenerNP\n# from .inbetweener_with_mask_wo_pos_wo_spec import InbetweenerNPS\n# from .transformer_refiner2 import Refiner as Refiner2\n# from .transformer_refiner3 import Refiner as Refiner3\n# from .transformer_refiner4 import Refiner as Refiner4\n# from .transformer_refiner5 import Refiner as Refiner5\n# from .transformer_refiner_norm import Refiner as RefinerN\n\n__all__ = [ 'InbetweenerTM', 'InbetweenerM2']\n"
  },
  {
    "path": "models/inbetweener_with_mask2.py",
    "content": "from copy import deepcopy\nfrom pathlib import Path\nimport torch\nfrom torch import nn\n# from seg_desc import seg_descriptor\nimport argparse\nimport torch.nn.functional as F\n\ndef MLP(channels: list, do_bn=True):\n    \"\"\" Multi-layer perceptron \"\"\"\n    n = len(channels)\n    layers = []\n    for i in range(1, n):\n        layers.append(\n            nn.Conv1d(channels[i - 1], channels[i], kernel_size=1, bias=True))\n        if i < (n-1):\n            if do_bn:\n                # layers.append(nn.BatchNorm1d(channels[i]))\n                layers.append(nn.InstanceNorm1d(channels[i]))\n            layers.append(nn.ReLU())\n    return nn.Sequential(*layers)\n\n\ndef normalize_keypoints(kpts, image_shape):\n    \"\"\" Normalize keypoints locations based on image image_shape\"\"\"\n    _, _, height, width = image_shape\n    one = kpts.new_tensor(1)\n    size = torch.stack([one*width, one*height])[None]\n    center = size / 2\n    scaling = size.max(1, keepdim=True).values * 0.7\n    return (kpts - center[:, None, :]) / scaling[:, None, :]\n\nclass ThreeLayerEncoder(nn.Module):\n    \"\"\" Joint encoding of visual appearance and location using MLPs\"\"\"\n    def __init__(self, enc_dim):\n        super().__init__()\n        # input must be 3 channel (r, g, b)\n        self.layer1 = nn.Conv2d(3, enc_dim//4, 7, padding=3)\n        self.non_linear1 = nn.ReLU()\n        self.layer2 = nn.Conv2d(enc_dim//4, enc_dim//2, 3, padding=1)\n        self.non_linear2 = nn.ReLU()\n        self.layer3 = nn.Conv2d(enc_dim//2, enc_dim, 3, padding=1)\n\n        self.norm1 = nn.InstanceNorm2d(enc_dim//4)\n        self.norm2 = nn.InstanceNorm2d(enc_dim//2)\n        self.norm3 = nn.InstanceNorm2d(enc_dim)\n\n        for m in self.modules():\n            if isinstance(m, nn.Conv2d):\n                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')\n                nn.init.constant_(m.bias, 0.0)\n\n    def forward(self, img):\n        x = self.non_linear1(self.norm1(self.layer1(img)))\n        x = self.non_linear2(self.norm2(self.layer2(x)))\n        x = self.norm3(self.layer3(x))\n        # x = self.non_linear1(self.layer1(img))\n        # x = self.non_linear2(self.layer2(x))\n        # x = self.layer3(x)\n        return x\n\n\nclass VertexDescriptor(nn.Module):\n    \"\"\" Joint encoding of visual appearance and location using MLPs\"\"\"\n    def __init__(self, enc_dim):\n        super().__init__()\n        self.encoder = ThreeLayerEncoder(enc_dim)\n        # self.super_pixel_pooling = \n        # use scatter\n        # nn.init.constant_(self.encoder[-1].bias, 0.0)\n\n    def forward(self, img, vtx):\n        x = self.encoder(img)\n        n, c, h, w = x.size()\n        assert((h, w) == img.size()[2:4])\n        return x[:, :, torch.round(vtx[0, :, 1]).long(), torch.round(vtx[0, :, 0]).long()]\n        # return super_pixel_pooling(x.view(n, c, -1), seg.view(-1).long(), reduce='mean')\n        # here return size is [1]xCx|Seg|\n\n\nclass KeypointEncoder(nn.Module):\n    \"\"\" Joint encoding of visual appearance and location using MLPs\"\"\"\n    def __init__(self, feature_dim, layers):\n        super().__init__()\n        self.encoder = MLP([2] + layers + [feature_dim])\n        # for m in self.encoder.modules():\n        #     if isinstance(m, nn.Conv2d):\n        #         nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')\n        #         nn.init.constant_(m.bias, 0.0)\n        nn.init.constant_(self.encoder[-1].bias, 0.0)\n\n    def forward(self, kpts):\n        inputs = kpts.transpose(1, 2)\n        # print(inputs.size(), 'wula!')\n        x = self.encoder(inputs)\n        # print(x.size())\n        return x\n\n\ndef attention(query, key, value, mask=None):\n    dim = query.shape[1]\n    scores = torch.einsum('bdhn,bdhm->bhnm', query, key) / dim**.5\n    if mask is not None:\n        # print(mask, flush=True)\n        scores = scores.masked_fill(mask==0, float('-inf'))\n\n    # att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))\n    # att = att.masked_fill(self.mask[:,:,:T,:T] == 0, float('-inf'))\n    # att = F.softmax(att, dim=-1)\n    prob = torch.nn.functional.softmax(scores, dim=-1)\n\n    # print(scores[1][1], prob[1][1], flush=True)\n    # while True:\n    #     pass \n    # prob = torch.exp(scores) /((torch.sum(torch.exp(scores), dim=-1)[:, :, :, None]) + 1e-7)\n    return torch.einsum('bhnm,bdhm->bdhn', prob, value), prob\n\n\nclass MultiHeadedAttention(nn.Module):\n    \"\"\" Multi-head attention to increase model expressivitiy \"\"\"\n    def __init__(self, num_heads: int, d_model: int):\n        super().__init__()\n        assert d_model % num_heads == 0\n        self.dim = d_model // num_heads\n        self.num_heads = num_heads\n        self.merge = nn.Conv1d(d_model, d_model, kernel_size=1)\n        self.proj = nn.ModuleList([deepcopy(self.merge) for _ in range(3)])\n\n    def forward(self, query, key, value, mask=None):\n        batch_dim = query.size(0)\n        query, key, value = [l(x).view(batch_dim, self.dim, self.num_heads, -1)\n                             for l, x in zip(self.proj, (query, key, value))]\n        x, prob = attention(query, key, value, mask)\n        # self.prob.append(prob)\n        return self.merge(x.contiguous().view(batch_dim, self.dim*self.num_heads, -1))\n\n\nclass AttentionalPropagation(nn.Module):\n    def __init__(self, feature_dim: int, num_heads: int):\n        super().__init__()\n        self.attn = MultiHeadedAttention(num_heads, feature_dim)\n        self.mlp = MLP([feature_dim*2, feature_dim*2, feature_dim])\n        nn.init.constant_(self.mlp[-1].bias, 0.0)\n\n    def forward(self, x, source, mask=None):\n        message = self.attn(x, source, source, mask)\n        return self.mlp(torch.cat([x, message], dim=1))\n\n\nclass AttentionalGNN(nn.Module):\n    def __init__(self, feature_dim: int, layer_names: list):\n        super().__init__()\n        self.layers = nn.ModuleList([\n            AttentionalPropagation(feature_dim, 4)\n            for _ in range(len(layer_names))])\n        self.names = layer_names\n\n    def forward(self, desc0, desc1, mask00=None, mask11=None, mask01=None, mask10=None):\n        for layer, name in zip(self.layers, self.names):\n            layer.attn.prob = []\n            if name == 'cross':\n                src0, src1 = desc1, desc0\n                mask0, mask1 = mask01[:, None], mask10[:, None] \n            else:  # if name == 'self':\n                src0, src1 = desc0, desc1\n                mask0, mask1 = mask00[:, None], mask11[:, None]\n\n            delta0, delta1 = layer(desc0, src0, mask0), layer(desc1, src1, mask1)\n            desc0, desc1 = (desc0 + delta0), (desc1 + delta1)\n        return desc0, desc1\n\n\ndef log_sinkhorn_iterations(Z, log_mu, log_nu, iters: int):\n    \"\"\" Perform Sinkhorn Normalization in Log-space for stability\"\"\"\n    u, v = torch.zeros_like(log_mu), torch.zeros_like(log_nu)\n    for _ in range(iters):\n        u = log_mu - torch.logsumexp(Z + v.unsqueeze(1), dim=2)\n        v = log_nu - torch.logsumexp(Z + u.unsqueeze(2), dim=1)\n    return Z + u.unsqueeze(2) + v.unsqueeze(1)\n\n\ndef log_optimal_transport(scores, alpha, iters: int, ms=None, ns=None):\n    \"\"\" Perform Differentiable Optimal Transport in Log-space for stability\"\"\"\n    b, m, n = scores.shape\n    one = scores.new_tensor(1)\n    if ms is  None or ns is  None:\n        ms, ns = (m*one).to(scores), (n*one).to(scores)\n    # else:\n    #     ms, ns = ms.to(scores)[:, None], ns.to(scores)[:, None]\n    # here m,n should be parameters not shape\n\n    # ms, ns: (b, )\n    bins0 = alpha.expand(b, m, 1)\n    bins1 = alpha.expand(b, 1, n)\n    alpha = alpha.expand(b, 1, 1)\n\n    # pad additional scores for unmatcheed (to -1)\n    # alpha is the learned threshold\n    couplings = torch.cat([torch.cat([scores, bins0], -1),\n                           torch.cat([bins1, alpha], -1)], 1)\n\n    norm = - (ms + ns).log() # (b, )\n    # print(scores.min(), flush=True)\n    if ms.size()[0] > 0:\n        norm = norm[:, None]\n        log_mu = torch.cat([norm.expand(b, m), ns.log()[:, None] + norm], dim=-1) # (m + 1)\n        log_nu = torch.cat([norm.expand(b, n), ms.log()[:, None] + norm], dim=-1)\n        # print(log_nu.min(), log_mu.min(), flush=True)\n    else:\n        log_mu = torch.cat([norm.expand(m), ns.log()[None] + norm]) # (m + 1)\n        log_nu = torch.cat([norm.expand(n), ms.log()[None] + norm])\n        log_mu, log_nu = log_mu[None].expand(b, -1), log_nu[None].expand(b, -1)\n\n    \n    Z = log_sinkhorn_iterations(couplings, log_mu, log_nu, iters)\n\n    if ms.size()[0] > 1:\n        norm = norm[:, :, None]\n    Z = Z - norm  # multiply probabilities by M+N\n    return Z\n\n\ndef arange_like(x, dim: int):\n    return x.new_ones(x.shape[dim]).cumsum(0) - 1  # traceable in 1.1\n\n\nclass SuperGlueM(nn.Module):\n    \"\"\"SuperGlue feature matching middle-end\n\n    Given two sets of keypoints and locations, we determine the\n    correspondences by:\n      1. Keypoint Encoding (normalization + visual feature and location fusion)\n      2. Graph Neural Network with multiple self and cross-attention layers\n      3. Final projection layer\n      4. Optimal Transport Layer (a differentiable Hungarian matching algorithm)\n      5. Thresholding matrix based on mutual exclusivity and a match_threshold\n\n    The correspondence ids use -1 to indicate non-matching points.\n\n    Paul-Edouard Sarlin, Daniel DeTone, Tomasz Malisiewicz, and Andrew\n    Rabinovich. SuperGlue: Learning Feature Matching with Graph Neural\n    Networks. In CVPR, 2020. https://arxiv.org/abs/1911.11763\n\n    \"\"\"\n    # default_config = {\n    #     'descriptor_dim': 128,\n    #     'weights': 'indoor',\n    #     'keypoint_encoder': [32, 64, 128],\n    #     'GNN_layers': ['self', 'cross'] * 9,\n    #     'sinkhorn_iterations': 100,\n    #     'match_threshold': 0.2,\n    # }\n\n    def __init__(self, config=None):\n        super().__init__()\n\n        default_config = argparse.Namespace()\n        default_config.descriptor_dim = 128\n        # default_config.weights = \n        default_config.keypoint_encoder = [32, 64, 128]\n        default_config.GNN_layers = ['self', 'cross'] * 9\n        default_config.sinkhorn_iterations = 100\n        default_config.match_threshold = 0.2\n        # self.config = {**self.default_config, **config}\n\n        if config is None:\n            self.config = default_config\n        else:\n            self.config = config   \n            self.config.GNN_layers = ['self', 'cross'] * self.config.GNN_layer_num\n            # print('WULA!', self.config.GNN_layer_num)\n\n        self.kenc = KeypointEncoder(\n            self.config.descriptor_dim, self.config.keypoint_encoder)\n\n        self.gnn = AttentionalGNN(\n            self.config.descriptor_dim, self.config.GNN_layers)\n\n        self.final_proj = nn.Conv1d(\n            self.config.descriptor_dim, self.config.descriptor_dim,\n            kernel_size=1, bias=True)\n\n        bin_score = torch.nn.Parameter(torch.tensor(1.))\n        self.register_parameter('bin_score', bin_score)\n        self.vertex_desc = VertexDescriptor(self.config.descriptor_dim)\n\n        # assert self.config.weights in ['indoor', 'outdoor']\n        # path = Path(__file__).parent\n        # path = path / 'weights/superglue_{}.pth'.format(self.config.weights)\n        # self.load_state_dict(torch.load(path))\n        # print('Loaded SuperGlue model (\\\"{}\\\" weights)'.format(\n        #     self.config.weights))\n\n    def forward(self, data):\n        \"\"\"Run SuperGlue on a pair of keypoints and descriptors\"\"\"\n        # print(data['segment0'].size())\n        # desc0, desc1 = data['descriptors0'].float()(), data['descriptors1'].float()()\n         # print(desc0.size())\n        kpts0, kpts1 = data['keypoints0'].float(), data['keypoints1'].float()\n\n        ori_mask0, ori_mask1 = data['mask0'].float(), data['mask1'].float()\n        dim_m, dim_n = data['ms'].float(), data['ns'].float()\n\n        mmax = dim_m.int().max()\n        nmax = dim_n.int().max()\n\n        mask0 = ori_mask0[:, :mmax]\n        mask1 = ori_mask1[:, :nmax]\n\n        kpts0 = kpts0[:, :mmax]\n        kpts1 = kpts1[:, :nmax]\n\n        desc0, desc1 = self.vertex_desc(data['image0'], kpts0.float()), self.vertex_desc(data['image1'], kpts1.float())\n        \n       \n        # print(desc0.size(), flush=True)\n\n        mask00 = torch.ones_like(mask0)[:, :, None] * mask0[:, None, :]\n        # print(mask00[1], flush=True)\n        \n        mask11 = torch.ones_like(mask1)[:, :, None] * mask1[:, None, :]\n        mask01 = torch.ones_like(mask0)[:, :, None] * mask1[:, None, :]\n        mask10 = torch.ones_like(mask1)[:, :, None] * mask0[:, None, :]\n        \n        # desc0 = desc0.transpose(0,1)\n        # desc1 = desc1.transpose(0,1)\n        # kpts0 = torch.reshape(kpts0, (1, -1, 2))\n        # kpts1 = torch.reshape(kpts1, (1, -1, 2))\n\n        if kpts0.shape[1] < 2 or kpts1.shape[1] < 2:  # no keypoints\n            shape0, shape1 = kpts0.shape[:-1], kpts1.shape[:-1]\n            # print(data['file_name'])\n            return {\n                'matches0': kpts0.new_full(shape0, -1, dtype=torch.int)[0],\n                # 'matches1': kpts1.new_full(shape1, -1, dtype=torch.int)[0],\n                'matching_scores0': kpts0.new_zeros(shape0)[0],\n                # 'matching_scores1': kpts1.new_zeros(shape1)[0],\n                'skip_train': True\n            }\n\n        # file_name = data['file_name']\n        all_matches = data['all_matches'] if 'all_matches' in data else None# shape = (1, K1)\n        # .permute(1,2,0) # shape=torch.Size([1, 87,])\n        \n        # positional embedding\n        # Keypoint normalization.\n        kpts0 = normalize_keypoints(kpts0, data['image0'].shape)\n        kpts1 = normalize_keypoints(kpts1, data['image1'].shape)\n\n        # Keypoint MLP encoder.\n        # print(data['file_name'])\n        # print(kpts0.size())\n    \n        pos0 = self.kenc(kpts0)\n        pos1 = self.kenc(kpts1)\n        # print(desc0.size(), pos0.size())\n        # print(desc0.size(), pos0.size())\n        desc0 = desc0 + pos0\n        desc1 = desc1 + pos1\n\n        # self.register_buffer(\"mask\", torch.tril(torch.ones(config.block_size, config.block_size))\n                                    #  .view(1, 1, config.block_size, config.block_size))\n        # mask0 = ...\n        # mask1 = ...\n\n        # Multi-layer Transformer network.\n        desc0, desc1 = self.gnn(desc0, desc1, mask00, mask11, mask01, mask10)\n\n        # Final MLP projection.\n        mdesc0, mdesc1 = self.final_proj(desc0), self.final_proj(desc1)\n\n        # Compute matching descriptor distance.\n        # print(mdesc0.size(), mdesc1.size())\n        scores = torch.einsum('bdn,bdm->bnm', mdesc0, mdesc1)\n        scores0 = torch.einsum('bdn,bdm->bnm', mdesc0, mdesc0)\n        scores1 = torch.einsum('bdn,bdm->bnm', mdesc1, mdesc1)\n        # #print('here1!!', scores.size())\n\n        # b k1 k2\n        scores = scores / self.config.descriptor_dim**.5\n        # print(scores.size(), mask01.size())\n        # mask01 = mask0[:, :, None] * mask1[:, None, :]\n        # scores = scores.masked_fill(mask01 == 0, float('-inf'))\n\n        # print(scores.size())\n        # Run the optimal transport.\n        # print(dim_m.size(), dim_m, flush=True)\n        scores = log_optimal_transport(\n            scores, self.bin_score,\n            iters=self.config.sinkhorn_iterations,\n            ms=dim_m, ns=dim_n)\n\n        # print(scores)\n        # print(scores.sum())\n        # print(scores.sum(1))\n        # print(scores.sum(0))\n\n        # Get the matches with score above \"match_threshold\".\n        return scores[:, :-1, :-1], scores0, scores1, mdesc0, mdesc1\n       \n\ndef tensor_erode(bin_img, ksize=5):\n    # 首先为原图加入 padding，防止腐蚀后图像尺寸缩小\n    B, C, H, W = bin_img.shape\n    pad = (ksize - 1) // 2\n    bin_img = F.pad(bin_img, [pad, pad, pad, pad], mode='constant', value=0)\n\n    # 将原图 unfold 成 patch\n    patches = bin_img.unfold(dimension=2, size=ksize, step=1)\n    patches = patches.unfold(dimension=3, size=ksize, step=1)\n    # B x C x H x W x k x k\n\n    # 取每个 patch 中最小的值，i.e., 0\n    eroded, _ = patches.reshape(B, C, H, W, -1).min(dim=-1)\n    return eroded\n\nclass InbetweenerM(nn.Module):\n    \"\"\"SuperGlue feature matching middle-end\n\n    Given two sets of keypoints and locations, we determine the\n    correspondences by:\n      1. Keypoint Encoding (normalization + visual feature and location fusion)\n      2. Graph Neural Network with multiple self and cross-attention layers\n      3. Final projection layer\n      4. Optimal Transport Layer (a differentiable Hungarian matching algorithm)\n      5. Thresholding matrix based on mutual exclusivity and a match_threshold\n\n    The correspondence ids use -1 to indicate non-matching points.\n\n    Paul-Edouard Sarlin, Daniel DeTone, Tomasz Malisiewicz, and Andrew\n    Rabinovich. SuperGlue: Learning Feature Matching with Graph Neural\n    Networks. In CVPR, 2020. https://arxiv.org/abs/1911.11763\n\n    \"\"\"\n    # default_config = {\n    #     'descriptor_dim': 128,\n    #     'weights': 'indoor',\n    #     'keypoint_encoder': [32, 64, 128],\n    #     'GNN_layers': ['self', 'cross'] * 9,\n    #     'sinkhorn_iterations': 100,\n    #     'match_threshold': 0.2,\n    # }\n\n    def __init__(self, config=None):\n        super().__init__()\n        self.corr = SuperGlueM(config.corr_model)\n        self.mask_map = MLP([config.corr_model.descriptor_dim, 32, 1])\n        self.pos_weight = config.pos_weight\n        # self.motion_propagation = \n        \n        # assert self.config.weights in ['indoor', 'outdoor']\n        # path = Path(__file__).parent\n        # path = path / 'weights/superglue_{}.pth'.format(self.config.weights)\n        # self.load_state_dict(torch.load(path))\n        # print('Loaded SuperGlue model (\\\"{}\\\" weights)'.format(\n        #     self.config.weights))\n\n    def forward(self, data):\n        if 'gen_vid' in data:\n            dim_m, dim_n = data['ms'].float(), data['ns'].float()\n            mmax = dim_m.int().max()\n            nmax = dim_n.int().max()\n            # with torch.no_grad():\n            #     self.corr.eval()\n            score01, score0, score1, dec0, dec1 = self.corr(data)\n            kpts0, kpts1 = data['keypoints0'][:,:mmax].float(), data['keypoints1'][:,:nmax].float() # BM2, BN2 \n          ##  print(kpts0.mean(), kpts1.mean(), flush=True)\n\n            motion_pred0 = torch.softmax(score01, dim=-1) @ kpts1 - kpts0\n            motion_pred1 = torch.softmax(score01.transpose(1, 2), dim=-1) @ kpts0 - kpts1\n\n            motion_pred0 = torch.softmax(score0, dim=-1) @ motion_pred0\n            motion_pred1 = torch.softmax(score1, dim=-1) @ motion_pred1\n\n            max0, max1 = score01.max(2), score01.max(1)\n            indices0, indices1 = max0.indices, max1.indices\n            mutual0 = arange_like(indices0, 1)[None] == indices1.gather(1, indices0)\n            mutual1 = arange_like(indices1, 1)[None] == indices0.gather(1, indices1)\n            zero = score01.new_tensor(0)\n\n            mscores0 = torch.where(mutual0, max0.values.exp(), zero)\n            mscores1 = torch.where(mutual1, mscores0.gather(1, indices1), zero)\n            # valid0 = mutual0 & (mscores0 > self.config.match_threshold)\n            # valid1 = mutual1 & valid0.gather(1, indices1)\n            \n            valid0 = mscores0 > 0.2\n            valid1 = valid0.gather(1, indices1)\n            indices0 = torch.where(valid0, indices0, indices0.new_tensor(-1))\n            indices1 = torch.where(valid1, indices1, indices1.new_tensor(-1))\n\n            adj0, adj1 = data['adj_mat0'].float(), data['adj_mat1'].float()\n\n            motion_pred0 = torch.softmax(score01, dim=-1) @ kpts1 - kpts0\n            motion_pred1 = torch.softmax(score01.transpose(1, 2), dim=-1) @ kpts0 - kpts1\n\n            # score0.mask_off()\n\n            motion_pred0 = torch.softmax(score0.masked_fill(adj0==0, float('-inf')), dim=-1) @ motion_pred0\n            motion_pred1 = torch.softmax(score1.masked_fill(adj1==0, float('-inf')), dim=-1) @ motion_pred1\n            \n            vb0 = self.mask_map(dec0)[:, 0]\n            vb1 = self.mask_map(dec1)[:, 0]\n            vb0[:] = 1\n            vb1[:] = 1\n\n            im0_erode =  data['image0']\n            im1_erode =  data['image1']\n            im0_erode[im0_erode > 0] = 1\n            im0_erode[im0_erode <= 0] = 0\n            im1_erode[im1_erode > 0] = 1\n            im1_erode[im1_erode <= 0] = 0\n            \n            im0_erode = tensor_erode(im0_erode, 3)\n            im1_erode = tensor_erode(im1_erode, 3)\n\n            motion_output0, motion_output1 =  motion_pred0.clone(), motion_pred1.clone()\n          ##  print('>>>>> here', motion_pred0.mean(), motion_pred1.mean(), flush=True)\n            kpt0t = kpts0 + motion_output0 * 1\n            kpt1t = kpts1 + motion_output1 * 1\n            if 'topo0' in data and 'topo1' in data:\n              ##  print(len(data['topo0'][0]), len(data['topo1']), flush=True)\n                for node, nbs in enumerate(data['topo0'][0]):\n                    for nb in nbs:\n                        # print(nb, flush=True)\n                        # print(kpt0t.size(), 'fDsafdsafds', flush=True)\n                        # if vb0[0, nb] and vb0[0, node] and ((kpt0t[0, node] - kpt0t[0, nb]) ** 2).sum() / (((kpts0[0, node] - kpts0[0, nb]) ** 2).sum() + 1e-7) > 3:\n                        #     vb0[0, nb] = -1\n                        #     vb0[0, node] = -1\n                        # print(node.size())\n                        center = ((kpt0t[0, node] + kpt0t[0, nb]) * 0.5).int()[0]\n                        # print(center.size(), flush=True)\n                        if vb0[0, nb] and vb0[0, node] and im1_erode[0,:, center[1], center[0]].mean() > 0.8:\n                            vb0[0, nb] = -1\n                            vb0[0, node] = -1\n                        # center = ((kpt0t[0, node] + kpt0t[0, nb]) * 0.25).int()[0]\n                        # # print(center.size(), flush=True)\n                        # if vb0[0, nb] and vb0[0, node] and center[1] < 720 and center[0] < 720 and im1_erode[0,:, center[1], center[0]].mean() > 0.8:\n                        #     vb0[0, nb] = -1\n                        #     vb0[0, node] = -1\n                        # center = ((kpt0t[0, node] + kpt0t[0, nb]) * 0.75).int()[0]\n                        # # print(center.size(), flush=True)\n                        # if vb0[0, nb] and vb0[0, node] and center[1] < 720 and center[0] < 720 and im1_erode[0,:, center[1], center[0]].mean() > 0.8:\n                        #     vb0[0, nb] = -1\n                        #     vb0[0, node] = -1\n                for node, nbs in enumerate(data['topo1'][0]):\n                    for nb in nbs:\n                        \n                        # if vb1[0, nb] and vb1[0, node] and ((kpt1t[0, node] - kpt1t[0, nb]) ** 2).sum() / (((kpts1[0, node] - kpts1[0, nb]) ** 2).sum() + 1e-7) >3:\n                        #     vb1[0, nb] = -1\n                        #     vb1[0, node] = -1\n                        center = ((kpt1t[0, node] + kpt1t[0, nb]) * 0.5).int()[0]\n                        if vb1[0, nb] and vb1[0, node] and im0_erode[0,:, center[1], center[0]].mean() > 0.95:\n                            vb1[0, nb] = -1\n                            vb1[0, node] = -1\n                        # center = ((kpt1t[0, node] + kpt1t[0, nb]) * 0.25).int()[0]\n                        # if vb1[0, nb] and vb1[0, node] and center[1] < 720 and center[0] < 720 and im0_erode[0,:, center[1], center[0]].mean() > 0.95:\n                        #     vb1[0, nb] = -1\n                        #     vb1[0, node] = -1\n                        # center = ((kpt1t[0, node] + kpt1t[0, nb]) * 0.75).int()[0]\n                        # if vb1[0, nb] and vb1[0, node] and center[1] < 720 and center[0] < 720 and im0_erode[0,:, center[1], center[0]].mean() > 0.95:\n                        #     vb1[0, nb] = -1\n                        #     vb1[0, node] = -1\n            # print(vb0.mean(), vb1.mean(), flush=True)\n            return {'r0': motion_output0, 'r1': motion_output1, 'vb0':(vb0 > 0).float(), 'vb1':(vb1 > 0).float(),}\n\n        dim_m, dim_n = data['ms'].float(), data['ns'].float()\n        mmax = dim_m.int().max()\n        nmax = dim_n.int().max()\n        # with torch.no_grad():\n        #     self.corr.eval()\n        score01, score0, score1, dec0, dec1 = self.corr(data)\n\n\n        kpts0, kpts1 = data['keypoints0'][:,:mmax].float(), data['keypoints1'][:,:nmax].float() # BM2, BN2 \n\n\n        adj0, adj1 = data['adj_mat0'].float(), data['adj_mat1'].float()\n\n        motion_pred0 = torch.softmax(score01, dim=-1) @ kpts1 - kpts0\n        motion_pred1 = torch.softmax(score01.transpose(1, 2), dim=-1) @ kpts0 - kpts1\n\n        # score0.mask_off()\n\n        motion_pred0 = torch.softmax(score0.masked_fill(adj0==0, float('-inf')), dim=-1) @ motion_pred0\n        motion_pred1 = torch.softmax(score1.masked_fill(adj1==0, float('-inf')), dim=-1) @ motion_pred1\n        \n        vb0 = self.mask_map(dec0)[:, 0]\n        vb1 = self.mask_map(dec1)[:, 0]\n\n        # motion0_pred, vb0 = pred0[:, :2].permute(0, 2, 1), pred0[:, 2:][:, 0]\n        # motion1_pred, vb1 = pred1[:, :2].permute(0, 2, 1), pred1[:, 2:][:, 0]\n        \n        # delta0, delta1 = motion_delta[:, :, :mmax].permute(0, 2, 1), motion_delta[:, :, mmax:].permute(0, 2, 1)\n        # motion_output0, motion_output1 =  motion0 + delta0, motion1 + delta1\n        motion_output0, motion_output1 =  motion_pred0.clone(), motion_pred1.clone()\n\n        # print(delta0.max(), delta1.max())\n        # vb0 = kpts0.new_ones(motion_pred0[:, :, 0].size()) + 1.0\n        # vb1 = kpts1.new_ones(motion_pred1[:, :, 0].size()) + 1.0\n\n        # vb0, vb1 = visibility[:, 0, :mmax], visibility[:, 0, mmax:]\n        # mask0, mask1 = mask[:, :mmax].bool(), mask[:, mmax:].bool()\n        # vb0_output = vb0.clone()\n        # vb1_output = vb1.clone()\n\n        # vb1_output[batch, corr01[corr01 != -1]] = 1.0\n\n        # motion_output0[valid0.bool()] = motion0[valid0.bool()]\n        # motion_output1[valid1.bool()] = motion1[valid1.bool()]\n\n        # vb0_output[vb0_output >= 0] = 1.0\n        # vb0_output[vb0_output < 0] = 0.0\n        # vb1_output[vb1_output >= 0] = 1.0\n        # vb1_output[vb1_output < 0 ] = 0.0\n\n        \n\n        kpt0t = kpts0 + motion_output0 / 2\n        kpt1t = kpts1 + motion_output1 / 2\n        # kpt1t[batch, corr01[corr01 != -1]] = kpt0t[corr01 != -1]\n        \n        \n        ##################################################\n        ##  Note Here the mini batch size is 1!!!!!!!!  ##\n        ##################################################\n\n        if 'topo0' in data and 'topo1' in data:\n            # print(len(data['topo0'][0]), len(data['topo1']), flush=True)\n            for node, nbs in enumerate(data['topo0'][0]):\n                for nb in nbs:\n                    if vb0[0, nb] and vb0[0, node] and ((kpt0t[0, node] - kpt0t[0, nb]) ** 2).sum() / (((kpts0[0, node] - kpts0[0, nb]) ** 2).sum() + 1e-7) > 5:\n                        vb0[0, nb] = -1\n                        vb0[0, node] = -1\n            for node, nbs in enumerate(data['topo1'][0]):\n                for nb in nbs:\n                    if vb1[0, nb] and vb1[0, node] and ((kpt1t[0, node] - kpt1t[0, nb]) ** 2).sum() / (((kpts1[0, node] - kpts1[0, nb]) ** 2).sum() + 1e-7) > 5:\n                        vb1[0, nb] = -1\n                        vb1[0, node] = -1\n\n        if 'motion0' in data and 'motion1' in data:\n            # valid_motion0 = motion_output0[mask0[:, :, None].repeat(1, 1, 2)]\n            # gt_valid_motion0 = data['motion0'][:, :mmax][mask0[:, :, None].repeat(1, 1, 2)].float()\n            # valid_motion1 = motion_output1[mask1[:, :, None].repeat(1, 1, 2)]\n            # gt_valid_motion1 = data['motion1'][:, :nmax][mask1[:, :, None].repeat(1, 1, 2)].float()\n\n            loss_motion = torch.nn.functional.l1_loss(motion_pred0, data['motion0'][:, :mmax]) +\\\n                torch.nn.functional.l1_loss(motion_pred1, data['motion1'][:, :nmax])\n            \n            # loss_valid0 = ((corr01 == -1) & (mask0 == 1))\n            # loss_valid1 = ((corr10 == -1) & (mask1 == 1))\n            EPE0 = ((motion_pred0 - data['motion0'][:, :mmax]) ** 2).sum(dim=-1).sqrt()\n            EPE1 = ((motion_pred1 - data['motion1'][:, :nmax]) ** 2).sum(dim=-1).sqrt()\n            # print(EPE0.size(), 'fdsafdsa')\n\n            EPE = (EPE0.mean() + EPE1.mean()) * 0.5\n            # print(len(EPE0[mask0]), len(EPE1[mask1]))\n            # print(vb0[:, :mmax][mask0], vb0[:, :mmax][mask0].shape, data['visibility0'][:, :mmax][mask0], data['visibility0'][:, :mmax][mask0].shape)\n            # print(.size())\n            # print((vb0[:, :mmax] > 0).float().sum(), data['visibility0'][:, :mmax].float().sum())\n            # pos_weight=vb0.new_tensor([0.5])\n            if 'visibility0' in data and 'visibility1' in data:\n                loss_visibility = torch.nn.functional.binary_cross_entropy_with_logits(vb0[:, :mmax].view(-1, 1), data['visibility0'][:, :mmax].view(-1, 1), pos_weight=vb0.new_tensor([self.pos_weight])) + \\\n                torch.nn.functional.binary_cross_entropy_with_logits(vb1[:, :nmax].view(-1, 1), data['visibility1'][:, :nmax].view(-1, 1), pos_weight=vb0.new_tensor([self.pos_weight]))\n            \n                VB_Acc = ((((vb0 > 0).float() == data['visibility0'][:, :mmax]).float().sum() + ((vb1 > 0).float() == data['visibility1'][:, :nmax]).float().sum()) * 1.0 / (mmax + nmax))\n            else:\n                loss_visibility = 0\n                VB_Acc = EPE.new_zeros([1])\n            loss = loss_motion + 10 * loss_visibility\n\n            loss_mean = torch.mean(loss)\n            # loss_mean = torch.reshape(loss_mean, (1, -1))\n            # print(loss_mean, flush=True)\n\n            # print(all_matches[:, :mmax].size(), indices0.size(), mask0.size(), flush=True)\n            #print((all_matches[0] == indices0[0]).sum())\n\n            # print(vb1.size(),corr01.size())\n\n            # kpt0t = torch.nn.functional.pad(kpts0 + motion_output0, (0, 0, 0, self.max_len - mmax, 0, 0), mode='constant', value=0)\n            # kpt1t = torch.nn.functional.pad(kpts1 + motion_output1, (0, 0, 0, self.max_len - nmax, 0, 0), mode='constant', value=0),\n\n            # kpt1t[:, :nmax][batch, corr01[corr01 != -1]] = kpt0t[:, :mmax][corr01 != -1]\n\n            b, _, _ = motion_pred0.size()\n            # batch = torch.arange(b)[:, None].repeat(1, mmax)[corr01 != -1].long()\n            # # print(kpts0[corr01 != -1].size(), corr01[corr01 != -1].size())\n            # matched_intermediate = (kpts0[(corr01 != -1)] + kpts1[batch, corr01[corr01 != -1].long(), :]) * 0.5\n            # motion0[corr01 != -1] = matched_intermediate - kpts0[corr01 != -1]\n            # motion1[batch, corr01[corr01 != -1].long(), :] = matched_intermediate - kpts1[batch, corr01[corr01 != -1].long(), :]\n\n            # vb0 = torch.nn.functional.pad(vb0, (0, self.max_len - mmax, 0, 0), mode='constant', value=0),\n            # vb1 = torch.nn.functional.pad(vb1, (0, self.max_len - nmax, 0, 0), mode='constant', value=0),\n\n            # self.max_len = 3050\n            # VB_Acc = ((((vb0 > 0.5).float() == data['visibility0'][:, :mmax]).float().sum() + ((vb1 > 0.5).float() == data['visibility1'][:, :nmax]).float().sum()) * 1.0 / (mmax + nmax))\n                \n            return {\n                # 'matches0': indices0, # use -1 for invalid match\n                # 'matches1': indices1[0], # use -1 for invalid match\n                # 'matching_scores0': mscores0,\n                # 'matching_scores1': mscores1[0],\n                # 'keypointst0': torch.nn.functional.pad(kpts0 + motion_output0, (0, 0, 0, self.max_len - mmax, 0, 0), mode='constant', value=0),\n                # 'keypointst1': torch.nn.functional.pad(kpts1 + motion_output1, (0, 0, 0, self.max_len - nmax, 0, 0), mode='constant', value=0),\n                # 'vb0': torch.nn.functional.pad(vb0, (0, self.max_len - mmax, 0, 0), mode='constant', value=0),\n                # 'vb1': torch.nn.functional.pad(vb1, (0, self.max_len - nmax, 0, 0), mode='constant', value=0),\n                'keypoints0t': kpt0t,\n                'keypoints1t': kpt1t,\n                'vb0': (vb0 > 0).float(),\n                'vb1': (vb1 > 0).float(),\n                'loss': loss_mean,\n                'EPE': EPE,\n                'Visibility Acc': VB_Acc\n                # ((((vb0[mask0] > 0).float() == data['visibility0'][:, :mmax][mask0]).float().sum() + ((vb1[mask1] > 0).float() == data['visibility1'][:, :nmax][mask1]).float().sum()) * 1.0 / (mask0.float().sum() + mask1.float().sum())),\n                # 'skip_train': [False],\n                # 'accuracy': (((all_matches[:, :mmax] == indices0) & mask0.bool()).sum() / mask0.sum()).item(),\n                # 'valid_accuracy': (((all_matches[:, :mmax] == indices0) & (all_matches[:, :mmax] != -1) & mask0.bool()).float().sum() / ((all_matches[:, :mmax] != -1) & mask0.bool()).float().sum()).item(),\n            }\n        else:\n            return {\n                'loss': -1,\n                'skip_train': True,\n                'keypointst0': kpts0 + motion_output0,\n                'keypointst1': kpts1 + motion_output1,\n                'vb0': vb0,\n                'vb1': vb1,\n                # 'accuracy': -1,\n                # 'area_accuracy': -1,\n                # 'valid_accuracy': -1,\n            }\n\n\nif __name__ == '__main__':\n\n    args = argparse.Namespace()\n    args.batch_size = 2\n    args.gap = 5\n    args.type = 'train'\n    args.model = None\n    args.action = None\n    ss = Refiner()\n\n\n    loader = fetch_dataloader(args)\n    # #print(len(loader))\n    for data in loader:\n        # p1, p2, s1, s2, mi = data\n        dict1 = data\n\n        kp1 = dict1['keypoints0']\n        kp2 = dict1['keypoints1']\n        p1 = dict1['image0']\n        p2 = dict1['image1']  \n\n        # #print(s1)\n        # #print(s1.type)\n        mi = dict1['m01']\n        fname = dict1['file_name'] \n        print(dict1['keypoints0'].size(), dict1['keypoints1'].size(), dict1['m01'].size(), dict1['motion0'].size(), dict1['mask0'].size())\n        # print(kp1.shape, p1.shape, mi.shape)  \n        # #print(mi.size())  \n        # #print(mi)\n        # break\n\n        a = ss(data)\n        print(dict1['file_name'])\n        print(a['loss'])\n        print(a['EPE'], a['Visibility Acc'],flush=True)\n        a['loss'].backward()"
  },
  {
    "path": "models/inbetweener_with_mask_with_spec.py",
    "content": "from copy import deepcopy\nfrom pathlib import Path\nimport torch\nfrom torch import nn\n# from seg_desc import seg_descriptor\nimport argparse\nimport numpy as np\nimport torch.nn.functional as F\nfrom sknetwork.embedding import Spectral\n\ndef MLP(channels: list, do_bn=True):\n    \"\"\" Multi-layer perceptron \"\"\"\n    n = len(channels)\n    layers = []\n    for i in range(1, n):\n        layers.append(\n            nn.Conv1d(channels[i - 1], channels[i], kernel_size=1, bias=True))\n        if i < (n-1):\n            if do_bn:\n                # layers.append(nn.BatchNorm1d(channels[i]))\n                layers.append(nn.InstanceNorm1d(channels[i]))\n            layers.append(nn.ReLU())\n    return nn.Sequential(*layers)\n\n\ndef normalize_keypoints(kpts, image_shape):\n    \"\"\" Normalize keypoints locations based on image image_shape\"\"\"\n    _, _, height, width = image_shape\n    one = kpts.new_tensor(1)\n    size = torch.stack([one*width, one*height])[None]\n    center = size / 2\n    scaling = size.max(1, keepdim=True).values * 0.7\n    return (kpts - center[:, None, :]) / scaling[:, None, :]\n\nclass ThreeLayerEncoder(nn.Module):\n    \"\"\" Joint encoding of visual appearance and location using MLPs\"\"\"\n    def __init__(self, enc_dim):\n        super().__init__()\n        # input must be 3 channel (r, g, b)\n        self.layer1 = nn.Conv2d(3, enc_dim//4, 7, padding=3)\n        self.non_linear1 = nn.ReLU()\n        self.layer2 = nn.Conv2d(enc_dim//4, enc_dim//2, 3, padding=1)\n        self.non_linear2 = nn.ReLU()\n        self.layer3 = nn.Conv2d(enc_dim//2, enc_dim, 3, padding=1)\n\n        self.norm1 = nn.InstanceNorm2d(enc_dim//4)\n        self.norm2 = nn.InstanceNorm2d(enc_dim//2)\n        self.norm3 = nn.InstanceNorm2d(enc_dim)\n\n        for m in self.modules():\n            if isinstance(m, nn.Conv2d):\n                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')\n                nn.init.constant_(m.bias, 0.0)\n\n    def forward(self, img):\n        x = self.non_linear1(self.norm1(self.layer1(img)))\n        x = self.non_linear2(self.norm2(self.layer2(x)))\n        x = self.norm3(self.layer3(x))\n        # x = self.non_linear1(self.layer1(img))\n        # x = self.non_linear2(self.layer2(x))\n        # x = self.layer3(x)\n        return x\n\n\nclass VertexDescriptor(nn.Module):\n    \"\"\" Joint encoding of visual appearance and location using MLPs\"\"\"\n    def __init__(self, enc_dim):\n        super().__init__()\n        self.encoder = ThreeLayerEncoder(enc_dim)\n        # self.super_pixel_pooling = \n        # use scatter\n        # nn.init.constant_(self.encoder[-1].bias, 0.0)\n\n    def forward(self, img, vtx):\n        x = self.encoder(img)\n        n, c, h, w = x.size()\n        assert((h, w) == img.size()[2:4])\n        return x[:, :, torch.round(vtx[0, :, 1]).long(), torch.round(vtx[0, :, 0]).long()]\n        # return super_pixel_pooling(x.view(n, c, -1), seg.view(-1).long(), reduce='mean')\n        # here return size is [1]xCx|Seg|\n\n\nclass KeypointEncoder(nn.Module):\n    \"\"\" Joint encoding of visual appearance and location using MLPs\"\"\"\n    def __init__(self, feature_dim, layers):\n        super().__init__()\n        self.encoder = MLP([2] + layers + [feature_dim])\n        # for m in self.encoder.modules():\n        #     if isinstance(m, nn.Conv2d):\n        #         nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')\n        #         nn.init.constant_(m.bias, 0.0)\n        nn.init.constant_(self.encoder[-1].bias, 0.0)\n\n    def forward(self, kpts):\n        inputs = kpts.transpose(1, 2)\n\n        x = self.encoder(inputs)\n        return x\n\nclass TopoEncoder(nn.Module):\n    \"\"\" Joint encoding of visual appearance and location using MLPs\"\"\"\n    def __init__(self, feature_dim, layers):\n        super().__init__()\n        self.encoder = MLP([64] + layers + [feature_dim])\n        # for m in self.encoder.modules():\n        #     if isinstance(m, nn.Conv2d):\n        #         nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')\n        #         nn.init.constant_(m.bias, 0.0)\n        nn.init.constant_(self.encoder[-1].bias, 0.0)\n\n    def forward(self, kpts):\n        inputs = kpts.transpose(1, 2)\n        x = self.encoder(inputs)\n        return x\n\n\ndef attention(query, key, value, mask=None):\n    dim = query.shape[1]\n    scores = torch.einsum('bdhn,bdhm->bhnm', query, key) / dim**.5\n    if mask is not None:\n        scores = scores.masked_fill(mask==0, float('-inf'))\n\n    prob = torch.nn.functional.softmax(scores, dim=-1)\n\n    return torch.einsum('bhnm,bdhm->bdhn', prob, value), prob\n\n\nclass MultiHeadedAttention(nn.Module):\n    \"\"\" Multi-head attention to increase model expressivitiy \"\"\"\n    def __init__(self, num_heads: int, d_model: int):\n        super().__init__()\n        assert d_model % num_heads == 0\n        self.dim = d_model // num_heads\n        self.num_heads = num_heads\n        self.merge = nn.Conv1d(d_model, d_model, kernel_size=1)\n        self.proj = nn.ModuleList([deepcopy(self.merge) for _ in range(3)])\n\n    def forward(self, query, key, value, mask=None):\n        batch_dim = query.size(0)\n        query, key, value = [l(x).view(batch_dim, self.dim, self.num_heads, -1)\n                             for l, x in zip(self.proj, (query, key, value))]\n        x, prob = attention(query, key, value, mask)\n        # self.prob.append(prob)\n        return self.merge(x.contiguous().view(batch_dim, self.dim*self.num_heads, -1))\n\n\nclass AttentionalPropagation(nn.Module):\n    def __init__(self, feature_dim: int, num_heads: int):\n        super().__init__()\n        self.attn = MultiHeadedAttention(num_heads, feature_dim)\n        self.mlp = MLP([feature_dim*2, feature_dim*2, feature_dim])\n        nn.init.constant_(self.mlp[-1].bias, 0.0)\n\n    def forward(self, x, source, mask=None):\n        message = self.attn(x, source, source, mask)\n        return self.mlp(torch.cat([x, message], dim=1))\n\n\nclass AttentionalGNN(nn.Module):\n    def __init__(self, feature_dim: int, layer_names: list):\n        super().__init__()\n        self.layers = nn.ModuleList([\n            AttentionalPropagation(feature_dim, 4)\n            for _ in range(len(layer_names))])\n        self.names = layer_names\n\n    def forward(self, desc0, desc1, mask00=None, mask11=None, mask01=None, mask10=None):\n        for layer, name in zip(self.layers, self.names):\n            layer.attn.prob = []\n            if name == 'cross':\n                src0, src1 = desc1, desc0\n                mask0, mask1 = mask01[:, None], mask10[:, None] \n            else:  # if name == 'self':\n                src0, src1 = desc0, desc1\n                mask0, mask1 = mask00[:, None], mask11[:, None]\n\n            delta0, delta1 = layer(desc0, src0, mask0), layer(desc1, src1, mask1)\n            desc0, desc1 = (desc0 + delta0), (desc1 + delta1)\n        return desc0, desc1\n\n\ndef log_sinkhorn_iterations(Z, log_mu, log_nu, iters: int):\n    \"\"\" Perform Sinkhorn Normalization in Log-space for stability\"\"\"\n    u, v = torch.zeros_like(log_mu), torch.zeros_like(log_nu)\n    for _ in range(iters):\n        u = log_mu - torch.logsumexp(Z + v.unsqueeze(1), dim=2)\n        v = log_nu - torch.logsumexp(Z + u.unsqueeze(2), dim=1)\n    return Z + u.unsqueeze(2) + v.unsqueeze(1)\n\n\ndef log_optimal_transport(scores, alpha, iters: int, ms=None, ns=None):\n    \"\"\" Perform Differentiable Optimal Transport in Log-space for stability\"\"\"\n    b, m, n = scores.shape\n    one = scores.new_tensor(1)\n    if ms is  None or ns is  None:\n        ms, ns = (m*one).to(scores), (n*one).to(scores)\n    # else:\n    #     ms, ns = ms.to(scores)[:, None], ns.to(scores)[:, None]\n    # here m,n should be parameters not shape\n\n    # ms, ns: (b, )\n    bins0 = alpha.expand(b, m, 1)\n    bins1 = alpha.expand(b, 1, n)\n    alpha = alpha.expand(b, 1, 1)\n\n    # pad additional scores for unmatcheed (to -1)\n    # alpha is the learned threshold\n    couplings = torch.cat([torch.cat([scores, bins0], -1),\n                           torch.cat([bins1, alpha], -1)], 1)\n\n    norm = - (ms + ns).log() # (b, )\n\n    if ms.size()[0] > 0:\n        norm = norm[:, None]\n        log_mu = torch.cat([norm.expand(b, m), ns.log()[:, None] + norm], dim=-1) # (m + 1)\n        log_nu = torch.cat([norm.expand(b, n), ms.log()[:, None] + norm], dim=-1)\n    else:\n        log_mu = torch.cat([norm.expand(m), ns.log()[None] + norm]) # (m + 1)\n        log_nu = torch.cat([norm.expand(n), ms.log()[None] + norm])\n        log_mu, log_nu = log_mu[None].expand(b, -1), log_nu[None].expand(b, -1)\n\n    \n    Z = log_sinkhorn_iterations(couplings, log_mu, log_nu, iters)\n\n    if ms.size()[0] > 1:\n        norm = norm[:, :, None]\n    Z = Z - norm  # multiply probabilities by M+N\n    return Z\n\n\ndef arange_like(x, dim: int):\n    return x.new_ones(x.shape[dim]).cumsum(0) - 1  # traceable in 1.1\n\n\nclass SuperGlueT(nn.Module):\n    \"\"\"SuperGlue feature matching middle-end\n\n    Given two sets of keypoints and locations, we determine the\n    correspondences by:\n      1. Keypoint Encoding (normalization + visual feature and location fusion)\n      2. Graph Neural Network with multiple self and cross-attention layers\n      3. Final projection layer\n      4. Optimal Transport Layer (a differentiable Hungarian matching algorithm)\n      5. Thresholding matrix based on mutual exclusivity and a match_threshold\n\n    The correspondence ids use -1 to indicate non-matching points.\n\n    Paul-Edouard Sarlin, Daniel DeTone, Tomasz Malisiewicz, and Andrew\n    Rabinovich. SuperGlue: Learning Feature Matching with Graph Neural\n    Networks. In CVPR, 2020. https://arxiv.org/abs/1911.11763\n\n    \"\"\"\n    # default_config = {\n    #     'descriptor_dim': 128,\n    #     'weights': 'indoor',\n    #     'keypoint_encoder': [32, 64, 128],\n    #     'GNN_layers': ['self', 'cross'] * 9,\n    #     'sinkhorn_iterations': 100,\n    #     'match_threshold': 0.2,\n    # }\n\n    def __init__(self, config=None):\n        super().__init__()\n\n        default_config = argparse.Namespace()\n        default_config.descriptor_dim = 128\n\n        default_config.keypoint_encoder = [32, 64, 128]\n        default_config.GNN_layers = ['self', 'cross'] * 9\n        default_config.sinkhorn_iterations = 100\n        default_config.match_threshold = 0.2\n        self.spectral = Spectral(64,  normalized=False)\n\n\n        if config is None:\n            self.config = default_config\n        else:\n            self.config = config   \n            self.config.GNN_layers = ['self', 'cross'] * self.config.GNN_layer_num\n\n        self.kenc = KeypointEncoder(\n            self.config.descriptor_dim, self.config.keypoint_encoder)\n\n        self.tenc = TopoEncoder(\n            self.config.descriptor_dim, [96])\n\n\n        self.gnn = AttentionalGNN(\n            self.config.descriptor_dim, self.config.GNN_layers)\n\n        self.final_proj = nn.Conv1d(\n            self.config.descriptor_dim, self.config.descriptor_dim,\n            kernel_size=1, bias=True)\n\n        bin_score = torch.nn.Parameter(torch.tensor(1.))\n        self.register_parameter('bin_score', bin_score)\n        self.vertex_desc = VertexDescriptor(self.config.descriptor_dim)\n       \n\n    def forward(self, data):\n        kpts0, kpts1 = data['keypoints0'].float(), data['keypoints1'].float()\n\n        ori_mask0, ori_mask1 = data['mask0'].float(), data['mask1'].float()\n        dim_m, dim_n = data['ms'].float(), data['ns'].float()\n\n        # spectual embedding of adj matrices\n        # here I find that online computation of spectrals are too slow during training\n        # so the spectrual embedding is moved to dataset pipeline \n        # such that it can be computed in data preparation by multi-processing cpus\n        spec0, spec1 = data['spec0'], data['spec1']\n        # spec0, spec1 = np.abs(self.spectral.fit_transform(adj_mat0[0].cpu().numpy())), np.abs(self.spectral.fit_transform(adj_mat1[0].cpu().numpy()))\n\n        mmax = dim_m.int().max()\n        nmax = dim_n.int().max()\n\n        mask0 = ori_mask0[:, :mmax]\n        mask1 = ori_mask1[:, :nmax]\n\n        kpts0 = kpts0[:, :mmax]\n        kpts1 = kpts1[:, :nmax]\n\n        # image context embedding\n        desc0, desc1 = self.vertex_desc(data['image0'], kpts0.float()), self.vertex_desc(data['image1'], kpts1.float())\n\n        # add topological embedding\n        desc0 = desc0 + self.tenc(desc0.new_tensor(spec0))\n        desc1 = desc1 + self.tenc(desc1.new_tensor(spec1))\n\n        # masks here were prepared for synchronized training with bach size > 1, but seems not to work well\n        # so the current framework still uses grad accumulation \n        mask00 = torch.ones_like(mask0)[:, :, None] * mask0[:, None, :]\n        \n        mask11 = torch.ones_like(mask1)[:, :, None] * mask1[:, None, :]\n        mask01 = torch.ones_like(mask0)[:, :, None] * mask1[:, None, :]\n        mask10 = torch.ones_like(mask1)[:, :, None] * mask0[:, None, :]\n        \n\n        if kpts0.shape[1] < 2 or kpts1.shape[1] < 2:  # no keypoints\n            shape0, shape1 = kpts0.shape[:-1], kpts1.shape[:-1]\n            # print(data['file_name'])\n            return {\n                'matches0': kpts0.new_full(shape0, -1, dtype=torch.int)[0],\n                # 'matches1': kpts1.new_full(shape1, -1, dtype=torch.int)[0],\n                'matching_scores0': kpts0.new_zeros(shape0)[0],\n                # 'matching_scores1': kpts1.new_zeros(shape1)[0],\n                'skip_train': True\n            }\n\n        all_matches = data['all_matches'] if 'all_matches' in data else None# shape = (1, K1)\n\n        # positional embedding\n        # Keypoint normalization.\n        kpts0 = normalize_keypoints(kpts0, data['image0'].shape)\n        kpts1 = normalize_keypoints(kpts1, data['image1'].shape)\n\n        # Keypoint MLP encoder.\n        pos0 = self.kenc(kpts0)\n        pos1 = self.kenc(kpts1)\n\n        desc0 = desc0 + pos0\n        desc1 = desc1 + pos1\n\n\n        # Multi-layer Transformer network.\n        desc0, desc1 = self.gnn(desc0, desc1, mask00, mask11, mask01, mask10)\n\n        # Final MLP projection.\n        mdesc0, mdesc1 = self.final_proj(desc0), self.final_proj(desc1)\n\n        # Compute matching descriptor distance.\n        scores = torch.einsum('bdn,bdm->bnm', mdesc0, mdesc1)\n        scores0 = torch.einsum('bdn,bdm->bnm', mdesc0, mdesc0)\n        scores1 = torch.einsum('bdn,bdm->bnm', mdesc1, mdesc1)\n\n\n        # b k1 k2\n        scores = scores / self.config.descriptor_dim**.5\n\n\n        # Run the optimal transport.\n        scores = log_optimal_transport(\n            scores, self.bin_score,\n            iters=self.config.sinkhorn_iterations,\n            ms=dim_m, ns=dim_n)\n\n\n        # Get the matches with score above \"match_threshold\".\n        return scores[:, :-1, :-1], scores0, scores1, mdesc0, mdesc1\n       \ndef tensor_erode(bin_img, ksize=5):\n    B, C, H, W = bin_img.shape\n    pad = (ksize - 1) // 2\n    bin_img = F.pad(bin_img, [pad, pad, pad, pad], mode='constant', value=0)\n\n    patches = bin_img.unfold(dimension=2, size=ksize, step=1)\n    patches = patches.unfold(dimension=3, size=ksize, step=1)\n    # B x C x H x W x k x k\n\n    eroded, _ = patches.reshape(B, C, H, W, -1).min(dim=-1)\n    return eroded\n\nclass InbetweenerTM(nn.Module):\n    \"\"\"AnimeInbet\n    The whole pipeline includes\n    1. vertex correspondence (vertex embedding + correspondence transformer)\n    2. repositioning propagation\n    3. vis mask\n\n    vertex corr code is modified from SUPER GLUE \n\n    Paul-Edouard Sarlin, Daniel DeTone, Tomasz Malisiewicz, and Andrew\n    Rabinovich. SuperGlue: Learning Feature Matching with Graph Neural\n    Networks. In CVPR, 2020. https://arxiv.org/abs/1911.11763\n\n    \"\"\"\n\n\n    def __init__(self, config=None):\n        super().__init__()\n        # vertex correspondence\n        self.corr = SuperGlueT(config.corr_model)\n        self.mask_map = MLP([config.corr_model.descriptor_dim, 32, 1])\n        self.pos_weight = config.pos_weight\n        \n\n    def forward(self, data):\n        # if in the mode of video generating\n        if 'gen_vid' in data:\n            dim_m, dim_n = data['ms'].float(), data['ns'].float()\n            mmax = dim_m.int().max()\n            nmax = dim_n.int().max()\n            with torch.no_grad():\n                self.corr.eval()\n                score01, score0, score1, dec0, dec1 = self.corr(data)\n                kpts0, kpts1 = data['keypoints0'][:,:mmax].float(), data['keypoints1'][:,:nmax].float() # BM2, BN2 \n\n                motion_pred0 = torch.softmax(score01, dim=-1) @ kpts1 - kpts0\n                motion_pred1 = torch.softmax(score01.transpose(1, 2), dim=-1) @ kpts0 - kpts1\n\n            motion_pred0 = torch.softmax(score0, dim=-1) @ motion_pred0\n            motion_pred1 = torch.softmax(score1, dim=-1) @ motion_pred1\n            \n            self.mask_map.eval()\n            vb0 = self.mask_map(dec0)[:, 0]\n            vb1 = self.mask_map(dec1)[:, 0]\n\n            motion_output0, motion_output1 =  motion_pred0.clone(), motion_pred1.clone()\n            kpt0t = kpts0 + motion_output0 \n            kpt1t = kpts1 + motion_output1 \n            if 'topo0' in data and 'topo1' in data:\n            # print(len(data['topo0'][0]), len(data['topo1']), flush=True)\n                for node, nbs in enumerate(data['topo0'][0]):\n                    for nb in nbs:\n                        if vb0[0, nb] and vb0[0, node] and ((kpt0t[0, node] - kpt0t[0, nb]) ** 2).sum() / (((kpts0[0, node] - kpts0[0, nb]) ** 2).sum() + 1e-7) > 3:\n                            vb0[0, nb] = 0\n                            vb0[0, node] = 0\n                for node, nbs in enumerate(data['topo1'][0]):\n                    for nb in nbs:\n                        if vb1[0, nb] and vb1[0, node] and ((kpt1t[0, node] - kpt1t[0, nb]) ** 2).sum() / (((kpts1[0, node] - kpts1[0, nb]) ** 2).sum() + 1e-7) > 3:\n                            vb1[0, nb] = 0\n                            vb1[0, node] = 0\n            return {'r0': motion_output0, 'r1': motion_output1, 'vb0':vb0, 'vb1':vb1,}\n\n        # in the normal train/test mode\n        dim_m, dim_n = data['ms'].float(), data['ns'].float()\n        mmax = dim_m.int().max()\n        nmax = dim_n.int().max()\n        # with torch.no_grad():\n        #     self.corr.eval()\n        score01, score0, score1, dec0, dec1 = self.corr(data)\n\n\n        kpts0, kpts1 = data['keypoints0'][:,:mmax].float(), data['keypoints1'][:,:nmax].float() # BM2, BN2 \n\n        motion_pred0 = torch.softmax(score01, dim=-1) @ kpts1 - kpts0\n        motion_pred1 = torch.softmax(score01.transpose(1, 2), dim=-1) @ kpts0 - kpts1\n\n        motion_pred0 = torch.softmax(score0, dim=-1) @ motion_pred0\n        motion_pred1 = torch.softmax(score1, dim=-1) @ motion_pred1\n\n\n        max0, max1 = score01.max(2), score01.max(1)\n        indices0, indices1 = max0.indices, max1.indices\n        mutual0 = arange_like(indices0, 1)[None] == indices1.gather(1, indices0)\n        mutual1 = arange_like(indices1, 1)[None] == indices0.gather(1, indices1)\n        zero = score01.new_tensor(0)\n\n        mscores0 = torch.where(mutual0, max0.values.exp(), zero)\n        mscores1 = torch.where(mutual1, mscores0.gather(1, indices1), zero)\n        # valid0 = mutual0 & (mscores0 > self.config.match_threshold)\n        # valid1 = mutual1 & valid0.gather(1, indices1)\n        \n        valid0 = mscores0 > 0.2\n        valid1 = valid0.gather(1, indices1)\n        indices0 = torch.where(valid0, indices0, indices0.new_tensor(-1))\n        indices1 = torch.where(valid1, indices1, indices1.new_tensor(-1))\n\n\n\n\n        # motion_pred1[0][indices1[0]==-1] = 0\n\n\n\n\n        \n        vb0 = self.mask_map(dec0)[:, 0]\n        vb1 = self.mask_map(dec1)[:, 0]\n        motion_output0, motion_output1 =  motion_pred0.clone(), motion_pred1.clone()\n\n        if not self.training:\n            motion_pred0[0][indices0[0]!=-1] = kpts1[0][indices0[0][indices0[0]!=-1]] - kpts0[0][indices0[0]!=-1]\n        # # motion_pred0[0][indices0[0]==-1] = 0\n            motion_pred1[0][indices1[0]!=-1] = kpts0[0][indices1[0][indices1[0]!=-1]] - kpts1[0][indices1[0]!=-1]\n            vb0[:] = vb0[:] + 0.7\n            vb1[:] = vb1[:] + 0.7\n\n            # motion0_pred, vb0 = pred0[:, :2].permute(0, 2, 1), pred0[:, 2:][:, 0]\n            # motion1_pred, vb1 = pred1[:, :2].permute(0, 2, 1), pred1[:, 2:][:, 0]\n            \n            # delta0, delta1 = motion_delta[:, :, :mmax].permute(0, 2, 1), motion_delta[:, :, mmax:].permute(0, 2, 1)\n            # motion_output0, motion_output1 =  motion0 + delta0, motion1 + delta1\n            motion_output0, motion_output1 =  motion_pred0.clone(), motion_pred1.clone()\n\n            im0_erode =  data['image0']\n            im1_erode =  data['image1']\n            im0_erode[im0_erode > 0] = 1\n            im0_erode[im0_erode <= 0] = 0\n            im1_erode[im1_erode > 0] = 1\n            im1_erode[im1_erode <= 0] = 0\n            \n            im0_erode = tensor_erode(im0_erode, 7)\n            im1_erode = tensor_erode(im1_erode, 7)\n\n\n            \n            kpt0t = kpts0 + motion_output0 / 2\n            kpt1t = kpts1 + motion_output1 / 2\n        \n        \n            ##################################################\n            ##  Note Here the mini batch size is 1!!!!!!!!  ##\n            ##################################################\n\n            if 'topo0' in data and 'topo1' in data:\n                # print(len(data['topo0'][0]), len(data['topo1']), flush=True)\n                for node, nbs in enumerate(data['topo0'][0]):\n                    for nb in nbs:\n                        if vb0[0, nb] > 0 and vb0[0, node] > 0 and ((kpt0t[0, node] - kpt0t[0, nb]) ** 2).sum() / (((kpts0[0, node] - kpts0[0, nb]) ** 2).sum() + 1e-7) > 5:\n                            vb0[0, nb] = -1\n                            vb0[0, node] = -1\n                for node, nbs in enumerate(data['topo1'][0]):\n                    for nb in nbs:\n                        if vb1[0, nb] > 0 and vb1[0, node] > 0 and ((kpt1t[0, node] - kpt1t[0, nb]) ** 2).sum() / (((kpts1[0, node] - kpts1[0, nb]) ** 2).sum() + 1e-7) > 5:\n                            vb1[0, nb] = -1\n                            vb1[0, node] = -1\n            \n            \n            kpt0t = kpts0 + motion_output0 * 1\n            kpt1t = kpts1 + motion_output1 * 1\n            if 'topo0' in data and 'topo1' in data:\n                ##  print(len(data['topo0'][0]), len(data['topo1']), flush=True)\n                for node, nbs in enumerate(data['topo0'][0]):\n                    for nb in nbs:\n\n                        center = ((kpt0t[0, node] + kpt0t[0, nb]) * 0.5).int()[0]\n                        if center[0] >= 720 or center[1] >= 720:\n                            continue\n\n                        if vb0[0, nb] > 0 and vb0[0, node] > 0 and im1_erode[0,:, center[1], center[0]].mean() > 0.8:\n                            vb0[0, nb] = -1\n                            vb0[0, node] = -1\n\n                for node, nbs in enumerate(data['topo1'][0]):\n                    for nb in nbs:\n                        \n                        center = ((kpt1t[0, node] + kpt1t[0, nb]) * 0.5).int()[0]\n                        if vb1[0, nb] > 0  and vb1[0, node] > 0 and im0_erode[0,:, center[1], center[0]].mean() > 0.8:\n                            vb1[0, nb] = -1\n                            vb1[0, node] = -1\n\n        \n\n        kpt0t = kpts0 + motion_output0 / 2\n        kpt1t = kpts1 + motion_output1 / 2\n\n        \n\n        if 'motion0' in data and 'motion1' in data:\n            loss_motion = torch.nn.functional.l1_loss(motion_pred0, data['motion0'][:, :mmax]) +\\\n                torch.nn.functional.l1_loss(motion_pred1, data['motion1'][:, :nmax])\n            \n\n            EPE0 = ((motion_pred0 - data['motion0'][:, :mmax]) ** 2).sum(dim=-1).sqrt()\n            EPE1 = ((motion_pred1 - data['motion1'][:, :nmax]) ** 2).sum(dim=-1).sqrt()\n            # print(EPE0.size(), 'fdsafdsa')\n\n            EPE = (EPE0.mean() + EPE1.mean()) * 0.5\n\n            if 'visibility0' in data and 'visibility1' in data:\n                loss_visibility = torch.nn.functional.binary_cross_entropy_with_logits(vb0[:, :mmax].view(-1, 1), data['visibility0'][:, :mmax].view(-1, 1), pos_weight=vb0.new_tensor([self.pos_weight])) + \\\n                torch.nn.functional.binary_cross_entropy_with_logits(vb1[:, :nmax].view(-1, 1), data['visibility1'][:, :nmax].view(-1, 1), pos_weight=vb0.new_tensor([self.pos_weight]))\n            \n                VB_Acc = ((((vb0 > 0).float() == data['visibility0'][:, :mmax]).float().sum() + ((vb1 > 0).float() == data['visibility1'][:, :nmax]).float().sum()) * 1.0 / (mmax + nmax))\n            else:\n                loss_visibility = 0\n                VB_Acc = EPE.new_zeros([1])\n            loss = loss_motion + 10 * loss_visibility\n\n            loss_mean = torch.mean(loss)\n\n            b, _, _ = motion_pred0.size()\n\n            return {\n                'keypoints0t': kpt0t,\n                'keypoints1t': kpt1t,\n                'vb0': (vb0 > 0).float(),\n                'vb1': (vb1 > 0).float(),\n                'r0': motion_output0,\n                'r1': motion_output1,\n                'loss': loss_mean,\n                'EPE': EPE,\n                'Visibility Acc': VB_Acc\n            }\n        else:\n            return {\n                'loss': -1,\n                'skip_train': True,\n                'keypointst0': kpts0 + motion_output0,\n                'keypointst1': kpts1 + motion_output1,\n                'vb0': vb0,\n                'vb1': vb1,\n            }\n\n\nif __name__ == '__main__':\n\n    args = argparse.Namespace()\n    args.batch_size = 2\n    args.gap = 5\n    args.type = 'train'\n    args.model = None\n    args.action = None\n    ss = Refiner()\n\n\n    loader = fetch_dataloader(args)\n    # #print(len(loader))\n    for data in loader:\n        # p1, p2, s1, s2, mi = data\n        dict1 = data\n\n        kp1 = dict1['keypoints0']\n        kp2 = dict1['keypoints1']\n        p1 = dict1['image0']\n        p2 = dict1['image1']  \n\n        # #print(s1)\n        # #print(s1.type)\n        mi = dict1['m01']\n        fname = dict1['file_name'] \n        print(dict1['keypoints0'].size(), dict1['keypoints1'].size(), dict1['m01'].size(), dict1['motion0'].size(), dict1['mask0'].size())\n        # print(kp1.shape, p1.shape, mi.shape)  \n        # #print(mi.size())  \n        # #print(mi)\n        # break\n\n        a = ss(data)\n        print(dict1['file_name'])\n        print(a['loss'])\n        print(a['EPE'], a['Visibility Acc'],flush=True)\n        a['loss'].backward()"
  },
  {
    "path": "requirement.txt",
    "content": "opencv-python\npyyaml==5.4.1\nscikit-network\ntqdm\nmatplotlib\neasydict\ngdown"
  },
  {
    "path": "srun.sh",
    "content": "#!/bin/sh\ncurrenttime=`date \"+%Y%m%d%H%M%S\"`\nif [ ! -d log ]; then\n    mkdir log\nfi\n\necho \"[Usage] ./srun.sh config_path [train|eval] partition gpunum\"\n# check config exists\nif [ ! -e $1 ]\nthen\n    echo \"[ERROR] configuration file: $1 does not exists!\"\n    exit\nfi\n\n\nif [ ! -d ${expname} ]; then\n    mkdir ${expname}\nfi\n\necho \"[INFO] saving results to, or loading files from: \"$expname\n\nif [ \"$3\" == \"\" ]; then\n    echo \"[ERROR] enter partition name\"\n    exit\nfi\npartition_name=$3\necho \"[INFO] partition name: $partition_name\"\n\nif [ \"$4\" == \"\" ]; then\n    echo \"[ERROR] enter gpu num\"\n    exit\nfi\ngpunum=$4\ngpunum=$(($gpunum<8?$gpunum:8))\necho \"[INFO] GPU num: $gpunum\"\n((ntask=$gpunum*3))\n\n\nTOOLS=\"srun  --partition=$partition_name -x SG-IDC2-10-51-5-44 --cpus-per-task=16 --gres=gpu:$gpunum -N 1 --mem-per-gpu=32G  --job-name=${config_suffix}\"\nPYTHONCMD=\"python -u main.py --config $1\"\n\nif [ $2 == \"train\" ];\nthen\n    $TOOLS $PYTHONCMD \\\n    --train \nelif [ $2 == \"eval\" ];\nthen\n    $TOOLS $PYTHONCMD \\\n    --eval \nelif [ $2 == \"gen\" ];\nthen\n    $TOOLS $PYTHONCMD \\\n    --gen \nfi\n# elif [ $2 == \"visgt\" ];\n# then\n#     $TOOLS $PYTHONCMD \\\n#     --visgt \n# elif [ $2 == \"anl\" ];\n# then\n#     $TOOLS $PYTHONCMD \\\n#     --anl \n# elif [ $2 == \"sample\" ];\n# then\n#     $TOOLS $PYTHONCMD \\\n#     --sample \n# fi\n\n"
  },
  {
    "path": "utils/chamfer_distance.py",
    "content": "import os\nimport numpy as np\nfrom time import time\nimport cv2\nimport pdb\nimport scipy\nimport scipy.ndimage\nimport torch\nimport torchmetrics\n\nblack_threshold = 255.0 * 0.99\n\n\ndef batch_edt(img, block=1024):\n    expand = False\n    bs,h,w = img.shape\n    diam2 = h**2 + w**2\n    odtype = img.dtype\n    grid = (img.nelement()+block-1) // block\n\n    # cupy implementation\n\n    # default to scipy cpu implementation\n\n    sums = img.sum(dim=(1,2))\n    ans = torch.tensor(np.stack([\n        scipy.ndimage.morphology.distance_transform_edt(i)\n        if s!=0 else  # change scipy behavior for empty image\n        np.ones_like(i) * np.sqrt(diam2)\n        for i,s in zip(1-img, sums)\n    ]), dtype=odtype)\n\n    if expand:\n        ans = ans.unsqueeze(1)\n    return ans\n\n\n############### DERIVED DISTANCES ###############\n\n# input: (bs,h,w) or (bs,1,h,w)\n# returns: (bs,)\n# normalized s.t. metric is same across proportional image scales\n\n# average of two asymmetric distances\n# normalized by diameter and area\ndef batch_chamfer_distance(gt, pred, block=1024, return_more=False):\n    t = batch_chamfer_distance_t(gt, pred, block=block)\n    p = batch_chamfer_distance_p(gt, pred, block=block)\n    cd = (t + p) / 2\n    return cd\ndef batch_chamfer_distance_t(gt, pred, block=1024, return_more=False):\n    #pdb.set_trace()\n    assert gt.device==pred.device and gt.shape==pred.shape\n    bs,h,w = gt.shape[0], gt.shape[-2], gt.shape[-1]\n    dpred = batch_edt(pred, block=block)\n    cd = (gt*dpred).float().mean((-2,-1)) / np.sqrt(h**2+w**2)\n    if len(cd.shape)==2:\n        assert cd.shape[1]==1\n        cd = cd.squeeze(1)\n    return cd\ndef batch_chamfer_distance_p(gt, pred, block=1024, return_more=False):\n    assert gt.device==pred.device and gt.shape==pred.shape\n    bs,h,w = gt.shape[0], gt.shape[-2], gt.shape[-1]\n    dgt = batch_edt(gt, block=block)\n    cd = (pred*dgt).float().mean((-2,-1)) / np.sqrt(h**2+w**2)\n    if len(cd.shape)==2:\n        assert cd.shape[1]==1\n        cd = cd.squeeze(1)\n    return cd\n\n# normalized by diameter\n# always between [0,1]\ndef batch_hausdorff_distance(gt, pred, block=1024, return_more=False):\n    assert gt.device==pred.device and gt.shape==pred.shape\n    bs,h,w = gt.shape[0], gt.shape[-2], gt.shape[-1]\n    dgt = batch_edt(gt, block=block)\n    dpred = batch_edt(pred, block=block)\n    hd = torch.stack([\n        (dgt*pred).amax(dim=(-2,-1)),\n        (dpred*gt).amax(dim=(-2,-1)),\n    ]).amax(dim=0).float() / np.sqrt(h**2+w**2)\n    if len(hd.shape)==2:\n        assert hd.shape[1]==1\n        hd = hd.squeeze(1)\n    return hd\n\n\n############### TORCHMETRICS ###############\n\nclass ChamferDistance2dMetric(torchmetrics.Metric):\n    full_state_update=False\n    def __init__(\n            self, block=1024, convert_dog=True, k=1.6, epsilon=0.01, kernel_factor=4, clip=False,\n            **kwargs,\n        ):\n        super().__init__(**kwargs)\n        self.block = block\n        self.convert_dog = convert_dog\n\n        self.add_state('running_sum', default=torch.tensor(0.0), dist_reduce_fx='sum')\n        self.add_state('running_count', default=torch.tensor(0.0), dist_reduce_fx='sum')\n        return\n\n    def update(self, preds: torch.Tensor, target: torch.Tensor):\n        dist = batch_chamfer_distance(target, preds, block=self.block)\n        self.running_sum += dist.sum()\n        self.running_count += len(dist)\n        return\n        \n    def compute(self):\n        return self.running_sum.float() / self.running_count\n\nclass ChamferDistance2dTMetric(ChamferDistance2dMetric):\n    def update(self, preds: torch.Tensor, target: torch.Tensor):\n        if self.convert_dog:\n            preds = (usketchers.batch_dog(preds, **self.dog_params)>0.5).float()\n            target = (usketchers.batch_dog(target, **self.dog_params)>0.5).float()\n        dist = batch_chamfer_distance_t(target, preds, block=self.block)\n        self.running_sum += dist.sum()\n        self.running_count += len(dist)\n        return\nclass ChamferDistance2dPMetric(ChamferDistance2dMetric):\n    def update(self, preds: torch.Tensor, target: torch.Tensor):\n        if self.convert_dog:\n            preds = (usketchers.batch_dog(preds, **self.dog_params)>0.5).float()\n            target = (usketchers.batch_dog(target, **self.dog_params)>0.5).float()\n        dist = batch_chamfer_distance_p(target, preds, block=self.block)\n        self.running_sum += dist.sum()\n        self.running_count += len(dist)\n        return\n\nclass HausdorffDistance2dMetric(torchmetrics.Metric):\n    def __init__(\n            self, block=1024, convert_dog=True,\n            t=2.0, sigma=1.0, k=1.6, epsilon=0.01, kernel_factor=4, clip=False,\n            **kwargs,\n        ):\n        super().__init__(**kwargs)\n        self.block = block\n        self.convert_dog = convert_dog\n        self.dog_params = {\n            't': t, 'sigma': sigma, 'k': k, 'epsilon': epsilon,\n            'kernel_factor': kernel_factor, 'clip': clip,\n        }\n        self.add_state('running_sum', default=torch.tensor(0.0), dist_reduce_fx='sum')\n        self.add_state('running_count', default=torch.tensor(0.0), dist_reduce_fx='sum')\n        return\n    def update(self, preds: torch.Tensor, target: torch.Tensor):\n        if self.convert_dog:\n            preds = (usketchers.batch_dog(preds, **self.dog_params)>0.5).float()\n            target = (usketchers.batch_dog(target, **self.dog_params)>0.5).float()\n        dist = batch_hausdorff_distance(target, preds, block=self.block)\n        self.running_sum += dist.sum()\n        self.running_count += len(dist)\n        return\n    def compute(self):\n        return self.running_sum.float() / self.running_count\n\n\n\n\ndef rgb2sketch(img, black_threshold):\n    #pdb.set_trace()\n    img[img < black_threshold] = 1\n    img[img >= black_threshold] = 0\n    #cv2.imwrite(\"grey.png\",img*255)\n    return torch.tensor(img)\ndef rgb2gray(rgb):\n    r, g, b = rgb[:,:,0], rgb[:,:,1], rgb[:,:,2]\n    gray = 0.2989 * r + 0.5870 * g + 0.1140 * b\n\n    return gray\n\ndef cd_score(img1, img2):\n\n\n    img1 = rgb2gray(img1.astype(float))\n    img2 = rgb2gray(img2.astype(float))\n    \n    img1_sketch = rgb2sketch(img1, black_threshold)\n    img2_sketch = rgb2sketch(img2, black_threshold)\n\n    img1_sketch = img1_sketch.unsqueeze(0)\n    img2_sketch = img2_sketch.unsqueeze(0)\n\n    CD = ChamferDistance2dMetric()\n    cd = CD(img1_sketch,img2_sketch)\n    return cd\n\n"
  },
  {
    "path": "utils/log.py",
    "content": "# This source code is licensed under the MIT license found in the\n# LICENSE file in the root directory of this open-source project.\n\n\n\"\"\" Define the Logger class to print log\"\"\"\nimport os\nimport sys\nimport logging\nfrom datetime import datetime\n\n\nclass Logger:\n    def __init__(self, args, output_dir):\n\n        log = logging.getLogger(output_dir)\n        if not log.handlers:\n            log.setLevel(logging.DEBUG)\n            # if not os.path.exists(output_dir):\n            #     os.mkdir(args.data.output_dir)\n            fh = logging.FileHandler(os.path.join(output_dir,'log.txt'))\n            fh.setLevel(logging.INFO)\n            ch = ProgressHandler()\n            ch.setLevel(logging.DEBUG)\n            formatter = logging.Formatter(fmt='%(asctime)s %(message)s', datefmt='%m/%d/%Y %I:%M:%S')\n            fh.setFormatter(formatter)\n            ch.setFormatter(formatter)\n            log.addHandler(fh)\n            log.addHandler(ch)\n        self.log = log\n        # setup TensorBoard\n        # if args.tensorboard:\n        #     from tensorboardX import SummaryWriter\n        #     self.writer = SummaryWriter(log_dir=args.output_dir)\n        # else:\n        self.writer = None\n        self.log_per_updates = args.log_per_updates\n\n    def set_progress(self, epoch, total):\n        self.log.info(f'Epoch: {epoch}')\n        self.epoch = epoch\n        self.i = 0\n        self.total = total\n        self.start = datetime.now()\n\n    def update(self, stats):\n        self.i += 1\n        if self.i % self.log_per_updates == 0:\n            remaining = str((datetime.now() - self.start) / self.i * (self.total - self.i))\n            remaining = remaining.split('.')[0]\n            updates = stats.pop('updates')\n            stats_str = ' '.join(f'{key}[{val:.8f}]' for key, val in stats.items())\n            \n            self.log.info(f'> epoch [{self.epoch}] updates[{updates}] {stats_str} eta[{remaining}]')\n            \n            if self.writer:\n                for key, val in stats.items():\n                    self.writer.add_scalar(f'train/{key}', val, updates)\n        if self.i == self.total:\n            self.log.debug('\\n')\n            self.log.debug(f'elapsed time: {str(datetime.now() - self.start).split(\".\")[0]}')\n\n    def log_eval(self, stats, metrics_group=None):\n        stats_str = ' '.join(f'{key}: {val:.8f}' for key, val in stats.items())\n        self.log.info(f'valid {stats_str}')\n        if self.writer:\n            for key, val in stats.items():\n                self.writer.add_scalar(f'valid/{key}', val, self.epoch)\n        # for mode, metrics in metrics_group.items():\n        #     self.log.info(f'evaluation scores ({mode}):')\n        #     for key, (val, _) in metrics.items():\n        #         self.log.info(f'\\t{key} {val:.4f}')\n        # if self.writer and metrics_group is not None:\n        #     for key, val in stats.items():\n        #         self.writer.add_scalar(f'valid/{key}', val, self.epoch)\n        #     for key in list(metrics_group.values())[0]:\n        #         group = {}\n        #         for mode, metrics in metrics_group.items():\n        #             group[mode] = metrics[key][0]\n        #         self.writer.add_scalars(f'valid/{key}', group, self.epoch)\n\n    def __call__(self, msg):\n        self.log.info(msg)\n\n\nclass ProgressHandler(logging.Handler):\n    def __init__(self, level=logging.NOTSET):\n        super().__init__(level)\n\n    def emit(self, record):\n        log_entry = self.format(record)\n        if record.message.startswith('> '):\n            sys.stdout.write('{}\\r'.format(log_entry.rstrip()))\n            sys.stdout.flush()\n        else:\n            sys.stdout.write('{}\\n'.format(log_entry))\n\n"
  },
  {
    "path": "utils/visualize_inbetween.py",
    "content": "import numpy as np\nimport torch\nimport cv2\nfrom .chamfer_distance import cd_score\n\n\n# def make_inter_graph(v2d1, v2d2, topo1, topo2, match12):\n#     valid = (match12 != -1)\n#     marked2 = np.zeros(len(v2d2)).astype(bool)\n#     # print(match12[valid])\n#     marked2[match12[valid]] = True\n\n#     id1toh, id2toh = np.zeros(len(v2d1)), np.zeros(len(v2d2))\n#     id1toh[valid] = np.arange(np.sum(valid))\n#     id2toh[match12[valid]] = np.arange(np.sum(valid))\n#     id1toh[np.invert(valid)] = np.arange(np.sum(1 - valid)) + np.sum(valid)\n#     # print(marked2)\n#     id2toh[np.invert(marked2)] = len(v2d1) + np.arange(np.sum(np.invert(marked2)))\n\n#     id1toh = id1toh.astype(int)\n#     id2toh = id2toh.astype(int)\n\n#     tot_len = len(v2d1) + np.sum(np.invert(marked2))\n\n#     vin1 = v2d1[valid][:]\n#     vin2 = v2d2[match12[valid]][:]\n#     vh = 0.5 * (vin1 + vin2)\n#     vh = np.concatenate((vh, v2d1[np.invert(valid)], v2d2[np.invert(marked2)]), axis=0)\n\n#     topoh = [[] for ii in range(tot_len)]\n\n\n#     for node in range(len(topo1)):\n        \n#         for nb in topo1[node]:\n#             if int(id1toh[nb]) not in topoh[id1toh[node]]:\n#                 topoh[id1toh[node]].append(int(id1toh[nb]))\n\n\n#     for node in range(len(topo2)):\n#         for nb in topo2[node]:\n#             if int(id2toh[nb]) not in topoh[id2toh[node]]:\n#                 topoh[id2toh[node]].append(int(id2toh[nb]))\n\n#     return vh, topoh\n\n\n# def make_inter_graph_valid(v2d1, v2d2, topo1, topo2, match12):\n#     valid = (match12 != -1)\n#     marked2 = np.zeros(len(v2d2)).astype(bool)\n#     # print(match12[valid])\n#     marked2[match12[valid]] = True\n\n#     id1toh, id2toh = np.zeros(len(v2d1)), np.zeros(len(v2d2))\n#     id1toh[valid] = np.arange(np.sum(valid))\n#     id2toh[match12[valid]] = np.arange(np.sum(valid))\n#     id1toh[np.invert(valid)] = np.arange(np.sum(1 - valid)) + np.sum(valid)\n#     # print(marked2)\n#     id2toh[np.invert(marked2)] = len(v2d1) + np.arange(np.sum(np.invert(marked2)))\n\n#     id1toh = id1toh.astype(int)\n#     id2toh = id2toh.astype(int)\n\n#     tot_len = len(v2d1) + np.sum(np.invert(marked2))\n\n#     vin1 = v2d1[valid][:]\n#     vin2 = v2d2[match12[valid]][:]\n#     vh = 0.5 * (vin1 + vin2)\n#     # vh = np.concatenate((vh, v2d1[np.invert(valid)], v2d2[np.invert(marked2)]), axis=0)\n\n#     # topoh = [[] for ii in range(tot_len)]\n#     topoh = [[] for ii in range(np.sum(valid))]\n\n#     for node in range(len(topo1)):\n#         if not valid[node]:\n#             continue\n#         for nb in topo1[node]:\n#             if int(id1toh[nb]) not in topoh[id1toh[node]]:\n#                 if valid[nb]:\n#                     topoh[id1toh[node]].append(int(id1toh[nb]))\n\n\n#     for node in range(len(topo2)):\n#         if not marked2[node]:\n#             continue\n#         for nb in topo2[node]:\n#             if int(id2toh[nb]) not in topoh[id2toh[node]]:\n#                 if marked2[nb]:\n#                     topoh[id2toh[node]].append(int(id2toh[nb]))\n\n#     return vh, topoh\n\n\n\ndef visualize(dict):\n    # print(dict['keypoints0'].size(), flush=True)\n    img1 = ((dict['image0'][0].permute(1, 2, 0).float().numpy() + 1.0) * 255 / 2).astype(np.uint8).copy()\n    original_target = ((dict['imaget'][0].permute(1, 2, 0).float().numpy() + 1.0) * 255 / 2).astype(np.uint8).copy()\n    # img1p = ((dict['image0'].permute(1, 2, 0).float().numpy() + 1.0) * 255 / 2).astype(int).copy()\n    # img2p = ((dict['image1'].permute(1, 2, 0).float().numpy() + 1.0) * 255 / 2).astype(int).copy()\n\n    # img1[:, :, 0] += 255\n    # img1[:, :, 1] += 180\n    # img1[:, :, 2] += 180\n    # img1[img1 > 255] = 255\n\n    # img2[:, :, 0] += 255\n    # img2[:, :, 1] += 180\n    # img2[:, :, 2] += 180\n    # img2[img2 > 255] = 255\n    \n    # img1p[:, :, 0] += 255\n    # img1p[:, :, 1] += 180\n    # img1p[:, :, 2] += 180\n    # img1p[img1p > 255] = 255\n    \n    # img2p[:, :, 0] += 255\n    # img2p[:, :, 1] += 180\n    # img2p[:, :, 2] += 180\n    # img2p[img2p > 255] = 255\n\n    # img1, img2, img1p, img2p = img1.astype(np.uint8), img2.astype(np.uint8), img1p.astype(np.uint8), img2p.astype(np.uint8)\n    motion01 = dict['motion0'][0].cpu().numpy().astype(int) \n    motion21 = dict['motion1'][0].cpu().numpy().astype(int) \n\n    source0_warp = dict['keypoints0t'][0].cpu().numpy().astype(int)\n    source2_warp = dict['keypoints1t'][0].cpu().numpy().astype(int)\n    source0 = dict['keypoints0'][0].cpu().numpy().astype(int)\n    source2 = dict['keypoints1'][0].cpu().numpy().astype(int)\n    source0_topo = dict['topo0'][0]\n    # print(len(dict['topo0']))\n    source2_topo = dict['topo1'][0]\n    visible01 = dict['vb0'][0].cpu().numpy().astype(int)\n    visible21 = dict['vb1'][0].cpu().numpy().astype(int)\n\n    # corr01 = dict['m01'][0].cpu().numpy().astype(int)\n    # corr10 = dict['m10'][0].cpu().numpy().astype(int)\n\n    # canvas = np.zeros_like(img1) + 255\n\n    # source0_warp2 = source0 + motion01 // 2\n    # source2_warp2 = source2 + motion21 // 2\n\n    # for node, nbs in enumerate(source0_topo):\n    #     for nb in nbs:\n    #         # print([source0_warp[nb][0], source0_warp[nb][1]])\n    #         cv2.line(canvas, [source0_warp[node][0], source0_warp[node][1]], [source0_warp[nb][0], source0_warp[nb][1]], [0, 0, 0], 2)\n    # for node, nbs in enumerate(source2_topo):\n    #     for nb in nbs:\n    #         cv2.line(canvas, [source2_warp[node][0], source2_warp[node][1]], [source2_warp[nb][0], source2_warp[nb][1]], [0, 0, 0], 2)\n\n\n    # canvas6 = np.zeros_like(img1) + 255\n\n\n    # for node, nbs in enumerate(source0_topo):\n    #     for nb in nbs:\n    #         # print([source0_warp[nb][0], source0_warp[nb][1]])\n    #         cv2.line(canvas6, [source0_warp2[node][0], source0_warp2[node][1]], [source0_warp2[nb][0], source0_warp2[nb][1]], [0, 0, 0], 2)\n    # for node, nbs in enumerate(source2_topo):\n    #     for nb in nbs:\n    #         cv2.line(canvas6, [source2_warp2[node][0], source2_warp2[node][1]], [source2_warp2[nb][0], source2_warp2[nb][1]], [0, 0, 0], 2)\n\n    canvas2 = np.zeros_like(img1) + 255\n\n  ##  print('huala<<<', source0_warp.mean(), source2_warp.mean(), flush=True)\n\n    # source0_warp = source0 + motion01\n    # source2_warp = source2 + motion21\n\n    for node, nbs in enumerate(source0_topo):\n        for nb in nbs:\n            # if visible01[node] and visible01[nb]:\n            cv2.line(canvas2, [source0_warp[node][0], source0_warp[node][1]], [source0_warp[nb][0], source0_warp[nb][1]], [0, 0, 0], 2)\n    for node, nbs in enumerate(source2_topo):\n        for nb in nbs:\n            # if visible21[node] and visible21[nb]:\n            cv2.line(canvas2, [source2_warp[node][0], source2_warp[node][1]], [source2_warp[nb][0], source2_warp[nb][1]], [0, 0, 0], 2)\n\n    \n\n    # canvas2\n    # black_threshold = 255 // 2\n    # img1_sketch = rgb2sketch(img1, black_threshold)\n    # img2_sketch = rgb2sketch(img2, black_threshold)\n\n    # img1_sketch = img1_sketch.unsqueeze(0)\n    # img2_sketch = img2_sketch.unsqueeze(0)\n\n    # CD = ChamferDistance2dMetric()\n    # cd = CD(img1_sketch,img2_sketch)\n    canvas5 = np.zeros_like(img1) + 255\n\n    # source0_warp = source0 + motion01\n    # source2_warp = source2 + motion21\n\n  ##  print('gulaa>>>', visible01.mean(), visible21.mean(), flush=True)\n\n    for node, nbs in enumerate(source0_topo):\n        for nb in nbs:\n            if visible01[node] and visible01[nb]:\n                cv2.line(canvas5, [source0_warp[node][0], source0_warp[node][1]], [source0_warp[nb][0], source0_warp[nb][1]], [0, 0, 0], 2)\n    for node, nbs in enumerate(source2_topo):\n        for nb in nbs:\n            if visible21[node] and visible21[nb]:\n                cv2.line(canvas5, [source2_warp[node][0], source2_warp[node][1]], [source2_warp[nb][0], source2_warp[nb][1]], [0, 0, 0], 2)\n\n\n\n    canvas3 = np.zeros_like(img1) + 255\n    \n\n    for node, nbs in enumerate(source0_topo):\n        for nb in nbs:\n            cv2.line(canvas3, [source0[node][0], source0[node][1]], [source0[nb][0], source0[nb][1]], [255, 180, 180], 2)\n    for node, nbs in enumerate(source2_topo):\n        for nb in nbs:\n            cv2.line(canvas3, [source2[node][0], source2[node][1]], [source2[nb][0], source2[nb][1]], [180, 180, 255], 2)\n\n    # canvas_corr1 = np.zeros_like(img1) + 255\n    # canvas_corr2 = np.zeros_like(img1) + 255\n\n    canvas_corr1 = ((dict['image0'][0].permute(1, 2, 0).float().numpy() + 1.0) * 255 / 2).astype(int).copy()\n    canvas_corr2 = ((dict['image1'][0].permute(1, 2, 0).float().numpy() + 1.0) * 255 / 2).astype(int).copy()\n\n    canvas_corr1[:, :, 0] += 255\n    canvas_corr1[:, :, 1] += 180\n    canvas_corr1[:, :, 2] += 180\n    canvas_corr1[canvas_corr1 > 255] = 255\n\n    canvas_corr2[:, :, 0] += 255\n    canvas_corr2[:, :, 1] += 180\n    canvas_corr2[:, :, 2] += 180\n    canvas_corr2[canvas_corr2 > 255] = 255\n\n    canvas_corr1 = canvas_corr1.astype(np.uint8)\n    canvas_corr2 = canvas_corr2.astype(np.uint8)\n\n    # colors1_gt, colors2_gt = {}, {}\n    colors1_pred, colors2_pred = {}, {}\n    # cross1_pred, cross2_pred = {}, {}\n    id1 = np.arange(len(source0))\n    id2 = np.arange(len(source2))\n\n    # predicted = dict['matches0'].cpu().data.numpy()[0]\n    # predicted1 = dict['matches1'].cpu().data.numpy()[0]\n    # for index in id1:\n    #     color = [np.random.randint(0, 256), np.random.randint(0, 256), np.random.randint(0, 256)]\n    #         # print(predicted.shape, flush=True)\n    #     # if all_matches[index] != -1:\n    #     #     colors2_gt[all_matches[index]] = color\n    #     if predicted[index] != -1:\n    #         colors2_pred[predicted[index]] = color\n    #     # else:\n    #     #     colors2_pred[predicted[index]] = [0, 0, 0]\n\n    #     # colors1_gt[index] = color if all_matches[index] != -1 else [0, 0, 0]\n    #     colors1_pred[index] = color if predicted[index] != -1 else [0, 0, 0]\n\n    #     # if predicted[index] == -1 and colors1_pred[index] != [0, 0, 0]:\n    #     #     color = [np.random.randint(0, 256), np.random.randint(0, 256), np.random.randint(0, 256)]\n    #     #     colors1_pred[index] = [0, 0, 0]\n    #     #     colors2_pred.pop(all_matches[index])\n    #     # whether predicted correctly\n    #     # if predicted[index] != all_matches[index]:\n    #     #     cross1_pred[index] = True\n    #     #     if predicted[index] != -1:\n    #     #         cross2_pred[predicted[index]] = True\n        \n    # for i, p in enumerate(source0):\n    #     ii = id1[i]\n    #     # print(ii)\n    #     cv2.circle(canvas_corr1, [int(p[0]), int(p[1])], 1, colors1_pred[i], 2)\n    #     # if ii in cross1_pred and cross1_pred[ii]:\n    #     #     cv2.rectangle(img1p, [int(p[0]-1), int(p[1]-1)], [int(p[0]+1), int(p[1]+1)], colors1_pred[i],-1)\n    #     # else:\n    #     #     cv2.circle(img1p, [int(p[0]), int(p[1])], 1, colors1_pred[i], 2)\n        \n    # for ii in id2:\n    #     # print(ii)\n    #     color = [0, 0, 0]\n    #     this_is_umatched = 1\n    #     if ii not in colors2_pred:\n    #         colors2_pred[ii] = color\n\n    # for i, p in enumerate(source2):\n    #     ii = id2[i]\n    #     # print(p)\n    #     # cv2.circle(img2, [int(p[0]), int( p[1])], 1, colors2_gt[ii], 2)\n    #     # if ii in cross2_pred and cross2_pred[ii]:\n    #     #     cv2.rectangle(img2p, [int(p[0]-1), int(p[1]-1)], [int(p[0]+1), int(p[1]+1)], colors2_pred[i], -1)\n    #     # else:\n    #     cv2.circle(canvas_corr2, [int(p[0]), int(p[1])], 1, colors2_pred[i], 2)\n\n\n\n\n    #canvas6, canvas5, canvas, \n    im_h = cv2.hconcat([canvas3,  original_target, canvas2, canvas5])\n  ##  print('<<<< mean cavans5: ', canvas5.mean())\n    cd = cd_score(canvas5.copy(), original_target.copy()) * 1e5\n\n    cv2.putText(im_h, str(cd), \\\n        (720, 100), cv2.FONT_HERSHEY_PLAIN, 3, (0, 0, 255), 2)\n\n\n    \n    return im_h, cd\n"
  },
  {
    "path": "utils/visualize_inbetween2.py",
    "content": "import numpy as np\nimport torch\nimport cv2\nfrom .chamfer_distance import cd_score\n\n\n# def make_inter_graph(v2d1, v2d2, topo1, topo2, match12):\n#     valid = (match12 != -1)\n#     marked2 = np.zeros(len(v2d2)).astype(bool)\n#     # print(match12[valid])\n#     marked2[match12[valid]] = True\n\n#     id1toh, id2toh = np.zeros(len(v2d1)), np.zeros(len(v2d2))\n#     id1toh[valid] = np.arange(np.sum(valid))\n#     id2toh[match12[valid]] = np.arange(np.sum(valid))\n#     id1toh[np.invert(valid)] = np.arange(np.sum(1 - valid)) + np.sum(valid)\n#     # print(marked2)\n#     id2toh[np.invert(marked2)] = len(v2d1) + np.arange(np.sum(np.invert(marked2)))\n\n#     id1toh = id1toh.astype(int)\n#     id2toh = id2toh.astype(int)\n\n#     tot_len = len(v2d1) + np.sum(np.invert(marked2))\n\n#     vin1 = v2d1[valid][:]\n#     vin2 = v2d2[match12[valid]][:]\n#     vh = 0.5 * (vin1 + vin2)\n#     vh = np.concatenate((vh, v2d1[np.invert(valid)], v2d2[np.invert(marked2)]), axis=0)\n\n#     topoh = [[] for ii in range(tot_len)]\n\n\n#     for node in range(len(topo1)):\n        \n#         for nb in topo1[node]:\n#             if int(id1toh[nb]) not in topoh[id1toh[node]]:\n#                 topoh[id1toh[node]].append(int(id1toh[nb]))\n\n\n#     for node in range(len(topo2)):\n#         for nb in topo2[node]:\n#             if int(id2toh[nb]) not in topoh[id2toh[node]]:\n#                 topoh[id2toh[node]].append(int(id2toh[nb]))\n\n#     return vh, topoh\n\n\n# def make_inter_graph_valid(v2d1, v2d2, topo1, topo2, match12):\n#     valid = (match12 != -1)\n#     marked2 = np.zeros(len(v2d2)).astype(bool)\n#     # print(match12[valid])\n#     marked2[match12[valid]] = True\n\n#     id1toh, id2toh = np.zeros(len(v2d1)), np.zeros(len(v2d2))\n#     id1toh[valid] = np.arange(np.sum(valid))\n#     id2toh[match12[valid]] = np.arange(np.sum(valid))\n#     id1toh[np.invert(valid)] = np.arange(np.sum(1 - valid)) + np.sum(valid)\n#     # print(marked2)\n#     id2toh[np.invert(marked2)] = len(v2d1) + np.arange(np.sum(np.invert(marked2)))\n\n#     id1toh = id1toh.astype(int)\n#     id2toh = id2toh.astype(int)\n\n#     tot_len = len(v2d1) + np.sum(np.invert(marked2))\n\n#     vin1 = v2d1[valid][:]\n#     vin2 = v2d2[match12[valid]][:]\n#     vh = 0.5 * (vin1 + vin2)\n#     # vh = np.concatenate((vh, v2d1[np.invert(valid)], v2d2[np.invert(marked2)]), axis=0)\n\n#     # topoh = [[] for ii in range(tot_len)]\n#     topoh = [[] for ii in range(np.sum(valid))]\n\n#     for node in range(len(topo1)):\n#         if not valid[node]:\n#             continue\n#         for nb in topo1[node]:\n#             if int(id1toh[nb]) not in topoh[id1toh[node]]:\n#                 if valid[nb]:\n#                     topoh[id1toh[node]].append(int(id1toh[nb]))\n\n\n#     for node in range(len(topo2)):\n#         if not marked2[node]:\n#             continue\n#         for nb in topo2[node]:\n#             if int(id2toh[nb]) not in topoh[id2toh[node]]:\n#                 if marked2[nb]:\n#                     topoh[id2toh[node]].append(int(id2toh[nb]))\n\n#     return vh, topoh\n\n\n\ndef visualize(dict):\n    # print(dict['keypoints0'].size(), flush=True)\n    img1 = ((dict['image0'][0].permute(1, 2, 0).float().numpy() + 1.0) * 255 / 2).astype(np.uint8).copy()\n    img2 = ((dict['image1'][0].permute(1, 2, 0).float().numpy() + 1.0) * 255 / 2).astype(np.uint8).copy()\n    original_target = ((dict['imaget'][0].permute(1, 2, 0).float().numpy() + 1.0) * 255 / 2).astype(np.uint8).copy()\n    # img1p = ((dict['image0'].permute(1, 2, 0).float().numpy() + 1.0) * 255 / 2).astype(int).copy()\n    # img2p = ((dict['image1'].permute(1, 2, 0).float().numpy() + 1.0) * 255 / 2).astype(int).copy()\n\n    # img1[:, :, 0] += 255\n    # img1[:, :, 1] += 180\n    # img1[:, :, 2] += 180\n    # img1[img1 > 255] = 255\n\n    # img2[:, :, 0] += 255\n    # img2[:, :, 1] += 180\n    # img2[:, :, 2] += 180\n    # img2[img2 > 255] = 255\n    \n    # img1p[:, :, 0] += 255\n    # img1p[:, :, 1] += 180\n    # img1p[:, :, 2] += 180\n    # img1p[img1p > 255] = 255\n    \n    # img2p[:, :, 0] += 255\n    # img2p[:, :, 1] += 180\n    # img2p[:, :, 2] += 180\n    # img2p[img2p > 255] = 255\n\n    # img1, img2, img1p, img2p = img1.astype(np.uint8), img2.astype(np.uint8), img1p.astype(np.uint8), img2p.astype(np.uint8)\n    r0 = dict['r0'][0].cpu().numpy().astype(int) \n    r1 = dict['r1'][0].cpu().numpy().astype(int) \n\n    source0_warp = dict['keypoints0t'][0].cpu().numpy().astype(int)\n    source2_warp = dict['keypoints1t'][0].cpu().numpy().astype(int)\n    source0 = dict['keypoints0'][0].cpu().numpy().astype(int)\n    source2 = dict['keypoints1'][0].cpu().numpy().astype(int)\n    source0_topo = dict['topo0'][0]\n    # print(len(dict['topo0']))\n    source2_topo = dict['topo1'][0]\n    visible01 = dict['vb0'][0].cpu().numpy().astype(int)\n    visible21 = dict['vb1'][0].cpu().numpy().astype(int)\n\n    # corr01 = dict['m01'][0].cpu().numpy().astype(int)\n    # corr10 = dict['m10'][0].cpu().numpy().astype(int)\n\n    # canvas = np.zeros_like(img1) + 255\n\n    # source0_warp2 = source0 + motion01 // 2\n    # source2_warp2 = source2 + motion21 // 2\n\n    # for node, nbs in enumerate(source0_topo):\n    #     for nb in nbs:\n    #         # print([source0_warp[nb][0], source0_warp[nb][1]])\n    #         cv2.line(canvas, [source0_warp[node][0], source0_warp[node][1]], [source0_warp[nb][0], source0_warp[nb][1]], [0, 0, 0], 2)\n    # for node, nbs in enumerate(source2_topo):\n    #     for nb in nbs:\n    #         cv2.line(canvas, [source2_warp[node][0], source2_warp[node][1]], [source2_warp[nb][0], source2_warp[nb][1]], [0, 0, 0], 2)\n\n\n    # canvas6 = np.zeros_like(img1) + 255\n\n\n    # for node, nbs in enumerate(source0_topo):\n    #     for nb in nbs:\n    #         # print([source0_warp[nb][0], source0_warp[nb][1]])\n    #         cv2.line(canvas6, [source0_warp2[node][0], source0_warp2[node][1]], [source0_warp2[nb][0], source0_warp2[nb][1]], [0, 0, 0], 2)\n    # for node, nbs in enumerate(source2_topo):\n    #     for nb in nbs:\n    #         cv2.line(canvas6, [source2_warp2[node][0], source2_warp2[node][1]], [source2_warp2[nb][0], source2_warp2[nb][1]], [0, 0, 0], 2)\n\n    canvas2 = np.zeros_like(img1) + 255\n\n    # source0_warp = source0 + motion01\n    # source2_warp = source2 + motion21\n\n    for node, nbs in enumerate(source0_topo):\n        for nb in nbs:\n            # if visible01[node] and visible01[nb]:\n            cv2.line(canvas2, [source0_warp[node][0], source0_warp[node][1]], [source0_warp[nb][0], source0_warp[nb][1]], [0, 0, 0], 2)\n    for node, nbs in enumerate(source2_topo):\n        for nb in nbs:\n            # if visible21[node] and visible21[nb]:\n            cv2.line(canvas2, [source2_warp[node][0], source2_warp[node][1]], [source2_warp[nb][0], source2_warp[nb][1]], [0, 0, 0], 2)\n\n    \n\n    # canvas2\n    # black_threshold = 255 // 2\n    # img1_sketch = rgb2sketch(img1, black_threshold)\n    # img2_sketch = rgb2sketch(img2, black_threshold)\n\n    # img1_sketch = img1_sketch.unsqueeze(0)\n    # img2_sketch = img2_sketch.unsqueeze(0)\n\n    # CD = ChamferDistance2dMetric()\n    # cd = CD(img1_sketch,img2_sketch)\n    canvases = [np.zeros_like(img1) + 255, np.zeros_like(img1) + 255, np.zeros_like(img1) + 255, np.zeros_like(img1) + 255]\n    canvas5 = np.zeros_like(img1) + 255\n    # canvas7 = np.zeros_like(img1) + 255\n    # canvas8 = np.zeros_like(img1) + 255\n\n\n    # source0_warp = source0 + motion01\n    # source2_warp = source2 + motion21\n    for ii in range(len(canvases)):\n        source0_warp = (source0 + (ii + 1.0) / (len(canvases) + 1.0) * r0).astype(int)\n        source2_warp = (source2 + (1 - (ii + 1.0) / (len(canvases) + 1.0)) * r1).astype(int)\n        for node, nbs in enumerate(source0_topo):\n            for nb in nbs:\n                if visible01[node] and visible01[nb]:\n                    cv2.line(canvases[ii], [source0_warp[node][0], source0_warp[node][1]], [source0_warp[nb][0], source0_warp[nb][1]], [0, 0, 0], 2)\n        for node, nbs in enumerate(source2_topo):\n            for nb in nbs:\n                if visible21[node] and visible21[nb]:\n                    cv2.line(canvases[ii], [source2_warp[node][0], source2_warp[node][1]], [source2_warp[nb][0], source2_warp[nb][1]], [0, 0, 0], 2)\n\n\n\n    canvas3 = np.zeros_like(img1) + 255\n\n    for node, nbs in enumerate(source0_topo):\n        for nb in nbs:\n            cv2.line(canvas3, [source0[node][0], source0[node][1]], [source0[nb][0], source0[nb][1]], [255, 180, 180], 2)\n    for node, nbs in enumerate(source2_topo):\n        for nb in nbs:\n            cv2.line(canvas3, [source2[node][0], source2[node][1]], [source2[nb][0], source2[nb][1]], [180, 180, 255], 2)\n\n    #canvas6, canvas5, canvas, \n    # im_h = cv2.hconcat([canvas3, original_target, canvas2, canvas5])\n    im_h = cv2.hconcat([img1] + canvases + [img2])\n    cd = cd_score(canvas5.copy(), original_target.copy()) * 1e5\n\n    # cv2.putText(im_h, str(cd), \\\n    #     (720, 100), cv2.FONT_HERSHEY_PLAIN, 3, (0, 0, 255), 2)\n\n\n    \n    return im_h, cd\n"
  },
  {
    "path": "utils/visualize_inbetween3.py",
    "content": "import numpy as np\nimport torch\nimport cv2\nfrom .chamfer_distance import cd_score\n\n\n# def make_inter_graph(v2d1, v2d2, topo1, topo2, match12):\n#     valid = (match12 != -1)\n#     marked2 = np.zeros(len(v2d2)).astype(bool)\n#     # print(match12[valid])\n#     marked2[match12[valid]] = True\n\n#     id1toh, id2toh = np.zeros(len(v2d1)), np.zeros(len(v2d2))\n#     id1toh[valid] = np.arange(np.sum(valid))\n#     id2toh[match12[valid]] = np.arange(np.sum(valid))\n#     id1toh[np.invert(valid)] = np.arange(np.sum(1 - valid)) + np.sum(valid)\n#     # print(marked2)\n#     id2toh[np.invert(marked2)] = len(v2d1) + np.arange(np.sum(np.invert(marked2)))\n\n#     id1toh = id1toh.astype(int)\n#     id2toh = id2toh.astype(int)\n\n#     tot_len = len(v2d1) + np.sum(np.invert(marked2))\n\n#     vin1 = v2d1[valid][:]\n#     vin2 = v2d2[match12[valid]][:]\n#     vh = 0.5 * (vin1 + vin2)\n#     vh = np.concatenate((vh, v2d1[np.invert(valid)], v2d2[np.invert(marked2)]), axis=0)\n\n#     topoh = [[] for ii in range(tot_len)]\n\n\n#     for node in range(len(topo1)):\n        \n#         for nb in topo1[node]:\n#             if int(id1toh[nb]) not in topoh[id1toh[node]]:\n#                 topoh[id1toh[node]].append(int(id1toh[nb]))\n\n\n#     for node in range(len(topo2)):\n#         for nb in topo2[node]:\n#             if int(id2toh[nb]) not in topoh[id2toh[node]]:\n#                 topoh[id2toh[node]].append(int(id2toh[nb]))\n\n#     return vh, topoh\n\n\n# def make_inter_graph_valid(v2d1, v2d2, topo1, topo2, match12):\n#     valid = (match12 != -1)\n#     marked2 = np.zeros(len(v2d2)).astype(bool)\n#     # print(match12[valid])\n#     marked2[match12[valid]] = True\n\n#     id1toh, id2toh = np.zeros(len(v2d1)), np.zeros(len(v2d2))\n#     id1toh[valid] = np.arange(np.sum(valid))\n#     id2toh[match12[valid]] = np.arange(np.sum(valid))\n#     id1toh[np.invert(valid)] = np.arange(np.sum(1 - valid)) + np.sum(valid)\n#     # print(marked2)\n#     id2toh[np.invert(marked2)] = len(v2d1) + np.arange(np.sum(np.invert(marked2)))\n\n#     id1toh = id1toh.astype(int)\n#     id2toh = id2toh.astype(int)\n\n#     tot_len = len(v2d1) + np.sum(np.invert(marked2))\n\n#     vin1 = v2d1[valid][:]\n#     vin2 = v2d2[match12[valid]][:]\n#     vh = 0.5 * (vin1 + vin2)\n#     # vh = np.concatenate((vh, v2d1[np.invert(valid)], v2d2[np.invert(marked2)]), axis=0)\n\n#     # topoh = [[] for ii in range(tot_len)]\n#     topoh = [[] for ii in range(np.sum(valid))]\n\n#     for node in range(len(topo1)):\n#         if not valid[node]:\n#             continue\n#         for nb in topo1[node]:\n#             if int(id1toh[nb]) not in topoh[id1toh[node]]:\n#                 if valid[nb]:\n#                     topoh[id1toh[node]].append(int(id1toh[nb]))\n\n\n#     for node in range(len(topo2)):\n#         if not marked2[node]:\n#             continue\n#         for nb in topo2[node]:\n#             if int(id2toh[nb]) not in topoh[id2toh[node]]:\n#                 if marked2[nb]:\n#                     topoh[id2toh[node]].append(int(id2toh[nb]))\n\n#     return vh, topoh\n\n\n\ndef visualize(dict):\n    # print(dict['keypoints0'].size(), flush=True)\n    img1 = ((dict['image0'][0].permute(1, 2, 0).float().numpy() + 1.0) * 255 / 2).astype(np.uint8).copy()\n    original_target = ((dict['imaget'][0].permute(1, 2, 0).float().numpy() + 1.0) * 255 / 2).astype(np.uint8).copy()\n    # img1p = ((dict['image0'].permute(1, 2, 0).float().numpy() + 1.0) * 255 / 2).astype(int).copy()\n    # img2p = ((dict['image1'].permute(1, 2, 0).float().numpy() + 1.0) * 255 / 2).astype(int).copy()\n\n    # img1[:, :, 0] += 255\n    # img1[:, :, 1] += 180\n    # img1[:, :, 2] += 180\n    # img1[img1 > 255] = 255\n\n    # img2[:, :, 0] += 255\n    # img2[:, :, 1] += 180\n    # img2[:, :, 2] += 180\n    # img2[img2 > 255] = 255\n    \n    # img1p[:, :, 0] += 255\n    # img1p[:, :, 1] += 180\n    # img1p[:, :, 2] += 180\n    # img1p[img1p > 255] = 255\n    \n    # img2p[:, :, 0] += 255\n    # img2p[:, :, 1] += 180\n    # img2p[:, :, 2] += 180\n    # img2p[img2p > 255] = 255\n\n    # img1, img2, img1p, img2p = img1.astype(np.uint8), img2.astype(np.uint8), img1p.astype(np.uint8), img2p.astype(np.uint8)\n    motion01 = dict['motion0'][0].cpu().numpy().astype(int) \n    motion21 = dict['motion1'][0].cpu().numpy().astype(int) \n\n    source0_warp = dict['keypoints0t'][0].cpu().numpy().astype(int)\n    source2_warp = dict['keypoints1t'][0].cpu().numpy().astype(int)\n    source0 = dict['keypoints0'][0].cpu().numpy().astype(int)\n    source2 = dict['keypoints1'][0].cpu().numpy().astype(int)\n    source0_topo = dict['topo0'][0]\n    # print(len(dict['topo0']))\n    source2_topo = dict['topo1'][0]\n    visible01 = dict['vb0'][0].cpu().numpy().astype(int)\n    visible21 = dict['vb1'][0].cpu().numpy().astype(int)\n\n    # corr01 = dict['m01'][0].cpu().numpy().astype(int)\n    # corr10 = dict['m10'][0].cpu().numpy().astype(int)\n\n    # canvas = np.zeros_like(img1) + 255\n\n    # source0_warp2 = source0 + motion01 // 2\n    # source2_warp2 = source2 + motion21 // 2\n\n    # for node, nbs in enumerate(source0_topo):\n    #     for nb in nbs:\n    #         # print([source0_warp[nb][0], source0_warp[nb][1]])\n    #         cv2.line(canvas, [source0_warp[node][0], source0_warp[node][1]], [source0_warp[nb][0], source0_warp[nb][1]], [0, 0, 0], 2)\n    # for node, nbs in enumerate(source2_topo):\n    #     for nb in nbs:\n    #         cv2.line(canvas, [source2_warp[node][0], source2_warp[node][1]], [source2_warp[nb][0], source2_warp[nb][1]], [0, 0, 0], 2)\n\n\n    # canvas6 = np.zeros_like(img1) + 255\n\n\n    # for node, nbs in enumerate(source0_topo):\n    #     for nb in nbs:\n    #         # print([source0_warp[nb][0], source0_warp[nb][1]])\n    #         cv2.line(canvas6, [source0_warp2[node][0], source0_warp2[node][1]], [source0_warp2[nb][0], source0_warp2[nb][1]], [0, 0, 0], 2)\n    # for node, nbs in enumerate(source2_topo):\n    #     for nb in nbs:\n    #         cv2.line(canvas6, [source2_warp2[node][0], source2_warp2[node][1]], [source2_warp2[nb][0], source2_warp2[nb][1]], [0, 0, 0], 2)\n\n  #   canvas2 = np.zeros_like(img1) + 255\n\n  # ##  print('huala<<<', source0_warp.mean(), source2_warp.mean(), flush=True)\n\n  #   # source0_warp = source0 + motion01\n  #   # source2_warp = source2 + motion21\n\n  #   for node, nbs in enumerate(source0_topo):\n  #       for nb in nbs:\n  #           # if visible01[node] and visible01[nb]:\n  #           cv2.line(canvas2, [source0_warp[node][0], source0_warp[node][1]], [source0_warp[nb][0], source0_warp[nb][1]], [0, 0, 0], 2)\n  #   for node, nbs in enumerate(source2_topo):\n  #       for nb in nbs:\n  #           # if visible21[node] and visible21[nb]:\n  #           cv2.line(canvas2, [source2_warp[node][0], source2_warp[node][1]], [source2_warp[nb][0], source2_warp[nb][1]], [0, 0, 0], 2)\n\n    \n\n    # canvas2\n    # black_threshold = 255 // 2\n    # img1_sketch = rgb2sketch(img1, black_threshold)\n    # img2_sketch = rgb2sketch(img2, black_threshold)\n\n    # img1_sketch = img1_sketch.unsqueeze(0)\n    # img2_sketch = img2_sketch.unsqueeze(0)\n\n    # CD = ChamferDistance2dMetric()\n    # cd = CD(img1_sketch,img2_sketch)\n    canvas5 = np.zeros_like(img1) + 255\n\n    # source0_warp = source0 + motion01\n    # source2_warp = source2 + motion21\n\n  ##  print('gulaa>>>', visible01.mean(), visible21.mean(), flush=True)\n\n    for node, nbs in enumerate(source0_topo):\n        for nb in nbs:\n            if visible01[node] and visible01[nb]:\n                cv2.line(canvas5, [source0_warp[node][0], source0_warp[node][1]], [source0_warp[nb][0], source0_warp[nb][1]], [0, 0, 0], 2)\n    for node, nbs in enumerate(source2_topo):\n        for nb in nbs:\n            if visible21[node] and visible21[nb]:\n                cv2.line(canvas5, [source2_warp[node][0], source2_warp[node][1]], [source2_warp[nb][0], source2_warp[nb][1]], [0, 0, 0], 2)\n\n\n\n    # canvas3 = np.zeros_like(img1) + 255\n    \n\n    # for node, nbs in enumerate(source0_topo):\n    #     for nb in nbs:\n    #         cv2.line(canvas3, [source0[node][0], source0[node][1]], [source0[nb][0], source0[nb][1]], [255, 180, 180], 2)\n    # for node, nbs in enumerate(source2_topo):\n    #     for nb in nbs:\n    #         cv2.line(canvas3, [source2[node][0], source2[node][1]], [source2[nb][0], source2[nb][1]], [180, 180, 255], 2)\n\n    # canvas_corr1 = np.zeros_like(img1) + 255\n    # canvas_corr2 = np.zeros_like(img1) + 255\n\n    # canvas_corr1 = ((dict['image0'][0].permute(1, 2, 0).float().numpy() + 1.0) * 255 / 2).astype(int).copy()\n    # canvas_corr2 = ((dict['image1'][0].permute(1, 2, 0).float().numpy() + 1.0) * 255 / 2).astype(int).copy()\n\n    # canvas_corr1[:, :, 0] += 255\n    # canvas_corr1[:, :, 1] += 180\n    # canvas_corr1[:, :, 2] += 180\n    # canvas_corr1[canvas_corr1 > 255] = 255\n\n    # canvas_corr2[:, :, 0] += 255\n    # canvas_corr2[:, :, 1] += 180\n    # canvas_corr2[:, :, 2] += 180\n    # canvas_corr2[canvas_corr2 > 255] = 255\n\n    # canvas_corr1 = canvas_corr1.astype(np.uint8)\n    # canvas_corr2 = canvas_corr2.astype(np.uint8)\n\n    # # colors1_gt, colors2_gt = {}, {}\n    # colors1_pred, colors2_pred = {}, {}\n    # # cross1_pred, cross2_pred = {}, {}\n    # id1 = np.arange(len(source0))\n    # id2 = np.arange(len(source2))\n\n    # predicted = dict['matches0'].cpu().data.numpy()[0]\n    # predicted1 = dict['matches1'].cpu().data.numpy()[0]\n    # for index in id1:\n    #     color = [np.random.randint(0, 256), np.random.randint(0, 256), np.random.randint(0, 256)]\n    #         # print(predicted.shape, flush=True)\n    #     # if all_matches[index] != -1:\n    #     #     colors2_gt[all_matches[index]] = color\n    #     if predicted[index] != -1:\n    #         colors2_pred[predicted[index]] = color\n    #     # else:\n    #     #     colors2_pred[predicted[index]] = [0, 0, 0]\n\n    #     # colors1_gt[index] = color if all_matches[index] != -1 else [0, 0, 0]\n    #     colors1_pred[index] = color if predicted[index] != -1 else [0, 0, 0]\n\n    #     # if predicted[index] == -1 and colors1_pred[index] != [0, 0, 0]:\n    #     #     color = [np.random.randint(0, 256), np.random.randint(0, 256), np.random.randint(0, 256)]\n    #     #     colors1_pred[index] = [0, 0, 0]\n    #     #     colors2_pred.pop(all_matches[index])\n    #     # whether predicted correctly\n    #     # if predicted[index] != all_matches[index]:\n    #     #     cross1_pred[index] = True\n    #     #     if predicted[index] != -1:\n    #     #         cross2_pred[predicted[index]] = True\n        \n    # for i, p in enumerate(source0):\n    #     ii = id1[i]\n    #     # print(ii)\n    #     cv2.circle(canvas_corr1, [int(p[0]), int(p[1])], 1, colors1_pred[i], 2)\n    #     # if ii in cross1_pred and cross1_pred[ii]:\n    #     #     cv2.rectangle(img1p, [int(p[0]-1), int(p[1]-1)], [int(p[0]+1), int(p[1]+1)], colors1_pred[i],-1)\n    #     # else:\n    #     #     cv2.circle(img1p, [int(p[0]), int(p[1])], 1, colors1_pred[i], 2)\n        \n    # for ii in id2:\n    #     # print(ii)\n    #     color = [0, 0, 0]\n    #     this_is_umatched = 1\n    #     if ii not in colors2_pred:\n    #         colors2_pred[ii] = color\n\n    # for i, p in enumerate(source2):\n    #     ii = id2[i]\n    #     # print(p)\n    #     # cv2.circle(img2, [int(p[0]), int( p[1])], 1, colors2_gt[ii], 2)\n    #     # if ii in cross2_pred and cross2_pred[ii]:\n    #     #     cv2.rectangle(img2p, [int(p[0]-1), int(p[1]-1)], [int(p[0]+1), int(p[1]+1)], colors2_pred[i], -1)\n    #     # else:\n    #     cv2.circle(canvas_corr2, [int(p[0]), int(p[1])], 1, colors2_pred[i], 2)\n\n\n\n\n    #canvas6, canvas5, canvas, \n    im_h = cv2.hconcat([canvas5])\n    # im_h = canvas5\n  ##  print('<<<< mean cavans5: ', canvas5.mean())\n    # cd = cd_score(canvas5.copy(), original_target.copy()) * 1e5\n\n    # cv2.putText(im_h, str(cd), \\\n    #     (720, 100), cv2.FONT_HERSHEY_PLAIN, 3, (0, 0, 255), 2)\n\n\n    \n    return im_h\n"
  },
  {
    "path": "utils/visualize_video.py",
    "content": "import numpy as np\nimport torch\nimport cv2\n\n\n\ndef visvid(dict, inter_frames=1):\n    img1 = ((dict['image0'][0].permute(1, 2, 0).float().numpy() + 1.0) * 255 / 2).astype(np.uint8).copy()\n    img2 = ((dict['image1'][0].permute(1, 2, 0).float().numpy() + 1.0) * 255 / 2).astype(np.uint8).copy()\n\n    r0 = dict['r0'][0].cpu().numpy()\n    r1 = dict['r1'][0].cpu().numpy()\n\n    source0 = dict['keypoints0'][0].cpu().numpy()\n    source2 = dict['keypoints1'][0].cpu().numpy()\n \n    source0_topo = dict['ntopo0'][0]\n\n    source2_topo = dict['ntopo1'][0]\n    ori_source0_topo = dict['topo0'][0]\n\n    ori_source2_topo = dict['topo1'][0]\n    visible01 = dict['vb0'][0].cpu().numpy().astype(int)\n    visible21 = dict['vb1'][0].cpu().numpy().astype(int)\n\n    canvas1 = np.zeros_like(img1) + 255\n    canvas2 = np.zeros_like(img1) + 255\n\n    for node, nbs in enumerate(ori_source0_topo):\n        for nb in nbs:\n            cv2.line(canvas1, [source0[node][0], source0[node][1]], [source0[nb][0], source0[nb][1]], [0, 0, 0], 2)\n    for node, nbs in enumerate(ori_source2_topo):\n        for nb in nbs:\n            cv2.line(canvas2, [source2[node][0], source2[node][1]], [source2[nb][0], source2[nb][1]], [0, 0, 0], 2)\n\n\n    canvases = [ np.zeros_like(img1).copy() + 255 for jj in range(inter_frames)  ] \n\n    for ii in range(len(canvases)):\n        source0_warp = (source0 + (ii + 1.0) / (len(canvases) + 1.0) * r0).astype(int)\n        source2_warp = (source2 + (1 - (ii + 1.0) / (len(canvases) + 1.0)) * r1).astype(int)\n        for node, nbs in enumerate(source0_topo):\n            for nb in nbs:\n                if visible01[node] and visible01[nb]:\n                    cv2.line(canvases[ii], [source0_warp[node][0], source0_warp[node][1]], [source0_warp[nb][0], source0_warp[nb][1]], [0, 0, 0], 2)\n        for node, nbs in enumerate(source2_topo):\n            for nb in nbs:\n                if visible21[node] and visible21[nb]:\n                    cv2.line(canvases[ii], [source2_warp[node][0], source2_warp[node][1]], [source2_warp[nb][0], source2_warp[nb][1]], [0, 0, 0], 2)\n        # if ii == 15:\n          ##  print('hulala>>>>', source0_warp.mean(), source2_warp.mean(), (ii + 1.0) / (len(canvases) + 1.0), (1 - (ii + 1.0) / (len(canvases) + 1.0)), flush=True)\n          ##  print(canvases[ii].mean())\n\n    for ii in range(len(canvases)):\n        canvases[ii] =  cv2.hconcat([canvas1, canvases[ii]])\n\n    images = [cv2.hconcat([canvas1, canvas1])] + canvases + [cv2.hconcat([canvas2, canvas2])]\n    \n    return images\n"
  }
]