[
  {
    "path": "chat.py",
    "content": "import tensorflow as tf\nimport importlib\nimport random\nfrom preprocess.data_utils import utter_preprocess, is_reach_goal\n\nclass Target_Chat():\n    def __init__(self, agent):\n        self.agent = agent\n        self.start_utter = config_data._start_corpus\n        with tf.Session(config=self.agent.gpu_config) as sess:\n            self.agent.retrieve_init(sess)\n            for i in range(int(FLAGS.times)):\n                print('--------Session {} --------'.format(i))\n                self.chat(sess)\n\n    def chat(self, sess):\n        history = []\n        history.append(random.sample(self.start_utter, 1)[0])\n        target_kw = random.sample(target_set,1)[0]\n        self.agent.target = target_kw\n        self.agent.score = 0.\n        self.agent.reply_list = []\n        print('START: ' + history[0])\n        for i in range(config_data._max_turns):\n            history.append(input('HUMAN: '))\n            source = utter_preprocess(history, self.agent.data_config._max_seq_len)\n            reply = self.agent.retrieve(source, sess)\n            print('AGENT: ', reply)\n#             print('Keyword: {}, Similarity: {:.2f}'.format(self.agent.next_kw, self.agent.score))\n            history.append(reply)\n            if is_reach_goal(history[-2] + history[-1], target_kw):\n                print('Successfully chat to the target \\'{}\\'.'.format(target_kw))\n                return\n        print('Failed by reaching the maximum turn, target: \\'{}\\'.'.format(target_kw))\n\nif __name__ == '__main__':\n    flags = tf.flags\n    # supports kernel / matrix / neural / retrieval / retrieval-stg\n    flags.DEFINE_string('agent', 'kernel', 'The agent type')\n    flags.DEFINE_string('times', '100', 'Conversation times')\n    FLAGS = flags.FLAGS\n\n    config_data = importlib.import_module('config.data_config')\n    config_model = importlib.import_module('config.' + FLAGS.agent)\n    model = importlib.import_module('model.' + FLAGS.agent)\n    predictor = model.Predictor(config_model, config_data, 'test')\n    \n    target_set = []\n    for line in open('tx_data/test/keywords.txt', 'r').readlines():\n        target_set = target_set + line.strip().split(' ')\n\n    Target_Chat(predictor)\n"
  },
  {
    "path": "config/data_config.py",
    "content": "import os\ndata_root = './tx_data'\n_corpus = [x.strip() for x in open('tx_data/corpus.txt', 'r').readlines()]\n_start_corpus = [x.strip() for x in open('tx_data/start_corpus.txt', 'r').readlines()]\n_max_seq_len = 30\n_num_neg = 20\n_max_turns = 8\n_batch_size = 64\n_retrieval_candidates = 1000\n\ndata_hparams = {\n    stage: {\n        \"num_epochs\": 1,\n        \"shuffle\": stage != 'test',\n        \"batch_size\": _batch_size,\n        \"datasets\": [\n            {  # dialogue history\n                \"variable_utterance\": True,\n                \"max_utterance_cnt\": 9,\n                \"max_seq_length\": _max_seq_len,\n                \"files\": [os.path.join(data_root, '{}/source.txt'.format(stage))],\n                \"vocab_file\": os.path.join(data_root, 'vocab.txt'),\n                \"embedding_init\": {\n                    \"file\": os.path.join(data_root, 'embedding.txt'),\n                    \"dim\": 200,\n                    \"read_fn\": \"load_glove\"\n                },\n                \"data_name\": \"source\"\n            },\n            {  # candidate response\n                \"variable_utterance\": True,\n                \"max_utterance_cnt\": 20,\n                \"max_seq_length\": _max_seq_len,\n                \"files\": [os.path.join(data_root, '{}/target.txt'.format(stage))],\n                \"vocab_share_with\": 0,\n                \"embedding_init_share_with\" : 0,\n                \"data_name\": \"target\"\n            },\n            {  # context (source keywords)\n                \"files\": [os.path.join(data_root, '{}/context.txt'.format(stage))],\n                \"vocab_share_with\": 0,\n                \"embedding_init_share_with\": 0,\n                \"data_name\": \"context\",\n                \"bos_token\": '',\n                \"eos_token\": '',\n            },\n            {  # target keywords\n                \"files\": [os.path.join(data_root, '{}/keywords.txt'.format(stage))],\n                \"vocab_share_with\": 0,\n                \"embedding_init_share_with\": 0,\n                \"data_name\": \"keywords\",\n                \"bos_token\": '',\n                \"eos_token\": '',\n            },\n            {  # label\n                \"files\": [os.path.join(data_root, '{}/label.txt'.format(stage))],\n                \"data_type\": \"int\",\n                \"data_name\": \"label\"\n            }\n        ]\n    }\n    for stage in ['train','valid','test']\n}\n\n\ncorpus_hparams = {\n    \"batch_size\": _batch_size*2,\n    \"shuffle\": False,\n    \"dataset\":{\n        \"max_seq_length\": _max_seq_len,\n        \"files\": [os.path.join(data_root, 'corpus.txt')],\n        \"vocab_file\": os.path.join(data_root, 'vocab.txt'),\n        \"data_name\": \"corpus\"\n    }\n}\n\n\n_keywords_path = 'tx_data/test/keywords_vocab.txt'\n_keywords_candi = [x.strip() for x in open(_keywords_path, 'r').readlines()]\n_keywords_num = len(_keywords_candi)\n_keywords_dict = {}\nfor i in range(_keywords_num):\n    _keywords_dict[_keywords_candi[i]] = i\n"
  },
  {
    "path": "config/kernel.py",
    "content": "_hidden_size = 200\n_code_len = 800\n_save_path = 'save/kernel/model_1'\n_kernel_save_path = 'save/kernel/keyword_1'\n_kernel_mu = [0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,1.]\n_kernel_sigma = 0.1\n_max_epoch = 10\n_early_stopping = 2\n\nkernel_opt_hparams = {\n    \"optimizer\": {\n        \"type\": \"AdamOptimizer\",\n        \"kwargs\": {\n            \"learning_rate\": 0.001,\n        }\n    },\n    \"learning_rate_decay\": {\n        \"type\": \"inverse_time_decay\",\n        \"kwargs\": {\n            \"decay_steps\": 1600,\n            \"decay_rate\": 0.8\n        },\n        \"start_decay_step\": 0,\n        \"end_decay_step\": 16000,\n    },\n}\n\nsource_encoder_hparams = {\n    \"encoder_minor_type\": \"BidirectionalRNNEncoder\",\n    \"encoder_minor_hparams\": {\n        \"rnn_cell_fw\": {\n            \"type\": \"GRUCell\",\n            \"kwargs\": {\n                \"num_units\": _hidden_size,\n            },\n        },\n        \"rnn_cell_share_config\": True\n    },\n    \"encoder_major_type\": \"UnidirectionalRNNEncoder\",\n    \"encoder_major_hparams\": {\n        \"rnn_cell\": {\n            \"type\": \"GRUCell\",\n            \"kwargs\": {\n                \"num_units\": _hidden_size*2,\n            },\n        }\n    }\n}\n\ntarget_encoder_hparams = {\n    \"rnn_cell_fw\": {\n        \"type\": \"GRUCell\",\n        \"kwargs\": {\n            \"num_units\": _hidden_size,\n        },\n    },\n    \"rnn_cell_share_config\": True\n}\n\ntarget_kwencoder_hparams = {\n    \"rnn_cell_fw\": {\n        \"type\": \"GRUCell\",\n        \"kwargs\": {\n            \"num_units\": _hidden_size,\n        },\n    },\n    \"rnn_cell_share_config\": True\n}\n\ncontext_encoder_hparams = {\n    \"rnn_cell\": {\n        \"type\": \"GRUCell\",\n        \"kwargs\": {\n            \"num_units\": _hidden_size,\n        },\n    }\n}\n\nopt_hparams = {\n    \"optimizer\": {\n        \"type\": \"AdamOptimizer\",\n        \"kwargs\": {\n            \"learning_rate\": 0.001,\n        }\n    },\n}"
  },
  {
    "path": "config/matrix.py",
    "content": "_hidden_size = 200\n_code_len = 800\n_save_path = 'save/matrix/model_1'\n_matrix_save_path = 'save/matrix/matrix_1.pk'\n_max_epoch = 10\n\n_vocab_path = 'tx_data/vocab.txt'\n_vocab = [x.strip() for x in open(_vocab_path, 'r').readlines()]\n_vocab_size = len(_vocab)\n\nsource_encoder_hparams = {\n    \"encoder_minor_type\": \"BidirectionalRNNEncoder\",\n    \"encoder_minor_hparams\": {\n        \"rnn_cell_fw\": {\n            \"type\": \"GRUCell\",\n            \"kwargs\": {\n                \"num_units\": _hidden_size,\n            },\n        },\n        \"rnn_cell_share_config\": True\n    },\n    \"encoder_major_type\": \"UnidirectionalRNNEncoder\",\n    \"encoder_major_hparams\": {\n        \"rnn_cell\": {\n            \"type\": \"GRUCell\",\n            \"kwargs\": {\n                \"num_units\": _hidden_size*2,\n            },\n        }\n    }\n}\n\ntarget_encoder_hparams = {\n    \"rnn_cell_fw\": {\n        \"type\": \"GRUCell\",\n        \"kwargs\": {\n            \"num_units\": _hidden_size,\n        },\n    },\n    \"rnn_cell_share_config\": True\n}\n\ntarget_kwencoder_hparams = {\n    \"rnn_cell_fw\": {\n        \"type\": \"GRUCell\",\n        \"kwargs\": {\n            \"num_units\": _hidden_size,\n        },\n    },\n    \"rnn_cell_share_config\": True\n}\n\nopt_hparams = {\n    \"optimizer\": {\n        \"type\": \"AdamOptimizer\",\n        \"kwargs\": {\n            \"learning_rate\": 0.001,\n        }\n    }\n}"
  },
  {
    "path": "config/neural.py",
    "content": "_hidden_size = 200\n_code_len = 800\n_save_path = 'save/neural/model_1'\n_neural_save_path = 'save/neural/keyword_1'\n_max_epoch = 10\n\nneural_opt_hparams = {\n    \"optimizer\": {\n        \"type\": \"AdamOptimizer\",\n        \"kwargs\": {\n            \"learning_rate\": 0.005,\n        }\n    },\n    \"learning_rate_decay\": {\n        \"type\": \"inverse_time_decay\",\n        \"kwargs\": {\n            \"decay_steps\": 1600,\n            \"decay_rate\": 0.8\n        },\n        \"start_decay_step\": 0,\n        \"end_decay_step\": 16000,\n    },\n}\n\nsource_encoder_hparams = {\n    \"encoder_minor_type\": \"BidirectionalRNNEncoder\",\n    \"encoder_minor_hparams\": {\n        \"rnn_cell_fw\": {\n            \"type\": \"GRUCell\",\n            \"kwargs\": {\n                \"num_units\": _hidden_size,\n            },\n        },\n        \"rnn_cell_share_config\": True\n    },\n    \"encoder_major_type\": \"UnidirectionalRNNEncoder\",\n    \"encoder_major_hparams\": {\n        \"rnn_cell\": {\n            \"type\": \"GRUCell\",\n            \"kwargs\": {\n                \"num_units\": _hidden_size*2,\n            },\n        }\n    }\n}\n\ntarget_encoder_hparams = {\n    \"rnn_cell_fw\": {\n        \"type\": \"GRUCell\",\n        \"kwargs\": {\n            \"num_units\": _hidden_size,\n        },\n    },\n    \"rnn_cell_share_config\": True\n}\n\ntarget_kwencoder_hparams = {\n    \"rnn_cell_fw\": {\n        \"type\": \"GRUCell\",\n        \"kwargs\": {\n            \"num_units\": _hidden_size,\n        },\n    },\n    \"rnn_cell_share_config\": True\n}\n\ncontext_encoder_hparams = {\n    \"rnn_cell\": {\n        \"type\": \"GRUCell\",\n        \"kwargs\": {\n            \"num_units\": _hidden_size,\n        },\n    }\n}\n\nopt_hparams = {\n    \"optimizer\": {\n        \"type\": \"AdamOptimizer\",\n        \"kwargs\": {\n            \"learning_rate\": 0.001,\n        }\n    }\n}\n"
  },
  {
    "path": "config/retrieval.py",
    "content": "_hidden_size = 200\n_code_len = 200\n_save_path = 'save/retrieval/model_1'\n_max_epoch = 10\n\nsource_encoder_hparams = {\n    \"encoder_minor_type\": \"UnidirectionalRNNEncoder\",\n    \"encoder_minor_hparams\": {\n        \"rnn_cell\": {\n            \"type\": \"GRUCell\",\n            \"kwargs\": {\n                \"num_units\": _hidden_size,\n            },\n        },\n    },\n    \"encoder_major_type\": \"UnidirectionalRNNEncoder\",\n    \"encoder_major_hparams\": {\n        \"rnn_cell\": {\n            \"type\": \"GRUCell\",\n            \"kwargs\": {\n                \"num_units\": _hidden_size,\n            },\n        }\n    }\n}\n\ntarget_encoder_hparams = {\n    \"rnn_cell\": {\n        \"type\": \"GRUCell\",\n        \"kwargs\": {\n            \"num_units\": _hidden_size,\n        },\n    }\n}\n\nopt_hparams = {\n    \"optimizer\": {\n        \"type\": \"AdamOptimizer\",\n        \"kwargs\": {\n            \"learning_rate\": 0.001,\n        }\n    },\n}"
  },
  {
    "path": "config/retrieval_stgy.py",
    "content": "_hidden_size = 200\n_code_len = 200\n_save_path = 'save/retrieval/model_1'\n_max_epoch = 10\n\nsource_encoder_hparams = {\n    \"encoder_minor_type\": \"UnidirectionalRNNEncoder\",\n    \"encoder_minor_hparams\": {\n        \"rnn_cell\": {\n            \"type\": \"GRUCell\",\n            \"kwargs\": {\n                \"num_units\": _hidden_size,\n            },\n        },\n    },\n    \"encoder_major_type\": \"UnidirectionalRNNEncoder\",\n    \"encoder_major_hparams\": {\n        \"rnn_cell\": {\n            \"type\": \"GRUCell\",\n            \"kwargs\": {\n                \"num_units\": _hidden_size,\n            },\n        }\n    }\n}\n\ntarget_encoder_hparams = {\n    \"rnn_cell\": {\n        \"type\": \"GRUCell\",\n        \"kwargs\": {\n            \"num_units\": _hidden_size,\n        },\n    }\n}\n\nopt_hparams = {\n    \"optimizer\": {\n        \"type\": \"AdamOptimizer\",\n        \"kwargs\": {\n            \"learning_rate\": 0.001,\n        }\n    },\n}"
  },
  {
    "path": "model/kernel.py",
    "content": "import texar as tx\nimport tensorflow as tf\nimport numpy as np\nfrom preprocess.data_utils import kw_tokenize\n\n\nclass Predictor():\n    def __init__(self, config_model, config_data, mode=None):\n        self.config = config_model\n        self.data_config = config_data\n        self.gpu_config = tf.ConfigProto()\n        self.gpu_config.gpu_options.allow_growth = True\n        self.build_model()\n\n    def build_model(self):\n        self.train_data = tx.data.MultiAlignedData(self.data_config.data_hparams['train'])\n        self.valid_data = tx.data.MultiAlignedData(self.data_config.data_hparams['valid'])\n        self.test_data = tx.data.MultiAlignedData(self.data_config.data_hparams['test'])\n        self.iterator = tx.data.TrainTestDataIterator(train=self.train_data, val=self.valid_data, test=self.test_data)\n        self.vocab = self.train_data.vocab(0)\n        self.embedder = tx.modules.WordEmbedder(init_value=self.train_data.embedding_init_value(0).word_vecs)\n        self.kw_embedder = tx.modules.WordEmbedder(init_value=self.train_data.embedding_init_value(0).word_vecs)\n        self.source_encoder = tx.modules.HierarchicalRNNEncoder(hparams=self.config.source_encoder_hparams)\n        self.target_encoder = tx.modules.BidirectionalRNNEncoder(hparams=self.config.target_encoder_hparams)\n        self.target_kwencoder = tx.modules.BidirectionalRNNEncoder(hparams=self.config.target_kwencoder_hparams)\n        self.linear_transform = tx.modules.MLPTransformConnector(self.config._code_len // 2)\n        self.linear_matcher = tx.modules.MLPTransformConnector(1)\n        self.linear_kernel = tx.modules.MLPTransformConnector(1)\n        self.kw_list = self.vocab.map_tokens_to_ids(tf.convert_to_tensor(self.data_config._keywords_candi))\n        self.kw_vocab = tx.data.Vocab(self.data_config._keywords_path)\n        self.keywords_embed = tf.nn.l2_normalize(self.kw_embedder(self.kw_list), axis=1)\n\n    def forward_kernel(self, kw_embed, context_ids):\n        kernel_sigma = self.config._kernel_sigma\n        mu = tf.convert_to_tensor(self.config._kernel_mu)\n        mask = tf.cast(context_ids > 3, dtype=tf.float32)\n        context_embed = self.kw_embedder(context_ids)\n        context_embed = tf.nn.l2_normalize(context_embed, axis=2)\n        similarity_matrix = tf.reduce_sum(kw_embed * context_embed, axis=2)\n        similarity_matrix = tf.tile(tf.expand_dims(similarity_matrix, 2), [1, 1, len(self.config._kernel_mu)])\n        matching_feature = tf.exp(-(similarity_matrix - mu) ** 2 / (kernel_sigma ** 2))\n        matching_feature = matching_feature * tf.tile(tf.expand_dims(mask, 2), [1, 1, len(self.config._kernel_mu)])\n        matching_feature = tf.reduce_sum(matching_feature, axis=1)\n        matching_score = self.linear_kernel(matching_feature)\n        matching_score = tf.squeeze(matching_score, 1)\n        return matching_score\n\n    def predict_keywords(self, batch):\n        keywords_ids = self.kw_vocab.map_tokens_to_ids(batch['keywords_text'])\n        matching_score = tf.map_fn(lambda kw_embed: self.forward_kernel(kw_embed, batch['context_text_ids']),\n            self.keywords_embed, dtype=tf.float32, parallel_iterations=True)\n        matching_score = tf.transpose(matching_score)\n        matching_score = tf.nn.softmax(matching_score)\n        kw_labels = tf.map_fn(lambda x: tf.sparse_to_dense(x, [self.kw_vocab.size], 1., 0., False),\n            keywords_ids, dtype=tf.float32, parallel_iterations=True)[:, 4:]\n        loss = tf.reduce_sum(-tf.log(matching_score) * kw_labels) / tf.reduce_sum(kw_labels)\n        kw_ans = tf.arg_max(matching_score, -1)\n        acc_label = tf.map_fn(lambda x: tf.gather(x[0], x[1]), (kw_labels, kw_ans), dtype=tf.float32)\n        acc = tf.reduce_mean(acc_label)\n        kws = tf.nn.top_k(matching_score, k=5)[1]\n        kws = tf.reshape(kws,[-1])\n        kws = tf.map_fn(lambda x: self.kw_list[x], kws, dtype=tf.int64)\n        kws = tf.reshape(kws,[-1, 5])\n        return loss, acc, kws\n\n    def train_keywords(self):\n        batch = self.iterator.get_next()\n        loss, acc, _ = self.predict_keywords(batch)\n        op_step = tf.Variable(0, name='op_step')\n        train_op = tx.core.get_train_op(loss, global_step=op_step, hparams=self.config.kernel_opt_hparams)\n        max_val_acc, stopping_flag = 0, 0\n        self.saver = tf.train.Saver()\n        with tf.Session(config=self.gpu_config) as sess:\n            sess.run(tf.global_variables_initializer())\n            sess.run(tf.local_variables_initializer())\n            sess.run(tf.tables_initializer())\n            for epoch_id in range(self.config._max_epoch):\n                self.iterator.switch_to_train_data(sess)\n                cur_step = 0\n                cnt_acc = []\n                while True:\n                    try:\n                        cur_step += 1\n                        feed = {tx.global_mode(): tf.estimator.ModeKeys.TRAIN}\n                        loss_, acc_ = sess.run([train_op, acc], feed_dict=feed)\n                        cnt_acc.append(acc_)\n                        if cur_step % 100 == 0:\n                            print('batch {}, loss={}, acc1={}'.format(cur_step, loss_, np.mean(cnt_acc[-100:])))\n                    except tf.errors.OutOfRangeError:\n                        break\n\n                self.iterator.switch_to_val_data(sess)\n                cnt_acc = []\n                while True:\n                    try:\n                        feed = {tx.global_mode(): tf.estimator.ModeKeys.EVAL}\n                        acc_ = sess.run(acc, feed_dict=feed)\n                        cnt_acc.append(acc_)\n                    except tf.errors.OutOfRangeError:\n                        mean_acc = np.mean(cnt_acc)\n                        if mean_acc > max_val_acc:\n                            max_val_acc = mean_acc\n                            self.saver.save(sess, self.config._kernel_save_path)\n                        else:\n                            stopping_flag += 1\n                        print('epoch_id {}, valid acc1={}'.format(epoch_id+1, mean_acc))\n                        break\n                if stopping_flag >= self.config._early_stopping:\n                    break\n\n    def test_keywords(self):\n        batch = self.iterator.get_next()\n        loss, acc, kws = self.predict_keywords(batch)\n        saver = tf.train.Saver()\n        with tf.Session(config=self.gpu_config) as sess:\n            sess.run(tf.global_variables_initializer())\n            sess.run(tf.local_variables_initializer())\n            sess.run(tf.tables_initializer())\n            saver.restore(sess, self.config._kernel_save_path)\n            self.iterator.switch_to_test_data(sess)\n            cnt_acc, cnt_rec1, cnt_rec3, cnt_rec5 = [], [], [], []\n            while True:\n                try:\n                    feed = {tx.global_mode(): tf.estimator.ModeKeys.PREDICT}\n                    acc_, kw_ans, kw_labels = sess.run([acc, kws, batch['keywords_text_ids']], feed_dict=feed)\n                    cnt_acc.append(acc_)\n                    rec = [0,0,0,0,0]\n                    sum_kws = 0\n                    for i in range(len(kw_ans)):\n                        sum_kws += sum(kw_labels[i] > 3)\n                        for j in range(5):\n                            if kw_ans[i][j] in kw_labels[i]:\n                                for k in range(j, 5):\n                                    rec[k] += 1\n                    cnt_rec1.append(rec[0]/sum_kws)\n                    cnt_rec3.append(rec[2]/sum_kws)\n                    cnt_rec5.append(rec[4]/sum_kws)\n\n                except tf.errors.OutOfRangeError:\n                    print('test_kw acc@1={:.4f}, rec@1={:.4f}, rec@3={:.4f}, rec@5={:.4f}'.format(\n                        np.mean(cnt_acc), np.mean(cnt_rec1), np.mean(cnt_rec3), np.mean(cnt_rec5)))\n                    break\n\n    def forward(self, batch):\n        matching_score = tf.map_fn(lambda kw_embed: self.forward_kernel(kw_embed, batch['context_text_ids']),\n            self.keywords_embed, dtype=tf.float32, parallel_iterations=True)\n        matching_score = tf.transpose(matching_score)\n\n        kw_weight, predict_kw = tf.nn.top_k(matching_score, k=3)\n        predict_kw = tf.reshape(predict_kw,[-1])\n        predict_kw = tf.map_fn(lambda x: self.kw_list[x], predict_kw, dtype=tf.int64)\n        predict_kw = tf.reshape(predict_kw,[-1,3])\n        embed_code = self.embedder(predict_kw)\n        embed_code = tf.reduce_sum(embed_code, axis=1)\n        embed_code = self.linear_transform(embed_code)\n\n        source_embed = self.embedder(batch['source_text_ids'])\n        target_embed = self.embedder(batch['target_text_ids']) # bs * 20 * 32 * 200\n        target_embed = tf.reshape(target_embed,[-1, self.data_config._max_seq_len+2, self.embedder.dim]) # (bs * 20) * 32 * 200\n        target_length = tf.reshape(batch['target_length'],[-1]) # (bs * 20) * 32 * 200\n        source_code = self.source_encoder(\n            source_embed,\n            sequence_length_minor=batch['source_length'],\n            sequence_length_major=batch['source_utterance_cnt'])[1]\n        target_code = self.target_encoder(\n            target_embed,\n            sequence_length=target_length)[1]\n        target_kwcode = self.target_kwencoder(\n            target_embed,\n            sequence_length=target_length)[1]\n        target_code = tf.concat([target_code[0], target_code[1], target_kwcode[0], target_kwcode[1]], -1)\n        target_code = tf.reshape(target_code, [-1,20,self.config._code_len])\n\n        source_code = tf.concat([source_code,embed_code], -1)\n        source_code = tf.expand_dims(source_code, 1)\n        source_code = tf.tile(source_code, [1,20,1])\n        feature_code = target_code * source_code\n        feature_code = tf.reshape(feature_code,[-1,self.config._code_len])\n        logits = self.linear_matcher(feature_code)\n        logits = tf.reshape(logits,[-1,20])\n        labels = tf.one_hot(batch['label'], 20)\n        loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=labels, logits=logits))\n        ans = tf.arg_max(logits, -1)\n        acc = tx.evals.accuracy(batch['label'], ans)\n        rank = tf.nn.top_k(logits, k=20)[1]\n        return loss, acc, rank\n\n    def train(self):\n        batch = self.iterator.get_next()\n        loss_t, acc_t, _ = self.predict_keywords(batch)\n        kw_saver = tf.train.Saver()\n        loss, acc, _ = self.forward(batch)\n        retrieval_step = tf.Variable(0, name='retrieval_step')\n        train_op = tx.core.get_train_op(loss, global_step=retrieval_step, hparams=self.config.opt_hparams)\n        max_val_acc, stopping_flag = 0, 0\n        with tf.Session(config=self.gpu_config) as sess:\n            sess.run(tf.tables_initializer())\n            sess.run(tf.global_variables_initializer())\n            sess.run(tf.local_variables_initializer())\n            kw_saver.restore(sess, self.config._kernel_save_path)\n            saver = tf.train.Saver()\n            for epoch_id in range(self.config._max_epoch):\n                self.iterator.switch_to_train_data(sess)\n                cur_step = 0\n                cnt_acc, cnt_kwacc = [],[]\n                while True:\n                    try:\n                        cur_step += 1\n                        feed = {tx.global_mode(): tf.estimator.ModeKeys.TRAIN}\n                        loss, acc_, acc_kw = sess.run([train_op, acc, acc_t], feed_dict=feed)\n                        cnt_acc.append(acc_)\n                        cnt_kwacc.append(acc_kw)\n                        if cur_step % 200 == 0:\n                            print('batch {}, loss={}, acc1={}, kw_acc1={}'.format(cur_step, loss,\n                                                np.mean(cnt_acc[-200:]) ,np.mean(cnt_kwacc[-200:])))\n                    except tf.errors.OutOfRangeError:\n                        break\n                self.iterator.switch_to_val_data(sess)\n                cnt_acc, cnt_kwacc = [],[]\n                while True:\n                    try:\n                        feed = {tx.global_mode(): tf.estimator.ModeKeys.EVAL}\n                        acc_, acc_kw = sess.run([acc, acc_t], feed_dict=feed)\n                        cnt_acc.append(acc_)\n                        cnt_kwacc.append(acc_kw)\n                    except tf.errors.OutOfRangeError:\n                        mean_acc = np.mean(cnt_acc)\n                        print('valid acc1={}, kw_acc1={}'.format(mean_acc, np.mean(cnt_kwacc)))\n                        if mean_acc > max_val_acc:\n                            max_val_acc = mean_acc\n                            saver.save(sess, self.config._save_path)\n                        else:\n                            stopping_flag += 1\n                        break\n                if stopping_flag >= self.config._early_stopping:\n                    break\n\n    def test(self):\n        batch = self.iterator.get_next()\n        loss, acc, rank = self.forward(batch)\n        with tf.Session(config=self.gpu_config) as sess:\n            sess.run(tf.tables_initializer())\n            self.saver = tf.train.Saver()\n            self.saver.restore(sess, self.config._save_path)\n            self.iterator.switch_to_test_data(sess)\n            rank_cnt = []\n            while True:\n                try:\n                    feed = {tx.global_mode(): tf.estimator.ModeKeys.PREDICT}\n                    ranks, labels = sess.run([rank, batch['label']], feed_dict=feed)\n                    for i in range(len(ranks)):\n                        rank_cnt.append(np.where(ranks[i]==labels[i])[0][0])\n                except tf.errors.OutOfRangeError:\n                    rec = [0,0,0,0,0]\n                    MRR = 0\n                    for rank in rank_cnt:\n                        for i in range(5):\n                            rec[i] += (rank <= i)\n                        MRR += 1 / (rank+1)\n                    print('test rec1@20={:.4f}, rec3@20={:.4f}, rec5@20={:.4f}, MRR={:.4f}'.format(\n                        rec[0]/len(rank_cnt), rec[2]/len(rank_cnt), rec[4]/len(rank_cnt), MRR/len(rank_cnt)))\n                    break\n\n    def retrieve_init(self, sess):\n        data_batch = self.iterator.get_next()\n        loss, acc, _ = self.forward(data_batch)\n        self.corpus = self.data_config._corpus\n        self.corpus_data = tx.data.MonoTextData(self.data_config.corpus_hparams)\n        corpus_iterator = tx.data.DataIterator(self.corpus_data)\n        batch = corpus_iterator.get_next()\n        corpus_embed = self.embedder(batch['corpus_text_ids'])\n        utter_code = self.target_encoder(corpus_embed, sequence_length=batch['corpus_length'])[1]\n        utter_kwcode = self.target_kwencoder(corpus_embed, sequence_length=batch['corpus_length'])[1]\n        utter_code = tf.concat([utter_code[0], utter_code[1], utter_kwcode[0], utter_kwcode[1]], -1)\n        self.corpus_code = np.zeros([0, self.config._code_len])\n        corpus_iterator.switch_to_dataset(sess)\n        sess.run(tf.tables_initializer())\n        saver = tf.train.Saver()\n        saver.restore(sess, self.config._save_path)\n        feed = {tx.global_mode(): tf.estimator.ModeKeys.PREDICT}\n        while True:\n            try:\n                utter_code_ = sess.run(utter_code, feed_dict=feed)\n                self.corpus_code = np.concatenate([self.corpus_code, utter_code_], axis=0)\n            except tf.errors.OutOfRangeError:\n                break\n        self.kw_embedding = sess.run(self.keywords_embed)\n\n        # predict keyword\n        self.context_input = tf.placeholder(dtype=object)\n        context_ids = tf.expand_dims(self.vocab.map_tokens_to_ids(self.context_input), 0)\n        matching_score = tf.map_fn(lambda kw_embed: self.forward_kernel(kw_embed, context_ids),\n                                   self.keywords_embed, dtype=tf.float32, parallel_iterations=True)\n        self.candi_output = tf.nn.top_k(tf.squeeze(matching_score, 1), self.data_config._keywords_num)[1]\n\n        # retrieve\n        self.minor_length_input = tf.placeholder(dtype=tf.int32, shape=(1, 9))\n        self.major_length_input = tf.placeholder(dtype=tf.int32, shape=(1))\n        self.history_input = tf.placeholder(dtype=object, shape=(9, self.data_config._max_seq_len + 2))\n        self.kw_input = tf.placeholder(dtype=tf.int32)\n        history_ids = self.vocab.map_tokens_to_ids(self.history_input)\n        history_embed = self.embedder(history_ids)\n        history_code = self.source_encoder(tf.expand_dims(history_embed, axis=0),\n                                           sequence_length_minor=self.minor_length_input,\n                                           sequence_length_major=self.major_length_input)[1]\n        self.next_kw_ids = self.kw_list[self.kw_input]\n        embed_code = tf.expand_dims(self.embedder(self.next_kw_ids), 0)\n        embed_code = self.linear_transform(embed_code)\n        history_code = tf.concat([history_code, embed_code], 1)\n        select_corpus = tf.cast(self.corpus_code, dtype=tf.float32)\n        feature_code = self.linear_matcher(select_corpus * history_code)\n        self.ans_output = tf.nn.top_k(tf.squeeze(feature_code,1), k=self.data_config._retrieval_candidates)[1]\n\n    def retrieve(self, history_all, sess):\n        history, seq_len, turns, context, context_len = history_all\n        kw_candi = sess.run(self.candi_output, feed_dict={self.context_input: context[:context_len]})\n        for kw in kw_candi:\n            tmp_score = sum(self.kw_embedding[kw] * self.kw_embedding[self.data_config._keywords_dict[self.target]])\n            if tmp_score > self.score:\n                self.score = tmp_score\n                self.next_kw = self.data_config._keywords_candi[kw]\n                break\n        ans = sess.run(self.ans_output, feed_dict={self.history_input: history,\n                                                   self.minor_length_input: [seq_len], self.major_length_input: [turns],\n                                                   self.kw_input: self.data_config._keywords_dict[self.next_kw]})\n        flag = 0\n        reply = self.corpus[ans[0]]\n        for i in ans:\n            if i in self.reply_list:  # avoid repeat\n                continue\n            for wd in kw_tokenize(self.corpus[i]):\n                if wd in self.data_config._keywords_candi:\n                    tmp_score = sum(self.kw_embedding[self.data_config._keywords_dict[wd]] *\n                                    self.kw_embedding[self.data_config._keywords_dict[self.target]])\n                    if tmp_score > self.score:\n                        reply = self.corpus[i]\n                        self.score = tmp_score\n                        self.next_kw = wd\n                        flag = 1\n                        break\n            if flag == 0:\n                continue\n            break\n        return reply\n"
  },
  {
    "path": "model/matrix.py",
    "content": "import texar as tx\nimport tensorflow as tf\nimport numpy as np\nimport pickle\n\nclass Predictor():\n    def __init__(self, config_model, config_data, mode=None):\n        self.config = config_model\n        self.data_config = config_data\n        self.gpu_config = tf.ConfigProto()\n        self.gpu_config.gpu_options.allow_growth = True\n        self.build_model(mode)\n\n    def build_model(self, mode):\n        self.train_data = tx.data.MultiAlignedData(self.data_config.data_hparams['train'])\n        self.valid_data = tx.data.MultiAlignedData(self.data_config.data_hparams['valid'])\n        self.test_data = tx.data.MultiAlignedData(self.data_config.data_hparams['test'])\n        self.iterator = tx.data.TrainTestDataIterator(train=self.train_data, val=self.valid_data, test=self.test_data)\n        self.vocab = self.train_data.vocab(0)\n        self.source_encoder = tx.modules.HierarchicalRNNEncoder(hparams=self.config.source_encoder_hparams)\n        self.target_encoder = tx.modules.BidirectionalRNNEncoder(hparams=self.config.target_encoder_hparams)\n        self.target_kwencoder = tx.modules.BidirectionalRNNEncoder(hparams=self.config.target_kwencoder_hparams)\n        self.linear_transform = tx.modules.MLPTransformConnector(self.config._code_len // 2)\n        self.linear_matcher = tx.modules.MLPTransformConnector(1)\n        self.embedder = tx.modules.WordEmbedder(init_value=self.train_data.embedding_init_value(0).word_vecs)\n        self.kw_list = self.vocab.map_tokens_to_ids(tf.convert_to_tensor(self.data_config._keywords_candi))\n        self.kw_vocab = tx.data.Vocab(self.data_config._keywords_path)\n\n        if mode == 'train_kw':\n            self.pmi_matrix = np.zeros([self.config._vocab_size+4, self.data_config._keywords_num])\n        else:\n            with open(self.config._matrix_save_path, 'rb') as f:\n                matrix = pickle.load(f)\n                self.pmi_matrix = tf.convert_to_tensor(matrix,dtype=tf.float32)\n\n    def forward_matrix(self, context_ids):\n        matching_score = tf.gather(self.pmi_matrix, context_ids)\n        return tf.reduce_sum(tf.log(matching_score), axis=0)\n\n    def predict_keywords(self, batch):\n        keywords_ids = self.kw_vocab.map_tokens_to_ids(batch['keywords_text'])\n        matching_score = tf.map_fn(lambda x: self.forward_matrix(x), batch['context_text_ids'],\n             dtype=tf.float32, parallel_iterations=True)\n        kw_labels = tf.map_fn(lambda x: tf.sparse_to_dense(x, [self.kw_vocab.size], 1., 0., False),\n                              keywords_ids, dtype=tf.float32, parallel_iterations=True)[:, 4:]\n        kw_ans = tf.arg_max(matching_score, -1)\n        acc_label = tf.map_fn(lambda x: tf.gather(x[0], x[1]), (kw_labels, kw_ans), dtype=tf.float32)\n        acc = tf.reduce_mean(acc_label)\n        kws = tf.nn.top_k(matching_score, k=5)[1]\n        kws = tf.reshape(kws,[-1])\n        kws = tf.map_fn(lambda x: self.kw_list[x], kws, dtype=tf.int64)\n        kws = tf.reshape(kws,[-1, 5])\n        return acc, kws\n\n    def train_keywords(self):\n        batch = self.iterator.get_next()\n        acc, _ = self.predict_keywords(batch)\n        with tf.Session(config=self.gpu_config) as sess:\n            sess.run(tf.global_variables_initializer())\n            sess.run(tf.local_variables_initializer())\n            sess.run(tf.tables_initializer())\n            self.iterator.switch_to_train_data(sess)\n\n            batchid = 0\n            while True:\n                try:\n                    batchid += 1\n                    feed = {tx.global_mode(): tf.estimator.ModeKeys.TRAIN}\n                    source_keywords, target_keywords = sess.run([batch['context_text_ids'],\n                                                                   batch['keywords_text_ids']], feed_dict=feed)\n                    for i in range(len(source_keywords)):\n                        for skw_id in source_keywords[i]:\n                            if skw_id == 0:\n                                break\n                            for tkw_id in target_keywords[i]:\n                                if skw_id >= 3 and tkw_id >= 3:\n                                    tkw = self.config._vocab[tkw_id-4]\n                                    if tkw in self.data_config._keywords_candi:\n                                        tkw_id = self.data_config._keywords_dict[tkw]\n                                        self.pmi_matrix[skw_id][tkw_id] += 1\n\n                except tf.errors.OutOfRangeError:\n                    break\n            self.pmi_matrix += 0.5\n            self.pmi_matrix = self.pmi_matrix / (np.sum(self.pmi_matrix, axis=0) + 1)\n            with open(self.config._matrix_save_path,'wb') as f:\n                pickle.dump(self.pmi_matrix, f)\n\n    def test_keywords(self):\n        batch = self.iterator.get_next()\n        acc, kws = self.predict_keywords(batch)\n        with tf.Session(config=self.gpu_config) as sess:\n            sess.run(tf.global_variables_initializer())\n            sess.run(tf.local_variables_initializer())\n            sess.run(tf.tables_initializer())\n            self.iterator.switch_to_test_data(sess)\n            cnt_acc, cnt_rec1, cnt_rec3, cnt_rec5 = [], [], [], []\n            while True:\n                try:\n                    feed = {tx.global_mode(): tf.estimator.ModeKeys.PREDICT}\n                    acc_, kw_ans, kw_labels = sess.run([acc, kws, batch['keywords_text_ids']], feed_dict=feed)\n                    cnt_acc.append(acc_)\n                    rec = [0,0,0,0,0]\n                    sum_kws = 0\n                    for i in range(len(kw_ans)):\n                        sum_kws += sum(kw_labels[i] > 3)\n                        for j in range(5):\n                            if kw_ans[i][j] in kw_labels[i]:\n                                for k in range(j, 5):\n                                    rec[k] += 1\n                    cnt_rec1.append(rec[0]/sum_kws)\n                    cnt_rec3.append(rec[2]/sum_kws)\n                    cnt_rec5.append(rec[4]/sum_kws)\n\n                except tf.errors.OutOfRangeError:\n                    print('test_kw acc@1={:.4f}, rec@1={:.4f}, rec@3={:.4f}, rec@5={:.4f}'.format(\n                        np.mean(cnt_acc), np.mean(cnt_rec1), np.mean(cnt_rec3), np.mean(cnt_rec5)))\n                    break\n\n\n    def forward(self, batch):\n        matching_score = tf.map_fn(lambda x: self.forward_matrix(x), batch['context_text_ids'],\n             dtype=tf.float32, parallel_iterations=True)\n        kw_weight, predict_kw = tf.nn.top_k(matching_score, k=3)\n        predict_kw = tf.reshape(predict_kw, [-1])\n        predict_kw = tf.map_fn(lambda x: self.kw_list[x], predict_kw, dtype=tf.int64)\n        predict_kw = tf.reshape(predict_kw, [-1, 3])\n        embed_code = self.embedder(predict_kw)\n        embed_code = tf.reduce_sum(embed_code, axis=1)\n        embed_code = self.linear_transform(embed_code)\n\n        source_embed = self.embedder(batch['source_text_ids'])\n        target_embed = self.embedder(batch['target_text_ids'])\n        target_embed = tf.reshape(target_embed, [-1, self.data_config._max_seq_len + 2, self.embedder.dim])\n        target_length = tf.reshape(batch['target_length'], [-1])\n        source_code = self.source_encoder(\n            source_embed,\n            sequence_length_minor=batch['source_length'],\n            sequence_length_major=batch['source_utterance_cnt'])[1]\n        target_code = self.target_encoder(\n            target_embed,\n            sequence_length=target_length)[1]\n        target_kwcode = self.target_kwencoder(\n            target_embed,\n            sequence_length=target_length)[1]\n        target_code = tf.concat([target_code[0], target_code[1], target_kwcode[0], target_kwcode[1]], -1)\n        target_code = tf.reshape(target_code, [-1, 20, self.config._code_len])\n\n        source_code = tf.concat([source_code, embed_code], -1)\n        source_code = tf.expand_dims(source_code, 1)\n        source_code = tf.tile(source_code, [1, 20, 1])\n        feature_code = target_code * source_code\n        feature_code = tf.reshape(feature_code, [-1, self.config._code_len])\n\n        logits = self.linear_matcher(feature_code)\n        logits = tf.reshape(logits, [-1, 20])\n        labels = tf.one_hot(batch['label'], 20)\n        loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=labels, logits=logits))\n        ans = tf.arg_max(logits, -1)\n        acc = tx.evals.accuracy(batch['label'], ans)\n        rank = tf.nn.top_k(logits, k=20)[1]\n        return loss, acc, rank\n\n    def train(self):\n        batch = self.iterator.get_next()\n        loss, acc, _ = self.forward(batch)\n        op_step = tf.Variable(0, name='retrieval_step')\n        train_op = tx.core.get_train_op(loss, global_step=op_step, hparams=self.config.opt_hparams)\n        max_val_acc = 0.\n        with tf.Session(config=self.gpu_config) as sess:\n            sess.run(tf.tables_initializer())\n            sess.run(tf.global_variables_initializer())\n            sess.run(tf.local_variables_initializer())\n            saver = tf.train.Saver()\n            for epoch_id in range(self.config._max_epoch):\n                self.iterator.switch_to_train_data(sess)\n                cur_step = 0\n                cnt_acc = []\n                while True:\n                    try:\n                        cur_step += 1\n                        feed = {tx.global_mode(): tf.estimator.ModeKeys.TRAIN}\n                        loss, acc_ = sess.run([train_op, acc], feed_dict=feed)\n                        cnt_acc.append(acc_)\n                        if cur_step % 200 == 0:\n                            print('batch {}, loss={}, acc1={}'.format(cur_step, loss, np.mean(cnt_acc[-200:])))\n                    except tf.errors.OutOfRangeError:\n                        break\n                self.iterator.switch_to_val_data(sess)\n\n                cnt_acc= []\n                while True:\n                    try:\n                        feed = {tx.global_mode(): tf.estimator.ModeKeys.EVAL}\n                        acc_ = sess.run(acc, feed_dict=feed)\n                        cnt_acc.append(acc_)\n                    except tf.errors.OutOfRangeError:\n                        mean_acc = np.mean(cnt_acc)\n                        print('valid acc1={}'.format(mean_acc))\n                        if mean_acc > max_val_acc:\n                            max_val_acc = mean_acc\n                            saver.save(sess, self.config._save_path)\n                        break\n\n    def test(self):\n        batch = self.iterator.get_next()\n        loss, acc, rank = self.forward(batch)\n        with tf.Session(config=self.gpu_config) as sess:\n            sess.run(tf.tables_initializer())\n            self.saver = tf.train.Saver()\n            self.saver.restore(sess, self.config._save_path)\n            self.iterator.switch_to_test_data(sess)\n            rank_cnt = []\n            while True:\n                try:\n                    feed = {tx.global_mode(): tf.estimator.ModeKeys.PREDICT}\n                    ranks, labels = sess.run([rank, batch['label']], feed_dict=feed)\n                    for i in range(len(ranks)):\n                        rank_cnt.append(np.where(ranks[i]==labels[i])[0][0])\n                except tf.errors.OutOfRangeError:\n                    rec = [0,0,0,0,0]\n                    MRR = 0\n                    for rank in rank_cnt:\n                        for i in range(5):\n                            rec[i] += (rank <= i)\n                        MRR += 1 / (rank+1)\n                    print('test rec1@20={:.4f}, rec3@20={:.4f}, rec5@20={:.4f}, MRR={:.4f}'.format(\n                        rec[0]/len(rank_cnt), rec[2]/len(rank_cnt), rec[4]/len(rank_cnt), MRR/len(rank_cnt)))\n                    break\n\n    def retrieve_init(self, sess):\n        data_batch = self.iterator.get_next()\n        loss, acc, _ = self.forward(data_batch)\n        self.corpus = self.data_config._corpus\n        self.corpus_data = tx.data.MonoTextData(self.data_config.corpus_hparams)\n        corpus_iterator = tx.data.DataIterator(self.corpus_data)\n        batch = corpus_iterator.get_next()\n        corpus_embed = self.embedder(batch['corpus_text_ids'])\n        utter_code = self.target_encoder(corpus_embed, sequence_length=batch['corpus_length'])[1]\n        utter_kwcode = self.target_kwencoder(corpus_embed, sequence_length=batch['corpus_length'])[1]\n        utter_code = tf.concat([utter_code[0], utter_code[1], utter_kwcode[0], utter_kwcode[1]], -1)\n        self.corpus_code = np.zeros([0, self.config._code_len])\n\n        corpus_iterator.switch_to_dataset(sess)\n        sess.run(tf.tables_initializer())\n        saver = tf.train.Saver()\n        saver.restore(sess, self.config._save_path)\n        feed = {tx.global_mode(): tf.estimator.ModeKeys.PREDICT}\n        while True:\n            try:\n                utter_code_ = sess.run(utter_code, feed_dict=feed)\n                self.corpus_code = np.concatenate([self.corpus_code, utter_code_], axis=0)\n            except tf.errors.OutOfRangeError:\n                break\n        self.minor_length_input = tf.placeholder(dtype=tf.int32, shape=(1, 9))\n        self.major_length_input = tf.placeholder(dtype=tf.int32, shape=(1))\n        self.history_input = tf.placeholder(dtype=object, shape=(9, self.data_config._max_seq_len + 2))\n        self.keywords_embed = tf.nn.l2_normalize(self.embedder(self.kw_list), axis=1)\n        self.kw_embedding = sess.run(self.keywords_embed)\n\n        # predict keyword\n        self.context_input = tf.placeholder(dtype=object)\n        context_ids = self.vocab.map_tokens_to_ids(self.context_input)\n        matching_score = self.forward_matrix(context_ids)\n        self.candi_output =tf.nn.top_k(matching_score, self.data_config._keywords_num)[1]\n\n        # retrieve\n        self.minor_length_input = tf.placeholder(dtype=tf.int32, shape=(1, 9))\n        self.major_length_input = tf.placeholder(dtype=tf.int32, shape=(1))\n        self.history_input = tf.placeholder(dtype=object, shape=(9, self.data_config._max_seq_len + 2))\n        self.kw_input = tf.placeholder(dtype=tf.int32)\n        history_ids = self.vocab.map_tokens_to_ids(self.history_input)\n        history_embed = self.embedder(history_ids)\n        history_code = self.source_encoder(tf.expand_dims(history_embed, axis=0),\n                                           sequence_length_minor=self.minor_length_input,\n                                           sequence_length_major=self.major_length_input)[1]\n        self.next_kw_ids = self.kw_list[self.kw_input]\n        embed_code = tf.expand_dims(self.embedder(self.next_kw_ids), 0)\n        embed_code = self.linear_transform(embed_code)\n        history_code = tf.concat([history_code, embed_code], 1)\n        select_corpus = tf.cast(self.corpus_code, dtype=tf.float32)\n        feature_code = self.linear_matcher(select_corpus * history_code)\n        self.ans_output = tf.nn.top_k(tf.squeeze(feature_code,1), k=self.data_config._retrieval_candidates)[1]\n\n    def retrieve(self, history_all, sess):\n        history, seq_len, turns, context, context_len = history_all\n        kw_candi = sess.run(self.candi_output, feed_dict={self.context_input: context[:context_len]})\n        for kw in kw_candi:\n            tmp_score = sum(self.kw_embedding[kw] * self.kw_embedding[self.data_config._keywords_dict[self.target]])\n            if tmp_score > self.score:\n                self.score = tmp_score\n                self.next_kw = self.data_config._keywords_candi[kw]\n                break\n        ans = sess.run(self.ans_output, feed_dict={self.history_input: history,\n                                                   self.minor_length_input: [seq_len], self.major_length_input: [turns],\n                                                   self.kw_input: self.data_config._keywords_dict[self.next_kw]})\n        for i in range(self.data_config._max_turns + 1):\n            if ans[i] not in self.reply_list:\n                self.reply_list.append(ans[i])\n                reply = self.corpus[ans[i]]\n                break\n        return reply\n"
  },
  {
    "path": "model/neural.py",
    "content": "import texar as tx\nimport tensorflow as tf\nimport numpy as np\nfrom preprocess.data_utils import kw_tokenize\n\n\nclass Predictor():\n    def __init__(self, config_model, config_data, mode=None):\n        self.config = config_model\n        self.data_config = config_data\n        self.gpu_config = tf.ConfigProto()\n        self.gpu_config.gpu_options.allow_growth = True\n        self.build_model()\n\n    def build_model(self):\n        self.train_data = tx.data.MultiAlignedData(self.data_config.data_hparams['train'])\n        self.valid_data = tx.data.MultiAlignedData(self.data_config.data_hparams['valid'])\n        self.test_data = tx.data.MultiAlignedData(self.data_config.data_hparams['test'])\n        self.iterator = tx.data.TrainTestDataIterator(train=self.train_data, val=self.valid_data, test=self.test_data)\n        self.vocab = self.train_data.vocab(0)\n        self.source_encoder = tx.modules.HierarchicalRNNEncoder(hparams=self.config.source_encoder_hparams)\n        self.target_encoder = tx.modules.BidirectionalRNNEncoder(hparams=self.config.target_encoder_hparams)\n        self.target_kwencoder = tx.modules.BidirectionalRNNEncoder(hparams=self.config.target_kwencoder_hparams)\n        self.linear_transform = tx.modules.MLPTransformConnector(self.config._code_len // 2)\n        self.linear_matcher = tx.modules.MLPTransformConnector(1)\n        self.context_encoder = tx.modules.UnidirectionalRNNEncoder(hparams=self.config.context_encoder_hparams)\n        self.predict_layer = tx.modules.MLPTransformConnector(self.data_config._keywords_num)\n        self.embedder = tx.modules.WordEmbedder(init_value=self.train_data.embedding_init_value(0).word_vecs)\n        self.kw_list = self.vocab.map_tokens_to_ids(tf.convert_to_tensor(self.data_config._keywords_candi))\n        self.kw_vocab = tx.data.Vocab(self.data_config._keywords_path)\n\n    def forward_neural(self, context_ids, context_length):\n        context_embed = self.embedder(context_ids)\n        context_code = self.context_encoder(context_embed, sequence_length=context_length)[1]\n        keyword_score = self.predict_layer(context_code)\n        return keyword_score\n\n    def predict_keywords(self, batch):\n        matching_score = self.forward_neural(batch['context_text_ids'], batch['context_length'])\n        keywords_ids = self.kw_vocab.map_tokens_to_ids(batch['keywords_text'])\n        kw_labels = tf.map_fn(lambda x: tf.sparse_to_dense(x, [self.kw_vocab.size], 1., 0., False),\n                              keywords_ids, dtype=tf.float32, parallel_iterations=True)[:, 4:]\n        loss = tf.nn.sigmoid_cross_entropy_with_logits(labels=kw_labels, logits=matching_score)\n        loss = tf.reduce_mean(loss)\n        kw_ans = tf.arg_max(matching_score, -1)\n        acc_label = tf.map_fn(lambda x: tf.gather(x[0], x[1]), (kw_labels, kw_ans), dtype=tf.float32)\n        acc = tf.reduce_mean(acc_label)\n        kws = tf.nn.top_k(matching_score, k=5)[1]\n        kws = tf.reshape(kws,[-1])\n        kws = tf.map_fn(lambda x: self.kw_list[x], kws, dtype=tf.int64)\n        kws = tf.reshape(kws,[-1, 5])\n        return loss, acc, kws\n\n    def train_keywords(self):\n        batch = self.iterator.get_next()\n        loss, acc, _ = self.predict_keywords(batch)\n        op_step = tf.Variable(0, name='op_step')\n        train_op = tx.core.get_train_op(loss, global_step=op_step, hparams=self.config.neural_opt_hparams)\n        max_val_acc = 0.\n        self.saver = tf.train.Saver()\n        with tf.Session(config=self.gpu_config) as sess:\n            sess.run(tf.global_variables_initializer())\n            sess.run(tf.local_variables_initializer())\n            sess.run(tf.tables_initializer())\n            for epoch_id in range(self.config._max_epoch):\n                self.iterator.switch_to_train_data(sess)\n                cur_step = 0\n                cnt_acc = []\n                while True:\n                    try:\n                        cur_step += 1\n                        feed = {tx.global_mode(): tf.estimator.ModeKeys.TRAIN}\n                        loss_, acc_ = sess.run([train_op, acc], feed_dict=feed)\n                        cnt_acc.append(acc_)\n                        if cur_step % 200 == 0:\n                            print('batch {}, loss={}, acc1={}'.format(cur_step, loss_, np.mean(cnt_acc[-200:])))\n                    except tf.errors.OutOfRangeError:\n                        break\n                self.iterator.switch_to_val_data(sess)\n                cnt_acc = []\n                while True:\n                    try:\n                        feed = {tx.global_mode(): tf.estimator.ModeKeys.EVAL}\n                        acc_ = sess.run(acc, feed_dict=feed)\n                        cnt_acc.append(acc_)\n                    except tf.errors.OutOfRangeError:\n                        mean_acc = np.mean(cnt_acc)\n                        if mean_acc > max_val_acc:\n                            max_val_acc = mean_acc\n                            self.saver.save(sess, self.config._neural_save_path)\n                        print('epoch_id {}, valid acc1={}'.format(epoch_id+1, mean_acc))\n                        break\n\n    def test_keywords(self):\n        batch = self.iterator.get_next()\n        loss, acc, kws = self.predict_keywords(batch)\n        saver = tf.train.Saver()\n        with tf.Session(config=self.gpu_config) as sess:\n            sess.run(tf.global_variables_initializer())\n            sess.run(tf.local_variables_initializer())\n            sess.run(tf.tables_initializer())\n            saver.restore(sess, self.config._neural_save_path)\n            self.iterator.switch_to_test_data(sess)\n            cnt_acc, cnt_rec1, cnt_rec3, cnt_rec5 = [], [], [], []\n            while True:\n                try:\n                    feed = {tx.global_mode(): tf.estimator.ModeKeys.PREDICT}\n                    acc_, kw_ans, kw_labels = sess.run([acc, kws, batch['keywords_text_ids']], feed_dict=feed)\n                    cnt_acc.append(acc_)\n                    rec = [0,0,0,0,0]\n                    sum_kws = 0\n                    for i in range(len(kw_ans)):\n                        sum_kws += sum(kw_labels[i] > 3)\n                        for j in range(5):\n                            if kw_ans[i][j] in kw_labels[i]:\n                                for k in range(j, 5):\n                                    rec[k] += 1\n                    cnt_rec1.append(rec[0]/sum_kws)\n                    cnt_rec3.append(rec[2]/sum_kws)\n                    cnt_rec5.append(rec[4]/sum_kws)\n\n                except tf.errors.OutOfRangeError:\n                    print('test_kw acc@1={:.4f}, rec@1={:.4f}, rec@3={:.4f}, rec@5={:.4f}'.format(\n                        np.mean(cnt_acc), np.mean(cnt_rec1), np.mean(cnt_rec3), np.mean(cnt_rec5)))\n                    break\n\n    def forward(self, batch):\n        matching_score = self.forward_neural(batch['context_text_ids'], batch['context_length'])\n        kw_weight, predict_kw = tf.nn.top_k(matching_score, k=3)\n        predict_kw = tf.reshape(predict_kw, [-1])\n        predict_kw = tf.map_fn(lambda x: self.kw_list[x], predict_kw, dtype=tf.int64)\n        predict_kw = tf.reshape(predict_kw, [-1, 3])\n        embed_code = self.embedder(predict_kw)\n        embed_code = tf.reduce_sum(embed_code, axis=1)\n        embed_code = self.linear_transform(embed_code)\n\n        source_embed = self.embedder(batch['source_text_ids'])\n        target_embed = self.embedder(batch['target_text_ids'])\n        target_embed = tf.reshape(target_embed, [-1, self.data_config._max_seq_len + 2, self.embedder.dim])\n        target_length = tf.reshape(batch['target_length'], [-1])\n        source_code = self.source_encoder(\n            source_embed,\n            sequence_length_minor=batch['source_length'],\n            sequence_length_major=batch['source_utterance_cnt'])[1]  #\n        target_code = self.target_encoder(\n            target_embed,\n            sequence_length=target_length)[1]\n        target_kwcode = self.target_kwencoder(\n            target_embed,\n            sequence_length=target_length)[1]\n        target_code = tf.concat([target_code[0], target_code[1], target_kwcode[0], target_kwcode[1]], -1)\n        target_code = tf.reshape(target_code, [-1, 20, self.config._code_len])\n\n        source_code = tf.concat([source_code, embed_code], -1)\n        source_code = tf.expand_dims(source_code, 1)\n        source_code = tf.tile(source_code, [1, 20, 1])\n        feature_code = target_code * source_code\n        feature_code = tf.reshape(feature_code, [-1, self.config._code_len])\n\n        logits = self.linear_matcher(feature_code)\n        logits = tf.reshape(logits, [-1, 20])\n        labels = tf.one_hot(batch['label'], 20)\n        loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=labels, logits=logits))\n        ans = tf.arg_max(logits, -1)\n        acc = tx.evals.accuracy(batch['label'], ans)\n        rank = tf.nn.top_k(logits, k=20)[1]\n        return loss, acc, rank\n\n    def train(self):\n        batch = self.iterator.get_next()\n        kw_loss, kw_acc, _ = self.predict_keywords(batch)\n        kw_saver = tf.train.Saver()\n        loss, acc, _ = self.forward(batch)\n        op_step = tf.Variable(0, name='retrieval_step')\n        train_op = tx.core.get_train_op(loss, global_step=op_step, hparams=self.config.opt_hparams)\n        max_val_acc = 0.\n        with tf.Session(config=self.gpu_config) as sess:\n            sess.run(tf.tables_initializer())\n            sess.run(tf.global_variables_initializer())\n            sess.run(tf.local_variables_initializer())\n            kw_saver.restore(sess, self.config._neural_save_path)\n            saver = tf.train.Saver()\n            for epoch_id in range(self.config._max_epoch):\n                self.iterator.switch_to_train_data(sess)\n                cur_step = 0\n                cnt_acc = []\n                while True:\n                    try:\n                        cur_step += 1\n                        feed = {tx.global_mode(): tf.estimator.ModeKeys.TRAIN}\n                        loss, acc_ = sess.run([train_op, acc], feed_dict=feed)\n                        cnt_acc.append(acc_)\n                        if cur_step % 200 == 0:\n                            print('batch {}, loss={}, acc1={}'.format(cur_step, loss, np.mean(cnt_acc[-200:])))\n                    except tf.errors.OutOfRangeError:\n                        break\n                        \n                self.iterator.switch_to_val_data(sess)\n                cnt_acc, cnt_kwacc = [], []\n                while True:\n                    try:\n                        feed = {tx.global_mode(): tf.estimator.ModeKeys.EVAL}\n                        acc_, kw_acc_ = sess.run([acc, kw_acc], feed_dict=feed)\n                        cnt_acc.append(acc_)\n                        cnt_kwacc.append(kw_acc_)\n                    except tf.errors.OutOfRangeError:\n                        mean_acc = np.mean(cnt_acc)\n                        print('valid acc1={}, kw_acc1={}'.format(mean_acc, np.mean(cnt_kwacc)))\n                        if mean_acc > max_val_acc:\n                            max_val_acc = mean_acc\n                            saver.save(sess, self.config._save_path)\n                        break\n\n    def test(self):\n        batch = self.iterator.get_next()\n        loss, acc, rank = self.forward(batch)\n        with tf.Session(config=self.gpu_config) as sess:\n            sess.run(tf.tables_initializer())\n            self.saver = tf.train.Saver()\n            self.saver.restore(sess, self.config._save_path)\n            self.iterator.switch_to_test_data(sess)\n            rank_cnt = []\n            while True:\n                try:\n                    feed = {tx.global_mode(): tf.estimator.ModeKeys.PREDICT}\n                    ranks, labels = sess.run([rank, batch['label']], feed_dict=feed)\n                    for i in range(len(ranks)):\n                        rank_cnt.append(np.where(ranks[i]==labels[i])[0][0])\n                except tf.errors.OutOfRangeError:\n                    rec = [0,0,0,0,0]\n                    MRR = 0\n                    for rank in rank_cnt:\n                        for i in range(5):\n                            rec[i] += (rank <= i)\n                        MRR += 1 / (rank+1)\n                    print('test rec1@20={:.4f}, rec3@20={:.4f}, rec5@20={:.4f}, MRR={:.4f}'.format(\n                        rec[0]/len(rank_cnt), rec[2]/len(rank_cnt), rec[4]/len(rank_cnt), MRR/len(rank_cnt)))\n                    break\n\n\n    def retrieve_init(self, sess):\n        data_batch = self.iterator.get_next()\n        loss, acc, _ = self.forward(data_batch)\n        self.corpus = self.data_config._corpus\n        self.corpus_data = tx.data.MonoTextData(self.data_config.corpus_hparams)\n        corpus_iterator = tx.data.DataIterator(self.corpus_data)\n        batch = corpus_iterator.get_next()\n        corpus_embed = self.embedder(batch['corpus_text_ids'])\n        utter_code = self.target_encoder(corpus_embed, sequence_length=batch['corpus_length'])[1]\n        utter_kwcode = self.target_kwencoder(corpus_embed, sequence_length=batch['corpus_length'])[1]\n        utter_code = tf.concat([utter_code[0], utter_code[1], utter_kwcode[0], utter_kwcode[1]], -1)\n        self.corpus_code = np.zeros([0, self.config._code_len])\n\n        corpus_iterator.switch_to_dataset(sess)\n        sess.run(tf.tables_initializer())\n        saver = tf.train.Saver()\n        saver.restore(sess, self.config._save_path)\n        feed = {tx.global_mode(): tf.estimator.ModeKeys.PREDICT}\n        while True:\n            try:\n                utter_code_ = sess.run(utter_code, feed_dict=feed)\n                self.corpus_code = np.concatenate([self.corpus_code, utter_code_], axis=0)\n            except tf.errors.OutOfRangeError:\n                break\n        self.keywords_embed = tf.nn.l2_normalize(self.embedder(self.kw_list), axis=1)\n        self.kw_embedding = sess.run(self.keywords_embed)\n\n        # predict keyword\n        self.context_input = tf.placeholder(dtype=object, shape=(20))\n        self.context_length_input = tf.placeholder(dtype=tf.int32, shape=(1))\n        context_ids = tf.expand_dims(self.vocab.map_tokens_to_ids(self.context_input), 0)\n        context_embed = self.embedder(context_ids)\n        context_code = self.context_encoder(context_embed, sequence_length=self.context_length_input)[1]\n        matching_score = self.predict_layer(context_code)\n        self.candi_output =tf.nn.top_k(tf.squeeze(matching_score, 0), self.data_config._keywords_num)[1]\n\n        # retrieve\n        self.minor_length_input = tf.placeholder(dtype=tf.int32, shape=(1, 9))\n        self.major_length_input = tf.placeholder(dtype=tf.int32, shape=(1))\n        self.history_input = tf.placeholder(dtype=object, shape=(9, self.data_config._max_seq_len + 2))\n        self.kw_input = tf.placeholder(dtype=tf.int32)\n        history_ids = self.vocab.map_tokens_to_ids(self.history_input)\n        history_embed = self.embedder(history_ids)\n        history_code = self.source_encoder(tf.expand_dims(history_embed, axis=0),\n                                           sequence_length_minor=self.minor_length_input,\n                                           sequence_length_major=self.major_length_input)[1]\n        self.next_kw_ids = self.kw_list[self.kw_input]\n        embed_code = tf.expand_dims(self.embedder(self.next_kw_ids), 0)\n        embed_code = self.linear_transform(embed_code)\n        history_code = tf.concat([history_code, embed_code], 1)\n        select_corpus = tf.cast(self.corpus_code, dtype=tf.float32)\n        feature_code = self.linear_matcher(select_corpus * history_code)\n        self.ans_output = tf.nn.top_k(tf.squeeze(feature_code,1), k=self.data_config._retrieval_candidates)[1]\n\n    def retrieve(self, history_all, sess):\n        history, seq_len, turns, context, context_len = history_all\n        kw_candi = sess.run(self.candi_output, feed_dict={self.context_input: context,\n                                                          self.context_length_input: [context_len]})\n        for kw in kw_candi:\n            tmp_score = sum(self.kw_embedding[kw] * self.kw_embedding[self.data_config._keywords_dict[self.target]])\n            if tmp_score > self.score:\n                self.score = tmp_score\n                self.next_kw = self.data_config._keywords_candi[kw]\n                break\n        ans = sess.run(self.ans_output, feed_dict={self.history_input: history,\n                                                   self.minor_length_input: [seq_len], self.major_length_input: [turns],\n                                                   self.kw_input: self.data_config._keywords_dict[self.next_kw]})\n        flag = 0\n        reply = self.corpus[ans[0]]\n        for i in ans:\n            if i in self.reply_list:  # avoid repeat\n                continue\n            for wd in kw_tokenize(self.corpus[i]):\n                if wd in self.data_config._keywords_candi:\n                    tmp_score = sum(self.kw_embedding[self.data_config._keywords_dict[wd]] *\n                                    self.kw_embedding[self.data_config._keywords_dict[self.target]])\n                    if tmp_score > self.score:\n                        reply = self.corpus[i]\n                        self.score = tmp_score\n                        self.next_kw = wd\n                        flag = 1\n                        break\n            if flag == 0:\n                continue\n            break\n        return reply\n"
  },
  {
    "path": "model/retrieval.py",
    "content": "import texar as tx\nimport tensorflow as tf\nimport numpy as np\n\n\nclass Predictor():\n    def __init__(self, config_model, config_data, mode=None):\n        self.config = config_model\n        self.data_config = config_data\n        self.build_model()\n        self.gpu_config = tf.ConfigProto()\n        self.gpu_config.gpu_options.allow_growth = True\n\n    def build_model(self):\n        self.train_data = tx.data.MultiAlignedData(self.data_config.data_hparams['train'])\n        self.valid_data = tx.data.MultiAlignedData(self.data_config.data_hparams['valid'])\n        self.test_data = tx.data.MultiAlignedData(self.data_config.data_hparams['test'])\n        self.iterator = tx.data.TrainTestDataIterator(train=self.train_data, val=self.valid_data, test=self.test_data)\n        self.vocab = self.train_data.vocab(0)\n        self.embedder = tx.modules.WordEmbedder(init_value=self.train_data.embedding_init_value(0).word_vecs)\n        self.source_encoder = tx.modules.HierarchicalRNNEncoder(hparams=self.config.source_encoder_hparams)\n        self.target_encoder = tx.modules.UnidirectionalRNNEncoder(hparams=self.config.target_encoder_hparams)\n        self.linear_matcher = tx.modules.MLPTransformConnector(1)\n\n    def forward(self, batch):\n        source_embed = self.embedder(batch['source_text_ids'])\n        target_embed = self.embedder(batch['target_text_ids'])\n        target_embed = tf.reshape(target_embed, [-1, self.data_config._max_seq_len + 2, self.embedder.dim])\n        source_code = self.source_encoder(source_embed,\n                                          sequence_length_minor=batch['source_length'],\n                                          sequence_length_major=batch['source_utterance_cnt'])[1]\n        target_length = tf.reshape(batch['target_length'], [-1])\n        target_code = self.target_encoder(target_embed, sequence_length=target_length)[1]\n        target_code = tf.reshape(target_code, [-1, 20, self.config._code_len])\n        source_code = tf.expand_dims(source_code, 1)\n        source_code = tf.tile(source_code, [1, 20, 1])\n        feature_code = target_code * source_code\n        feature_code = tf.reshape(feature_code, [-1, self.config._code_len])\n        logits = self.linear_matcher(feature_code)\n        logits = tf.reshape(logits, [-1, 20])\n        labels = tf.one_hot(batch['label'], 20)\n        loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=labels, logits=logits))\n        ans = tf.arg_max(logits, -1)\n        acc = tx.evals.accuracy(batch['label'], ans)\n        rank = tf.nn.top_k(logits, k=20)[1]\n        return loss, acc, rank\n\n    def train(self):\n        batch = self.iterator.get_next()\n        loss, acc, _ = self.forward(batch)\n        op_step = tf.Variable(0, name='op_step')\n        train_op = tx.core.get_train_op(loss, global_step=op_step, hparams=self.config.opt_hparams)\n        max_val_acc = 0.\n        self.saver = tf.train.Saver()\n        with tf.Session(config=self.gpu_config) as sess:\n            sess.run(tf.global_variables_initializer())\n            sess.run(tf.local_variables_initializer())\n            sess.run(tf.tables_initializer())\n            for epoch_id in range(self.config._max_epoch):\n                self.iterator.switch_to_train_data(sess)\n                cur_step = 0\n                cnt_acc = []\n                while True:\n                    try:\n                        cur_step += 1\n                        feed = {tx.global_mode(): tf.estimator.ModeKeys.TRAIN}\n                        loss, acc_ = sess.run([train_op, acc], feed_dict=feed)\n                        cnt_acc.append(acc_)\n                        if cur_step % 200 == 0:\n                            print('batch {}, loss={}, acc1={}'.format(cur_step, loss, np.mean(cnt_acc[-200:])))\n                    except tf.errors.OutOfRangeError:\n                        break\n                op_step = op_step + 1\n                self.iterator.switch_to_val_data(sess)\n                cnt_acc = []\n                while True:\n                    try:\n                        feed = {tx.global_mode(): tf.estimator.ModeKeys.EVAL}\n                        acc_ = sess.run([acc], feed_dict=feed)\n                        cnt_acc.append(acc_)\n                    except tf.errors.OutOfRangeError:\n                        mean_acc = np.mean(cnt_acc)\n                        print('valid acc1={}'.format(mean_acc))\n                        if mean_acc > max_val_acc:\n                            max_val_acc = mean_acc\n                            self.saver.save(sess, self.config._save_path)\n                        break\n\n    def test(self):\n        batch = self.iterator.get_next()\n        loss, acc, rank = self.forward(batch)\n        with tf.Session(config=self.gpu_config) as sess:\n            sess.run(tf.tables_initializer())\n            self.saver = tf.train.Saver()\n            self.saver.restore(sess, self.config._save_path)\n            self.iterator.switch_to_test_data(sess)\n            rank_cnt = []\n            while True:\n                try:\n                    feed = {tx.global_mode(): tf.estimator.ModeKeys.PREDICT}\n                    ranks, labels = sess.run([rank, batch['label']], feed_dict=feed)\n                    for i in range(len(ranks)):\n                        rank_cnt.append(np.where(ranks[i]==labels[i])[0][0])\n                except tf.errors.OutOfRangeError:\n                    rec = [0,0,0,0,0]\n                    MRR = 0\n                    for rank in rank_cnt:\n                        for i in range(5):\n                            rec[i] += (rank <= i)\n                        MRR += 1 / (rank+1)\n                    print('test rec1@20={:.4f}, rec3@20={:.4f}, rec5@20={:.4f}, MRR={:.4f}'.format(\n                        rec[0]/len(rank_cnt), rec[2]/len(rank_cnt), rec[4]/len(rank_cnt), MRR/len(rank_cnt)))\n                    break\n\n    def retrieve_init(self, sess):\n        data_batch = self.iterator.get_next()\n        loss, acc, _ = self.forward(data_batch)\n        self.corpus = self.data_config._corpus\n        self.corpus_data = tx.data.MonoTextData(self.data_config.corpus_hparams)\n        corpus_iterator = tx.data.DataIterator(self.corpus_data)\n        batch = corpus_iterator.get_next()\n        corpus_embed = self.embedder(batch['corpus_text_ids'])\n        utter_code = self.target_encoder(corpus_embed, sequence_length=batch['corpus_length'])[1]\n        self.corpus_code = np.zeros([0, self.config._code_len])\n        corpus_iterator.switch_to_dataset(sess)\n        sess.run(tf.tables_initializer())\n        saver = tf.train.Saver()\n        saver.restore(sess, self.config._save_path)\n        feed = {tx.global_mode(): tf.estimator.ModeKeys.PREDICT}\n        while True:\n            try:\n                utter_code_ = sess.run(utter_code, feed_dict=feed)\n                self.corpus_code = np.concatenate([self.corpus_code, utter_code_], axis=0)\n            except tf.errors.OutOfRangeError:\n                break\n\n        self.minor_length_input = tf.placeholder(dtype=tf.int32, shape=(1, 9))\n        self.major_length_input = tf.placeholder(dtype=tf.int32, shape=(1))\n        self.history_input = tf.placeholder(dtype=object, shape=(9, self.data_config._max_seq_len + 2))\n\n        history_ids = self.vocab.map_tokens_to_ids(self.history_input)\n        history_embed = self.embedder(history_ids)\n        history_code = self.source_encoder(tf.expand_dims(history_embed, axis=0),\n                                           sequence_length_minor=self.minor_length_input,\n                                           sequence_length_major=self.major_length_input)[1]\n        select_corpus = tf.cast(self.corpus_code, dtype=tf.float32)\n        feature_code = self.linear_matcher(select_corpus * history_code)\n        self.ans_output = tf.nn.top_k(tf.squeeze(feature_code, 1), k=self.data_config._retrieval_candidates)[1]\n\n    def retrieve(self, source, sess):\n        history, seq_len, turns, context, context_len = source\n        ans = sess.run(self.ans_output, feed_dict={self.history_input: history,\n                                                   self.minor_length_input: [seq_len],\n                                                   self.major_length_input: [turns]})\n        for i in range(self.data_config._max_turns + 1):\n            if ans[i] not in self.reply_list:  # avoid repeat\n                self.reply_list.append(ans[i])\n                reply = self.corpus[ans[i]]\n                break\n        return reply\n"
  },
  {
    "path": "model/retrieval_stgy.py",
    "content": "import texar as tx\nimport tensorflow as tf\nimport numpy as np\nfrom preprocess.data_utils import kw_tokenize\n\nclass Predictor():\n    def __init__(self, config_model, config_data, mode=None):\n        self.config = config_model\n        self.data_config = config_data\n        self.build_model()\n        self.gpu_config = tf.ConfigProto()\n        self.gpu_config.gpu_options.allow_growth = True\n\n    def build_model(self):\n        self.train_data = tx.data.MultiAlignedData(self.data_config.data_hparams['train'])\n        self.valid_data = tx.data.MultiAlignedData(self.data_config.data_hparams['valid'])\n        self.test_data = tx.data.MultiAlignedData(self.data_config.data_hparams['test'])\n        self.iterator = tx.data.TrainTestDataIterator(train=self.train_data, val=self.valid_data, test=self.test_data)\n        self.vocab = self.train_data.vocab(0)\n        self.embedder = tx.modules.WordEmbedder(init_value=self.train_data.embedding_init_value(0).word_vecs)\n        self.source_encoder = tx.modules.HierarchicalRNNEncoder(hparams=self.config.source_encoder_hparams)\n        self.target_encoder = tx.modules.UnidirectionalRNNEncoder(hparams=self.config.target_encoder_hparams)\n        self.linear_matcher = tx.modules.MLPTransformConnector(1)\n        self.kw_list = self.vocab.map_tokens_to_ids(tf.convert_to_tensor(self.data_config._keywords_candi))\n\n    def forward(self, batch):\n        source_embed = self.embedder(batch['source_text_ids'])\n        target_embed = self.embedder(batch['target_text_ids'])\n        target_embed = tf.reshape(target_embed, [-1, self.data_config._max_seq_len + 2, self.embedder.dim])\n        source_code = self.source_encoder(source_embed,\n                                          sequence_length_minor=batch['source_length'],\n                                          sequence_length_major=batch['source_utterance_cnt'])[1]\n        target_length = tf.reshape(batch['target_length'], [-1])\n        target_code = self.target_encoder(target_embed, sequence_length=target_length)[1]\n        target_code = tf.reshape(target_code, [-1, 20, self.config._code_len])\n        source_code = tf.expand_dims(source_code, 1)\n        source_code = tf.tile(source_code, [1, 20, 1])\n        feature_code = target_code * source_code\n        feature_code = tf.reshape(feature_code, [-1, self.config._code_len])\n        logits = self.linear_matcher(feature_code)\n        logits = tf.reshape(logits, [-1, 20])\n        labels = tf.one_hot(batch['label'], 20)\n        loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=labels, logits=logits))\n        ans = tf.arg_max(logits, -1)\n        acc = tx.evals.accuracy(batch['label'], ans)\n        rank = tf.nn.top_k(logits, k=20)[1]\n        return loss, acc, rank\n\n    def train(self):\n        batch = self.iterator.get_next()\n        loss, acc, _ = self.forward(batch)\n        op_step = tf.Variable(0, name='op_step')\n        train_op = tx.core.get_train_op(loss, global_step=op_step, hparams=self.config.opt_hparams)\n        max_val_acc = 0.\n        self.saver = tf.train.Saver()\n        with tf.Session(config=self.gpu_config) as sess:\n            sess.run(tf.global_variables_initializer())\n            sess.run(tf.local_variables_initializer())\n            sess.run(tf.tables_initializer())\n            for epoch_id in range(self.config._max_epoch):\n                self.iterator.switch_to_train_data(sess)\n                cur_step = 0\n                cnt_acc = []\n                while True:\n                    try:\n                        cur_step += 1\n                        feed = {tx.global_mode(): tf.estimator.ModeKeys.TRAIN}\n                        loss, acc_ = sess.run([train_op, acc], feed_dict=feed)\n                        cnt_acc.append(acc_)\n                        if cur_step % 200 == 0:\n                            print('batch {}, loss={}, acc1={}'.format(cur_step, loss, np.mean(cnt_acc[-200:])))\n                    except tf.errors.OutOfRangeError:\n                        break\n                op_step = op_step + 1\n                self.iterator.switch_to_val_data(sess)\n                cnt_acc = []\n                while True:\n                    try:\n                        feed = {tx.global_mode(): tf.estimator.ModeKeys.EVAL}\n                        acc_ = sess.run([acc], feed_dict=feed)\n                        cnt_acc.append(acc_)\n                    except tf.errors.OutOfRangeError:\n                        mean_acc = np.mean(cnt_acc)\n                        print('valid acc1={}'.format(mean_acc))\n                        if mean_acc > max_val_acc:\n                            max_val_acc = mean_acc\n                            self.saver.save(sess, self.config._save_path)\n                        break\n\n    def test(self):\n        batch = self.iterator.get_next()\n        loss, acc, rank = self.forward(batch)\n        with tf.Session(config=self.gpu_config) as sess:\n            sess.run(tf.tables_initializer())\n            self.saver = tf.train.Saver()\n            self.saver.restore(sess, self.config._save_path)\n            self.iterator.switch_to_test_data(sess)\n            rank_cnt = []\n            while True:\n                try:\n                    feed = {tx.global_mode(): tf.estimator.ModeKeys.PREDICT}\n                    ranks, labels = sess.run([rank, batch['label']], feed_dict=feed)\n                    for i in range(len(ranks)):\n                        rank_cnt.append(np.where(ranks[i]==labels[i])[0][0])\n                except tf.errors.OutOfRangeError:\n                    rec = [0,0,0,0,0]\n                    MRR = 0\n                    for rank in rank_cnt:\n                        for i in range(5):\n                            rec[i] += (rank <= i)\n                        MRR += 1 / (rank+1)\n                    print('test rec1@20={:.4f}, rec3@20={:.4f}, rec5@20={:.4f}, MRR={:.4f}'.format(\n                        rec[0]/len(rank_cnt), rec[2]/len(rank_cnt), rec[4]/len(rank_cnt), MRR/len(rank_cnt)))\n                    break\n\n    def retrieve_init(self, sess):\n        data_batch = self.iterator.get_next()\n        loss, acc, _ = self.forward(data_batch)\n        self.corpus = self.data_config._corpus\n        self.corpus_data = tx.data.MonoTextData(self.data_config.corpus_hparams)\n        corpus_iterator = tx.data.DataIterator(self.corpus_data)\n        batch = corpus_iterator.get_next()\n        corpus_embed = self.embedder(batch['corpus_text_ids'])\n        utter_code = self.target_encoder(corpus_embed, sequence_length=batch['corpus_length'])[1]\n        self.corpus_code = np.zeros([0, self.config._code_len])\n\n        corpus_iterator.switch_to_dataset(sess)\n        sess.run(tf.tables_initializer())\n        saver = tf.train.Saver()\n        saver.restore(sess, self.config._save_path)\n        feed = {tx.global_mode(): tf.estimator.ModeKeys.PREDICT}\n        while True:\n            try:\n                utter_code_ = sess.run(utter_code, feed_dict=feed)\n                self.corpus_code = np.concatenate([self.corpus_code, utter_code_], axis=0)\n            except tf.errors.OutOfRangeError:\n                break\n\n        self.keywords_embed = tf.nn.l2_normalize(self.embedder(self.kw_list), axis=1)\n        self.kw_embedding = sess.run(self.keywords_embed)\n\n        self.minor_length_input = tf.placeholder(dtype=tf.int32, shape=(1, 9))\n        self.major_length_input = tf.placeholder(dtype=tf.int32, shape=(1))\n        self.history_input = tf.placeholder(dtype=object, shape=(9, self.data_config._max_seq_len + 2))\n\n        history_ids = self.vocab.map_tokens_to_ids(self.history_input)\n        history_embed = self.embedder(history_ids)\n        history_code = self.source_encoder(tf.expand_dims(history_embed, axis=0),\n                                           sequence_length_minor=self.minor_length_input,\n                                           sequence_length_major=self.major_length_input)[1]\n        select_corpus = tf.cast(self.corpus_code, dtype=tf.float32)\n        feature_code = self.linear_matcher(select_corpus * history_code)\n        self.ans_output = tf.nn.top_k(tf.squeeze(feature_code, 1), k=1000)[1]\n\n    def retrieve(self, source, sess):\n        history, seq_len, turns, context, context_len = source\n        ans = sess.run(self.ans_output, feed_dict={self.history_input: history,\n                                                   self.minor_length_input: [seq_len],\n                                                   self.major_length_input: [turns]})\n        flag = 0\n        reply = self.corpus[ans[0]]\n        for i in ans:\n            if i in self.reply_list:  # avoid repeat\n                continue\n            for wd in kw_tokenize(self.corpus[i]):\n                if wd in self.data_config._keywords_candi:\n                    tmp_score = sum(self.kw_embedding[self.data_config._keywords_dict[wd]] *\n                                    self.kw_embedding[self.data_config._keywords_dict[self.target]])\n                    if tmp_score > self.score:\n                        reply = self.corpus[i]\n                        self.score = tmp_score\n                        self.next_kw = wd\n                        flag = 1\n                        break\n            if flag == 0:\n                continue\n            break\n        return reply\n\n"
  },
  {
    "path": "preprocess/convai2/__init__.py",
    "content": "from .api import *\n"
  },
  {
    "path": "preprocess/convai2/api.py",
    "content": "import os\ndata_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'source')\n\nclass dts_ConvAI2(object):\n    def __init__(self, path=data_path):\n        self.path = path\n\n    def _txt_to_json(self, txt_path, mode, cands):\n        def pop_one_sample(lines):\n            self_persona = []\n            other_persona = []\n            dialog = []\n            candidates = []\n\n            started = False\n            while len(lines) > 0:\n                line = lines.pop()\n                id, context = line.split(' ', 1)\n                id = int(id)\n                context = context.strip()\n\n                if started == False: # not started\n                    assert id == 1\n                    started = True\n                elif id == 1: # break for next\n                    lines.append(line)\n                    break\n\n                if context.startswith('partner\\'s persona: '): # partner\n                    assert mode in ['both', 'other']\n                    other_persona.append(context[19:])\n\n                elif context.startswith('your persona: '): # self\n                    assert mode in ['both', 'self']\n                    self_persona.append(context[13:])\n\n                elif cands == False: # no cands \n                    try:\n                        uttr, response = context.split('\\t', 2)[:2]\n                        dialog.append(uttr)\n                        dialog.append(response)\n                    except:\n                        uttr = context \n                        dialog.append(uttr)\n                else:\n                    uttr, response, _, negs = context.split('\\t', 4)[:4]\n                    dialog.append(uttr)\n                    dialog.append(response)                    \n                    candidates.append(negs.split('|'))\n                    candidates.append(None)\n\n            return {\n                'self_persona': self_persona,\n                'other_persona': other_persona,\n                'dialog': dialog,\n                'candidates': candidates\n            }\n\n        lines = open(txt_path, 'r').readlines()[::-1]\n\n        samples = []\n        while len(lines) > 0:\n            samples.append(pop_one_sample(lines))\n\n        return samples\n\n    def get_data(self, mode='train', revised=False, cands=False):\n        txt_path = os.path.join(self.path, '{}_{}_{}{}.txt'.format(\n            mode,\n            'none',\n            'revised' if revised is True else 'original',\n            '' if cands is True else '_no_cands'))\n        assert mode in ['train', 'valid', 'test', 'all']\n        print(\"Get dialog from \", txt_path)\n        assert os.path.exists(txt_path)\n        return self._txt_to_json(txt_path, mode, cands)\n\n    def get_dialogs(self, mode='all'):\n        dialogs = [sample['dialog'] for sample in self.get_data(mode, False, False)]\n        return dialogs"
  },
  {
    "path": "preprocess/convai2/candi_keyword.txt",
    "content": "favorite\nsound\nplay\ndog\nmusic\nkid\neat\nschool\nenjoy\njob\nwatch\nread\nfood\ncat\nfriend\nfamily\nhobby\npeople\npet\ncar\ngame\nhear\nmovie\ntravel\nbook\ncook\nlisten\nanimal\nlife\ncolor\ndrive\ncollege\nhope\nliving\nparent\nteach\nbad\nsport\nhard\ndad\nfeel\nchild\nhair\nband\ncountry\nmoney\npizza\nmarry\nbet\nhate\nwalk\nstay\nstudy\nwrite\nhusband\nfish\nstart\nguess\nbrother\nblue\nnight\ndance\nbusy\nred\nwife\nlearn\ntalk\nsing\nmeet\nbeach\nspend\ndrink\nrock\ncity\nvideo\ngrow\nperson\nhouse\ntv\ntrue\nshop\nteacher\nbuy\ngirl\nweekend\nprefer\nhike\nmeat\nfootball\nfree\nvisit\nsister\ncare\nplan\nart\nswim\nsweet\npaint\nstuff\npretty\nvegan\nstore\nclass\nfarm\ntype\nfunny\nunderstand\nson\nbusiness\nhappy\nmother\nride\ncoffee\nleave\nfan\ngarden\nweek\nboy\nbake\ngreen\npay\nbeautiful\ndraw\nstudent\nhorse\nsell\nsad\nsummer\nwear\ntruck\nblack\nhot\nwait\nyea\nsleep\ncold\nagree\nsingle\nguy\nsibling\nreal\ncrazy\nhealthy\nfine\ntall\nguitar\nlose\ncute\nhour\ncake\npurple\nmonth\nrelax\nfinish\nidea\nbreak\ncompany\ndaughter\nhappen\ndream\nteam\nmind\npark\nwoman\nreading\nrestaurant\nbike\nshort\nitalian\nexercise\nhang\nchocolate\nwonderful\nbasketball\nspeak\nsoccer\nweather\ngraduate\nretire\nwinter\nmorning\nnurse\nbring\ndie\neasy\nparty\ngrade\nfather\nready\nice\nwin\nspare\ndoctor\nonline\nsong\nlocal\ndegree\nchicken\nconcert\nflorida\nglad\nfall\nbaseball\neye\noffice\nvolunteer\nwater\ngirlfriend\nbear\nfrench\nshopping\nage\nluck\nfishing\nartist\nweird\nfoot\nsave\nsurf\ndinner\nyoga\nlucky\nstory\nhunt\ncooking\npink\nsuck\ncream\nboyfriend\nlanguage\nbeer\noutdoors\nchange\ndiet\npassion\nyork\nmajor\ncheese\ncollect\nimagine\nsuper\nenglish\nrap\npractice\nchat\nhead\ncalifornia\ntrain\nlake\nclothes\npass\nnature\nbaby\nfry\nseason\napple\npiano\nsit\nhop\nscary\ntaco\nlaw\nsteak\nhockey\ncomic\nbrown\nbar\ngym\nsmart\nhuge\nclean\nfav\nclose\njazz\nvegetarian\ncanada\nscience\ncareer\npicture\ntea\nexcite\ntough\ndeal\npick\nallergic\nneat\nchurch\nsocial\nsick\nshoe\ntrip\nfly\nvacation\ncatch\nbed\nraise\nboring\nrid\nocean\nclub\ntown\ninstrument\ncheck\nrace\ndragon\nfast\nvegetable\nwrong\nboat\nterrible\ntennis\ncandy\nrain\nworry\nveggie\nfruit\nmetal\nperfect\ntwin\ntattoo\nmountain\ntomorrow\ngod\nstand\nhospital\napartment\nbuild\nstick\ndesign\njapan\ntexas\nmeal\ngrocery\ncamp\nhit\norange\nmexican\nallergy\nta\nphone\nfamous\nplayer\nflower\nbore\narmy\nstar\nshare\npool\ndelicious\nrelationship\nexciting\nmath\negg\nmuseum\nclassic\ndress\nsushi\ntaste\nmarried\namaze\nclassical\nlady\nshelter\nsense\njoke\npop\npie\nyellow\namerican\nlawyer\nexpensive\nstress\nhand\ncut\npasta\nremember\nprofessional\nchoice\nimpressive\nyoutube\nkinda\nyum\nchicago\nbirthday\ncooky\nsunday\nfollow\ndivorce\ngon\nmoment\nfresh\ntire\nhurt\nwedding\nfit\nweight\nhealth\nplant\ncount\nchef\nheard\nball\nscar\nsort\nsmell\ndead\nspecial\ncomedy\ncouple\nrich\nhiking\nfave\ncreate\naccountant\nbird\nrelaxing\nhistory\nblonde\nfilm\neveryday\nglass\nheart\nvoice\nballet\nvet\nmilitary\nhorror\nfield\nfight\nmile\nsalad\nextra\nafraid\nmarket\nchristmas\nreason\nmexico\nattend\ncop\nchance\npotato\nsnow\nhalloween\nfolk\nsnake\nski\ncharacter\ncard\nadopt\nfigure\ntho\nhat\nnail\nford\nfat\nwarm\nswimming\ndifficult\nsew\nsuppose\ndye\nsafe\nwine\ntend\nstyle\nafford\nrecipe\nwriter\nshower\nlunch\nremind\npainting\nnews\ncongrats\nphotography\nrough\nroad\nholiday\njoin\nbeat\nwhite\nthrone\nmichigan\nawful\nactor\nbreakfast\nlibrary\ncongratulation\ngoal\nshrimp\nsinger\nmiddle\nhorrible\ncommon\nplane\nprofession\nbacon\nsmoke\ncoast\nkill\nfuture\nword\nfemale\nsouth\ncraft\nneighbor\nviolin\nfair\nparis\nfashion\nquiet\nservice\nbank\nworkout\ndish\nshape\ntour\ntheater\ngenre\nforget\nfiction\nmakeup\nmodel\nclown\nauthor\nexperience\nrest\nlibrarian\nking\npuppy\ngerman\ntree\ntrust\nstreet\nyummy\nmarketing\nquit\nhungry\ncreative\npoor\ncali\nwild\nalright\naccident\nfantasy\ncartoon\nair\npepper\nhurricane\nton\nlab\nanime\nuniversity\ndrop\nfactory\niphone\nmystery\nactive\nnetflix\nstressful\nspace\ntrouble\nsecret\nmcdonalds\nlaugh\ngun\nsupport\nspanish\npain\ncasino\nfrance\nset\nspaghetti\nsale\nmagic\ngrill\nburger\nclimb\ninternet\npig\nhold\ndecide\nmustang\nprogram\ndrum\nhmm\nfocus\nseafood\nmaster\nmedium\nstrange\nlesson\narm\nwar\nserve\nvan\nnorth\ncow\ntoe\ndriver\npaper\neducation\ntable\nstation\ndancer\nsoda\nplaying\nmedical\ndark\nblog\nusa\npower\nwan\nbob\nsimple\nzoo\nshoot\nfancy\nmechanic\ncollection\nnational\nmotorcycle\nstrong\nparrot\nnursing\namerica\nharry\nblood\nahahah\nbody\nbible\npolice\nenergy\nalabama\nlazy\ndeath\nrose\nmusician\nitaly\nadventure\nworker\nwriting\npair\nactivity\nbread\nthrow\nissue\neurope\nlondon\ngross\njealous\nmale\nheight\ndisney\ntreat\nmarathon\ndoor\nbean\nbos\nboot\naww\ncancer\nlover\nmilk\nskill\nplenty\nbeard\nlight\npoetry\nsubject\nsea\nromantic\nrelate\nspending\nfirm\nsend\nseattle\ncouch\nquestion\nflavor\nphoto\nchannel\nfavourite\nmatch\ncontact\nchoose\njapanese\nsinging\nlift\nboston\nwood\ngrandmother\nfarmer\nradio\nnavy\nroommate\nbus\npolitics\ngrandchild\njam\nchinese\nviking\nmatter\nseries\ncompetition\ndesigner\ngas\nsan\nfeeling\nknit\nlizard\nslow\npublish\nfantastic\ncrochet\ndonate\nadd\ntelevision\ninsurance\ntalent\nbowl\nhuman\nland\nrise\nhmmm\ntest\nalive\ninvolve\ndangerous\npro\nyesterday\nwind\nobsess\ngolden\nstone\npost\nwest\nnephew\ndepend\nnervous\nbreed\nbakery\nhip\ngoodness\nbaker\nsurgery\nmall\ndon\nhonda\ncolorado\ndancing\ndr\ndude\nlifestyle\ncalm\nwake\ncrime\nathlete\nskin\nbeauty\nlie\nfacebook\nmar\nreader\nlead\nmess\nsnack\nshift\nemploy\nspring\nsize\nhandle\nshy\niron\ncorvette\nevening\nsuperman\nboard\ncheer\ntraffic\ngrand\nsun\nfee\nmusical\nbase\npublic\nchip\nmad\ncareful\npound\nbrand\npotter\npunk\nsmile\nadvice\njustin\nmanager\nbass\npeaceful\ngolf\nclothing\njohn\naction\nbuddy\nsunny\nbox\nwaitress\ngardening\npopular\nwing\nwall\npersonal\ncoach\ncover\nelementary\nmix\npray\nodd\nstink\nminute\nkey\ninspire\nevent\nadorable\ncommunity\njuice\nengineer\nkarate\nskate\nsaturday\nliterature\nreward\nbaking\nghost\nohio\nvega\nkitchen\nsugar\nchill\ntomato\nchase\nchevy\nroll\nla\norganize\nblame\nhomework\nbee\ngrandma\nbeatles\npassionate\nsoul\nknee\nanxiety\nnut\nspot\nticket\ngrandparent\nhonest\ngoogle\njunk\nproduct\ntech\ntrack\nnormal\nrole\nclient\nveterinarian\nspicy\nbuilding\nscare\nhawaii\nbummer\nreality\niguana\nniece\nopera\ncharlie\nworth\npack\ntrail\nsenior\ntoyota\nbeet\nireland\nfinance\nchili\nofficer\npickle\nkinds\nmanage\ntired\nchristian\ndaily\nnose\ncheap\nspider\ngig\nindian\nlocate\npleasure\nleague\nhunting\nretired\ngeorge\ntuna\nteaching\nattention\nstep\nawe\nstock\nlist\nportland\nsign\navoid\nbug\nbrain\nscientist\ndessert\nexcellent\nfinger\nreligious\nsuperhero\nhighschool\nkick\nrule\nsalesman\ndrawing\nhandful\npilot\nescape\nschedule\ntasty\ngosh\nburrito\nkansa\ncommercial\nrescue\nmary\nfi\nretail\nretriever\nskydive\nhusky\ngamble\ngeorgia\nindustry\nmeeting\nnyc\nwalmart\nannoy\ncruise\ndoubt\nforest\nbunch\ncashier\nactress\nkayak\npartner\nmac\ntiny\nfriday\nspell\nculture\nkindergarten\nlay\ntiger\ndeaf\nmention\ndrinking\naccounting\ncup\nsubway\npancake\nhabit\nprius\nblind\nfamiliar\nstarbucks\nassistant\npumpkin\nburn\nanswer\npoodle\nmetallica\ncircus\njewelry\nsock\nrecommend\nsandwich\njournalist\nfear\npretend\ngrandson\nnfl\nbrownie\ncupcake\ndoll\nstephen\nbother\ndesert\nsci\nmoon\nconstruction\nreach\ntechnology\noption\nmarriage\npony\nhell\nromance\ncuisine\nstrawberry\nshirt\nbbq\npurse\nleg\nkitten\nunemployed\nproject\nrunning\ndeliver\nnatural\ndang\ncontrol\nhero\nbeef\ntrade\nsunset\ncarrot\nmain\nchess\ngrandkids\ntalented\nexplore\nnasty\neast\nhamburger\nprivate\nyup\nguard\nchoir\naustralia\npen\nstamp\nstruggle\ndive\nbro\nperry\nomg\nthinking\ncandle\nfell\nink\nwheel\nranch\nowner\ncarry\nbiology\nearn\nregular\nlion\npeanut\nbroccoli\nsuit\nmedicine\nspeaking\nchew\ninspiration\nsauce\nhotel\ncamera\nengland\nrepair\nturtle\nathletic\nunique\nray\nsky\nskateboard\nriver\ntune\nfreak\nestate\ncheesecake\nyuck\nsurfing\nrent\nooh\ngrey\nmemory\nperform\npopcorn\ndislike\nchildhood\nadore\nquick\ncomedian\nsweater\nantique\nlottery\nhows\nmcdonald\njump\ntutor\ncarolina\nwalking\npregnant\nmike\nvampire\ndecorate\nbieber\nalcohol\ncompete\nmansion\nowl\ngotcha\njack\nengineering\nretirement\npot\nairplane\nferrari\ndry\ndentist\nrussian\npiece\nsecurity\nspirit\noffer\ndorm\nrecord\nsettle\nlobster\nforeign\nsoftware\nhoney\nrice\nprincess\nexcited\namazon\nbaltimore\nisland\nskiing\ncenter\nalien\nbutter\ncorn\ncivic\njane\nview\nnap\npit\nbulldog\nlovely\nprince\nloud\nphotograph\ngift\ncoke\norg\nbelt\nfestival\nugh\nahh\nvintage\npug\nbirth\nunderstandable\nsweetheart\nirma\ndeep\nindia\nfeed\nring\nitem\nscene\nspear\nmushroom\nposition\nafternoon\nhire\ntrainer\ndistract\ntouch\nunfortunate\nalaska\nexpect\nkaty\nscratch\ncost\nsituation\ngay\nloss\norganic\nwashington\njoy\ngummy\nsurvive\nmed\nbless\nprepare\ncharity\nsight\nrare\nheavy\nrural\nrussia\nnewspaper\nkaraoke\ndriving\ncustomer\nwatching\nrequire\ngraphic\nmood\nmaine\nfreelance\nfitness\ndiner\npepsi\ncondo\nmiami\ncross\nrunner\naccept\npanda\nbunny\nengage\nuncle\ncommute\nindie\ncooler\nstraight\nhollywood\nbagel\nterrify\ntraining\nboxer\nleft\nprotein\nbull\njog\ntom\nshame\nouch\ncurrent\nprogrammer\nnerd\nmagazine\nartistic\nyikes\neating\nskittle\nfurniture\neagle\nform\nfabulous\nlegal\nagency\ninternship\ncabin\ndrama\npositive\naddict\nsurprise\nrewarding\ntax\nbarista\nfake\nspain\ncrash\nrandom\nkale\nbright\nshark\nstudio\nbow\nboys\nbell\nbrace\ntrick\nwheelchair\ncloud\nsouthern\nforce\nchair\nspouse\nthumb\nfrank\nrapper\nvirginia\nphysical\nbye\ngrad\nsoup\nfiance\nelvis\nmeatloaf\nqueen\nunited\nincome\nsalon\nvolleyball\ntarget\nchihuahua\nlimit\nscholarship\ndirection\nsupply\ncanadian\ndaddy\ntoy\nkentucky\nrespect\nearth\npolitical\nbinge\nhilarious\nblast\nunhealthy\npant\ncheeseburger\ncomplete\nunicorn\nreligion\ndairy\ndrug\nfl\nwhiskey\nreject\niced\naverage\neater\ndirty\nrat\nangry\nheck\nhide\ncompetitive\ngum\nwebsite\nlaptop\nexhaust\nrobot\nchallenge\noutdoor\nraw\naudition\ncafe\nonion\nassume\nopinion\nhorseback\nzebra\nphilosophy\npsychology\nexact\nspice\ndebt\nreside\nheat\nhobbies\nleather\nrude\nchina\nstorm\ngrown\ninvite\ninstructor\nsteal\ncurly\nwash\nworm\nbf\ncredit\ngreek\noregon\nshot\nriding\npride\nspeed\nfloor\nintense\ncourt\ndental\nbone\nfriendly\ndiving\nlame\nkitty\nalternative\nhubby\nstrict\nham\ncamping\npond\nmarine\ngrandpa\nsecretary\ntheme\njudge\nshade\ncomplain\nherb\nadvertising\ncelebrity\nfascinate\nlasagna\nenvironment\npainter\ncomfortable\nbeagle\nrecycle\nbruno\ninvest\nsearch\nsociety\nhalo\npursue\nhm\nsam\nconnect\nangeles\nweed\ngrab\nkiss\nbald\nem\nbritney\noil\nconversation\nmmm\nyard\ncash\njean\nship\nexotic\nvanilla\nwaste\nphotographer\nadult\nrolling\ndig\ncookie\ntennessee\nbalance\ncleaning\nocd\neggplant\narchery\nma\nmeatball\ndust\ndeer\ngluten\nagent\ngaming\ncelebrate\nhelpful\ntofu\ncampus\nlord\nevil\ner\nlos\nshake\nmartial\ndrummer\noutfit\ngrass\nwrestle\nnote\nconvertible\nbiking\ntaller\ngorgeous\nfile\nhectic\nsalsa\nrush\nawhile\ndistance\nsoft\nhomeless\ndaycare\nprocess\npatient\nhouston\ncrowd\nstew\nduty\nbookstore\ntie\nneighborhood\nprofessor\norleans\nelectric\noriginal\nopportunity\nfrancisco\nangel\nfund\nsite\npeace\npicky\nwisconsin\nmadonna\nadam\niceland\nblow\nuh\ntotal\nurban\ncoincidence\nentire\njello\ncompare\ndollar\nheadache\nblond\nguilty\ncure\nelectronic\ngrader\ngreece\ncorrect\ntwilight\nenjoyable\nbenefit\ndamn\ncoupon\ncheat\nthrift\nprint\nscout\nsuccessful\nbartender\nstranger\ncattle\nmommy\nbath\nnascar\nhonor\nsuggestion\nappalachian\nginger\nhook\naquarium\njames\nscared\ngamer\ncollie\ncreepy\nupstate\ntherapist\nbicycle\ntrophy\ndepartment\ncheerleader\nvariety\nsuburb\nexplain\ncuddle\nflip\nshepherd\ntupac\nwhoa\niq\nlifeguard\nsunshine\njersey\nrainy\nvehicle\ncloset\nrainbow\nross\nspecialty\nattorney\ninterior\ngf\ncrab\npodcasts\nfart\ndecent\ninsane\nshepard\ngossip\njimmy\nblackjack\naustin\nnike\nlouisiana\nbrat\ntherapy\nnoble\nskunk\nlistening\nawww\npleasant\npittsburgh\nbarbie\nspecific\npas\num\navid\noccasion\nmustache\nautograph\nnoise\ndiego\nruin\nadmit\nfail\nscotch\ncreature\ncostume\njeopardy\nswift\nafrica\ncd\naccount\nedit\ntopic\nhandy\nwindow\nsteelers\naccent\nactivist\npreacher\naffect\napply\nloose\nrefuse\nbum\nminnesota\ngender\nconfuse\nlandscape\nbroadway\nbmw\nargue\nfoodie\ntrek\nrelieve\ntodd\npath\nfreedom\npearl\ntap\nmedication\nvera\nrapid\nohh\ngovernment\nenvironmental\nfault\nft\ncope\ncarbs\nstandard\nplanet\nmcqueen\nnugget\npull\ndifference\nbabysit\nteller\ndisappoint\nmermaid\npageant\nsoap\nmidwest\ngiant\npuerto\nbowling\nasian\narizona\nhappiness\nprovide\nfond\nbud\ncell\nentertain\nbutterfly\ngenius\nscream\narchitect\njar\nmonkey\ntheatre\nhah\npurchase\nyay\nvote\nlevel\ncab\ncombo\ntool\nfluent\nduck\ncreek\nload\naffair\nhumor\nslave\npublishing\npainful\nredhead\njerry\nthursday\ndungeon\ncape\nmessy\napprove\ncousin\nweekly\nterrier\npaddle\ndavid\ntypical\nwaffle\ndesk\ncatholic\ngermany\ninstagram\nadmire\nimage\nupset\nmuscle\nmonday\npic\nbasement\nground\nwave\nshellfish\ntechnician\nepisode\njim\njacob\nballerina\nloan\nimprove\ngarage\nibm\ntooth\nbachelor\nthrill\nhippie\nnotice\nreturn\ncrush\njesus\nstomach\ntechno\nraven\nnerve\ndenver\nfoster\nthriller\nrugby\ndaydream\noreo\ndiscover\ndetroit\ncult\natlanta\nincredible\nstable\npoem\nyr\noooh\nwide\nmango\nsnowboard\nweapon\ncountryside\nalcoholic\nbrunch\nfisherman\naunt\ntoronto\nsarah\naddiction\nsurround\nlactose\nclinic\nuniverse\npretzel\ntoto\nblah\nsir\nsausage\ncosplay\ntext\nquality\nmillionaire\nclerk\nskinny\nhendrix\ndisabled\npuzzle\npepperoni\ncivil\nsurfer\nlarp\npackage\nroof\npa\nsuspense\ndrone\nmail\njason\nexam\nmark\nfinancial\nencyclopedia\ncheetos\ndemand\nshell\nplanning\nstupid\nyell\ngrateful\nbingo\nsource\ncompanion\ndirector\nbite\ndetective\nbiography\ngospel\nsilly\npudding\npork\nteeth\nautobiography\nsalt\nfootstep\ndeserve\nproduce\nswimmer\nbarbecue\nmaryland\nbtw\ndefense\nfallon\nteen\ncontinue\ncart\nwizard\nmeditate\nshelf\nzombie\nirish\npecan\nbubble\ndiscount\nscooter\npush\ntutorial\nscuba\nhomemade\nweakness\ntranslator\ngymnastics\nbackground\nsoftball\nkidding\nmistake\nrealize\nironic\nfloyd\nflight\nsurgeon\ncrack\ndesire\nintrovert\nted\nknife\nmba\npottery\norphan\nrecover\nmini\nbucket\nperk\nautism\nmoped\ncycle\nyouth\nspoil\nattitude\ninjury\npennsylvania\naspire\nloyal\nattack\nmurder\nprice\ntank\nsafety\nolympics\nrome\ndj\ncarpenter\nsesame\nconsume\nprotect\nbritish\nsword\ncheetah\nfloat\nasia\nmate\ncreed\nxbox\nlean\nfreckle\ncaffeine\nhunter\ntrump\nprogramming\npicnic\near\nfridge\neaster\nasparagus\noldie\ndarn\ndisagree\nmirror\nhehe\new\nbiscuit\nfreshman\nwhistle\nusual\ninch\ndeli\neclipse\ncycling\nmodern\ntease\nreview\npattern\naward\nbelong\nnickname\nbuff\nsweden\ncuz\nselfish\npersonality\ngraduation\ndolphin\nalbum\npup\ntaylor\nlease\neducate\ndepress\nparamedic\ngila\npurpose\nprison\nworried\nvermont\nblock\nbuffalo\nolive\nidentical\nscore\nstring\nactual\nintelligent\nepilepsy\nsand\nplate\nsubtitle\nborder\ncable\nsmooth\nlexus\ndevelop\nvienna\nbrave\nwelfare\ndoberman\nwealthy\nswitch\nsneak\nrobert\nhill\nnacho\nmugger\nsnap\ndumb\ncoworker\npeta\nhr\nchick\nfur\ngoalie\nrange\nintroduce\ncostco\nrailroad\nsuffer\nmenu\nsoldier\nasthma\nsex\nlindsey\nutah\ngrandfather\ndocumentary\nadmirable\ntraveling\ndane\nemployee\narticle\ncaramel\nharley\nequipment\nheal\nbasic\ndabble\ndepression\nattract\ngaga\ntale\ntube\nfried\ndoggy\nmash\ntaught\nreceptionist\nempire\nintern\npiercings\ntackle\nease\nblanket\nparticipate\ncheesy\ngray\ndaisy\ntoddler\nlink\ndiamond\npropose\ndallas\npresident\nranger\nwolf\ngain\nchai\nannoying\nearring\nversion\nbasket\nlens\nsalary\ncorner\nchampion\nfirefighter\nferret\nachieve\nsears\nmia\nidol\njoe\ndecade\nemotion\nkoala\nmanagement\npharmacist\napps\ndepends\nkiller\nfellow\nuniform\ngourmet\ncleveland\nmotivate\nhummus\nentertainment\nmmmm\nbag\nnashville\nflirt\nowen\ntrumpet\nnevada\nstage\njerky\nresponsibility\ndrake\nbentley\ngold\ntx\narcade\nankle\nvegas\nkj\nbatman\nunwind\nkeyboard\ncombination\nleaf\nkoi\ncello\nminimum\nadventurous\nloving\nkiddos\nspill\ndiabetic\ncentral\nrob\ntrend\nbubblegum\nindoors\nmonster\ndrunk\nconfidence\npyramid\ngrunge\nbanker\nbreath\ngrasshopper\nhoop\nencourage\ngutter\nmacaroni\nrobotics\ndouble\nseat\ndew\nuncomfortable\nflash\nbench\nbomb\ninfo\nsemi\ncomfort\nwildlife\ngeology\nlonely\ncoz\nedge\nanniversary\nvitamin\nmaterial\nhotdog\ngathering\nsocialize\nmachine\neditor\ndroopy\nbrew\nliberal\nmercedes\nquarterback\nrn\nplanner\nfortunate\nur\ngreenhouse\nsi\npsychologist\npromotion\ndef\ndiscovery\ncarb\nhumane\nbroken\nchanel\nfluffy\nchain\nvision\nstanford\npipe\nability\noverweight\npromote\nlabrador\nveteran\npreference\nsymphony\nalpaca\nve\napp\nteenager\nanne\npromise\ndoo\npublisher\ncurious\ntiki\nporsche\nmixed\nmaid\nlegend\nmichael\nsupportive\npineapple\nariel\ndiabetes\nconsulting\nstarve\ngal\ncollar\ngable\nbattle\njacket\nsexy\nsleeve\nfelix\npastime\njamaica\nmortal\nweak\nscrub\nrabbit\ngodfather\nsinatra\nvalley\ndespise\nregret\ngoodwill\nheaven\nbuddhist\nsmoking\noops\npitbulls\nsalmon\nrick\nbitcoin\ndip\ntrout\npill\nfarming\nthankful\ntokyo\nhousewife\nprayer\nimpala\nvaledictorian\nplain\nmessage\ntemper\nflintstone\nleprechaun\nsucker\nbreathe\ncsi\ncriminal\nrip\nmaiden\nfascinating\nrico\nalgeria\nreport\numm\npatience\nleader\ncurl\nmotivation\nclimbing\ntahoe\nymca\nrelief\nglacier\nbreast\nenter\nclutter\ndull\nfighter\ntat\nawake\nbrewery\nvictorian\nvolcano\nfriends\nmount\npillage\nmagical\ngeneration\nclue\nconscious\nstare\nsilver\nwrestling\nlevine\njoint\nrestore\neverest\ndope\nstray\ninternational\nparking\nhampshire\nhearse\nwarehouse\npitbull\nnyu\noutdoorsy\ndevelopment\nemployment\ndrinker\nzumba\npaul\nbudget\ndaniel\neyesight\nsour\nmouth\nstain\nblogger\nexist\nrib\nbrush\ninterview\nbff\ncustom\nsnuggle\nvancouver\nmario\nferraris\nmural\npoet\noriole\nperiod\nkarma\ndamage\nwarmer\ncrossword\nchildrens\npomeranian\nimaginary\ndave\nanatomy\ntone\ncode\nvideogames\nwoodstock\nconvention\njanitor\npreschool\nscreen\nprejudice\ncrystal\nrage\ntradition\nchatting\ntraditional\nparakeet\nramen\ncombat\nmultiple\ncrave\nsyrup\nracing\nhighlight\ncommunist\nconcentrate\nwaiter\nebooks\ndodge\nhp\nboil\nattic\nmedal\ncommitment\nrelease\ndowntown\nalligator\nstatement\ndebate\nagreed\nmaga\nhomeschooled\nstrength\nplumber\nhippy\nwindy\ncondition\nsmoothie\nstair\ncontent\ndepressed\nferrell\nketo\nremodel\ndonut\nwinner\nplaylist\nwayne\nnation\nkpop\nmap\ncoon\njunior\nmum\ntape\nquake\nsmithsonian\nwasher\nabigail\nradiohead\nhumble\nunicycle\nadministration\nontario\nperformance\ntruth\nfred\ningredient\ncucumber\nbeastie\norchestra\nsewing\nknock\nculinary\nsweat\nseashell\nimpression\nnetwork\nlanguages\ntailgate\ncelebration\nthomas\nembarrass\nborn\nmama\nfreeze\ncrap\nfortune\nfigurine\nconfident\nhomebody\nchemistry\ncollector\nmerna\narrive\ntitanic\nmeditation\nbout\nmanta\nannouncer\nsolo\ncircle\nmd\nfuneral\nengine\nbutt\ndelivery\nultimate\nspecialize\nweb\npalm\nabsolute\ninvestment\nharsh\npistachio\nloner\nexperiment\ngut\nausten\nfuel\ncramp\ntrauma\nsleepy\nceltic\npress\ndraft\nauto\nsprite\nobsession\nsip\nfifty\nvinyl\nswing\nfool\nhbu\nharvey\ncopperfield\nplayoff\nkite\nlesbian\njerk\nowe\ndemocrat\nmass\nhamilton\nga\nuk\nluis\nimpress\nslice\npita\nhobbie\napologize\nsanta\ntacos\nlanding\nhometown\ntelecom\nmater\nmutt\ndeploy\ndel\nsore\nnancy\nbarbies\nfam\nclay\nethnic\npastry\nhostage\ntight\nbackyard\nconvince\nmaker\ncurry\nandroid\npc\njessica\nignore\nflow\nsickness\nelderly\nchore\nupholstery\nsweetie\nlettuce\ncuba\ngadget\nanimation\ntrooper\nfaith\ntongue\nsuccess\ngentle\nportrait\nsheeran\nchevrolet\npacker\nrisk\nspark\nfrustrate\nmouse\npitch\nweld\neyebrow\nbella\nlinebacker\nbully\nroutine\nspelling\nbc\ncoat\nsaudi\narabia\ntampa\nemmy\nsamsung\nmop\nkevin\nchecker\nteapot\nweigh\nsuv\nmiserable\nsevenfold\nf150\nlit\nposse\nthai\ncurator\nsteve\npoop\nhistorical\nmorty\ncane\nmiley\nwise\npetition\ntear\npenn\nastronaut\ncod\ncolour\nacting\nprecious\nbuck\nlucy\nmuse\ncosmetic\noccupation\nnba\nate\nflexible\nideal\nsuspender\nbang\ndirect\ngotti\nagitate\nhairdresser\ndealership\ninfluence\ncursive\nsunfish\nsnorkel\nshallow\nroot\npediatrician\ncompost\ncoaster\nnearby\nforeman\ndeadbeat\npenny\njay\njasper\ntarot\npressure\nclarinet\nsupper\nexpress\nai\nmartini\nfavor\nchop\nlutefisk\ncharge\ndakota\nhitchhike\nformal\nivy\nraptor\nbattlestar\ncaptain\ndisgust\ntask\nsitcom\nyorkie\ncoco\nunderstood\nnaw\nant\nstinky\nspeckle\ntitle\ncorporate\nwednesday\ngambler\nwage\nmulti\nmma\ncookbook\ncitizen\nhazel\naspiration\ngoat\nstuck\nlumberjack\nflag\nwet\nufc\nlearning\nstirling\ndealer\ngrisham\nacre\n"
  },
  {
    "path": "preprocess/data_utils.py",
    "content": "import nltk\nimport os\nfrom nltk.stem import WordNetLemmatizer\n\n_lemmatizer = WordNetLemmatizer()\n\n\ndef tokenize(example, ppln):\n    for fn in ppln:\n        example = fn(example)\n    return example\n\n\ndef kw_tokenize(string):\n    return tokenize(string, [nltk_tokenize, lower, pos_tag, to_basic_form])\n\n\ndef simp_tokenize(string):\n    return tokenize(string, [nltk_tokenize, lower])\n\n\ndef nltk_tokenize(string):\n    return nltk.word_tokenize(string)\n\n\ndef lower(tokens):\n    if not isinstance(tokens, str):\n        return [lower(token) for token in tokens]\n    return tokens.lower()\n\n\ndef pos_tag(tokens):\n    return nltk.pos_tag(tokens)\n\n\ndef to_basic_form(tokens):\n    if not isinstance(tokens, tuple):\n        return [to_basic_form(token) for token in tokens]\n    word, tag = tokens\n    if tag.startswith('NN'):\n        pos = 'n'\n    elif tag.startswith('VB'):\n        pos = 'v'\n    elif tag.startswith('JJ'):\n        pos = 'a'\n    else:\n        return word\n    return _lemmatizer.lemmatize(word, pos)\n\n\ndef truecasing(tokens):\n    ret = []\n    is_start = True\n    for word, tag in tokens:\n        if word == 'i':\n            ret.append('I')\n        elif tag[0].isalpha():\n            if is_start:\n                ret.append(word[0].upper() + word[1:])\n            else:\n                ret.append(word)\n            is_start = False\n        else:\n            if tag != ',':\n                is_start = True\n            ret.append(word)\n    return ret\n\n\ncandi_keyword_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'convai2/candi_keyword.txt')\n_candiwords = [x.strip() for x in open(candi_keyword_path).readlines()]\n\n\ndef is_candiword(a):\n    if a in _candiwords:\n        return True\n    return False\n\n\nfrom nltk.corpus import wordnet as wn\nfrom nltk.corpus import wordnet_ic\n\nbrown_ic = wordnet_ic.ic('ic-brown.dat')\n\n\ndef calculate_linsim(a, b):\n    linsim = -1\n    syna = wn.synsets(a)\n    synb = wn.synsets(b)\n    for sa in syna:\n        for sb in synb:\n            try:\n                linsim = max(linsim, sa.lin_similarity(sb, brown_ic))\n            except:\n                pass\n    return linsim\n\n\ndef is_reach_goal(context, goal):\n    context = kw_tokenize(context)\n    if goal in context:\n        return True\n    for wd in context:\n        if is_candiword(wd):\n            rela = calculate_linsim(wd, goal)\n            if rela > 0.9:\n                return True\n    return False\n\n\ndef make_context(string):\n    string = kw_tokenize(string)\n    context = []\n    for word in string:\n        if is_candiword(word):\n            context.append(word)\n    return context\n\n\ndef utter_preprocess(string_list, max_length):\n    source, minor_length = [], []\n    string_list = string_list[-9:]\n    major_length = len(string_list)\n    if major_length == 1:\n        context = make_context(string_list[-1])\n    else:\n        context = make_context(string_list[-2] + string_list[-1])\n    context_len = len(context)\n    while len(context) < 20:\n        context.append('<PAD>')\n    for string in string_list:\n        string = simp_tokenize(string)\n        if len(string) > max_length:\n            string = string[:max_length]\n        string = ['<BOS>'] + string + ['<EOS>']\n        minor_length.append(len(string))\n        while len(string) < max_length + 2:\n            string.append('<PAD>')\n        source.append(string)\n    while len(source) < 9:\n        source.append(['<PAD>'] * (max_length + 2))\n        minor_length.append(0)\n    return (source, minor_length, major_length, context, context_len)\n"
  },
  {
    "path": "preprocess/dataset.py",
    "content": "import numpy as np\nimport collections\nimport random\nimport pickle\nfrom convai2 import dts_ConvAI2\nfrom extraction import KeywordExtractor\nfrom data_utils import *\n\nclass dts_Target(dts_ConvAI2):\n    def __init__(self, *args, **kwargs):\n        super().__init__(*args, **kwargs)\n\n    def get_vocab(self):\n        counter = collections.Counter()\n        dialogs = self.get_dialogs()\n        for dialog in dialogs:\n            for uttr in dialog:\n                counter.update(simp_tokenize(uttr))\n        print('total vocab count: ', len(counter.items()))\n        vocab = [token for token, times in sorted(list(counter.items()), key=lambda x: (-x[1], x[0]))]\n        with open('../tx_data/vocab.txt','w') as f:\n            for word in vocab:\n                f.write(word + '\\n')\n        print('save vocab in vocab.txt')\n        return vocab\n\n    def get_kwsess(self, vocab, mode='all'):\n        keyword_extractor = KeywordExtractor(vocab)\n        corpus = self.get_data(mode = mode, cands=False)\n        sess_set = []\n        for sess in corpus:\n            data = {}\n            data['history'] = ''\n            data['dialog'] = []\n            for dialog in sess['dialog']:\n                data['dialog'].append(dialog)\n                data['history'] = data['history'] + ' ' + dialog\n            data['kws'] = keyword_extractor.extract(data['history'])\n            sess_set.append(data)\n        return sess_set\n\n    def cal_idf(self):\n        counter = collections.Counter()\n        dialogs = self.get_dialogs()\n        total = 0.\n        for dialog in dialogs:\n            for uttr in dialog:\n                total += 1\n                counter.update(set(kw_tokenize(uttr)))\n        idf_dict = {}\n        for k,v in counter.items():\n            idf_dict[k] = np.log10(total / (v+1.))\n        return idf_dict\n\n    def make_dataset(self):\n        vocab = self.get_vocab()\n        idf_dict = self.cal_idf()\n        kw_counter = collections.Counter()\n        sess_set = self.get_kwsess(vocab)\n        for data in sess_set:\n            kw_counter.update(data['kws'])\n        kw_freq = {}\n        kw_sum = sum(kw_counter.values())\n        for k, v in kw_counter.most_common():\n            kw_freq[k] = v / kw_sum\n        for data in sess_set:\n            data['score'] = 0.\n            for kw in set(data['kws']):\n                data['score'] += kw_freq[kw]\n            data['score'] /= len(set(data['kws']))\n        sess_set.sort(key=lambda x: x['score'], reverse=True)\n\n        all_data = {'train':[], 'valid':[], 'test':[]}\n        keyword_extractor = KeywordExtractor(idf_dict)\n        for id, sess in enumerate(sess_set):\n            type = 'train'\n            if id < 500:\n                type = 'test'\n            elif random.random() < 0.05:\n                type = 'valid'\n            sample = {'dialog':sess['dialog'], 'kwlist':[]}\n            for i in range(len(sess['dialog'])):\n                sample['kwlist'].append(keyword_extractor.idf_extract(sess['dialog'][i]))\n            all_data[type].append(sample)\n        pickle.dump(all_data, open('source_data.pk','wb'))\n        return all_data"
  },
  {
    "path": "preprocess/extraction.py",
    "content": "from data_utils import *\n\nclass KeywordExtractor():\n    def __init__(self, idf_dict = None):\n        self.idf_dict = idf_dict\n\n    @staticmethod\n    def is_keyword_tag(tag):\n        return tag.startswith('VB') or tag.startswith('NN') or tag.startswith('JJ')\n\n    @staticmethod\n    def cal_tag_score(tag):\n        if tag.startswith('VB'):\n            return 1.\n        if tag.startswith('NN'):\n            return 2.\n        if tag.startswith('JJ'):\n            return 0.5\n        return 0.\n\n    def idf_extract(self, string, con_kw = None):\n        tokens = simp_tokenize(string)\n        seq_len = len(tokens)\n        tokens = pos_tag(tokens)\n        source = kw_tokenize(string)\n        candi = []\n        result = []\n        for i, (word, tag) in enumerate(tokens):\n            score = self.cal_tag_score(tag)\n            if not is_candiword(source[i]) or score == 0.:\n                continue\n            if con_kw is not None and source[i] in con_kw:\n                continue\n            score *= source.count(source[i])\n            score *= 1 / seq_len\n            score *= self.idf_dict[source[i]]\n            candi.append((source[i], score))\n            if score > 0.15:\n                result.append(source[i])\n        return list(set(result))\n\n\n    def extract(self, string):\n        tokens = simp_tokenize(string)\n        tokens = pos_tag(tokens)\n        source = kw_tokenize(string)\n        kwpos_alters = []\n        for i, (word, tag) in enumerate(tokens):\n            if source[i] and self.is_keyword_tag(tag):\n                kwpos_alters.append(i)\n        kwpos, keywords = [], []\n        for id in kwpos_alters:\n            if is_candiword(source[id]):\n                keywords.append(source[id])\n        return list(set(keywords))"
  },
  {
    "path": "preprocess/prepare_data.py",
    "content": "from dataset import dts_Target\nfrom collections import Counter\nimport pickle\nimport random\nimport os\nimport shutil\nif not os.path.exists('../tx_data'):\n    os.mkdir('../tx_data')\n    os.mkdir('../tx_data/train')\n    os.mkdir('../tx_data/valid')\n    os.mkdir('../tx_data/test')\n\n# import texar\n# if not os.path.exists('convai2/source'):\n#     print('Downloading source ConvAI2 data')\n#     texar.data.maybe_download('https://drive.google.com/file/d/1LPxNIVO52hZOwbV3Zply_ITi2Uacit-V/view?usp=sharing'\n#                                 ,'convai2', extract=True)\n\nshutil.copy('convai2/source/embedding.txt', '../tx_data/embedding.txt')\ndataset = dts_Target()\ndataset.make_dataset()\n\ndata = pickle.load(open(\"source_data.pk\",\"rb\"))\nmax_utter = 9\ncandidate_num = 20\nstart_corpus_file = open(\"../tx_data/start_corpus.txt\", \"w\")\ncorpus_file = open(\"../tx_data/corpus.txt\", \"w\")\n\nfor stage in ['train', 'valid', 'test']:\n    source_file = open(\"../tx_data/{}/source.txt\".format(stage), \"w\")\n    target_file = open(\"../tx_data/{}/target.txt\".format(stage), \"w\")\n    context_file = open(\"../tx_data/{}/context.txt\".format(stage), \"w\")\n    keywords_file = open(\"../tx_data/{}/keywords.txt\".format(stage), \"w\")\n    label_file = open(\"../tx_data/{}/label.txt\".format(stage), \"w\")\n    keywords_vocab_file = open(\"../tx_data/{}/keywords_vocab.txt\".format(stage), \"w\")\n    corpus = []\n    keywords_counter = Counter()\n    for sample in data[stage]:\n        corpus += sample['dialog'][1:]\n        start_corpus_file.write(sample['dialog'][0]+ '\\n')\n        for kws in sample['kwlist']:\n            keywords_counter.update(kws)\n    for kw, _ in keywords_counter.most_common():\n        keywords_vocab_file.write(kw + '\\n')\n    for sample in data[stage]:\n        for i in range(2, len(sample['dialog'])):\n            if len(sample['kwlist'][i]) > 0:\n                source_list = sample['dialog'][max(0, i - max_utter):i]\n                source_str = '|||'.join(source_list)\n                while True:\n                    random_corpus = random.sample(corpus, candidate_num - 1)\n                    if sample['dialog'][i] not in random_corpus:\n                        break\n                corpus_file.write(sample['dialog'][i] + '\\n')\n                target_list = [sample['dialog'][i]] + random_corpus\n                target_str = '|||'.join(target_list)\n                source_file.write(source_str + '\\n')\n                target_file.write(target_str + '\\n')\n                context_file.write(' '.join(sample['kwlist'][i-2] +\n                    sample['kwlist'][i-1]) + '\\n')\n                keywords_file.write(' '.join(sample['kwlist'][i]) + '\\n')\n                label_file.write('0\\n')\n                \n    source_file.close()\n    target_file.close()\n    label_file.close()\n    keywords_vocab_file.close()\n    context_file.close()\n\nstart_corpus_file.close()\ncorpus_file.close()\n"
  },
  {
    "path": "readme.md",
    "content": "# Target-Guided Open-Domain Conversation\r\n\r\nThis is the code for the following paper:\r\n\r\n[Target-Guided Open-Domain Conversation](http://arxiv.org/abs/1905.11553)  \r\n*Jianheng Tang, Tiancheng Zhao, Chenyan Xiong, Xiaodan Liang, Eric Xing, Zhiting Hu; ACL 2019*\r\n\r\n### Requirement\r\n\r\n- `nltk==3.4`  \r\n- `tensoflow==1.12`   \r\n- `texar>=0.2.1` ([Texar](https://github.com/asyml/texar))\r\n\r\n### Usage\r\n\r\n#### Data Preparation\r\nThe dataset developed in the paper is on [google drive](https://drive.google.com/file/d/1oTjOQjm7iiUitOPLCmlkXOCbEPoSWDPX/view?usp=sharing). Download \r\nand unzip it into `preprocess/convai2`. Then run the following command:\r\n```shell\r\ncd preprocess\r\npython prepare_data.py\r\n```\r\nBy default, the processed data will be put in the `tx_data` directory.\r\n\r\n#### Turn-level Supervised Learning\r\nIn this project there are 5 different types of agents, including the kernel/neural/matrix/retrieval/retrieval_stgy agent,\r\n which are all discribed in the paper. You can modify the configration of each agent in the `config` directory.\r\n\r\nTo train the kernel/neural/matrix agent, you need to first train/test the keyword prediction module, \r\nand then train/test the retrieval module of each agent specified by the `--agent` parameter.\r\n\r\n```shell\r\npython train.py --mode train_kw --agent kernel\r\npython train.py --mode train --agent kernel\r\npython train.py --mode test --agent kernel\r\n```\r\n\r\nThe retrieval agent and the retrieval_stgy agent share the same retrival module. You only need to train one of them:\r\n\r\n```shell\r\npython train.py --mode train --agent retrieval\r\npython train.py --mode test --agent retrieval\r\n```\r\n\r\n#### Target-guided Conversation\r\n\r\nAfter turn-level training, you can start target-guided conversation (human evaluation) with \r\nthe kernel/neural/matrix/retrieval/retrieval_stgy  agent specified by the `--agent` parameter.\r\n\r\n```shell\r\npython chat.py --agent kernel\r\n```\r\nA demo of target-guided conversation (user is required to input the response after `HUMAN:`)\r\n\r\n```\r\n--------Session 0 --------\r\nSTART:  hello, how are you doing\r\nHUMAN:  hi i am great just finishing up some homework. how about you\r\nAGENT:  just got done eating some chicken and rice after i got done with my gym session\r\nHUMAN:  i am gonna eat some vegetable for dinner\r\nAGENT:  so you are vegetarian ? i am from los angeles so thereare lots of vegetarian restaurants\r\nHUMAN:  i am not a vegetarian . fried chicken is my favorite.\r\nAGENT:  oh boy , when i go to mcdonald i always get fries for my 2 dogs\r\nSuccessfully chat to the target 'mcdonald'.\r\n```\r\n\r\nYou can also watch the simulation of the target-guided conversation \r\nbetween the retrieval agent pretending the user and the kernel/neural/matrix/retrieval_stgy agent specified by the `--agent` parameter. Simulation times are defined by the `--times` parameter. The program will calculate the success times and average turns in the end.\r\n\r\n```shell\r\npython simulate.py --agent kernel --times 100\r\n```\r\n\r\nA demo of simulation:\r\n```\r\n--------Session 0 --------\r\nSTART: hey hey good morning to you\r\nretrieval_agent:  fine thanks . what do you do for a living ?\r\nkernel_agent:  rewarding job , i work at a hospital\r\nKeyword: job, Similarity: 0.58\r\nSuccessfully chat to the target 'hospital'.\r\n\r\n...\r\n--------Session 99 --------\r\nSTART: hey hows it going ? i'm just cooking a steak\r\nretrieval_agent:  i'm thinking of a bbq sandwich for lunch\r\nkernel_agent:  nice i love to cook but now its just me and the fur babies\r\nKeyword: baby, Similarity: 0.45\r\nretrieval_agent:  i love bagels however i own a dry cleaners\r\nkernel_agent:  i love animals felix my cat and my dog emmy\r\nKeyword: cat, Similarity: 0.56\r\nretrieval_agent:  sounds awesome i have all kind of pets my family own a farm\r\nkernel_agent:  i love blue as well even my hair is blue\r\nKeyword: blue, Similarity: 1.00\r\nSuccessfully chat to the target 'blue'.\r\n\r\nsuccess time 83, average turns 4.28\r\n```\r\n"
  },
  {
    "path": "simulate.py",
    "content": "import tensorflow as tf\nimport importlib\nimport random\nfrom preprocess.data_utils import utter_preprocess, is_reach_goal\nfrom model import retrieval\n\nclass Target_Simulation():\n    def __init__(self, config_model, config_data, config_retrieval):\n        g1 = tf.Graph()\n        with g1.as_default():\n            self.retrieval_agent = retrieval.Predictor(config_retrieval, config_data)\n            sess1 = tf.Session(graph=g1, config=self.retrieval_agent.gpu_config)\n            self.retrieval_agent.retrieve_init(sess1)\n        g2 = tf.Graph()\n        with g2.as_default():\n            self.target_agent = model.Predictor(config_model, config_data)\n            sess2 = tf.Session(graph=g2, config=self.target_agent.gpu_config)\n            self.target_agent.retrieve_init(sess2)\n        self.start_utter = config_data._start_corpus\n        success_cnt, turns_cnt = 0, 0\n        for i in range(int(FLAGS.times)):\n            print('--------Session {} --------'.format(i))\n            success, turns = self.simulate(sess1, sess2)\n            success_cnt += success\n            turns_cnt += turns\n        print('success time {}, average turns {:.2f}'.format(success_cnt, turns_cnt / success_cnt))\n\n    def simulate(self, sess1, sess2):\n        history = []\n        history.append(random.sample(self.start_utter,1)[0])\n        target_kw = random.sample(target_set,1)[0]\n        self.target_agent.target = target_kw\n        self.target_agent.score = 0.\n        self.target_agent.reply_list = []\n        self.retrieval_agent.reply_list = []\n\n        print('START: ' + history[0])\n        for i in range(config_data._max_turns):\n            source = utter_preprocess(history, config_data._max_seq_len)\n            reply = self.retrieval_agent.retrieve(source, sess1)\n            print('retrieval_agent: ', reply)\n            history.append(reply)\n            source = utter_preprocess(history, config_data._max_seq_len)\n            reply = self.target_agent.retrieve(source, sess2)\n            print('{}_agent: '.format(FLAGS.agent), reply)\n            print('Keyword: {}, Similarity: {:.2f}'.format(self.target_agent.next_kw, self.target_agent.score))\n            history.append(reply)\n            if is_reach_goal(history[-2] + history[-1], target_kw):\n                print('Successfully chat to the target \\'{}\\'.'.format(target_kw))\n                return (True, (len(history)+1)//2)\n\n        print('Failed by reaching the maximum turn, target: \\'{}\\'.'.format(target_kw))\n        return (False, 0)\n\nif __name__ == '__main__':\n    flags = tf.flags\n    flags.DEFINE_string('agent', 'kernel', 'The agent type, supports kernel / matrix / neural / retrieval.')\n    flags.DEFINE_string('times', '100', 'Simulation times.')\n\n    FLAGS = flags.FLAGS\n    config_data = importlib.import_module('config.data_config')\n    config_model = importlib.import_module('config.' + FLAGS.agent)\n    config_retrieval = importlib.import_module('config.retrieval')\n    model = importlib.import_module('model.' + FLAGS.agent)\n\n    target_set = []\n    for line in open('tx_data/test/keywords.txt', 'r').readlines():\n        target_set = target_set + line.strip().split(' ')\n\n    Target_Simulation(config_model,config_data,config_retrieval)"
  },
  {
    "path": "train.py",
    "content": "import tensorflow as tf\nimport importlib\nimport os\nif __name__ == '__main__':\n    flags = tf.flags\n    flags.DEFINE_string('data', 'data_config', 'The data config')\n    flags.DEFINE_string('agent', 'kernel', 'The predictor type')\n    flags.DEFINE_string('mode', 'train', 'The mode')\n\n    FLAGS = flags.FLAGS\n    config_data = importlib.import_module('config.' + FLAGS.data)\n    config_model = importlib.import_module('config.' + FLAGS.agent)\n    model = importlib.import_module('model.' + FLAGS.agent)\n    predictor = model.Predictor(config_model, config_data, FLAGS.mode)\n    if not os.path.exists('save/'+FLAGS.agent):\n        os.makedirs('save/'+FLAGS.agent)\n\n    if FLAGS.mode == 'train_kw':\n        predictor.train_keywords()\n    if FLAGS.mode == 'test_kw':\n        predictor.test_keywords()\n    if FLAGS.mode == 'train':\n        predictor.train()\n        predictor.test()\n    if FLAGS.mode == 'test':\n        predictor.test()\n"
  }
]