forked from PaddlePaddle/Paddle
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmq2007.py
More file actions
336 lines (286 loc) · 10.4 KB
/
Copy pathmq2007.py
File metadata and controls
336 lines (286 loc) · 10.4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
MQ2007 dataset
MQ2007 is a query set from Million Query track of TREC 2007. There are about 1700 queries in it with labeled documents. In MQ2007, the 5-fold cross
validation strategy is adopted and the 5-fold partitions are included in the package. In each fold, there are three subsets for learning: training set,
validation set and testing set.
MQ2007 dataset from website
http://research.microsoft.com/en-us/um/beijing/projects/letor/LETOR4.0/Data/MQ2007.rar and parse training set and test set into paddle reader creators
"""
from __future__ import print_function
import os
import functools
from .common import download
import numpy as np
# URL = "http://research.microsoft.com/en-us/um/beijing/projects/letor/LETOR4.0/Data/MQ2007.rar"
URL = "http://www.bigdatalab.ac.cn/benchmark/upload/download_source/7b6dbbe2-842c-11e4-a536-bcaec51b9163_MQ2007.rar"
MD5 = "7be1640ae95c6408dab0ae7207bdc706"
def __initialize_meta_info__():
"""
download and extract the MQ2007 dataset
"""
import rarfile
fn = fetch()
rar = rarfile.RarFile(fn)
dirpath = os.path.dirname(fn)
rar.extractall(path=dirpath)
return dirpath
class Query(object):
"""
queries used for learning to rank algorithms. It is created from relevance scores, query-document feature vectors
Parameters:
----------
query_id : int
query_id in dataset, mapping from query to relevance documents
relevance_score : int
relevance score of query and document pair
feature_vector : array, dense feature
feature in vector format
description : string
comment section in query doc pair data
"""
def __init__(self,
query_id=-1,
relevance_score=-1,
feature_vector=None,
description=""):
self.query_id = query_id
self.relevance_score = relevance_score
if feature_vector is None:
self.feature_vector = []
else:
self.feature_vector = feature_vector
self.description = description
def __str__(self):
string = "%s %s %s" % (str(self.relevance_score), str(self.query_id),
" ".join(str(f) for f in self.feature_vector))
return string
# @classmethod
def _parse_(self, text):
"""
parse line into Query
"""
comment_position = text.find('#')
line = text[:comment_position].strip()
self.description = text[comment_position + 1:].strip()
parts = line.split()
if len(parts) != 48:
sys.stdout.write("expect 48 space split parts, get %d" %
(len(parts)))
return None
# format : 0 qid:10 1:0.000272 2:0.000000 ....
self.relevance_score = int(parts[0])
self.query_id = int(parts[1].split(':')[1])
for p in parts[2:]:
pair = p.split(':')
self.feature_vector.append(float(pair[1]))
return self
class QueryList(object):
"""
group query into list, every item in list is a Query
"""
def __init__(self, querylist=None):
self.query_id = -1
if querylist is None:
self.querylist = []
else:
self.querylist = querylist
for query in self.querylist:
if self.query_id == -1:
self.query_id = query.query_id
else:
if self.query_id != query.query_id:
raise ValueError("query in list must be same query_id")
def __iter__(self):
for query in self.querylist:
yield query
def __len__(self):
return len(self.querylist)
def __getitem__(self, i):
return self.querylist[i]
def _correct_ranking_(self):
if self.querylist is None:
return
self.querylist.sort(key=lambda x: x.relevance_score, reverse=True)
def _add_query(self, query):
if self.query_id == -1:
self.query_id = query.query_id
else:
if self.query_id != query.query_id:
raise ValueError("query in list must be same query_id")
self.querylist.append(query)
def gen_plain_txt(querylist):
"""
gen plain text in list for other usage
Paramters:
--------
querylist : querylist, one query match many document pairs in list, see QueryList
return :
------
query_id : np.array, shape=(samples_num, )
label : np.array, shape=(samples_num, )
querylist : np.array, shape=(samples_num, feature_dimension)
"""
if not isinstance(querylist, QueryList):
querylist = QueryList(querylist)
querylist._correct_ranking_()
for query in querylist:
yield querylist.query_id, query.relevance_score, np.array(
query.feature_vector)
def gen_point(querylist):
"""
gen item in list for point-wise learning to rank algorithm
Paramters:
--------
querylist : querylist, one query match many document pairs in list, see QueryList
return :
------
label : np.array, shape=(samples_num, )
querylist : np.array, shape=(samples_num, feature_dimension)
"""
if not isinstance(querylist, QueryList):
querylist = QueryList(querylist)
querylist._correct_ranking_()
for query in querylist:
yield query.relevance_score, np.array(query.feature_vector)
def gen_pair(querylist, partial_order="full"):
"""
gen pair for pair-wise learning to rank algorithm
Paramters:
--------
querylist : querylist, one query match many document pairs in list, see QueryList
pairtial_order : "full" or "neighbour"
there is redundant in all possible pair combinations, which can be simplified
gen pairs for neighbour items or the full partial order pairs
return :
------
label : np.array, shape=(1)
query_left : np.array, shape=(1, feature_dimension)
query_right : same as left
"""
if not isinstance(querylist, QueryList):
querylist = QueryList(querylist)
querylist._correct_ranking_()
labels = []
docpairs = []
# C(n,2)
for i in range(len(querylist)):
query_left = querylist[i]
for j in range(i + 1, len(querylist)):
query_right = querylist[j]
if query_left.relevance_score > query_right.relevance_score:
labels.append([1])
docpairs.append([
np.array(query_left.feature_vector),
np.array(query_right.feature_vector)
])
elif query_left.relevance_score < query_right.relevance_score:
labels.append([1])
docpairs.append([
np.array(query_right.feature_vector),
np.array(query_left.feature_vector)
])
for label, pair in zip(labels, docpairs):
yield np.array(label), pair[0], pair[1]
def gen_list(querylist):
"""
gen item in list for list-wise learning to rank algorithm
Paramters:
--------
querylist : querylist, one query match many document pairs in list, see QueryList
return :
------
label : np.array, shape=(samples_num, )
querylist : np.array, shape=(samples_num, feature_dimension)
"""
if not isinstance(querylist, QueryList):
querylist = QueryList(querylist)
querylist._correct_ranking_()
relevance_score_list = [[query.relevance_score] for query in querylist]
feature_vector_list = [query.feature_vector for query in querylist]
yield np.array(relevance_score_list), np.array(feature_vector_list)
def query_filter(querylists):
"""
filter query get only document with label 0.
label 0, 1, 2 means the relevance score document with query
parameters :
querylist : QueyList list
return :
querylist : QueyList list
"""
filter_query = []
for querylist in querylists:
relevance_score_list = [query.relevance_score for query in querylist]
if sum(relevance_score_list) != .0:
filter_query.append(querylist)
return filter_query
def load_from_text(filepath, shuffle=False, fill_missing=-1):
"""
parse data file into queries
"""
prev_query_id = -1
querylists = []
querylist = None
fn = __initialize_meta_info__()
with open(os.path.join(fn, filepath)) as f:
for line in f:
query = Query()
query = query._parse_(line)
if query == None:
continue
if query.query_id != prev_query_id:
if querylist is not None:
querylists.append(querylist)
querylist = QueryList()
prev_query_id = query.query_id
querylist._add_query(query)
if querylist is not None:
querylists.append(querylist)
return querylists
def __reader__(filepath, format="pairwise", shuffle=False, fill_missing=-1):
"""
Parameters
--------
filename : string
fill_missing : fill the missing value. default in MQ2007 is -1
Returns
------
yield
label query_left, query_right # format = "pairwise"
label querylist # format = "listwise"
"""
querylists = query_filter(
load_from_text(
filepath, shuffle=shuffle, fill_missing=fill_missing))
for querylist in querylists:
if format == "plain_txt":
yield next(gen_plain_txt(querylist))
elif format == "pointwise":
yield next(gen_point(querylist))
elif format == "pairwise":
for pair in gen_pair(querylist):
yield pair
elif format == "listwise":
yield next(gen_list(querylist))
train = functools.partial(__reader__, filepath="MQ2007/MQ2007/Fold1/train.txt")
test = functools.partial(__reader__, filepath="MQ2007/MQ2007/Fold1/test.txt")
def fetch():
return download(URL, "MQ2007", MD5)
if __name__ == "__main__":
fetch()
mytest = functools.partial(
__reader__, filepath="MQ2007/MQ2007/Fold1/sample", format="listwise")
for label, query in mytest():
print(label, query)