Repository: khy0809/fewshot-egnn
Branch: master
Commit: 205fa80ec7cb
Files: 21
Total size: 94.9 KB
Directory structure:
gitextract__r5a0unu/
├── .idea/
│ ├── egnn_distribute.iml
│ ├── modules.xml
│ ├── vcs.xml
│ └── workspace.xml
├── LICENSE
├── README.md
├── __init__.py
├── _version.py
├── data.py
├── eval.py
├── model.py
├── torchtools/
│ ├── __init__.py
│ ├── _version.py
│ └── tt/
│ ├── __init__.py
│ ├── arg.py
│ ├── layer.py
│ ├── logger.py
│ ├── stat.py
│ ├── trainer.py
│ └── utils.py
└── train.py
================================================
FILE CONTENTS
================================================
================================================
FILE: .idea/egnn_distribute.iml
================================================
<?xml version="1.0" encoding="UTF-8"?>
<module type="PYTHON_MODULE" version="4">
<component name="NewModuleRootManager">
<content url="file://$MODULE_DIR$" />
<orderEntry type="jdk" jdkName="Remote Python 3.6.8 (sftp://root@instance.cloud.kakaobrain.com:11255/opt/conda/bin/python3)" jdkType="Python SDK" />
<orderEntry type="sourceFolder" forTests="false" />
</component>
<component name="TestRunnerService">
<option name="PROJECT_TEST_RUNNER" value="Unittests" />
</component>
</module>
================================================
FILE: .idea/modules.xml
================================================
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="ProjectModuleManager">
<modules>
<module fileurl="file://$PROJECT_DIR$/.idea/egnn_distribute.iml" filepath="$PROJECT_DIR$/.idea/egnn_distribute.iml" />
</modules>
</component>
</project>
================================================
FILE: .idea/vcs.xml
================================================
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="VcsDirectoryMappings">
<mapping directory="$PROJECT_DIR$" vcs="Git" />
</component>
</project>
================================================
FILE: .idea/workspace.xml
================================================
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="ChangeListManager">
<list default="true" id="f20b581c-b8b4-4c9e-9203-c0c2c2f454b5" name="Default Changelist" comment="" />
<option name="EXCLUDED_CONVERTED_TO_IGNORED" value="true" />
<option name="SHOW_DIALOG" value="false" />
<option name="HIGHLIGHT_CONFLICTS" value="true" />
<option name="HIGHLIGHT_NON_ACTIVE_CHANGELIST" value="false" />
<option name="LAST_RESOLUTION" value="IGNORE" />
</component>
<component name="FUSProjectUsageTrigger">
<session id="-1231903785">
<usages-collector id="statistics.lifecycle.project">
<counts>
<entry key="project.open.time.1" value="2" />
<entry key="project.opened" value="2" />
</counts>
</usages-collector>
<usages-collector id="statistics.file.extensions.open">
<counts>
<entry key="md" value="2" />
<entry key="py" value="4" />
</counts>
</usages-collector>
<usages-collector id="statistics.file.types.open">
<counts>
<entry key="Markdown" value="2" />
<entry key="Python" value="4" />
</counts>
</usages-collector>
<usages-collector id="statistics.file.extensions.edit">
<counts>
<entry key="md" value="1812" />
<entry key="py" value="112" />
</counts>
</usages-collector>
<usages-collector id="statistics.file.types.edit">
<counts>
<entry key="Markdown" value="1812" />
<entry key="Python" value="112" />
</counts>
</usages-collector>
</session>
</component>
<component name="FileEditorManager">
<leaf SIDE_TABS_SIZE_LIMIT_KEY="300">
<file pinned="false" current-in-tab="false">
<entry file="file://$PROJECT_DIR$/data.py">
<provider selected="true" editor-type-id="text-editor">
<state relative-caret-position="3705">
<caret line="247" column="37" selection-start-line="247" selection-start-column="37" selection-end-line="247" selection-end-column="37" />
<folding>
<element signature="e#0#37#0" expanded="true" />
</folding>
</state>
</provider>
</entry>
</file>
<file pinned="false" current-in-tab="false">
<entry file="file://$PROJECT_DIR$/model.py">
<provider selected="true" editor-type-id="text-editor">
<state relative-caret-position="615">
<caret line="41" column="30" selection-start-line="41" selection-start-column="30" selection-end-line="41" selection-end-column="30" />
<folding>
<element signature="e#0#24#0" expanded="true" />
</folding>
</state>
</provider>
</entry>
</file>
<file pinned="false" current-in-tab="false">
<entry file="file://$PROJECT_DIR$/train.py">
<provider selected="true" editor-type-id="text-editor">
<state relative-caret-position="5820">
<caret line="388" column="49" selection-start-line="388" selection-start-column="26" selection-end-line="388" selection-end-column="49" />
<folding>
<element signature="e#0#24#0" expanded="true" />
</folding>
</state>
</provider>
</entry>
</file>
<file pinned="false" current-in-tab="true">
<entry file="file://$PROJECT_DIR$/README.md">
<provider selected="true" editor-type-id="split-provider[text-editor;markdown-preview-editor]">
<state split_layout="SPLIT">
<first_editor relative-caret-position="239">
<caret line="144" column="47" selection-start-line="144" selection-start-column="47" selection-end-line="144" selection-end-column="47" />
</first_editor>
<second_editor />
</state>
</provider>
</entry>
</file>
</leaf>
</component>
<component name="FindInProjectRecents">
<findStrings>
<find>tt.arg.inter_dea</find>
<find>inter_deactivate</find>
</findStrings>
</component>
<component name="Git.Settings">
<option name="RECENT_GIT_ROOT_PATH" value="$PROJECT_DIR$" />
</component>
<component name="IdeDocumentHistory">
<option name="CHANGED_PATHS">
<list>
<option value="$PROJECT_DIR$/model.py" />
<option value="$PROJECT_DIR$/train.py" />
<option value="$PROJECT_DIR$/README.md" />
</list>
</option>
</component>
<component name="JsBuildToolGruntFileManager" detection-done="true" sorting="DEFINITION_ORDER" />
<component name="JsBuildToolPackageJson" detection-done="true" sorting="DEFINITION_ORDER" />
<component name="JsGulpfileManager">
<detection-done>true</detection-done>
<sorting>DEFINITION_ORDER</sorting>
</component>
<component name="ProjectFrameBounds" fullScreen="true">
<option name="y" value="23" />
<option name="width" value="1440" />
<option name="height" value="877" />
</component>
<component name="ProjectLevelVcsManager" settingsEditedManually="true" />
<component name="ProjectView">
<navigator proportions="" version="1">
<foldersAlwaysOnTop value="true" />
</navigator>
<panes>
<pane id="ProjectPane">
<subPane>
<expand>
<path>
<item name="egnn_distribute" type="b2602c69:ProjectViewProjectNode" />
<item name="egnn_distribute" type="462c0819:PsiDirectoryNode" />
</path>
</expand>
<select />
</subPane>
</pane>
<pane id="Scope" />
</panes>
</component>
<component name="PropertiesComponent">
<property name="WebServerToolWindowFactoryState" value="true" />
<property name="last_opened_file_path" value="$PROJECT_DIR$" />
<property name="nodejs_interpreter_path.stuck_in_default_project" value="undefined stuck path" />
<property name="nodejs_npm_path_reset_for_default_project" value="true" />
<property name="settings.editor.selected.configurable" value="com.jetbrains.python.configuration.PyActiveSdkModuleConfigurable" />
</component>
<component name="RunDashboard">
<option name="ruleStates">
<list>
<RuleState>
<option name="name" value="ConfigurationTypeDashboardGroupingRule" />
</RuleState>
<RuleState>
<option name="name" value="StatusDashboardGroupingRule" />
</RuleState>
</list>
</option>
</component>
<component name="SvnConfiguration">
<configuration />
</component>
<component name="TaskManager">
<task active="true" id="Default" summary="Default task">
<changelist id="f20b581c-b8b4-4c9e-9203-c0c2c2f454b5" name="Default Changelist" comment="" />
<created>1556855662817</created>
<option name="number" value="Default" />
<option name="presentableId" value="Default" />
<updated>1556855662817</updated>
</task>
<servers />
</component>
<component name="ToolWindowManager">
<frame x="0" y="0" width="1440" height="900" extended-state="0" />
<editor active="true" />
<layout>
<window_info active="true" content_ui="combo" id="Project" order="0" visible="true" weight="0.19456366" />
<window_info id="Structure" order="1" side_tool="true" weight="0.25" />
<window_info id="Favorites" order="2" side_tool="true" />
<window_info anchor="bottom" id="Message" order="0" />
<window_info anchor="bottom" id="Find" order="1" />
<window_info anchor="bottom" id="Run" order="2" />
<window_info anchor="bottom" id="Debug" order="3" weight="0.4" />
<window_info anchor="bottom" id="Cvs" order="4" weight="0.25" />
<window_info anchor="bottom" id="Inspection" order="5" weight="0.4" />
<window_info anchor="bottom" id="TODO" order="6" />
<window_info anchor="bottom" id="Docker" order="7" show_stripe_button="false" />
<window_info anchor="bottom" id="Version Control" order="8" show_stripe_button="false" />
<window_info anchor="bottom" id="File Transfer" order="9" visible="true" weight="0.32771084" />
<window_info anchor="bottom" id="Database Changes" order="10" show_stripe_button="false" />
<window_info anchor="bottom" id="Terminal" order="11" />
<window_info anchor="bottom" id="Event Log" order="12" side_tool="true" />
<window_info anchor="bottom" id="Python Console" order="13" />
<window_info anchor="right" id="Commander" internal_type="SLIDING" order="0" type="SLIDING" weight="0.4" />
<window_info anchor="right" id="Ant Build" order="1" weight="0.25" />
<window_info anchor="right" content_ui="combo" id="Hierarchy" order="2" weight="0.25" />
<window_info anchor="right" id="Remote Host" order="3" />
<window_info anchor="right" id="SciView" order="4" />
<window_info anchor="right" id="Database" order="5" />
</layout>
</component>
<component name="TypeScriptGeneratedFilesManager">
<option name="version" value="1" />
</component>
<component name="VcsContentAnnotationSettings">
<option name="myLimit" value="2678400000" />
</component>
<component name="editorHistoryManager">
<entry file="file://$PROJECT_DIR$/eval.py">
<provider selected="true" editor-type-id="text-editor">
<state relative-caret-position="405">
<caret line="27" column="46" lean-forward="true" selection-start-line="27" selection-start-column="46" selection-end-line="27" selection-end-column="46" />
<folding>
<element signature="e#0#24#0" expanded="true" />
</folding>
</state>
</provider>
</entry>
<entry file="file://$PROJECT_DIR$/data.py">
<provider selected="true" editor-type-id="text-editor">
<state relative-caret-position="3705">
<caret line="247" column="37" selection-start-line="247" selection-start-column="37" selection-end-line="247" selection-end-column="37" />
<folding>
<element signature="e#0#37#0" expanded="true" />
</folding>
</state>
</provider>
</entry>
<entry file="file://$PROJECT_DIR$/model.py">
<provider selected="true" editor-type-id="text-editor">
<state relative-caret-position="615">
<caret line="41" column="30" selection-start-line="41" selection-start-column="30" selection-end-line="41" selection-end-column="30" />
<folding>
<element signature="e#0#24#0" expanded="true" />
</folding>
</state>
</provider>
</entry>
<entry file="file://$PROJECT_DIR$/train.py">
<provider selected="true" editor-type-id="text-editor">
<state relative-caret-position="5820">
<caret line="388" column="49" selection-start-line="388" selection-start-column="26" selection-end-line="388" selection-end-column="49" />
<folding>
<element signature="e#0#24#0" expanded="true" />
</folding>
</state>
</provider>
</entry>
<entry file="file://$PROJECT_DIR$/README.md">
<provider selected="true" editor-type-id="split-provider[text-editor;markdown-preview-editor]">
<state split_layout="SPLIT">
<first_editor relative-caret-position="239">
<caret line="144" column="47" selection-start-line="144" selection-start-column="47" selection-end-line="144" selection-end-column="47" />
</first_editor>
<second_editor />
</state>
</provider>
</entry>
</component>
</project>
================================================
FILE: LICENSE
================================================
MIT License
Copyright (c) 2019 Jongmin Kim
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
================================================
FILE: README.md
================================================
# fewshot-egnn
### Introduction
The current project page provides pytorch code that implements the following CVPR2019 paper:
**Title:** "Edge-labeling Graph Neural Network for Few-shot Learning"
**Authors:** Jongmin Kim, Taesup Kim, Sungwoong Kim, Chang D.Yoo
**Institution:** KAIST, KaKaoBrain
**Code:** https://github.com/khy0809/fewshot-egnn
**Arxiv:** https://arxiv.org/abs/1905.01436
**Abstract:**
In this paper, we propose a novel edge-labeling graph
neural network (EGNN), which adapts a deep neural network
on the edge-labeling graph, for few-shot learning.
The previous graph neural network (GNN) approaches in
few-shot learning have been based on the node-labeling
framework, which implicitly models the intra-cluster similarity
and the inter-cluster dissimilarity. In contrast, the
proposed EGNN learns to predict the edge-labels rather
than the node-labels on the graph that enables the evolution
of an explicit clustering by iteratively updating the edgelabels
with direct exploitation of both intra-cluster similarity
and the inter-cluster dissimilarity. It is also well suited
for performing on various numbers of classes without retraining,
and can be easily extended to perform a transductive
inference. The parameters of the EGNN are learned
by episodic training with an edge-labeling loss to obtain a
well-generalizable model for unseen low-data problem. On
both of the supervised and semi-supervised few-shot image
classification tasks with two benchmark datasets, the proposed
EGNN significantly improves the performances over
the existing GNNs.
### Citation
If you find this code useful you can cite us using the following bibTex:
```
@article{kim2019egnn,
title={Edge-labeling Graph Neural Network for Few-shot Learning},
author={Jongmin Kim, Taesup Kim, Sungwoong Kim, Chang D. Yoo},
journal={arXiv preprint arXiv:1905.01436},
year={2019}
}
```
### Platform
This code was developed and tested with pytorch version 1.0.1
### Setting
You can download miniImagenet dataset from [here](https://drive.google.com/open?id=15WuREBvhEbSWo4fTr1r-vMY0C_6QWv4w).
Download 'mini_imagenet_train/val/test.pickle', and put them in the path
'tt.arg.dataset_root/mini-imagenet/compacted_dataset/'
In ```train.py```, replace the dataset root directory with your own:
tt.arg.dataset_root = '/data/private/dataset'
### Training
```
# ************************** miniImagenet, 5way 1shot *****************************
$ python3 train.py --dataset mini --num_ways 5 --num_shots 1 --transductive False
$ python3 train.py --dataset mini --num_ways 5 --num_shots 1 --transductive True
# ************************** miniImagenet, 5way 5shot *****************************
$ python3 train.py --dataset mini --num_ways 5 --num_shots 5 --transductive False
$ python3 train.py --dataset mini --num_ways 5 --num_shots 5 --transductive True
# ************************** miniImagenet, 10way 5shot *****************************
$ python3 train.py --dataset mini --num_ways 10 --num_shots 5 --meta_batch_size 20 --transductive True
# ************************** tieredImagenet, 5way 5shot *****************************
$ python3 train.py --dataset tiered --num_ways 5 --num_shots 5 --transductive False
$ python3 train.py --dataset tiered --num_ways 5 --num_shots 5 --transductive True
# **************** miniImagenet, 5way 5shot, 20% labeled (semi) *********************
$ python3 train.py --dataset mini --num_ways 5 --num_shots 5 --num_unlabeled 4 --transductive False
$ python3 train.py --dataset mini --num_ways 5 --num_shots 5 --num_unlabeled 4 --transductive True
```
### Evaluation
The trained models are saved in the path './asset/checkpoints/', with the name of 'D-{dataset}-N-{ways}-K-{shots}-U-{num_unlabeld}-L-{num_layers}-B-{batch size}-T-{transductive}'.
So, for example, if you want to test the trained model of 'miniImagenet, 5way 1shot, transductive' setting, you can give --test_model argument as follow:
```
$ python3 eval.py --test_model D-mini_N-5_K-1_U-0_L-3_B-40_T-True
```
## Result
Here are some experimental results presented in the paper. You should be able to reproduce all the results by using the trained models which can be downloaded from [here](https://drive.google.com/open?id=15WuREBvhEbSWo4fTr1r-vMY0C_6QWv4w).
#### miniImageNet, non-transductive
| Model | 5-way 5-shot acc (%)|
|--------------------------| ------------------: |
| Matching Networks [1] | 55.30 |
| Reptile [2] | 62.74 |
| Prototypical Net [3] | 65.77 |
| GNN [4] | 66.41 |
| **(ours)** EGNN | **66.85** |
#### miniImageNet, transductive
| Model | 5-way 5-shot acc (%)|
|--------------------------| ------------------: |
| MAML [5] | 63.11 |
| Reptile + BN [2] | 65.99 |
| Relation Net [6] | 67.07 |
| MAML + Transduction [5] | 66.19 |
| TPN [7] | 69.43 |
| TPN (Higher K) [7] | 69.86 |
| **(ours)** EGNN | **76.37** |
#### tieredImageNet, non-transductive
| Model | 5-way 5-shot acc (%)|
|--------------------------| ------------------: |
| Reptile [2] | 66.47 |
| Prototypical Net [3] | 69.57 |
| **(ours)** EGNN | **70.98** |
#### tieredImageNet, transductive
| Model | 5-way 5-shot acc (%)|
|--------------------------| ------------------: |
| MAML [5] | 70.30 |
| Reptile + BN [2] | 71.03 |
| Relation Net [6] | 71.31 |
| MAML + Transduction [5] | 70.83 |
| TPN [7] | 72.58 |
| **(ours)** EGNN | **80.15** |
#### miniImageNet, semi-supervised, 5-way 5-shot
| Model | 20% | 40% | 60% | 100% |
|--------------------------| ------------------: | ------------------: | ------------------: | ------------------: |
| GNN-LabeledOnly [4] | 50.33 | 56.91 | - | 66.41 |
| GNN-Semi [4] | 52.45 | 58.76 | - | 66.41 |
| EGNN-LabeledOnly | 52.86 | - | - | 66.85 |
| EGNN-Semi | 61.88 | 62.52 | 63.53 | 66.85 |
| EGNN-LabeledOnly (Transductive) | 59.18 | - | - | 76.37 |
| EGNN-Semi (Transductive) | 63.62 | 64.32 | 66.37 | 76.37 |
#### miniImageNet, cross-way experiment
| Model | train way | test way | Accuracy |
|--------------------------| ------------------: | ------------------: | ------------------: |
| GNN | 5 | 5 | 66.41 |
| GNN | 5 | 10 | N/A |
| GNN | 10 | 10 | 51.75 |
| GNN | 10 | 5 | N/A |
| EGNN | 5 | 5 | 76.37 |
| EGNN | 5 | 10 | 56.35 |
| EGNN | 10 | 10 | 57.61 |
| EGNN | 10 | 5 | 76.27 |
### References
```
[1] O. Vinyals et al. Matching networks for one shot learning.
[2] A Nichol, J Achiam, J Schulman, On first-order meta-learning algorithms.
[3] J. Snell, K. Swersky, and R. S. Zemel. Prototypical networks for few-shot learning.
[4] V Garcia, J Bruna, Few-shot learning with graph neural network.
[5] C. Finn, P. Abbeel, and S. Levine. Model-agnostic meta-learning for fast adaptation of deep networks.
[6] F. Sung et al, Learning to Compare: Relation Network for Few-Shot Learning.
[7] Y Liu, J Lee, M Park, S Kim, Y Yang, Transductive propagation network for few-shot learning.
================================================
FILE: __init__.py
================================================
import numpy as np
import torch
from torch import nn
from torch import optim
from torch import cuda
from torch import utils
from torch.nn import functional as F
from torch.utils.data import *
from torch.distributions import *
from torchtools import tt
__author__ = 'namju.kim@kakaobrain.com'
# initialize seed
if tt.arg.seed:
np.random.seed(tt.arg.seed)
torch.manual_seed(tt.arg.seed)
================================================
FILE: _version.py
================================================
__version__ = '0.4.0' # align version with pytorch
================================================
FILE: data.py
================================================
from __future__ import print_function
from torchtools import *
import torch.utils.data as data
import random
import os
import numpy as np
from PIL import Image as pil_image
import pickle
from itertools import islice
from torchvision import transforms
class MiniImagenetLoader(data.Dataset):
def __init__(self, root, partition='train'):
super(MiniImagenetLoader, self).__init__()
# set dataset information
self.root = root
self.partition = partition
self.data_size = [3, 84, 84]
# set normalizer
mean_pix = [x / 255.0 for x in [120.39586422, 115.59361427, 104.54012653]]
std_pix = [x / 255.0 for x in [70.68188272, 68.27635443, 72.54505529]]
normalize = transforms.Normalize(mean=mean_pix, std=std_pix)
# set transformer
if self.partition == 'train':
self.transform = transforms.Compose([transforms.RandomCrop(84, padding=4),
lambda x: np.asarray(x),
transforms.ToTensor(),
normalize])
else: # 'val' or 'test' ,
self.transform = transforms.Compose([lambda x: np.asarray(x),
transforms.ToTensor(),
normalize])
# load data
self.data = self.load_dataset()
def load_dataset(self):
# load data
dataset_path = os.path.join(self.root, 'mini-imagenet/compacted_datasets', 'mini_imagenet_%s.pickle' % self.partition)
with open(dataset_path, 'rb') as handle:
data = pickle.load(handle)
# for each class
for c_idx in data:
# for each image
for i_idx in range(len(data[c_idx])):
# resize
image_data = pil_image.fromarray(np.uint8(data[c_idx][i_idx]))
image_data = image_data.resize((self.data_size[2], self.data_size[1]))
#image_data = np.array(image_data, dtype='float32')
#image_data = np.transpose(image_data, (2, 0, 1))
# save
data[c_idx][i_idx] = image_data
return data
def get_task_batch(self,
num_tasks=5,
num_ways=20,
num_shots=1,
num_queries=1,
seed=None):
if seed is not None:
random.seed(seed)
# init task batch data
support_data, support_label, query_data, query_label = [], [], [], []
for _ in range(num_ways * num_shots):
data = np.zeros(shape=[num_tasks] + self.data_size,
dtype='float32')
label = np.zeros(shape=[num_tasks],
dtype='float32')
support_data.append(data)
support_label.append(label)
for _ in range(num_ways * num_queries):
data = np.zeros(shape=[num_tasks] + self.data_size,
dtype='float32')
label = np.zeros(shape=[num_tasks],
dtype='float32')
query_data.append(data)
query_label.append(label)
# get full class list in dataset
full_class_list = list(self.data.keys())
# for each task
for t_idx in range(num_tasks):
# define task by sampling classes (num_ways)
task_class_list = random.sample(full_class_list, num_ways)
# for each sampled class in task
for c_idx in range(num_ways):
# sample data for support and query (num_shots + num_queries)
class_data_list = random.sample(self.data[task_class_list[c_idx]], num_shots + num_queries)
# load sample for support set
for i_idx in range(num_shots):
# set data
support_data[i_idx + c_idx * num_shots][t_idx] = self.transform(class_data_list[i_idx])
support_label[i_idx + c_idx * num_shots][t_idx] = c_idx
# load sample for query set
for i_idx in range(num_queries):
query_data[i_idx + c_idx * num_queries][t_idx] = self.transform(class_data_list[num_shots + i_idx])
query_label[i_idx + c_idx * num_queries][t_idx] = c_idx
# convert to tensor (num_tasks x (num_ways * (num_supports + num_queries)) x ...)
support_data = torch.stack([torch.from_numpy(data).float().to(tt.arg.device) for data in support_data], 1)
support_label = torch.stack([torch.from_numpy(label).float().to(tt.arg.device) for label in support_label], 1)
query_data = torch.stack([torch.from_numpy(data).float().to(tt.arg.device) for data in query_data], 1)
query_label = torch.stack([torch.from_numpy(label).float().to(tt.arg.device) for label in query_label], 1)
return [support_data, support_label, query_data, query_label]
class TieredImagenetLoader(data.Dataset):
def __init__(self, root, partition='train'):
self.root = root
self.partition = partition # train/val/test
#self.preprocess()
self.data_size = [3, 84, 84]
# load data
self.data = self.load_dataset()
# if not self._check_exists_():
# self._init_folders_()
# if self.check_decompress():
# self._decompress_()
# self._preprocess_()
def get_image_paths(self, file):
images_path, class_names = [], []
with open(file, 'r') as f:
f.readline()
for line in f:
name, class_ = line.split(',')
class_ = class_[0:(len(class_)-1)]
path = self.root + '/tiered-imagenet/images/'+name
images_path.append(path)
class_names.append(class_)
return class_names, images_path
def preprocess(self):
print('\nPreprocessing Tiered-Imagenet images...')
(class_names_train, images_path_train) = self.get_image_paths('%s/tiered-imagenet/train.csv' % self.root)
(class_names_test, images_path_test) = self.get_image_paths('%s/tiered-imagenet/test.csv' % self.root)
(class_names_val, images_path_val) = self.get_image_paths('%s/tiered-imagenet/val.csv' % self.root)
keys_train = list(set(class_names_train))
keys_test = list(set(class_names_test))
keys_val = list(set(class_names_val))
label_encoder = {}
label_decoder = {}
for i in range(len(keys_train)):
label_encoder[keys_train[i]] = i
label_decoder[i] = keys_train[i]
for i in range(len(keys_train), len(keys_train)+len(keys_test)):
label_encoder[keys_test[i-len(keys_train)]] = i
label_decoder[i] = keys_test[i-len(keys_train)]
for i in range(len(keys_train)+len(keys_test), len(keys_train)+len(keys_test)+len(keys_val)):
label_encoder[keys_val[i-len(keys_train) - len(keys_test)]] = i
label_decoder[i] = keys_val[i-len(keys_train)-len(keys_test)]
counter = 0
train_set = {}
for class_, path in zip(class_names_train, images_path_train):
img = pil_image.open(path)
img = img.convert('RGB')
img = img.resize((84, 84), pil_image.ANTIALIAS)
img = np.array(img, dtype='float32')
if label_encoder[class_] not in train_set:
train_set[label_encoder[class_]] = []
train_set[label_encoder[class_]].append(img)
counter += 1
if counter % 1000 == 0:
print("Counter "+str(counter) + " from " + str(len(images_path_train)))
test_set = {}
for class_, path in zip(class_names_test, images_path_test):
img = pil_image.open(path)
img = img.convert('RGB')
img = img.resize((84, 84), pil_image.ANTIALIAS)
img = np.array(img, dtype='float32')
if label_encoder[class_] not in test_set:
test_set[label_encoder[class_]] = []
test_set[label_encoder[class_]].append(img)
counter += 1
if counter % 1000 == 0:
print("Counter " + str(counter) + " from "+str(len(class_names_test)))
val_set = {}
for class_, path in zip(class_names_val, images_path_val):
img = pil_image.open(path)
img = img.convert('RGB')
img = img.resize((84, 84), pil_image.ANTIALIAS)
img = np.array(img, dtype='float32')
if label_encoder[class_] not in val_set:
val_set[label_encoder[class_]] = []
val_set[label_encoder[class_]].append(img)
counter += 1
if counter % 1000 == 0:
print("Counter "+str(counter) + " from " + str(len(class_names_val)))
partition_count = 0
for item in self.chunks(train_set, 20):
partition_count = partition_count + 1
with open(os.path.join(self.root, 'tiered-imagenet/compacted_datasets', 'tiered_imagenet_train_{}.pickle'.format(partition_count)), 'wb') as handle:
pickle.dump(item, handle, protocol=2)
partition_count = 0
for item in self.chunks(test_set, 20):
partition_count = partition_count + 1
with open(os.path.join(self.root, 'tiered-imagenet/compacted_datasets', 'tiered_imagenet_test_{}.pickle'.format(partition_count)), 'wb') as handle:
pickle.dump(item, handle, protocol=2)
partition_count = 0
for item in self.chunks(val_set, 20):
partition_count = partition_count + 1
with open(os.path.join(self.root, 'tiered-imagenet/compacted_datasets', 'tiered_imagenet_val_{}.pickle'.format(partition_count)), 'wb') as handle:
pickle.dump(item, handle, protocol=2)
label_encoder = {}
keys = list(train_set.keys()) + list(test_set.keys())
for id_key, key in enumerate(keys):
label_encoder[key] = id_key
with open(os.path.join(self.root, 'tiered-imagenet/compacted_datasets', 'tiered_imagenet_label_encoder.pickle'), 'wb') as handle:
pickle.dump(label_encoder, handle, protocol=2)
print('Images preprocessed')
def load_dataset(self):
print("Loading dataset")
data = {}
if self.partition == 'train':
num_partition = 18
elif self.partition == 'val':
num_partition = 5
elif self.partition == 'test':
num_partition = 8
partition_count = 0
for i in range(num_partition):
partition_count = partition_count +1
with open(os.path.join(self.root, 'tiered-imagenet/compacted_datasets', 'tiered_imagenet_{}_{}.pickle'.format(self.partition, partition_count)), 'rb') as handle:
data.update(pickle.load(handle))
# Resize images and normalize
for class_ in data:
for i in range(len(data[class_])):
image2resize = pil_image.fromarray(np.uint8(data[class_][i]))
image_resized = image2resize.resize((self.data_size[2], self.data_size[1]))
image_resized = np.array(image_resized, dtype='float32')
# Normalize
image_resized = np.transpose(image_resized, (2, 0, 1))
image_resized[0, :, :] -= 120.45 # R
image_resized[1, :, :] -= 115.74 # G
image_resized[2, :, :] -= 104.65 # B
image_resized /= 127.5
data[class_][i] = image_resized
print("Num classes " + str(len(data)))
num_images = 0
for class_ in data:
num_images += len(data[class_])
print("Num images " + str(num_images))
return data
def chunks(self, data, size=10000):
it = iter(data)
for i in range(0, len(data), size):
yield {k: data[k] for k in islice(it, size)}
def get_task_batch(self,
num_tasks=5,
num_ways=20,
num_shots=1,
num_queries=1,
seed=None):
if seed is not None:
random.seed(seed)
# init task batch data
support_data, support_label, query_data, query_label = [], [], [], []
for _ in range(num_ways * num_shots):
data = np.zeros(shape=[num_tasks] + self.data_size,
dtype='float32')
label = np.zeros(shape=[num_tasks],
dtype='float32')
support_data.append(data)
support_label.append(label)
for _ in range(num_ways * num_queries):
data = np.zeros(shape=[num_tasks] + self.data_size,
dtype='float32')
label = np.zeros(shape=[num_tasks],
dtype='float32')
query_data.append(data)
query_label.append(label)
# get full class list in dataset
full_class_list = list(self.data.keys())
# for each task
for t_idx in range(num_tasks):
# define task by sampling classes (num_ways)
task_class_list = random.sample(full_class_list, num_ways)
# for each sampled class in task
for c_idx in range(num_ways):
# sample data for support and query (num_shots + num_queries)
class_data_list = random.sample(self.data[task_class_list[c_idx]], num_shots + num_queries)
# load sample for support set
for i_idx in range(num_shots):
# set data
support_data[i_idx + c_idx * num_shots][t_idx] = class_data_list[i_idx]
support_label[i_idx + c_idx * num_shots][t_idx] = c_idx
# load sample for query set
for i_idx in range(num_queries):
query_data[i_idx + c_idx * num_queries][t_idx] = class_data_list[num_shots + i_idx]
query_label[i_idx + c_idx * num_queries][t_idx] = c_idx
# convert to tensor (num_tasks x (num_ways * (num_supports + num_queries)) x ...)
support_data = torch.stack([torch.from_numpy(data).float().to(tt.arg.device) for data in support_data], 1)
support_label = torch.stack([torch.from_numpy(label).float().to(tt.arg.device) for label in support_label], 1)
query_data = torch.stack([torch.from_numpy(data).float().to(tt.arg.device) for data in query_data], 1)
query_label = torch.stack([torch.from_numpy(label).float().to(tt.arg.device) for label in query_label], 1)
return [support_data, support_label, query_data, query_label]
================================================
FILE: eval.py
================================================
from torchtools import *
from data import MiniImagenetLoader, TieredImagenetLoader
from model import EmbeddingImagenet, GraphNetwork, ConvNet
import shutil
import os
import random
from train import ModelTrainer
if __name__ == '__main__':
tt.arg.test_model = 'D-mini_N-5_K-1_U-0_L-3_B-40_T-True' if tt.arg.test_model is None else tt.arg.test_model
list1 = tt.arg.test_model.split("_")
param = {}
for i in range(len(list1)):
param[list1[i].split("-", 1)[0]] = list1[i].split("-", 1)[1]
tt.arg.dataset = param['D']
tt.arg.num_ways = int(param['N'])
tt.arg.num_shots = int(param['K'])
tt.arg.num_unlabeled = int(param['U'])
tt.arg.num_layers = int(param['L'])
tt.arg.meta_batch_size = int(param['B'])
tt.arg.transductive = False if param['T'] == 'False' else True
####################
tt.arg.device = 'cuda:0' if tt.arg.device is None else tt.arg.device
# replace dataset_root with your own
tt.arg.dataset_root = '/data/private/dataset'
tt.arg.dataset = 'mini' if tt.arg.dataset is None else tt.arg.dataset
tt.arg.num_ways = 5 if tt.arg.num_ways is None else tt.arg.num_ways
tt.arg.num_shots = 1 if tt.arg.num_shots is None else tt.arg.num_shots
tt.arg.num_unlabeled = 0 if tt.arg.num_unlabeled is None else tt.arg.num_unlabeled
tt.arg.num_layers = 3 if tt.arg.num_layers is None else tt.arg.num_layers
tt.arg.meta_batch_size = 40 if tt.arg.meta_batch_size is None else tt.arg.meta_batch_size
tt.arg.transductive = False if tt.arg.transductive is None else tt.arg.transductive
tt.arg.seed = 222 if tt.arg.seed is None else tt.arg.seed
tt.arg.num_gpus = 1 if tt.arg.num_gpus is None else tt.arg.num_gpus
tt.arg.num_ways_train = tt.arg.num_ways
tt.arg.num_ways_test = tt.arg.num_ways
tt.arg.num_shots_train = tt.arg.num_shots
tt.arg.num_shots_test = tt.arg.num_shots
tt.arg.train_transductive = tt.arg.transductive
tt.arg.test_transductive = tt.arg.transductive
# model parameter related
tt.arg.num_edge_features = 96
tt.arg.num_node_features = 96
tt.arg.emb_size = 128
# train, test parameters
tt.arg.train_iteration = 100000 if tt.arg.dataset == 'mini' else 200000
tt.arg.test_iteration = 10000
tt.arg.test_interval = 5000
tt.arg.test_batch_size = 10
tt.arg.log_step = 1000
tt.arg.lr = 1e-3
tt.arg.grad_clip = 5
tt.arg.weight_decay = 1e-6
tt.arg.dec_lr = 15000 if tt.arg.dataset == 'mini' else 30000
tt.arg.dropout = 0.1 if tt.arg.dataset == 'mini' else 0.0
#set random seed
np.random.seed(tt.arg.seed)
torch.manual_seed(tt.arg.seed)
torch.cuda.manual_seed_all(tt.arg.seed)
random.seed(tt.arg.seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
enc_module = EmbeddingImagenet(emb_size=tt.arg.emb_size)
# set random seed
np.random.seed(tt.arg.seed)
torch.manual_seed(tt.arg.seed)
torch.cuda.manual_seed_all(tt.arg.seed)
random.seed(tt.arg.seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
# to check
exp_name = 'D-{}'.format(tt.arg.dataset)
exp_name += '_N-{}_K-{}_U-{}'.format(tt.arg.num_ways, tt.arg.num_shots, tt.arg.num_unlabeled)
exp_name += '_L-{}_B-{}'.format(tt.arg.num_layers, tt.arg.meta_batch_size)
exp_name += '_T-{}'.format(tt.arg.transductive)
if not exp_name == tt.arg.test_model:
print(exp_name)
print(tt.arg.test_model)
print('Test model and input arguments are mismatched!')
AssertionError()
gnn_module = GraphNetwork(in_features=tt.arg.emb_size,
node_features=tt.arg.num_edge_features,
edge_features=tt.arg.num_node_features,
num_layers=tt.arg.num_layers,
dropout=tt.arg.dropout)
if tt.arg.dataset == 'mini':
test_loader = MiniImagenetLoader(root=tt.arg.dataset_root, partition='test')
elif tt.arg.dataset == 'tiered':
test_loader = TieredImagenetLoader(root=tt.arg.dataset_root, partition='test')
else:
print('Unknown dataset!')
data_loader = {'test': test_loader}
# create trainer
tester = ModelTrainer(enc_module=enc_module,
gnn_module=gnn_module,
data_loader=data_loader)
#checkpoint = torch.load('asset/checkpoints/{}/'.format(exp_name) + 'model_best.pth.tar')
checkpoint = torch.load('./trained_models/{}/'.format(exp_name) + 'model_best.pth.tar')
tester.enc_module.load_state_dict(checkpoint['enc_module_state_dict'])
print("load pre-trained enc_nn done!")
# initialize gnn pre-trained
tester.gnn_module.load_state_dict(checkpoint['gnn_module_state_dict'])
print("load pre-trained egnn done!")
tester.val_acc = checkpoint['val_acc']
tester.global_step = checkpoint['iteration']
print(tester.global_step)
tester.eval(partition='test')
================================================
FILE: model.py
================================================
from torchtools import *
from collections import OrderedDict
import math
#import seaborn as sns
import numpy as np
import matplotlib.pyplot as plt
class ConvBlock(nn.Module):
def __init__(self, in_planes, out_planes, userelu=True, momentum=0.1, affine=True, track_running_stats=True):
super(ConvBlock, self).__init__()
self.layers = nn.Sequential()
self.layers.add_module('Conv', nn.Conv2d(in_planes, out_planes,
kernel_size=3, stride=1, padding=1, bias=False))
if tt.arg.normtype == 'batch':
self.layers.add_module('Norm', nn.BatchNorm2d(out_planes, momentum=momentum, affine=affine, track_running_stats=track_running_stats))
elif tt.arg.normtype == 'instance':
self.layers.add_module('Norm', nn.InstanceNorm2d(out_planes))
if userelu:
self.layers.add_module('ReLU', nn.ReLU(inplace=True))
self.layers.add_module(
'MaxPool', nn.MaxPool2d(kernel_size=2, stride=2, padding=0))
def forward(self, x):
out = self.layers(x)
return out
class ConvNet(nn.Module):
def __init__(self, opt, momentum=0.1, affine=True, track_running_stats=True):
super(ConvNet, self).__init__()
self.in_planes = opt['in_planes']
self.out_planes = opt['out_planes']
self.num_stages = opt['num_stages']
if type(self.out_planes) == int:
self.out_planes = [self.out_planes for i in range(self.num_stages)]
assert(type(self.out_planes)==list and len(self.out_planes)==self.num_stages)
num_planes = [self.in_planes,] + self.out_planes
userelu = opt['userelu'] if ('userelu' in opt) else True
conv_blocks = []
for i in range(self.num_stages):
if i == (self.num_stages-1):
conv_blocks.append(
ConvBlock(num_planes[i], num_planes[i+1], userelu=userelu))
else:
conv_blocks.append(
ConvBlock(num_planes[i], num_planes[i+1]))
self.conv_blocks = nn.Sequential(*conv_blocks)
for m in self.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
def forward(self, x):
out = self.conv_blocks(x)
out = out.view(out.size(0),-1)
return out
# encoder for imagenet dataset
class EmbeddingImagenet(nn.Module):
def __init__(self,
emb_size):
super(EmbeddingImagenet, self).__init__()
# set size
self.hidden = 64
self.last_hidden = self.hidden * 25
self.emb_size = emb_size
# set layers
self.conv_1 = nn.Sequential(nn.Conv2d(in_channels=3,
out_channels=self.hidden,
kernel_size=3,
padding=1,
bias=False),
nn.BatchNorm2d(num_features=self.hidden),
nn.MaxPool2d(kernel_size=2),
nn.LeakyReLU(negative_slope=0.2, inplace=True))
self.conv_2 = nn.Sequential(nn.Conv2d(in_channels=self.hidden,
out_channels=int(self.hidden*1.5),
kernel_size=3,
bias=False),
nn.BatchNorm2d(num_features=int(self.hidden*1.5)),
nn.MaxPool2d(kernel_size=2),
nn.LeakyReLU(negative_slope=0.2, inplace=True))
self.conv_3 = nn.Sequential(nn.Conv2d(in_channels=int(self.hidden*1.5),
out_channels=self.hidden*2,
kernel_size=3,
padding=1,
bias=False),
nn.BatchNorm2d(num_features=self.hidden * 2),
nn.MaxPool2d(kernel_size=2),
nn.LeakyReLU(negative_slope=0.2, inplace=True),
nn.Dropout2d(0.4))
self.conv_4 = nn.Sequential(nn.Conv2d(in_channels=self.hidden*2,
out_channels=self.hidden*4,
kernel_size=3,
padding=1,
bias=False),
nn.BatchNorm2d(num_features=self.hidden * 4),
nn.MaxPool2d(kernel_size=2),
nn.LeakyReLU(negative_slope=0.2, inplace=True),
nn.Dropout2d(0.5))
self.layer_last = nn.Sequential(nn.Linear(in_features=self.last_hidden * 4,
out_features=self.emb_size, bias=True),
nn.BatchNorm1d(self.emb_size))
def forward(self, input_data):
output_data = self.conv_4(self.conv_3(self.conv_2(self.conv_1(input_data))))
return self.layer_last(output_data.view(output_data.size(0), -1))
class NodeUpdateNetwork(nn.Module):
def __init__(self,
in_features,
num_features,
ratio=[2, 1],
dropout=0.0):
super(NodeUpdateNetwork, self).__init__()
# set size
self.in_features = in_features
self.num_features_list = [num_features * r for r in ratio]
self.dropout = dropout
# layers
layer_list = OrderedDict()
for l in range(len(self.num_features_list)):
layer_list['conv{}'.format(l)] = nn.Conv2d(
in_channels=self.num_features_list[l - 1] if l > 0 else self.in_features * 3,
out_channels=self.num_features_list[l],
kernel_size=1,
bias=False)
layer_list['norm{}'.format(l)] = nn.BatchNorm2d(num_features=self.num_features_list[l],
)
layer_list['relu{}'.format(l)] = nn.LeakyReLU()
if self.dropout > 0 and l == (len(self.num_features_list) - 1):
layer_list['drop{}'.format(l)] = nn.Dropout2d(p=self.dropout)
self.network = nn.Sequential(layer_list)
def forward(self, node_feat, edge_feat):
# get size
num_tasks = node_feat.size(0)
num_data = node_feat.size(1)
# get eye matrix (batch_size x 2 x node_size x node_size)
diag_mask = 1.0 - torch.eye(num_data).unsqueeze(0).unsqueeze(0).repeat(num_tasks, 2, 1, 1).to(tt.arg.device)
# set diagonal as zero and normalize
edge_feat = F.normalize(edge_feat * diag_mask, p=1, dim=-1)
# compute attention and aggregate
aggr_feat = torch.bmm(torch.cat(torch.split(edge_feat, 1, 1), 2).squeeze(1), node_feat)
node_feat = torch.cat([node_feat, torch.cat(aggr_feat.split(num_data, 1), -1)], -1).transpose(1, 2)
# non-linear transform
node_feat = self.network(node_feat.unsqueeze(-1)).transpose(1, 2).squeeze(-1)
return node_feat
class EdgeUpdateNetwork(nn.Module):
def __init__(self,
in_features,
num_features,
ratio=[2, 2, 1, 1],
separate_dissimilarity=False,
dropout=0.0):
super(EdgeUpdateNetwork, self).__init__()
# set size
self.in_features = in_features
self.num_features_list = [num_features * r for r in ratio]
self.separate_dissimilarity = separate_dissimilarity
self.dropout = dropout
# layers
layer_list = OrderedDict()
for l in range(len(self.num_features_list)):
# set layer
layer_list['conv{}'.format(l)] = nn.Conv2d(in_channels=self.num_features_list[l-1] if l > 0 else self.in_features,
out_channels=self.num_features_list[l],
kernel_size=1,
bias=False)
layer_list['norm{}'.format(l)] = nn.BatchNorm2d(num_features=self.num_features_list[l],
)
layer_list['relu{}'.format(l)] = nn.LeakyReLU()
if self.dropout > 0:
layer_list['drop{}'.format(l)] = nn.Dropout2d(p=self.dropout)
layer_list['conv_out'] = nn.Conv2d(in_channels=self.num_features_list[-1],
out_channels=1,
kernel_size=1)
self.sim_network = nn.Sequential(layer_list)
if self.separate_dissimilarity:
# layers
layer_list = OrderedDict()
for l in range(len(self.num_features_list)):
# set layer
layer_list['conv{}'.format(l)] = nn.Conv2d(in_channels=self.num_features_list[l-1] if l > 0 else self.in_features,
out_channels=self.num_features_list[l],
kernel_size=1,
bias=False)
layer_list['norm{}'.format(l)] = nn.BatchNorm2d(num_features=self.num_features_list[l],
)
layer_list['relu{}'.format(l)] = nn.LeakyReLU()
if self.dropout > 0:
layer_list['drop{}'.format(l)] = nn.Dropout(p=self.dropout)
layer_list['conv_out'] = nn.Conv2d(in_channels=self.num_features_list[-1],
out_channels=1,
kernel_size=1)
self.dsim_network = nn.Sequential(layer_list)
def forward(self, node_feat, edge_feat):
# compute abs(x_i, x_j)
x_i = node_feat.unsqueeze(2)
x_j = torch.transpose(x_i, 1, 2)
x_ij = torch.abs(x_i - x_j)
x_ij = torch.transpose(x_ij, 1, 3)
# compute similarity/dissimilarity (batch_size x feat_size x num_samples x num_samples)
sim_val = F.sigmoid(self.sim_network(x_ij))
if self.separate_dissimilarity:
dsim_val = F.sigmoid(self.dsim_network(x_ij))
else:
dsim_val = 1.0 - sim_val
diag_mask = 1.0 - torch.eye(node_feat.size(1)).unsqueeze(0).unsqueeze(0).repeat(node_feat.size(0), 2, 1, 1).to(tt.arg.device)
edge_feat = edge_feat * diag_mask
merge_sum = torch.sum(edge_feat, -1, True)
# set diagonal as zero and normalize
edge_feat = F.normalize(torch.cat([sim_val, dsim_val], 1) * edge_feat, p=1, dim=-1) * merge_sum
force_edge_feat = torch.cat((torch.eye(node_feat.size(1)).unsqueeze(0), torch.zeros(node_feat.size(1), node_feat.size(1)).unsqueeze(0)), 0).unsqueeze(0).repeat(node_feat.size(0), 1, 1, 1).to(tt.arg.device)
edge_feat = edge_feat + force_edge_feat
edge_feat = edge_feat + 1e-6
edge_feat = edge_feat / torch.sum(edge_feat, dim=1).unsqueeze(1).repeat(1, 2, 1, 1)
return edge_feat
class GraphNetwork(nn.Module):
def __init__(self,
in_features,
node_features,
edge_features,
num_layers,
dropout=0.0):
super(GraphNetwork, self).__init__()
# set size
self.in_features = in_features
self.node_features = node_features
self.edge_features = edge_features
self.num_layers = num_layers
self.dropout = dropout
# for each layer
for l in range(self.num_layers):
# set edge to node
edge2node_net = NodeUpdateNetwork(in_features=self.in_features if l == 0 else self.node_features,
num_features=self.node_features,
dropout=self.dropout if l < self.num_layers-1 else 0.0)
# set node to edge
node2edge_net = EdgeUpdateNetwork(in_features=self.node_features,
num_features=self.edge_features,
separate_dissimilarity=False,
dropout=self.dropout if l < self.num_layers-1 else 0.0)
self.add_module('edge2node_net{}'.format(l), edge2node_net)
self.add_module('node2edge_net{}'.format(l), node2edge_net)
# forward
def forward(self, node_feat, edge_feat):
# for each layer
edge_feat_list = []
for l in range(self.num_layers):
# (1) edge to node
node_feat = self._modules['edge2node_net{}'.format(l)](node_feat, edge_feat)
# (2) node to edge
edge_feat = self._modules['node2edge_net{}'.format(l)](node_feat, edge_feat)
# save edge feature
edge_feat_list.append(edge_feat)
# if tt.arg.visualization:
# for l in range(self.num_layers):
# ax = sns.heatmap(tt.nvar(edge_feat_list[l][0, 0, :, :]), xticklabels=False, yticklabels=False, linewidth=0.1, cmap="coolwarm", cbar=False, square=True)
# ax.get_figure().savefig('./visualization/edge_feat_layer{}.png'.format(l))
return edge_feat_list
================================================
FILE: torchtools/__init__.py
================================================
import numpy as np
import torch
from torch import nn
from torch import optim
from torch import cuda
from torch import utils
from torch.nn import functional as F
from torch.utils.data import *
from torch.distributions import *
from torchtools import tt
__author__ = 'namju.kim@kakaobrain.com'
# initialize seed
if tt.arg.seed:
np.random.seed(tt.arg.seed)
torch.manual_seed(tt.arg.seed)
================================================
FILE: torchtools/_version.py
================================================
__version__ = '0.4.0' # align version with pytorch
================================================
FILE: torchtools/tt/__init__.py
================================================
from torchtools.tt.arg import _parse_opts
from torchtools.tt.utils import *
from torchtools.tt.layer import *
from torchtools.tt.logger import *
from torchtools.tt.stat import *
from torchtools.tt.trainer import *
__author__ = 'namju.kim@kakaobrain.com'
# global command line arguments
arg = _parse_opts()
================================================
FILE: torchtools/tt/arg.py
================================================
import sys
import configparser
import torch
import threading
import time
import os
__author__ = 'namju.kim@kakaobrain.com'
_config_time_stamp = 0
class _Opt(object):
def __len__(self):
return len(self.__dict__)
def __setitem__(self, key, value):
self.__dict__[key] = value
def __getitem__(self, item):
if item in self.__dict__:
return self.__dict__[item]
else:
return None
def __getattr__(self, item):
return self.__getitem__(item)
def _to_py_obj(x):
# check boolean first
if x.lower() in ['true', 'yes', 'on']:
return True
if x.lower() in ['false', 'no', 'off']:
return False
# from string to python object if possible
try:
obj = eval(x)
if type(obj).__name__ in ['int', 'float', 'tuple', 'list', 'dict', 'NoneType']:
x = obj
except:
pass
return x
def _parse_config(arg, file):
# read config file
config = configparser.ConfigParser()
config.read(file)
# traverse sections
for section in config.sections():
# traverse items
opt = _Opt()
for key in config[section]:
opt[key] = _to_py_obj(config[section][key])
# if default section, save items to global scope
if section.lower() == 'default':
for k, v in opt.__dict__.items():
arg[k] = v
else:
arg['_'.join(section.split())] = opt
def _parse_config_thread(arg, file):
global _config_time_stamp
while True:
# check timestamp
stamp = os.stat(file).st_mtime
if not stamp == _config_time_stamp:
# update timestamp
_config_time_stamp = stamp
# parse config file
_parse_config(arg, file)
# print result
# _print_opts(arg, 'CONFIGURATION CHANGE DETECTED')
# sleep
time.sleep(1)
def _print_opts(arg, header):
print(header, flush=True)
print('-' * 30, flush=True)
for k, v in arg.__dict__.items():
print('%s=%s' % (k, v), flush=True)
print('-' * 30, flush=True)
def _parse_opts():
global _config_time_stamp
# get command line arguments
arg = _Opt()
argv = sys.argv[1:]
# check length
assert len(argv) % 2 == 0, 'arguments should be paired with the format of --key value'
# parse args
for i in range(0, len(argv), 2):
# check format
assert argv[i].startswith('--'), 'arguments should be paired with the format of --key value'
# save argument
arg[argv[i][2:]] = _to_py_obj(argv[i + 1])
# check config file
if argv[i][2:].lower() == 'config':
_parse_config(arg, argv[i + 1])
_config_time_stamp = os.stat(argv[i + 1]).st_mtime
#
# inject default options
#
# device setting
if arg.device is None:
arg.device = 'cuda' if torch.cuda.is_available() else 'cpu'
arg.device = torch.device(arg.device)
arg.cuda = arg.device.type == 'cuda'
# default learning rate
#arg.lr = 1e-3
# directories
arg.log_dir = arg.log_dir or 'asset/log/'
arg.data_dir = arg.data_dir or 'asset/data/'
arg.save_dir = arg.save_dir or 'asset/train/'
arg.log_dir += '' if arg.log_dir.endswith('/') else '/'
arg.data_dir += '' if arg.data_dir.endswith('/') else '/'
arg.save_dir += '' if arg.save_dir.endswith('/') else '/'
# print arg option
# _print_opts(arg, 'CONFIGURATION')
# start config file watcher if config is defined
if arg.config:
t = threading.Thread(target=_parse_config_thread, args=(arg, arg.config))
t.daemon = True
t.start()
return arg
================================================
FILE: torchtools/tt/layer.py
================================================
from torchtools import nn
#
# Reshape layer for Sequential or ModuleList
#
class Reshape(nn.Module):
def __init__(self, *shape):
super(Reshape, self).__init__()
self.shape = shape
def forward(self, x):
return x.reshape(self.shape)
def extra_repr(self):
return 'shape={}'.format(self.shape)
================================================
FILE: torchtools/tt/logger.py
================================================
import datetime
import time
from tensorboardX import SummaryWriter
from torchtools import tt
__author__ = 'namju.kim@kakaobrain.com'
# tensorboard writer
_writer = None
_stats_scalar, _stats_image, _stats_audio, _stats_text, _stats_hist = {}, {}, {}, {}, {}
# time stamp
_last_logged = time.time()
# general print wrapper
def log(*args):
print(*args, flush=True)
# save to log_file
if tt.arg.log_file:
with open(tt.arg.log_dir + tt.arg.log_file, 'a') as f:
print(*args, flush=True, file=f)
# tensor board writer
def _get_writer():
global _writer
if _writer is None:
# logging directory
tf_log_dir = tt.arg.log_dir
tf_log_dir += '' if tf_log_dir.endswith('/') else '/'
if tt.arg.experiment:
tf_log_dir += tt.arg.experiment
tf_log_dir += datetime.datetime.now().strftime('-%Y%m%d-%H%M%S')
# create writer
_writer = SummaryWriter(tf_log_dir)
return _writer
def log_scalar(tag, value, global_step=None):
_stats_scalar[tag] = (tt.nvar(value), global_step)
def log_audio(tag, audio, global_step=None):
_stats_audio[tag] = (tt.nvar(audio), global_step)
def log_image(tag, image, global_step=None):
_stats_image[tag] = (tt.nvar(image), global_step)
def log_text(tag, text, global_step=None):
_stats_text[tag] = (text, global_step)
def log_hist(tag, values, global_step=None):
_stats_hist[tag] = (tt.nvar(values), global_step)
def log_step(epoch=None, global_step=None, max_epoch=None, max_step=None):
global _last_logged, _last_logged_step, _stats_scalar, _stats_image, _stats_audio, _stats_text, _stats_hist
# logging
if (tt.arg.log_interval is None and tt.arg.log_step is None) or \
(tt.arg.log_interval and time.time() - _last_logged >= tt.arg.log_interval) or \
(tt.arg.log_step and global_step % tt.arg.log_step == 0):
# update logging time stamp
_last_logged = time.time()
_last_logged_step = global_step
# console output string
console_out = ''
if epoch:
console_out += 'ep: %d' % epoch
if max_epoch:
console_out += '/%d' % max_epoch
if global_step:
if max_step:
step = global_step % max_step
step = max_step if step == 0 else step
console_out += ' step: %d/%d' % (step, max_step)
else:
console_out += ' step: %d' % global_step
# add stats to tensor board
for k, v in _stats_scalar.items():
_get_writer().add_scalar(k, *v)
# add to console output
if not k.startswith('weight/') and not k.startswith('gradient/'):
console_out += ' %s: %f' % (k, v[0])
for k, v in _stats_image.items():
_get_writer().add_image(k, *v)
for k, v in _stats_audio.items():
_get_writer().add_audio(k, *v)
for k, v in _stats_text.items():
_get_writer().add_text(k, *v)
for k, v in _stats_hist.items():
_get_writer().add_histogram(k, *v, 'auto')
# flush
_get_writer().file_writer.flush()
# console out
if len(console_out) > 0:
log(console_out)
# clear stats
_stats_scalar, _stats_image, _stats_audio, _stats_text = {}, {}, {}, {}
def log_weight(model, global_step=None):
# weight statics
if tt.arg.log_weight:
for k, v in model.named_parameters():
if 'weight' in k: # only for weight not bias
log_scalar('weight/' + k, v.norm(), global_step)
def log_gradient(model, global_step=None):
# gradient statics
if tt.arg.log_grad:
for k, v in model.named_parameters():
if 'weight' in k: # only for weight not bias
if v.grad is not None:
log_scalar('gradient/' + k, v.grad.norm(), global_step)
================================================
FILE: torchtools/tt/stat.py
================================================
from torchtools import tt
__author__ = 'namju.kim@kakaobrain.com'
def accuracy(prob, label, ignore_index=-100):
# argmax
pred = prob.max(1)[1].type_as(label)
# masking
mask = label.ne(ignore_index)
pred = pred.masked_select(mask)
label = label.masked_select(mask)
# calc accuracy
hit = tt.nvar(pred.eq(label).long().sum())
acc = hit / label.size(0)
return acc
================================================
FILE: torchtools/tt/trainer.py
================================================
from torchtools import nn, optim, tt
__author__ = 'namju.kim@kakaobrain.com'
class SupervisedTrainer(object):
def __init__(self, model, data_loader, optimizer=None, criterion=None):
self.global_step = 0
self.model = model.to(tt.arg.device)
self.data_loader = data_loader
self.optimizer = optimizer or optim.Adam(model.parameters())
self.criterion = criterion or nn.CrossEntropyLoss()
def train(self, inputs):
# split inputs
x, y = inputs
# forward
if tt.arg.cuda:
z = nn.DataParallel(self.model)(x)
else:
z = self.model(x)
# loss
loss = self.criterion(z, y)
# accuracy
acc = tt.accuracy(z, y)
# update model
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
# logging
tt.log_scalar('loss', loss, self.global_step)
tt.log_scalar('acc', acc, self.global_step)
def epoch(self, ep_no=None):
pass
def run(self):
# experiment name
tt.arg.experiment = tt.arg.experiment or self.model.__class__.__name__.lower()
# load model
self.global_step = self.model.load_model()
epoch, min_step = divmod(self.global_step, len(self.data_loader))
# epochs
while epoch < (tt.arg.epoch or 1):
epoch += 1
# iterations
for step, inputs in enumerate(self.data_loader, min_step + 1):
# check step counter
if step > len(self.data_loader):
break
# increase global step count
self.global_step += 1
# update learning rate
for param_group in self.optimizer.param_groups:
param_group['lr'] = tt.arg.lr
# call train func
if type(inputs) in [list, tuple]:
self.train([tt.var(d) for d in inputs])
else:
self.train(tt.var(inputs))
# logging
tt.log_weight(self.model, global_step=self.global_step)
tt.log_gradient(self.model, global_step=self.global_step)
tt.log_step(epoch=epoch, global_step=self.global_step,
max_epoch=(tt.arg.epoch or 1), max_step=len(self.data_loader))
# save model
self.model.save_model(self.global_step)
# epoch handler
self.epoch(epoch)
# save final model
self.model.save_model(self.global_step, force=True)
================================================
FILE: torchtools/tt/utils.py
================================================
import os
import datetime
import time
import pathlib
from torchtools import torch, nn, tt
__author__ = 'namju.kim@kakaobrain.com'
# time stamp
_tic_start = _last_saved = _last_archived = time.time()
# best statics
_best = -100000000.
def tic():
global _tic_start
_tic_start = time.time()
return _tic_start
def toc(tic=None):
global _tic_start
if tic is None:
return time.time() - _tic_start
else:
return time.time() - tic
def sleep(seconds):
time.sleep(seconds)
#
# automatic device-aware torch.tensor
#
def var(data, dtype=None, device=None, requires_grad=False):
# return torch.tensor(data, dtype=dtype, device=(device or tt.arg.device), requires_grad=requires_grad)
# the upper code doesn't work, so work around as following. ( maybe bug )
return torch.tensor(data, dtype=dtype, requires_grad=requires_grad).to((device or tt.arg.device))
def vars(x_list, dtype=None, device=None, requires_grad=False):
return [var(x, dtype, device, requires_grad) for x in x_list]
# for old torchtools compatibility
def cvar(x):
return x.detach()
#
# to python or numpy variable(s)
#
def nvar(x):
if isinstance(x, torch.Tensor):
x = x.detach().cpu()
x = x.item() if x.dim() == 0 else x.numpy()
return x
def nvars(x_list):
return [nvar(x) for x in x_list]
def load_model(model, best=False, postfix=None, experiment=None):
global _best
# model file name
filename = tt.arg.save_dir + '%s.pt' % (experiment or tt.arg.experiment or model.__class__.__name__.lower())
if postfix is not None:
filename = filename + '.%s' % postfix
# load model
global_step = 0
if os.path.exists(filename):
if best:
global_step, model_state, _best = torch.load(filename + '.best', map_location=lambda storage, loc: storage)
else:
global_step, model_state = torch.load(filename, map_location=lambda storage, loc: storage)
model.load_state_dict(model_state)
# update best stat
filename += '.best'
if os.path.exists(filename):
_, _, _best = torch.load(filename, map_location=lambda storage, loc: storage)
return global_step
def save_model(model, global_step, force=False, best=None, postfix=None):
global _last_saved, _last_archived, _best
# make directory
pathlib.Path(tt.arg.save_dir).mkdir(parents=True, exist_ok=True)
# filename to save
filename = '%s.pt' % (tt.arg.experiment or model.__class__.__name__.lower())
if postfix is not None:
filename = filename + '.%s' % postfix
# save model
if force or (tt.arg.save_interval and time.time() - _last_saved >= tt.arg.save_interval) or \
(tt.arg.save_step and global_step % tt.arg.save_step == 0):
torch.save((global_step, model.state_dict()), tt.arg.save_dir + filename)
_last_saved = time.time()
# archive model
if (tt.arg.archive_interval and time.time() - _last_archived >= tt.arg.archive_interval) or \
(tt.arg.archive_step and global_step % tt.arg.archive_step == 0):
# filename to archive
if tt.arg.archive_interval:
filename = filename + datetime.datetime.now().strftime('.%Y%m%d.%H%M%S')
else:
filename = filename + '.%d' % global_step
torch.save((global_step, model.state_dict()), tt.arg.save_dir + filename)
_last_archived = time.time()
# save best model
if best is not None and best > _best:
_best = best
filename = filename + '.best'
torch.save((global_step, model.state_dict(), best), tt.arg.save_dir + filename)
# patch Module
nn.Module.load_model = load_model
nn.Module.save_model = save_model
================================================
FILE: train.py
================================================
from torchtools import *
from data import MiniImagenetLoader, TieredImagenetLoader
from model import EmbeddingImagenet, GraphNetwork, ConvNet
import shutil
import os
import random
#import seaborn as sns
class ModelTrainer(object):
def __init__(self,
enc_module,
gnn_module,
data_loader):
# set encoder and gnn
self.enc_module = enc_module.to(tt.arg.device)
self.gnn_module = gnn_module.to(tt.arg.device)
if tt.arg.num_gpus > 1:
print('Construct multi-gpu model ...')
self.enc_module = nn.DataParallel(self.enc_module, device_ids=[0, 1, 2, 3], dim=0)
self.gnn_module = nn.DataParallel(self.gnn_module, device_ids=[0, 1, 2, 3], dim=0)
print('done!\n')
# get data loader
self.data_loader = data_loader
# set optimizer
self.module_params = list(self.enc_module.parameters()) + list(self.gnn_module.parameters())
# set optimizer
self.optimizer = optim.Adam(params=self.module_params,
lr=tt.arg.lr,
weight_decay=tt.arg.weight_decay)
# set loss
self.edge_loss = nn.BCELoss(reduction='none')
self.node_loss = nn.CrossEntropyLoss(reduction='none')
self.global_step = 0
self.val_acc = 0
self.test_acc = 0
def train(self):
val_acc = self.val_acc
# set edge mask (to distinguish support and query edges)
num_supports = tt.arg.num_ways_train * tt.arg.num_shots_train
num_queries = tt.arg.num_ways_train * 1
num_samples = num_supports + num_queries
support_edge_mask = torch.zeros(tt.arg.meta_batch_size, num_samples, num_samples).to(tt.arg.device)
support_edge_mask[:, :num_supports, :num_supports] = 1
query_edge_mask = 1 - support_edge_mask
evaluation_mask = torch.ones(tt.arg.meta_batch_size, num_samples, num_samples).to(tt.arg.device)
# for semi-supervised setting, ignore unlabeled support sets for evaluation
for c in range(tt.arg.num_ways_train):
evaluation_mask[:,
((c + 1) * tt.arg.num_shots_train - tt.arg.num_unlabeled):(c + 1) * tt.arg.num_shots_train,
:num_supports] = 0
evaluation_mask[:, :num_supports,
((c + 1) * tt.arg.num_shots_train - tt.arg.num_unlabeled):(c + 1) * tt.arg.num_shots_train] = 0
# for each iteration
for iter in range(self.global_step + 1, tt.arg.train_iteration + 1):
# init grad
self.optimizer.zero_grad()
# set current step
self.global_step = iter
# load task data list
[support_data,
support_label,
query_data,
query_label] = self.data_loader['train'].get_task_batch(num_tasks=tt.arg.meta_batch_size,
num_ways=tt.arg.num_ways_train,
num_shots=tt.arg.num_shots_train,
seed=iter + tt.arg.seed)
# set as single data
full_data = torch.cat([support_data, query_data], 1)
full_label = torch.cat([support_label, query_label], 1)
full_edge = self.label2edge(full_label)
# set init edge
init_edge = full_edge.clone() # batch_size x 2 x num_samples x num_samples
init_edge[:, :, num_supports:, :] = 0.5
init_edge[:, :, :, num_supports:] = 0.5
for i in range(num_queries):
init_edge[:, 0, num_supports + i, num_supports + i] = 1.0
init_edge[:, 1, num_supports + i, num_supports + i] = 0.0
# for semi-supervised setting,
for c in range(tt.arg.num_ways_train):
init_edge[:, :, ((c+1) * tt.arg.num_shots_train - tt.arg.num_unlabeled):(c+1) * tt.arg.num_shots_train, :num_supports] = 0.5
init_edge[:, :, :num_supports, ((c+1) * tt.arg.num_shots_train - tt.arg.num_unlabeled):(c+1) * tt.arg.num_shots_train] = 0.5
# set as train mode
self.enc_module.train()
self.gnn_module.train()
# (1) encode data
full_data = [self.enc_module(data.squeeze(1)) for data in full_data.chunk(full_data.size(1), dim=1)]
full_data = torch.stack(full_data, dim=1) # batch_size x num_samples x featdim
# (2) predict edge logit (consider only the last layer logit, num_tasks x 2 x num_samples x num_samples)
if tt.arg.train_transductive:
full_logit_layers = self.gnn_module(node_feat=full_data, edge_feat=init_edge)
else:
evaluation_mask[:, num_supports:, num_supports:] = 0 # ignore query-query edges, since it is non-transductive setting
# input_node_feat: (batch_size x num_queries) x (num_support + 1) x featdim
# input_edge_feat: (batch_size x num_queries) x 2 x (num_support + 1) x (num_support + 1)
support_data = full_data[:, :num_supports] # batch_size x num_support x featdim
query_data = full_data[:, num_supports:] # batch_size x num_query x featdim
support_data_tiled = support_data.unsqueeze(1).repeat(1, num_queries, 1, 1) # batch_size x num_queries x num_support x featdim
support_data_tiled = support_data_tiled.view(tt.arg.meta_batch_size * num_queries, num_supports, -1) # (batch_size x num_queries) x num_support x featdim
query_data_reshaped = query_data.contiguous().view(tt.arg.meta_batch_size * num_queries, -1).unsqueeze(1) # (batch_size x num_queries) x 1 x featdim
input_node_feat = torch.cat([support_data_tiled, query_data_reshaped], 1) # (batch_size x num_queries) x (num_support + 1) x featdim
input_edge_feat = 0.5 * torch.ones(tt.arg.meta_batch_size, 2, num_supports + 1, num_supports + 1).to(tt.arg.device) # batch_size x 2 x (num_support + 1) x (num_support + 1)
input_edge_feat[:, :, :num_supports, :num_supports] = init_edge[:, :, :num_supports, :num_supports] # batch_size x 2 x (num_support + 1) x (num_support + 1)
input_edge_feat = input_edge_feat.repeat(num_queries, 1, 1, 1) #(batch_size x num_queries) x 2 x (num_support + 1) x (num_support + 1)
# logit: (batch_size x num_queries) x 2 x (num_support + 1) x (num_support + 1)
logit_layers = self.gnn_module(node_feat=input_node_feat, edge_feat=input_edge_feat)
logit_layers = [logit_layer.view(tt.arg.meta_batch_size, num_queries, 2, num_supports + 1, num_supports + 1) for logit_layer in logit_layers]
# logit --> full_logit (batch_size x 2 x num_samples x num_samples)
full_logit_layers = []
for l in range(tt.arg.num_layers):
full_logit_layers.append(torch.zeros(tt.arg.meta_batch_size, 2, num_samples, num_samples).to(tt.arg.device))
for l in range(tt.arg.num_layers):
full_logit_layers[l][:, :, :num_supports, :num_supports] = logit_layers[l][:, :, :, :num_supports, :num_supports].mean(1)
full_logit_layers[l][:, :, :num_supports, num_supports:] = logit_layers[l][:, :, :, :num_supports, -1].transpose(1, 2).transpose(2, 3)
full_logit_layers[l][:, :, num_supports:, :num_supports] = logit_layers[l][:, :, :, -1, :num_supports].transpose(1, 2)
# (4) compute loss
full_edge_loss_layers = [self.edge_loss((1-full_logit_layer[:, 0]), (1-full_edge[:, 0])) for full_logit_layer in full_logit_layers]
# weighted edge loss for balancing pos/neg
pos_query_edge_loss_layers = [torch.sum(full_edge_loss_layer * query_edge_mask * full_edge[:, 0] * evaluation_mask) / torch.sum(query_edge_mask * full_edge[:, 0] * evaluation_mask) for full_edge_loss_layer in full_edge_loss_layers]
neg_query_edge_loss_layers = [torch.sum(full_edge_loss_layer * query_edge_mask * (1-full_edge[:, 0]) * evaluation_mask) / torch.sum(query_edge_mask * (1-full_edge[:, 0]) * evaluation_mask) for full_edge_loss_layer in full_edge_loss_layers]
query_edge_loss_layers = [pos_query_edge_loss_layer + neg_query_edge_loss_layer for (pos_query_edge_loss_layer, neg_query_edge_loss_layer) in zip(pos_query_edge_loss_layers, neg_query_edge_loss_layers)]
# compute accuracy
full_edge_accr_layers = [self.hit(full_logit_layer, 1-full_edge[:, 0].long()) for full_logit_layer in full_logit_layers]
query_edge_accr_layers = [torch.sum(full_edge_accr_layer * query_edge_mask * evaluation_mask) / torch.sum(query_edge_mask * evaluation_mask) for full_edge_accr_layer in full_edge_accr_layers]
# compute node loss & accuracy (num_tasks x num_quries x num_ways)
query_node_pred_layers = [torch.bmm(full_logit_layer[:, 0, num_supports:, :num_supports], self.one_hot_encode(tt.arg.num_ways_train, support_label.long())) for full_logit_layer in full_logit_layers] # (num_tasks x num_quries x num_supports) * (num_tasks x num_supports x num_ways)
query_node_accr_layers = [torch.eq(torch.max(query_node_pred_layer, -1)[1], query_label.long()).float().mean() for query_node_pred_layer in query_node_pred_layers]
total_loss_layers = query_edge_loss_layers
# update model
total_loss = []
for l in range(tt.arg.num_layers - 1):
total_loss += [total_loss_layers[l].view(-1) * 0.5]
total_loss += [total_loss_layers[-1].view(-1) * 1.0]
total_loss = torch.mean(torch.cat(total_loss, 0))
total_loss.backward()
self.optimizer.step()
# adjust learning rate
self.adjust_learning_rate(optimizers=[self.optimizer],
lr=tt.arg.lr,
iter=self.global_step)
# logging
tt.log_scalar('train/edge_loss', query_edge_loss_layers[-1], self.global_step)
tt.log_scalar('train/edge_accr', query_edge_accr_layers[-1], self.global_step)
tt.log_scalar('train/node_accr', query_node_accr_layers[-1], self.global_step)
# evaluation
if self.global_step % tt.arg.test_interval == 0:
val_acc = self.eval(partition='val')
is_best = 0
if val_acc >= self.val_acc:
self.val_acc = val_acc
is_best = 1
tt.log_scalar('val/best_accr', self.val_acc, self.global_step)
self.save_checkpoint({
'iteration': self.global_step,
'enc_module_state_dict': self.enc_module.state_dict(),
'gnn_module_state_dict': self.gnn_module.state_dict(),
'val_acc': val_acc,
'optimizer': self.optimizer.state_dict(),
}, is_best)
tt.log_step(global_step=self.global_step)
def eval(self, partition='test', log_flag=True):
best_acc = 0
# set edge mask (to distinguish support and query edges)
num_supports = tt.arg.num_ways_test * tt.arg.num_shots_test
num_queries = tt.arg.num_ways_test * 1
num_samples = num_supports + num_queries
support_edge_mask = torch.zeros(tt.arg.test_batch_size, num_samples, num_samples).to(tt.arg.device)
support_edge_mask[:, :num_supports, :num_supports] = 1
query_edge_mask = 1 - support_edge_mask
evaluation_mask = torch.ones(tt.arg.test_batch_size, num_samples, num_samples).to(tt.arg.device)
# for semi-supervised setting, ignore unlabeled support sets for evaluation
for c in range(tt.arg.num_ways_test):
evaluation_mask[:,
((c + 1) * tt.arg.num_shots_test - tt.arg.num_unlabeled):(c + 1) * tt.arg.num_shots_test,
:num_supports] = 0
evaluation_mask[:, :num_supports,
((c + 1) * tt.arg.num_shots_test - tt.arg.num_unlabeled):(c + 1) * tt.arg.num_shots_test] = 0
query_edge_losses = []
query_edge_accrs = []
query_node_accrs = []
# for each iteration
for iter in range(tt.arg.test_iteration//tt.arg.test_batch_size):
# load task data list
[support_data,
support_label,
query_data,
query_label] = self.data_loader[partition].get_task_batch(num_tasks=tt.arg.test_batch_size,
num_ways=tt.arg.num_ways_test,
num_shots=tt.arg.num_shots_test,
seed=iter)
# set as single data
full_data = torch.cat([support_data, query_data], 1)
full_label = torch.cat([support_label, query_label], 1)
full_edge = self.label2edge(full_label)
# set init edge
init_edge = full_edge.clone()
init_edge[:, :, num_supports:, :] = 0.5
init_edge[:, :, :, num_supports:] = 0.5
for i in range(num_queries):
init_edge[:, 0, num_supports + i, num_supports + i] = 1.0
init_edge[:, 1, num_supports + i, num_supports + i] = 0.0
# for semi-supervised setting,
for c in range(tt.arg.num_ways_test):
init_edge[:, :, ((c+1) * tt.arg.num_shots_test - tt.arg.num_unlabeled):(c+1) * tt.arg.num_shots_test, :num_supports] = 0.5
init_edge[:, :, :num_supports, ((c+1) * tt.arg.num_shots_test - tt.arg.num_unlabeled):(c+1) * tt.arg.num_shots_test] = 0.5
# set as train mode
self.enc_module.eval()
self.gnn_module.eval()
# (1) encode data
full_data = [self.enc_module(data.squeeze(1)) for data in full_data.chunk(full_data.size(1), dim=1)]
full_data = torch.stack(full_data, dim=1)
# (2) predict edge logit (consider only the last layer logit, num_tasks x 2 x num_samples x num_samples)
if tt.arg.test_transductive:
full_logit_all = self.gnn_module(node_feat=full_data, edge_feat=init_edge)
full_logit = full_logit_all[-1]
else:
evaluation_mask[:, num_supports:, num_supports:] = 0 # ignore query-query edges, since it is non-transductive setting
full_logit = torch.zeros(tt.arg.test_batch_size, 2, num_samples, num_samples).to(tt.arg.device)
# input_node_feat: (batch_size x num_queries) x (num_support + 1) x featdim
# input_edge_feat: (batch_size x num_queries) x 2 x (num_support + 1) x (num_support + 1)
support_data = full_data[:, :num_supports] # batch_size x num_support x featdim
query_data = full_data[:, num_supports:] # batch_size x num_query x featdim
support_data_tiled = support_data.unsqueeze(1).repeat(1, num_queries, 1, 1) # batch_size x num_queries x num_support x featdim
support_data_tiled = support_data_tiled.view(tt.arg.test_batch_size * num_queries, num_supports, -1) # (batch_size x num_queries) x num_support x featdim
query_data_reshaped = query_data.contiguous().view(tt.arg.test_batch_size * num_queries, -1).unsqueeze(1) # (batch_size x num_queries) x 1 x featdim
input_node_feat = torch.cat([support_data_tiled, query_data_reshaped], 1) # (batch_size x num_queries) x (num_support + 1) x featdim
input_edge_feat = 0.5 * torch.ones(tt.arg.test_batch_size, 2, num_supports + 1, num_supports + 1).to(tt.arg.device) # batch_size x 2 x (num_support + 1) x (num_support + 1)
input_edge_feat[:, :, :num_supports, :num_supports] = init_edge[:, :, :num_supports, :num_supports] # batch_size x 2 x (num_support + 1) x (num_support + 1)
input_edge_feat = input_edge_feat.repeat(num_queries, 1, 1, 1) # (batch_size x num_queries) x 2 x (num_support + 1) x (num_support + 1)
# logit: (batch_size x num_queries) x 2 x (num_support + 1) x (num_support + 1)
logit = self.gnn_module(node_feat=input_node_feat, edge_feat=input_edge_feat)[-1]
logit = logit.view(tt.arg.test_batch_size, num_queries, 2, num_supports + 1, num_supports + 1)
# batch_size x num_queries x 2 x (num_support + 1) x (num_support + 1)
# logit --> full_logit (batch_size x 2 x num_samples x num_samples)
full_logit[:, :, :num_supports, :num_supports] = logit[:, :, :, :num_supports, :num_supports].mean(1)
full_logit[:, :, :num_supports, num_supports:] = logit[:, :, :, :num_supports, -1].transpose(1, 2).transpose(2, 3)
full_logit[:, :, num_supports:, :num_supports] = logit[:, :, :, -1, :num_supports].transpose(1, 2)
# (4) compute loss
full_edge_loss = self.edge_loss(1-full_logit[:, 0], 1-full_edge[:, 0])
query_edge_loss = torch.sum(full_edge_loss * query_edge_mask * evaluation_mask) / torch.sum(query_edge_mask * evaluation_mask)
# weighted loss for balancing pos/neg
pos_query_edge_loss = torch.sum(full_edge_loss * query_edge_mask * full_edge[:, 0] * evaluation_mask) / torch.sum(query_edge_mask * full_edge[:, 0] * evaluation_mask)
neg_query_edge_loss = torch.sum(full_edge_loss * query_edge_mask * (1-full_edge[:, 0]) * evaluation_mask) / torch.sum(query_edge_mask * (1-full_edge[:, 0]) * evaluation_mask)
query_edge_loss = pos_query_edge_loss + neg_query_edge_loss
# compute accuracy
full_edge_accr = self.hit(full_logit, 1-full_edge[:, 0].long())
query_edge_accr = torch.sum(full_edge_accr * query_edge_mask * evaluation_mask) / torch.sum(query_edge_mask * evaluation_mask)
# compute node accuracy (num_tasks x num_quries x num_ways)
query_node_pred = torch.bmm(full_logit[:, 0, num_supports:, :num_supports], self.one_hot_encode(tt.arg.num_ways_test, support_label.long())) # (num_tasks x num_quries x num_supports) * (num_tasks x num_supports x num_ways)
query_node_accr = torch.eq(torch.max(query_node_pred, -1)[1], query_label.long()).float().mean()
query_edge_losses += [query_edge_loss.item()]
query_edge_accrs += [query_edge_accr.item()]
query_node_accrs += [query_node_accr.item()]
# logging
if log_flag:
tt.log('---------------------------')
tt.log_scalar('{}/edge_loss'.format(partition), np.array(query_edge_losses).mean(), self.global_step)
tt.log_scalar('{}/edge_accr'.format(partition), np.array(query_edge_accrs).mean(), self.global_step)
tt.log_scalar('{}/node_accr'.format(partition), np.array(query_node_accrs).mean(), self.global_step)
tt.log('evaluation: total_count=%d, accuracy: mean=%.2f%%, std=%.2f%%, ci95=%.2f%%' %
(iter,
np.array(query_node_accrs).mean() * 100,
np.array(query_node_accrs).std() * 100,
1.96 * np.array(query_node_accrs).std() / np.sqrt(float(len(np.array(query_node_accrs)))) * 100))
tt.log('---------------------------')
return np.array(query_node_accrs).mean()
def adjust_learning_rate(self, optimizers, lr, iter):
new_lr = lr * (0.5 ** (int(iter / tt.arg.dec_lr)))
for optimizer in optimizers:
for param_group in optimizer.param_groups:
param_group['lr'] = new_lr
def label2edge(self, label):
# get size
num_samples = label.size(1)
# reshape
label_i = label.unsqueeze(-1).repeat(1, 1, num_samples)
label_j = label_i.transpose(1, 2)
# compute edge
edge = torch.eq(label_i, label_j).float().to(tt.arg.device)
# expand
edge = edge.unsqueeze(1)
edge = torch.cat([edge, 1 - edge], 1)
return edge
def hit(self, logit, label):
pred = logit.max(1)[1]
hit = torch.eq(pred, label).float()
return hit
def one_hot_encode(self, num_classes, class_idx):
return torch.eye(num_classes)[class_idx].to(tt.arg.device)
def save_checkpoint(self, state, is_best):
torch.save(state, 'asset/checkpoints/{}/'.format(tt.arg.experiment) + 'checkpoint.pth.tar')
if is_best:
shutil.copyfile('asset/checkpoints/{}/'.format(tt.arg.experiment) + 'checkpoint.pth.tar',
'asset/checkpoints/{}/'.format(tt.arg.experiment) + 'model_best.pth.tar')
def set_exp_name():
exp_name = 'D-{}'.format(tt.arg.dataset)
exp_name += '_N-{}_K-{}_U-{}'.format(tt.arg.num_ways, tt.arg.num_shots, tt.arg.num_unlabeled)
exp_name += '_L-{}_B-{}'.format(tt.arg.num_layers, tt.arg.meta_batch_size)
exp_name += '_T-{}'.format(tt.arg.transductive)
exp_name += '_SEED-{}'.format(tt.arg.seed)
return exp_name
if __name__ == '__main__':
tt.arg.device = 'cuda:0' if tt.arg.device is None else tt.arg.device
# replace dataset_root with your own
tt.arg.dataset_root = '/data/private/dataset'
tt.arg.dataset = 'mini' if tt.arg.dataset is None else tt.arg.dataset
tt.arg.num_ways = 5 if tt.arg.num_ways is None else tt.arg.num_ways
tt.arg.num_shots = 1 if tt.arg.num_shots is None else tt.arg.num_shots
tt.arg.num_unlabeled = 0 if tt.arg.num_unlabeled is None else tt.arg.num_unlabeled
tt.arg.num_layers = 3 if tt.arg.num_layers is None else tt.arg.num_layers
tt.arg.meta_batch_size = 40 if tt.arg.meta_batch_size is None else tt.arg.meta_batch_size
tt.arg.transductive = False if tt.arg.transductive is None else tt.arg.transductive
tt.arg.seed = 222 if tt.arg.seed is None else tt.arg.seed
tt.arg.num_gpus = 1 if tt.arg.num_gpus is None else tt.arg.num_gpus
tt.arg.num_ways_train = tt.arg.num_ways
tt.arg.num_ways_test = tt.arg.num_ways
tt.arg.num_shots_train = tt.arg.num_shots
tt.arg.num_shots_test = tt.arg.num_shots
tt.arg.train_transductive = tt.arg.transductive
tt.arg.test_transductive = tt.arg.transductive
# model parameter related
tt.arg.num_edge_features = 96
tt.arg.num_node_features = 96
tt.arg.emb_size = 128
# train, test parameters
tt.arg.train_iteration = 100000 if tt.arg.dataset == 'mini' else 200000
tt.arg.test_iteration = 10000
tt.arg.test_interval = 5000 if tt.arg.test_interval is None else tt.arg.test_interval
tt.arg.test_batch_size = 10
tt.arg.log_step = 1000 if tt.arg.log_step is None else tt.arg.log_step
tt.arg.lr = 1e-3
tt.arg.grad_clip = 5
tt.arg.weight_decay = 1e-6
tt.arg.dec_lr = 15000 if tt.arg.dataset == 'mini' else 30000
tt.arg.dropout = 0.1 if tt.arg.dataset == 'mini' else 0.0
tt.arg.experiment = set_exp_name() if tt.arg.experiment is None else tt.arg.experiment
print(set_exp_name())
#set random seed
np.random.seed(tt.arg.seed)
torch.manual_seed(tt.arg.seed)
torch.cuda.manual_seed_all(tt.arg.seed)
random.seed(tt.arg.seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
tt.arg.log_dir_user = tt.arg.log_dir if tt.arg.log_dir_user is None else tt.arg.log_dir_user
tt.arg.log_dir = tt.arg.log_dir_user
if not os.path.exists('asset/checkpoints'):
os.makedirs('asset/checkpoints')
if not os.path.exists('asset/checkpoints/' + tt.arg.experiment):
os.makedirs('asset/checkpoints/' + tt.arg.experiment)
enc_module = EmbeddingImagenet(emb_size=tt.arg.emb_size)
gnn_module = GraphNetwork(in_features=tt.arg.emb_size,
node_features=tt.arg.num_edge_features,
edge_features=tt.arg.num_node_features,
num_layers=tt.arg.num_layers,
dropout=tt.arg.dropout)
if tt.arg.dataset == 'mini':
train_loader = MiniImagenetLoader(root=tt.arg.dataset_root, partition='train')
valid_loader = MiniImagenetLoader(root=tt.arg.dataset_root, partition='val')
elif tt.arg.dataset == 'tiered':
train_loader = TieredImagenetLoader(root=tt.arg.dataset_root, partition='train')
valid_loader = TieredImagenetLoader(root=tt.arg.dataset_root, partition='val')
else:
print('Unknown dataset!')
data_loader = {'train': train_loader,
'val': valid_loader
}
# create trainer
trainer = ModelTrainer(enc_module=enc_module,
gnn_module=gnn_module,
data_loader=data_loader)
trainer.train()
gitextract__r5a0unu/ ├── .idea/ │ ├── egnn_distribute.iml │ ├── modules.xml │ ├── vcs.xml │ └── workspace.xml ├── LICENSE ├── README.md ├── __init__.py ├── _version.py ├── data.py ├── eval.py ├── model.py ├── torchtools/ │ ├── __init__.py │ ├── _version.py │ └── tt/ │ ├── __init__.py │ ├── arg.py │ ├── layer.py │ ├── logger.py │ ├── stat.py │ ├── trainer.py │ └── utils.py └── train.py
SYMBOL INDEX (79 symbols across 9 files)
FILE: data.py
class MiniImagenetLoader (line 13) | class MiniImagenetLoader(data.Dataset):
method __init__ (line 14) | def __init__(self, root, partition='train'):
method load_dataset (line 40) | def load_dataset(self):
method get_task_batch (line 61) | def get_task_batch(self,
class TieredImagenetLoader (line 123) | class TieredImagenetLoader(data.Dataset):
method __init__ (line 124) | def __init__(self, root, partition='train'):
method get_image_paths (line 140) | def get_image_paths(self, file):
method preprocess (line 152) | def preprocess(self):
method load_dataset (line 245) | def load_dataset(self):
method chunks (line 284) | def chunks(self, data, size=10000):
method get_task_batch (line 289) | def get_task_batch(self,
FILE: model.py
class ConvBlock (line 9) | class ConvBlock(nn.Module):
method __init__ (line 10) | def __init__(self, in_planes, out_planes, userelu=True, momentum=0.1, ...
method forward (line 27) | def forward(self, x):
class ConvNet (line 31) | class ConvNet(nn.Module):
method __init__ (line 32) | def __init__(self, opt, momentum=0.1, affine=True, track_running_stats...
method forward (line 62) | def forward(self, x):
class EmbeddingImagenet (line 70) | class EmbeddingImagenet(nn.Module):
method __init__ (line 71) | def __init__(self,
method forward (line 117) | def forward(self, input_data):
class NodeUpdateNetwork (line 124) | class NodeUpdateNetwork(nn.Module):
method __init__ (line 125) | def __init__(self,
method forward (line 154) | def forward(self, node_feat, edge_feat):
class EdgeUpdateNetwork (line 175) | class EdgeUpdateNetwork(nn.Module):
method __init__ (line 176) | def __init__(self,
method forward (line 230) | def forward(self, node_feat, edge_feat):
class GraphNetwork (line 259) | class GraphNetwork(nn.Module):
method __init__ (line 260) | def __init__(self,
method forward (line 291) | def forward(self, node_feat, edge_feat):
FILE: torchtools/tt/arg.py
class _Opt (line 15) | class _Opt(object):
method __len__ (line 17) | def __len__(self):
method __setitem__ (line 20) | def __setitem__(self, key, value):
method __getitem__ (line 23) | def __getitem__(self, item):
method __getattr__ (line 29) | def __getattr__(self, item):
function _to_py_obj (line 33) | def _to_py_obj(x):
function _parse_config (line 49) | def _parse_config(arg, file):
function _parse_config_thread (line 68) | def _parse_config_thread(arg, file):
function _print_opts (line 86) | def _print_opts(arg, header):
function _parse_opts (line 94) | def _parse_opts():
FILE: torchtools/tt/layer.py
class Reshape (line 7) | class Reshape(nn.Module):
method __init__ (line 9) | def __init__(self, *shape):
method forward (line 13) | def forward(self, x):
method extra_repr (line 16) | def extra_repr(self):
FILE: torchtools/tt/logger.py
function log (line 19) | def log(*args):
function _get_writer (line 28) | def _get_writer():
function log_scalar (line 42) | def log_scalar(tag, value, global_step=None):
function log_audio (line 46) | def log_audio(tag, audio, global_step=None):
function log_image (line 50) | def log_image(tag, image, global_step=None):
function log_text (line 54) | def log_text(tag, text, global_step=None):
function log_hist (line 58) | def log_hist(tag, values, global_step=None):
function log_step (line 62) | def log_step(epoch=None, global_step=None, max_epoch=None, max_step=None):
function log_weight (line 115) | def log_weight(model, global_step=None):
function log_gradient (line 123) | def log_gradient(model, global_step=None):
FILE: torchtools/tt/stat.py
function accuracy (line 7) | def accuracy(prob, label, ignore_index=-100):
FILE: torchtools/tt/trainer.py
class SupervisedTrainer (line 7) | class SupervisedTrainer(object):
method __init__ (line 9) | def __init__(self, model, data_loader, optimizer=None, criterion=None):
method train (line 16) | def train(self, inputs):
method epoch (line 42) | def epoch(self, ep_no=None):
method run (line 45) | def run(self):
FILE: torchtools/tt/utils.py
function tic (line 17) | def tic():
function toc (line 23) | def toc(tic=None):
function sleep (line 31) | def sleep(seconds):
function var (line 38) | def var(data, dtype=None, device=None, requires_grad=False):
function vars (line 44) | def vars(x_list, dtype=None, device=None, requires_grad=False):
function cvar (line 49) | def cvar(x):
function nvar (line 56) | def nvar(x):
function nvars (line 63) | def nvars(x_list):
function load_model (line 67) | def load_model(model, best=False, postfix=None, experiment=None):
function save_model (line 92) | def save_model(model, global_step, force=False, best=None, postfix=None):
FILE: train.py
class ModelTrainer (line 10) | class ModelTrainer(object):
method __init__ (line 11) | def __init__(self,
method train (line 46) | def train(self):
method eval (line 204) | def eval(self, partition='test', log_flag=True):
method adjust_learning_rate (line 335) | def adjust_learning_rate(self, optimizers, lr, iter):
method label2edge (line 342) | def label2edge(self, label):
method hit (line 358) | def hit(self, logit, label):
method one_hot_encode (line 363) | def one_hot_encode(self, num_classes, class_idx):
method save_checkpoint (line 366) | def save_checkpoint(self, state, is_best):
function set_exp_name (line 372) | def set_exp_name():
Condensed preview — 21 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (101K chars).
[
{
"path": ".idea/egnn_distribute.iml",
"chars": 512,
"preview": "<?xml version=\"1.0\" encoding=\"UTF-8\"?>\n<module type=\"PYTHON_MODULE\" version=\"4\">\n <component name=\"NewModuleRootManager"
},
{
"path": ".idea/modules.xml",
"chars": 282,
"preview": "<?xml version=\"1.0\" encoding=\"UTF-8\"?>\n<project version=\"4\">\n <component name=\"ProjectModuleManager\">\n <modules>\n "
},
{
"path": ".idea/vcs.xml",
"chars": 180,
"preview": "<?xml version=\"1.0\" encoding=\"UTF-8\"?>\n<project version=\"4\">\n <component name=\"VcsDirectoryMappings\">\n <mapping dire"
},
{
"path": ".idea/workspace.xml",
"chars": 11692,
"preview": "<?xml version=\"1.0\" encoding=\"UTF-8\"?>\n<project version=\"4\">\n <component name=\"ChangeListManager\">\n <list default=\"t"
},
{
"path": "LICENSE",
"chars": 1068,
"preview": "MIT License\n\nCopyright (c) 2019 Jongmin Kim\n\nPermission is hereby granted, free of charge, to any person obtaining a cop"
},
{
"path": "README.md",
"chars": 8545,
"preview": "# fewshot-egnn\n\n### Introduction\n\nThe current project page provides pytorch code that implements the following CVPR2019 "
},
{
"path": "__init__.py",
"chars": 397,
"preview": "import numpy as np\nimport torch\nfrom torch import nn\nfrom torch import optim\nfrom torch import cuda\nfrom torch import ut"
},
{
"path": "_version.py",
"chars": 53,
"preview": "__version__ = '0.4.0' # align version with pytorch\n\n"
},
{
"path": "data.py",
"chars": 14914,
"preview": "from __future__ import print_function\nfrom torchtools import *\nimport torch.utils.data as data\nimport random\nimport os\ni"
},
{
"path": "eval.py",
"chars": 5038,
"preview": "from torchtools import *\nfrom data import MiniImagenetLoader, TieredImagenetLoader\nfrom model import EmbeddingImagenet, "
},
{
"path": "model.py",
"chars": 13890,
"preview": "from torchtools import *\nfrom collections import OrderedDict\nimport math\n#import seaborn as sns\nimport numpy as np\nimpor"
},
{
"path": "torchtools/__init__.py",
"chars": 397,
"preview": "import numpy as np\nimport torch\nfrom torch import nn\nfrom torch import optim\nfrom torch import cuda\nfrom torch import ut"
},
{
"path": "torchtools/_version.py",
"chars": 53,
"preview": "__version__ = '0.4.0' # align version with pytorch\n\n"
},
{
"path": "torchtools/tt/__init__.py",
"chars": 310,
"preview": "from torchtools.tt.arg import _parse_opts\nfrom torchtools.tt.utils import *\nfrom torchtools.tt.layer import *\nfrom torch"
},
{
"path": "torchtools/tt/arg.py",
"chars": 3735,
"preview": "import sys\nimport configparser\nimport torch\nimport threading\nimport time\nimport os\n\n\n__author__ = 'namju.kim@kakaobrain."
},
{
"path": "torchtools/tt/layer.py",
"chars": 338,
"preview": "from torchtools import nn\n\n\n#\n# Reshape layer for Sequential or ModuleList\n#\nclass Reshape(nn.Module):\n\n def __init__"
},
{
"path": "torchtools/tt/logger.py",
"chars": 3951,
"preview": "import datetime\nimport time\nfrom tensorboardX import SummaryWriter\nfrom torchtools import tt\n\n\n__author__ = 'namju.kim@k"
},
{
"path": "torchtools/tt/stat.py",
"chars": 407,
"preview": "from torchtools import tt\n\n\n__author__ = 'namju.kim@kakaobrain.com'\n\n\ndef accuracy(prob, label, ignore_index=-100):\n\n "
},
{
"path": "torchtools/tt/trainer.py",
"chars": 2631,
"preview": "from torchtools import nn, optim, tt\n\n\n__author__ = 'namju.kim@kakaobrain.com'\n\n\nclass SupervisedTrainer(object):\n\n d"
},
{
"path": "torchtools/tt/utils.py",
"chars": 3725,
"preview": "import os\nimport datetime\nimport time\nimport pathlib\nfrom torchtools import torch, nn, tt\n\n\n__author__ = 'namju.kim@kaka"
},
{
"path": "train.py",
"chars": 25078,
"preview": "from torchtools import *\nfrom data import MiniImagenetLoader, TieredImagenetLoader\nfrom model import EmbeddingImagenet, "
}
]
About this extraction
This page contains the full source code of the khy0809/fewshot-egnn GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 21 files (94.9 KB), approximately 24.2k tokens, and a symbol index with 79 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.
Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.