[
  {
    "path": ".idea/egnn_distribute.iml",
    "content": "<?xml version=\"1.0\" encoding=\"UTF-8\"?>\n<module type=\"PYTHON_MODULE\" version=\"4\">\n  <component name=\"NewModuleRootManager\">\n    <content url=\"file://$MODULE_DIR$\" />\n    <orderEntry type=\"jdk\" jdkName=\"Remote Python 3.6.8 (sftp://root@instance.cloud.kakaobrain.com:11255/opt/conda/bin/python3)\" jdkType=\"Python SDK\" />\n    <orderEntry type=\"sourceFolder\" forTests=\"false\" />\n  </component>\n  <component name=\"TestRunnerService\">\n    <option name=\"PROJECT_TEST_RUNNER\" value=\"Unittests\" />\n  </component>\n</module>"
  },
  {
    "path": ".idea/modules.xml",
    "content": "<?xml version=\"1.0\" encoding=\"UTF-8\"?>\n<project version=\"4\">\n  <component name=\"ProjectModuleManager\">\n    <modules>\n      <module fileurl=\"file://$PROJECT_DIR$/.idea/egnn_distribute.iml\" filepath=\"$PROJECT_DIR$/.idea/egnn_distribute.iml\" />\n    </modules>\n  </component>\n</project>"
  },
  {
    "path": ".idea/vcs.xml",
    "content": "<?xml version=\"1.0\" encoding=\"UTF-8\"?>\n<project version=\"4\">\n  <component name=\"VcsDirectoryMappings\">\n    <mapping directory=\"$PROJECT_DIR$\" vcs=\"Git\" />\n  </component>\n</project>"
  },
  {
    "path": ".idea/workspace.xml",
    "content": "<?xml version=\"1.0\" encoding=\"UTF-8\"?>\n<project version=\"4\">\n  <component name=\"ChangeListManager\">\n    <list default=\"true\" id=\"f20b581c-b8b4-4c9e-9203-c0c2c2f454b5\" name=\"Default Changelist\" comment=\"\" />\n    <option name=\"EXCLUDED_CONVERTED_TO_IGNORED\" value=\"true\" />\n    <option name=\"SHOW_DIALOG\" value=\"false\" />\n    <option name=\"HIGHLIGHT_CONFLICTS\" value=\"true\" />\n    <option name=\"HIGHLIGHT_NON_ACTIVE_CHANGELIST\" value=\"false\" />\n    <option name=\"LAST_RESOLUTION\" value=\"IGNORE\" />\n  </component>\n  <component name=\"FUSProjectUsageTrigger\">\n    <session id=\"-1231903785\">\n      <usages-collector id=\"statistics.lifecycle.project\">\n        <counts>\n          <entry key=\"project.open.time.1\" value=\"2\" />\n          <entry key=\"project.opened\" value=\"2\" />\n        </counts>\n      </usages-collector>\n      <usages-collector id=\"statistics.file.extensions.open\">\n        <counts>\n          <entry key=\"md\" value=\"2\" />\n          <entry key=\"py\" value=\"4\" />\n        </counts>\n      </usages-collector>\n      <usages-collector id=\"statistics.file.types.open\">\n        <counts>\n          <entry key=\"Markdown\" value=\"2\" />\n          <entry key=\"Python\" value=\"4\" />\n        </counts>\n      </usages-collector>\n      <usages-collector id=\"statistics.file.extensions.edit\">\n        <counts>\n          <entry key=\"md\" value=\"1812\" />\n          <entry key=\"py\" value=\"112\" />\n        </counts>\n      </usages-collector>\n      <usages-collector id=\"statistics.file.types.edit\">\n        <counts>\n          <entry key=\"Markdown\" value=\"1812\" />\n          <entry key=\"Python\" value=\"112\" />\n        </counts>\n      </usages-collector>\n    </session>\n  </component>\n  <component name=\"FileEditorManager\">\n    <leaf SIDE_TABS_SIZE_LIMIT_KEY=\"300\">\n      <file pinned=\"false\" current-in-tab=\"false\">\n        <entry file=\"file://$PROJECT_DIR$/data.py\">\n          <provider selected=\"true\" editor-type-id=\"text-editor\">\n            <state relative-caret-position=\"3705\">\n              <caret line=\"247\" column=\"37\" selection-start-line=\"247\" selection-start-column=\"37\" selection-end-line=\"247\" selection-end-column=\"37\" />\n              <folding>\n                <element signature=\"e#0#37#0\" expanded=\"true\" />\n              </folding>\n            </state>\n          </provider>\n        </entry>\n      </file>\n      <file pinned=\"false\" current-in-tab=\"false\">\n        <entry file=\"file://$PROJECT_DIR$/model.py\">\n          <provider selected=\"true\" editor-type-id=\"text-editor\">\n            <state relative-caret-position=\"615\">\n              <caret line=\"41\" column=\"30\" selection-start-line=\"41\" selection-start-column=\"30\" selection-end-line=\"41\" selection-end-column=\"30\" />\n              <folding>\n                <element signature=\"e#0#24#0\" expanded=\"true\" />\n              </folding>\n            </state>\n          </provider>\n        </entry>\n      </file>\n      <file pinned=\"false\" current-in-tab=\"false\">\n        <entry file=\"file://$PROJECT_DIR$/train.py\">\n          <provider selected=\"true\" editor-type-id=\"text-editor\">\n            <state relative-caret-position=\"5820\">\n              <caret line=\"388\" column=\"49\" selection-start-line=\"388\" selection-start-column=\"26\" selection-end-line=\"388\" selection-end-column=\"49\" />\n              <folding>\n                <element signature=\"e#0#24#0\" expanded=\"true\" />\n              </folding>\n            </state>\n          </provider>\n        </entry>\n      </file>\n      <file pinned=\"false\" current-in-tab=\"true\">\n        <entry file=\"file://$PROJECT_DIR$/README.md\">\n          <provider selected=\"true\" editor-type-id=\"split-provider[text-editor;markdown-preview-editor]\">\n            <state split_layout=\"SPLIT\">\n              <first_editor relative-caret-position=\"239\">\n                <caret line=\"144\" column=\"47\" selection-start-line=\"144\" selection-start-column=\"47\" selection-end-line=\"144\" selection-end-column=\"47\" />\n              </first_editor>\n              <second_editor />\n            </state>\n          </provider>\n        </entry>\n      </file>\n    </leaf>\n  </component>\n  <component name=\"FindInProjectRecents\">\n    <findStrings>\n      <find>tt.arg.inter_dea</find>\n      <find>inter_deactivate</find>\n    </findStrings>\n  </component>\n  <component name=\"Git.Settings\">\n    <option name=\"RECENT_GIT_ROOT_PATH\" value=\"$PROJECT_DIR$\" />\n  </component>\n  <component name=\"IdeDocumentHistory\">\n    <option name=\"CHANGED_PATHS\">\n      <list>\n        <option value=\"$PROJECT_DIR$/model.py\" />\n        <option value=\"$PROJECT_DIR$/train.py\" />\n        <option value=\"$PROJECT_DIR$/README.md\" />\n      </list>\n    </option>\n  </component>\n  <component name=\"JsBuildToolGruntFileManager\" detection-done=\"true\" sorting=\"DEFINITION_ORDER\" />\n  <component name=\"JsBuildToolPackageJson\" detection-done=\"true\" sorting=\"DEFINITION_ORDER\" />\n  <component name=\"JsGulpfileManager\">\n    <detection-done>true</detection-done>\n    <sorting>DEFINITION_ORDER</sorting>\n  </component>\n  <component name=\"ProjectFrameBounds\" fullScreen=\"true\">\n    <option name=\"y\" value=\"23\" />\n    <option name=\"width\" value=\"1440\" />\n    <option name=\"height\" value=\"877\" />\n  </component>\n  <component name=\"ProjectLevelVcsManager\" settingsEditedManually=\"true\" />\n  <component name=\"ProjectView\">\n    <navigator proportions=\"\" version=\"1\">\n      <foldersAlwaysOnTop value=\"true\" />\n    </navigator>\n    <panes>\n      <pane id=\"ProjectPane\">\n        <subPane>\n          <expand>\n            <path>\n              <item name=\"egnn_distribute\" type=\"b2602c69:ProjectViewProjectNode\" />\n              <item name=\"egnn_distribute\" type=\"462c0819:PsiDirectoryNode\" />\n            </path>\n          </expand>\n          <select />\n        </subPane>\n      </pane>\n      <pane id=\"Scope\" />\n    </panes>\n  </component>\n  <component name=\"PropertiesComponent\">\n    <property name=\"WebServerToolWindowFactoryState\" value=\"true\" />\n    <property name=\"last_opened_file_path\" value=\"$PROJECT_DIR$\" />\n    <property name=\"nodejs_interpreter_path.stuck_in_default_project\" value=\"undefined stuck path\" />\n    <property name=\"nodejs_npm_path_reset_for_default_project\" value=\"true\" />\n    <property name=\"settings.editor.selected.configurable\" value=\"com.jetbrains.python.configuration.PyActiveSdkModuleConfigurable\" />\n  </component>\n  <component name=\"RunDashboard\">\n    <option name=\"ruleStates\">\n      <list>\n        <RuleState>\n          <option name=\"name\" value=\"ConfigurationTypeDashboardGroupingRule\" />\n        </RuleState>\n        <RuleState>\n          <option name=\"name\" value=\"StatusDashboardGroupingRule\" />\n        </RuleState>\n      </list>\n    </option>\n  </component>\n  <component name=\"SvnConfiguration\">\n    <configuration />\n  </component>\n  <component name=\"TaskManager\">\n    <task active=\"true\" id=\"Default\" summary=\"Default task\">\n      <changelist id=\"f20b581c-b8b4-4c9e-9203-c0c2c2f454b5\" name=\"Default Changelist\" comment=\"\" />\n      <created>1556855662817</created>\n      <option name=\"number\" value=\"Default\" />\n      <option name=\"presentableId\" value=\"Default\" />\n      <updated>1556855662817</updated>\n    </task>\n    <servers />\n  </component>\n  <component name=\"ToolWindowManager\">\n    <frame x=\"0\" y=\"0\" width=\"1440\" height=\"900\" extended-state=\"0\" />\n    <editor active=\"true\" />\n    <layout>\n      <window_info active=\"true\" content_ui=\"combo\" id=\"Project\" order=\"0\" visible=\"true\" weight=\"0.19456366\" />\n      <window_info id=\"Structure\" order=\"1\" side_tool=\"true\" weight=\"0.25\" />\n      <window_info id=\"Favorites\" order=\"2\" side_tool=\"true\" />\n      <window_info anchor=\"bottom\" id=\"Message\" order=\"0\" />\n      <window_info anchor=\"bottom\" id=\"Find\" order=\"1\" />\n      <window_info anchor=\"bottom\" id=\"Run\" order=\"2\" />\n      <window_info anchor=\"bottom\" id=\"Debug\" order=\"3\" weight=\"0.4\" />\n      <window_info anchor=\"bottom\" id=\"Cvs\" order=\"4\" weight=\"0.25\" />\n      <window_info anchor=\"bottom\" id=\"Inspection\" order=\"5\" weight=\"0.4\" />\n      <window_info anchor=\"bottom\" id=\"TODO\" order=\"6\" />\n      <window_info anchor=\"bottom\" id=\"Docker\" order=\"7\" show_stripe_button=\"false\" />\n      <window_info anchor=\"bottom\" id=\"Version Control\" order=\"8\" show_stripe_button=\"false\" />\n      <window_info anchor=\"bottom\" id=\"File Transfer\" order=\"9\" visible=\"true\" weight=\"0.32771084\" />\n      <window_info anchor=\"bottom\" id=\"Database Changes\" order=\"10\" show_stripe_button=\"false\" />\n      <window_info anchor=\"bottom\" id=\"Terminal\" order=\"11\" />\n      <window_info anchor=\"bottom\" id=\"Event Log\" order=\"12\" side_tool=\"true\" />\n      <window_info anchor=\"bottom\" id=\"Python Console\" order=\"13\" />\n      <window_info anchor=\"right\" id=\"Commander\" internal_type=\"SLIDING\" order=\"0\" type=\"SLIDING\" weight=\"0.4\" />\n      <window_info anchor=\"right\" id=\"Ant Build\" order=\"1\" weight=\"0.25\" />\n      <window_info anchor=\"right\" content_ui=\"combo\" id=\"Hierarchy\" order=\"2\" weight=\"0.25\" />\n      <window_info anchor=\"right\" id=\"Remote Host\" order=\"3\" />\n      <window_info anchor=\"right\" id=\"SciView\" order=\"4\" />\n      <window_info anchor=\"right\" id=\"Database\" order=\"5\" />\n    </layout>\n  </component>\n  <component name=\"TypeScriptGeneratedFilesManager\">\n    <option name=\"version\" value=\"1\" />\n  </component>\n  <component name=\"VcsContentAnnotationSettings\">\n    <option name=\"myLimit\" value=\"2678400000\" />\n  </component>\n  <component name=\"editorHistoryManager\">\n    <entry file=\"file://$PROJECT_DIR$/eval.py\">\n      <provider selected=\"true\" editor-type-id=\"text-editor\">\n        <state relative-caret-position=\"405\">\n          <caret line=\"27\" column=\"46\" lean-forward=\"true\" selection-start-line=\"27\" selection-start-column=\"46\" selection-end-line=\"27\" selection-end-column=\"46\" />\n          <folding>\n            <element signature=\"e#0#24#0\" expanded=\"true\" />\n          </folding>\n        </state>\n      </provider>\n    </entry>\n    <entry file=\"file://$PROJECT_DIR$/data.py\">\n      <provider selected=\"true\" editor-type-id=\"text-editor\">\n        <state relative-caret-position=\"3705\">\n          <caret line=\"247\" column=\"37\" selection-start-line=\"247\" selection-start-column=\"37\" selection-end-line=\"247\" selection-end-column=\"37\" />\n          <folding>\n            <element signature=\"e#0#37#0\" expanded=\"true\" />\n          </folding>\n        </state>\n      </provider>\n    </entry>\n    <entry file=\"file://$PROJECT_DIR$/model.py\">\n      <provider selected=\"true\" editor-type-id=\"text-editor\">\n        <state relative-caret-position=\"615\">\n          <caret line=\"41\" column=\"30\" selection-start-line=\"41\" selection-start-column=\"30\" selection-end-line=\"41\" selection-end-column=\"30\" />\n          <folding>\n            <element signature=\"e#0#24#0\" expanded=\"true\" />\n          </folding>\n        </state>\n      </provider>\n    </entry>\n    <entry file=\"file://$PROJECT_DIR$/train.py\">\n      <provider selected=\"true\" editor-type-id=\"text-editor\">\n        <state relative-caret-position=\"5820\">\n          <caret line=\"388\" column=\"49\" selection-start-line=\"388\" selection-start-column=\"26\" selection-end-line=\"388\" selection-end-column=\"49\" />\n          <folding>\n            <element signature=\"e#0#24#0\" expanded=\"true\" />\n          </folding>\n        </state>\n      </provider>\n    </entry>\n    <entry file=\"file://$PROJECT_DIR$/README.md\">\n      <provider selected=\"true\" editor-type-id=\"split-provider[text-editor;markdown-preview-editor]\">\n        <state split_layout=\"SPLIT\">\n          <first_editor relative-caret-position=\"239\">\n            <caret line=\"144\" column=\"47\" selection-start-line=\"144\" selection-start-column=\"47\" selection-end-line=\"144\" selection-end-column=\"47\" />\n          </first_editor>\n          <second_editor />\n        </state>\n      </provider>\n    </entry>\n  </component>\n</project>"
  },
  {
    "path": "LICENSE",
    "content": "MIT License\n\nCopyright (c) 2019 Jongmin Kim\n\nPermission is hereby granted, free of charge, to any person obtaining a copy\nof this software and associated documentation files (the \"Software\"), to deal\nin the Software without restriction, including without limitation the rights\nto use, copy, modify, merge, publish, distribute, sublicense, and/or sell\ncopies of the Software, and to permit persons to whom the Software is\nfurnished to do so, subject to the following conditions:\n\nThe above copyright notice and this permission notice shall be included in all\ncopies or substantial portions of the Software.\n\nTHE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\nIMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\nFITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\nAUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\nLIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\nOUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\nSOFTWARE.\n"
  },
  {
    "path": "README.md",
    "content": "# fewshot-egnn\n\n### Introduction\n\nThe current project page provides pytorch code that implements the following CVPR2019 paper:   \n**Title:**      \"Edge-labeling Graph Neural Network for Few-shot Learning\"    \n**Authors:**     Jongmin Kim, Taesup Kim, Sungwoong Kim, Chang D.Yoo\n\n**Institution:** KAIST, KaKaoBrain     \n**Code:**        https://github.com/khy0809/fewshot-egnn  \n**Arxiv:**       https://arxiv.org/abs/1905.01436\n\n**Abstract:**\nIn this paper, we propose a novel edge-labeling graph\nneural network (EGNN), which adapts a deep neural network\non the edge-labeling graph, for few-shot learning.\nThe previous graph neural network (GNN) approaches in\nfew-shot learning have been based on the node-labeling\nframework, which implicitly models the intra-cluster similarity\nand the inter-cluster dissimilarity. In contrast, the\nproposed EGNN learns to predict the edge-labels rather\nthan the node-labels on the graph that enables the evolution\nof an explicit clustering by iteratively updating the edgelabels\nwith direct exploitation of both intra-cluster similarity\nand the inter-cluster dissimilarity. It is also well suited\nfor performing on various numbers of classes without retraining,\nand can be easily extended to perform a transductive\ninference. The parameters of the EGNN are learned\nby episodic training with an edge-labeling loss to obtain a\nwell-generalizable model for unseen low-data problem. On\nboth of the supervised and semi-supervised few-shot image\nclassification tasks with two benchmark datasets, the proposed\nEGNN significantly improves the performances over\nthe existing GNNs.\n\n### Citation\nIf you find this code useful you can cite us using the following bibTex:\n```\n@article{kim2019egnn,\n  title={Edge-labeling Graph Neural Network for Few-shot Learning},\n  author={Jongmin Kim, Taesup Kim, Sungwoong Kim, Chang D. Yoo},\n  journal={arXiv preprint arXiv:1905.01436},\n  year={2019}\n}\n```\n\n\n### Platform\nThis code was developed and tested with pytorch version 1.0.1\n\n### Setting\n\nYou can download miniImagenet dataset from [here](https://drive.google.com/open?id=15WuREBvhEbSWo4fTr1r-vMY0C_6QWv4w).\n\nDownload 'mini_imagenet_train/val/test.pickle', and put them in the path \n'tt.arg.dataset_root/mini-imagenet/compacted_dataset/'\n\nIn ```train.py```, replace the dataset root directory with your own:\ntt.arg.dataset_root = '/data/private/dataset'\n\n\n\n### Training\n\n```\n# ************************** miniImagenet, 5way 1shot *****************************\n$ python3 train.py --dataset mini --num_ways 5 --num_shots 1 --transductive False\n$ python3 train.py --dataset mini --num_ways 5 --num_shots 1 --transductive True\n\n# ************************** miniImagenet, 5way 5shot *****************************\n$ python3 train.py --dataset mini --num_ways 5 --num_shots 5 --transductive False\n$ python3 train.py --dataset mini --num_ways 5 --num_shots 5 --transductive True\n\n# ************************** miniImagenet, 10way 5shot *****************************\n$ python3 train.py --dataset mini --num_ways 10 --num_shots 5 --meta_batch_size 20 --transductive True\n\n# ************************** tieredImagenet, 5way 5shot *****************************\n$ python3 train.py --dataset tiered --num_ways 5 --num_shots 5 --transductive False\n$ python3 train.py --dataset tiered --num_ways 5 --num_shots 5 --transductive True\n\n# **************** miniImagenet, 5way 5shot, 20% labeled (semi) *********************\n$ python3 train.py --dataset mini --num_ways 5 --num_shots 5 --num_unlabeled 4 --transductive False\n$ python3 train.py --dataset mini --num_ways 5 --num_shots 5 --num_unlabeled 4 --transductive True\n\n```\n\n### Evaluation\nThe trained models are saved in the path './asset/checkpoints/', with the name of 'D-{dataset}-N-{ways}-K-{shots}-U-{num_unlabeld}-L-{num_layers}-B-{batch size}-T-{transductive}'.\nSo, for example, if you want to test the trained model of 'miniImagenet, 5way 1shot, transductive' setting, you can give --test_model argument as follow:\n```\n$ python3 eval.py --test_model D-mini_N-5_K-1_U-0_L-3_B-40_T-True\n```\n\n\n## Result\nHere are some experimental results presented in the paper. You should be able to reproduce all the results by using the trained models which can be downloaded from [here](https://drive.google.com/open?id=15WuREBvhEbSWo4fTr1r-vMY0C_6QWv4w).\n#### miniImageNet, non-transductive\n\n| Model                    |  5-way 5-shot acc (%)| \n|--------------------------|  ------------------: | \n| Matching Networks [1]    |         55.30        | \n| Reptile [2]              |         62.74        | \n| Prototypical Net [3]     |         65.77        | \n| GNN [4]                  |         66.41        | \n| **(ours)** EGNN          |         **66.85**        | \n\n#### miniImageNet, transductive\n\n| Model                    |  5-way 5-shot acc (%)| \n|--------------------------|  ------------------: | \n| MAML [5]                 |         63.11        | \n| Reptile + BN [2]         |         65.99        | \n| Relation Net [6]         |         67.07        | \n| MAML + Transduction [5]  |         66.19        | \n| TPN [7]                  |         69.43        | \n| TPN (Higher K) [7]       |         69.86        | \n| **(ours)** EGNN          |         **76.37**        | \n\n#### tieredImageNet, non-transductive\n\n| Model                    |  5-way 5-shot acc (%)| \n|--------------------------|  ------------------: | \n| Reptile [2]              |         66.47        | \n| Prototypical Net [3]     |         69.57        | \n| **(ours)** EGNN          |         **70.98**        | \n\n#### tieredImageNet, transductive\n\n| Model                    |  5-way 5-shot acc (%)| \n|--------------------------|  ------------------: | \n| MAML [5]                 |         70.30        | \n| Reptile + BN [2]         |         71.03        | \n| Relation Net [6]         |         71.31        | \n| MAML + Transduction [5]  |         70.83        | \n| TPN [7]                  |         72.58        | \n| **(ours)** EGNN          |         **80.15**        | \n\n\n#### miniImageNet, semi-supervised, 5-way 5-shot\n\n| Model                    |  20%                 | 40%                 | 60%                 | 100%                 | \n|--------------------------|  ------------------: | ------------------: | ------------------: | ------------------:  | \n| GNN-LabeledOnly [4]       |      50.33                |      56.91               |        -             |        66.41              |\n| GNN-Semi [4]             |      52.45                |      58.76               |        -             |        66.41              |\n| EGNN-LabeledOnly         |      52.86                |        -             |            -         |            66.85          |\n| EGNN-Semi                |      61.88                |        62.52             |        63.53             |    66.85                  |\n| EGNN-LabeledOnly (Transductive) |      59.18         |         -            |           -          |           76.37           |\n| EGNN-Semi (Transductive)        |      63.62         |        64.32             |        66.37             |   76.37                   |\n\n\n#### miniImageNet, cross-way experiment\n| Model                    |  train way                 | test way                 |  Accuracy |\n|--------------------------|  ------------------: | ------------------: | ------------------: |\n| GNN       |      5                |      5               |      66.41     |\n| GNN       |      5                |      10               |     N/A      |\n| GNN       |      10                |     10            |       51.75    |\n| GNN       |      10             |      5              |       N/A    |\n| EGNN       |      5             |      5              |       76.37    |\n| EGNN       |      5             |      10              |       56.35    |\n| EGNN       |      10             |      10              |       57.61   |\n| EGNN       |      10             |      5              |       76.27   |\n\n\n\n### References\n```\n[1] O. Vinyals et al. Matching networks for one shot learning.\n[2] A Nichol, J Achiam, J Schulman, On first-order meta-learning algorithms.\n[3] J. Snell, K. Swersky, and R. S. Zemel. Prototypical networks for few-shot learning.\n[4] V Garcia, J Bruna, Few-shot learning with graph neural network.\n[5] C. Finn, P. Abbeel, and S. Levine. Model-agnostic meta-learning for fast adaptation of deep networks.\n[6] F. Sung et al, Learning to Compare: Relation Network for Few-Shot Learning.\n[7] Y Liu, J Lee, M Park, S Kim, Y Yang, Transductive propagation network for few-shot learning.\n"
  },
  {
    "path": "__init__.py",
    "content": "import numpy as np\nimport torch\nfrom torch import nn\nfrom torch import optim\nfrom torch import cuda\nfrom torch import utils\nfrom torch.nn import functional as F\nfrom torch.utils.data import *\nfrom torch.distributions import *\nfrom torchtools import tt\n\n\n__author__ = 'namju.kim@kakaobrain.com'\n\n\n# initialize seed\nif tt.arg.seed:\n    np.random.seed(tt.arg.seed)\n    torch.manual_seed(tt.arg.seed)\n"
  },
  {
    "path": "_version.py",
    "content": "__version__ = '0.4.0'  # align version with pytorch\n\n"
  },
  {
    "path": "data.py",
    "content": "from __future__ import print_function\nfrom torchtools import *\nimport torch.utils.data as data\nimport random\nimport os\nimport numpy as np\nfrom PIL import Image as pil_image\nimport pickle\nfrom itertools import islice\nfrom torchvision import transforms\n\n\nclass MiniImagenetLoader(data.Dataset):\n    def __init__(self, root, partition='train'):\n        super(MiniImagenetLoader, self).__init__()\n        # set dataset information\n        self.root = root\n        self.partition = partition\n        self.data_size = [3, 84, 84]\n\n        # set normalizer\n        mean_pix = [x / 255.0 for x in [120.39586422, 115.59361427, 104.54012653]]\n        std_pix = [x / 255.0 for x in [70.68188272, 68.27635443, 72.54505529]]\n        normalize = transforms.Normalize(mean=mean_pix, std=std_pix)\n\n        # set transformer\n        if self.partition == 'train':\n            self.transform = transforms.Compose([transforms.RandomCrop(84, padding=4),\n                                                 lambda x: np.asarray(x),\n                                                 transforms.ToTensor(),\n                                                 normalize])\n        else:  # 'val' or 'test' ,\n            self.transform = transforms.Compose([lambda x: np.asarray(x),\n                                                 transforms.ToTensor(),\n                                                 normalize])\n\n        # load data\n        self.data = self.load_dataset()\n\n    def load_dataset(self):\n        # load data\n        dataset_path = os.path.join(self.root, 'mini-imagenet/compacted_datasets', 'mini_imagenet_%s.pickle' % self.partition)\n        with open(dataset_path, 'rb') as handle:\n            data = pickle.load(handle)\n\n        # for each class\n        for c_idx in data:\n            # for each image\n            for i_idx in range(len(data[c_idx])):\n                # resize\n                image_data = pil_image.fromarray(np.uint8(data[c_idx][i_idx]))\n                image_data = image_data.resize((self.data_size[2], self.data_size[1]))\n                #image_data = np.array(image_data, dtype='float32')\n\n                #image_data = np.transpose(image_data, (2, 0, 1))\n\n                # save\n                data[c_idx][i_idx] = image_data\n        return data\n\n    def get_task_batch(self,\n                       num_tasks=5,\n                       num_ways=20,\n                       num_shots=1,\n                       num_queries=1,\n                       seed=None):\n\n        if seed is not None:\n            random.seed(seed)\n\n        # init task batch data\n        support_data, support_label, query_data, query_label = [], [], [], []\n        for _ in range(num_ways * num_shots):\n            data = np.zeros(shape=[num_tasks] + self.data_size,\n                            dtype='float32')\n            label = np.zeros(shape=[num_tasks],\n                             dtype='float32')\n            support_data.append(data)\n            support_label.append(label)\n        for _ in range(num_ways * num_queries):\n            data = np.zeros(shape=[num_tasks] + self.data_size,\n                            dtype='float32')\n            label = np.zeros(shape=[num_tasks],\n                             dtype='float32')\n            query_data.append(data)\n            query_label.append(label)\n\n        # get full class list in dataset\n        full_class_list = list(self.data.keys())\n\n        # for each task\n        for t_idx in range(num_tasks):\n            # define task by sampling classes (num_ways)\n            task_class_list = random.sample(full_class_list, num_ways)\n\n            # for each sampled class in task\n            for c_idx in range(num_ways):\n                # sample data for support and query (num_shots + num_queries)\n                class_data_list = random.sample(self.data[task_class_list[c_idx]], num_shots + num_queries)\n\n\n                # load sample for support set\n                for i_idx in range(num_shots):\n                    # set data\n                    support_data[i_idx + c_idx * num_shots][t_idx] = self.transform(class_data_list[i_idx])\n                    support_label[i_idx + c_idx * num_shots][t_idx] = c_idx\n\n                # load sample for query set\n                for i_idx in range(num_queries):\n                    query_data[i_idx + c_idx * num_queries][t_idx] = self.transform(class_data_list[num_shots + i_idx])\n                    query_label[i_idx + c_idx * num_queries][t_idx] = c_idx\n\n        # convert to tensor (num_tasks x (num_ways * (num_supports + num_queries)) x ...)\n        support_data = torch.stack([torch.from_numpy(data).float().to(tt.arg.device) for data in support_data], 1)\n        support_label = torch.stack([torch.from_numpy(label).float().to(tt.arg.device) for label in support_label], 1)\n        query_data = torch.stack([torch.from_numpy(data).float().to(tt.arg.device) for data in query_data], 1)\n        query_label = torch.stack([torch.from_numpy(label).float().to(tt.arg.device) for label in query_label], 1)\n\n        return [support_data, support_label, query_data, query_label]\n\n\n\nclass TieredImagenetLoader(data.Dataset):\n    def __init__(self, root, partition='train'):\n        self.root = root\n        self.partition = partition  # train/val/test\n        #self.preprocess()\n        self.data_size = [3, 84, 84]\n\n        # load data\n        self.data = self.load_dataset()\n\n        # if not self._check_exists_():\n        #     self._init_folders_()\n        #     if self.check_decompress():\n        #         self._decompress_()\n        #     self._preprocess_()\n\n\n    def get_image_paths(self, file):\n        images_path, class_names = [], []\n        with open(file, 'r') as f:\n            f.readline()\n            for line in f:\n                name, class_ = line.split(',')\n                class_ = class_[0:(len(class_)-1)]\n                path = self.root + '/tiered-imagenet/images/'+name\n                images_path.append(path)\n                class_names.append(class_)\n        return class_names, images_path\n\n    def preprocess(self):\n        print('\\nPreprocessing Tiered-Imagenet images...')\n        (class_names_train, images_path_train) = self.get_image_paths('%s/tiered-imagenet/train.csv' % self.root)\n        (class_names_test, images_path_test) = self.get_image_paths('%s/tiered-imagenet/test.csv' % self.root)\n        (class_names_val, images_path_val) = self.get_image_paths('%s/tiered-imagenet/val.csv' % self.root)\n\n        keys_train = list(set(class_names_train))\n        keys_test = list(set(class_names_test))\n        keys_val = list(set(class_names_val))\n        label_encoder = {}\n        label_decoder = {}\n        for i in range(len(keys_train)):\n            label_encoder[keys_train[i]] = i\n            label_decoder[i] = keys_train[i]\n        for i in range(len(keys_train), len(keys_train)+len(keys_test)):\n            label_encoder[keys_test[i-len(keys_train)]] = i\n            label_decoder[i] = keys_test[i-len(keys_train)]\n        for i in range(len(keys_train)+len(keys_test), len(keys_train)+len(keys_test)+len(keys_val)):\n            label_encoder[keys_val[i-len(keys_train) - len(keys_test)]] = i\n            label_decoder[i] = keys_val[i-len(keys_train)-len(keys_test)]\n\n        counter = 0\n        train_set = {}\n\n        for class_, path in zip(class_names_train, images_path_train):\n            img = pil_image.open(path)\n            img = img.convert('RGB')\n            img = img.resize((84, 84), pil_image.ANTIALIAS)\n            img = np.array(img, dtype='float32')\n            if label_encoder[class_] not in train_set:\n                train_set[label_encoder[class_]] = []\n            train_set[label_encoder[class_]].append(img)\n            counter += 1\n            if counter % 1000 == 0:\n                print(\"Counter \"+str(counter) + \" from \" + str(len(images_path_train)))\n\n        test_set = {}\n        for class_, path in zip(class_names_test, images_path_test):\n            img = pil_image.open(path)\n            img = img.convert('RGB')\n            img = img.resize((84, 84), pil_image.ANTIALIAS)\n            img = np.array(img, dtype='float32')\n\n            if label_encoder[class_] not in test_set:\n                test_set[label_encoder[class_]] = []\n            test_set[label_encoder[class_]].append(img)\n            counter += 1\n            if counter % 1000 == 0:\n                print(\"Counter \" + str(counter) + \" from \"+str(len(class_names_test)))\n\n        val_set = {}\n        for class_, path in zip(class_names_val, images_path_val):\n            img = pil_image.open(path)\n            img = img.convert('RGB')\n            img = img.resize((84, 84), pil_image.ANTIALIAS)\n            img = np.array(img, dtype='float32')\n\n            if label_encoder[class_] not in val_set:\n                val_set[label_encoder[class_]] = []\n            val_set[label_encoder[class_]].append(img)\n            counter += 1\n            if counter % 1000 == 0:\n                print(\"Counter \"+str(counter) + \" from \" + str(len(class_names_val)))\n\n        partition_count = 0\n        for item in self.chunks(train_set, 20):\n            partition_count = partition_count + 1\n            with open(os.path.join(self.root, 'tiered-imagenet/compacted_datasets', 'tiered_imagenet_train_{}.pickle'.format(partition_count)), 'wb') as handle:\n                pickle.dump(item, handle, protocol=2)\n\n        partition_count = 0\n        for item in self.chunks(test_set, 20):\n            partition_count = partition_count + 1\n            with open(os.path.join(self.root, 'tiered-imagenet/compacted_datasets', 'tiered_imagenet_test_{}.pickle'.format(partition_count)), 'wb') as handle:\n                pickle.dump(item, handle, protocol=2)\n\n        partition_count = 0\n        for item in self.chunks(val_set, 20):\n            partition_count = partition_count + 1\n            with open(os.path.join(self.root, 'tiered-imagenet/compacted_datasets', 'tiered_imagenet_val_{}.pickle'.format(partition_count)), 'wb') as handle:\n                pickle.dump(item, handle, protocol=2)\n\n\n\n        label_encoder = {}\n        keys = list(train_set.keys()) + list(test_set.keys())\n        for id_key, key in enumerate(keys):\n            label_encoder[key] = id_key\n        with open(os.path.join(self.root, 'tiered-imagenet/compacted_datasets', 'tiered_imagenet_label_encoder.pickle'), 'wb') as handle:\n            pickle.dump(label_encoder, handle, protocol=2)\n\n        print('Images preprocessed')\n\n    def load_dataset(self):\n        print(\"Loading dataset\")\n        data = {}\n        if self.partition == 'train':\n            num_partition = 18\n        elif self.partition == 'val':\n            num_partition = 5\n        elif self.partition == 'test':\n            num_partition = 8\n\n        partition_count = 0\n        for i in range(num_partition):\n            partition_count = partition_count +1\n            with open(os.path.join(self.root, 'tiered-imagenet/compacted_datasets', 'tiered_imagenet_{}_{}.pickle'.format(self.partition, partition_count)), 'rb') as handle:\n                data.update(pickle.load(handle))\n\n        # Resize images and normalize\n        for class_ in data:\n            for i in range(len(data[class_])):\n                image2resize = pil_image.fromarray(np.uint8(data[class_][i]))\n                image_resized = image2resize.resize((self.data_size[2], self.data_size[1]))\n                image_resized = np.array(image_resized, dtype='float32')\n\n                # Normalize\n                image_resized = np.transpose(image_resized, (2, 0, 1))\n                image_resized[0, :, :] -= 120.45  # R\n                image_resized[1, :, :] -= 115.74  # G\n                image_resized[2, :, :] -= 104.65  # B\n                image_resized /= 127.5\n\n                data[class_][i] = image_resized\n\n        print(\"Num classes \" + str(len(data)))\n        num_images = 0\n        for class_ in data:\n            num_images += len(data[class_])\n        print(\"Num images \" + str(num_images))\n        return data\n\n    def chunks(self, data, size=10000):\n        it = iter(data)\n        for i in range(0, len(data), size):\n            yield {k: data[k] for k in islice(it, size)}\n\n    def get_task_batch(self,\n                       num_tasks=5,\n                       num_ways=20,\n                       num_shots=1,\n                       num_queries=1,\n                       seed=None):\n        if seed is not None:\n            random.seed(seed)\n\n        # init task batch data\n        support_data, support_label, query_data, query_label = [], [], [], []\n        for _ in range(num_ways * num_shots):\n            data = np.zeros(shape=[num_tasks] + self.data_size,\n                            dtype='float32')\n            label = np.zeros(shape=[num_tasks],\n                             dtype='float32')\n            support_data.append(data)\n            support_label.append(label)\n        for _ in range(num_ways * num_queries):\n            data = np.zeros(shape=[num_tasks] + self.data_size,\n                            dtype='float32')\n            label = np.zeros(shape=[num_tasks],\n                             dtype='float32')\n            query_data.append(data)\n            query_label.append(label)\n\n        # get full class list in dataset\n        full_class_list = list(self.data.keys())\n\n        # for each task\n        for t_idx in range(num_tasks):\n            # define task by sampling classes (num_ways)\n            task_class_list = random.sample(full_class_list, num_ways)\n\n            # for each sampled class in task\n            for c_idx in range(num_ways):\n                # sample data for support and query (num_shots + num_queries)\n                class_data_list = random.sample(self.data[task_class_list[c_idx]], num_shots + num_queries)\n\n                # load sample for support set\n                for i_idx in range(num_shots):\n                    # set data\n                    support_data[i_idx + c_idx * num_shots][t_idx] = class_data_list[i_idx]\n                    support_label[i_idx + c_idx * num_shots][t_idx] = c_idx\n\n                # load sample for query set\n                for i_idx in range(num_queries):\n                    query_data[i_idx + c_idx * num_queries][t_idx] = class_data_list[num_shots + i_idx]\n                    query_label[i_idx + c_idx * num_queries][t_idx] = c_idx\n\n\n\n        # convert to tensor (num_tasks x (num_ways * (num_supports + num_queries)) x ...)\n        support_data = torch.stack([torch.from_numpy(data).float().to(tt.arg.device) for data in support_data], 1)\n        support_label = torch.stack([torch.from_numpy(label).float().to(tt.arg.device) for label in support_label], 1)\n        query_data = torch.stack([torch.from_numpy(data).float().to(tt.arg.device) for data in query_data], 1)\n        query_label = torch.stack([torch.from_numpy(label).float().to(tt.arg.device) for label in query_label], 1)\n\n        return [support_data, support_label, query_data, query_label]"
  },
  {
    "path": "eval.py",
    "content": "from torchtools import *\nfrom data import MiniImagenetLoader, TieredImagenetLoader\nfrom model import EmbeddingImagenet, GraphNetwork, ConvNet\nimport shutil\nimport os\nimport random\nfrom train import ModelTrainer\n\nif __name__ == '__main__':\n\n    tt.arg.test_model = 'D-mini_N-5_K-1_U-0_L-3_B-40_T-True' if tt.arg.test_model is None else tt.arg.test_model\n\n    list1 = tt.arg.test_model.split(\"_\")\n    param = {}\n    for i in range(len(list1)):\n        param[list1[i].split(\"-\", 1)[0]] = list1[i].split(\"-\", 1)[1]\n    tt.arg.dataset = param['D']\n    tt.arg.num_ways = int(param['N'])\n    tt.arg.num_shots = int(param['K'])\n    tt.arg.num_unlabeled = int(param['U'])\n    tt.arg.num_layers = int(param['L'])\n    tt.arg.meta_batch_size = int(param['B'])\n    tt.arg.transductive = False if param['T'] == 'False' else True\n\n\n    ####################\n    tt.arg.device = 'cuda:0' if tt.arg.device is None else tt.arg.device\n    # replace dataset_root with your own\n    tt.arg.dataset_root = '/data/private/dataset'\n    tt.arg.dataset = 'mini' if tt.arg.dataset is None else tt.arg.dataset\n    tt.arg.num_ways = 5 if tt.arg.num_ways is None else tt.arg.num_ways\n    tt.arg.num_shots = 1 if tt.arg.num_shots is None else tt.arg.num_shots\n    tt.arg.num_unlabeled = 0 if tt.arg.num_unlabeled is None else tt.arg.num_unlabeled\n    tt.arg.num_layers = 3 if tt.arg.num_layers is None else tt.arg.num_layers\n    tt.arg.meta_batch_size = 40 if tt.arg.meta_batch_size is None else tt.arg.meta_batch_size\n    tt.arg.transductive = False if tt.arg.transductive is None else tt.arg.transductive\n    tt.arg.seed = 222 if tt.arg.seed is None else tt.arg.seed\n    tt.arg.num_gpus = 1 if tt.arg.num_gpus is None else tt.arg.num_gpus\n\n    tt.arg.num_ways_train = tt.arg.num_ways\n    tt.arg.num_ways_test = tt.arg.num_ways\n\n    tt.arg.num_shots_train = tt.arg.num_shots\n    tt.arg.num_shots_test = tt.arg.num_shots\n\n    tt.arg.train_transductive = tt.arg.transductive\n    tt.arg.test_transductive = tt.arg.transductive\n\n    # model parameter related\n    tt.arg.num_edge_features = 96\n    tt.arg.num_node_features = 96\n    tt.arg.emb_size = 128\n\n    # train, test parameters\n    tt.arg.train_iteration = 100000 if tt.arg.dataset == 'mini' else 200000\n    tt.arg.test_iteration = 10000\n    tt.arg.test_interval = 5000\n    tt.arg.test_batch_size = 10\n    tt.arg.log_step = 1000\n\n    tt.arg.lr = 1e-3\n    tt.arg.grad_clip = 5\n    tt.arg.weight_decay = 1e-6\n    tt.arg.dec_lr = 15000 if tt.arg.dataset == 'mini' else 30000\n    tt.arg.dropout = 0.1 if tt.arg.dataset == 'mini' else 0.0\n\n    #set random seed\n    np.random.seed(tt.arg.seed)\n    torch.manual_seed(tt.arg.seed)\n    torch.cuda.manual_seed_all(tt.arg.seed)\n    random.seed(tt.arg.seed)\n    torch.backends.cudnn.deterministic = True\n    torch.backends.cudnn.benchmark = False\n\n\n    enc_module = EmbeddingImagenet(emb_size=tt.arg.emb_size)\n\n    # set random seed\n    np.random.seed(tt.arg.seed)\n    torch.manual_seed(tt.arg.seed)\n    torch.cuda.manual_seed_all(tt.arg.seed)\n    random.seed(tt.arg.seed)\n    torch.backends.cudnn.deterministic = True\n    torch.backends.cudnn.benchmark = False\n\n    # to check\n    exp_name = 'D-{}'.format(tt.arg.dataset)\n    exp_name += '_N-{}_K-{}_U-{}'.format(tt.arg.num_ways, tt.arg.num_shots, tt.arg.num_unlabeled)\n    exp_name += '_L-{}_B-{}'.format(tt.arg.num_layers, tt.arg.meta_batch_size)\n    exp_name += '_T-{}'.format(tt.arg.transductive)\n\n\n    if not exp_name == tt.arg.test_model:\n        print(exp_name)\n        print(tt.arg.test_model)\n        print('Test model and input arguments are mismatched!')\n        AssertionError()\n\n    gnn_module = GraphNetwork(in_features=tt.arg.emb_size,\n                              node_features=tt.arg.num_edge_features,\n                              edge_features=tt.arg.num_node_features,\n                              num_layers=tt.arg.num_layers,\n                              dropout=tt.arg.dropout)\n\n    if tt.arg.dataset == 'mini':\n        test_loader = MiniImagenetLoader(root=tt.arg.dataset_root, partition='test')\n    elif tt.arg.dataset == 'tiered':\n        test_loader = TieredImagenetLoader(root=tt.arg.dataset_root, partition='test')\n    else:\n        print('Unknown dataset!')\n\n\n    data_loader = {'test': test_loader}\n\n    # create trainer\n    tester = ModelTrainer(enc_module=enc_module,\n                           gnn_module=gnn_module,\n                           data_loader=data_loader)\n\n\n    #checkpoint = torch.load('asset/checkpoints/{}/'.format(exp_name) + 'model_best.pth.tar')\n    checkpoint = torch.load('./trained_models/{}/'.format(exp_name) + 'model_best.pth.tar')\n\n\n    tester.enc_module.load_state_dict(checkpoint['enc_module_state_dict'])\n    print(\"load pre-trained enc_nn done!\")\n\n    # initialize gnn pre-trained\n    tester.gnn_module.load_state_dict(checkpoint['gnn_module_state_dict'])\n    print(\"load pre-trained egnn done!\")\n\n    tester.val_acc = checkpoint['val_acc']\n    tester.global_step = checkpoint['iteration']\n\n    print(tester.global_step)\n\n\n    tester.eval(partition='test')\n\n\n\n\n\n"
  },
  {
    "path": "model.py",
    "content": "from torchtools import *\nfrom collections import OrderedDict\nimport math\n#import seaborn as sns\nimport numpy as np\nimport matplotlib.pyplot as plt\n\n\nclass ConvBlock(nn.Module):\n    def __init__(self, in_planes, out_planes, userelu=True, momentum=0.1, affine=True, track_running_stats=True):\n        super(ConvBlock, self).__init__()\n        self.layers = nn.Sequential()\n        self.layers.add_module('Conv', nn.Conv2d(in_planes, out_planes,\n            kernel_size=3, stride=1, padding=1, bias=False))\n\n        if tt.arg.normtype == 'batch':\n            self.layers.add_module('Norm', nn.BatchNorm2d(out_planes, momentum=momentum, affine=affine, track_running_stats=track_running_stats))\n        elif tt.arg.normtype == 'instance':\n            self.layers.add_module('Norm', nn.InstanceNorm2d(out_planes))\n\n        if userelu:\n            self.layers.add_module('ReLU', nn.ReLU(inplace=True))\n\n        self.layers.add_module(\n            'MaxPool', nn.MaxPool2d(kernel_size=2, stride=2, padding=0))\n\n    def forward(self, x):\n        out = self.layers(x)\n        return out\n\nclass ConvNet(nn.Module):\n    def __init__(self, opt, momentum=0.1, affine=True, track_running_stats=True):\n        super(ConvNet, self).__init__()\n        self.in_planes  = opt['in_planes']\n        self.out_planes = opt['out_planes']\n        self.num_stages = opt['num_stages']\n        if type(self.out_planes) == int:\n            self.out_planes = [self.out_planes for i in range(self.num_stages)]\n        assert(type(self.out_planes)==list and len(self.out_planes)==self.num_stages)\n\n        num_planes = [self.in_planes,] + self.out_planes\n        userelu = opt['userelu'] if ('userelu' in opt) else True\n\n        conv_blocks = []\n        for i in range(self.num_stages):\n            if i == (self.num_stages-1):\n                conv_blocks.append(\n                    ConvBlock(num_planes[i], num_planes[i+1], userelu=userelu))\n            else:\n                conv_blocks.append(\n                    ConvBlock(num_planes[i], num_planes[i+1]))\n        self.conv_blocks = nn.Sequential(*conv_blocks)\n\n        for m in self.modules():\n            if isinstance(m, nn.Conv2d):\n                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels\n                m.weight.data.normal_(0, math.sqrt(2. / n))\n            elif isinstance(m, nn.BatchNorm2d):\n                m.weight.data.fill_(1)\n                m.bias.data.zero_()\n\n    def forward(self, x):\n        out = self.conv_blocks(x)\n        out = out.view(out.size(0),-1)\n        return out\n\n\n\n# encoder for imagenet dataset\nclass EmbeddingImagenet(nn.Module):\n    def __init__(self,\n                 emb_size):\n        super(EmbeddingImagenet, self).__init__()\n        # set size\n        self.hidden = 64\n        self.last_hidden = self.hidden * 25\n        self.emb_size = emb_size\n\n        # set layers\n        self.conv_1 = nn.Sequential(nn.Conv2d(in_channels=3,\n                                              out_channels=self.hidden,\n                                              kernel_size=3,\n                                              padding=1,\n                                              bias=False),\n                                    nn.BatchNorm2d(num_features=self.hidden),\n                                    nn.MaxPool2d(kernel_size=2),\n                                    nn.LeakyReLU(negative_slope=0.2, inplace=True))\n        self.conv_2 = nn.Sequential(nn.Conv2d(in_channels=self.hidden,\n                                              out_channels=int(self.hidden*1.5),\n                                              kernel_size=3,\n                                              bias=False),\n                                    nn.BatchNorm2d(num_features=int(self.hidden*1.5)),\n                                    nn.MaxPool2d(kernel_size=2),\n                                    nn.LeakyReLU(negative_slope=0.2, inplace=True))\n        self.conv_3 = nn.Sequential(nn.Conv2d(in_channels=int(self.hidden*1.5),\n                                              out_channels=self.hidden*2,\n                                              kernel_size=3,\n                                              padding=1,\n                                              bias=False),\n                                    nn.BatchNorm2d(num_features=self.hidden * 2),\n                                    nn.MaxPool2d(kernel_size=2),\n                                    nn.LeakyReLU(negative_slope=0.2, inplace=True),\n                                    nn.Dropout2d(0.4))\n        self.conv_4 = nn.Sequential(nn.Conv2d(in_channels=self.hidden*2,\n                                              out_channels=self.hidden*4,\n                                              kernel_size=3,\n                                              padding=1,\n                                              bias=False),\n                                    nn.BatchNorm2d(num_features=self.hidden * 4),\n                                    nn.MaxPool2d(kernel_size=2),\n                                    nn.LeakyReLU(negative_slope=0.2, inplace=True),\n                                    nn.Dropout2d(0.5))\n        self.layer_last = nn.Sequential(nn.Linear(in_features=self.last_hidden * 4,\n                                              out_features=self.emb_size, bias=True),\n                                        nn.BatchNorm1d(self.emb_size))\n\n    def forward(self, input_data):\n        output_data = self.conv_4(self.conv_3(self.conv_2(self.conv_1(input_data))))\n        return self.layer_last(output_data.view(output_data.size(0), -1))\n\n\n\n\nclass NodeUpdateNetwork(nn.Module):\n    def __init__(self,\n                 in_features,\n                 num_features,\n                 ratio=[2, 1],\n                 dropout=0.0):\n        super(NodeUpdateNetwork, self).__init__()\n        # set size\n        self.in_features = in_features\n        self.num_features_list = [num_features * r for r in ratio]\n        self.dropout = dropout\n\n        # layers\n        layer_list = OrderedDict()\n        for l in range(len(self.num_features_list)):\n\n            layer_list['conv{}'.format(l)] = nn.Conv2d(\n                in_channels=self.num_features_list[l - 1] if l > 0 else self.in_features * 3,\n                out_channels=self.num_features_list[l],\n                kernel_size=1,\n                bias=False)\n            layer_list['norm{}'.format(l)] = nn.BatchNorm2d(num_features=self.num_features_list[l],\n                                                            )\n            layer_list['relu{}'.format(l)] = nn.LeakyReLU()\n\n            if self.dropout > 0 and l == (len(self.num_features_list) - 1):\n                layer_list['drop{}'.format(l)] = nn.Dropout2d(p=self.dropout)\n\n        self.network = nn.Sequential(layer_list)\n\n    def forward(self, node_feat, edge_feat):\n        # get size\n        num_tasks = node_feat.size(0)\n        num_data = node_feat.size(1)\n\n        # get eye matrix (batch_size x 2 x node_size x node_size)\n        diag_mask = 1.0 - torch.eye(num_data).unsqueeze(0).unsqueeze(0).repeat(num_tasks, 2, 1, 1).to(tt.arg.device)\n\n        # set diagonal as zero and normalize\n        edge_feat = F.normalize(edge_feat * diag_mask, p=1, dim=-1)\n\n        # compute attention and aggregate\n        aggr_feat = torch.bmm(torch.cat(torch.split(edge_feat, 1, 1), 2).squeeze(1), node_feat)\n\n        node_feat = torch.cat([node_feat, torch.cat(aggr_feat.split(num_data, 1), -1)], -1).transpose(1, 2)\n\n        # non-linear transform\n        node_feat = self.network(node_feat.unsqueeze(-1)).transpose(1, 2).squeeze(-1)\n        return node_feat\n\n\nclass EdgeUpdateNetwork(nn.Module):\n    def __init__(self,\n                 in_features,\n                 num_features,\n                 ratio=[2, 2, 1, 1],\n                 separate_dissimilarity=False,\n                 dropout=0.0):\n        super(EdgeUpdateNetwork, self).__init__()\n        # set size\n        self.in_features = in_features\n        self.num_features_list = [num_features * r for r in ratio]\n        self.separate_dissimilarity = separate_dissimilarity\n        self.dropout = dropout\n\n        # layers\n        layer_list = OrderedDict()\n        for l in range(len(self.num_features_list)):\n            # set layer\n            layer_list['conv{}'.format(l)] = nn.Conv2d(in_channels=self.num_features_list[l-1] if l > 0 else self.in_features,\n                                                       out_channels=self.num_features_list[l],\n                                                       kernel_size=1,\n                                                       bias=False)\n            layer_list['norm{}'.format(l)] = nn.BatchNorm2d(num_features=self.num_features_list[l],\n                                                            )\n            layer_list['relu{}'.format(l)] = nn.LeakyReLU()\n\n            if self.dropout > 0:\n                layer_list['drop{}'.format(l)] = nn.Dropout2d(p=self.dropout)\n\n        layer_list['conv_out'] = nn.Conv2d(in_channels=self.num_features_list[-1],\n                                           out_channels=1,\n                                           kernel_size=1)\n        self.sim_network = nn.Sequential(layer_list)\n\n        if self.separate_dissimilarity:\n            # layers\n            layer_list = OrderedDict()\n            for l in range(len(self.num_features_list)):\n                # set layer\n                layer_list['conv{}'.format(l)] = nn.Conv2d(in_channels=self.num_features_list[l-1] if l > 0 else self.in_features,\n                                                           out_channels=self.num_features_list[l],\n                                                           kernel_size=1,\n                                                           bias=False)\n                layer_list['norm{}'.format(l)] = nn.BatchNorm2d(num_features=self.num_features_list[l],\n                                                                )\n                layer_list['relu{}'.format(l)] = nn.LeakyReLU()\n\n                if self.dropout > 0:\n                    layer_list['drop{}'.format(l)] = nn.Dropout(p=self.dropout)\n\n            layer_list['conv_out'] = nn.Conv2d(in_channels=self.num_features_list[-1],\n                                               out_channels=1,\n                                               kernel_size=1)\n            self.dsim_network = nn.Sequential(layer_list)\n\n    def forward(self, node_feat, edge_feat):\n        # compute abs(x_i, x_j)\n        x_i = node_feat.unsqueeze(2)\n        x_j = torch.transpose(x_i, 1, 2)\n        x_ij = torch.abs(x_i - x_j)\n        x_ij = torch.transpose(x_ij, 1, 3)\n\n        # compute similarity/dissimilarity (batch_size x feat_size x num_samples x num_samples)\n        sim_val = F.sigmoid(self.sim_network(x_ij))\n\n        if self.separate_dissimilarity:\n            dsim_val = F.sigmoid(self.dsim_network(x_ij))\n        else:\n            dsim_val = 1.0 - sim_val\n\n\n        diag_mask = 1.0 - torch.eye(node_feat.size(1)).unsqueeze(0).unsqueeze(0).repeat(node_feat.size(0), 2, 1, 1).to(tt.arg.device)\n        edge_feat = edge_feat * diag_mask\n        merge_sum = torch.sum(edge_feat, -1, True)\n        # set diagonal as zero and normalize\n        edge_feat = F.normalize(torch.cat([sim_val, dsim_val], 1) * edge_feat, p=1, dim=-1) * merge_sum\n        force_edge_feat = torch.cat((torch.eye(node_feat.size(1)).unsqueeze(0), torch.zeros(node_feat.size(1), node_feat.size(1)).unsqueeze(0)), 0).unsqueeze(0).repeat(node_feat.size(0), 1, 1, 1).to(tt.arg.device)\n        edge_feat = edge_feat + force_edge_feat\n        edge_feat = edge_feat + 1e-6\n        edge_feat = edge_feat / torch.sum(edge_feat, dim=1).unsqueeze(1).repeat(1, 2, 1, 1)\n\n        return edge_feat\n\n\nclass GraphNetwork(nn.Module):\n    def __init__(self,\n                 in_features,\n                 node_features,\n                 edge_features,\n                 num_layers,\n                 dropout=0.0):\n        super(GraphNetwork, self).__init__()\n        # set size\n        self.in_features = in_features\n        self.node_features = node_features\n        self.edge_features = edge_features\n        self.num_layers = num_layers\n        self.dropout = dropout\n\n        # for each layer\n        for l in range(self.num_layers):\n            # set edge to node\n            edge2node_net = NodeUpdateNetwork(in_features=self.in_features if l == 0 else self.node_features,\n                                              num_features=self.node_features,\n                                              dropout=self.dropout if l < self.num_layers-1 else 0.0)\n\n            # set node to edge\n            node2edge_net = EdgeUpdateNetwork(in_features=self.node_features,\n                                              num_features=self.edge_features,\n                                              separate_dissimilarity=False,\n                                              dropout=self.dropout if l < self.num_layers-1 else 0.0)\n\n            self.add_module('edge2node_net{}'.format(l), edge2node_net)\n            self.add_module('node2edge_net{}'.format(l), node2edge_net)\n\n    # forward\n    def forward(self, node_feat, edge_feat):\n        # for each layer\n        edge_feat_list = []\n        for l in range(self.num_layers):\n            # (1) edge to node\n            node_feat = self._modules['edge2node_net{}'.format(l)](node_feat, edge_feat)\n\n            # (2) node to edge\n            edge_feat = self._modules['node2edge_net{}'.format(l)](node_feat, edge_feat)\n\n            # save edge feature\n            edge_feat_list.append(edge_feat)\n\n        # if tt.arg.visualization:\n        #     for l in range(self.num_layers):\n        #         ax = sns.heatmap(tt.nvar(edge_feat_list[l][0, 0, :, :]), xticklabels=False, yticklabels=False, linewidth=0.1,  cmap=\"coolwarm\",  cbar=False, square=True)\n        #         ax.get_figure().savefig('./visualization/edge_feat_layer{}.png'.format(l))\n\n\n        return edge_feat_list\n\n"
  },
  {
    "path": "torchtools/__init__.py",
    "content": "import numpy as np\nimport torch\nfrom torch import nn\nfrom torch import optim\nfrom torch import cuda\nfrom torch import utils\nfrom torch.nn import functional as F\nfrom torch.utils.data import *\nfrom torch.distributions import *\nfrom torchtools import tt\n\n\n__author__ = 'namju.kim@kakaobrain.com'\n\n\n# initialize seed\nif tt.arg.seed:\n    np.random.seed(tt.arg.seed)\n    torch.manual_seed(tt.arg.seed)\n"
  },
  {
    "path": "torchtools/_version.py",
    "content": "__version__ = '0.4.0'  # align version with pytorch\n\n"
  },
  {
    "path": "torchtools/tt/__init__.py",
    "content": "from torchtools.tt.arg import _parse_opts\nfrom torchtools.tt.utils import *\nfrom torchtools.tt.layer import *\nfrom torchtools.tt.logger import *\nfrom torchtools.tt.stat import *\nfrom torchtools.tt.trainer import *\n\n\n__author__ = 'namju.kim@kakaobrain.com'\n\n\n# global command line arguments\narg = _parse_opts()\n"
  },
  {
    "path": "torchtools/tt/arg.py",
    "content": "import sys\nimport configparser\nimport torch\nimport threading\nimport time\nimport os\n\n\n__author__ = 'namju.kim@kakaobrain.com'\n\n\n_config_time_stamp = 0\n\n\nclass _Opt(object):\n\n    def __len__(self):\n        return len(self.__dict__)\n\n    def __setitem__(self, key, value):\n        self.__dict__[key] = value\n\n    def __getitem__(self, item):\n        if item in self.__dict__:\n            return self.__dict__[item]\n        else:\n            return None\n\n    def __getattr__(self, item):\n        return self.__getitem__(item)\n\n\ndef _to_py_obj(x):\n    # check boolean first\n    if x.lower() in ['true', 'yes', 'on']:\n        return True\n    if x.lower() in ['false', 'no', 'off']:\n        return False\n    # from string to python object if possible\n    try:\n        obj = eval(x)\n        if type(obj).__name__ in ['int', 'float', 'tuple', 'list', 'dict', 'NoneType']:\n            x = obj\n    except:\n        pass\n    return x\n\n\ndef _parse_config(arg, file):\n\n    # read config file\n    config = configparser.ConfigParser()\n    config.read(file)\n    # traverse sections\n    for section in config.sections():\n        # traverse items\n        opt = _Opt()\n        for key in config[section]:\n            opt[key] = _to_py_obj(config[section][key])\n        # if default section, save items to global scope\n        if section.lower() == 'default':\n            for k, v in opt.__dict__.items():\n                arg[k] = v\n        else:\n            arg['_'.join(section.split())] = opt\n\n\ndef _parse_config_thread(arg, file):\n\n    global _config_time_stamp\n\n    while True:\n        # check timestamp\n        stamp = os.stat(file).st_mtime\n        if not stamp == _config_time_stamp:\n            # update timestamp\n            _config_time_stamp = stamp\n            # parse config file\n            _parse_config(arg, file)\n            # print result\n            # _print_opts(arg, 'CONFIGURATION CHANGE DETECTED')\n        # sleep\n        time.sleep(1)\n\n\ndef _print_opts(arg, header):\n    print(header, flush=True)\n    print('-' * 30, flush=True)\n    for k, v in arg.__dict__.items():\n        print('%s=%s' % (k, v), flush=True)\n    print('-' * 30, flush=True)\n\n\ndef _parse_opts():\n\n    global _config_time_stamp\n\n    # get command line arguments\n    arg = _Opt()\n    argv = sys.argv[1:]\n\n    # check length\n    assert len(argv) % 2 == 0, 'arguments should be paired with the format of --key value'\n\n    # parse args\n    for i in range(0, len(argv), 2):\n\n        # check format\n        assert argv[i].startswith('--'), 'arguments should be paired with the format of --key value'\n\n        # save argument\n        arg[argv[i][2:]] = _to_py_obj(argv[i + 1])\n\n        # check config file\n        if argv[i][2:].lower() == 'config':\n            _parse_config(arg, argv[i + 1])\n            _config_time_stamp = os.stat(argv[i + 1]).st_mtime\n\n    #\n    # inject default options\n    #\n\n    # device setting\n    if arg.device is None:\n        arg.device = 'cuda' if torch.cuda.is_available() else 'cpu'\n    arg.device = torch.device(arg.device)\n    arg.cuda = arg.device.type == 'cuda'\n\n    # default learning rate\n    #arg.lr = 1e-3\n\n    # directories\n    arg.log_dir = arg.log_dir or 'asset/log/'\n    arg.data_dir = arg.data_dir or 'asset/data/'\n    arg.save_dir = arg.save_dir or 'asset/train/'\n    arg.log_dir += '' if arg.log_dir.endswith('/') else '/'\n    arg.data_dir += '' if arg.data_dir.endswith('/') else '/'\n    arg.save_dir += '' if arg.save_dir.endswith('/') else '/'\n\n    # print arg option\n    # _print_opts(arg, 'CONFIGURATION')\n\n    # start config file watcher if config is defined\n    if arg.config:\n        t = threading.Thread(target=_parse_config_thread, args=(arg, arg.config))\n        t.daemon = True\n        t.start()\n\n    return arg\n"
  },
  {
    "path": "torchtools/tt/layer.py",
    "content": "from torchtools import nn\n\n\n#\n# Reshape layer for Sequential or ModuleList\n#\nclass Reshape(nn.Module):\n\n    def __init__(self, *shape):\n        super(Reshape, self).__init__()\n        self.shape = shape\n\n    def forward(self, x):\n        return x.reshape(self.shape)\n\n    def extra_repr(self):\n        return 'shape={}'.format(self.shape)"
  },
  {
    "path": "torchtools/tt/logger.py",
    "content": "import datetime\nimport time\nfrom tensorboardX import SummaryWriter\nfrom torchtools import tt\n\n\n__author__ = 'namju.kim@kakaobrain.com'\n\n\n# tensorboard writer\n_writer = None\n_stats_scalar, _stats_image, _stats_audio, _stats_text, _stats_hist = {}, {}, {}, {}, {}\n\n# time stamp\n_last_logged = time.time()\n\n\n# general print wrapper\ndef log(*args):\n    print(*args, flush=True)\n    # save to log_file\n    if tt.arg.log_file:\n        with open(tt.arg.log_dir + tt.arg.log_file, 'a') as f:\n            print(*args, flush=True, file=f)\n\n\n# tensor board writer\ndef _get_writer():\n    global _writer\n    if _writer is None:\n        # logging directory\n        tf_log_dir = tt.arg.log_dir\n        tf_log_dir += '' if tf_log_dir.endswith('/') else '/'\n        if tt.arg.experiment:\n            tf_log_dir += tt.arg.experiment\n        tf_log_dir += datetime.datetime.now().strftime('-%Y%m%d-%H%M%S')\n        # create writer\n        _writer = SummaryWriter(tf_log_dir)\n    return _writer\n\n\ndef log_scalar(tag, value, global_step=None):\n    _stats_scalar[tag] = (tt.nvar(value), global_step)\n\n\ndef log_audio(tag, audio, global_step=None):\n    _stats_audio[tag] = (tt.nvar(audio), global_step)\n\n\ndef log_image(tag, image, global_step=None):\n    _stats_image[tag] = (tt.nvar(image), global_step)\n\n\ndef log_text(tag, text, global_step=None):\n    _stats_text[tag] = (text, global_step)\n\n\ndef log_hist(tag, values, global_step=None):\n    _stats_hist[tag] = (tt.nvar(values), global_step)\n\n\ndef log_step(epoch=None, global_step=None, max_epoch=None, max_step=None):\n\n    global _last_logged, _last_logged_step, _stats_scalar, _stats_image, _stats_audio, _stats_text, _stats_hist\n\n    # logging\n    if (tt.arg.log_interval is None and tt.arg.log_step is None) or \\\n       (tt.arg.log_interval and time.time() - _last_logged >= tt.arg.log_interval) or \\\n       (tt.arg.log_step and global_step % tt.arg.log_step == 0):\n\n        # update logging time stamp\n        _last_logged = time.time()\n        _last_logged_step = global_step\n\n        # console output string\n        console_out = ''\n        if epoch:\n            console_out += 'ep: %d' % epoch\n            if max_epoch:\n                console_out += '/%d' % max_epoch\n        if global_step:\n            if max_step:\n                step = global_step % max_step\n                step = max_step if step == 0 else step\n                console_out += ' step: %d/%d' % (step, max_step)\n            else:\n                console_out += ' step: %d' % global_step\n\n        # add stats to tensor board\n        for k, v in _stats_scalar.items():\n            _get_writer().add_scalar(k, *v)\n            # add to console output\n            if not k.startswith('weight/') and not k.startswith('gradient/'):\n                console_out += ' %s: %f' % (k, v[0])\n        for k, v in _stats_image.items():\n            _get_writer().add_image(k, *v)\n        for k, v in _stats_audio.items():\n            _get_writer().add_audio(k, *v)\n        for k, v in _stats_text.items():\n            _get_writer().add_text(k, *v)\n        for k, v in _stats_hist.items():\n            _get_writer().add_histogram(k, *v, 'auto')\n\n        # flush\n        _get_writer().file_writer.flush()\n\n        # console out\n        if len(console_out) > 0:\n            log(console_out)\n\n        # clear stats\n        _stats_scalar, _stats_image, _stats_audio, _stats_text = {}, {}, {}, {}\n\n\ndef log_weight(model, global_step=None):\n    # weight statics\n    if tt.arg.log_weight:\n        for k, v in model.named_parameters():\n            if 'weight' in k:  # only for weight not bias\n                log_scalar('weight/' + k, v.norm(), global_step)\n\n\ndef log_gradient(model, global_step=None):\n    # gradient statics\n    if tt.arg.log_grad:\n        for k, v in model.named_parameters():\n            if 'weight' in k:  # only for weight not bias\n                if v.grad is not None:\n                    log_scalar('gradient/' + k, v.grad.norm(), global_step)\n"
  },
  {
    "path": "torchtools/tt/stat.py",
    "content": "from torchtools import tt\n\n\n__author__ = 'namju.kim@kakaobrain.com'\n\n\ndef accuracy(prob, label, ignore_index=-100):\n\n    # argmax\n    pred = prob.max(1)[1].type_as(label)\n\n    # masking\n    mask = label.ne(ignore_index)\n    pred = pred.masked_select(mask)\n    label = label.masked_select(mask)\n\n    # calc accuracy\n    hit = tt.nvar(pred.eq(label).long().sum())\n    acc = hit / label.size(0)\n    return acc\n"
  },
  {
    "path": "torchtools/tt/trainer.py",
    "content": "from torchtools import nn, optim, tt\n\n\n__author__ = 'namju.kim@kakaobrain.com'\n\n\nclass SupervisedTrainer(object):\n\n    def __init__(self, model, data_loader, optimizer=None, criterion=None):\n        self.global_step = 0\n        self.model = model.to(tt.arg.device)\n        self.data_loader = data_loader\n        self.optimizer = optimizer or optim.Adam(model.parameters())\n        self.criterion = criterion or nn.CrossEntropyLoss()\n\n    def train(self, inputs):\n\n        # split inputs\n        x, y = inputs\n\n        # forward\n        if tt.arg.cuda:\n            z = nn.DataParallel(self.model)(x)\n        else:\n            z = self.model(x)\n\n        # loss\n        loss = self.criterion(z, y)\n\n        # accuracy\n        acc = tt.accuracy(z, y)\n\n        # update model\n        self.optimizer.zero_grad()\n        loss.backward()\n        self.optimizer.step()\n\n        # logging\n        tt.log_scalar('loss', loss, self.global_step)\n        tt.log_scalar('acc', acc, self.global_step)\n\n    def epoch(self, ep_no=None):\n        pass\n\n    def run(self):\n\n        # experiment name\n        tt.arg.experiment = tt.arg.experiment or self.model.__class__.__name__.lower()\n\n        # load model\n        self.global_step = self.model.load_model()\n        epoch, min_step = divmod(self.global_step, len(self.data_loader))\n\n        # epochs\n        while epoch < (tt.arg.epoch or 1):\n            epoch += 1\n\n            # iterations\n            for step, inputs in enumerate(self.data_loader, min_step + 1):\n\n                # check step counter\n                if step > len(self.data_loader):\n                    break\n\n                # increase global step count\n                self.global_step += 1\n\n                # update learning rate\n                for param_group in self.optimizer.param_groups:\n                    param_group['lr'] = tt.arg.lr\n\n                # call train func\n                if type(inputs) in [list, tuple]:\n                    self.train([tt.var(d) for d in inputs])\n                else:\n                    self.train(tt.var(inputs))\n\n                # logging\n                tt.log_weight(self.model, global_step=self.global_step)\n                tt.log_gradient(self.model, global_step=self.global_step)\n                tt.log_step(epoch=epoch, global_step=self.global_step,\n                            max_epoch=(tt.arg.epoch or 1), max_step=len(self.data_loader))\n\n                # save model\n                self.model.save_model(self.global_step)\n\n            # epoch handler\n            self.epoch(epoch)\n\n        # save final model\n        self.model.save_model(self.global_step, force=True)\n"
  },
  {
    "path": "torchtools/tt/utils.py",
    "content": "import os\nimport datetime\nimport time\nimport pathlib\nfrom torchtools import torch, nn, tt\n\n\n__author__ = 'namju.kim@kakaobrain.com'\n\n\n# time stamp\n_tic_start = _last_saved = _last_archived = time.time()\n# best statics\n_best = -100000000.\n\n\ndef tic():\n    global _tic_start\n    _tic_start = time.time()\n    return _tic_start\n\n\ndef toc(tic=None):\n    global _tic_start\n    if tic is None:\n        return time.time() - _tic_start\n    else:\n        return time.time() - tic\n\n\ndef sleep(seconds):\n    time.sleep(seconds)\n\n\n#\n# automatic device-aware torch.tensor\n#\ndef var(data, dtype=None, device=None, requires_grad=False):\n    # return torch.tensor(data, dtype=dtype, device=(device or tt.arg.device), requires_grad=requires_grad)\n    # the upper code doesn't work, so work around as following. ( maybe bug )\n    return torch.tensor(data, dtype=dtype, requires_grad=requires_grad).to((device or tt.arg.device))\n\n\ndef vars(x_list, dtype=None, device=None, requires_grad=False):\n    return [var(x, dtype, device, requires_grad) for x in x_list]\n\n\n# for old torchtools compatibility\ndef cvar(x):\n    return x.detach()\n\n\n#\n# to python or numpy variable(s)\n#\ndef nvar(x):\n    if isinstance(x, torch.Tensor):\n        x = x.detach().cpu()\n        x = x.item() if x.dim() == 0 else x.numpy()\n    return x\n\n\ndef nvars(x_list):\n    return [nvar(x) for x in x_list]\n\n\ndef load_model(model, best=False, postfix=None, experiment=None):\n    global _best\n\n    # model file name\n    filename = tt.arg.save_dir + '%s.pt' % (experiment or tt.arg.experiment or model.__class__.__name__.lower())\n    if postfix is not None:\n        filename = filename + '.%s' % postfix\n\n    # load model\n    global_step = 0\n    if os.path.exists(filename):\n        if best:\n            global_step, model_state, _best = torch.load(filename + '.best', map_location=lambda storage, loc: storage)\n        else:\n            global_step, model_state = torch.load(filename, map_location=lambda storage, loc: storage)\n        model.load_state_dict(model_state)\n\n    # update best stat\n    filename += '.best'\n    if os.path.exists(filename):\n        _, _, _best = torch.load(filename, map_location=lambda storage, loc: storage)\n\n    return global_step\n\n\ndef save_model(model, global_step, force=False, best=None, postfix=None):\n    global _last_saved, _last_archived, _best\n\n    # make directory\n    pathlib.Path(tt.arg.save_dir).mkdir(parents=True, exist_ok=True)\n\n    # filename to save\n    filename = '%s.pt' % (tt.arg.experiment or model.__class__.__name__.lower())\n    if postfix is not None:\n        filename = filename + '.%s' % postfix\n\n    # save model\n    if force or (tt.arg.save_interval and time.time() - _last_saved >= tt.arg.save_interval) or \\\n       (tt.arg.save_step and global_step % tt.arg.save_step == 0):\n        torch.save((global_step, model.state_dict()), tt.arg.save_dir + filename)\n        _last_saved = time.time()\n\n    # archive model\n    if (tt.arg.archive_interval and time.time() - _last_archived >= tt.arg.archive_interval) or \\\n       (tt.arg.archive_step and global_step % tt.arg.archive_step == 0):\n        # filename to archive\n        if tt.arg.archive_interval:\n            filename = filename + datetime.datetime.now().strftime('.%Y%m%d.%H%M%S')\n        else:\n            filename = filename + '.%d' % global_step\n        torch.save((global_step, model.state_dict()), tt.arg.save_dir + filename)\n        _last_archived = time.time()\n\n    # save best model\n    if best is not None and best > _best:\n        _best = best\n        filename = filename + '.best'\n        torch.save((global_step, model.state_dict(), best), tt.arg.save_dir + filename)\n\n\n# patch Module\nnn.Module.load_model = load_model\nnn.Module.save_model = save_model\n"
  },
  {
    "path": "train.py",
    "content": "from torchtools import *\nfrom data import MiniImagenetLoader, TieredImagenetLoader\nfrom model import EmbeddingImagenet, GraphNetwork, ConvNet\nimport shutil\nimport os\nimport random\n#import seaborn as sns\n\n\nclass ModelTrainer(object):\n    def __init__(self,\n                 enc_module,\n                 gnn_module,\n                 data_loader):\n        # set encoder and gnn\n        self.enc_module = enc_module.to(tt.arg.device)\n        self.gnn_module = gnn_module.to(tt.arg.device)\n\n        if tt.arg.num_gpus > 1:\n            print('Construct multi-gpu model ...')\n            self.enc_module = nn.DataParallel(self.enc_module, device_ids=[0, 1, 2, 3], dim=0)\n            self.gnn_module = nn.DataParallel(self.gnn_module, device_ids=[0, 1, 2, 3], dim=0)\n\n            print('done!\\n')\n\n        # get data loader\n        self.data_loader = data_loader\n\n        # set optimizer\n        self.module_params = list(self.enc_module.parameters()) + list(self.gnn_module.parameters())\n\n        # set optimizer\n        self.optimizer = optim.Adam(params=self.module_params,\n                                    lr=tt.arg.lr,\n                                    weight_decay=tt.arg.weight_decay)\n\n        # set loss\n        self.edge_loss = nn.BCELoss(reduction='none')\n\n        self.node_loss = nn.CrossEntropyLoss(reduction='none')\n\n        self.global_step = 0\n        self.val_acc = 0\n        self.test_acc = 0\n\n    def train(self):\n        val_acc = self.val_acc\n\n        # set edge mask (to distinguish support and query edges)\n        num_supports = tt.arg.num_ways_train * tt.arg.num_shots_train\n        num_queries = tt.arg.num_ways_train * 1\n        num_samples = num_supports + num_queries\n        support_edge_mask = torch.zeros(tt.arg.meta_batch_size, num_samples, num_samples).to(tt.arg.device)\n        support_edge_mask[:, :num_supports, :num_supports] = 1\n        query_edge_mask = 1 - support_edge_mask\n\n        evaluation_mask = torch.ones(tt.arg.meta_batch_size, num_samples, num_samples).to(tt.arg.device)\n        # for semi-supervised setting, ignore unlabeled support sets for evaluation\n        for c in range(tt.arg.num_ways_train):\n            evaluation_mask[:,\n            ((c + 1) * tt.arg.num_shots_train - tt.arg.num_unlabeled):(c + 1) * tt.arg.num_shots_train,\n            :num_supports] = 0\n            evaluation_mask[:, :num_supports,\n            ((c + 1) * tt.arg.num_shots_train - tt.arg.num_unlabeled):(c + 1) * tt.arg.num_shots_train] = 0\n\n        # for each iteration\n        for iter in range(self.global_step + 1, tt.arg.train_iteration + 1):\n            # init grad\n            self.optimizer.zero_grad()\n\n            # set current step\n            self.global_step = iter\n\n            # load task data list\n            [support_data,\n             support_label,\n             query_data,\n             query_label] = self.data_loader['train'].get_task_batch(num_tasks=tt.arg.meta_batch_size,\n                                                                     num_ways=tt.arg.num_ways_train,\n                                                                     num_shots=tt.arg.num_shots_train,\n                                                                     seed=iter + tt.arg.seed)\n\n            # set as single data\n            full_data = torch.cat([support_data, query_data], 1)\n            full_label = torch.cat([support_label, query_label], 1)\n            full_edge = self.label2edge(full_label)\n\n            # set init edge\n            init_edge = full_edge.clone()  # batch_size x 2 x num_samples x num_samples\n            init_edge[:, :, num_supports:, :] = 0.5\n            init_edge[:, :, :, num_supports:] = 0.5\n            for i in range(num_queries):\n                init_edge[:, 0, num_supports + i, num_supports + i] = 1.0\n                init_edge[:, 1, num_supports + i, num_supports + i] = 0.0\n\n            # for semi-supervised setting,\n            for c in range(tt.arg.num_ways_train):\n                init_edge[:, :, ((c+1) * tt.arg.num_shots_train - tt.arg.num_unlabeled):(c+1) * tt.arg.num_shots_train, :num_supports] = 0.5\n                init_edge[:, :, :num_supports, ((c+1) * tt.arg.num_shots_train - tt.arg.num_unlabeled):(c+1) * tt.arg.num_shots_train] = 0.5\n\n            # set as train mode\n            self.enc_module.train()\n            self.gnn_module.train()\n\n            # (1) encode data\n            full_data = [self.enc_module(data.squeeze(1)) for data in full_data.chunk(full_data.size(1), dim=1)]\n            full_data = torch.stack(full_data, dim=1) # batch_size x num_samples x featdim\n\n            # (2) predict edge logit (consider only the last layer logit, num_tasks x 2 x num_samples x num_samples)\n            if tt.arg.train_transductive:\n                full_logit_layers = self.gnn_module(node_feat=full_data, edge_feat=init_edge)\n            else:\n                evaluation_mask[:, num_supports:, num_supports:] = 0 # ignore query-query edges, since it is non-transductive setting\n                # input_node_feat: (batch_size x num_queries) x (num_support + 1) x featdim\n                # input_edge_feat: (batch_size x num_queries) x 2 x (num_support + 1) x (num_support + 1)\n                support_data = full_data[:, :num_supports] # batch_size x num_support x featdim\n                query_data = full_data[:, num_supports:] # batch_size x num_query x featdim\n                support_data_tiled = support_data.unsqueeze(1).repeat(1, num_queries, 1, 1) # batch_size x num_queries x num_support x featdim\n                support_data_tiled = support_data_tiled.view(tt.arg.meta_batch_size * num_queries, num_supports, -1) # (batch_size x num_queries) x num_support x featdim\n                query_data_reshaped = query_data.contiguous().view(tt.arg.meta_batch_size * num_queries, -1).unsqueeze(1) # (batch_size x num_queries) x 1 x featdim\n                input_node_feat = torch.cat([support_data_tiled, query_data_reshaped], 1) # (batch_size x num_queries) x (num_support + 1) x featdim\n\n                input_edge_feat = 0.5 * torch.ones(tt.arg.meta_batch_size, 2, num_supports + 1, num_supports + 1).to(tt.arg.device) # batch_size x 2 x (num_support + 1) x (num_support + 1)\n\n                input_edge_feat[:, :, :num_supports, :num_supports] = init_edge[:, :, :num_supports, :num_supports] # batch_size x 2 x (num_support + 1) x (num_support + 1)\n                input_edge_feat = input_edge_feat.repeat(num_queries, 1, 1, 1) #(batch_size x num_queries) x 2 x (num_support + 1) x (num_support + 1)\n\n                # logit: (batch_size x num_queries) x 2 x (num_support + 1) x (num_support + 1)\n                logit_layers = self.gnn_module(node_feat=input_node_feat, edge_feat=input_edge_feat)\n\n                logit_layers = [logit_layer.view(tt.arg.meta_batch_size, num_queries, 2, num_supports + 1, num_supports + 1) for logit_layer in logit_layers]\n\n                # logit --> full_logit (batch_size x 2 x num_samples x num_samples)\n                full_logit_layers = []\n                for l in range(tt.arg.num_layers):\n                    full_logit_layers.append(torch.zeros(tt.arg.meta_batch_size, 2, num_samples, num_samples).to(tt.arg.device))\n\n                for l in range(tt.arg.num_layers):\n                    full_logit_layers[l][:, :, :num_supports, :num_supports] = logit_layers[l][:, :, :, :num_supports, :num_supports].mean(1)\n                    full_logit_layers[l][:, :, :num_supports, num_supports:] = logit_layers[l][:, :, :, :num_supports, -1].transpose(1, 2).transpose(2, 3)\n                    full_logit_layers[l][:, :, num_supports:, :num_supports] = logit_layers[l][:, :, :, -1, :num_supports].transpose(1, 2)\n\n            # (4) compute loss\n            full_edge_loss_layers = [self.edge_loss((1-full_logit_layer[:, 0]), (1-full_edge[:, 0])) for full_logit_layer in full_logit_layers]\n\n            # weighted edge loss for balancing pos/neg\n            pos_query_edge_loss_layers = [torch.sum(full_edge_loss_layer * query_edge_mask * full_edge[:, 0] * evaluation_mask) / torch.sum(query_edge_mask * full_edge[:, 0] * evaluation_mask) for full_edge_loss_layer in full_edge_loss_layers]\n            neg_query_edge_loss_layers = [torch.sum(full_edge_loss_layer * query_edge_mask * (1-full_edge[:, 0]) * evaluation_mask) / torch.sum(query_edge_mask * (1-full_edge[:, 0]) * evaluation_mask) for full_edge_loss_layer in full_edge_loss_layers]\n            query_edge_loss_layers = [pos_query_edge_loss_layer + neg_query_edge_loss_layer for (pos_query_edge_loss_layer, neg_query_edge_loss_layer) in zip(pos_query_edge_loss_layers, neg_query_edge_loss_layers)]\n\n            # compute accuracy\n            full_edge_accr_layers = [self.hit(full_logit_layer, 1-full_edge[:, 0].long()) for full_logit_layer in full_logit_layers]\n            query_edge_accr_layers = [torch.sum(full_edge_accr_layer * query_edge_mask * evaluation_mask) / torch.sum(query_edge_mask * evaluation_mask) for full_edge_accr_layer in full_edge_accr_layers]\n\n            # compute node loss & accuracy (num_tasks x num_quries x num_ways)\n            query_node_pred_layers = [torch.bmm(full_logit_layer[:, 0, num_supports:, :num_supports], self.one_hot_encode(tt.arg.num_ways_train, support_label.long())) for full_logit_layer in full_logit_layers] # (num_tasks x num_quries x num_supports) * (num_tasks x num_supports x num_ways)\n            query_node_accr_layers = [torch.eq(torch.max(query_node_pred_layer, -1)[1], query_label.long()).float().mean() for query_node_pred_layer in query_node_pred_layers]\n\n            total_loss_layers = query_edge_loss_layers\n\n            # update model\n            total_loss = []\n            for l in range(tt.arg.num_layers - 1):\n                total_loss += [total_loss_layers[l].view(-1) * 0.5]\n            total_loss += [total_loss_layers[-1].view(-1) * 1.0]\n            total_loss = torch.mean(torch.cat(total_loss, 0))\n\n            total_loss.backward()\n\n            self.optimizer.step()\n\n            # adjust learning rate\n            self.adjust_learning_rate(optimizers=[self.optimizer],\n                                      lr=tt.arg.lr,\n                                      iter=self.global_step)\n\n            # logging\n            tt.log_scalar('train/edge_loss', query_edge_loss_layers[-1], self.global_step)\n            tt.log_scalar('train/edge_accr', query_edge_accr_layers[-1], self.global_step)\n            tt.log_scalar('train/node_accr', query_node_accr_layers[-1], self.global_step)\n\n            # evaluation\n            if self.global_step % tt.arg.test_interval == 0:\n                val_acc = self.eval(partition='val')\n\n                is_best = 0\n\n                if val_acc >= self.val_acc:\n                    self.val_acc = val_acc\n                    is_best = 1\n\n                tt.log_scalar('val/best_accr', self.val_acc, self.global_step)\n\n                self.save_checkpoint({\n                    'iteration': self.global_step,\n                    'enc_module_state_dict': self.enc_module.state_dict(),\n                    'gnn_module_state_dict': self.gnn_module.state_dict(),\n                    'val_acc': val_acc,\n                    'optimizer': self.optimizer.state_dict(),\n                    }, is_best)\n\n            tt.log_step(global_step=self.global_step)\n\n    def eval(self, partition='test', log_flag=True):\n        best_acc = 0\n        # set edge mask (to distinguish support and query edges)\n        num_supports = tt.arg.num_ways_test * tt.arg.num_shots_test\n        num_queries = tt.arg.num_ways_test * 1\n        num_samples = num_supports + num_queries\n        support_edge_mask = torch.zeros(tt.arg.test_batch_size, num_samples, num_samples).to(tt.arg.device)\n        support_edge_mask[:, :num_supports, :num_supports] = 1\n        query_edge_mask = 1 - support_edge_mask\n        evaluation_mask = torch.ones(tt.arg.test_batch_size, num_samples, num_samples).to(tt.arg.device)\n        # for semi-supervised setting, ignore unlabeled support sets for evaluation\n        for c in range(tt.arg.num_ways_test):\n            evaluation_mask[:,\n            ((c + 1) * tt.arg.num_shots_test - tt.arg.num_unlabeled):(c + 1) * tt.arg.num_shots_test,\n            :num_supports] = 0\n            evaluation_mask[:, :num_supports,\n            ((c + 1) * tt.arg.num_shots_test - tt.arg.num_unlabeled):(c + 1) * tt.arg.num_shots_test] = 0\n\n        query_edge_losses = []\n        query_edge_accrs = []\n        query_node_accrs = []\n\n        # for each iteration\n        for iter in range(tt.arg.test_iteration//tt.arg.test_batch_size):\n            # load task data list\n            [support_data,\n             support_label,\n             query_data,\n             query_label] = self.data_loader[partition].get_task_batch(num_tasks=tt.arg.test_batch_size,\n                                                                       num_ways=tt.arg.num_ways_test,\n                                                                       num_shots=tt.arg.num_shots_test,\n                                                                       seed=iter)\n\n            # set as single data\n            full_data = torch.cat([support_data, query_data], 1)\n            full_label = torch.cat([support_label, query_label], 1)\n            full_edge = self.label2edge(full_label)\n\n            # set init edge\n            init_edge = full_edge.clone()\n            init_edge[:, :, num_supports:, :] = 0.5\n            init_edge[:, :, :, num_supports:] = 0.5\n            for i in range(num_queries):\n                init_edge[:, 0, num_supports + i, num_supports + i] = 1.0\n                init_edge[:, 1, num_supports + i, num_supports + i] = 0.0\n\n            # for semi-supervised setting,\n            for c in range(tt.arg.num_ways_test):\n                init_edge[:, :, ((c+1) * tt.arg.num_shots_test - tt.arg.num_unlabeled):(c+1) * tt.arg.num_shots_test, :num_supports] = 0.5\n                init_edge[:, :, :num_supports, ((c+1) * tt.arg.num_shots_test - tt.arg.num_unlabeled):(c+1) * tt.arg.num_shots_test] = 0.5\n\n            # set as train mode\n            self.enc_module.eval()\n            self.gnn_module.eval()\n\n            # (1) encode data\n            full_data = [self.enc_module(data.squeeze(1)) for data in full_data.chunk(full_data.size(1), dim=1)]\n            full_data = torch.stack(full_data, dim=1)\n\n            # (2) predict edge logit (consider only the last layer logit, num_tasks x 2 x num_samples x num_samples)\n            if tt.arg.test_transductive:\n                full_logit_all = self.gnn_module(node_feat=full_data, edge_feat=init_edge)\n                full_logit = full_logit_all[-1]\n            else:\n                evaluation_mask[:, num_supports:, num_supports:] = 0  # ignore query-query edges, since it is non-transductive setting\n\n                full_logit = torch.zeros(tt.arg.test_batch_size, 2, num_samples, num_samples).to(tt.arg.device)\n\n                # input_node_feat: (batch_size x num_queries) x (num_support + 1) x featdim\n                # input_edge_feat: (batch_size x num_queries) x 2 x (num_support + 1) x (num_support + 1)\n                support_data = full_data[:, :num_supports] # batch_size x num_support x featdim\n                query_data = full_data[:, num_supports:] # batch_size x num_query x featdim\n                support_data_tiled = support_data.unsqueeze(1).repeat(1, num_queries, 1, 1) # batch_size x num_queries x num_support x featdim\n                support_data_tiled = support_data_tiled.view(tt.arg.test_batch_size * num_queries, num_supports, -1) # (batch_size x num_queries) x num_support x featdim\n                query_data_reshaped = query_data.contiguous().view(tt.arg.test_batch_size * num_queries, -1).unsqueeze(1) # (batch_size x num_queries) x 1 x featdim\n                input_node_feat = torch.cat([support_data_tiled, query_data_reshaped], 1) # (batch_size x num_queries) x (num_support + 1) x featdim\n\n                input_edge_feat = 0.5 * torch.ones(tt.arg.test_batch_size, 2, num_supports + 1, num_supports + 1).to(tt.arg.device)  # batch_size x 2 x (num_support + 1) x (num_support + 1)\n\n                input_edge_feat[:, :, :num_supports, :num_supports] = init_edge[:, :, :num_supports, :num_supports]  # batch_size x 2 x (num_support + 1) x (num_support + 1)\n                input_edge_feat = input_edge_feat.repeat(num_queries, 1, 1, 1)  # (batch_size x num_queries) x 2 x (num_support + 1) x (num_support + 1)\n\n                # logit: (batch_size x num_queries) x 2 x (num_support + 1) x (num_support + 1)\n                logit = self.gnn_module(node_feat=input_node_feat, edge_feat=input_edge_feat)[-1]\n\n                logit = logit.view(tt.arg.test_batch_size, num_queries, 2, num_supports + 1, num_supports + 1)\n\n                # batch_size x num_queries x 2 x (num_support + 1) x (num_support + 1)\n                # logit --> full_logit (batch_size x 2 x num_samples x num_samples)\n                full_logit[:, :, :num_supports, :num_supports] = logit[:, :, :, :num_supports, :num_supports].mean(1)\n                full_logit[:, :, :num_supports, num_supports:] = logit[:, :, :, :num_supports, -1].transpose(1, 2).transpose(2, 3)\n                full_logit[:, :, num_supports:, :num_supports] = logit[:, :, :, -1, :num_supports].transpose(1, 2)\n\n            # (4) compute loss\n            full_edge_loss = self.edge_loss(1-full_logit[:, 0], 1-full_edge[:, 0])\n\n            query_edge_loss =  torch.sum(full_edge_loss * query_edge_mask * evaluation_mask) / torch.sum(query_edge_mask * evaluation_mask)\n\n            # weighted loss for balancing pos/neg\n            pos_query_edge_loss = torch.sum(full_edge_loss * query_edge_mask * full_edge[:, 0] * evaluation_mask) / torch.sum(query_edge_mask * full_edge[:, 0] * evaluation_mask)\n            neg_query_edge_loss = torch.sum(full_edge_loss * query_edge_mask * (1-full_edge[:, 0]) * evaluation_mask) / torch.sum(query_edge_mask * (1-full_edge[:, 0]) * evaluation_mask)\n            query_edge_loss = pos_query_edge_loss + neg_query_edge_loss\n\n            # compute accuracy\n            full_edge_accr = self.hit(full_logit, 1-full_edge[:, 0].long())\n            query_edge_accr = torch.sum(full_edge_accr * query_edge_mask * evaluation_mask) / torch.sum(query_edge_mask * evaluation_mask)\n\n            # compute node accuracy (num_tasks x num_quries x num_ways)\n            query_node_pred = torch.bmm(full_logit[:, 0, num_supports:, :num_supports], self.one_hot_encode(tt.arg.num_ways_test, support_label.long())) # (num_tasks x num_quries x num_supports) * (num_tasks x num_supports x num_ways)\n            query_node_accr = torch.eq(torch.max(query_node_pred, -1)[1], query_label.long()).float().mean()\n\n            query_edge_losses += [query_edge_loss.item()]\n            query_edge_accrs += [query_edge_accr.item()]\n            query_node_accrs += [query_node_accr.item()]\n\n        # logging\n        if log_flag:\n            tt.log('---------------------------')\n            tt.log_scalar('{}/edge_loss'.format(partition), np.array(query_edge_losses).mean(), self.global_step)\n            tt.log_scalar('{}/edge_accr'.format(partition), np.array(query_edge_accrs).mean(), self.global_step)\n            tt.log_scalar('{}/node_accr'.format(partition), np.array(query_node_accrs).mean(), self.global_step)\n\n            tt.log('evaluation: total_count=%d, accuracy: mean=%.2f%%, std=%.2f%%, ci95=%.2f%%' %\n                   (iter,\n                    np.array(query_node_accrs).mean() * 100,\n                    np.array(query_node_accrs).std() * 100,\n                    1.96 * np.array(query_node_accrs).std() / np.sqrt(float(len(np.array(query_node_accrs)))) * 100))\n            tt.log('---------------------------')\n\n        return np.array(query_node_accrs).mean()\n\n    def adjust_learning_rate(self, optimizers, lr, iter):\n        new_lr = lr * (0.5 ** (int(iter / tt.arg.dec_lr)))\n\n        for optimizer in optimizers:\n            for param_group in optimizer.param_groups:\n                param_group['lr'] = new_lr\n\n    def label2edge(self, label):\n        # get size\n        num_samples = label.size(1)\n\n        # reshape\n        label_i = label.unsqueeze(-1).repeat(1, 1, num_samples)\n        label_j = label_i.transpose(1, 2)\n\n        # compute edge\n        edge = torch.eq(label_i, label_j).float().to(tt.arg.device)\n\n        # expand\n        edge = edge.unsqueeze(1)\n        edge = torch.cat([edge, 1 - edge], 1)\n        return edge\n\n    def hit(self, logit, label):\n        pred = logit.max(1)[1]\n        hit = torch.eq(pred, label).float()\n        return hit\n\n    def one_hot_encode(self, num_classes, class_idx):\n        return torch.eye(num_classes)[class_idx].to(tt.arg.device)\n\n    def save_checkpoint(self, state, is_best):\n        torch.save(state, 'asset/checkpoints/{}/'.format(tt.arg.experiment) + 'checkpoint.pth.tar')\n        if is_best:\n            shutil.copyfile('asset/checkpoints/{}/'.format(tt.arg.experiment) + 'checkpoint.pth.tar',\n                            'asset/checkpoints/{}/'.format(tt.arg.experiment) + 'model_best.pth.tar')\n\ndef set_exp_name():\n    exp_name = 'D-{}'.format(tt.arg.dataset)\n    exp_name += '_N-{}_K-{}_U-{}'.format(tt.arg.num_ways, tt.arg.num_shots, tt.arg.num_unlabeled)\n    exp_name += '_L-{}_B-{}'.format(tt.arg.num_layers, tt.arg.meta_batch_size)\n    exp_name += '_T-{}'.format(tt.arg.transductive)\n    exp_name += '_SEED-{}'.format(tt.arg.seed)\n\n    return exp_name\n\nif __name__ == '__main__':\n\n    tt.arg.device = 'cuda:0' if tt.arg.device is None else tt.arg.device\n    # replace dataset_root with your own\n    tt.arg.dataset_root = '/data/private/dataset'\n    tt.arg.dataset = 'mini' if tt.arg.dataset is None else tt.arg.dataset\n    tt.arg.num_ways = 5 if tt.arg.num_ways is None else tt.arg.num_ways\n    tt.arg.num_shots = 1 if tt.arg.num_shots is None else tt.arg.num_shots\n    tt.arg.num_unlabeled = 0 if tt.arg.num_unlabeled is None else tt.arg.num_unlabeled\n    tt.arg.num_layers = 3 if tt.arg.num_layers is None else tt.arg.num_layers\n    tt.arg.meta_batch_size = 40 if tt.arg.meta_batch_size is None else tt.arg.meta_batch_size\n    tt.arg.transductive = False if tt.arg.transductive is None else tt.arg.transductive\n    tt.arg.seed = 222 if tt.arg.seed is None else tt.arg.seed\n    tt.arg.num_gpus = 1 if tt.arg.num_gpus is None else tt.arg.num_gpus\n\n    tt.arg.num_ways_train = tt.arg.num_ways\n    tt.arg.num_ways_test = tt.arg.num_ways\n\n    tt.arg.num_shots_train = tt.arg.num_shots\n    tt.arg.num_shots_test = tt.arg.num_shots\n\n    tt.arg.train_transductive = tt.arg.transductive\n    tt.arg.test_transductive = tt.arg.transductive\n\n    # model parameter related\n    tt.arg.num_edge_features = 96\n    tt.arg.num_node_features = 96\n    tt.arg.emb_size = 128\n\n    # train, test parameters\n    tt.arg.train_iteration = 100000 if tt.arg.dataset == 'mini' else 200000\n    tt.arg.test_iteration = 10000\n    tt.arg.test_interval = 5000 if tt.arg.test_interval is None else tt.arg.test_interval\n    tt.arg.test_batch_size = 10\n    tt.arg.log_step = 1000 if tt.arg.log_step is None else tt.arg.log_step\n\n    tt.arg.lr = 1e-3\n    tt.arg.grad_clip = 5\n    tt.arg.weight_decay = 1e-6\n    tt.arg.dec_lr = 15000 if tt.arg.dataset == 'mini' else 30000\n    tt.arg.dropout = 0.1 if tt.arg.dataset == 'mini' else 0.0\n\n    tt.arg.experiment = set_exp_name() if tt.arg.experiment is None else tt.arg.experiment\n\n    print(set_exp_name())\n\n    #set random seed\n    np.random.seed(tt.arg.seed)\n    torch.manual_seed(tt.arg.seed)\n    torch.cuda.manual_seed_all(tt.arg.seed)\n    random.seed(tt.arg.seed)\n    torch.backends.cudnn.deterministic = True\n    torch.backends.cudnn.benchmark = False\n\n    tt.arg.log_dir_user = tt.arg.log_dir if tt.arg.log_dir_user is None else tt.arg.log_dir_user\n    tt.arg.log_dir = tt.arg.log_dir_user\n\n    if not os.path.exists('asset/checkpoints'):\n        os.makedirs('asset/checkpoints')\n    if not os.path.exists('asset/checkpoints/' + tt.arg.experiment):\n        os.makedirs('asset/checkpoints/' + tt.arg.experiment)\n\n\n    enc_module = EmbeddingImagenet(emb_size=tt.arg.emb_size)\n\n    gnn_module = GraphNetwork(in_features=tt.arg.emb_size,\n                              node_features=tt.arg.num_edge_features,\n                              edge_features=tt.arg.num_node_features,\n                              num_layers=tt.arg.num_layers,\n                              dropout=tt.arg.dropout)\n\n    if tt.arg.dataset == 'mini':\n        train_loader = MiniImagenetLoader(root=tt.arg.dataset_root, partition='train')\n        valid_loader = MiniImagenetLoader(root=tt.arg.dataset_root, partition='val')\n    elif tt.arg.dataset == 'tiered':\n        train_loader = TieredImagenetLoader(root=tt.arg.dataset_root, partition='train')\n        valid_loader = TieredImagenetLoader(root=tt.arg.dataset_root, partition='val')\n    else:\n        print('Unknown dataset!')\n\n    data_loader = {'train': train_loader,\n                   'val': valid_loader\n                   }\n\n    # create trainer\n    trainer = ModelTrainer(enc_module=enc_module,\n                           gnn_module=gnn_module,\n                           data_loader=data_loader)\n\n    trainer.train()\n"
  }
]