Skip to content

Commit 05d4654

Browse files
committed
complete word hash & quora demo
1 parent 133e93f commit 05d4654

File tree

11 files changed

+443
-123
lines changed

11 files changed

+443
-123
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
# Personal ignore
22
dataset/
3+
*.pkl
4+
result/
35

46
# Compiled source #
57
###################

DSSM/dssm.py

Lines changed: 129 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,18 @@
44
import pandas as pd
55
import numpy as np
66
import time
7+
import sklearn
78

89
import utils
10+
import tools
911

1012
class DSSM(object):
1113
'''
1214
Impletement DSSM Model in the Paper: Learning Deep Structured Semantic Models for Web Search using Clickthrough Data
1315
'''
1416
def __init__(self, hash_tokens_nums=3000, dnn_layer_nums=1, dnn_hidden_node_nums=50, feature_nums=50,
15-
batch_size=10, neg_nums=4, learning_rate=0.5, max_epochs=200):
17+
batch_size=10, neg_nums=4, learning_rate=0.5, max_epochs=200, loss_kind='mcl', w_init=0.1, \
18+
save_model_path='./', mlp_hidden_node_nums=32, mlp_layer_nums=2):
1619
'''
1720
paras:
1821
hash_tokens_nums: word hash后词的个数
@@ -23,7 +26,13 @@ def __init__(self, hash_tokens_nums=3000, dnn_layer_nums=1, dnn_hidden_node_nums
2326
neg_nums: 负样本的个数
2427
learning_rate: 学习率
2528
max_epoch: 迭代次数
29+
loss_kind: 'mcl': maximize the condition likelihood,极大似然估计条件概率; 'log_loss':交叉熵的方式计算loss
30+
w_init: 权重初始化
31+
save_model_path: 保存验证集上最优模型的文件路劲
32+
mlp_hidden_node_nums: 学习到的隐向量连接后加mlp层的节点数
33+
mlp_layer_nums: mlp层的层数
2634
'''
35+
2736
self.hash_token_nums = hash_tokens_nums
2837
self.dnn_layer_nums = dnn_layer_nums
2938
self.dnn_hidden_node_nums = dnn_hidden_node_nums
@@ -32,18 +41,36 @@ def __init__(self, hash_tokens_nums=3000, dnn_layer_nums=1, dnn_hidden_node_nums
3241
self.neg_nums = neg_nums
3342
self.learning_rate = learning_rate
3443
self.max_epochs = max_epochs
44+
self.loss_kind = loss_kind
45+
self.positive_weights = 1
46+
self.w_init = w_init
47+
self.save_model_path = save_model_path
48+
self.mlp_hidden_node_nums = mlp_hidden_node_nums
49+
self.mlp_layer_nums = mlp_layer_nums
3550

3651
'''
3752
query and doc 使用不同的网络结构,像论文中提到的那样
3853
'''
39-
self.input_q = tf.placeholder(tf.float32, shape=[batch_size, self.hash_token_nums]) # sample_nums, word_nums, hash_tokens_nums
40-
self.input_doc = tf.placeholder(tf.float32, shape=[batch_size, self.hash_token_nums]) # sample_nums, word_nums, hash_tokens_nums
41-
self.label = tf.placeholder(tf.float32, shape=[batch_size])
54+
self.input_q = tf.placeholder(tf.float32, shape=[None, self.hash_token_nums]) # sample_nums, word_nums, hash_tokens_nums
55+
self.input_doc = tf.placeholder(tf.float32, shape=[None, self.hash_token_nums]) # sample_nums, word_nums, hash_tokens_nums
56+
self.label = tf.placeholder(tf.float32, shape=[None])
57+
58+
self.predict_doc = None
59+
self.predict_query = None
60+
61+
self.relevance = self.create_model_op()
62+
63+
if self.loss_kind == 'mlc':
64+
self.loss = self.create_loss_max_condition_lh_op()
65+
elif self.loss_kind == 'log_loss':
66+
self.loss = self.create_log_loss_op()
67+
else:
68+
pass
4269

43-
self.predict_op = self.create_model_op()
44-
self.loss_op = self.create_loss_op()
45-
self.train_op = self.create_train_op()
70+
self.train = self.create_train_op()
4671

72+
def set_positive_weights(self, positive_weights):
73+
self.positive_weights = positive_weights
4774

4875
def create_model_op(self):
4976

@@ -65,7 +92,7 @@ def create_model_op(self):
6592
result = input_dict[one_structrue]
6693
for i in range(len(node_nums)-1):
6794
w = tf.Variable(
68-
tf.random_uniform([node_nums[i], node_nums[i+1]], -0.001, 0.001), name='weights'+str(i)
95+
tf.random_uniform([node_nums[i], node_nums[i+1]], -self.w_init, self.w_init), name='weights'+str(i)
6996
)
7097
# 网络比较深,参数比较多时,注意w取值应该比较小,学习率适当增大
7198
b = tf.Variable(tf.zeros([node_nums[i+1]]), name="bias"+str(i))
@@ -75,31 +102,66 @@ def create_model_op(self):
75102

76103
self.predict_query = features[0]
77104
self.predict_doc = features[1]
78-
norms1 = tf.sqrt(tf.reduce_sum(tf.square(features[0]), 1, keep_dims=False))
79-
norms2 = tf.sqrt(tf.reduce_sum(tf.square(features[1]), 1, keep_dims=False))
80-
self.relevance = tf.reduce_sum(features[0] * features[1], 1) / norms1 / norms2
81-
return self.relevance
82105

106+
'''
107+
为了对学习到了两个向量进行相似度打分,加一个mlp层, 最后一层全连接
83108
84-
def create_loss_op(self):
85109
'''
110+
result = tf.concat(features, 1)
111+
print result
112+
113+
with tf.variable_scope('mlp'):
114+
node_nums = [self.feature_nums*2] + [self.mlp_hidden_node_nums] * self.mlp_layer_nums + [1]
115+
for i in range(len(node_nums) - 1):
116+
w = tf.Variable(
117+
tf.random_uniform([node_nums[i], node_nums[i + 1]], -self.w_init, self.w_init),
118+
name='weights' + str(i)
119+
)
120+
b = tf.Variable(tf.zeros([node_nums[i + 1]]), name="bias" + str(i))
121+
result = tf.matmul(result, w) + b
122+
result = tf.nn.sigmoid(result)
123+
124+
125+
# norms1 = tf.sqrt(tf.reduce_sum(tf.square(features[0]), 1, keep_dims=False))
126+
# norms2 = tf.sqrt(tf.reduce_sum(tf.square(features[1]), 1, keep_dims=False))
127+
# relevance = tf.reduce_sum(features[0] * features[1], 1) / norms1 / norms2
128+
129+
# w_r = tf.Variable(tf.random_uniform([1], -self.w_init, self.w_init), name="weight-of-relevance")
130+
# b_r = tf.Variable(tf.zeros([1]), name="bais-of-relevance")
131+
# relevance = relevance * w_r + b_r
132+
# relevance = tf.nn.softmax(relevance)
133+
134+
return result
135+
136+
137+
def create_loss_max_condition_lh_op(self):
138+
'''
139+
用极大似然的方法计算, 正例的条件概率
86140
计算相关文档的loss, gama经验值也用来学习
87141
:return:
88142
'''
89143
gama = tf.Variable(tf.random_uniform([1]), name="gama")
90-
ret = self.predict_op * gama
144+
ret = self.relevance * gama
91145
ret = tf.reshape(ret, [-1, self.neg_nums+1])
92146
ret = tf.log(tf.nn.softmax(ret))
93-
ret = tf.reduce_sum(ret, 0)
94-
return -tf.gather(ret, 0)
147+
ret = tf.reduce_sum(ret, 0) # 行相加
148+
return -tf.gather(ret, 0) # 得到第一个,也即是正例的loss
149+
150+
151+
def create_log_loss_op(self):
152+
'''
153+
计算log_loss, 也就是交叉熵
154+
:return:
155+
'''
156+
return tf.reduce_sum(tf.contrib.losses.log_loss(self.relevance, self.label))
95157

96158

97159
def create_train_op(self):
98160
'''
99161
采用梯度下降方式学习
100162
:return:
101163
'''
102-
return tf.train.GradientDescentOptimizer(self.learning_rate).minimize(self.loss_op)
164+
return tf.train.GradientDescentOptimizer(self.learning_rate).minimize(self.loss)
103165

104166

105167
def creat_feed_dict(self, query_batch, doc_batch, label_batch):
@@ -116,7 +178,7 @@ def creat_feed_dict(self, query_batch, doc_batch, label_batch):
116178
}
117179

118180

119-
def run_epoch(self, sess, query_input, doc_input, labels):
181+
def run_epoch(self, sess, query_input, doc_input, labels, is_valid=False):
120182
'''
121183
计算一次迭代过程
122184
:param sess:
@@ -126,20 +188,31 @@ def run_epoch(self, sess, query_input, doc_input, labels):
126188
:return:
127189
'''
128190
average_loss = 0
191+
step = 0
129192
for step, (query, doc, label) in enumerate(
130193
utils.data_iterator(query_input, doc_input, labels, self.batch_size)
131194
):
132195
# print query[1, 1], doc[1, 1], label[1]
133196
self.creat_feed_dict(query, doc, label)
134-
_, loss_value, predict_query, predict_doc, relevance = sess.run([self.train_op, self.loss_op, self.predict_query\
135-
, self.predict_doc, self.relevance], feed_dict=self.feed_dict)
197+
self.set_positive_weights(len(query))
198+
199+
if not is_valid:
200+
# 跑这个train的时候 才更新W
201+
_, loss_value, predict_query, predict_doc, relevance = sess.run([self.train, self.loss, self.predict_query\
202+
, self.predict_doc, self.relevance], feed_dict=self.feed_dict)
203+
else:
204+
205+
loss_value, relevance = sess.run([self.loss, self.relevance], feed_dict=self.feed_dict)
206+
# print 'Chcek ', sklearn.metrics.log_loss(label, relevance), loss_value
207+
136208
average_loss += loss_value
137209
#print 'step ', step, loss_value
138210
#print 'predict ', predict_query[0], predict_doc[0], relevance[0]
139-
return average_loss / step
211+
return average_loss / (step+1), relevance
140212

141213

142-
def fit(self, sess, query_input, doc_input, labels):
214+
def fit(self, sess, query_input, doc_input, labels, valid_q_input=None, valid_d_input=None, valid_labels=None, \
215+
load_model=False):
143216
'''
144217
模型入口
145218
:param sess:
@@ -149,14 +222,40 @@ def fit(self, sess, query_input, doc_input, labels):
149222
:return:
150223
'''
151224
losses = []
225+
best_loss = 99999
226+
saver = tf.train.Saver()
227+
if load_model:
228+
saver.restore(sess, self.save_model_path)
229+
start_time = time.time()
230+
valid_loss, _ = self.run_epoch(sess, valid_q_input, valid_d_input, valid_labels, is_valid=True)
231+
duration = time.time() - start_time
232+
print('valid loss = %.5f (%.3f sec)'
233+
% (valid_loss, duration))
234+
losses.append(valid_loss)
235+
return losses
236+
152237
for epoch in range(self.max_epochs):
153238
start_time = time.time()
154-
average_loss = self.run_epoch(sess, query_input, doc_input, labels)
239+
average_loss, _ = self.run_epoch(sess, query_input, doc_input, labels)
155240
duration = time.time() - start_time
156241

157-
print('Epoch %d: loss = %.5f (%.3f sec)'
158-
% (epoch, average_loss, duration))
242+
if (epoch+1) % 100 == 0:
243+
if valid_labels is None:
244+
print('Epoch %d: loss = %.5f (%.3f sec)'
245+
% (epoch+1, average_loss, duration))
246+
else:
247+
valid_loss, _ = self.run_epoch(sess, valid_q_input, valid_d_input, valid_labels, is_valid=True)
248+
if valid_loss < best_loss:
249+
best_loss = valid_loss
250+
saver.save(sess, self.save_model_path)
251+
duration = time.time() - start_time
252+
print('Epoch %d: loss = %.5f valid loss = %.5f (%.3f sec)'
253+
% (epoch+1, average_loss, valid_loss, duration))
254+
159255
losses.append(average_loss)
256+
257+
if not valid_labels is None:
258+
print 'Final valid loss: ', best_loss
160259
return losses
161260

162261
def predict(self, sess, query, doc, labels):
@@ -170,6 +269,8 @@ def predict(self, sess, query, doc, labels):
170269
'''
171270
self.creat_feed_dict(query, doc, labels)
172271
predict = sess.run(self.relevance, feed_dict=self.feed_dict)
272+
return predict
273+
173274

174275

175276
def test_dssm():
@@ -179,12 +280,14 @@ def test_dssm():
179280
'''
180281
with tf.Graph().as_default():
181282
tf.set_random_seed(1)
283+
182284
model = DSSM(hash_tokens_nums=30000, dnn_layer_nums=2, dnn_hidden_node_nums=300, feature_nums=128,
183285
batch_size=10, neg_nums=4, learning_rate=0.02, max_epochs=500)
184286
sess = tf.Session()
185287
init = tf.initialize_all_variables()
186288
sess.run(init)
187289
np.random.seed(1)
290+
188291
query = np.random.rand(500, 30000)
189292
doc = np.random.rand(500, 30000)
190293
label = np.array([1, 0, 0, 0, 0] * 100)
@@ -195,7 +298,7 @@ def test_dssm():
195298

196299
losses = model.fit(sess, query, doc, label)
197300

198-
print losses[-1]
301+
#print losses[-1]
199302

200303

201304
if __name__ == '__main__':

DSSM/utils.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,16 @@
55

66

77
def data_iterator(orig_X, orig_y=None, orig_label=None, batch_size=10, shuffle=False):
8+
'''
9+
10+
:param orig_X:
11+
:param orig_y:
12+
:param orig_label:
13+
:param batch_size:
14+
:param shuffle:
15+
:return:
16+
'''
17+
818
# Optionally shuffle the data before training
919
if shuffle:
1020
indices = np.random.permutation(len(orig_X))

README.md

Lines changed: 26 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,16 @@
11
## For what
2-
Understanding the Application about Deep Learning in Text Matching Area & Implement Codes about the Classical Methods
2+
Understanding the Methods in Text Matching Area Including Key-words based Matching Model & Latent Semantic Matching Model.
3+
Implement the Classical Methods.
4+
5+
## Categories
6+
- Key-words based methods
7+
- tf-idf model
8+
- words common rate model
9+
- find the most important word with adding syntax information
10+
- Semantic methods
11+
- term bag models
12+
- structure considered models
13+
- Features based methods
314

415
## People in these area
516
- [Po-Sen Huang](https://posenhuang.github.io/full_publication.html)
@@ -8,11 +19,11 @@ Understanding the Application about Deep Learning in Text Matching Area & Implem
819
- [Hang Li](http://www.hangli-hl.com/index.html)
920

1021
## Survey
11-
> [深度文本匹配综述](http://kns.cnki.net/KCMS/detail/detail.aspx?dbcode=CJFQ&dbname=CAPJLAST&filename=JSJX20160920002&uid=WEEvREcwSlJHSldRa1FhdXNXYXJvK0FZMlhXUDZsYnBMQjhHTElMeE1jRT0=$9A4hF_YAuvQ5obgVAqNKPCYcEjKensW4ggI8Fm4gTkoUKaID8j8gFw!!&v=MzA2OTFscVdNMENMTDdSN3FlWU9ac0ZDcmxWYnZPSTFzPUx6N0Jkckc0SDlmTXBvMUZaT3NOWXc5TXptUm42ajU3VDNm)
22+
> [深度文本匹配综述(A Survey on Deep Text Matching)](http://kns.cnki.net/KCMS/detail/detail.aspx?dbcode=CJFQ&dbname=CAPJLAST&filename=JSJX20160920002&uid=WEEvREcwSlJHSldRa1FhdXNXYXJvK0FZMlhXUDZsYnBMQjhHTElMeE1jRT0=$9A4hF_YAuvQ5obgVAqNKPCYcEjKensW4ggI8Fm4gTkoUKaID8j8gFw!!&v=MzA2OTFscVdNMENMTDdSN3FlWU9ac0ZDcmxWYnZPSTFzPUx6N0Jkckc0SDlmTXBvMUZaT3NOWXc5TXptUm42ajU3VDNm)
1223
<br>
1324
1425

15-
## Methods & Papers
26+
## Methods & Papers about Semantic Methods
1627

1728
> [**DSSM**](./DSSM/dssm.py)
1829
<br> [Learning Deep Structured Semantic Models for Web Search using Clickthrough Data](https://posenhuang.github.io/papers/cikm2013_DSSM_fullversion.pdf)
@@ -64,14 +75,19 @@ Architecture for Community-based Question Answering](https://ijcai.org/Proceedin
6475
> [**DeepMatch_tree**]()
6576
<br> [Syntax-based Deep Matching of Short Texts](https://arxiv.org/pdf/1503.02427.pdf)
6677
78+
## Methods & Papers about Key Words Based Methods
79+
> [****]()
80+
<br> []()
6781
6882
## Related talks and books
69-
[Deep Learning for Web Search and
83+
* [Deep Learning for Web Search and
7084
Natural Language Processing](https://www.microsoft.com/en-us/research/wp-content/uploads/2016/02/wsdm2015.v3.pdf)
71-
[Deep Learning for Information Retrieval(Sigir 2016 Tutorial)](http://www.hangli-hl.com/uploads/3/4/4/6/34465961/deep_learning_for_information_retrieval.pdf)
72-
[Semantic Matching in Search (Sigir 2014 Workshop)](http://www.hangli-hl.com/uploads/3/4/4/6/34465961/semantic_matching_in_search.pdf)
73-
[Semantic Matching in Search (Book 2014)](http://www.bigdatalab.ac.cn/~junxu/publications/SemanticMatchingInSearch_2014.pdf)
74-
85+
* [Deep Learning for Information Retrieval(Sigir 2016 Tutorial)](http://www.hangli-hl.com/uploads/3/4/4/6/34465961/deep_learning_for_information_retrieval.pdf)
86+
* [Semantic Matching in Search (Sigir 2014 Workshop)](http://www.hangli-hl.com/uploads/3/4/4/6/34465961/semantic_matching_in_search.pdf)
87+
* [Semantic Matching in Search (Book 2014)](http://www.bigdatalab.ac.cn/~junxu/publications/SemanticMatchingInSearch_2014.pdf)
88+
* [gensim notebook](https://github.com/RaRe-Technologies/gensim/tree/develop/docs/notebooks)
89+
90+
7591
## Downloads
7692
> [DSSM/Sent2Vec Release Version](https://www.microsoft.com/en-us/download/details.aspx?id=52365)
7793
<br> MSRA发布的Sent2Vec发行版
@@ -86,12 +102,13 @@ Natural Language Processing](https://www.microsoft.com/en-us/research/wp-content
86102
* [Stack Exchange Data Dump](https://archive.org/details/stackexchange "Stack Exchange")
87103
* [Europarl: A Parallel Corpus for Statistical Machine Translation](http://www.iccs.inf.ed.ac.uk/~pkoehn/publications/europarl-mtsummit05.pdf "Philipp Koehn") ([www.statmt.org/europarl/](http://www.statmt.org/europarl/))
88104
* [RTE Knowledge Resources](http://aclweb.org/aclwiki/index.php?title=RTE_Knowledge_Resources)
89-
* [Kaggle Quora Question Pairs]()
105+
* [**Kaggle Quora Question Pairs**]()
90106

91107

92108
## Competition
93109
* [Kaggle Quora Question Pairs](https://www.kaggle.com/c/quora-question-pairs)
94110

111+
95112
## Pretrained Models
96113
* [Model Zoo](https://github.com/BVLC/caffe/wiki/Model-Zoo "Berkeley Vision and Learning Center")
97114
* [word2vec](https://code.google.com/p/word2vec/ "Tomas Mikolov")

0 commit comments

Comments
 (0)