[
  {
    "path": ".gitignore",
    "content": "*.pyc\n*.so\n*.jsonlines\nlogs\nconll-2012\nchar_vocab*.txt\nglove*.txt\nglove*.txt.filtered\n*.v*_*_conll\n*.hdf5"
  },
  {
    "path": "LICENSE",
    "content": "                                 Apache License\n                           Version 2.0, January 2004\n                        http://www.apache.org/licenses/\n\n   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION\n\n   1. Definitions.\n\n      \"License\" shall mean the terms and conditions for use, reproduction,\n      and distribution as defined by Sections 1 through 9 of this document.\n\n      \"Licensor\" shall mean the copyright owner or entity authorized by\n      the copyright owner that is granting the License.\n\n      \"Legal Entity\" shall mean the union of the acting entity and all\n      other entities that control, are controlled by, or are under common\n      control with that entity. For the purposes of this definition,\n      \"control\" means (i) the power, direct or indirect, to cause the\n      direction or management of such entity, whether by contract or\n      otherwise, or (ii) ownership of fifty percent (50%) or more of the\n      outstanding shares, or (iii) beneficial ownership of such entity.\n\n      \"You\" (or \"Your\") shall mean an individual or Legal Entity\n      exercising permissions granted by this License.\n\n      \"Source\" form shall mean the preferred form for making modifications,\n      including but not limited to software source code, documentation\n      source, and configuration files.\n\n      \"Object\" form shall mean any form resulting from mechanical\n      transformation or translation of a Source form, including but\n      not limited to compiled object code, generated documentation,\n      and conversions to other media types.\n\n      \"Work\" shall mean the work of authorship, whether in Source or\n      Object form, made available under the License, as indicated by a\n      copyright notice that is included in or attached to the work\n      (an example is provided in the Appendix below).\n\n      \"Derivative Works\" shall mean any work, whether in Source or Object\n      form, that is based on (or derived from) the Work and for which the\n      editorial revisions, annotations, elaborations, or other modifications\n      represent, as a whole, an original work of authorship. For the purposes\n      of this License, Derivative Works shall not include works that remain\n      separable from, or merely link (or bind by name) to the interfaces of,\n      the Work and Derivative Works thereof.\n\n      \"Contribution\" shall mean any work of authorship, including\n      the original version of the Work and any modifications or additions\n      to that Work or Derivative Works thereof, that is intentionally\n      submitted to Licensor for inclusion in the Work by the copyright owner\n      or by an individual or Legal Entity authorized to submit on behalf of\n      the copyright owner. For the purposes of this definition, \"submitted\"\n      means any form of electronic, verbal, or written communication sent\n      to the Licensor or its representatives, including but not limited to\n      communication on electronic mailing lists, source code control systems,\n      and issue tracking systems that are managed by, or on behalf of, the\n      Licensor for the purpose of discussing and improving the Work, but\n      excluding communication that is conspicuously marked or otherwise\n      designated in writing by the copyright owner as \"Not a Contribution.\"\n\n      \"Contributor\" shall mean Licensor and any individual or Legal Entity\n      on behalf of whom a Contribution has been received by Licensor and\n      subsequently incorporated within the Work.\n\n   2. Grant of Copyright License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      copyright license to reproduce, prepare Derivative Works of,\n      publicly display, publicly perform, sublicense, and distribute the\n      Work and such Derivative Works in Source or Object form.\n\n   3. Grant of Patent License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      (except as stated in this section) patent license to make, have made,\n      use, offer to sell, sell, import, and otherwise transfer the Work,\n      where such license applies only to those patent claims licensable\n      by such Contributor that are necessarily infringed by their\n      Contribution(s) alone or by combination of their Contribution(s)\n      with the Work to which such Contribution(s) was submitted. If You\n      institute patent litigation against any entity (including a\n      cross-claim or counterclaim in a lawsuit) alleging that the Work\n      or a Contribution incorporated within the Work constitutes direct\n      or contributory patent infringement, then any patent licenses\n      granted to You under this License for that Work shall terminate\n      as of the date such litigation is filed.\n\n   4. Redistribution. You may reproduce and distribute copies of the\n      Work or Derivative Works thereof in any medium, with or without\n      modifications, and in Source or Object form, provided that You\n      meet the following conditions:\n\n      (a) You must give any other recipients of the Work or\n          Derivative Works a copy of this License; and\n\n      (b) You must cause any modified files to carry prominent notices\n          stating that You changed the files; and\n\n      (c) You must retain, in the Source form of any Derivative Works\n          that You distribute, all copyright, patent, trademark, and\n          attribution notices from the Source form of the Work,\n          excluding those notices that do not pertain to any part of\n          the Derivative Works; and\n\n      (d) If the Work includes a \"NOTICE\" text file as part of its\n          distribution, then any Derivative Works that You distribute must\n          include a readable copy of the attribution notices contained\n          within such NOTICE file, excluding those notices that do not\n          pertain to any part of the Derivative Works, in at least one\n          of the following places: within a NOTICE text file distributed\n          as part of the Derivative Works; within the Source form or\n          documentation, if provided along with the Derivative Works; or,\n          within a display generated by the Derivative Works, if and\n          wherever such third-party notices normally appear. The contents\n          of the NOTICE file are for informational purposes only and\n          do not modify the License. You may add Your own attribution\n          notices within Derivative Works that You distribute, alongside\n          or as an addendum to the NOTICE text from the Work, provided\n          that such additional attribution notices cannot be construed\n          as modifying the License.\n\n      You may add Your own copyright statement to Your modifications and\n      may provide additional or different license terms and conditions\n      for use, reproduction, or distribution of Your modifications, or\n      for any such Derivative Works as a whole, provided Your use,\n      reproduction, and distribution of the Work otherwise complies with\n      the conditions stated in this License.\n\n   5. Submission of Contributions. Unless You explicitly state otherwise,\n      any Contribution intentionally submitted for inclusion in the Work\n      by You to the Licensor shall be under the terms and conditions of\n      this License, without any additional terms or conditions.\n      Notwithstanding the above, nothing herein shall supersede or modify\n      the terms of any separate license agreement you may have executed\n      with Licensor regarding such Contributions.\n\n   6. Trademarks. This License does not grant permission to use the trade\n      names, trademarks, service marks, or product names of the Licensor,\n      except as required for reasonable and customary use in describing the\n      origin of the Work and reproducing the content of the NOTICE file.\n\n   7. Disclaimer of Warranty. Unless required by applicable law or\n      agreed to in writing, Licensor provides the Work (and each\n      Contributor provides its Contributions) on an \"AS IS\" BASIS,\n      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n      implied, including, without limitation, any warranties or conditions\n      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A\n      PARTICULAR PURPOSE. You are solely responsible for determining the\n      appropriateness of using or redistributing the Work and assume any\n      risks associated with Your exercise of permissions under this License.\n\n   8. Limitation of Liability. In no event and under no legal theory,\n      whether in tort (including negligence), contract, or otherwise,\n      unless required by applicable law (such as deliberate and grossly\n      negligent acts) or agreed to in writing, shall any Contributor be\n      liable to You for damages, including any direct, indirect, special,\n      incidental, or consequential damages of any character arising as a\n      result of this License or out of the use or inability to use the\n      Work (including but not limited to damages for loss of goodwill,\n      work stoppage, computer failure or malfunction, or any and all\n      other commercial damages or losses), even if such Contributor\n      has been advised of the possibility of such damages.\n\n   9. Accepting Warranty or Additional Liability. While redistributing\n      the Work or Derivative Works thereof, You may choose to offer,\n      and charge a fee for, acceptance of support, warranty, indemnity,\n      or other liability obligations and/or rights consistent with this\n      License. However, in accepting such obligations, You may act only\n      on Your own behalf and on Your sole responsibility, not on behalf\n      of any other Contributor, and only if You agree to indemnify,\n      defend, and hold each Contributor harmless for any liability\n      incurred by, or claims asserted against, such Contributor by reason\n      of your accepting any such warranty or additional liability.\n\n   END OF TERMS AND CONDITIONS\n\n   APPENDIX: How to apply the Apache License to your work.\n\n      To apply the Apache License to your work, attach the following\n      boilerplate notice, with the fields enclosed by brackets \"{}\"\n      replaced with your own identifying information. (Don't include\n      the brackets!)  The text should be enclosed in the appropriate\n      comment syntax for the file format. We also recommend that a\n      file or class name and description of purpose be included on the\n      same \"printed page\" as the copyright notice for easier\n      identification within third-party archives.\n\n   Copyright 2017 Kenton Lee\n\n   Licensed under the Apache License, Version 2.0 (the \"License\");\n   you may not use this file except in compliance with the License.\n   You may obtain a copy of the License at\n\n       http://www.apache.org/licenses/LICENSE-2.0\n\n   Unless required by applicable law or agreed to in writing, software\n   distributed under the License is distributed on an \"AS IS\" BASIS,\n   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n   See the License for the specific language governing permissions and\n   limitations under the License.\n"
  },
  {
    "path": "README.md",
    "content": "# Higher-order Coreference Resolution with Coarse-to-fine Inference\n\n## Introduction\nThis repository contains the code for replicating results from\n\n* [Higher-order Coreference Resolution with Coarse-to-fine Inference](https://arxiv.org/abs/1804.05392)\n* [Kenton Lee](http://kentonl.com/), [Luheng He](https://homes.cs.washington.edu/~luheng), and [Luke Zettlemoyer](https://www.cs.washington.edu/people/faculty/lsz)\n* In NAACL 2018\n\n## Getting Started\n\n* Install python (either 2 or 3) requirements: `pip install -r requirements.txt`\n* Download pretrained models at https://drive.google.com/file/d/1fkifqZzdzsOEo0DXMzCFjiNXqsKG_cHi\n  * Move the downloaded file to the root of the repo and extract: `tar -xzvf e2e-coref.tgz`\n* Download GloVe embeddings and build custom kernels by running `setup_all.sh`.\n  * There are 3 platform-dependent ways to build custom TensorFlow kernels. Please comment/uncomment the appropriate lines in the script.\n* To train your own models, run `setup_training.sh`\n  * This assumes access to OntoNotes 5.0. Please edit the `ontonotes_path` variable.\n\n## Training Instructions\n\n* Experiment configurations are found in `experiments.conf`\n* Choose an experiment that you would like to run, e.g. `best`\n* Training: `python train.py <experiment>`\n* Results are stored in the `logs` directory and can be viewed via TensorBoard.\n* Evaluation: `python evaluate.py <experiment>`\n\n## Demo Instructions\n\n* Command-line demo: `python demo.py final`\n* To run the demo with other experiments, replace `final` with your configuration name.\n\n## Batched Prediction Instructions\n\n* Create a file where each line is in the following json format (make sure to strip the newlines so each line is well-formed json):\n```\n{\n  \"clusters\": [],\n  \"doc_key\": \"nw\",\n  \"sentences\": [[\"This\", \"is\", \"the\", \"first\", \"sentence\", \".\"], [\"This\", \"is\", \"the\", \"second\", \".\"]],\n  \"speakers\": [[\"spk1\", \"spk1\", \"spk1\", \"spk1\", \"spk1\", \"spk1\"], [\"spk2\", \"spk2\", \"spk2\", \"spk2\", \"spk2\"]]\n}\n```\n  * `clusters` should be left empty and is only used for evaluation purposes.\n  * `doc_key` indicates the genre, which can be one of the following: `\"bc\", \"bn\", \"mz\", \"nw\", \"pt\", \"tc\", \"wb\"`\n  * `speakers` indicates the speaker of each word. These can be all empty strings if there is only one known speaker.\n* Run `python predict.py <experiment> <input_file> <output_file>`, which outputs the input jsonlines with predicted clusters.\n\n## Other Quirks\n\n* It does not use GPUs by default. Instead, it looks for the `GPU` environment variable, which the code treats as shorthand for `CUDA_VISIBLE_DEVICES`.\n* The training runs indefinitely and needs to be terminated manually. The model generally converges at about 400k steps.\n"
  },
  {
    "path": "cache_elmo.py",
    "content": "from __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport numpy as np\nimport tensorflow as tf\nimport tensorflow_hub as hub\nimport h5py\nimport json\nimport sys\n\ndef build_elmo():\n  token_ph = tf.placeholder(tf.string, [None, None])\n  len_ph = tf.placeholder(tf.int32, [None])\n  elmo_module = hub.Module(\"https://tfhub.dev/google/elmo/2\")\n  lm_embeddings = elmo_module(\n      inputs={\"tokens\": token_ph, \"sequence_len\": len_ph},\n      signature=\"tokens\", as_dict=True)\n  word_emb = lm_embeddings[\"word_emb\"]\n  lm_emb = tf.stack([tf.concat([word_emb, word_emb], -1),\n                     lm_embeddings[\"lstm_outputs1\"],\n                     lm_embeddings[\"lstm_outputs2\"]], -1)\n  return token_ph, len_ph, lm_emb\n\ndef cache_dataset(data_path, session, token_ph, len_ph, lm_emb, out_file):\n  with open(data_path) as in_file:\n    for doc_num, line in enumerate(in_file.readlines()):\n      example = json.loads(line)\n      sentences = example[\"sentences\"]\n      max_sentence_length = max(len(s) for s in sentences)\n      tokens = [[\"\"] * max_sentence_length for _ in sentences]\n      text_len = np.array([len(s) for s in sentences])\n      for i, sentence in enumerate(sentences):\n        for j, word in enumerate(sentence):\n          tokens[i][j] = word\n      tokens = np.array(tokens)\n      tf_lm_emb = session.run(lm_emb, feed_dict={\n          token_ph: tokens,\n          len_ph: text_len\n      })\n      file_key = example[\"doc_key\"].replace(\"/\", \":\")\n      group = out_file.create_group(file_key)\n      for i, (e, l) in enumerate(zip(tf_lm_emb, text_len)):\n        e = e[:l, :, :]\n        group[str(i)] = e\n      if doc_num % 10 == 0:\n        print(\"Cached {} documents in {}\".format(doc_num + 1, data_path))\n\nif __name__ == \"__main__\":\n  token_ph, len_ph, lm_emb = build_elmo()\n  with tf.Session() as session:\n    session.run(tf.global_variables_initializer())\n    with h5py.File(\"elmo_cache.hdf5\", \"w\") as out_file:\n      for json_filename in sys.argv[1:]:\n        cache_dataset(json_filename, session, token_ph, len_ph, lm_emb, out_file)\n"
  },
  {
    "path": "conll.py",
    "content": "from __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport re\nimport os\nimport sys\nimport json\nimport tempfile\nimport subprocess\nimport operator\nimport collections\n\nBEGIN_DOCUMENT_REGEX = re.compile(r\"#begin document \\((.*)\\); part (\\d+)\")\nCOREF_RESULTS_REGEX = re.compile(r\".*Coreference: Recall: \\([0-9.]+ / [0-9.]+\\) ([0-9.]+)%\\tPrecision: \\([0-9.]+ / [0-9.]+\\) ([0-9.]+)%\\tF1: ([0-9.]+)%.*\", re.DOTALL)\n\ndef get_doc_key(doc_id, part):\n  return \"{}_{}\".format(doc_id, int(part))\n\ndef output_conll(input_file, output_file, predictions):\n  prediction_map = {}\n  for doc_key, clusters in predictions.items():\n    start_map = collections.defaultdict(list)\n    end_map = collections.defaultdict(list)\n    word_map = collections.defaultdict(list)\n    for cluster_id, mentions in enumerate(clusters):\n      for start, end in mentions:\n        if start == end:\n          word_map[start].append(cluster_id)\n        else:\n          start_map[start].append((cluster_id, end))\n          end_map[end].append((cluster_id, start))\n    for k,v in start_map.items():\n      start_map[k] = [cluster_id for cluster_id, end in sorted(v, key=operator.itemgetter(1), reverse=True)]\n    for k,v in end_map.items():\n      end_map[k] = [cluster_id for cluster_id, start in sorted(v, key=operator.itemgetter(1), reverse=True)]\n    prediction_map[doc_key] = (start_map, end_map, word_map)\n\n  word_index = 0\n  for line in input_file.readlines():\n    row = line.split()\n    if len(row) == 0:\n      output_file.write(\"\\n\")\n    elif row[0].startswith(\"#\"):\n      begin_match = re.match(BEGIN_DOCUMENT_REGEX, line)\n      if begin_match:\n        doc_key = get_doc_key(begin_match.group(1), begin_match.group(2))\n        start_map, end_map, word_map = prediction_map[doc_key]\n        word_index = 0\n      output_file.write(line)\n      output_file.write(\"\\n\")\n    else:\n      assert get_doc_key(row[0], row[1]) == doc_key\n      coref_list = []\n      if word_index in end_map:\n        for cluster_id in end_map[word_index]:\n          coref_list.append(\"{})\".format(cluster_id))\n      if word_index in word_map:\n        for cluster_id in word_map[word_index]:\n          coref_list.append(\"({})\".format(cluster_id))\n      if word_index in start_map:\n        for cluster_id in start_map[word_index]:\n          coref_list.append(\"({}\".format(cluster_id))\n\n      if len(coref_list) == 0:\n        row[-1] = \"-\"\n      else:\n        row[-1] = \"|\".join(coref_list)\n\n      output_file.write(\"   \".join(row))\n      output_file.write(\"\\n\")\n      word_index += 1\n\ndef official_conll_eval(gold_path, predicted_path, metric, official_stdout=False):\n  cmd = [\"conll-2012/scorer/v8.01/scorer.pl\", metric, gold_path, predicted_path, \"none\"]\n  process = subprocess.Popen(cmd, stdout=subprocess.PIPE)\n  stdout, stderr = process.communicate()\n  process.wait()\n\n  stdout = stdout.decode(\"utf-8\")\n  if stderr is not None:\n    print(stderr)\n\n  if official_stdout:\n    print(\"Official result for {}\".format(metric))\n    print(stdout)\n\n  coref_results_match = re.match(COREF_RESULTS_REGEX, stdout)\n  recall = float(coref_results_match.group(1))\n  precision = float(coref_results_match.group(2))\n  f1 = float(coref_results_match.group(3))\n  return { \"r\": recall, \"p\": precision, \"f\": f1 }\n\ndef evaluate_conll(gold_path, predictions, official_stdout=False):\n  with tempfile.NamedTemporaryFile(delete=False, mode=\"w\") as prediction_file:\n    with open(gold_path, \"r\") as gold_file:\n      output_conll(gold_file, prediction_file, predictions)\n    print(\"Predicted conll file: {}\".format(prediction_file.name))\n  return { m: official_conll_eval(gold_file.name, prediction_file.name, m, official_stdout) for m in (\"muc\", \"bcub\", \"ceafe\") }\n"
  },
  {
    "path": "continuous_evaluate.py",
    "content": "#!/usr/bin/env python\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport os\nimport re\nimport time\nimport shutil\n\nimport tensorflow as tf\nimport coref_model as cm\nimport util\n\ndef copy_checkpoint(source, target):\n  for ext in (\".index\", \".data-00000-of-00001\"):\n    shutil.copyfile(source + ext, target + ext)\n\nif __name__ == \"__main__\":\n  config = util.initialize_from_env()\n  model = cm.CorefModel(config)\n\n  saver = tf.train.Saver()\n  log_dir = config[\"log_dir\"]\n  writer = tf.summary.FileWriter(log_dir, flush_secs=20)\n  evaluated_checkpoints = set()\n  max_f1 = 0\n  checkpoint_pattern = re.compile(\".*model.ckpt-([0-9]*)\\Z\")\n\n  with tf.Session() as session:\n    while True:\n      ckpt = tf.train.get_checkpoint_state(log_dir)\n      if ckpt and ckpt.model_checkpoint_path and ckpt.model_checkpoint_path not in evaluated_checkpoints:\n        print(\"Evaluating {}\".format(ckpt.model_checkpoint_path))\n\n        # Move it to a temporary location to avoid being deleted by the training supervisor.\n        tmp_checkpoint_path = os.path.join(log_dir, \"model.tmp.ckpt\")\n        copy_checkpoint(ckpt.model_checkpoint_path, tmp_checkpoint_path)\n\n        global_step = int(checkpoint_pattern.match(ckpt.model_checkpoint_path).group(1))\n        saver.restore(session, ckpt.model_checkpoint_path)\n\n        eval_summary, f1 = model.evaluate(session)\n\n        if f1 > max_f1:\n          max_f1 = f1\n          copy_checkpoint(tmp_checkpoint_path, os.path.join(log_dir, \"model.max.ckpt\"))\n\n        print(\"Current max F1: {:.2f}\".format(max_f1))\n\n        writer.add_summary(eval_summary, global_step)\n        print(\"Evaluation written to {} at step {}\".format(log_dir, global_step))\n\n        evaluated_checkpoints.add(ckpt.model_checkpoint_path)\n        sleep_time = 60\n      else:\n        sleep_time = 10\n      print(\"Waiting for {} seconds before looking for next checkpoint.\".format(sleep_time))\n      time.sleep(sleep_time)\n"
  },
  {
    "path": "coref_kernels.cc",
    "content": "#include <map>\n\n#include \"tensorflow/core/framework/op.h\"\n#include \"tensorflow/core/framework/shape_inference.h\"\n#include \"tensorflow/core/framework/op_kernel.h\"\n\nusing namespace tensorflow;\n\nREGISTER_OP(\"ExtractSpans\")\n.Input(\"span_scores: float32\")\n.Input(\"candidate_starts: int32\")\n.Input(\"candidate_ends: int32\")\n.Input(\"num_output_spans: int32\")\n.Input(\"max_sentence_length: int32\")\n.Attr(\"sort_spans: bool\")\n.Output(\"output_span_indices: int32\");\n\nclass ExtractSpansOp : public OpKernel {\npublic:\n  explicit ExtractSpansOp(OpKernelConstruction* context) : OpKernel(context) {\n    OP_REQUIRES_OK(context, context->GetAttr(\"sort_spans\", &_sort_spans));\n  }\n\n  void Compute(OpKernelContext* context) override {\n    TTypes<float>::ConstMatrix span_scores = context->input(0).matrix<float>();\n    TTypes<int32>::ConstMatrix candidate_starts = context->input(1).matrix<int32>();\n    TTypes<int32>::ConstMatrix candidate_ends = context->input(2).matrix<int32>();\n    TTypes<int32>::ConstVec num_output_spans = context->input(3).vec<int32>();\n    int max_sentence_length = context->input(4).scalar<int32>()();\n\n    int num_sentences = span_scores.dimension(0);\n    int num_input_spans = span_scores.dimension(1);\n    int max_num_output_spans = 0;\n    for (int i = 0; i < num_sentences; i++) {\n      if (num_output_spans(i) > max_num_output_spans) {\n        max_num_output_spans = num_output_spans(i);\n      }\n    }\n\n    Tensor* output_span_indices_tensor = nullptr;\n    TensorShape output_span_indices_shape({num_sentences, max_num_output_spans});\n    OP_REQUIRES_OK(context, context->allocate_output(0, output_span_indices_shape,\n                                                     &output_span_indices_tensor));\n    TTypes<int32>::Matrix output_span_indices = output_span_indices_tensor->matrix<int32>();\n\n    std::vector<std::vector<int>> sorted_input_span_indices(num_sentences,\n                                                            std::vector<int>(num_input_spans));\n    for (int i = 0; i < num_sentences; i++) {\n      std::iota(sorted_input_span_indices[i].begin(), sorted_input_span_indices[i].end(), 0);\n      std::sort(sorted_input_span_indices[i].begin(), sorted_input_span_indices[i].end(),\n                [&span_scores, &i](int j1, int j2) {\n                  return span_scores(i, j2) < span_scores(i, j1);\n                });\n    }\n\n    for (int l = 0; l < num_sentences; l++) {\n      std::vector<int> top_span_indices;\n      std::unordered_map<int, int> end_to_earliest_start;\n      std::unordered_map<int, int> start_to_latest_end;\n\n      int current_span_index = 0,\n          num_selected_spans = 0;\n      while (num_selected_spans < num_output_spans(l) && current_span_index < num_input_spans) {\n        int i = sorted_input_span_indices[l][current_span_index];\n        bool any_crossing = false;\n        const int start = candidate_starts(l, i);\n        const int end = candidate_ends(l, i);\n        for (int j = start; j <= end; ++j) {\n          auto latest_end_iter = start_to_latest_end.find(j);\n          if (latest_end_iter != start_to_latest_end.end() && j > start && latest_end_iter->second > end) {\n            // Given (), exists [], such that ( [ ) ]\n            any_crossing = true;\n            break;\n          }\n          auto earliest_start_iter = end_to_earliest_start.find(j);\n          if (earliest_start_iter != end_to_earliest_start.end() && j < end && earliest_start_iter->second < start) {\n            // Given (), exists [], such that [ ( ] )\n            any_crossing = true;\n            break;\n          }\n        }\n        if (!any_crossing) {\n          if (_sort_spans) {\n            top_span_indices.push_back(i);\n          } else {\n            output_span_indices(l, num_selected_spans) = i;\n          }\n          ++num_selected_spans;\n          // Update data struct.\n          auto latest_end_iter = start_to_latest_end.find(start);\n          if (latest_end_iter == start_to_latest_end.end() || end > latest_end_iter->second) {\n            start_to_latest_end[start] = end;\n          }\n          auto earliest_start_iter = end_to_earliest_start.find(end);\n          if (earliest_start_iter == end_to_earliest_start.end() || start < earliest_start_iter->second) {\n            end_to_earliest_start[end] = start;\n          }\n        }\n        ++current_span_index;\n      }\n      // Sort and populate selected span indices.\n      if (_sort_spans) {\n        std::sort(top_span_indices.begin(), top_span_indices.end(),\n                  [&candidate_starts, &candidate_ends, &l] (int i1, int i2) {\n                    if (candidate_starts(l, i1) < candidate_starts(l, i2)) {\n                      return true;\n                    } else if (candidate_starts(l, i1) > candidate_starts(l, i2)) {\n                      return false;\n                    } else if (candidate_ends(l, i1) < candidate_ends(l, i2)) {\n                      return true;\n                    } else if (candidate_ends(l, i1) > candidate_ends(l, i2)) {\n                      return false;\n                    } else {\n                      return i1 < i2;\n                    }\n                  });\n        for (int i = 0; i < num_output_spans(l); ++i) {\n          output_span_indices(l, i) = top_span_indices[i];\n        }\n      }\n      // Pad with the first span index.\n      for (int i = num_selected_spans; i < max_num_output_spans; ++i) {\n        output_span_indices(l, i) = output_span_indices(l, 0);\n      }\n    }\n  }\nprivate:\n  bool _sort_spans;\n};\n\nREGISTER_KERNEL_BUILDER(Name(\"ExtractSpans\").Device(DEVICE_CPU), ExtractSpansOp);\n"
  },
  {
    "path": "coref_model.py",
    "content": "from __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport os\nimport operator\nimport random\nimport math\nimport json\nimport threading\nimport numpy as np\nimport tensorflow as tf\nimport tensorflow_hub as hub\nimport h5py\n\nimport util\nimport coref_ops\nimport conll\nimport metrics\n\nclass CorefModel(object):\n  def __init__(self, config):\n    self.config = config\n    self.context_embeddings = util.EmbeddingDictionary(config[\"context_embeddings\"])\n    self.head_embeddings = util.EmbeddingDictionary(config[\"head_embeddings\"], maybe_cache=self.context_embeddings)\n    self.char_embedding_size = config[\"char_embedding_size\"]\n    self.char_dict = util.load_char_dict(config[\"char_vocab_path\"])\n    self.max_span_width = config[\"max_span_width\"]\n    self.genres = { g:i for i,g in enumerate(config[\"genres\"]) }\n    if config[\"lm_path\"]:\n      self.lm_file = h5py.File(self.config[\"lm_path\"], \"r\")\n    else:\n      self.lm_file = None\n    self.lm_layers = self.config[\"lm_layers\"]\n    self.lm_size = self.config[\"lm_size\"]\n    self.eval_data = None # Load eval data lazily.\n\n    input_props = []\n    input_props.append((tf.string, [None, None])) # Tokens.\n    input_props.append((tf.float32, [None, None, self.context_embeddings.size])) # Context embeddings.\n    input_props.append((tf.float32, [None, None, self.head_embeddings.size])) # Head embeddings.\n    input_props.append((tf.float32, [None, None, self.lm_size, self.lm_layers])) # LM embeddings.\n    input_props.append((tf.int32, [None, None, None])) # Character indices.\n    input_props.append((tf.int32, [None])) # Text lengths.\n    input_props.append((tf.int32, [None])) # Speaker IDs.\n    input_props.append((tf.int32, [])) # Genre.\n    input_props.append((tf.bool, [])) # Is training.\n    input_props.append((tf.int32, [None])) # Gold starts.\n    input_props.append((tf.int32, [None])) # Gold ends.\n    input_props.append((tf.int32, [None])) # Cluster ids.\n\n    self.queue_input_tensors = [tf.placeholder(dtype, shape) for dtype, shape in input_props]\n    dtypes, shapes = zip(*input_props)\n    queue = tf.PaddingFIFOQueue(capacity=10, dtypes=dtypes, shapes=shapes)\n    self.enqueue_op = queue.enqueue(self.queue_input_tensors)\n    self.input_tensors = queue.dequeue()\n\n    self.predictions, self.loss = self.get_predictions_and_loss(*self.input_tensors)\n    self.global_step = tf.Variable(0, name=\"global_step\", trainable=False)\n    self.reset_global_step = tf.assign(self.global_step, 0)\n    learning_rate = tf.train.exponential_decay(self.config[\"learning_rate\"], self.global_step,\n                                               self.config[\"decay_frequency\"], self.config[\"decay_rate\"], staircase=True)\n    trainable_params = tf.trainable_variables()\n    gradients = tf.gradients(self.loss, trainable_params)\n    gradients, _ = tf.clip_by_global_norm(gradients, self.config[\"max_gradient_norm\"])\n    optimizers = {\n      \"adam\" : tf.train.AdamOptimizer,\n      \"sgd\" : tf.train.GradientDescentOptimizer\n    }\n    optimizer = optimizers[self.config[\"optimizer\"]](learning_rate)\n    self.train_op = optimizer.apply_gradients(zip(gradients, trainable_params), global_step=self.global_step)\n\n  def start_enqueue_thread(self, session):\n    with open(self.config[\"train_path\"]) as f:\n      train_examples = [json.loads(jsonline) for jsonline in f.readlines()]\n    def _enqueue_loop():\n      while True:\n        random.shuffle(train_examples)\n        for example in train_examples:\n          tensorized_example = self.tensorize_example(example, is_training=True)\n          feed_dict = dict(zip(self.queue_input_tensors, tensorized_example))\n          session.run(self.enqueue_op, feed_dict=feed_dict)\n    enqueue_thread = threading.Thread(target=_enqueue_loop)\n    enqueue_thread.daemon = True\n    enqueue_thread.start()\n\n  def restore(self, session):\n    # Don't try to restore unused variables from the TF-Hub ELMo module.\n    vars_to_restore = [v for v in tf.global_variables() if \"module/\" not in v.name]\n    saver = tf.train.Saver(vars_to_restore)\n    checkpoint_path = os.path.join(self.config[\"log_dir\"], \"model.max.ckpt\")\n    print(\"Restoring from {}\".format(checkpoint_path))\n    session.run(tf.global_variables_initializer())\n    saver.restore(session, checkpoint_path)\n\n  def load_lm_embeddings(self, doc_key):\n    if self.lm_file is None:\n      return np.zeros([0, 0, self.lm_size, self.lm_layers])\n    file_key = doc_key.replace(\"/\", \":\")\n    group = self.lm_file[file_key]\n    num_sentences = len(list(group.keys()))\n    sentences = [group[str(i)][...] for i in range(num_sentences)]\n    lm_emb = np.zeros([num_sentences, max(s.shape[0] for s in sentences), self.lm_size, self.lm_layers])\n    for i, s in enumerate(sentences):\n      lm_emb[i, :s.shape[0], :, :] = s\n    return lm_emb\n\n  def tensorize_mentions(self, mentions):\n    if len(mentions) > 0:\n      starts, ends = zip(*mentions)\n    else:\n      starts, ends = [], []\n    return np.array(starts), np.array(ends)\n\n  def tensorize_span_labels(self, tuples, label_dict):\n    if len(tuples) > 0:\n      starts, ends, labels = zip(*tuples)\n    else:\n      starts, ends, labels = [], [], []\n    return np.array(starts), np.array(ends), np.array([label_dict[c] for c in labels])\n\n  def tensorize_example(self, example, is_training):\n    clusters = example[\"clusters\"]\n\n    gold_mentions = sorted(tuple(m) for m in util.flatten(clusters))\n    gold_mention_map = {m:i for i,m in enumerate(gold_mentions)}\n    cluster_ids = np.zeros(len(gold_mentions))\n    for cluster_id, cluster in enumerate(clusters):\n      for mention in cluster:\n        cluster_ids[gold_mention_map[tuple(mention)]] = cluster_id + 1\n\n    sentences = example[\"sentences\"]\n    num_words = sum(len(s) for s in sentences)\n    speakers = util.flatten(example[\"speakers\"])\n\n    assert num_words == len(speakers)\n\n    max_sentence_length = max(len(s) for s in sentences)\n    max_word_length = max(max(max(len(w) for w in s) for s in sentences), max(self.config[\"filter_widths\"]))\n    text_len = np.array([len(s) for s in sentences])\n    tokens = [[\"\"] * max_sentence_length for _ in sentences]\n    context_word_emb = np.zeros([len(sentences), max_sentence_length, self.context_embeddings.size])\n    head_word_emb = np.zeros([len(sentences), max_sentence_length, self.head_embeddings.size])\n    char_index = np.zeros([len(sentences), max_sentence_length, max_word_length])\n    for i, sentence in enumerate(sentences):\n      for j, word in enumerate(sentence):\n        tokens[i][j] = word\n        context_word_emb[i, j] = self.context_embeddings[word]\n        head_word_emb[i, j] = self.head_embeddings[word]\n        char_index[i, j, :len(word)] = [self.char_dict[c] for c in word]\n    tokens = np.array(tokens)\n\n    speaker_dict = { s:i for i,s in enumerate(set(speakers)) }\n    speaker_ids = np.array([speaker_dict[s] for s in speakers])\n\n    doc_key = example[\"doc_key\"]\n    genre = self.genres[doc_key[:2]]\n\n    gold_starts, gold_ends = self.tensorize_mentions(gold_mentions)\n\n    lm_emb = self.load_lm_embeddings(doc_key)\n\n    example_tensors = (tokens, context_word_emb, head_word_emb, lm_emb, char_index, text_len, speaker_ids, genre, is_training, gold_starts, gold_ends, cluster_ids)\n\n    if is_training and len(sentences) > self.config[\"max_training_sentences\"]:\n      return self.truncate_example(*example_tensors)\n    else:\n      return example_tensors\n\n  def truncate_example(self, tokens, context_word_emb, head_word_emb, lm_emb, char_index, text_len, speaker_ids, genre, is_training, gold_starts, gold_ends, cluster_ids):\n    max_training_sentences = self.config[\"max_training_sentences\"]\n    num_sentences = context_word_emb.shape[0]\n    assert num_sentences > max_training_sentences\n\n    sentence_offset = random.randint(0, num_sentences - max_training_sentences)\n    word_offset = text_len[:sentence_offset].sum()\n    num_words = text_len[sentence_offset:sentence_offset + max_training_sentences].sum()\n    tokens = tokens[sentence_offset:sentence_offset + max_training_sentences, :]\n    context_word_emb = context_word_emb[sentence_offset:sentence_offset + max_training_sentences, :, :]\n    head_word_emb = head_word_emb[sentence_offset:sentence_offset + max_training_sentences, :, :]\n    lm_emb = lm_emb[sentence_offset:sentence_offset + max_training_sentences, :, :, :]\n    char_index = char_index[sentence_offset:sentence_offset + max_training_sentences, :, :]\n    text_len = text_len[sentence_offset:sentence_offset + max_training_sentences]\n\n    speaker_ids = speaker_ids[word_offset: word_offset + num_words]\n    gold_spans = np.logical_and(gold_ends >= word_offset, gold_starts < word_offset + num_words)\n    gold_starts = gold_starts[gold_spans] - word_offset\n    gold_ends = gold_ends[gold_spans] - word_offset\n    cluster_ids = cluster_ids[gold_spans]\n\n    return tokens, context_word_emb, head_word_emb, lm_emb, char_index, text_len, speaker_ids, genre, is_training, gold_starts, gold_ends, cluster_ids\n\n  def get_candidate_labels(self, candidate_starts, candidate_ends, labeled_starts, labeled_ends, labels):\n    same_start = tf.equal(tf.expand_dims(labeled_starts, 1), tf.expand_dims(candidate_starts, 0)) # [num_labeled, num_candidates]\n    same_end = tf.equal(tf.expand_dims(labeled_ends, 1), tf.expand_dims(candidate_ends, 0)) # [num_labeled, num_candidates]\n    same_span = tf.logical_and(same_start, same_end) # [num_labeled, num_candidates]\n    candidate_labels = tf.matmul(tf.expand_dims(labels, 0), tf.to_int32(same_span)) # [1, num_candidates]\n    candidate_labels = tf.squeeze(candidate_labels, 0) # [num_candidates]\n    return candidate_labels\n\n  def get_dropout(self, dropout_rate, is_training):\n    return 1 - (tf.to_float(is_training) * dropout_rate)\n\n  def coarse_to_fine_pruning(self, top_span_emb, top_span_mention_scores, c):\n    k = util.shape(top_span_emb, 0)\n    top_span_range = tf.range(k) # [k]\n    antecedent_offsets = tf.expand_dims(top_span_range, 1) - tf.expand_dims(top_span_range, 0) # [k, k]\n    antecedents_mask = antecedent_offsets >= 1 # [k, k]\n    fast_antecedent_scores = tf.expand_dims(top_span_mention_scores, 1) + tf.expand_dims(top_span_mention_scores, 0) # [k, k]\n    fast_antecedent_scores += tf.log(tf.to_float(antecedents_mask)) # [k, k]\n    fast_antecedent_scores += self.get_fast_antecedent_scores(top_span_emb) # [k, k]\n\n    _, top_antecedents = tf.nn.top_k(fast_antecedent_scores, c, sorted=False) # [k, c]\n    top_antecedents_mask = util.batch_gather(antecedents_mask, top_antecedents) # [k, c]\n    top_fast_antecedent_scores = util.batch_gather(fast_antecedent_scores, top_antecedents) # [k, c]\n    top_antecedent_offsets = util.batch_gather(antecedent_offsets, top_antecedents) # [k, c]\n    return top_antecedents, top_antecedents_mask, top_fast_antecedent_scores, top_antecedent_offsets\n\n  def distance_pruning(self, top_span_emb, top_span_mention_scores, c):\n    k = util.shape(top_span_emb, 0)\n    top_antecedent_offsets = tf.tile(tf.expand_dims(tf.range(c) + 1, 0), [k, 1]) # [k, c]\n    raw_top_antecedents = tf.expand_dims(tf.range(k), 1) - top_antecedent_offsets # [k, c]\n    top_antecedents_mask = raw_top_antecedents >= 0 # [k, c]\n    top_antecedents = tf.maximum(raw_top_antecedents, 0) # [k, c]\n\n    top_fast_antecedent_scores = tf.expand_dims(top_span_mention_scores, 1) + tf.gather(top_span_mention_scores, top_antecedents) # [k, c]\n    top_fast_antecedent_scores += tf.log(tf.to_float(top_antecedents_mask)) # [k, c]\n    return top_antecedents, top_antecedents_mask, top_fast_antecedent_scores, top_antecedent_offsets\n\n  def get_predictions_and_loss(self, tokens, context_word_emb, head_word_emb, lm_emb, char_index, text_len, speaker_ids, genre, is_training, gold_starts, gold_ends, cluster_ids):\n    self.dropout = self.get_dropout(self.config[\"dropout_rate\"], is_training)\n    self.lexical_dropout = self.get_dropout(self.config[\"lexical_dropout_rate\"], is_training)\n    self.lstm_dropout = self.get_dropout(self.config[\"lstm_dropout_rate\"], is_training)\n\n    num_sentences = tf.shape(context_word_emb)[0]\n    max_sentence_length = tf.shape(context_word_emb)[1]\n\n    context_emb_list = [context_word_emb]\n    head_emb_list = [head_word_emb]\n\n    if self.config[\"char_embedding_size\"] > 0:\n      char_emb = tf.gather(tf.get_variable(\"char_embeddings\", [len(self.char_dict), self.config[\"char_embedding_size\"]]), char_index) # [num_sentences, max_sentence_length, max_word_length, emb]\n      flattened_char_emb = tf.reshape(char_emb, [num_sentences * max_sentence_length, util.shape(char_emb, 2), util.shape(char_emb, 3)]) # [num_sentences * max_sentence_length, max_word_length, emb]\n      flattened_aggregated_char_emb = util.cnn(flattened_char_emb, self.config[\"filter_widths\"], self.config[\"filter_size\"]) # [num_sentences * max_sentence_length, emb]\n      aggregated_char_emb = tf.reshape(flattened_aggregated_char_emb, [num_sentences, max_sentence_length, util.shape(flattened_aggregated_char_emb, 1)]) # [num_sentences, max_sentence_length, emb]\n      context_emb_list.append(aggregated_char_emb)\n      head_emb_list.append(aggregated_char_emb)\n\n    if not self.lm_file:\n      elmo_module = hub.Module(\"https://tfhub.dev/google/elmo/2\")\n      lm_embeddings = elmo_module(\n          inputs={\"tokens\": tokens, \"sequence_len\": text_len},\n          signature=\"tokens\", as_dict=True)\n      word_emb = lm_embeddings[\"word_emb\"]  # [num_sentences, max_sentence_length, 512]\n      lm_emb = tf.stack([tf.concat([word_emb, word_emb], -1),\n                         lm_embeddings[\"lstm_outputs1\"],\n                         lm_embeddings[\"lstm_outputs2\"]], -1)  # [num_sentences, max_sentence_length, 1024, 3]\n    lm_emb_size = util.shape(lm_emb, 2)\n    lm_num_layers = util.shape(lm_emb, 3)\n    with tf.variable_scope(\"lm_aggregation\"):\n      self.lm_weights = tf.nn.softmax(tf.get_variable(\"lm_scores\", [lm_num_layers], initializer=tf.constant_initializer(0.0)))\n      self.lm_scaling = tf.get_variable(\"lm_scaling\", [], initializer=tf.constant_initializer(1.0))\n    flattened_lm_emb = tf.reshape(lm_emb, [num_sentences * max_sentence_length * lm_emb_size, lm_num_layers])\n    flattened_aggregated_lm_emb = tf.matmul(flattened_lm_emb, tf.expand_dims(self.lm_weights, 1)) # [num_sentences * max_sentence_length * emb, 1]\n    aggregated_lm_emb = tf.reshape(flattened_aggregated_lm_emb, [num_sentences, max_sentence_length, lm_emb_size])\n    aggregated_lm_emb *= self.lm_scaling\n    context_emb_list.append(aggregated_lm_emb)\n\n    context_emb = tf.concat(context_emb_list, 2) # [num_sentences, max_sentence_length, emb]\n    head_emb = tf.concat(head_emb_list, 2) # [num_sentences, max_sentence_length, emb]\n    context_emb = tf.nn.dropout(context_emb, self.lexical_dropout) # [num_sentences, max_sentence_length, emb]\n    head_emb = tf.nn.dropout(head_emb, self.lexical_dropout) # [num_sentences, max_sentence_length, emb]\n\n    text_len_mask = tf.sequence_mask(text_len, maxlen=max_sentence_length) # [num_sentence, max_sentence_length]\n\n    context_outputs = self.lstm_contextualize(context_emb, text_len, text_len_mask) # [num_words, emb]\n    num_words = util.shape(context_outputs, 0)\n\n    genre_emb = tf.gather(tf.get_variable(\"genre_embeddings\", [len(self.genres), self.config[\"feature_size\"]]), genre) # [emb]\n\n    sentence_indices = tf.tile(tf.expand_dims(tf.range(num_sentences), 1), [1, max_sentence_length]) # [num_sentences, max_sentence_length]\n    flattened_sentence_indices = self.flatten_emb_by_sentence(sentence_indices, text_len_mask) # [num_words]\n    flattened_head_emb = self.flatten_emb_by_sentence(head_emb, text_len_mask) # [num_words]\n\n    candidate_starts = tf.tile(tf.expand_dims(tf.range(num_words), 1), [1, self.max_span_width]) # [num_words, max_span_width]\n    candidate_ends = candidate_starts + tf.expand_dims(tf.range(self.max_span_width), 0) # [num_words, max_span_width]\n    candidate_start_sentence_indices = tf.gather(flattened_sentence_indices, candidate_starts) # [num_words, max_span_width]\n    candidate_end_sentence_indices = tf.gather(flattened_sentence_indices, tf.minimum(candidate_ends, num_words - 1)) # [num_words, max_span_width]\n    candidate_mask = tf.logical_and(candidate_ends < num_words, tf.equal(candidate_start_sentence_indices, candidate_end_sentence_indices)) # [num_words, max_span_width]\n    flattened_candidate_mask = tf.reshape(candidate_mask, [-1]) # [num_words * max_span_width]\n    candidate_starts = tf.boolean_mask(tf.reshape(candidate_starts, [-1]), flattened_candidate_mask) # [num_candidates]\n    candidate_ends = tf.boolean_mask(tf.reshape(candidate_ends, [-1]), flattened_candidate_mask) # [num_candidates]\n    candidate_sentence_indices = tf.boolean_mask(tf.reshape(candidate_start_sentence_indices, [-1]), flattened_candidate_mask) # [num_candidates]\n\n    candidate_cluster_ids = self.get_candidate_labels(candidate_starts, candidate_ends, gold_starts, gold_ends, cluster_ids) # [num_candidates]\n\n    candidate_span_emb = self.get_span_emb(flattened_head_emb, context_outputs, candidate_starts, candidate_ends) # [num_candidates, emb]\n    candidate_mention_scores =  self.get_mention_scores(candidate_span_emb) # [k, 1]\n    candidate_mention_scores = tf.squeeze(candidate_mention_scores, 1) # [k]\n\n    k = tf.to_int32(tf.floor(tf.to_float(tf.shape(context_outputs)[0]) * self.config[\"top_span_ratio\"]))\n    top_span_indices = coref_ops.extract_spans(tf.expand_dims(candidate_mention_scores, 0),\n                                               tf.expand_dims(candidate_starts, 0),\n                                               tf.expand_dims(candidate_ends, 0),\n                                               tf.expand_dims(k, 0),\n                                               util.shape(context_outputs, 0),\n                                               True) # [1, k]\n    top_span_indices.set_shape([1, None])\n    top_span_indices = tf.squeeze(top_span_indices, 0) # [k]\n\n    top_span_starts = tf.gather(candidate_starts, top_span_indices) # [k]\n    top_span_ends = tf.gather(candidate_ends, top_span_indices) # [k]\n    top_span_emb = tf.gather(candidate_span_emb, top_span_indices) # [k, emb]\n    top_span_cluster_ids = tf.gather(candidate_cluster_ids, top_span_indices) # [k]\n    top_span_mention_scores = tf.gather(candidate_mention_scores, top_span_indices) # [k]\n    top_span_sentence_indices = tf.gather(candidate_sentence_indices, top_span_indices) # [k]\n    top_span_speaker_ids = tf.gather(speaker_ids, top_span_starts) # [k]\n\n    c = tf.minimum(self.config[\"max_top_antecedents\"], k)\n\n    if self.config[\"coarse_to_fine\"]:\n      top_antecedents, top_antecedents_mask, top_fast_antecedent_scores, top_antecedent_offsets = self.coarse_to_fine_pruning(top_span_emb, top_span_mention_scores, c)\n    else:\n      top_antecedents, top_antecedents_mask, top_fast_antecedent_scores, top_antecedent_offsets = self.distance_pruning(top_span_emb, top_span_mention_scores, c)\n\n    dummy_scores = tf.zeros([k, 1]) # [k, 1]\n    for i in range(self.config[\"coref_depth\"]):\n      with tf.variable_scope(\"coref_layer\", reuse=(i > 0)):\n        top_antecedent_emb = tf.gather(top_span_emb, top_antecedents) # [k, c, emb]\n        top_antecedent_scores = top_fast_antecedent_scores + self.get_slow_antecedent_scores(top_span_emb, top_antecedents, top_antecedent_emb, top_antecedent_offsets, top_span_speaker_ids, genre_emb) # [k, c]\n        top_antecedent_weights = tf.nn.softmax(tf.concat([dummy_scores, top_antecedent_scores], 1)) # [k, c + 1]\n        top_antecedent_emb = tf.concat([tf.expand_dims(top_span_emb, 1), top_antecedent_emb], 1) # [k, c + 1, emb]\n        attended_span_emb = tf.reduce_sum(tf.expand_dims(top_antecedent_weights, 2) * top_antecedent_emb, 1) # [k, emb]\n        with tf.variable_scope(\"f\"):\n          f = tf.sigmoid(util.projection(tf.concat([top_span_emb, attended_span_emb], 1), util.shape(top_span_emb, -1))) # [k, emb]\n          top_span_emb = f * attended_span_emb + (1 - f) * top_span_emb # [k, emb]\n\n    top_antecedent_scores = tf.concat([dummy_scores, top_antecedent_scores], 1) # [k, c + 1]\n\n    top_antecedent_cluster_ids = tf.gather(top_span_cluster_ids, top_antecedents) # [k, c]\n    top_antecedent_cluster_ids += tf.to_int32(tf.log(tf.to_float(top_antecedents_mask))) # [k, c]\n    same_cluster_indicator = tf.equal(top_antecedent_cluster_ids, tf.expand_dims(top_span_cluster_ids, 1)) # [k, c]\n    non_dummy_indicator = tf.expand_dims(top_span_cluster_ids > 0, 1) # [k, 1]\n    pairwise_labels = tf.logical_and(same_cluster_indicator, non_dummy_indicator) # [k, c]\n    dummy_labels = tf.logical_not(tf.reduce_any(pairwise_labels, 1, keepdims=True)) # [k, 1]\n    top_antecedent_labels = tf.concat([dummy_labels, pairwise_labels], 1) # [k, c + 1]\n    loss = self.softmax_loss(top_antecedent_scores, top_antecedent_labels) # [k]\n    loss = tf.reduce_sum(loss) # []\n\n    return [candidate_starts, candidate_ends, candidate_mention_scores, top_span_starts, top_span_ends, top_antecedents, top_antecedent_scores], loss\n\n  def get_span_emb(self, head_emb, context_outputs, span_starts, span_ends):\n    span_emb_list = []\n\n    span_start_emb = tf.gather(context_outputs, span_starts) # [k, emb]\n    span_emb_list.append(span_start_emb)\n\n    span_end_emb = tf.gather(context_outputs, span_ends) # [k, emb]\n    span_emb_list.append(span_end_emb)\n\n    span_width = 1 + span_ends - span_starts # [k]\n\n    if self.config[\"use_features\"]:\n      span_width_index = span_width - 1 # [k]\n      span_width_emb = tf.gather(tf.get_variable(\"span_width_embeddings\", [self.config[\"max_span_width\"], self.config[\"feature_size\"]]), span_width_index) # [k, emb]\n      span_width_emb = tf.nn.dropout(span_width_emb, self.dropout)\n      span_emb_list.append(span_width_emb)\n\n    if self.config[\"model_heads\"]:\n      span_indices = tf.expand_dims(tf.range(self.config[\"max_span_width\"]), 0) + tf.expand_dims(span_starts, 1) # [k, max_span_width]\n      span_indices = tf.minimum(util.shape(context_outputs, 0) - 1, span_indices) # [k, max_span_width]\n      span_text_emb = tf.gather(head_emb, span_indices) # [k, max_span_width, emb]\n      with tf.variable_scope(\"head_scores\"):\n        self.head_scores = util.projection(context_outputs, 1) # [num_words, 1]\n      span_head_scores = tf.gather(self.head_scores, span_indices) # [k, max_span_width, 1]\n      span_mask = tf.expand_dims(tf.sequence_mask(span_width, self.config[\"max_span_width\"], dtype=tf.float32), 2) # [k, max_span_width, 1]\n      span_head_scores += tf.log(span_mask) # [k, max_span_width, 1]\n      span_attention = tf.nn.softmax(span_head_scores, 1) # [k, max_span_width, 1]\n      span_head_emb = tf.reduce_sum(span_attention * span_text_emb, 1) # [k, emb]\n      span_emb_list.append(span_head_emb)\n\n    span_emb = tf.concat(span_emb_list, 1) # [k, emb]\n    return span_emb # [k, emb]\n\n  def get_mention_scores(self, span_emb):\n    with tf.variable_scope(\"mention_scores\"):\n      return util.ffnn(span_emb, self.config[\"ffnn_depth\"], self.config[\"ffnn_size\"], 1, self.dropout) # [k, 1]\n\n  def softmax_loss(self, antecedent_scores, antecedent_labels):\n    gold_scores = antecedent_scores + tf.log(tf.to_float(antecedent_labels)) # [k, max_ant + 1]\n    marginalized_gold_scores = tf.reduce_logsumexp(gold_scores, [1]) # [k]\n    log_norm = tf.reduce_logsumexp(antecedent_scores, [1]) # [k]\n    return log_norm - marginalized_gold_scores # [k]\n\n  def bucket_distance(self, distances):\n    \"\"\"\n    Places the given values (designed for distances) into 10 semi-logscale buckets:\n    [0, 1, 2, 3, 4, 5-7, 8-15, 16-31, 32-63, 64+].\n    \"\"\"\n    logspace_idx = tf.to_int32(tf.floor(tf.log(tf.to_float(distances))/math.log(2))) + 3\n    use_identity = tf.to_int32(distances <= 4)\n    combined_idx = use_identity * distances + (1 - use_identity) * logspace_idx\n    return tf.clip_by_value(combined_idx, 0, 9)\n\n  def get_slow_antecedent_scores(self, top_span_emb, top_antecedents, top_antecedent_emb, top_antecedent_offsets, top_span_speaker_ids, genre_emb):\n    k = util.shape(top_span_emb, 0)\n    c = util.shape(top_antecedents, 1)\n\n    feature_emb_list = []\n\n    if self.config[\"use_metadata\"]:\n      top_antecedent_speaker_ids = tf.gather(top_span_speaker_ids, top_antecedents) # [k, c]\n      same_speaker = tf.equal(tf.expand_dims(top_span_speaker_ids, 1), top_antecedent_speaker_ids) # [k, c]\n      speaker_pair_emb = tf.gather(tf.get_variable(\"same_speaker_emb\", [2, self.config[\"feature_size\"]]), tf.to_int32(same_speaker)) # [k, c, emb]\n      feature_emb_list.append(speaker_pair_emb)\n\n      tiled_genre_emb = tf.tile(tf.expand_dims(tf.expand_dims(genre_emb, 0), 0), [k, c, 1]) # [k, c, emb]\n      feature_emb_list.append(tiled_genre_emb)\n\n    if self.config[\"use_features\"]:\n      antecedent_distance_buckets = self.bucket_distance(top_antecedent_offsets) # [k, c]\n      antecedent_distance_emb = tf.gather(tf.get_variable(\"antecedent_distance_emb\", [10, self.config[\"feature_size\"]]), antecedent_distance_buckets) # [k, c]\n      feature_emb_list.append(antecedent_distance_emb)\n\n    feature_emb = tf.concat(feature_emb_list, 2) # [k, c, emb]\n    feature_emb = tf.nn.dropout(feature_emb, self.dropout) # [k, c, emb]\n\n    target_emb = tf.expand_dims(top_span_emb, 1) # [k, 1, emb]\n    similarity_emb = top_antecedent_emb * target_emb # [k, c, emb]\n    target_emb = tf.tile(target_emb, [1, c, 1]) # [k, c, emb]\n\n    pair_emb = tf.concat([target_emb, top_antecedent_emb, similarity_emb, feature_emb], 2) # [k, c, emb]\n\n    with tf.variable_scope(\"slow_antecedent_scores\"):\n      slow_antecedent_scores = util.ffnn(pair_emb, self.config[\"ffnn_depth\"], self.config[\"ffnn_size\"], 1, self.dropout) # [k, c, 1]\n    slow_antecedent_scores = tf.squeeze(slow_antecedent_scores, 2) # [k, c]\n    return slow_antecedent_scores # [k, c]\n\n  def get_fast_antecedent_scores(self, top_span_emb):\n    with tf.variable_scope(\"src_projection\"):\n      source_top_span_emb = tf.nn.dropout(util.projection(top_span_emb, util.shape(top_span_emb, -1)), self.dropout) # [k, emb]\n    target_top_span_emb = tf.nn.dropout(top_span_emb, self.dropout) # [k, emb]\n    return tf.matmul(source_top_span_emb, target_top_span_emb, transpose_b=True) # [k, k]\n\n  def flatten_emb_by_sentence(self, emb, text_len_mask):\n    num_sentences = tf.shape(emb)[0]\n    max_sentence_length = tf.shape(emb)[1]\n\n    emb_rank = len(emb.get_shape())\n    if emb_rank  == 2:\n      flattened_emb = tf.reshape(emb, [num_sentences * max_sentence_length])\n    elif emb_rank == 3:\n      flattened_emb = tf.reshape(emb, [num_sentences * max_sentence_length, util.shape(emb, 2)])\n    else:\n      raise ValueError(\"Unsupported rank: {}\".format(emb_rank))\n    return tf.boolean_mask(flattened_emb, tf.reshape(text_len_mask, [num_sentences * max_sentence_length]))\n\n  def lstm_contextualize(self, text_emb, text_len, text_len_mask):\n    num_sentences = tf.shape(text_emb)[0]\n\n    current_inputs = text_emb # [num_sentences, max_sentence_length, emb]\n\n    for layer in range(self.config[\"contextualization_layers\"]):\n      with tf.variable_scope(\"layer_{}\".format(layer)):\n        with tf.variable_scope(\"fw_cell\"):\n          cell_fw = util.CustomLSTMCell(self.config[\"contextualization_size\"], num_sentences, self.lstm_dropout)\n        with tf.variable_scope(\"bw_cell\"):\n          cell_bw = util.CustomLSTMCell(self.config[\"contextualization_size\"], num_sentences, self.lstm_dropout)\n        state_fw = tf.contrib.rnn.LSTMStateTuple(tf.tile(cell_fw.initial_state.c, [num_sentences, 1]), tf.tile(cell_fw.initial_state.h, [num_sentences, 1]))\n        state_bw = tf.contrib.rnn.LSTMStateTuple(tf.tile(cell_bw.initial_state.c, [num_sentences, 1]), tf.tile(cell_bw.initial_state.h, [num_sentences, 1]))\n\n        (fw_outputs, bw_outputs), _ = tf.nn.bidirectional_dynamic_rnn(\n          cell_fw=cell_fw,\n          cell_bw=cell_bw,\n          inputs=current_inputs,\n          sequence_length=text_len,\n          initial_state_fw=state_fw,\n          initial_state_bw=state_bw)\n\n        text_outputs = tf.concat([fw_outputs, bw_outputs], 2) # [num_sentences, max_sentence_length, emb]\n        text_outputs = tf.nn.dropout(text_outputs, self.lstm_dropout)\n        if layer > 0:\n          highway_gates = tf.sigmoid(util.projection(text_outputs, util.shape(text_outputs, 2))) # [num_sentences, max_sentence_length, emb]\n          text_outputs = highway_gates * text_outputs + (1 - highway_gates) * current_inputs\n        current_inputs = text_outputs\n\n    return self.flatten_emb_by_sentence(text_outputs, text_len_mask)\n\n  def get_predicted_antecedents(self, antecedents, antecedent_scores):\n    predicted_antecedents = []\n    for i, index in enumerate(np.argmax(antecedent_scores, axis=1) - 1):\n      if index < 0:\n        predicted_antecedents.append(-1)\n      else:\n        predicted_antecedents.append(antecedents[i, index])\n    return predicted_antecedents\n\n  def get_predicted_clusters(self, top_span_starts, top_span_ends, predicted_antecedents):\n    mention_to_predicted = {}\n    predicted_clusters = []\n    for i, predicted_index in enumerate(predicted_antecedents):\n      if predicted_index < 0:\n        continue\n      assert i > predicted_index\n      predicted_antecedent = (int(top_span_starts[predicted_index]), int(top_span_ends[predicted_index]))\n      if predicted_antecedent in mention_to_predicted:\n        predicted_cluster = mention_to_predicted[predicted_antecedent]\n      else:\n        predicted_cluster = len(predicted_clusters)\n        predicted_clusters.append([predicted_antecedent])\n        mention_to_predicted[predicted_antecedent] = predicted_cluster\n\n      mention = (int(top_span_starts[i]), int(top_span_ends[i]))\n      predicted_clusters[predicted_cluster].append(mention)\n      mention_to_predicted[mention] = predicted_cluster\n\n    predicted_clusters = [tuple(pc) for pc in predicted_clusters]\n    mention_to_predicted = { m:predicted_clusters[i] for m,i in mention_to_predicted.items() }\n\n    return predicted_clusters, mention_to_predicted\n\n  def evaluate_coref(self, top_span_starts, top_span_ends, predicted_antecedents, gold_clusters, evaluator):\n    gold_clusters = [tuple(tuple(m) for m in gc) for gc in gold_clusters]\n    mention_to_gold = {}\n    for gc in gold_clusters:\n      for mention in gc:\n        mention_to_gold[mention] = gc\n\n    predicted_clusters, mention_to_predicted = self.get_predicted_clusters(top_span_starts, top_span_ends, predicted_antecedents)\n    evaluator.update(predicted_clusters, gold_clusters, mention_to_predicted, mention_to_gold)\n    return predicted_clusters\n\n  def load_eval_data(self):\n    if self.eval_data is None:\n      def load_line(line):\n        example = json.loads(line)\n        return self.tensorize_example(example, is_training=False), example\n      with open(self.config[\"eval_path\"]) as f:\n        self.eval_data = [load_line(l) for l in f.readlines()]\n      num_words = sum(tensorized_example[2].sum() for tensorized_example, _ in self.eval_data)\n      print(\"Loaded {} eval examples.\".format(len(self.eval_data)))\n\n  def evaluate(self, session, official_stdout=False):\n    self.load_eval_data()\n\n    coref_predictions = {}\n    coref_evaluator = metrics.CorefEvaluator()\n\n    for example_num, (tensorized_example, example) in enumerate(self.eval_data):\n      _, _, _, _, _, _, _, _, _, gold_starts, gold_ends, _ = tensorized_example\n      feed_dict = {i:t for i,t in zip(self.input_tensors, tensorized_example)}\n      candidate_starts, candidate_ends, candidate_mention_scores, top_span_starts, top_span_ends, top_antecedents, top_antecedent_scores = session.run(self.predictions, feed_dict=feed_dict)\n      predicted_antecedents = self.get_predicted_antecedents(top_antecedents, top_antecedent_scores)\n      coref_predictions[example[\"doc_key\"]] = self.evaluate_coref(top_span_starts, top_span_ends, predicted_antecedents, example[\"clusters\"], coref_evaluator)\n      if example_num % 10 == 0:\n        print(\"Evaluated {}/{} examples.\".format(example_num + 1, len(self.eval_data)))\n\n    summary_dict = {}\n    conll_results = conll.evaluate_conll(self.config[\"conll_eval_path\"], coref_predictions, official_stdout)\n    average_f1 = sum(results[\"f\"] for results in conll_results.values()) / len(conll_results)\n    summary_dict[\"Average F1 (conll)\"] = average_f1\n    print(\"Average F1 (conll): {:.2f}%\".format(average_f1))\n\n    p,r,f = coref_evaluator.get_prf()\n    summary_dict[\"Average F1 (py)\"] = f\n    print(\"Average F1 (py): {:.2f}%\".format(f * 100))\n    summary_dict[\"Average precision (py)\"] = p\n    print(\"Average precision (py): {:.2f}%\".format(p * 100))\n    summary_dict[\"Average recall (py)\"] = r\n    print(\"Average recall (py): {:.2f}%\".format(r * 100))\n\n    return util.make_summary(summary_dict), average_f1\n"
  },
  {
    "path": "coref_ops.py",
    "content": "from __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport tensorflow as tf\nfrom tensorflow.python import pywrap_tensorflow\n\ncoref_op_library = tf.load_op_library(\"./coref_kernels.so\")\n\nextract_spans = coref_op_library.extract_spans\ntf.NotDifferentiable(\"ExtractSpans\")\n"
  },
  {
    "path": "demo.py",
    "content": "from __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nfrom six.moves import input\nimport tensorflow as tf\nimport coref_model as cm\nimport util\n\nimport nltk\nnltk.download(\"punkt\")\nfrom nltk.tokenize import sent_tokenize, word_tokenize\n\ndef create_example(text):\n  raw_sentences = sent_tokenize(text)\n  sentences = [word_tokenize(s) for s in raw_sentences]\n  speakers = [[\"\" for _ in sentence] for sentence in sentences]\n  return {\n    \"doc_key\": \"nw\",\n    \"clusters\": [],\n    \"sentences\": sentences,\n    \"speakers\": speakers,\n  }\n\ndef print_predictions(example):\n  words = util.flatten(example[\"sentences\"])\n  for cluster in example[\"predicted_clusters\"]:\n    print(u\"Predicted cluster: {}\".format([\" \".join(words[m[0]:m[1]+1]) for m in cluster]))\n\ndef make_predictions(text, model):\n  example = create_example(text)\n  tensorized_example = model.tensorize_example(example, is_training=False)\n  feed_dict = {i:t for i,t in zip(model.input_tensors, tensorized_example)}\n  _, _, _, mention_starts, mention_ends, antecedents, antecedent_scores, head_scores = session.run(model.predictions + [model.head_scores], feed_dict=feed_dict)\n\n  predicted_antecedents = model.get_predicted_antecedents(antecedents, antecedent_scores)\n\n  example[\"predicted_clusters\"], _ = model.get_predicted_clusters(mention_starts, mention_ends, predicted_antecedents)\n  example[\"top_spans\"] = zip((int(i) for i in mention_starts), (int(i) for i in mention_ends))\n  example[\"head_scores\"] = head_scores.tolist()\n  return example\n\nif __name__ == \"__main__\":\n  config = util.initialize_from_env()\n  model = cm.CorefModel(config)\n  with tf.Session() as session:\n    model.restore(session)\n    while True:\n      text = input(\"Document text: \")\n      if len(text) > 0:\n        print_predictions(make_predictions(text, model))\n"
  },
  {
    "path": "evaluate.py",
    "content": "#!/usr/bin/env python\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport os\n\nimport tensorflow as tf\nimport coref_model as cm\nimport util\n\nif __name__ == \"__main__\":\n  config = util.initialize_from_env()\n  model = cm.CorefModel(config)\n  with tf.Session() as session:\n    model.restore(session)\n    model.evaluate(session, official_stdout=True)\n"
  },
  {
    "path": "experiments.conf",
    "content": "# Word embeddings.\nglove_300d {\n  path = glove.840B.300d.txt\n  size = 300\n}\nglove_300d_filtered {\n  path = glove.840B.300d.txt.filtered\n  size = 300\n}\nglove_300d_2w {\n  path = glove_50_300_2.txt\n  size = 300\n}\n\n# Distributed training configurations.\ntwo_local_gpus {\n  addresses {\n    ps = [localhost:2222]\n    worker = [localhost:2223, localhost:2224]\n  }\n  gpus = [0, 1]\n}\n\n# Main configuration.\nbest {\n  # Computation limits.\n  max_top_antecedents = 50\n  max_training_sentences = 50\n  top_span_ratio = 0.4\n\n  # Model hyperparameters.\n  filter_widths = [3, 4, 5]\n  filter_size = 50\n  char_embedding_size = 8\n  char_vocab_path = \"char_vocab.english.txt\"\n  context_embeddings = ${glove_300d_filtered}\n  head_embeddings = ${glove_300d_2w}\n  contextualization_size = 200\n  contextualization_layers = 3\n  ffnn_size = 150\n  ffnn_depth = 2\n  feature_size = 20\n  max_span_width = 30\n  use_metadata = true\n  use_features = true\n  model_heads = true\n  coref_depth = 2\n  lm_layers = 3\n  lm_size = 1024\n  coarse_to_fine = true\n\n  # Learning hyperparameters.\n  max_gradient_norm = 5.0\n  lstm_dropout_rate = 0.4\n  lexical_dropout_rate = 0.5\n  dropout_rate = 0.2\n  optimizer = adam\n  learning_rate = 0.001\n  decay_rate = 0.999\n  decay_frequency = 100\n\n  # Other.\n  train_path = train.english.jsonlines\n  eval_path = dev.english.jsonlines\n  conll_eval_path = dev.english.v4_gold_conll\n  lm_path = elmo_cache.hdf5\n  genres = [\"bc\", \"bn\", \"mz\", \"nw\", \"pt\", \"tc\", \"wb\"]\n  eval_frequency = 5000\n  report_frequency = 100\n  log_root = logs\n  cluster = ${two_local_gpus}\n}\n\n# For evaluation. Do not use for training (i.e. only for predict.py, evaluate.py, and demo.py). Rename `best` directory to `final`.\nfinal = ${best} {\n  context_embeddings = ${glove_300d}\n  head_embeddings = ${glove_300d_2w}\n  lm_path = \"\"\n  eval_path = test.english.jsonlines\n  conll_eval_path = test.english.v4_gold_conll\n}\n\n# Baselines.\nc2f_100_ant = ${best} {\n  max_top_antecedents = 100\n}\nc2f_250_ant = ${best} {\n  max_top_antecedents = 250\n}\nc2f_1_layer = ${best} {\n  coref_depth = 1\n}\nc2f_3_layer = ${best} {\n  coref_depth = 3\n}\ndistance_50_ant = ${best} {\n  max_top_antecedents = 50\n  coarse_to_fine = false\n  coref_depth = 1\n}\ndistance_100_ant = ${distance_50_ant} {\n  max_top_antecedents = 100\n}\ndistance_250_ant = ${distance_50_ant} {\n  max_top_antecedents = 250\n}\n"
  },
  {
    "path": "filter_embeddings.py",
    "content": "from __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport sys\nimport json\n\nif __name__ == \"__main__\":\n  if len(sys.argv) < 3:\n    sys.exit(\"Usage: {} <embeddings> <json1> <json2> ...\".format(sys.argv[0]))\n\n  words_to_keep = set()\n  for json_filename in sys.argv[2:]:\n    with open(json_filename) as json_file:\n      for line in json_file.readlines():\n        for sentence in json.loads(line)[\"sentences\"]:\n          words_to_keep.update(sentence)\n\n  print(\"Found {} words in {} dataset(s).\".format(len(words_to_keep), len(sys.argv) - 2))\n\n  total_lines = 0\n  kept_lines = 0\n  out_filename = \"{}.filtered\".format(sys.argv[1])\n  with open(sys.argv[1]) as in_file:\n    with open(out_filename, \"w\") as out_file:\n      for line in in_file.readlines():\n        total_lines += 1\n        word = line.split()[0]\n        if word in words_to_keep:\n          kept_lines += 1\n          out_file.write(line)\n\n  print(\"Kept {} out of {} lines.\".format(kept_lines, total_lines))\n  print(\"Wrote result to {}.\".format(out_filename))\n"
  },
  {
    "path": "get_char_vocab.py",
    "content": "from __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport sys\nimport json\nimport io\n\ndef get_char_vocab(input_filenames, output_filename):\n  vocab = set()\n  for filename in input_filenames:\n    with open(filename) as f:\n      for line in f.readlines():\n        for sentence in json.loads(line)[\"sentences\"]:\n          for word in sentence:\n            vocab.update(word)\n  vocab = sorted(list(vocab))\n  with io.open(output_filename, mode=\"w\", encoding=\"utf8\") as f:\n    for char in vocab:\n      f.write(char)\n      f.write(u\"\\n\")\n  print(\"Wrote {} characters to {}\".format(len(vocab), output_filename))\n\ndef get_char_vocab_language(language):\n  get_char_vocab([\"{}.{}.jsonlines\".format(partition, language) for partition in (\"train\", \"dev\", \"test\")], \"char_vocab.{}.txt\".format(language))\n\nget_char_vocab_language(\"english\")\nget_char_vocab_language(\"chinese\")\nget_char_vocab_language(\"arabic\")\n"
  },
  {
    "path": "metrics.py",
    "content": "from __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport numpy as np\nfrom collections import Counter\nfrom sklearn.utils.linear_assignment_ import linear_assignment\n\n\ndef f1(p_num, p_den, r_num, r_den, beta=1):\n    p = 0 if p_den == 0 else p_num / float(p_den)\n    r = 0 if r_den == 0 else r_num / float(r_den)\n    return 0 if p + r == 0 else (1 + beta * beta) * p * r / (beta * beta * p + r)\n\nclass CorefEvaluator(object):\n    def __init__(self):\n        self.evaluators = [Evaluator(m) for m in (muc, b_cubed, ceafe)]\n\n    def update(self, predicted, gold, mention_to_predicted, mention_to_gold):\n        for e in self.evaluators:\n            e.update(predicted, gold, mention_to_predicted, mention_to_gold)\n\n    def get_f1(self):\n        return sum(e.get_f1() for e in self.evaluators) / len(self.evaluators)\n\n    def get_recall(self):\n        return sum(e.get_recall() for e in self.evaluators) / len(self.evaluators)\n\n    def get_precision(self):\n        return sum(e.get_precision() for e in self.evaluators) / len(self.evaluators)\n\n    def get_prf(self):\n        return self.get_precision(), self.get_recall(), self.get_f1()\n\nclass Evaluator(object):\n    def __init__(self, metric, beta=1):\n        self.p_num = 0\n        self.p_den = 0\n        self.r_num = 0\n        self.r_den = 0\n        self.metric = metric\n        self.beta = beta\n\n    def update(self, predicted, gold, mention_to_predicted, mention_to_gold):\n        if self.metric == ceafe:\n            pn, pd, rn, rd = self.metric(predicted, gold)\n        else:\n            pn, pd = self.metric(predicted, mention_to_gold)\n            rn, rd = self.metric(gold, mention_to_predicted)\n        self.p_num += pn\n        self.p_den += pd\n        self.r_num += rn\n        self.r_den += rd\n\n    def get_f1(self):\n        return f1(self.p_num, self.p_den, self.r_num, self.r_den, beta=self.beta)\n\n    def get_recall(self):\n        return 0 if self.r_num == 0 else self.r_num / float(self.r_den)\n\n    def get_precision(self):\n        return 0 if self.p_num == 0 else self.p_num / float(self.p_den)\n\n    def get_prf(self):\n        return self.get_precision(), self.get_recall(), self.get_f1()\n\n    def get_counts(self):\n        return self.p_num, self.p_den, self.r_num, self.r_den\n\n\ndef evaluate_documents(documents, metric, beta=1):\n    evaluator = Evaluator(metric, beta=beta)\n    for document in documents:\n        evaluator.update(document)\n    return evaluator.get_precision(), evaluator.get_recall(), evaluator.get_f1()\n\n\ndef b_cubed(clusters, mention_to_gold):\n    num, dem = 0, 0\n\n    for c in clusters:\n        if len(c) == 1:\n            continue\n\n        gold_counts = Counter()\n        correct = 0\n        for m in c:\n            if m in mention_to_gold:\n                gold_counts[tuple(mention_to_gold[m])] += 1\n        for c2, count in gold_counts.items():\n            if len(c2) != 1:\n                correct += count * count\n\n        num += correct / float(len(c))\n        dem += len(c)\n\n    return num, dem\n\n\ndef muc(clusters, mention_to_gold):\n    tp, p = 0, 0\n    for c in clusters:\n        p += len(c) - 1\n        tp += len(c)\n        linked = set()\n        for m in c:\n            if m in mention_to_gold:\n                linked.add(mention_to_gold[m])\n            else:\n                tp -= 1\n        tp -= len(linked)\n    return tp, p\n\n\ndef phi4(c1, c2):\n    return 2 * len([m for m in c1 if m in c2]) / float(len(c1) + len(c2))\n\n\ndef ceafe(clusters, gold_clusters):\n    clusters = [c for c in clusters if len(c) != 1]\n    scores = np.zeros((len(gold_clusters), len(clusters)))\n    for i in range(len(gold_clusters)):\n        for j in range(len(clusters)):\n            scores[i, j] = phi4(gold_clusters[i], clusters[j])\n    matching = linear_assignment(-scores)\n    similarity = sum(scores[matching[:, 0], matching[:, 1]])\n    return similarity, len(clusters), similarity, len(gold_clusters)\n\n\ndef lea(clusters, mention_to_gold):\n    num, dem = 0, 0\n\n    for c in clusters:\n        if len(c) == 1:\n            continue\n\n        common_links = 0\n        all_links = len(c) * (len(c) - 1) / 2.0\n        for i, m in enumerate(c):\n            if m in mention_to_gold:\n                for m2 in c[i + 1:]:\n                    if m2 in mention_to_gold and mention_to_gold[m] == mention_to_gold[m2]:\n                        common_links += 1\n\n        num += len(c) * common_links / float(all_links)\n        dem += len(c)\n\n    return num, dem\n"
  },
  {
    "path": "minimize.py",
    "content": "from __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport re\nimport os\nimport sys\nimport json\nimport tempfile\nimport subprocess\nimport collections\n\nimport util\nimport conll\n\nclass DocumentState(object):\n  def __init__(self):\n    self.doc_key = None\n    self.text = []\n    self.text_speakers = []\n    self.speakers = []\n    self.sentences = []\n    self.constituents = {}\n    self.const_stack = []\n    self.ner = {}\n    self.ner_stack = []\n    self.clusters = collections.defaultdict(list)\n    self.coref_stacks = collections.defaultdict(list)\n\n  def assert_empty(self):\n    assert self.doc_key is None\n    assert len(self.text) == 0\n    assert len(self.text_speakers) == 0\n    assert len(self.speakers) == 0\n    assert len(self.sentences) == 0\n    assert len(self.constituents) == 0\n    assert len(self.const_stack) == 0\n    assert len(self.ner) == 0\n    assert len(self.ner_stack) == 0\n    assert len(self.coref_stacks) == 0\n    assert len(self.clusters) == 0\n\n  def assert_finalizable(self):\n    assert self.doc_key is not None\n    assert len(self.text) == 0\n    assert len(self.text_speakers) == 0\n    assert len(self.speakers) > 0\n    assert len(self.sentences) > 0\n    assert len(self.constituents) > 0\n    assert len(self.const_stack) == 0\n    assert len(self.ner_stack) == 0\n    assert all(len(s) == 0 for s in self.coref_stacks.values())\n\n  def span_dict_to_list(self, span_dict):\n    return [(s,e,l) for (s,e),l in span_dict.items()]\n\n  def finalize(self):\n    merged_clusters = []\n    for c1 in self.clusters.values():\n      existing = None\n      for m in c1:\n        for c2 in merged_clusters:\n          if m in c2:\n            existing = c2\n            break\n        if existing is not None:\n          break\n      if existing is not None:\n        print(\"Merging clusters (shouldn't happen very often.)\")\n        existing.update(c1)\n      else:\n        merged_clusters.append(set(c1))\n    merged_clusters = [list(c) for c in merged_clusters]\n    all_mentions = util.flatten(merged_clusters)\n    assert len(all_mentions) == len(set(all_mentions))\n\n    return {\n      \"doc_key\": self.doc_key,\n      \"sentences\": self.sentences,\n      \"speakers\": self.speakers,\n      \"constituents\": self.span_dict_to_list(self.constituents),\n      \"ner\": self.span_dict_to_list(self.ner),\n      \"clusters\": merged_clusters\n    }\n\ndef normalize_word(word, language):\n  if language == \"arabic\":\n    word = word[:word.find(\"#\")]\n  if word == \"/.\" or word == \"/?\":\n    return word[1:]\n  else:\n    return word\n\ndef handle_bit(word_index, bit, stack, spans):\n  asterisk_idx = bit.find(\"*\")\n  if asterisk_idx >= 0:\n    open_parens = bit[:asterisk_idx]\n    close_parens = bit[asterisk_idx + 1:]\n  else:\n    open_parens = bit[:-1]\n    close_parens = bit[-1]\n\n  current_idx = open_parens.find(\"(\")\n  while current_idx >= 0:\n    next_idx = open_parens.find(\"(\", current_idx + 1)\n    if next_idx >= 0:\n      label = open_parens[current_idx + 1:next_idx]\n    else:\n      label = open_parens[current_idx + 1:]\n    stack.append((word_index, label))\n    current_idx = next_idx\n\n  for c in close_parens:\n    assert c == \")\"\n    open_index, label = stack.pop()\n    current_span = (open_index, word_index)\n    \"\"\"\n    if current_span in spans:\n      spans[current_span] += \"_\" + label\n    else:\n      spans[current_span] = label\n    \"\"\"\n    spans[current_span] = label\n\ndef handle_line(line, document_state, language, labels, stats):\n  begin_document_match = re.match(conll.BEGIN_DOCUMENT_REGEX, line)\n  if begin_document_match:\n    document_state.assert_empty()\n    document_state.doc_key = conll.get_doc_key(begin_document_match.group(1), begin_document_match.group(2))\n    return None\n  elif line.startswith(\"#end document\"):\n    document_state.assert_finalizable()\n    finalized_state = document_state.finalize()\n    stats[\"num_clusters\"] += len(finalized_state[\"clusters\"])\n    stats[\"num_mentions\"] += sum(len(c) for c in finalized_state[\"clusters\"])\n    labels[\"{}_const_labels\".format(language)].update(l for _, _, l in finalized_state[\"constituents\"])\n    labels[\"ner\"].update(l for _, _, l in finalized_state[\"ner\"])\n    return finalized_state\n  else:\n    row = line.split()\n    if len(row) == 0:\n      stats[\"max_sent_len_{}\".format(language)] = max(len(document_state.text), stats[\"max_sent_len_{}\".format(language)])\n      stats[\"num_sents_{}\".format(language)] += 1\n      document_state.sentences.append(tuple(document_state.text))\n      del document_state.text[:]\n      document_state.speakers.append(tuple(document_state.text_speakers))\n      del document_state.text_speakers[:]\n      return None\n    assert len(row) >= 12\n\n    doc_key = conll.get_doc_key(row[0], row[1])\n    word = normalize_word(row[3], language)\n    parse = row[5]\n    speaker = row[9]\n    ner = row[10]\n    coref = row[-1]\n\n    word_index = len(document_state.text) + sum(len(s) for s in document_state.sentences)\n    document_state.text.append(word)\n    document_state.text_speakers.append(speaker)\n\n    handle_bit(word_index, parse, document_state.const_stack, document_state.constituents)\n    handle_bit(word_index, ner, document_state.ner_stack, document_state.ner)\n\n    if coref != \"-\":\n      for segment in coref.split(\"|\"):\n        if segment[0] == \"(\":\n          if segment[-1] == \")\":\n            cluster_id = int(segment[1:-1])\n            document_state.clusters[cluster_id].append((word_index, word_index))\n          else:\n            cluster_id = int(segment[1:])\n            document_state.coref_stacks[cluster_id].append(word_index)\n        else:\n          cluster_id = int(segment[:-1])\n          start = document_state.coref_stacks[cluster_id].pop()\n          document_state.clusters[cluster_id].append((start, word_index))\n    return None\n\ndef minimize_partition(name, language, extension, labels, stats):\n  input_path = \"{}.{}.{}\".format(name, language, extension)\n  output_path = \"{}.{}.jsonlines\".format(name, language)\n  count = 0\n  print(\"Minimizing {}\".format(input_path))\n  with open(input_path, \"r\") as input_file:\n    with open(output_path, \"w\") as output_file:\n      document_state = DocumentState()\n      for line in input_file.readlines():\n        document = handle_line(line, document_state, language, labels, stats)\n        if document is not None:\n          output_file.write(json.dumps(document))\n          output_file.write(\"\\n\")\n          count += 1\n          document_state = DocumentState()\n  print(\"Wrote {} documents to {}\".format(count, output_path))\n\ndef minimize_language(language, labels, stats):\n  minimize_partition(\"dev\", language, \"v4_gold_conll\", labels, stats)\n  minimize_partition(\"train\", language, \"v4_gold_conll\", labels, stats)\n  minimize_partition(\"test\", language, \"v4_gold_conll\", labels, stats)\n\nif __name__ == \"__main__\":\n  labels = collections.defaultdict(set)\n  stats = collections.defaultdict(int)\n  minimize_language(\"english\", labels, stats)\n  minimize_language(\"chinese\", labels, stats)\n  minimize_language(\"arabic\", labels, stats)\n  for k, v in labels.items():\n    print(\"{} = [{}]\".format(k, \", \".join(\"\\\"{}\\\"\".format(label) for label in v)))\n  for k, v in stats.items():\n    print(\"{} = {}\".format(k, v))\n"
  },
  {
    "path": "predict.py",
    "content": "from __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport sys\nimport json\n\nimport tensorflow as tf\nimport coref_model as cm\nimport util\n\nif __name__ == \"__main__\":\n  config = util.initialize_from_env()\n\n  # Input file in .jsonlines format.\n  input_filename = sys.argv[2]\n\n  # Predictions will be written to this file in .jsonlines format.\n  output_filename = sys.argv[3]\n\n  model = cm.CorefModel(config)\n\n  with tf.Session() as session:\n    model.restore(session)\n\n    with open(output_filename, \"w\") as output_file:\n      with open(input_filename) as input_file:\n        for example_num, line in enumerate(input_file.readlines()):\n          example = json.loads(line)\n          tensorized_example = model.tensorize_example(example, is_training=False)\n          feed_dict = {i:t for i,t in zip(model.input_tensors, tensorized_example)}\n          _, _, _, top_span_starts, top_span_ends, top_antecedents, top_antecedent_scores = session.run(model.predictions, feed_dict=feed_dict)\n          predicted_antecedents = model.get_predicted_antecedents(top_antecedents, top_antecedent_scores)\n          example[\"predicted_clusters\"], _ = model.get_predicted_clusters(top_span_starts, top_span_ends, predicted_antecedents)\n\n          output_file.write(json.dumps(example))\n          output_file.write(\"\\n\")\n          if example_num % 100 == 0:\n            print(\"Decoded {} examples.\".format(example_num + 1))\n"
  },
  {
    "path": "ps.py",
    "content": "#!/usr/bin/env python\n\nimport os\n\nimport tensorflow as tf\nimport util\n\nif __name__ == \"__main__\":\n  config = util.initialize_from_env()\n  report_frequency = config[\"report_frequency\"]\n  cluster_config = config[\"cluster\"]\n  util.set_gpus()\n  cluster = tf.train.ClusterSpec(cluster_config[\"addresses\"])\n  server = tf.train.Server(cluster, job_name=\"ps\", task_index=0)\n  server.join()\n"
  },
  {
    "path": "requirements.txt",
    "content": "tensorflow-gpu>=1.13.1\ntensorflow-hub>=0.4.0\nh5py\nnltk\npyhocon\nscipy\nsklearn\n"
  },
  {
    "path": "setup_all.sh",
    "content": "#!/bin/bash\n\n# Download pretrained embeddings.\ncurl -O http://downloads.cs.stanford.edu/nlp/data/glove.840B.300d.zip\nunzip glove.840B.300d.zip\nrm glove.840B.300d.zip\n\n# Build custom kernels.\nTF_CFLAGS=( $(python -c 'import tensorflow as tf; print(\" \".join(tf.sysconfig.get_compile_flags()))') )\nTF_LFLAGS=( $(python -c 'import tensorflow as tf; print(\" \".join(tf.sysconfig.get_link_flags()))') )\n\n# Linux (pip)\ng++ -std=c++11 -shared coref_kernels.cc -o coref_kernels.so -fPIC ${TF_CFLAGS[@]} ${TF_LFLAGS[@]} -O2 -D_GLIBCXX_USE_CXX11_ABI=0\n\n# Linux (build from source)\n#g++ -std=c++11 -shared coref_kernels.cc -o coref_kernels.so -fPIC ${TF_CFLAGS[@]} ${TF_LFLAGS[@]} -O2\n\n# Mac\n#g++ -std=c++11 -shared coref_kernels.cc -o coref_kernels.so -I -fPIC ${TF_CFLAGS[@]} ${TF_LFLAGS[@]} -O2 -D_GLIBCXX_USE_CXX11_ABI=0 -undefined dynamic_lookup\n"
  },
  {
    "path": "setup_training.sh",
    "content": "#!/bin/bash\n\ndlx() {\n  wget $1/$2\n  tar -xvzf $2\n  rm $2\n}\n\nconll_url=http://conll.cemantix.org/2012/download\ndlx $conll_url conll-2012-train.v4.tar.gz\ndlx $conll_url conll-2012-development.v4.tar.gz\ndlx $conll_url/test conll-2012-test-key.tar.gz\ndlx $conll_url/test conll-2012-test-official.v9.tar.gz\n\ndlx $conll_url conll-2012-scripts.v3.tar.gz\n\ndlx http://conll.cemantix.org/download reference-coreference-scorers.v8.01.tar.gz\nmv reference-coreference-scorers conll-2012/scorer\n\nontonotes_path=/projects/WebWare6/ontonotes-release-5.0\nbash conll-2012/v3/scripts/skeleton2conll.sh -D $ontonotes_path/data/files/data conll-2012\n\nfunction compile_partition() {\n    rm -f $2.$5.$3$4\n    cat conll-2012/$3/data/$1/data/$5/annotations/*/*/*/*.$3$4 >> $2.$5.$3$4\n}\n\nfunction compile_language() {\n    compile_partition development dev v4 _gold_conll $1\n    compile_partition train train v4 _gold_conll $1\n    compile_partition test test v4 _gold_conll $1\n}\n\ncompile_language english\ncompile_language chinese\ncompile_language arabic\n\npython minimize.py\npython get_char_vocab.py\n\npython filter_embeddings.py glove.840B.300d.txt train.english.jsonlines dev.english.jsonlines\npython cache_elmo.py train.english.jsonlines dev.english.jsonlines\n"
  },
  {
    "path": "train.py",
    "content": "#!/usr/bin/env python\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport os\nimport time\n\nimport tensorflow as tf\nimport coref_model as cm\nimport util\n\nif __name__ == \"__main__\":\n  config = util.initialize_from_env()\n\n  report_frequency = config[\"report_frequency\"]\n  eval_frequency = config[\"eval_frequency\"]\n\n  model = cm.CorefModel(config)\n  saver = tf.train.Saver()\n\n  log_dir = config[\"log_dir\"]\n  writer = tf.summary.FileWriter(log_dir, flush_secs=20)\n\n  max_f1 = 0\n\n  with tf.Session() as session:\n    session.run(tf.global_variables_initializer())\n    model.start_enqueue_thread(session)\n    accumulated_loss = 0.0\n\n    ckpt = tf.train.get_checkpoint_state(log_dir)\n    if ckpt and ckpt.model_checkpoint_path:\n      print(\"Restoring from: {}\".format(ckpt.model_checkpoint_path))\n      saver.restore(session, ckpt.model_checkpoint_path)\n\n    initial_time = time.time()\n    while True:\n      tf_loss, tf_global_step, _ = session.run([model.loss, model.global_step, model.train_op])\n      accumulated_loss += tf_loss\n\n      if tf_global_step % report_frequency == 0:\n        total_time = time.time() - initial_time\n        steps_per_second = tf_global_step / total_time\n\n        average_loss = accumulated_loss / report_frequency\n        print(\"[{}] loss={:.2f}, steps/s={:.2f}\".format(tf_global_step, average_loss, steps_per_second))\n        writer.add_summary(util.make_summary({\"loss\": average_loss}), tf_global_step)\n        accumulated_loss = 0.0\n\n      if tf_global_step % eval_frequency == 0:\n        saver.save(session, os.path.join(log_dir, \"model\"), global_step=tf_global_step)\n        eval_summary, eval_f1 = model.evaluate(session)\n\n        if eval_f1 > max_f1:\n          max_f1 = eval_f1\n          util.copy_checkpoint(os.path.join(log_dir, \"model-{}\".format(tf_global_step)), os.path.join(log_dir, \"model.max.ckpt\"))\n\n        writer.add_summary(eval_summary, tf_global_step)\n        writer.add_summary(util.make_summary({\"max_eval_f1\": max_f1}), tf_global_step)\n\n        print(\"[{}] evaL_f1={:.2f}, max_f1={:.2f}\".format(tf_global_step, eval_f1, max_f1))\n"
  },
  {
    "path": "util.py",
    "content": "from __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport os\nimport errno\nimport codecs\nimport collections\nimport json\nimport math\nimport shutil\nimport sys\n\nimport numpy as np\nimport tensorflow as tf\nimport pyhocon\n\n\ndef initialize_from_env():\n  if \"GPU\" in os.environ:\n    set_gpus(int(os.environ[\"GPU\"]))\n  else:\n    set_gpus()\n\n  name = sys.argv[1]\n  print(\"Running experiment: {}\".format(name))\n\n  config = pyhocon.ConfigFactory.parse_file(\"experiments.conf\")[name]\n  config[\"log_dir\"] = mkdirs(os.path.join(config[\"log_root\"], name))\n\n  print(pyhocon.HOCONConverter.convert(config, \"hocon\"))\n  return config\n\ndef copy_checkpoint(source, target):\n  for ext in (\".index\", \".data-00000-of-00001\"):\n    shutil.copyfile(source + ext, target + ext)\n\ndef make_summary(value_dict):\n  return tf.Summary(value=[tf.Summary.Value(tag=k, simple_value=v) for k,v in value_dict.items()])\n\ndef flatten(l):\n  return [item for sublist in l for item in sublist]\n\ndef set_gpus(*gpus):\n  os.environ[\"CUDA_VISIBLE_DEVICES\"] = \",\".join(str(g) for g in gpus)\n  print(\"Setting CUDA_VISIBLE_DEVICES to: {}\".format(os.environ[\"CUDA_VISIBLE_DEVICES\"]))\n\ndef mkdirs(path):\n  try:\n    os.makedirs(path)\n  except OSError as exception:\n    if exception.errno != errno.EEXIST:\n      raise\n  return path\n\ndef load_char_dict(char_vocab_path):\n  vocab = [u\"<unk>\"]\n  with codecs.open(char_vocab_path, encoding=\"utf-8\") as f:\n    vocab.extend(l.strip() for l in f.readlines())\n  char_dict = collections.defaultdict(int)\n  char_dict.update({c:i for i, c in enumerate(vocab)})\n  return char_dict\n\ndef maybe_divide(x, y):\n  return 0 if y == 0 else x / float(y)\n\ndef projection(inputs, output_size, initializer=None):\n  return ffnn(inputs, 0, -1, output_size, dropout=None, output_weights_initializer=initializer)\n\ndef highway(inputs, num_layers, dropout):\n  for i in range(num_layers):\n    with tf.variable_scope(\"highway_{}\".format(i)):\n      j, f = tf.split(projection(inputs, 2 * shape(inputs, -1)), 2, -1)\n      f = tf.sigmoid(f)\n      j = tf.nn.relu(j)\n      if dropout is not None:\n        j = tf.nn.dropout(j, dropout)\n      inputs = f * j + (1 - f) * inputs\n  return inputs\n\ndef shape(x, dim):\n  return x.get_shape()[dim].value or tf.shape(x)[dim]\n\ndef ffnn(inputs, num_hidden_layers, hidden_size, output_size, dropout, output_weights_initializer=None):\n  if len(inputs.get_shape()) > 3:\n    raise ValueError(\"FFNN with rank {} not supported\".format(len(inputs.get_shape())))\n\n  if len(inputs.get_shape()) == 3:\n    batch_size = shape(inputs, 0)\n    seqlen = shape(inputs, 1)\n    emb_size = shape(inputs, 2)\n    current_inputs = tf.reshape(inputs, [batch_size * seqlen, emb_size])\n  else:\n    current_inputs = inputs\n\n  for i in range(num_hidden_layers):\n    hidden_weights = tf.get_variable(\"hidden_weights_{}\".format(i), [shape(current_inputs, 1), hidden_size])\n    hidden_bias = tf.get_variable(\"hidden_bias_{}\".format(i), [hidden_size])\n    current_outputs = tf.nn.relu(tf.nn.xw_plus_b(current_inputs, hidden_weights, hidden_bias))\n\n    if dropout is not None:\n      current_outputs = tf.nn.dropout(current_outputs, dropout)\n    current_inputs = current_outputs\n\n  output_weights = tf.get_variable(\"output_weights\", [shape(current_inputs, 1), output_size], initializer=output_weights_initializer)\n  output_bias = tf.get_variable(\"output_bias\", [output_size])\n  outputs = tf.nn.xw_plus_b(current_inputs, output_weights, output_bias)\n\n  if len(inputs.get_shape()) == 3:\n    outputs = tf.reshape(outputs, [batch_size, seqlen, output_size])\n  return outputs\n\ndef cnn(inputs, filter_sizes, num_filters):\n  num_words = shape(inputs, 0)\n  num_chars = shape(inputs, 1)\n  input_size = shape(inputs, 2)\n  outputs = []\n  for i, filter_size in enumerate(filter_sizes):\n    with tf.variable_scope(\"conv_{}\".format(i)):\n      w = tf.get_variable(\"w\", [filter_size, input_size, num_filters])\n      b = tf.get_variable(\"b\", [num_filters])\n    conv = tf.nn.conv1d(inputs, w, stride=1, padding=\"VALID\") # [num_words, num_chars - filter_size, num_filters]\n    h = tf.nn.relu(tf.nn.bias_add(conv, b)) # [num_words, num_chars - filter_size, num_filters]\n    pooled = tf.reduce_max(h, 1) # [num_words, num_filters]\n    outputs.append(pooled)\n  return tf.concat(outputs, 1) # [num_words, num_filters * len(filter_sizes)]\n\ndef batch_gather(emb, indices):\n  batch_size = shape(emb, 0)\n  seqlen = shape(emb, 1)\n  if len(emb.get_shape()) > 2:\n    emb_size = shape(emb, 2)\n  else:\n    emb_size = 1\n  flattened_emb = tf.reshape(emb, [batch_size * seqlen, emb_size])  # [batch_size * seqlen, emb]\n  offset = tf.expand_dims(tf.range(batch_size) * seqlen, 1)  # [batch_size, 1]\n  gathered = tf.gather(flattened_emb, indices + offset) # [batch_size, num_indices, emb]\n  if len(emb.get_shape()) == 2:\n    gathered = tf.squeeze(gathered, 2) # [batch_size, num_indices]\n  return gathered\n\nclass RetrievalEvaluator(object):\n  def __init__(self):\n    self._num_correct = 0\n    self._num_gold = 0\n    self._num_predicted = 0\n\n  def update(self, gold_set, predicted_set):\n    self._num_correct += len(gold_set & predicted_set)\n    self._num_gold += len(gold_set)\n    self._num_predicted += len(predicted_set)\n\n  def recall(self):\n    return maybe_divide(self._num_correct, self._num_gold)\n\n  def precision(self):\n    return maybe_divide(self._num_correct, self._num_predicted)\n\n  def metrics(self):\n    recall = self.recall()\n    precision = self.precision()\n    f1 = maybe_divide(2 * recall * precision, precision + recall)\n    return recall, precision, f1\n\nclass EmbeddingDictionary(object):\n  def __init__(self, info, normalize=True, maybe_cache=None):\n    self._size = info[\"size\"]\n    self._normalize = normalize\n    self._path = info[\"path\"]\n    if maybe_cache is not None and maybe_cache._path == self._path:\n      assert self._size == maybe_cache._size\n      self._embeddings = maybe_cache._embeddings\n    else:\n      self._embeddings = self.load_embedding_dict(self._path)\n\n  @property\n  def size(self):\n    return self._size\n\n  def load_embedding_dict(self, path):\n    print(\"Loading word embeddings from {}...\".format(path))\n    default_embedding = np.zeros(self.size)\n    embedding_dict = collections.defaultdict(lambda:default_embedding)\n    if len(path) > 0:\n      vocab_size = None\n      with open(path) as f:\n        for i, line in enumerate(f.readlines()):\n          word_end = line.find(\" \")\n          word = line[:word_end]\n          embedding = np.fromstring(line[word_end + 1:], np.float32, sep=\" \")\n          assert len(embedding) == self.size\n          embedding_dict[word] = embedding\n      if vocab_size is not None:\n        assert vocab_size == len(embedding_dict)\n      print(\"Done loading word embeddings.\")\n    return embedding_dict\n\n  def __getitem__(self, key):\n    embedding = self._embeddings[key]\n    if self._normalize:\n      embedding = self.normalize(embedding)\n    return embedding\n\n  def normalize(self, v):\n    norm = np.linalg.norm(v)\n    if norm > 0:\n      return v / norm\n    else:\n      return v\n\nclass CustomLSTMCell(tf.contrib.rnn.RNNCell):\n  def __init__(self, num_units, batch_size, dropout):\n    self._num_units = num_units\n    self._dropout = dropout\n    self._dropout_mask = tf.nn.dropout(tf.ones([batch_size, self.output_size]), dropout)\n    self._initializer = self._block_orthonormal_initializer([self.output_size] * 3)\n    initial_cell_state = tf.get_variable(\"lstm_initial_cell_state\", [1, self.output_size])\n    initial_hidden_state = tf.get_variable(\"lstm_initial_hidden_state\", [1, self.output_size])\n    self._initial_state = tf.contrib.rnn.LSTMStateTuple(initial_cell_state, initial_hidden_state)\n\n  @property\n  def state_size(self):\n    return tf.contrib.rnn.LSTMStateTuple(self.output_size, self.output_size)\n\n  @property\n  def output_size(self):\n    return self._num_units\n\n  @property\n  def initial_state(self):\n    return self._initial_state\n\n  def __call__(self, inputs, state, scope=None):\n    \"\"\"Long short-term memory cell (LSTM).\"\"\"\n    with tf.variable_scope(scope or type(self).__name__):  # \"CustomLSTMCell\"\n      c, h = state\n      h *= self._dropout_mask\n      concat = projection(tf.concat([inputs, h], 1), 3 * self.output_size, initializer=self._initializer)\n      i, j, o = tf.split(concat, num_or_size_splits=3, axis=1)\n      i = tf.sigmoid(i)\n      new_c = (1 - i) * c  + i * tf.tanh(j)\n      new_h = tf.tanh(new_c) * tf.sigmoid(o)\n      new_state = tf.contrib.rnn.LSTMStateTuple(new_c, new_h)\n      return new_h, new_state\n\n  def _orthonormal_initializer(self, scale=1.0):\n    def _initializer(shape, dtype=tf.float32, partition_info=None):\n      M1 = np.random.randn(shape[0], shape[0]).astype(np.float32)\n      M2 = np.random.randn(shape[1], shape[1]).astype(np.float32)\n      Q1, R1 = np.linalg.qr(M1)\n      Q2, R2 = np.linalg.qr(M2)\n      Q1 = Q1 * np.sign(np.diag(R1))\n      Q2 = Q2 * np.sign(np.diag(R2))\n      n_min = min(shape[0], shape[1])\n      params = np.dot(Q1[:, :n_min], Q2[:n_min, :]) * scale\n      return params\n    return _initializer\n\n  def _block_orthonormal_initializer(self, output_sizes):\n    def _initializer(shape, dtype=np.float32, partition_info=None):\n      assert len(shape) == 2\n      assert sum(output_sizes) == shape[1]\n      initializer = self._orthonormal_initializer()\n      params = np.concatenate([initializer([shape[0], o], dtype, partition_info) for o in output_sizes], 1)\n      return params\n    return _initializer\n"
  },
  {
    "path": "worker.py",
    "content": "#!/usr/bin/env python\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport os\nimport sys\nimport time\n\nimport tensorflow as tf\nimport coref_model as cm\nimport util\n\nif __name__ == \"__main__\":\n  config = util.initialize_from_env()\n  task_index = int(os.environ[\"TASK\"])\n\n  report_frequency = config[\"report_frequency\"]\n  cluster_config = config[\"cluster\"]\n\n  util.set_gpus(cluster_config[\"gpus\"][task_index])\n\n  cluster = tf.train.ClusterSpec(cluster_config[\"addresses\"])\n  server = tf.train.Server(cluster,\n                           job_name=\"worker\",\n                           task_index=task_index)\n\n  # Assigns ops to the local worker by default.\n  with tf.device(tf.train.replica_device_setter(worker_device=\"/job:worker/task:%d\" % task_index, cluster=cluster)):\n    model = cm.CorefModel(config)\n    saver = tf.train.Saver()\n    init_op = tf.global_variables_initializer()\n\n  log_dir = config[\"log_dir\"]\n  writer = tf.summary.FileWriter(os.path.join(log_dir, \"w{}\".format(task_index)), flush_secs=20)\n\n  is_chief = (task_index == 0)\n\n  # Create a \"supervisor\", which oversees the training process.\n  sv = tf.train.Supervisor(is_chief=is_chief,\n                           logdir=log_dir,\n                           init_op=init_op,\n                           saver=saver,\n                           global_step=model.global_step,\n                           save_model_secs=120)\n\n  # The supervisor takes care of session initialization, restoring from\n  # a checkpoint, and closing when done or an error occurs.\n  with sv.managed_session(server.target) as session:\n    model.start_enqueue_thread(session)\n    accumulated_loss = 0.0\n    initial_time = time.time()\n    while not sv.should_stop():\n      tf_loss, tf_global_step, _ = session.run([model.loss, model.global_step, model.train_op])\n      accumulated_loss += tf_loss\n\n      if tf_global_step % report_frequency == 0:\n        total_time = time.time() - initial_time\n        steps_per_second = tf_global_step / total_time\n\n        average_loss = accumulated_loss / report_frequency\n        print(\"[{}] loss={:.2f}, steps/s={:.2f}\".format(tf_global_step, tf_loss, steps_per_second))\n        accumulated_loss = 0.0\n        writer.add_summary(util.make_summary({\n          \"Train Loss\": average_loss,\n          \"Steps per second\": steps_per_second\n        }))\n\n  # Ask for all the services to stop.\n  sv.stop()\n"
  }
]