[
  {
    "path": "README.md",
    "content": "# Rethinking the Role of Demonstrations: What Makes In-Context Learning Work?\n\nThis includes an original implementation of \"[Rethinking the Role of Demonstrations: What Makes In-Context Learning Work?][paper]\" by [Sewon Min][sewon], [Xinxi Lyu][xinxi], [Ari Holtzman][ari], [Mikel Artetxe][mikel], [Mike Lewis][mike], [Hannaneh Hajishirzi][hanna], and [Luke Zettlemoyer][luke].\n\nThis code provides:\n- Codes for creating the variants of the demonstrations used in the experiments.\n- Commands to run the models and get numbers reported in the paper, based on the [MetaICL][metaicl] codebase.\n\nPlease leave issues for any questions about the paper or the code.\n\nIf you find our code or paper useful, please cite the paper:\n```\n@inproceedings{ min2022rethinking,\n    title={ Rethinking the Role of Demonstrations: What makes In-context Learning Work? },\n    author={ Min, Sewon and Lyu, Xinxi and Holtzman, Ari and Artetxe, Mikel and Lewis, Mike and Hajishirzi, Hannaneh and Zettlemoyer, Luke },\n    booktitle={ EMNLP },\n    year={ 2022 }\n}\n```\n\n### Announcements\n* 07/25/2022: The code also supports running GPT-3 now.\n* 02/25/2022: The code supports running GPT-2, MetaICL and GPT-J for now. Please contact authors for running other models.\n\n## Content\n\n1. [Preparation](#preparation)\n2. [Reproducing Main Experiments](#reproducing-main-experiments) (Section 4.1 of the paper)\n    * [No Demonstrations](#no-demonstrations)\n    * [Demonstrations with gold labels](#demonstrations-with-gold-labels)\n    * [Demonstrations with random labels](#demonstrations-with-random-labels)\n3. [Reproducing Ablations](#reproducing-ablations) (Section 4.2 of the paper)\n    * [Number of correct labels](#number-of-correct-labels)\n    * [Number of input-label pairs in the demonstrations](#number-of-input-label-pairs-in-the-demonstrations)\n    * [Using manual templates](#using-manual-templates)\n4. [Reproducing Analysis](#reproducing-analysis) (Section 5 of the paper)\n    * [Demonstrations with OOD input text](#demonstrations-with-ood-input-text)\n    * [Demonstrations with random english words](#demonstrations-with-random-english-words)\n    * [Demonstrations with random labels only (no inputs)](#demonstrations-with-random-labels-only-no-inputs)\n    * [Demonstrations with no labels (inputs only)](#demonstrations-with-no-labels-inputs-only)\n\n## Preparation\n\nThe code is tested with python 3.8.\n\nThe data and the code are based on the MetaICL codebase.\n```bash\ngit remote add metaicl https://github.com/facebookresearch/MetaICL.git\ngit pull metaicl main --allow-unrelated-histories -X ours\n```\n\nInstall the data dependencies and download the data.\n```bash\nconda conda create --name metaicl-data python=3.8\nconda activate metaicl-data\npip install datasets==1.4.0 wget\ncd preprocess\npython _build_gym.py --build --n_proc=40 --do_test\n```\n\nThis uses `k=16` by default. If you want to run ablations with varying `k`, please also run the following.\n```bash\npython _build_gym.py --build --n_proc=40 --do_test --test_k {4|8|32}\n```\n\nAfter preprocesisng is done, come back to the main directory.\n```bash\ncd ../\nconda deactivate\n```\n\nNow, install the model dependencies to run the model. Please note that the Transformer version is not compatible to the datasets library used to download the data, so make sure to use a different environment.\n```\nconda conda create --name metaicl python=3.8\nconda activate metaicl\npip install torch==1.9.0\npip install git+https://github.com/huggingface/transformers.git@c37573806ab3526dd805c49cbe2489ad4d68a9d7\n```\n\n(Optional) Install OpenAI Python Library for running GPT-3\n```\npip install openai\n```\n\n## Reproducing Main Experiments\n\nThis is for reproducing experiments in Section 4.1 of the paper.\nEvaluation datasets are:\n* Classification (16 datasets): `financial_phrasebank`,`poem_sentiment`,`glue-wnli`,`climate_fever`,`glue-rte`,`superglue-cb`,`sick`,`medical_questions_pairs`,`glue-mrpc`,`hate_speech18`,`ethos-national_origin`,`ethos-race`,`ethos-religion`,`tweet_eval-hate`,`tweet_eval-stance_atheism`,`tweet_eval-stance_feminist`\n* Multi-choice (10 datasets): `quarel`,`openbookqa`,`qasc`,`commonsense_qa`,`ai2_arc`,`codah`,`superglue-copa`,`dream`,`quartz-with_knowledge`,`quartz-no_knowledge`\n\n#### No Demonstrations\n\nTo run the evaluation of No-Demonstrations:\n\n```bash\n# Direct GPT-2 Large\npython test.py --dataset {dataset} --gpt2 gpt2-large --method direct --out_dir out/gpt2-large --do_zeroshot\n# Channel GPT-2 Large\npython test.py --dataset {dataset} --gpt2 gpt2-large --method channel --out_dir out/gpt2-large --do_zeroshot\n# Direct MetaICL\npython test.py --dataset {dataset} --gpt2 metaicl --method direct --out_dir out/metaicl --do_zeroshot\n# Channel MetaICL\npython test.py --dataset {dataset} --gpt2 channel-metaicl --method channel --out_dir out/channel-metaicl --do_zeroshot\n# Direct GPT-J\npython test.py --dataset {dataset} --gpt2 gpt-j-6B --method direct --out_dir out/gpt-j --do_zeroshot\n# Channel GPT-J\npython test.py --dataset {dataset} --gpt2 gpt-j-6B --method channel --out_dir out/gpt-j --do_zeroshot\n# GPT-3\npython test_gpt3.py --dataset {dataset} --gpt3 {ada|babbage|curie|davinci} --method {direct|channel} --out_dir out/gpt3 --do_zeroshot --api {API key}\n```\nNote that `test.py` and `test_gpt3.py` does not support multi-gpu for inference.\n\nOther useful flags:\n* `--test_batch_size`: can be adjusted based on your GPU memory availability. With a 32GB GPU, you can use 64 for GPT-2 Large & MetaICL, and 16 for GPT-J **with no demonstrations**. Later, when you run the code **with demonstrations**, decreasing the batch size by 4 times typically works, e.g., 16 (GPT-2 Large & MetaICL) and 4 (GPT-J) with a 32GB GPU.\n* `--log_file`: if you want to save logs in a file, you can specify the path to the log file.\n\nNotes for running GPT-3:\n* You can create/check your OpenAI API keys by visiting [this link](https://beta.openai.com/account/api-keys).\n* Running with GPT-3 can be expensive, and different models of GPT-3 comes with different costs. Please check [this link](https://openai.com/api/pricing/) to evaluate the cost before running each experiment.\n* The responses from the GPT-3 API are cached in the `out_dir`.\n\nFrom now on, we will use the above commands as a default and tell you which flags you need to add.\n\n\n#### Demonstrations with gold labels\n\nRun the commands same as [default commands](#no-demonstrations) but add `--use_demonstrations --k 16 --seed 100,13,21,42,87`.\n\n#### Demonstrations with random labels\n\nCreate the demonstrations with random labels via:\n```bash\npython create_data.py --variant random --dataset {dataset}\n```\nThen, run the commands same as [default commands](#no-demonstrations) but add `--use_demonstrations --k 16 --seed 100,13,21,42,87 --dataset {dataset}_random`.\n\n## Reproducing Ablations\n\nThis is for reproducing experiments in Section 4.2 of the paper.\nEvaluation datasets are:\n* Classification (5 datasets): `poem_sentiment`,`glue-rte`,`sick`,`glue-mrpc`,`tweet_eval-hate`\n* Multi-choice (4 datasets): `openbookqa`,`commonsense_qa`,`ai2_arc`,`superglue-copa`\n\n#### Number of correct labels\n\nCreate the demonstrations with varying number of correct labels via:\n```bash\npython create_data.py --variant {75|50|25|0}_correct --dataset {dataset}\n```\nThen, run the commands same as [default commands](#no-demonstrations) but add `--use_demonstrations --k 16 --seed 100,13,21,42,87 --dataset {dataset}_{75|50|25|0}_correct`.\n\n#### Number of input-label pairs in the demonstrations\n\n(Note that you should have run preprocessing with varying `k` to run this ablation. If you have not done this, please re-visit the [Preparation](#preparation) section.)\n\nCreate the demonstrations with varying `k` via:\n```bash\npython create_data.py --variant random --dataset {dataset} --k {4|8|16|32}\n```\nThen, run the commands same as [default commands](#no-demonstrations) but add `--use_demonstrations --k {4|8|16|32} --seed 100,13,21,42,87 --dataset {dataset}_random`.\n\n#### Using manual templates\n\nCreate the demonstrations with varying type of labels and inference method via:\n```bash\npython create_data.py --variant {gold|random}_w_template --dataset {dataset} --method {direct|channel}\n```\nThen, run the commands same as [default commands](#no-demonstrations) but add `--use_demonstrations --k 16 --seed 100,13,21,42,87 --dataset {dataset}_{gold|random}_w_template_{direct|channel}`.\n\n## Reproducing Analysis\n\nThis is for reproducing experiments in Section 5 of the paper.\nEvaluation datasets are:\n* Classification (5 datasets): `poem_sentiment`,`glue-rte`,`sick`,`glue-mrpc`,`tweet_eval-hate`\n* Multi-choice (4 datasets): `openbookqa`,`commonsense_qa`,`ai2_arc`,`superglue-copa`\n\n#### Demonstrations with OOD input text\n\nFirst, you need a corpus file in a .txt format, where each line is a sentence (in the plain text).\nIn the paper, we used samples from the English portion of CC News, which we are unable to release here.\nPlease visit [this link](https://commoncrawl.org/2016/10/news-dataset-available/) to learn more about how to download the CC News corpus.\n\nCreate the demonstrations with OOD input text via:\n```bash\npython create_data.py --variant ood_inputs --dataset {dataset} --corpus_path {corpus_path}\n```\nThen, run the commands same as [default commands](#no-demonstrations) but add `--use_demonstrations --k 16 --seed 100,13,21,42,87 --dataset {dataset}_ood_inputs`.\n\n#### Demonstrations with random english words\n\nCreate the demonstrations with random English words as labels via:\n```bash\npython create_data.py --variant random_english_words --dataset {dataset}\n```\nThen, run the commands same as [default commands](#no-demonstrations) but add `--use_demonstrations --k 16 --seed {seed} --dataset {dataset}_random_english_words_seed={seed}`, where `seed` can be one of 100, 13, 21, 42, and 87.\n\n#### Demonstrations with random labels only (no inputs)\n\nCreate the demonstrations with random labels only via:\n```bash\npython create_data.py --variant random_labels_only --dataset {dataset}\n```\nThen, run the commands same as [default commands](#no-demonstrations) but add `--use_demonstrations --k 16 --seed 100,13,21,42,87 --dataset {dataset}_random_labels_only`.\n\n#### Demonstrations with no labels (inputs only)\n\nCreate the demonstrations with no labels via:\n```bash\npython create_data.py --variant no_labels --dataset {dataset}\n```\nThen, run the commands same as [default commands](#no-demonstrations) but add `--use_demonstrations --k 16 --seed 100,13,21,42,87 --dataset {dataset}_no_labels`.\n\n\n[paper]: https://arxiv.org/abs/2202.12837\n[sewon]: http://shmsw25.github.io/\n[xinxi]: https://alrope123.github.io/\n[ari]: https://ari-holtzman.github.io/\n[mikel]: https://scholar.google.com/citations?user=N5InzP8AAAAJ&hl=en\n[mike]: https://ai.facebook.com/people/mike-lewis/\n[hanna]: https://homes.cs.washington.edu/~hannaneh/index.html\n[luke]: https://www.cs.washington.edu/people/faculty/lsz\n\n[metaicl]: https://github.com/facebookresearch/MetaICL\n"
  },
  {
    "path": "create_data.py",
    "content": "import os\nimport argparse\nimport random\nimport json\nimport numpy as np\n\nfrom collections import defaultdict, Counter\n\nfrom templates import apply_template\n\ndef main(args):\n    assert args.variant in [\n        \"gold\", \"random\", # main experiments in Section 4\n        \"75_correct\", \"50_correct\", \"25_correct\", \"0_correct\", # ablations in Section 4\n        \"gold_w_template\", \"random_w_template\", # ablations in Section 4\n        \"ood_inputs\", \"random_english_words\", \"random_labels_only\", \"no_labels\", # Section 5\n        \"random_english_words_gold_labels\", \"permutated_labels\", \"random_true_distribution\"\n    ]\n    if args.variant in [\"gold_w_template\", \"random_w_template\"]:\n        assert args.method is not None, \"Please specify `--method` with the inference method (`direct` or `channel`) for using the template.\"\n        assert args.method in [\"direct\", \"channel\"], \"Please make sure to use either `direct` or `channel`.\"\n\n    if args.variant==\"gold\":\n        print (\"No need to run `create_data.py` --- you can use the original data as it is.\")\n        return\n\n    if args.variant==\"ood_inputs\":\n        # load sources for OOD inputs\n        assert args.corpus_path is not None, \\\n        \"\"\"\n        Please note that you need to specify the path to the corpus from which the OOD inputs will be sampled.\n        It should be a .txt file where each line contains a sentence (plain text).\n        \"\"\"\n        grouped_samples = defaultdict(list)\n        with open(args.corpus_path, \"r\") as f:\n            random_texts = []\n            random_text_lens = []\n            for line in f:\n                line = line.strip()\n                random_texts.append(line)\n                random_text_lens.append(len(line.split()))\n            random_text_lens = np.array(random_text_lens)\n\n    elif args.variant in [\"random_english_words\", \"random_english_words_gold_labels\"]:\n        from english_words import english_words_set\n        english_words_set = sorted(english_words_set)\n\n    datasets = args.dataset.split(',')\n    new_datasets = [dataset + \"_\" + args.variant + ((\"_\" + args.method) if args.method is not None else \"\") for dataset in datasets]\n    assert len(datasets) == len(new_datasets)\n\n    ################################################################################################################\n\n    seeds = args.seed.split(',')\n    perfs = []\n    for dataset_idx, (dataset, new_dataset) in enumerate(zip(datasets, new_datasets)):\n\n        # contruct and save a new config file and data directory\n        config_file = os.path.join(args.config_dir, \"tasks\")\n        assert os.path.exists(config_file), config_file\n        with open(os.path.join(config_file, \"{}.json\".format(dataset)), \"r\") as f:\n            config = json.load(f)\n\n        # in case of random English words, we will create a config file and data directory\n        # for each random seed later on (since the data is different across seeds)\n        if args.variant not in [\"random_english_words\", \"random_english_words_gold_labels\"]:\n            with open(os.path.join(config_file, \"{}.json\".format(new_dataset)), \"w\") as f:\n                json.dump(config, f)\n\n            new_dataset_dir = os.path.join(args.data_dir, new_dataset)\n            if not os.path.exists(new_dataset_dir):\n                os.mkdir(new_dataset_dir)\n        \n        # load full training data to get the distribution of the labels\n        if args.variant==\"random_true_distribution\":\n            full_train_data_path = os.path.join(args.data_dir, dataset, \"{}_16384_100_train.jsonl\".format(dataset))\n            assert os.path.exists(full_train_data_path), \"Please generate full training data first by running _build_gym.py with k=16384.\"\n            full_train_data_labels = []\n            with open(full_train_data_path, \"r\") as f:\n                for line in f:\n                    dp = json.loads(line)\n                    assert dp[\"task\"]==dataset\n                    full_train_data_labels.append(dp[\"output\"])\n            train_label_counter = Counter(full_train_data_labels)\n            train_label_distribution = {label : train_label_counter[label] / len(full_train_data_labels) for label in train_label_counter}\n\n        for seed in seeds:\n            # random seed\n            np.random.seed(int(seed))\n\n            if args.variant in [\"random_english_words\", \"random_english_words_gold_labels\"]:\n                new_dataset = new_datasets[dataset_idx] + \"_seed={}\".format(seed)\n\n            # read the original training and test data\n            # note that we are modifying the training data only,\n            # and the test data will always be the same\n            # (we are creating duplicates only for convenience)\n            train_data = []\n            train_data_path = os.path.join(args.data_dir, dataset, \"{}_{}_{}_{}.jsonl\".format(dataset, args.k, seed, \"train\"))\n            with open(train_data_path, \"r\") as f:\n                for line in f:\n                    dp = json.loads(line)\n                    assert dp[\"task\"]==dataset\n                    dp[\"task\"] = new_dataset\n                    train_data.append(dp)\n\n            test_data = []\n            test_data_path = os.path.join(args.data_dir, dataset, \"{}_{}_{}_{}.jsonl\".format(dataset, args.k, seed, \"test\"))\n            with open(test_data_path, \"r\") as f:\n                for line in f:\n                    dp = json.loads(line)\n                    assert dp[\"task\"]==dataset\n                    dp[\"task\"] = new_dataset\n                    test_data.append(dp)\n\n            # apply templates to inputs and labels\n            if args.variant in [\"gold_w_template\", \"random_w_template\"]:\n                for dp in train_data:\n                    apply_template(dp, dataset, args.method)\n                for dp in test_data:\n                    apply_template(dp, dataset, args.method)\n\n            # now, for random_english_words, create a config file and data directory\n            if args.variant in [\"random_english_words\", \"random_english_words_gold_labels\"]:\n                new_dataset_dir = os.path.join(args.data_dir, new_dataset)\n                if not os.path.exists(new_dataset_dir):\n                    os.mkdir(new_dataset_dir)\n\n                if config[\"task_type\"]==\"classification\":\n                    new_options = list(np.random.choice(english_words_set, size=len(config[\"options\"]), replace=False))\n                    new_mapping = {option: new_option for option, new_option in zip(config[\"options\"], new_options)}\n                    new_config = config.copy()\n                    new_config[\"options\"] = new_options\n\n                    with open(os.path.join(config_file, \"{}.json\".format(new_dataset)), \"w\") as f:\n                        json.dump(new_config, f)\n\n                    for i, dp in enumerate(train_data):\n                        train_data[i][\"output\"] = new_mapping[dp[\"output\"]]\n                        train_data[i][\"options\"] = [new_mapping[option] for option in dp[\"options\"]]\n\n                    if args.variant == \"random_english_words_gold_labels\":\n                        # also modify the test data for classification tasks\n                        for i, dp in enumerate(test_data):\n                            test_data[i][\"output\"] = new_mapping[dp[\"output\"]]\n                            test_data[i][\"options\"] = [new_mapping[option] for option in dp[\"options\"]]\n\n                elif config[\"task_type\"]==\"multi-choice\":\n                    with open(os.path.join(config_file, \"{}.json\".format(new_dataset)), \"w\") as f:\n                        json.dump(config, f)\n\n                    shuffled_indices = np.random.permutation(range(len(english_words_set)))\n                    shuffled_options = [english_words_set[i] for i in shuffled_indices]\n                    offset = 0\n                    for i, dp in enumerate(train_data):\n                        new_options = shuffled_options[offset:offset+len(dp[\"options\"])]\n                        offset += len(dp[\"options\"])\n                        train_data[i][\"output\"] = new_options[dp[\"options\"].index(dp[\"output\"])]\n                        train_data[i][\"options\"] = new_options\n\n                else:\n                    raise NotImplementedError()\n\n            # modify both train input and test input for permutated_labels with classification tasks\n            if args.variant == \"permutated_labels\" and config[\"task_type\"]==\"classification\":\n                old_options = config[\"options\"]\n                new_options = [old_options[(i+1)%len(old_options)] for i in range(len(old_options))]\n                new_mapping = {old_option: new_option for old_option, new_option in zip(old_options, new_options)}\n\n                for i, dp in enumerate(train_data):\n                    train_data[i][\"output\"] = new_mapping[dp[\"output\"]]                    \n                for i, dp in enumerate(test_data):\n                    test_data[i][\"output\"] = new_mapping[dp[\"output\"]]\n                    \n\n            ## modify labels in the training data\n\n            if args.variant in [\"75_correct\", \"50_correct\", \"25_correct\"]:\n                num_correct = args.k * int(args.variant.split(\"_\")[0]) // 100\n                indices_correct = np.random.permutation(range(args.k))[:num_correct]\n\n            for dp_idx, dp in enumerate(train_data):\n                if args.variant in [\"gold\", \"gold_w_template\", \"permutated_labels\", \"random_english_words_gold_labels\"] or \\\n                        (args.variant in [\"75_correct\", \"50_correct\", \"25_correct\"] and dp_idx in indices_correct):\n                    # assign correct label\n                    pass\n                elif args.variant.endswith(\"_correct\"):\n                    # assign incorrect label\n                    dp[\"output\"] = dp[\"options\"][np.random.choice([i for i in range(len(dp[\"options\"])) if dp[\"options\"][i] != dp[\"output\"]])]\n                elif args.variant==\"no_labels\":\n                    # assign empty label\n                    dp[\"output\"] = \"\"\n                    dp[\"options\"] = [\"\"]\n                elif args.variant==\"random_true_distribution\":\n                    # assign random labels according to the distribution in the training data\n                    dp[\"output\"] = np.random.choice(list(train_label_distribution.keys()), p=list(train_label_distribution.values()))\n                else:\n                    # assign random label\n                    dp[\"output\"] = np.random.choice(dp[\"options\"])\n\n            ## modify inputs in the training data\n\n            if args.variant==\"random_labels_only\":\n                for dp in train_data:\n                    dp[\"input\"] = \"\"\n\n            elif args.variant==\"ood_inputs\":\n                new_train_data = []\n                for dp in test_data:\n                    l = len(dp[\"input\"].split())\n                    prob = np.exp(-np.power(random_text_lens-l, 2)/50)\n                    prob /= np.sum(prob)\n                    samples = np.random.choice(random_texts, size=args.k, replace=False, p=prob)\n                    assert len(samples)==len(train_data)\n                    new_train_data.append([])\n                    for train_dp, sample in zip(train_data, samples):\n                        new_train_data[-1].append({\"task\": train_dp[\"task\"],\n                                                    \"input\": sample,\n                                                    \"output\": train_dp[\"output\"],\n                                                    \"options\": train_dp[\"options\"]})\n                train_data = new_train_data\n\n            # write the modified data\n            with open(os.path.join(new_dataset_dir, \"{}_{}_{}_{}.jsonl\".format(new_dataset, args.k, seed, \"train\")), \"w\") as f:\n                for dp in train_data:\n                    f.write(json.dumps(dp))\n                    f.write(\"\\n\")\n\n            with open(os.path.join(new_dataset_dir, \"{}_{}_{}_{}.jsonl\".format(new_dataset, args.k, seed, \"test\")), \"w\") as f:\n                for dp in test_data:\n                    f.write(json.dumps(dp))\n                    f.write(\"\\n\")\n\n            print (\"Done for %s seed=%s\" % (new_dataset, seed))\n\n\nif __name__=='__main__':\n\n    parser = argparse.ArgumentParser()\n\n    parser.add_argument(\"--dataset\", type=str, default=None)\n    parser.add_argument(\"--k\", type=int, default=16)\n    parser.add_argument(\"--seed\", type=str, default=\"100,13,21,42,87\")\n    parser.add_argument(\"--variant\", type=str, default=\"random\", required=True)\n    parser.add_argument(\"--method\", type=str, default=None)\n\n    parser.add_argument(\"--data_dir\", type=str, default=\"data\")\n    parser.add_argument(\"--config_dir\", type=str, default=\"config\")\n    parser.add_argument(\"--corpus_path\", type=str, default=None)\n\n    args = parser.parse_args()\n\n    main(args)\n"
  },
  {
    "path": "gpt3.py",
    "content": "import time\nimport sys\nimport numpy as np\nimport torch\nimport json\nimport openai\nfrom torch.utils.data import TensorDataset, DataLoader, SequentialSampler\n\nfrom transformers import GPT2Tokenizer\n\nclass GPT3Model(object):\n\n    def __init__(self, model_name, api_key, logger=None):\n        self.model_name = model_name\n        try:\n            openai.api_key = api_key\n        except Exception:\n            pass\n        self.tokenizer = GPT2Tokenizer.from_pretrained(\"gpt2-xl\")\n        self.logger=logger\n\n\n    def prepare_data(self, train_data, test_data, method, batch_size=10, dp_sep=\"\\n\", max_length=1024):\n        # format demonstrations\n        demonstrations = \"\"\n        for dp in train_data:\n            if method==\"direct\":\n                demonstrations += \"{}{}{}\\n\\n\\n\".format(dp[\"input\"], dp_sep, dp[\"output\"])\n            elif method==\"channel\":\n                demonstrations += \"{}{}{}\\n\\n\\n\".format(dp[\"output\"], dp_sep, dp[\"input\"])\n            else:\n                raise NotImplementedError()\n\n        # append demonstrations and separate options\n        inputs = []\n        outputs = []\n        metadata = []\n        for dp in test_data:\n            prompt = dp[\"input\"]\n            options = dp[\"options\"]\n\n            indices = [i for i in range(len(inputs), len(inputs) + len(options))]\n            metadata.append({\"indices\": indices, \"options\": options})\n\n            if method==\"direct\":\n                inputs += [demonstrations + prompt + dp_sep for option in options]\n                outputs += [option for option in options]\n            elif method==\"channel\":\n                inputs += [demonstrations + option + dp_sep for option in options]\n                outputs += [prompt for option in options]\n            else:\n                raise NotImplementedError()\n\n        # truncate inputs\n        for i, (inp, out) in enumerate(zip(inputs, outputs)):\n            input_ids = self.tokenizer.encode(inp)\n            output_ids = self.tokenizer.encode(out)\n            if (len(input_ids) + len(output_ids) > max_length):\n                input_ids = input_ids[len(input_ids)+len(output_ids) - max_length:]\n                assert len(input_ids)+len(output_ids) == max_length\n            inputs[i] = self.tokenizer.decode(input_ids)\n\n        if self.logger is not None:\n            self.logger.info(\"Checking the first example...\")\n            self.logger.info(inputs[0] + \"\" + outputs[0])\n\n        # construct a dataloader\n        dataset = zip(inputs, outputs)\n        input_chunks = [inputs[i : i + batch_size] for i in range(0, len(inputs), batch_size)]\n        output_chunks = [outputs[i : i + batch_size] for i in range(0, len(outputs), batch_size)]\n        dataloader = [(input_chunks[i], output_chunks[i]) for i in range(0, len(input_chunks))]\n\n        return dataloader, metadata\n\n\n    def do_inference(self, dataloader):\n        losses = []\n        cache = []\n        cost = 0\n        for inputs, outputs in dataloader:\n            data = [inp + out for inp, out in zip(inputs, outputs)]\n            response = self.gpt3(data)\n            for choice in response[\"choices\"]:\n                cost += len(choice[\"logprobs\"][\"tokens\"]) * 0.00006\n            # print(\"current cost = \" + str(cost))\n            cache.append((data, response))\n            # get the beginning of the target from the response (based on tokenization)\n            for inp, outp, out in zip(inputs, outputs, response[\"choices\"]):\n                assert inp+outp==out[\"text\"]\n                i = 0\n                while out['logprobs']['text_offset'][i] < len(inp):\n                    i += 1\n                loss = -sum(out['logprobs'][\"token_logprobs\"][i:])\n                losses.append(loss / (len(out['logprobs']['text_offset']) - i))\n        return losses, cache\n\n\n    def do_predict(self, losses, metadata):\n        predictions = []\n        for dp in metadata:\n            curr_label_losses = [losses[index] for index in dp[\"indices\"]]\n            prediction_idx = sorted(enumerate(curr_label_losses), key=lambda x: x[1])[0][0]\n            prediction = dp[\"options\"][prediction_idx]\n            predictions.append(prediction.strip())\n        return predictions\n\n\n    def gpt3(self, prompt, max_len=0, temp=0, num_log_probs=0, echo=True, n=None):\n        # call GPT-3 API until result is provided and then return it\n        response = None\n        received = False\n        while not received:\n            try:\n                response = openai.Completion.create(engine=self.model_name,\n                                                    prompt=prompt,\n                                                    max_tokens=max_len,\n                                                    temperature=temp,\n                                                    logprobs=num_log_probs,\n                                                    echo=echo,\n                                                    stop='\\n',\n                                                    n=n)\n                received = True\n            except:\n                error = sys.exc_info()[0]\n                if error == openai.error.InvalidRequestError:\n                    # something is wrong: e.g. prompt too long\n                    print(f\"InvalidRequestError\\nPrompt passed in:\\n\\n{prompt}\\n\\n\")\n                    assert False\n                print(\"API error:\", error)\n                time.sleep(1)\n        return response\n"
  },
  {
    "path": "templates.py",
    "content": "import string\n\nTEMPLATES = {\n    \"financial_phrasebank\": {\n        \"direct\" : (\"{}\", \"The sentiment is: {}\"),\n        \"channel\": (\"{}\", \"The sentiment is: {}\")\n    },\n    \"poem_sentiment\": {\n        \"direct\" : (\"{}\", \"The sentiment is: {}\"),\n        \"channel\": (\"{}\", \"The sentiment is: {}\")\n    },\n    \"glue-mrpc\": {\n        \"direct\" : (\"{}\\nThe question is: {} True or False?\", \"The answer is: {}\"),\n        \"channel\": (\"The question is: {} True or False?\\n{}\", \"The answer is: {}\")\n    },\n    \"glue-rte\": {\n        \"direct\" : (\"{}\\nThe question is: {} True or False?\", \"The answer is: {}\"),\n        \"channel\": (\"The question is: {} True or False?\\n{}\", \"The answer is: {}\")\n    },\n    \"sick\": {\n        \"direct\" : (\"{}\\nThe question is: {} True or False?\", \"The answer is: {}\"),\n        \"channel\": (\"The question is: {} True or False?\\n{}\", \"The answer is: {}\")\n    },\n    \"tweet_eval-hate\": {\n        \"direct\" : (\"Tweet: {}\", \"Sentiment: {}\"),\n        \"channel\": (\"Tweet: {}\", \"Sentiment: {}\"),\n    },\n    \"openbookqa\": {\n        \"direct\" : (\"The question is: {}\", \"The answer is: {}\"),\n        \"channel\": (\"The question is: {}\", \"The answer is: {}\")\n    },\n    \"ai2_arc\": {\n        \"direct\" : (\"The question is: {}\", \"The answer is: {}\"),\n        \"channel\": (\"The question is: {}\", \"The answer is: {}\")\n    },\n    \"codah\": {\n        \"direct\" : (\"The question is: {}\", \"The answer is: {}\"),\n        \"channel\": (\"The question is: {}\", \"The answer is: {}\")\n    },\n    \"commonsense_qa\": {\n        \"direct\" : (\"The question is: {}\", \"The answer is: {}\"),\n        \"channel\": (\"The question is: {}\", \"The answer is: {}\")\n    }\n}\n\ndef apply_template(dp, dataset, method):\n    if dataset.startswith(\"superglue-copa\"):\n        if method == \"direct\":\n            if dp[\"input\"].startswith(\"Cause: \"):\n                dp[\"input\"] = dp[\"input\"][7:-1] + \" so\"\n                dp[\"output\"] = dp[\"output\"][8].lower() + dp[\"output\"][9:]\n                for i, options in enumerate(dp[\"options\"]):\n                    dp[\"options\"][i] = dp[\"options\"][i][8].lower() + dp[\"options\"][i][9:]\n            elif dp[\"input\"].startswith(\"Effect: \"):\n                dp[\"input\"] = dp[\"input\"][8:-1] + \" because\"\n                dp[\"output\"] = dp[\"output\"][7].lower() + dp[\"output\"][8:]\n                for i, options in enumerate(dp[\"options\"]):\n                    dp[\"options\"][i] = dp[\"options\"][i][7].lower() + dp[\"options\"][i][8:]\n            else:\n                raise NotImplementedError()\n        elif method == \"channel\":\n            if dp[\"output\"].startswith(\"Cause: \"):\n                dp[\"output\"] = dp[\"output\"][7:-1] + \" so\"\n                dp[\"input\"] = dp[\"input\"][8].lower() + dp[\"input\"][9:]\n                for i, options in enumerate(dp[\"options\"]):\n                    dp[\"options\"][i] = dp[\"options\"][i][7:-1] + \" so\"\n            elif dp[\"output\"].startswith(\"Effect: \"):\n                dp[\"output\"] = dp[\"output\"][8:-1] + \" because\"\n                dp[\"input\"] = dp[\"input\"][7].lower() + dp[\"input\"][8:]\n                for i, options in enumerate(dp[\"options\"]):\n                    dp[\"options\"][i] =  dp[\"options\"][i][8:-1] + \" because\"\n        else:\n            raise NotImplementedError(o)\n    elif dataset.startswith(\"glue\") or dataset.startswith(\"sick\"):\n        def map_option(option):\n            if option in [\"equivalent\", \"entailment\"]:\n                return \"True\"\n            if option in [\"not_equivalent\", \"not_entailment\", \"contradiction\"]:\n                return \"False\"\n            if option in [\"neutral\"]:\n                return \"Not sure\"\n            raise NotImplementedError(option)\n        dp[\"input\"] = dp[\"input\"].replace(\"sentence 1: \", \"\").replace(\"sentence 2: \", \"\")\n        splits = dp[\"input\"].split(\" [SEP] \")\n        if method==\"channel\":\n            splits = [splits[1], splits[0]]\n        splits = [split if split[-1] in string.punctuation else split+\".\" for split in splits]\n        dp[\"input\"] = TEMPLATES[dataset][method][0].format(splits[0], splits[1])\n        dp[\"output\"] = TEMPLATES[dataset][method][1].format(map_option(dp[\"output\"]))\n        for i, options in enumerate(dp[\"options\"]):\n            dp[\"options\"][i] =TEMPLATES[dataset][method][1].format(map_option(dp[\"options\"][i]))\n    else:\n        def map_option(option):\n            if dataset==\"tweet_eval-hate\":\n                return {\"hate\": \"against\", \"non-hate\": \"favor\"}[option]\n            return option\n        dp[\"input\"] = TEMPLATES[dataset][method][0].format(dp[\"input\"])\n        dp[\"output\"] = TEMPLATES[dataset][method][1].format(map_option(dp[\"output\"]))\n        for i, options in enumerate(dp[\"options\"]):\n            dp[\"options\"][i] =TEMPLATES[dataset][method][1].format(map_option(dp[\"options\"][i]))\n\n\n\n\n\n"
  },
  {
    "path": "test_gpt3.py",
    "content": "# Copyright (c) Facebook, Inc. and its affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the license found in the\n# LICENSE file in the root directory of this source tree.\n\nimport os\nimport argparse\nimport pickle as pkl\nimport random\nimport torch\nimport math\nimport json\nimport string\nimport logging\nimport numpy as np\n\nfrom tqdm import tqdm\nfrom collections import Counter, defaultdict\n\nfrom torch.utils.data import TensorDataset, DataLoader, SequentialSampler\nfrom transformers import GPT2Tokenizer, AutoTokenizer\n\nfrom metaicl.data import MetaICLData\nfrom metaicl.model import MetaICLModel\n\nfrom utils.data import load_data\n\nfrom gpt3 import GPT3Model\n\ndef main(logger, args):\n    assert (args.dataset is not None and args.task is None) or (args.dataset is None and args.task is not None)\n\n    tokenizer = AutoTokenizer.from_pretrained(\"gpt2\")\n    add_newlines = True\n\n    ### checkpoint ...\n    if not args.do_zeroshot:\n        if args.checkpoint is not None:\n            checkpoint = args.checkpoint\n            assert args.global_step is None\n        else:\n            assert args.global_step is not None\n            checkpoint = os.path.join(args.out_dir, \"model-{}.pt\".format(args.global_step))\n        assert os.path.exists(checkpoint)\n    else:\n        add_newlines = False\n        checkpoint = None\t\n        \t\n    metaicl_model = GPT3Model(args.gpt3, args.api, logger)\n\n    if not os.path.exists(args.out_dir):\n        os.makedirs(args.out_dir)\n\n    # setup hyperparams for data\n\n    max_length_per_example = 256\n    max_length = 256\n    if args.use_demonstrations:\n        orig_max_length = max_length\n        if args.do_zeroshot:\n            max_length = min(max_length * args.k, 1024)\n        else:\n            max_length = min(max_length * args.k, 1024)\n\n    logger.info(\"batch_size=%d\\tmax_length=%d\\tmax_length_per_example=%d\" % (\n        args.test_batch_size, max_length, max_length_per_example))\n\n    metaicl_data = MetaICLData(logger, tokenizer, args.method,args.use_demonstrations, args.k,\n                               max_length, max_length_per_example)\n\n    results = []\n    errors = []\n    seeds = args.seed.split(\",\")\n    config_split = \"unseen_domain_test\" if args.unseen_domain_only else \"test\"\n\n    for seed in seeds:\n\n        ### data ...\n        train_data = load_data(args.task, \"train\", args.k, seed=seed, config_split=config_split,\n                               datasets=None if args.dataset is None else args.dataset.split(\",\"))\n        dev_data = load_data(args.task, args.split, args.k, seed=seed, config_split=config_split,\n                             datasets=None if args.dataset is None else args.dataset.split(\",\"), is_null=args.is_null)\n\n        train_counter = Counter()\n        dev_counter = Counter()\n        for dp in train_data:\n            train_counter[dp[\"task\"]] += 1\n        for dp in dev_data:\n            dev_counter[dp[\"task\"]] += 1\n        for k, v in train_counter.items():\n            logger.info(\"[Train] %s\\t%d\" % (k, v))\n        for k, v in dev_counter.items():\n            logger.info(\"[Dev] %s\\t%d\" % (k, v))\n\n        logger.info(\"%s on %s (%d train, %d dev)\" % (args.method, args.task, len(train_counter), len(dev_counter)))\n\n        for test_task in dev_counter:\n            curr_dev_data = [dp for dp in dev_data if dp[\"task\"]==test_task]\n            curr_train_data = [dp for dp in train_data if dp[\"task\"]==test_task]\n            assert len(curr_dev_data)>0\n            assert not args.use_demonstrations or len(curr_train_data)==args.k, \\\n                    (args.use_demonstrations, len(curr_train_data), args.k)\n\n            config_file = \"config/tasks/{}.json\".format(test_task)\n            assert os.path.exists(config_file), config_file\n            with open(config_file, \"r\") as f:\n                config = json.load(f)\n            is_classification = config[\"task_type\"]==\"classification\"\n            if is_classification:\n                options = curr_dev_data[0][\"options\"]\n                assert np.all([d[\"options\"]==options for d in curr_dev_data+curr_train_data])\n\n            result = run(logger, test_task, metaicl_data, metaicl_model,\n                         curr_train_data, curr_dev_data, seed, checkpoint, is_classification, add_newlines)\n\n            if result is None:\n                errors.append(\"%s/%s\" % (test_task, seed))\n            else:\n                results.append(result)\n\n    if args.is_null:\n        return\n\n    logger.info(\"Macro-F1 of %s over %d target tasks: %.1f\" % (args.task, len(results) // len(seeds), 100*np.mean(results)))\n\n    if len(errors)>0:\n        logger.info(\"You had errors with datasets:\", \",\".join(errors))\n        logger.info(\"Please see the error messages\")\n\n\ndef run(logger, task, metaicl_data, metaicl_model, train_data, dev_data, seed,\n        checkpoint, is_classification, add_newlines):\n\n    if args.do_zeroshot:\n        split_name = args.split\n        if args.is_null:\n            split_name += \"-null\"\n        cache_path = os.path.join(args.out_dir,\n                                  \"{}-{}-{}{}{}{}.pkl\".format(\n                                      task,\n                                      split_name,\n                                      metaicl_data.method,\n                                      \"-k={}\".format(args.k) if args.use_demonstrations else \"\",\n                                      \"-s={}\".format(seed) if args.use_demonstrations else \"\",\n                                      \"\" if add_newlines else \"-no-newlines\"))\n        gpt3_cache_path = os.path.join(args.out_dir,\n                                  \"{}-{}-{}{}{}{}.json\".format(\n                                      task,\n                                      split_name,\n                                      metaicl_data.method,\n                                      \"-k={}\".format(args.k) if args.use_demonstrations else \"\",\n                                      \"-s={}\".format(seed) if args.use_demonstrations else \"\",\n                                      \"\" if add_newlines else \"-no-newlines\"))\n    else:\n        assert add_newlines\n        cache_path = os.path.join(args.out_dir, \"{}-{}-{}{}{}.pkl\".format(\n                        task,\n                        args.split,\n                        metaicl_data.method,\n                        \"-k={}\".format(args.k) if args.use_demonstrations else \"\",\n                        \"-s={}\".format(seed) if args.use_demonstrations else \"\"\n                      ))\n        gp3_cache_path = os.path.join(args.out_dir, \"{}-{}-{}{}{}.json\".format(\n                        task,\n                        args.split,\n                        metaicl_data.method,\n                        \"-k={}\".format(args.k) if args.use_demonstrations else \"\",\n                        \"-s={}\".format(seed) if args.use_demonstrations else \"\"\n                      ))\n\n    metaicl_data.tensorize(train_data, dev_data, add_newlines=add_newlines)\n    gpt3_dataloader, gpt3_metadata = metaicl_model.prepare_data(train_data if args.use_demonstrations else [],\n                                dev_data, args.method, batch_size=args.test_batch_size)\n    # metaicl_data.print_tensorized_example()\n    logger.info(cache_path)\n\n    if os.path.exists(cache_path):\n        with open(cache_path, \"rb\") as f:\n            losses = pkl.load(f)\n    else:\n        losses, gpt3cache = metaicl_model.do_inference(gpt3_dataloader)\t\n        with open(cache_path, \"wb\") as f:\n            pkl.dump(losses, f)\n        with open(gpt3_cache_path, \"w\") as f:\t\n            json.dump(gpt3cache, f)\n\n    if args.is_null:\n        return None\n\n    if args.use_calibration:\n        assert args.do_zeroshot\n        bias_path = cache_path.replace(\"/\"+task+\"-\"+args.split, \"/\"+task+\"-\"+args.split+\"-null\")\n        assert os.path.exists(bias_path), bias_path\n        with open(bias_path, \"rb\") as f:\n            bias_losses = pkl.load(f)\n\n        losses = np.array(losses)\n        bias_losses = np.array(bias_losses)\n        assert losses.shape == bias_losses.shape\n        losses -= bias_losses\n\n    predictions = metaicl_model.do_predict(losses=losses, metadata=gpt3_metadata)\n    groundtruths = [dp[\"output\"] for dp in dev_data]\n    perf = metaicl_data.evaluate(predictions, groundtruths, is_classification)\n    logger.info(\"Accuracy=%s\" % perf)\n\n    prediction_path = cache_path.replace(\".pkl\", \".txt\")\n    if args.use_calibration:\n        prediction_path = prediction_path.replace(\".txt\", \"-calibrated.txt\")\n\n    with open(prediction_path, \"w\") as f:\n        for prediction in predictions:\n            f.write(prediction)\n            f.write(\"\\n\")\n\n    return perf\n\nif __name__=='__main__':\n\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--do_zeroshot\", default=False, action=\"store_true\")\n    parser.add_argument(\"--use_demonstrations\", default=False, action=\"store_true\")\n    parser.add_argument(\"--use_calibration\", default=False, action=\"store_true\")\n    parser.add_argument(\"--unseen_domain_only\", default=False, action=\"store_true\")\n\n    parser.add_argument(\"--log_file\", default=None, type=str)\n\n    parser.add_argument(\"--task\", type=str, default=None)\n    parser.add_argument(\"--dataset\", type=str, default=None)\n    parser.add_argument(\"--k\", type=int, default=16)\n    parser.add_argument(\"--seed\", type=str, default=\"100\")\n\n    parser.add_argument(\"--test_batch_size\", type=int, default=64)\n    parser.add_argument(\"--global_step\", type=str, default=None)\n    parser.add_argument(\"--checkpoint\", type=str, default=None)\n\n    parser.add_argument(\"--out_dir\", type=str, required=True)\n\n    parser.add_argument(\"--split\", type=str, default=\"test\")\n    parser.add_argument(\"--is_null\", default=False, action=\"store_true\")\n    parser.add_argument(\"--method\", type=str, default=\"direct\", choices=[\"direct\", \"channel\"])\n    parser.add_argument(\"--gpt3\", type=str, default=\"davinci\", choices=[\"ada\", \"babbage\", \"curie\", \"davinci\"])\n    parser.add_argument(\"--api\", type=str, required=True)\n\n    args = parser.parse_args()\n\n    handlers = [logging.StreamHandler()]\n    if args.log_file is not None:\n        handlers.append(logging.FileHandler(args.log_file))\n    logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',\n                        datefmt='%m/%d/%Y %H:%M:%S',\n                        level=logging.INFO,\n                        handlers=handlers)\n    logger = logging.getLogger(__name__)\n    logger.info(args)\n\n    main(logger, args)\n"
  }
]