Full Code of khy0809/fewshot-egnn for AI

master 205fa80ec7cb cached
21 files
94.9 KB
24.2k tokens
79 symbols
1 requests
Download .txt
Repository: khy0809/fewshot-egnn
Branch: master
Commit: 205fa80ec7cb
Files: 21
Total size: 94.9 KB

Directory structure:
gitextract__r5a0unu/

├── .idea/
│   ├── egnn_distribute.iml
│   ├── modules.xml
│   ├── vcs.xml
│   └── workspace.xml
├── LICENSE
├── README.md
├── __init__.py
├── _version.py
├── data.py
├── eval.py
├── model.py
├── torchtools/
│   ├── __init__.py
│   ├── _version.py
│   └── tt/
│       ├── __init__.py
│       ├── arg.py
│       ├── layer.py
│       ├── logger.py
│       ├── stat.py
│       ├── trainer.py
│       └── utils.py
└── train.py

================================================
FILE CONTENTS
================================================

================================================
FILE: .idea/egnn_distribute.iml
================================================
<?xml version="1.0" encoding="UTF-8"?>
<module type="PYTHON_MODULE" version="4">
  <component name="NewModuleRootManager">
    <content url="file://$MODULE_DIR$" />
    <orderEntry type="jdk" jdkName="Remote Python 3.6.8 (sftp://root@instance.cloud.kakaobrain.com:11255/opt/conda/bin/python3)" jdkType="Python SDK" />
    <orderEntry type="sourceFolder" forTests="false" />
  </component>
  <component name="TestRunnerService">
    <option name="PROJECT_TEST_RUNNER" value="Unittests" />
  </component>
</module>

================================================
FILE: .idea/modules.xml
================================================
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
  <component name="ProjectModuleManager">
    <modules>
      <module fileurl="file://$PROJECT_DIR$/.idea/egnn_distribute.iml" filepath="$PROJECT_DIR$/.idea/egnn_distribute.iml" />
    </modules>
  </component>
</project>

================================================
FILE: .idea/vcs.xml
================================================
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
  <component name="VcsDirectoryMappings">
    <mapping directory="$PROJECT_DIR$" vcs="Git" />
  </component>
</project>

================================================
FILE: .idea/workspace.xml
================================================
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
  <component name="ChangeListManager">
    <list default="true" id="f20b581c-b8b4-4c9e-9203-c0c2c2f454b5" name="Default Changelist" comment="" />
    <option name="EXCLUDED_CONVERTED_TO_IGNORED" value="true" />
    <option name="SHOW_DIALOG" value="false" />
    <option name="HIGHLIGHT_CONFLICTS" value="true" />
    <option name="HIGHLIGHT_NON_ACTIVE_CHANGELIST" value="false" />
    <option name="LAST_RESOLUTION" value="IGNORE" />
  </component>
  <component name="FUSProjectUsageTrigger">
    <session id="-1231903785">
      <usages-collector id="statistics.lifecycle.project">
        <counts>
          <entry key="project.open.time.1" value="2" />
          <entry key="project.opened" value="2" />
        </counts>
      </usages-collector>
      <usages-collector id="statistics.file.extensions.open">
        <counts>
          <entry key="md" value="2" />
          <entry key="py" value="4" />
        </counts>
      </usages-collector>
      <usages-collector id="statistics.file.types.open">
        <counts>
          <entry key="Markdown" value="2" />
          <entry key="Python" value="4" />
        </counts>
      </usages-collector>
      <usages-collector id="statistics.file.extensions.edit">
        <counts>
          <entry key="md" value="1812" />
          <entry key="py" value="112" />
        </counts>
      </usages-collector>
      <usages-collector id="statistics.file.types.edit">
        <counts>
          <entry key="Markdown" value="1812" />
          <entry key="Python" value="112" />
        </counts>
      </usages-collector>
    </session>
  </component>
  <component name="FileEditorManager">
    <leaf SIDE_TABS_SIZE_LIMIT_KEY="300">
      <file pinned="false" current-in-tab="false">
        <entry file="file://$PROJECT_DIR$/data.py">
          <provider selected="true" editor-type-id="text-editor">
            <state relative-caret-position="3705">
              <caret line="247" column="37" selection-start-line="247" selection-start-column="37" selection-end-line="247" selection-end-column="37" />
              <folding>
                <element signature="e#0#37#0" expanded="true" />
              </folding>
            </state>
          </provider>
        </entry>
      </file>
      <file pinned="false" current-in-tab="false">
        <entry file="file://$PROJECT_DIR$/model.py">
          <provider selected="true" editor-type-id="text-editor">
            <state relative-caret-position="615">
              <caret line="41" column="30" selection-start-line="41" selection-start-column="30" selection-end-line="41" selection-end-column="30" />
              <folding>
                <element signature="e#0#24#0" expanded="true" />
              </folding>
            </state>
          </provider>
        </entry>
      </file>
      <file pinned="false" current-in-tab="false">
        <entry file="file://$PROJECT_DIR$/train.py">
          <provider selected="true" editor-type-id="text-editor">
            <state relative-caret-position="5820">
              <caret line="388" column="49" selection-start-line="388" selection-start-column="26" selection-end-line="388" selection-end-column="49" />
              <folding>
                <element signature="e#0#24#0" expanded="true" />
              </folding>
            </state>
          </provider>
        </entry>
      </file>
      <file pinned="false" current-in-tab="true">
        <entry file="file://$PROJECT_DIR$/README.md">
          <provider selected="true" editor-type-id="split-provider[text-editor;markdown-preview-editor]">
            <state split_layout="SPLIT">
              <first_editor relative-caret-position="239">
                <caret line="144" column="47" selection-start-line="144" selection-start-column="47" selection-end-line="144" selection-end-column="47" />
              </first_editor>
              <second_editor />
            </state>
          </provider>
        </entry>
      </file>
    </leaf>
  </component>
  <component name="FindInProjectRecents">
    <findStrings>
      <find>tt.arg.inter_dea</find>
      <find>inter_deactivate</find>
    </findStrings>
  </component>
  <component name="Git.Settings">
    <option name="RECENT_GIT_ROOT_PATH" value="$PROJECT_DIR$" />
  </component>
  <component name="IdeDocumentHistory">
    <option name="CHANGED_PATHS">
      <list>
        <option value="$PROJECT_DIR$/model.py" />
        <option value="$PROJECT_DIR$/train.py" />
        <option value="$PROJECT_DIR$/README.md" />
      </list>
    </option>
  </component>
  <component name="JsBuildToolGruntFileManager" detection-done="true" sorting="DEFINITION_ORDER" />
  <component name="JsBuildToolPackageJson" detection-done="true" sorting="DEFINITION_ORDER" />
  <component name="JsGulpfileManager">
    <detection-done>true</detection-done>
    <sorting>DEFINITION_ORDER</sorting>
  </component>
  <component name="ProjectFrameBounds" fullScreen="true">
    <option name="y" value="23" />
    <option name="width" value="1440" />
    <option name="height" value="877" />
  </component>
  <component name="ProjectLevelVcsManager" settingsEditedManually="true" />
  <component name="ProjectView">
    <navigator proportions="" version="1">
      <foldersAlwaysOnTop value="true" />
    </navigator>
    <panes>
      <pane id="ProjectPane">
        <subPane>
          <expand>
            <path>
              <item name="egnn_distribute" type="b2602c69:ProjectViewProjectNode" />
              <item name="egnn_distribute" type="462c0819:PsiDirectoryNode" />
            </path>
          </expand>
          <select />
        </subPane>
      </pane>
      <pane id="Scope" />
    </panes>
  </component>
  <component name="PropertiesComponent">
    <property name="WebServerToolWindowFactoryState" value="true" />
    <property name="last_opened_file_path" value="$PROJECT_DIR$" />
    <property name="nodejs_interpreter_path.stuck_in_default_project" value="undefined stuck path" />
    <property name="nodejs_npm_path_reset_for_default_project" value="true" />
    <property name="settings.editor.selected.configurable" value="com.jetbrains.python.configuration.PyActiveSdkModuleConfigurable" />
  </component>
  <component name="RunDashboard">
    <option name="ruleStates">
      <list>
        <RuleState>
          <option name="name" value="ConfigurationTypeDashboardGroupingRule" />
        </RuleState>
        <RuleState>
          <option name="name" value="StatusDashboardGroupingRule" />
        </RuleState>
      </list>
    </option>
  </component>
  <component name="SvnConfiguration">
    <configuration />
  </component>
  <component name="TaskManager">
    <task active="true" id="Default" summary="Default task">
      <changelist id="f20b581c-b8b4-4c9e-9203-c0c2c2f454b5" name="Default Changelist" comment="" />
      <created>1556855662817</created>
      <option name="number" value="Default" />
      <option name="presentableId" value="Default" />
      <updated>1556855662817</updated>
    </task>
    <servers />
  </component>
  <component name="ToolWindowManager">
    <frame x="0" y="0" width="1440" height="900" extended-state="0" />
    <editor active="true" />
    <layout>
      <window_info active="true" content_ui="combo" id="Project" order="0" visible="true" weight="0.19456366" />
      <window_info id="Structure" order="1" side_tool="true" weight="0.25" />
      <window_info id="Favorites" order="2" side_tool="true" />
      <window_info anchor="bottom" id="Message" order="0" />
      <window_info anchor="bottom" id="Find" order="1" />
      <window_info anchor="bottom" id="Run" order="2" />
      <window_info anchor="bottom" id="Debug" order="3" weight="0.4" />
      <window_info anchor="bottom" id="Cvs" order="4" weight="0.25" />
      <window_info anchor="bottom" id="Inspection" order="5" weight="0.4" />
      <window_info anchor="bottom" id="TODO" order="6" />
      <window_info anchor="bottom" id="Docker" order="7" show_stripe_button="false" />
      <window_info anchor="bottom" id="Version Control" order="8" show_stripe_button="false" />
      <window_info anchor="bottom" id="File Transfer" order="9" visible="true" weight="0.32771084" />
      <window_info anchor="bottom" id="Database Changes" order="10" show_stripe_button="false" />
      <window_info anchor="bottom" id="Terminal" order="11" />
      <window_info anchor="bottom" id="Event Log" order="12" side_tool="true" />
      <window_info anchor="bottom" id="Python Console" order="13" />
      <window_info anchor="right" id="Commander" internal_type="SLIDING" order="0" type="SLIDING" weight="0.4" />
      <window_info anchor="right" id="Ant Build" order="1" weight="0.25" />
      <window_info anchor="right" content_ui="combo" id="Hierarchy" order="2" weight="0.25" />
      <window_info anchor="right" id="Remote Host" order="3" />
      <window_info anchor="right" id="SciView" order="4" />
      <window_info anchor="right" id="Database" order="5" />
    </layout>
  </component>
  <component name="TypeScriptGeneratedFilesManager">
    <option name="version" value="1" />
  </component>
  <component name="VcsContentAnnotationSettings">
    <option name="myLimit" value="2678400000" />
  </component>
  <component name="editorHistoryManager">
    <entry file="file://$PROJECT_DIR$/eval.py">
      <provider selected="true" editor-type-id="text-editor">
        <state relative-caret-position="405">
          <caret line="27" column="46" lean-forward="true" selection-start-line="27" selection-start-column="46" selection-end-line="27" selection-end-column="46" />
          <folding>
            <element signature="e#0#24#0" expanded="true" />
          </folding>
        </state>
      </provider>
    </entry>
    <entry file="file://$PROJECT_DIR$/data.py">
      <provider selected="true" editor-type-id="text-editor">
        <state relative-caret-position="3705">
          <caret line="247" column="37" selection-start-line="247" selection-start-column="37" selection-end-line="247" selection-end-column="37" />
          <folding>
            <element signature="e#0#37#0" expanded="true" />
          </folding>
        </state>
      </provider>
    </entry>
    <entry file="file://$PROJECT_DIR$/model.py">
      <provider selected="true" editor-type-id="text-editor">
        <state relative-caret-position="615">
          <caret line="41" column="30" selection-start-line="41" selection-start-column="30" selection-end-line="41" selection-end-column="30" />
          <folding>
            <element signature="e#0#24#0" expanded="true" />
          </folding>
        </state>
      </provider>
    </entry>
    <entry file="file://$PROJECT_DIR$/train.py">
      <provider selected="true" editor-type-id="text-editor">
        <state relative-caret-position="5820">
          <caret line="388" column="49" selection-start-line="388" selection-start-column="26" selection-end-line="388" selection-end-column="49" />
          <folding>
            <element signature="e#0#24#0" expanded="true" />
          </folding>
        </state>
      </provider>
    </entry>
    <entry file="file://$PROJECT_DIR$/README.md">
      <provider selected="true" editor-type-id="split-provider[text-editor;markdown-preview-editor]">
        <state split_layout="SPLIT">
          <first_editor relative-caret-position="239">
            <caret line="144" column="47" selection-start-line="144" selection-start-column="47" selection-end-line="144" selection-end-column="47" />
          </first_editor>
          <second_editor />
        </state>
      </provider>
    </entry>
  </component>
</project>

================================================
FILE: LICENSE
================================================
MIT License

Copyright (c) 2019 Jongmin Kim

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.


================================================
FILE: README.md
================================================
# fewshot-egnn

### Introduction

The current project page provides pytorch code that implements the following CVPR2019 paper:   
**Title:**      "Edge-labeling Graph Neural Network for Few-shot Learning"    
**Authors:**     Jongmin Kim, Taesup Kim, Sungwoong Kim, Chang D.Yoo

**Institution:** KAIST, KaKaoBrain     
**Code:**        https://github.com/khy0809/fewshot-egnn  
**Arxiv:**       https://arxiv.org/abs/1905.01436

**Abstract:**
In this paper, we propose a novel edge-labeling graph
neural network (EGNN), which adapts a deep neural network
on the edge-labeling graph, for few-shot learning.
The previous graph neural network (GNN) approaches in
few-shot learning have been based on the node-labeling
framework, which implicitly models the intra-cluster similarity
and the inter-cluster dissimilarity. In contrast, the
proposed EGNN learns to predict the edge-labels rather
than the node-labels on the graph that enables the evolution
of an explicit clustering by iteratively updating the edgelabels
with direct exploitation of both intra-cluster similarity
and the inter-cluster dissimilarity. It is also well suited
for performing on various numbers of classes without retraining,
and can be easily extended to perform a transductive
inference. The parameters of the EGNN are learned
by episodic training with an edge-labeling loss to obtain a
well-generalizable model for unseen low-data problem. On
both of the supervised and semi-supervised few-shot image
classification tasks with two benchmark datasets, the proposed
EGNN significantly improves the performances over
the existing GNNs.

### Citation
If you find this code useful you can cite us using the following bibTex:
```
@article{kim2019egnn,
  title={Edge-labeling Graph Neural Network for Few-shot Learning},
  author={Jongmin Kim, Taesup Kim, Sungwoong Kim, Chang D. Yoo},
  journal={arXiv preprint arXiv:1905.01436},
  year={2019}
}
```


### Platform
This code was developed and tested with pytorch version 1.0.1

### Setting

You can download miniImagenet dataset from [here](https://drive.google.com/open?id=15WuREBvhEbSWo4fTr1r-vMY0C_6QWv4w).

Download 'mini_imagenet_train/val/test.pickle', and put them in the path 
'tt.arg.dataset_root/mini-imagenet/compacted_dataset/'

In ```train.py```, replace the dataset root directory with your own:
tt.arg.dataset_root = '/data/private/dataset'



### Training

```
# ************************** miniImagenet, 5way 1shot *****************************
$ python3 train.py --dataset mini --num_ways 5 --num_shots 1 --transductive False
$ python3 train.py --dataset mini --num_ways 5 --num_shots 1 --transductive True

# ************************** miniImagenet, 5way 5shot *****************************
$ python3 train.py --dataset mini --num_ways 5 --num_shots 5 --transductive False
$ python3 train.py --dataset mini --num_ways 5 --num_shots 5 --transductive True

# ************************** miniImagenet, 10way 5shot *****************************
$ python3 train.py --dataset mini --num_ways 10 --num_shots 5 --meta_batch_size 20 --transductive True

# ************************** tieredImagenet, 5way 5shot *****************************
$ python3 train.py --dataset tiered --num_ways 5 --num_shots 5 --transductive False
$ python3 train.py --dataset tiered --num_ways 5 --num_shots 5 --transductive True

# **************** miniImagenet, 5way 5shot, 20% labeled (semi) *********************
$ python3 train.py --dataset mini --num_ways 5 --num_shots 5 --num_unlabeled 4 --transductive False
$ python3 train.py --dataset mini --num_ways 5 --num_shots 5 --num_unlabeled 4 --transductive True

```

### Evaluation
The trained models are saved in the path './asset/checkpoints/', with the name of 'D-{dataset}-N-{ways}-K-{shots}-U-{num_unlabeld}-L-{num_layers}-B-{batch size}-T-{transductive}'.
So, for example, if you want to test the trained model of 'miniImagenet, 5way 1shot, transductive' setting, you can give --test_model argument as follow:
```
$ python3 eval.py --test_model D-mini_N-5_K-1_U-0_L-3_B-40_T-True
```


## Result
Here are some experimental results presented in the paper. You should be able to reproduce all the results by using the trained models which can be downloaded from [here](https://drive.google.com/open?id=15WuREBvhEbSWo4fTr1r-vMY0C_6QWv4w).
#### miniImageNet, non-transductive

| Model                    |  5-way 5-shot acc (%)| 
|--------------------------|  ------------------: | 
| Matching Networks [1]    |         55.30        | 
| Reptile [2]              |         62.74        | 
| Prototypical Net [3]     |         65.77        | 
| GNN [4]                  |         66.41        | 
| **(ours)** EGNN          |         **66.85**        | 

#### miniImageNet, transductive

| Model                    |  5-way 5-shot acc (%)| 
|--------------------------|  ------------------: | 
| MAML [5]                 |         63.11        | 
| Reptile + BN [2]         |         65.99        | 
| Relation Net [6]         |         67.07        | 
| MAML + Transduction [5]  |         66.19        | 
| TPN [7]                  |         69.43        | 
| TPN (Higher K) [7]       |         69.86        | 
| **(ours)** EGNN          |         **76.37**        | 

#### tieredImageNet, non-transductive

| Model                    |  5-way 5-shot acc (%)| 
|--------------------------|  ------------------: | 
| Reptile [2]              |         66.47        | 
| Prototypical Net [3]     |         69.57        | 
| **(ours)** EGNN          |         **70.98**        | 

#### tieredImageNet, transductive

| Model                    |  5-way 5-shot acc (%)| 
|--------------------------|  ------------------: | 
| MAML [5]                 |         70.30        | 
| Reptile + BN [2]         |         71.03        | 
| Relation Net [6]         |         71.31        | 
| MAML + Transduction [5]  |         70.83        | 
| TPN [7]                  |         72.58        | 
| **(ours)** EGNN          |         **80.15**        | 


#### miniImageNet, semi-supervised, 5-way 5-shot

| Model                    |  20%                 | 40%                 | 60%                 | 100%                 | 
|--------------------------|  ------------------: | ------------------: | ------------------: | ------------------:  | 
| GNN-LabeledOnly [4]       |      50.33                |      56.91               |        -             |        66.41              |
| GNN-Semi [4]             |      52.45                |      58.76               |        -             |        66.41              |
| EGNN-LabeledOnly         |      52.86                |        -             |            -         |            66.85          |
| EGNN-Semi                |      61.88                |        62.52             |        63.53             |    66.85                  |
| EGNN-LabeledOnly (Transductive) |      59.18         |         -            |           -          |           76.37           |
| EGNN-Semi (Transductive)        |      63.62         |        64.32             |        66.37             |   76.37                   |


#### miniImageNet, cross-way experiment
| Model                    |  train way                 | test way                 |  Accuracy |
|--------------------------|  ------------------: | ------------------: | ------------------: |
| GNN       |      5                |      5               |      66.41     |
| GNN       |      5                |      10               |     N/A      |
| GNN       |      10                |     10            |       51.75    |
| GNN       |      10             |      5              |       N/A    |
| EGNN       |      5             |      5              |       76.37    |
| EGNN       |      5             |      10              |       56.35    |
| EGNN       |      10             |      10              |       57.61   |
| EGNN       |      10             |      5              |       76.27   |



### References
```
[1] O. Vinyals et al. Matching networks for one shot learning.
[2] A Nichol, J Achiam, J Schulman, On first-order meta-learning algorithms.
[3] J. Snell, K. Swersky, and R. S. Zemel. Prototypical networks for few-shot learning.
[4] V Garcia, J Bruna, Few-shot learning with graph neural network.
[5] C. Finn, P. Abbeel, and S. Levine. Model-agnostic meta-learning for fast adaptation of deep networks.
[6] F. Sung et al, Learning to Compare: Relation Network for Few-Shot Learning.
[7] Y Liu, J Lee, M Park, S Kim, Y Yang, Transductive propagation network for few-shot learning.


================================================
FILE: __init__.py
================================================
import numpy as np
import torch
from torch import nn
from torch import optim
from torch import cuda
from torch import utils
from torch.nn import functional as F
from torch.utils.data import *
from torch.distributions import *
from torchtools import tt


__author__ = 'namju.kim@kakaobrain.com'


# initialize seed
if tt.arg.seed:
    np.random.seed(tt.arg.seed)
    torch.manual_seed(tt.arg.seed)


================================================
FILE: _version.py
================================================
__version__ = '0.4.0'  # align version with pytorch



================================================
FILE: data.py
================================================
from __future__ import print_function
from torchtools import *
import torch.utils.data as data
import random
import os
import numpy as np
from PIL import Image as pil_image
import pickle
from itertools import islice
from torchvision import transforms


class MiniImagenetLoader(data.Dataset):
    def __init__(self, root, partition='train'):
        super(MiniImagenetLoader, self).__init__()
        # set dataset information
        self.root = root
        self.partition = partition
        self.data_size = [3, 84, 84]

        # set normalizer
        mean_pix = [x / 255.0 for x in [120.39586422, 115.59361427, 104.54012653]]
        std_pix = [x / 255.0 for x in [70.68188272, 68.27635443, 72.54505529]]
        normalize = transforms.Normalize(mean=mean_pix, std=std_pix)

        # set transformer
        if self.partition == 'train':
            self.transform = transforms.Compose([transforms.RandomCrop(84, padding=4),
                                                 lambda x: np.asarray(x),
                                                 transforms.ToTensor(),
                                                 normalize])
        else:  # 'val' or 'test' ,
            self.transform = transforms.Compose([lambda x: np.asarray(x),
                                                 transforms.ToTensor(),
                                                 normalize])

        # load data
        self.data = self.load_dataset()

    def load_dataset(self):
        # load data
        dataset_path = os.path.join(self.root, 'mini-imagenet/compacted_datasets', 'mini_imagenet_%s.pickle' % self.partition)
        with open(dataset_path, 'rb') as handle:
            data = pickle.load(handle)

        # for each class
        for c_idx in data:
            # for each image
            for i_idx in range(len(data[c_idx])):
                # resize
                image_data = pil_image.fromarray(np.uint8(data[c_idx][i_idx]))
                image_data = image_data.resize((self.data_size[2], self.data_size[1]))
                #image_data = np.array(image_data, dtype='float32')

                #image_data = np.transpose(image_data, (2, 0, 1))

                # save
                data[c_idx][i_idx] = image_data
        return data

    def get_task_batch(self,
                       num_tasks=5,
                       num_ways=20,
                       num_shots=1,
                       num_queries=1,
                       seed=None):

        if seed is not None:
            random.seed(seed)

        # init task batch data
        support_data, support_label, query_data, query_label = [], [], [], []
        for _ in range(num_ways * num_shots):
            data = np.zeros(shape=[num_tasks] + self.data_size,
                            dtype='float32')
            label = np.zeros(shape=[num_tasks],
                             dtype='float32')
            support_data.append(data)
            support_label.append(label)
        for _ in range(num_ways * num_queries):
            data = np.zeros(shape=[num_tasks] + self.data_size,
                            dtype='float32')
            label = np.zeros(shape=[num_tasks],
                             dtype='float32')
            query_data.append(data)
            query_label.append(label)

        # get full class list in dataset
        full_class_list = list(self.data.keys())

        # for each task
        for t_idx in range(num_tasks):
            # define task by sampling classes (num_ways)
            task_class_list = random.sample(full_class_list, num_ways)

            # for each sampled class in task
            for c_idx in range(num_ways):
                # sample data for support and query (num_shots + num_queries)
                class_data_list = random.sample(self.data[task_class_list[c_idx]], num_shots + num_queries)


                # load sample for support set
                for i_idx in range(num_shots):
                    # set data
                    support_data[i_idx + c_idx * num_shots][t_idx] = self.transform(class_data_list[i_idx])
                    support_label[i_idx + c_idx * num_shots][t_idx] = c_idx

                # load sample for query set
                for i_idx in range(num_queries):
                    query_data[i_idx + c_idx * num_queries][t_idx] = self.transform(class_data_list[num_shots + i_idx])
                    query_label[i_idx + c_idx * num_queries][t_idx] = c_idx

        # convert to tensor (num_tasks x (num_ways * (num_supports + num_queries)) x ...)
        support_data = torch.stack([torch.from_numpy(data).float().to(tt.arg.device) for data in support_data], 1)
        support_label = torch.stack([torch.from_numpy(label).float().to(tt.arg.device) for label in support_label], 1)
        query_data = torch.stack([torch.from_numpy(data).float().to(tt.arg.device) for data in query_data], 1)
        query_label = torch.stack([torch.from_numpy(label).float().to(tt.arg.device) for label in query_label], 1)

        return [support_data, support_label, query_data, query_label]



class TieredImagenetLoader(data.Dataset):
    def __init__(self, root, partition='train'):
        self.root = root
        self.partition = partition  # train/val/test
        #self.preprocess()
        self.data_size = [3, 84, 84]

        # load data
        self.data = self.load_dataset()

        # if not self._check_exists_():
        #     self._init_folders_()
        #     if self.check_decompress():
        #         self._decompress_()
        #     self._preprocess_()


    def get_image_paths(self, file):
        images_path, class_names = [], []
        with open(file, 'r') as f:
            f.readline()
            for line in f:
                name, class_ = line.split(',')
                class_ = class_[0:(len(class_)-1)]
                path = self.root + '/tiered-imagenet/images/'+name
                images_path.append(path)
                class_names.append(class_)
        return class_names, images_path

    def preprocess(self):
        print('\nPreprocessing Tiered-Imagenet images...')
        (class_names_train, images_path_train) = self.get_image_paths('%s/tiered-imagenet/train.csv' % self.root)
        (class_names_test, images_path_test) = self.get_image_paths('%s/tiered-imagenet/test.csv' % self.root)
        (class_names_val, images_path_val) = self.get_image_paths('%s/tiered-imagenet/val.csv' % self.root)

        keys_train = list(set(class_names_train))
        keys_test = list(set(class_names_test))
        keys_val = list(set(class_names_val))
        label_encoder = {}
        label_decoder = {}
        for i in range(len(keys_train)):
            label_encoder[keys_train[i]] = i
            label_decoder[i] = keys_train[i]
        for i in range(len(keys_train), len(keys_train)+len(keys_test)):
            label_encoder[keys_test[i-len(keys_train)]] = i
            label_decoder[i] = keys_test[i-len(keys_train)]
        for i in range(len(keys_train)+len(keys_test), len(keys_train)+len(keys_test)+len(keys_val)):
            label_encoder[keys_val[i-len(keys_train) - len(keys_test)]] = i
            label_decoder[i] = keys_val[i-len(keys_train)-len(keys_test)]

        counter = 0
        train_set = {}

        for class_, path in zip(class_names_train, images_path_train):
            img = pil_image.open(path)
            img = img.convert('RGB')
            img = img.resize((84, 84), pil_image.ANTIALIAS)
            img = np.array(img, dtype='float32')
            if label_encoder[class_] not in train_set:
                train_set[label_encoder[class_]] = []
            train_set[label_encoder[class_]].append(img)
            counter += 1
            if counter % 1000 == 0:
                print("Counter "+str(counter) + " from " + str(len(images_path_train)))

        test_set = {}
        for class_, path in zip(class_names_test, images_path_test):
            img = pil_image.open(path)
            img = img.convert('RGB')
            img = img.resize((84, 84), pil_image.ANTIALIAS)
            img = np.array(img, dtype='float32')

            if label_encoder[class_] not in test_set:
                test_set[label_encoder[class_]] = []
            test_set[label_encoder[class_]].append(img)
            counter += 1
            if counter % 1000 == 0:
                print("Counter " + str(counter) + " from "+str(len(class_names_test)))

        val_set = {}
        for class_, path in zip(class_names_val, images_path_val):
            img = pil_image.open(path)
            img = img.convert('RGB')
            img = img.resize((84, 84), pil_image.ANTIALIAS)
            img = np.array(img, dtype='float32')

            if label_encoder[class_] not in val_set:
                val_set[label_encoder[class_]] = []
            val_set[label_encoder[class_]].append(img)
            counter += 1
            if counter % 1000 == 0:
                print("Counter "+str(counter) + " from " + str(len(class_names_val)))

        partition_count = 0
        for item in self.chunks(train_set, 20):
            partition_count = partition_count + 1
            with open(os.path.join(self.root, 'tiered-imagenet/compacted_datasets', 'tiered_imagenet_train_{}.pickle'.format(partition_count)), 'wb') as handle:
                pickle.dump(item, handle, protocol=2)

        partition_count = 0
        for item in self.chunks(test_set, 20):
            partition_count = partition_count + 1
            with open(os.path.join(self.root, 'tiered-imagenet/compacted_datasets', 'tiered_imagenet_test_{}.pickle'.format(partition_count)), 'wb') as handle:
                pickle.dump(item, handle, protocol=2)

        partition_count = 0
        for item in self.chunks(val_set, 20):
            partition_count = partition_count + 1
            with open(os.path.join(self.root, 'tiered-imagenet/compacted_datasets', 'tiered_imagenet_val_{}.pickle'.format(partition_count)), 'wb') as handle:
                pickle.dump(item, handle, protocol=2)



        label_encoder = {}
        keys = list(train_set.keys()) + list(test_set.keys())
        for id_key, key in enumerate(keys):
            label_encoder[key] = id_key
        with open(os.path.join(self.root, 'tiered-imagenet/compacted_datasets', 'tiered_imagenet_label_encoder.pickle'), 'wb') as handle:
            pickle.dump(label_encoder, handle, protocol=2)

        print('Images preprocessed')

    def load_dataset(self):
        print("Loading dataset")
        data = {}
        if self.partition == 'train':
            num_partition = 18
        elif self.partition == 'val':
            num_partition = 5
        elif self.partition == 'test':
            num_partition = 8

        partition_count = 0
        for i in range(num_partition):
            partition_count = partition_count +1
            with open(os.path.join(self.root, 'tiered-imagenet/compacted_datasets', 'tiered_imagenet_{}_{}.pickle'.format(self.partition, partition_count)), 'rb') as handle:
                data.update(pickle.load(handle))

        # Resize images and normalize
        for class_ in data:
            for i in range(len(data[class_])):
                image2resize = pil_image.fromarray(np.uint8(data[class_][i]))
                image_resized = image2resize.resize((self.data_size[2], self.data_size[1]))
                image_resized = np.array(image_resized, dtype='float32')

                # Normalize
                image_resized = np.transpose(image_resized, (2, 0, 1))
                image_resized[0, :, :] -= 120.45  # R
                image_resized[1, :, :] -= 115.74  # G
                image_resized[2, :, :] -= 104.65  # B
                image_resized /= 127.5

                data[class_][i] = image_resized

        print("Num classes " + str(len(data)))
        num_images = 0
        for class_ in data:
            num_images += len(data[class_])
        print("Num images " + str(num_images))
        return data

    def chunks(self, data, size=10000):
        it = iter(data)
        for i in range(0, len(data), size):
            yield {k: data[k] for k in islice(it, size)}

    def get_task_batch(self,
                       num_tasks=5,
                       num_ways=20,
                       num_shots=1,
                       num_queries=1,
                       seed=None):
        if seed is not None:
            random.seed(seed)

        # init task batch data
        support_data, support_label, query_data, query_label = [], [], [], []
        for _ in range(num_ways * num_shots):
            data = np.zeros(shape=[num_tasks] + self.data_size,
                            dtype='float32')
            label = np.zeros(shape=[num_tasks],
                             dtype='float32')
            support_data.append(data)
            support_label.append(label)
        for _ in range(num_ways * num_queries):
            data = np.zeros(shape=[num_tasks] + self.data_size,
                            dtype='float32')
            label = np.zeros(shape=[num_tasks],
                             dtype='float32')
            query_data.append(data)
            query_label.append(label)

        # get full class list in dataset
        full_class_list = list(self.data.keys())

        # for each task
        for t_idx in range(num_tasks):
            # define task by sampling classes (num_ways)
            task_class_list = random.sample(full_class_list, num_ways)

            # for each sampled class in task
            for c_idx in range(num_ways):
                # sample data for support and query (num_shots + num_queries)
                class_data_list = random.sample(self.data[task_class_list[c_idx]], num_shots + num_queries)

                # load sample for support set
                for i_idx in range(num_shots):
                    # set data
                    support_data[i_idx + c_idx * num_shots][t_idx] = class_data_list[i_idx]
                    support_label[i_idx + c_idx * num_shots][t_idx] = c_idx

                # load sample for query set
                for i_idx in range(num_queries):
                    query_data[i_idx + c_idx * num_queries][t_idx] = class_data_list[num_shots + i_idx]
                    query_label[i_idx + c_idx * num_queries][t_idx] = c_idx



        # convert to tensor (num_tasks x (num_ways * (num_supports + num_queries)) x ...)
        support_data = torch.stack([torch.from_numpy(data).float().to(tt.arg.device) for data in support_data], 1)
        support_label = torch.stack([torch.from_numpy(label).float().to(tt.arg.device) for label in support_label], 1)
        query_data = torch.stack([torch.from_numpy(data).float().to(tt.arg.device) for data in query_data], 1)
        query_label = torch.stack([torch.from_numpy(label).float().to(tt.arg.device) for label in query_label], 1)

        return [support_data, support_label, query_data, query_label]

================================================
FILE: eval.py
================================================
from torchtools import *
from data import MiniImagenetLoader, TieredImagenetLoader
from model import EmbeddingImagenet, GraphNetwork, ConvNet
import shutil
import os
import random
from train import ModelTrainer

if __name__ == '__main__':

    tt.arg.test_model = 'D-mini_N-5_K-1_U-0_L-3_B-40_T-True' if tt.arg.test_model is None else tt.arg.test_model

    list1 = tt.arg.test_model.split("_")
    param = {}
    for i in range(len(list1)):
        param[list1[i].split("-", 1)[0]] = list1[i].split("-", 1)[1]
    tt.arg.dataset = param['D']
    tt.arg.num_ways = int(param['N'])
    tt.arg.num_shots = int(param['K'])
    tt.arg.num_unlabeled = int(param['U'])
    tt.arg.num_layers = int(param['L'])
    tt.arg.meta_batch_size = int(param['B'])
    tt.arg.transductive = False if param['T'] == 'False' else True


    ####################
    tt.arg.device = 'cuda:0' if tt.arg.device is None else tt.arg.device
    # replace dataset_root with your own
    tt.arg.dataset_root = '/data/private/dataset'
    tt.arg.dataset = 'mini' if tt.arg.dataset is None else tt.arg.dataset
    tt.arg.num_ways = 5 if tt.arg.num_ways is None else tt.arg.num_ways
    tt.arg.num_shots = 1 if tt.arg.num_shots is None else tt.arg.num_shots
    tt.arg.num_unlabeled = 0 if tt.arg.num_unlabeled is None else tt.arg.num_unlabeled
    tt.arg.num_layers = 3 if tt.arg.num_layers is None else tt.arg.num_layers
    tt.arg.meta_batch_size = 40 if tt.arg.meta_batch_size is None else tt.arg.meta_batch_size
    tt.arg.transductive = False if tt.arg.transductive is None else tt.arg.transductive
    tt.arg.seed = 222 if tt.arg.seed is None else tt.arg.seed
    tt.arg.num_gpus = 1 if tt.arg.num_gpus is None else tt.arg.num_gpus

    tt.arg.num_ways_train = tt.arg.num_ways
    tt.arg.num_ways_test = tt.arg.num_ways

    tt.arg.num_shots_train = tt.arg.num_shots
    tt.arg.num_shots_test = tt.arg.num_shots

    tt.arg.train_transductive = tt.arg.transductive
    tt.arg.test_transductive = tt.arg.transductive

    # model parameter related
    tt.arg.num_edge_features = 96
    tt.arg.num_node_features = 96
    tt.arg.emb_size = 128

    # train, test parameters
    tt.arg.train_iteration = 100000 if tt.arg.dataset == 'mini' else 200000
    tt.arg.test_iteration = 10000
    tt.arg.test_interval = 5000
    tt.arg.test_batch_size = 10
    tt.arg.log_step = 1000

    tt.arg.lr = 1e-3
    tt.arg.grad_clip = 5
    tt.arg.weight_decay = 1e-6
    tt.arg.dec_lr = 15000 if tt.arg.dataset == 'mini' else 30000
    tt.arg.dropout = 0.1 if tt.arg.dataset == 'mini' else 0.0

    #set random seed
    np.random.seed(tt.arg.seed)
    torch.manual_seed(tt.arg.seed)
    torch.cuda.manual_seed_all(tt.arg.seed)
    random.seed(tt.arg.seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


    enc_module = EmbeddingImagenet(emb_size=tt.arg.emb_size)

    # set random seed
    np.random.seed(tt.arg.seed)
    torch.manual_seed(tt.arg.seed)
    torch.cuda.manual_seed_all(tt.arg.seed)
    random.seed(tt.arg.seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    # to check
    exp_name = 'D-{}'.format(tt.arg.dataset)
    exp_name += '_N-{}_K-{}_U-{}'.format(tt.arg.num_ways, tt.arg.num_shots, tt.arg.num_unlabeled)
    exp_name += '_L-{}_B-{}'.format(tt.arg.num_layers, tt.arg.meta_batch_size)
    exp_name += '_T-{}'.format(tt.arg.transductive)


    if not exp_name == tt.arg.test_model:
        print(exp_name)
        print(tt.arg.test_model)
        print('Test model and input arguments are mismatched!')
        AssertionError()

    gnn_module = GraphNetwork(in_features=tt.arg.emb_size,
                              node_features=tt.arg.num_edge_features,
                              edge_features=tt.arg.num_node_features,
                              num_layers=tt.arg.num_layers,
                              dropout=tt.arg.dropout)

    if tt.arg.dataset == 'mini':
        test_loader = MiniImagenetLoader(root=tt.arg.dataset_root, partition='test')
    elif tt.arg.dataset == 'tiered':
        test_loader = TieredImagenetLoader(root=tt.arg.dataset_root, partition='test')
    else:
        print('Unknown dataset!')


    data_loader = {'test': test_loader}

    # create trainer
    tester = ModelTrainer(enc_module=enc_module,
                           gnn_module=gnn_module,
                           data_loader=data_loader)


    #checkpoint = torch.load('asset/checkpoints/{}/'.format(exp_name) + 'model_best.pth.tar')
    checkpoint = torch.load('./trained_models/{}/'.format(exp_name) + 'model_best.pth.tar')


    tester.enc_module.load_state_dict(checkpoint['enc_module_state_dict'])
    print("load pre-trained enc_nn done!")

    # initialize gnn pre-trained
    tester.gnn_module.load_state_dict(checkpoint['gnn_module_state_dict'])
    print("load pre-trained egnn done!")

    tester.val_acc = checkpoint['val_acc']
    tester.global_step = checkpoint['iteration']

    print(tester.global_step)


    tester.eval(partition='test')







================================================
FILE: model.py
================================================
from torchtools import *
from collections import OrderedDict
import math
#import seaborn as sns
import numpy as np
import matplotlib.pyplot as plt


class ConvBlock(nn.Module):
    def __init__(self, in_planes, out_planes, userelu=True, momentum=0.1, affine=True, track_running_stats=True):
        super(ConvBlock, self).__init__()
        self.layers = nn.Sequential()
        self.layers.add_module('Conv', nn.Conv2d(in_planes, out_planes,
            kernel_size=3, stride=1, padding=1, bias=False))

        if tt.arg.normtype == 'batch':
            self.layers.add_module('Norm', nn.BatchNorm2d(out_planes, momentum=momentum, affine=affine, track_running_stats=track_running_stats))
        elif tt.arg.normtype == 'instance':
            self.layers.add_module('Norm', nn.InstanceNorm2d(out_planes))

        if userelu:
            self.layers.add_module('ReLU', nn.ReLU(inplace=True))

        self.layers.add_module(
            'MaxPool', nn.MaxPool2d(kernel_size=2, stride=2, padding=0))

    def forward(self, x):
        out = self.layers(x)
        return out

class ConvNet(nn.Module):
    def __init__(self, opt, momentum=0.1, affine=True, track_running_stats=True):
        super(ConvNet, self).__init__()
        self.in_planes  = opt['in_planes']
        self.out_planes = opt['out_planes']
        self.num_stages = opt['num_stages']
        if type(self.out_planes) == int:
            self.out_planes = [self.out_planes for i in range(self.num_stages)]
        assert(type(self.out_planes)==list and len(self.out_planes)==self.num_stages)

        num_planes = [self.in_planes,] + self.out_planes
        userelu = opt['userelu'] if ('userelu' in opt) else True

        conv_blocks = []
        for i in range(self.num_stages):
            if i == (self.num_stages-1):
                conv_blocks.append(
                    ConvBlock(num_planes[i], num_planes[i+1], userelu=userelu))
            else:
                conv_blocks.append(
                    ConvBlock(num_planes[i], num_planes[i+1]))
        self.conv_blocks = nn.Sequential(*conv_blocks)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

    def forward(self, x):
        out = self.conv_blocks(x)
        out = out.view(out.size(0),-1)
        return out



# encoder for imagenet dataset
class EmbeddingImagenet(nn.Module):
    def __init__(self,
                 emb_size):
        super(EmbeddingImagenet, self).__init__()
        # set size
        self.hidden = 64
        self.last_hidden = self.hidden * 25
        self.emb_size = emb_size

        # set layers
        self.conv_1 = nn.Sequential(nn.Conv2d(in_channels=3,
                                              out_channels=self.hidden,
                                              kernel_size=3,
                                              padding=1,
                                              bias=False),
                                    nn.BatchNorm2d(num_features=self.hidden),
                                    nn.MaxPool2d(kernel_size=2),
                                    nn.LeakyReLU(negative_slope=0.2, inplace=True))
        self.conv_2 = nn.Sequential(nn.Conv2d(in_channels=self.hidden,
                                              out_channels=int(self.hidden*1.5),
                                              kernel_size=3,
                                              bias=False),
                                    nn.BatchNorm2d(num_features=int(self.hidden*1.5)),
                                    nn.MaxPool2d(kernel_size=2),
                                    nn.LeakyReLU(negative_slope=0.2, inplace=True))
        self.conv_3 = nn.Sequential(nn.Conv2d(in_channels=int(self.hidden*1.5),
                                              out_channels=self.hidden*2,
                                              kernel_size=3,
                                              padding=1,
                                              bias=False),
                                    nn.BatchNorm2d(num_features=self.hidden * 2),
                                    nn.MaxPool2d(kernel_size=2),
                                    nn.LeakyReLU(negative_slope=0.2, inplace=True),
                                    nn.Dropout2d(0.4))
        self.conv_4 = nn.Sequential(nn.Conv2d(in_channels=self.hidden*2,
                                              out_channels=self.hidden*4,
                                              kernel_size=3,
                                              padding=1,
                                              bias=False),
                                    nn.BatchNorm2d(num_features=self.hidden * 4),
                                    nn.MaxPool2d(kernel_size=2),
                                    nn.LeakyReLU(negative_slope=0.2, inplace=True),
                                    nn.Dropout2d(0.5))
        self.layer_last = nn.Sequential(nn.Linear(in_features=self.last_hidden * 4,
                                              out_features=self.emb_size, bias=True),
                                        nn.BatchNorm1d(self.emb_size))

    def forward(self, input_data):
        output_data = self.conv_4(self.conv_3(self.conv_2(self.conv_1(input_data))))
        return self.layer_last(output_data.view(output_data.size(0), -1))




class NodeUpdateNetwork(nn.Module):
    def __init__(self,
                 in_features,
                 num_features,
                 ratio=[2, 1],
                 dropout=0.0):
        super(NodeUpdateNetwork, self).__init__()
        # set size
        self.in_features = in_features
        self.num_features_list = [num_features * r for r in ratio]
        self.dropout = dropout

        # layers
        layer_list = OrderedDict()
        for l in range(len(self.num_features_list)):

            layer_list['conv{}'.format(l)] = nn.Conv2d(
                in_channels=self.num_features_list[l - 1] if l > 0 else self.in_features * 3,
                out_channels=self.num_features_list[l],
                kernel_size=1,
                bias=False)
            layer_list['norm{}'.format(l)] = nn.BatchNorm2d(num_features=self.num_features_list[l],
                                                            )
            layer_list['relu{}'.format(l)] = nn.LeakyReLU()

            if self.dropout > 0 and l == (len(self.num_features_list) - 1):
                layer_list['drop{}'.format(l)] = nn.Dropout2d(p=self.dropout)

        self.network = nn.Sequential(layer_list)

    def forward(self, node_feat, edge_feat):
        # get size
        num_tasks = node_feat.size(0)
        num_data = node_feat.size(1)

        # get eye matrix (batch_size x 2 x node_size x node_size)
        diag_mask = 1.0 - torch.eye(num_data).unsqueeze(0).unsqueeze(0).repeat(num_tasks, 2, 1, 1).to(tt.arg.device)

        # set diagonal as zero and normalize
        edge_feat = F.normalize(edge_feat * diag_mask, p=1, dim=-1)

        # compute attention and aggregate
        aggr_feat = torch.bmm(torch.cat(torch.split(edge_feat, 1, 1), 2).squeeze(1), node_feat)

        node_feat = torch.cat([node_feat, torch.cat(aggr_feat.split(num_data, 1), -1)], -1).transpose(1, 2)

        # non-linear transform
        node_feat = self.network(node_feat.unsqueeze(-1)).transpose(1, 2).squeeze(-1)
        return node_feat


class EdgeUpdateNetwork(nn.Module):
    def __init__(self,
                 in_features,
                 num_features,
                 ratio=[2, 2, 1, 1],
                 separate_dissimilarity=False,
                 dropout=0.0):
        super(EdgeUpdateNetwork, self).__init__()
        # set size
        self.in_features = in_features
        self.num_features_list = [num_features * r for r in ratio]
        self.separate_dissimilarity = separate_dissimilarity
        self.dropout = dropout

        # layers
        layer_list = OrderedDict()
        for l in range(len(self.num_features_list)):
            # set layer
            layer_list['conv{}'.format(l)] = nn.Conv2d(in_channels=self.num_features_list[l-1] if l > 0 else self.in_features,
                                                       out_channels=self.num_features_list[l],
                                                       kernel_size=1,
                                                       bias=False)
            layer_list['norm{}'.format(l)] = nn.BatchNorm2d(num_features=self.num_features_list[l],
                                                            )
            layer_list['relu{}'.format(l)] = nn.LeakyReLU()

            if self.dropout > 0:
                layer_list['drop{}'.format(l)] = nn.Dropout2d(p=self.dropout)

        layer_list['conv_out'] = nn.Conv2d(in_channels=self.num_features_list[-1],
                                           out_channels=1,
                                           kernel_size=1)
        self.sim_network = nn.Sequential(layer_list)

        if self.separate_dissimilarity:
            # layers
            layer_list = OrderedDict()
            for l in range(len(self.num_features_list)):
                # set layer
                layer_list['conv{}'.format(l)] = nn.Conv2d(in_channels=self.num_features_list[l-1] if l > 0 else self.in_features,
                                                           out_channels=self.num_features_list[l],
                                                           kernel_size=1,
                                                           bias=False)
                layer_list['norm{}'.format(l)] = nn.BatchNorm2d(num_features=self.num_features_list[l],
                                                                )
                layer_list['relu{}'.format(l)] = nn.LeakyReLU()

                if self.dropout > 0:
                    layer_list['drop{}'.format(l)] = nn.Dropout(p=self.dropout)

            layer_list['conv_out'] = nn.Conv2d(in_channels=self.num_features_list[-1],
                                               out_channels=1,
                                               kernel_size=1)
            self.dsim_network = nn.Sequential(layer_list)

    def forward(self, node_feat, edge_feat):
        # compute abs(x_i, x_j)
        x_i = node_feat.unsqueeze(2)
        x_j = torch.transpose(x_i, 1, 2)
        x_ij = torch.abs(x_i - x_j)
        x_ij = torch.transpose(x_ij, 1, 3)

        # compute similarity/dissimilarity (batch_size x feat_size x num_samples x num_samples)
        sim_val = F.sigmoid(self.sim_network(x_ij))

        if self.separate_dissimilarity:
            dsim_val = F.sigmoid(self.dsim_network(x_ij))
        else:
            dsim_val = 1.0 - sim_val


        diag_mask = 1.0 - torch.eye(node_feat.size(1)).unsqueeze(0).unsqueeze(0).repeat(node_feat.size(0), 2, 1, 1).to(tt.arg.device)
        edge_feat = edge_feat * diag_mask
        merge_sum = torch.sum(edge_feat, -1, True)
        # set diagonal as zero and normalize
        edge_feat = F.normalize(torch.cat([sim_val, dsim_val], 1) * edge_feat, p=1, dim=-1) * merge_sum
        force_edge_feat = torch.cat((torch.eye(node_feat.size(1)).unsqueeze(0), torch.zeros(node_feat.size(1), node_feat.size(1)).unsqueeze(0)), 0).unsqueeze(0).repeat(node_feat.size(0), 1, 1, 1).to(tt.arg.device)
        edge_feat = edge_feat + force_edge_feat
        edge_feat = edge_feat + 1e-6
        edge_feat = edge_feat / torch.sum(edge_feat, dim=1).unsqueeze(1).repeat(1, 2, 1, 1)

        return edge_feat


class GraphNetwork(nn.Module):
    def __init__(self,
                 in_features,
                 node_features,
                 edge_features,
                 num_layers,
                 dropout=0.0):
        super(GraphNetwork, self).__init__()
        # set size
        self.in_features = in_features
        self.node_features = node_features
        self.edge_features = edge_features
        self.num_layers = num_layers
        self.dropout = dropout

        # for each layer
        for l in range(self.num_layers):
            # set edge to node
            edge2node_net = NodeUpdateNetwork(in_features=self.in_features if l == 0 else self.node_features,
                                              num_features=self.node_features,
                                              dropout=self.dropout if l < self.num_layers-1 else 0.0)

            # set node to edge
            node2edge_net = EdgeUpdateNetwork(in_features=self.node_features,
                                              num_features=self.edge_features,
                                              separate_dissimilarity=False,
                                              dropout=self.dropout if l < self.num_layers-1 else 0.0)

            self.add_module('edge2node_net{}'.format(l), edge2node_net)
            self.add_module('node2edge_net{}'.format(l), node2edge_net)

    # forward
    def forward(self, node_feat, edge_feat):
        # for each layer
        edge_feat_list = []
        for l in range(self.num_layers):
            # (1) edge to node
            node_feat = self._modules['edge2node_net{}'.format(l)](node_feat, edge_feat)

            # (2) node to edge
            edge_feat = self._modules['node2edge_net{}'.format(l)](node_feat, edge_feat)

            # save edge feature
            edge_feat_list.append(edge_feat)

        # if tt.arg.visualization:
        #     for l in range(self.num_layers):
        #         ax = sns.heatmap(tt.nvar(edge_feat_list[l][0, 0, :, :]), xticklabels=False, yticklabels=False, linewidth=0.1,  cmap="coolwarm",  cbar=False, square=True)
        #         ax.get_figure().savefig('./visualization/edge_feat_layer{}.png'.format(l))


        return edge_feat_list



================================================
FILE: torchtools/__init__.py
================================================
import numpy as np
import torch
from torch import nn
from torch import optim
from torch import cuda
from torch import utils
from torch.nn import functional as F
from torch.utils.data import *
from torch.distributions import *
from torchtools import tt


__author__ = 'namju.kim@kakaobrain.com'


# initialize seed
if tt.arg.seed:
    np.random.seed(tt.arg.seed)
    torch.manual_seed(tt.arg.seed)


================================================
FILE: torchtools/_version.py
================================================
__version__ = '0.4.0'  # align version with pytorch



================================================
FILE: torchtools/tt/__init__.py
================================================
from torchtools.tt.arg import _parse_opts
from torchtools.tt.utils import *
from torchtools.tt.layer import *
from torchtools.tt.logger import *
from torchtools.tt.stat import *
from torchtools.tt.trainer import *


__author__ = 'namju.kim@kakaobrain.com'


# global command line arguments
arg = _parse_opts()


================================================
FILE: torchtools/tt/arg.py
================================================
import sys
import configparser
import torch
import threading
import time
import os


__author__ = 'namju.kim@kakaobrain.com'


_config_time_stamp = 0


class _Opt(object):

    def __len__(self):
        return len(self.__dict__)

    def __setitem__(self, key, value):
        self.__dict__[key] = value

    def __getitem__(self, item):
        if item in self.__dict__:
            return self.__dict__[item]
        else:
            return None

    def __getattr__(self, item):
        return self.__getitem__(item)


def _to_py_obj(x):
    # check boolean first
    if x.lower() in ['true', 'yes', 'on']:
        return True
    if x.lower() in ['false', 'no', 'off']:
        return False
    # from string to python object if possible
    try:
        obj = eval(x)
        if type(obj).__name__ in ['int', 'float', 'tuple', 'list', 'dict', 'NoneType']:
            x = obj
    except:
        pass
    return x


def _parse_config(arg, file):

    # read config file
    config = configparser.ConfigParser()
    config.read(file)
    # traverse sections
    for section in config.sections():
        # traverse items
        opt = _Opt()
        for key in config[section]:
            opt[key] = _to_py_obj(config[section][key])
        # if default section, save items to global scope
        if section.lower() == 'default':
            for k, v in opt.__dict__.items():
                arg[k] = v
        else:
            arg['_'.join(section.split())] = opt


def _parse_config_thread(arg, file):

    global _config_time_stamp

    while True:
        # check timestamp
        stamp = os.stat(file).st_mtime
        if not stamp == _config_time_stamp:
            # update timestamp
            _config_time_stamp = stamp
            # parse config file
            _parse_config(arg, file)
            # print result
            # _print_opts(arg, 'CONFIGURATION CHANGE DETECTED')
        # sleep
        time.sleep(1)


def _print_opts(arg, header):
    print(header, flush=True)
    print('-' * 30, flush=True)
    for k, v in arg.__dict__.items():
        print('%s=%s' % (k, v), flush=True)
    print('-' * 30, flush=True)


def _parse_opts():

    global _config_time_stamp

    # get command line arguments
    arg = _Opt()
    argv = sys.argv[1:]

    # check length
    assert len(argv) % 2 == 0, 'arguments should be paired with the format of --key value'

    # parse args
    for i in range(0, len(argv), 2):

        # check format
        assert argv[i].startswith('--'), 'arguments should be paired with the format of --key value'

        # save argument
        arg[argv[i][2:]] = _to_py_obj(argv[i + 1])

        # check config file
        if argv[i][2:].lower() == 'config':
            _parse_config(arg, argv[i + 1])
            _config_time_stamp = os.stat(argv[i + 1]).st_mtime

    #
    # inject default options
    #

    # device setting
    if arg.device is None:
        arg.device = 'cuda' if torch.cuda.is_available() else 'cpu'
    arg.device = torch.device(arg.device)
    arg.cuda = arg.device.type == 'cuda'

    # default learning rate
    #arg.lr = 1e-3

    # directories
    arg.log_dir = arg.log_dir or 'asset/log/'
    arg.data_dir = arg.data_dir or 'asset/data/'
    arg.save_dir = arg.save_dir or 'asset/train/'
    arg.log_dir += '' if arg.log_dir.endswith('/') else '/'
    arg.data_dir += '' if arg.data_dir.endswith('/') else '/'
    arg.save_dir += '' if arg.save_dir.endswith('/') else '/'

    # print arg option
    # _print_opts(arg, 'CONFIGURATION')

    # start config file watcher if config is defined
    if arg.config:
        t = threading.Thread(target=_parse_config_thread, args=(arg, arg.config))
        t.daemon = True
        t.start()

    return arg


================================================
FILE: torchtools/tt/layer.py
================================================
from torchtools import nn


#
# Reshape layer for Sequential or ModuleList
#
class Reshape(nn.Module):

    def __init__(self, *shape):
        super(Reshape, self).__init__()
        self.shape = shape

    def forward(self, x):
        return x.reshape(self.shape)

    def extra_repr(self):
        return 'shape={}'.format(self.shape)

================================================
FILE: torchtools/tt/logger.py
================================================
import datetime
import time
from tensorboardX import SummaryWriter
from torchtools import tt


__author__ = 'namju.kim@kakaobrain.com'


# tensorboard writer
_writer = None
_stats_scalar, _stats_image, _stats_audio, _stats_text, _stats_hist = {}, {}, {}, {}, {}

# time stamp
_last_logged = time.time()


# general print wrapper
def log(*args):
    print(*args, flush=True)
    # save to log_file
    if tt.arg.log_file:
        with open(tt.arg.log_dir + tt.arg.log_file, 'a') as f:
            print(*args, flush=True, file=f)


# tensor board writer
def _get_writer():
    global _writer
    if _writer is None:
        # logging directory
        tf_log_dir = tt.arg.log_dir
        tf_log_dir += '' if tf_log_dir.endswith('/') else '/'
        if tt.arg.experiment:
            tf_log_dir += tt.arg.experiment
        tf_log_dir += datetime.datetime.now().strftime('-%Y%m%d-%H%M%S')
        # create writer
        _writer = SummaryWriter(tf_log_dir)
    return _writer


def log_scalar(tag, value, global_step=None):
    _stats_scalar[tag] = (tt.nvar(value), global_step)


def log_audio(tag, audio, global_step=None):
    _stats_audio[tag] = (tt.nvar(audio), global_step)


def log_image(tag, image, global_step=None):
    _stats_image[tag] = (tt.nvar(image), global_step)


def log_text(tag, text, global_step=None):
    _stats_text[tag] = (text, global_step)


def log_hist(tag, values, global_step=None):
    _stats_hist[tag] = (tt.nvar(values), global_step)


def log_step(epoch=None, global_step=None, max_epoch=None, max_step=None):

    global _last_logged, _last_logged_step, _stats_scalar, _stats_image, _stats_audio, _stats_text, _stats_hist

    # logging
    if (tt.arg.log_interval is None and tt.arg.log_step is None) or \
       (tt.arg.log_interval and time.time() - _last_logged >= tt.arg.log_interval) or \
       (tt.arg.log_step and global_step % tt.arg.log_step == 0):

        # update logging time stamp
        _last_logged = time.time()
        _last_logged_step = global_step

        # console output string
        console_out = ''
        if epoch:
            console_out += 'ep: %d' % epoch
            if max_epoch:
                console_out += '/%d' % max_epoch
        if global_step:
            if max_step:
                step = global_step % max_step
                step = max_step if step == 0 else step
                console_out += ' step: %d/%d' % (step, max_step)
            else:
                console_out += ' step: %d' % global_step

        # add stats to tensor board
        for k, v in _stats_scalar.items():
            _get_writer().add_scalar(k, *v)
            # add to console output
            if not k.startswith('weight/') and not k.startswith('gradient/'):
                console_out += ' %s: %f' % (k, v[0])
        for k, v in _stats_image.items():
            _get_writer().add_image(k, *v)
        for k, v in _stats_audio.items():
            _get_writer().add_audio(k, *v)
        for k, v in _stats_text.items():
            _get_writer().add_text(k, *v)
        for k, v in _stats_hist.items():
            _get_writer().add_histogram(k, *v, 'auto')

        # flush
        _get_writer().file_writer.flush()

        # console out
        if len(console_out) > 0:
            log(console_out)

        # clear stats
        _stats_scalar, _stats_image, _stats_audio, _stats_text = {}, {}, {}, {}


def log_weight(model, global_step=None):
    # weight statics
    if tt.arg.log_weight:
        for k, v in model.named_parameters():
            if 'weight' in k:  # only for weight not bias
                log_scalar('weight/' + k, v.norm(), global_step)


def log_gradient(model, global_step=None):
    # gradient statics
    if tt.arg.log_grad:
        for k, v in model.named_parameters():
            if 'weight' in k:  # only for weight not bias
                if v.grad is not None:
                    log_scalar('gradient/' + k, v.grad.norm(), global_step)


================================================
FILE: torchtools/tt/stat.py
================================================
from torchtools import tt


__author__ = 'namju.kim@kakaobrain.com'


def accuracy(prob, label, ignore_index=-100):

    # argmax
    pred = prob.max(1)[1].type_as(label)

    # masking
    mask = label.ne(ignore_index)
    pred = pred.masked_select(mask)
    label = label.masked_select(mask)

    # calc accuracy
    hit = tt.nvar(pred.eq(label).long().sum())
    acc = hit / label.size(0)
    return acc


================================================
FILE: torchtools/tt/trainer.py
================================================
from torchtools import nn, optim, tt


__author__ = 'namju.kim@kakaobrain.com'


class SupervisedTrainer(object):

    def __init__(self, model, data_loader, optimizer=None, criterion=None):
        self.global_step = 0
        self.model = model.to(tt.arg.device)
        self.data_loader = data_loader
        self.optimizer = optimizer or optim.Adam(model.parameters())
        self.criterion = criterion or nn.CrossEntropyLoss()

    def train(self, inputs):

        # split inputs
        x, y = inputs

        # forward
        if tt.arg.cuda:
            z = nn.DataParallel(self.model)(x)
        else:
            z = self.model(x)

        # loss
        loss = self.criterion(z, y)

        # accuracy
        acc = tt.accuracy(z, y)

        # update model
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        # logging
        tt.log_scalar('loss', loss, self.global_step)
        tt.log_scalar('acc', acc, self.global_step)

    def epoch(self, ep_no=None):
        pass

    def run(self):

        # experiment name
        tt.arg.experiment = tt.arg.experiment or self.model.__class__.__name__.lower()

        # load model
        self.global_step = self.model.load_model()
        epoch, min_step = divmod(self.global_step, len(self.data_loader))

        # epochs
        while epoch < (tt.arg.epoch or 1):
            epoch += 1

            # iterations
            for step, inputs in enumerate(self.data_loader, min_step + 1):

                # check step counter
                if step > len(self.data_loader):
                    break

                # increase global step count
                self.global_step += 1

                # update learning rate
                for param_group in self.optimizer.param_groups:
                    param_group['lr'] = tt.arg.lr

                # call train func
                if type(inputs) in [list, tuple]:
                    self.train([tt.var(d) for d in inputs])
                else:
                    self.train(tt.var(inputs))

                # logging
                tt.log_weight(self.model, global_step=self.global_step)
                tt.log_gradient(self.model, global_step=self.global_step)
                tt.log_step(epoch=epoch, global_step=self.global_step,
                            max_epoch=(tt.arg.epoch or 1), max_step=len(self.data_loader))

                # save model
                self.model.save_model(self.global_step)

            # epoch handler
            self.epoch(epoch)

        # save final model
        self.model.save_model(self.global_step, force=True)


================================================
FILE: torchtools/tt/utils.py
================================================
import os
import datetime
import time
import pathlib
from torchtools import torch, nn, tt


__author__ = 'namju.kim@kakaobrain.com'


# time stamp
_tic_start = _last_saved = _last_archived = time.time()
# best statics
_best = -100000000.


def tic():
    global _tic_start
    _tic_start = time.time()
    return _tic_start


def toc(tic=None):
    global _tic_start
    if tic is None:
        return time.time() - _tic_start
    else:
        return time.time() - tic


def sleep(seconds):
    time.sleep(seconds)


#
# automatic device-aware torch.tensor
#
def var(data, dtype=None, device=None, requires_grad=False):
    # return torch.tensor(data, dtype=dtype, device=(device or tt.arg.device), requires_grad=requires_grad)
    # the upper code doesn't work, so work around as following. ( maybe bug )
    return torch.tensor(data, dtype=dtype, requires_grad=requires_grad).to((device or tt.arg.device))


def vars(x_list, dtype=None, device=None, requires_grad=False):
    return [var(x, dtype, device, requires_grad) for x in x_list]


# for old torchtools compatibility
def cvar(x):
    return x.detach()


#
# to python or numpy variable(s)
#
def nvar(x):
    if isinstance(x, torch.Tensor):
        x = x.detach().cpu()
        x = x.item() if x.dim() == 0 else x.numpy()
    return x


def nvars(x_list):
    return [nvar(x) for x in x_list]


def load_model(model, best=False, postfix=None, experiment=None):
    global _best

    # model file name
    filename = tt.arg.save_dir + '%s.pt' % (experiment or tt.arg.experiment or model.__class__.__name__.lower())
    if postfix is not None:
        filename = filename + '.%s' % postfix

    # load model
    global_step = 0
    if os.path.exists(filename):
        if best:
            global_step, model_state, _best = torch.load(filename + '.best', map_location=lambda storage, loc: storage)
        else:
            global_step, model_state = torch.load(filename, map_location=lambda storage, loc: storage)
        model.load_state_dict(model_state)

    # update best stat
    filename += '.best'
    if os.path.exists(filename):
        _, _, _best = torch.load(filename, map_location=lambda storage, loc: storage)

    return global_step


def save_model(model, global_step, force=False, best=None, postfix=None):
    global _last_saved, _last_archived, _best

    # make directory
    pathlib.Path(tt.arg.save_dir).mkdir(parents=True, exist_ok=True)

    # filename to save
    filename = '%s.pt' % (tt.arg.experiment or model.__class__.__name__.lower())
    if postfix is not None:
        filename = filename + '.%s' % postfix

    # save model
    if force or (tt.arg.save_interval and time.time() - _last_saved >= tt.arg.save_interval) or \
       (tt.arg.save_step and global_step % tt.arg.save_step == 0):
        torch.save((global_step, model.state_dict()), tt.arg.save_dir + filename)
        _last_saved = time.time()

    # archive model
    if (tt.arg.archive_interval and time.time() - _last_archived >= tt.arg.archive_interval) or \
       (tt.arg.archive_step and global_step % tt.arg.archive_step == 0):
        # filename to archive
        if tt.arg.archive_interval:
            filename = filename + datetime.datetime.now().strftime('.%Y%m%d.%H%M%S')
        else:
            filename = filename + '.%d' % global_step
        torch.save((global_step, model.state_dict()), tt.arg.save_dir + filename)
        _last_archived = time.time()

    # save best model
    if best is not None and best > _best:
        _best = best
        filename = filename + '.best'
        torch.save((global_step, model.state_dict(), best), tt.arg.save_dir + filename)


# patch Module
nn.Module.load_model = load_model
nn.Module.save_model = save_model


================================================
FILE: train.py
================================================
from torchtools import *
from data import MiniImagenetLoader, TieredImagenetLoader
from model import EmbeddingImagenet, GraphNetwork, ConvNet
import shutil
import os
import random
#import seaborn as sns


class ModelTrainer(object):
    def __init__(self,
                 enc_module,
                 gnn_module,
                 data_loader):
        # set encoder and gnn
        self.enc_module = enc_module.to(tt.arg.device)
        self.gnn_module = gnn_module.to(tt.arg.device)

        if tt.arg.num_gpus > 1:
            print('Construct multi-gpu model ...')
            self.enc_module = nn.DataParallel(self.enc_module, device_ids=[0, 1, 2, 3], dim=0)
            self.gnn_module = nn.DataParallel(self.gnn_module, device_ids=[0, 1, 2, 3], dim=0)

            print('done!\n')

        # get data loader
        self.data_loader = data_loader

        # set optimizer
        self.module_params = list(self.enc_module.parameters()) + list(self.gnn_module.parameters())

        # set optimizer
        self.optimizer = optim.Adam(params=self.module_params,
                                    lr=tt.arg.lr,
                                    weight_decay=tt.arg.weight_decay)

        # set loss
        self.edge_loss = nn.BCELoss(reduction='none')

        self.node_loss = nn.CrossEntropyLoss(reduction='none')

        self.global_step = 0
        self.val_acc = 0
        self.test_acc = 0

    def train(self):
        val_acc = self.val_acc

        # set edge mask (to distinguish support and query edges)
        num_supports = tt.arg.num_ways_train * tt.arg.num_shots_train
        num_queries = tt.arg.num_ways_train * 1
        num_samples = num_supports + num_queries
        support_edge_mask = torch.zeros(tt.arg.meta_batch_size, num_samples, num_samples).to(tt.arg.device)
        support_edge_mask[:, :num_supports, :num_supports] = 1
        query_edge_mask = 1 - support_edge_mask

        evaluation_mask = torch.ones(tt.arg.meta_batch_size, num_samples, num_samples).to(tt.arg.device)
        # for semi-supervised setting, ignore unlabeled support sets for evaluation
        for c in range(tt.arg.num_ways_train):
            evaluation_mask[:,
            ((c + 1) * tt.arg.num_shots_train - tt.arg.num_unlabeled):(c + 1) * tt.arg.num_shots_train,
            :num_supports] = 0
            evaluation_mask[:, :num_supports,
            ((c + 1) * tt.arg.num_shots_train - tt.arg.num_unlabeled):(c + 1) * tt.arg.num_shots_train] = 0

        # for each iteration
        for iter in range(self.global_step + 1, tt.arg.train_iteration + 1):
            # init grad
            self.optimizer.zero_grad()

            # set current step
            self.global_step = iter

            # load task data list
            [support_data,
             support_label,
             query_data,
             query_label] = self.data_loader['train'].get_task_batch(num_tasks=tt.arg.meta_batch_size,
                                                                     num_ways=tt.arg.num_ways_train,
                                                                     num_shots=tt.arg.num_shots_train,
                                                                     seed=iter + tt.arg.seed)

            # set as single data
            full_data = torch.cat([support_data, query_data], 1)
            full_label = torch.cat([support_label, query_label], 1)
            full_edge = self.label2edge(full_label)

            # set init edge
            init_edge = full_edge.clone()  # batch_size x 2 x num_samples x num_samples
            init_edge[:, :, num_supports:, :] = 0.5
            init_edge[:, :, :, num_supports:] = 0.5
            for i in range(num_queries):
                init_edge[:, 0, num_supports + i, num_supports + i] = 1.0
                init_edge[:, 1, num_supports + i, num_supports + i] = 0.0

            # for semi-supervised setting,
            for c in range(tt.arg.num_ways_train):
                init_edge[:, :, ((c+1) * tt.arg.num_shots_train - tt.arg.num_unlabeled):(c+1) * tt.arg.num_shots_train, :num_supports] = 0.5
                init_edge[:, :, :num_supports, ((c+1) * tt.arg.num_shots_train - tt.arg.num_unlabeled):(c+1) * tt.arg.num_shots_train] = 0.5

            # set as train mode
            self.enc_module.train()
            self.gnn_module.train()

            # (1) encode data
            full_data = [self.enc_module(data.squeeze(1)) for data in full_data.chunk(full_data.size(1), dim=1)]
            full_data = torch.stack(full_data, dim=1) # batch_size x num_samples x featdim

            # (2) predict edge logit (consider only the last layer logit, num_tasks x 2 x num_samples x num_samples)
            if tt.arg.train_transductive:
                full_logit_layers = self.gnn_module(node_feat=full_data, edge_feat=init_edge)
            else:
                evaluation_mask[:, num_supports:, num_supports:] = 0 # ignore query-query edges, since it is non-transductive setting
                # input_node_feat: (batch_size x num_queries) x (num_support + 1) x featdim
                # input_edge_feat: (batch_size x num_queries) x 2 x (num_support + 1) x (num_support + 1)
                support_data = full_data[:, :num_supports] # batch_size x num_support x featdim
                query_data = full_data[:, num_supports:] # batch_size x num_query x featdim
                support_data_tiled = support_data.unsqueeze(1).repeat(1, num_queries, 1, 1) # batch_size x num_queries x num_support x featdim
                support_data_tiled = support_data_tiled.view(tt.arg.meta_batch_size * num_queries, num_supports, -1) # (batch_size x num_queries) x num_support x featdim
                query_data_reshaped = query_data.contiguous().view(tt.arg.meta_batch_size * num_queries, -1).unsqueeze(1) # (batch_size x num_queries) x 1 x featdim
                input_node_feat = torch.cat([support_data_tiled, query_data_reshaped], 1) # (batch_size x num_queries) x (num_support + 1) x featdim

                input_edge_feat = 0.5 * torch.ones(tt.arg.meta_batch_size, 2, num_supports + 1, num_supports + 1).to(tt.arg.device) # batch_size x 2 x (num_support + 1) x (num_support + 1)

                input_edge_feat[:, :, :num_supports, :num_supports] = init_edge[:, :, :num_supports, :num_supports] # batch_size x 2 x (num_support + 1) x (num_support + 1)
                input_edge_feat = input_edge_feat.repeat(num_queries, 1, 1, 1) #(batch_size x num_queries) x 2 x (num_support + 1) x (num_support + 1)

                # logit: (batch_size x num_queries) x 2 x (num_support + 1) x (num_support + 1)
                logit_layers = self.gnn_module(node_feat=input_node_feat, edge_feat=input_edge_feat)

                logit_layers = [logit_layer.view(tt.arg.meta_batch_size, num_queries, 2, num_supports + 1, num_supports + 1) for logit_layer in logit_layers]

                # logit --> full_logit (batch_size x 2 x num_samples x num_samples)
                full_logit_layers = []
                for l in range(tt.arg.num_layers):
                    full_logit_layers.append(torch.zeros(tt.arg.meta_batch_size, 2, num_samples, num_samples).to(tt.arg.device))

                for l in range(tt.arg.num_layers):
                    full_logit_layers[l][:, :, :num_supports, :num_supports] = logit_layers[l][:, :, :, :num_supports, :num_supports].mean(1)
                    full_logit_layers[l][:, :, :num_supports, num_supports:] = logit_layers[l][:, :, :, :num_supports, -1].transpose(1, 2).transpose(2, 3)
                    full_logit_layers[l][:, :, num_supports:, :num_supports] = logit_layers[l][:, :, :, -1, :num_supports].transpose(1, 2)

            # (4) compute loss
            full_edge_loss_layers = [self.edge_loss((1-full_logit_layer[:, 0]), (1-full_edge[:, 0])) for full_logit_layer in full_logit_layers]

            # weighted edge loss for balancing pos/neg
            pos_query_edge_loss_layers = [torch.sum(full_edge_loss_layer * query_edge_mask * full_edge[:, 0] * evaluation_mask) / torch.sum(query_edge_mask * full_edge[:, 0] * evaluation_mask) for full_edge_loss_layer in full_edge_loss_layers]
            neg_query_edge_loss_layers = [torch.sum(full_edge_loss_layer * query_edge_mask * (1-full_edge[:, 0]) * evaluation_mask) / torch.sum(query_edge_mask * (1-full_edge[:, 0]) * evaluation_mask) for full_edge_loss_layer in full_edge_loss_layers]
            query_edge_loss_layers = [pos_query_edge_loss_layer + neg_query_edge_loss_layer for (pos_query_edge_loss_layer, neg_query_edge_loss_layer) in zip(pos_query_edge_loss_layers, neg_query_edge_loss_layers)]

            # compute accuracy
            full_edge_accr_layers = [self.hit(full_logit_layer, 1-full_edge[:, 0].long()) for full_logit_layer in full_logit_layers]
            query_edge_accr_layers = [torch.sum(full_edge_accr_layer * query_edge_mask * evaluation_mask) / torch.sum(query_edge_mask * evaluation_mask) for full_edge_accr_layer in full_edge_accr_layers]

            # compute node loss & accuracy (num_tasks x num_quries x num_ways)
            query_node_pred_layers = [torch.bmm(full_logit_layer[:, 0, num_supports:, :num_supports], self.one_hot_encode(tt.arg.num_ways_train, support_label.long())) for full_logit_layer in full_logit_layers] # (num_tasks x num_quries x num_supports) * (num_tasks x num_supports x num_ways)
            query_node_accr_layers = [torch.eq(torch.max(query_node_pred_layer, -1)[1], query_label.long()).float().mean() for query_node_pred_layer in query_node_pred_layers]

            total_loss_layers = query_edge_loss_layers

            # update model
            total_loss = []
            for l in range(tt.arg.num_layers - 1):
                total_loss += [total_loss_layers[l].view(-1) * 0.5]
            total_loss += [total_loss_layers[-1].view(-1) * 1.0]
            total_loss = torch.mean(torch.cat(total_loss, 0))

            total_loss.backward()

            self.optimizer.step()

            # adjust learning rate
            self.adjust_learning_rate(optimizers=[self.optimizer],
                                      lr=tt.arg.lr,
                                      iter=self.global_step)

            # logging
            tt.log_scalar('train/edge_loss', query_edge_loss_layers[-1], self.global_step)
            tt.log_scalar('train/edge_accr', query_edge_accr_layers[-1], self.global_step)
            tt.log_scalar('train/node_accr', query_node_accr_layers[-1], self.global_step)

            # evaluation
            if self.global_step % tt.arg.test_interval == 0:
                val_acc = self.eval(partition='val')

                is_best = 0

                if val_acc >= self.val_acc:
                    self.val_acc = val_acc
                    is_best = 1

                tt.log_scalar('val/best_accr', self.val_acc, self.global_step)

                self.save_checkpoint({
                    'iteration': self.global_step,
                    'enc_module_state_dict': self.enc_module.state_dict(),
                    'gnn_module_state_dict': self.gnn_module.state_dict(),
                    'val_acc': val_acc,
                    'optimizer': self.optimizer.state_dict(),
                    }, is_best)

            tt.log_step(global_step=self.global_step)

    def eval(self, partition='test', log_flag=True):
        best_acc = 0
        # set edge mask (to distinguish support and query edges)
        num_supports = tt.arg.num_ways_test * tt.arg.num_shots_test
        num_queries = tt.arg.num_ways_test * 1
        num_samples = num_supports + num_queries
        support_edge_mask = torch.zeros(tt.arg.test_batch_size, num_samples, num_samples).to(tt.arg.device)
        support_edge_mask[:, :num_supports, :num_supports] = 1
        query_edge_mask = 1 - support_edge_mask
        evaluation_mask = torch.ones(tt.arg.test_batch_size, num_samples, num_samples).to(tt.arg.device)
        # for semi-supervised setting, ignore unlabeled support sets for evaluation
        for c in range(tt.arg.num_ways_test):
            evaluation_mask[:,
            ((c + 1) * tt.arg.num_shots_test - tt.arg.num_unlabeled):(c + 1) * tt.arg.num_shots_test,
            :num_supports] = 0
            evaluation_mask[:, :num_supports,
            ((c + 1) * tt.arg.num_shots_test - tt.arg.num_unlabeled):(c + 1) * tt.arg.num_shots_test] = 0

        query_edge_losses = []
        query_edge_accrs = []
        query_node_accrs = []

        # for each iteration
        for iter in range(tt.arg.test_iteration//tt.arg.test_batch_size):
            # load task data list
            [support_data,
             support_label,
             query_data,
             query_label] = self.data_loader[partition].get_task_batch(num_tasks=tt.arg.test_batch_size,
                                                                       num_ways=tt.arg.num_ways_test,
                                                                       num_shots=tt.arg.num_shots_test,
                                                                       seed=iter)

            # set as single data
            full_data = torch.cat([support_data, query_data], 1)
            full_label = torch.cat([support_label, query_label], 1)
            full_edge = self.label2edge(full_label)

            # set init edge
            init_edge = full_edge.clone()
            init_edge[:, :, num_supports:, :] = 0.5
            init_edge[:, :, :, num_supports:] = 0.5
            for i in range(num_queries):
                init_edge[:, 0, num_supports + i, num_supports + i] = 1.0
                init_edge[:, 1, num_supports + i, num_supports + i] = 0.0

            # for semi-supervised setting,
            for c in range(tt.arg.num_ways_test):
                init_edge[:, :, ((c+1) * tt.arg.num_shots_test - tt.arg.num_unlabeled):(c+1) * tt.arg.num_shots_test, :num_supports] = 0.5
                init_edge[:, :, :num_supports, ((c+1) * tt.arg.num_shots_test - tt.arg.num_unlabeled):(c+1) * tt.arg.num_shots_test] = 0.5

            # set as train mode
            self.enc_module.eval()
            self.gnn_module.eval()

            # (1) encode data
            full_data = [self.enc_module(data.squeeze(1)) for data in full_data.chunk(full_data.size(1), dim=1)]
            full_data = torch.stack(full_data, dim=1)

            # (2) predict edge logit (consider only the last layer logit, num_tasks x 2 x num_samples x num_samples)
            if tt.arg.test_transductive:
                full_logit_all = self.gnn_module(node_feat=full_data, edge_feat=init_edge)
                full_logit = full_logit_all[-1]
            else:
                evaluation_mask[:, num_supports:, num_supports:] = 0  # ignore query-query edges, since it is non-transductive setting

                full_logit = torch.zeros(tt.arg.test_batch_size, 2, num_samples, num_samples).to(tt.arg.device)

                # input_node_feat: (batch_size x num_queries) x (num_support + 1) x featdim
                # input_edge_feat: (batch_size x num_queries) x 2 x (num_support + 1) x (num_support + 1)
                support_data = full_data[:, :num_supports] # batch_size x num_support x featdim
                query_data = full_data[:, num_supports:] # batch_size x num_query x featdim
                support_data_tiled = support_data.unsqueeze(1).repeat(1, num_queries, 1, 1) # batch_size x num_queries x num_support x featdim
                support_data_tiled = support_data_tiled.view(tt.arg.test_batch_size * num_queries, num_supports, -1) # (batch_size x num_queries) x num_support x featdim
                query_data_reshaped = query_data.contiguous().view(tt.arg.test_batch_size * num_queries, -1).unsqueeze(1) # (batch_size x num_queries) x 1 x featdim
                input_node_feat = torch.cat([support_data_tiled, query_data_reshaped], 1) # (batch_size x num_queries) x (num_support + 1) x featdim

                input_edge_feat = 0.5 * torch.ones(tt.arg.test_batch_size, 2, num_supports + 1, num_supports + 1).to(tt.arg.device)  # batch_size x 2 x (num_support + 1) x (num_support + 1)

                input_edge_feat[:, :, :num_supports, :num_supports] = init_edge[:, :, :num_supports, :num_supports]  # batch_size x 2 x (num_support + 1) x (num_support + 1)
                input_edge_feat = input_edge_feat.repeat(num_queries, 1, 1, 1)  # (batch_size x num_queries) x 2 x (num_support + 1) x (num_support + 1)

                # logit: (batch_size x num_queries) x 2 x (num_support + 1) x (num_support + 1)
                logit = self.gnn_module(node_feat=input_node_feat, edge_feat=input_edge_feat)[-1]

                logit = logit.view(tt.arg.test_batch_size, num_queries, 2, num_supports + 1, num_supports + 1)

                # batch_size x num_queries x 2 x (num_support + 1) x (num_support + 1)
                # logit --> full_logit (batch_size x 2 x num_samples x num_samples)
                full_logit[:, :, :num_supports, :num_supports] = logit[:, :, :, :num_supports, :num_supports].mean(1)
                full_logit[:, :, :num_supports, num_supports:] = logit[:, :, :, :num_supports, -1].transpose(1, 2).transpose(2, 3)
                full_logit[:, :, num_supports:, :num_supports] = logit[:, :, :, -1, :num_supports].transpose(1, 2)

            # (4) compute loss
            full_edge_loss = self.edge_loss(1-full_logit[:, 0], 1-full_edge[:, 0])

            query_edge_loss =  torch.sum(full_edge_loss * query_edge_mask * evaluation_mask) / torch.sum(query_edge_mask * evaluation_mask)

            # weighted loss for balancing pos/neg
            pos_query_edge_loss = torch.sum(full_edge_loss * query_edge_mask * full_edge[:, 0] * evaluation_mask) / torch.sum(query_edge_mask * full_edge[:, 0] * evaluation_mask)
            neg_query_edge_loss = torch.sum(full_edge_loss * query_edge_mask * (1-full_edge[:, 0]) * evaluation_mask) / torch.sum(query_edge_mask * (1-full_edge[:, 0]) * evaluation_mask)
            query_edge_loss = pos_query_edge_loss + neg_query_edge_loss

            # compute accuracy
            full_edge_accr = self.hit(full_logit, 1-full_edge[:, 0].long())
            query_edge_accr = torch.sum(full_edge_accr * query_edge_mask * evaluation_mask) / torch.sum(query_edge_mask * evaluation_mask)

            # compute node accuracy (num_tasks x num_quries x num_ways)
            query_node_pred = torch.bmm(full_logit[:, 0, num_supports:, :num_supports], self.one_hot_encode(tt.arg.num_ways_test, support_label.long())) # (num_tasks x num_quries x num_supports) * (num_tasks x num_supports x num_ways)
            query_node_accr = torch.eq(torch.max(query_node_pred, -1)[1], query_label.long()).float().mean()

            query_edge_losses += [query_edge_loss.item()]
            query_edge_accrs += [query_edge_accr.item()]
            query_node_accrs += [query_node_accr.item()]

        # logging
        if log_flag:
            tt.log('---------------------------')
            tt.log_scalar('{}/edge_loss'.format(partition), np.array(query_edge_losses).mean(), self.global_step)
            tt.log_scalar('{}/edge_accr'.format(partition), np.array(query_edge_accrs).mean(), self.global_step)
            tt.log_scalar('{}/node_accr'.format(partition), np.array(query_node_accrs).mean(), self.global_step)

            tt.log('evaluation: total_count=%d, accuracy: mean=%.2f%%, std=%.2f%%, ci95=%.2f%%' %
                   (iter,
                    np.array(query_node_accrs).mean() * 100,
                    np.array(query_node_accrs).std() * 100,
                    1.96 * np.array(query_node_accrs).std() / np.sqrt(float(len(np.array(query_node_accrs)))) * 100))
            tt.log('---------------------------')

        return np.array(query_node_accrs).mean()

    def adjust_learning_rate(self, optimizers, lr, iter):
        new_lr = lr * (0.5 ** (int(iter / tt.arg.dec_lr)))

        for optimizer in optimizers:
            for param_group in optimizer.param_groups:
                param_group['lr'] = new_lr

    def label2edge(self, label):
        # get size
        num_samples = label.size(1)

        # reshape
        label_i = label.unsqueeze(-1).repeat(1, 1, num_samples)
        label_j = label_i.transpose(1, 2)

        # compute edge
        edge = torch.eq(label_i, label_j).float().to(tt.arg.device)

        # expand
        edge = edge.unsqueeze(1)
        edge = torch.cat([edge, 1 - edge], 1)
        return edge

    def hit(self, logit, label):
        pred = logit.max(1)[1]
        hit = torch.eq(pred, label).float()
        return hit

    def one_hot_encode(self, num_classes, class_idx):
        return torch.eye(num_classes)[class_idx].to(tt.arg.device)

    def save_checkpoint(self, state, is_best):
        torch.save(state, 'asset/checkpoints/{}/'.format(tt.arg.experiment) + 'checkpoint.pth.tar')
        if is_best:
            shutil.copyfile('asset/checkpoints/{}/'.format(tt.arg.experiment) + 'checkpoint.pth.tar',
                            'asset/checkpoints/{}/'.format(tt.arg.experiment) + 'model_best.pth.tar')

def set_exp_name():
    exp_name = 'D-{}'.format(tt.arg.dataset)
    exp_name += '_N-{}_K-{}_U-{}'.format(tt.arg.num_ways, tt.arg.num_shots, tt.arg.num_unlabeled)
    exp_name += '_L-{}_B-{}'.format(tt.arg.num_layers, tt.arg.meta_batch_size)
    exp_name += '_T-{}'.format(tt.arg.transductive)
    exp_name += '_SEED-{}'.format(tt.arg.seed)

    return exp_name

if __name__ == '__main__':

    tt.arg.device = 'cuda:0' if tt.arg.device is None else tt.arg.device
    # replace dataset_root with your own
    tt.arg.dataset_root = '/data/private/dataset'
    tt.arg.dataset = 'mini' if tt.arg.dataset is None else tt.arg.dataset
    tt.arg.num_ways = 5 if tt.arg.num_ways is None else tt.arg.num_ways
    tt.arg.num_shots = 1 if tt.arg.num_shots is None else tt.arg.num_shots
    tt.arg.num_unlabeled = 0 if tt.arg.num_unlabeled is None else tt.arg.num_unlabeled
    tt.arg.num_layers = 3 if tt.arg.num_layers is None else tt.arg.num_layers
    tt.arg.meta_batch_size = 40 if tt.arg.meta_batch_size is None else tt.arg.meta_batch_size
    tt.arg.transductive = False if tt.arg.transductive is None else tt.arg.transductive
    tt.arg.seed = 222 if tt.arg.seed is None else tt.arg.seed
    tt.arg.num_gpus = 1 if tt.arg.num_gpus is None else tt.arg.num_gpus

    tt.arg.num_ways_train = tt.arg.num_ways
    tt.arg.num_ways_test = tt.arg.num_ways

    tt.arg.num_shots_train = tt.arg.num_shots
    tt.arg.num_shots_test = tt.arg.num_shots

    tt.arg.train_transductive = tt.arg.transductive
    tt.arg.test_transductive = tt.arg.transductive

    # model parameter related
    tt.arg.num_edge_features = 96
    tt.arg.num_node_features = 96
    tt.arg.emb_size = 128

    # train, test parameters
    tt.arg.train_iteration = 100000 if tt.arg.dataset == 'mini' else 200000
    tt.arg.test_iteration = 10000
    tt.arg.test_interval = 5000 if tt.arg.test_interval is None else tt.arg.test_interval
    tt.arg.test_batch_size = 10
    tt.arg.log_step = 1000 if tt.arg.log_step is None else tt.arg.log_step

    tt.arg.lr = 1e-3
    tt.arg.grad_clip = 5
    tt.arg.weight_decay = 1e-6
    tt.arg.dec_lr = 15000 if tt.arg.dataset == 'mini' else 30000
    tt.arg.dropout = 0.1 if tt.arg.dataset == 'mini' else 0.0

    tt.arg.experiment = set_exp_name() if tt.arg.experiment is None else tt.arg.experiment

    print(set_exp_name())

    #set random seed
    np.random.seed(tt.arg.seed)
    torch.manual_seed(tt.arg.seed)
    torch.cuda.manual_seed_all(tt.arg.seed)
    random.seed(tt.arg.seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    tt.arg.log_dir_user = tt.arg.log_dir if tt.arg.log_dir_user is None else tt.arg.log_dir_user
    tt.arg.log_dir = tt.arg.log_dir_user

    if not os.path.exists('asset/checkpoints'):
        os.makedirs('asset/checkpoints')
    if not os.path.exists('asset/checkpoints/' + tt.arg.experiment):
        os.makedirs('asset/checkpoints/' + tt.arg.experiment)


    enc_module = EmbeddingImagenet(emb_size=tt.arg.emb_size)

    gnn_module = GraphNetwork(in_features=tt.arg.emb_size,
                              node_features=tt.arg.num_edge_features,
                              edge_features=tt.arg.num_node_features,
                              num_layers=tt.arg.num_layers,
                              dropout=tt.arg.dropout)

    if tt.arg.dataset == 'mini':
        train_loader = MiniImagenetLoader(root=tt.arg.dataset_root, partition='train')
        valid_loader = MiniImagenetLoader(root=tt.arg.dataset_root, partition='val')
    elif tt.arg.dataset == 'tiered':
        train_loader = TieredImagenetLoader(root=tt.arg.dataset_root, partition='train')
        valid_loader = TieredImagenetLoader(root=tt.arg.dataset_root, partition='val')
    else:
        print('Unknown dataset!')

    data_loader = {'train': train_loader,
                   'val': valid_loader
                   }

    # create trainer
    trainer = ModelTrainer(enc_module=enc_module,
                           gnn_module=gnn_module,
                           data_loader=data_loader)

    trainer.train()
Download .txt
gitextract__r5a0unu/

├── .idea/
│   ├── egnn_distribute.iml
│   ├── modules.xml
│   ├── vcs.xml
│   └── workspace.xml
├── LICENSE
├── README.md
├── __init__.py
├── _version.py
├── data.py
├── eval.py
├── model.py
├── torchtools/
│   ├── __init__.py
│   ├── _version.py
│   └── tt/
│       ├── __init__.py
│       ├── arg.py
│       ├── layer.py
│       ├── logger.py
│       ├── stat.py
│       ├── trainer.py
│       └── utils.py
└── train.py
Download .txt
SYMBOL INDEX (79 symbols across 9 files)

FILE: data.py
  class MiniImagenetLoader (line 13) | class MiniImagenetLoader(data.Dataset):
    method __init__ (line 14) | def __init__(self, root, partition='train'):
    method load_dataset (line 40) | def load_dataset(self):
    method get_task_batch (line 61) | def get_task_batch(self,
  class TieredImagenetLoader (line 123) | class TieredImagenetLoader(data.Dataset):
    method __init__ (line 124) | def __init__(self, root, partition='train'):
    method get_image_paths (line 140) | def get_image_paths(self, file):
    method preprocess (line 152) | def preprocess(self):
    method load_dataset (line 245) | def load_dataset(self):
    method chunks (line 284) | def chunks(self, data, size=10000):
    method get_task_batch (line 289) | def get_task_batch(self,

FILE: model.py
  class ConvBlock (line 9) | class ConvBlock(nn.Module):
    method __init__ (line 10) | def __init__(self, in_planes, out_planes, userelu=True, momentum=0.1, ...
    method forward (line 27) | def forward(self, x):
  class ConvNet (line 31) | class ConvNet(nn.Module):
    method __init__ (line 32) | def __init__(self, opt, momentum=0.1, affine=True, track_running_stats...
    method forward (line 62) | def forward(self, x):
  class EmbeddingImagenet (line 70) | class EmbeddingImagenet(nn.Module):
    method __init__ (line 71) | def __init__(self,
    method forward (line 117) | def forward(self, input_data):
  class NodeUpdateNetwork (line 124) | class NodeUpdateNetwork(nn.Module):
    method __init__ (line 125) | def __init__(self,
    method forward (line 154) | def forward(self, node_feat, edge_feat):
  class EdgeUpdateNetwork (line 175) | class EdgeUpdateNetwork(nn.Module):
    method __init__ (line 176) | def __init__(self,
    method forward (line 230) | def forward(self, node_feat, edge_feat):
  class GraphNetwork (line 259) | class GraphNetwork(nn.Module):
    method __init__ (line 260) | def __init__(self,
    method forward (line 291) | def forward(self, node_feat, edge_feat):

FILE: torchtools/tt/arg.py
  class _Opt (line 15) | class _Opt(object):
    method __len__ (line 17) | def __len__(self):
    method __setitem__ (line 20) | def __setitem__(self, key, value):
    method __getitem__ (line 23) | def __getitem__(self, item):
    method __getattr__ (line 29) | def __getattr__(self, item):
  function _to_py_obj (line 33) | def _to_py_obj(x):
  function _parse_config (line 49) | def _parse_config(arg, file):
  function _parse_config_thread (line 68) | def _parse_config_thread(arg, file):
  function _print_opts (line 86) | def _print_opts(arg, header):
  function _parse_opts (line 94) | def _parse_opts():

FILE: torchtools/tt/layer.py
  class Reshape (line 7) | class Reshape(nn.Module):
    method __init__ (line 9) | def __init__(self, *shape):
    method forward (line 13) | def forward(self, x):
    method extra_repr (line 16) | def extra_repr(self):

FILE: torchtools/tt/logger.py
  function log (line 19) | def log(*args):
  function _get_writer (line 28) | def _get_writer():
  function log_scalar (line 42) | def log_scalar(tag, value, global_step=None):
  function log_audio (line 46) | def log_audio(tag, audio, global_step=None):
  function log_image (line 50) | def log_image(tag, image, global_step=None):
  function log_text (line 54) | def log_text(tag, text, global_step=None):
  function log_hist (line 58) | def log_hist(tag, values, global_step=None):
  function log_step (line 62) | def log_step(epoch=None, global_step=None, max_epoch=None, max_step=None):
  function log_weight (line 115) | def log_weight(model, global_step=None):
  function log_gradient (line 123) | def log_gradient(model, global_step=None):

FILE: torchtools/tt/stat.py
  function accuracy (line 7) | def accuracy(prob, label, ignore_index=-100):

FILE: torchtools/tt/trainer.py
  class SupervisedTrainer (line 7) | class SupervisedTrainer(object):
    method __init__ (line 9) | def __init__(self, model, data_loader, optimizer=None, criterion=None):
    method train (line 16) | def train(self, inputs):
    method epoch (line 42) | def epoch(self, ep_no=None):
    method run (line 45) | def run(self):

FILE: torchtools/tt/utils.py
  function tic (line 17) | def tic():
  function toc (line 23) | def toc(tic=None):
  function sleep (line 31) | def sleep(seconds):
  function var (line 38) | def var(data, dtype=None, device=None, requires_grad=False):
  function vars (line 44) | def vars(x_list, dtype=None, device=None, requires_grad=False):
  function cvar (line 49) | def cvar(x):
  function nvar (line 56) | def nvar(x):
  function nvars (line 63) | def nvars(x_list):
  function load_model (line 67) | def load_model(model, best=False, postfix=None, experiment=None):
  function save_model (line 92) | def save_model(model, global_step, force=False, best=None, postfix=None):

FILE: train.py
  class ModelTrainer (line 10) | class ModelTrainer(object):
    method __init__ (line 11) | def __init__(self,
    method train (line 46) | def train(self):
    method eval (line 204) | def eval(self, partition='test', log_flag=True):
    method adjust_learning_rate (line 335) | def adjust_learning_rate(self, optimizers, lr, iter):
    method label2edge (line 342) | def label2edge(self, label):
    method hit (line 358) | def hit(self, logit, label):
    method one_hot_encode (line 363) | def one_hot_encode(self, num_classes, class_idx):
    method save_checkpoint (line 366) | def save_checkpoint(self, state, is_best):
  function set_exp_name (line 372) | def set_exp_name():
Condensed preview — 21 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (101K chars).
[
  {
    "path": ".idea/egnn_distribute.iml",
    "chars": 512,
    "preview": "<?xml version=\"1.0\" encoding=\"UTF-8\"?>\n<module type=\"PYTHON_MODULE\" version=\"4\">\n  <component name=\"NewModuleRootManager"
  },
  {
    "path": ".idea/modules.xml",
    "chars": 282,
    "preview": "<?xml version=\"1.0\" encoding=\"UTF-8\"?>\n<project version=\"4\">\n  <component name=\"ProjectModuleManager\">\n    <modules>\n   "
  },
  {
    "path": ".idea/vcs.xml",
    "chars": 180,
    "preview": "<?xml version=\"1.0\" encoding=\"UTF-8\"?>\n<project version=\"4\">\n  <component name=\"VcsDirectoryMappings\">\n    <mapping dire"
  },
  {
    "path": ".idea/workspace.xml",
    "chars": 11692,
    "preview": "<?xml version=\"1.0\" encoding=\"UTF-8\"?>\n<project version=\"4\">\n  <component name=\"ChangeListManager\">\n    <list default=\"t"
  },
  {
    "path": "LICENSE",
    "chars": 1068,
    "preview": "MIT License\n\nCopyright (c) 2019 Jongmin Kim\n\nPermission is hereby granted, free of charge, to any person obtaining a cop"
  },
  {
    "path": "README.md",
    "chars": 8545,
    "preview": "# fewshot-egnn\n\n### Introduction\n\nThe current project page provides pytorch code that implements the following CVPR2019 "
  },
  {
    "path": "__init__.py",
    "chars": 397,
    "preview": "import numpy as np\nimport torch\nfrom torch import nn\nfrom torch import optim\nfrom torch import cuda\nfrom torch import ut"
  },
  {
    "path": "_version.py",
    "chars": 53,
    "preview": "__version__ = '0.4.0'  # align version with pytorch\n\n"
  },
  {
    "path": "data.py",
    "chars": 14914,
    "preview": "from __future__ import print_function\nfrom torchtools import *\nimport torch.utils.data as data\nimport random\nimport os\ni"
  },
  {
    "path": "eval.py",
    "chars": 5038,
    "preview": "from torchtools import *\nfrom data import MiniImagenetLoader, TieredImagenetLoader\nfrom model import EmbeddingImagenet, "
  },
  {
    "path": "model.py",
    "chars": 13890,
    "preview": "from torchtools import *\nfrom collections import OrderedDict\nimport math\n#import seaborn as sns\nimport numpy as np\nimpor"
  },
  {
    "path": "torchtools/__init__.py",
    "chars": 397,
    "preview": "import numpy as np\nimport torch\nfrom torch import nn\nfrom torch import optim\nfrom torch import cuda\nfrom torch import ut"
  },
  {
    "path": "torchtools/_version.py",
    "chars": 53,
    "preview": "__version__ = '0.4.0'  # align version with pytorch\n\n"
  },
  {
    "path": "torchtools/tt/__init__.py",
    "chars": 310,
    "preview": "from torchtools.tt.arg import _parse_opts\nfrom torchtools.tt.utils import *\nfrom torchtools.tt.layer import *\nfrom torch"
  },
  {
    "path": "torchtools/tt/arg.py",
    "chars": 3735,
    "preview": "import sys\nimport configparser\nimport torch\nimport threading\nimport time\nimport os\n\n\n__author__ = 'namju.kim@kakaobrain."
  },
  {
    "path": "torchtools/tt/layer.py",
    "chars": 338,
    "preview": "from torchtools import nn\n\n\n#\n# Reshape layer for Sequential or ModuleList\n#\nclass Reshape(nn.Module):\n\n    def __init__"
  },
  {
    "path": "torchtools/tt/logger.py",
    "chars": 3951,
    "preview": "import datetime\nimport time\nfrom tensorboardX import SummaryWriter\nfrom torchtools import tt\n\n\n__author__ = 'namju.kim@k"
  },
  {
    "path": "torchtools/tt/stat.py",
    "chars": 407,
    "preview": "from torchtools import tt\n\n\n__author__ = 'namju.kim@kakaobrain.com'\n\n\ndef accuracy(prob, label, ignore_index=-100):\n\n   "
  },
  {
    "path": "torchtools/tt/trainer.py",
    "chars": 2631,
    "preview": "from torchtools import nn, optim, tt\n\n\n__author__ = 'namju.kim@kakaobrain.com'\n\n\nclass SupervisedTrainer(object):\n\n    d"
  },
  {
    "path": "torchtools/tt/utils.py",
    "chars": 3725,
    "preview": "import os\nimport datetime\nimport time\nimport pathlib\nfrom torchtools import torch, nn, tt\n\n\n__author__ = 'namju.kim@kaka"
  },
  {
    "path": "train.py",
    "chars": 25078,
    "preview": "from torchtools import *\nfrom data import MiniImagenetLoader, TieredImagenetLoader\nfrom model import EmbeddingImagenet, "
  }
]

About this extraction

This page contains the full source code of the khy0809/fewshot-egnn GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 21 files (94.9 KB), approximately 24.2k tokens, and a symbol index with 79 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!