Copy disabled (too large)
Download .txt
Showing preview only (15,395K chars total). Download the full file to get everything.
Repository: yue-zhongqi/ifsl
Branch: master
Commit: 3503392a9365
Files: 195
Total size: 14.7 MB
Directory structure:
gitextract_l9qhes0e/
├── LEO/
│ ├── LICENSE
│ ├── config.py
│ ├── data.py
│ ├── ifsl_configs/
│ │ ├── __init__.py
│ │ ├── baseline_config.py
│ │ └── ifsl_config.py
│ ├── model.py
│ ├── model_test.py
│ ├── pretrain/
│ │ ├── miniImagenet_baseline_ResNet10_mean.npy
│ │ ├── miniImagenet_feat_wrn_mean.npy
│ │ ├── miniImagenet_sib_wrn_mean.npy
│ │ ├── miniImagenet_simpleshot_ResNet10_mean.npy
│ │ ├── miniImagenet_simpleshotwide_wideres_mean.npy
│ │ ├── norm_miniImagenet_baseline_ResNet10_mean.npy
│ │ ├── norm_miniImagenet_feat_wrn_mean.npy
│ │ ├── norm_miniImagenet_sib_wrn_mean.npy
│ │ ├── norm_miniImagenet_simpleshot_ResNet10_mean.npy
│ │ ├── norm_miniImagenet_simpleshotwide_wideres_mean.npy
│ │ ├── norm_tiered_simpleshot_ResNet10_mean.npy
│ │ ├── norm_tiered_simpleshotwide_wideres_mean.npy
│ │ ├── tiered_simpleshot_ResNet10_mean.npy
│ │ └── tiered_simpleshotwide_wideres_mean.npy
│ ├── readme.md
│ ├── runner.py
│ └── utils.py
├── MAML_MN_FT/
│ ├── README.md
│ ├── backbone.py
│ ├── configs.py
│ ├── data/
│ │ ├── __init__.py
│ │ ├── additional_transforms.py
│ │ ├── datamgr.py
│ │ ├── dataset.py
│ │ └── feature_loader.py
│ ├── filelists/
│ │ ├── CUB/
│ │ │ ├── attributes.txt
│ │ │ ├── base.json
│ │ │ ├── download_CUB.sh
│ │ │ ├── novel.json
│ │ │ ├── val.json
│ │ │ └── write_CUB_filelist.py
│ │ ├── miniImagenet/
│ │ │ ├── all.json
│ │ │ ├── base.json
│ │ │ ├── download_miniImagenet.sh
│ │ │ ├── novel.json
│ │ │ ├── test.csv
│ │ │ ├── train.csv
│ │ │ ├── val.csv
│ │ │ ├── val.json
│ │ │ ├── write_cross_filelist.py
│ │ │ └── write_miniImagenet_filelist.py
│ │ └── tiered/
│ │ └── write_tiered_filelist.py
│ ├── io_utils.py
│ ├── main.py
│ ├── methods/
│ │ ├── DMAML.py
│ │ ├── DMatchingNet.py
│ │ ├── MethodTester.py
│ │ ├── NNEDSplitNew.py
│ │ ├── PretrainedModel.py
│ │ ├── VanillaMAML.py
│ │ ├── VanillaMatchingNet.py
│ │ ├── __init__.py
│ │ ├── meta_template.py
│ │ └── meta_toolkits.py
│ ├── models/
│ │ ├── FeatWRN.py
│ │ ├── SimpleShotResNet.py
│ │ ├── SimpleShotWideResNet.py
│ │ └── __init__.py
│ ├── pretrain/
│ │ ├── miniImagenet_baseline_ResNet10_mean.npy
│ │ ├── miniImagenet_cosine_ResNet10_mean.npy
│ │ ├── miniImagenet_feat_wrn_mean.npy
│ │ ├── miniImagenet_sib_wrn_mean.npy
│ │ ├── miniImagenet_simpleshot_ResNet10_mean.npy
│ │ ├── miniImagenet_simpleshotwide_wideres_mean.npy
│ │ ├── norm_miniImagenet_baseline_ResNet10_mean.npy
│ │ ├── norm_miniImagenet_cosine_ResNet10_mean.npy
│ │ ├── norm_miniImagenet_feat_wrn_mean.npy
│ │ ├── norm_miniImagenet_sib_wrn_mean.npy
│ │ ├── norm_miniImagenet_simpleshot_ResNet10_mean.npy
│ │ ├── norm_miniImagenet_simpleshotwide_wideres_mean.npy
│ │ ├── norm_tiered_simpleshot_ResNet10_mean.npy
│ │ ├── norm_tiered_simpleshotwide_wideres_mean.npy
│ │ ├── tiered_simpleshot_ResNet10_mean.npy
│ │ └── tiered_simpleshotwide_wideres_mean.npy
│ ├── save_features.py
│ ├── tests/
│ │ ├── MetaTrain.py
│ │ └── __init__.py
│ └── utils.py
├── MTL/
│ ├── README.md
│ ├── configs/
│ │ ├── __init__.py
│ │ ├── baseline_config.py
│ │ ├── ifsl_resnet_config.py
│ │ └── ifsl_wrn_config.py
│ ├── dataloader/
│ │ ├── __init__.py
│ │ ├── dataset_loader.py
│ │ └── samplers.py
│ ├── main.py
│ ├── models/
│ │ ├── IFSL.py
│ │ ├── IFSL_modules.py
│ │ ├── IFSL_pretrain.py
│ │ ├── ResNet10.py
│ │ ├── WRN28.py
│ │ ├── __init__.py
│ │ ├── conv2d_mtl.py
│ │ ├── mtl.py
│ │ └── resnet_mtl.py
│ ├── pretrain/
│ │ ├── miniImagenet_baseline_ResNet10_mean.npy
│ │ ├── miniImagenet_feat_wrn_mean.npy
│ │ ├── miniImagenet_sib_wrn_mean.npy
│ │ ├── miniImagenet_simpleshot_ResNet10_mean.npy
│ │ ├── miniImagenet_simpleshotwide_wideres_mean.npy
│ │ ├── norm_miniImagenet_baseline_ResNet10_mean.npy
│ │ ├── norm_miniImagenet_feat_wrn_mean.npy
│ │ ├── norm_miniImagenet_sib_wrn_mean.npy
│ │ ├── norm_miniImagenet_simpleshot_ResNet10_mean.npy
│ │ ├── norm_miniImagenet_simpleshotwide_wideres_mean.npy
│ │ ├── norm_tiered_simpleshot_ResNet10_mean.npy
│ │ ├── norm_tiered_simpleshotwide_wideres_mean.npy
│ │ ├── tiered_simpleshot_ResNet10_mean.npy
│ │ └── tiered_simpleshotwide_wideres_mean.npy
│ ├── run_meta.py
│ ├── run_pre.py
│ ├── run_pre_clfs.py
│ ├── run_test.py
│ ├── setup.cfg
│ ├── trainer/
│ │ ├── __init__.py
│ │ ├── meta.py
│ │ └── pre.py
│ └── utils/
│ ├── __init__.py
│ ├── gpu_tools.py
│ ├── hacc.py
│ └── misc.py
├── SIB/
│ ├── PretrainedModel.py
│ ├── algorithm.py
│ ├── backbone.py
│ ├── config/
│ │ ├── minires_1_baseline.yaml
│ │ ├── minires_1_ifsl.yaml
│ │ ├── minires_5_baseline.yaml
│ │ ├── minires_5_ifsl.yaml
│ │ ├── miniwrn_1_baseline.yaml
│ │ ├── miniwrn_1_ifsl.yaml
│ │ ├── miniwrn_5_baseline.yaml
│ │ ├── miniwrn_5_ifsl.yaml
│ │ ├── tieredres_1_baseline.yaml
│ │ ├── tieredres_1_ifsl.yaml
│ │ ├── tieredres_5_baseline.yaml
│ │ ├── tieredres_5_ifsl.yaml
│ │ ├── tieredwrn_1_baseline.yaml
│ │ ├── tieredwrn_1_ifsl.yaml
│ │ ├── tieredwrn_5_baseline.yaml
│ │ └── tieredwrn_5_ifsl.yaml
│ ├── data/
│ │ ├── __init__.py
│ │ ├── additional_transforms.py
│ │ ├── datamgr.py
│ │ ├── dataset.py
│ │ ├── download_cifarfs.sh
│ │ ├── download_miniimagenet.sh
│ │ ├── feature_loader.py
│ │ └── get_cifarfs.py
│ ├── dataloader.py
│ ├── dataset.py
│ ├── deconfound/
│ │ ├── DSIB.py
│ │ ├── __init__.py
│ │ └── meta_toolkits.py
│ ├── dfsl_configs.py
│ ├── io_utils.py
│ ├── main.py
│ ├── main_feat.py
│ ├── networks.py
│ ├── pretrain/
│ │ ├── miniImagenet_baseline_ResNet10_mean.npy
│ │ ├── miniImagenet_sib_wrn_mean.npy
│ │ ├── miniImagenet_simpleshot_ResNet10_mean.npy
│ │ ├── norm_miniImagenet_baseline_ResNet10_mean.npy
│ │ ├── norm_miniImagenet_feat_wrn_mean.npy
│ │ ├── norm_miniImagenet_sib_wrn_mean.npy
│ │ ├── norm_miniImagenet_simpleshot_ResNet10_mean.npy
│ │ ├── norm_miniImagenet_simpleshotwide_wideres_mean.npy
│ │ ├── norm_tiered_simpleshot_ResNet10_mean.npy
│ │ ├── norm_tiered_simpleshotwide_wideres_mean.npy
│ │ ├── tiered_simpleshot_ResNet10_mean.npy
│ │ └── tiered_simpleshotwide_wideres_mean.npy
│ ├── readme.md
│ ├── requirements.txt
│ ├── setup.cfg
│ ├── sib.py
│ ├── simple_shot_models/
│ │ ├── Conv4.py
│ │ ├── DenseNet.py
│ │ ├── MobileNet.py
│ │ ├── ProtoNet.py
│ │ ├── ResNet.py
│ │ ├── WideResNet.py
│ │ └── __init__.py
│ └── utils/
│ ├── __init__.py
│ ├── config.py
│ ├── outils.py
│ └── utils.py
└── readme.md
================================================
FILE CONTENTS
================================================
================================================
FILE: LEO/LICENSE
================================================
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright 2018 DeepMind Technologies Limited
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
================================================
FILE: LEO/config.py
================================================
# coding=utf8
# Copyright 2018 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""A module containing just the configs for the different LEO parts."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl import flags
import os
import shutil
# python runner.py --config=mini_5_resnet_baseline
# python runner.py --config=mini_5_resnet_baseline --evaluation_mode=True
FLAGS = flags.FLAGS
flags.DEFINE_string("data_path", None, "Path to the dataset.")
flags.DEFINE_string(
"dataset_name", "miniImageNet", "Name of the dataset to "
"train on, which will be mapped to data.MetaDataset.")
flags.DEFINE_string(
"embedding_crop", "center", "Type of the cropping, which "
"will be mapped to data.EmbeddingCrop.")
flags.DEFINE_boolean("train_on_val", False, "Whether to train on the "
"validation data.")
flags.DEFINE_integer(
"inner_unroll_length", 5, "Number of unroll steps in the "
"inner loop of leo (number of adaptation steps in the "
"latent space).")
flags.DEFINE_integer(
"finetuning_unroll_length", 5, "Number of unroll steps "
"in the loop performing finetuning (number of adaptation "
"steps in the parameter space).")
flags.DEFINE_integer("num_latents", 64, "The dimensionality of the latent "
"space.")
flags.DEFINE_float(
"inner_lr_init", 1.0, "The initialization value for the "
"learning rate of the inner loop of leo.")
flags.DEFINE_float(
"finetuning_lr_init", 0.001, "The initialization value for "
"learning rate of the finetuning loop.")
flags.DEFINE_float("dropout_rate", 0.5, "Rate of dropout: probability of "
"dropping a given unit.")
flags.DEFINE_float(
"kl_weight", 1e-3, "The weight measuring importance of the "
"KL in the final loss. β in the paper.")
flags.DEFINE_float(
"encoder_penalty_weight", 1e-9, "The weight measuring "
"importance of the encoder penalty in the final loss. γ in "
"the paper.")
flags.DEFINE_float("l2_penalty_weight", 1e-8, "The weight measuring the "
"importance of the l2 regularization in the final loss. λ₁ "
"in the paper.")
flags.DEFINE_float("orthogonality_penalty_weight", 1e-3, "The weight measuring "
"the importance of the decoder orthogonality regularization "
"in the final loss. λ₂ in the paper.")
flags.DEFINE_integer(
"num_classes", 5, "Number of classes, N in N-way classification.")
flags.DEFINE_integer(
"num_tr_examples_per_class", 1, "Number of training samples per class, "
"K in K-shot classification.")
flags.DEFINE_integer(
"num_val_examples_per_class", 15, "Number of validation samples per class "
"in a task instance.")
flags.DEFINE_integer("metatrain_batch_size", 12, "Number of problem instances "
"in a batch.")
flags.DEFINE_integer("metavalid_batch_size", 200, "Number of meta-validation "
"problem instances.")
flags.DEFINE_integer("metatest_batch_size", 200, "Number of meta-testing "
"problem instances.")
flags.DEFINE_integer("num_steps_limit", int(1e5), "Number of steps to train "
"for.")
flags.DEFINE_float("outer_lr", 1e-4, "Outer (metatraining) loop learning "
"rate.")
flags.DEFINE_float(
"gradient_threshold", 0.1, "The cutoff for the gradient "
"clipping. Gradients will be clipped to "
"[-gradient_threshold, gradient_threshold]")
flags.DEFINE_float(
"gradient_norm_threshold", 0.1, "The cutoff for clipping of "
"the gradient norm. Gradient norm clipping will be applied "
"after pointwise clipping (described above).")
flags.DEFINE_string("config", "mini_5_resnet_baseline", "Configuration to use.")
flags.DEFINE_integer("feat_dim", 640, "Feature dimension.")
flags.DEFINE_boolean("deconfound", False, "Whether to deconfound.")
flags.DEFINE_boolean("use_test", False, "Whether to use test.")
flags.DEFINE_boolean("retrain", False, "Whether to discard saved checkpoints and retrain.")
flags.DEFINE_integer("num_pretrain_classes", 64, "Number of classes in pre-train dataset")
flags.DEFINE_string("pretrain_mean_filename", "miniImagenet_simpleshot_ResNet10_mean.npy", "Pretrain mean npy file name")
flags.DEFINE_integer("n_splits", 4, "Number of splits")
flags.DEFINE_boolean("is_cosine_feature", True, "Is it cosine feature")
flags.DEFINE_string("fusion", "concat", "How to fuse feature.")
flags.DEFINE_string("classifier", "single", "Classifier design")
flags.DEFINE_integer("pretrain_num_classes", 64, "Number of classes in pretrain dataset")
flags.DEFINE_string("logit_fusion", "product", "When using bi classifier, the logit fusion function to use")
flags.DEFINE_boolean("use_x_only", False, "Only using X feature")
flags.DEFINE_string("preprocess_before_split", "none", "Preprocessing before split")
flags.DEFINE_string("preprocess_after_split", "none", "Preprocessing after split")
flags.DEFINE_boolean("normalize_before_center", True, "Normalizing feature before centering operation")
flags.DEFINE_boolean("normalize_d", False, "Normalizing d features")
flags.DEFINE_boolean("normalize_ed", False, "Normalizing ed features")
flags.DEFINE_boolean("hacc", False, "Turn on saving hacc for evaluation")
flags.DEFINE_boolean("cross", False, "Evaluating on cross")
def get_data_config():
config = {}
config["data_path"] = FLAGS.data_path
config["dataset_name"] = FLAGS.dataset_name
config["embedding_crop"] = FLAGS.embedding_crop
config["train_on_val"] = FLAGS.train_on_val
config["total_examples_per_class"] = 600
return config
def get_inner_model_config():
"""Returns the config used to initialize LEO model."""
config = {}
config["inner_unroll_length"] = FLAGS.inner_unroll_length
config["finetuning_unroll_length"] = FLAGS.finetuning_unroll_length
config["num_latents"] = FLAGS.num_latents
config["inner_lr_init"] = FLAGS.inner_lr_init
config["finetuning_lr_init"] = FLAGS.finetuning_lr_init
config["dropout_rate"] = FLAGS.dropout_rate
config["kl_weight"] = FLAGS.kl_weight
config["encoder_penalty_weight"] = FLAGS.encoder_penalty_weight
config["l2_penalty_weight"] = FLAGS.l2_penalty_weight
config["orthogonality_penalty_weight"] = FLAGS.orthogonality_penalty_weight
config["feat_dim"] = FLAGS.feat_dim
config["pretrain_mean_filename"] = FLAGS.pretrain_mean_filename
return config
def get_outer_model_config():
"""Returns the outer config file for N-way K-shot classification tasks."""
config = {}
config["num_classes"] = FLAGS.num_classes
config["num_tr_examples_per_class"] = FLAGS.num_tr_examples_per_class
config["num_val_examples_per_class"] = FLAGS.num_val_examples_per_class
config["metatrain_batch_size"] = FLAGS.metatrain_batch_size
config["metavalid_batch_size"] = FLAGS.metavalid_batch_size
config["metatest_batch_size"] = FLAGS.metatest_batch_size
config["num_steps_limit"] = FLAGS.num_steps_limit
config["outer_lr"] = FLAGS.outer_lr
config["gradient_threshold"] = FLAGS.gradient_threshold
config["gradient_norm_threshold"] = FLAGS.gradient_norm_threshold
return config
def load_ifsl_config(config):
# dataset_name, number of pretrain classes
if config.dataset == "miniImagenet":
FLAGS.dataset_name = "miniImageNet"
FLAGS.num_pretrain_classes = 64
elif config.dataset == "tiered":
FLAGS.dataset_name = "tieredImageNet"
FLAGS.num_pretrain_classes = 351
# checkpoint path
FLAGS.checkpoint_path = "/data2/yuezhongqi/Model/leo/ifsl/" + config.dataset + "_" + config.model + "_" + \
str(config.shot) + "_" + config.meta_label
# data path
if config.model == "ResNet10":
model_abbr = "resnet"
elif config.model == "wideres":
model_abbr = "wrn"
FLAGS.data_path = "/data2/yuezhongqi/Model/leo/" + model_abbr + "_noaug_embeddings"
# pretrain mean filename
FLAGS.pretrain_mean_filename = config.dataset + "_" + config.method + "_" + config.model + "_mean.npy"
# shot
FLAGS.num_tr_examples_per_class = config.shot
# test iter: Default is 2000, which is desired
# deconfound
FLAGS.deconfound = config.deconfound
# feature dimension
if config.model == "ResNet10":
FLAGS.feat_dim = 512
elif config.model == "wideres":
FLAGS.feat_dim = 640
# evaluation mode
if FLAGS.evaluation_mode:
FLAGS.checkpoint_steps = 0
FLAGS.retrain = False
# hyperparameter settings
if config.shot == 5 and config.dataset == "miniImagenet":
FLAGS.outer_lr = 4.1024e-4
FLAGS.l2_penalty_weight = 8.54e-9
FLAGS.orthogonality_penalty_weight = 1.523998e-3
FLAGS.dropout_rate = 0.300299
FLAGS.kl_weight = 0.466387
FLAGS.encoder_penalty_weight = 2.661608e-7
elif config.shot == 1 and config.dataset == "miniImagenet":
FLAGS.outer_lr = 2.739071e-4
FLAGS.l2_penalty_weight = 3.623413e-10
FLAGS.orthogonality_penalty_weight = 0.188103
FLAGS.dropout_rate = 0.307651
FLAGS.kl_weight = 0.756143
FLAGS.encoder_penalty_weight = 5.756821e-6
elif config.shot == 1 and config.dataset == "tiered":
FLAGS.outer_lr = 8.659053e-4
FLAGS.l2_penalty_weight = 4.148858e-10
FLAGS.orthogonality_penalty_weight = 5.451078e-3
FLAGS.dropout_rate = 0.475126
FLAGS.kl_weight = 2.034189e-3
FLAGS.encoder_penalty_weight = 8.302962e-5
elif config.shot == 5 and config.dataset == "tiered":
FLAGS.outer_lr = 6.110314e-4
FLAGS.l2_penalty_weight = 1.690399e-10
FLAGS.orthogonality_penalty_weight = 2.481216e-2
FLAGS.dropout_rate = 0.415158
FLAGS.kl_weight = 1.622811
FLAGS.encoder_penalty_weight = 2.672450e-5
# retrain
if FLAGS.retrain:
if os.path.isdir(FLAGS.checkpoint_path):
shutil.rmtree(FLAGS.checkpoint_path)
# IFSL parameters
if config.deconfound:
FLAGS.n_splits = config.n_splits
FLAGS.is_cosine_feature = config.is_cosine_feature
FLAGS.fusion = config.fusion
FLAGS.classifier = config.classifier
FLAGS.pretrain_num_classes = config.num_classes
FLAGS.logit_fusion = config.logit_fusion
FLAGS.use_x_only = config.use_x_only
FLAGS.preprocess_before_split = config.preprocess_before_split
FLAGS.preprocess_after_split = config.preprocess_after_split
FLAGS.normalize_before_center = config.normalize_before_center
FLAGS.normalize_d = config.normalize_d
FLAGS.normalize_ed = config.normalize_ed
# Overwrite default hyperparameter settings
if hasattr(config, "outer_lr"):
FLAGS.outer_lr = config.outer_lr
================================================
FILE: LEO/data.py
================================================
# Copyright 2018 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Creates problem instances for LEO."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import os
import pickle
import random
import enum
import numpy as np
import six
import tensorflow as tf
NDIM = 512
ProblemInstance = collections.namedtuple(
"ProblemInstance",
["tr_input", "tr_output", "tr_info", "val_input", "val_output", "val_info", "tr_logit", "val_logit"])
class StrEnum(enum.Enum):
"""An Enum represented by a string."""
def __str__(self):
return self.value
def __repr__(self):
return self.__str__()
class MetaDataset(StrEnum):
"""Datasets supported by the DataProvider class."""
MINI = "miniImageNet"
TIERED = "tieredImageNet"
class EmbeddingCrop(StrEnum):
"""Embedding types supported by the DataProvider class."""
CENTER = "center"
MULTIVIEW = "multiview"
class MetaSplit(StrEnum):
"""Meta-datasets split supported by the DataProvider class."""
TRAIN = "train"
VALID = "val"
TEST = "test"
class DataProvider(object):
"""Creates problem instances from a specific split and dataset."""
def __init__(self, dataset_split, config, verbose=False, feat_dim=640, use_cross=False):
self._dataset_split = MetaSplit(dataset_split)
self._config = config
self._verbose = verbose
self._check_config()
self.NDIM = feat_dim
self.use_cross = use_cross
self._index_data(self._load_data())
def _check_config(self):
"""Checks configuration arguments of constructor."""
self._config["dataset_name"] = MetaDataset(self._config["dataset_name"])
self._config["embedding_crop"] = EmbeddingCrop(
self._config["embedding_crop"])
if self._config["dataset_name"] == MetaDataset.TIERED:
error_message = "embedding_crop: {} not supported for {}".format(
self._config["embedding_crop"], self._config["dataset_name"])
assert self._config[
"embedding_crop"] == EmbeddingCrop.CENTER, error_message
def _load_data(self):
"""Loads data into memory and caches ."""
raw_data = self._load(
tf.gfile.Open(self._get_full_pickle_path(self._dataset_split), "rb"))
if self._dataset_split == MetaSplit.TRAIN and self._config["train_on_val"]:
valid_data = self._load(
tf.gfile.Open(self._get_full_pickle_path(MetaSplit.VALID), "rb"))
for key in valid_data:
if self._verbose:
tf.logging.info(str([key, raw_data[key].shape]))
raw_data[key] = np.concatenate([raw_data[key],
valid_data[key]], axis=0)
if self._verbose:
tf.logging.info(str([key, raw_data[key].shape]))
if self._verbose:
tf.logging.info(
str([(k, np.shape(v)) for k, v in six.iteritems(raw_data)]))
return raw_data
def _load(self, opened_file):
if six.PY2:
result = pickle.load(opened_file)
else:
result = pickle.load(opened_file, encoding="latin1") # pylint: disable=unexpected-keyword-arg
return result
def _index_data(self, raw_data):
"""Builds an index of images embeddings by class."""
self._all_class_images = collections.OrderedDict()
self._image_embedding = collections.OrderedDict()
self._logit_embedding = collections.OrderedDict()
for i, k in enumerate(raw_data["keys"]):
_, class_label, image_file = k.split("-")
image_file_class_label = image_file.split("_")[0]
# assert class_label == image_file_class_label
self._image_embedding[image_file] = raw_data["embeddings"][i]
self._logit_embedding[image_file] = raw_data["logits"][i]
if class_label not in self._all_class_images:
self._all_class_images[class_label] = []
self._all_class_images[class_label].append(image_file)
self._check_data_index(raw_data)
self._all_class_images = collections.OrderedDict([
(k, np.array(v)) for k, v in six.iteritems(self._all_class_images)
])
if self._verbose:
tf.logging.info(str([len(raw_data), len(self._all_class_images),
len(self._image_embedding)]))
def _check_data_index(self, raw_data):
"""Performs checks of the data index and image counts per class."""
n = raw_data["keys"].shape[0]
error_message = "{} != {}".format(len(self._image_embedding), n)
assert len(self._image_embedding) == n, error_message
error_message = "{} != {}".format(raw_data["embeddings"].shape[0], n)
assert raw_data["embeddings"].shape[0] == n, error_message
all_class_folders = list(self._all_class_images.keys())
error_message = "no duplicate class names"
assert len(set(all_class_folders)) == len(all_class_folders), error_message
image_counts = set([len(class_images)
for class_images in self._all_class_images.values()])
error_message = ("len(image_counts) should have at least one element but "
"is: {}").format(image_counts)
assert len(image_counts) >= 1, error_message
assert min(image_counts) > 0
def _get_full_pickle_path(self, split_name):
if not self.use_cross:
full_pickle_path = os.path.join(
self._config["data_path"],
str(self._config["dataset_name"]),
str(self._config["embedding_crop"]),
"{}_embeddings.pkl".format(split_name))
else:
full_pickle_path = os.path.join(
self._config["data_path"],
"cross",
str(self._config["embedding_crop"]),
"{}_embeddings.pkl".format(split_name))
if self._verbose:
tf.logging.info("get_one_emb_instance: folder_path: {}".format(
full_pickle_path))
return full_pickle_path
def get_instance(self, num_classes, tr_size, val_size):
"""Samples a random N-way K-shot classification problem instance.
Args:
num_classes: N in N-way classification.
tr_size: K in K-shot; number of training examples per class.
val_size: number of validation examples per class.
Returns:
A tuple with 6 Tensors with the following shapes:
- tr_input: (num_classes, tr_size, NDIM): training image embeddings.
- tr_output: (num_classes, tr_size, 1): training image labels.
- tr_info: (num_classes, tr_size): training image file names.
- val_input: (num_classes, val_size, NDIM): validation image embeddings.
- val_output: (num_classes, val_size, 1): validation image labels.
- val_input: (num_classes, val_size): validation image file names.
- tr_logits: (num_classes, tr_size, #classes in pretrain)
- val_logits: (num_classes, val_size, #classes in pretrain)
"""
def _build_one_instance_py():
"""Builds a random problem instance using data from specified classes."""
class_list = list(self._all_class_images.keys())
sample_count = (tr_size + val_size)
shuffled_folders = class_list[:]
random.shuffle(shuffled_folders)
shuffled_folders = shuffled_folders[:num_classes]
error_message = "len(shuffled_folders) {} is not num_classes: {}".format(
len(shuffled_folders), num_classes)
assert len(shuffled_folders) == num_classes, error_message
image_paths = []
class_ids = []
embeddings = self._image_embedding
logits = self._logit_embedding
for class_id, class_name in enumerate(shuffled_folders):
all_images = self._all_class_images[class_name]
all_images = np.random.choice(all_images, sample_count, replace=False)
error_message = "{} == {} failed".format(len(all_images), sample_count)
assert len(all_images) == sample_count, error_message
image_paths.append(all_images)
class_ids.append([[class_id]]*sample_count)
label_array = np.array(class_ids, dtype=np.int32)
if self._verbose:
tf.logging.info(label_array.shape)
if self._verbose:
tf.logging.info(label_array.shape)
path_array = np.array(image_paths)
if self._verbose:
tf.logging.info(path_array.shape)
if self._verbose:
tf.logging.info(path_array.shape)
embedding_array = np.array([[embeddings[image_path]
for image_path in class_paths]
for class_paths in path_array])
logits_array = np.array([[logits[image_path]
for image_path in class_paths]
for class_paths in path_array])
if self._verbose:
tf.logging.info(embedding_array.shape)
return embedding_array, logits_array, label_array, path_array
output_list = tf.py_func(_build_one_instance_py, [],
[tf.float32, tf.float32, tf.int32, tf.string])
instance_input, instance_logits, instance_output, instance_info = output_list
# instance_input = tf.nn.l2_normalize(instance_input, axis=-1)
instance_info = tf.regex_replace(instance_info, "\x00*", "")
if self._verbose:
tf.logging.info("input_batch: {} ".format(instance_input.shape))
tf.logging.info("output_batch: {} ".format(instance_output.shape))
tf.logging.info("info_batch: {} ".format(instance_info.shape))
split_sizes = [tr_size, val_size]
tr_input, val_input = tf.split(instance_input, split_sizes, axis=1)
tr_logits, val_logits = tf.split(instance_logits, split_sizes, axis=1)
tr_output, val_output = tf.split(instance_output, split_sizes, axis=1)
tr_info, val_info = tf.split(instance_info, split_sizes, axis=1)
if self._verbose:
tf.logging.info("tr_output: {} ".format(tr_output))
tf.logging.info("val_output: {}".format(val_output))
with tf.control_dependencies(
self._check_labels(num_classes, tr_size, val_size,
tr_output, val_output)):
tr_output = tf.identity(tr_output)
val_output = tf.identity(val_output)
return tr_input, tr_output, tr_info, val_input, val_output, val_info, tr_logits, val_logits
def get_batch(self, batch_size, num_classes, tr_size, val_size,
num_threads=10, num_pretrain_classes=64):
"""Returns a batch of random N-way K-shot classification problem instances.
Args:
batch_size: number of problem instances in the batch.
num_classes: N in N-way classification.
tr_size: K in K-shot; number of training examples per class.
val_size: number of validation examples per class.
num_threads: number of threads used to sample problem instances in
parallel.
Returns:
A ProblemInstance of Tensors with the following shapes:
- tr_input: (batch_size, num_classes, tr_size, NDIM): training image
embeddings.
- tr_output: (batch_size, num_classes, tr_size, 1): training image
labels.
- tr_info: (batch_size, num_classes, tr_size): training image file
names.
- val_input: (batch_size, num_classes, val_size, NDIM): validation
image embeddings.
- val_output: (batch_size, num_classes, val_size, 1): validation
image labels.
- val_info: (batch_size, num_classes, val_size): validation image
file names.
"""
if self._verbose:
num_threads = 1
one_instance = self.get_instance(num_classes, tr_size, val_size)
tr_data_size = (num_classes, tr_size)
val_data_size = (num_classes, val_size)
task_batch = tf.train.shuffle_batch(one_instance, batch_size=batch_size,
capacity=1000, min_after_dequeue=0,
enqueue_many=False,
shapes=[tr_data_size + (self.NDIM,),
tr_data_size + (1,),
tr_data_size,
val_data_size + (self.NDIM,),
val_data_size + (1,),
val_data_size,
tr_data_size + (num_pretrain_classes,),
val_data_size + (num_pretrain_classes,)],
num_threads=num_threads)
if self._verbose:
tf.logging.info(task_batch)
return ProblemInstance(*task_batch)
def _check_labels(self, num_classes, tr_size, val_size,
tr_output, val_output):
correct_label_sum = (num_classes*(num_classes-1))//2
tr_label_sum = tf.reduce_sum(tr_output)/tr_size
val_label_sum = tf.reduce_sum(val_output)/val_size
all_label_asserts = [
tf.assert_equal(tf.to_int32(tr_label_sum), correct_label_sum),
tf.assert_equal(tf.to_int32(val_label_sum), correct_label_sum),
]
return all_label_asserts
================================================
FILE: LEO/ifsl_configs/__init__.py
================================================
from .baseline_config import *
from .ifsl_config import *
================================================
FILE: LEO/ifsl_configs/baseline_config.py
================================================
class Config():
def __init__(self):
self.is_config = True
def mini_5_resnet_baseline():
config = Config()
config.shot = 5
config.test = True
config.debug = False
config.dataset = "miniImagenet"
config.method = "simpleshot"
config.model = "ResNet10"
config.deconfound = False
config.meta_label = "baseline"
return config
def mini_1_resnet_baseline():
config = Config()
config.shot = 1
config.test = True
config.debug = False
config.dataset = "miniImagenet"
config.method = "simpleshot"
config.model = "ResNet10"
config.deconfound = False
config.meta_label = "baseline"
return config
def mini_5_wrn_baseline():
config = Config()
config.shot = 5
config.test = True
config.debug = False
config.dataset = "miniImagenet"
config.method = "simpleshotwide"
config.model = "wideres"
config.deconfound = False
config.meta_label = "baseline"
return config
def mini_1_wrn_baseline():
config = Config()
config.shot = 1
config.test = True
config.debug = False
config.dataset = "miniImagenet"
config.method = "simpleshotwide"
config.model = "wideres"
config.deconfound = False
config.meta_label = "baseline"
return config
def tiered_5_resnet_baseline():
config = Config()
config.shot = 5
config.test = True
config.debug = False
config.dataset = "tiered"
config.method = "simpleshot"
config.model = "ResNet10"
config.deconfound = False
config.meta_label = "baseline"
config.num_classes = 351
return config
def tiered_1_resnet_baseline():
config = Config()
config.shot = 1
config.test = True
config.debug = False
config.dataset = "tiered"
config.method = "simpleshot"
config.model = "ResNet10"
config.deconfound = False
config.meta_label = "baseline"
return config
def tiered_5_wrn_baseline():
config = Config()
config.shot = 5
config.test = True
config.debug = False
config.dataset = "tiered"
config.method = "simpleshotwide"
config.model = "wideres"
config.deconfound = False
config.meta_label = "baseline"
config.num_classes = 351
return config
def tiered_1_wrn_baseline():
config = Config()
config.shot = 1
config.test = True
config.debug = False
config.dataset = "tiered"
config.method = "simpleshotwide"
config.model = "wideres"
config.deconfound = False
config.meta_label = "baseline"
return config
================================================
FILE: LEO/ifsl_configs/ifsl_config.py
================================================
class Config():
def __init__(self):
self.is_config = True
def mini_5_resnet_ifsl():
config = Config()
config.shot = 5
config.test = True
config.debug = False
config.dataset = "miniImagenet"
config.method = "simpleshot"
config.model = "ResNet10"
config.deconfound = True
config.meta_label = "ifsl"
# IFSL parameters
config.n_splits = 8
config.fusion = "+"
config.classifier = "single"
config.num_classes = 64
config.logit_fusion = "product"
config.use_x_only = False
config.preprocess_before_split = "cl2n"
config.preprocess_after_split = "l2n"
config.is_cosine_feature = True
config.normalize_before_center = True
config.normalize_d = False
config.normalize_ed = False
# config.outer_lr = 1.51024e-4
return config
def mini_1_resnet_ifsl():
config = Config()
config.shot = 1
config.test = True
config.debug = False
config.dataset = "miniImagenet"
config.method = "simpleshot"
config.model = "ResNet10"
config.deconfound = True
config.meta_label = "ifsl"
# IFSL parameters
config.n_splits = 8
config.fusion = "+"
config.classifier = "single"
config.num_classes = 64
config.logit_fusion = "product"
config.use_x_only = False
config.preprocess_before_split = "cl2n"
config.preprocess_after_split = "l2n"
config.is_cosine_feature = True
config.normalize_before_center = True
config.normalize_d = False
config.normalize_ed = False
return config
def mini_5_wrn_ifsl():
config = Config()
config.shot = 5
config.test = True
config.debug = False
config.dataset = "miniImagenet"
config.method = "simpleshotwide"
config.model = "wideres"
config.deconfound = True
config.meta_label = "ifsl_lr"
# IFSL parameters
config.n_splits = 8
config.fusion = "+"
config.classifier = "single"
config.num_classes = 64
config.logit_fusion = "product"
config.use_x_only = False
config.preprocess_before_split = "cl2n"
config.preprocess_after_split = "l2n"
config.is_cosine_feature = True
config.normalize_before_center = True
config.normalize_d = False
config.normalize_ed = False
config.outer_lr = 2.61024e-4
return config
def mini_1_wrn_ifsl():
config = Config()
config.shot = 1
config.test = True
config.debug = False
config.dataset = "miniImagenet"
config.method = "simpleshotwide"
config.model = "wideres"
config.deconfound = True
config.meta_label = "ifsl_lr"
# IFSL parameters
config.n_splits = 8
config.fusion = "+"
config.classifier = "single"
config.num_classes = 64
config.logit_fusion = "product"
config.use_x_only = False
config.preprocess_before_split = "cl2n"
config.preprocess_after_split = "l2n"
config.is_cosine_feature = True
config.normalize_before_center = True
config.normalize_d = False
config.normalize_ed = False
config.outer_lr = 1.51024e-4
return config
def tiered_5_resnet_ifsl():
config = Config()
config.shot = 5
config.test = True
config.debug = False
config.dataset = "tiered"
config.method = "simpleshot"
config.model = "ResNet10"
config.deconfound = True
config.meta_label = "split"
# IFSL parameters
config.n_splits = 8
config.fusion = "+"
config.classifier = "single"
config.num_classes = 351
config.logit_fusion = "product"
config.use_x_only = False
config.preprocess_before_split = "cl2n"
config.preprocess_after_split = "l2n"
config.is_cosine_feature = True
config.normalize_before_center = True
config.normalize_d = False
config.normalize_ed = False
config.outer_lr = 2.669053e-4
return config
def tiered_1_resnet_ifsl():
config = Config()
config.shot = 1
config.test = True
config.debug = False
config.dataset = "tiered"
config.method = "simpleshot"
config.model = "ResNet10"
config.deconfound = True
config.meta_label = "ifsl"
# IFSL parameters
config.n_splits = 8
config.fusion = "+"
config.classifier = "single"
config.num_classes = 351
config.logit_fusion = "product"
config.use_x_only = False
config.preprocess_before_split = "cl2n"
config.preprocess_after_split = "l2n"
config.is_cosine_feature = True
config.normalize_before_center = True
config.normalize_d = False
config.normalize_ed = False
config.outer_lr = 2.669053e-4
return config
def tiered_5_wrn_ifsl():
config = Config()
config.shot = 5
config.test = True
config.debug = False
config.dataset = "tiered"
config.method = "simpleshotwide"
config.model = "wideres"
config.deconfound = True
config.meta_label = "ifsl"
# IFSL parameters
config.n_splits = 8
config.fusion = "+"
config.classifier = "single"
config.num_classes = 351
config.logit_fusion = "product"
config.use_x_only = False
config.preprocess_before_split = "cl2n"
config.preprocess_after_split = "l2n"
config.is_cosine_feature = True
config.normalize_before_center = True
config.normalize_d = False
config.normalize_ed = False
return config
def tiered_1_wrn_ifsl():
config = Config()
config.shot = 1
config.test = True
config.debug = False
config.dataset = "tiered"
config.method = "simpleshotwide"
config.model = "wideres"
config.deconfound = True
config.meta_label = "ifsl"
# IFSL parameters
config.n_splits = 8
config.fusion = "+"
config.classifier = "single"
config.num_classes = 351
config.logit_fusion = "product"
config.use_x_only = False
config.preprocess_before_split = "cl2n"
config.preprocess_after_split = "l2n"
config.is_cosine_feature = True
config.normalize_before_center = True
config.normalize_d = False
config.normalize_ed = False
return config
================================================
FILE: LEO/model.py
================================================
# Copyright 2018 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Code defining LEO inner loop.
See "Meta-Learning with Latent Embedding Optimization" by Rusu et al.
(https://arxiv.org/pdf/1807.05960.pdf).
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import functools
import numpy as np
from six.moves import range
from six.moves import zip
import sonnet as snt
import tensorflow as tf
import tensorflow_probability as tfp
import data as data_module
def get_orthogonality_regularizer(orthogonality_penalty_weight):
"""Returns the orthogonality regularizer."""
def orthogonality(weight):
"""Calculates the layer-wise penalty encouraging orthogonality."""
with tf.name_scope(None, "orthogonality", [weight]) as name:
w2 = tf.matmul(weight, weight, transpose_b=True)
wn = tf.norm(weight, ord=2, axis=1, keepdims=True) + 1e-32
correlation_matrix = w2 / tf.matmul(wn, wn, transpose_b=True)
# print(tf.matmul(wn, wn, transpose_b=True).get_shape()) output:64 * 64
matrix_size = correlation_matrix.get_shape().as_list()[0]
base_dtype = weight.dtype.base_dtype
identity = tf.eye(matrix_size, dtype=base_dtype)
# print(matrix_size) output: 64
weight_corr = tf.reduce_mean(
tf.squared_difference(correlation_matrix, identity))
# tf.print(weight_corr)
return tf.multiply(
tf.cast(orthogonality_penalty_weight, base_dtype),
weight_corr,
name=name)
return orthogonality
def run_leo(model, inputs, is_meta_training):
"""Returns batched loss and accuracy of the model ran on the inputs."""
call_fn = functools.partial(
model.__call__, is_meta_training=is_meta_training)
per_instance_loss, per_instance_accuracy = tf.map_fn(
call_fn,
inputs,
dtype=(tf.float32, tf.float32),
back_prop=is_meta_training)
# loss = tf.reduce_mean(per_instance_loss)
# accuracy = tf.reduce_mean(per_instance_accuracy)
return per_instance_loss, per_instance_accuracy
class FeatureProcessor():
def __init__(self, n_splits, pretrain_mean_filename, feat_dim, is_cosine_feature=False, num_classes=64,
preprocess_after_split="none", preprocess_before_split="none", normalize_before_center=False,
normalize_d=False, normalize_ed=False):
self.feat_dim = feat_dim
self.n_splits = n_splits
self.num_classes = num_classes
self.is_cosine_feature = is_cosine_feature
self.preprocess_after_split = preprocess_after_split
self.preprocess_before_split = preprocess_before_split
self.normalize_before_center = normalize_before_center
self.normalize_d = normalize_d
self.normalize_ed = normalize_ed
# Load pre-train means from npy file
if self.is_cosine_feature:
pretrain_mean_filename = "norm_" + pretrain_mean_filename
pretrain_features = np.load("pretrain/" + pretrain_mean_filename) # num_classes * feat_dim
self.pretrain_features = tf.convert_to_tensor(pretrain_features, dtype=tf.float32)
if self.normalize_d:
self.pretrain_features = tf.nn.l2_normalize(self.pretrain_features, axis=-1)
self.pretrain_features_mean = tf.reduce_mean(self.pretrain_features, axis=0)
def get_d_features(self, logit):
prob = tf.nn.softmax(logit, axis=-1) # 5 * N * 64
d = tf.tensordot(prob, self.pretrain_features, axes=[[-1],[0]])
return d
def preprocess(self, data, center=None, method="none"):
if method == "none":
return data
elif method == "l2n":
return tf.nn.l2_normalize(data, axis=-1)
elif method == "cl2n":
if self.normalize_before_center:
data = tf.nn.l2_normalize(data, axis=-1)
return tf.nn.l2_normalize(data - center, axis=-1)
def get_split_features(self, data, center, method="none"):
split_dim = int(self.feat_dim / self.n_splits)
split_data = []
for i in range(self.n_splits):
start_idx = split_dim * i
end_idx = split_dim * i + split_dim
data_i = data[:, :, start_idx:end_idx]
if center is not None:
center_i = center[:, :, start_idx:end_idx]
else:
center_i = None
data_i = self.preprocess(data_i, center_i, method)
split_data.append(data_i)
return split_data
def get_features(self, data):
# data = tf.identity(data)
support_x = data[0]
query_x = data[3]
support_d = self.get_d_features(data[6])
query_d = self.get_d_features(data[7])
if self.normalize_ed:
support_d = tf.nn.l2_normalize(support_d, axis=-1)
query_d = tf.nn.l2_normalize(query_d, axis=-1)
# support_center = tf.broadcast_to(self.pretrain_features_mean, support_x.shape)
# query_center = tf.broadcast_to(self.pretrain_features_mean, query_x.shape)
n_way = support_x.shape[0]
n_support = support_x.shape[1]
n_query = query_x.shape[1]
mean_expanded = tf.expand_dims(tf.expand_dims(self.pretrain_features_mean, axis=0), axis=0)
support_center = tf.tile(mean_expanded, (n_way, n_support, 1))
query_center = tf.tile(mean_expanded, (n_way, n_query, 1))
support_x = self.preprocess(support_x, support_center, self.preprocess_before_split)
query_x = self.preprocess(query_x, query_center, self.preprocess_before_split)
split_support_x = self.get_split_features(support_x, support_center, self.preprocess_after_split)
split_query_x = self.get_split_features(query_x, query_center, self.preprocess_after_split)
split_support_d = self.get_split_features(support_d, None, "none")
split_query_d = self.get_split_features(query_d, None, "none")
debug = {}
debug["support_logits"] = data[6]
debug["support_prob"] = tf.nn.softmax(data[6], axis=-1)
debug["support_center"] = support_center
debug["pretrain_features_means"] = self.pretrain_features_mean
debug["support_d"] = support_d
return split_support_x, split_support_d, split_query_x, split_query_d, debug
class IFSL(snt.AbstractModule):
def __init__(self, config=None, use_64bits_dtype=True, n_splits=4, is_cosine_feature=True, fusion="concat", classifier="single",
num_classes=64, logit_fusion="product", use_x_only=False, preprocess_before_split="none", preprocess_after_split="none",
normalize_before_center=True, normalize_d=False, normalize_ed=False):
super(IFSL, self).__init__(name="IFSL")
self.n_splits = n_splits
self.feat_dim = config["feat_dim"]
self._int_dtype = tf.int64 if use_64bits_dtype else tf.int32
self.feature_processor = FeatureProcessor(n_splits, config["pretrain_mean_filename"], self.feat_dim,
is_cosine_feature, num_classes, preprocess_after_split,
preprocess_before_split, normalize_before_center, normalize_d, normalize_ed)
self.n_splits = n_splits
self.num_classes = num_classes
self.fusion = fusion
self.classifier = classifier
self.logit_fusion = logit_fusion
self.use_x_only = use_x_only
if self.use_x_only:
self.classifier = "single"
if self.classifier == "single":
self.modules = []
for i in range(self.n_splits):
self.modules.append(LEO(config=config, use_64bits_dtype=use_64bits_dtype, name="leo" + str(i), deconfound=True))
else:
self.x_modules = []
self.d_modules = []
for i in range(self.n_splits):
self.x_modules.append(LEO(config=config, use_64bits_dtype=use_64bits_dtype, name="x_leo" + str(i), deconfound=True))
self.d_modules.append(LEO(config=config, use_64bits_dtype=use_64bits_dtype, name="d_leo" + str(i), deconfound=True))
def build_input_data(self, support_features, query_features, data):
input_data = []
input_data.append(support_features)
input_data.append(data[1])
input_data.append(data[2])
input_data.append(query_features)
input_data.append(data[4])
input_data.append(data[5])
input_data.append(data[6])
input_data.append(data[7])
return input_data
def get_debug_data(self, data):
debug_data = []
for i in range(8):
# Take the first batch
debug_data.append(data[i][0])
return debug_data
def fuse_features(self, x1, x2):
if self.fusion == "concat":
return tf.concat([x1, x2], axis=-1)
elif self.fusion == "+":
# return x1 + x2
return tf.math.add(x1, x2)
elif self.fusion == "-":
return tf.math.subtract(x1, x2)
def _build(self, data, is_meta_training=True, debug=False, break_down=False):
leo_feat_dim = int(self.feat_dim / self.n_splits)
if debug:
data = self.get_debug_data(data)
'''
split_support_x, split_support_d, split_query_x, split_query_d, debug_data = self.feature_processor.get_features(data)
prediction = None
losses = None
for i in range(self.n_splits):
current_data = []
start = i * leo_feat_dim
end = start + leo_feat_dim
current_data.append(data[0][0][:, :, start:end])
current_data.append(data[1][0])
current_data.append(data[2][0])
current_data.append(data[3][0][:, :, start:end])
current_data.append(data[4][0])
current_data.append(data[5][0])
current_data.append(data[6][0])
current_data.append(data[7][0])
loss, additional_loss, accuracy, output = self.modules[i](current_data, is_meta_training)
output = tf.nn.softmax(output, axis=-1)
if prediction is None:
prediction = output
losses = loss
else:
prediction += output
losses += loss
# accuracies.append(accuracy)
model_prediction = tf.argmax(prediction, -1, output_type=self._int_dtype)
accuracy = tf.contrib.metrics.accuracy(model_prediction, tf.squeeze(data[4], axis=-1))
losses = losses / self.n_splits
'''
split_support_x, split_support_d, split_query_x, split_query_d, debug_data = self.feature_processor.get_features(data)
if self.classifier == "single":
if self.use_x_only:
fused_support = split_support_x
fused_query = split_query_x
else:
fused_support = self.fuse_features(split_support_x, split_support_d)
fused_query = self.fuse_features(split_query_x, split_query_d)
if self.classifier == "single":
prediction = None
losses = None
for i in range(self.n_splits):
input_data = self.build_input_data(fused_support[i], fused_query[i], data)
loss, additional_loss, accuracy, output = self.modules[i](input_data, is_meta_training)
output = tf.nn.softmax(output, axis=-1)
if prediction is None:
prediction = output
losses = loss
else:
prediction += output
losses += loss
model_prediction = tf.argmax(prediction, -1, output_type=self._int_dtype)
accuracy = tf.contrib.metrics.accuracy(model_prediction, tf.squeeze(data[4], axis=-1))
losses = losses / self.n_splits
dacc = self.calculate_dacc(data, model_prediction)
debug_data["fused_support"] = fused_support
debug_data["output"] = output
debug_data["model_prediction"] = model_prediction
debug_data["dacc"] = dacc
if break_down:
hardness, correct = self.calculate_dacc(data, model_prediction, True)
return hardness, correct
if debug:
return losses, accuracy, debug_data
else:
return losses, accuracy, dacc
def calculate_dacc(self, data, model_prediction, break_down=False):
n_way = 5
n_query = data[7].shape[1]
support_logits = tf.nn.relu(data[6])
query_labels = tf.squeeze(data[4], axis=-1) # 5 * 15
query_logits = tf.nn.relu(data[7])
w = tf.reduce_mean(support_logits, axis=1) # way * 64
w = tf.nn.l2_normalize(w, axis=-1)
query_logits = tf.nn.l2_normalize(query_logits, axis=-1) # way * 15 * num_classes
w_t = tf.transpose(w) # num_classes * way
logits = tf.tensordot(query_logits, w_t, axes=[[-1],[0]]) # way * 15 * way
probs = tf.nn.softmax(logits, axis=-1)
idx = tf.tile(tf.range(0, n_way), tf.constant([n_way * n_query], tf.int32))
idx = tf.reshape(idx, (n_way, n_query, n_way))
q_labels_expanded = tf.expand_dims(query_labels, axis=2)
q_labels_expanded = tf.tile(q_labels_expanded, (1, 1, n_way))
idx = tf.math.equal(idx, q_labels_expanded)
hardness = tf.math.log((1 - probs) / probs) # way * 15
correct_expanded = tf.expand_dims(tf.math.equal(model_prediction, query_labels), axis=2)
correct_expanded = tf.cast(tf.tile(correct_expanded, multiples=(1, 1, n_way)), dtype=tf.float32)
# correct_probs = tf.gather(probs, idx)
'''
correct_probs = tf.zeros([n_way, n_query])
for i in range(n_way):
correct_probs[i] = probs[i, :, query_labels[i][0]]
correct_probs = tf.convert_to_tensor(correct_probs)
hardness = tf.math.log((1 - correct_probs) / correct_probs) # way * 15
return query_labels, hardness
'''
hardness = hardness * tf.cast(idx, dtype=tf.float32)
total_hardness = tf.reduce_sum(hardness)
scored_hardness = tf.reduce_sum(hardness * correct_expanded)
self.hardness = hardness
if break_down:
return hardness, tf.math.equal(model_prediction, query_labels)
else:
return scored_hardness / total_hardness
def grads_and_vars(self, metatrain_loss):
metatrain_gradients_merged = []
metatrain_variables_merged = []
for i in range(self.n_splits):
metatrain_gradients, metatrain_variables = self.modules[i].grads_and_vars(metatrain_loss)
metatrain_gradients_merged += metatrain_gradients
metatrain_variables_merged += metatrain_variables
return metatrain_gradients_merged, metatrain_variables_merged
class LEO(snt.AbstractModule):
"""Sonnet module implementing the inner loop of LEO."""
def __init__(self, config=None, use_64bits_dtype=True, name="leo", deconfound=False):
super(LEO, self).__init__(name=name)
self._float_dtype = tf.float64 if use_64bits_dtype else tf.float32
self._int_dtype = tf.int64 if use_64bits_dtype else tf.int32
self._inner_unroll_length = config["inner_unroll_length"]
self._finetuning_unroll_length = config["finetuning_unroll_length"]
self._inner_lr_init = config["inner_lr_init"]
self._finetuning_lr_init = config["finetuning_lr_init"]
self._num_latents = config["num_latents"]
self._dropout_rate = config["dropout_rate"]
self._deconfound = deconfound
self._kl_weight = config["kl_weight"] # beta
self._encoder_penalty_weight = config["encoder_penalty_weight"] # gamma
self._l2_penalty_weight = config["l2_penalty_weight"] # lambda_1
# lambda_2
self._orthogonality_penalty_weight = config["orthogonality_penalty_weight"]
assert self._inner_unroll_length > 0, ("Positive unroll length is necessary"
" to create the graph")
def _build(self, data, is_meta_training=True):
"""Connects the LEO module to the graph, creating the variables.
Args:
data: A data_module.ProblemInstance constaining Tensors with the
following shapes:
- tr_input: (N, K, dim)
- tr_output: (N, K, 1)
- tr_info: (N, K)
- val_input: (N, K_valid, dim)
- val_output: (N, K_valid, 1)
- val_info: (N, K_valid)
where N is the number of classes (as in N-way) and K and the and
K_valid are numbers of training and validation examples within a
problem instance correspondingly (as in K-shot), and dim is the
dimensionality of the embedding.
is_meta_training: A boolean describing whether we run in the training
mode.
Returns:
Tensor with the inner validation loss of LEO (include both adaptation in
the latent space and finetuning).
"""
if isinstance(data, list):
data = data_module.ProblemInstance(*data)
self.is_meta_training = is_meta_training
self.save_problem_instance_stats(data.tr_input)
latents, kl = self.forward_encoder(data)
tr_loss, adapted_classifier_weights, encoder_penalty = self.leo_inner_loop(
data, latents)
# print(encoder_penalty)
val_loss, val_accuracy, val_output = self.finetuning_inner_loop(
data, tr_loss, adapted_classifier_weights)
val_loss += self._kl_weight * kl
val_loss += self._encoder_penalty_weight * encoder_penalty
# The l2 regularization is is already added to the graph when constructing
# the snt.Linear modules. We pass the orthogonality regularizer separately,
# because it is not used in self.grads_and_vars.
regularization_penalty = (
self._l2_regularization + self._decoder_orthogonality_reg)
batch_val_loss = tf.reduce_mean(val_loss)
batch_val_accuracy = tf.reduce_mean(val_accuracy)
additional_loss = self._kl_weight * kl + self._encoder_penalty_weight * encoder_penalty + regularization_penalty
if self._deconfound:
# return batch_val_loss + regularization_penalty, batch_val_accuracy, val_output
return batch_val_loss + regularization_penalty, additional_loss, batch_val_accuracy, val_output
else:
return batch_val_loss + regularization_penalty, batch_val_accuracy
@snt.reuse_variables
def leo_inner_loop(self, data, latents):
with tf.variable_scope("leo_inner"):
inner_lr = tf.get_variable(
"lr", [1, 1, self._num_latents],
dtype=self._float_dtype,
initializer=tf.constant_initializer(self._inner_lr_init))
starting_latents = latents
loss, _ = self.forward_decoder(data, latents)
for _ in range(self._inner_unroll_length):
loss_grad = tf.gradients(loss, latents) # dLtrain/dz
latents -= inner_lr * loss_grad[0]
loss, classifier_weights = self.forward_decoder(data, latents)
if self.is_meta_training:
encoder_penalty = tf.losses.mean_squared_error(
labels=tf.stop_gradient(latents), predictions=starting_latents)
encoder_penalty = tf.cast(encoder_penalty, self._float_dtype)
else:
encoder_penalty = tf.constant(0., self._float_dtype)
# print(encoder_penalty)
return loss, classifier_weights, encoder_penalty
@snt.reuse_variables
def finetuning_inner_loop(self, data, leo_loss, classifier_weights):
tr_loss = leo_loss
with tf.variable_scope("finetuning"):
finetuning_lr = tf.get_variable(
"lr", [1, 1, self.embedding_dim],
dtype=self._float_dtype,
initializer=tf.constant_initializer(self._finetuning_lr_init))
for _ in range(self._finetuning_unroll_length):
loss_grad = tf.gradients(tr_loss, classifier_weights)
classifier_weights -= finetuning_lr * loss_grad[0]
tr_loss, _, _ = self.calculate_inner_loss(data.tr_input, data.tr_output,
classifier_weights)
val_loss, val_accuracy, val_output = self.calculate_inner_loss(
data.val_input, data.val_output, classifier_weights)
return val_loss, val_accuracy, val_output
@snt.reuse_variables
def forward_encoder(self, data):
encoder_outputs = self.encoder(data.tr_input)
relation_network_outputs = self.relation_network(encoder_outputs)
latent_dist_params = self.average_codes_per_class(relation_network_outputs)
latents, kl = self.possibly_sample(latent_dist_params)
return latents, kl
@snt.reuse_variables
def forward_decoder(self, data, latents):
weights_dist_params = self.decoder(latents)
# Default to glorot_initialization and not stddev=1.
fan_in = self.embedding_dim.value
fan_out = self.num_classes.value
# print(fan_in, fan_out) output: 640 5
stddev_offset = np.sqrt(2. / (fan_out + fan_in))
classifier_weights, _ = self.possibly_sample(weights_dist_params,
stddev_offset=stddev_offset)
# print(classifier_weights.get_shape()) output: 5*5*640
tr_loss, _ , _= self.calculate_inner_loss(data.tr_input, data.tr_output,
classifier_weights)
return tr_loss, classifier_weights
@snt.reuse_variables
def encoder(self, inputs):
with tf.variable_scope("encoder"):
after_dropout = tf.nn.dropout(inputs, rate=self.dropout_rate)
regularizer = tf.contrib.layers.l2_regularizer(self._l2_penalty_weight)
initializer = tf.initializers.glorot_uniform(dtype=self._float_dtype)
encoder_module = snt.Linear(
self._num_latents,
use_bias=False,
regularizers={"w": regularizer},
initializers={"w": initializer},
)
outputs = snt.BatchApply(encoder_module)(after_dropout)
return outputs
@snt.reuse_variables
def relation_network(self, inputs):
with tf.variable_scope("relation_network"):
regularizer = tf.contrib.layers.l2_regularizer(self._l2_penalty_weight)
initializer = tf.initializers.glorot_uniform(dtype=self._float_dtype)
relation_network_module = snt.nets.MLP(
[2 * self._num_latents] * 3,
use_bias=False,
regularizers={"w": regularizer},
initializers={"w": initializer},
)
total_num_examples = self.num_examples_per_class*self.num_classes
inputs = tf.reshape(inputs, [total_num_examples, self._num_latents])
left = tf.tile(tf.expand_dims(inputs, 1), [1, total_num_examples, 1])
right = tf.tile(tf.expand_dims(inputs, 0), [total_num_examples, 1, 1])
concat_codes = tf.concat([left, right], axis=-1)
outputs = snt.BatchApply(relation_network_module)(concat_codes)
outputs = tf.reduce_mean(outputs, axis=1)
# 2 * latents, because we are returning means and variances of a Gaussian
outputs = tf.reshape(outputs, [self.num_classes,
self.num_examples_per_class,
2 * self._num_latents])
return outputs
@snt.reuse_variables
def decoder(self, inputs):
with tf.variable_scope("decoder"):
l2_regularizer = tf.contrib.layers.l2_regularizer(self._l2_penalty_weight)
orthogonality_reg = get_orthogonality_regularizer(
self._orthogonality_penalty_weight)
initializer = tf.initializers.glorot_uniform(dtype=self._float_dtype)
# 2 * embedding_dim, because we are returning means and variances
decoder_module = snt.Linear(
2 * self.embedding_dim,
use_bias=False,
regularizers={"w": l2_regularizer},
initializers={"w": initializer},
)
outputs = snt.BatchApply(decoder_module)(inputs)
self._orthogonality_reg = orthogonality_reg(decoder_module.w)
return outputs
def average_codes_per_class(self, codes):
codes = tf.reduce_mean(codes, axis=1, keep_dims=True) # K dimension
# Keep the shape (N, K, *)
codes = tf.tile(codes, [1, self.num_examples_per_class, 1])
return codes
def possibly_sample(self, distribution_params, stddev_offset=0.):
means, unnormalized_stddev = tf.split(distribution_params, 2, axis=-1)
stddev = tf.exp(unnormalized_stddev)
stddev -= (1. - stddev_offset)
stddev = tf.maximum(stddev, 1e-10)
distribution = tfp.distributions.Normal(loc=means, scale=stddev)
if not self.is_meta_training:
return means, tf.constant(0., dtype=self._float_dtype)
samples = distribution.sample()
# print(samples.get_shape()) # 5*5*640
kl_divergence = self.kl_divergence(samples, distribution)
# print(kl_divergence)
return samples, kl_divergence
def kl_divergence(self, samples, normal_distribution):
random_prior = tfp.distributions.Normal(
loc=tf.zeros_like(samples), scale=tf.ones_like(samples))
kl = tf.reduce_mean(
normal_distribution.log_prob(samples) - random_prior.log_prob(samples))
return kl
def predict(self, inputs, weights):
after_dropout = tf.nn.dropout(inputs, rate=self.dropout_rate)
# This is 3-dimensional equivalent of a matrix product, where we sum over
# the last (embedding_dim) dimension. We get [N, K, N, K] tensor as output.
per_image_predictions = tf.einsum("ijk,lmk->ijlm", after_dropout, weights)
# Predictions have shape [N, K, N]: for each image ([N, K] of them), what
# is the probability of a given class (N)?
predictions = tf.reduce_mean(per_image_predictions, axis=-1)
return predictions
def calculate_inner_loss(self, inputs, true_outputs, classifier_weights):
model_outputs = self.predict(inputs, classifier_weights)
model_predictions = tf.argmax(
model_outputs, -1, output_type=self._int_dtype)
accuracy = tf.contrib.metrics.accuracy(model_predictions,
tf.squeeze(true_outputs, axis=-1))
return self.loss_fn(model_outputs, true_outputs), accuracy, model_outputs
def save_problem_instance_stats(self, instance):
num_classes, num_examples_per_class, embedding_dim = instance.get_shape()
if hasattr(self, "num_classes"):
assert self.num_classes == num_classes, (
"Given different number of classes (N in N-way) in consecutive runs.")
if hasattr(self, "num_examples_per_class"):
assert self.num_examples_per_class == num_examples_per_class, (
"Given different number of examples (K in K-shot) in consecutive"
"runs.")
if hasattr(self, "embedding_dim"):
assert self.embedding_dim == embedding_dim, (
"Given different embedding dimension in consecutive runs.")
self.num_classes = num_classes
self.num_examples_per_class = num_examples_per_class
self.embedding_dim = embedding_dim
@property
def dropout_rate(self):
return self._dropout_rate if self.is_meta_training else 0.0
def loss_fn(self, model_outputs, original_classes):
original_classes = tf.squeeze(original_classes, axis=-1)
# Tensorflow doesn't handle second order gradients of a sparse_softmax yet.
one_hot_outputs = tf.one_hot(original_classes, depth=self.num_classes)
return tf.nn.softmax_cross_entropy_with_logits_v2(
labels=one_hot_outputs, logits=model_outputs)
def grads_and_vars(self, metatrain_loss):
"""Computes gradients of metatrain_loss, avoiding NaN.
Uses a fixed penalty of 1e-4 to enforce only the l2 regularization (and not
minimize the loss) when metatrain_loss or any of its gradients with respect
to trainable_vars are NaN. In practice, this approach pulls the variables
back into a feasible region of the space when the loss or its gradients are
not defined.
Args:
metatrain_loss: A tensor with the LEO meta-training loss.
Returns:
A tuple with:
metatrain_gradients: A list of gradient tensors.
metatrain_variables: A list of variables for this LEO model.
"""
metatrain_variables = self.trainable_variables
metatrain_gradients = tf.gradients(metatrain_loss, metatrain_variables)
nan_loss_or_grad = tf.logical_or(
tf.is_nan(metatrain_loss),
tf.reduce_any([tf.reduce_any(tf.is_nan(g))
for g in metatrain_gradients]))
regularization_penalty = (
1e-4 / self._l2_penalty_weight * self._l2_regularization)
zero_or_regularization_gradients = [
g if g is not None else tf.zeros_like(v)
for v, g in zip(tf.gradients(regularization_penalty,
metatrain_variables), metatrain_variables)]
metatrain_gradients = tf.cond(nan_loss_or_grad,
lambda: zero_or_regularization_gradients,
lambda: metatrain_gradients, strict=True)
return metatrain_gradients, metatrain_variables
@property
def _l2_regularization(self):
return tf.cast(
tf.reduce_sum(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)),
dtype=self._float_dtype)
@property
def _decoder_orthogonality_reg(self):
return self._orthogonality_reg
================================================
FILE: LEO/model_test.py
================================================
# Copyright 2018 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Tests for ml_leo.model."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import functools
from absl.testing import parameterized
import mock
import numpy as np
from six.moves import zip
import sonnet as snt
import tensorflow as tf
import data
import model
# Adding float64 and 32 gives an error in TensorFlow.
constant_float64 = lambda x: tf.constant(x, dtype=tf.float64)
def get_test_config():
"""Returns the config used to initialize LEO model."""
config = {}
config["inner_unroll_length"] = 3
config["finetuning_unroll_length"] = 4
config["inner_lr_init"] = 0.1
config["finetuning_lr_init"] = 0.2
config["num_latents"] = 1
config["dropout_rate"] = 0.3
config["kl_weight"] = 0.01
config["encoder_penalty_weight"] = 0.01
config["l2_penalty_weight"] = 0.01
config["orthogonality_penalty_weight"] = 0.01
return config
def mockify_everything(test_function=None,
mock_finetuning=True,
mock_encdec=True):
"""Mockifies most of the LEO"s model functions to behave as identity."""
def inner_decorator(f):
@functools.wraps(f)
def mockified(*args, **kwargs):
identity_mapping = lambda unused_self, inp, *args: tf.identity(inp)
mock_encoder = mock.patch.object(
model.LEO, "encoder", new=identity_mapping)
mock_relation_network = mock.patch.object(
model.LEO, "relation_network", new=identity_mapping)
mock_decoder = mock.patch.object(
model.LEO, "decoder", new=identity_mapping)
mock_average = mock.patch.object(
model.LEO, "average_codes_per_class", new=identity_mapping)
mock_loss = mock.patch.object(model.LEO, "loss_fn", new=identity_mapping)
float64_zero = constant_float64(0.)
def identity_sample_fn(unused_self, inp, *unused_args, **unused_kwargs):
return inp, float64_zero
def mock_sample_with_split(unused_self, inp, *unused_args,
**unused_kwargs):
out = tf.split(inp, 2, axis=-1)[0]
return out, float64_zero
# When not mocking relation net, it will double the latents.
mock_sample = mock.patch.object(
model.LEO,
"possibly_sample",
new=identity_sample_fn if mock_encdec else mock_sample_with_split)
def dummy_predict(unused_self, inputs, classifier_weights):
return inputs * classifier_weights**2
mock_predict = mock.patch.object(model.LEO, "predict", new=dummy_predict)
mock_decoder_regularizer = mock.patch.object(
model.LEO, "_decoder_orthogonality_reg", new=float64_zero)
all_mocks = [mock_average, mock_loss, mock_predict, mock_sample]
if mock_encdec:
all_mocks.extend([
mock_encoder,
mock_relation_network,
mock_decoder,
mock_decoder_regularizer,
])
if mock_finetuning:
mock_finetuning_inner = mock.patch.object(
model.LEO,
"finetuning_inner_loop",
new=lambda unused_self, d, l, adapted: (adapted, float64_zero))
all_mocks.append(mock_finetuning_inner)
for m in all_mocks:
m.start()
f(*args, **kwargs)
for m in all_mocks:
m.stop()
return mockified
if test_function:
# Decorator called with no arguments, so the function is passed
return inner_decorator(test_function)
return inner_decorator
def _random_problem_instance(num_classes=7,
num_examples_per_class=5,
embedding_dim=17, use_64bits_dtype=True):
inputs_dtype = tf.float64 if use_64bits_dtype else tf.float32
inputs = tf.constant(
np.random.random((num_classes, num_examples_per_class, embedding_dim)),
dtype=inputs_dtype)
outputs_dtype = tf.int64 if use_64bits_dtype else tf.int32
outputs = tf.constant(
np.random.randint(
low=0,
high=num_classes,
size=(num_classes, num_examples_per_class, 1)), dtype=outputs_dtype)
problem = data.ProblemInstance(
tr_input=inputs,
val_input=inputs,
tr_info=inputs,
tr_output=outputs,
val_output=outputs,
val_info=inputs)
return problem
class LEOTest(tf.test.TestCase, parameterized.TestCase):
def setUp(self):
super(LEOTest, self).setUp()
self._problem = _random_problem_instance(5, 7, 4)
# This doesn"t call any function, so doesn't need the mocks to be started.
self._config = get_test_config()
self._leo = model.LEO(config=self._config)
self.addCleanup(mock.patch.stopall)
@mockify_everything
def test_instantiate_leo(self):
encoder_output = self._leo.encoder(5, 7)
with self.session() as sess:
encoder_output_ev = sess.run(encoder_output)
self.assertEqual(encoder_output_ev, 5)
@mockify_everything
def test_inner_loop_adaptation(self):
problem_instance = data.ProblemInstance(
tr_input=constant_float64([[[4.]]]),
tr_output=tf.constant([[[0]]], dtype=tf.int64),
tr_info=[],
val_input=[],
val_output=[],
val_info=[],
)
# encoder = decoder = id
# predict returns classifier_weights**2 * inputs = latents**2 * inputs
# loss = id = inputs*latents
# dl/dlatent = 2 * latent * inputs
# 4 -> 4 - 0.1 * 2 * 4 * 4 = 0.8
# 0.8 -> 0.8 - 0.1 * 2 * 0.8 * 4 = 0.16
# 0.16 -> 0.16 - 0.1 * 2 * 0.16 * 4 = 0.032
# is_meta_training=False disables kl and encoder penalties
adapted_parameters, _ = self._leo(problem_instance, is_meta_training=False)
with self.session() as sess:
sess.run(tf.global_variables_initializer())
self.assertAllClose(sess.run(adapted_parameters), 0.032)
@mockify_everything
def test_map_input(self):
problem = [
constant_float64([[[5.]]]), # tr_input
tf.constant([[[0]]], dtype=tf.int64), # tr_output
constant_float64([[[0]]]), # tr_info
constant_float64([[[0.]]]), # val_input
tf.constant([[[0]]], dtype=tf.int64), # val_output
constant_float64([[[0]]]), # val_info
]
another_problem = [
constant_float64([[[4.]]]),
tf.constant([[[0]]], dtype=tf.int64),
constant_float64([[[0]]]),
constant_float64([[[0.]]]),
tf.constant([[[0]]], dtype=tf.int64),
constant_float64([[[0]]]),
]
# first dimension (list): diffent input kind (tr_input, val_output, etc.)
# second dim: different problems; this has to be a tensor dim for map_fn
# to split over it.
# next three: (1, 1, 1)
# map_fn cannot receive structured inputs (namedtuples).
ins = [
tf.stack([in1, in2])
for in1, in2 in zip(problem, another_problem)
]
two_adapted_params, _ = tf.map_fn(
self._leo.__call__, ins, dtype=(tf.float64, tf.float64))
with self.session() as sess:
sess.run(tf.global_variables_initializer())
output1, output2 = sess.run(two_adapted_params)
self.assertGreater(abs(output1 - output2), 1e-3)
@mockify_everything
def test_setting_is_meta_training(self):
self._leo(self._problem, is_meta_training=True)
self.assertTrue(self._leo.is_meta_training)
self._leo(self._problem, is_meta_training=False)
self.assertFalse(self._leo.is_meta_training)
@mockify_everything(mock_finetuning=False)
def test_finetuning_improves_loss(self):
# Create graph
self._leo(self._problem)
latents, _ = self._leo.forward_encoder(self._problem)
leo_loss, adapted_classifier_weights, _ = self._leo.leo_inner_loop(
self._problem, latents)
leo_loss = tf.reduce_mean(leo_loss)
finetuning_loss, _ = self._leo.finetuning_inner_loop(
self._problem, leo_loss, adapted_classifier_weights)
finetuning_loss = tf.reduce_mean(finetuning_loss)
with self.session() as sess:
sess.run(tf.global_variables_initializer())
leo_loss_ev, finetuning_loss_ev = sess.run([leo_loss, finetuning_loss])
self.assertGreater(leo_loss_ev - 1e-3, finetuning_loss_ev)
@mockify_everything
def test_gradients_dont_flow_through_input(self):
# Create graph
self._leo(self._problem)
latents, _ = self._leo.forward_encoder(self._problem)
grads = tf.gradients(self._problem.tr_input, latents)
self.assertIsNone(grads[0])
@mockify_everything
def test_inferring_embedding_dim(self):
self._leo(self._problem)
self.assertEqual(self._leo.embedding_dim, 4)
@mockify_everything(mock_encdec=False, mock_finetuning=False)
def test_variable_creation(self):
self._leo(self._problem)
encoder_variables = snt.get_variables_in_scope("leo/encoder")
self.assertNotEmpty(encoder_variables)
relation_network_variables = snt.get_variables_in_scope(
"leo/relation_network")
self.assertNotEmpty(relation_network_variables)
decoder_variables = snt.get_variables_in_scope("leo/decoder")
self.assertNotEmpty(decoder_variables)
inner_lr = snt.get_variables_in_scope("leo/leo_inner")
self.assertNotEmpty(inner_lr)
finetuning_lr = snt.get_variables_in_scope("leo/finetuning")
self.assertNotEmpty(finetuning_lr)
self.assertSameElements(
encoder_variables + relation_network_variables + decoder_variables +
inner_lr + finetuning_lr, self._leo.trainable_variables)
def test_graph_construction(self):
self._leo(self._problem)
def test_possibly_sample(self):
# Embedding dimension has to be divisible by 2 here.
self._leo(self._problem, is_meta_training=True)
train_samples, train_kl = self._leo.possibly_sample(self._problem.tr_input)
self._leo(self._problem, is_meta_training=False)
test_samples, test_kl = self._leo.possibly_sample(self._problem.tr_input)
with self.session() as sess:
train_samples_ev1, test_samples_ev1 = sess.run(
[train_samples, test_samples])
train_samples_ev2, test_samples_ev2 = sess.run(
[train_samples, test_samples])
self.assertAllClose(test_samples_ev1, test_samples_ev2)
self.assertGreater(abs(np.sum(train_samples_ev1 - train_samples_ev2)), 1.)
train_kl_ev, test_kl_ev = sess.run([train_kl, test_kl])
self.assertNotEqual(train_kl_ev, 0.)
self.assertEqual(test_kl_ev, 0.)
def test_different_shapes(self):
problem_instance2 = _random_problem_instance(5, 6, 13)
self._leo(self._problem)
with self.assertRaises(AssertionError):
self._leo(problem_instance2)
def test_encoder_penalty(self):
self._leo(self._problem) # Sets is_meta_training
latents, _ = self._leo.forward_encoder(self._problem)
_, _, train_encoder_penalty = self._leo.leo_inner_loop(
self._problem, latents)
self._leo(self._problem, is_meta_training=False)
_, _, test_encoder_penalty = self._leo.leo_inner_loop(
self._problem, latents)
with self.session() as sess:
sess.run(tf.initializers.global_variables())
train_encoder_penalty_ev, test_encoder_penalty_ev = sess.run(
[train_encoder_penalty, test_encoder_penalty])
self.assertGreater(train_encoder_penalty_ev, 1e-3)
self.assertLess(test_encoder_penalty_ev, 1e-7)
def test_construct_float32_leo_graph(self):
leo = model.LEO(use_64bits_dtype=False, config=self._config)
problem_instance_32_bits = _random_problem_instance(use_64bits_dtype=False)
leo(problem_instance_32_bits)
if __name__ == "__main__":
tf.test.main()
================================================
FILE: LEO/readme.md
================================================
# LEO + IFSL
This project is based on the official code base of the paper [Meta-Learning with Latent Embedding Optimization](https://arxiv.org/abs/1807.05960). This project includes a Tensorflow implementation of IFSL classifier in model.py. The class-wise mean features are available in pretrain folder.
The backbone used in the official code release of LEO is not released. We used the pre-trained backbone provided [here](https://github.com/mileyan/simple_shot), which is also used to evaluate other meta-learning methods. Prior to training LEO, one needs to save the features generated by pre-trained network, which will be subsequently loaded during training. We provide the download link for pre-saved features of the backbone we used.
## Dependencies
Recommended version:
- TensorFlow 1.13
- TensorFlow Probability 0.5
- Sonnet 0.1.6
- Abseil
## Preparation
- Download the processed features at
- https://drive.google.com/drive/folders/1uS3UMhg7v37orxx5f2X_nX8WTJg33EXT?usp=sharing
- In config.py, go to function load_ifsl_config and change checkpoint_path to where you want to store the trained model and change data_path to where you store the processed features
## Running Experiments
The following command will meta-train a model followed by meta-test.
```
# Training
python runner.py --config=mini_5_resnet_baseline
# Testing
python runner.py --config=mini_5_resnet_baseline --evaluation_mode=True
```
Please refer to ifsl_configs folder for the complete list of configurations.
To perform CUB test, add --cross option.
================================================
FILE: LEO/runner.py
================================================
# Copyright 2018 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""A binary building the graph and performing the optimization of LEO."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import functools
import os
import pickle
import logging
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '0'
logging.getLogger("tensorflow").setLevel(logging.INFO)
from absl import flags
import ifsl_configs
from six.moves import zip
import tensorflow as tf
import config
import data
import model
import utils
import numpy as np
FLAGS = flags.FLAGS
flags.DEFINE_string("checkpoint_path", "/data2/yuezhongqi/Model/leo/ifsl/miniresnet5baselinewoaug", "Path to restore from and "
"save to checkpoints.")
flags.DEFINE_integer(
"checkpoint_steps", 1000, "The frequency, in number of "
"steps, of saving the checkpoints.")
flags.DEFINE_boolean("evaluation_mode", False, "Whether to run in an "
"evaluation-only mode.")
def _clip_gradients(gradients, gradient_threshold, gradient_norm_threshold):
"""Clips gradients by value and then by norm."""
if gradient_threshold > 0:
gradients = [
tf.clip_by_value(g, -gradient_threshold, gradient_threshold)
for g in gradients
]
if gradient_norm_threshold > 0:
gradients = [
tf.clip_by_norm(g, gradient_norm_threshold) for g in gradients
]
return gradients
def _construct_validation_summaries(metavalid_loss, metavalid_accuracy):
tf.summary.scalar("metavalid_loss", metavalid_loss)
tf.summary.scalar("metavalid_valid_accuracy", metavalid_accuracy)
# The summaries are passed implicitly by TensorFlow.
def _construct_training_summaries(metatrain_loss, metatrain_accuracy,
model_grads, model_vars):
tf.summary.scalar("metatrain_loss", metatrain_loss)
tf.summary.scalar("metatrain_valid_accuracy", metatrain_accuracy)
for g, v in zip(model_grads, model_vars):
histogram_name = v.name.split(":")[0]
tf.summary.histogram(histogram_name, v)
histogram_name = "gradient/{}".format(histogram_name)
tf.summary.histogram(histogram_name, g)
def _construct_examples_batch(batch_size, split, num_classes,
num_tr_examples_per_class,
num_val_examples_per_class,
use_cross=False):
data_provider = data.DataProvider(split, config.get_data_config(), feat_dim=FLAGS.feat_dim, use_cross=use_cross)
examples_batch = data_provider.get_batch(batch_size, num_classes,
num_tr_examples_per_class,
num_val_examples_per_class,
num_pretrain_classes=FLAGS.num_pretrain_classes)
return utils.unpack_data(examples_batch)
def _construct_loss_and_accuracy(inner_model, inputs, is_meta_training):
"""Returns batched loss and accuracy of the model ran on the inputs."""
call_fn = functools.partial(
inner_model.__call__, is_meta_training=is_meta_training)
per_instance_loss, per_instance_accuracy, per_instance_dacc = tf.map_fn(
call_fn,
inputs,
dtype=(tf.float32, tf.float32, tf.float32),
back_prop=is_meta_training)
loss = tf.reduce_mean(per_instance_loss)
accuracy = tf.reduce_mean(per_instance_accuracy)
dacc = tf.reduce_mean(per_instance_dacc)
return loss, accuracy, dacc
def construct_debug_graph(outer_model_config):
inner_model_config = config.get_inner_model_config()
tf.logging.info("inner_model_config: {}".format(inner_model_config))
# leo = model.LEO(inner_model_config, use_64bits_dtype=False)
ifsl = model.IFSL(inner_model_config, use_64bits_dtype=False, n_splits=4)
num_classes = outer_model_config["num_classes"]
num_tr_examples_per_class = outer_model_config["num_tr_examples_per_class"]
# Construct a batch from training
metatrain_batch = _construct_examples_batch(
outer_model_config["metatrain_batch_size"], "train", num_classes,
num_tr_examples_per_class,
outer_model_config["num_val_examples_per_class"])
# call_fn = functools.partial(ifsl.__call__, is_meta_training=True)
# losses, accuracies = tf.map_fn(call_fn, metatrain_batch, dtype=(tf.float32, tf.float32), back_prop=True)
losses, accuracies, outputs = ifsl(metatrain_batch, True)
global_step = tf.train.get_or_create_global_step()
return losses, accuracies, outputs, metatrain_batch, global_step
def construct_graph(outer_model_config):
"""Constructs the optimization graph."""
inner_model_config = config.get_inner_model_config()
tf.logging.info("inner_model_config: {}".format(inner_model_config))
if FLAGS.deconfound:
leo = model.IFSL(inner_model_config, use_64bits_dtype=False, n_splits=FLAGS.n_splits,
is_cosine_feature=FLAGS.is_cosine_feature, fusion=FLAGS.fusion,
classifier=FLAGS.classifier, num_classes=FLAGS.pretrain_num_classes,
logit_fusion=FLAGS.logit_fusion, use_x_only=FLAGS.use_x_only,
preprocess_before_split=FLAGS.preprocess_before_split,
preprocess_after_split=FLAGS.preprocess_after_split,
normalize_before_center=FLAGS.normalize_before_center,
normalize_d=FLAGS.normalize_d, normalize_ed=FLAGS.normalize_ed)
else:
# leo = model.LEO(inner_model_config, use_64bits_dtype=False)
leo = model.IFSL(inner_model_config, False, 1, False, "concat", "single", FLAGS.pretrain_num_classes,
"product", True)
num_classes = outer_model_config["num_classes"]
num_tr_examples_per_class = outer_model_config["num_tr_examples_per_class"]
metatrain_batch = _construct_examples_batch(
outer_model_config["metatrain_batch_size"], "train", num_classes,
num_tr_examples_per_class,
outer_model_config["num_val_examples_per_class"])
metatrain_loss, metatrain_accuracy, metatrain_dacc = _construct_loss_and_accuracy(
leo, metatrain_batch, True)
metatrain_gradients, metatrain_variables = leo.grads_and_vars(metatrain_loss)
# Avoids NaNs in summaries.
metatrain_loss = tf.cond(tf.is_nan(metatrain_loss),
lambda: tf.zeros_like(metatrain_loss),
lambda: metatrain_loss)
metatrain_gradients = _clip_gradients(
metatrain_gradients, outer_model_config["gradient_threshold"],
outer_model_config["gradient_norm_threshold"])
_construct_training_summaries(metatrain_loss, metatrain_accuracy,
metatrain_gradients, metatrain_variables)
optimizer = tf.train.AdamOptimizer(
learning_rate=outer_model_config["outer_lr"])
global_step = tf.train.get_or_create_global_step()
train_op = optimizer.apply_gradients(
list(zip(metatrain_gradients, metatrain_variables)), global_step)
data_config = config.get_data_config()
tf.logging.info("data_config: {}".format(data_config))
total_examples_per_class = data_config["total_examples_per_class"]
split = "val"
metavalid_batch = _construct_examples_batch(
outer_model_config["metavalid_batch_size"], split, num_classes,
num_tr_examples_per_class,
total_examples_per_class - num_tr_examples_per_class)
metavalid_loss, metavalid_accuracy, metavalid_dacc = _construct_loss_and_accuracy(
leo, metavalid_batch, False)
if not FLAGS.cross:
metatest_batch = _construct_examples_batch(
outer_model_config["metatest_batch_size"], "test", num_classes,
num_tr_examples_per_class,
total_examples_per_class - num_tr_examples_per_class, use_cross=FLAGS.cross)
else:
metatest_batch = _construct_examples_batch(
outer_model_config["metatest_batch_size"], "test", num_classes,
num_tr_examples_per_class,
15, use_cross=FLAGS.cross)
_, metatest_accuracy, metatest_dacc = _construct_loss_and_accuracy(
leo, metatest_batch, False)
_construct_validation_summaries(metavalid_loss, metavalid_accuracy)
break_down_batch = _construct_examples_batch(
1, "test", num_classes,
num_tr_examples_per_class,
15)
hardness, correct = leo(break_down_batch, False, True, True)
return (train_op, global_step, metatrain_accuracy, metavalid_accuracy,
metatest_accuracy, metatrain_dacc, metavalid_dacc, metatest_dacc, hardness, correct)
def run_debug_loop(checkpoint_path):
outer_model_config = config.get_outer_model_config()
tf.logging.info("outer_model_config: {}".format(outer_model_config))
(losses, accuracies, outputs, metatrain_batch, global_step) = construct_debug_graph(outer_model_config)
num_steps_limit = outer_model_config["num_steps_limit"]
best_metavalid_accuracy = 0.
with tf.train.MonitoredTrainingSession(
checkpoint_dir=checkpoint_path,
save_summaries_steps=FLAGS.checkpoint_steps,
log_step_count_steps=FLAGS.checkpoint_steps,
save_checkpoint_steps=FLAGS.checkpoint_steps,
summary_dir=checkpoint_path) as sess:
if not FLAGS.evaluation_mode:
global_step_ev = sess.run(global_step)
losses_ev, accuracies_ev, outputs_ev, batch = sess.run([losses, accuracies, outputs, metatrain_batch])
a = 1
def write_output_message(message, file_name=None):
if file_name is None:
file_name = "results"
# output_file = os.path.join(self.args.save_path, "results.txt")
output_file = os.path.join("outputs", file_name + ".txt")
with open(output_file, "a") as f:
f.write(message + "\n")
def run_training_loop(checkpoint_path):
"""Runs the training loop, either saving a checkpoint or evaluating it."""
outer_model_config = config.get_outer_model_config()
tf.logging.info("outer_model_config: {}".format(outer_model_config))
(train_op, global_step, metatrain_accuracy, metavalid_accuracy,
metatest_accuracy, metatrain_dacc, metavalid_dacc, metatest_dacc, hardness, correct) = construct_graph(outer_model_config)
num_steps_limit = outer_model_config["num_steps_limit"]
best_metavalid_accuracy = 0.
best_metavalid_dacc = 0.
with tf.train.MonitoredTrainingSession(
checkpoint_dir=checkpoint_path,
save_summaries_steps=FLAGS.checkpoint_steps,
log_step_count_steps=FLAGS.checkpoint_steps,
save_checkpoint_steps=FLAGS.checkpoint_steps,
summary_dir=checkpoint_path) as sess:
if not FLAGS.evaluation_mode:
global_step_ev = sess.run(global_step)
while global_step_ev < num_steps_limit:
if global_step_ev % FLAGS.checkpoint_steps == 0:
# Just after saving checkpoint, calculate accuracy 10 times and save
# the best checkpoint for early stopping.
#metavalid_accuracy_ev = utils.evaluate_and_average(
#sess, metavalid_accuracy, 10)
metavalid_accuracy_ev, metavalid_dacc_ev = utils.evaluate_and_average_acc_dacc(
sess, metavalid_accuracy, metavalid_dacc, 10)
tf.logging.info("Step: {} meta-valid accuracy: {}, dacc: {} best acc: {} best dacc: {}".format(
global_step_ev, metavalid_accuracy_ev, metavalid_dacc_ev, best_metavalid_accuracy, best_metavalid_dacc))
if metavalid_accuracy_ev > best_metavalid_accuracy:
utils.copy_checkpoint(checkpoint_path, global_step_ev,
metavalid_accuracy_ev)
best_metavalid_accuracy = metavalid_accuracy_ev
if metavalid_dacc_ev > best_metavalid_dacc:
best_metavalid_dacc = metavalid_dacc_ev
_, global_step_ev, metatrain_accuracy_ev = sess.run(
[train_op, global_step, metatrain_accuracy])
if global_step_ev % (FLAGS.checkpoint_steps // 2) == 0:
tf.logging.info("Step: {} meta-train accuracy: {}".format(
global_step_ev, metatrain_accuracy_ev))
else:
if not FLAGS.hacc:
assert not FLAGS.checkpoint_steps
num_metatest_estimates = (
2000 // outer_model_config["metatest_batch_size"])
# Not changed to dacc yet
test_accuracy = utils.evaluate_and_average(sess, metatest_accuracy,
num_metatest_estimates)
tf.logging.info("Metatest accuracy: %f", test_accuracy)
with tf.gfile.Open(
os.path.join(checkpoint_path, "test_accuracy"), "wb") as f:
pickle.dump(test_accuracy, f)
else:
all_hardness = []
all_correct = []
for i in range(2000):
hardness_ev, correct_ev = sess.run([hardness, correct])
hardness_ev = [hardness_ev[i,:,i] for i in range(5)]
hardness_ev = np.array(hardness_ev).flatten()
correct_ev = np.array(correct_ev).flatten()
all_hardness.append(hardness_ev)
all_correct.append(correct_ev)
all_hardness = np.array(all_hardness).flatten()
all_correct = np.array(all_correct).flatten()
save_file = {
"hardness": all_hardness,
"correct": all_correct
}
print(all_correct.sum() / len(all_correct))
pickle.dump(save_file, open("hacc/" + FLAGS.config, "wb"))
def main(argv):
del argv # Unused.
# print("here")
ifsl_config = ifsl_configs.__dict__[FLAGS.config]()
config.load_ifsl_config(ifsl_config)
run_training_loop(FLAGS.checkpoint_path)
# run_debug_loop(FLAGS.checkpoint_path)
if __name__ == "__main__":
# print("here")
tf.app.run()
================================================
FILE: LEO/utils.py
================================================
# Copyright 2018 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Short utility functions for LEO."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import pickle
from six.moves import range
import tensorflow as tf
import config
import data
def unpack_data(problem_instance):
"""Map data.ProblemInstance to a list of Tensors, to process with map_fn."""
if isinstance(problem_instance, data.ProblemInstance):
return list(problem_instance)
return problem_instance
def copy_checkpoint(checkpoint_path, global_step, accuracy):
"""Copies the checkpoint to a separate directory."""
tmp_checkpoint_path = os.path.join(checkpoint_path, "tmp_best_checkpoint")
best_checkpoint_path = os.path.join(checkpoint_path, "best_checkpoint")
if _is_previous_accuracy_better(best_checkpoint_path, accuracy):
tf.logging.info("Not copying the checkpoint: there is a better one from "
"before a preemption.")
return
checkpoint_regex = os.path.join(checkpoint_path,
"model.ckpt-{}.*".format(global_step))
checkpoint_files = tf.gfile.Glob(checkpoint_regex)
graph_file = os.path.join(checkpoint_path, "graph.pbtxt")
checkpoint_files.append(graph_file)
_save_files_in_tmp_directory(tmp_checkpoint_path, checkpoint_files, accuracy)
new_checkpoint_index_file = os.path.join(tmp_checkpoint_path, "checkpoint")
with tf.gfile.Open(new_checkpoint_index_file, "w") as f:
f.write("model_checkpoint_path: \"{}/model.ckpt-{}\"\n".format(
best_checkpoint_path, global_step))
# We first copy the better checkpoint to a temporary directory, and only
# when it's created move it to avoid inconsistent state when job is preempted
# when copying the checkpoint.
if tf.gfile.Exists(best_checkpoint_path):
tf.gfile.DeleteRecursively(best_checkpoint_path)
tf.gfile.Rename(tmp_checkpoint_path, best_checkpoint_path)
tf.logging.info("Copied new best checkpoint with accuracy %.5f", accuracy)
def _save_files_in_tmp_directory(tmp_checkpoint_path, checkpoint_files,
accuracy):
"""Saves the checkpoint files and accuracy in a temporary directory."""
if tf.gfile.Exists(tmp_checkpoint_path):
tf.logging.info("The temporary directory exists, because job was preempted "
"before it managed to move it. We're removing it.")
tf.gfile.DeleteRecursively(tmp_checkpoint_path)
tf.gfile.MkDir(tmp_checkpoint_path)
def dump_in_best_checkpoint_path(obj, filename):
full_path = os.path.join(tmp_checkpoint_path, filename)
with tf.gfile.Open(full_path, "wb") as f:
pickle.dump(obj, f)
for file_ in checkpoint_files:
just_filename = file_.split("/")[-1]
tf.gfile.Copy(
file_,
os.path.join(tmp_checkpoint_path, just_filename),
overwrite=False)
dump_in_best_checkpoint_path(config.get_inner_model_config(), "inner_config")
dump_in_best_checkpoint_path(config.get_outer_model_config(), "outer_config")
dump_in_best_checkpoint_path(accuracy, "accuracy")
def _is_previous_accuracy_better(best_checkpoint_path, accuracy):
if not tf.gfile.Exists(best_checkpoint_path):
return False
previous_accuracy_file = os.path.join(best_checkpoint_path, "accuracy")
with tf.gfile.Open(previous_accuracy_file, "rb") as f:
previous_accuracy = pickle.load(f)
return previous_accuracy > accuracy
def evaluate_and_average(session, tensor, num_estimates):
tensor_value_estimates = [session.run(tensor) for _ in range(num_estimates)]
average_tensor_value = sum(tensor_value_estimates) / num_estimates
return average_tensor_value
def evaluate_and_average_acc_dacc(session, acc, dacc, num_estimates):
accs = []
daccs = []
for i in range(num_estimates):
acc_ev, dacc_ev = session.run([acc, dacc])
accs.append(acc_ev)
daccs.append(dacc_ev)
avg_acc = sum(accs) / num_estimates
avg_dacc = sum(daccs) / num_estimates
return avg_acc, avg_dacc
================================================
FILE: MAML_MN_FT/README.md
================================================
# IFSL + Matching Networks, MAML
This project is based on the official code base of the paper [A Closer Look At Few-Shot Classification](https://arxiv.org/abs/1904.04232). IFSL implementations are added in the *methods* folder. The folder pretrain contains pre-saved class-wise feature means used for class-wise adjustment. The folder *tests* contains the running configurations.
## Dependencies
Recommended version:
- Python 3.7.6
- PyTorch 1.4.0
## Preparation
- Download pre-trained backbone in https://github.com/mileyan/simple_shot
- Download mini-ImageNet following https://github.com/wyharveychen/CloserLookFewShot
- Tiered-ImageNet download instruction https://github.com/yaoyao-liu/meta-transfer-learning
Once the datasets are downloaded, go to *filelists/miniImagenet/write_miniImagenet_filelist.py* and *filelists/tiered/write_tiered_filelist.py*. Change *data_path* to the dataset location. Then run the two scripts.
Go to *configs.py.* Change *save_dir* to desired save path for trained models. Change *simple_shot_dir* to the directory where pre-trained weight is stored. Change *tiered_dir* to tiered-ImageNet directory.
## Train and Test
The file tests/MetaTrain.py contains all the configurations to run MAML/Matching Networks Baseline/IFSL with either ResNet10 or WRN-28-10. Run main.py for meta-training followed by meta-testing. Two examples are given below:
```
python main.py --method metatrain --train_aug --test maml5_resnet # MAML 5 shot miniImageNet with ResNet10
python main.py --method metatrain --train_aug --test maml5_ifsl_resnet_tiered # MAML 5 shot tieredImageNet with ResNet10
```
================================================
FILE: MAML_MN_FT/backbone.py
================================================
# This code is modified from https://github.com/facebookresearch/low-shot-shrink-hallucinate
import torch
from torch.autograd import Variable
import torch.nn.init as init
import torch.nn as nn
import math
import numpy as np
import torch.nn.functional as F
from torch.nn.utils.weight_norm import WeightNorm
from sklearn.cluster import KMeans
from sklearn.decomposition import PCA
from numpy import linalg as LA
import os
# Basic ResNet model
def init_layer(L):
# Initialization using fan-in
if isinstance(L, nn.Conv2d):
n = L.kernel_size[0]*L.kernel_size[1]*L.out_channels
L.weight.data.normal_(0,math.sqrt(2.0/float(n)))
elif isinstance(L, nn.BatchNorm2d):
L.weight.data.fill_(1)
L.bias.data.fill_(0)
class NNClassifier():
def __init__(self, n_way):
self.n_way = n_way
def normalize(self, x):
x_norm = torch.norm(x, p=2, dim=1).unsqueeze(1).expand_as(x)
x_normalized = x.div(x_norm + 0.00001)
return x_normalized
def preprocess(self, data):
'''
if self.preprocessing == "none":
return data
elif self.preprocessing == "l2n":
return self.normalize(data)
'''
return data # Do reprocessing outside
def dist(self, x1, x2):
return np.linalg.norm(x1 - x2)
def kl_divergence(self, k1, k2):
k1_safe = k1 + 0.0001
k2_safe = k2 + 0.0001
t = k1_safe * torch.log(k1_safe / k2_safe)
return torch.sum(t, dim=2)
def fit(self, support, support_labels, support_weights=None):
self.support = support
self.labels = support_labels
self.feat_dim = support.shape[1]
processed_support = support
self.centroids = torch.zeros(self.n_way, self.feat_dim).cuda()
for i in range(self.n_way):
class_support = processed_support[support_labels == i]
if support_weights is None:
self.centroids[i] = class_support.mean(dim=0)
else:
class_support_weights = support_weights[support_labels == i]
softmax = nn.Softmax(dim=0)
class_support_weights = softmax(class_support_weights)
class_support_weights = class_support_weights.unsqueeze(1).expand_as(class_support)
weighted_support = class_support * class_support_weights
self.centroids[i] = weighted_support.sum(dim=0)
def predict(self, query):
query_size = query.shape[0]
processed_query = self.preprocess(query).cpu().numpy()
scores = torch.zeros(query_size, self.n_way).cuda()
for i in range(query_size):
for j in range(self.n_way):
d = self.dist(self.centroids[j], processed_query[i])
scores[i][j] = np.exp(-d * d)
softmax = torch.nn.Softmax(dim=1)
scores = softmax(scores)
return scores
def predict_alt(self, query, measure="euclidean", norm_scores=False, temp=1.0):
query_size = query.shape[0]
scores = torch.zeros(query_size, self.n_way).cuda()
processed_query = query.unsqueeze(1).expand(query_size, self.n_way, self.feat_dim)
# processed_query = processed_query.cpu().numpy()
centroids = self.centroids
centroids = centroids.unsqueeze(0).expand(query_size, self.n_way, self.feat_dim)
if measure == "euclidean":
dist = torch.norm(processed_query - centroids, p=2, dim=2)
scores = torch.exp(-dist * dist)
elif measure == "cosine":
inner_product = (processed_query * centroids).sum(dim=2)
n2 = torch.norm(centroids, p=2, dim=2)
n1 = torch.norm(processed_query, p=2, dim=2)
dist = inner_product / n1 / n2
scores = dist
elif measure == "kl":
# This has problem now
dist = self.kl_divergence(processed_query, centroids) / 4.0
elif measure == "linear":
dist = torch.norm(processed_query - centroids, p=2, dim=2)
scores = -dist * dist / temp
if norm_scores:
scores = scores / torch.norm(scores, p=2, dim=1).unsqueeze(1).expand(-1, self.n_way)
softmax = torch.nn.Softmax(dim=1)
scores = softmax(scores)
return scores
class MultiNNBiClassifier():
def __init__(self, n_way, n_classifiers, measure="linear", fusion="linear_sum", temp=1.0):
self.n_way = n_way
self.n_classifiers = n_classifiers
self.x_clfs = [NNClassifier(n_way) for i in range(n_classifiers)]
self.d_clfs = [NNClassifier(n_way) for i in range(n_classifiers)]
self.measure = measure
self.temp = temp
self.proba_fusion = fusion
def fit(self, support_x, support_d, support_labels, support_weights=None):
for i in range(self.n_classifiers):
if support_weights is None:
self.x_clfs[i].fit(support_x[i], support_labels)
self.d_clfs[i].fit(support_d[i], support_labels)
else:
self.x_clfs[i].fit(support_x[i], support_labels, support_weights=support_weights[:, i])
self.d_clfs[i].fit(support_d[i], support_labels, support_weights=support_weights[:, i])
def fuse_proba(self, p1, p2):
sigmoid = torch.nn.Sigmoid()
if self.proba_fusion == "linear_sum":
return p1 + p2
elif self.proba_fusion == "product":
return torch.log(sigmoid(p1) * sigmoid(p2))
elif self.proba_fusion == "sum":
return torch.log(sigmoid(p1 + p2))
elif self.proba_fusion == "harmonic":
p = sigmoid(p1) * sigmoid(p2)
return torch.log(p / (1 + p))
def predict(self, query_x, query_d, weights=None, counterfactual=False):
if isinstance(query_x, list):
query_size = query_x[0].shape[0]
else:
query_size = query_x.shape[1]
scores = torch.zeros(self.n_classifiers, query_size, self.n_way).cuda()
for i in range(self.n_classifiers):
d_scores = self.d_clfs[i].predict_alt(query_d[i], self.measure, temp=self.temp)
if not counterfactual:
x_scores = self.x_clfs[i].predict_alt(query_x[i], self.measure, temp=self.temp)
else:
x_scores = torch.ones(d_scores.shape).cuda() / self.n_way
scores[i] = self.fuse_proba(x_scores, d_scores)
if weights is None:
combined_scores = scores.mean(dim=0)
else:
scores = scores.permute(1, 0, 2)
weights = weights.unsqueeze(2).expand(-1, -1, self.n_way)
combined_scores = (weights * scores).sum(dim=1)
return combined_scores
class MultiNNClassifier():
def __init__(self, n_way, n_classifiers, measure="euclidean", temp=1.0):
self.n_way = n_way
self.n_classifiers = n_classifiers
self.clfs = [NNClassifier(n_way) for i in range(n_classifiers)]
self.measure = measure
self.temp = temp
'''
support of shape (n_classifiers, N, feature_dim)
'''
def fit(self, support, support_labels, support_weights=None):
for i in range(self.n_classifiers):
if support_weights is None:
self.clfs[i].fit(support[i], support_labels)
else:
self.clfs[i].fit(support[i], support_labels, support_weights=support_weights[:, i])
'''
query of shape (n_classifiers, N, feature_dim)
optionally provide weights of shape (n_classifiers) for a weighted average
'''
def predict(self, query, weights=None):
if isinstance(query, list):
query_size = query[0].shape[0]
else:
query_size = query.shape[1]
scores = torch.zeros(self.n_classifiers, query_size, self.n_way).cuda()
for i in range(self.n_classifiers):
classifier_scores = self.clfs[i].predict_alt(query[i], self.measure, temp=self.temp)
scores[i] = classifier_scores
self.scores = scores
if weights is None:
combined_scores = scores.mean(dim=0)
else:
# weights = weights.unsqueeze(1).expand(-1, query_size).unsqueeze(2).expand(-1, -1, self.n_way)
scores = scores.permute(1, 0, 2)
weights = weights.unsqueeze(2).expand(-1, -1, self.n_way)
combined_scores = (weights * scores).sum(dim=1)
return combined_scores
class BidrectionalLSTM(nn.Module):
def __init__(self, size: int, layers: int):
"""Bidirectional LSTM used to generate fully conditional embeddings (FCE) of the support set as described
in the Matching Networks paper.
# Arguments
size: Size of input and hidden layers. These are constrained to be the same in order to implement the skip
connection described in Appendix A.2
layers: Number of LSTM layers
"""
super(BidrectionalLSTM, self).__init__()
self.num_layers = layers
self.batch_size = 1
# Force input size and hidden size to be the same in order to implement
# the skip connection as described in Appendix A.1 and A.2 of Matching Networks
self.lstm = nn.LSTM(input_size=size,
num_layers=layers,
hidden_size=size,
bidirectional=True)
def forward(self, inputs):
# Give None as initial state and Pytorch LSTM creates initial hidden states
output, (hn, cn) = self.lstm(inputs, None)
forward_output = output[:, :, :self.lstm.hidden_size]
backward_output = output[:, :, self.lstm.hidden_size:]
# g(x_i, S) = h_forward_i + h_backward_i + g'(x_i) as written in Appendix A.2
# AKA A skip connection between inputs and outputs is used
output = forward_output + backward_output + inputs
return output, hn, cn
class AttentionLSTM(nn.Module):
def __init__(self, size: int, unrolling_steps: int):
"""Attentional LSTM used to generate fully conditional embeddings (FCE) of the query set as described
in the Matching Networks paper.
# Arguments
size: Size of input and hidden layers. These are constrained to be the same in order to implement the skip
connection described in Appendix A.2
unrolling_steps: Number of steps of attention over the support set to compute. Analogous to number of
layers in a regular LSTM
"""
super(AttentionLSTM, self).__init__()
self.unrolling_steps = unrolling_steps
self.lstm_cell = nn.LSTMCell(input_size=size,
hidden_size=size)
def forward(self, support, queries):
# Get embedding dimension, d
if support.shape[-1] != queries.shape[-1]:
raise(ValueError("Support and query set have different embedding dimension!"))
batch_size = queries.shape[0]
embedding_dim = queries.shape[1]
h_hat = torch.zeros_like(queries).cuda().double()
c = torch.zeros(batch_size, embedding_dim).cuda().double()
for k in range(self.unrolling_steps):
# Calculate hidden state cf. equation (4) of appendix A.2
h = h_hat + queries
# Calculate softmax attentions between hidden states and support set embeddings
# cf. equation (6) of appendix A.2
attentions = torch.mm(h, support.t())
attentions = attentions.softmax(dim=1)
# Calculate readouts from support set embeddings cf. equation (5)
readout = torch.mm(attentions, support)
# Run LSTM cell cf. equation (3)
# h_hat, c = self.lstm_cell(queries, (torch.cat([h, readout], dim=1), c))
h_hat, c = self.lstm_cell(queries, (h + readout, c))
h = h_hat + queries
return h
class MultiLinearClassifier(nn.Module):
def __init__(self, n_clf, feat_dim, n_way, sum_log=True, permute=False, shapes=None, loss_type="softmax"):
super(MultiLinearClassifier, self).__init__()
self.n_clf = n_clf
self.feat_dim = feat_dim
self.n_way = n_way
self.sum_log = sum_log
self.softmax = nn.Softmax(dim=2)
self.permute = permute
self.shapes = shapes
if self.permute:
self.clfs = nn.ModuleList([self.create_clf(loss_type, shapes[i], n_way).cuda() for i in range(n_clf)])
else:
self.clfs = nn.ModuleList([self.create_clf(loss_type, feat_dim, n_way).cuda() for i in range(n_clf)])
def create_clf(self, loss_type, in_dim, out_dim):
if loss_type == "softmax":
return nn.Linear(in_dim, out_dim)
elif loss_type == "dist":
return distLinear(in_dim, out_dim, True)
def forward(self, X):
# X is n_clf * N * feat_dim
if self.permute:
N = X[0].shape[0]
else:
N = X.shape[1]
resp = torch.zeros(self.n_clf, N, self.n_way).cuda()
for i in range(self.n_clf):
resp[i] = self.clfs[i](X[i])
proba = self.softmax(resp)
if self.sum_log:
log_proba = torch.log(proba)
sum_log_proba = log_proba.mean(dim=0)
scores = sum_log_proba
else:
mean_proba = proba.mean(dim=0)
log_proba = torch.log(mean_proba)
scores = log_proba
return scores
class MultiBiLinearClassifier(nn.Module):
def __init__(self, n_clf, x_feat_dim, d_feat_dim, n_way, sum_log=True, loss_type="softmax", logit_fusion="linear_sum"):
super(MultiBiLinearClassifier, self).__init__()
self.n_clf = n_clf
self.x_feat_dim = x_feat_dim
self.d_feat_dim = d_feat_dim
self.n_way = n_way
self.sum_log = sum_log
self.softmax = nn.Softmax(dim=2)
self.logit_fusion = logit_fusion
self.x_clfs = nn.ModuleList([self.create_clf(loss_type, x_feat_dim, n_way).cuda() for i in range(n_clf)])
self.d_clfs = nn.ModuleList([self.create_clf(loss_type, d_feat_dim, n_way).cuda() for i in range(n_clf)])
def fuse_logits(self, p1, p2):
sigmoid = torch.nn.Sigmoid()
if self.logit_fusion == "linear_sum":
return p1 + p2
elif self.logit_fusion == "product":
return torch.log(sigmoid(p1) * sigmoid(p2))
elif self.logit_fusion == "sum":
return torch.log(sigmoid(p1 + p2))
elif self.logit_fusion == "harmonic":
p = sigmoid(p1) * sigmoid(p2)
return torch.log(p / (1 + p))
def create_clf(self, loss_type, in_dim, out_dim):
if loss_type == "softmax":
return nn.Linear(in_dim, out_dim)
elif loss_type == "dist":
return distLinear(in_dim, out_dim, True)
def forward(self, X, D, counterfactual=False):
# X is n_clf * N * feat_dim
N = X.shape[1]
resp = torch.zeros(self.n_clf, N, self.n_way).cuda()
for i in range(self.n_clf):
d_logit = self.d_clfs[i](D[i])
if counterfactual:
x_logit = torch.ones_like(d_logit).cuda()
else:
x_logit = self.x_clfs[i](X[i])
resp[i] = self.fuse_logits(x_logit, d_logit)
proba = self.softmax(resp)
if self.sum_log:
log_proba = torch.log(proba)
sum_log_proba = log_proba.mean(dim=0)
scores = sum_log_proba
else:
mean_proba = proba.mean(dim=0)
log_proba = torch.log(mean_proba)
scores = log_proba
return scores
class ResNetKernelClusterAgent():
def __init__(self, pretrain, n_clusters, pca_dim, cluster_method="kmeans"):
self.pretrain = pretrain
self.n_clusters = n_clusters
self.cluster_method = cluster_method
self.pca_dim = pca_dim
def fit(self):
# !!!! Note this 30 is a hard coded value that may need to change if not using ResNet10
weights = list(self.pretrain.model.parameters())[30]
# weights = weights.permute(1, 0, 2, 3)
# kernel_features = torch.flatten(weights, 1, 3)
kernel_features = torch.mean(weights, dim=1)
kernel_features = torch.flatten(kernel_features, 1, 2)
kernel_features_np = kernel_features.cpu().detach().numpy()
transformed = kernel_features_np
# Cluster: in this case, PCA then KMeans
# pca = PCA(n_components=self.pca_dim)
# transformed = pca.fit_transform(kernel_features_np)
kmeans = KMeans(n_clusters=self.n_clusters, random_state=0).fit(transformed)
return kmeans.labels_
class ResNetParamClusterModel():
def __init__(self, pretrain, n_clusters, cluster_method="kmeans"):
self.pretrain = pretrain
self.cluster_method = cluster_method
self.n_clusters = n_clusters
def cluster(self, features, n_clusters):
if self.cluster_method == "kmeans":
kmeans = KMeans(n_clusters=self.n_clusters, random_state=0).fit(features)
return kmeans.labels_
def get_weight_features(self, weights):
kernel_features = torch.mean(weights, dim=1)
kernel_features = torch.flatten(kernel_features, 1, 2)
kernel_features_np = kernel_features.cpu().detach().numpy()
return kernel_features_np
def fit(self):
c1_weights = list(self.pretrain.model.parameters())[27]
c2_weights = list(self.pretrain.model.parameters())[30]
c1_features = self.get_weight_features(c1_weights)
c2_features = self.get_weight_features(c2_weights)
self.c1_labels = self.cluster(c1_features, self.n_clusters)
self.c2_labels = self.cluster(c2_features, self.n_clusters)
def conv_forward(self, inputs, labels, n_clusters, original_conv):
output = original_conv(inputs)
desired_output = output.unsqueeze(0).expand(n_clusters, -1, -1, -1, -1)
# N * n_channels * size * size
new_output = torch.zeros(desired_output.shape).cuda()
for i in range(n_clusters):
cluster_output = output[:, labels == i, :, :]
n_channels = output.shape[1]
cluster_channels = cluster_output.shape[1]
n_repeat = int(n_channels / cluster_channels)
# Tile up the entire n_channels
tiled_channels = n_repeat * cluster_channels
remaining_channels = n_channels - tiled_channels
tiled_cluster_output = cluster_output.repeat(1, n_repeat, 1, 1)
#new_output[i, :, :tiled_channels, :, :] = tiled_cluster_output
#new_output[i, :, tiled_channels:, :, :] = cluster_output[:, :remaining_channels, :, :]
# Set specific channels to cluster output value
new_output[i, :, labels == i, :, :] = cluster_output
return new_output
def forward(self, imgs):
# imgs are N * 3 * 224 * 224
model = self.pretrain.model.feature.trunk
# 1. Get output before trunk 7
trunk6_out = model[:7](imgs)
# 2. Get output of C1
'''
out = self.C1(x)
out = self.BN1(out)
out = self.relu1(out)
out = self.C2(out)
out = self.BN2(out)
short_out = x if self.shortcut_type == 'identity' else self.BNshortcut(self.shortcut(x))
out = out + short_out
out = self.relu2(out)
return out
'''
t7 = model[7]
c1_outputs = self.conv_forward(trunk6_out, self.c1_labels, self.n_clusters, t7.C1)
output_features = []
shapes = []
for i in range(self.n_clusters):
out = c1_outputs[i]
out = t7.BN1(out)
out = t7.relu1(out)
out[:, self.c1_labels!=i, :, :] = 0
c2_outputs = self.conv_forward(out, self.c2_labels, self.n_clusters, t7.C2)
for j in range(self.n_clusters):
out = c2_outputs[j]
out = t7.BN2(out)
short_out = t7.BNshortcut(t7.shortcut(trunk6_out))
out = out + short_out
out = t7.relu2(out)
out = model[8:](out).detach()
out = out[:, self.c2_labels == j]
output_features.append(out)
shapes.append(out.shape[1])
return output_features, shapes
def forward2(self, imgs):
# Use only 1 conv block to cluster
# imgs are N * 3 * 224 * 224
model = self.pretrain.model.feature.trunk
# 1. Get output before trunk 7
trunk6_out = model[:7](imgs)
# 2. Get output of C1
t7 = model[7]
out = t7.C1(trunk6_out)
out = t7.BN1(out)
out = t7.relu1(out)
c2_outputs = self.conv_forward(out, self.c2_labels, self.n_clusters, t7.C2)
output_features = []
shapes = []
for i in range(self.n_clusters):
out = c2_outputs[i]
out = t7.BN2(out)
short_out = t7.BNshortcut(t7.shortcut(trunk6_out))
out = out + short_out
out = t7.relu2(out)
out = model[8:](out).detach()
out = out[:, self.c2_labels == i]
output_features.append(out)
shapes.append(out.shape[1])
return output_features, shapes
def forward3(self, imgs):
model = self.pretrain.model.feature.trunk
trunk6_out = model[:7](imgs)
# 2. Get output of C1
t7 = model[7]
class BasisTransformer():
def __init__(self, pretrain, recluster=False, cluster_method="kmeans", mode='project', kernel='rbf'):
self.pretrain = pretrain
self.recluster = recluster
self.mode = mode
self.kernel = kernel
self.cluster_method = cluster_method
def fit(self, n_clusters, feat_dim, pca_dim=50):
self.feat_dim = feat_dim
self.n_clusters = n_clusters
# Get features
features, labels = self.pretrain.get_pretrain_dataset('base')
# Perform PCA dimension reduction before k-means
if pca_dim > 0:
pca_model = PCA(n_components=pca_dim)
features_reduced = pca_model.fit_transform(features)
else:
features_reduced = features
# K Means clustering
if self.recluster:
if self.cluster_method == "kmeans":
new_labels_file = "kmeans/new_labels_%s_%s_%s_%s.npy" % (str(self.n_clusters), self.pretrain.method,
self.pretrain.model_name, self.pretrain.dataset)
if os.path.isfile(new_labels_file):
new_labels = np.load(new_labels_file)
else:
kmeans_model = KMeans(n_clusters=n_clusters, random_state=0).fit(features_reduced)
self.kmeans_model = kmeans_model
new_labels = kmeans_model.labels_
np.save(new_labels_file, new_labels)
elif self.cluster_method == "hdbscan":
a = 1
else:
new_labels = labels
# Fit basis transformation function
self.basis_transform_models = []
for i in range(n_clusters):
cluster_features = features[new_labels == i]
if self.mode == 'project':
model = PCA(n_components=feat_dim)
model.fit(cluster_features)
elif self.mode == 'kernel':
model = KernelTransformer(feat_dim, self.kernel)
model.fit(cluster_features)
self.basis_transform_models.append(model)
def transform(self, X):
# X is N * original_feat_dim
N = X.shape[0]
transformed = np.zeros((self.n_clusters, N, self.feat_dim))
for i in range(self.n_clusters):
transformed[i] = self.basis_transform_models[i].transform(X)
return transformed
class KernelTransformer():
def __init__(self, feat_dim, kernel):
self.feat_dim = feat_dim
self.kernel = kernel
def fit(self, features):
N = features.shape[0]
# Randomly sample feat_dim points from features
rand_id = np.random.permutation(N)
selected_id = rand_id[:self.feat_dim]
self.centroids = features[selected_id]
def transform(self, X):
N = X.shape[0]
transformed = np.zeros((N, self.feat_dim))
for i in range(N):
for j in range(self.feat_dim):
transformed[i][j] = self.kernel_f(X[i], self.centroids[j])
return transformed
def kernel_f(self, x1, x2):
if self.kernel == "rbf":
diff = x1 - x2
norm = LA.norm(diff, ord=2)
t = -0.5 * norm * norm / 10
k = np.exp(t)
elif self.kernel == "linear":
k = np.sum(x1 * x2)
return k
class ChannelwiseClassifier(nn.Module):
def __init__(self, feat_dim, n_way, weight, bias=False):
super(ChannelwiseClassifier, self).__init__()
self.n_way = n_way
self.feat_dim = feat_dim
self.bias = bias
self.W = nn.Parameter(torch.Tensor(n_way, feat_dim))
if self.bias:
self.B = nn.Parameter(torch.Tensor(n_way, feat_dim))
# self.reset_parameters()
with torch.no_grad():
self.W.copy_(weight)
def reset_parameters(self):
# init.kaiming_uniform_(self.W, a=math.sqrt(5))
init.uniform_(self.W, -0.75, 0.75)
# init.normal_(self.W)
if self.bias:
fan_in, _ = init._calculate_fan_in_and_fan_out(self.W)
bound = 1 / math.sqrt(fan_in)
init.uniform_(self.B, -bound, bound)
def forward(self, X):
n = X.shape[0]
# W_expanded = self.W.unsqueeze(0).expand(n, -1, -1)
X_expanded = X.unsqueeze(1).expand(-1, self.n_way, -1)
scores = torch.zeros(n, self.n_way, self.feat_dim).cuda()
for i in range(n):
if self.bias:
scores[i] = self.W * X_expanded[i] + self.B
else:
dist = (self.W - X_expanded[i]).pow(2)
scores[i] = -dist
'''
if self.bias:
B_expanded = self.B.unsqueeze(0).expand(n, -1, -1)
scores = W_expanded * X_expanded + B_expanded
else:
scores = W_expanded * X_expanded
'''
scores = scores
softmax = nn.Softmax(dim=1)
softmax_scores = softmax(scores)
probability = torch.mean(softmax_scores, dim=2).view(n, self.n_way)
return probability, softmax_scores
class UnbiasedClassifier(nn.Module):
'''
Input:
architecture
softmax
feature_fusion
concat
has_x_branch_classifier
If True, will have a separate x classifier and will produce self.branch_clf_resp after forward
logit_fusion
Requried if has_x_branch_classifier; Otherwise ignored
product, sum, harmonic
'''
def __init__(self, n_way, x_feature_dim, z_feature_dim, d_feature_dim=0,
has_d_branch=False, has_x_branch_classifier=False, architecture="softmax",
feature_fusion="concat", logit_fusion="product"):
super(UnbiasedClassifier, self).__init__()
self.n_way = n_way
self.architecture = architecture
self.feature_fusion = feature_fusion
self.logit_fusion = logit_fusion
self.has_x_branch_classifier = has_x_branch_classifier
self.has_d_branch = has_d_branch
self.x_feature_dim = x_feature_dim
self.z_feature_dim = z_feature_dim
self.d_feature_dim = d_feature_dim
self.main_feature_dim = x_feature_dim + z_feature_dim
if self.has_d_branch:
self.main_feature_dim += d_feature_dim
if self.has_x_branch_classifier:
self.branch_clf = self.create_clf(x_feature_dim)
self.logit_fusion_fn = self.create_logit_fusion_fn()
self.main_clf = self.create_clf(self.main_feature_dim)
def create_clf(self, feat_dim):
if self.architecture == "softmax":
clf = nn.Linear(feat_dim, self.n_way).cuda()
elif self.architecture == "dist":
clf = distLinear(feat_dim, self.n_way, True).cuda()
return clf
def create_logit_fusion_fn(self):
if self.logit_fusion == "product":
fn = ProductGate().cuda()
if self.logit_fusion == "harmonic":
fn = HarmonicGate().cuda()
if self.logit_fusion == "sum":
fn = SumGate().cuda()
return fn
def get_fused_feature(self, feature_array):
if self.feature_fusion == "concat":
return torch.cat(feature_array, 1)
def forward(self, X, Z, D=None):
batch_size = X.shape[0]
if self.has_d_branch:
fused_feature = self.get_fused_feature((X, Z, D))
else:
fused_feature = self.get_fused_feature((X, Z))
main_clf_resp = self.main_clf(fused_feature)
if self.has_x_branch_classifier:
branch_feat = X
branch_clf_resp = self.branch_clf(branch_feat)
cat_resp = torch.cat((main_clf_resp.view(batch_size, self.n_way, 1), branch_clf_resp.view(batch_size, self.n_way, 1)), dim=2)
combined_resp = self.logit_fusion_fn(cat_resp).view(batch_size, self.n_way)
self.branch_clf_resp = branch_clf_resp
return combined_resp
else:
return main_clf_resp
class XDBiClassifier(nn.Module):
def __init__(self, n_way, x_feature_dim, d_feature_dim, architecture="softmax",
fusion="product", d_clf_is_linear=True, sigmoid_d_resp=False):
super(XDBiClassifier, self).__init__()
self.n_way = n_way
self.architecture = architecture
self.x_feature_dim = x_feature_dim
self.d_feature_dim = d_feature_dim
self.x_clf = self.create_clf(self.x_feature_dim)
if d_clf_is_linear:
self.d_clf = nn.Linear(self.d_feature_dim, self.n_way).cuda()
else:
self.d_clf = self.create_clf(self.d_feature_dim)
self.logit_fusion = fusion
self.fusion_fn = self.create_logit_fusion_fn()
self.sigmoid = nn.Sigmoid().cuda()
self.sigmoid_d_resp = sigmoid_d_resp
def create_clf(self, feat_dim):
if self.architecture == "softmax":
clf = nn.Linear(feat_dim, self.n_way).cuda()
elif self.architecture == "dist":
clf = distLinear(feat_dim, self.n_way, True).cuda()
return clf
def create_logit_fusion_fn(self):
if self.logit_fusion == "product":
fn = ProductGate().cuda()
if self.logit_fusion == "harmonic":
fn = HarmonicGate().cuda()
if self.logit_fusion == "sum":
fn = SumGate().cuda()
if self.logit_fusion == "linear_sum":
fn = nn.Linear(2, 1, bias=False).cuda()
return fn
def cat_for_logit_fusion(self, A, B):
batch_size = A.shape[0]
cat_resp = torch.cat((A.view(batch_size, self.n_way, 1), B.view(batch_size, self.n_way, 1)), dim=2)
return cat_resp
def forward(self, X, D):
actual_batch_size = X.shape[0]
x_resp = self.x_clf(X)
d_resp = self.d_clf(D)
self.d_resp = d_resp
self.x_resp = x_resp
if self.logit_fusion != "linear_sum":
cat_resp = self.cat_for_logit_fusion(x_resp, d_resp)
scores = self.fusion_fn(cat_resp).view(actual_batch_size, self.n_way)
else:
# change d_resp to -1 to 1
if self.sigmoid_d_resp:
d_resp = (self.sigmoid(d_resp) - 0.5) * 2
scores = x_resp + d_resp
return scores
class XDClassifier(nn.Module):
'''
Input:
architecture
softmax
feature_fusion
concat
has_x_branch_classifier
If True, will have a separate x classifier and will produce self.branch_clf_resp after forward
logit_fusion
Requried if has_x_branch_classifier; Otherwise ignored
product, sum, harmonic
'''
def __init__(self, n_way, x_feature_dim, d_feature_dim, architecture="softmax",
feature_fusion="concat", transform_d=False, hidden_nodes=50, use_d=True):
super(XDClassifier, self).__init__()
self.n_way = n_way
self.architecture = architecture
self.feature_fusion = feature_fusion
self.x_feature_dim = x_feature_dim
self.d_feature_dim = d_feature_dim
self.feat_dim = self.get_feature_dim()
self.transform_d = transform_d
self.use_d = use_d
if self.transform_d:
self.hidden_nodes = hidden_nodes
self.transform_linear = nn.Linear(self.feat_dim, self.hidden_nodes).cuda()
self.transform_activation = nn.LeakyReLU().cuda()
self.clf = self.create_clf(self.hidden_nodes)
else:
if use_d:
self.clf = self.create_clf(self.feat_dim)
else:
self.clf = self.create_clf(self.x_feature_dim)
def get_feature_dim(self):
if self.feature_fusion == "concat":
return self.x_feature_dim + self.d_feature_dim
elif self.feature_fusion == "sum":
assert self.x_feature_dim == self.d_feature_dim
return self.x_feature_dim
elif self.feature_fusion == "gate":
assert self.x_feature_dim == self.d_feature_dim
return self.x_feature_dim
elif self.feature_fusion == "-":
return self.x_feature_dim
elif self.feature_fusion == "+":
return self.x_feature_dim
def create_clf(self, feat_dim):
if self.architecture == "softmax":
clf = nn.Linear(feat_dim, self.n_way).cuda()
elif self.architecture == "dist":
clf = distLinear(feat_dim, self.n_way, True).cuda()
return clf
def get_fused_feature(self, feature_array):
if self.feature_fusion == "concat":
return torch.cat(feature_array, 1)
elif self.feature_fusion == "-":
return feature_array[0] - feature_array[1]
elif self.feature_fusion == "+":
return feature_array[0] + feature_array[1]
def forward(self, X, D):
if self.transform_d:
fused_feature = self.get_fused_feature((X, D))
hidden_resp = self.transform_linear(fused_feature)
activated_resp = self.transform_activation(hidden_resp)
clf_resp = self.clf(activated_resp)
else:
if self.use_d:
fused_feature = self.get_fused_feature((X, D))
clf_resp = self.clf(fused_feature)
else:
clf_resp = self.clf(X)
return clf_resp
class ProductGate(nn.Module):
def __init__(self):
super(ProductGate, self).__init__()
def forward(self, x):
permuted = x.permute(1, 0, 2)
sigmoid = nn.Sigmoid().cuda()
sig_results = sigmoid(permuted)
product = torch.mul(sig_results[:, :, 0], sig_results[:, :, 1])
log = torch.log(product).permute(1, 0)
return log
class HarmonicGate(nn.Module):
def __init__(self):
super(HarmonicGate, self).__init__()
def forward(self, x):
permuted = x.permute(1, 0, 2)
sigmoid = nn.Sigmoid().cuda()
sig_results = sigmoid(permuted)
product = torch.mul(sig_results[:, :, 0], sig_results[:, :, 1])
val = product / (1 + product)
log = torch.log(val).permute(1, 0)
return log
class SumGate(nn.Module):
def __init__(self):
super(SumGate, self).__init__()
def forward(self, x):
permuted = x.permute(1, 0, 2)
sigmoid = nn.Sigmoid().cuda()
sig_results = sigmoid(permuted[:, :, 0] + permuted[:, :, 1])
log = torch.log(sig_results).permute(1, 0)
return log
class distLinear(nn.Module):
def __init__(self, indim, outdim, class_wise_learnable_norm=True):
super(distLinear, self).__init__()
self.L = nn.Linear( indim, outdim, bias = False)
self.class_wise_learnable_norm = class_wise_learnable_norm #See the issue#4&8 in the github
if self.class_wise_learnable_norm:
WeightNorm.apply(self.L, 'weight', dim=0) #split the weight update component to direction and norm
if outdim <= 200:
self.scale_factor = 2 # a fixed scale factor to scale the output of cos value into a reasonably large input for softmax, for to reproduce the result of CUB with ResNet10, use 4. see the issue#31 in the github
else:
self.scale_factor = 10; #in omniglot, a larger scale factor is required to handle >1000 output classes.
def forward(self, x):
x_norm = torch.norm(x, p=2, dim =1).unsqueeze(1).expand_as(x)
x_normalized = x.div(x_norm + 0.00001)
if not self.class_wise_learnable_norm:
L_norm = torch.norm(self.L.weight.data, p=2, dim=1).unsqueeze(1).expand_as(self.L.weight.data)
self.L.weight.data = self.L.weight.data.div(L_norm + 0.00001)
cos_dist = self.L(x_normalized) #matrix product by forward function, but when using WeightNorm, this also multiply the cosine distance by a class-wise learnable norm, see the issue#4&8 in the github
scores = self.scale_factor* (cos_dist)
return scores
class Flatten(nn.Module):
def __init__(self):
super(Flatten, self).__init__()
def forward(self, x):
return x.view(x.size(0), -1)
class Linear_fw(nn.Linear): #used in MAML to forward input with fast weight
def __init__(self, in_features, out_features):
super(Linear_fw, self).__init__(in_features, out_features)
self.weight.fast = None #Lazy hack to add fast weight link
self.bias.fast = None
def forward(self, x):
if self.weight.fast is not None and self.bias.fast is not None:
out = F.linear(x, self.weight.fast, self.bias.fast) #weight.fast (fast weight) is the temporaily adapted weight
else:
out = super(Linear_fw, self).forward(x)
return out
class Conv2d_fw(nn.Conv2d): #used in MAML to forward input with fast weight
def __init__(self, in_channels, out_channels, kernel_size, stride=1,padding=0, bias = True):
super(Conv2d_fw, self).__init__(in_channels, out_channels, kernel_size, stride=stride, padding=padding, bias=bias)
self.weight.fast = None
if not self.bias is None:
self.bias.fast = None
def forward(self, x):
if self.bias is None:
if self.weight.fast is not None:
out = F.conv2d(x, self.weight.fast, None, stride= self.stride, padding=self.padding)
else:
out = super(Conv2d_fw, self).forward(x)
else:
if self.weight.fast is not None and self.bias.fast is not None:
out = F.conv2d(x, self.weight.fast, self.bias.fast, stride= self.stride, padding=self.padding)
else:
out = super(Conv2d_fw, self).forward(x)
return out
class BatchNorm2d_fw(nn.BatchNorm2d): #used in MAML to forward input with fast weight
def __init__(self, num_features):
super(BatchNorm2d_fw, self).__init__(num_features)
self.weight.fast = None
self.bias.fast = None
def forward(self, x):
running_mean = torch.zeros(x.data.size()[1]).cuda()
running_var = torch.ones(x.data.size()[1]).cuda()
if self.weight.fast is not None and self.bias.fast is not None:
out = F.batch_norm(x, running_mean, running_var, self.weight.fast, self.bias.fast, training = True, momentum = 1)
#batch_norm momentum hack: follow hack of Kate Rakelly in pytorch-maml/src/layers.py
else:
out = F.batch_norm(x, running_mean, running_var, self.weight, self.bias, training = True, momentum = 1)
return out
# Simple Conv Block
class ConvBlock(nn.Module):
maml = False #Default
def __init__(self, indim, outdim, pool = True, padding = 1):
super(ConvBlock, self).__init__()
self.indim = indim
self.outdim = outdim
if self.maml:
self.C = Conv2d_fw(indim, outdim, 3, padding = padding)
self.BN = BatchNorm2d_fw(outdim)
else:
self.C = nn.Conv2d(indim, outdim, 3, padding= padding)
self.BN = nn.BatchNorm2d(outdim)
self.relu = nn.ReLU(inplace=True)
self.parametrized_layers = [self.C, self.BN, self.relu]
if pool:
self.pool = nn.MaxPool2d(2)
self.parametrized_layers.append(self.pool)
for layer in self.parametrized_layers:
init_layer(layer)
self.trunk = nn.Sequential(*self.parametrized_layers)
def forward(self,x):
out = self.trunk(x)
return out
# Simple ResNet Block
class SimpleBlock(nn.Module):
maml = False #Default
def __init__(self, indim, outdim, half_res):
super(SimpleBlock, self).__init__()
self.indim = indim
self.outdim = outdim
if self.maml:
self.C1 = Conv2d_fw(indim, outdim, kernel_size=3, stride=2 if half_res else 1, padding=1, bias=False)
self.BN1 = BatchNorm2d_fw(outdim)
self.C2 = Conv2d_fw(outdim, outdim,kernel_size=3, padding=1,bias=False)
self.BN2 = BatchNorm2d_fw(outdim)
else:
self.C1 = nn.Conv2d(indim, outdim, kernel_size=3, stride=2 if half_res else 1, padding=1, bias=False)
self.BN1 = nn.BatchNorm2d(outdim)
self.C2 = nn.Conv2d(outdim, outdim,kernel_size=3, padding=1,bias=False)
self.BN2 = nn.BatchNorm2d(outdim)
self.relu1 = nn.ReLU(inplace=True)
self.relu2 = nn.ReLU(inplace=True)
self.parametrized_layers = [self.C1, self.C2, self.BN1, self.BN2]
self.half_res = half_res
# if the input number of channels is not equal to the output, then need a 1x1 convolution
if indim!=outdim:
if self.maml:
self.shortcut = Conv2d_fw(indim, outdim, 1, 2 if half_res else 1, bias=False)
self.BNshortcut = BatchNorm2d_fw(outdim)
else:
self.shortcut = nn.Conv2d(indim, outdim, 1, 2 if half_res else 1, bias=False)
self.BNshortcut = nn.BatchNorm2d(outdim)
self.parametrized_layers.append(self.shortcut)
self.parametrized_layers.append(self.BNshortcut)
self.shortcut_type = '1x1'
else:
self.shortcut_type = 'identity'
for layer in self.parametrized_layers:
init_layer(layer)
def forward(self, x):
out = self.C1(x)
out = self.BN1(out)
out = self.relu1(out)
out = self.C2(out)
out = self.BN2(out)
short_out = x if self.shortcut_type == 'identity' else self.BNshortcut(self.shortcut(x))
out = out + short_out
out = self.relu2(out)
return out
# Bottleneck block
class BottleneckBlock(nn.Module):
maml = False #Default
def __init__(self, indim, outdim, half_res):
super(BottleneckBlock, self).__init__()
bottleneckdim = int(outdim/4)
self.indim = indim
self.outdim = outdim
if self.maml:
self.C1 = Conv2d_fw(indim, bottleneckdim, kernel_size=1, bias=False)
self.BN1 = BatchNorm2d_fw(bottleneckdim)
self.C2 = Conv2d_fw(bottleneckdim, bottleneckdim, kernel_size=3, stride=2 if half_res else 1,padding=1)
self.BN2 = BatchNorm2d_fw(bottleneckdim)
self.C3 = Conv2d_fw(bottleneckdim, outdim, kernel_size=1, bias=False)
self.BN3 = BatchNorm2d_fw(outdim)
else:
self.C1 = nn.Conv2d(indim, bottleneckdim, kernel_size=1, bias=False)
self.BN1 = nn.BatchNorm2d(bottleneckdim)
self.C2 = nn.Conv2d(bottleneckdim, bottleneckdim, kernel_size=3, stride=2 if half_res else 1,padding=1)
self.BN2 = nn.BatchNorm2d(bottleneckdim)
self.C3 = nn.Conv2d(bottleneckdim, outdim, kernel_size=1, bias=False)
self.BN3 = nn.BatchNorm2d(outdim)
self.relu = nn.ReLU()
self.parametrized_layers = [self.C1, self.BN1, self.C2, self.BN2, self.C3, self.BN3]
self.half_res = half_res
# if the input number of channels is not equal to the output, then need a 1x1 convolution
if indim!=outdim:
if self.maml:
self.shortcut = Conv2d_fw(indim, outdim, 1, stride=2 if half_res else 1, bias=False)
else:
self.shortcut = nn.Conv2d(indim, outdim, 1, stride=2 if half_res else 1, bias=False)
self.parametrized_layers.append(self.shortcut)
self.shortcut_type = '1x1'
else:
self.shortcut_type = 'identity'
for layer in self.parametrized_layers:
init_layer(layer)
def forward(self, x):
short_out = x if self.shortcut_type == 'identity' else self.shortcut(x)
out = self.C1(x)
out = self.BN1(out)
out = self.relu(out)
out = self.C2(out)
out = self.BN2(out)
out = self.relu(out)
out = self.C3(out)
out = self.BN3(out)
out = out + short_out
out = self.relu(out)
return out
class ConvNet(nn.Module):
def __init__(self, depth, flatten = True):
super(ConvNet,self).__init__()
trunk = []
for i in range(depth):
indim = 3 if i == 0 else 64
outdim = 64
B = ConvBlock(indim, outdim, pool = ( i <4 ) ) #only pooling for fist 4 layers
trunk.append(B)
if flatten:
trunk.append(Flatten())
self.trunk = nn.Sequential(*trunk)
self.final_feat_dim = 1600
def forward(self,x):
out = self.trunk(x)
return out
class ConvNetNopool(nn.Module): #Relation net use a 4 layer conv with pooling in only first two layers, else no pooling
def __init__(self, depth):
super(ConvNetNopool,self).__init__()
trunk = []
for i in range(depth):
indim = 3 if i == 0 else 64
outdim = 64
B = ConvBlock(indim, outdim, pool = ( i in [0,1] ), padding = 0 if i in[0,1] else 1 ) #only first two layer has pooling and no padding
trunk.append(B)
self.trunk = nn.Sequential(*trunk)
self.final_feat_dim = [64,19,19]
def forward(self,x):
out = self.trunk(x)
return out
class ConvNetS(nn.Module): #For omniglot, only 1 input channel, output dim is 64
def __init__(self, depth, flatten = True):
super(ConvNetS,self).__init__()
trunk = []
for i in range(depth):
indim = 1 if i == 0 else 64
outdim = 64
B = ConvBlock(indim, outdim, pool = ( i <4 ) ) #only pooling for fist 4 layers
trunk.append(B)
if flatten:
trunk.append(Flatten())
self.trunk = nn.Sequential(*trunk)
self.final_feat_dim = 64
def forward(self,x):
out = x[:,0:1,:,:] #only use the first dimension
out = self.trunk(out)
return out
class ConvNetSNopool(nn.Module): #Relation net use a 4 layer conv with pooling in only first two layers, else no pooling. For omniglot, only 1 input channel, output dim is [64,5,5]
def __init__(self, depth):
super(ConvNetSNopool,self).__init__()
trunk = []
for i in range(depth):
indim = 1 if i == 0 else 64
outdim = 64
B = ConvBlock(indim, outdim, pool = ( i in [0,1] ), padding = 0 if i in[0,1] else 1 ) #only first two layer has pooling and no padding
trunk.append(B)
self.trunk = nn.Sequential(*trunk)
self.final_feat_dim = [64,5,5]
def forward(self,x):
out = x[:,0:1,:,:] #only use the first dimension
out = self.trunk(out)
return out
class ResNet(nn.Module):
maml = False #Default
def __init__(self,block,list_of_num_layers, list_of_out_dims, flatten = True):
# list_of_num_layers specifies number of layers in each stage
# list_of_out_dims specifies number of output channel for each stage
super(ResNet,self).__init__()
assert len(list_of_num_layers)==4, 'Can have only four stages'
if self.maml:
conv1 = Conv2d_fw(3, 64, kernel_size=7, stride=2, padding=3,
bias=False)
bn1 = BatchNorm2d_fw(64)
else:
conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
bias=False)
bn1 = nn.BatchNorm2d(64)
relu = nn.ReLU()
pool1 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
init_layer(conv1)
init_layer(bn1)
trunk = [conv1, bn1, relu, pool1]
indim = 64
for i in range(4):
for j in range(list_of_num_layers[i]):
half_res = (i>=1) and (j==0)
B = block(indim, list_of_out_dims[i], half_res)
trunk.append(B)
indim = list_of_out_dims[i]
if flatten:
avgpool = nn.AvgPool2d(7)
trunk.append(avgpool)
trunk.append(Flatten())
self.final_feat_dim = indim
else:
self.final_feat_dim = [ indim, 7, 7]
self.trunk = nn.Sequential(*trunk)
def forward(self,x):
out = self.trunk(x)
return out
def Conv4():
return ConvNet(4)
def Conv6():
return ConvNet(6)
def Conv4NP():
return ConvNetNopool(4)
def Conv6NP():
return ConvNetNopool(6)
def Conv4S():
return ConvNetS(4)
def Conv4SNP():
return ConvNetSNopool(4)
def ResNet10( flatten = True):
return ResNet(SimpleBlock, [1,1,1,1],[64,128,256,512], flatten)
def ResNet18( flatten = True):
return ResNet(SimpleBlock, [2,2,2,2],[64,128,256,512], flatten)
def ResNet34( flatten = True):
return ResNet(SimpleBlock, [3,4,6,3],[64,128,256,512], flatten)
def ResNet50( flatten = True):
return ResNet(BottleneckBlock, [3,4,6,3], [256,512,1024,2048], flatten)
def ResNet101( flatten = True):
return ResNet(BottleneckBlock, [3,4,23,3],[256,512,1024,2048], flatten)
================================================
FILE: MAML_MN_FT/configs.py
================================================
save_dir = '/data2/yuezhongqi/Model/CloserLookFSL/' # Change to desired saving dir
data_dir = {}
data_dir['CUB'] = './filelists/CUB/'
data_dir['miniImagenet'] = './filelists/miniImagenet/'
data_dir['tiered'] = './filelists/tiered/'
simple_shot_dir = "/data2/yuezhongqi/Model/simple_shot/" # Location of the downloaded pretrained model
feat_dir = "/data2/yuezhongqi/Model/feat/" # Location of the downloaded pretrained model
tiered_dir = "/data2/yuezhongqi/Dataset/tiered" # Location of the downloaded tieredImageNet
================================================
FILE: MAML_MN_FT/data/__init__.py
================================================
from . import datamgr
from . import dataset
from . import additional_transforms
from . import feature_loader
================================================
FILE: MAML_MN_FT/data/additional_transforms.py
================================================
# Copyright 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import torch
from PIL import ImageEnhance
transformtypedict=dict(Brightness=ImageEnhance.Brightness, Contrast=ImageEnhance.Contrast, Sharpness=ImageEnhance.Sharpness, Color=ImageEnhance.Color)
class ImageJitter(object):
def __init__(self, transformdict):
self.transforms = [(transformtypedict[k], transformdict[k]) for k in transformdict]
def __call__(self, img):
out = img
randtensor = torch.rand(len(self.transforms))
for i, (transformer, alpha) in enumerate(self.transforms):
r = alpha*(randtensor[i]*2.0 -1.0) + 1
out = transformer(out).enhance(r).convert('RGB')
return out
================================================
FILE: MAML_MN_FT/data/datamgr.py
================================================
# This code is modified from https://github.com/facebookresearch/low-shot-shrink-hallucinate
import torch
from PIL import Image
import numpy as np
import torchvision.transforms as transforms
import data.additional_transforms as add_transforms
from data.dataset import SimpleDataset, SetDataset, EpisodicBatchSampler, SimpleTieredDataset
from abc import abstractmethod
class TransformLoader:
def __init__(self, image_size,
normalize_param=dict(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
jitter_param=dict(Brightness=0.4, Contrast=0.4, Color=0.4)):
self.image_size = image_size
self.normalize_param = normalize_param
self.jitter_param = jitter_param
def parse_transform(self, transform_type):
if transform_type == 'ImageJitter':
method = add_transforms.ImageJitter(self.jitter_param)
return method
method = getattr(transforms, transform_type)
if transform_type == 'RandomSizedCrop':
return method(self.image_size)
elif transform_type == 'CenterCrop':
return method(self.image_size)
elif transform_type == 'Scale':
return method([int(self.image_size * 1.15), int(self.image_size * 1.15)])
elif transform_type == 'Normalize':
return method(**self.normalize_param )
else:
return method()
def get_composed_transform(self, aug=False):
if aug:
transform_list = ['RandomSizedCrop', 'ImageJitter', 'RandomHorizontalFlip', 'ToTensor', 'Normalize']
else:
transform_list = ['Scale', 'CenterCrop', 'ToTensor', 'Normalize']
transform_funcs = [self.parse_transform(x) for x in transform_list]
transform = transforms.Compose(transform_funcs)
return transform
class DataManager:
@abstractmethod
def get_data_loader(self, data_file, aug):
pass
class SimpleDataManager(DataManager):
def __init__(self, image_size, batch_size):
super(SimpleDataManager, self).__init__()
self.batch_size = batch_size
self.trans_loader = TransformLoader(image_size)
def get_data_loader(self, data_file, aug, num_workers=12, tiered_mini=False): # parameters that would change on train/val set
transform = self.trans_loader.get_composed_transform(aug)
if not tiered_mini:
dataset = SimpleDataset(data_file, transform)
else:
dataset = SimpleTieredDataset(data_file, transform)
data_loader_params = dict(batch_size=self.batch_size, shuffle=True, num_workers=num_workers, pin_memory=True)
data_loader = torch.utils.data.DataLoader(dataset, **data_loader_params)
return data_loader
class SetDataManager(DataManager):
def __init__(self, image_size, n_way, n_support, n_query, n_eposide=100):
super(SetDataManager, self).__init__()
self.image_size = image_size
self.n_way = n_way
self.batch_size = n_support + n_query
self.n_eposide = n_eposide
self.trans_loader = TransformLoader(image_size)
def get_data_loader(self, data_file, aug, debug=False): # parameters that would change on train/val set
transform = self.trans_loader.get_composed_transform(aug)
dataset = SetDataset(data_file, self.batch_size, transform )
sampler = EpisodicBatchSampler(len(dataset), self.n_way, self.n_eposide)
n_workers = 12
if debug:
n_workers = 0
data_loader_params = dict(batch_sampler=sampler, num_workers=n_workers, pin_memory=True)
data_loader = torch.utils.data.DataLoader(dataset, **data_loader_params)
return data_loader
================================================
FILE: MAML_MN_FT/data/dataset.py
================================================
# This code is modified from https://github.com/facebookresearch/low-shot-shrink-hallucinate
import torch
from PIL import Image
import json
import numpy as np
import torchvision.transforms as transforms
import os
import os.path as osp
import configs
identity = lambda x: x
class SimpleDataset:
def __init__(self, data_file, transform, target_transform=identity):
with open(data_file, 'r') as f:
self.meta = json.load(f)
self.transform = transform
self.target_transform = target_transform
def __getitem__(self, i):
image_path = os.path.join(self.meta['image_names'][i])
img = Image.open(image_path).convert('RGB')
img = self.transform(img)
target = self.target_transform(self.meta['image_labels'][i])
return img, target, image_path
def __len__(self):
return len(self.meta['image_names'])
class SimpleTieredDataset:
def __init__(self, setname, transform):
self.transform = transform
if setname == 'base':
THE_PATH = osp.join(configs.tiered_dir, 'train')
label_list = os.listdir(THE_PATH)
elif setname == 'novel':
THE_PATH = osp.join(configs.tiered_dir, 'test')
label_list = os.listdir(THE_PATH)
elif setname == 'val':
THE_PATH = osp.join(configs.tiered_dir, 'val')
label_list = os.listdir(THE_PATH)
else:
raise ValueError('Wrong setname.')
# Generate empty list for data and label
data = []
label = []
# Get folders' name
folders = [osp.join(THE_PATH, the_label) for the_label in label_list if os.path.isdir(osp.join(THE_PATH, the_label))]
# Get the images' paths and labels
for idx, this_folder in enumerate(folders):
this_folder_images = os.listdir(this_folder)
for image_path in this_folder_images:
data.append(osp.join(this_folder, image_path))
label.append(idx)
# Set data, label and class number to be accessable from outside
self.data = data
self.label = label
self.num_class = len(set(label))
def __len__(self):
return len(self.data)
def __getitem__(self, i):
path, label = self.data[i], self.label[i]
image = self.transform(Image.open(path).convert('RGB'))
return image, label, path
class SetDataset:
def __init__(self, data_file, batch_size, transform):
with open(data_file, 'r') as f:
self.meta = json.load(f)
self.cl_list = np.unique(self.meta['image_labels']).tolist()
self.sub_meta = {}
for cl in self.cl_list:
self.sub_meta[cl] = []
for x, y in zip(self.meta['image_names'], self.meta['image_labels']):
self.sub_meta[y].append(x)
self.sub_dataloader = []
sub_data_loader_params = dict(batch_size=batch_size,
shuffle=True,
num_workers=0, # use main thread only or may receive multiple batches
pin_memory=False)
for cl in self.cl_list:
sub_dataset = SubDataset(self.sub_meta[cl], cl, transform=transform)
self.sub_dataloader.append(torch.utils.data.DataLoader(sub_dataset, **sub_data_loader_params))
def __getitem__(self, i):
return next(iter(self.sub_dataloader[i]))
def __len__(self):
return len(self.cl_list)
class SubDataset:
def __init__(self, sub_meta, cl, transform=transforms.ToTensor(), target_transform=identity):
self.sub_meta = sub_meta
self.cl = cl
self.transform = transform
self.target_transform = target_transform
def __getitem__(self, i):
# print( '%d -%d' %(self.cl,i))
image_path = os.path.join(self.sub_meta[i])
img = Image.open(image_path).convert('RGB')
img = self.transform(img)
target = self.target_transform(self.cl)
return img, target
def __len__(self):
return len(self.sub_meta)
class EpisodicBatchSampler(object):
def __init__(self, n_classes, n_way, n_episodes):
self.n_classes = n_classes
self.n_way = n_way
self.n_episodes = n_episodes
def __len__(self):
return self.n_episodes
def __iter__(self):
for i in range(self.n_episodes):
yield torch.randperm(self.n_classes)[:self.n_way]
================================================
FILE: MAML_MN_FT/data/feature_loader.py
================================================
import torch
import numpy as np
import h5py
class SimpleHDF5Dataset:
def __init__(self, file_handle = None):
if file_handle == None:
self.f = ''
self.all_feats_dset = []
self.all_labels = []
self.total = 0
else:
self.f = file_handle
self.all_feats_dset = self.f['all_feats'][...]
self.all_labels = self.f['all_labels'][...]
self.total = self.f['count'][0]
# print('here')
def __getitem__(self, i):
return torch.Tensor(self.all_feats_dset[i,:]), int(self.all_labels[i])
def __len__(self):
return self.total
def init_loader(filename, get_path=False, path_file=None):
with h5py.File(filename, 'r') as f:
fileset = SimpleHDF5Dataset(f)
#labels = [ l for l in fileset.all_labels if l != 0]
feats = fileset.all_feats_dset
labels = fileset.all_labels
if get_path:
paths = np.load(path_file)
while np.sum(feats[-1]) == 0:
feats = np.delete(feats,-1,axis = 0)
labels = np.delete(labels,-1,axis = 0)
class_list = np.unique(np.array(labels)).tolist()
inds = range(len(labels))
cl_data_file = {}
path_data_file = {}
for cl in class_list:
cl_data_file[cl] = []
path_data_file[cl] = []
for ind in inds:
cl_data_file[labels[ind]].append( feats[ind])
if get_path:
path_data_file[labels[ind]].append(paths[ind])
if get_path:
return cl_data_file, path_data_file
else:
return cl_data_file
================================================
FILE: MAML_MN_FT/filelists/CUB/attributes.txt
================================================
1 has_bill_shape::curved_(up_or_down)
2 has_bill_shape::dagger
3 has_bill_shape::hooked
4 has_bill_shape::needle
5 has_bill_shape::hooked_seabird
6 has_bill_shape::spatulate
7 has_bill_shape::all-purpose
8 has_bill_shape::cone
9 has_bill_shape::specialized
10 has_wing_color::blue
11 has_wing_color::brown
12 has_wing_color::iridescent
13 has_wing_color::purple
14 has_wing_color::rufous
15 has_wing_color::grey
16 has_wing_color::yellow
17 has_wing_color::olive
18 has_wing_color::green
19 has_wing_color::pink
20 has_wing_color::orange
21 has_wing_color::black
22 has_wing_color::white
23 has_wing_color::red
24 has_wing_color::buff
25 has_upperparts_color::blue
26 has_upperparts_color::brown
27 has_upperparts_color::iridescent
28 has_upperparts_color::purple
29 has_upperparts_color::rufous
30 has_upperparts_color::grey
31 has_upperparts_color::yellow
32 has_upperparts_color::olive
33 has_upperparts_color::green
34 has_upperparts_color::pink
35 has_upperparts_color::orange
36 has_upperparts_color::black
37 has_upperparts_color::white
38 has_upperparts_color::red
39 has_upperparts_color::buff
40 has_underparts_color::blue
41 has_underparts_color::brown
42 has_underparts_color::iridescent
43 has_underparts_color::purple
44 has_underparts_color::rufous
45 has_underparts_color::grey
46 has_underparts_color::yellow
47 has_underparts_color::olive
48 has_underparts_color::green
49 has_underparts_color::pink
50 has_underparts_color::orange
51 has_underparts_color::black
52 has_underparts_color::white
53 has_underparts_color::red
54 has_underparts_color::buff
55 has_breast_pattern::solid
56 has_breast_pattern::spotted
57 has_breast_pattern::striped
58 has_breast_pattern::multi-colored
59 has_back_color::blue
60 has_back_color::brown
61 has_back_color::iridescent
62 has_back_color::purple
63 has_back_color::rufous
64 has_back_color::grey
65 has_back_color::yellow
66 has_back_color::olive
67 has_back_color::green
68 has_back_color::pink
69 has_back_color::orange
70 has_back_color::black
71 has_back_color::white
72 has_back_color::red
73 has_back_color::buff
74 has_tail_shape::forked_tail
75 has_tail_shape::rounded_tail
76 has_tail_shape::notched_tail
77 has_tail_shape::fan-shaped_tail
78 has_tail_shape::pointed_tail
79 has_tail_shape::squared_tail
80 has_upper_tail_color::blue
81 has_upper_tail_color::brown
82 has_upper_tail_color::iridescent
83 has_upper_tail_color::purple
84 has_upper_tail_color::rufous
85 has_upper_tail_color::grey
86 has_upper_tail_color::yellow
87 has_upper_tail_color::olive
88 has_upper_tail_color::green
89 has_upper_tail_color::pink
90 has_upper_tail_color::orange
91 has_upper_tail_color::black
92 has_upper_tail_color::white
93 has_upper_tail_color::red
94 has_upper_tail_color::buff
95 has_head_pattern::spotted
96 has_head_pattern::malar
97 has_head_pattern::crested
98 has_head_pattern::masked
99 has_head_pattern::unique_pattern
100 has_head_pattern::eyebrow
101 has_head_pattern::eyering
102 has_head_pattern::plain
103 has_head_pattern::eyeline
104 has_head_pattern::striped
105 has_head_pattern::capped
106 has_breast_color::blue
107 has_breast_color::brown
108 has_breast_color::iridescent
109 has_breast_color::purple
110 has_breast_color::rufous
111 has_breast_color::grey
112 has_breast_color::yellow
113 has_breast_color::olive
114 has_breast_color::green
115 has_breast_color::pink
116 has_breast_color::orange
117 has_breast_color::black
118 has_breast_color::white
119 has_breast_color::red
120 has_breast_color::buff
121 has_throat_color::blue
122 has_throat_color::brown
123 has_throat_color::iridescent
124 has_throat_color::purple
125 has_throat_color::rufous
126 has_throat_color::grey
127 has_throat_color::yellow
128 has_throat_color::olive
129 has_throat_color::green
130 has_throat_color::pink
131 has_throat_color::orange
132 has_throat_color::black
133 has_throat_color::white
134 has_throat_color::red
135 has_throat_color::buff
136 has_eye_color::blue
137 has_eye_color::brown
138 has_eye_color::purple
139 has_eye_color::rufous
140 has_eye_color::grey
141 has_eye_color::yellow
142 has_eye_color::olive
143 has_eye_color::green
144 has_eye_color::pink
145 has_eye_color::orange
146 has_eye_color::black
147 has_eye_color::white
148 has_eye_color::red
149 has_eye_color::buff
150 has_bill_length::about_the_same_as_head
151 has_bill_length::longer_than_head
152 has_bill_length::shorter_than_head
153 has_forehead_color::blue
154 has_forehead_color::brown
155 has_forehead_color::iridescent
156 has_forehead_color::purple
157 has_forehead_color::rufous
158 has_forehead_color::grey
159 has_forehead_color::yellow
160 has_forehead_color::olive
161 has_forehead_color::green
162 has_forehead_color::pink
163 has_forehead_color::orange
164 has_forehead_color::black
165 has_forehead_color::white
166 has_forehead_color::red
167 has_forehead_color::buff
168 has_under_tail_color::blue
169 has_under_tail_color::brown
170 has_under_tail_color::iridescent
171 has_under_tail_color::purple
172 has_under_tail_color::rufous
173 has_under_tail_color::grey
174 has_under_tail_color::yellow
175 has_under_tail_color::olive
176 has_under_tail_color::green
177 has_under_tail_color::pink
178 has_under_tail_color::orange
179 has_under_tail_color::black
180 has_under_tail_color::white
181 has_under_tail_color::red
182 has_under_tail_color::buff
183 has_nape_color::blue
184 has_nape_color::brown
185 has_nape_color::iridescent
186 has_nape_color::purple
187 has_nape_color::rufous
188 has_nape_color::grey
189 has_nape_color::yellow
190 has_nape_color::olive
191 has_nape_color::green
192 has_nape_color::pink
193 has_nape_color::orange
194 has_nape_color::black
195 has_nape_color::white
196 has_nape_color::red
197 has_nape_color::buff
198 has_belly_color::blue
199 has_belly_color::brown
200 has_belly_color::iridescent
201 has_belly_color::purple
202 has_belly_color::rufous
203 has_belly_color::grey
204 has_belly_color::yellow
205 has_belly_color::olive
206 has_belly_color::green
207 has_belly_color::pink
208 has_belly_color::orange
209 has_belly_color::black
210 has_belly_color::white
211 has_belly_color::red
212 has_belly_color::buff
213 has_wing_shape::rounded-wings
214 has_wing_shape::pointed-wings
215 has_wing_shape::broad-wings
216 has_wing_shape::tapered-wings
217 has_wing_shape::long-wings
218 has_size::large_(16_-_32_in)
219 has_size::small_(5_-_9_in)
220 has_size::very_large_(32_-_72_in)
221 has_size::medium_(9_-_16_in)
222 has_size::very_small_(3_-_5_in)
223 has_shape::upright-perching_water-like
224 has_shape::chicken-like-marsh
225 has_shape::long-legged-like
226 has_shape::duck-like
227 has_shape::owl-like
228 has_shape::gull-like
229 has_shape::hummingbird-like
230 has_shape::pigeon-like
231 has_shape::tree-clinging-like
232 has_shape::hawk-like
233 has_shape::sandpiper-like
234 has_shape::upland-ground-like
235 has_shape::swallow-like
236 has_shape::perching-like
237 has_back_pattern::solid
238 has_back_pattern::spotted
239 has_back_pattern::striped
240 has_back_pattern::multi-colored
241 has_tail_pattern::solid
242 has_tail_pattern::spotted
243 has_tail_pattern::striped
244 has_tail_pattern::multi-colored
245 has_belly_pattern::solid
246 has_belly_pattern::spotted
247 has_belly_pattern::striped
248 has_belly_pattern::multi-colored
249 has_primary_color::blue
250 has_primary_color::brown
251 has_primary_color::iridescent
252 has_primary_color::purple
253 has_primary_color::rufous
254 has_primary_color::grey
255 has_primary_color::yellow
256 has_primary_color::olive
257 has_primary_color::green
258 has_primary_color::pink
259 has_primary_color::orange
260 has_primary_color::black
261 has_primary_color::white
262 has_primary_color::red
263 has_primary_color::buff
264 has_leg_color::blue
265 has_leg_color::brown
266 has_leg_color::iridescent
267 has_leg_color::purple
268 has_leg_color::rufous
269 has_leg_color::grey
270 has_leg_color::yellow
271 has_leg_color::olive
272 has_leg_color::green
273 has_leg_color::pink
274 has_leg_color::orange
275 has_leg_color::black
276 has_leg_color::white
277 has_leg_color::red
278 has_leg_color::buff
279 has_bill_color::blue
280 has_bill_color::brown
281 has_bill_color::iridescent
282 has_bill_color::purple
283 has_bill_color::rufous
284 has_bill_color::grey
285 has_bill_color::yellow
286 has_bill_color::olive
287 has_bill_color::green
288 has_bill_color::pink
289 has_bill_color::orange
290 has_bill_color::black
291 has_bill_color::white
292 has_bill_color::red
293 has_bill_color::buff
294 has_crown_color::blue
295 has_crown_color::brown
296 has_crown_color::iridescent
297 has_crown_color::purple
298 has_crown_color::rufous
299 has_crown_color::grey
300 has_crown_color::yellow
301 has_crown_color::olive
302 has_crown_color::green
303 has_crown_color::pink
304 has_crown_color::orange
305 has_crown_color::black
306 has_crown_color::white
307 has_crown_color::red
308 has_crown_color::buff
309 has_wing_pattern::solid
310 has_wing_pattern::spotted
311 has_wing_pattern::striped
312 has_wing_pattern::multi-colored
================================================
FILE: MAML_MN_FT/filelists/CUB/base.json
================================================
{"label_names": ["001.Black_footed_Albatross","002.Laysan_Albatross","003.Sooty_Albatross","004.Groove_billed_Ani","005.Crested_Auklet","006.Least_Auklet","007.Parakeet_Auklet","008.Rhinoceros_Auklet","009.Brewer_Blackbird","010.Red_winged_Blackbird","011.Rusty_Blackbird","012.Yellow_headed_Blackbird","013.Bobolink","014.Indigo_Bunting","015.Lazuli_Bunting","016.Painted_Bunting","017.Cardinal","018.Spotted_Catbird","019.Gray_Catbird","020.Yellow_breasted_Chat","021.Eastern_Towhee","022.Chuck_will_Widow","023.Brandt_Cormorant","024.Red_faced_Cormorant","025.Pelagic_Cormorant","026.Bronzed_Cowbird","027.Shiny_Cowbird","028.Brown_Creeper","029.American_Crow","030.Fish_Crow","031.Black_billed_Cuckoo","032.Mangrove_Cuckoo","033.Yellow_billed_Cuckoo","034.Gray_crowned_Rosy_Finch","035.Purple_Finch","036.Northern_Flicker","037.Acadian_Flycatcher","038.Great_Crested_Flycatcher","039.Least_Flycatcher","040.Olive_sided_Flycatcher","041.Scissor_tailed_Flycatcher","042.Vermilion_Flycatcher","043.Yellow_bellied_Flycatcher","044.Frigatebird","045.Northern_Fulmar","046.Gadwall","047.American_Goldfinch","048.European_Goldfinch","049.Boat_tailed_Grackle","050.Eared_Grebe","051.Horned_Grebe","052.Pied_billed_Grebe","053.Western_Grebe","054.Blue_Grosbeak","055.Evening_Grosbeak","056.Pine_Grosbeak","057.Rose_breasted_Grosbeak","058.Pigeon_Guillemot","059.California_Gull","060.Glaucous_winged_Gull","061.Heermann_Gull","062.Herring_Gull","063.Ivory_Gull","064.Ring_billed_Gull","065.Slaty_backed_Gull","066.Western_Gull","067.Anna_Hummingbird","068.Ruby_throated_Hummingbird","069.Rufous_Hummingbird","070.Green_Violetear","071.Long_tailed_Jaeger","072.Pomarine_Jaeger","073.Blue_Jay","074.Florida_Jay","075.Green_Jay","076.Dark_eyed_Junco","077.Tropical_Kingbird","078.Gray_Kingbird","079.Belted_Kingfisher","080.Green_Kingfisher","081.Pied_Kingfisher","082.Ringed_Kingfisher","083.White_breasted_Kingfisher","084.Red_legged_Kittiwake","085.Horned_Lark","086.Pacific_Loon","087.Mallard","088.Western_Meadowlark","089.Hooded_Merganser","090.Red_breasted_Merganser","091.Mockingbird","092.Nighthawk","093.Clark_Nutcracker","094.White_breasted_Nuthatch","095.Baltimore_Oriole","096.Hooded_Oriole","097.Orchard_Oriole","098.Scott_Oriole","099.Ovenbird","100.Brown_Pelican","101.White_Pelican","102.Western_Wood_Pewee","103.Sayornis","104.American_Pipit","105.Whip_poor_Will","106.Horned_Puffin","107.Common_Raven","108.White_necked_Raven","109.American_Redstart","110.Geococcyx","111.Loggerhead_Shrike","112.Great_Grey_Shrike","113.Baird_Sparrow","114.Black_throated_Sparrow","115.Brewer_Sparrow","116.Chipping_Sparrow","117.Clay_colored_Sparrow","118.House_Sparrow","119.Field_Sparrow","120.Fox_Sparrow","121.Grasshopper_Sparrow","122.Harris_Sparrow","123.Henslow_Sparrow","124.Le_Conte_Sparrow","125.Lincoln_Sparrow","126.Nelson_Sharp_tailed_Sparrow","127.Savannah_Sparrow","128.Seaside_Sparrow","129.Song_Sparrow","130.Tree_Sparrow","131.Vesper_Sparrow","132.White_crowned_Sparrow","133.White_throated_Sparrow","134.Cape_Glossy_Starling","135.Bank_Swallow","136.Barn_Swallow","137.Cliff_Swallow","138.Tree_Swallow","139.Scarlet_Tanager","140.Summer_Tanager","141.Artic_Tern","142.Black_Tern","143.Caspian_Tern","144.Common_Tern","145.Elegant_Tern","146.Forsters_Tern","147.Least_Tern","148.Green_tailed_Towhee","149.Brown_Thrasher","150.Sage_Thrasher","151.Black_capped_Vireo","152.Blue_headed_Vireo","153.Philadelphia_Vireo","154.Red_eyed_Vireo","155.Warbling_Vireo","156.White_eyed_Vireo","157.Yellow_throated_Vireo","158.Bay_breasted_Warbler","159.Black_and_white_Warbler","160.Black_throated_Blue_Warbler","161.Blue_winged_Warbler","162.Canada_Warbler","163.Cape_May_Warbler","164.Cerulean_Warbler","165.Chestnut_sided_Warbler","166.Golden_winged_Warbler","167.Hooded_Warbler","168.Kentucky_Warbler","169.Magnolia_Warbler","170.Mourning_Warbler","171.Myrtle_Warbler","172.Nashville_Warbler","173.Orange_crowned_Warbler","174.Palm_Warbler","175.Pine_Warbler","176.Prairie_Warbler","177.Prothonotary_Warbler","178.Swainson_Warbler","179.Tennessee_Warbler","180.Wilson_Warbler","181.Worm_eating_Warbler","182.Yellow_Warbler","183.Northern_Waterthrush","184.Louisiana_Waterthrush","185.Bohemian_Waxwing","186.Cedar_Waxwing","187.American_Three_toed_Woodpecker","188.Pileated_Woodpecker","189.Red_bellied_Woodpecker","190.Red_cockaded_Woodpecker","191.Red_headed_Woodpecker","192.Downy_Woodpecker","193.Bewick_Wren","194.Cactus_Wren","195.Carolina_Wren","196.House_Wren","197.Marsh_Wren","198.Rock_Wren","199.Winter_Wren","200.Common_Yellowthroat"],"image_names": ["/home/yuezhongqi/Coding/CloserLookFewShot/filelists/CUB/CUB_200_2011/images/001.Black_footed_Albatross/Black_Footed_Albatross_0049_796063.jpg","/home/yuezhongqi/Coding/CloserLookFewShot/filelists/CUB/CUB_200_2011/images/001.Black_footed_Albatross/Black_Footed_Albatross_0008_796083.jpg","/home/yuezhongqi/Coding/CloserLookFewShot/filelists/CUB/CUB_200_2011/images/001.Black_footed_Albatross/Black_Footed_Albatross_0065_796068.jpg","/home/yuezhongqi/Coding/CloserLookFewShot/filelists/CUB/CUB_200_2011/images/001.Black_footed_Albatross/Black_Footed_Albatross_0060_796076.jpg","/home/yuezhongqi/Coding/CloserLookFewShot/filelists/CUB/CUB_200_2011/images/001.Black_footed_Albatross/Black_Footed_Albatross_0086_796062.jpg","/home/yuezhongqi/Coding/CloserLookFewShot/filelists/CUB/CUB_200_2011/images/001.Black_footed_Albatross/Black_Footed_Albatross_0042_796071.jpg","/home/yuezhongqi/Coding/CloserLookFewShot/filelists/CUB/CUB_200_2011/images/001.Black_footed_Albatross/Black_Footed_Albatross_0050_796125.jpg","/home/yuezhongqi/Coding/CloserLookFewShot/filelists/CUB/CUB_200_2011/images/001.Black_footed_Albatross/Black_Footed_Albatross_0040_796066.jpg","/home/yuezhongqi/Coding/CloserLookFewShot/filelists/CUB/CUB_200_2011/images/001.Black_footed_Albatross/Black_Footed_Albatross_0057_796106.jpg","/home/yuezhongqi/Coding/CloserLookFewShot/filelists/CUB/CUB_200_2011/images/001.Black_footed_Albatross/Black_Footed_Albatross_0069_796139.jpg","/home/yuezhongqi/Coding/CloserLookFewShot/filelists/CUB/CUB_200_2011/images/001.Black_footed_Albatross/Black_Footed_Albatross_0002_55.jpg","/home/yuezhongqi/Coding/CloserLookFewShot/filelists/CUB/CUB_200_2011/images/001.Black_footed_Albatross/Black_Footed_Albatross_0035_796140.jpg","/home/yuezhongqi/Coding/CloserLookFewShot/filelists/CUB/CUB_200_2011/images/001.Black_footed_Albatross/Black_Footed_Albatross_0058_796074.jpg","/home/yuezhongqi/Coding/CloserLookFewShot/filelists/CUB/CUB_200_2011/images/001.Black_footed_Albatross/Black_Footed_Albatross_0014_89.jpg","/home/yuezhongqi/Coding/CloserLookFewShot/filelists/CUB/CUB_200_2011/images/001.Black_footed_Albatross/Black_Footed_Albatross_0080_796096.jpg","/home/yuezhongqi/Coding/CloserLookFewShot/filelists/CUB/CUB_200_2011/images/001.Black_footed_Albatross/Black_Footed_Albatross_0001_796111.jpg","/home/yuezhongqi/Coding/CloserLookFewShot/filelists/CUB/CUB_200_2011/images/001.Black_footed_Albatross/Black_Footed_Albatross_0026_796095.jpg","/home/yuezhongqi/Coding/CloserLookFewShot/filelists/CUB/CUB_200_2011/images/001.Black_footed_Albatross/Black_Footed_Albatross_0090_796077.jpg","/home/yuezhongqi/Coding/CloserLookFewShot/filelists/CUB/CUB_200_2011/images/001.Black_footed_Albatross/Black_Footed_Albatross_0017_796098.jpg","/home/yuezhongqi/Coding/CloserLookFewShot/filelists/CUB/CUB_200_2011/images/001.Black_footed_Albatross/Black_Footed_Albatross_0067_170.jpg","/home/yuezhongqi/Coding/CloserLookFewShot/filelists/CUB/CUB_200_2011/images/001.Black_footed_Albatross/Black_Footed_Albatross_0031_100.jpg","/home/yuezhongqi/Coding/CloserLookFewShot/filelists/CUB/CUB_200_2011/images/001.Black_footed_Albatross/Black_Footed_Albatross_0077_796114.jpg","/home/yuezhongqi/Coding/CloserLookFewShot/filelists/CUB/CUB_200_2011/images/001.Black_footed_Albatross/Black_Footed_Albatross_0016_796067.jpg","/home/yuezhongqi/Coding/CloserLookFewShot/filelists/CUB/CUB_200_2011/images/001.Black_footed_Albatross/Black_Footed_Albatross_0082_796121.jpg","/home/yuezhongqi/Coding/CloserLookFewShot/filelists/CUB/CUB_200_2011/images/001.Black_footed_Albatross/Black_Footed_Albatross_0005_796090.jpg","/home/yuezhongqi/Coding/CloserLookFewShot/filelists/CUB/CUB_200_2011/images/001.Black_footed_Albatross/Black_Footed_Albatross_0068_796135.jpg","/home/yuezhongqi/Coding/CloserLookFewShot/filelists/CUB/CUB_200_2011/images/001.Black_footed_Albatross/Black_Footed_Albatross_0009_34.jpg","/home/yuezhongqi/Coding/CloserLookFewShot/filelists/CUB/CUB_200_2011/images/001.Black_footed_Albatross/Black_Footed_Albatross_0036_796127.jpg","/home/yuezhongqi/Coding/CloserLookFewShot/filelists/CUB/CUB_200_2011/images/001.Black_footed_Albatross/Black_Footed_Albatross_0041_796108.jpg","/home/yuezhongqi/Coding/CloserLookFewShot/filelists/CUB/CUB_200_2011/images/001.Black_footed_Albatross/Black_Footed_Albatross_0089_796069.jpg","/home/yuezhongqi/Coding/CloserLookFewShot/filelists/CUB/CUB_200_2011/images/001.Black_footed_Albatross/Black_Footed_Albatross_0081_426.jpg","/home/yuezhongqi/Coding/CloserLookFewShot/filelists/CUB/CUB_200_2011/images/001.Black_footed_Albatross/Black_Footed_Albatross_0038_212.jpg","/home/yuezhongqi/Coding/CloserLookFewShot/filelists/CUB/CUB_200_2011/images/001.Black_footed_Albatross/Black_Footed_Albatross_0039_796132.jpg","/home/yuezhongqi/Coding/CloserLookFewShot/filelists/CUB/CUB_200_2011/images/001.Black_footed_Albatross/Black_Footed_Albatross_0007_796138.jpg","/home/yuezhongqi/Coding/CloserLookFewShot/filelists/CUB/CUB_200_2011/images/001.Black_footed_Albatross/Black_Footed_Albatross_0023_796059.jpg","/home/yuezhongqi/Coding/CloserLookFewShot/filelists/CUB/CUB_200_2011/images/001.Black_footed_Albatross/Black_Footed_Albatross_0033_796086.jpg","/home/yuezhongqi/Coding/CloserLookFewShot/filelists/CUB/CUB_200_2011/images/001.Black_footed_Albatross/Black_Footed_Albatross_0010_796097.jpg","/home/yuezhongqi/Coding/CloserLookFewShot/filelists/CUB/CUB_200_2011/images/001.Black_footed_Albatross/Black_Footed_Albatross_0037_796120.jpg","/home/yuezhongqi/Coding/CloserLookFewShot/filelists/CUB/CUB_200_2011/images/001.Black_footed_Albatross/Black_Footed_Albatross_0045_796129.jpg","/home/yuezhongqi/Coding/CloserLookFewShot/filelists/CUB/CUB_200_2011/images/001.Black_footed_Albatross/Black_Footed_Albatross_0061_796082.jpg","/home/yuezhongqi/Coding/CloserLookFewShot/filelists/CUB/CUB_200_2011/images/001.Black_footed_Albatross/Black_Footed_Albatross_0032_796115.jpg","/home/yuezhongqi/Coding/CloserLookFewShot/filelists/CUB/CUB_200_2011/images/001.Black_footed_Albatross/Black_Footed_Albatross_0079_796122.jpg","/home/yuezhongqi/Coding/CloserLookFewShot/filelists/CUB/CUB_200_2011/images/001.Black_footed_Albatross/Black_Footed_Albatross_0071_796113.jpg","/home/yuezhongqi/Coding/CloserLookFewShot/filelists/CUB/CUB_200_2011/images/001.Black_footed_Albatross/Black_Footed_Albatross_0056_796078.jpg","/home/yuezhongqi/Coding/CloserLookFewShot/filelists/CUB/CUB_200_2011/images/001.Black_footed_Albatross/Black_Footed_Albatross_0019_796104.jpg","/home/yuezhongqi/Coding/CloserLookFewShot/filelists/CUB/CUB_200_2011/images/001.Black_footed_Albatross/Black_Footed_Albatross_0003_796136.jpg","/home/yuezhongqi/Coding/CloserLookFewShot/filelists/CUB/CUB_200_2011/images/001.Black_footed_Albatross/Black_Footed_Albatross_0074_59.jpg","/home/yuezhongqi/Coding/CloserLookFewShot/filelists/CUB/CUB_200_2011/images/001.Black_footed_Albatross/Black_Footed_Albatross_0051_796103.jpg","/home/yuezhongqi/Coding/CloserLookFewShot/filelists/CUB/CUB_200_2011/images/001.Black_footed_Albatross/Black_Footed_Albatross_0078_796126.jpg","/home/yuezhongqi/Coding/CloserLookFewShot/filelists/CUB/CUB_200_2011/images/001.Black_footed_Albatross/Black_Footed_Albatross_0076_417.jpg","/home/yuezhongqi/Coding/CloserLookFewShot/filelists/CUB/CUB_200_2011/images/001.Black_footed_Albatross/Black_Footed_Albatross_0085_92.jpg","/home/yuezhongqi/Coding/CloserLookFewShot/filelists/CUB/CUB_200_2011/images/001.Black_footed_Albatross/Black_Footed_Albatross_0088_796133.jpg","/home/yuezhongqi/Coding/CloserLookFewShot/filelists/CUB/CUB_200_2011/images/001.Black_footed_Albatross/Black_Footed_Albatross_0047_796064.jpg","/home/yuezhongqi/Coding/CloserLookFewShot/filelists/CUB/CUB_200_2011/images/001.Black_footed_Albatross/Black_Footed_Albatross_0063_796141.jpg","/home/yuezhongqi/Coding/CloserLookFewShot/filelists/CUB/CUB_200_2011/images/001.Black_footed_Albatross/Black_Footed_Albatross_0046_18.jpg","/home/yuezhongqi/Coding/CloserLookFewShot/filelists/CUB/CUB_200_2011/images/001.Black_footed_Albatross/Black_Footed_Albatross_0006_796065.jpg","/home/yuezhongqi/Coding/CloserLookFewShot/filelists/CUB/CUB_200_2011/images/001.Black_footed_Albatross/Black_Footed_Albatross_0024_796089.jpg","/home/yuezhongqi/Coding/CloserLookFewShot/filelists/CUB/CUB_200_2011/images/001.Black_footed_Albatross/Black_Footed_Albatross_0053_796109.jpg","/home/yuezhongqi/Coding/CloserLookFewShot/filelists/CUB/CUB_200_2011/images/001.Black_footed_Albatross/Black_Footed_Albatross_0025_796057.jpg","/home/yuezhongqi/Coding/CloserLookFewShot/filelists/CUB/CUB_200_2011/images/001.Black_footed_Albatross/Black_Footed_Albatross_0064_796101.jpg","/home/yuezhongqi/Coding/CloserLookFewShot/filelists/CUB/CUB_200_2011/images/003.Sooty_Albatross/Sooty_Albatross_0019_796391.jpg","/home/yuezhongqi/Coding/CloserLookFewShot/filelists/CUB/CUB_200_2011/images/003.Sooty_Albatross/Sooty_Albatross_0021_796339.jpg","/home/yuezhongqi/Coding/CloserLookFewShot/filelists/CUB/CUB_200_2011/images/003.Sooty_Albatross/Sooty_Albatross_0025_796361.jpg","/home/yuezhongqi/Coding/CloserLookFewShot/filelists/CUB/CUB_200_2011/images/003.Sooty_Albatross/Sooty_Albatross_0067_796376.jpg","/home/yuezhongqi/Coding/CloserLookFewShot/filelists/CUB/CUB_200_2011/images/003.Sooty_Albatross/Sooty_Albatross_0071_1116.jpg","/home/yuezhongqi/Coding/CloserLookFewShot/filelists/CUB/CUB_200_2011/images/003.Sooty_Albatross/Sooty_Albatross_0049_796350.jpg","/home/yuezhongqi/Coding/CloserLookFewShot/filelists/CUB/CUB_200_2011/images/003.Sooty_Albatross/Sooty_Albatross_0048_1130.jpg","/home/yuezhongqi/Coding/CloserLookFewShot/filelists/CUB/CUB_200_2011/images/003.Sooty_Albatross/Sooty_Albatross_0004_796366.jpg","/home/yuezhongqi/Coding/CloserLookFewShot/filelists/CUB/CUB_200_2011/images/003.Sooty_Albatross/Sooty_Albatross_0017_796349.jpg","/home/yuezhongqi/Coding/CloserLookFewShot/filelists/CUB/CUB_200_2011/images/003.Sooty_Albatross/Sooty_Albatross_0030_1122.jpg","/home/yuezhongqi/Coding/CloserLookFewShot/filelists/CUB/CUB_200_2011/images/003.Sooty_Albatross/Sooty_Albatross_0046_1211.jpg","/home/yuezhongqi/Coding/CloserLookFewShot/filelists/CUB/CUB_200_2011/images/003.Sooty_Albatross/Sooty_Albatross_0023_796401.jpg","/home/yuezhongqi/Coding/CloserLookFewShot/filelists/CUB/CUB_200_2011/images/003.Sooty_Albatross/Sooty_Albatross_0044_1105.jpg","/home/yuezhongqi/Coding/CloserLookFewShot/filelists/CUB/CUB_200_2011/images/003.Sooty_Albatross/Sooty_Albatross_0057_796354.jpg","/home/yuezhongqi/Coding/CloserLookFewShot/filelists/CUB/CUB_200_2011/images/003.Sooty_Albatross/Sooty_Albatross_0058_796360.jpg","/home/yuezhongqi/Coding/CloserLookFewShot/filelists/CUB/CUB_200_2011/images/003.Sooty_Albatross/Sooty_Albatross_0038_1065.jpg","/home/yuezhongqi/Coding/CloserLookFewShot/filelists/CUB/CUB_200_2011/images/003.Sooty_Albatross/Sooty_Albatross_0020_796359.jpg","/home/yuezhongqi/Coding/CloserLookFewShot/filelists/CUB/CUB_200_2011/images/003.Sooty_Albatross/Sooty_Albatross_0055_1160.jpg","/home/yuezhongqi/Coding/CloserLookFewShot/filelists/CUB/CUB_200_2011/images/003.Sooty_Albatross/Sooty_Albatross_0034_1154.jpg","/home/yuezhongqi/Coding/CloserLookFewShot/filelists/CUB/CUB_200_2011/images/003.Sooty_Albatross/Sooty_Albatross_0051_796374.jpg","/home/yuezhongqi/Coding/CloserLookFewShot/filelists/CUB/CUB_200_2011/images/003.Sooty_Albatross/Sooty_Albatross_0074_1221.jpg","/home/yuezhongqi/Coding/CloserLookFewShot/filelists/CUB/CUB_200_2011/images/003.Sooty_Albatross/Sooty_Albatross_0076_796365.jpg","/home/yuezhongqi/Coding/CloserLookFewShot/filelists/CUB/CUB_200_2011/images/003.Sooty_Albatross/Sooty_Albatross_0069_796358.jpg","/home/yuezhongqi/Coding/CloserLookFewShot/filelists/CUB/CUB_200_2011/images/003.Sooty_Albatross/Sooty_Albatross_0016_1075.jpg","/home/yuezhongqi/Coding/CloserLookFewShot/filelists/CUB/CUB_200_2011/images/003.Sooty_Albatross/Sooty_Albatross_0014_796373.jpg","/home/yuezhongqi/Coding/CloserLookFewShot/filelists/CUB/CUB_200_2011/images/003.Sooty_Albatross/Sooty_Albatross_0042_1210.jpg","/home/yuezhongqi/Coding/CloserLookFewShot/filelists/CUB/CUB_200_2011/images/003.Sooty_Albatross/Sooty_Albatross_0022_796398.jpg","/home/yuezhongqi/Coding/CloserLookFewShot/filelists/CUB/CUB_200_2011/images/003.Sooty_Albatross/Sooty_Albatross_0001_1071.jpg","/home/yuezhongqi/Coding/CloserLookFewShot/filelists/CUB/CUB_200_2011/images/003.Sooty_Albatross/Sooty_Albatross_0054_796347.jpg","/home/yuezhongqi/Coding/CloserLookFewShot/filelists/CUB/CUB_200_2011/images/003.Sooty_Albatross/Sooty_Albatross_0043_1076.jpg","/home/yuezhongqi/Coding/CloserLookFewShot/filelists/CUB/CUB_200_2011/images/003.Sooty_Albatross/Sooty_Albatross_0065_796367.jpg","/home/yuezhongqi/Coding/CloserLookFewShot/filelists/CUB/CUB_200_2011/images/003.Sooty_Albatross/Sooty_Albatross_0072_796371.jpg","/home/yuezhongqi/Coding/CloserLookFewShot/filelists/CUB/CUB_200_2011/images/003.Sooty_Albatross/Sooty_Albatross_0045_1162.jpg","/home/yuezhongqi/Coding/CloserLookFewShot/filelists/CUB/CUB_200_2011/images/003.Sooty_Albatross/Sooty_Albatross_0079_796389.jpg","/home/yuezhongqi/Coding/CloserLookFewShot/filelists/CUB/CUB_200_2011/images/003.Sooty_Albatross/Sooty_Albatross_0077_1080.jpg","/home/yuezhongqi/Coding/CloserLookFewShot/filelists/CUB/CUB_200_2011/images/003.Sooty_Albatross/Sooty_Albatross_0002_796395.jpg","/home/yuezhongqi/Coding/CloserLookFewShot/filelists/CUB/CUB_200_2011/images/003.Sooty_Albatross/Sooty_Albatross_0064_796343.jpg","/home/yuezhongqi/Coding/CloserLookFewShot/filelists/CUB/CUB_200_2011/images/003.Sooty_Albatross/Sooty_Albatross_0073_1171.jpg","/home/yuezhongqi/Coding/CloserLookFewShot/filelists/CUB/CUB_200_2011/images/003.Sooty_Albatross/Sooty_Albatross_0006_796390.jpg","/home/yuezhongqi/Coding/CloserLookFewShot/filelists/CUB/CUB_200_2011/images/003.Sooty_Albatross/Sooty_Albatross_0033_1128.jpg","/home/yuezhongqi/Coding/CloserLookFewShot/filelists/CUB/CUB_200_2011/images/003.Sooty_Albatross/Sooty_Albatross_0041_796364.jpg","/home/yuezhongqi/Coding/CloserLookFewShot/filelists/CUB/CUB_200_2011/images/003.Sooty_Albatross/Sooty_Albatross_0007_796372.jpg","/home/yuezhongqi/Coding/CloserLookFewShot/filelists/CUB/CUB_200_2011/images/003.Sooty_Albatross/Sooty_Albatross_0036_796387.jpg","/home/yuezhongqi/Coding/CloserLookFewShot/filelists/CUB/CUB_200_2011/images/003.Sooty_Albatross/Sooty_Albatross_0075_796352.jpg","/home/yuezhongqi/Coding/CloserLookF
gitextract_l9qhes0e/ ├── LEO/ │ ├── LICENSE │ ├── config.py │ ├── data.py │ ├── ifsl_configs/ │ │ ├── __init__.py │ │ ├── baseline_config.py │ │ └── ifsl_config.py │ ├── model.py │ ├── model_test.py │ ├── pretrain/ │ │ ├── miniImagenet_baseline_ResNet10_mean.npy │ │ ├── miniImagenet_feat_wrn_mean.npy │ │ ├── miniImagenet_sib_wrn_mean.npy │ │ ├── miniImagenet_simpleshot_ResNet10_mean.npy │ │ ├── miniImagenet_simpleshotwide_wideres_mean.npy │ │ ├── norm_miniImagenet_baseline_ResNet10_mean.npy │ │ ├── norm_miniImagenet_feat_wrn_mean.npy │ │ ├── norm_miniImagenet_sib_wrn_mean.npy │ │ ├── norm_miniImagenet_simpleshot_ResNet10_mean.npy │ │ ├── norm_miniImagenet_simpleshotwide_wideres_mean.npy │ │ ├── norm_tiered_simpleshot_ResNet10_mean.npy │ │ ├── norm_tiered_simpleshotwide_wideres_mean.npy │ │ ├── tiered_simpleshot_ResNet10_mean.npy │ │ └── tiered_simpleshotwide_wideres_mean.npy │ ├── readme.md │ ├── runner.py │ └── utils.py ├── MAML_MN_FT/ │ ├── README.md │ ├── backbone.py │ ├── configs.py │ ├── data/ │ │ ├── __init__.py │ │ ├── additional_transforms.py │ │ ├── datamgr.py │ │ ├── dataset.py │ │ └── feature_loader.py │ ├── filelists/ │ │ ├── CUB/ │ │ │ ├── attributes.txt │ │ │ ├── base.json │ │ │ ├── download_CUB.sh │ │ │ ├── novel.json │ │ │ ├── val.json │ │ │ └── write_CUB_filelist.py │ │ ├── miniImagenet/ │ │ │ ├── all.json │ │ │ ├── base.json │ │ │ ├── download_miniImagenet.sh │ │ │ ├── novel.json │ │ │ ├── test.csv │ │ │ ├── train.csv │ │ │ ├── val.csv │ │ │ ├── val.json │ │ │ ├── write_cross_filelist.py │ │ │ └── write_miniImagenet_filelist.py │ │ └── tiered/ │ │ └── write_tiered_filelist.py │ ├── io_utils.py │ ├── main.py │ ├── methods/ │ │ ├── DMAML.py │ │ ├── DMatchingNet.py │ │ ├── MethodTester.py │ │ ├── NNEDSplitNew.py │ │ ├── PretrainedModel.py │ │ ├── VanillaMAML.py │ │ ├── VanillaMatchingNet.py │ │ ├── __init__.py │ │ ├── meta_template.py │ │ └── meta_toolkits.py │ ├── models/ │ │ ├── FeatWRN.py │ │ ├── SimpleShotResNet.py │ │ ├── SimpleShotWideResNet.py │ │ └── __init__.py │ ├── pretrain/ │ │ ├── miniImagenet_baseline_ResNet10_mean.npy │ │ ├── miniImagenet_cosine_ResNet10_mean.npy │ │ ├── miniImagenet_feat_wrn_mean.npy │ │ ├── miniImagenet_sib_wrn_mean.npy │ │ ├── miniImagenet_simpleshot_ResNet10_mean.npy │ │ ├── miniImagenet_simpleshotwide_wideres_mean.npy │ │ ├── norm_miniImagenet_baseline_ResNet10_mean.npy │ │ ├── norm_miniImagenet_cosine_ResNet10_mean.npy │ │ ├── norm_miniImagenet_feat_wrn_mean.npy │ │ ├── norm_miniImagenet_sib_wrn_mean.npy │ │ ├── norm_miniImagenet_simpleshot_ResNet10_mean.npy │ │ ├── norm_miniImagenet_simpleshotwide_wideres_mean.npy │ │ ├── norm_tiered_simpleshot_ResNet10_mean.npy │ │ ├── norm_tiered_simpleshotwide_wideres_mean.npy │ │ ├── tiered_simpleshot_ResNet10_mean.npy │ │ └── tiered_simpleshotwide_wideres_mean.npy │ ├── save_features.py │ ├── tests/ │ │ ├── MetaTrain.py │ │ └── __init__.py │ └── utils.py ├── MTL/ │ ├── README.md │ ├── configs/ │ │ ├── __init__.py │ │ ├── baseline_config.py │ │ ├── ifsl_resnet_config.py │ │ └── ifsl_wrn_config.py │ ├── dataloader/ │ │ ├── __init__.py │ │ ├── dataset_loader.py │ │ └── samplers.py │ ├── main.py │ ├── models/ │ │ ├── IFSL.py │ │ ├── IFSL_modules.py │ │ ├── IFSL_pretrain.py │ │ ├── ResNet10.py │ │ ├── WRN28.py │ │ ├── __init__.py │ │ ├── conv2d_mtl.py │ │ ├── mtl.py │ │ └── resnet_mtl.py │ ├── pretrain/ │ │ ├── miniImagenet_baseline_ResNet10_mean.npy │ │ ├── miniImagenet_feat_wrn_mean.npy │ │ ├── miniImagenet_sib_wrn_mean.npy │ │ ├── miniImagenet_simpleshot_ResNet10_mean.npy │ │ ├── miniImagenet_simpleshotwide_wideres_mean.npy │ │ ├── norm_miniImagenet_baseline_ResNet10_mean.npy │ │ ├── norm_miniImagenet_feat_wrn_mean.npy │ │ ├── norm_miniImagenet_sib_wrn_mean.npy │ │ ├── norm_miniImagenet_simpleshot_ResNet10_mean.npy │ │ ├── norm_miniImagenet_simpleshotwide_wideres_mean.npy │ │ ├── norm_tiered_simpleshot_ResNet10_mean.npy │ │ ├── norm_tiered_simpleshotwide_wideres_mean.npy │ │ ├── tiered_simpleshot_ResNet10_mean.npy │ │ └── tiered_simpleshotwide_wideres_mean.npy │ ├── run_meta.py │ ├── run_pre.py │ ├── run_pre_clfs.py │ ├── run_test.py │ ├── setup.cfg │ ├── trainer/ │ │ ├── __init__.py │ │ ├── meta.py │ │ └── pre.py │ └── utils/ │ ├── __init__.py │ ├── gpu_tools.py │ ├── hacc.py │ └── misc.py ├── SIB/ │ ├── PretrainedModel.py │ ├── algorithm.py │ ├── backbone.py │ ├── config/ │ │ ├── minires_1_baseline.yaml │ │ ├── minires_1_ifsl.yaml │ │ ├── minires_5_baseline.yaml │ │ ├── minires_5_ifsl.yaml │ │ ├── miniwrn_1_baseline.yaml │ │ ├── miniwrn_1_ifsl.yaml │ │ ├── miniwrn_5_baseline.yaml │ │ ├── miniwrn_5_ifsl.yaml │ │ ├── tieredres_1_baseline.yaml │ │ ├── tieredres_1_ifsl.yaml │ │ ├── tieredres_5_baseline.yaml │ │ ├── tieredres_5_ifsl.yaml │ │ ├── tieredwrn_1_baseline.yaml │ │ ├── tieredwrn_1_ifsl.yaml │ │ ├── tieredwrn_5_baseline.yaml │ │ └── tieredwrn_5_ifsl.yaml │ ├── data/ │ │ ├── __init__.py │ │ ├── additional_transforms.py │ │ ├── datamgr.py │ │ ├── dataset.py │ │ ├── download_cifarfs.sh │ │ ├── download_miniimagenet.sh │ │ ├── feature_loader.py │ │ └── get_cifarfs.py │ ├── dataloader.py │ ├── dataset.py │ ├── deconfound/ │ │ ├── DSIB.py │ │ ├── __init__.py │ │ └── meta_toolkits.py │ ├── dfsl_configs.py │ ├── io_utils.py │ ├── main.py │ ├── main_feat.py │ ├── networks.py │ ├── pretrain/ │ │ ├── miniImagenet_baseline_ResNet10_mean.npy │ │ ├── miniImagenet_sib_wrn_mean.npy │ │ ├── miniImagenet_simpleshot_ResNet10_mean.npy │ │ ├── norm_miniImagenet_baseline_ResNet10_mean.npy │ │ ├── norm_miniImagenet_feat_wrn_mean.npy │ │ ├── norm_miniImagenet_sib_wrn_mean.npy │ │ ├── norm_miniImagenet_simpleshot_ResNet10_mean.npy │ │ ├── norm_miniImagenet_simpleshotwide_wideres_mean.npy │ │ ├── norm_tiered_simpleshot_ResNet10_mean.npy │ │ ├── norm_tiered_simpleshotwide_wideres_mean.npy │ │ ├── tiered_simpleshot_ResNet10_mean.npy │ │ └── tiered_simpleshotwide_wideres_mean.npy │ ├── readme.md │ ├── requirements.txt │ ├── setup.cfg │ ├── sib.py │ ├── simple_shot_models/ │ │ ├── Conv4.py │ │ ├── DenseNet.py │ │ ├── MobileNet.py │ │ ├── ProtoNet.py │ │ ├── ResNet.py │ │ ├── WideResNet.py │ │ └── __init__.py │ └── utils/ │ ├── __init__.py │ ├── config.py │ ├── outils.py │ └── utils.py └── readme.md
SYMBOL INDEX (1092 symbols across 78 files)
FILE: LEO/config.py
function get_data_config (line 121) | def get_data_config():
function get_inner_model_config (line 131) | def get_inner_model_config():
function get_outer_model_config (line 149) | def get_outer_model_config():
function load_ifsl_config (line 165) | def load_ifsl_config(config):
FILE: LEO/data.py
class StrEnum (line 39) | class StrEnum(enum.Enum):
method __str__ (line 42) | def __str__(self):
method __repr__ (line 45) | def __repr__(self):
class MetaDataset (line 49) | class MetaDataset(StrEnum):
class EmbeddingCrop (line 55) | class EmbeddingCrop(StrEnum):
class MetaSplit (line 61) | class MetaSplit(StrEnum):
class DataProvider (line 68) | class DataProvider(object):
method __init__ (line 71) | def __init__(self, dataset_split, config, verbose=False, feat_dim=640,...
method _check_config (line 80) | def _check_config(self):
method _load_data (line 91) | def _load_data(self):
method _load (line 112) | def _load(self, opened_file):
method _index_data (line 119) | def _index_data(self, raw_data):
method _check_data_index (line 143) | def _check_data_index(self, raw_data):
method _get_full_pickle_path (line 161) | def _get_full_pickle_path(self, split_name):
method get_instance (line 179) | def get_instance(self, num_classes, tr_size, val_size):
method get_batch (line 269) | def get_batch(self, batch_size, num_classes, tr_size, val_size,
method _check_labels (line 320) | def _check_labels(self, num_classes, tr_size, val_size,
FILE: LEO/ifsl_configs/baseline_config.py
class Config (line 1) | class Config():
method __init__ (line 2) | def __init__(self):
function mini_5_resnet_baseline (line 6) | def mini_5_resnet_baseline():
function mini_1_resnet_baseline (line 18) | def mini_1_resnet_baseline():
function mini_5_wrn_baseline (line 30) | def mini_5_wrn_baseline():
function mini_1_wrn_baseline (line 42) | def mini_1_wrn_baseline():
function tiered_5_resnet_baseline (line 54) | def tiered_5_resnet_baseline():
function tiered_1_resnet_baseline (line 67) | def tiered_1_resnet_baseline():
function tiered_5_wrn_baseline (line 79) | def tiered_5_wrn_baseline():
function tiered_1_wrn_baseline (line 92) | def tiered_1_wrn_baseline():
FILE: LEO/ifsl_configs/ifsl_config.py
class Config (line 1) | class Config():
method __init__ (line 2) | def __init__(self):
function mini_5_resnet_ifsl (line 5) | def mini_5_resnet_ifsl():
function mini_1_resnet_ifsl (line 31) | def mini_1_resnet_ifsl():
function mini_5_wrn_ifsl (line 56) | def mini_5_wrn_ifsl():
function mini_1_wrn_ifsl (line 82) | def mini_1_wrn_ifsl():
function tiered_5_resnet_ifsl (line 108) | def tiered_5_resnet_ifsl():
function tiered_1_resnet_ifsl (line 134) | def tiered_1_resnet_ifsl():
function tiered_5_wrn_ifsl (line 160) | def tiered_5_wrn_ifsl():
function tiered_1_wrn_ifsl (line 185) | def tiered_1_wrn_ifsl():
FILE: LEO/model.py
function get_orthogonality_regularizer (line 36) | def get_orthogonality_regularizer(orthogonality_penalty_weight):
function run_leo (line 60) | def run_leo(model, inputs, is_meta_training):
class FeatureProcessor (line 74) | class FeatureProcessor():
method __init__ (line 75) | def __init__(self, n_splits, pretrain_mean_filename, feat_dim, is_cosi...
method get_d_features (line 96) | def get_d_features(self, logit):
method preprocess (line 101) | def preprocess(self, data, center=None, method="none"):
method get_split_features (line 111) | def get_split_features(self, data, center, method="none"):
method get_features (line 126) | def get_features(self, data):
class IFSL (line 158) | class IFSL(snt.AbstractModule):
method __init__ (line 159) | def __init__(self, config=None, use_64bits_dtype=True, n_splits=4, is_...
method build_input_data (line 191) | def build_input_data(self, support_features, query_features, data):
method get_debug_data (line 203) | def get_debug_data(self, data):
method fuse_features (line 210) | def fuse_features(self, x1, x2):
method _build (line 219) | def _build(self, data, is_meta_training=True, debug=False, break_down=...
method calculate_dacc (line 291) | def calculate_dacc(self, data, model_prediction, break_down=False):
method grads_and_vars (line 330) | def grads_and_vars(self, metatrain_loss):
class LEO (line 339) | class LEO(snt.AbstractModule):
method __init__ (line 342) | def __init__(self, config=None, use_64bits_dtype=True, name="leo", dec...
method _build (line 364) | def _build(self, data, is_meta_training=True):
method leo_inner_loop (line 419) | def leo_inner_loop(self, data, latents):
method finetuning_inner_loop (line 442) | def finetuning_inner_loop(self, data, leo_loss, classifier_weights):
method forward_encoder (line 460) | def forward_encoder(self, data):
method forward_decoder (line 468) | def forward_decoder(self, data, latents):
method encoder (line 483) | def encoder(self, inputs):
method relation_network (line 498) | def relation_network(self, inputs):
method decoder (line 524) | def decoder(self, inputs):
method average_codes_per_class (line 541) | def average_codes_per_class(self, codes):
method possibly_sample (line 547) | def possibly_sample(self, distribution_params, stddev_offset=0.):
method kl_divergence (line 562) | def kl_divergence(self, samples, normal_distribution):
method predict (line 569) | def predict(self, inputs, weights):
method calculate_inner_loss (line 580) | def calculate_inner_loss(self, inputs, true_outputs, classifier_weights):
method save_problem_instance_stats (line 589) | def save_problem_instance_stats(self, instance):
method dropout_rate (line 607) | def dropout_rate(self):
method loss_fn (line 610) | def loss_fn(self, model_outputs, original_classes):
method grads_and_vars (line 617) | def grads_and_vars(self, metatrain_loss):
method _l2_regularization (line 656) | def _l2_regularization(self):
method _decoder_orthogonality_reg (line 662) | def _decoder_orthogonality_reg(self):
FILE: LEO/model_test.py
function get_test_config (line 36) | def get_test_config():
function mockify_everything (line 53) | def mockify_everything(test_function=None,
function _random_problem_instance (line 126) | def _random_problem_instance(num_classes=7,
class LEOTest (line 149) | class LEOTest(tf.test.TestCase, parameterized.TestCase):
method setUp (line 151) | def setUp(self):
method test_instantiate_leo (line 160) | def test_instantiate_leo(self):
method test_inner_loop_adaptation (line 168) | def test_inner_loop_adaptation(self):
method test_map_input (line 193) | def test_map_input(self):
method test_setting_is_meta_training (line 230) | def test_setting_is_meta_training(self):
method test_finetuning_improves_loss (line 237) | def test_finetuning_improves_loss(self):
method test_gradients_dont_flow_through_input (line 254) | def test_gradients_dont_flow_through_input(self):
method test_inferring_embedding_dim (line 262) | def test_inferring_embedding_dim(self):
method test_variable_creation (line 267) | def test_variable_creation(self):
method test_graph_construction (line 284) | def test_graph_construction(self):
method test_possibly_sample (line 287) | def test_possibly_sample(self):
method test_different_shapes (line 308) | def test_different_shapes(self):
method test_encoder_penalty (line 315) | def test_encoder_penalty(self):
method test_construct_float32_leo_graph (line 332) | def test_construct_float32_leo_graph(self):
FILE: LEO/runner.py
function _clip_gradients (line 49) | def _clip_gradients(gradients, gradient_threshold, gradient_norm_thresho...
function _construct_validation_summaries (line 63) | def _construct_validation_summaries(metavalid_loss, metavalid_accuracy):
function _construct_training_summaries (line 69) | def _construct_training_summaries(metatrain_loss, metatrain_accuracy,
function _construct_examples_batch (line 80) | def _construct_examples_batch(batch_size, split, num_classes,
function _construct_loss_and_accuracy (line 92) | def _construct_loss_and_accuracy(inner_model, inputs, is_meta_training):
function construct_debug_graph (line 107) | def construct_debug_graph(outer_model_config):
function construct_graph (line 126) | def construct_graph(outer_model_config):
function run_debug_loop (line 208) | def run_debug_loop(checkpoint_path):
function write_output_message (line 227) | def write_output_message(message, file_name=None):
function run_training_loop (line 235) | def run_training_loop(checkpoint_path):
function main (line 309) | def main(argv):
FILE: LEO/utils.py
function unpack_data (line 31) | def unpack_data(problem_instance):
function copy_checkpoint (line 38) | def copy_checkpoint(checkpoint_path, global_step, accuracy):
function _save_files_in_tmp_directory (line 69) | def _save_files_in_tmp_directory(tmp_checkpoint_path, checkpoint_files,
function _is_previous_accuracy_better (line 95) | def _is_previous_accuracy_better(best_checkpoint_path, accuracy):
function evaluate_and_average (line 106) | def evaluate_and_average(session, tensor, num_estimates):
function evaluate_and_average_acc_dacc (line 111) | def evaluate_and_average_acc_dacc(session, acc, dacc, num_estimates):
FILE: MAML_MN_FT/backbone.py
function init_layer (line 18) | def init_layer(L):
class NNClassifier (line 28) | class NNClassifier():
method __init__ (line 29) | def __init__(self, n_way):
method normalize (line 32) | def normalize(self, x):
method preprocess (line 37) | def preprocess(self, data):
method dist (line 46) | def dist(self, x1, x2):
method kl_divergence (line 49) | def kl_divergence(self, k1, k2):
method fit (line 55) | def fit(self, support, support_labels, support_weights=None):
method predict (line 73) | def predict(self, query):
method predict_alt (line 85) | def predict_alt(self, query, measure="euclidean", norm_scores=False, t...
class MultiNNBiClassifier (line 114) | class MultiNNBiClassifier():
method __init__ (line 115) | def __init__(self, n_way, n_classifiers, measure="linear", fusion="lin...
method fit (line 124) | def fit(self, support_x, support_d, support_labels, support_weights=No...
method fuse_proba (line 133) | def fuse_proba(self, p1, p2):
method predict (line 145) | def predict(self, query_x, query_d, weights=None, counterfactual=False):
class MultiNNClassifier (line 167) | class MultiNNClassifier():
method __init__ (line 168) | def __init__(self, n_way, n_classifiers, measure="euclidean", temp=1.0):
method fit (line 178) | def fit(self, support, support_labels, support_weights=None):
method predict (line 189) | def predict(self, query, weights=None):
class BidrectionalLSTM (line 209) | class BidrectionalLSTM(nn.Module):
method __init__ (line 210) | def __init__(self, size: int, layers: int):
method forward (line 229) | def forward(self, inputs):
class AttentionLSTM (line 242) | class AttentionLSTM(nn.Module):
method __init__ (line 243) | def __init__(self, size: int, unrolling_steps: int):
method forward (line 258) | def forward(self, support, queries):
class MultiLinearClassifier (line 288) | class MultiLinearClassifier(nn.Module):
method __init__ (line 289) | def __init__(self, n_clf, feat_dim, n_way, sum_log=True, permute=False...
method create_clf (line 303) | def create_clf(self, loss_type, in_dim, out_dim):
method forward (line 309) | def forward(self, X):
class MultiBiLinearClassifier (line 330) | class MultiBiLinearClassifier(nn.Module):
method __init__ (line 331) | def __init__(self, n_clf, x_feat_dim, d_feat_dim, n_way, sum_log=True,...
method fuse_logits (line 343) | def fuse_logits(self, p1, p2):
method create_clf (line 355) | def create_clf(self, loss_type, in_dim, out_dim):
method forward (line 361) | def forward(self, X, D, counterfactual=False):
class ResNetKernelClusterAgent (line 384) | class ResNetKernelClusterAgent():
method __init__ (line 385) | def __init__(self, pretrain, n_clusters, pca_dim, cluster_method="kmea...
method fit (line 391) | def fit(self):
class ResNetParamClusterModel (line 407) | class ResNetParamClusterModel():
method __init__ (line 408) | def __init__(self, pretrain, n_clusters, cluster_method="kmeans"):
method cluster (line 413) | def cluster(self, features, n_clusters):
method get_weight_features (line 418) | def get_weight_features(self, weights):
method fit (line 424) | def fit(self):
method conv_forward (line 432) | def conv_forward(self, inputs, labels, n_clusters, original_conv):
method forward (line 452) | def forward(self, imgs):
method forward2 (line 491) | def forward2(self, imgs):
method forward3 (line 517) | def forward3(self, imgs):
class BasisTransformer (line 524) | class BasisTransformer():
method __init__ (line 525) | def __init__(self, pretrain, recluster=False, cluster_method="kmeans",...
method fit (line 532) | def fit(self, n_clusters, feat_dim, pca_dim=50):
method transform (line 573) | def transform(self, X):
class KernelTransformer (line 582) | class KernelTransformer():
method __init__ (line 583) | def __init__(self, feat_dim, kernel):
method fit (line 587) | def fit(self, features):
method transform (line 594) | def transform(self, X):
method kernel_f (line 602) | def kernel_f(self, x1, x2):
class ChannelwiseClassifier (line 613) | class ChannelwiseClassifier(nn.Module):
method __init__ (line 614) | def __init__(self, feat_dim, n_way, weight, bias=False):
method reset_parameters (line 626) | def reset_parameters(self):
method forward (line 635) | def forward(self, X):
class UnbiasedClassifier (line 660) | class UnbiasedClassifier(nn.Module):
method __init__ (line 673) | def __init__(self, n_way, x_feature_dim, z_feature_dim, d_feature_dim=0,
method create_clf (line 694) | def create_clf(self, feat_dim):
method create_logit_fusion_fn (line 701) | def create_logit_fusion_fn(self):
method get_fused_feature (line 710) | def get_fused_feature(self, feature_array):
method forward (line 714) | def forward(self, X, Z, D=None):
class XDBiClassifier (line 733) | class XDBiClassifier(nn.Module):
method __init__ (line 734) | def __init__(self, n_way, x_feature_dim, d_feature_dim, architecture="...
method create_clf (line 751) | def create_clf(self, feat_dim):
method create_logit_fusion_fn (line 758) | def create_logit_fusion_fn(self):
method cat_for_logit_fusion (line 769) | def cat_for_logit_fusion(self, A, B):
method forward (line 774) | def forward(self, X, D):
class XDClassifier (line 791) | class XDClassifier(nn.Module):
method __init__ (line 804) | def __init__(self, n_way, x_feature_dim, d_feature_dim, architecture="...
method get_feature_dim (line 826) | def get_feature_dim(self):
method create_clf (line 840) | def create_clf(self, feat_dim):
method get_fused_feature (line 847) | def get_fused_feature(self, feature_array):
method forward (line 855) | def forward(self, X, D):
class ProductGate (line 870) | class ProductGate(nn.Module):
method __init__ (line 871) | def __init__(self):
method forward (line 874) | def forward(self, x):
class HarmonicGate (line 882) | class HarmonicGate(nn.Module):
method __init__ (line 883) | def __init__(self):
method forward (line 886) | def forward(self, x):
class SumGate (line 895) | class SumGate(nn.Module):
method __init__ (line 896) | def __init__(self):
method forward (line 899) | def forward(self, x):
class distLinear (line 906) | class distLinear(nn.Module):
method __init__ (line 907) | def __init__(self, indim, outdim, class_wise_learnable_norm=True):
method forward (line 919) | def forward(self, x):
class Flatten (line 930) | class Flatten(nn.Module):
method __init__ (line 931) | def __init__(self):
method forward (line 934) | def forward(self, x):
class Linear_fw (line 938) | class Linear_fw(nn.Linear): #used in MAML to forward input with fast weight
method __init__ (line 939) | def __init__(self, in_features, out_features):
method forward (line 944) | def forward(self, x):
class Conv2d_fw (line 951) | class Conv2d_fw(nn.Conv2d): #used in MAML to forward input with fast weight
method __init__ (line 952) | def __init__(self, in_channels, out_channels, kernel_size, stride=1,pa...
method forward (line 958) | def forward(self, x):
class BatchNorm2d_fw (line 972) | class BatchNorm2d_fw(nn.BatchNorm2d): #used in MAML to forward input wit...
method __init__ (line 973) | def __init__(self, num_features):
method forward (line 978) | def forward(self, x):
class ConvBlock (line 989) | class ConvBlock(nn.Module):
method __init__ (line 991) | def __init__(self, indim, outdim, pool = True, padding = 1):
method forward (line 1014) | def forward(self,x):
class SimpleBlock (line 1019) | class SimpleBlock(nn.Module):
method __init__ (line 1021) | def __init__(self, indim, outdim, half_res):
method forward (line 1060) | def forward(self, x):
class BottleneckBlock (line 1074) | class BottleneckBlock(nn.Module):
method __init__ (line 1076) | def __init__(self, indim, outdim, half_res):
method forward (line 1117) | def forward(self, x):
class ConvNet (line 1134) | class ConvNet(nn.Module):
method __init__ (line 1135) | def __init__(self, depth, flatten = True):
method forward (line 1150) | def forward(self,x):
class ConvNetNopool (line 1154) | class ConvNetNopool(nn.Module): #Relation net use a 4 layer conv with po...
method __init__ (line 1155) | def __init__(self, depth):
method forward (line 1167) | def forward(self,x):
class ConvNetS (line 1171) | class ConvNetS(nn.Module): #For omniglot, only 1 input channel, output d...
method __init__ (line 1172) | def __init__(self, depth, flatten = True):
method forward (line 1187) | def forward(self,x):
class ConvNetSNopool (line 1192) | class ConvNetSNopool(nn.Module): #Relation net use a 4 layer conv with p...
method __init__ (line 1193) | def __init__(self, depth):
method forward (line 1205) | def forward(self,x):
class ResNet (line 1210) | class ResNet(nn.Module):
method __init__ (line 1212) | def __init__(self,block,list_of_num_layers, list_of_out_dims, flatten ...
method forward (line 1254) | def forward(self,x):
function Conv4 (line 1258) | def Conv4():
function Conv6 (line 1261) | def Conv6():
function Conv4NP (line 1264) | def Conv4NP():
function Conv6NP (line 1267) | def Conv6NP():
function Conv4S (line 1270) | def Conv4S():
function Conv4SNP (line 1273) | def Conv4SNP():
function ResNet10 (line 1276) | def ResNet10( flatten = True):
function ResNet18 (line 1279) | def ResNet18( flatten = True):
function ResNet34 (line 1282) | def ResNet34( flatten = True):
function ResNet50 (line 1285) | def ResNet50( flatten = True):
function ResNet101 (line 1288) | def ResNet101( flatten = True):
FILE: MAML_MN_FT/data/additional_transforms.py
class ImageJitter (line 15) | class ImageJitter(object):
method __init__ (line 16) | def __init__(self, transformdict):
method __call__ (line 20) | def __call__(self, img):
FILE: MAML_MN_FT/data/datamgr.py
class TransformLoader (line 12) | class TransformLoader:
method __init__ (line 13) | def __init__(self, image_size,
method parse_transform (line 20) | def parse_transform(self, transform_type):
method get_composed_transform (line 36) | def get_composed_transform(self, aug=False):
class DataManager (line 47) | class DataManager:
method get_data_loader (line 49) | def get_data_loader(self, data_file, aug):
class SimpleDataManager (line 53) | class SimpleDataManager(DataManager):
method __init__ (line 54) | def __init__(self, image_size, batch_size):
method get_data_loader (line 59) | def get_data_loader(self, data_file, aug, num_workers=12, tiered_mini=...
class SetDataManager (line 70) | class SetDataManager(DataManager):
method __init__ (line 71) | def __init__(self, image_size, n_way, n_support, n_query, n_eposide=100):
method get_data_loader (line 80) | def get_data_loader(self, data_file, aug, debug=False): # parameters ...
FILE: MAML_MN_FT/data/dataset.py
class SimpleDataset (line 14) | class SimpleDataset:
method __init__ (line 15) | def __init__(self, data_file, transform, target_transform=identity):
method __getitem__ (line 21) | def __getitem__(self, i):
method __len__ (line 28) | def __len__(self):
class SimpleTieredDataset (line 32) | class SimpleTieredDataset:
method __init__ (line 33) | def __init__(self, setname, transform):
method __len__ (line 66) | def __len__(self):
method __getitem__ (line 69) | def __getitem__(self, i):
class SetDataset (line 75) | class SetDataset:
method __init__ (line 76) | def __init__(self, data_file, batch_size, transform):
method __getitem__ (line 98) | def __getitem__(self, i):
method __len__ (line 101) | def __len__(self):
class SubDataset (line 105) | class SubDataset:
method __init__ (line 106) | def __init__(self, sub_meta, cl, transform=transforms.ToTensor(), targ...
method __getitem__ (line 112) | def __getitem__(self, i):
method __len__ (line 120) | def __len__(self):
class EpisodicBatchSampler (line 124) | class EpisodicBatchSampler(object):
method __init__ (line 125) | def __init__(self, n_classes, n_way, n_episodes):
method __len__ (line 130) | def __len__(self):
method __iter__ (line 133) | def __iter__(self):
FILE: MAML_MN_FT/data/feature_loader.py
class SimpleHDF5Dataset (line 5) | class SimpleHDF5Dataset:
method __init__ (line 6) | def __init__(self, file_handle = None):
method __getitem__ (line 18) | def __getitem__(self, i):
method __len__ (line 21) | def __len__(self):
function init_loader (line 24) | def init_loader(filename, get_path=False, path_file=None):
FILE: MAML_MN_FT/io_utils.py
function parse_args (line 20) | def parse_args(script):
function get_assigned_file (line 53) | def get_assigned_file(checkpoint_dir,num):
function get_resume_file (line 57) | def get_resume_file(checkpoint_dir):
function get_best_file (line 68) | def get_best_file(checkpoint_dir):
function print_accuracy (line 75) | def print_accuracy(acc):
function print_with_carriage_return (line 79) | def print_with_carriage_return(line):
function end_carriage_return_print (line 83) | def end_carriage_return_print():
function append_to_file (line 86) | def append_to_file(file, line):
function get_result_file (line 91) | def get_result_file(test_name, method_name):
function calc_recall_precision (line 101) | def calc_recall_precision(y, pred):
FILE: MAML_MN_FT/main.py
function func_not_found (line 4) | def func_not_found(): # just in case we dont have the function
function main (line 8) | def main():
FILE: MAML_MN_FT/methods/DMAML.py
class DMAML (line 14) | class DMAML(MetaTemplate):
method __init__ (line 15) | def __init__(self, model_func, n_way, n_support, pretrain, n_splits, i...
method get_feat_dim (line 65) | def get_feat_dim(self):
method fuse_features (line 75) | def fuse_features(self, x1, x2):
method normalize (line 83) | def normalize(self, x, dim=1):
method fuse_proba (line 88) | def fuse_proba(self, p1, p2):
method set_forward (line 100) | def set_forward(self, x, is_feature=False):
method set_forward_adaptation (line 148) | def set_forward_adaptation(self, x, is_feature=False):
method set_forward_loss (line 151) | def set_forward_loss(self, x):
method train_loop (line 158) | def train_loop(self, epoch, train_loader, optimizer):
method test_loop (line 183) | def test_loop(self, test_loader, return_std=False, metric="acc"):
FILE: MAML_MN_FT/methods/DMatchingNet.py
class DMatchingNet (line 12) | class DMatchingNet(MetaTemplate):
method __init__ (line 13) | def __init__(self, model_func, n_way, n_support, pretrain, n_splits, i...
method get_feat_dim (line 58) | def get_feat_dim(self):
method fuse_features (line 68) | def fuse_features(self, x1, x2):
method normalize (line 76) | def normalize(self, x, dim=1):
method fuse_proba (line 81) | def fuse_proba(self, p1, p2):
method set_forward (line 93) | def set_forward(self, x, is_feature=False):
method predict (line 143) | def predict(self, support, query):
method set_forward_loss (line 158) | def set_forward_loss(self, x):
FILE: MAML_MN_FT/methods/MethodTester.py
class MethodTester (line 21) | class MethodTester():
method __init__ (line 22) | def __init__(self, params):
method get_backbone (line 29) | def get_backbone(self):
method baseline_s2m2_initialize (line 44) | def baseline_s2m2_initialize(self, params, provide_original_image):
method simpleshot_initialize (line 82) | def simpleshot_initialize(self, params, provide_original_image):
method feat_initialize (line 92) | def feat_initialize(self, params, provide_original_image):
method sib_initialize (line 102) | def sib_initialize(self, params, provide_original_image):
method initialize (line 111) | def initialize(self, params, provide_original_image=False):
method get_task (line 126) | def get_task(self, all_from_same_class=False, provide_original_image=F...
method get_task_special (line 149) | def get_task_special(self, sampling="sim"):
method set_experiment_config (line 241) | def set_experiment_config(self, config):
method set_conditional_config (line 244) | def set_conditional_config(self, config_func):
method set_experiment_method (line 247) | def set_experiment_method(self, method):
method add_early_stop_criteria (line 250) | def add_early_stop_criteria(self, iter, acc):
method _should_reinitialize (line 257) | def _should_reinitialize(self, config):
method start_experiment (line 260) | def start_experiment(self, method, config, test_name="No name", condit...
method _get_config_list (line 341) | def _get_config_list(self, config_dict):
method _increment_config_counter (line 356) | def _increment_config_counter(self, config_dict, config_key_counter):
method _generate_test_case_name (line 364) | def _generate_test_case_name(self, config):
method _run_model (line 371) | def _run_model(self, model, show_current_accuracy=True, provide_origin...
method _feature_evaluation (line 403) | def _feature_evaluation(self, model, provide_original_image, sampling=...
method cosine_similarity (line 432) | def cosine_similarity(self, a, b):
method _evaluate_hardness (line 435) | def _evaluate_hardness(self, z_all):
method normalize (line 468) | def normalize(self, x):
method _evaluate_hardness_logodd (line 473) | def _evaluate_hardness_logodd(self, z_all):
method _should_early_stop (line 510) | def _should_early_stop(self, iter, acc):
method meta_train (line 517) | def meta_train(self, config, method, descriptor_str, debug=True, use_t...
method meta_test (line 610) | def meta_test(self, config, method, descriptor_str, debug=True, requir...
FILE: MAML_MN_FT/methods/NNEDSplitNew.py
class NNEDSplitNew (line 9) | class NNEDSplitNew(MetaTemplate):
method __init__ (line 10) | def __init__(self, model_func, n_way, n_support, n_query, pretrain, n_...
method set_forward (line 46) | def set_forward(self, x, is_feature=True):
method calc_pd (line 49) | def calc_pd(self, x, clf_idx):
method calc_ed (line 61) | def calc_ed(self, x):
method temp (line 72) | def temp(self, x):
method normalize (line 81) | def normalize(self, x, dim=1):
method fuse_proba (line 86) | def fuse_proba(self, p1, p2):
method fuse_features (line 98) | def fuse_features(self, x1, x2):
method nn_preprocess (line 106) | def nn_preprocess(self, data, center=None, preprocessing="l2n"):
method get_split_features (line 117) | def get_split_features(self, x, preprocess=False, center=None, preproc...
method set_forward_adaptation (line 136) | def set_forward_adaptation(self, x, image_paths=None, is_feature=True):
method set_forward_loss (line 222) | def set_forward_loss(self, x):
FILE: MAML_MN_FT/methods/PretrainedModel.py
class PretrainedModel (line 17) | class PretrainedModel():
method __init__ (line 18) | def __init__(self, params):
method baseline_s2m2_init (line 40) | def baseline_s2m2_init(self, params):
method simpleshot_init (line 87) | def simpleshot_init(self, params):
method feat_init (line 133) | def feat_init(self, params):
method sib_init (line 152) | def sib_init(self, params):
method get_features (line 177) | def get_features(self, x):
method classify (line 191) | def classify(self, x, normalize_prob=True):
method load_d_specific_classifiers (line 212) | def load_d_specific_classifiers(self, n_clf):
method train_d_specific_classifiers (line 227) | def train_d_specific_classifiers(self, n_clf):
method test_d_specific_classifiers (line 268) | def test_d_specific_classifiers(self, n_clf):
method save_pretrain_dataset (line 295) | def save_pretrain_dataset(self, split):
method get_pretrain_dataset (line 339) | def get_pretrain_dataset(self, split):
method normalize (line 350) | def normalize(self, x):
method _calc_pretrained_class_mean (line 355) | def _calc_pretrained_class_mean(self, normalize=False):
method get_kmeans_pca_model (line 399) | def get_kmeans_pca_model(self, k=8, n_clusters=10, normalize=False):
method get_pretrained_class_mean (line 451) | def get_pretrained_class_mean(self, normalize=False):
FILE: MAML_MN_FT/methods/VanillaMAML.py
class VanillaMAML (line 13) | class VanillaMAML(MetaTemplate):
method __init__ (line 14) | def __init__(self, model_func, n_way, n_support, approx=False, update_...
method set_forward (line 25) | def set_forward(self, x, is_feature=False):
method set_forward_adaptation (line 33) | def set_forward_adaptation(self, x, is_feature=False):
method set_forward_loss (line 36) | def set_forward_loss(self, x):
method train_loop (line 42) | def train_loop(self, epoch, train_loader, optimizer):
method test_loop (line 67) | def test_loop(self, test_loader, return_std=False, metric="acc"):
FILE: MAML_MN_FT/methods/VanillaMatchingNet.py
class VanillaMatchingNet (line 11) | class VanillaMatchingNet(MetaTemplate):
method __init__ (line 12) | def __init__(self, model_func, n_way, n_support):
method normalize (line 19) | def normalize(self, x, dim=1):
method set_forward (line 24) | def set_forward(self, x, is_feature=False):
method set_forward_loss (line 39) | def set_forward_loss(self, x):
FILE: MAML_MN_FT/methods/meta_template.py
class MetaTemplate (line 12) | class MetaTemplate(nn.Module):
method __init__ (line 13) | def __init__(self, model_func, n_way, n_support, change_way = True, im...
method set_forward (line 27) | def set_forward(self,x,is_feature):
method set_forward_loss (line 31) | def set_forward_loss(self, x):
method forward (line 34) | def forward(self,x):
method parse_feature (line 38) | def parse_feature(self,x,is_feature):
method parse_images (line 55) | def parse_images(self, image_paths):
method correct (line 75) | def correct(self, x, metric="acc"):
method normalize (line 88) | def normalize(self, x):
method _evaluate_hardness_logodd (line 93) | def _evaluate_hardness_logodd(self, z_all):
method calc_correct (line 130) | def calc_correct(self, scores):
method train_loop (line 137) | def train_loop(self, epoch, train_loader, optimizer ):
method test_loop (line 158) | def test_loop(self, test_loader, record = None, metric="acc"):
method set_forward_adaptation (line 182) | def set_forward_adaptation(self, x, is_feature = True): #further adapt...
FILE: MAML_MN_FT/methods/meta_toolkits.py
class MatchingNetModule (line 9) | class MatchingNetModule(nn.Module):
method __init__ (line 10) | def __init__(self, feat_dim):
method forward (line 16) | def forward(self, support, query):
method cuda (line 25) | def cuda(self):
class FullyContextualEmbedding (line 32) | class FullyContextualEmbedding(nn.Module):
method __init__ (line 33) | def __init__(self, feat_dim):
method forward (line 40) | def forward(self, f, G):
method cuda (line 54) | def cuda(self):
class FeatureProcessor (line 61) | class FeatureProcessor():
method __init__ (line 62) | def __init__(self, pretrain, n_splits, is_cosine_feature=False, d_feat...
method get_split_features (line 88) | def get_split_features(self, x, preprocess=False, center=None, preproc...
method nn_preprocess (line 106) | def nn_preprocess(self, data, center=None, preprocessing="l2n"):
method calc_pd (line 117) | def calc_pd(self, x, clf_idx):
method normalize (line 122) | def normalize(self, x, dim=1):
method get_d_feature (line 127) | def get_d_feature(self, x):
method get_features (line 144) | def get_features(self, support, query):
class MAMLBlock (line 163) | class MAMLBlock(nn.Module):
method __init__ (line 164) | def __init__(self, feat_dim, n_way, update_step, approx=True, lr=0.01):
method forward (line 183) | def forward(self, x):
method fit (line 188) | def fit(self, support, labels):
method predict (line 215) | def predict(self, query):
class LEOBlock (line 219) | class LEOBlock(nn.Module):
method __init__ (line 220) | def __init__(self, feat_dim, latent_dim, n_way, drop_rate):
method prepare_for_rel_net (line 243) | def prepare_for_rel_net(self, embeddings):
method get_sampled_weights (line 250) | def get_sampled_weights(self, weight_stats):
method finetune_z (line 259) | def finetune_z(self, support, initial_z, initial_w, labels):
method finetune_w (line 277) | def finetune_w(self, support, initial_w, labels):
method forward_relation_net (line 293) | def forward_relation_net(self, embeddings):
method average_codes_per_class (line 301) | def average_codes_per_class(self, relation_net_outputs):
method possibly_sample (line 307) | def possibly_sample(self, distribution_params, stddev_offset=0.):
method forward_decoder (line 322) | def forward_decoder(self, latents):
method calc_kl_penalty (line 332) | def calc_kl_penalty(self, latent_samples, latent_distributions):
method calc_encoder_penalty (line 339) | def calc_encoder_penalty(self, z_f, z):
method calc_l2_penalty (line 342) | def calc_l2_penalty(self):
method calc_orthogonality_penalty (line 351) | def calc_orthogonality_penalty(self):
method fit (line 361) | def fit(self, support, labels):
method predict (line 377) | def predict(self, query, w):
FILE: MAML_MN_FT/models/FeatWRN.py
function conv3x3 (line 10) | def conv3x3(in_planes, out_planes, stride=1):
function conv_init (line 13) | def conv_init(m):
class wide_basic (line 22) | class wide_basic(nn.Module):
method __init__ (line 23) | def __init__(self, in_planes, planes, dropout_rate, stride=1):
method forward (line 37) | def forward(self, x):
class Wide_ResNet (line 44) | class Wide_ResNet(nn.Module):
method __init__ (line 45) | def __init__(self, depth, widen_factor, dropout_rate):
method _wide_layer (line 62) | def _wide_layer(self, block, planes, num_blocks, dropout_rate, stride):
method forward (line 72) | def forward(self, x):
class FEATWRN (line 84) | class FEATWRN(nn.Module):
method __init__ (line 85) | def __init__(self, num_classes=64):
method forward (line 92) | def forward(self, x):
method forward_feature (line 96) | def forward_feature(self, x):
FILE: MAML_MN_FT/models/SimpleShotResNet.py
function conv3x3 (line 7) | def conv3x3(in_planes, out_planes, stride=1):
function conv1x1 (line 13) | def conv1x1(in_planes, out_planes, stride=1):
class BasicBlock (line 18) | class BasicBlock(nn.Module):
method __init__ (line 21) | def __init__(self, inplanes, planes, stride=1, downsample=None):
method forward (line 31) | def forward(self, x):
class Bottleneck (line 50) | class Bottleneck(nn.Module):
method __init__ (line 53) | def __init__(self, inplanes, planes, stride=1, downsample=None):
method forward (line 65) | def forward(self, x):
class ResNet (line 88) | class ResNet(nn.Module):
method __init__ (line 90) | def __init__(self, block, layers, num_classes=1000, zero_init_residual...
method _make_layer (line 125) | def _make_layer(self, block, planes, blocks, stride=1):
method forward (line 141) | def forward(self, x, feature=False):
function resnet10 (line 165) | def resnet10(**kwargs):
function resnet18 (line 172) | def resnet18(**kwargs):
function resnet34 (line 179) | def resnet34(**kwargs):
function resnet50 (line 186) | def resnet50(**kwargs):
function resnet101 (line 193) | def resnet101(**kwargs):
function resnet152 (line 200) | def resnet152(**kwargs):
FILE: MAML_MN_FT/models/SimpleShotWideResNet.py
function conv3x3 (line 9) | def conv3x3(in_planes, out_planes, stride=1):
function conv_init (line 13) | def conv_init(m):
class wide_basic (line 23) | class wide_basic(nn.Module):
method __init__ (line 24) | def __init__(self, in_planes, planes, dropout_rate, stride=1):
method forward (line 38) | def forward(self, x):
class Wide_ResNet (line 46) | class Wide_ResNet(nn.Module):
method __init__ (line 47) | def __init__(self, depth, widen_factor, dropout_rate, num_classes, rem...
method _wide_layer (line 75) | def _wide_layer(self, block, planes, num_blocks, dropout_rate, stride):
method forward (line 85) | def forward(self, x, feature=False):
function wideres (line 105) | def wideres(num_classes=64, remove_linear=False):
FILE: MAML_MN_FT/save_features.py
function save_features (line 25) | def save_features(model, data_loader, outfile, params):
function baseline_s2m2_init (line 57) | def baseline_s2m2_init(params):
function remove_module_from_param_name (line 173) | def remove_module_from_param_name(params_name_str):
function simple_shot_init (line 179) | def simple_shot_init(params, split):
function feat_init (line 218) | def feat_init(params, split):
function sib_init (line 246) | def sib_init(params, split):
function cosine_init (line 273) | def cosine_init(params, split):
function initialize_and_save (line 307) | def initialize_and_save(params, split):
FILE: MAML_MN_FT/tests/MetaTrain.py
class MetaTrain (line 8) | class MetaTrain(MethodTester):
method __init__ (line 9) | def __init__(self, params):
method maml5_resnet (line 16) | def maml5_resnet(self):
method maml5_wrn (line 37) | def maml5_wrn(self):
method maml1_resnet (line 58) | def maml1_resnet(self):
method maml1_wrn (line 79) | def maml1_wrn(self):
method maml5_resnet_tiered (line 100) | def maml5_resnet_tiered(self):
method maml5_wrn_tiered (line 121) | def maml5_wrn_tiered(self):
method maml1_resnet_tiered (line 142) | def maml1_resnet_tiered(self):
method maml1_wrn_tiered (line 163) | def maml1_wrn_tiered(self):
method maml5_ifsl_resnet (line 186) | def maml5_ifsl_resnet(self):
method maml5_ifsl_wrn (line 224) | def maml5_ifsl_wrn(self):
method maml1_ifsl_resnet (line 262) | def maml1_ifsl_resnet(self):
method maml1_ifsl_wrn (line 300) | def maml1_ifsl_wrn(self):
method maml5_ifsl_resnet_tiered (line 338) | def maml5_ifsl_resnet_tiered(self):
method maml5_ifsl_wrn_tiered (line 376) | def maml5_ifsl_wrn_tiered(self):
method maml1_ifsl_resnet_tiered (line 414) | def maml1_ifsl_resnet_tiered(self):
method maml1_ifsl_wrn_tiered (line 452) | def maml1_ifsl_wrn_tiered(self):
method mn5_resnet (line 492) | def mn5_resnet(self):
method mn5_wrn (line 507) | def mn5_wrn(self):
method mn1_resnet (line 522) | def mn1_resnet(self):
method mn1_wrn (line 537) | def mn1_wrn(self):
method mn5_resnet_tiered (line 552) | def mn5_resnet_tiered(self):
method mn5_wrn_tiered (line 567) | def mn5_wrn_tiered(self):
method mn1_resnet_tiered (line 582) | def mn1_resnet_tiered(self):
method mn1_wrn_tiered (line 597) | def mn1_wrn_tiered(self):
method mn5_ifsl_resnet (line 613) | def mn5_ifsl_resnet(self):
method mn5_ifsl_wrn (line 646) | def mn5_ifsl_wrn(self):
method mn1_ifsl_resnet (line 679) | def mn1_ifsl_resnet(self):
method mn1_ifsl_wrn (line 712) | def mn1_ifsl_wrn(self):
method mn5_ifsl_resnet_tiered (line 745) | def mn5_ifsl_resnet_tiered(self):
method mn5_ifsl_wrn_tiered (line 779) | def mn5_ifsl_wrn_tiered(self):
method mn1_ifsl_resnet_tiered (line 813) | def mn1_ifsl_resnet_tiered(self):
method mn1_ifsl_wrn_tiered (line 847) | def mn1_ifsl_wrn_tiered(self):
FILE: MAML_MN_FT/utils.py
function one_hot (line 4) | def one_hot(y, num_class):
function DBindex (line 7) | def DBindex(cl_data_file):
function sparsity (line 25) | def sparsity(cl_data_file):
FILE: MTL/configs/baseline_config.py
class Params (line 1) | class Params():
method __init__ (line 2) | def __init__(self):
function mini_5_resnet_baseline (line 6) | def mini_5_resnet_baseline():
function mini_1_resnet_baseline (line 21) | def mini_1_resnet_baseline():
function tiered_5_resnet_baseline (line 36) | def tiered_5_resnet_baseline():
function tiered_1_resnet_baseline (line 51) | def tiered_1_resnet_baseline():
function mini_5_wrn_baseline (line 66) | def mini_5_wrn_baseline():
function mini_1_wrn_baseline (line 81) | def mini_1_wrn_baseline():
function tiered_5_wrn_baseline (line 96) | def tiered_5_wrn_baseline():
function tiered_1_wrn_baseline (line 111) | def tiered_1_wrn_baseline():
FILE: MTL/configs/ifsl_resnet_config.py
class Params (line 1) | class Params():
method __init__ (line 2) | def __init__(self):
function mini_5_resnet_d (line 6) | def mini_5_resnet_d():
function mini_1_resnet_d (line 44) | def mini_1_resnet_d():
function tiered_5_resnet_d (line 82) | def tiered_5_resnet_d():
function tiered_1_resnet_d (line 120) | def tiered_1_resnet_d():
FILE: MTL/configs/ifsl_wrn_config.py
class Params (line 1) | class Params():
method __init__ (line 2) | def __init__(self):
function mini_5_wrn_d (line 6) | def mini_5_wrn_d():
function mini_1_wrn_d (line 44) | def mini_1_wrn_d():
function tiered_5_wrn_d (line 82) | def tiered_5_wrn_d():
function tiered_1_wrn_d (line 120) | def tiered_1_wrn_d():
FILE: MTL/dataloader/dataset_loader.py
class DatasetLoader (line 20) | class DatasetLoader(Dataset):
method __init__ (line 22) | def __init__(self, setname, args, dataset="miniImagenet", train_aug=Fa...
method __len__ (line 99) | def __len__(self):
method __getitem__ (line 102) | def __getitem__(self, i):
FILE: MTL/dataloader/samplers.py
class CategoriesSampler (line 15) | class CategoriesSampler():
method __init__ (line 17) | def __init__(self, label, n_batch, n_cls, n_per):
method __len__ (line 29) | def __len__(self):
method __iter__ (line 31) | def __iter__(self):
FILE: MTL/models/IFSL.py
class PretrainNet (line 14) | class PretrainNet():
method __init__ (line 15) | def __init__(self, args):
method load_classifier (line 36) | def load_classifier(self, n_splits, epoch=22, num_classes=64):
method train_classifier (line 50) | def train_classifier(self, n_splits, num_classes=64):
method normalize (line 94) | def normalize(self, x):
method get_base_means (line 99) | def get_base_means(self, num_classes=64, is_cosine_feature=False):
method save_base_means (line 107) | def save_base_means(self, num_classes=64, is_cosine_feature=False):
class distLinear (line 137) | class distLinear(nn.Module):
method __init__ (line 138) | def __init__(self, indim, outdim, class_wise_learnable_norm=True):
method forward (line 150) | def forward(self, x):
class MultiLinearClassifier (line 162) | class MultiLinearClassifier(nn.Module):
method __init__ (line 163) | def __init__(self, n_clf, feat_dim, n_way, sum_log=True, permute=False...
method create_clf (line 177) | def create_clf(self, loss_type, in_dim, out_dim):
method forward (line 183) | def forward(self, X):
class MultiBiLinearClassifier (line 204) | class MultiBiLinearClassifier(nn.Module):
method __init__ (line 205) | def __init__(self, n_clf, x_feat_dim, d_feat_dim, n_way, sum_log=True,...
method fuse_logits (line 217) | def fuse_logits(self, p1, p2):
method create_clf (line 229) | def create_clf(self, loss_type, in_dim, out_dim):
method forward (line 235) | def forward(self, X, D, counterfactual=False):
class BaseLearner (line 258) | class BaseLearner(nn.Module):
method __init__ (line 260) | def __init__(self, args, z_dim):
method forward (line 271) | def forward(self, input_x, the_vars=None):
method parameters (line 279) | def parameters(self):
method initialize (line 282) | def initialize(self):
class DeconfoundedLearner (line 287) | class DeconfoundedLearner():
method __init__ (line 288) | def __init__(self, pretrain, classifier="bi", logit_fusion="product", ...
method calc_pd (line 355) | def calc_pd(self, x, clf_idx):
method get_pd_features (line 359) | def get_pd_features(self, x):
method get_ed_features (line 369) | def get_ed_features(self, x):
method get_split_features (line 379) | def get_split_features(self, x):
method fuse_feature (line 389) | def fuse_feature(self, a, b, dim=2):
method fuse_logits (line 397) | def fuse_logits(self, p1, p2):
method backward_loss_and_step (line 409) | def backward_loss_and_step(self, loss, optimizer):
method fit (line 421) | def fit(self, support, query, support_labels, support_embedding, query...
method fit_no_split (line 427) | def fit_no_split(self, support, query, support_labels, support_embeddi...
method no_split_update (line 451) | def no_split_update(self, x, d, label):
method calc_no_split_logit (line 471) | def calc_no_split_logit(self, x, d, fast_weight_x=None, fast_weight_d=...
method fit_multi_splits (line 481) | def fit_multi_splits(self, support, query, support_labels, support_emb...
method predict (line 549) | def predict(self, support_labels, support_embedding, query_embedding):
FILE: MTL/models/IFSL_modules.py
class Linear_fw (line 7) | class Linear_fw(nn.Linear): #used in MAML to forward input with fast weight
method __init__ (line 8) | def __init__(self, in_features, out_features):
method forward (line 13) | def forward(self, x):
class IFSLBaseLearner (line 21) | class IFSLBaseLearner(nn.Module):
method __init__ (line 22) | def __init__(self, feat_dim, n_way, update_step, approx=True, lr=0.01):
method forward (line 33) | def forward(self, x):
method fit (line 37) | def fit(self, support, labels):
method predict (line 64) | def predict(self, query):
class FeatureProcessor (line 68) | class FeatureProcessor():
method __init__ (line 69) | def __init__(self, pretrain, n_splits, is_cosine_feature=False, d_feat...
method get_split_features (line 95) | def get_split_features(self, x, preprocess=False, center=None, preproc...
method nn_preprocess (line 120) | def nn_preprocess(self, data, center=None, preprocessing="l2n"):
method calc_pd (line 131) | def calc_pd(self, x, clf_idx):
method normalize (line 135) | def normalize(self, x, dim=1):
method get_d_feature (line 140) | def get_d_feature(self, x, x_ori):
method get_features (line 160) | def get_features(self, support, query, support_ori, query_ori):
FILE: MTL/models/IFSL_pretrain.py
class Pretrain (line 11) | class Pretrain():
method __init__ (line 12) | def __init__(self, dataset, method, model, init_model=True):
method simpleshot_init (line 22) | def simpleshot_init(self):
method classify (line 69) | def classify(self, x, normalize_prob=True):
method get_features (line 89) | def get_features(self, x):
method get_pretrained_class_mean (line 103) | def get_pretrained_class_mean(self, normalize=False):
method normalize (line 117) | def normalize(self, x):
method get_base_means (line 122) | def get_base_means(self, normalize=False):
FILE: MTL/models/ResNet10.py
class ResNet (line 7) | class ResNet(nn.Module):
method __init__ (line 8) | def __init__(self, block, conv, layers, num_classes=1000, zero_init_re...
method _make_layer (line 46) | def _make_layer(self, block, planes, blocks, stride=1):
method forward (line 63) | def forward(self, x, feature=False):
function ResNet10MTL (line 95) | def ResNet10MTL(**kwargs):
function ResNet10 (line 100) | def ResNet10(**kwargs):
FILE: MTL/models/WRN28.py
class wide_basic (line 6) | class wide_basic(nn.Module):
method __init__ (line 7) | def __init__(self, in_planes, planes, dropout_rate, stride=1):
method forward (line 21) | def forward(self, x):
class wide_basic_mtl (line 28) | class wide_basic_mtl(nn.Module):
method __init__ (line 29) | def __init__(self, in_planes, planes, dropout_rate, stride=1):
method forward (line 43) | def forward(self, x):
function conv3x3 (line 50) | def conv3x3(in_planes, out_planes, stride=1):
function conv3x3mtl (line 54) | def conv3x3mtl(in_planes, out_planes, stride=1):
class Wide_ResNet (line 58) | class Wide_ResNet(nn.Module):
method __init__ (line 59) | def __init__(self, block, conv3x3proto, depth, widen_factor, dropout_r...
method _wide_layer (line 87) | def _wide_layer(self, block, planes, num_blocks, dropout_rate, stride):
method forward (line 97) | def forward(self, x, feature=False):
function WideRes28 (line 117) | def WideRes28(num_classes=64, remove_linear=False):
function WideRes28Mtl (line 123) | def WideRes28Mtl(num_classes=64, remove_linear=False):
FILE: MTL/models/conv2d_mtl.py
class _ConvNdMtl (line 19) | class _ConvNdMtl(Module):
method __init__ (line 21) | def __init__(self, in_channels, out_channels, kernel_size, stride,
method reset_parameters (line 55) | def reset_parameters(self):
method extra_repr (line 66) | def extra_repr(self):
class Conv2dMtl (line 81) | class Conv2dMtl(_ConvNdMtl):
method __init__ (line 83) | def __init__(self, in_channels, out_channels, kernel_size, stride=1,
method forward (line 93) | def forward(self, inp):
FILE: MTL/models/mtl.py
class BaseLearner (line 24) | class BaseLearner(nn.Module):
method __init__ (line 26) | def __init__(self, args, z_dim):
method forward (line 37) | def forward(self, input_x, the_vars=None):
method parameters (line 45) | def parameters(self):
method initialize (line 48) | def initialize(self):
class MtlLearner (line 53) | class MtlLearner(nn.Module):
method __init__ (line 55) | def __init__(self, args, mode='meta', num_cls=64):
method load_pretrain_weight (line 85) | def load_pretrain_weight(self, model_dir):
method encode (line 101) | def encode(self, x):
method forward (line 105) | def forward(self, inp):
method pretrain_forward (line 123) | def pretrain_forward(self, inp):
method meta_forward (line 132) | def meta_forward(self, data_shot, label_shot, data_query, val=False):
method predict (line 174) | def predict(self, embedding_shot, label_shot, embedding_query):
method backward_loss_and_step (line 180) | def backward_loss_and_step(self, loss, optimizer=None):
method preval_forward (line 187) | def preval_forward(self, data_shot, label_shot, data_query):
FILE: MTL/models/resnet_mtl.py
function conv3x3 (line 15) | def conv3x3(in_planes, out_planes, stride=1):
class BasicBlock (line 19) | class BasicBlock(nn.Module):
method __init__ (line 22) | def __init__(self, inplanes, planes, stride=1, downsample=None):
method forward (line 32) | def forward(self, x):
class Bottleneck (line 50) | class Bottleneck(nn.Module):
method __init__ (line 53) | def __init__(self, inplanes, planes, stride=1, downsample=None):
method forward (line 66) | def forward(self, x):
function conv3x3mtl (line 88) | def conv3x3mtl(in_planes, out_planes, stride=1):
class BasicBlockMtl (line 93) | class BasicBlockMtl(nn.Module):
method __init__ (line 96) | def __init__(self, inplanes, planes, stride=1, downsample=None):
method forward (line 106) | def forward(self, x):
class BottleneckMtl (line 125) | class BottleneckMtl(nn.Module):
method __init__ (line 128) | def __init__(self, inplanes, planes, stride=1, downsample=None):
method forward (line 141) | def forward(self, x):
class ResNetMtl (line 164) | class ResNetMtl(nn.Module):
method __init__ (line 165) | def __init__(self, layers=[4, 4, 4], mtl=True):
method _make_layer (line 190) | def _make_layer(self, block, planes, blocks, stride=1):
method forward (line 206) | def forward(self, x):
FILE: MTL/run_meta.py
function run_exp (line 13) | def run_exp(num_batch=1000, shot=1, query=15, lr1=0.0001, lr2=0.001, bas...
FILE: MTL/run_pre.py
function run_exp (line 13) | def run_exp(lr=0.1, gamma=0.2, step_size=30):
FILE: MTL/run_pre_clfs.py
function run_exp (line 13) | def run_exp(lr=0.1, gamma=0.2, step_size=30, n_clfs=10):
FILE: MTL/run_test.py
function run_exp (line 13) | def run_exp(num_batch=1000, shot=1, query=15, lr1=0.0001, lr2=0.001, bas...
FILE: MTL/trainer/meta.py
class MetaTrainer (line 31) | class MetaTrainer(object):
method __init__ (line 33) | def __init__(self, args):
method write_output_message (line 123) | def write_output_message(self, message, file_name=None):
method save_model (line 131) | def save_model(self, name):
method train (line 138) | def train(self):
method eval (line 301) | def eval(self):
FILE: MTL/trainer/pre.py
class PreTrainer (line 23) | class PreTrainer(object):
method __init__ (line 25) | def __init__(self, args):
method save_model (line 70) | def save_model(self, name):
method train (line 77) | def train(self):
FILE: MTL/utils/gpu_tools.py
function set_gpu (line 14) | def set_gpu(cuda_device):
FILE: MTL/utils/hacc.py
class Hacc (line 4) | class Hacc():
method __init__ (line 5) | def __init__(self, splits=10, topk=10):
method add_data (line 11) | def add_data(self, hardness, correct_prediction):
method get_splits_hacc (line 15) | def get_splits_hacc(self):
method get_topk_hacc (line 31) | def get_topk_hacc(self):
method get_topk_hard_acc (line 42) | def get_topk_hard_acc(self):
method get_acc_in_range (line 50) | def get_acc_in_range(self, start, end):
method get_plot (line 60) | def get_plot(self, splits):
FILE: MTL/utils/misc.py
function ensure_path (line 22) | def ensure_path(path):
class Averager (line 32) | class Averager():
method __init__ (line 34) | def __init__(self):
method add (line 38) | def add(self, x):
method item (line 42) | def item(self):
function count_acc (line 45) | def count_acc(logits, label):
function normalize (line 58) | def normalize(x):
function count_dacc (line 63) | def count_dacc(pred_logits, support_labels, query_labels, support_imgs, ...
function get_hardness_correct (line 93) | def get_hardness_correct(pred_logits, support_labels, query_labels, supp...
class Timer (line 121) | class Timer():
method __init__ (line 123) | def __init__(self):
method measure (line 126) | def measure(self, p=1):
function pprint (line 137) | def pprint(x):
function format_time (line 141) | def format_time(seconds):
function compute_confidence_interval (line 174) | def compute_confidence_interval(data):
function progress_bar (line 197) | def progress_bar(current, total, msg=None):
FILE: SIB/PretrainedModel.py
class PretrainedModel (line 18) | class PretrainedModel():
method __init__ (line 19) | def __init__(self, params):
method simpleshot_init (line 40) | def simpleshot_init(self, params):
method sib_init (line 78) | def sib_init(self, params):
method get_features (line 84) | def get_features(self, x):
method classify (line 100) | def classify(self, x, normalize_prob=True):
method load_d_specific_classifiers (line 122) | def load_d_specific_classifiers(self, n_clf):
method load_classifier_weights (line 139) | def load_classifier_weights(self, n_clf, idx):
method train_d_specific_classifiers (line 148) | def train_d_specific_classifiers(self, n_clf):
method test_d_specific_classifiers (line 188) | def test_d_specific_classifiers(self, n_clf):
method save_pretrain_dataset (line 215) | def save_pretrain_dataset(self, split):
method get_pretrain_dataset (line 258) | def get_pretrain_dataset(self, split):
method normalize (line 269) | def normalize(self, x):
method _calc_pretrained_class_mean (line 274) | def _calc_pretrained_class_mean(self, normalize=False):
method get_kmeans_pca_model (line 305) | def get_kmeans_pca_model(self, k=8, n_clusters=10, normalize=False):
method get_pretrained_class_mean (line 357) | def get_pretrained_class_mean(self, normalize=False):
FILE: SIB/algorithm.py
class Params (line 26) | class Params:
method __init__ (line 27) | def __init__(self):
class Algorithm (line 30) | class Algorithm:
method __init__ (line 46) | def __init__(self, args, logger, netFeat, netSIB, optimizer, criterion...
method load_ckpt (line 84) | def load_ckpt(self, ckptPth):
method compute_grad_loss (line 104) | def compute_grad_loss(self, clsScore, QueryLabel):
method cosine_similarity (line 127) | def cosine_similarity(self, a, b):
method calc_diff_scores (line 130) | def calc_diff_scores(self, pretrain, support, query, support_labels, q...
method normalize (line 145) | def normalize(self, x):
method _evaluate_hardness_logodd (line 150) | def _evaluate_hardness_logodd(self, pretrain, support, query, support_...
method validate (line 181) | def validate(self, valLoader, lr=None, mode='val'):
method write_output_message (line 255) | def write_output_message(self, message):
method train (line 260) | def train(self, trainLoader, valLoader, lr=None, coeffGrad=0.0) :
FILE: SIB/backbone.py
function init_layer (line 18) | def init_layer(L):
class NNClassifier (line 28) | class NNClassifier():
method __init__ (line 29) | def __init__(self, n_way):
method normalize (line 32) | def normalize(self, x):
method preprocess (line 37) | def preprocess(self, data):
method dist (line 46) | def dist(self, x1, x2):
method kl_divergence (line 49) | def kl_divergence(self, k1, k2):
method fit (line 55) | def fit(self, support, support_labels, support_weights=None):
method predict (line 73) | def predict(self, query):
method predict_alt (line 85) | def predict_alt(self, query, measure="euclidean", norm_scores=False, t...
class MultiNNBiClassifier (line 114) | class MultiNNBiClassifier():
method __init__ (line 115) | def __init__(self, n_way, n_classifiers, measure="linear", fusion="lin...
method fit (line 124) | def fit(self, support_x, support_d, support_labels, support_weights=No...
method fuse_proba (line 133) | def fuse_proba(self, p1, p2):
method predict (line 145) | def predict(self, query_x, query_d, weights=None, counterfactual=False):
class MultiNNClassifier (line 167) | class MultiNNClassifier():
method __init__ (line 168) | def __init__(self, n_way, n_classifiers, measure="euclidean", temp=1.0):
method fit (line 178) | def fit(self, support, support_labels, support_weights=None):
method predict (line 189) | def predict(self, query, weights=None):
class BidrectionalLSTM (line 209) | class BidrectionalLSTM(nn.Module):
method __init__ (line 210) | def __init__(self, size: int, layers: int):
method forward (line 229) | def forward(self, inputs):
class AttentionLSTM (line 242) | class AttentionLSTM(nn.Module):
method __init__ (line 243) | def __init__(self, size: int, unrolling_steps: int):
method forward (line 258) | def forward(self, support, queries):
class MultiLinearClassifier (line 288) | class MultiLinearClassifier(nn.Module):
method __init__ (line 289) | def __init__(self, n_clf, feat_dim, n_way, sum_log=True, permute=False...
method create_clf (line 303) | def create_clf(self, loss_type, in_dim, out_dim):
method forward (line 309) | def forward(self, X):
class MultiBiLinearClassifier (line 330) | class MultiBiLinearClassifier(nn.Module):
method __init__ (line 331) | def __init__(self, n_clf, x_feat_dim, d_feat_dim, n_way, sum_log=True,...
method fuse_logits (line 343) | def fuse_logits(self, p1, p2):
method create_clf (line 355) | def create_clf(self, loss_type, in_dim, out_dim):
method forward (line 361) | def forward(self, X, D, counterfactual=False):
class ResNetKernelClusterAgent (line 384) | class ResNetKernelClusterAgent():
method __init__ (line 385) | def __init__(self, pretrain, n_clusters, pca_dim, cluster_method="kmea...
method fit (line 391) | def fit(self):
class ResNetParamClusterModel (line 407) | class ResNetParamClusterModel():
method __init__ (line 408) | def __init__(self, pretrain, n_clusters, cluster_method="kmeans"):
method cluster (line 413) | def cluster(self, features, n_clusters):
method get_weight_features (line 418) | def get_weight_features(self, weights):
method fit (line 424) | def fit(self):
method conv_forward (line 432) | def conv_forward(self, inputs, labels, n_clusters, original_conv):
method forward (line 452) | def forward(self, imgs):
method forward2 (line 491) | def forward2(self, imgs):
method forward3 (line 517) | def forward3(self, imgs):
class BasisTransformer (line 524) | class BasisTransformer():
method __init__ (line 525) | def __init__(self, pretrain, recluster=False, cluster_method="kmeans",...
method fit (line 532) | def fit(self, n_clusters, feat_dim, pca_dim=50):
method transform (line 573) | def transform(self, X):
class KernelTransformer (line 582) | class KernelTransformer():
method __init__ (line 583) | def __init__(self, feat_dim, kernel):
method fit (line 587) | def fit(self, features):
method transform (line 594) | def transform(self, X):
method kernel_f (line 602) | def kernel_f(self, x1, x2):
class ChannelwiseClassifier (line 613) | class ChannelwiseClassifier(nn.Module):
method __init__ (line 614) | def __init__(self, feat_dim, n_way, weight, bias=False):
method reset_parameters (line 626) | def reset_parameters(self):
method forward (line 635) | def forward(self, X):
class UnbiasedClassifier (line 660) | class UnbiasedClassifier(nn.Module):
method __init__ (line 673) | def __init__(self, n_way, x_feature_dim, z_feature_dim, d_feature_dim=0,
method create_clf (line 694) | def create_clf(self, feat_dim):
method create_logit_fusion_fn (line 701) | def create_logit_fusion_fn(self):
method get_fused_feature (line 710) | def get_fused_feature(self, feature_array):
method forward (line 714) | def forward(self, X, Z, D=None):
class XDBiClassifier (line 733) | class XDBiClassifier(nn.Module):
method __init__ (line 734) | def __init__(self, n_way, x_feature_dim, d_feature_dim, architecture="...
method create_clf (line 751) | def create_clf(self, feat_dim):
method create_logit_fusion_fn (line 758) | def create_logit_fusion_fn(self):
method cat_for_logit_fusion (line 769) | def cat_for_logit_fusion(self, A, B):
method forward (line 774) | def forward(self, X, D):
class XDClassifier (line 791) | class XDClassifier(nn.Module):
method __init__ (line 804) | def __init__(self, n_way, x_feature_dim, d_feature_dim, architecture="...
method get_feature_dim (line 826) | def get_feature_dim(self):
method create_clf (line 840) | def create_clf(self, feat_dim):
method get_fused_feature (line 847) | def get_fused_feature(self, feature_array):
method forward (line 855) | def forward(self, X, D):
class ProductGate (line 870) | class ProductGate(nn.Module):
method __init__ (line 871) | def __init__(self):
method forward (line 874) | def forward(self, x):
class HarmonicGate (line 882) | class HarmonicGate(nn.Module):
method __init__ (line 883) | def __init__(self):
method forward (line 886) | def forward(self, x):
class SumGate (line 895) | class SumGate(nn.Module):
method __init__ (line 896) | def __init__(self):
method forward (line 899) | def forward(self, x):
class distLinear (line 906) | class distLinear(nn.Module):
method __init__ (line 907) | def __init__(self, indim, outdim, class_wise_learnable_norm=True):
method forward (line 919) | def forward(self, x):
class Flatten (line 930) | class Flatten(nn.Module):
method __init__ (line 931) | def __init__(self):
method forward (line 934) | def forward(self, x):
class Linear_fw (line 938) | class Linear_fw(nn.Linear): #used in MAML to forward input with fast weight
method __init__ (line 939) | def __init__(self, in_features, out_features):
method forward (line 944) | def forward(self, x):
class Conv2d_fw (line 951) | class Conv2d_fw(nn.Conv2d): #used in MAML to forward input with fast weight
method __init__ (line 952) | def __init__(self, in_channels, out_channels, kernel_size, stride=1,pa...
method forward (line 958) | def forward(self, x):
class BatchNorm2d_fw (line 972) | class BatchNorm2d_fw(nn.BatchNorm2d): #used in MAML to forward input wit...
method __init__ (line 973) | def __init__(self, num_features):
method forward (line 978) | def forward(self, x):
class ConvBlock (line 989) | class ConvBlock(nn.Module):
method __init__ (line 991) | def __init__(self, indim, outdim, pool = True, padding = 1):
method forward (line 1014) | def forward(self,x):
class SimpleBlock (line 1019) | class SimpleBlock(nn.Module):
method __init__ (line 1021) | def __init__(self, indim, outdim, half_res):
method forward (line 1060) | def forward(self, x):
class BottleneckBlock (line 1074) | class BottleneckBlock(nn.Module):
method __init__ (line 1076) | def __init__(self, indim, outdim, half_res):
method forward (line 1117) | def forward(self, x):
class ConvNet (line 1134) | class ConvNet(nn.Module):
method __init__ (line 1135) | def __init__(self, depth, flatten = True):
method forward (line 1150) | def forward(self,x):
class ConvNetNopool (line 1154) | class ConvNetNopool(nn.Module): #Relation net use a 4 layer conv with po...
method __init__ (line 1155) | def __init__(self, depth):
method forward (line 1167) | def forward(self,x):
class ConvNetS (line 1171) | class ConvNetS(nn.Module): #For omniglot, only 1 input channel, output d...
method __init__ (line 1172) | def __init__(self, depth, flatten = True):
method forward (line 1187) | def forward(self,x):
class ConvNetSNopool (line 1192) | class ConvNetSNopool(nn.Module): #Relation net use a 4 layer conv with p...
method __init__ (line 1193) | def __init__(self, depth):
method forward (line 1205) | def forward(self,x):
class ResNet (line 1210) | class ResNet(nn.Module):
method __init__ (line 1212) | def __init__(self,block,list_of_num_layers, list_of_out_dims, flatten ...
method forward (line 1254) | def forward(self,x):
function Conv4 (line 1258) | def Conv4():
function Conv6 (line 1261) | def Conv6():
function Conv4NP (line 1264) | def Conv4NP():
function Conv6NP (line 1267) | def Conv6NP():
function Conv4S (line 1270) | def Conv4S():
function Conv4SNP (line 1273) | def Conv4SNP():
function ResNet10 (line 1276) | def ResNet10( flatten = True):
function ResNet18 (line 1279) | def ResNet18( flatten = True):
function ResNet34 (line 1282) | def ResNet34( flatten = True):
function ResNet50 (line 1285) | def ResNet50( flatten = True):
function ResNet101 (line 1288) | def ResNet101( flatten = True):
FILE: SIB/data/additional_transforms.py
class ImageJitter (line 15) | class ImageJitter(object):
method __init__ (line 16) | def __init__(self, transformdict):
method __call__ (line 20) | def __call__(self, img):
FILE: SIB/data/datamgr.py
class TransformLoader (line 12) | class TransformLoader:
method __init__ (line 13) | def __init__(self, image_size,
method parse_transform (line 20) | def parse_transform(self, transform_type):
method get_composed_transform (line 36) | def get_composed_transform(self, aug=False):
class DataManager (line 47) | class DataManager:
method get_data_loader (line 49) | def get_data_loader(self, data_file, aug):
class SimpleDataManager (line 53) | class SimpleDataManager(DataManager):
method __init__ (line 54) | def __init__(self, image_size, batch_size):
method get_data_loader (line 59) | def get_data_loader(self, data_file, aug, num_workers=12, tiered_mini=...
class SetDataManager (line 70) | class SetDataManager(DataManager):
method __init__ (line 71) | def __init__(self, image_size, n_way, n_support, n_query, n_eposide=100):
method get_data_loader (line 80) | def get_data_loader(self, data_file, aug, debug=False): # parameters ...
FILE: SIB/data/dataset.py
class SimpleDataset (line 14) | class SimpleDataset:
method __init__ (line 15) | def __init__(self, data_file, transform, target_transform=identity):
method __getitem__ (line 21) | def __getitem__(self, i):
method __len__ (line 28) | def __len__(self):
class SimpleTieredDataset (line 32) | class SimpleTieredDataset:
method __init__ (line 33) | def __init__(self, setname, transform):
method __len__ (line 66) | def __len__(self):
method __getitem__ (line 69) | def __getitem__(self, i):
class SetDataset (line 75) | class SetDataset:
method __init__ (line 76) | def __init__(self, data_file, batch_size, transform):
method __getitem__ (line 98) | def __getitem__(self, i):
method __len__ (line 101) | def __len__(self):
class SubDataset (line 105) | class SubDataset:
method __init__ (line 106) | def __init__(self, sub_meta, cl, transform=transforms.ToTensor(), targ...
method __getitem__ (line 112) | def __getitem__(self, i):
method __len__ (line 120) | def __len__(self):
class EpisodicBatchSampler (line 124) | class EpisodicBatchSampler(object):
method __init__ (line 125) | def __init__(self, n_classes, n_way, n_episodes):
method __len__ (line 130) | def __len__(self):
method __iter__ (line 133) | def __iter__(self):
FILE: SIB/data/feature_loader.py
class SimpleHDF5Dataset (line 5) | class SimpleHDF5Dataset:
method __init__ (line 6) | def __init__(self, file_handle = None):
method __getitem__ (line 18) | def __getitem__(self, i):
method __len__ (line 21) | def __len__(self):
function init_loader (line 24) | def init_loader(filename, get_path=False, path_file=None):
FILE: SIB/data/get_cifarfs.py
function download_file (line 24) | def download_file(url, filename):
FILE: SIB/dataloader.py
function PilLoaderRGB (line 26) | def PilLoaderRGB(imgPath) :
class EpisodeSampler (line 30) | class EpisodeSampler():
method __init__ (line 44) | def __init__(self, imgDir, nClsEpisode, nSupport, nQuery, transform, u...
method getEpisode (line 61) | def getEpisode(self):
class BatchSampler (line 107) | class BatchSampler():
method __init__ (line 122) | def __init__(self, imgDir, nClsEpisode, nSupport, nQuery, transform, u...
method getBatch (line 136) | def getBatch(self):
class ValImageFolder (line 159) | class ValImageFolder(data.Dataset):
method __init__ (line 170) | def __init__(self, episodeJson, imgDir, inputW, inputH, valTransform, ...
method __getitem__ (line 194) | def __getitem__(self, index):
method __len__ (line 221) | def __len__(self):
function ValLoader (line 228) | def ValLoader(episodeJson, imgDir, inputW, inputH, valTransform, useGPU) :
function TrainLoader (line 235) | def TrainLoader(batchSize, imgDir, trainTransform) :
FILE: SIB/dataset.py
function dataset_setting (line 19) | def dataset_setting(dataset, nSupport, image_size=80):
FILE: SIB/deconfound/DSIB.py
class DeconfoundedSIB (line 13) | class DeconfoundedSIB(nn.Module):
method __init__ (line 14) | def __init__(self, n_way, pretrain, n_splits, is_cosine_feature, d_fea...
method get_feat_dim (line 62) | def get_feat_dim(self):
method fuse_features (line 72) | def fuse_features(self, x1, x2):
method normalize (line 80) | def normalize(self, x, dim=1):
method fuse_proba (line 85) | def fuse_proba(self, p1, p2):
method forward (line 97) | def forward(self, support, labels, query, _):
FILE: SIB/deconfound/meta_toolkits.py
class FeatureProcessor (line 9) | class FeatureProcessor():
method __init__ (line 10) | def __init__(self, pretrain, n_splits, is_cosine_feature=False, d_feat...
method get_split_features (line 36) | def get_split_features(self, x, preprocess=False, center=None, preproc...
method nn_preprocess (line 54) | def nn_preprocess(self, data, center=None, preprocessing="l2n"):
method calc_pd (line 65) | def calc_pd(self, x, clf_idx):
method normalize (line 70) | def normalize(self, x, dim=1):
method get_d_feature (line 75) | def get_d_feature(self, x):
method get_features (line 93) | def get_features(self, support, query):
FILE: SIB/io_utils.py
function parse_args (line 20) | def parse_args(script):
function get_assigned_file (line 53) | def get_assigned_file(checkpoint_dir,num):
function get_resume_file (line 57) | def get_resume_file(checkpoint_dir):
function get_best_file (line 68) | def get_best_file(checkpoint_dir):
function print_accuracy (line 75) | def print_accuracy(acc):
function print_with_carriage_return (line 79) | def print_with_carriage_return(line):
function end_carriage_return_print (line 83) | def end_carriage_return_print():
function append_to_file (line 86) | def append_to_file(file, line):
function get_result_file (line 91) | def get_result_file(test_name, method_name):
function calc_recall_precision (line 101) | def calc_recall_precision(y, pred):
FILE: SIB/main.py
class Params (line 35) | class Params:
method __init__ (line 36) | def __init__(self):
FILE: SIB/main_feat.py
class ClassifierEval (line 42) | class ClassifierEval(nn.Module):
method __init__ (line 47) | def __init__(self, nKnovel, nFeat):
method apply_classification_weights (line 57) | def apply_classification_weights(self, features, cls_weights):
method forward (line 67) | def forward(self, features_supp, features_query):
class ClassifierTrain (line 80) | class ClassifierTrain(nn.Module):
method __init__ (line 81) | def __init__(self, nCls, nFeat=640, scaleCls = 10.):
method getWeight (line 101) | def getWeight(self):
method applyWeightCosine (line 104) | def applyWeightCosine(self, feature, weight, bias, scaleCls):
method forward (line 113) | def forward(self, feature):
class BaseTrainer (line 119) | class BaseTrainer:
method __init__ (line 120) | def __init__(self, trainLoader, valLoader, nbCls, nClsEpisode, nFeat,
method LrWarmUp (line 145) | def LrWarmUp(self, totalIter, lr):
method train (line 209) | def train(self, epoch):
method test (line 242) | def test(self, epoch):
FILE: SIB/networks.py
class ConvBlock (line 21) | class ConvBlock(nn.Module):
method __init__ (line 22) | def __init__(self, in_planes, out_planes):
method forward (line 34) | def forward(self, x):
class ConvNet_4_64 (line 38) | class ConvNet_4_64(nn.Module):
method __init__ (line 39) | def __init__(self, inputW=80, inputH=80):
method forward (line 58) | def forward(self, x):
class BasicBlock (line 64) | class BasicBlock(nn.Module):
method __init__ (line 65) | def __init__(self, in_planes, out_planes, stride, dropRate=0.0):
method forward (line 82) | def forward(self, x):
class NetworkBlock (line 101) | class NetworkBlock(nn.Module):
method __init__ (line 102) | def __init__(self, nb_layers, in_planes, out_planes, block, stride, dr...
method _make_layer (line 106) | def _make_layer(self, block, in_planes, out_planes, nb_layers, stride,...
method forward (line 114) | def forward(self, x):
class WideResNet (line 118) | class WideResNet(nn.Module):
method __init__ (line 119) | def __init__(self, depth=28, widen_factor=10, dropRate=0.0, userelu=Tr...
method forward (line 150) | def forward(self, x):
function label_to_1hot (line 166) | def label_to_1hot(label, K):
class dni_linear (line 174) | class dni_linear(nn.Module):
method __init__ (line 175) | def __init__(self, input_dims, dni_hidden_size=1024):
method forward (line 192) | def forward(self, x):
class LinearDiag (line 199) | class LinearDiag(nn.Module):
method __init__ (line 200) | def __init__(self, num_features, bias=False):
method forward (line 211) | def forward(self, X):
class FeatExemplarAvgBlock (line 219) | class FeatExemplarAvgBlock(nn.Module):
method __init__ (line 220) | def __init__(self, nFeat):
method forward (line 223) | def forward(self, features_train, labels_train):
function get_featnet (line 232) | def get_featnet(architecture, inputW=80, inputH=80):
FILE: SIB/sib.py
class ClassifierSIB (line 21) | class ClassifierSIB(nn.Module):
method __init__ (line 36) | def __init__(self, nKnovel, nFeat, q_steps):
method apply_classification_weights (line 54) | def apply_classification_weights(self, features, cls_weights):
method init_theta (line 73) | def init_theta(self, features_supp, labels_supp_1hot):
method refine_theta (line 89) | def refine_theta(self, theta, features_query, lr=1e-3):
method get_classification_weights (line 118) | def get_classification_weights(self, features_supp, labels_supp_1hot, ...
method forward (line 142) | def forward(self, features_supp, labels_supp, features_query, lr):
FILE: SIB/simple_shot_models/Conv4.py
function conv_block (line 6) | def conv_block(in_channels: int, out_channels: int) -> nn.Module:
class Conv4 (line 15) | class Conv4(nn.Module):
method __init__ (line 16) | def __init__(self, num_classes, remove_linear=False):
method forward (line 27) | def forward(self, x, feature=False):
FILE: SIB/simple_shot_models/DenseNet.py
class _DenseLayer (line 10) | class _DenseLayer(nn.Sequential):
method __init__ (line 11) | def __init__(self, num_input_features, growth_rate, bn_size, drop_rate):
method forward (line 23) | def forward(self, x):
class _DenseBlock (line 30) | class _DenseBlock(nn.Sequential):
method __init__ (line 31) | def __init__(self, num_layers, num_input_features, bn_size, growth_rat...
class _Transition (line 38) | class _Transition(nn.Sequential):
method __init__ (line 39) | def __init__(self, num_input_features, num_output_features):
class DenseNet (line 48) | class DenseNet(nn.Module):
method __init__ (line 62) | def __init__(self, growth_rate=32, block_config=(6, 12, 24, 16),
method forward (line 103) | def forward(self, x, feature=False):
function densenet121 (line 120) | def densenet121(**kwargs):
function densenet169 (line 129) | def densenet169(**kwargs):
function densenet201 (line 138) | def densenet201(**kwargs):
function densenet161 (line 147) | def densenet161(**kwargs):
FILE: SIB/simple_shot_models/MobileNet.py
class Block (line 12) | class Block(nn.Module):
method __init__ (line 15) | def __init__(self, in_planes, out_planes, stride=1):
method forward (line 23) | def forward(self, x):
class MobileNet (line 29) | class MobileNet(nn.Module):
method __init__ (line 33) | def __init__(self, num_classes=10, remove_linear=False):
method _make_layers (line 44) | def _make_layers(self, in_planes):
method forward (line 53) | def forward(self, x, feature=False):
FILE: SIB/simple_shot_models/ProtoNet.py
function get_metric (line 6) | def get_metric(metric_type):
class ProtoNet (line 16) | class ProtoNet(nn.Module):
method __init__ (line 18) | def __init__(self, feature_net, args=None):
method forward (line 28) | def forward(self, data, _=False):
FILE: SIB/simple_shot_models/ResNet.py
function conv3x3 (line 7) | def conv3x3(in_planes, out_planes, stride=1):
function conv1x1 (line 13) | def conv1x1(in_planes, out_planes, stride=1):
class BasicBlock (line 18) | class BasicBlock(nn.Module):
method __init__ (line 21) | def __init__(self, inplanes, planes, stride=1, downsample=None):
method forward (line 31) | def forward(self, x):
class Bottleneck (line 50) | class Bottleneck(nn.Module):
method __init__ (line 53) | def __init__(self, inplanes, planes, stride=1, downsample=None):
method forward (line 65) | def forward(self, x):
class ResNet (line 88) | class ResNet(nn.Module):
method __init__ (line 90) | def __init__(self, block, layers, num_classes=1000, zero_init_residual...
method _make_layer (line 125) | def _make_layer(self, block, planes, blocks, stride=1):
method forward (line 141) | def forward(self, x, feature=False):
function resnet10 (line 165) | def resnet10(**kwargs):
function resnet18 (line 172) | def resnet18(**kwargs):
function resnet34 (line 179) | def resnet34(**kwargs):
function resnet50 (line 186) | def resnet50(**kwargs):
function resnet101 (line 193) | def resnet101(**kwargs):
function resnet152 (line 200) | def resnet152(**kwargs):
FILE: SIB/simple_shot_models/WideResNet.py
function conv3x3 (line 9) | def conv3x3(in_planes, out_planes, stride=1):
function conv_init (line 13) | def conv_init(m):
class wide_basic (line 23) | class wide_basic(nn.Module):
method __init__ (line 24) | def __init__(self, in_planes, planes, dropout_rate, stride=1):
method forward (line 38) | def forward(self, x):
class Wide_ResNet (line 46) | class Wide_ResNet(nn.Module):
method __init__ (line 47) | def __init__(self, depth, widen_factor, dropout_rate, num_classes, rem...
method _wide_layer (line 75) | def _wide_layer(self, block, planes, num_blocks, dropout_rate, stride):
method forward (line 85) | def forward(self, x, feature=False):
function wideres (line 105) | def wideres(num_classes=64, remove_linear=False):
FILE: SIB/utils/config.py
function create_dirs (line 21) | def create_dirs(dirs):
function get_config_from_json (line 38) | def get_config_from_json(json_file):
function get_config_from_yaml (line 54) | def get_config_from_yaml(yaml_file):
function get_args (line 69) | def get_args():
function get_config (line 104) | def get_config():
FILE: SIB/utils/outils.py
class AverageMeter (line 30) | class AverageMeter(object):
method __init__ (line 32) | def __init__(self):
method reset (line 35) | def reset(self):
method update (line 41) | def update(self, val, n=1):
function getCi (line 48) | def getCi(accLog):
function accuracy (line 57) | def accuracy(output, target, topk=(1,), diff_scores=None):
function get_mean_and_std (line 79) | def get_mean_and_std(dataset):
function init_params (line 94) | def init_params(net):
function progress_bar (line 118) | def progress_bar(current, total, msg=None):
function format_time (line 162) | def format_time(seconds):
FILE: SIB/utils/utils.py
function set_random_seed (line 26) | def set_random_seed(seed=3):
function to_device (line 33) | def to_device(input, device):
function fast_hist (line 46) | def fast_hist(label_pred, label_true, n_class):
function convert_state_dict (line 53) | def convert_state_dict(state_dict):
function get_logger (line 67) | def get_logger(logdir, name):
Copy disabled (too large)
Download .json
Condensed preview — 195 files, each showing path, character count, and a content snippet. Download the .json file for the full structured content (15,725K chars).
[
{
"path": "LEO/LICENSE",
"chars": 11360,
"preview": "\n Apache License\n Version 2.0, January 2004\n "
},
{
"path": "LEO/config.py",
"chars": 11408,
"preview": "# coding=utf8\n# Copyright 2018 DeepMind Technologies Limited\n#\n# Licensed under the Apache License, Version 2.0 (the \"Li"
},
{
"path": "LEO/data.py",
"chars": 13483,
"preview": "# Copyright 2018 DeepMind Technologies Limited\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you"
},
{
"path": "LEO/ifsl_configs/__init__.py",
"chars": 57,
"preview": "from .baseline_config import *\nfrom .ifsl_config import *"
},
{
"path": "LEO/ifsl_configs/baseline_config.py",
"chars": 2536,
"preview": "class Config():\n def __init__(self):\n self.is_config = True\n\n\ndef mini_5_resnet_baseline():\n config = Confi"
},
{
"path": "LEO/ifsl_configs/ifsl_config.py",
"chars": 5977,
"preview": "class Config():\n def __init__(self):\n self.is_config = True\n\ndef mini_5_resnet_ifsl():\n config = Config()\n "
},
{
"path": "LEO/model.py",
"chars": 28576,
"preview": "# Copyright 2018 DeepMind Technologies Limited\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you"
},
{
"path": "LEO/model_test.py",
"chars": 12157,
"preview": "# Copyright 2018 DeepMind Technologies Limited\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you"
},
{
"path": "LEO/readme.md",
"chars": 1549,
"preview": "# LEO + IFSL\n\nThis project is based on the official code base of the paper [Meta-Learning with Latent Embedding Optimiza"
},
{
"path": "LEO/runner.py",
"chars": 14062,
"preview": "# Copyright 2018 DeepMind Technologies Limited\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you"
},
{
"path": "LEO/utils.py",
"chars": 4639,
"preview": "# Copyright 2018 DeepMind Technologies Limited\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you"
},
{
"path": "MAML_MN_FT/README.md",
"chars": 1656,
"preview": "# IFSL + Matching Networks, MAML\r\n\r\nThis project is based on the official code base of the paper [A Closer Look At Few-S"
},
{
"path": "MAML_MN_FT/backbone.py",
"chars": 50613,
"preview": "# This code is modified from https://github.com/facebookresearch/low-shot-shrink-hallucinate\n\nimport torch\nfrom torch.au"
},
{
"path": "MAML_MN_FT/configs.py",
"chars": 563,
"preview": "save_dir = '/data2/yuezhongqi/Model/CloserLookFSL/' # Change to desired saving dir\ndata_dir = {}"
},
{
"path": "MAML_MN_FT/data/__init__.py",
"chars": 109,
"preview": "from . import datamgr\nfrom . import dataset\nfrom . import additional_transforms\nfrom . import feature_loader\n"
},
{
"path": "MAML_MN_FT/data/additional_transforms.py",
"chars": 850,
"preview": "# Copyright 2017-present, Facebook, Inc.\n# All rights reserved.\n#\n# This source code is licensed under the license found"
},
{
"path": "MAML_MN_FT/data/datamgr.py",
"chars": 3717,
"preview": "# This code is modified from https://github.com/facebookresearch/low-shot-shrink-hallucinate\n\nimport torch\nfrom PIL impo"
},
{
"path": "MAML_MN_FT/data/dataset.py",
"chars": 4491,
"preview": "# This code is modified from https://github.com/facebookresearch/low-shot-shrink-hallucinate\n\nimport torch\nfrom PIL impo"
},
{
"path": "MAML_MN_FT/data/feature_loader.py",
"chars": 1587,
"preview": "import torch\nimport numpy as np\nimport h5py\n\nclass SimpleHDF5Dataset:\n def __init__(self, file_handle = None):\n "
},
{
"path": "MAML_MN_FT/filelists/CUB/attributes.txt",
"chars": 8992,
"preview": "1 has_bill_shape::curved_(up_or_down)\n2 has_bill_shape::dagger\n3 has_bill_shape::hooked\n4 has_bill_shape::needle\n5 has_b"
},
{
"path": "MAML_MN_FT/filelists/CUB/base.json",
"chars": 792863,
"preview": "{\"label_names\": [\"001.Black_footed_Albatross\",\"002.Laysan_Albatross\",\"003.Sooty_Albatross\",\"004.Groove_billed_Ani\",\"005."
},
{
"path": "MAML_MN_FT/filelists/CUB/download_CUB.sh",
"chars": 156,
"preview": "#!/usr/bin/env bash\nwget http://www.vision.caltech.edu/visipedia-data/CUB-200-2011/CUB_200_2011.tgz\ntar -zxvf CUB_200_20"
},
{
"path": "MAML_MN_FT/filelists/CUB/novel.json",
"chars": 405825,
"preview": "{\"label_names\": [\"001.Black_footed_Albatross\",\"002.Laysan_Albatross\",\"003.Sooty_Albatross\",\"004.Groove_billed_Ani\",\"005."
},
{
"path": "MAML_MN_FT/filelists/CUB/val.json",
"chars": 397941,
"preview": "{\"label_names\": [\"001.Black_footed_Albatross\",\"002.Laysan_Albatross\",\"003.Sooty_Albatross\",\"004.Groove_billed_Ani\",\"005."
},
{
"path": "MAML_MN_FT/filelists/CUB/write_CUB_filelist.py",
"chars": 2096,
"preview": "import numpy as np\nfrom os import listdir\nfrom os.path import isfile, isdir, join\nimport os\nimport json\nimport random\n\nc"
},
{
"path": "MAML_MN_FT/filelists/miniImagenet/all.json",
"chars": 5537137,
"preview": "{\"label_names\": [\"n01532829\",\"n01558993\",\"n01704323\",\"n01749939\",\"n01770081\",\"n01843383\",\"n01910747\",\"n02074367\",\"n02089"
},
{
"path": "MAML_MN_FT/filelists/miniImagenet/base.json",
"chars": 3681221,
"preview": "{\"label_names\": [\"n02687172\",\"n04258138\",\"n03347037\",\"n03888605\",\"n03062245\",\"n01910747\",\"n02120079\",\"n02457408\",\"n04251"
},
{
"path": "MAML_MN_FT/filelists/miniImagenet/download_miniImagenet.sh",
"chars": 497,
"preview": "#!/usr/bin/env bash\nwget https://raw.githubusercontent.com/twitter/meta-learning-lstm/master/data/miniImagenet/train.csv"
},
{
"path": "MAML_MN_FT/filelists/miniImagenet/novel.json",
"chars": 1062293,
"preview": "{\"label_names\": [\"n04149813\",\"n01981276\",\"n02219486\",\"n02443484\",\"n02871525\",\"n02110063\",\"n04418357\",\"n03146219\",\"n02129"
},
{
"path": "MAML_MN_FT/filelists/miniImagenet/test.csv",
"chars": 384015,
"preview": "filename,label\nn0193011200000001.jpg,n01930112\nn0193011200000004.jpg,n01930112\nn0193011200000005.jpg,n01930112\nn01930112"
},
{
"path": "MAML_MN_FT/filelists/miniImagenet/train.csv",
"chars": 1228815,
"preview": "filename,label\nn0153282900000005.jpg,n01532829\nn0153282900000006.jpg,n01532829\nn0153282900000007.jpg,n01532829\nn01532829"
},
{
"path": "MAML_MN_FT/filelists/miniImagenet/val.csv",
"chars": 307215,
"preview": "filename,label\nn0185567200000003.jpg,n01855672\nn0185567200000004.jpg,n01855672\nn0185567200000010.jpg,n01855672\nn01855672"
},
{
"path": "MAML_MN_FT/filelists/miniImagenet/val.json",
"chars": 839045,
"preview": "{\"label_names\": [\"n03535780\",\"n02114548\",\"n02971356\",\"n02138441\",\"n03980874\",\"n09256479\",\"n02981792\",\"n03417042\",\"n03075"
},
{
"path": "MAML_MN_FT/filelists/miniImagenet/write_cross_filelist.py",
"chars": 2485,
"preview": "import numpy as np\nfrom os import listdir\nfrom os.path import isfile, isdir, join\nimport os\nimport json\nimport random\nim"
},
{
"path": "MAML_MN_FT/filelists/miniImagenet/write_miniImagenet_filelist.py",
"chars": 2347,
"preview": "import numpy as np\nfrom os import listdir\nfrom os.path import isfile, isdir, join\nimport os\nimport json\nimport random\nim"
},
{
"path": "MAML_MN_FT/filelists/tiered/write_tiered_filelist.py",
"chars": 1522,
"preview": "import numpy as np\nfrom os import listdir\nfrom os.path import isfile, isdir, join\nimport os\nimport json\nimport random\nim"
},
{
"path": "MAML_MN_FT/io_utils.py",
"chars": 6166,
"preview": "import numpy as np\nimport os\nimport glob\nimport argparse\nimport backbone\nimport sys\nimport configs\nfrom datetime import "
},
{
"path": "MAML_MN_FT/main.py",
"chars": 379,
"preview": "from io_utils import parse_args\r\nfrom tests.MetaTrain import MetaTrain\r\n\r\ndef func_not_found(): # just in case we dont "
},
{
"path": "MAML_MN_FT/methods/DMAML.py",
"chars": 9001,
"preview": "# This code is modified from https://github.com/dragen1860/MAML-Pytorch and https://github.com/katerakelly/pytorch-maml "
},
{
"path": "MAML_MN_FT/methods/DMatchingNet.py",
"chars": 8245,
"preview": "# This code is modified from https://github.com/facebookresearch/low-shot-shrink-hallucinate\nimport torch\nimport torch.n"
},
{
"path": "MAML_MN_FT/methods/MethodTester.py",
"chars": 30075,
"preview": "import os\nimport data.feature_loader as feat_loader\nimport numpy as np\nimport random\nimport torch\nimport configs\nfrom io"
},
{
"path": "MAML_MN_FT/methods/NNEDSplitNew.py",
"chars": 10586,
"preview": "import backbone\nimport torch\nfrom torch.autograd import Variable\nimport numpy as np\nfrom methods.meta_template import Me"
},
{
"path": "MAML_MN_FT/methods/PretrainedModel.py",
"chars": 19990,
"preview": "import numpy as np\nimport torch\nimport torch.nn as nn\nfrom io_utils import model_dict, get_best_file, get_assigned_file,"
},
{
"path": "MAML_MN_FT/methods/VanillaMAML.py",
"chars": 3432,
"preview": "# This code is modified from https://github.com/dragen1860/MAML-Pytorch and https://github.com/katerakelly/pytorch-maml "
},
{
"path": "MAML_MN_FT/methods/VanillaMatchingNet.py",
"chars": 1829,
"preview": "# This code is modified from https://github.com/facebookresearch/low-shot-shrink-hallucinate\nimport torch\nimport torch.n"
},
{
"path": "MAML_MN_FT/methods/__init__.py",
"chars": 163,
"preview": "from . import meta_template\r\nfrom . import MethodTester\r\nfrom . import PretrainedModel\r\nfrom . import DMAML\r\nfrom . impo"
},
{
"path": "MAML_MN_FT/methods/meta_template.py",
"chars": 8859,
"preview": "import backbone\nimport torch\nimport torch.nn as nn\nfrom torch.autograd import Variable\nfrom data.datamgr import Transfor"
},
{
"path": "MAML_MN_FT/methods/meta_toolkits.py",
"chars": 16719,
"preview": "import torch\nimport torch.nn as nn\nimport backbone\nfrom torch.autograd import Variable\nimport numpy as np\nimport math\n\n\n"
},
{
"path": "MAML_MN_FT/models/FeatWRN.py",
"chars": 3321,
"preview": "import torch\nimport torch.nn as nn\nimport torch.nn.init as init\nimport torch.nn.functional as F\nfrom torch.autograd impo"
},
{
"path": "MAML_MN_FT/models/SimpleShotResNet.py",
"chars": 6037,
"preview": "import torch.nn as nn\n\n__all__ = ['resnet10', 'resnet18', 'resnet34', 'resnet50', 'resnet101',\n 'resnet152']\n\n"
},
{
"path": "MAML_MN_FT/models/SimpleShotWideResNet.py",
"chars": 3941,
"preview": "import numpy as np\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport torch.nn.init as init\nfrom "
},
{
"path": "MAML_MN_FT/models/__init__.py",
"chars": 89,
"preview": "from . import FeatWRN\r\nfrom . import SimpleShotResNet\r\nfrom . import SimpleShotWideResNet"
},
{
"path": "MAML_MN_FT/save_features.py",
"chars": 13348,
"preview": "import numpy as np\nimport torch\nfrom torch.autograd import Variable\nimport os\nimport glob\nimport h5py\n\nimport configs\nim"
},
{
"path": "MAML_MN_FT/tests/MetaTrain.py",
"chars": 30083,
"preview": "from methods.MethodTester import MethodTester\nfrom methods.VanillaMAML import VanillaMAML\nfrom methods.DMAML import DMAM"
},
{
"path": "MAML_MN_FT/tests/__init__.py",
"chars": 23,
"preview": "from . import MetaTrain"
},
{
"path": "MAML_MN_FT/utils.py",
"chars": 1052,
"preview": "import torch\nimport numpy as np\n\ndef one_hot(y, num_class): \n return torch.zeros((len(y), num_class)).scatter"
},
{
"path": "MTL/README.md",
"chars": 1588,
"preview": "# MTL + IFSL\n\nThis project is based on the official code base of the paper [Meta Transfer Learning for Few-Shot Learning"
},
{
"path": "MTL/configs/__init__.py",
"chars": 95,
"preview": "from .baseline_config import *\nfrom .ifsl_resnet_config import *\nfrom .ifsl_wrn_config import *"
},
{
"path": "MTL/configs/baseline_config.py",
"chars": 3654,
"preview": "class Params():\n def __init__(self):\n self.is_param = True\n\n# python main.py --config=mini_5_resnet_baseline -"
},
{
"path": "MTL/configs/ifsl_resnet_config.py",
"chars": 4405,
"preview": "class Params():\n def __init__(self):\n self.is_param = True\n\n# python main.py --config=mini_5_resnet_d --gpu=\nd"
},
{
"path": "MTL/configs/ifsl_wrn_config.py",
"chars": 4389,
"preview": "class Params():\n def __init__(self):\n self.is_param = True\n\n# python main.py --config=mini_5_wrn_d --gpu=\ndef "
},
{
"path": "MTL/dataloader/__init__.py",
"chars": 389,
"preview": "##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\r\n## Created by: Yaoyao Liu\r\n## Tianjin Unive"
},
{
"path": "MTL/dataloader/dataset_loader.py",
"chars": 4161,
"preview": "##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n## Created by: Yaoyao Liu\n## Modified from: "
},
{
"path": "MTL/dataloader/samplers.py",
"chars": 1381,
"preview": "##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n## Created by: Yaoyao Liu\n## Modified from: "
},
{
"path": "MTL/main.py",
"chars": 6020,
"preview": "##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n## Created by: Yaoyao Liu\n## Tianjin Univers"
},
{
"path": "MTL/models/IFSL.py",
"chars": 25829,
"preview": "from models.resnet_mtl import ResNetMtl\nimport torch\nimport torch.nn as nn\nimport os.path as osp\nimport tqdm\nfrom datalo"
},
{
"path": "MTL/models/IFSL_modules.py",
"chars": 7921,
"preview": "import torch\nimport torch.nn as nn\nfrom torch.autograd import Variable\nimport torch.nn.functional as F\n\n\nclass Linear_fw"
},
{
"path": "MTL/models/IFSL_pretrain.py",
"chars": 6150,
"preview": "import os\nfrom models.ResNet10 import ResNet10\nfrom models.WRN28 import WideRes28\nimport torch\nimport numpy as np\nimport"
},
{
"path": "MTL/models/ResNet10.py",
"chars": 3566,
"preview": "import torch.nn as nn\nfrom models.conv2d_mtl import Conv2dMtl\nfrom models.resnet_mtl import BasicBlockMtl\nfrom models.re"
},
{
"path": "MTL/models/WRN28.py",
"chars": 4796,
"preview": "import torch.nn as nn\nfrom models.conv2d_mtl import Conv2dMtl\nimport torch.nn.functional as F\n\n\nclass wide_basic(nn.Modu"
},
{
"path": "MTL/models/__init__.py",
"chars": 553,
"preview": "##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\r\n## Created by: Yaoyao Liu\r\n## Tianjin Unive"
},
{
"path": "MTL/models/conv2d_mtl.py",
"chars": 4195,
"preview": "##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n## Created by: Yaoyao Liu\n## Modified from: "
},
{
"path": "MTL/models/mtl.py",
"chars": 8728,
"preview": "##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n## Created by: Yaoyao Liu\n## Tianjin Univers"
},
{
"path": "MTL/models/resnet_mtl.py",
"chars": 6841,
"preview": "##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n## Created by: Yaoyao Liu\n## Modified from: "
},
{
"path": "MTL/run_meta.py",
"chars": 1732,
"preview": "##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n## Created by: Yaoyao Liu\n## Tianjin Univers"
},
{
"path": "MTL/run_pre.py",
"chars": 1079,
"preview": "##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n## Created by: Yaoyao Liu\n## Tianjin Univers"
},
{
"path": "MTL/run_pre_clfs.py",
"chars": 1222,
"preview": "##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n## Created by: Yaoyao Liu\n## Tianjin Univers"
},
{
"path": "MTL/run_test.py",
"chars": 1577,
"preview": "##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n## Created by: Yaoyao Liu\n## Tianjin Univers"
},
{
"path": "MTL/setup.cfg",
"chars": 118,
"preview": "[pep8]\nignore = E202\nmax-line-length = 160\n[flake8]\nignore = E202, E712, W293, W391, W292, E266\nmax-line-length = 160\n"
},
{
"path": "MTL/trainer/__init__.py",
"chars": 426,
"preview": "##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\r\n## Created by: Yaoyao Liu\r\n## Tianjin Unive"
},
{
"path": "MTL/trainer/meta.py",
"chars": 17629,
"preview": "##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n## Created by: Yaoyao Liu\n## Modified from: "
},
{
"path": "MTL/trainer/pre.py",
"chars": 9314,
"preview": "##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n## Created by: Yaoyao Liu\n## Tianjin Univers"
},
{
"path": "MTL/utils/__init__.py",
"chars": 432,
"preview": "##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\r\n## Created by: Yaoyao Liu\r\n## Tianjin Unive"
},
{
"path": "MTL/utils/gpu_tools.py",
"chars": 547,
"preview": "##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n## Created by: Yaoyao Liu\n## Tianjin Univers"
},
{
"path": "MTL/utils/hacc.py",
"chars": 2907,
"preview": "import numpy as np\n\n\nclass Hacc():\n def __init__(self, splits=10, topk=10):\n self.hardness = []\n self.c"
},
{
"path": "MTL/utils/misc.py",
"chars": 7224,
"preview": "##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n## Created by: Yaoyao Liu\n## Modified from: "
},
{
"path": "SIB/PretrainedModel.py",
"chars": 16171,
"preview": "import numpy as np\nimport torch\nimport torch.nn as nn\nfrom io_utils import get_best_file, get_assigned_file, print_with_"
},
{
"path": "SIB/algorithm.py",
"chars": 14944,
"preview": "# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.\n#\n# Licensed under the Apache License, Versi"
},
{
"path": "SIB/backbone.py",
"chars": 50613,
"preview": "# This code is modified from https://github.com/facebookresearch/low-shot-shrink-hallucinate\n\nimport torch\nfrom torch.au"
},
{
"path": "SIB/config/minires_1_baseline.yaml",
"chars": 1398,
"preview": "# Few-shot dataset\nnClsEpisode: 5 # number of categories in each episode\nnSupport: 1 # number of samples per category in"
},
{
"path": "SIB/config/minires_1_ifsl.yaml",
"chars": 1393,
"preview": "# Few-shot dataset\nnClsEpisode: 5 # number of categories in each episode\nnSupport: 1 # number of samples per category in"
},
{
"path": "SIB/config/minires_5_baseline.yaml",
"chars": 1398,
"preview": "# Few-shot dataset\nnClsEpisode: 5 # number of categories in each episode\nnSupport: 5 # number of samples per category in"
},
{
"path": "SIB/config/minires_5_ifsl.yaml",
"chars": 1393,
"preview": "# Few-shot dataset\nnClsEpisode: 5 # number of categories in each episode\nnSupport: 5 # number of samples per category in"
},
{
"path": "SIB/config/miniwrn_1_baseline.yaml",
"chars": 1382,
"preview": "# Few-shot dataset\nnClsEpisode: 5 # number of categories in each episode\nnSupport: 1 # number of samples per category in"
},
{
"path": "SIB/config/miniwrn_1_ifsl.yaml",
"chars": 1376,
"preview": "# Few-shot dataset\nnClsEpisode: 5 # number of categories in each episode\nnSupport: 1 # number of samples per category in"
},
{
"path": "SIB/config/miniwrn_5_baseline.yaml",
"chars": 1382,
"preview": "# Few-shot dataset\nnClsEpisode: 5 # number of categories in each episode\nnSupport: 5 # number of samples per category in"
},
{
"path": "SIB/config/miniwrn_5_ifsl.yaml",
"chars": 1374,
"preview": "# Few-shot dataset\nnClsEpisode: 5 # number of categories in each episode\nnSupport: 5 # number of samples per category in"
},
{
"path": "SIB/config/tieredres_1_baseline.yaml",
"chars": 1382,
"preview": "# Few-shot dataset\nnClsEpisode: 5 # number of categories in each episode\nnSupport: 1 # number of samples per category in"
},
{
"path": "SIB/config/tieredres_1_ifsl.yaml",
"chars": 1377,
"preview": "# Few-shot dataset\nnClsEpisode: 5 # number of categories in each episode\nnSupport: 1 # number of samples per category in"
},
{
"path": "SIB/config/tieredres_5_baseline.yaml",
"chars": 1382,
"preview": "# Few-shot dataset\nnClsEpisode: 5 # number of categories in each episode\nnSupport: 5 # number of samples per category in"
},
{
"path": "SIB/config/tieredres_5_ifsl.yaml",
"chars": 1377,
"preview": "# Few-shot dataset\nnClsEpisode: 5 # number of categories in each episode\nnSupport: 5 # number of samples per category in"
},
{
"path": "SIB/config/tieredwrn_1_baseline.yaml",
"chars": 1385,
"preview": "# Few-shot dataset\nnClsEpisode: 5 # number of categories in each episode\nnSupport: 1 # number of samples per category in"
},
{
"path": "SIB/config/tieredwrn_1_ifsl.yaml",
"chars": 1380,
"preview": "# Few-shot dataset\nnClsEpisode: 5 # number of categories in each episode\nnSupport: 1 # number of samples per category in"
},
{
"path": "SIB/config/tieredwrn_5_baseline.yaml",
"chars": 1454,
"preview": "# Few-shot dataset\nnClsEpisode: 5 # number of categories in each episode\nnSupport: 5 # number of samples per category in"
},
{
"path": "SIB/config/tieredwrn_5_ifsl.yaml",
"chars": 1381,
"preview": "# Few-shot dataset\nnClsEpisode: 5 # number of categories in each episode\nnSupport: 5 # number of samples per category in"
},
{
"path": "SIB/data/__init__.py",
"chars": 108,
"preview": "from . import additional_transforms\nfrom . import datamgr\nfrom . import dataset\nfrom . import feature_loader"
},
{
"path": "SIB/data/additional_transforms.py",
"chars": 850,
"preview": "# Copyright 2017-present, Facebook, Inc.\n# All rights reserved.\n#\n# This source code is licensed under the license found"
},
{
"path": "SIB/data/datamgr.py",
"chars": 3717,
"preview": "# This code is modified from https://github.com/facebookresearch/low-shot-shrink-hallucinate\n\nimport torch\nfrom PIL impo"
},
{
"path": "SIB/data/dataset.py",
"chars": 4507,
"preview": "# This code is modified from https://github.com/facebookresearch/low-shot-shrink-hallucinate\n\nimport torch\nfrom PIL impo"
},
{
"path": "SIB/data/download_cifarfs.sh",
"chars": 440,
"preview": "wget https://www.dropbox.com/s/wuxb1wlahado3nq/cifar-fs-splits.zip?dl=0\nmv cifar-fs-splits.zip?dl=0 cifar-fs-splits.zip\n"
},
{
"path": "SIB/data/download_miniimagenet.sh",
"chars": 909,
"preview": "wget https://www.dropbox.com/s/a2a0bll17f5dvhr/Mini-ImageNet.zip?dl=0\nmv Mini-ImageNet.zip?dl=0.1 Mini-ImageNet.zip\nunzi"
},
{
"path": "SIB/data/feature_loader.py",
"chars": 1587,
"preview": "import torch\nimport numpy as np\nimport h5py\n\nclass SimpleHDF5Dataset:\n def __init__(self, file_handle = None):\n "
},
{
"path": "SIB/data/get_cifarfs.py",
"chars": 4010,
"preview": "\"\"\"\n@author: Arnout Devos\n2018/12/06\nMIT License\n\nScript for downloading, and reorganizing CIFAR few shot from CIFAR-100"
},
{
"path": "SIB/dataloader.py",
"chars": 11396,
"preview": "# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.\n#\n# Licensed under the Apache License, Versi"
},
{
"path": "SIB/dataset.py",
"chars": 5690,
"preview": "# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.\n#\n# Licensed under the Apache License, Versi"
},
{
"path": "SIB/deconfound/DSIB.py",
"chars": 6295,
"preview": "# This code is modified from https://github.com/dragen1860/MAML-Pytorch and https://github.com/katerakelly/pytorch-maml "
},
{
"path": "SIB/deconfound/__init__.py",
"chars": 46,
"preview": "from . import DSIB\nfrom . import meta_toolkits"
},
{
"path": "SIB/deconfound/meta_toolkits.py",
"chars": 5197,
"preview": "import torch\nimport torch.nn as nn\nimport backbone\nfrom torch.autograd import Variable\nimport numpy as np\nimport math\n\n\n"
},
{
"path": "SIB/dfsl_configs.py",
"chars": 152,
"preview": "save_dir = '/data2/yuezhongqi/Model/CloserLookFSL/'\nsimple_shot_dir = \"/data2/yuezhongqi/Model/simple"
},
{
"path": "SIB/io_utils.py",
"chars": 6182,
"preview": "import numpy as np\nimport os\nimport glob\nimport argparse\nimport backbone\nimport sys\nimport dfsl_configs as configs\nfrom "
},
{
"path": "SIB/main.py",
"chars": 6707,
"preview": "# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.\n#\n# Licensed under the Apache License, Versi"
},
{
"path": "SIB/main_feat.py",
"chars": 15298,
"preview": "# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.\n#\n# Licensed under the Apache License, Versi"
},
{
"path": "SIB/networks.py",
"chars": 8900,
"preview": "# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.\n#\n# Licensed under the Apache License, Versi"
},
{
"path": "SIB/readme.md",
"chars": 1415,
"preview": "# SIB + IFSL\n\nThis project is based on the official code base of the paper [Empirical Bayes Transductive Meta-Learning w"
},
{
"path": "SIB/requirements.txt",
"chars": 32,
"preview": "tensorboardX\neasydict\ntqdm\nbypy\n"
},
{
"path": "SIB/setup.cfg",
"chars": 118,
"preview": "[pep8]\nignore = E202\nmax-line-length = 160\n[flake8]\nignore = E202, E712, W293, W391, W292, E266\nmax-line-length = 160\n"
},
{
"path": "SIB/sib.py",
"chars": 7370,
"preview": "# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.\n#\n# Licensed under the Apache License, Versi"
},
{
"path": "SIB/simple_shot_models/Conv4.py",
"chars": 1131,
"preview": "from torch import nn\n\n__all__ = ['Conv4']\n\n\ndef conv_block(in_channels: int, out_channels: int) -> nn.Module:\n return"
},
{
"path": "SIB/simple_shot_models/DenseNet.py",
"chars": 6239,
"preview": "from collections import OrderedDict\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\n__all__ = ['den"
},
{
"path": "SIB/simple_shot_models/MobileNet.py",
"chars": 2340,
"preview": "'''MobileNet in PyTorch.\n\nSee the paper \"MobileNets: Efficient Convolutional Neural Networks for Mobile Vision Applicati"
},
{
"path": "SIB/simple_shot_models/ProtoNet.py",
"chars": 1501,
"preview": "import torch.nn as nn\nimport torch\nimport torch.nn.functional as F\n\n\ndef get_metric(metric_type):\n METRICS = {\n "
},
{
"path": "SIB/simple_shot_models/ResNet.py",
"chars": 6037,
"preview": "import torch.nn as nn\n\n__all__ = ['resnet10', 'resnet18', 'resnet34', 'resnet50', 'resnet101',\n 'resnet152']\n\n"
},
{
"path": "SIB/simple_shot_models/WideResNet.py",
"chars": 3941,
"preview": "import numpy as np\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport torch.nn.init as init\nfrom "
},
{
"path": "SIB/simple_shot_models/__init__.py",
"chars": 158,
"preview": "from .ResNet import *\nfrom .DenseNet import *\nfrom .Conv4 import Conv4 as conv4\nfrom .MobileNet import MobileNet as mobi"
},
{
"path": "SIB/utils/__init__.py",
"chars": 61,
"preview": "from . import config\nfrom . import outils\nfrom . import utils"
},
{
"path": "SIB/utils/config.py",
"chars": 3918,
"preview": "# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.\n#\n# Licensed under the Apache License, Versi"
},
{
"path": "SIB/utils/outils.py",
"chars": 5455,
"preview": "# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.\n#\n# Licensed under the Apache License, Versi"
},
{
"path": "SIB/utils/utils.py",
"chars": 2753,
"preview": "# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.\n#\n# Licensed under the Apache License, Versi"
},
{
"path": "readme.md",
"chars": 2548,
"preview": "# Interventional Few-Shot Learning\n\nThis project provides a strong Baseline with WRN28-10 and ResNet-10 backbone for the"
}
]
// ... and 56 more files (download for full content)
About this extraction
This page contains the full source code of the yue-zhongqi/ifsl GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 195 files (14.7 MB), approximately 3.8M tokens, and a symbol index with 1092 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.
Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.