Repository: timoschick/dino
Branch: main
Commit: 0e4a0525f05d
Files: 17
Total size: 86.2 KB
Directory structure:
gitextract_0h3fq7ni/
├── .gitignore
├── LICENSE
├── README.md
├── dino.py
├── generation.py
├── modeling.py
├── requirements.txt
├── scripts/
│ ├── imdb/
│ │ ├── __init__.py
│ │ ├── run_supervised.py
│ │ └── run_unsupervised.py
│ └── sts/
│ ├── postprocess_dataset.py
│ └── run_training.py
├── task_specs/
│ ├── imdb-movies.json
│ ├── imdb-reviews.json
│ ├── sts-x1.json
│ └── sts-x2.json
└── utils.py
================================================
FILE CONTENTS
================================================
================================================
FILE: .gitignore
================================================
*.pickle
.idea
*.iml
*.pyc
*.ipynb
.ipynb_checkpoints
.pytest_cache
venv/*
data/*
================================================
FILE: LICENSE
================================================
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
================================================
FILE: README.md
================================================
# Datasets from Instructions (DINO 🦕)
This repository contains the code for [Generating Datasets with Pretrained Language Models](https://arxiv.org/abs/2104.07540). The paper introduces a method called *Datasets from Instructions* (DINO 🦕) that enables pretrained language models to generate entire datasets from scratch.
## 🔧 Setup
All requirements for DINO can be found in ``requirements.txt``. You can install all required packages in a new environment with ``pip install -r requirements.txt``.
## 💬 CLI Usage
#### Single Texts
To generate datasets for (single) text classification, you can use DINO as follows:
````
python3 dino.py \
--output_dir \
--task_file \
--num_entries_per_label \
--batch_size 1
````
where ```` is a directory to which the generated dataset is written, ```` is a JSON file containing a *task specification* (see [Task Specs](#-task-specs)), and ```` is the number of examples to generate per label. To get an overview of additional parameters, run ``python3 dino.py --help``.
#### Text Pairs
To generate datasets for text pair classification, you first need a dataset of raw input texts (which you can also generate using DINO). You can then run
````
python3 dino.py \
--output_dir \
--task_file \
--input_file \
--input_file_type \
--num_entries_per_input_and_label
````
with ```` and ```` as before. ```` refers to the file containing raw input texts, ```` specifies its type, which should be one of
- ``plain``: for a plain text file with one input text per line
- ``jsonl``: for a dataset file generated by DINO in a previous step
and ```` is the number of examples to generate per label and input text.
## 📋 Task Specs
🚨 *Before you write custom task specifications, please note that this is still a very early release and we have not tested DINO on other tasks than semantic textual similarity yet. Please let us know if you see something strange.* 🚨
To generate a dataset for a task, you need to provide a file containing a *task specification*, containing (among other things) the instructions given to the pretrained language model. A task specification is a single JSON object that looks like this:
```
{
"task_name": "",
"labels": {
"": {
"instruction": "",
"counter_labels": []
},
...,
"": {
"instruction": "",
"counter_labels": []
}
}
}
```
Here, ```` is the name for the task and ````, ..., ```` are the task's labels. For each label ````, ```` is the instruction provided to the language model for generating examples with label `` (see [Writing Instructions](#writing-instructions)). You can additionally specify a list of counter labels ```` for each label. This tells the model to generate outputs that are not only likely given the current label, but also unlikely given all counter labels (see [the paper](https://arxiv.org/abs/2104.07540) for details).
#### Examples
You can find two examples of task specifications in ``/task_specs``:
- ``sts.json`` is a task specification for generating a semantic textual similarity dataset if *a set of raw input texts is already given*.
- ``sts-x1.json`` is a task specification for generating a set of raw input texts. This set can then be used in a subsequent step to generate a full STS dataset using ``sts.json``.
#### Writing Instructions
When writing instructions for a new task, you should consider the following things:
- Always end your instructions with an (opening) **quotation mark** (`"`). This is required because it allows us to interpret the next quotation mark generated by the language model as a signal that it is done generating an example.
- For good results, keep the instructions as **short and simple** as possible as this makes it easier for a pretrained language model to understand them.
- If you are writing instructions for a text **pair** classification task, make sure that each instruction contains the placeholder ```` exactly once. At this position, the provided raw input sentences are inserted during generation.
An example for an instruction that prompts the model to generate a positive review for a restaurant would be:
````
Task: Write a review for a really great restaurant.
Review: "
````
An example for an instruction that prompts the model to generate a sentence that has the same meaning as another given sentence would be:
````
Task: Write two sentences that mean the same thing.
Sentence 1: ""
Sentence 2: "
````
## 🦕 Generated DINOs
This section lists datasets that we have generated using DINO.
| Dataset | Description | Link |
| :------ | :---------- | :--- |
| STS‑🦕‑x2 (pp) | A postprocessed version of STS-🦕-x2. Postprocessing includes label smoothing, data augmentation and selecting at most two x2's for each (x1, y) and is performed using this script. | [📥 Download](https://www.cis.uni-muenchen.de/~schickt/dino/sts-dino-x2-postprocessed.jsonl) |
| STS‑🦕‑x1x2 (pp) | A postprocessed version of STS-🦕-x1x2. Postprocessing includes label smoothing, data augmentation and selecting at most two x2's for each (x1, y) and is performed using this script. | [📥 Download](https://www.cis.uni-muenchen.de/~schickt/dino/sts-dino-x1x2-postprocessed.jsonl) |
| STS‑🦕‑x2 (raw) | A semantic textual similarity dataset generated with DINO, where the first text for each pair (x1, x2) is from the STS benchmark. For almost all use cases, you probably want to use the postprocessed (pp) version of this dataset. | [📥 Download](https://www.cis.uni-muenchen.de/~schickt/dino/sts-dino-x2.jsonl) |
| STS‑🦕‑x1x2 (raw) | A semantic textual similarity dataset generated with DINO, where each pair (x1, x2) is generated from scratch. For almost all use cases, you probably want to use the postprocessed (pp) version of this dataset. | [📥 Download](https://www.cis.uni-muenchen.de/~schickt/dino/sts-dino-x1x2.jsonl) |
## 📕 Citation
If you make use of the code in this repository or of any DINO-based dataset, please cite the following paper:
````
@article{schick2020generating,
title={Generating Datasets with Pretrained Language Models},
author={Timo Schick and Hinrich Schütze},
journal={Computing Research Repository},
volume={arXiv:2104.07540},
url={https://arxiv.org/abs/2104.07540},
year={2021}
}
````
================================================
FILE: dino.py
================================================
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script can be used to generate datasets with DINO (Datasets from Instructions).
"""
import argparse
import json
import os
import re
from datetime import datetime
from typing import Any, Dict
from modeling import GPT2Wrapper, DinoGenerator, PLACEHOLDER_STR
from utils import set_seed, read_inputs, DatasetEntry
def validate_args(args) -> None:
"""Validate the given command line arguments"""
if args.input_file is not None:
assert args.num_entries_per_input_and_label is not None, "If 'input_file' is set, 'num_entries_per_input_and_label' must be set"
assert args.num_entries_per_label is None, "If 'input_file' is set, 'num_entries_per_label' must not be set"
assert args.batch_size is None, "If 'input_file' is set, batch_size must not be set as 'num_entries_per_input_and_label' also " \
"serves as batch size in this case"
else:
assert args.num_entries_per_input_and_label is None, "If 'input_file' is not set, 'num_entries_per_input_and_label' must not be set"
assert args.num_entries_per_label is not None, "If 'input_file' is not set, 'num_entries_per_label' must be set"
assert args.batch_size is not None, "If 'input_file' is not set, 'batch_size' must be set"
def validate_task_spec(task_spec: Dict[str, Any], with_inputs: bool) -> None:
"""Validate the given task specification"""
error_prefix = "Invalid task specification:"
assert 'task_name' in task_spec, f"{error_prefix} missing field 'task_name'"
assert isinstance(task_spec['task_name'], str) and re.match(r"^[A-Za-z0-9\-_.]+$", task_spec['task_name']), \
f"{error_prefix} 'task_name' must be a string consisting only of [A-Za-z0-9\\-_.]"
assert 'labels' in task_spec, f"{error_prefix} missing field 'labels'"
assert isinstance(task_spec['labels'], dict), f"{error_prefix} 'labels' must be a dictionary"
all_labels = task_spec['labels'].keys()
for label, label_dict in task_spec['labels'].items():
assert isinstance(label_dict, dict), f"{error_prefix} label '{label}' is not mapped to a dictionary"
assert not label_dict.keys() - {'instruction', 'counter_labels'}, \
f"{error_prefix} invalid keys for label '{label}', only 'instruction' and 'counter_labels' are allowed"
assert 'instruction' in label_dict.keys(), f"{error_prefix} missing field 'instruction' for label '{label}'"
assert isinstance(label_dict['instruction'], str), f"{error_prefix} 'instruction' not a string for label '{label}'"
assert label_dict['instruction'][-1] == '"', \
f"{error_prefix} each instruction should end with an opening quotation mark (\") so that the next quotation mark generated " \
f"by the model can be interpreted as a signal that it is done."
if with_inputs:
assert label_dict['instruction'].count(PLACEHOLDER_STR) == 1, \
f"{error_prefix} The instruction for label '{label}' does not contain exactly one placeholder token ({PLACEHOLDER_STR}). " \
f"If an input file is specified, each instruction must contain this placeholder to indicate where the input should be " \
f"inserted."
else:
assert label_dict['instruction'].count(PLACEHOLDER_STR) == 0, \
f"{error_prefix} The instruction for label '{label}' contains a placeholder token ({PLACEHOLDER_STR}). If no input file " \
f"is specified, instructions must not contain this placeholder as there is no input to replace it with."
if 'counter_labels' in label_dict.keys():
assert isinstance(label_dict['counter_labels'], list), f"{error_prefix} 'counter_labels' not a list for label '{label}'"
for counter_label in label_dict['counter_labels']:
assert counter_label in all_labels, f"{error_prefix} counter_label '{counter_label}' for label '{label}' is not a label"
if __name__ == '__main__':
parser = argparse.ArgumentParser()
# Required parameters
parser.add_argument("--output_dir", type=str, required=True,
help="The output directory to which the generated dataset is saved")
parser.add_argument("--task_file", type=str, required=True,
help="A json file providing the instructions and other information required for dataset generation. "
"See the 'task_specs' directory for examples and 'README.md' for more details on how to create this file.")
# Text generation and sampling parameters
parser.add_argument("--model_name", type=str, default="gpt2-xl",
help="The pretrained model to use for dataset generation. Currently, only variants of GPT2 are supported.")
parser.add_argument("--openai_api_key", type=str, default=None)
parser.add_argument("--max_output_length", type=int, default=40,
help="The maximum output length for each generated text.")
parser.add_argument("--decay_constant", type=float, default=100,
help="The decay constant for self-debiasing")
parser.add_argument("--top_p", type=float, default=0.9,
help="p value for top-p sampling (set to 0 to perform no top-p sampling)")
parser.add_argument("--top_k", type=int, default=5,
help="k value for top-k sampling (set to 0 to perform no top-k sampling)")
# Dataset parameters
parser.add_argument("--input_file", type=str,
help="An optional input file containing raw texts. This is required for generating text pair datasets.")
parser.add_argument("--input_file_type", choices=["plain", "jsonl", "stsb"], default="plain",
help="The type of the input file. Choices are 'plain' (a raw text file with one input per line), 'jsonl' (a jsonl "
"file as produced by DINO) and 'stsb' (a TSV file in the STS Benchmark format)")
parser.add_argument("--num_entries_per_input_and_label", type=int, default=None,
help="The number of entries to generate for each pair of input text and label (only if --input_file is set)")
parser.add_argument("--num_entries_per_label", type=int, default=None,
help="The number of entries to generate for each label (only if --input_file is not set)")
parser.add_argument("--batch_size", type=int, default=None,
help="The batch size for generation (only if --input_file is not set)")
parser.add_argument("--remove_duplicates", action='store_true',
help="Whether duplicates should be removed from the generated dataset")
parser.add_argument("--remove_identical_pairs", action='store_true',
help="Whether text pairs with text_a == text_b should be removed from the dataset (only for text pair datasets)")
parser.add_argument("--keep_outputs_without_eos", action='store_true',
help="If set to true, examples where the language model does not output a quotation mark (which is interpreted as "
"a signal that it has completed its output) are not removed from the dataset.")
parser.add_argument("--allow_newlines_in_outputs", action='store_true',
help="If set to true, model outputs that contain a newline character before the end-of-sequence token (a quotation "
"mark) are not removed from the dataset.")
parser.add_argument("--min_num_words", type=int, default=-1,
help="The minimum number of (whitespace-separated) words for each dataset entry. Entries with fewer words are "
"removed.")
parser.add_argument("--min_num_tokens", type=int, default=-1,
help="The minimum number of tokens for each dataset entry. Entries with fewer tokens are removed.")
# Miscellaneous further parameters
parser.add_argument("--no_cuda", action='store_true')
parser.add_argument("--seed", type=int, default=42)
args = parser.parse_args()
validate_args(args)
set_seed(args.seed)
args.date = datetime.now().strftime("%d/%m/%Y %H:%M:%S")
print(f"Parameters: {args}")
if not os.path.exists(args.output_dir):
os.makedirs(args.output_dir)
with open(args.task_file, 'r', encoding='utf8') as fh:
task_specification = json.load(fh)
validate_task_spec(task_specification, with_inputs=args.input_file is not None)
args_file = os.path.join(args.output_dir, f'{task_specification["task_name"]}-args.json')
with open(args_file, 'w', encoding='utf8') as fh:
fh.write(json.dumps(vars(args), indent=4))
inputs = read_inputs(args.input_file, args.input_file_type) if args.input_file else None
if args.openai_api_key:
print(f"Using OpenAI's GPT3 ({args.model_name}) as generator. The following parameters are ignored: ['decay_constant', 'top_k']")
model = GPT2Wrapper(model_name=args.model_name, use_cuda=not args.no_cuda) if not args.openai_api_key else args.model_name
generator = DinoGenerator(
task_spec=task_specification, model=model, openai_api_key=args.openai_api_key, max_output_length=args.max_output_length,
decay_constant=args.decay_constant, top_p=args.top_p, top_k=args.top_k, remove_duplicates=args.remove_duplicates,
remove_identical_pairs=args.remove_identical_pairs, min_num_words=args.min_num_words, min_num_tokens=args.min_num_tokens,
keep_outputs_without_eos=args.keep_outputs_without_eos, allow_newlines_in_outputs=args.allow_newlines_in_outputs
)
print("Starting dataset generation with DINO...")
outputs = generator.generate_dataset(inputs, num_entries_per_input_and_label=args.num_entries_per_input_and_label,
num_entries_per_label=args.num_entries_per_label, batch_size=args.batch_size)
print(f"Dataset generation complete, dataset contains {len(outputs)} entries")
dataset_path = os.path.join(args.output_dir, f'{task_specification["task_name"]}-dataset.jsonl')
DatasetEntry.save_list(outputs, dataset_path)
print(f"Done saving dataset to file '{dataset_path}'")
================================================
FILE: generation.py
================================================
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This module contains various classes and functions required for text generation with self-debiasing.
"""
from typing import List, Optional, Union, Tuple
import torch
import torch.nn.functional as F
from transformers import GPT2LMHeadModel, LogitsProcessorList, LogitsProcessor, PreTrainedTokenizer
from transformers.generation_utils import GenerationMixin, SampleOutput, SampleEncoderDecoderOutput, SampleDecoderOnlyOutput
class SelfDebiasingLogitsProcessor(LogitsProcessor):
"""This class represents a logits processor that applies self-debiasing."""
def __init__(self, num_debiasing_prefixes: int, decay_constant: float = 100, epsilon: float = 0.01, debug: bool = False,
tokenizer: Optional[PreTrainedTokenizer] = None):
"""
:param num_debiasing_prefixes: the number of debiasing prefixes used
:param decay_constant: the decay constant (lambda in the paper)
:param epsilon: the minimum factor by which each probability is multiplied
:param debug: whether to print additional debugging output
:param tokenizer: a tokenizer used to print debugging output
"""
assert not debug or tokenizer, "If debug=True, a tokenizer must be passed to SelfDebiasingLogitsProcessor()"
self.num_debiasing_prefixes = num_debiasing_prefixes
self.decay_constant = decay_constant
self.epsilon = epsilon
self.debug = debug
self.tokenizer = tokenizer
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
batch_size = scores.shape[0] // (1 + self.num_debiasing_prefixes)
regular_sentence_indices = range(batch_size)
for regular_sentence_idx in regular_sentence_indices:
bias_indices = self._get_bias_indices(regular_sentence_idx, batch_size)
if bias_indices:
self._debias_scores(scores, regular_sentence_idx, bias_indices)
return scores
def _get_bias_indices(self, regular_sentence_idx: int, batch_size: int) -> List[int]:
"""Returns the indices of all self-debiasing inputs for a regular input"""
return [regular_sentence_idx + (prefix_idx + 1) * batch_size for prefix_idx in range(self.num_debiasing_prefixes)]
def _debias_scores(self, scores: torch.FloatTensor, regular_sent_idx: int, bias_indices: List[int]) -> None:
"""Partially debiases the given scores considering a single sentence and the corresponding self-debiasing inputs"""
logits_biased = [scores[bias_idx] for bias_idx in bias_indices]
mask = self._generate_decay_mask(scores[regular_sent_idx], logits_biased)
scores[regular_sent_idx] = torch.log(self._apply_decay_mask(scores[regular_sent_idx], mask))
for debiasing_sent_idx in bias_indices:
scores[debiasing_sent_idx] = scores[regular_sent_idx]
def _apply_decay_mask(self, logits: torch.Tensor, decay_mask: torch.Tensor) -> torch.Tensor:
"""Applies exponential decay to a tensor of logits"""
probabilities = logits.softmax(dim=-1)
decay_mask = torch.exp(- decay_mask * self.decay_constant)
decay_mask = torch.max(decay_mask, torch.tensor([self.epsilon], device=decay_mask.device))
probabilities = probabilities * decay_mask
probabilities = probabilities / probabilities.sum(dim=-1)
return probabilities
def _generate_decay_mask(self, logits_regular: torch.FloatTensor, logits_biased_list: List[torch.FloatTensor]) -> torch.Tensor:
"""Computes the alpha values (see paper) for each token and stores them in a mask tensor"""
p_regular = logits_regular.softmax(dim=-1)
p_biased = None
for logits_biased in logits_biased_list:
if p_biased is None:
p_biased = logits_biased.softmax(dim=-1)
else:
p_biased = torch.max(p_biased, logits_biased.softmax(dim=-1))
if self.debug:
print(f'== Before Debiasing ==\n'
f'Top 5 predictions (regular): {self._get_most_likely_tokens(p_regular, k=5)}\n'
f'Top 5 predictions (biased): {self._get_most_likely_tokens(p_biased, k=5)}')
mask = torch.max(p_biased - p_regular, torch.tensor([0.], device=p_regular.device))
if self.debug:
p_regular = self._apply_decay_mask(logits_regular, mask)
print(f'== After Debiasing ==\n'
f'Top 5 predictions (regular): {self._get_most_likely_tokens(p_regular, k=5)}')
return mask
def _get_most_likely_tokens(self, probabilities_tensor: torch.Tensor, k: int) -> List[Tuple[str, float]]:
"""Returns the most likely tokens according to a tensor of probabilities"""
assert len(probabilities_tensor.shape) == 1
values, indices = torch.topk(probabilities_tensor, k=k, dim=-1)
tokens = self.tokenizer.convert_ids_to_tokens(indices)
return list(zip(tokens, [pv.item() for pv in values]))
class SelfDebiasingGPT2LMHeadModel(GPT2LMHeadModel, GenerationMixin):
"""
This class represents a regular GPT2LMHeadModel that additionally has the capacity to perform self-debiasing. For self-debiasing, the
init_logits_processor function must be called. Otherwise, this model just performs regular language modeling.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.logits_processor = None # type: Optional[SelfDebiasingLogitsProcessor]
def init_logits_processor(self, *args, **kwargs):
"""Initialize the logits processor. For a list of arguments, see the self-debiasing logit processor's init function."""
self.logits_processor = SelfDebiasingLogitsProcessor(*args, **kwargs)
def _get_logits_processor(self, *args, **kwargs) -> LogitsProcessorList:
logits_processor = super()._get_logits_processor(*args, **kwargs)
if self.logits_processor is not None:
logits_processor.append(self.logits_processor)
return logits_processor
def beam_sample(self, *args, **kwargs):
raise NotImplementedError("Beam sampling is not implemented for self-debiasing models")
def sample(self, input_ids: torch.LongTensor, logits_processor: Optional[LogitsProcessorList] = None,
logits_warper: Optional[LogitsProcessorList] = None, max_length: Optional[int] = None, pad_token_id: Optional[int] = None,
eos_token_id: Optional[int] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
output_scores: Optional[bool] = None, return_dict_in_generate: Optional[bool] = None, **model_kwargs) -> Union[
SampleOutput, torch.LongTensor]:
"""
This is a verbatim copy of the original implementation by huggingface, with a single modification to ensure that a text and all
corresponding self-debiasing inputs always chose the same token to generate next. This modification is enclosed by the texts
"BEGIN MODIFICATIONS" and "END MODIFICATIONS", respectively.
"""
# init values
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList()
max_length = max_length if max_length is not None else self.config.max_length
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
output_scores = output_scores if output_scores is not None else self.config.output_scores
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict_in_generate = (
return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate
)
# init attention / hidden states / scores tuples
scores = () if (return_dict_in_generate and output_scores) else None
decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
# if model is an encoder-decoder, retrieve encoder attention weights and hidden states
if return_dict_in_generate and self.config.is_encoder_decoder:
encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
encoder_hidden_states = (
model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
)
# init sequence length tensors
sequence_lengths, unfinished_sequences, cur_len = self._init_sequence_length_for_generation(
input_ids, max_length
)
# auto-regressive generation
while cur_len < max_length:
# prepare model inputs
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
# forward pass to get next token
outputs = self(
**model_inputs,
return_dict=True,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
next_token_logits = outputs.logits[:, -1, :]
# pre-process distribution
next_token_scores = logits_processor(input_ids, next_token_logits)
next_token_scores = logits_warper(input_ids, next_token_scores)
# Store scores, attentions and hidden_states when required
if return_dict_in_generate:
if output_scores:
scores += (next_token_scores,)
if output_attentions:
decoder_attentions += (
(outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
)
if output_hidden_states:
decoder_hidden_states += (
(outputs.decoder_hidden_states,)
if self.config.is_encoder_decoder
else (outputs.hidden_states,)
)
# sample
probs = F.softmax(next_token_scores, dim=-1)
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
# =========================
# BEGIN MODIFICATIONS
# the following modification to the sample method is necessary to ensure that each debiasing sentence is continued in the same
# way as the original sentence
if self.logits_processor is not None:
batch_size = next_tokens.shape[0] // (1 + self.logits_processor.num_debiasing_prefixes)
regular_sentence_indices = range(batch_size)
for regular_sentence_idx in regular_sentence_indices:
debiasing_sentence_indices = self.logits_processor._get_bias_indices(regular_sentence_idx, batch_size)
for debiasing_sentence_idx in debiasing_sentence_indices:
next_tokens[debiasing_sentence_idx] = next_tokens[regular_sentence_idx]
# END MODIFICATIONS
# =========================
# add code that transfomers next_tokens to tokens_to_add
if eos_token_id is not None:
assert pad_token_id is not None, "If eos_token_id is defined, make sure that pad_token_id is defined."
next_tokens = next_tokens * unfinished_sequences + (pad_token_id) * (1 - unfinished_sequences)
# add token and increase length by one
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
cur_len = cur_len + 1
# update sequence length
if eos_token_id is not None:
sequence_lengths, unfinished_sequences = self._update_seq_length_for_generation(
sequence_lengths, unfinished_sequences, cur_len, next_tokens == eos_token_id
)
# stop when there is a in each sentence, or if we exceed the maximul length
if unfinished_sequences.max() == 0:
break
# update model kwargs
model_kwargs = self._update_model_kwargs_for_generation(
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
)
if return_dict_in_generate:
if self.config.is_encoder_decoder:
return SampleEncoderDecoderOutput(
sequences=input_ids,
scores=scores,
encoder_attentions=encoder_attentions,
encoder_hidden_states=encoder_hidden_states,
decoder_attentions=decoder_attentions,
decoder_hidden_states=decoder_hidden_states,
)
else:
return SampleDecoderOnlyOutput(
sequences=input_ids,
scores=scores,
attentions=decoder_attentions,
hidden_states=decoder_hidden_states,
)
else:
return input_ids
================================================
FILE: modeling.py
================================================
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This module contains the core classes used by DINO.
"""
import math
from abc import ABC, abstractmethod
from typing import List, Optional, Dict, Any, Union
import openai
import torch
from tqdm import tqdm
from transformers import GPT2Tokenizer, PreTrainedTokenizer, PreTrainedModel
from generation import SelfDebiasingGPT2LMHeadModel
from utils import DatasetEntry
PLACEHOLDER_STR = ""
class DinoGenerator:
"""
This class represents a generative language model which can be used to generate datasets from instructions.
"""
def __init__(self, task_spec: Dict[str, Any], model: Union['str', 'ModelWrapper'] = None, openai_api_key: Optional[str] = None,
max_output_length: int = 40, decay_constant: float = 100, top_p: float = 0.9, top_k: int = 5,
remove_duplicates: bool = True, remove_identical_pairs: bool = False, min_num_words: int = -1, min_num_tokens: int = -1,
keep_outputs_without_eos: bool = False, allow_newlines_in_outputs: bool = False):
"""
:param task_spec: the task specification
:param model: a wrapper around the underlying language model.
If GPT-3 is used, this should instead be the name of the GPT-3 model (e.g., "davinci")
:param openai_api_key: an optional API key for GPT-3. If given, GPT-3 is used as a language model
:param max_output_length: the maximum output length for each generated text
:param decay_constant: the decay constant for self-debiasing
:param top_p: p value for top-p sampling (set to 0 to perform no top-p sampling)
:param top_k: k value for top-k sampling (set to 0 to perform no top-k sampling)
:param remove_duplicates: whether duplicates should be removed from the generated dataset
:param remove_identical_pairs: whether text pairs with identical texts should be removed (only for text pair datasets)
:param min_num_words: the minimum number of (whitespace-separated) words for each dataset entry
:param min_num_tokens: the minimum number of tokens for each dataset entry
:param keep_outputs_without_eos: if set to true, examples where the language model does not output a quotation mark (which is
interpreted as a signal that it has completed its output) are not removed from the dataset.
:param allow_newlines_in_outputs: if set to true, model outputs that contain a newline character before the end-of-sequence token
(a quotation mark) are not removed from the dataset
"""
self.model = model
self.openai_api_key = openai_api_key
self.max_output_length = max_output_length
self.decay_constant = decay_constant
self.top_p = top_p
self.top_k = top_k
self.remove_duplicates = remove_duplicates
self.remove_identical_pairs = remove_identical_pairs
self.min_num_words = min_num_words
self.min_num_tokens = min_num_tokens
self.keep_outputs_without_eos = keep_outputs_without_eos
self.allow_newlines_in_outputs = allow_newlines_in_outputs
self.labels = list(task_spec['labels'].keys())
self.instructions = {label: task_spec['labels'][label]['instruction'] for label in self.labels}
self.counter_labels = {label: task_spec['labels'][label].get('counter_labels', []) for label in self.labels}
def generate_dataset(self, input_texts: Optional[List[str]], num_entries_per_input_and_label: Optional[int] = None,
num_entries_per_label: Optional[int] = None, batch_size: Optional[int] = None) -> List[DatasetEntry]:
"""
Generate a new dataset.
:param input_texts: an optional list of raw texts; this is required for generating text pair datasets
:param num_entries_per_input_and_label: the number of entries to generate for each pair of input text and label
:param num_entries_per_label: the number of entries to generate for each label
:param batch_size: the number of entries to generate simultaneously
:return: the generated dataset
"""
generate_with_inputs = input_texts is not None
if not generate_with_inputs:
input_texts = list(range(math.ceil(num_entries_per_label / batch_size)))
num_entries_per_input_and_label = batch_size
input_iterator = tqdm(input_texts, desc="Dataset Entries")
dataset = []
for input_text_or_id in input_iterator:
for label in self.labels:
dataset += self._generate_dataset_entries(input_text_or_id, label=label, num_entries=num_entries_per_input_and_label,
generate_with_inputs=generate_with_inputs)
dataset = self._postprocess_dataset(dataset, generate_with_inputs)
return dataset
def _generate_dataset_entries(self, input_text_or_id: Union[str, int], label: str, num_entries: int,
generate_with_inputs: bool) -> List[DatasetEntry]:
instruction = self._build_instruction(label, input_text_or_id, generate_with_inputs)
if self.openai_api_key is not None:
try:
model_responses = [openai.Completion.create(
engine=self.model, prompt=instruction, max_tokens=self.max_output_length, top_p=self.top_p, stop=['"']
) for _ in range(num_entries)]
model_outputs = [model_response["choices"][0]["text"] for model_response in model_responses]
except openai.error.RateLimitError as e:
print(e)
return []
else:
counter_instructions = [
self._build_instruction(other_label, input_text_or_id, generate_with_inputs) for other_label in self.counter_labels[label]
]
model_outputs = self.model.generate_self_debiasing(
input_text=instruction, debiasing_texts=counter_instructions, num_samples=num_entries, decay_constant=self.decay_constant,
do_sample=True, min_length=self.max_output_length, max_length=self.max_output_length, top_k=self.top_k, top_p=self.top_p
)
model_outputs = [
self._process_output(input_text=input_text_or_id, output_text=output, label=label, generate_with_inputs=generate_with_inputs)
for output in model_outputs
]
model_outputs = [output for output in model_outputs if output is not None]
return model_outputs
def _build_instruction(self, label: str, text: str, generate_with_inputs: bool) -> str:
instruction_template = self.instructions[label]
if generate_with_inputs:
assert instruction_template.count(PLACEHOLDER_STR) == 1, \
f"An input text was provided, but the instruction for label '{label}' does not contain exactly one placeholder"
return instruction_template.replace(PLACEHOLDER_STR, text)
else:
assert instruction_template.count(PLACEHOLDER_STR) == 0, \
f"No input text was provided, but the instruction for label '{label}' contains a placeholder"
return instruction_template
def _process_output(self, input_text: Union[str, int], output_text: str, label: str, generate_with_inputs: bool) \
-> Optional[DatasetEntry]:
output_text = output_text.split('"')[0] if '"' in output_text else (output_text if self.keep_outputs_without_eos else None)
if output_text and ('\n' not in output_text or self.allow_newlines_in_outputs):
text_a = input_text if generate_with_inputs else output_text
text_b = output_text if generate_with_inputs else None
return DatasetEntry(text_a=text_a, text_b=text_b, label=label)
return None
def _postprocess_dataset(self, dataset: List[DatasetEntry], generate_with_inputs: bool) -> List[DatasetEntry]:
if self.remove_duplicates:
dataset = list(set(dataset))
if self.min_num_words > 0:
if generate_with_inputs:
dataset = [entry for entry in dataset if len(entry.text_b.split()) >= self.min_num_words]
else:
dataset = [entry for entry in dataset if len(entry.text_a.split()) >= self.min_num_words]
if self.min_num_tokens > 0:
if generate_with_inputs:
dataset = [entry for entry in dataset if len(self.model._tokenizer.tokenize(entry.text_b)) >= self.min_num_tokens]
else:
dataset = [entry for entry in dataset if len(self.model._tokenizer.tokenize(entry.text_a)) >= self.min_num_tokens]
if generate_with_inputs and self.remove_identical_pairs:
dataset = [entry for entry in dataset if entry.text_a != entry.text_b]
return dataset
class ModelWrapper(ABC):
"""
This class represents a wrapper for a pretrained language model that provides high-level functions for the generation of texts with
the self-debiasing method described in https://arxiv.org/abs/2103.00453.
"""
def __init__(self, use_cuda: bool = True):
"""
:param use_cuda: whether to use CUDA
"""
self._device = "cuda" if torch.cuda.is_available() and use_cuda else "cpu"
self._tokenizer = None # type: Optional[PreTrainedTokenizer]
self._model = None # type: Optional[PreTrainedModel]
def query_model(self, input_text: str) -> torch.FloatTensor:
"""For a given input text, returns the probability distribution over possible next tokens."""
return self.query_model_batch([input_text])[0]
@abstractmethod
def query_model_batch(self, input_texts: List[str]) -> torch.FloatTensor:
"""For a batch of input texts, returns the probability distribution over possible next tokens."""
pass
@abstractmethod
def generate(self, input_text: str, **kwargs) -> str:
"""Generates a continuation for a given input text."""
pass
@abstractmethod
def generate_self_debiasing(self, input_text: str, debiasing_texts: List[str], num_samples: int = 1, decay_constant: float = 100,
epsilon: float = 0.01, debug: bool = False, **kwargs) -> List[str]:
"""
Generates continuations for the given input texts with self-debiasing.
:param input_texts: the input texts to generate continuations for
:param debiasing_prefixes: the debiasing prefixes to be used
:param decay_constant: the decay constant (lambda in the paper)
:param epsilon: the minimum factor by which each probability is multiplied
:param debug: whether to print additional debugging output
:param kwargs: further arguments are passed on to the original generate function
:return: the list of generated continuations
"""
pass
class GPT2Wrapper(ModelWrapper):
def __init__(self, model_name: str = "gpt2-xl", use_cuda: bool = True):
"""
:param model_name: the name of the pretrained GPT2 model (default: "gpt2-xl")
:param use_cuda: whether to use CUDA
"""
super().__init__(use_cuda=use_cuda)
self._tokenizer = GPT2Tokenizer.from_pretrained(model_name)
self._model = SelfDebiasingGPT2LMHeadModel.from_pretrained(model_name) # type: SelfDebiasingGPT2LMHeadModel
if use_cuda:
self._model.parallelize()
self._tokenizer.pad_token = self._tokenizer.eos_token
self._model.config.pad_token_id = self._tokenizer.eos_token_id
def query_model_batch(self, input_texts: List[str]):
inputs = self._tokenizer.batch_encode_plus(input_texts, padding=True, max_length=512, return_tensors='pt')
inputs = {key: val.to(self._device) for key, val in inputs.items()}
output_indices = inputs['attention_mask'].sum(dim=1) - 1
output = self._model(**inputs)['logits']
return torch.stack([output[example_idx, last_word_idx, :] for example_idx, last_word_idx in enumerate(output_indices)])
def generate(self, input_text: str, **kwargs):
input_ids = self._tokenizer.encode(input_text, return_tensors='pt').to(self._device)
output_ids = self._model.generate(input_ids, **kwargs)[0]
return self._tokenizer.decode(output_ids)
def generate_self_debiasing(self, input_text: str, debiasing_texts: List[str], num_samples: int = 1, decay_constant: float = 100,
epsilon: float = 0.01, debug: bool = False, min_length: int = None, max_length: int = None,
**kwargs) -> List[str]:
self._model.init_logits_processor(num_debiasing_prefixes=len(debiasing_texts), decay_constant=decay_constant, epsilon=epsilon,
debug=debug, tokenizer=self._tokenizer)
inputs = [input_text] * num_samples
for debiasing_text in debiasing_texts:
inputs += [debiasing_text] * num_samples
inputs = self._tokenizer.batch_encode_plus(inputs, padding=True, return_tensors='pt')
inputs['attention_mask'] = torch.flip(inputs['attention_mask'], dims=[1])
shifts = inputs['attention_mask'].shape[-1] - inputs['attention_mask'].sum(dim=-1)
for batch_idx in range(inputs['input_ids'].shape[0]):
inputs['input_ids'][batch_idx] = inputs['input_ids'][batch_idx].roll(shifts[batch_idx].item())
inputs = {k: v.to(self._device) for k, v in inputs.items()}
input_length = inputs['input_ids'].shape[1]
if min_length is not None:
min_length = min_length + input_length
if max_length is not None:
max_length = min(self.model._model.config.max_position_embeddings, max_length + input_length)
output_ids = self._model.generate(**inputs, min_length=min_length, max_length=max_length, **kwargs)
batch_size = output_ids.shape[0] // (1 + len(debiasing_texts))
output_ids = output_ids[:batch_size, inputs['input_ids'].shape[1]:]
return self._tokenizer.batch_decode(output_ids)
================================================
FILE: requirements.txt
================================================
-f https://download.pytorch.org/whl/torch_stable.html
numpy==1.19
torch==1.5.0
torchvision==0.6.0
transformers==4.2.1
tqdm==4.49.0
scikit-learn==0.24.1
datasets==1.6.0
openai==0.6.3
scipy==1.6.2
================================================
FILE: scripts/imdb/__init__.py
================================================
================================================
FILE: scripts/imdb/run_supervised.py
================================================
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script can be used to train and evaluate a regular supervised model trained with DINO datasets on the IMDb dataset.
"""
import argparse
import os
from typing import Dict
import torch
import torch.utils.data
from sklearn.model_selection import train_test_split
from datasets import load_dataset, load_metric
from transformers import AutoModelForSequenceClassification, AutoTokenizer, Trainer, TrainingArguments, EvaluationStrategy
from utils import DatasetEntry
class IMDbDataset(torch.utils.data.Dataset):
def __init__(self, encodings, labels):
self.encodings = encodings
self.labels = labels
def __getitem__(self, idx):
item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
item['labels'] = torch.tensor(self.labels[idx])
return item
def __len__(self):
return len(self.labels)
def compute_metrics(eval_pred):
predictions, labels = eval_pred
predictions = predictions.argmax(axis=-1)
return metric.compute(predictions=predictions, references=labels)
def load_datasets(args) -> Dict[str, IMDbDataset]:
dino_ds = DatasetEntry.read_list(args.input_file)
imdb_ds = load_dataset("")['test']
imdb500_ds = DatasetEntry.read_list(args.imdb_500_file)
train_texts, train_labels = [x.text_b for x in dino_ds], [int(x.label) for x in dino_ds]
test_texts, test_labels = [x.replace('
', '\n') for x in imdb_ds['text']], imdb_ds['label']
imdb500_texts, imdb500_labels = [x.text_a for x in imdb500_ds], [int(x.label) for x in imdb500_ds]
train_texts, val_texts, train_labels, val_labels = train_test_split(train_texts, train_labels, test_size=.1, random_state=42)
train_encodings = tokenizer(train_texts, truncation=True, padding=True)
val_encodings = tokenizer(val_texts, truncation=True, padding=True)
test_encodings = tokenizer(test_texts, truncation=True, padding=True)
imdb500_encodings = tokenizer(imdb500_texts, truncation=True, padding=True)
return {
'train': IMDbDataset(train_encodings, train_labels),
'val': IMDbDataset(val_encodings, val_labels),
'test': IMDbDataset(test_encodings, test_labels),
'imdb500': IMDbDataset(imdb500_encodings, imdb500_labels)
}
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--output_dir", type=str, required=True,
help="Path to an output directory were the finetuned model and results are saved")
parser.add_argument("--input_file", type=str, required=True,
help="Path to a jsonl file generated with DINO containing the training dataset")
parser.add_argument("--imdb_500_file", type=str, required=True,
help="Path to the IMDb-500 dataset for evaluation")
parser.add_argument("--model", type=str, default="roberta-base",
help="Name of the pretrained model to use for finetuning")
parser.add_argument("--per_device_train_batch_size", type=int, default=8,
help="The training batch size per GPU")
parser.add_argument("--per_device_eval_batch_size", type=int, default=32,
help="The eval batch size per GPU")
parser.add_argument("--gradient_accumulation_steps", type=int, default=1,
help="The number of gradient accumulation steps to perform")
parser.add_argument("--learning_rate", type=float, default=1e-5,
help="The maximum learning rate")
parser.add_argument("--warmup_steps", type=int, default=100,
help="The number of initial warmup steps")
parser.add_argument("--max_steps", type=int, default=-1,
help="The maximum number of training steps")
parser.add_argument("--num_train_epochs", type=int, default=1,
help="The maximum number of training epochs")
parser.add_argument("--seeds", type=int, nargs='+', default=[42, 100, 123],
help="The seeds to use. If multiple are given, the entire finetuning process is repeated multiple times.")
args = parser.parse_args()
tokenizer = AutoTokenizer.from_pretrained(args.model)
model = AutoModelForSequenceClassification.from_pretrained(args.model)
metric = load_metric("accuracy")
datasets = load_datasets(args)
for seed in args.seeds:
output_dir = os.path.join(args.output_dir, str(seed))
training_args = TrainingArguments(
output_dir=output_dir, num_train_epochs=args.num_train_epochs, max_steps=args.max_steps,
per_device_train_batch_size=args.per_device_train_batch_size, per_device_eval_batch_size=args.per_device_eval_batch_size,
gradient_accumulation_steps=args.gradient_accumulation_steps, learning_rate=args.learning_rate, warmup_steps=args.warmup_steps,
weight_decay=0.01, logging_dir='./logs', logging_steps=100, evaluation_strategy=EvaluationStrategy.STEPS,
load_best_model_at_end=True, metric_for_best_model="accuracy", seed=seed
)
trainer = Trainer(
model=model, args=training_args, train_dataset=datasets['train'], eval_dataset=datasets['val'], compute_metrics=compute_metrics
)
trainer.train()
with open(os.path.join(args.output_dir, f'results-{seed}.txt'), 'w', encoding='utf8') as fh:
print("Evaluating on IMDb500")
result_imdb500 = trainer.evaluate(eval_dataset=datasets['imdb500'])
print(result_imdb500)
fh.write("=== IMDb500 ===\n")
fh.write(str(result_imdb500) + '\n\n')
print("Evaluating on IMDb test")
result_test = trainer.evaluate(eval_dataset=datasets['test'])
print(result_test)
fh.write("=== IMDb test ===\n")
fh.write(str(result_test) + '\n')
================================================
FILE: scripts/imdb/run_unsupervised.py
================================================
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script can be used to evaluate unsupervised models on the IMDb dataset using prompts.
"""
import argparse
import math
import openai
import torch
from tqdm import tqdm
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForMaskedLM, AutoModelForCausalLM
from utils import DatasetEntry
class CausalLMWrapper:
"""A wrapper for a causal language model (like GPT-2)"""
def __init__(self, model_name: str, use_cuda: bool = True):
self._device = "cuda" if torch.cuda.is_available() and use_cuda else "cpu"
self._tokenizer = AutoTokenizer.from_pretrained(model_name)
self._model = AutoModelForCausalLM.from_pretrained(model_name)
if use_cuda:
self._model.parallelize()
def get_token_probabilities(self, input_text: str, prompt: str) -> torch.Tensor:
input_text = input_text + prompt
inputs = self._tokenizer.batch_encode_plus([input_text], truncation=True, return_tensors='pt')
inputs = {key: val.to(self._device) for key, val in inputs.items()}
output = self._model(**inputs)['logits']
return output[:, -1, :]
class MaskedLMWrapper:
"""A wrapper for a masked language model (like BERT)"""
def __init__(self, model_name: str, use_cuda: bool = True):
self._device = "cuda" if torch.cuda.is_available() and use_cuda else "cpu"
self._tokenizer = AutoTokenizer.from_pretrained(model_name)
self._model = AutoModelForMaskedLM.from_pretrained(model_name).to(self._device)
def get_token_probabilities(self, input_text: str, prompt: str) -> torch.Tensor:
text_ids = self._tokenizer.encode(input_text, truncation=True, add_special_tokens=False)
prompt_ids = self._tokenizer.encode(prompt, truncation=False, add_special_tokens=False)
max_len = self._tokenizer.model_max_length
max_len_for_text_ids = max_len - len(prompt_ids) - self._tokenizer.num_special_tokens_to_add(False)
text_ids = text_ids[:max_len_for_text_ids]
input_ids = text_ids + prompt_ids
input_ids = torch.tensor([self._tokenizer.build_inputs_with_special_tokens(input_ids)], device=self._device)
assert sum(1 for id_ in input_ids[0] if id_ == self._tokenizer.mask_token_id) == 1, \
f"Input text must contain exactly one mask token ('{self._tokenizer.mask_token}'). Got '{input_text}'."
scores = self._model(input_ids)['logits']
mask_positions = (input_ids == self._tokenizer.mask_token_id)
return scores[mask_positions]
class GPT3Wrapper:
"""A wrapper around OpenAI's GPT-3 API"""
def __init__(self, engine: str):
self.engine = engine
def get_scores(self, prompt: str):
response = openai.Completion.create(engine=self.engine, prompt=prompt, max_tokens=1, logprobs=100)
top_logprobs = response['choices'][0]['logprobs']['top_logprobs'][0]
positive_score = max([top_logprobs.get(" good", -math.inf)])
negative_score = max([top_logprobs.get(" bad", -math.inf)])
return positive_score, negative_score
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--model_type", choices=["causal", "masked", "gpt3"], required=True,
help="The type of the model to evaluate. One of 'causal' (for causal language models like GPT-2), 'masked' (for "
"masked language models like BERT), and 'gpt3' (for GPT-3 models accessed via OpenAI's API)")
parser.add_argument("--model_name", type=str, required=True,
help="The name of the pretrained model to use (e.g., 'roberta-large')")
parser.add_argument("--openai_api_key", type=str,
help="An optional key for OpenAI's API (only if --model_type is gpt3)")
parser.add_argument("--test_file", type=str,
help="An optional path to a jsonl file of dataset entries. If not given, the entire IMDb dataset is used.")
parser.add_argument("--no_cuda", action='store_true',
help="If set to true, inference is done on CPU only")
args = parser.parse_args()
if args.test_file:
print(f"Evaluating on entries from '{args.test_file}'")
dataset = DatasetEntry.read_list(args.test_file)
print(f"Done loading {len(dataset)} examples from '{args.test_file}'")
else:
print("Evaluating on the entire IMDb test set")
dataset = load_dataset('imdb')['test']
dataset = [DatasetEntry(text_a=text, text_b=None, label=label) for text, label in zip(dataset['text'], dataset['label'])]
print(f"Done loading {len(dataset)} examples")
if args.openai_api_key:
openai.api_key = args.openai_api_key
predictions, labels = [], []
if args.model_type == "causal":
model = CausalLMWrapper(args.model_name, use_cuda=not args.no_cuda)
prompt = "\nQuestion: Is this movie good or bad?\nAnswer: It is"
elif args.model_type == "masked":
model = MaskedLMWrapper(args.model_name, use_cuda=not args.no_cuda)
prompt = "\nQuestion: Is this movie good or bad?\nAnswer: It is ."
elif args.model_type == "gpt3":
model = GPT3Wrapper(args.model_name)
prompt = "\nQuestion: Is this movie good or bad?\nAnswer: It is"
else:
raise ValueError()
dataset_iterator = tqdm(dataset)
for ds_entry in dataset_iterator:
if args.model_type == "gpt3":
instance_prompt = ds_entry.text_a + prompt
positive_score, negative_score = model.get_scores(instance_prompt)
else:
token_probabilities = model.get_token_probabilities(input_text=ds_entry.text_a, prompt=prompt)[0].detach()
positive_score = token_probabilities[model._tokenizer.convert_tokens_to_ids("Ġgood")]
negative_score = token_probabilities[model._tokenizer.convert_tokens_to_ids("Ġbad")]
labels.append(int(ds_entry.label))
predictions.append(1 if positive_score > negative_score else 0)
dataset_iterator.set_description(f"Texts (acc={100 * sum(1 for x, y in zip(labels, predictions) if x == y) / len(labels):5.2f})")
dataset_iterator.refresh()
print(f"Final accuracy: {sum(1 for x, y in zip(labels, predictions) if x == y) / len(labels)} (total: {len(labels)})")
================================================
FILE: scripts/sts/postprocess_dataset.py
================================================
import argparse
import random
from collections import defaultdict
from typing import List
from utils import DatasetEntry
def postprocess_dataset(
dataset: List[DatasetEntry],
remove_identical_pairs: bool = True,
remove_duplicates: bool = True,
add_sampled_pairs: bool = True,
max_num_text_b_for_text_a_and_label: int = 2,
label_smoothing: float = 0.2,
seed: int = 42
) -> List[DatasetEntry]:
"""
Apply various postprocessing steps to a STS dataset.
:param dataset: The dataset to postprocess.
:param remove_identical_pairs: If set to true, we remove all pairs (x1, x2, y) where x1 == x2 as a bi-encoder cannot learn from them.
:param remove_duplicates: If set to true, if there are pairs (x1, x2, y) and (x1', x2', y') with x1 == x1', x2 == x2', y == y', we
only keep one of them.
:param add_sampled_pairs: If set to true, we add pairs of randomly sampled x1's and x2's and similarity 0 to the dataset.
:param max_num_text_b_for_text_a_and_label: We keep at most this many examples for each pair of text_a and similarity label.
:param label_smoothing: The amount of label smoothing to apply.
:param seed: The seed for the random number generator used to shuffle the dataset and for adding sampled pairs.
:return: The postprocessed dataset.
"""
postprocessed_dataset = []
num_text_b_for_text_a_and_label = defaultdict(int)
rng = random.Random(seed)
rng.shuffle(dataset)
if remove_duplicates:
dataset = list(set(dataset))
for example in dataset:
if remove_identical_pairs and example.text_a == example.text_b:
continue
example.label = example.label * (1 - label_smoothing) + (label_smoothing / 3 * 1.5)
if max_num_text_b_for_text_a_and_label > 0:
if num_text_b_for_text_a_and_label[(example.text_a, example.label)] >= max_num_text_b_for_text_a_and_label:
continue
postprocessed_dataset.append(example)
num_text_b_for_text_a_and_label[(example.text_a, example.label)] += 1
if add_sampled_pairs:
sampled_dataset = []
for text_a in set(x.text_a for x in postprocessed_dataset):
for _ in range(max_num_text_b_for_text_a_and_label):
text_b = rng.choice(postprocessed_dataset).text_b
sampled_dataset.append(DatasetEntry(text_a=text_a, text_b=text_b, label=0))
postprocessed_dataset += sampled_dataset
return postprocessed_dataset
if __name__ == '__main__':
parser = argparse.ArgumentParser()
# Required parameters
parser.add_argument("--input_file", type=str, required=True,
help="The input file which contains the STS dataset")
parser.add_argument("--output_file", type=str, required=True,
help="The output file to which the postprocessed STS dataset is saved")
args = parser.parse_args()
ds = DatasetEntry.read_list(args.input_file)
ds_pp = postprocess_dataset(ds)
DatasetEntry.save_list(ds_pp, args.output_file)
================================================
FILE: scripts/sts/run_training.py
================================================
import argparse
import logging
import random
import os
import gzip
import csv
import math
from collections import defaultdict, OrderedDict
from datetime import datetime
from typing import List, Dict
import numpy as np
import torch
from torch.utils.data import DataLoader
from sentence_transformers import SentenceTransformer, losses, models, util, LoggingHandler
from sentence_transformers.evaluation import EmbeddingSimilarityEvaluator
from sentence_transformers.readers import InputExample
from utils import DatasetEntry
def download_sts_dataset(sts_dataset_path: str) -> None:
"""Download the STS dataset if it isn't already present."""
if not os.path.exists(sts_dataset_path):
util.http_get('https://sbert.net/datasets/stsbenchmark.tsv.gz', sts_dataset_path)
def set_seed(seed: int) -> None:
"""Set RNG seeds for python's `random` module, numpy and torch."""
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
def build_sentence_transformer(model_name: str) -> SentenceTransformer:
"""Build the Sentence Transformer model."""
try:
word_embedding_model = models.Transformer(model_name)
pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension(), pooling_mode_mean_tokens=True,
pooling_mode_cls_token=False, pooling_mode_max_tokens=False)
return SentenceTransformer(modules=[word_embedding_model, pooling_model])
except OSError:
return SentenceTransformer(model_name)
def split_dataset(ds: List[DatasetEntry], dev_size: float = 0.1, seed: int = 42) -> Dict[str, List[DatasetEntry]]:
"""Split a dataset into a train and dev set.
The split is performed such that the distribution of labels is identical for the training and development set.
:param ds: The dataset to split.
:param dev_size: The relative size of the development set, in the range (0,1).
:param seed: The seed used to initialize the random number generator.
:return: A dictionary with keys "train" and "dev", whose values are the corresponding datasets.
"""
train, dev = [], []
rng = random.Random(seed)
ds_grouped_by_label = defaultdict(list)
for x in ds:
ds_grouped_by_label[x.label].append(x)
for label_list in ds_grouped_by_label.values():
rng.shuffle(label_list)
num_dev_examples = int(len(label_list) * dev_size)
train += label_list[num_dev_examples:]
dev += label_list[:num_dev_examples]
return {'train': train, 'dev': dev}
if __name__ == '__main__':
parser = argparse.ArgumentParser()
# Required parameters
parser.add_argument("--input_file", type=str, required=True,
help="The JSONL file that contains the DINO-generated dataset to train on.")
parser.add_argument("--output_dir", type=str, required=True,
help="The output directory for storing the trained model and evaluation results.")
# Model and training parameters
parser.add_argument("--model_name", type=str, default='roberta-base',
help="The pretrained Transformer language model to use.")
parser.add_argument("--train_batch_size", type=int, default=32,
help="The batch size used for training.")
parser.add_argument("--num_epochs", type=int, default=1,
help="The number of epochs to train for.")
parser.add_argument("--seed", type=int, default=42,
help="The seed used to initialize all random number generators.")
# Evaluation parameters
parser.add_argument("--sts_dataset_path", type=str, default="datasets/stsbenchmark.tsv.gz",
help="The path to the STSb dataset. The STSb dataset is downloaded and saved at this path if it does not exist.")
args = parser.parse_args()
logging.basicConfig(format='%(asctime)s - %(message)s', datefmt='%Y-%m-%d %H:%M:%S', level=logging.INFO, handlers=[LoggingHandler()])
if not os.path.exists(args.output_dir):
os.makedirs(args.output_dir)
input_filename = os.path.basename(args.input_file)
set_seed(args.seed)
args.date = datetime.now().strftime("%d/%m/%Y %H:%M:%S")
print(f"Parameters: {args}")
# We write all arguments to a file for better reproducibility.
args_file = os.path.join(args.output_dir, f'args-{input_filename}.jsonl')
with open(args_file, 'w', encoding='utf8') as fh:
fh.write(str(vars(args)))
# If the STSb dataset does not exist, we download it.
download_sts_dataset(args.sts_dataset_path)
model = build_sentence_transformer(args.model_name)
model_save_name = '_'.join([input_filename, args.model_name.replace("/", "-"), args.date.replace("/", "-").replace(" ", "_")])
model_save_path = os.path.join(args.output_dir, model_save_name)
# Load and split the (postprocessed) STS-DINO dataset.
dataset = DatasetEntry.read_list(args.input_file)
dataset = split_dataset(dataset, dev_size=0.1, seed=args.seed)
train_samples = [InputExample(texts=[x.text_a, x.text_b], label=x.label) for x in dataset['train']]
dev_samples = [InputExample(texts=[x.text_a, x.text_b], label=x.label) for x in dataset['dev']]
train_dataloader = DataLoader(train_samples, shuffle=True, batch_size=args.train_batch_size)
train_loss = losses.CosineSimilarityLoss(model=model)
evaluator = EmbeddingSimilarityEvaluator.from_input_examples(dev_samples, name='sts-dino-dev')
# We use 10% of the training data for warm-up.
warmup_steps = math.ceil(len(train_dataloader) * args.num_epochs * 0.1)
# Train the model.
model.fit(train_objectives=[(train_dataloader, train_loss)],
evaluator=evaluator,
epochs=args.num_epochs,
evaluation_steps=100,
warmup_steps=warmup_steps,
output_path=model_save_path)
# Load the trained model and perform evaluation.
if args.num_epochs > 0:
model = SentenceTransformer(model_save_path)
results = OrderedDict()
stsb_samples = []
with gzip.open(args.sts_dataset_path, 'rt', encoding='utf8') as fIn:
reader = csv.DictReader(fIn, delimiter='\t', quoting=csv.QUOTE_NONE)
for row in reader:
score = float(row['score']) / 5.0 # Normalize score to range 0 ... 1.
inp_example = InputExample(texts=[row['sentence1'], row['sentence2']], label=score)
if row['split'] == 'test':
stsb_samples.append(inp_example)
stsb_evaluator = EmbeddingSimilarityEvaluator.from_input_examples(stsb_samples, name='stsb-test')
results['stsb'] = stsb_evaluator(model, output_path='.')
print(results)
with open(os.path.join(args.output_dir, f'{input_filename}-results.txt'), 'w', encoding='utf8') as fh:
for task, result in results.items():
fh.write(f'{task}: {result}\n')
================================================
FILE: task_specs/imdb-movies.json
================================================
{
"task_name": "imdb",
"labels": {
"0": {
"instruction": "Task: Write a review for a bad movie.\nMovie: \"",
"counter_labels": []
},
"1": {
"instruction": "Task: Write a review for a good movie.\nMovie: \"",
"counter_labels": []
}
}
}
================================================
FILE: task_specs/imdb-reviews.json
================================================
{
"task_name": "imdb",
"labels": {
"0": {
"instruction": "Task: Write a review for a bad movie.\nMovie: \"\"\nReview: \"",
"counter_labels": [
"1"
]
},
"1": {
"instruction": "Task: Write a review for a good movie.\nMovie: \"\"\nReview: \"",
"counter_labels": [
"0"
]
}
}
}
================================================
FILE: task_specs/sts-x1.json
================================================
{
"task_name": "sts-x1",
"labels": {
"1": {
"instruction": "Task: Write two sentences that mean the same thing.\nSentence 1: \"",
"counter_labels": []
},
"0.5": {
"instruction": "Task: Write two sentences that are somewhat similar.\nSentence 1: \"",
"counter_labels": []
},
"0": {
"instruction": "Task: Write two sentences that are on completely different topics.\nSentence 1: \"",
"counter_labels": []
}
}
}
================================================
FILE: task_specs/sts-x2.json
================================================
{
"task_name": "sts",
"labels": {
"1": {
"instruction": "Task: Write two sentences that mean the same thing.\nSentence 1: \"\"\nSentence 2: \"",
"counter_labels": []
},
"0.5": {
"instruction": "Task: Write two sentences that are somewhat similar.\nSentence 1: \"\"\nSentence 2: \"",
"counter_labels": [
"1"
]
},
"0": {
"instruction": "Task: Write two sentences that are on completely different topics.\nSentence 1: \"\"\nSentence 2: \"",
"counter_labels": [
"1",
"0.5"
]
}
}
}
================================================
FILE: utils.py
================================================
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This module contains some utility functions.
"""
import csv
import json
import random
from typing import List, Optional, Any
import numpy as np
import torch
def set_seed(seed: int) -> None:
"""Set RNG seeds for python's `random` module, numpy and torch"""
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
def read_inputs(input_file: str, input_file_type: str) -> List[str]:
"""
Read a list of input texts from a text file.
:param input_file: the path to the input file
:param input_file_type: the file type, one of 'plain', 'jsonl' and 'stsb':
- 'plain': a plain text file where each line corresponds to one input
- 'jsonl': a jsonl file, where each line is one json object and input texts are stored in the field 'text_a'
- 'stsb': a tsv file, formatted like the official STS benchmark
:return: the list of extracted input texts
"""
valid_types = ['plain', 'jsonl', 'stsb']
assert input_file_type in valid_types, f"Invalid input file type: '{input_file_type}'. Valid types: {valid_types}"
if input_file_type == "plain":
return read_plaintext_inputs(input_file)
elif input_file_type == "jsonl":
return read_jsonl_inputs(input_file)
elif input_file_type == "stsb":
return read_sts_inputs(input_file)
def read_plaintext_inputs(path: str) -> List[str]:
"""Read input texts from a plain text file where each line corresponds to one input"""
with open(path, 'r', encoding='utf8') as fh:
inputs = fh.read().splitlines()
print(f"Done loading {len(inputs)} inputs from file '{path}'")
return inputs
def read_jsonl_inputs(path: str) -> List[str]:
"""Read input texts from a jsonl file, where each line is one json object and input texts are stored in the field 'text_a'"""
ds_entries = DatasetEntry.read_list(path)
print(f"Done loading {len(ds_entries)} inputs from file '{path}'")
return [entry.text_a for entry in ds_entries]
def read_sts_inputs(path: str) -> List[str]:
"""Read input texts from a tsv file, formatted like the official STS benchmark"""
inputs = []
with open(path, 'r', encoding='utf8') as fh:
reader = csv.reader(fh, delimiter='\t', quoting=csv.QUOTE_NONE)
for row in reader:
try:
sent_a, sent_b = row[5], row[6]
inputs.append(sent_a)
inputs.append(sent_b)
except IndexError:
print(f"Cannot parse line {row}")
print(f"Done loading {len(inputs)} inputs from file '{path}'")
return inputs
class DatasetEntry:
"""This class represents a dataset entry for text (pair) classification"""
def __init__(self, text_a: str, text_b: Optional[str], label: Any):
self.text_a = text_a
self.text_b = text_b
self.label = label
def __repr__(self):
if self.text_b is not None:
return f'DatasetEntry(text_a="{self.text_a}", text_b="{self.text_b}", label={self.label})'
else:
return f'DatasetEntry(text_a="{self.text_a}", label={self.label})'
def __key(self):
return self.text_a, self.text_b, self.label
def __hash__(self):
return hash(self.__key())
def __eq__(self, other):
if isinstance(other, DatasetEntry):
return self.__key() == other.__key()
return False
@staticmethod
def save_list(entries: List['DatasetEntry'], path: str):
with open(path, 'w', encoding='utf8') as fh:
for entry in entries:
fh.write(f'{json.dumps(entry.__dict__)}\n')
@staticmethod
def read_list(path: str) -> List['DatasetEntry']:
pairs = []
with open(path, 'r', encoding='utf8') as fh:
for line in fh:
pairs.append(DatasetEntry(**json.loads(line)))
return pairs