Repository: squareRoot3/Target-Guided-Conversation Branch: master Commit: 4bbfaece1da3 Files: 22 Total size: 119.2 KB Directory structure: gitextract_1diizlg0/ ├── chat.py ├── config/ │ ├── data_config.py │ ├── kernel.py │ ├── matrix.py │ ├── neural.py │ ├── retrieval.py │ └── retrieval_stgy.py ├── model/ │ ├── kernel.py │ ├── matrix.py │ ├── neural.py │ ├── retrieval.py │ └── retrieval_stgy.py ├── preprocess/ │ ├── convai2/ │ │ ├── __init__.py │ │ ├── api.py │ │ └── candi_keyword.txt │ ├── data_utils.py │ ├── dataset.py │ ├── extraction.py │ └── prepare_data.py ├── readme.md ├── simulate.py └── train.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: chat.py ================================================ import tensorflow as tf import importlib import random from preprocess.data_utils import utter_preprocess, is_reach_goal class Target_Chat(): def __init__(self, agent): self.agent = agent self.start_utter = config_data._start_corpus with tf.Session(config=self.agent.gpu_config) as sess: self.agent.retrieve_init(sess) for i in range(int(FLAGS.times)): print('--------Session {} --------'.format(i)) self.chat(sess) def chat(self, sess): history = [] history.append(random.sample(self.start_utter, 1)[0]) target_kw = random.sample(target_set,1)[0] self.agent.target = target_kw self.agent.score = 0. self.agent.reply_list = [] print('START: ' + history[0]) for i in range(config_data._max_turns): history.append(input('HUMAN: ')) source = utter_preprocess(history, self.agent.data_config._max_seq_len) reply = self.agent.retrieve(source, sess) print('AGENT: ', reply) # print('Keyword: {}, Similarity: {:.2f}'.format(self.agent.next_kw, self.agent.score)) history.append(reply) if is_reach_goal(history[-2] + history[-1], target_kw): print('Successfully chat to the target \'{}\'.'.format(target_kw)) return print('Failed by reaching the maximum turn, target: \'{}\'.'.format(target_kw)) if __name__ == '__main__': flags = tf.flags # supports kernel / matrix / neural / retrieval / retrieval-stg flags.DEFINE_string('agent', 'kernel', 'The agent type') flags.DEFINE_string('times', '100', 'Conversation times') FLAGS = flags.FLAGS config_data = importlib.import_module('config.data_config') config_model = importlib.import_module('config.' + FLAGS.agent) model = importlib.import_module('model.' + FLAGS.agent) predictor = model.Predictor(config_model, config_data, 'test') target_set = [] for line in open('tx_data/test/keywords.txt', 'r').readlines(): target_set = target_set + line.strip().split(' ') Target_Chat(predictor) ================================================ FILE: config/data_config.py ================================================ import os data_root = './tx_data' _corpus = [x.strip() for x in open('tx_data/corpus.txt', 'r').readlines()] _start_corpus = [x.strip() for x in open('tx_data/start_corpus.txt', 'r').readlines()] _max_seq_len = 30 _num_neg = 20 _max_turns = 8 _batch_size = 64 _retrieval_candidates = 1000 data_hparams = { stage: { "num_epochs": 1, "shuffle": stage != 'test', "batch_size": _batch_size, "datasets": [ { # dialogue history "variable_utterance": True, "max_utterance_cnt": 9, "max_seq_length": _max_seq_len, "files": [os.path.join(data_root, '{}/source.txt'.format(stage))], "vocab_file": os.path.join(data_root, 'vocab.txt'), "embedding_init": { "file": os.path.join(data_root, 'embedding.txt'), "dim": 200, "read_fn": "load_glove" }, "data_name": "source" }, { # candidate response "variable_utterance": True, "max_utterance_cnt": 20, "max_seq_length": _max_seq_len, "files": [os.path.join(data_root, '{}/target.txt'.format(stage))], "vocab_share_with": 0, "embedding_init_share_with" : 0, "data_name": "target" }, { # context (source keywords) "files": [os.path.join(data_root, '{}/context.txt'.format(stage))], "vocab_share_with": 0, "embedding_init_share_with": 0, "data_name": "context", "bos_token": '', "eos_token": '', }, { # target keywords "files": [os.path.join(data_root, '{}/keywords.txt'.format(stage))], "vocab_share_with": 0, "embedding_init_share_with": 0, "data_name": "keywords", "bos_token": '', "eos_token": '', }, { # label "files": [os.path.join(data_root, '{}/label.txt'.format(stage))], "data_type": "int", "data_name": "label" } ] } for stage in ['train','valid','test'] } corpus_hparams = { "batch_size": _batch_size*2, "shuffle": False, "dataset":{ "max_seq_length": _max_seq_len, "files": [os.path.join(data_root, 'corpus.txt')], "vocab_file": os.path.join(data_root, 'vocab.txt'), "data_name": "corpus" } } _keywords_path = 'tx_data/test/keywords_vocab.txt' _keywords_candi = [x.strip() for x in open(_keywords_path, 'r').readlines()] _keywords_num = len(_keywords_candi) _keywords_dict = {} for i in range(_keywords_num): _keywords_dict[_keywords_candi[i]] = i ================================================ FILE: config/kernel.py ================================================ _hidden_size = 200 _code_len = 800 _save_path = 'save/kernel/model_1' _kernel_save_path = 'save/kernel/keyword_1' _kernel_mu = [0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,1.] _kernel_sigma = 0.1 _max_epoch = 10 _early_stopping = 2 kernel_opt_hparams = { "optimizer": { "type": "AdamOptimizer", "kwargs": { "learning_rate": 0.001, } }, "learning_rate_decay": { "type": "inverse_time_decay", "kwargs": { "decay_steps": 1600, "decay_rate": 0.8 }, "start_decay_step": 0, "end_decay_step": 16000, }, } source_encoder_hparams = { "encoder_minor_type": "BidirectionalRNNEncoder", "encoder_minor_hparams": { "rnn_cell_fw": { "type": "GRUCell", "kwargs": { "num_units": _hidden_size, }, }, "rnn_cell_share_config": True }, "encoder_major_type": "UnidirectionalRNNEncoder", "encoder_major_hparams": { "rnn_cell": { "type": "GRUCell", "kwargs": { "num_units": _hidden_size*2, }, } } } target_encoder_hparams = { "rnn_cell_fw": { "type": "GRUCell", "kwargs": { "num_units": _hidden_size, }, }, "rnn_cell_share_config": True } target_kwencoder_hparams = { "rnn_cell_fw": { "type": "GRUCell", "kwargs": { "num_units": _hidden_size, }, }, "rnn_cell_share_config": True } context_encoder_hparams = { "rnn_cell": { "type": "GRUCell", "kwargs": { "num_units": _hidden_size, }, } } opt_hparams = { "optimizer": { "type": "AdamOptimizer", "kwargs": { "learning_rate": 0.001, } }, } ================================================ FILE: config/matrix.py ================================================ _hidden_size = 200 _code_len = 800 _save_path = 'save/matrix/model_1' _matrix_save_path = 'save/matrix/matrix_1.pk' _max_epoch = 10 _vocab_path = 'tx_data/vocab.txt' _vocab = [x.strip() for x in open(_vocab_path, 'r').readlines()] _vocab_size = len(_vocab) source_encoder_hparams = { "encoder_minor_type": "BidirectionalRNNEncoder", "encoder_minor_hparams": { "rnn_cell_fw": { "type": "GRUCell", "kwargs": { "num_units": _hidden_size, }, }, "rnn_cell_share_config": True }, "encoder_major_type": "UnidirectionalRNNEncoder", "encoder_major_hparams": { "rnn_cell": { "type": "GRUCell", "kwargs": { "num_units": _hidden_size*2, }, } } } target_encoder_hparams = { "rnn_cell_fw": { "type": "GRUCell", "kwargs": { "num_units": _hidden_size, }, }, "rnn_cell_share_config": True } target_kwencoder_hparams = { "rnn_cell_fw": { "type": "GRUCell", "kwargs": { "num_units": _hidden_size, }, }, "rnn_cell_share_config": True } opt_hparams = { "optimizer": { "type": "AdamOptimizer", "kwargs": { "learning_rate": 0.001, } } } ================================================ FILE: config/neural.py ================================================ _hidden_size = 200 _code_len = 800 _save_path = 'save/neural/model_1' _neural_save_path = 'save/neural/keyword_1' _max_epoch = 10 neural_opt_hparams = { "optimizer": { "type": "AdamOptimizer", "kwargs": { "learning_rate": 0.005, } }, "learning_rate_decay": { "type": "inverse_time_decay", "kwargs": { "decay_steps": 1600, "decay_rate": 0.8 }, "start_decay_step": 0, "end_decay_step": 16000, }, } source_encoder_hparams = { "encoder_minor_type": "BidirectionalRNNEncoder", "encoder_minor_hparams": { "rnn_cell_fw": { "type": "GRUCell", "kwargs": { "num_units": _hidden_size, }, }, "rnn_cell_share_config": True }, "encoder_major_type": "UnidirectionalRNNEncoder", "encoder_major_hparams": { "rnn_cell": { "type": "GRUCell", "kwargs": { "num_units": _hidden_size*2, }, } } } target_encoder_hparams = { "rnn_cell_fw": { "type": "GRUCell", "kwargs": { "num_units": _hidden_size, }, }, "rnn_cell_share_config": True } target_kwencoder_hparams = { "rnn_cell_fw": { "type": "GRUCell", "kwargs": { "num_units": _hidden_size, }, }, "rnn_cell_share_config": True } context_encoder_hparams = { "rnn_cell": { "type": "GRUCell", "kwargs": { "num_units": _hidden_size, }, } } opt_hparams = { "optimizer": { "type": "AdamOptimizer", "kwargs": { "learning_rate": 0.001, } } } ================================================ FILE: config/retrieval.py ================================================ _hidden_size = 200 _code_len = 200 _save_path = 'save/retrieval/model_1' _max_epoch = 10 source_encoder_hparams = { "encoder_minor_type": "UnidirectionalRNNEncoder", "encoder_minor_hparams": { "rnn_cell": { "type": "GRUCell", "kwargs": { "num_units": _hidden_size, }, }, }, "encoder_major_type": "UnidirectionalRNNEncoder", "encoder_major_hparams": { "rnn_cell": { "type": "GRUCell", "kwargs": { "num_units": _hidden_size, }, } } } target_encoder_hparams = { "rnn_cell": { "type": "GRUCell", "kwargs": { "num_units": _hidden_size, }, } } opt_hparams = { "optimizer": { "type": "AdamOptimizer", "kwargs": { "learning_rate": 0.001, } }, } ================================================ FILE: config/retrieval_stgy.py ================================================ _hidden_size = 200 _code_len = 200 _save_path = 'save/retrieval/model_1' _max_epoch = 10 source_encoder_hparams = { "encoder_minor_type": "UnidirectionalRNNEncoder", "encoder_minor_hparams": { "rnn_cell": { "type": "GRUCell", "kwargs": { "num_units": _hidden_size, }, }, }, "encoder_major_type": "UnidirectionalRNNEncoder", "encoder_major_hparams": { "rnn_cell": { "type": "GRUCell", "kwargs": { "num_units": _hidden_size, }, } } } target_encoder_hparams = { "rnn_cell": { "type": "GRUCell", "kwargs": { "num_units": _hidden_size, }, } } opt_hparams = { "optimizer": { "type": "AdamOptimizer", "kwargs": { "learning_rate": 0.001, } }, } ================================================ FILE: model/kernel.py ================================================ import texar as tx import tensorflow as tf import numpy as np from preprocess.data_utils import kw_tokenize class Predictor(): def __init__(self, config_model, config_data, mode=None): self.config = config_model self.data_config = config_data self.gpu_config = tf.ConfigProto() self.gpu_config.gpu_options.allow_growth = True self.build_model() def build_model(self): self.train_data = tx.data.MultiAlignedData(self.data_config.data_hparams['train']) self.valid_data = tx.data.MultiAlignedData(self.data_config.data_hparams['valid']) self.test_data = tx.data.MultiAlignedData(self.data_config.data_hparams['test']) self.iterator = tx.data.TrainTestDataIterator(train=self.train_data, val=self.valid_data, test=self.test_data) self.vocab = self.train_data.vocab(0) self.embedder = tx.modules.WordEmbedder(init_value=self.train_data.embedding_init_value(0).word_vecs) self.kw_embedder = tx.modules.WordEmbedder(init_value=self.train_data.embedding_init_value(0).word_vecs) self.source_encoder = tx.modules.HierarchicalRNNEncoder(hparams=self.config.source_encoder_hparams) self.target_encoder = tx.modules.BidirectionalRNNEncoder(hparams=self.config.target_encoder_hparams) self.target_kwencoder = tx.modules.BidirectionalRNNEncoder(hparams=self.config.target_kwencoder_hparams) self.linear_transform = tx.modules.MLPTransformConnector(self.config._code_len // 2) self.linear_matcher = tx.modules.MLPTransformConnector(1) self.linear_kernel = tx.modules.MLPTransformConnector(1) self.kw_list = self.vocab.map_tokens_to_ids(tf.convert_to_tensor(self.data_config._keywords_candi)) self.kw_vocab = tx.data.Vocab(self.data_config._keywords_path) self.keywords_embed = tf.nn.l2_normalize(self.kw_embedder(self.kw_list), axis=1) def forward_kernel(self, kw_embed, context_ids): kernel_sigma = self.config._kernel_sigma mu = tf.convert_to_tensor(self.config._kernel_mu) mask = tf.cast(context_ids > 3, dtype=tf.float32) context_embed = self.kw_embedder(context_ids) context_embed = tf.nn.l2_normalize(context_embed, axis=2) similarity_matrix = tf.reduce_sum(kw_embed * context_embed, axis=2) similarity_matrix = tf.tile(tf.expand_dims(similarity_matrix, 2), [1, 1, len(self.config._kernel_mu)]) matching_feature = tf.exp(-(similarity_matrix - mu) ** 2 / (kernel_sigma ** 2)) matching_feature = matching_feature * tf.tile(tf.expand_dims(mask, 2), [1, 1, len(self.config._kernel_mu)]) matching_feature = tf.reduce_sum(matching_feature, axis=1) matching_score = self.linear_kernel(matching_feature) matching_score = tf.squeeze(matching_score, 1) return matching_score def predict_keywords(self, batch): keywords_ids = self.kw_vocab.map_tokens_to_ids(batch['keywords_text']) matching_score = tf.map_fn(lambda kw_embed: self.forward_kernel(kw_embed, batch['context_text_ids']), self.keywords_embed, dtype=tf.float32, parallel_iterations=True) matching_score = tf.transpose(matching_score) matching_score = tf.nn.softmax(matching_score) kw_labels = tf.map_fn(lambda x: tf.sparse_to_dense(x, [self.kw_vocab.size], 1., 0., False), keywords_ids, dtype=tf.float32, parallel_iterations=True)[:, 4:] loss = tf.reduce_sum(-tf.log(matching_score) * kw_labels) / tf.reduce_sum(kw_labels) kw_ans = tf.arg_max(matching_score, -1) acc_label = tf.map_fn(lambda x: tf.gather(x[0], x[1]), (kw_labels, kw_ans), dtype=tf.float32) acc = tf.reduce_mean(acc_label) kws = tf.nn.top_k(matching_score, k=5)[1] kws = tf.reshape(kws,[-1]) kws = tf.map_fn(lambda x: self.kw_list[x], kws, dtype=tf.int64) kws = tf.reshape(kws,[-1, 5]) return loss, acc, kws def train_keywords(self): batch = self.iterator.get_next() loss, acc, _ = self.predict_keywords(batch) op_step = tf.Variable(0, name='op_step') train_op = tx.core.get_train_op(loss, global_step=op_step, hparams=self.config.kernel_opt_hparams) max_val_acc, stopping_flag = 0, 0 self.saver = tf.train.Saver() with tf.Session(config=self.gpu_config) as sess: sess.run(tf.global_variables_initializer()) sess.run(tf.local_variables_initializer()) sess.run(tf.tables_initializer()) for epoch_id in range(self.config._max_epoch): self.iterator.switch_to_train_data(sess) cur_step = 0 cnt_acc = [] while True: try: cur_step += 1 feed = {tx.global_mode(): tf.estimator.ModeKeys.TRAIN} loss_, acc_ = sess.run([train_op, acc], feed_dict=feed) cnt_acc.append(acc_) if cur_step % 100 == 0: print('batch {}, loss={}, acc1={}'.format(cur_step, loss_, np.mean(cnt_acc[-100:]))) except tf.errors.OutOfRangeError: break self.iterator.switch_to_val_data(sess) cnt_acc = [] while True: try: feed = {tx.global_mode(): tf.estimator.ModeKeys.EVAL} acc_ = sess.run(acc, feed_dict=feed) cnt_acc.append(acc_) except tf.errors.OutOfRangeError: mean_acc = np.mean(cnt_acc) if mean_acc > max_val_acc: max_val_acc = mean_acc self.saver.save(sess, self.config._kernel_save_path) else: stopping_flag += 1 print('epoch_id {}, valid acc1={}'.format(epoch_id+1, mean_acc)) break if stopping_flag >= self.config._early_stopping: break def test_keywords(self): batch = self.iterator.get_next() loss, acc, kws = self.predict_keywords(batch) saver = tf.train.Saver() with tf.Session(config=self.gpu_config) as sess: sess.run(tf.global_variables_initializer()) sess.run(tf.local_variables_initializer()) sess.run(tf.tables_initializer()) saver.restore(sess, self.config._kernel_save_path) self.iterator.switch_to_test_data(sess) cnt_acc, cnt_rec1, cnt_rec3, cnt_rec5 = [], [], [], [] while True: try: feed = {tx.global_mode(): tf.estimator.ModeKeys.PREDICT} acc_, kw_ans, kw_labels = sess.run([acc, kws, batch['keywords_text_ids']], feed_dict=feed) cnt_acc.append(acc_) rec = [0,0,0,0,0] sum_kws = 0 for i in range(len(kw_ans)): sum_kws += sum(kw_labels[i] > 3) for j in range(5): if kw_ans[i][j] in kw_labels[i]: for k in range(j, 5): rec[k] += 1 cnt_rec1.append(rec[0]/sum_kws) cnt_rec3.append(rec[2]/sum_kws) cnt_rec5.append(rec[4]/sum_kws) except tf.errors.OutOfRangeError: print('test_kw acc@1={:.4f}, rec@1={:.4f}, rec@3={:.4f}, rec@5={:.4f}'.format( np.mean(cnt_acc), np.mean(cnt_rec1), np.mean(cnt_rec3), np.mean(cnt_rec5))) break def forward(self, batch): matching_score = tf.map_fn(lambda kw_embed: self.forward_kernel(kw_embed, batch['context_text_ids']), self.keywords_embed, dtype=tf.float32, parallel_iterations=True) matching_score = tf.transpose(matching_score) kw_weight, predict_kw = tf.nn.top_k(matching_score, k=3) predict_kw = tf.reshape(predict_kw,[-1]) predict_kw = tf.map_fn(lambda x: self.kw_list[x], predict_kw, dtype=tf.int64) predict_kw = tf.reshape(predict_kw,[-1,3]) embed_code = self.embedder(predict_kw) embed_code = tf.reduce_sum(embed_code, axis=1) embed_code = self.linear_transform(embed_code) source_embed = self.embedder(batch['source_text_ids']) target_embed = self.embedder(batch['target_text_ids']) # bs * 20 * 32 * 200 target_embed = tf.reshape(target_embed,[-1, self.data_config._max_seq_len+2, self.embedder.dim]) # (bs * 20) * 32 * 200 target_length = tf.reshape(batch['target_length'],[-1]) # (bs * 20) * 32 * 200 source_code = self.source_encoder( source_embed, sequence_length_minor=batch['source_length'], sequence_length_major=batch['source_utterance_cnt'])[1] target_code = self.target_encoder( target_embed, sequence_length=target_length)[1] target_kwcode = self.target_kwencoder( target_embed, sequence_length=target_length)[1] target_code = tf.concat([target_code[0], target_code[1], target_kwcode[0], target_kwcode[1]], -1) target_code = tf.reshape(target_code, [-1,20,self.config._code_len]) source_code = tf.concat([source_code,embed_code], -1) source_code = tf.expand_dims(source_code, 1) source_code = tf.tile(source_code, [1,20,1]) feature_code = target_code * source_code feature_code = tf.reshape(feature_code,[-1,self.config._code_len]) logits = self.linear_matcher(feature_code) logits = tf.reshape(logits,[-1,20]) labels = tf.one_hot(batch['label'], 20) loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=labels, logits=logits)) ans = tf.arg_max(logits, -1) acc = tx.evals.accuracy(batch['label'], ans) rank = tf.nn.top_k(logits, k=20)[1] return loss, acc, rank def train(self): batch = self.iterator.get_next() loss_t, acc_t, _ = self.predict_keywords(batch) kw_saver = tf.train.Saver() loss, acc, _ = self.forward(batch) retrieval_step = tf.Variable(0, name='retrieval_step') train_op = tx.core.get_train_op(loss, global_step=retrieval_step, hparams=self.config.opt_hparams) max_val_acc, stopping_flag = 0, 0 with tf.Session(config=self.gpu_config) as sess: sess.run(tf.tables_initializer()) sess.run(tf.global_variables_initializer()) sess.run(tf.local_variables_initializer()) kw_saver.restore(sess, self.config._kernel_save_path) saver = tf.train.Saver() for epoch_id in range(self.config._max_epoch): self.iterator.switch_to_train_data(sess) cur_step = 0 cnt_acc, cnt_kwacc = [],[] while True: try: cur_step += 1 feed = {tx.global_mode(): tf.estimator.ModeKeys.TRAIN} loss, acc_, acc_kw = sess.run([train_op, acc, acc_t], feed_dict=feed) cnt_acc.append(acc_) cnt_kwacc.append(acc_kw) if cur_step % 200 == 0: print('batch {}, loss={}, acc1={}, kw_acc1={}'.format(cur_step, loss, np.mean(cnt_acc[-200:]) ,np.mean(cnt_kwacc[-200:]))) except tf.errors.OutOfRangeError: break self.iterator.switch_to_val_data(sess) cnt_acc, cnt_kwacc = [],[] while True: try: feed = {tx.global_mode(): tf.estimator.ModeKeys.EVAL} acc_, acc_kw = sess.run([acc, acc_t], feed_dict=feed) cnt_acc.append(acc_) cnt_kwacc.append(acc_kw) except tf.errors.OutOfRangeError: mean_acc = np.mean(cnt_acc) print('valid acc1={}, kw_acc1={}'.format(mean_acc, np.mean(cnt_kwacc))) if mean_acc > max_val_acc: max_val_acc = mean_acc saver.save(sess, self.config._save_path) else: stopping_flag += 1 break if stopping_flag >= self.config._early_stopping: break def test(self): batch = self.iterator.get_next() loss, acc, rank = self.forward(batch) with tf.Session(config=self.gpu_config) as sess: sess.run(tf.tables_initializer()) self.saver = tf.train.Saver() self.saver.restore(sess, self.config._save_path) self.iterator.switch_to_test_data(sess) rank_cnt = [] while True: try: feed = {tx.global_mode(): tf.estimator.ModeKeys.PREDICT} ranks, labels = sess.run([rank, batch['label']], feed_dict=feed) for i in range(len(ranks)): rank_cnt.append(np.where(ranks[i]==labels[i])[0][0]) except tf.errors.OutOfRangeError: rec = [0,0,0,0,0] MRR = 0 for rank in rank_cnt: for i in range(5): rec[i] += (rank <= i) MRR += 1 / (rank+1) print('test rec1@20={:.4f}, rec3@20={:.4f}, rec5@20={:.4f}, MRR={:.4f}'.format( rec[0]/len(rank_cnt), rec[2]/len(rank_cnt), rec[4]/len(rank_cnt), MRR/len(rank_cnt))) break def retrieve_init(self, sess): data_batch = self.iterator.get_next() loss, acc, _ = self.forward(data_batch) self.corpus = self.data_config._corpus self.corpus_data = tx.data.MonoTextData(self.data_config.corpus_hparams) corpus_iterator = tx.data.DataIterator(self.corpus_data) batch = corpus_iterator.get_next() corpus_embed = self.embedder(batch['corpus_text_ids']) utter_code = self.target_encoder(corpus_embed, sequence_length=batch['corpus_length'])[1] utter_kwcode = self.target_kwencoder(corpus_embed, sequence_length=batch['corpus_length'])[1] utter_code = tf.concat([utter_code[0], utter_code[1], utter_kwcode[0], utter_kwcode[1]], -1) self.corpus_code = np.zeros([0, self.config._code_len]) corpus_iterator.switch_to_dataset(sess) sess.run(tf.tables_initializer()) saver = tf.train.Saver() saver.restore(sess, self.config._save_path) feed = {tx.global_mode(): tf.estimator.ModeKeys.PREDICT} while True: try: utter_code_ = sess.run(utter_code, feed_dict=feed) self.corpus_code = np.concatenate([self.corpus_code, utter_code_], axis=0) except tf.errors.OutOfRangeError: break self.kw_embedding = sess.run(self.keywords_embed) # predict keyword self.context_input = tf.placeholder(dtype=object) context_ids = tf.expand_dims(self.vocab.map_tokens_to_ids(self.context_input), 0) matching_score = tf.map_fn(lambda kw_embed: self.forward_kernel(kw_embed, context_ids), self.keywords_embed, dtype=tf.float32, parallel_iterations=True) self.candi_output = tf.nn.top_k(tf.squeeze(matching_score, 1), self.data_config._keywords_num)[1] # retrieve self.minor_length_input = tf.placeholder(dtype=tf.int32, shape=(1, 9)) self.major_length_input = tf.placeholder(dtype=tf.int32, shape=(1)) self.history_input = tf.placeholder(dtype=object, shape=(9, self.data_config._max_seq_len + 2)) self.kw_input = tf.placeholder(dtype=tf.int32) history_ids = self.vocab.map_tokens_to_ids(self.history_input) history_embed = self.embedder(history_ids) history_code = self.source_encoder(tf.expand_dims(history_embed, axis=0), sequence_length_minor=self.minor_length_input, sequence_length_major=self.major_length_input)[1] self.next_kw_ids = self.kw_list[self.kw_input] embed_code = tf.expand_dims(self.embedder(self.next_kw_ids), 0) embed_code = self.linear_transform(embed_code) history_code = tf.concat([history_code, embed_code], 1) select_corpus = tf.cast(self.corpus_code, dtype=tf.float32) feature_code = self.linear_matcher(select_corpus * history_code) self.ans_output = tf.nn.top_k(tf.squeeze(feature_code,1), k=self.data_config._retrieval_candidates)[1] def retrieve(self, history_all, sess): history, seq_len, turns, context, context_len = history_all kw_candi = sess.run(self.candi_output, feed_dict={self.context_input: context[:context_len]}) for kw in kw_candi: tmp_score = sum(self.kw_embedding[kw] * self.kw_embedding[self.data_config._keywords_dict[self.target]]) if tmp_score > self.score: self.score = tmp_score self.next_kw = self.data_config._keywords_candi[kw] break ans = sess.run(self.ans_output, feed_dict={self.history_input: history, self.minor_length_input: [seq_len], self.major_length_input: [turns], self.kw_input: self.data_config._keywords_dict[self.next_kw]}) flag = 0 reply = self.corpus[ans[0]] for i in ans: if i in self.reply_list: # avoid repeat continue for wd in kw_tokenize(self.corpus[i]): if wd in self.data_config._keywords_candi: tmp_score = sum(self.kw_embedding[self.data_config._keywords_dict[wd]] * self.kw_embedding[self.data_config._keywords_dict[self.target]]) if tmp_score > self.score: reply = self.corpus[i] self.score = tmp_score self.next_kw = wd flag = 1 break if flag == 0: continue break return reply ================================================ FILE: model/matrix.py ================================================ import texar as tx import tensorflow as tf import numpy as np import pickle class Predictor(): def __init__(self, config_model, config_data, mode=None): self.config = config_model self.data_config = config_data self.gpu_config = tf.ConfigProto() self.gpu_config.gpu_options.allow_growth = True self.build_model(mode) def build_model(self, mode): self.train_data = tx.data.MultiAlignedData(self.data_config.data_hparams['train']) self.valid_data = tx.data.MultiAlignedData(self.data_config.data_hparams['valid']) self.test_data = tx.data.MultiAlignedData(self.data_config.data_hparams['test']) self.iterator = tx.data.TrainTestDataIterator(train=self.train_data, val=self.valid_data, test=self.test_data) self.vocab = self.train_data.vocab(0) self.source_encoder = tx.modules.HierarchicalRNNEncoder(hparams=self.config.source_encoder_hparams) self.target_encoder = tx.modules.BidirectionalRNNEncoder(hparams=self.config.target_encoder_hparams) self.target_kwencoder = tx.modules.BidirectionalRNNEncoder(hparams=self.config.target_kwencoder_hparams) self.linear_transform = tx.modules.MLPTransformConnector(self.config._code_len // 2) self.linear_matcher = tx.modules.MLPTransformConnector(1) self.embedder = tx.modules.WordEmbedder(init_value=self.train_data.embedding_init_value(0).word_vecs) self.kw_list = self.vocab.map_tokens_to_ids(tf.convert_to_tensor(self.data_config._keywords_candi)) self.kw_vocab = tx.data.Vocab(self.data_config._keywords_path) if mode == 'train_kw': self.pmi_matrix = np.zeros([self.config._vocab_size+4, self.data_config._keywords_num]) else: with open(self.config._matrix_save_path, 'rb') as f: matrix = pickle.load(f) self.pmi_matrix = tf.convert_to_tensor(matrix,dtype=tf.float32) def forward_matrix(self, context_ids): matching_score = tf.gather(self.pmi_matrix, context_ids) return tf.reduce_sum(tf.log(matching_score), axis=0) def predict_keywords(self, batch): keywords_ids = self.kw_vocab.map_tokens_to_ids(batch['keywords_text']) matching_score = tf.map_fn(lambda x: self.forward_matrix(x), batch['context_text_ids'], dtype=tf.float32, parallel_iterations=True) kw_labels = tf.map_fn(lambda x: tf.sparse_to_dense(x, [self.kw_vocab.size], 1., 0., False), keywords_ids, dtype=tf.float32, parallel_iterations=True)[:, 4:] kw_ans = tf.arg_max(matching_score, -1) acc_label = tf.map_fn(lambda x: tf.gather(x[0], x[1]), (kw_labels, kw_ans), dtype=tf.float32) acc = tf.reduce_mean(acc_label) kws = tf.nn.top_k(matching_score, k=5)[1] kws = tf.reshape(kws,[-1]) kws = tf.map_fn(lambda x: self.kw_list[x], kws, dtype=tf.int64) kws = tf.reshape(kws,[-1, 5]) return acc, kws def train_keywords(self): batch = self.iterator.get_next() acc, _ = self.predict_keywords(batch) with tf.Session(config=self.gpu_config) as sess: sess.run(tf.global_variables_initializer()) sess.run(tf.local_variables_initializer()) sess.run(tf.tables_initializer()) self.iterator.switch_to_train_data(sess) batchid = 0 while True: try: batchid += 1 feed = {tx.global_mode(): tf.estimator.ModeKeys.TRAIN} source_keywords, target_keywords = sess.run([batch['context_text_ids'], batch['keywords_text_ids']], feed_dict=feed) for i in range(len(source_keywords)): for skw_id in source_keywords[i]: if skw_id == 0: break for tkw_id in target_keywords[i]: if skw_id >= 3 and tkw_id >= 3: tkw = self.config._vocab[tkw_id-4] if tkw in self.data_config._keywords_candi: tkw_id = self.data_config._keywords_dict[tkw] self.pmi_matrix[skw_id][tkw_id] += 1 except tf.errors.OutOfRangeError: break self.pmi_matrix += 0.5 self.pmi_matrix = self.pmi_matrix / (np.sum(self.pmi_matrix, axis=0) + 1) with open(self.config._matrix_save_path,'wb') as f: pickle.dump(self.pmi_matrix, f) def test_keywords(self): batch = self.iterator.get_next() acc, kws = self.predict_keywords(batch) with tf.Session(config=self.gpu_config) as sess: sess.run(tf.global_variables_initializer()) sess.run(tf.local_variables_initializer()) sess.run(tf.tables_initializer()) self.iterator.switch_to_test_data(sess) cnt_acc, cnt_rec1, cnt_rec3, cnt_rec5 = [], [], [], [] while True: try: feed = {tx.global_mode(): tf.estimator.ModeKeys.PREDICT} acc_, kw_ans, kw_labels = sess.run([acc, kws, batch['keywords_text_ids']], feed_dict=feed) cnt_acc.append(acc_) rec = [0,0,0,0,0] sum_kws = 0 for i in range(len(kw_ans)): sum_kws += sum(kw_labels[i] > 3) for j in range(5): if kw_ans[i][j] in kw_labels[i]: for k in range(j, 5): rec[k] += 1 cnt_rec1.append(rec[0]/sum_kws) cnt_rec3.append(rec[2]/sum_kws) cnt_rec5.append(rec[4]/sum_kws) except tf.errors.OutOfRangeError: print('test_kw acc@1={:.4f}, rec@1={:.4f}, rec@3={:.4f}, rec@5={:.4f}'.format( np.mean(cnt_acc), np.mean(cnt_rec1), np.mean(cnt_rec3), np.mean(cnt_rec5))) break def forward(self, batch): matching_score = tf.map_fn(lambda x: self.forward_matrix(x), batch['context_text_ids'], dtype=tf.float32, parallel_iterations=True) kw_weight, predict_kw = tf.nn.top_k(matching_score, k=3) predict_kw = tf.reshape(predict_kw, [-1]) predict_kw = tf.map_fn(lambda x: self.kw_list[x], predict_kw, dtype=tf.int64) predict_kw = tf.reshape(predict_kw, [-1, 3]) embed_code = self.embedder(predict_kw) embed_code = tf.reduce_sum(embed_code, axis=1) embed_code = self.linear_transform(embed_code) source_embed = self.embedder(batch['source_text_ids']) target_embed = self.embedder(batch['target_text_ids']) target_embed = tf.reshape(target_embed, [-1, self.data_config._max_seq_len + 2, self.embedder.dim]) target_length = tf.reshape(batch['target_length'], [-1]) source_code = self.source_encoder( source_embed, sequence_length_minor=batch['source_length'], sequence_length_major=batch['source_utterance_cnt'])[1] target_code = self.target_encoder( target_embed, sequence_length=target_length)[1] target_kwcode = self.target_kwencoder( target_embed, sequence_length=target_length)[1] target_code = tf.concat([target_code[0], target_code[1], target_kwcode[0], target_kwcode[1]], -1) target_code = tf.reshape(target_code, [-1, 20, self.config._code_len]) source_code = tf.concat([source_code, embed_code], -1) source_code = tf.expand_dims(source_code, 1) source_code = tf.tile(source_code, [1, 20, 1]) feature_code = target_code * source_code feature_code = tf.reshape(feature_code, [-1, self.config._code_len]) logits = self.linear_matcher(feature_code) logits = tf.reshape(logits, [-1, 20]) labels = tf.one_hot(batch['label'], 20) loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=labels, logits=logits)) ans = tf.arg_max(logits, -1) acc = tx.evals.accuracy(batch['label'], ans) rank = tf.nn.top_k(logits, k=20)[1] return loss, acc, rank def train(self): batch = self.iterator.get_next() loss, acc, _ = self.forward(batch) op_step = tf.Variable(0, name='retrieval_step') train_op = tx.core.get_train_op(loss, global_step=op_step, hparams=self.config.opt_hparams) max_val_acc = 0. with tf.Session(config=self.gpu_config) as sess: sess.run(tf.tables_initializer()) sess.run(tf.global_variables_initializer()) sess.run(tf.local_variables_initializer()) saver = tf.train.Saver() for epoch_id in range(self.config._max_epoch): self.iterator.switch_to_train_data(sess) cur_step = 0 cnt_acc = [] while True: try: cur_step += 1 feed = {tx.global_mode(): tf.estimator.ModeKeys.TRAIN} loss, acc_ = sess.run([train_op, acc], feed_dict=feed) cnt_acc.append(acc_) if cur_step % 200 == 0: print('batch {}, loss={}, acc1={}'.format(cur_step, loss, np.mean(cnt_acc[-200:]))) except tf.errors.OutOfRangeError: break self.iterator.switch_to_val_data(sess) cnt_acc= [] while True: try: feed = {tx.global_mode(): tf.estimator.ModeKeys.EVAL} acc_ = sess.run(acc, feed_dict=feed) cnt_acc.append(acc_) except tf.errors.OutOfRangeError: mean_acc = np.mean(cnt_acc) print('valid acc1={}'.format(mean_acc)) if mean_acc > max_val_acc: max_val_acc = mean_acc saver.save(sess, self.config._save_path) break def test(self): batch = self.iterator.get_next() loss, acc, rank = self.forward(batch) with tf.Session(config=self.gpu_config) as sess: sess.run(tf.tables_initializer()) self.saver = tf.train.Saver() self.saver.restore(sess, self.config._save_path) self.iterator.switch_to_test_data(sess) rank_cnt = [] while True: try: feed = {tx.global_mode(): tf.estimator.ModeKeys.PREDICT} ranks, labels = sess.run([rank, batch['label']], feed_dict=feed) for i in range(len(ranks)): rank_cnt.append(np.where(ranks[i]==labels[i])[0][0]) except tf.errors.OutOfRangeError: rec = [0,0,0,0,0] MRR = 0 for rank in rank_cnt: for i in range(5): rec[i] += (rank <= i) MRR += 1 / (rank+1) print('test rec1@20={:.4f}, rec3@20={:.4f}, rec5@20={:.4f}, MRR={:.4f}'.format( rec[0]/len(rank_cnt), rec[2]/len(rank_cnt), rec[4]/len(rank_cnt), MRR/len(rank_cnt))) break def retrieve_init(self, sess): data_batch = self.iterator.get_next() loss, acc, _ = self.forward(data_batch) self.corpus = self.data_config._corpus self.corpus_data = tx.data.MonoTextData(self.data_config.corpus_hparams) corpus_iterator = tx.data.DataIterator(self.corpus_data) batch = corpus_iterator.get_next() corpus_embed = self.embedder(batch['corpus_text_ids']) utter_code = self.target_encoder(corpus_embed, sequence_length=batch['corpus_length'])[1] utter_kwcode = self.target_kwencoder(corpus_embed, sequence_length=batch['corpus_length'])[1] utter_code = tf.concat([utter_code[0], utter_code[1], utter_kwcode[0], utter_kwcode[1]], -1) self.corpus_code = np.zeros([0, self.config._code_len]) corpus_iterator.switch_to_dataset(sess) sess.run(tf.tables_initializer()) saver = tf.train.Saver() saver.restore(sess, self.config._save_path) feed = {tx.global_mode(): tf.estimator.ModeKeys.PREDICT} while True: try: utter_code_ = sess.run(utter_code, feed_dict=feed) self.corpus_code = np.concatenate([self.corpus_code, utter_code_], axis=0) except tf.errors.OutOfRangeError: break self.minor_length_input = tf.placeholder(dtype=tf.int32, shape=(1, 9)) self.major_length_input = tf.placeholder(dtype=tf.int32, shape=(1)) self.history_input = tf.placeholder(dtype=object, shape=(9, self.data_config._max_seq_len + 2)) self.keywords_embed = tf.nn.l2_normalize(self.embedder(self.kw_list), axis=1) self.kw_embedding = sess.run(self.keywords_embed) # predict keyword self.context_input = tf.placeholder(dtype=object) context_ids = self.vocab.map_tokens_to_ids(self.context_input) matching_score = self.forward_matrix(context_ids) self.candi_output =tf.nn.top_k(matching_score, self.data_config._keywords_num)[1] # retrieve self.minor_length_input = tf.placeholder(dtype=tf.int32, shape=(1, 9)) self.major_length_input = tf.placeholder(dtype=tf.int32, shape=(1)) self.history_input = tf.placeholder(dtype=object, shape=(9, self.data_config._max_seq_len + 2)) self.kw_input = tf.placeholder(dtype=tf.int32) history_ids = self.vocab.map_tokens_to_ids(self.history_input) history_embed = self.embedder(history_ids) history_code = self.source_encoder(tf.expand_dims(history_embed, axis=0), sequence_length_minor=self.minor_length_input, sequence_length_major=self.major_length_input)[1] self.next_kw_ids = self.kw_list[self.kw_input] embed_code = tf.expand_dims(self.embedder(self.next_kw_ids), 0) embed_code = self.linear_transform(embed_code) history_code = tf.concat([history_code, embed_code], 1) select_corpus = tf.cast(self.corpus_code, dtype=tf.float32) feature_code = self.linear_matcher(select_corpus * history_code) self.ans_output = tf.nn.top_k(tf.squeeze(feature_code,1), k=self.data_config._retrieval_candidates)[1] def retrieve(self, history_all, sess): history, seq_len, turns, context, context_len = history_all kw_candi = sess.run(self.candi_output, feed_dict={self.context_input: context[:context_len]}) for kw in kw_candi: tmp_score = sum(self.kw_embedding[kw] * self.kw_embedding[self.data_config._keywords_dict[self.target]]) if tmp_score > self.score: self.score = tmp_score self.next_kw = self.data_config._keywords_candi[kw] break ans = sess.run(self.ans_output, feed_dict={self.history_input: history, self.minor_length_input: [seq_len], self.major_length_input: [turns], self.kw_input: self.data_config._keywords_dict[self.next_kw]}) for i in range(self.data_config._max_turns + 1): if ans[i] not in self.reply_list: self.reply_list.append(ans[i]) reply = self.corpus[ans[i]] break return reply ================================================ FILE: model/neural.py ================================================ import texar as tx import tensorflow as tf import numpy as np from preprocess.data_utils import kw_tokenize class Predictor(): def __init__(self, config_model, config_data, mode=None): self.config = config_model self.data_config = config_data self.gpu_config = tf.ConfigProto() self.gpu_config.gpu_options.allow_growth = True self.build_model() def build_model(self): self.train_data = tx.data.MultiAlignedData(self.data_config.data_hparams['train']) self.valid_data = tx.data.MultiAlignedData(self.data_config.data_hparams['valid']) self.test_data = tx.data.MultiAlignedData(self.data_config.data_hparams['test']) self.iterator = tx.data.TrainTestDataIterator(train=self.train_data, val=self.valid_data, test=self.test_data) self.vocab = self.train_data.vocab(0) self.source_encoder = tx.modules.HierarchicalRNNEncoder(hparams=self.config.source_encoder_hparams) self.target_encoder = tx.modules.BidirectionalRNNEncoder(hparams=self.config.target_encoder_hparams) self.target_kwencoder = tx.modules.BidirectionalRNNEncoder(hparams=self.config.target_kwencoder_hparams) self.linear_transform = tx.modules.MLPTransformConnector(self.config._code_len // 2) self.linear_matcher = tx.modules.MLPTransformConnector(1) self.context_encoder = tx.modules.UnidirectionalRNNEncoder(hparams=self.config.context_encoder_hparams) self.predict_layer = tx.modules.MLPTransformConnector(self.data_config._keywords_num) self.embedder = tx.modules.WordEmbedder(init_value=self.train_data.embedding_init_value(0).word_vecs) self.kw_list = self.vocab.map_tokens_to_ids(tf.convert_to_tensor(self.data_config._keywords_candi)) self.kw_vocab = tx.data.Vocab(self.data_config._keywords_path) def forward_neural(self, context_ids, context_length): context_embed = self.embedder(context_ids) context_code = self.context_encoder(context_embed, sequence_length=context_length)[1] keyword_score = self.predict_layer(context_code) return keyword_score def predict_keywords(self, batch): matching_score = self.forward_neural(batch['context_text_ids'], batch['context_length']) keywords_ids = self.kw_vocab.map_tokens_to_ids(batch['keywords_text']) kw_labels = tf.map_fn(lambda x: tf.sparse_to_dense(x, [self.kw_vocab.size], 1., 0., False), keywords_ids, dtype=tf.float32, parallel_iterations=True)[:, 4:] loss = tf.nn.sigmoid_cross_entropy_with_logits(labels=kw_labels, logits=matching_score) loss = tf.reduce_mean(loss) kw_ans = tf.arg_max(matching_score, -1) acc_label = tf.map_fn(lambda x: tf.gather(x[0], x[1]), (kw_labels, kw_ans), dtype=tf.float32) acc = tf.reduce_mean(acc_label) kws = tf.nn.top_k(matching_score, k=5)[1] kws = tf.reshape(kws,[-1]) kws = tf.map_fn(lambda x: self.kw_list[x], kws, dtype=tf.int64) kws = tf.reshape(kws,[-1, 5]) return loss, acc, kws def train_keywords(self): batch = self.iterator.get_next() loss, acc, _ = self.predict_keywords(batch) op_step = tf.Variable(0, name='op_step') train_op = tx.core.get_train_op(loss, global_step=op_step, hparams=self.config.neural_opt_hparams) max_val_acc = 0. self.saver = tf.train.Saver() with tf.Session(config=self.gpu_config) as sess: sess.run(tf.global_variables_initializer()) sess.run(tf.local_variables_initializer()) sess.run(tf.tables_initializer()) for epoch_id in range(self.config._max_epoch): self.iterator.switch_to_train_data(sess) cur_step = 0 cnt_acc = [] while True: try: cur_step += 1 feed = {tx.global_mode(): tf.estimator.ModeKeys.TRAIN} loss_, acc_ = sess.run([train_op, acc], feed_dict=feed) cnt_acc.append(acc_) if cur_step % 200 == 0: print('batch {}, loss={}, acc1={}'.format(cur_step, loss_, np.mean(cnt_acc[-200:]))) except tf.errors.OutOfRangeError: break self.iterator.switch_to_val_data(sess) cnt_acc = [] while True: try: feed = {tx.global_mode(): tf.estimator.ModeKeys.EVAL} acc_ = sess.run(acc, feed_dict=feed) cnt_acc.append(acc_) except tf.errors.OutOfRangeError: mean_acc = np.mean(cnt_acc) if mean_acc > max_val_acc: max_val_acc = mean_acc self.saver.save(sess, self.config._neural_save_path) print('epoch_id {}, valid acc1={}'.format(epoch_id+1, mean_acc)) break def test_keywords(self): batch = self.iterator.get_next() loss, acc, kws = self.predict_keywords(batch) saver = tf.train.Saver() with tf.Session(config=self.gpu_config) as sess: sess.run(tf.global_variables_initializer()) sess.run(tf.local_variables_initializer()) sess.run(tf.tables_initializer()) saver.restore(sess, self.config._neural_save_path) self.iterator.switch_to_test_data(sess) cnt_acc, cnt_rec1, cnt_rec3, cnt_rec5 = [], [], [], [] while True: try: feed = {tx.global_mode(): tf.estimator.ModeKeys.PREDICT} acc_, kw_ans, kw_labels = sess.run([acc, kws, batch['keywords_text_ids']], feed_dict=feed) cnt_acc.append(acc_) rec = [0,0,0,0,0] sum_kws = 0 for i in range(len(kw_ans)): sum_kws += sum(kw_labels[i] > 3) for j in range(5): if kw_ans[i][j] in kw_labels[i]: for k in range(j, 5): rec[k] += 1 cnt_rec1.append(rec[0]/sum_kws) cnt_rec3.append(rec[2]/sum_kws) cnt_rec5.append(rec[4]/sum_kws) except tf.errors.OutOfRangeError: print('test_kw acc@1={:.4f}, rec@1={:.4f}, rec@3={:.4f}, rec@5={:.4f}'.format( np.mean(cnt_acc), np.mean(cnt_rec1), np.mean(cnt_rec3), np.mean(cnt_rec5))) break def forward(self, batch): matching_score = self.forward_neural(batch['context_text_ids'], batch['context_length']) kw_weight, predict_kw = tf.nn.top_k(matching_score, k=3) predict_kw = tf.reshape(predict_kw, [-1]) predict_kw = tf.map_fn(lambda x: self.kw_list[x], predict_kw, dtype=tf.int64) predict_kw = tf.reshape(predict_kw, [-1, 3]) embed_code = self.embedder(predict_kw) embed_code = tf.reduce_sum(embed_code, axis=1) embed_code = self.linear_transform(embed_code) source_embed = self.embedder(batch['source_text_ids']) target_embed = self.embedder(batch['target_text_ids']) target_embed = tf.reshape(target_embed, [-1, self.data_config._max_seq_len + 2, self.embedder.dim]) target_length = tf.reshape(batch['target_length'], [-1]) source_code = self.source_encoder( source_embed, sequence_length_minor=batch['source_length'], sequence_length_major=batch['source_utterance_cnt'])[1] # target_code = self.target_encoder( target_embed, sequence_length=target_length)[1] target_kwcode = self.target_kwencoder( target_embed, sequence_length=target_length)[1] target_code = tf.concat([target_code[0], target_code[1], target_kwcode[0], target_kwcode[1]], -1) target_code = tf.reshape(target_code, [-1, 20, self.config._code_len]) source_code = tf.concat([source_code, embed_code], -1) source_code = tf.expand_dims(source_code, 1) source_code = tf.tile(source_code, [1, 20, 1]) feature_code = target_code * source_code feature_code = tf.reshape(feature_code, [-1, self.config._code_len]) logits = self.linear_matcher(feature_code) logits = tf.reshape(logits, [-1, 20]) labels = tf.one_hot(batch['label'], 20) loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=labels, logits=logits)) ans = tf.arg_max(logits, -1) acc = tx.evals.accuracy(batch['label'], ans) rank = tf.nn.top_k(logits, k=20)[1] return loss, acc, rank def train(self): batch = self.iterator.get_next() kw_loss, kw_acc, _ = self.predict_keywords(batch) kw_saver = tf.train.Saver() loss, acc, _ = self.forward(batch) op_step = tf.Variable(0, name='retrieval_step') train_op = tx.core.get_train_op(loss, global_step=op_step, hparams=self.config.opt_hparams) max_val_acc = 0. with tf.Session(config=self.gpu_config) as sess: sess.run(tf.tables_initializer()) sess.run(tf.global_variables_initializer()) sess.run(tf.local_variables_initializer()) kw_saver.restore(sess, self.config._neural_save_path) saver = tf.train.Saver() for epoch_id in range(self.config._max_epoch): self.iterator.switch_to_train_data(sess) cur_step = 0 cnt_acc = [] while True: try: cur_step += 1 feed = {tx.global_mode(): tf.estimator.ModeKeys.TRAIN} loss, acc_ = sess.run([train_op, acc], feed_dict=feed) cnt_acc.append(acc_) if cur_step % 200 == 0: print('batch {}, loss={}, acc1={}'.format(cur_step, loss, np.mean(cnt_acc[-200:]))) except tf.errors.OutOfRangeError: break self.iterator.switch_to_val_data(sess) cnt_acc, cnt_kwacc = [], [] while True: try: feed = {tx.global_mode(): tf.estimator.ModeKeys.EVAL} acc_, kw_acc_ = sess.run([acc, kw_acc], feed_dict=feed) cnt_acc.append(acc_) cnt_kwacc.append(kw_acc_) except tf.errors.OutOfRangeError: mean_acc = np.mean(cnt_acc) print('valid acc1={}, kw_acc1={}'.format(mean_acc, np.mean(cnt_kwacc))) if mean_acc > max_val_acc: max_val_acc = mean_acc saver.save(sess, self.config._save_path) break def test(self): batch = self.iterator.get_next() loss, acc, rank = self.forward(batch) with tf.Session(config=self.gpu_config) as sess: sess.run(tf.tables_initializer()) self.saver = tf.train.Saver() self.saver.restore(sess, self.config._save_path) self.iterator.switch_to_test_data(sess) rank_cnt = [] while True: try: feed = {tx.global_mode(): tf.estimator.ModeKeys.PREDICT} ranks, labels = sess.run([rank, batch['label']], feed_dict=feed) for i in range(len(ranks)): rank_cnt.append(np.where(ranks[i]==labels[i])[0][0]) except tf.errors.OutOfRangeError: rec = [0,0,0,0,0] MRR = 0 for rank in rank_cnt: for i in range(5): rec[i] += (rank <= i) MRR += 1 / (rank+1) print('test rec1@20={:.4f}, rec3@20={:.4f}, rec5@20={:.4f}, MRR={:.4f}'.format( rec[0]/len(rank_cnt), rec[2]/len(rank_cnt), rec[4]/len(rank_cnt), MRR/len(rank_cnt))) break def retrieve_init(self, sess): data_batch = self.iterator.get_next() loss, acc, _ = self.forward(data_batch) self.corpus = self.data_config._corpus self.corpus_data = tx.data.MonoTextData(self.data_config.corpus_hparams) corpus_iterator = tx.data.DataIterator(self.corpus_data) batch = corpus_iterator.get_next() corpus_embed = self.embedder(batch['corpus_text_ids']) utter_code = self.target_encoder(corpus_embed, sequence_length=batch['corpus_length'])[1] utter_kwcode = self.target_kwencoder(corpus_embed, sequence_length=batch['corpus_length'])[1] utter_code = tf.concat([utter_code[0], utter_code[1], utter_kwcode[0], utter_kwcode[1]], -1) self.corpus_code = np.zeros([0, self.config._code_len]) corpus_iterator.switch_to_dataset(sess) sess.run(tf.tables_initializer()) saver = tf.train.Saver() saver.restore(sess, self.config._save_path) feed = {tx.global_mode(): tf.estimator.ModeKeys.PREDICT} while True: try: utter_code_ = sess.run(utter_code, feed_dict=feed) self.corpus_code = np.concatenate([self.corpus_code, utter_code_], axis=0) except tf.errors.OutOfRangeError: break self.keywords_embed = tf.nn.l2_normalize(self.embedder(self.kw_list), axis=1) self.kw_embedding = sess.run(self.keywords_embed) # predict keyword self.context_input = tf.placeholder(dtype=object, shape=(20)) self.context_length_input = tf.placeholder(dtype=tf.int32, shape=(1)) context_ids = tf.expand_dims(self.vocab.map_tokens_to_ids(self.context_input), 0) context_embed = self.embedder(context_ids) context_code = self.context_encoder(context_embed, sequence_length=self.context_length_input)[1] matching_score = self.predict_layer(context_code) self.candi_output =tf.nn.top_k(tf.squeeze(matching_score, 0), self.data_config._keywords_num)[1] # retrieve self.minor_length_input = tf.placeholder(dtype=tf.int32, shape=(1, 9)) self.major_length_input = tf.placeholder(dtype=tf.int32, shape=(1)) self.history_input = tf.placeholder(dtype=object, shape=(9, self.data_config._max_seq_len + 2)) self.kw_input = tf.placeholder(dtype=tf.int32) history_ids = self.vocab.map_tokens_to_ids(self.history_input) history_embed = self.embedder(history_ids) history_code = self.source_encoder(tf.expand_dims(history_embed, axis=0), sequence_length_minor=self.minor_length_input, sequence_length_major=self.major_length_input)[1] self.next_kw_ids = self.kw_list[self.kw_input] embed_code = tf.expand_dims(self.embedder(self.next_kw_ids), 0) embed_code = self.linear_transform(embed_code) history_code = tf.concat([history_code, embed_code], 1) select_corpus = tf.cast(self.corpus_code, dtype=tf.float32) feature_code = self.linear_matcher(select_corpus * history_code) self.ans_output = tf.nn.top_k(tf.squeeze(feature_code,1), k=self.data_config._retrieval_candidates)[1] def retrieve(self, history_all, sess): history, seq_len, turns, context, context_len = history_all kw_candi = sess.run(self.candi_output, feed_dict={self.context_input: context, self.context_length_input: [context_len]}) for kw in kw_candi: tmp_score = sum(self.kw_embedding[kw] * self.kw_embedding[self.data_config._keywords_dict[self.target]]) if tmp_score > self.score: self.score = tmp_score self.next_kw = self.data_config._keywords_candi[kw] break ans = sess.run(self.ans_output, feed_dict={self.history_input: history, self.minor_length_input: [seq_len], self.major_length_input: [turns], self.kw_input: self.data_config._keywords_dict[self.next_kw]}) flag = 0 reply = self.corpus[ans[0]] for i in ans: if i in self.reply_list: # avoid repeat continue for wd in kw_tokenize(self.corpus[i]): if wd in self.data_config._keywords_candi: tmp_score = sum(self.kw_embedding[self.data_config._keywords_dict[wd]] * self.kw_embedding[self.data_config._keywords_dict[self.target]]) if tmp_score > self.score: reply = self.corpus[i] self.score = tmp_score self.next_kw = wd flag = 1 break if flag == 0: continue break return reply ================================================ FILE: model/retrieval.py ================================================ import texar as tx import tensorflow as tf import numpy as np class Predictor(): def __init__(self, config_model, config_data, mode=None): self.config = config_model self.data_config = config_data self.build_model() self.gpu_config = tf.ConfigProto() self.gpu_config.gpu_options.allow_growth = True def build_model(self): self.train_data = tx.data.MultiAlignedData(self.data_config.data_hparams['train']) self.valid_data = tx.data.MultiAlignedData(self.data_config.data_hparams['valid']) self.test_data = tx.data.MultiAlignedData(self.data_config.data_hparams['test']) self.iterator = tx.data.TrainTestDataIterator(train=self.train_data, val=self.valid_data, test=self.test_data) self.vocab = self.train_data.vocab(0) self.embedder = tx.modules.WordEmbedder(init_value=self.train_data.embedding_init_value(0).word_vecs) self.source_encoder = tx.modules.HierarchicalRNNEncoder(hparams=self.config.source_encoder_hparams) self.target_encoder = tx.modules.UnidirectionalRNNEncoder(hparams=self.config.target_encoder_hparams) self.linear_matcher = tx.modules.MLPTransformConnector(1) def forward(self, batch): source_embed = self.embedder(batch['source_text_ids']) target_embed = self.embedder(batch['target_text_ids']) target_embed = tf.reshape(target_embed, [-1, self.data_config._max_seq_len + 2, self.embedder.dim]) source_code = self.source_encoder(source_embed, sequence_length_minor=batch['source_length'], sequence_length_major=batch['source_utterance_cnt'])[1] target_length = tf.reshape(batch['target_length'], [-1]) target_code = self.target_encoder(target_embed, sequence_length=target_length)[1] target_code = tf.reshape(target_code, [-1, 20, self.config._code_len]) source_code = tf.expand_dims(source_code, 1) source_code = tf.tile(source_code, [1, 20, 1]) feature_code = target_code * source_code feature_code = tf.reshape(feature_code, [-1, self.config._code_len]) logits = self.linear_matcher(feature_code) logits = tf.reshape(logits, [-1, 20]) labels = tf.one_hot(batch['label'], 20) loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=labels, logits=logits)) ans = tf.arg_max(logits, -1) acc = tx.evals.accuracy(batch['label'], ans) rank = tf.nn.top_k(logits, k=20)[1] return loss, acc, rank def train(self): batch = self.iterator.get_next() loss, acc, _ = self.forward(batch) op_step = tf.Variable(0, name='op_step') train_op = tx.core.get_train_op(loss, global_step=op_step, hparams=self.config.opt_hparams) max_val_acc = 0. self.saver = tf.train.Saver() with tf.Session(config=self.gpu_config) as sess: sess.run(tf.global_variables_initializer()) sess.run(tf.local_variables_initializer()) sess.run(tf.tables_initializer()) for epoch_id in range(self.config._max_epoch): self.iterator.switch_to_train_data(sess) cur_step = 0 cnt_acc = [] while True: try: cur_step += 1 feed = {tx.global_mode(): tf.estimator.ModeKeys.TRAIN} loss, acc_ = sess.run([train_op, acc], feed_dict=feed) cnt_acc.append(acc_) if cur_step % 200 == 0: print('batch {}, loss={}, acc1={}'.format(cur_step, loss, np.mean(cnt_acc[-200:]))) except tf.errors.OutOfRangeError: break op_step = op_step + 1 self.iterator.switch_to_val_data(sess) cnt_acc = [] while True: try: feed = {tx.global_mode(): tf.estimator.ModeKeys.EVAL} acc_ = sess.run([acc], feed_dict=feed) cnt_acc.append(acc_) except tf.errors.OutOfRangeError: mean_acc = np.mean(cnt_acc) print('valid acc1={}'.format(mean_acc)) if mean_acc > max_val_acc: max_val_acc = mean_acc self.saver.save(sess, self.config._save_path) break def test(self): batch = self.iterator.get_next() loss, acc, rank = self.forward(batch) with tf.Session(config=self.gpu_config) as sess: sess.run(tf.tables_initializer()) self.saver = tf.train.Saver() self.saver.restore(sess, self.config._save_path) self.iterator.switch_to_test_data(sess) rank_cnt = [] while True: try: feed = {tx.global_mode(): tf.estimator.ModeKeys.PREDICT} ranks, labels = sess.run([rank, batch['label']], feed_dict=feed) for i in range(len(ranks)): rank_cnt.append(np.where(ranks[i]==labels[i])[0][0]) except tf.errors.OutOfRangeError: rec = [0,0,0,0,0] MRR = 0 for rank in rank_cnt: for i in range(5): rec[i] += (rank <= i) MRR += 1 / (rank+1) print('test rec1@20={:.4f}, rec3@20={:.4f}, rec5@20={:.4f}, MRR={:.4f}'.format( rec[0]/len(rank_cnt), rec[2]/len(rank_cnt), rec[4]/len(rank_cnt), MRR/len(rank_cnt))) break def retrieve_init(self, sess): data_batch = self.iterator.get_next() loss, acc, _ = self.forward(data_batch) self.corpus = self.data_config._corpus self.corpus_data = tx.data.MonoTextData(self.data_config.corpus_hparams) corpus_iterator = tx.data.DataIterator(self.corpus_data) batch = corpus_iterator.get_next() corpus_embed = self.embedder(batch['corpus_text_ids']) utter_code = self.target_encoder(corpus_embed, sequence_length=batch['corpus_length'])[1] self.corpus_code = np.zeros([0, self.config._code_len]) corpus_iterator.switch_to_dataset(sess) sess.run(tf.tables_initializer()) saver = tf.train.Saver() saver.restore(sess, self.config._save_path) feed = {tx.global_mode(): tf.estimator.ModeKeys.PREDICT} while True: try: utter_code_ = sess.run(utter_code, feed_dict=feed) self.corpus_code = np.concatenate([self.corpus_code, utter_code_], axis=0) except tf.errors.OutOfRangeError: break self.minor_length_input = tf.placeholder(dtype=tf.int32, shape=(1, 9)) self.major_length_input = tf.placeholder(dtype=tf.int32, shape=(1)) self.history_input = tf.placeholder(dtype=object, shape=(9, self.data_config._max_seq_len + 2)) history_ids = self.vocab.map_tokens_to_ids(self.history_input) history_embed = self.embedder(history_ids) history_code = self.source_encoder(tf.expand_dims(history_embed, axis=0), sequence_length_minor=self.minor_length_input, sequence_length_major=self.major_length_input)[1] select_corpus = tf.cast(self.corpus_code, dtype=tf.float32) feature_code = self.linear_matcher(select_corpus * history_code) self.ans_output = tf.nn.top_k(tf.squeeze(feature_code, 1), k=self.data_config._retrieval_candidates)[1] def retrieve(self, source, sess): history, seq_len, turns, context, context_len = source ans = sess.run(self.ans_output, feed_dict={self.history_input: history, self.minor_length_input: [seq_len], self.major_length_input: [turns]}) for i in range(self.data_config._max_turns + 1): if ans[i] not in self.reply_list: # avoid repeat self.reply_list.append(ans[i]) reply = self.corpus[ans[i]] break return reply ================================================ FILE: model/retrieval_stgy.py ================================================ import texar as tx import tensorflow as tf import numpy as np from preprocess.data_utils import kw_tokenize class Predictor(): def __init__(self, config_model, config_data, mode=None): self.config = config_model self.data_config = config_data self.build_model() self.gpu_config = tf.ConfigProto() self.gpu_config.gpu_options.allow_growth = True def build_model(self): self.train_data = tx.data.MultiAlignedData(self.data_config.data_hparams['train']) self.valid_data = tx.data.MultiAlignedData(self.data_config.data_hparams['valid']) self.test_data = tx.data.MultiAlignedData(self.data_config.data_hparams['test']) self.iterator = tx.data.TrainTestDataIterator(train=self.train_data, val=self.valid_data, test=self.test_data) self.vocab = self.train_data.vocab(0) self.embedder = tx.modules.WordEmbedder(init_value=self.train_data.embedding_init_value(0).word_vecs) self.source_encoder = tx.modules.HierarchicalRNNEncoder(hparams=self.config.source_encoder_hparams) self.target_encoder = tx.modules.UnidirectionalRNNEncoder(hparams=self.config.target_encoder_hparams) self.linear_matcher = tx.modules.MLPTransformConnector(1) self.kw_list = self.vocab.map_tokens_to_ids(tf.convert_to_tensor(self.data_config._keywords_candi)) def forward(self, batch): source_embed = self.embedder(batch['source_text_ids']) target_embed = self.embedder(batch['target_text_ids']) target_embed = tf.reshape(target_embed, [-1, self.data_config._max_seq_len + 2, self.embedder.dim]) source_code = self.source_encoder(source_embed, sequence_length_minor=batch['source_length'], sequence_length_major=batch['source_utterance_cnt'])[1] target_length = tf.reshape(batch['target_length'], [-1]) target_code = self.target_encoder(target_embed, sequence_length=target_length)[1] target_code = tf.reshape(target_code, [-1, 20, self.config._code_len]) source_code = tf.expand_dims(source_code, 1) source_code = tf.tile(source_code, [1, 20, 1]) feature_code = target_code * source_code feature_code = tf.reshape(feature_code, [-1, self.config._code_len]) logits = self.linear_matcher(feature_code) logits = tf.reshape(logits, [-1, 20]) labels = tf.one_hot(batch['label'], 20) loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=labels, logits=logits)) ans = tf.arg_max(logits, -1) acc = tx.evals.accuracy(batch['label'], ans) rank = tf.nn.top_k(logits, k=20)[1] return loss, acc, rank def train(self): batch = self.iterator.get_next() loss, acc, _ = self.forward(batch) op_step = tf.Variable(0, name='op_step') train_op = tx.core.get_train_op(loss, global_step=op_step, hparams=self.config.opt_hparams) max_val_acc = 0. self.saver = tf.train.Saver() with tf.Session(config=self.gpu_config) as sess: sess.run(tf.global_variables_initializer()) sess.run(tf.local_variables_initializer()) sess.run(tf.tables_initializer()) for epoch_id in range(self.config._max_epoch): self.iterator.switch_to_train_data(sess) cur_step = 0 cnt_acc = [] while True: try: cur_step += 1 feed = {tx.global_mode(): tf.estimator.ModeKeys.TRAIN} loss, acc_ = sess.run([train_op, acc], feed_dict=feed) cnt_acc.append(acc_) if cur_step % 200 == 0: print('batch {}, loss={}, acc1={}'.format(cur_step, loss, np.mean(cnt_acc[-200:]))) except tf.errors.OutOfRangeError: break op_step = op_step + 1 self.iterator.switch_to_val_data(sess) cnt_acc = [] while True: try: feed = {tx.global_mode(): tf.estimator.ModeKeys.EVAL} acc_ = sess.run([acc], feed_dict=feed) cnt_acc.append(acc_) except tf.errors.OutOfRangeError: mean_acc = np.mean(cnt_acc) print('valid acc1={}'.format(mean_acc)) if mean_acc > max_val_acc: max_val_acc = mean_acc self.saver.save(sess, self.config._save_path) break def test(self): batch = self.iterator.get_next() loss, acc, rank = self.forward(batch) with tf.Session(config=self.gpu_config) as sess: sess.run(tf.tables_initializer()) self.saver = tf.train.Saver() self.saver.restore(sess, self.config._save_path) self.iterator.switch_to_test_data(sess) rank_cnt = [] while True: try: feed = {tx.global_mode(): tf.estimator.ModeKeys.PREDICT} ranks, labels = sess.run([rank, batch['label']], feed_dict=feed) for i in range(len(ranks)): rank_cnt.append(np.where(ranks[i]==labels[i])[0][0]) except tf.errors.OutOfRangeError: rec = [0,0,0,0,0] MRR = 0 for rank in rank_cnt: for i in range(5): rec[i] += (rank <= i) MRR += 1 / (rank+1) print('test rec1@20={:.4f}, rec3@20={:.4f}, rec5@20={:.4f}, MRR={:.4f}'.format( rec[0]/len(rank_cnt), rec[2]/len(rank_cnt), rec[4]/len(rank_cnt), MRR/len(rank_cnt))) break def retrieve_init(self, sess): data_batch = self.iterator.get_next() loss, acc, _ = self.forward(data_batch) self.corpus = self.data_config._corpus self.corpus_data = tx.data.MonoTextData(self.data_config.corpus_hparams) corpus_iterator = tx.data.DataIterator(self.corpus_data) batch = corpus_iterator.get_next() corpus_embed = self.embedder(batch['corpus_text_ids']) utter_code = self.target_encoder(corpus_embed, sequence_length=batch['corpus_length'])[1] self.corpus_code = np.zeros([0, self.config._code_len]) corpus_iterator.switch_to_dataset(sess) sess.run(tf.tables_initializer()) saver = tf.train.Saver() saver.restore(sess, self.config._save_path) feed = {tx.global_mode(): tf.estimator.ModeKeys.PREDICT} while True: try: utter_code_ = sess.run(utter_code, feed_dict=feed) self.corpus_code = np.concatenate([self.corpus_code, utter_code_], axis=0) except tf.errors.OutOfRangeError: break self.keywords_embed = tf.nn.l2_normalize(self.embedder(self.kw_list), axis=1) self.kw_embedding = sess.run(self.keywords_embed) self.minor_length_input = tf.placeholder(dtype=tf.int32, shape=(1, 9)) self.major_length_input = tf.placeholder(dtype=tf.int32, shape=(1)) self.history_input = tf.placeholder(dtype=object, shape=(9, self.data_config._max_seq_len + 2)) history_ids = self.vocab.map_tokens_to_ids(self.history_input) history_embed = self.embedder(history_ids) history_code = self.source_encoder(tf.expand_dims(history_embed, axis=0), sequence_length_minor=self.minor_length_input, sequence_length_major=self.major_length_input)[1] select_corpus = tf.cast(self.corpus_code, dtype=tf.float32) feature_code = self.linear_matcher(select_corpus * history_code) self.ans_output = tf.nn.top_k(tf.squeeze(feature_code, 1), k=1000)[1] def retrieve(self, source, sess): history, seq_len, turns, context, context_len = source ans = sess.run(self.ans_output, feed_dict={self.history_input: history, self.minor_length_input: [seq_len], self.major_length_input: [turns]}) flag = 0 reply = self.corpus[ans[0]] for i in ans: if i in self.reply_list: # avoid repeat continue for wd in kw_tokenize(self.corpus[i]): if wd in self.data_config._keywords_candi: tmp_score = sum(self.kw_embedding[self.data_config._keywords_dict[wd]] * self.kw_embedding[self.data_config._keywords_dict[self.target]]) if tmp_score > self.score: reply = self.corpus[i] self.score = tmp_score self.next_kw = wd flag = 1 break if flag == 0: continue break return reply ================================================ FILE: preprocess/convai2/__init__.py ================================================ from .api import * ================================================ FILE: preprocess/convai2/api.py ================================================ import os data_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'source') class dts_ConvAI2(object): def __init__(self, path=data_path): self.path = path def _txt_to_json(self, txt_path, mode, cands): def pop_one_sample(lines): self_persona = [] other_persona = [] dialog = [] candidates = [] started = False while len(lines) > 0: line = lines.pop() id, context = line.split(' ', 1) id = int(id) context = context.strip() if started == False: # not started assert id == 1 started = True elif id == 1: # break for next lines.append(line) break if context.startswith('partner\'s persona: '): # partner assert mode in ['both', 'other'] other_persona.append(context[19:]) elif context.startswith('your persona: '): # self assert mode in ['both', 'self'] self_persona.append(context[13:]) elif cands == False: # no cands try: uttr, response = context.split('\t', 2)[:2] dialog.append(uttr) dialog.append(response) except: uttr = context dialog.append(uttr) else: uttr, response, _, negs = context.split('\t', 4)[:4] dialog.append(uttr) dialog.append(response) candidates.append(negs.split('|')) candidates.append(None) return { 'self_persona': self_persona, 'other_persona': other_persona, 'dialog': dialog, 'candidates': candidates } lines = open(txt_path, 'r').readlines()[::-1] samples = [] while len(lines) > 0: samples.append(pop_one_sample(lines)) return samples def get_data(self, mode='train', revised=False, cands=False): txt_path = os.path.join(self.path, '{}_{}_{}{}.txt'.format( mode, 'none', 'revised' if revised is True else 'original', '' if cands is True else '_no_cands')) assert mode in ['train', 'valid', 'test', 'all'] print("Get dialog from ", txt_path) assert os.path.exists(txt_path) return self._txt_to_json(txt_path, mode, cands) def get_dialogs(self, mode='all'): dialogs = [sample['dialog'] for sample in self.get_data(mode, False, False)] return dialogs ================================================ FILE: preprocess/convai2/candi_keyword.txt ================================================ favorite sound play dog music kid eat school enjoy job watch read food cat friend family hobby people pet car game hear movie travel book cook listen animal life color drive college hope living parent teach bad sport hard dad feel child hair band country money pizza marry bet hate walk stay study write husband fish start guess brother blue night dance busy red wife learn talk sing meet beach spend drink rock city video grow person house tv true shop teacher buy girl weekend prefer hike meat football free visit sister care plan art swim sweet paint stuff pretty vegan store class farm type funny understand son business happy mother ride coffee leave fan garden week boy bake green pay beautiful draw student horse sell sad summer wear truck black hot wait yea sleep cold agree single guy sibling real crazy healthy fine tall guitar lose cute hour cake purple month relax finish idea break company daughter happen dream team mind park woman reading restaurant bike short italian exercise hang chocolate wonderful basketball speak soccer weather graduate retire winter morning nurse bring die easy party grade father ready ice win spare doctor online song local degree chicken concert florida glad fall baseball eye office volunteer water girlfriend bear french shopping age luck fishing artist weird foot save surf dinner yoga lucky story hunt cooking pink suck cream boyfriend language beer outdoors change diet passion york major cheese collect imagine super english rap practice chat head california train lake clothes pass nature baby fry season apple piano sit hop scary taco law steak hockey comic brown bar gym smart huge clean fav close jazz vegetarian canada science career picture tea excite tough deal pick allergic neat church social sick shoe trip fly vacation catch bed raise boring rid ocean club town instrument check race dragon fast vegetable wrong boat terrible tennis candy rain worry veggie fruit metal perfect twin tattoo mountain tomorrow god stand hospital apartment build stick design japan texas meal grocery camp hit orange mexican allergy ta phone famous player flower bore army star share pool delicious relationship exciting math egg museum classic dress sushi taste married amaze classical lady shelter sense joke pop pie yellow american lawyer expensive stress hand cut pasta remember professional choice impressive youtube kinda yum chicago birthday cooky sunday follow divorce gon moment fresh tire hurt wedding fit weight health plant count chef heard ball scar sort smell dead special comedy couple rich hiking fave create accountant bird relaxing history blonde film everyday glass heart voice ballet vet military horror field fight mile salad extra afraid market christmas reason mexico attend cop chance potato snow halloween folk snake ski character card adopt figure tho hat nail ford fat warm swimming difficult sew suppose dye safe wine tend style afford recipe writer shower lunch remind painting news congrats photography rough road holiday join beat white throne michigan awful actor breakfast library congratulation goal shrimp singer middle horrible common plane profession bacon smoke coast kill future word female south craft neighbor violin fair paris fashion quiet service bank workout dish shape tour theater genre forget fiction makeup model clown author experience rest librarian king puppy german tree trust street yummy marketing quit hungry creative poor cali wild alright accident fantasy cartoon air pepper hurricane ton lab anime university drop factory iphone mystery active netflix stressful space trouble secret mcdonalds laugh gun support spanish pain casino france set spaghetti sale magic grill burger climb internet pig hold decide mustang program drum hmm focus seafood master medium strange lesson arm war serve van north cow toe driver paper education table station dancer soda playing medical dark blog usa power wan bob simple zoo shoot fancy mechanic collection national motorcycle strong parrot nursing america harry blood ahahah body bible police energy alabama lazy death rose musician italy adventure worker writing pair activity bread throw issue europe london gross jealous male height disney treat marathon door bean bos boot aww cancer lover milk skill plenty beard light poetry subject sea romantic relate spending firm send seattle couch question flavor photo channel favourite match contact choose japanese singing lift boston wood grandmother farmer radio navy roommate bus politics grandchild jam chinese viking matter series competition designer gas san feeling knit lizard slow publish fantastic crochet donate add television insurance talent bowl human land rise hmmm test alive involve dangerous pro yesterday wind obsess golden stone post west nephew depend nervous breed bakery hip goodness baker surgery mall don honda colorado dancing dr dude lifestyle calm wake crime athlete skin beauty lie facebook mar reader lead mess snack shift employ spring size handle shy iron corvette evening superman board cheer traffic grand sun fee musical base public chip mad careful pound brand potter punk smile advice justin manager bass peaceful golf clothing john action buddy sunny box waitress gardening popular wing wall personal coach cover elementary mix pray odd stink minute key inspire event adorable community juice engineer karate skate saturday literature reward baking ghost ohio vega kitchen sugar chill tomato chase chevy roll la organize blame homework bee grandma beatles passionate soul knee anxiety nut spot ticket grandparent honest google junk product tech track normal role client veterinarian spicy building scare hawaii bummer reality iguana niece opera charlie worth pack trail senior toyota beet ireland finance chili officer pickle kinds manage tired christian daily nose cheap spider gig indian locate pleasure league hunting retired george tuna teaching attention step awe stock list portland sign avoid bug brain scientist dessert excellent finger religious superhero highschool kick rule salesman drawing handful pilot escape schedule tasty gosh burrito kansa commercial rescue mary fi retail retriever skydive husky gamble georgia industry meeting nyc walmart annoy cruise doubt forest bunch cashier actress kayak partner mac tiny friday spell culture kindergarten lay tiger deaf mention drinking accounting cup subway pancake habit prius blind familiar starbucks assistant pumpkin burn answer poodle metallica circus jewelry sock recommend sandwich journalist fear pretend grandson nfl brownie cupcake doll stephen bother desert sci moon construction reach technology option marriage pony hell romance cuisine strawberry shirt bbq purse leg kitten unemployed project running deliver natural dang control hero beef trade sunset carrot main chess grandkids talented explore nasty east hamburger private yup guard choir australia pen stamp struggle dive bro perry omg thinking candle fell ink wheel ranch owner carry biology earn regular lion peanut broccoli suit medicine speaking chew inspiration sauce hotel camera england repair turtle athletic unique ray sky skateboard river tune freak estate cheesecake yuck surfing rent ooh grey memory perform popcorn dislike childhood adore quick comedian sweater antique lottery hows mcdonald jump tutor carolina walking pregnant mike vampire decorate bieber alcohol compete mansion owl gotcha jack engineering retirement pot airplane ferrari dry dentist russian piece security spirit offer dorm record settle lobster foreign software honey rice princess excited amazon baltimore island skiing center alien butter corn civic jane view nap pit bulldog lovely prince loud photograph gift coke org belt festival ugh ahh vintage pug birth understandable sweetheart irma deep india feed ring item scene spear mushroom position afternoon hire trainer distract touch unfortunate alaska expect katy scratch cost situation gay loss organic washington joy gummy survive med bless prepare charity sight rare heavy rural russia newspaper karaoke driving customer watching require graphic mood maine freelance fitness diner pepsi condo miami cross runner accept panda bunny engage uncle commute indie cooler straight hollywood bagel terrify training boxer left protein bull jog tom shame ouch current programmer nerd magazine artistic yikes eating skittle furniture eagle form fabulous legal agency internship cabin drama positive addict surprise rewarding tax barista fake spain crash random kale bright shark studio bow boys bell brace trick wheelchair cloud southern force chair spouse thumb frank rapper virginia physical bye grad soup fiance elvis meatloaf queen united income salon volleyball target chihuahua limit scholarship direction supply canadian daddy toy kentucky respect earth political binge hilarious blast unhealthy pant cheeseburger complete unicorn religion dairy drug fl whiskey reject iced average eater dirty rat angry heck hide competitive gum website laptop exhaust robot challenge outdoor raw audition cafe onion assume opinion horseback zebra philosophy psychology exact spice debt reside heat hobbies leather rude china storm grown invite instructor steal curly wash worm bf credit greek oregon shot riding pride speed floor intense court dental bone friendly diving lame kitty alternative hubby strict ham camping pond marine grandpa secretary theme judge shade complain herb advertising celebrity fascinate lasagna environment painter comfortable beagle recycle bruno invest search society halo pursue hm sam connect angeles weed grab kiss bald em britney oil conversation mmm yard cash jean ship exotic vanilla waste photographer adult rolling dig cookie tennessee balance cleaning ocd eggplant archery ma meatball dust deer gluten agent gaming celebrate helpful tofu campus lord evil er los shake martial drummer outfit grass wrestle note convertible biking taller gorgeous file hectic salsa rush awhile distance soft homeless daycare process patient houston crowd stew duty bookstore tie neighborhood professor orleans electric original opportunity francisco angel fund site peace picky wisconsin madonna adam iceland blow uh total urban coincidence entire jello compare dollar headache blond guilty cure electronic grader greece correct twilight enjoyable benefit damn coupon cheat thrift print scout successful bartender stranger cattle mommy bath nascar honor suggestion appalachian ginger hook aquarium james scared gamer collie creepy upstate therapist bicycle trophy department cheerleader variety suburb explain cuddle flip shepherd tupac whoa iq lifeguard sunshine jersey rainy vehicle closet rainbow ross specialty attorney interior gf crab podcasts fart decent insane shepard gossip jimmy blackjack austin nike louisiana brat therapy noble skunk listening awww pleasant pittsburgh barbie specific pas um avid occasion mustache autograph noise diego ruin admit fail scotch creature costume jeopardy swift africa cd account edit topic handy window steelers accent activist preacher affect apply loose refuse bum minnesota gender confuse landscape broadway bmw argue foodie trek relieve todd path freedom pearl tap medication vera rapid ohh government environmental fault ft cope carbs standard planet mcqueen nugget pull difference babysit teller disappoint mermaid pageant soap midwest giant puerto bowling asian arizona happiness provide fond bud cell entertain butterfly genius scream architect jar monkey theatre hah purchase yay vote level cab combo tool fluent duck creek load affair humor slave publishing painful redhead jerry thursday dungeon cape messy approve cousin weekly terrier paddle david typical waffle desk catholic germany instagram admire image upset muscle monday pic basement ground wave shellfish technician episode jim jacob ballerina loan improve garage ibm tooth bachelor thrill hippie notice return crush jesus stomach techno raven nerve denver foster thriller rugby daydream oreo discover detroit cult atlanta incredible stable poem yr oooh wide mango snowboard weapon countryside alcoholic brunch fisherman aunt toronto sarah addiction surround lactose clinic universe pretzel toto blah sir sausage cosplay text quality millionaire clerk skinny hendrix disabled puzzle pepperoni civil surfer larp package roof pa suspense drone mail jason exam mark financial encyclopedia cheetos demand shell planning stupid yell grateful bingo source companion director bite detective biography gospel silly pudding pork teeth autobiography salt footstep deserve produce swimmer barbecue maryland btw defense fallon teen continue cart wizard meditate shelf zombie irish pecan bubble discount scooter push tutorial scuba homemade weakness translator gymnastics background softball kidding mistake realize ironic floyd flight surgeon crack desire introvert ted knife mba pottery orphan recover mini bucket perk autism moped cycle youth spoil attitude injury pennsylvania aspire loyal attack murder price tank safety olympics rome dj carpenter sesame consume protect british sword cheetah float asia mate creed xbox lean freckle caffeine hunter trump programming picnic ear fridge easter asparagus oldie darn disagree mirror hehe ew biscuit freshman whistle usual inch deli eclipse cycling modern tease review pattern award belong nickname buff sweden cuz selfish personality graduation dolphin album pup taylor lease educate depress paramedic gila purpose prison worried vermont block buffalo olive identical score string actual intelligent epilepsy sand plate subtitle border cable smooth lexus develop vienna brave welfare doberman wealthy switch sneak robert hill nacho mugger snap dumb coworker peta hr chick fur goalie range introduce costco railroad suffer menu soldier asthma sex lindsey utah grandfather documentary admirable traveling dane employee article caramel harley equipment heal basic dabble depression attract gaga tale tube fried doggy mash taught receptionist empire intern piercings tackle ease blanket participate cheesy gray daisy toddler link diamond propose dallas president ranger wolf gain chai annoying earring version basket lens salary corner champion firefighter ferret achieve sears mia idol joe decade emotion koala management pharmacist apps depends killer fellow uniform gourmet cleveland motivate hummus entertainment mmmm bag nashville flirt owen trumpet nevada stage jerky responsibility drake bentley gold tx arcade ankle vegas kj batman unwind keyboard combination leaf koi cello minimum adventurous loving kiddos spill diabetic central rob trend bubblegum indoors monster drunk confidence pyramid grunge banker breath grasshopper hoop encourage gutter macaroni robotics double seat dew uncomfortable flash bench bomb info semi comfort wildlife geology lonely coz edge anniversary vitamin material hotdog gathering socialize machine editor droopy brew liberal mercedes quarterback rn planner fortunate ur greenhouse si psychologist promotion def discovery carb humane broken chanel fluffy chain vision stanford pipe ability overweight promote labrador veteran preference symphony alpaca ve app teenager anne promise doo publisher curious tiki porsche mixed maid legend michael supportive pineapple ariel diabetes consulting starve gal collar gable battle jacket sexy sleeve felix pastime jamaica mortal weak scrub rabbit godfather sinatra valley despise regret goodwill heaven buddhist smoking oops pitbulls salmon rick bitcoin dip trout pill farming thankful tokyo housewife prayer impala valedictorian plain message temper flintstone leprechaun sucker breathe csi criminal rip maiden fascinating rico algeria report umm patience leader curl motivation climbing tahoe ymca relief glacier breast enter clutter dull fighter tat awake brewery victorian volcano friends mount pillage magical generation clue conscious stare silver wrestling levine joint restore everest dope stray international parking hampshire hearse warehouse pitbull nyu outdoorsy development employment drinker zumba paul budget daniel eyesight sour mouth stain blogger exist rib brush interview bff custom snuggle vancouver mario ferraris mural poet oriole period karma damage warmer crossword childrens pomeranian imaginary dave anatomy tone code videogames woodstock convention janitor preschool screen prejudice crystal rage tradition chatting traditional parakeet ramen combat multiple crave syrup racing highlight communist concentrate waiter ebooks dodge hp boil attic medal commitment release downtown alligator statement debate agreed maga homeschooled strength plumber hippy windy condition smoothie stair content depressed ferrell keto remodel donut winner playlist wayne nation kpop map coon junior mum tape quake smithsonian washer abigail radiohead humble unicycle administration ontario performance truth fred ingredient cucumber beastie orchestra sewing knock culinary sweat seashell impression network languages tailgate celebration thomas embarrass born mama freeze crap fortune figurine confident homebody chemistry collector merna arrive titanic meditation bout manta announcer solo circle md funeral engine butt delivery ultimate specialize web palm absolute investment harsh pistachio loner experiment gut austen fuel cramp trauma sleepy celtic press draft auto sprite obsession sip fifty vinyl swing fool hbu harvey copperfield playoff kite lesbian jerk owe democrat mass hamilton ga uk luis impress slice pita hobbie apologize santa tacos landing hometown telecom mater mutt deploy del sore nancy barbies fam clay ethnic pastry hostage tight backyard convince maker curry android pc jessica ignore flow sickness elderly chore upholstery sweetie lettuce cuba gadget animation trooper faith tongue success gentle portrait sheeran chevrolet packer risk spark frustrate mouse pitch weld eyebrow bella linebacker bully routine spelling bc coat saudi arabia tampa emmy samsung mop kevin checker teapot weigh suv miserable sevenfold f150 lit posse thai curator steve poop historical morty cane miley wise petition tear penn astronaut cod colour acting precious buck lucy muse cosmetic occupation nba ate flexible ideal suspender bang direct gotti agitate hairdresser dealership influence cursive sunfish snorkel shallow root pediatrician compost coaster nearby foreman deadbeat penny jay jasper tarot pressure clarinet supper express ai martini favor chop lutefisk charge dakota hitchhike formal ivy raptor battlestar captain disgust task sitcom yorkie coco understood naw ant stinky speckle title corporate wednesday gambler wage multi mma cookbook citizen hazel aspiration goat stuck lumberjack flag wet ufc learning stirling dealer grisham acre ================================================ FILE: preprocess/data_utils.py ================================================ import nltk import os from nltk.stem import WordNetLemmatizer _lemmatizer = WordNetLemmatizer() def tokenize(example, ppln): for fn in ppln: example = fn(example) return example def kw_tokenize(string): return tokenize(string, [nltk_tokenize, lower, pos_tag, to_basic_form]) def simp_tokenize(string): return tokenize(string, [nltk_tokenize, lower]) def nltk_tokenize(string): return nltk.word_tokenize(string) def lower(tokens): if not isinstance(tokens, str): return [lower(token) for token in tokens] return tokens.lower() def pos_tag(tokens): return nltk.pos_tag(tokens) def to_basic_form(tokens): if not isinstance(tokens, tuple): return [to_basic_form(token) for token in tokens] word, tag = tokens if tag.startswith('NN'): pos = 'n' elif tag.startswith('VB'): pos = 'v' elif tag.startswith('JJ'): pos = 'a' else: return word return _lemmatizer.lemmatize(word, pos) def truecasing(tokens): ret = [] is_start = True for word, tag in tokens: if word == 'i': ret.append('I') elif tag[0].isalpha(): if is_start: ret.append(word[0].upper() + word[1:]) else: ret.append(word) is_start = False else: if tag != ',': is_start = True ret.append(word) return ret candi_keyword_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'convai2/candi_keyword.txt') _candiwords = [x.strip() for x in open(candi_keyword_path).readlines()] def is_candiword(a): if a in _candiwords: return True return False from nltk.corpus import wordnet as wn from nltk.corpus import wordnet_ic brown_ic = wordnet_ic.ic('ic-brown.dat') def calculate_linsim(a, b): linsim = -1 syna = wn.synsets(a) synb = wn.synsets(b) for sa in syna: for sb in synb: try: linsim = max(linsim, sa.lin_similarity(sb, brown_ic)) except: pass return linsim def is_reach_goal(context, goal): context = kw_tokenize(context) if goal in context: return True for wd in context: if is_candiword(wd): rela = calculate_linsim(wd, goal) if rela > 0.9: return True return False def make_context(string): string = kw_tokenize(string) context = [] for word in string: if is_candiword(word): context.append(word) return context def utter_preprocess(string_list, max_length): source, minor_length = [], [] string_list = string_list[-9:] major_length = len(string_list) if major_length == 1: context = make_context(string_list[-1]) else: context = make_context(string_list[-2] + string_list[-1]) context_len = len(context) while len(context) < 20: context.append('') for string in string_list: string = simp_tokenize(string) if len(string) > max_length: string = string[:max_length] string = [''] + string + [''] minor_length.append(len(string)) while len(string) < max_length + 2: string.append('') source.append(string) while len(source) < 9: source.append([''] * (max_length + 2)) minor_length.append(0) return (source, minor_length, major_length, context, context_len) ================================================ FILE: preprocess/dataset.py ================================================ import numpy as np import collections import random import pickle from convai2 import dts_ConvAI2 from extraction import KeywordExtractor from data_utils import * class dts_Target(dts_ConvAI2): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) def get_vocab(self): counter = collections.Counter() dialogs = self.get_dialogs() for dialog in dialogs: for uttr in dialog: counter.update(simp_tokenize(uttr)) print('total vocab count: ', len(counter.items())) vocab = [token for token, times in sorted(list(counter.items()), key=lambda x: (-x[1], x[0]))] with open('../tx_data/vocab.txt','w') as f: for word in vocab: f.write(word + '\n') print('save vocab in vocab.txt') return vocab def get_kwsess(self, vocab, mode='all'): keyword_extractor = KeywordExtractor(vocab) corpus = self.get_data(mode = mode, cands=False) sess_set = [] for sess in corpus: data = {} data['history'] = '' data['dialog'] = [] for dialog in sess['dialog']: data['dialog'].append(dialog) data['history'] = data['history'] + ' ' + dialog data['kws'] = keyword_extractor.extract(data['history']) sess_set.append(data) return sess_set def cal_idf(self): counter = collections.Counter() dialogs = self.get_dialogs() total = 0. for dialog in dialogs: for uttr in dialog: total += 1 counter.update(set(kw_tokenize(uttr))) idf_dict = {} for k,v in counter.items(): idf_dict[k] = np.log10(total / (v+1.)) return idf_dict def make_dataset(self): vocab = self.get_vocab() idf_dict = self.cal_idf() kw_counter = collections.Counter() sess_set = self.get_kwsess(vocab) for data in sess_set: kw_counter.update(data['kws']) kw_freq = {} kw_sum = sum(kw_counter.values()) for k, v in kw_counter.most_common(): kw_freq[k] = v / kw_sum for data in sess_set: data['score'] = 0. for kw in set(data['kws']): data['score'] += kw_freq[kw] data['score'] /= len(set(data['kws'])) sess_set.sort(key=lambda x: x['score'], reverse=True) all_data = {'train':[], 'valid':[], 'test':[]} keyword_extractor = KeywordExtractor(idf_dict) for id, sess in enumerate(sess_set): type = 'train' if id < 500: type = 'test' elif random.random() < 0.05: type = 'valid' sample = {'dialog':sess['dialog'], 'kwlist':[]} for i in range(len(sess['dialog'])): sample['kwlist'].append(keyword_extractor.idf_extract(sess['dialog'][i])) all_data[type].append(sample) pickle.dump(all_data, open('source_data.pk','wb')) return all_data ================================================ FILE: preprocess/extraction.py ================================================ from data_utils import * class KeywordExtractor(): def __init__(self, idf_dict = None): self.idf_dict = idf_dict @staticmethod def is_keyword_tag(tag): return tag.startswith('VB') or tag.startswith('NN') or tag.startswith('JJ') @staticmethod def cal_tag_score(tag): if tag.startswith('VB'): return 1. if tag.startswith('NN'): return 2. if tag.startswith('JJ'): return 0.5 return 0. def idf_extract(self, string, con_kw = None): tokens = simp_tokenize(string) seq_len = len(tokens) tokens = pos_tag(tokens) source = kw_tokenize(string) candi = [] result = [] for i, (word, tag) in enumerate(tokens): score = self.cal_tag_score(tag) if not is_candiword(source[i]) or score == 0.: continue if con_kw is not None and source[i] in con_kw: continue score *= source.count(source[i]) score *= 1 / seq_len score *= self.idf_dict[source[i]] candi.append((source[i], score)) if score > 0.15: result.append(source[i]) return list(set(result)) def extract(self, string): tokens = simp_tokenize(string) tokens = pos_tag(tokens) source = kw_tokenize(string) kwpos_alters = [] for i, (word, tag) in enumerate(tokens): if source[i] and self.is_keyword_tag(tag): kwpos_alters.append(i) kwpos, keywords = [], [] for id in kwpos_alters: if is_candiword(source[id]): keywords.append(source[id]) return list(set(keywords)) ================================================ FILE: preprocess/prepare_data.py ================================================ from dataset import dts_Target from collections import Counter import pickle import random import os import shutil if not os.path.exists('../tx_data'): os.mkdir('../tx_data') os.mkdir('../tx_data/train') os.mkdir('../tx_data/valid') os.mkdir('../tx_data/test') # import texar # if not os.path.exists('convai2/source'): # print('Downloading source ConvAI2 data') # texar.data.maybe_download('https://drive.google.com/file/d/1LPxNIVO52hZOwbV3Zply_ITi2Uacit-V/view?usp=sharing' # ,'convai2', extract=True) shutil.copy('convai2/source/embedding.txt', '../tx_data/embedding.txt') dataset = dts_Target() dataset.make_dataset() data = pickle.load(open("source_data.pk","rb")) max_utter = 9 candidate_num = 20 start_corpus_file = open("../tx_data/start_corpus.txt", "w") corpus_file = open("../tx_data/corpus.txt", "w") for stage in ['train', 'valid', 'test']: source_file = open("../tx_data/{}/source.txt".format(stage), "w") target_file = open("../tx_data/{}/target.txt".format(stage), "w") context_file = open("../tx_data/{}/context.txt".format(stage), "w") keywords_file = open("../tx_data/{}/keywords.txt".format(stage), "w") label_file = open("../tx_data/{}/label.txt".format(stage), "w") keywords_vocab_file = open("../tx_data/{}/keywords_vocab.txt".format(stage), "w") corpus = [] keywords_counter = Counter() for sample in data[stage]: corpus += sample['dialog'][1:] start_corpus_file.write(sample['dialog'][0]+ '\n') for kws in sample['kwlist']: keywords_counter.update(kws) for kw, _ in keywords_counter.most_common(): keywords_vocab_file.write(kw + '\n') for sample in data[stage]: for i in range(2, len(sample['dialog'])): if len(sample['kwlist'][i]) > 0: source_list = sample['dialog'][max(0, i - max_utter):i] source_str = '|||'.join(source_list) while True: random_corpus = random.sample(corpus, candidate_num - 1) if sample['dialog'][i] not in random_corpus: break corpus_file.write(sample['dialog'][i] + '\n') target_list = [sample['dialog'][i]] + random_corpus target_str = '|||'.join(target_list) source_file.write(source_str + '\n') target_file.write(target_str + '\n') context_file.write(' '.join(sample['kwlist'][i-2] + sample['kwlist'][i-1]) + '\n') keywords_file.write(' '.join(sample['kwlist'][i]) + '\n') label_file.write('0\n') source_file.close() target_file.close() label_file.close() keywords_vocab_file.close() context_file.close() start_corpus_file.close() corpus_file.close() ================================================ FILE: readme.md ================================================ # Target-Guided Open-Domain Conversation This is the code for the following paper: [Target-Guided Open-Domain Conversation](http://arxiv.org/abs/1905.11553) *Jianheng Tang, Tiancheng Zhao, Chenyan Xiong, Xiaodan Liang, Eric Xing, Zhiting Hu; ACL 2019* ### Requirement - `nltk==3.4` - `tensoflow==1.12` - `texar>=0.2.1` ([Texar](https://github.com/asyml/texar)) ### Usage #### Data Preparation The dataset developed in the paper is on [google drive](https://drive.google.com/file/d/1oTjOQjm7iiUitOPLCmlkXOCbEPoSWDPX/view?usp=sharing). Download and unzip it into `preprocess/convai2`. Then run the following command: ```shell cd preprocess python prepare_data.py ``` By default, the processed data will be put in the `tx_data` directory. #### Turn-level Supervised Learning In this project there are 5 different types of agents, including the kernel/neural/matrix/retrieval/retrieval_stgy agent, which are all discribed in the paper. You can modify the configration of each agent in the `config` directory. To train the kernel/neural/matrix agent, you need to first train/test the keyword prediction module, and then train/test the retrieval module of each agent specified by the `--agent` parameter. ```shell python train.py --mode train_kw --agent kernel python train.py --mode train --agent kernel python train.py --mode test --agent kernel ``` The retrieval agent and the retrieval_stgy agent share the same retrival module. You only need to train one of them: ```shell python train.py --mode train --agent retrieval python train.py --mode test --agent retrieval ``` #### Target-guided Conversation After turn-level training, you can start target-guided conversation (human evaluation) with the kernel/neural/matrix/retrieval/retrieval_stgy agent specified by the `--agent` parameter. ```shell python chat.py --agent kernel ``` A demo of target-guided conversation (user is required to input the response after `HUMAN:`) ``` --------Session 0 -------- START: hello, how are you doing HUMAN: hi i am great just finishing up some homework. how about you AGENT: just got done eating some chicken and rice after i got done with my gym session HUMAN: i am gonna eat some vegetable for dinner AGENT: so you are vegetarian ? i am from los angeles so thereare lots of vegetarian restaurants HUMAN: i am not a vegetarian . fried chicken is my favorite. AGENT: oh boy , when i go to mcdonald i always get fries for my 2 dogs Successfully chat to the target 'mcdonald'. ``` You can also watch the simulation of the target-guided conversation between the retrieval agent pretending the user and the kernel/neural/matrix/retrieval_stgy agent specified by the `--agent` parameter. Simulation times are defined by the `--times` parameter. The program will calculate the success times and average turns in the end. ```shell python simulate.py --agent kernel --times 100 ``` A demo of simulation: ``` --------Session 0 -------- START: hey hey good morning to you retrieval_agent: fine thanks . what do you do for a living ? kernel_agent: rewarding job , i work at a hospital Keyword: job, Similarity: 0.58 Successfully chat to the target 'hospital'. ... --------Session 99 -------- START: hey hows it going ? i'm just cooking a steak retrieval_agent: i'm thinking of a bbq sandwich for lunch kernel_agent: nice i love to cook but now its just me and the fur babies Keyword: baby, Similarity: 0.45 retrieval_agent: i love bagels however i own a dry cleaners kernel_agent: i love animals felix my cat and my dog emmy Keyword: cat, Similarity: 0.56 retrieval_agent: sounds awesome i have all kind of pets my family own a farm kernel_agent: i love blue as well even my hair is blue Keyword: blue, Similarity: 1.00 Successfully chat to the target 'blue'. success time 83, average turns 4.28 ``` ================================================ FILE: simulate.py ================================================ import tensorflow as tf import importlib import random from preprocess.data_utils import utter_preprocess, is_reach_goal from model import retrieval class Target_Simulation(): def __init__(self, config_model, config_data, config_retrieval): g1 = tf.Graph() with g1.as_default(): self.retrieval_agent = retrieval.Predictor(config_retrieval, config_data) sess1 = tf.Session(graph=g1, config=self.retrieval_agent.gpu_config) self.retrieval_agent.retrieve_init(sess1) g2 = tf.Graph() with g2.as_default(): self.target_agent = model.Predictor(config_model, config_data) sess2 = tf.Session(graph=g2, config=self.target_agent.gpu_config) self.target_agent.retrieve_init(sess2) self.start_utter = config_data._start_corpus success_cnt, turns_cnt = 0, 0 for i in range(int(FLAGS.times)): print('--------Session {} --------'.format(i)) success, turns = self.simulate(sess1, sess2) success_cnt += success turns_cnt += turns print('success time {}, average turns {:.2f}'.format(success_cnt, turns_cnt / success_cnt)) def simulate(self, sess1, sess2): history = [] history.append(random.sample(self.start_utter,1)[0]) target_kw = random.sample(target_set,1)[0] self.target_agent.target = target_kw self.target_agent.score = 0. self.target_agent.reply_list = [] self.retrieval_agent.reply_list = [] print('START: ' + history[0]) for i in range(config_data._max_turns): source = utter_preprocess(history, config_data._max_seq_len) reply = self.retrieval_agent.retrieve(source, sess1) print('retrieval_agent: ', reply) history.append(reply) source = utter_preprocess(history, config_data._max_seq_len) reply = self.target_agent.retrieve(source, sess2) print('{}_agent: '.format(FLAGS.agent), reply) print('Keyword: {}, Similarity: {:.2f}'.format(self.target_agent.next_kw, self.target_agent.score)) history.append(reply) if is_reach_goal(history[-2] + history[-1], target_kw): print('Successfully chat to the target \'{}\'.'.format(target_kw)) return (True, (len(history)+1)//2) print('Failed by reaching the maximum turn, target: \'{}\'.'.format(target_kw)) return (False, 0) if __name__ == '__main__': flags = tf.flags flags.DEFINE_string('agent', 'kernel', 'The agent type, supports kernel / matrix / neural / retrieval.') flags.DEFINE_string('times', '100', 'Simulation times.') FLAGS = flags.FLAGS config_data = importlib.import_module('config.data_config') config_model = importlib.import_module('config.' + FLAGS.agent) config_retrieval = importlib.import_module('config.retrieval') model = importlib.import_module('model.' + FLAGS.agent) target_set = [] for line in open('tx_data/test/keywords.txt', 'r').readlines(): target_set = target_set + line.strip().split(' ') Target_Simulation(config_model,config_data,config_retrieval) ================================================ FILE: train.py ================================================ import tensorflow as tf import importlib import os if __name__ == '__main__': flags = tf.flags flags.DEFINE_string('data', 'data_config', 'The data config') flags.DEFINE_string('agent', 'kernel', 'The predictor type') flags.DEFINE_string('mode', 'train', 'The mode') FLAGS = flags.FLAGS config_data = importlib.import_module('config.' + FLAGS.data) config_model = importlib.import_module('config.' + FLAGS.agent) model = importlib.import_module('model.' + FLAGS.agent) predictor = model.Predictor(config_model, config_data, FLAGS.mode) if not os.path.exists('save/'+FLAGS.agent): os.makedirs('save/'+FLAGS.agent) if FLAGS.mode == 'train_kw': predictor.train_keywords() if FLAGS.mode == 'test_kw': predictor.test_keywords() if FLAGS.mode == 'train': predictor.train() predictor.test() if FLAGS.mode == 'test': predictor.test()