Repository: squareRoot3/Target-Guided-Conversation
Branch: master
Commit: 4bbfaece1da3
Files: 22
Total size: 119.2 KB
Directory structure:
gitextract_1diizlg0/
├── chat.py
├── config/
│ ├── data_config.py
│ ├── kernel.py
│ ├── matrix.py
│ ├── neural.py
│ ├── retrieval.py
│ └── retrieval_stgy.py
├── model/
│ ├── kernel.py
│ ├── matrix.py
│ ├── neural.py
│ ├── retrieval.py
│ └── retrieval_stgy.py
├── preprocess/
│ ├── convai2/
│ │ ├── __init__.py
│ │ ├── api.py
│ │ └── candi_keyword.txt
│ ├── data_utils.py
│ ├── dataset.py
│ ├── extraction.py
│ └── prepare_data.py
├── readme.md
├── simulate.py
└── train.py
================================================
FILE CONTENTS
================================================
================================================
FILE: chat.py
================================================
import tensorflow as tf
import importlib
import random
from preprocess.data_utils import utter_preprocess, is_reach_goal
class Target_Chat():
def __init__(self, agent):
self.agent = agent
self.start_utter = config_data._start_corpus
with tf.Session(config=self.agent.gpu_config) as sess:
self.agent.retrieve_init(sess)
for i in range(int(FLAGS.times)):
print('--------Session {} --------'.format(i))
self.chat(sess)
def chat(self, sess):
history = []
history.append(random.sample(self.start_utter, 1)[0])
target_kw = random.sample(target_set,1)[0]
self.agent.target = target_kw
self.agent.score = 0.
self.agent.reply_list = []
print('START: ' + history[0])
for i in range(config_data._max_turns):
history.append(input('HUMAN: '))
source = utter_preprocess(history, self.agent.data_config._max_seq_len)
reply = self.agent.retrieve(source, sess)
print('AGENT: ', reply)
# print('Keyword: {}, Similarity: {:.2f}'.format(self.agent.next_kw, self.agent.score))
history.append(reply)
if is_reach_goal(history[-2] + history[-1], target_kw):
print('Successfully chat to the target \'{}\'.'.format(target_kw))
return
print('Failed by reaching the maximum turn, target: \'{}\'.'.format(target_kw))
if __name__ == '__main__':
flags = tf.flags
# supports kernel / matrix / neural / retrieval / retrieval-stg
flags.DEFINE_string('agent', 'kernel', 'The agent type')
flags.DEFINE_string('times', '100', 'Conversation times')
FLAGS = flags.FLAGS
config_data = importlib.import_module('config.data_config')
config_model = importlib.import_module('config.' + FLAGS.agent)
model = importlib.import_module('model.' + FLAGS.agent)
predictor = model.Predictor(config_model, config_data, 'test')
target_set = []
for line in open('tx_data/test/keywords.txt', 'r').readlines():
target_set = target_set + line.strip().split(' ')
Target_Chat(predictor)
================================================
FILE: config/data_config.py
================================================
import os
data_root = './tx_data'
_corpus = [x.strip() for x in open('tx_data/corpus.txt', 'r').readlines()]
_start_corpus = [x.strip() for x in open('tx_data/start_corpus.txt', 'r').readlines()]
_max_seq_len = 30
_num_neg = 20
_max_turns = 8
_batch_size = 64
_retrieval_candidates = 1000
data_hparams = {
stage: {
"num_epochs": 1,
"shuffle": stage != 'test',
"batch_size": _batch_size,
"datasets": [
{ # dialogue history
"variable_utterance": True,
"max_utterance_cnt": 9,
"max_seq_length": _max_seq_len,
"files": [os.path.join(data_root, '{}/source.txt'.format(stage))],
"vocab_file": os.path.join(data_root, 'vocab.txt'),
"embedding_init": {
"file": os.path.join(data_root, 'embedding.txt'),
"dim": 200,
"read_fn": "load_glove"
},
"data_name": "source"
},
{ # candidate response
"variable_utterance": True,
"max_utterance_cnt": 20,
"max_seq_length": _max_seq_len,
"files": [os.path.join(data_root, '{}/target.txt'.format(stage))],
"vocab_share_with": 0,
"embedding_init_share_with" : 0,
"data_name": "target"
},
{ # context (source keywords)
"files": [os.path.join(data_root, '{}/context.txt'.format(stage))],
"vocab_share_with": 0,
"embedding_init_share_with": 0,
"data_name": "context",
"bos_token": '',
"eos_token": '',
},
{ # target keywords
"files": [os.path.join(data_root, '{}/keywords.txt'.format(stage))],
"vocab_share_with": 0,
"embedding_init_share_with": 0,
"data_name": "keywords",
"bos_token": '',
"eos_token": '',
},
{ # label
"files": [os.path.join(data_root, '{}/label.txt'.format(stage))],
"data_type": "int",
"data_name": "label"
}
]
}
for stage in ['train','valid','test']
}
corpus_hparams = {
"batch_size": _batch_size*2,
"shuffle": False,
"dataset":{
"max_seq_length": _max_seq_len,
"files": [os.path.join(data_root, 'corpus.txt')],
"vocab_file": os.path.join(data_root, 'vocab.txt'),
"data_name": "corpus"
}
}
_keywords_path = 'tx_data/test/keywords_vocab.txt'
_keywords_candi = [x.strip() for x in open(_keywords_path, 'r').readlines()]
_keywords_num = len(_keywords_candi)
_keywords_dict = {}
for i in range(_keywords_num):
_keywords_dict[_keywords_candi[i]] = i
================================================
FILE: config/kernel.py
================================================
_hidden_size = 200
_code_len = 800
_save_path = 'save/kernel/model_1'
_kernel_save_path = 'save/kernel/keyword_1'
_kernel_mu = [0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,1.]
_kernel_sigma = 0.1
_max_epoch = 10
_early_stopping = 2
kernel_opt_hparams = {
"optimizer": {
"type": "AdamOptimizer",
"kwargs": {
"learning_rate": 0.001,
}
},
"learning_rate_decay": {
"type": "inverse_time_decay",
"kwargs": {
"decay_steps": 1600,
"decay_rate": 0.8
},
"start_decay_step": 0,
"end_decay_step": 16000,
},
}
source_encoder_hparams = {
"encoder_minor_type": "BidirectionalRNNEncoder",
"encoder_minor_hparams": {
"rnn_cell_fw": {
"type": "GRUCell",
"kwargs": {
"num_units": _hidden_size,
},
},
"rnn_cell_share_config": True
},
"encoder_major_type": "UnidirectionalRNNEncoder",
"encoder_major_hparams": {
"rnn_cell": {
"type": "GRUCell",
"kwargs": {
"num_units": _hidden_size*2,
},
}
}
}
target_encoder_hparams = {
"rnn_cell_fw": {
"type": "GRUCell",
"kwargs": {
"num_units": _hidden_size,
},
},
"rnn_cell_share_config": True
}
target_kwencoder_hparams = {
"rnn_cell_fw": {
"type": "GRUCell",
"kwargs": {
"num_units": _hidden_size,
},
},
"rnn_cell_share_config": True
}
context_encoder_hparams = {
"rnn_cell": {
"type": "GRUCell",
"kwargs": {
"num_units": _hidden_size,
},
}
}
opt_hparams = {
"optimizer": {
"type": "AdamOptimizer",
"kwargs": {
"learning_rate": 0.001,
}
},
}
================================================
FILE: config/matrix.py
================================================
_hidden_size = 200
_code_len = 800
_save_path = 'save/matrix/model_1'
_matrix_save_path = 'save/matrix/matrix_1.pk'
_max_epoch = 10
_vocab_path = 'tx_data/vocab.txt'
_vocab = [x.strip() for x in open(_vocab_path, 'r').readlines()]
_vocab_size = len(_vocab)
source_encoder_hparams = {
"encoder_minor_type": "BidirectionalRNNEncoder",
"encoder_minor_hparams": {
"rnn_cell_fw": {
"type": "GRUCell",
"kwargs": {
"num_units": _hidden_size,
},
},
"rnn_cell_share_config": True
},
"encoder_major_type": "UnidirectionalRNNEncoder",
"encoder_major_hparams": {
"rnn_cell": {
"type": "GRUCell",
"kwargs": {
"num_units": _hidden_size*2,
},
}
}
}
target_encoder_hparams = {
"rnn_cell_fw": {
"type": "GRUCell",
"kwargs": {
"num_units": _hidden_size,
},
},
"rnn_cell_share_config": True
}
target_kwencoder_hparams = {
"rnn_cell_fw": {
"type": "GRUCell",
"kwargs": {
"num_units": _hidden_size,
},
},
"rnn_cell_share_config": True
}
opt_hparams = {
"optimizer": {
"type": "AdamOptimizer",
"kwargs": {
"learning_rate": 0.001,
}
}
}
================================================
FILE: config/neural.py
================================================
_hidden_size = 200
_code_len = 800
_save_path = 'save/neural/model_1'
_neural_save_path = 'save/neural/keyword_1'
_max_epoch = 10
neural_opt_hparams = {
"optimizer": {
"type": "AdamOptimizer",
"kwargs": {
"learning_rate": 0.005,
}
},
"learning_rate_decay": {
"type": "inverse_time_decay",
"kwargs": {
"decay_steps": 1600,
"decay_rate": 0.8
},
"start_decay_step": 0,
"end_decay_step": 16000,
},
}
source_encoder_hparams = {
"encoder_minor_type": "BidirectionalRNNEncoder",
"encoder_minor_hparams": {
"rnn_cell_fw": {
"type": "GRUCell",
"kwargs": {
"num_units": _hidden_size,
},
},
"rnn_cell_share_config": True
},
"encoder_major_type": "UnidirectionalRNNEncoder",
"encoder_major_hparams": {
"rnn_cell": {
"type": "GRUCell",
"kwargs": {
"num_units": _hidden_size*2,
},
}
}
}
target_encoder_hparams = {
"rnn_cell_fw": {
"type": "GRUCell",
"kwargs": {
"num_units": _hidden_size,
},
},
"rnn_cell_share_config": True
}
target_kwencoder_hparams = {
"rnn_cell_fw": {
"type": "GRUCell",
"kwargs": {
"num_units": _hidden_size,
},
},
"rnn_cell_share_config": True
}
context_encoder_hparams = {
"rnn_cell": {
"type": "GRUCell",
"kwargs": {
"num_units": _hidden_size,
},
}
}
opt_hparams = {
"optimizer": {
"type": "AdamOptimizer",
"kwargs": {
"learning_rate": 0.001,
}
}
}
================================================
FILE: config/retrieval.py
================================================
_hidden_size = 200
_code_len = 200
_save_path = 'save/retrieval/model_1'
_max_epoch = 10
source_encoder_hparams = {
"encoder_minor_type": "UnidirectionalRNNEncoder",
"encoder_minor_hparams": {
"rnn_cell": {
"type": "GRUCell",
"kwargs": {
"num_units": _hidden_size,
},
},
},
"encoder_major_type": "UnidirectionalRNNEncoder",
"encoder_major_hparams": {
"rnn_cell": {
"type": "GRUCell",
"kwargs": {
"num_units": _hidden_size,
},
}
}
}
target_encoder_hparams = {
"rnn_cell": {
"type": "GRUCell",
"kwargs": {
"num_units": _hidden_size,
},
}
}
opt_hparams = {
"optimizer": {
"type": "AdamOptimizer",
"kwargs": {
"learning_rate": 0.001,
}
},
}
================================================
FILE: config/retrieval_stgy.py
================================================
_hidden_size = 200
_code_len = 200
_save_path = 'save/retrieval/model_1'
_max_epoch = 10
source_encoder_hparams = {
"encoder_minor_type": "UnidirectionalRNNEncoder",
"encoder_minor_hparams": {
"rnn_cell": {
"type": "GRUCell",
"kwargs": {
"num_units": _hidden_size,
},
},
},
"encoder_major_type": "UnidirectionalRNNEncoder",
"encoder_major_hparams": {
"rnn_cell": {
"type": "GRUCell",
"kwargs": {
"num_units": _hidden_size,
},
}
}
}
target_encoder_hparams = {
"rnn_cell": {
"type": "GRUCell",
"kwargs": {
"num_units": _hidden_size,
},
}
}
opt_hparams = {
"optimizer": {
"type": "AdamOptimizer",
"kwargs": {
"learning_rate": 0.001,
}
},
}
================================================
FILE: model/kernel.py
================================================
import texar as tx
import tensorflow as tf
import numpy as np
from preprocess.data_utils import kw_tokenize
class Predictor():
def __init__(self, config_model, config_data, mode=None):
self.config = config_model
self.data_config = config_data
self.gpu_config = tf.ConfigProto()
self.gpu_config.gpu_options.allow_growth = True
self.build_model()
def build_model(self):
self.train_data = tx.data.MultiAlignedData(self.data_config.data_hparams['train'])
self.valid_data = tx.data.MultiAlignedData(self.data_config.data_hparams['valid'])
self.test_data = tx.data.MultiAlignedData(self.data_config.data_hparams['test'])
self.iterator = tx.data.TrainTestDataIterator(train=self.train_data, val=self.valid_data, test=self.test_data)
self.vocab = self.train_data.vocab(0)
self.embedder = tx.modules.WordEmbedder(init_value=self.train_data.embedding_init_value(0).word_vecs)
self.kw_embedder = tx.modules.WordEmbedder(init_value=self.train_data.embedding_init_value(0).word_vecs)
self.source_encoder = tx.modules.HierarchicalRNNEncoder(hparams=self.config.source_encoder_hparams)
self.target_encoder = tx.modules.BidirectionalRNNEncoder(hparams=self.config.target_encoder_hparams)
self.target_kwencoder = tx.modules.BidirectionalRNNEncoder(hparams=self.config.target_kwencoder_hparams)
self.linear_transform = tx.modules.MLPTransformConnector(self.config._code_len // 2)
self.linear_matcher = tx.modules.MLPTransformConnector(1)
self.linear_kernel = tx.modules.MLPTransformConnector(1)
self.kw_list = self.vocab.map_tokens_to_ids(tf.convert_to_tensor(self.data_config._keywords_candi))
self.kw_vocab = tx.data.Vocab(self.data_config._keywords_path)
self.keywords_embed = tf.nn.l2_normalize(self.kw_embedder(self.kw_list), axis=1)
def forward_kernel(self, kw_embed, context_ids):
kernel_sigma = self.config._kernel_sigma
mu = tf.convert_to_tensor(self.config._kernel_mu)
mask = tf.cast(context_ids > 3, dtype=tf.float32)
context_embed = self.kw_embedder(context_ids)
context_embed = tf.nn.l2_normalize(context_embed, axis=2)
similarity_matrix = tf.reduce_sum(kw_embed * context_embed, axis=2)
similarity_matrix = tf.tile(tf.expand_dims(similarity_matrix, 2), [1, 1, len(self.config._kernel_mu)])
matching_feature = tf.exp(-(similarity_matrix - mu) ** 2 / (kernel_sigma ** 2))
matching_feature = matching_feature * tf.tile(tf.expand_dims(mask, 2), [1, 1, len(self.config._kernel_mu)])
matching_feature = tf.reduce_sum(matching_feature, axis=1)
matching_score = self.linear_kernel(matching_feature)
matching_score = tf.squeeze(matching_score, 1)
return matching_score
def predict_keywords(self, batch):
keywords_ids = self.kw_vocab.map_tokens_to_ids(batch['keywords_text'])
matching_score = tf.map_fn(lambda kw_embed: self.forward_kernel(kw_embed, batch['context_text_ids']),
self.keywords_embed, dtype=tf.float32, parallel_iterations=True)
matching_score = tf.transpose(matching_score)
matching_score = tf.nn.softmax(matching_score)
kw_labels = tf.map_fn(lambda x: tf.sparse_to_dense(x, [self.kw_vocab.size], 1., 0., False),
keywords_ids, dtype=tf.float32, parallel_iterations=True)[:, 4:]
loss = tf.reduce_sum(-tf.log(matching_score) * kw_labels) / tf.reduce_sum(kw_labels)
kw_ans = tf.arg_max(matching_score, -1)
acc_label = tf.map_fn(lambda x: tf.gather(x[0], x[1]), (kw_labels, kw_ans), dtype=tf.float32)
acc = tf.reduce_mean(acc_label)
kws = tf.nn.top_k(matching_score, k=5)[1]
kws = tf.reshape(kws,[-1])
kws = tf.map_fn(lambda x: self.kw_list[x], kws, dtype=tf.int64)
kws = tf.reshape(kws,[-1, 5])
return loss, acc, kws
def train_keywords(self):
batch = self.iterator.get_next()
loss, acc, _ = self.predict_keywords(batch)
op_step = tf.Variable(0, name='op_step')
train_op = tx.core.get_train_op(loss, global_step=op_step, hparams=self.config.kernel_opt_hparams)
max_val_acc, stopping_flag = 0, 0
self.saver = tf.train.Saver()
with tf.Session(config=self.gpu_config) as sess:
sess.run(tf.global_variables_initializer())
sess.run(tf.local_variables_initializer())
sess.run(tf.tables_initializer())
for epoch_id in range(self.config._max_epoch):
self.iterator.switch_to_train_data(sess)
cur_step = 0
cnt_acc = []
while True:
try:
cur_step += 1
feed = {tx.global_mode(): tf.estimator.ModeKeys.TRAIN}
loss_, acc_ = sess.run([train_op, acc], feed_dict=feed)
cnt_acc.append(acc_)
if cur_step % 100 == 0:
print('batch {}, loss={}, acc1={}'.format(cur_step, loss_, np.mean(cnt_acc[-100:])))
except tf.errors.OutOfRangeError:
break
self.iterator.switch_to_val_data(sess)
cnt_acc = []
while True:
try:
feed = {tx.global_mode(): tf.estimator.ModeKeys.EVAL}
acc_ = sess.run(acc, feed_dict=feed)
cnt_acc.append(acc_)
except tf.errors.OutOfRangeError:
mean_acc = np.mean(cnt_acc)
if mean_acc > max_val_acc:
max_val_acc = mean_acc
self.saver.save(sess, self.config._kernel_save_path)
else:
stopping_flag += 1
print('epoch_id {}, valid acc1={}'.format(epoch_id+1, mean_acc))
break
if stopping_flag >= self.config._early_stopping:
break
def test_keywords(self):
batch = self.iterator.get_next()
loss, acc, kws = self.predict_keywords(batch)
saver = tf.train.Saver()
with tf.Session(config=self.gpu_config) as sess:
sess.run(tf.global_variables_initializer())
sess.run(tf.local_variables_initializer())
sess.run(tf.tables_initializer())
saver.restore(sess, self.config._kernel_save_path)
self.iterator.switch_to_test_data(sess)
cnt_acc, cnt_rec1, cnt_rec3, cnt_rec5 = [], [], [], []
while True:
try:
feed = {tx.global_mode(): tf.estimator.ModeKeys.PREDICT}
acc_, kw_ans, kw_labels = sess.run([acc, kws, batch['keywords_text_ids']], feed_dict=feed)
cnt_acc.append(acc_)
rec = [0,0,0,0,0]
sum_kws = 0
for i in range(len(kw_ans)):
sum_kws += sum(kw_labels[i] > 3)
for j in range(5):
if kw_ans[i][j] in kw_labels[i]:
for k in range(j, 5):
rec[k] += 1
cnt_rec1.append(rec[0]/sum_kws)
cnt_rec3.append(rec[2]/sum_kws)
cnt_rec5.append(rec[4]/sum_kws)
except tf.errors.OutOfRangeError:
print('test_kw acc@1={:.4f}, rec@1={:.4f}, rec@3={:.4f}, rec@5={:.4f}'.format(
np.mean(cnt_acc), np.mean(cnt_rec1), np.mean(cnt_rec3), np.mean(cnt_rec5)))
break
def forward(self, batch):
matching_score = tf.map_fn(lambda kw_embed: self.forward_kernel(kw_embed, batch['context_text_ids']),
self.keywords_embed, dtype=tf.float32, parallel_iterations=True)
matching_score = tf.transpose(matching_score)
kw_weight, predict_kw = tf.nn.top_k(matching_score, k=3)
predict_kw = tf.reshape(predict_kw,[-1])
predict_kw = tf.map_fn(lambda x: self.kw_list[x], predict_kw, dtype=tf.int64)
predict_kw = tf.reshape(predict_kw,[-1,3])
embed_code = self.embedder(predict_kw)
embed_code = tf.reduce_sum(embed_code, axis=1)
embed_code = self.linear_transform(embed_code)
source_embed = self.embedder(batch['source_text_ids'])
target_embed = self.embedder(batch['target_text_ids']) # bs * 20 * 32 * 200
target_embed = tf.reshape(target_embed,[-1, self.data_config._max_seq_len+2, self.embedder.dim]) # (bs * 20) * 32 * 200
target_length = tf.reshape(batch['target_length'],[-1]) # (bs * 20) * 32 * 200
source_code = self.source_encoder(
source_embed,
sequence_length_minor=batch['source_length'],
sequence_length_major=batch['source_utterance_cnt'])[1]
target_code = self.target_encoder(
target_embed,
sequence_length=target_length)[1]
target_kwcode = self.target_kwencoder(
target_embed,
sequence_length=target_length)[1]
target_code = tf.concat([target_code[0], target_code[1], target_kwcode[0], target_kwcode[1]], -1)
target_code = tf.reshape(target_code, [-1,20,self.config._code_len])
source_code = tf.concat([source_code,embed_code], -1)
source_code = tf.expand_dims(source_code, 1)
source_code = tf.tile(source_code, [1,20,1])
feature_code = target_code * source_code
feature_code = tf.reshape(feature_code,[-1,self.config._code_len])
logits = self.linear_matcher(feature_code)
logits = tf.reshape(logits,[-1,20])
labels = tf.one_hot(batch['label'], 20)
loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=labels, logits=logits))
ans = tf.arg_max(logits, -1)
acc = tx.evals.accuracy(batch['label'], ans)
rank = tf.nn.top_k(logits, k=20)[1]
return loss, acc, rank
def train(self):
batch = self.iterator.get_next()
loss_t, acc_t, _ = self.predict_keywords(batch)
kw_saver = tf.train.Saver()
loss, acc, _ = self.forward(batch)
retrieval_step = tf.Variable(0, name='retrieval_step')
train_op = tx.core.get_train_op(loss, global_step=retrieval_step, hparams=self.config.opt_hparams)
max_val_acc, stopping_flag = 0, 0
with tf.Session(config=self.gpu_config) as sess:
sess.run(tf.tables_initializer())
sess.run(tf.global_variables_initializer())
sess.run(tf.local_variables_initializer())
kw_saver.restore(sess, self.config._kernel_save_path)
saver = tf.train.Saver()
for epoch_id in range(self.config._max_epoch):
self.iterator.switch_to_train_data(sess)
cur_step = 0
cnt_acc, cnt_kwacc = [],[]
while True:
try:
cur_step += 1
feed = {tx.global_mode(): tf.estimator.ModeKeys.TRAIN}
loss, acc_, acc_kw = sess.run([train_op, acc, acc_t], feed_dict=feed)
cnt_acc.append(acc_)
cnt_kwacc.append(acc_kw)
if cur_step % 200 == 0:
print('batch {}, loss={}, acc1={}, kw_acc1={}'.format(cur_step, loss,
np.mean(cnt_acc[-200:]) ,np.mean(cnt_kwacc[-200:])))
except tf.errors.OutOfRangeError:
break
self.iterator.switch_to_val_data(sess)
cnt_acc, cnt_kwacc = [],[]
while True:
try:
feed = {tx.global_mode(): tf.estimator.ModeKeys.EVAL}
acc_, acc_kw = sess.run([acc, acc_t], feed_dict=feed)
cnt_acc.append(acc_)
cnt_kwacc.append(acc_kw)
except tf.errors.OutOfRangeError:
mean_acc = np.mean(cnt_acc)
print('valid acc1={}, kw_acc1={}'.format(mean_acc, np.mean(cnt_kwacc)))
if mean_acc > max_val_acc:
max_val_acc = mean_acc
saver.save(sess, self.config._save_path)
else:
stopping_flag += 1
break
if stopping_flag >= self.config._early_stopping:
break
def test(self):
batch = self.iterator.get_next()
loss, acc, rank = self.forward(batch)
with tf.Session(config=self.gpu_config) as sess:
sess.run(tf.tables_initializer())
self.saver = tf.train.Saver()
self.saver.restore(sess, self.config._save_path)
self.iterator.switch_to_test_data(sess)
rank_cnt = []
while True:
try:
feed = {tx.global_mode(): tf.estimator.ModeKeys.PREDICT}
ranks, labels = sess.run([rank, batch['label']], feed_dict=feed)
for i in range(len(ranks)):
rank_cnt.append(np.where(ranks[i]==labels[i])[0][0])
except tf.errors.OutOfRangeError:
rec = [0,0,0,0,0]
MRR = 0
for rank in rank_cnt:
for i in range(5):
rec[i] += (rank <= i)
MRR += 1 / (rank+1)
print('test rec1@20={:.4f}, rec3@20={:.4f}, rec5@20={:.4f}, MRR={:.4f}'.format(
rec[0]/len(rank_cnt), rec[2]/len(rank_cnt), rec[4]/len(rank_cnt), MRR/len(rank_cnt)))
break
def retrieve_init(self, sess):
data_batch = self.iterator.get_next()
loss, acc, _ = self.forward(data_batch)
self.corpus = self.data_config._corpus
self.corpus_data = tx.data.MonoTextData(self.data_config.corpus_hparams)
corpus_iterator = tx.data.DataIterator(self.corpus_data)
batch = corpus_iterator.get_next()
corpus_embed = self.embedder(batch['corpus_text_ids'])
utter_code = self.target_encoder(corpus_embed, sequence_length=batch['corpus_length'])[1]
utter_kwcode = self.target_kwencoder(corpus_embed, sequence_length=batch['corpus_length'])[1]
utter_code = tf.concat([utter_code[0], utter_code[1], utter_kwcode[0], utter_kwcode[1]], -1)
self.corpus_code = np.zeros([0, self.config._code_len])
corpus_iterator.switch_to_dataset(sess)
sess.run(tf.tables_initializer())
saver = tf.train.Saver()
saver.restore(sess, self.config._save_path)
feed = {tx.global_mode(): tf.estimator.ModeKeys.PREDICT}
while True:
try:
utter_code_ = sess.run(utter_code, feed_dict=feed)
self.corpus_code = np.concatenate([self.corpus_code, utter_code_], axis=0)
except tf.errors.OutOfRangeError:
break
self.kw_embedding = sess.run(self.keywords_embed)
# predict keyword
self.context_input = tf.placeholder(dtype=object)
context_ids = tf.expand_dims(self.vocab.map_tokens_to_ids(self.context_input), 0)
matching_score = tf.map_fn(lambda kw_embed: self.forward_kernel(kw_embed, context_ids),
self.keywords_embed, dtype=tf.float32, parallel_iterations=True)
self.candi_output = tf.nn.top_k(tf.squeeze(matching_score, 1), self.data_config._keywords_num)[1]
# retrieve
self.minor_length_input = tf.placeholder(dtype=tf.int32, shape=(1, 9))
self.major_length_input = tf.placeholder(dtype=tf.int32, shape=(1))
self.history_input = tf.placeholder(dtype=object, shape=(9, self.data_config._max_seq_len + 2))
self.kw_input = tf.placeholder(dtype=tf.int32)
history_ids = self.vocab.map_tokens_to_ids(self.history_input)
history_embed = self.embedder(history_ids)
history_code = self.source_encoder(tf.expand_dims(history_embed, axis=0),
sequence_length_minor=self.minor_length_input,
sequence_length_major=self.major_length_input)[1]
self.next_kw_ids = self.kw_list[self.kw_input]
embed_code = tf.expand_dims(self.embedder(self.next_kw_ids), 0)
embed_code = self.linear_transform(embed_code)
history_code = tf.concat([history_code, embed_code], 1)
select_corpus = tf.cast(self.corpus_code, dtype=tf.float32)
feature_code = self.linear_matcher(select_corpus * history_code)
self.ans_output = tf.nn.top_k(tf.squeeze(feature_code,1), k=self.data_config._retrieval_candidates)[1]
def retrieve(self, history_all, sess):
history, seq_len, turns, context, context_len = history_all
kw_candi = sess.run(self.candi_output, feed_dict={self.context_input: context[:context_len]})
for kw in kw_candi:
tmp_score = sum(self.kw_embedding[kw] * self.kw_embedding[self.data_config._keywords_dict[self.target]])
if tmp_score > self.score:
self.score = tmp_score
self.next_kw = self.data_config._keywords_candi[kw]
break
ans = sess.run(self.ans_output, feed_dict={self.history_input: history,
self.minor_length_input: [seq_len], self.major_length_input: [turns],
self.kw_input: self.data_config._keywords_dict[self.next_kw]})
flag = 0
reply = self.corpus[ans[0]]
for i in ans:
if i in self.reply_list: # avoid repeat
continue
for wd in kw_tokenize(self.corpus[i]):
if wd in self.data_config._keywords_candi:
tmp_score = sum(self.kw_embedding[self.data_config._keywords_dict[wd]] *
self.kw_embedding[self.data_config._keywords_dict[self.target]])
if tmp_score > self.score:
reply = self.corpus[i]
self.score = tmp_score
self.next_kw = wd
flag = 1
break
if flag == 0:
continue
break
return reply
================================================
FILE: model/matrix.py
================================================
import texar as tx
import tensorflow as tf
import numpy as np
import pickle
class Predictor():
def __init__(self, config_model, config_data, mode=None):
self.config = config_model
self.data_config = config_data
self.gpu_config = tf.ConfigProto()
self.gpu_config.gpu_options.allow_growth = True
self.build_model(mode)
def build_model(self, mode):
self.train_data = tx.data.MultiAlignedData(self.data_config.data_hparams['train'])
self.valid_data = tx.data.MultiAlignedData(self.data_config.data_hparams['valid'])
self.test_data = tx.data.MultiAlignedData(self.data_config.data_hparams['test'])
self.iterator = tx.data.TrainTestDataIterator(train=self.train_data, val=self.valid_data, test=self.test_data)
self.vocab = self.train_data.vocab(0)
self.source_encoder = tx.modules.HierarchicalRNNEncoder(hparams=self.config.source_encoder_hparams)
self.target_encoder = tx.modules.BidirectionalRNNEncoder(hparams=self.config.target_encoder_hparams)
self.target_kwencoder = tx.modules.BidirectionalRNNEncoder(hparams=self.config.target_kwencoder_hparams)
self.linear_transform = tx.modules.MLPTransformConnector(self.config._code_len // 2)
self.linear_matcher = tx.modules.MLPTransformConnector(1)
self.embedder = tx.modules.WordEmbedder(init_value=self.train_data.embedding_init_value(0).word_vecs)
self.kw_list = self.vocab.map_tokens_to_ids(tf.convert_to_tensor(self.data_config._keywords_candi))
self.kw_vocab = tx.data.Vocab(self.data_config._keywords_path)
if mode == 'train_kw':
self.pmi_matrix = np.zeros([self.config._vocab_size+4, self.data_config._keywords_num])
else:
with open(self.config._matrix_save_path, 'rb') as f:
matrix = pickle.load(f)
self.pmi_matrix = tf.convert_to_tensor(matrix,dtype=tf.float32)
def forward_matrix(self, context_ids):
matching_score = tf.gather(self.pmi_matrix, context_ids)
return tf.reduce_sum(tf.log(matching_score), axis=0)
def predict_keywords(self, batch):
keywords_ids = self.kw_vocab.map_tokens_to_ids(batch['keywords_text'])
matching_score = tf.map_fn(lambda x: self.forward_matrix(x), batch['context_text_ids'],
dtype=tf.float32, parallel_iterations=True)
kw_labels = tf.map_fn(lambda x: tf.sparse_to_dense(x, [self.kw_vocab.size], 1., 0., False),
keywords_ids, dtype=tf.float32, parallel_iterations=True)[:, 4:]
kw_ans = tf.arg_max(matching_score, -1)
acc_label = tf.map_fn(lambda x: tf.gather(x[0], x[1]), (kw_labels, kw_ans), dtype=tf.float32)
acc = tf.reduce_mean(acc_label)
kws = tf.nn.top_k(matching_score, k=5)[1]
kws = tf.reshape(kws,[-1])
kws = tf.map_fn(lambda x: self.kw_list[x], kws, dtype=tf.int64)
kws = tf.reshape(kws,[-1, 5])
return acc, kws
def train_keywords(self):
batch = self.iterator.get_next()
acc, _ = self.predict_keywords(batch)
with tf.Session(config=self.gpu_config) as sess:
sess.run(tf.global_variables_initializer())
sess.run(tf.local_variables_initializer())
sess.run(tf.tables_initializer())
self.iterator.switch_to_train_data(sess)
batchid = 0
while True:
try:
batchid += 1
feed = {tx.global_mode(): tf.estimator.ModeKeys.TRAIN}
source_keywords, target_keywords = sess.run([batch['context_text_ids'],
batch['keywords_text_ids']], feed_dict=feed)
for i in range(len(source_keywords)):
for skw_id in source_keywords[i]:
if skw_id == 0:
break
for tkw_id in target_keywords[i]:
if skw_id >= 3 and tkw_id >= 3:
tkw = self.config._vocab[tkw_id-4]
if tkw in self.data_config._keywords_candi:
tkw_id = self.data_config._keywords_dict[tkw]
self.pmi_matrix[skw_id][tkw_id] += 1
except tf.errors.OutOfRangeError:
break
self.pmi_matrix += 0.5
self.pmi_matrix = self.pmi_matrix / (np.sum(self.pmi_matrix, axis=0) + 1)
with open(self.config._matrix_save_path,'wb') as f:
pickle.dump(self.pmi_matrix, f)
def test_keywords(self):
batch = self.iterator.get_next()
acc, kws = self.predict_keywords(batch)
with tf.Session(config=self.gpu_config) as sess:
sess.run(tf.global_variables_initializer())
sess.run(tf.local_variables_initializer())
sess.run(tf.tables_initializer())
self.iterator.switch_to_test_data(sess)
cnt_acc, cnt_rec1, cnt_rec3, cnt_rec5 = [], [], [], []
while True:
try:
feed = {tx.global_mode(): tf.estimator.ModeKeys.PREDICT}
acc_, kw_ans, kw_labels = sess.run([acc, kws, batch['keywords_text_ids']], feed_dict=feed)
cnt_acc.append(acc_)
rec = [0,0,0,0,0]
sum_kws = 0
for i in range(len(kw_ans)):
sum_kws += sum(kw_labels[i] > 3)
for j in range(5):
if kw_ans[i][j] in kw_labels[i]:
for k in range(j, 5):
rec[k] += 1
cnt_rec1.append(rec[0]/sum_kws)
cnt_rec3.append(rec[2]/sum_kws)
cnt_rec5.append(rec[4]/sum_kws)
except tf.errors.OutOfRangeError:
print('test_kw acc@1={:.4f}, rec@1={:.4f}, rec@3={:.4f}, rec@5={:.4f}'.format(
np.mean(cnt_acc), np.mean(cnt_rec1), np.mean(cnt_rec3), np.mean(cnt_rec5)))
break
def forward(self, batch):
matching_score = tf.map_fn(lambda x: self.forward_matrix(x), batch['context_text_ids'],
dtype=tf.float32, parallel_iterations=True)
kw_weight, predict_kw = tf.nn.top_k(matching_score, k=3)
predict_kw = tf.reshape(predict_kw, [-1])
predict_kw = tf.map_fn(lambda x: self.kw_list[x], predict_kw, dtype=tf.int64)
predict_kw = tf.reshape(predict_kw, [-1, 3])
embed_code = self.embedder(predict_kw)
embed_code = tf.reduce_sum(embed_code, axis=1)
embed_code = self.linear_transform(embed_code)
source_embed = self.embedder(batch['source_text_ids'])
target_embed = self.embedder(batch['target_text_ids'])
target_embed = tf.reshape(target_embed, [-1, self.data_config._max_seq_len + 2, self.embedder.dim])
target_length = tf.reshape(batch['target_length'], [-1])
source_code = self.source_encoder(
source_embed,
sequence_length_minor=batch['source_length'],
sequence_length_major=batch['source_utterance_cnt'])[1]
target_code = self.target_encoder(
target_embed,
sequence_length=target_length)[1]
target_kwcode = self.target_kwencoder(
target_embed,
sequence_length=target_length)[1]
target_code = tf.concat([target_code[0], target_code[1], target_kwcode[0], target_kwcode[1]], -1)
target_code = tf.reshape(target_code, [-1, 20, self.config._code_len])
source_code = tf.concat([source_code, embed_code], -1)
source_code = tf.expand_dims(source_code, 1)
source_code = tf.tile(source_code, [1, 20, 1])
feature_code = target_code * source_code
feature_code = tf.reshape(feature_code, [-1, self.config._code_len])
logits = self.linear_matcher(feature_code)
logits = tf.reshape(logits, [-1, 20])
labels = tf.one_hot(batch['label'], 20)
loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=labels, logits=logits))
ans = tf.arg_max(logits, -1)
acc = tx.evals.accuracy(batch['label'], ans)
rank = tf.nn.top_k(logits, k=20)[1]
return loss, acc, rank
def train(self):
batch = self.iterator.get_next()
loss, acc, _ = self.forward(batch)
op_step = tf.Variable(0, name='retrieval_step')
train_op = tx.core.get_train_op(loss, global_step=op_step, hparams=self.config.opt_hparams)
max_val_acc = 0.
with tf.Session(config=self.gpu_config) as sess:
sess.run(tf.tables_initializer())
sess.run(tf.global_variables_initializer())
sess.run(tf.local_variables_initializer())
saver = tf.train.Saver()
for epoch_id in range(self.config._max_epoch):
self.iterator.switch_to_train_data(sess)
cur_step = 0
cnt_acc = []
while True:
try:
cur_step += 1
feed = {tx.global_mode(): tf.estimator.ModeKeys.TRAIN}
loss, acc_ = sess.run([train_op, acc], feed_dict=feed)
cnt_acc.append(acc_)
if cur_step % 200 == 0:
print('batch {}, loss={}, acc1={}'.format(cur_step, loss, np.mean(cnt_acc[-200:])))
except tf.errors.OutOfRangeError:
break
self.iterator.switch_to_val_data(sess)
cnt_acc= []
while True:
try:
feed = {tx.global_mode(): tf.estimator.ModeKeys.EVAL}
acc_ = sess.run(acc, feed_dict=feed)
cnt_acc.append(acc_)
except tf.errors.OutOfRangeError:
mean_acc = np.mean(cnt_acc)
print('valid acc1={}'.format(mean_acc))
if mean_acc > max_val_acc:
max_val_acc = mean_acc
saver.save(sess, self.config._save_path)
break
def test(self):
batch = self.iterator.get_next()
loss, acc, rank = self.forward(batch)
with tf.Session(config=self.gpu_config) as sess:
sess.run(tf.tables_initializer())
self.saver = tf.train.Saver()
self.saver.restore(sess, self.config._save_path)
self.iterator.switch_to_test_data(sess)
rank_cnt = []
while True:
try:
feed = {tx.global_mode(): tf.estimator.ModeKeys.PREDICT}
ranks, labels = sess.run([rank, batch['label']], feed_dict=feed)
for i in range(len(ranks)):
rank_cnt.append(np.where(ranks[i]==labels[i])[0][0])
except tf.errors.OutOfRangeError:
rec = [0,0,0,0,0]
MRR = 0
for rank in rank_cnt:
for i in range(5):
rec[i] += (rank <= i)
MRR += 1 / (rank+1)
print('test rec1@20={:.4f}, rec3@20={:.4f}, rec5@20={:.4f}, MRR={:.4f}'.format(
rec[0]/len(rank_cnt), rec[2]/len(rank_cnt), rec[4]/len(rank_cnt), MRR/len(rank_cnt)))
break
def retrieve_init(self, sess):
data_batch = self.iterator.get_next()
loss, acc, _ = self.forward(data_batch)
self.corpus = self.data_config._corpus
self.corpus_data = tx.data.MonoTextData(self.data_config.corpus_hparams)
corpus_iterator = tx.data.DataIterator(self.corpus_data)
batch = corpus_iterator.get_next()
corpus_embed = self.embedder(batch['corpus_text_ids'])
utter_code = self.target_encoder(corpus_embed, sequence_length=batch['corpus_length'])[1]
utter_kwcode = self.target_kwencoder(corpus_embed, sequence_length=batch['corpus_length'])[1]
utter_code = tf.concat([utter_code[0], utter_code[1], utter_kwcode[0], utter_kwcode[1]], -1)
self.corpus_code = np.zeros([0, self.config._code_len])
corpus_iterator.switch_to_dataset(sess)
sess.run(tf.tables_initializer())
saver = tf.train.Saver()
saver.restore(sess, self.config._save_path)
feed = {tx.global_mode(): tf.estimator.ModeKeys.PREDICT}
while True:
try:
utter_code_ = sess.run(utter_code, feed_dict=feed)
self.corpus_code = np.concatenate([self.corpus_code, utter_code_], axis=0)
except tf.errors.OutOfRangeError:
break
self.minor_length_input = tf.placeholder(dtype=tf.int32, shape=(1, 9))
self.major_length_input = tf.placeholder(dtype=tf.int32, shape=(1))
self.history_input = tf.placeholder(dtype=object, shape=(9, self.data_config._max_seq_len + 2))
self.keywords_embed = tf.nn.l2_normalize(self.embedder(self.kw_list), axis=1)
self.kw_embedding = sess.run(self.keywords_embed)
# predict keyword
self.context_input = tf.placeholder(dtype=object)
context_ids = self.vocab.map_tokens_to_ids(self.context_input)
matching_score = self.forward_matrix(context_ids)
self.candi_output =tf.nn.top_k(matching_score, self.data_config._keywords_num)[1]
# retrieve
self.minor_length_input = tf.placeholder(dtype=tf.int32, shape=(1, 9))
self.major_length_input = tf.placeholder(dtype=tf.int32, shape=(1))
self.history_input = tf.placeholder(dtype=object, shape=(9, self.data_config._max_seq_len + 2))
self.kw_input = tf.placeholder(dtype=tf.int32)
history_ids = self.vocab.map_tokens_to_ids(self.history_input)
history_embed = self.embedder(history_ids)
history_code = self.source_encoder(tf.expand_dims(history_embed, axis=0),
sequence_length_minor=self.minor_length_input,
sequence_length_major=self.major_length_input)[1]
self.next_kw_ids = self.kw_list[self.kw_input]
embed_code = tf.expand_dims(self.embedder(self.next_kw_ids), 0)
embed_code = self.linear_transform(embed_code)
history_code = tf.concat([history_code, embed_code], 1)
select_corpus = tf.cast(self.corpus_code, dtype=tf.float32)
feature_code = self.linear_matcher(select_corpus * history_code)
self.ans_output = tf.nn.top_k(tf.squeeze(feature_code,1), k=self.data_config._retrieval_candidates)[1]
def retrieve(self, history_all, sess):
history, seq_len, turns, context, context_len = history_all
kw_candi = sess.run(self.candi_output, feed_dict={self.context_input: context[:context_len]})
for kw in kw_candi:
tmp_score = sum(self.kw_embedding[kw] * self.kw_embedding[self.data_config._keywords_dict[self.target]])
if tmp_score > self.score:
self.score = tmp_score
self.next_kw = self.data_config._keywords_candi[kw]
break
ans = sess.run(self.ans_output, feed_dict={self.history_input: history,
self.minor_length_input: [seq_len], self.major_length_input: [turns],
self.kw_input: self.data_config._keywords_dict[self.next_kw]})
for i in range(self.data_config._max_turns + 1):
if ans[i] not in self.reply_list:
self.reply_list.append(ans[i])
reply = self.corpus[ans[i]]
break
return reply
================================================
FILE: model/neural.py
================================================
import texar as tx
import tensorflow as tf
import numpy as np
from preprocess.data_utils import kw_tokenize
class Predictor():
def __init__(self, config_model, config_data, mode=None):
self.config = config_model
self.data_config = config_data
self.gpu_config = tf.ConfigProto()
self.gpu_config.gpu_options.allow_growth = True
self.build_model()
def build_model(self):
self.train_data = tx.data.MultiAlignedData(self.data_config.data_hparams['train'])
self.valid_data = tx.data.MultiAlignedData(self.data_config.data_hparams['valid'])
self.test_data = tx.data.MultiAlignedData(self.data_config.data_hparams['test'])
self.iterator = tx.data.TrainTestDataIterator(train=self.train_data, val=self.valid_data, test=self.test_data)
self.vocab = self.train_data.vocab(0)
self.source_encoder = tx.modules.HierarchicalRNNEncoder(hparams=self.config.source_encoder_hparams)
self.target_encoder = tx.modules.BidirectionalRNNEncoder(hparams=self.config.target_encoder_hparams)
self.target_kwencoder = tx.modules.BidirectionalRNNEncoder(hparams=self.config.target_kwencoder_hparams)
self.linear_transform = tx.modules.MLPTransformConnector(self.config._code_len // 2)
self.linear_matcher = tx.modules.MLPTransformConnector(1)
self.context_encoder = tx.modules.UnidirectionalRNNEncoder(hparams=self.config.context_encoder_hparams)
self.predict_layer = tx.modules.MLPTransformConnector(self.data_config._keywords_num)
self.embedder = tx.modules.WordEmbedder(init_value=self.train_data.embedding_init_value(0).word_vecs)
self.kw_list = self.vocab.map_tokens_to_ids(tf.convert_to_tensor(self.data_config._keywords_candi))
self.kw_vocab = tx.data.Vocab(self.data_config._keywords_path)
def forward_neural(self, context_ids, context_length):
context_embed = self.embedder(context_ids)
context_code = self.context_encoder(context_embed, sequence_length=context_length)[1]
keyword_score = self.predict_layer(context_code)
return keyword_score
def predict_keywords(self, batch):
matching_score = self.forward_neural(batch['context_text_ids'], batch['context_length'])
keywords_ids = self.kw_vocab.map_tokens_to_ids(batch['keywords_text'])
kw_labels = tf.map_fn(lambda x: tf.sparse_to_dense(x, [self.kw_vocab.size], 1., 0., False),
keywords_ids, dtype=tf.float32, parallel_iterations=True)[:, 4:]
loss = tf.nn.sigmoid_cross_entropy_with_logits(labels=kw_labels, logits=matching_score)
loss = tf.reduce_mean(loss)
kw_ans = tf.arg_max(matching_score, -1)
acc_label = tf.map_fn(lambda x: tf.gather(x[0], x[1]), (kw_labels, kw_ans), dtype=tf.float32)
acc = tf.reduce_mean(acc_label)
kws = tf.nn.top_k(matching_score, k=5)[1]
kws = tf.reshape(kws,[-1])
kws = tf.map_fn(lambda x: self.kw_list[x], kws, dtype=tf.int64)
kws = tf.reshape(kws,[-1, 5])
return loss, acc, kws
def train_keywords(self):
batch = self.iterator.get_next()
loss, acc, _ = self.predict_keywords(batch)
op_step = tf.Variable(0, name='op_step')
train_op = tx.core.get_train_op(loss, global_step=op_step, hparams=self.config.neural_opt_hparams)
max_val_acc = 0.
self.saver = tf.train.Saver()
with tf.Session(config=self.gpu_config) as sess:
sess.run(tf.global_variables_initializer())
sess.run(tf.local_variables_initializer())
sess.run(tf.tables_initializer())
for epoch_id in range(self.config._max_epoch):
self.iterator.switch_to_train_data(sess)
cur_step = 0
cnt_acc = []
while True:
try:
cur_step += 1
feed = {tx.global_mode(): tf.estimator.ModeKeys.TRAIN}
loss_, acc_ = sess.run([train_op, acc], feed_dict=feed)
cnt_acc.append(acc_)
if cur_step % 200 == 0:
print('batch {}, loss={}, acc1={}'.format(cur_step, loss_, np.mean(cnt_acc[-200:])))
except tf.errors.OutOfRangeError:
break
self.iterator.switch_to_val_data(sess)
cnt_acc = []
while True:
try:
feed = {tx.global_mode(): tf.estimator.ModeKeys.EVAL}
acc_ = sess.run(acc, feed_dict=feed)
cnt_acc.append(acc_)
except tf.errors.OutOfRangeError:
mean_acc = np.mean(cnt_acc)
if mean_acc > max_val_acc:
max_val_acc = mean_acc
self.saver.save(sess, self.config._neural_save_path)
print('epoch_id {}, valid acc1={}'.format(epoch_id+1, mean_acc))
break
def test_keywords(self):
batch = self.iterator.get_next()
loss, acc, kws = self.predict_keywords(batch)
saver = tf.train.Saver()
with tf.Session(config=self.gpu_config) as sess:
sess.run(tf.global_variables_initializer())
sess.run(tf.local_variables_initializer())
sess.run(tf.tables_initializer())
saver.restore(sess, self.config._neural_save_path)
self.iterator.switch_to_test_data(sess)
cnt_acc, cnt_rec1, cnt_rec3, cnt_rec5 = [], [], [], []
while True:
try:
feed = {tx.global_mode(): tf.estimator.ModeKeys.PREDICT}
acc_, kw_ans, kw_labels = sess.run([acc, kws, batch['keywords_text_ids']], feed_dict=feed)
cnt_acc.append(acc_)
rec = [0,0,0,0,0]
sum_kws = 0
for i in range(len(kw_ans)):
sum_kws += sum(kw_labels[i] > 3)
for j in range(5):
if kw_ans[i][j] in kw_labels[i]:
for k in range(j, 5):
rec[k] += 1
cnt_rec1.append(rec[0]/sum_kws)
cnt_rec3.append(rec[2]/sum_kws)
cnt_rec5.append(rec[4]/sum_kws)
except tf.errors.OutOfRangeError:
print('test_kw acc@1={:.4f}, rec@1={:.4f}, rec@3={:.4f}, rec@5={:.4f}'.format(
np.mean(cnt_acc), np.mean(cnt_rec1), np.mean(cnt_rec3), np.mean(cnt_rec5)))
break
def forward(self, batch):
matching_score = self.forward_neural(batch['context_text_ids'], batch['context_length'])
kw_weight, predict_kw = tf.nn.top_k(matching_score, k=3)
predict_kw = tf.reshape(predict_kw, [-1])
predict_kw = tf.map_fn(lambda x: self.kw_list[x], predict_kw, dtype=tf.int64)
predict_kw = tf.reshape(predict_kw, [-1, 3])
embed_code = self.embedder(predict_kw)
embed_code = tf.reduce_sum(embed_code, axis=1)
embed_code = self.linear_transform(embed_code)
source_embed = self.embedder(batch['source_text_ids'])
target_embed = self.embedder(batch['target_text_ids'])
target_embed = tf.reshape(target_embed, [-1, self.data_config._max_seq_len + 2, self.embedder.dim])
target_length = tf.reshape(batch['target_length'], [-1])
source_code = self.source_encoder(
source_embed,
sequence_length_minor=batch['source_length'],
sequence_length_major=batch['source_utterance_cnt'])[1] #
target_code = self.target_encoder(
target_embed,
sequence_length=target_length)[1]
target_kwcode = self.target_kwencoder(
target_embed,
sequence_length=target_length)[1]
target_code = tf.concat([target_code[0], target_code[1], target_kwcode[0], target_kwcode[1]], -1)
target_code = tf.reshape(target_code, [-1, 20, self.config._code_len])
source_code = tf.concat([source_code, embed_code], -1)
source_code = tf.expand_dims(source_code, 1)
source_code = tf.tile(source_code, [1, 20, 1])
feature_code = target_code * source_code
feature_code = tf.reshape(feature_code, [-1, self.config._code_len])
logits = self.linear_matcher(feature_code)
logits = tf.reshape(logits, [-1, 20])
labels = tf.one_hot(batch['label'], 20)
loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=labels, logits=logits))
ans = tf.arg_max(logits, -1)
acc = tx.evals.accuracy(batch['label'], ans)
rank = tf.nn.top_k(logits, k=20)[1]
return loss, acc, rank
def train(self):
batch = self.iterator.get_next()
kw_loss, kw_acc, _ = self.predict_keywords(batch)
kw_saver = tf.train.Saver()
loss, acc, _ = self.forward(batch)
op_step = tf.Variable(0, name='retrieval_step')
train_op = tx.core.get_train_op(loss, global_step=op_step, hparams=self.config.opt_hparams)
max_val_acc = 0.
with tf.Session(config=self.gpu_config) as sess:
sess.run(tf.tables_initializer())
sess.run(tf.global_variables_initializer())
sess.run(tf.local_variables_initializer())
kw_saver.restore(sess, self.config._neural_save_path)
saver = tf.train.Saver()
for epoch_id in range(self.config._max_epoch):
self.iterator.switch_to_train_data(sess)
cur_step = 0
cnt_acc = []
while True:
try:
cur_step += 1
feed = {tx.global_mode(): tf.estimator.ModeKeys.TRAIN}
loss, acc_ = sess.run([train_op, acc], feed_dict=feed)
cnt_acc.append(acc_)
if cur_step % 200 == 0:
print('batch {}, loss={}, acc1={}'.format(cur_step, loss, np.mean(cnt_acc[-200:])))
except tf.errors.OutOfRangeError:
break
self.iterator.switch_to_val_data(sess)
cnt_acc, cnt_kwacc = [], []
while True:
try:
feed = {tx.global_mode(): tf.estimator.ModeKeys.EVAL}
acc_, kw_acc_ = sess.run([acc, kw_acc], feed_dict=feed)
cnt_acc.append(acc_)
cnt_kwacc.append(kw_acc_)
except tf.errors.OutOfRangeError:
mean_acc = np.mean(cnt_acc)
print('valid acc1={}, kw_acc1={}'.format(mean_acc, np.mean(cnt_kwacc)))
if mean_acc > max_val_acc:
max_val_acc = mean_acc
saver.save(sess, self.config._save_path)
break
def test(self):
batch = self.iterator.get_next()
loss, acc, rank = self.forward(batch)
with tf.Session(config=self.gpu_config) as sess:
sess.run(tf.tables_initializer())
self.saver = tf.train.Saver()
self.saver.restore(sess, self.config._save_path)
self.iterator.switch_to_test_data(sess)
rank_cnt = []
while True:
try:
feed = {tx.global_mode(): tf.estimator.ModeKeys.PREDICT}
ranks, labels = sess.run([rank, batch['label']], feed_dict=feed)
for i in range(len(ranks)):
rank_cnt.append(np.where(ranks[i]==labels[i])[0][0])
except tf.errors.OutOfRangeError:
rec = [0,0,0,0,0]
MRR = 0
for rank in rank_cnt:
for i in range(5):
rec[i] += (rank <= i)
MRR += 1 / (rank+1)
print('test rec1@20={:.4f}, rec3@20={:.4f}, rec5@20={:.4f}, MRR={:.4f}'.format(
rec[0]/len(rank_cnt), rec[2]/len(rank_cnt), rec[4]/len(rank_cnt), MRR/len(rank_cnt)))
break
def retrieve_init(self, sess):
data_batch = self.iterator.get_next()
loss, acc, _ = self.forward(data_batch)
self.corpus = self.data_config._corpus
self.corpus_data = tx.data.MonoTextData(self.data_config.corpus_hparams)
corpus_iterator = tx.data.DataIterator(self.corpus_data)
batch = corpus_iterator.get_next()
corpus_embed = self.embedder(batch['corpus_text_ids'])
utter_code = self.target_encoder(corpus_embed, sequence_length=batch['corpus_length'])[1]
utter_kwcode = self.target_kwencoder(corpus_embed, sequence_length=batch['corpus_length'])[1]
utter_code = tf.concat([utter_code[0], utter_code[1], utter_kwcode[0], utter_kwcode[1]], -1)
self.corpus_code = np.zeros([0, self.config._code_len])
corpus_iterator.switch_to_dataset(sess)
sess.run(tf.tables_initializer())
saver = tf.train.Saver()
saver.restore(sess, self.config._save_path)
feed = {tx.global_mode(): tf.estimator.ModeKeys.PREDICT}
while True:
try:
utter_code_ = sess.run(utter_code, feed_dict=feed)
self.corpus_code = np.concatenate([self.corpus_code, utter_code_], axis=0)
except tf.errors.OutOfRangeError:
break
self.keywords_embed = tf.nn.l2_normalize(self.embedder(self.kw_list), axis=1)
self.kw_embedding = sess.run(self.keywords_embed)
# predict keyword
self.context_input = tf.placeholder(dtype=object, shape=(20))
self.context_length_input = tf.placeholder(dtype=tf.int32, shape=(1))
context_ids = tf.expand_dims(self.vocab.map_tokens_to_ids(self.context_input), 0)
context_embed = self.embedder(context_ids)
context_code = self.context_encoder(context_embed, sequence_length=self.context_length_input)[1]
matching_score = self.predict_layer(context_code)
self.candi_output =tf.nn.top_k(tf.squeeze(matching_score, 0), self.data_config._keywords_num)[1]
# retrieve
self.minor_length_input = tf.placeholder(dtype=tf.int32, shape=(1, 9))
self.major_length_input = tf.placeholder(dtype=tf.int32, shape=(1))
self.history_input = tf.placeholder(dtype=object, shape=(9, self.data_config._max_seq_len + 2))
self.kw_input = tf.placeholder(dtype=tf.int32)
history_ids = self.vocab.map_tokens_to_ids(self.history_input)
history_embed = self.embedder(history_ids)
history_code = self.source_encoder(tf.expand_dims(history_embed, axis=0),
sequence_length_minor=self.minor_length_input,
sequence_length_major=self.major_length_input)[1]
self.next_kw_ids = self.kw_list[self.kw_input]
embed_code = tf.expand_dims(self.embedder(self.next_kw_ids), 0)
embed_code = self.linear_transform(embed_code)
history_code = tf.concat([history_code, embed_code], 1)
select_corpus = tf.cast(self.corpus_code, dtype=tf.float32)
feature_code = self.linear_matcher(select_corpus * history_code)
self.ans_output = tf.nn.top_k(tf.squeeze(feature_code,1), k=self.data_config._retrieval_candidates)[1]
def retrieve(self, history_all, sess):
history, seq_len, turns, context, context_len = history_all
kw_candi = sess.run(self.candi_output, feed_dict={self.context_input: context,
self.context_length_input: [context_len]})
for kw in kw_candi:
tmp_score = sum(self.kw_embedding[kw] * self.kw_embedding[self.data_config._keywords_dict[self.target]])
if tmp_score > self.score:
self.score = tmp_score
self.next_kw = self.data_config._keywords_candi[kw]
break
ans = sess.run(self.ans_output, feed_dict={self.history_input: history,
self.minor_length_input: [seq_len], self.major_length_input: [turns],
self.kw_input: self.data_config._keywords_dict[self.next_kw]})
flag = 0
reply = self.corpus[ans[0]]
for i in ans:
if i in self.reply_list: # avoid repeat
continue
for wd in kw_tokenize(self.corpus[i]):
if wd in self.data_config._keywords_candi:
tmp_score = sum(self.kw_embedding[self.data_config._keywords_dict[wd]] *
self.kw_embedding[self.data_config._keywords_dict[self.target]])
if tmp_score > self.score:
reply = self.corpus[i]
self.score = tmp_score
self.next_kw = wd
flag = 1
break
if flag == 0:
continue
break
return reply
================================================
FILE: model/retrieval.py
================================================
import texar as tx
import tensorflow as tf
import numpy as np
class Predictor():
def __init__(self, config_model, config_data, mode=None):
self.config = config_model
self.data_config = config_data
self.build_model()
self.gpu_config = tf.ConfigProto()
self.gpu_config.gpu_options.allow_growth = True
def build_model(self):
self.train_data = tx.data.MultiAlignedData(self.data_config.data_hparams['train'])
self.valid_data = tx.data.MultiAlignedData(self.data_config.data_hparams['valid'])
self.test_data = tx.data.MultiAlignedData(self.data_config.data_hparams['test'])
self.iterator = tx.data.TrainTestDataIterator(train=self.train_data, val=self.valid_data, test=self.test_data)
self.vocab = self.train_data.vocab(0)
self.embedder = tx.modules.WordEmbedder(init_value=self.train_data.embedding_init_value(0).word_vecs)
self.source_encoder = tx.modules.HierarchicalRNNEncoder(hparams=self.config.source_encoder_hparams)
self.target_encoder = tx.modules.UnidirectionalRNNEncoder(hparams=self.config.target_encoder_hparams)
self.linear_matcher = tx.modules.MLPTransformConnector(1)
def forward(self, batch):
source_embed = self.embedder(batch['source_text_ids'])
target_embed = self.embedder(batch['target_text_ids'])
target_embed = tf.reshape(target_embed, [-1, self.data_config._max_seq_len + 2, self.embedder.dim])
source_code = self.source_encoder(source_embed,
sequence_length_minor=batch['source_length'],
sequence_length_major=batch['source_utterance_cnt'])[1]
target_length = tf.reshape(batch['target_length'], [-1])
target_code = self.target_encoder(target_embed, sequence_length=target_length)[1]
target_code = tf.reshape(target_code, [-1, 20, self.config._code_len])
source_code = tf.expand_dims(source_code, 1)
source_code = tf.tile(source_code, [1, 20, 1])
feature_code = target_code * source_code
feature_code = tf.reshape(feature_code, [-1, self.config._code_len])
logits = self.linear_matcher(feature_code)
logits = tf.reshape(logits, [-1, 20])
labels = tf.one_hot(batch['label'], 20)
loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=labels, logits=logits))
ans = tf.arg_max(logits, -1)
acc = tx.evals.accuracy(batch['label'], ans)
rank = tf.nn.top_k(logits, k=20)[1]
return loss, acc, rank
def train(self):
batch = self.iterator.get_next()
loss, acc, _ = self.forward(batch)
op_step = tf.Variable(0, name='op_step')
train_op = tx.core.get_train_op(loss, global_step=op_step, hparams=self.config.opt_hparams)
max_val_acc = 0.
self.saver = tf.train.Saver()
with tf.Session(config=self.gpu_config) as sess:
sess.run(tf.global_variables_initializer())
sess.run(tf.local_variables_initializer())
sess.run(tf.tables_initializer())
for epoch_id in range(self.config._max_epoch):
self.iterator.switch_to_train_data(sess)
cur_step = 0
cnt_acc = []
while True:
try:
cur_step += 1
feed = {tx.global_mode(): tf.estimator.ModeKeys.TRAIN}
loss, acc_ = sess.run([train_op, acc], feed_dict=feed)
cnt_acc.append(acc_)
if cur_step % 200 == 0:
print('batch {}, loss={}, acc1={}'.format(cur_step, loss, np.mean(cnt_acc[-200:])))
except tf.errors.OutOfRangeError:
break
op_step = op_step + 1
self.iterator.switch_to_val_data(sess)
cnt_acc = []
while True:
try:
feed = {tx.global_mode(): tf.estimator.ModeKeys.EVAL}
acc_ = sess.run([acc], feed_dict=feed)
cnt_acc.append(acc_)
except tf.errors.OutOfRangeError:
mean_acc = np.mean(cnt_acc)
print('valid acc1={}'.format(mean_acc))
if mean_acc > max_val_acc:
max_val_acc = mean_acc
self.saver.save(sess, self.config._save_path)
break
def test(self):
batch = self.iterator.get_next()
loss, acc, rank = self.forward(batch)
with tf.Session(config=self.gpu_config) as sess:
sess.run(tf.tables_initializer())
self.saver = tf.train.Saver()
self.saver.restore(sess, self.config._save_path)
self.iterator.switch_to_test_data(sess)
rank_cnt = []
while True:
try:
feed = {tx.global_mode(): tf.estimator.ModeKeys.PREDICT}
ranks, labels = sess.run([rank, batch['label']], feed_dict=feed)
for i in range(len(ranks)):
rank_cnt.append(np.where(ranks[i]==labels[i])[0][0])
except tf.errors.OutOfRangeError:
rec = [0,0,0,0,0]
MRR = 0
for rank in rank_cnt:
for i in range(5):
rec[i] += (rank <= i)
MRR += 1 / (rank+1)
print('test rec1@20={:.4f}, rec3@20={:.4f}, rec5@20={:.4f}, MRR={:.4f}'.format(
rec[0]/len(rank_cnt), rec[2]/len(rank_cnt), rec[4]/len(rank_cnt), MRR/len(rank_cnt)))
break
def retrieve_init(self, sess):
data_batch = self.iterator.get_next()
loss, acc, _ = self.forward(data_batch)
self.corpus = self.data_config._corpus
self.corpus_data = tx.data.MonoTextData(self.data_config.corpus_hparams)
corpus_iterator = tx.data.DataIterator(self.corpus_data)
batch = corpus_iterator.get_next()
corpus_embed = self.embedder(batch['corpus_text_ids'])
utter_code = self.target_encoder(corpus_embed, sequence_length=batch['corpus_length'])[1]
self.corpus_code = np.zeros([0, self.config._code_len])
corpus_iterator.switch_to_dataset(sess)
sess.run(tf.tables_initializer())
saver = tf.train.Saver()
saver.restore(sess, self.config._save_path)
feed = {tx.global_mode(): tf.estimator.ModeKeys.PREDICT}
while True:
try:
utter_code_ = sess.run(utter_code, feed_dict=feed)
self.corpus_code = np.concatenate([self.corpus_code, utter_code_], axis=0)
except tf.errors.OutOfRangeError:
break
self.minor_length_input = tf.placeholder(dtype=tf.int32, shape=(1, 9))
self.major_length_input = tf.placeholder(dtype=tf.int32, shape=(1))
self.history_input = tf.placeholder(dtype=object, shape=(9, self.data_config._max_seq_len + 2))
history_ids = self.vocab.map_tokens_to_ids(self.history_input)
history_embed = self.embedder(history_ids)
history_code = self.source_encoder(tf.expand_dims(history_embed, axis=0),
sequence_length_minor=self.minor_length_input,
sequence_length_major=self.major_length_input)[1]
select_corpus = tf.cast(self.corpus_code, dtype=tf.float32)
feature_code = self.linear_matcher(select_corpus * history_code)
self.ans_output = tf.nn.top_k(tf.squeeze(feature_code, 1), k=self.data_config._retrieval_candidates)[1]
def retrieve(self, source, sess):
history, seq_len, turns, context, context_len = source
ans = sess.run(self.ans_output, feed_dict={self.history_input: history,
self.minor_length_input: [seq_len],
self.major_length_input: [turns]})
for i in range(self.data_config._max_turns + 1):
if ans[i] not in self.reply_list: # avoid repeat
self.reply_list.append(ans[i])
reply = self.corpus[ans[i]]
break
return reply
================================================
FILE: model/retrieval_stgy.py
================================================
import texar as tx
import tensorflow as tf
import numpy as np
from preprocess.data_utils import kw_tokenize
class Predictor():
def __init__(self, config_model, config_data, mode=None):
self.config = config_model
self.data_config = config_data
self.build_model()
self.gpu_config = tf.ConfigProto()
self.gpu_config.gpu_options.allow_growth = True
def build_model(self):
self.train_data = tx.data.MultiAlignedData(self.data_config.data_hparams['train'])
self.valid_data = tx.data.MultiAlignedData(self.data_config.data_hparams['valid'])
self.test_data = tx.data.MultiAlignedData(self.data_config.data_hparams['test'])
self.iterator = tx.data.TrainTestDataIterator(train=self.train_data, val=self.valid_data, test=self.test_data)
self.vocab = self.train_data.vocab(0)
self.embedder = tx.modules.WordEmbedder(init_value=self.train_data.embedding_init_value(0).word_vecs)
self.source_encoder = tx.modules.HierarchicalRNNEncoder(hparams=self.config.source_encoder_hparams)
self.target_encoder = tx.modules.UnidirectionalRNNEncoder(hparams=self.config.target_encoder_hparams)
self.linear_matcher = tx.modules.MLPTransformConnector(1)
self.kw_list = self.vocab.map_tokens_to_ids(tf.convert_to_tensor(self.data_config._keywords_candi))
def forward(self, batch):
source_embed = self.embedder(batch['source_text_ids'])
target_embed = self.embedder(batch['target_text_ids'])
target_embed = tf.reshape(target_embed, [-1, self.data_config._max_seq_len + 2, self.embedder.dim])
source_code = self.source_encoder(source_embed,
sequence_length_minor=batch['source_length'],
sequence_length_major=batch['source_utterance_cnt'])[1]
target_length = tf.reshape(batch['target_length'], [-1])
target_code = self.target_encoder(target_embed, sequence_length=target_length)[1]
target_code = tf.reshape(target_code, [-1, 20, self.config._code_len])
source_code = tf.expand_dims(source_code, 1)
source_code = tf.tile(source_code, [1, 20, 1])
feature_code = target_code * source_code
feature_code = tf.reshape(feature_code, [-1, self.config._code_len])
logits = self.linear_matcher(feature_code)
logits = tf.reshape(logits, [-1, 20])
labels = tf.one_hot(batch['label'], 20)
loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=labels, logits=logits))
ans = tf.arg_max(logits, -1)
acc = tx.evals.accuracy(batch['label'], ans)
rank = tf.nn.top_k(logits, k=20)[1]
return loss, acc, rank
def train(self):
batch = self.iterator.get_next()
loss, acc, _ = self.forward(batch)
op_step = tf.Variable(0, name='op_step')
train_op = tx.core.get_train_op(loss, global_step=op_step, hparams=self.config.opt_hparams)
max_val_acc = 0.
self.saver = tf.train.Saver()
with tf.Session(config=self.gpu_config) as sess:
sess.run(tf.global_variables_initializer())
sess.run(tf.local_variables_initializer())
sess.run(tf.tables_initializer())
for epoch_id in range(self.config._max_epoch):
self.iterator.switch_to_train_data(sess)
cur_step = 0
cnt_acc = []
while True:
try:
cur_step += 1
feed = {tx.global_mode(): tf.estimator.ModeKeys.TRAIN}
loss, acc_ = sess.run([train_op, acc], feed_dict=feed)
cnt_acc.append(acc_)
if cur_step % 200 == 0:
print('batch {}, loss={}, acc1={}'.format(cur_step, loss, np.mean(cnt_acc[-200:])))
except tf.errors.OutOfRangeError:
break
op_step = op_step + 1
self.iterator.switch_to_val_data(sess)
cnt_acc = []
while True:
try:
feed = {tx.global_mode(): tf.estimator.ModeKeys.EVAL}
acc_ = sess.run([acc], feed_dict=feed)
cnt_acc.append(acc_)
except tf.errors.OutOfRangeError:
mean_acc = np.mean(cnt_acc)
print('valid acc1={}'.format(mean_acc))
if mean_acc > max_val_acc:
max_val_acc = mean_acc
self.saver.save(sess, self.config._save_path)
break
def test(self):
batch = self.iterator.get_next()
loss, acc, rank = self.forward(batch)
with tf.Session(config=self.gpu_config) as sess:
sess.run(tf.tables_initializer())
self.saver = tf.train.Saver()
self.saver.restore(sess, self.config._save_path)
self.iterator.switch_to_test_data(sess)
rank_cnt = []
while True:
try:
feed = {tx.global_mode(): tf.estimator.ModeKeys.PREDICT}
ranks, labels = sess.run([rank, batch['label']], feed_dict=feed)
for i in range(len(ranks)):
rank_cnt.append(np.where(ranks[i]==labels[i])[0][0])
except tf.errors.OutOfRangeError:
rec = [0,0,0,0,0]
MRR = 0
for rank in rank_cnt:
for i in range(5):
rec[i] += (rank <= i)
MRR += 1 / (rank+1)
print('test rec1@20={:.4f}, rec3@20={:.4f}, rec5@20={:.4f}, MRR={:.4f}'.format(
rec[0]/len(rank_cnt), rec[2]/len(rank_cnt), rec[4]/len(rank_cnt), MRR/len(rank_cnt)))
break
def retrieve_init(self, sess):
data_batch = self.iterator.get_next()
loss, acc, _ = self.forward(data_batch)
self.corpus = self.data_config._corpus
self.corpus_data = tx.data.MonoTextData(self.data_config.corpus_hparams)
corpus_iterator = tx.data.DataIterator(self.corpus_data)
batch = corpus_iterator.get_next()
corpus_embed = self.embedder(batch['corpus_text_ids'])
utter_code = self.target_encoder(corpus_embed, sequence_length=batch['corpus_length'])[1]
self.corpus_code = np.zeros([0, self.config._code_len])
corpus_iterator.switch_to_dataset(sess)
sess.run(tf.tables_initializer())
saver = tf.train.Saver()
saver.restore(sess, self.config._save_path)
feed = {tx.global_mode(): tf.estimator.ModeKeys.PREDICT}
while True:
try:
utter_code_ = sess.run(utter_code, feed_dict=feed)
self.corpus_code = np.concatenate([self.corpus_code, utter_code_], axis=0)
except tf.errors.OutOfRangeError:
break
self.keywords_embed = tf.nn.l2_normalize(self.embedder(self.kw_list), axis=1)
self.kw_embedding = sess.run(self.keywords_embed)
self.minor_length_input = tf.placeholder(dtype=tf.int32, shape=(1, 9))
self.major_length_input = tf.placeholder(dtype=tf.int32, shape=(1))
self.history_input = tf.placeholder(dtype=object, shape=(9, self.data_config._max_seq_len + 2))
history_ids = self.vocab.map_tokens_to_ids(self.history_input)
history_embed = self.embedder(history_ids)
history_code = self.source_encoder(tf.expand_dims(history_embed, axis=0),
sequence_length_minor=self.minor_length_input,
sequence_length_major=self.major_length_input)[1]
select_corpus = tf.cast(self.corpus_code, dtype=tf.float32)
feature_code = self.linear_matcher(select_corpus * history_code)
self.ans_output = tf.nn.top_k(tf.squeeze(feature_code, 1), k=1000)[1]
def retrieve(self, source, sess):
history, seq_len, turns, context, context_len = source
ans = sess.run(self.ans_output, feed_dict={self.history_input: history,
self.minor_length_input: [seq_len],
self.major_length_input: [turns]})
flag = 0
reply = self.corpus[ans[0]]
for i in ans:
if i in self.reply_list: # avoid repeat
continue
for wd in kw_tokenize(self.corpus[i]):
if wd in self.data_config._keywords_candi:
tmp_score = sum(self.kw_embedding[self.data_config._keywords_dict[wd]] *
self.kw_embedding[self.data_config._keywords_dict[self.target]])
if tmp_score > self.score:
reply = self.corpus[i]
self.score = tmp_score
self.next_kw = wd
flag = 1
break
if flag == 0:
continue
break
return reply
================================================
FILE: preprocess/convai2/__init__.py
================================================
from .api import *
================================================
FILE: preprocess/convai2/api.py
================================================
import os
data_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'source')
class dts_ConvAI2(object):
def __init__(self, path=data_path):
self.path = path
def _txt_to_json(self, txt_path, mode, cands):
def pop_one_sample(lines):
self_persona = []
other_persona = []
dialog = []
candidates = []
started = False
while len(lines) > 0:
line = lines.pop()
id, context = line.split(' ', 1)
id = int(id)
context = context.strip()
if started == False: # not started
assert id == 1
started = True
elif id == 1: # break for next
lines.append(line)
break
if context.startswith('partner\'s persona: '): # partner
assert mode in ['both', 'other']
other_persona.append(context[19:])
elif context.startswith('your persona: '): # self
assert mode in ['both', 'self']
self_persona.append(context[13:])
elif cands == False: # no cands
try:
uttr, response = context.split('\t', 2)[:2]
dialog.append(uttr)
dialog.append(response)
except:
uttr = context
dialog.append(uttr)
else:
uttr, response, _, negs = context.split('\t', 4)[:4]
dialog.append(uttr)
dialog.append(response)
candidates.append(negs.split('|'))
candidates.append(None)
return {
'self_persona': self_persona,
'other_persona': other_persona,
'dialog': dialog,
'candidates': candidates
}
lines = open(txt_path, 'r').readlines()[::-1]
samples = []
while len(lines) > 0:
samples.append(pop_one_sample(lines))
return samples
def get_data(self, mode='train', revised=False, cands=False):
txt_path = os.path.join(self.path, '{}_{}_{}{}.txt'.format(
mode,
'none',
'revised' if revised is True else 'original',
'' if cands is True else '_no_cands'))
assert mode in ['train', 'valid', 'test', 'all']
print("Get dialog from ", txt_path)
assert os.path.exists(txt_path)
return self._txt_to_json(txt_path, mode, cands)
def get_dialogs(self, mode='all'):
dialogs = [sample['dialog'] for sample in self.get_data(mode, False, False)]
return dialogs
================================================
FILE: preprocess/convai2/candi_keyword.txt
================================================
favorite
sound
play
dog
music
kid
eat
school
enjoy
job
watch
read
food
cat
friend
family
hobby
people
pet
car
game
hear
movie
travel
book
cook
listen
animal
life
color
drive
college
hope
living
parent
teach
bad
sport
hard
dad
feel
child
hair
band
country
money
pizza
marry
bet
hate
walk
stay
study
write
husband
fish
start
guess
brother
blue
night
dance
busy
red
wife
learn
talk
sing
meet
beach
spend
drink
rock
city
video
grow
person
house
tv
true
shop
teacher
buy
girl
weekend
prefer
hike
meat
football
free
visit
sister
care
plan
art
swim
sweet
paint
stuff
pretty
vegan
store
class
farm
type
funny
understand
son
business
happy
mother
ride
coffee
leave
fan
garden
week
boy
bake
green
pay
beautiful
draw
student
horse
sell
sad
summer
wear
truck
black
hot
wait
yea
sleep
cold
agree
single
guy
sibling
real
crazy
healthy
fine
tall
guitar
lose
cute
hour
cake
purple
month
relax
finish
idea
break
company
daughter
happen
dream
team
mind
park
woman
reading
restaurant
bike
short
italian
exercise
hang
chocolate
wonderful
basketball
speak
soccer
weather
graduate
retire
winter
morning
nurse
bring
die
easy
party
grade
father
ready
ice
win
spare
doctor
online
song
local
degree
chicken
concert
florida
glad
fall
baseball
eye
office
volunteer
water
girlfriend
bear
french
shopping
age
luck
fishing
artist
weird
foot
save
surf
dinner
yoga
lucky
story
hunt
cooking
pink
suck
cream
boyfriend
language
beer
outdoors
change
diet
passion
york
major
cheese
collect
imagine
super
english
rap
practice
chat
head
california
train
lake
clothes
pass
nature
baby
fry
season
apple
piano
sit
hop
scary
taco
law
steak
hockey
comic
brown
bar
gym
smart
huge
clean
fav
close
jazz
vegetarian
canada
science
career
picture
tea
excite
tough
deal
pick
allergic
neat
church
social
sick
shoe
trip
fly
vacation
catch
bed
raise
boring
rid
ocean
club
town
instrument
check
race
dragon
fast
vegetable
wrong
boat
terrible
tennis
candy
rain
worry
veggie
fruit
metal
perfect
twin
tattoo
mountain
tomorrow
god
stand
hospital
apartment
build
stick
design
japan
texas
meal
grocery
camp
hit
orange
mexican
allergy
ta
phone
famous
player
flower
bore
army
star
share
pool
delicious
relationship
exciting
math
egg
museum
classic
dress
sushi
taste
married
amaze
classical
lady
shelter
sense
joke
pop
pie
yellow
american
lawyer
expensive
stress
hand
cut
pasta
remember
professional
choice
impressive
youtube
kinda
yum
chicago
birthday
cooky
sunday
follow
divorce
gon
moment
fresh
tire
hurt
wedding
fit
weight
health
plant
count
chef
heard
ball
scar
sort
smell
dead
special
comedy
couple
rich
hiking
fave
create
accountant
bird
relaxing
history
blonde
film
everyday
glass
heart
voice
ballet
vet
military
horror
field
fight
mile
salad
extra
afraid
market
christmas
reason
mexico
attend
cop
chance
potato
snow
halloween
folk
snake
ski
character
card
adopt
figure
tho
hat
nail
ford
fat
warm
swimming
difficult
sew
suppose
dye
safe
wine
tend
style
afford
recipe
writer
shower
lunch
remind
painting
news
congrats
photography
rough
road
holiday
join
beat
white
throne
michigan
awful
actor
breakfast
library
congratulation
goal
shrimp
singer
middle
horrible
common
plane
profession
bacon
smoke
coast
kill
future
word
female
south
craft
neighbor
violin
fair
paris
fashion
quiet
service
bank
workout
dish
shape
tour
theater
genre
forget
fiction
makeup
model
clown
author
experience
rest
librarian
king
puppy
german
tree
trust
street
yummy
marketing
quit
hungry
creative
poor
cali
wild
alright
accident
fantasy
cartoon
air
pepper
hurricane
ton
lab
anime
university
drop
factory
iphone
mystery
active
netflix
stressful
space
trouble
secret
mcdonalds
laugh
gun
support
spanish
pain
casino
france
set
spaghetti
sale
magic
grill
burger
climb
internet
pig
hold
decide
mustang
program
drum
hmm
focus
seafood
master
medium
strange
lesson
arm
war
serve
van
north
cow
toe
driver
paper
education
table
station
dancer
soda
playing
medical
dark
blog
usa
power
wan
bob
simple
zoo
shoot
fancy
mechanic
collection
national
motorcycle
strong
parrot
nursing
america
harry
blood
ahahah
body
bible
police
energy
alabama
lazy
death
rose
musician
italy
adventure
worker
writing
pair
activity
bread
throw
issue
europe
london
gross
jealous
male
height
disney
treat
marathon
door
bean
bos
boot
aww
cancer
lover
milk
skill
plenty
beard
light
poetry
subject
sea
romantic
relate
spending
firm
send
seattle
couch
question
flavor
photo
channel
favourite
match
contact
choose
japanese
singing
lift
boston
wood
grandmother
farmer
radio
navy
roommate
bus
politics
grandchild
jam
chinese
viking
matter
series
competition
designer
gas
san
feeling
knit
lizard
slow
publish
fantastic
crochet
donate
add
television
insurance
talent
bowl
human
land
rise
hmmm
test
alive
involve
dangerous
pro
yesterday
wind
obsess
golden
stone
post
west
nephew
depend
nervous
breed
bakery
hip
goodness
baker
surgery
mall
don
honda
colorado
dancing
dr
dude
lifestyle
calm
wake
crime
athlete
skin
beauty
lie
facebook
mar
reader
lead
mess
snack
shift
employ
spring
size
handle
shy
iron
corvette
evening
superman
board
cheer
traffic
grand
sun
fee
musical
base
public
chip
mad
careful
pound
brand
potter
punk
smile
advice
justin
manager
bass
peaceful
golf
clothing
john
action
buddy
sunny
box
waitress
gardening
popular
wing
wall
personal
coach
cover
elementary
mix
pray
odd
stink
minute
key
inspire
event
adorable
community
juice
engineer
karate
skate
saturday
literature
reward
baking
ghost
ohio
vega
kitchen
sugar
chill
tomato
chase
chevy
roll
la
organize
blame
homework
bee
grandma
beatles
passionate
soul
knee
anxiety
nut
spot
ticket
grandparent
honest
google
junk
product
tech
track
normal
role
client
veterinarian
spicy
building
scare
hawaii
bummer
reality
iguana
niece
opera
charlie
worth
pack
trail
senior
toyota
beet
ireland
finance
chili
officer
pickle
kinds
manage
tired
christian
daily
nose
cheap
spider
gig
indian
locate
pleasure
league
hunting
retired
george
tuna
teaching
attention
step
awe
stock
list
portland
sign
avoid
bug
brain
scientist
dessert
excellent
finger
religious
superhero
highschool
kick
rule
salesman
drawing
handful
pilot
escape
schedule
tasty
gosh
burrito
kansa
commercial
rescue
mary
fi
retail
retriever
skydive
husky
gamble
georgia
industry
meeting
nyc
walmart
annoy
cruise
doubt
forest
bunch
cashier
actress
kayak
partner
mac
tiny
friday
spell
culture
kindergarten
lay
tiger
deaf
mention
drinking
accounting
cup
subway
pancake
habit
prius
blind
familiar
starbucks
assistant
pumpkin
burn
answer
poodle
metallica
circus
jewelry
sock
recommend
sandwich
journalist
fear
pretend
grandson
nfl
brownie
cupcake
doll
stephen
bother
desert
sci
moon
construction
reach
technology
option
marriage
pony
hell
romance
cuisine
strawberry
shirt
bbq
purse
leg
kitten
unemployed
project
running
deliver
natural
dang
control
hero
beef
trade
sunset
carrot
main
chess
grandkids
talented
explore
nasty
east
hamburger
private
yup
guard
choir
australia
pen
stamp
struggle
dive
bro
perry
omg
thinking
candle
fell
ink
wheel
ranch
owner
carry
biology
earn
regular
lion
peanut
broccoli
suit
medicine
speaking
chew
inspiration
sauce
hotel
camera
england
repair
turtle
athletic
unique
ray
sky
skateboard
river
tune
freak
estate
cheesecake
yuck
surfing
rent
ooh
grey
memory
perform
popcorn
dislike
childhood
adore
quick
comedian
sweater
antique
lottery
hows
mcdonald
jump
tutor
carolina
walking
pregnant
mike
vampire
decorate
bieber
alcohol
compete
mansion
owl
gotcha
jack
engineering
retirement
pot
airplane
ferrari
dry
dentist
russian
piece
security
spirit
offer
dorm
record
settle
lobster
foreign
software
honey
rice
princess
excited
amazon
baltimore
island
skiing
center
alien
butter
corn
civic
jane
view
nap
pit
bulldog
lovely
prince
loud
photograph
gift
coke
org
belt
festival
ugh
ahh
vintage
pug
birth
understandable
sweetheart
irma
deep
india
feed
ring
item
scene
spear
mushroom
position
afternoon
hire
trainer
distract
touch
unfortunate
alaska
expect
katy
scratch
cost
situation
gay
loss
organic
washington
joy
gummy
survive
med
bless
prepare
charity
sight
rare
heavy
rural
russia
newspaper
karaoke
driving
customer
watching
require
graphic
mood
maine
freelance
fitness
diner
pepsi
condo
miami
cross
runner
accept
panda
bunny
engage
uncle
commute
indie
cooler
straight
hollywood
bagel
terrify
training
boxer
left
protein
bull
jog
tom
shame
ouch
current
programmer
nerd
magazine
artistic
yikes
eating
skittle
furniture
eagle
form
fabulous
legal
agency
internship
cabin
drama
positive
addict
surprise
rewarding
tax
barista
fake
spain
crash
random
kale
bright
shark
studio
bow
boys
bell
brace
trick
wheelchair
cloud
southern
force
chair
spouse
thumb
frank
rapper
virginia
physical
bye
grad
soup
fiance
elvis
meatloaf
queen
united
income
salon
volleyball
target
chihuahua
limit
scholarship
direction
supply
canadian
daddy
toy
kentucky
respect
earth
political
binge
hilarious
blast
unhealthy
pant
cheeseburger
complete
unicorn
religion
dairy
drug
fl
whiskey
reject
iced
average
eater
dirty
rat
angry
heck
hide
competitive
gum
website
laptop
exhaust
robot
challenge
outdoor
raw
audition
cafe
onion
assume
opinion
horseback
zebra
philosophy
psychology
exact
spice
debt
reside
heat
hobbies
leather
rude
china
storm
grown
invite
instructor
steal
curly
wash
worm
bf
credit
greek
oregon
shot
riding
pride
speed
floor
intense
court
dental
bone
friendly
diving
lame
kitty
alternative
hubby
strict
ham
camping
pond
marine
grandpa
secretary
theme
judge
shade
complain
herb
advertising
celebrity
fascinate
lasagna
environment
painter
comfortable
beagle
recycle
bruno
invest
search
society
halo
pursue
hm
sam
connect
angeles
weed
grab
kiss
bald
em
britney
oil
conversation
mmm
yard
cash
jean
ship
exotic
vanilla
waste
photographer
adult
rolling
dig
cookie
tennessee
balance
cleaning
ocd
eggplant
archery
ma
meatball
dust
deer
gluten
agent
gaming
celebrate
helpful
tofu
campus
lord
evil
er
los
shake
martial
drummer
outfit
grass
wrestle
note
convertible
biking
taller
gorgeous
file
hectic
salsa
rush
awhile
distance
soft
homeless
daycare
process
patient
houston
crowd
stew
duty
bookstore
tie
neighborhood
professor
orleans
electric
original
opportunity
francisco
angel
fund
site
peace
picky
wisconsin
madonna
adam
iceland
blow
uh
total
urban
coincidence
entire
jello
compare
dollar
headache
blond
guilty
cure
electronic
grader
greece
correct
twilight
enjoyable
benefit
damn
coupon
cheat
thrift
print
scout
successful
bartender
stranger
cattle
mommy
bath
nascar
honor
suggestion
appalachian
ginger
hook
aquarium
james
scared
gamer
collie
creepy
upstate
therapist
bicycle
trophy
department
cheerleader
variety
suburb
explain
cuddle
flip
shepherd
tupac
whoa
iq
lifeguard
sunshine
jersey
rainy
vehicle
closet
rainbow
ross
specialty
attorney
interior
gf
crab
podcasts
fart
decent
insane
shepard
gossip
jimmy
blackjack
austin
nike
louisiana
brat
therapy
noble
skunk
listening
awww
pleasant
pittsburgh
barbie
specific
pas
um
avid
occasion
mustache
autograph
noise
diego
ruin
admit
fail
scotch
creature
costume
jeopardy
swift
africa
cd
account
edit
topic
handy
window
steelers
accent
activist
preacher
affect
apply
loose
refuse
bum
minnesota
gender
confuse
landscape
broadway
bmw
argue
foodie
trek
relieve
todd
path
freedom
pearl
tap
medication
vera
rapid
ohh
government
environmental
fault
ft
cope
carbs
standard
planet
mcqueen
nugget
pull
difference
babysit
teller
disappoint
mermaid
pageant
soap
midwest
giant
puerto
bowling
asian
arizona
happiness
provide
fond
bud
cell
entertain
butterfly
genius
scream
architect
jar
monkey
theatre
hah
purchase
yay
vote
level
cab
combo
tool
fluent
duck
creek
load
affair
humor
slave
publishing
painful
redhead
jerry
thursday
dungeon
cape
messy
approve
cousin
weekly
terrier
paddle
david
typical
waffle
desk
catholic
germany
instagram
admire
image
upset
muscle
monday
pic
basement
ground
wave
shellfish
technician
episode
jim
jacob
ballerina
loan
improve
garage
ibm
tooth
bachelor
thrill
hippie
notice
return
crush
jesus
stomach
techno
raven
nerve
denver
foster
thriller
rugby
daydream
oreo
discover
detroit
cult
atlanta
incredible
stable
poem
yr
oooh
wide
mango
snowboard
weapon
countryside
alcoholic
brunch
fisherman
aunt
toronto
sarah
addiction
surround
lactose
clinic
universe
pretzel
toto
blah
sir
sausage
cosplay
text
quality
millionaire
clerk
skinny
hendrix
disabled
puzzle
pepperoni
civil
surfer
larp
package
roof
pa
suspense
drone
mail
jason
exam
mark
financial
encyclopedia
cheetos
demand
shell
planning
stupid
yell
grateful
bingo
source
companion
director
bite
detective
biography
gospel
silly
pudding
pork
teeth
autobiography
salt
footstep
deserve
produce
swimmer
barbecue
maryland
btw
defense
fallon
teen
continue
cart
wizard
meditate
shelf
zombie
irish
pecan
bubble
discount
scooter
push
tutorial
scuba
homemade
weakness
translator
gymnastics
background
softball
kidding
mistake
realize
ironic
floyd
flight
surgeon
crack
desire
introvert
ted
knife
mba
pottery
orphan
recover
mini
bucket
perk
autism
moped
cycle
youth
spoil
attitude
injury
pennsylvania
aspire
loyal
attack
murder
price
tank
safety
olympics
rome
dj
carpenter
sesame
consume
protect
british
sword
cheetah
float
asia
mate
creed
xbox
lean
freckle
caffeine
hunter
trump
programming
picnic
ear
fridge
easter
asparagus
oldie
darn
disagree
mirror
hehe
ew
biscuit
freshman
whistle
usual
inch
deli
eclipse
cycling
modern
tease
review
pattern
award
belong
nickname
buff
sweden
cuz
selfish
personality
graduation
dolphin
album
pup
taylor
lease
educate
depress
paramedic
gila
purpose
prison
worried
vermont
block
buffalo
olive
identical
score
string
actual
intelligent
epilepsy
sand
plate
subtitle
border
cable
smooth
lexus
develop
vienna
brave
welfare
doberman
wealthy
switch
sneak
robert
hill
nacho
mugger
snap
dumb
coworker
peta
hr
chick
fur
goalie
range
introduce
costco
railroad
suffer
menu
soldier
asthma
sex
lindsey
utah
grandfather
documentary
admirable
traveling
dane
employee
article
caramel
harley
equipment
heal
basic
dabble
depression
attract
gaga
tale
tube
fried
doggy
mash
taught
receptionist
empire
intern
piercings
tackle
ease
blanket
participate
cheesy
gray
daisy
toddler
link
diamond
propose
dallas
president
ranger
wolf
gain
chai
annoying
earring
version
basket
lens
salary
corner
champion
firefighter
ferret
achieve
sears
mia
idol
joe
decade
emotion
koala
management
pharmacist
apps
depends
killer
fellow
uniform
gourmet
cleveland
motivate
hummus
entertainment
mmmm
bag
nashville
flirt
owen
trumpet
nevada
stage
jerky
responsibility
drake
bentley
gold
tx
arcade
ankle
vegas
kj
batman
unwind
keyboard
combination
leaf
koi
cello
minimum
adventurous
loving
kiddos
spill
diabetic
central
rob
trend
bubblegum
indoors
monster
drunk
confidence
pyramid
grunge
banker
breath
grasshopper
hoop
encourage
gutter
macaroni
robotics
double
seat
dew
uncomfortable
flash
bench
bomb
info
semi
comfort
wildlife
geology
lonely
coz
edge
anniversary
vitamin
material
hotdog
gathering
socialize
machine
editor
droopy
brew
liberal
mercedes
quarterback
rn
planner
fortunate
ur
greenhouse
si
psychologist
promotion
def
discovery
carb
humane
broken
chanel
fluffy
chain
vision
stanford
pipe
ability
overweight
promote
labrador
veteran
preference
symphony
alpaca
ve
app
teenager
anne
promise
doo
publisher
curious
tiki
porsche
mixed
maid
legend
michael
supportive
pineapple
ariel
diabetes
consulting
starve
gal
collar
gable
battle
jacket
sexy
sleeve
felix
pastime
jamaica
mortal
weak
scrub
rabbit
godfather
sinatra
valley
despise
regret
goodwill
heaven
buddhist
smoking
oops
pitbulls
salmon
rick
bitcoin
dip
trout
pill
farming
thankful
tokyo
housewife
prayer
impala
valedictorian
plain
message
temper
flintstone
leprechaun
sucker
breathe
csi
criminal
rip
maiden
fascinating
rico
algeria
report
umm
patience
leader
curl
motivation
climbing
tahoe
ymca
relief
glacier
breast
enter
clutter
dull
fighter
tat
awake
brewery
victorian
volcano
friends
mount
pillage
magical
generation
clue
conscious
stare
silver
wrestling
levine
joint
restore
everest
dope
stray
international
parking
hampshire
hearse
warehouse
pitbull
nyu
outdoorsy
development
employment
drinker
zumba
paul
budget
daniel
eyesight
sour
mouth
stain
blogger
exist
rib
brush
interview
bff
custom
snuggle
vancouver
mario
ferraris
mural
poet
oriole
period
karma
damage
warmer
crossword
childrens
pomeranian
imaginary
dave
anatomy
tone
code
videogames
woodstock
convention
janitor
preschool
screen
prejudice
crystal
rage
tradition
chatting
traditional
parakeet
ramen
combat
multiple
crave
syrup
racing
highlight
communist
concentrate
waiter
ebooks
dodge
hp
boil
attic
medal
commitment
release
downtown
alligator
statement
debate
agreed
maga
homeschooled
strength
plumber
hippy
windy
condition
smoothie
stair
content
depressed
ferrell
keto
remodel
donut
winner
playlist
wayne
nation
kpop
map
coon
junior
mum
tape
quake
smithsonian
washer
abigail
radiohead
humble
unicycle
administration
ontario
performance
truth
fred
ingredient
cucumber
beastie
orchestra
sewing
knock
culinary
sweat
seashell
impression
network
languages
tailgate
celebration
thomas
embarrass
born
mama
freeze
crap
fortune
figurine
confident
homebody
chemistry
collector
merna
arrive
titanic
meditation
bout
manta
announcer
solo
circle
md
funeral
engine
butt
delivery
ultimate
specialize
web
palm
absolute
investment
harsh
pistachio
loner
experiment
gut
austen
fuel
cramp
trauma
sleepy
celtic
press
draft
auto
sprite
obsession
sip
fifty
vinyl
swing
fool
hbu
harvey
copperfield
playoff
kite
lesbian
jerk
owe
democrat
mass
hamilton
ga
uk
luis
impress
slice
pita
hobbie
apologize
santa
tacos
landing
hometown
telecom
mater
mutt
deploy
del
sore
nancy
barbies
fam
clay
ethnic
pastry
hostage
tight
backyard
convince
maker
curry
android
pc
jessica
ignore
flow
sickness
elderly
chore
upholstery
sweetie
lettuce
cuba
gadget
animation
trooper
faith
tongue
success
gentle
portrait
sheeran
chevrolet
packer
risk
spark
frustrate
mouse
pitch
weld
eyebrow
bella
linebacker
bully
routine
spelling
bc
coat
saudi
arabia
tampa
emmy
samsung
mop
kevin
checker
teapot
weigh
suv
miserable
sevenfold
f150
lit
posse
thai
curator
steve
poop
historical
morty
cane
miley
wise
petition
tear
penn
astronaut
cod
colour
acting
precious
buck
lucy
muse
cosmetic
occupation
nba
ate
flexible
ideal
suspender
bang
direct
gotti
agitate
hairdresser
dealership
influence
cursive
sunfish
snorkel
shallow
root
pediatrician
compost
coaster
nearby
foreman
deadbeat
penny
jay
jasper
tarot
pressure
clarinet
supper
express
ai
martini
favor
chop
lutefisk
charge
dakota
hitchhike
formal
ivy
raptor
battlestar
captain
disgust
task
sitcom
yorkie
coco
understood
naw
ant
stinky
speckle
title
corporate
wednesday
gambler
wage
multi
mma
cookbook
citizen
hazel
aspiration
goat
stuck
lumberjack
flag
wet
ufc
learning
stirling
dealer
grisham
acre
================================================
FILE: preprocess/data_utils.py
================================================
import nltk
import os
from nltk.stem import WordNetLemmatizer
_lemmatizer = WordNetLemmatizer()
def tokenize(example, ppln):
for fn in ppln:
example = fn(example)
return example
def kw_tokenize(string):
return tokenize(string, [nltk_tokenize, lower, pos_tag, to_basic_form])
def simp_tokenize(string):
return tokenize(string, [nltk_tokenize, lower])
def nltk_tokenize(string):
return nltk.word_tokenize(string)
def lower(tokens):
if not isinstance(tokens, str):
return [lower(token) for token in tokens]
return tokens.lower()
def pos_tag(tokens):
return nltk.pos_tag(tokens)
def to_basic_form(tokens):
if not isinstance(tokens, tuple):
return [to_basic_form(token) for token in tokens]
word, tag = tokens
if tag.startswith('NN'):
pos = 'n'
elif tag.startswith('VB'):
pos = 'v'
elif tag.startswith('JJ'):
pos = 'a'
else:
return word
return _lemmatizer.lemmatize(word, pos)
def truecasing(tokens):
ret = []
is_start = True
for word, tag in tokens:
if word == 'i':
ret.append('I')
elif tag[0].isalpha():
if is_start:
ret.append(word[0].upper() + word[1:])
else:
ret.append(word)
is_start = False
else:
if tag != ',':
is_start = True
ret.append(word)
return ret
candi_keyword_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'convai2/candi_keyword.txt')
_candiwords = [x.strip() for x in open(candi_keyword_path).readlines()]
def is_candiword(a):
if a in _candiwords:
return True
return False
from nltk.corpus import wordnet as wn
from nltk.corpus import wordnet_ic
brown_ic = wordnet_ic.ic('ic-brown.dat')
def calculate_linsim(a, b):
linsim = -1
syna = wn.synsets(a)
synb = wn.synsets(b)
for sa in syna:
for sb in synb:
try:
linsim = max(linsim, sa.lin_similarity(sb, brown_ic))
except:
pass
return linsim
def is_reach_goal(context, goal):
context = kw_tokenize(context)
if goal in context:
return True
for wd in context:
if is_candiword(wd):
rela = calculate_linsim(wd, goal)
if rela > 0.9:
return True
return False
def make_context(string):
string = kw_tokenize(string)
context = []
for word in string:
if is_candiword(word):
context.append(word)
return context
def utter_preprocess(string_list, max_length):
source, minor_length = [], []
string_list = string_list[-9:]
major_length = len(string_list)
if major_length == 1:
context = make_context(string_list[-1])
else:
context = make_context(string_list[-2] + string_list[-1])
context_len = len(context)
while len(context) < 20:
context.append('<PAD>')
for string in string_list:
string = simp_tokenize(string)
if len(string) > max_length:
string = string[:max_length]
string = ['<BOS>'] + string + ['<EOS>']
minor_length.append(len(string))
while len(string) < max_length + 2:
string.append('<PAD>')
source.append(string)
while len(source) < 9:
source.append(['<PAD>'] * (max_length + 2))
minor_length.append(0)
return (source, minor_length, major_length, context, context_len)
================================================
FILE: preprocess/dataset.py
================================================
import numpy as np
import collections
import random
import pickle
from convai2 import dts_ConvAI2
from extraction import KeywordExtractor
from data_utils import *
class dts_Target(dts_ConvAI2):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def get_vocab(self):
counter = collections.Counter()
dialogs = self.get_dialogs()
for dialog in dialogs:
for uttr in dialog:
counter.update(simp_tokenize(uttr))
print('total vocab count: ', len(counter.items()))
vocab = [token for token, times in sorted(list(counter.items()), key=lambda x: (-x[1], x[0]))]
with open('../tx_data/vocab.txt','w') as f:
for word in vocab:
f.write(word + '\n')
print('save vocab in vocab.txt')
return vocab
def get_kwsess(self, vocab, mode='all'):
keyword_extractor = KeywordExtractor(vocab)
corpus = self.get_data(mode = mode, cands=False)
sess_set = []
for sess in corpus:
data = {}
data['history'] = ''
data['dialog'] = []
for dialog in sess['dialog']:
data['dialog'].append(dialog)
data['history'] = data['history'] + ' ' + dialog
data['kws'] = keyword_extractor.extract(data['history'])
sess_set.append(data)
return sess_set
def cal_idf(self):
counter = collections.Counter()
dialogs = self.get_dialogs()
total = 0.
for dialog in dialogs:
for uttr in dialog:
total += 1
counter.update(set(kw_tokenize(uttr)))
idf_dict = {}
for k,v in counter.items():
idf_dict[k] = np.log10(total / (v+1.))
return idf_dict
def make_dataset(self):
vocab = self.get_vocab()
idf_dict = self.cal_idf()
kw_counter = collections.Counter()
sess_set = self.get_kwsess(vocab)
for data in sess_set:
kw_counter.update(data['kws'])
kw_freq = {}
kw_sum = sum(kw_counter.values())
for k, v in kw_counter.most_common():
kw_freq[k] = v / kw_sum
for data in sess_set:
data['score'] = 0.
for kw in set(data['kws']):
data['score'] += kw_freq[kw]
data['score'] /= len(set(data['kws']))
sess_set.sort(key=lambda x: x['score'], reverse=True)
all_data = {'train':[], 'valid':[], 'test':[]}
keyword_extractor = KeywordExtractor(idf_dict)
for id, sess in enumerate(sess_set):
type = 'train'
if id < 500:
type = 'test'
elif random.random() < 0.05:
type = 'valid'
sample = {'dialog':sess['dialog'], 'kwlist':[]}
for i in range(len(sess['dialog'])):
sample['kwlist'].append(keyword_extractor.idf_extract(sess['dialog'][i]))
all_data[type].append(sample)
pickle.dump(all_data, open('source_data.pk','wb'))
return all_data
================================================
FILE: preprocess/extraction.py
================================================
from data_utils import *
class KeywordExtractor():
def __init__(self, idf_dict = None):
self.idf_dict = idf_dict
@staticmethod
def is_keyword_tag(tag):
return tag.startswith('VB') or tag.startswith('NN') or tag.startswith('JJ')
@staticmethod
def cal_tag_score(tag):
if tag.startswith('VB'):
return 1.
if tag.startswith('NN'):
return 2.
if tag.startswith('JJ'):
return 0.5
return 0.
def idf_extract(self, string, con_kw = None):
tokens = simp_tokenize(string)
seq_len = len(tokens)
tokens = pos_tag(tokens)
source = kw_tokenize(string)
candi = []
result = []
for i, (word, tag) in enumerate(tokens):
score = self.cal_tag_score(tag)
if not is_candiword(source[i]) or score == 0.:
continue
if con_kw is not None and source[i] in con_kw:
continue
score *= source.count(source[i])
score *= 1 / seq_len
score *= self.idf_dict[source[i]]
candi.append((source[i], score))
if score > 0.15:
result.append(source[i])
return list(set(result))
def extract(self, string):
tokens = simp_tokenize(string)
tokens = pos_tag(tokens)
source = kw_tokenize(string)
kwpos_alters = []
for i, (word, tag) in enumerate(tokens):
if source[i] and self.is_keyword_tag(tag):
kwpos_alters.append(i)
kwpos, keywords = [], []
for id in kwpos_alters:
if is_candiword(source[id]):
keywords.append(source[id])
return list(set(keywords))
================================================
FILE: preprocess/prepare_data.py
================================================
from dataset import dts_Target
from collections import Counter
import pickle
import random
import os
import shutil
if not os.path.exists('../tx_data'):
os.mkdir('../tx_data')
os.mkdir('../tx_data/train')
os.mkdir('../tx_data/valid')
os.mkdir('../tx_data/test')
# import texar
# if not os.path.exists('convai2/source'):
# print('Downloading source ConvAI2 data')
# texar.data.maybe_download('https://drive.google.com/file/d/1LPxNIVO52hZOwbV3Zply_ITi2Uacit-V/view?usp=sharing'
# ,'convai2', extract=True)
shutil.copy('convai2/source/embedding.txt', '../tx_data/embedding.txt')
dataset = dts_Target()
dataset.make_dataset()
data = pickle.load(open("source_data.pk","rb"))
max_utter = 9
candidate_num = 20
start_corpus_file = open("../tx_data/start_corpus.txt", "w")
corpus_file = open("../tx_data/corpus.txt", "w")
for stage in ['train', 'valid', 'test']:
source_file = open("../tx_data/{}/source.txt".format(stage), "w")
target_file = open("../tx_data/{}/target.txt".format(stage), "w")
context_file = open("../tx_data/{}/context.txt".format(stage), "w")
keywords_file = open("../tx_data/{}/keywords.txt".format(stage), "w")
label_file = open("../tx_data/{}/label.txt".format(stage), "w")
keywords_vocab_file = open("../tx_data/{}/keywords_vocab.txt".format(stage), "w")
corpus = []
keywords_counter = Counter()
for sample in data[stage]:
corpus += sample['dialog'][1:]
start_corpus_file.write(sample['dialog'][0]+ '\n')
for kws in sample['kwlist']:
keywords_counter.update(kws)
for kw, _ in keywords_counter.most_common():
keywords_vocab_file.write(kw + '\n')
for sample in data[stage]:
for i in range(2, len(sample['dialog'])):
if len(sample['kwlist'][i]) > 0:
source_list = sample['dialog'][max(0, i - max_utter):i]
source_str = '|||'.join(source_list)
while True:
random_corpus = random.sample(corpus, candidate_num - 1)
if sample['dialog'][i] not in random_corpus:
break
corpus_file.write(sample['dialog'][i] + '\n')
target_list = [sample['dialog'][i]] + random_corpus
target_str = '|||'.join(target_list)
source_file.write(source_str + '\n')
target_file.write(target_str + '\n')
context_file.write(' '.join(sample['kwlist'][i-2] +
sample['kwlist'][i-1]) + '\n')
keywords_file.write(' '.join(sample['kwlist'][i]) + '\n')
label_file.write('0\n')
source_file.close()
target_file.close()
label_file.close()
keywords_vocab_file.close()
context_file.close()
start_corpus_file.close()
corpus_file.close()
================================================
FILE: readme.md
================================================
# Target-Guided Open-Domain Conversation
This is the code for the following paper:
[Target-Guided Open-Domain Conversation](http://arxiv.org/abs/1905.11553)
*Jianheng Tang, Tiancheng Zhao, Chenyan Xiong, Xiaodan Liang, Eric Xing, Zhiting Hu; ACL 2019*
### Requirement
- `nltk==3.4`
- `tensoflow==1.12`
- `texar>=0.2.1` ([Texar](https://github.com/asyml/texar))
### Usage
#### Data Preparation
The dataset developed in the paper is on [google drive](https://drive.google.com/file/d/1oTjOQjm7iiUitOPLCmlkXOCbEPoSWDPX/view?usp=sharing). Download
and unzip it into `preprocess/convai2`. Then run the following command:
```shell
cd preprocess
python prepare_data.py
```
By default, the processed data will be put in the `tx_data` directory.
#### Turn-level Supervised Learning
In this project there are 5 different types of agents, including the kernel/neural/matrix/retrieval/retrieval_stgy agent,
which are all discribed in the paper. You can modify the configration of each agent in the `config` directory.
To train the kernel/neural/matrix agent, you need to first train/test the keyword prediction module,
and then train/test the retrieval module of each agent specified by the `--agent` parameter.
```shell
python train.py --mode train_kw --agent kernel
python train.py --mode train --agent kernel
python train.py --mode test --agent kernel
```
The retrieval agent and the retrieval_stgy agent share the same retrival module. You only need to train one of them:
```shell
python train.py --mode train --agent retrieval
python train.py --mode test --agent retrieval
```
#### Target-guided Conversation
After turn-level training, you can start target-guided conversation (human evaluation) with
the kernel/neural/matrix/retrieval/retrieval_stgy agent specified by the `--agent` parameter.
```shell
python chat.py --agent kernel
```
A demo of target-guided conversation (user is required to input the response after `HUMAN:`)
```
--------Session 0 --------
START: hello, how are you doing
HUMAN: hi i am great just finishing up some homework. how about you
AGENT: just got done eating some chicken and rice after i got done with my gym session
HUMAN: i am gonna eat some vegetable for dinner
AGENT: so you are vegetarian ? i am from los angeles so thereare lots of vegetarian restaurants
HUMAN: i am not a vegetarian . fried chicken is my favorite.
AGENT: oh boy , when i go to mcdonald i always get fries for my 2 dogs
Successfully chat to the target 'mcdonald'.
```
You can also watch the simulation of the target-guided conversation
between the retrieval agent pretending the user and the kernel/neural/matrix/retrieval_stgy agent specified by the `--agent` parameter. Simulation times are defined by the `--times` parameter. The program will calculate the success times and average turns in the end.
```shell
python simulate.py --agent kernel --times 100
```
A demo of simulation:
```
--------Session 0 --------
START: hey hey good morning to you
retrieval_agent: fine thanks . what do you do for a living ?
kernel_agent: rewarding job , i work at a hospital
Keyword: job, Similarity: 0.58
Successfully chat to the target 'hospital'.
...
--------Session 99 --------
START: hey hows it going ? i'm just cooking a steak
retrieval_agent: i'm thinking of a bbq sandwich for lunch
kernel_agent: nice i love to cook but now its just me and the fur babies
Keyword: baby, Similarity: 0.45
retrieval_agent: i love bagels however i own a dry cleaners
kernel_agent: i love animals felix my cat and my dog emmy
Keyword: cat, Similarity: 0.56
retrieval_agent: sounds awesome i have all kind of pets my family own a farm
kernel_agent: i love blue as well even my hair is blue
Keyword: blue, Similarity: 1.00
Successfully chat to the target 'blue'.
success time 83, average turns 4.28
```
================================================
FILE: simulate.py
================================================
import tensorflow as tf
import importlib
import random
from preprocess.data_utils import utter_preprocess, is_reach_goal
from model import retrieval
class Target_Simulation():
def __init__(self, config_model, config_data, config_retrieval):
g1 = tf.Graph()
with g1.as_default():
self.retrieval_agent = retrieval.Predictor(config_retrieval, config_data)
sess1 = tf.Session(graph=g1, config=self.retrieval_agent.gpu_config)
self.retrieval_agent.retrieve_init(sess1)
g2 = tf.Graph()
with g2.as_default():
self.target_agent = model.Predictor(config_model, config_data)
sess2 = tf.Session(graph=g2, config=self.target_agent.gpu_config)
self.target_agent.retrieve_init(sess2)
self.start_utter = config_data._start_corpus
success_cnt, turns_cnt = 0, 0
for i in range(int(FLAGS.times)):
print('--------Session {} --------'.format(i))
success, turns = self.simulate(sess1, sess2)
success_cnt += success
turns_cnt += turns
print('success time {}, average turns {:.2f}'.format(success_cnt, turns_cnt / success_cnt))
def simulate(self, sess1, sess2):
history = []
history.append(random.sample(self.start_utter,1)[0])
target_kw = random.sample(target_set,1)[0]
self.target_agent.target = target_kw
self.target_agent.score = 0.
self.target_agent.reply_list = []
self.retrieval_agent.reply_list = []
print('START: ' + history[0])
for i in range(config_data._max_turns):
source = utter_preprocess(history, config_data._max_seq_len)
reply = self.retrieval_agent.retrieve(source, sess1)
print('retrieval_agent: ', reply)
history.append(reply)
source = utter_preprocess(history, config_data._max_seq_len)
reply = self.target_agent.retrieve(source, sess2)
print('{}_agent: '.format(FLAGS.agent), reply)
print('Keyword: {}, Similarity: {:.2f}'.format(self.target_agent.next_kw, self.target_agent.score))
history.append(reply)
if is_reach_goal(history[-2] + history[-1], target_kw):
print('Successfully chat to the target \'{}\'.'.format(target_kw))
return (True, (len(history)+1)//2)
print('Failed by reaching the maximum turn, target: \'{}\'.'.format(target_kw))
return (False, 0)
if __name__ == '__main__':
flags = tf.flags
flags.DEFINE_string('agent', 'kernel', 'The agent type, supports kernel / matrix / neural / retrieval.')
flags.DEFINE_string('times', '100', 'Simulation times.')
FLAGS = flags.FLAGS
config_data = importlib.import_module('config.data_config')
config_model = importlib.import_module('config.' + FLAGS.agent)
config_retrieval = importlib.import_module('config.retrieval')
model = importlib.import_module('model.' + FLAGS.agent)
target_set = []
for line in open('tx_data/test/keywords.txt', 'r').readlines():
target_set = target_set + line.strip().split(' ')
Target_Simulation(config_model,config_data,config_retrieval)
================================================
FILE: train.py
================================================
import tensorflow as tf
import importlib
import os
if __name__ == '__main__':
flags = tf.flags
flags.DEFINE_string('data', 'data_config', 'The data config')
flags.DEFINE_string('agent', 'kernel', 'The predictor type')
flags.DEFINE_string('mode', 'train', 'The mode')
FLAGS = flags.FLAGS
config_data = importlib.import_module('config.' + FLAGS.data)
config_model = importlib.import_module('config.' + FLAGS.agent)
model = importlib.import_module('model.' + FLAGS.agent)
predictor = model.Predictor(config_model, config_data, FLAGS.mode)
if not os.path.exists('save/'+FLAGS.agent):
os.makedirs('save/'+FLAGS.agent)
if FLAGS.mode == 'train_kw':
predictor.train_keywords()
if FLAGS.mode == 'test_kw':
predictor.test_keywords()
if FLAGS.mode == 'train':
predictor.train()
predictor.test()
if FLAGS.mode == 'test':
predictor.test()
gitextract_1diizlg0/ ├── chat.py ├── config/ │ ├── data_config.py │ ├── kernel.py │ ├── matrix.py │ ├── neural.py │ ├── retrieval.py │ └── retrieval_stgy.py ├── model/ │ ├── kernel.py │ ├── matrix.py │ ├── neural.py │ ├── retrieval.py │ └── retrieval_stgy.py ├── preprocess/ │ ├── convai2/ │ │ ├── __init__.py │ │ ├── api.py │ │ └── candi_keyword.txt │ ├── data_utils.py │ ├── dataset.py │ ├── extraction.py │ └── prepare_data.py ├── readme.md ├── simulate.py └── train.py
SYMBOL INDEX (88 symbols across 11 files)
FILE: chat.py
class Target_Chat (line 6) | class Target_Chat():
method __init__ (line 7) | def __init__(self, agent):
method chat (line 16) | def chat(self, sess):
FILE: model/kernel.py
class Predictor (line 7) | class Predictor():
method __init__ (line 8) | def __init__(self, config_model, config_data, mode=None):
method build_model (line 15) | def build_model(self):
method forward_kernel (line 33) | def forward_kernel(self, kw_embed, context_ids):
method predict_keywords (line 48) | def predict_keywords(self, batch):
method train_keywords (line 66) | def train_keywords(self):
method test_keywords (line 111) | def test_keywords(self):
method forward (line 144) | def forward(self, batch):
method train (line 188) | def train(self):
method test (line 238) | def test(self):
method retrieve_init (line 264) | def retrieve_init(self, sess):
method retrieve (line 314) | def retrieve(self, history_all, sess):
FILE: model/matrix.py
class Predictor (line 6) | class Predictor():
method __init__ (line 7) | def __init__(self, config_model, config_data, mode=None):
method build_model (line 14) | def build_model(self, mode):
method forward_matrix (line 36) | def forward_matrix(self, context_ids):
method predict_keywords (line 40) | def predict_keywords(self, batch):
method train_keywords (line 55) | def train_keywords(self):
method test_keywords (line 89) | def test_keywords(self):
method forward (line 121) | def forward(self, batch):
method train (line 164) | def train(self):
method test (line 205) | def test(self):
method retrieve_init (line 231) | def retrieve_init(self, sess):
method retrieve (line 285) | def retrieve(self, history_all, sess):
FILE: model/neural.py
class Predictor (line 7) | class Predictor():
method __init__ (line 8) | def __init__(self, config_model, config_data, mode=None):
method build_model (line 15) | def build_model(self):
method forward_neural (line 32) | def forward_neural(self, context_ids, context_length):
method predict_keywords (line 38) | def predict_keywords(self, batch):
method train_keywords (line 54) | def train_keywords(self):
method test_keywords (line 94) | def test_keywords(self):
method forward (line 127) | def forward(self, batch):
method train (line 169) | def train(self):
method test (line 214) | def test(self):
method retrieve_init (line 241) | def retrieve_init(self, sess):
method retrieve (line 295) | def retrieve(self, history_all, sess):
FILE: model/retrieval.py
class Predictor (line 6) | class Predictor():
method __init__ (line 7) | def __init__(self, config_model, config_data, mode=None):
method build_model (line 14) | def build_model(self):
method forward (line 25) | def forward(self, batch):
method train (line 48) | def train(self):
method test (line 89) | def test(self):
method retrieve_init (line 115) | def retrieve_init(self, sess):
method retrieve (line 150) | def retrieve(self, source, sess):
FILE: model/retrieval_stgy.py
class Predictor (line 6) | class Predictor():
method __init__ (line 7) | def __init__(self, config_model, config_data, mode=None):
method build_model (line 14) | def build_model(self):
method forward (line 26) | def forward(self, batch):
method train (line 49) | def train(self):
method test (line 90) | def test(self):
method retrieve_init (line 116) | def retrieve_init(self, sess):
method retrieve (line 155) | def retrieve(self, source, sess):
FILE: preprocess/convai2/api.py
class dts_ConvAI2 (line 4) | class dts_ConvAI2(object):
method __init__ (line 5) | def __init__(self, path=data_path):
method _txt_to_json (line 8) | def _txt_to_json(self, txt_path, mode, cands):
method get_data (line 67) | def get_data(self, mode='train', revised=False, cands=False):
method get_dialogs (line 78) | def get_dialogs(self, mode='all'):
FILE: preprocess/data_utils.py
function tokenize (line 8) | def tokenize(example, ppln):
function kw_tokenize (line 14) | def kw_tokenize(string):
function simp_tokenize (line 18) | def simp_tokenize(string):
function nltk_tokenize (line 22) | def nltk_tokenize(string):
function lower (line 26) | def lower(tokens):
function pos_tag (line 32) | def pos_tag(tokens):
function to_basic_form (line 36) | def to_basic_form(tokens):
function truecasing (line 51) | def truecasing(tokens):
function is_candiword (line 74) | def is_candiword(a):
function calculate_linsim (line 86) | def calculate_linsim(a, b):
function is_reach_goal (line 99) | def is_reach_goal(context, goal):
function make_context (line 111) | def make_context(string):
function utter_preprocess (line 120) | def utter_preprocess(string_list, max_length):
FILE: preprocess/dataset.py
class dts_Target (line 9) | class dts_Target(dts_ConvAI2):
method __init__ (line 10) | def __init__(self, *args, **kwargs):
method get_vocab (line 13) | def get_vocab(self):
method get_kwsess (line 27) | def get_kwsess(self, vocab, mode='all'):
method cal_idf (line 42) | def cal_idf(self):
method make_dataset (line 55) | def make_dataset(self):
FILE: preprocess/extraction.py
class KeywordExtractor (line 3) | class KeywordExtractor():
method __init__ (line 4) | def __init__(self, idf_dict = None):
method is_keyword_tag (line 8) | def is_keyword_tag(tag):
method cal_tag_score (line 12) | def cal_tag_score(tag):
method idf_extract (line 21) | def idf_extract(self, string, con_kw = None):
method extract (line 43) | def extract(self, string):
FILE: simulate.py
class Target_Simulation (line 7) | class Target_Simulation():
method __init__ (line 8) | def __init__(self, config_model, config_data, config_retrieval):
method simulate (line 28) | def simulate(self, sess1, sess2):
Condensed preview — 22 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (129K chars).
[
{
"path": "chat.py",
"chars": 2169,
"preview": "import tensorflow as tf\nimport importlib\nimport random\nfrom preprocess.data_utils import utter_preprocess, is_reach_goal"
},
{
"path": "config/data_config.py",
"chars": 2865,
"preview": "import os\ndata_root = './tx_data'\n_corpus = [x.strip() for x in open('tx_data/corpus.txt', 'r').readlines()]\n_start_corp"
},
{
"path": "config/kernel.py",
"chars": 1828,
"preview": "_hidden_size = 200\n_code_len = 800\n_save_path = 'save/kernel/model_1'\n_kernel_save_path = 'save/kernel/keyword_1'\n_kerne"
},
{
"path": "config/matrix.py",
"chars": 1326,
"preview": "_hidden_size = 200\n_code_len = 800\n_save_path = 'save/matrix/model_1'\n_matrix_save_path = 'save/matrix/matrix_1.pk'\n_max"
},
{
"path": "config/neural.py",
"chars": 1734,
"preview": "_hidden_size = 200\n_code_len = 800\n_save_path = 'save/neural/model_1'\n_neural_save_path = 'save/neural/keyword_1'\n_max_e"
},
{
"path": "config/retrieval.py",
"chars": 887,
"preview": "_hidden_size = 200\n_code_len = 200\n_save_path = 'save/retrieval/model_1'\n_max_epoch = 10\n\nsource_encoder_hparams = {\n "
},
{
"path": "config/retrieval_stgy.py",
"chars": 887,
"preview": "_hidden_size = 200\n_code_len = 200\n_save_path = 'save/retrieval/model_1'\n_max_epoch = 10\n\nsource_encoder_hparams = {\n "
},
{
"path": "model/kernel.py",
"chars": 18647,
"preview": "import texar as tx\nimport tensorflow as tf\nimport numpy as np\nfrom preprocess.data_utils import kw_tokenize\n\n\nclass Pred"
},
{
"path": "model/matrix.py",
"chars": 15962,
"preview": "import texar as tx\nimport tensorflow as tf\nimport numpy as np\nimport pickle\n\nclass Predictor():\n def __init__(self, c"
},
{
"path": "model/neural.py",
"chars": 17361,
"preview": "import texar as tx\nimport tensorflow as tf\nimport numpy as np\nfrom preprocess.data_utils import kw_tokenize\n\n\nclass Pred"
},
{
"path": "model/retrieval.py",
"chars": 8428,
"preview": "import texar as tx\nimport tensorflow as tf\nimport numpy as np\n\n\nclass Predictor():\n def __init__(self, config_model, "
},
{
"path": "model/retrieval_stgy.py",
"chars": 9234,
"preview": "import texar as tx\nimport tensorflow as tf\nimport numpy as np\nfrom preprocess.data_utils import kw_tokenize\n\nclass Predi"
},
{
"path": "preprocess/convai2/__init__.py",
"chars": 19,
"preview": "from .api import *\n"
},
{
"path": "preprocess/convai2/api.py",
"chars": 2847,
"preview": "import os\ndata_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'source')\n\nclass dts_ConvAI2(object):\n "
},
{
"path": "preprocess/convai2/candi_keyword.txt",
"chars": 18577,
"preview": "favorite\nsound\nplay\ndog\nmusic\nkid\neat\nschool\nenjoy\njob\nwatch\nread\nfood\ncat\nfriend\nfamily\nhobby\npeople\npet\ncar\ngame\nhear\n"
},
{
"path": "preprocess/data_utils.py",
"chars": 3513,
"preview": "import nltk\nimport os\nfrom nltk.stem import WordNetLemmatizer\n\n_lemmatizer = WordNetLemmatizer()\n\n\ndef tokenize(example,"
},
{
"path": "preprocess/dataset.py",
"chars": 3101,
"preview": "import numpy as np\nimport collections\nimport random\nimport pickle\nfrom convai2 import dts_ConvAI2\nfrom extraction import"
},
{
"path": "preprocess/extraction.py",
"chars": 1746,
"preview": "from data_utils import *\n\nclass KeywordExtractor():\n def __init__(self, idf_dict = None):\n self.idf_dict = idf"
},
{
"path": "preprocess/prepare_data.py",
"chars": 2868,
"preview": "from dataset import dts_Target\nfrom collections import Counter\nimport pickle\nimport random\nimport os\nimport shutil\nif no"
},
{
"path": "readme.md",
"chars": 3921,
"preview": "# Target-Guided Open-Domain Conversation\r\n\r\nThis is the code for the following paper:\r\n\r\n[Target-Guided Open-Domain Conv"
},
{
"path": "simulate.py",
"chars": 3212,
"preview": "import tensorflow as tf\nimport importlib\nimport random\nfrom preprocess.data_utils import utter_preprocess, is_reach_goal"
},
{
"path": "train.py",
"chars": 932,
"preview": "import tensorflow as tf\nimport importlib\nimport os\nif __name__ == '__main__':\n flags = tf.flags\n flags.DEFINE_stri"
}
]
About this extraction
This page contains the full source code of the squareRoot3/Target-Guided-Conversation GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 22 files (119.2 KB), approximately 30.8k tokens, and a symbol index with 88 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.
Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.