Skip to content

Commit 8122d94

Browse files
authored
Add Scorer option to Query (#113)
* add scorer option to query * add query scorer test * add test deps directory to gitignore * add all scorer types in tests withscore
1 parent f3f11a4 commit 8122d94

File tree

3 files changed

+70
-25
lines changed

3 files changed

+70
-25
lines changed

.gitignore

+4-1
Original file line numberDiff line numberDiff line change
@@ -150,4 +150,7 @@ fabric.properties
150150
.idea/httpRequests
151151

152152
# Android studio 3.1+ serialized cache file
153-
.idea/caches/build_file_checksums.ser
153+
.idea/caches/build_file_checksums.ser
154+
155+
# Dependencies
156+
deps/*

redisearch/query.py

+21-9
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,13 @@ class Query(object):
55
Query is used to build complex queries that have more parameters than just the query string.
66
The query string is set in the constructor, and other options have setter functions.
77
8-
The setter functions return the query object, so they can be chained,
8+
The setter functions return the query object, so they can be chained,
99
i.e. `Query("foo").verbatim().filter(...)` etc.
1010
"""
1111

1212
def __init__(self, query_string):
1313
"""
14-
Create a new query object.
14+
Create a new query object.
1515
The query string is set in the constructor, and other options have setter functions.
1616
"""
1717

@@ -24,6 +24,7 @@ def __init__(self, query_string):
2424
self._verbatim = False
2525
self._with_payloads = False
2626
self._with_scores = False
27+
self._scorer = False
2728
self._filters = list()
2829
self._ids = None
2930
self._slop = -1
@@ -129,6 +130,14 @@ def in_order(self):
129130
self._in_order = True
130131
return self
131132

133+
def scorer(self, scorer):
134+
"""
135+
Use a different scoring function to evaluate document relevance. Default is `TFIDF`
136+
:param scorer: The scoring function to use (e.g. `TFIDF.DOCNORM` or `BM25`)
137+
"""
138+
self._scorer = scorer
139+
return self
140+
132141
def get_args(self):
133142
"""
134143
Format the redis arguments for this query and return them
@@ -144,7 +153,7 @@ def get_args(self):
144153
args.append('INFIELDS')
145154
args.append(len(self._fields))
146155
args += self._fields
147-
156+
148157
if self._verbatim:
149158
args.append('VERBATIM')
150159

@@ -159,9 +168,12 @@ def get_args(self):
159168
if self._with_payloads:
160169
args.append('WITHPAYLOADS')
161170

171+
if self._scorer:
172+
args += ['SCORER', self._scorer]
173+
162174
if self._with_scores:
163175
args.append('WITHSCORES')
164-
176+
165177
if self._ids:
166178
args.append('INKEYS')
167179
args.append(len(self._ids))
@@ -217,7 +229,7 @@ def no_content(self):
217229

218230
def no_stopwords(self):
219231
"""
220-
Prevent the query from being filtered for stopwords.
232+
Prevent the query from being filtered for stopwords.
221233
Only useful in very big queries that you are certain contain no stopwords.
222234
"""
223235
self._no_stopwords = True
@@ -236,7 +248,7 @@ def with_scores(self):
236248
"""
237249
self._with_scores = True
238250
return self
239-
251+
240252
def limit_fields(self, *fields):
241253
"""
242254
Limit the search to specific TEXT fields only
@@ -248,7 +260,7 @@ def limit_fields(self, *fields):
248260

249261
def add_filter(self, flt):
250262
"""
251-
Add a numeric or geo filter to the query.
263+
Add a numeric or geo filter to the query.
252264
**Currently only one of each filter is supported by the engine**
253265
254266
- **flt**: A NumericFilter or GeoFilter object, used on a corresponding field
@@ -273,7 +285,7 @@ class Filter(object):
273285
def __init__(self, keyword, field, *args):
274286

275287
self.args = [keyword, field] + list(args)
276-
288+
277289
class NumericFilter(Filter):
278290

279291
INF = '+inf'
@@ -303,4 +315,4 @@ class SortbyField(object):
303315

304316
def __init__(self, field, asc=True):
305317

306-
self.args = [field, 'ASC' if asc else 'DESC']
318+
self.args = [field, 'ASC' if asc else 'DESC']

test/test.py

+45-15
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,8 @@ def createIndex(self, client, num_docs = 100, definition=None):
5353

5454
assert isinstance(client, Client)
5555
try:
56-
client.create_index((TextField('play', weight=5.0),
57-
TextField('txt'),
56+
client.create_index((TextField('play', weight=5.0),
57+
TextField('txt'),
5858
NumericField('chapter')), definition=definition)
5959
except redis.ResponseError:
6060
client.dropindex(delete_documents=True)
@@ -161,7 +161,7 @@ def testClient(self):
161161
self.assertEqual(len(subset), docs.total)
162162
ids = [x.id for x in docs.docs]
163163
self.assertEqual(set(ids), set(subset))
164-
164+
165165
# self.assertRaises(redis.ResponseError, client.search, Query('henry king').return_fields('play', 'nonexist'))
166166

167167
# test slop and in order
@@ -272,7 +272,7 @@ def testScores(self):
272272
#self.assertEqual(0.2, res.docs[1].score)
273273

274274
def testReplace(self):
275-
275+
276276
conn = self.redis()
277277

278278
with conn as r:
@@ -296,7 +296,7 @@ def testReplace(self):
296296
self.assertEqual(1, res.total)
297297
self.assertEqual('doc1', res.docs[0].id)
298298

299-
def testStopwords(self):
299+
def testStopwords(self):
300300
# Creating a client with a given index name
301301
client = self.getCleanClient('idx')
302302

@@ -324,7 +324,7 @@ def testFilters(self):
324324

325325
for i in r.retry_with_rdb_reload():
326326
waitForIndex(r, 'idx')
327-
# Test numerical filter
327+
# Test numerical filter
328328
q1 = Query("foo").add_filter(NumericFilter('num', 0, 2)).no_content()
329329
q2 = Query("foo").add_filter(NumericFilter('num', 2, NumericFilter.INF, minExclusive=True)).no_content()
330330
res1, res2 = client.search(q1), client.search(q2)
@@ -338,11 +338,11 @@ def testFilters(self):
338338
q1 = Query("foo").add_filter(GeoFilter('loc', -0.44, 51.45, 10)).no_content()
339339
q2 = Query("foo").add_filter(GeoFilter('loc', -0.44, 51.45, 100)).no_content()
340340
res1, res2 = client.search(q1), client.search(q2)
341-
341+
342342
self.assertEqual(1, res1.total)
343343
self.assertEqual(2, res2.total)
344344
self.assertEqual('doc1', res1.docs[0].id)
345-
345+
346346
# Sort results, after RDB reload order may change
347347
list = [res2.docs[0].id, res2.docs[1].id]
348348
list.sort()
@@ -371,7 +371,7 @@ def testSortby(self):
371371
# Creating a client with a given index name
372372
client = Client('idx', port=conn.port)
373373
client.redis.flushdb()
374-
374+
375375
client.create_index((TextField('txt'), NumericField('num', sortable=True)))
376376
client.add_document('doc1', txt = 'foo bar', num = 1)
377377
client.add_document('doc2', txt = 'foo baz', num = 2)
@@ -381,7 +381,7 @@ def testSortby(self):
381381
q1 = Query("foo").sort_by('num', asc=True).no_content()
382382
q2 = Query("foo").sort_by('num', asc=False).no_content()
383383
res1, res2 = client.search(q1), client.search(q2)
384-
384+
385385
self.assertEqual(3, res1.total)
386386
self.assertEqual('doc1', res1.docs[0].id)
387387
self.assertEqual('doc2', res1.docs[1].id)
@@ -417,7 +417,7 @@ def testExample(self):
417417
# Creating a client with a given index name
418418
client = Client('myIndex', port=conn.port)
419419
client.redis.flushdb()
420-
420+
421421
# Creating the index definition and schema
422422
client.create_index((TextField('title', weight=5.0), TextField('body')))
423423

@@ -552,7 +552,7 @@ def testNoCreate(self):
552552
# values
553553
res = client.search('@f3:f3_val @f2:f2_val @f1:f1_val')
554554
self.assertEqual(1, res.total)
555-
555+
556556
with self.assertRaises(redis.ResponseError) as error:
557557
client.add_document('doc3', f2='f2_val', f3='f3_val', no_create=True)
558558

@@ -578,7 +578,7 @@ def testSummarize(self):
578578
doc.txt)
579579

580580
q = Query('king henry').paging(0, 1).summarize().highlight()
581-
581+
582582
doc = sorted(client.search(q).docs)[0]
583583
self.assertEqual('<b>Henry</b> ... ', doc.play)
584584
self.assertEqual('ACT I SCENE I. London. The palace. Enter <b>KING</b> <b>HENRY</b>, LORD JOHN OF LANCASTER, the EARL of WESTMORELAND, SIR... ',
@@ -600,7 +600,7 @@ def testAlias(self):
600600

601601
def1 =IndexDefinition(prefix=['index1:'],score_field='name')
602602
def2 =IndexDefinition(prefix=['index2:'],score_field='name')
603-
603+
604604
index1.create_index((TextField('name'),),definition=def1)
605605
index2.create_index((TextField('name'),),definition=def2)
606606

@@ -628,7 +628,7 @@ def testAlias(self):
628628
with self.assertRaises(Exception) as context:
629629
alias_client2.search('*').docs[0]
630630
self.assertEqual('spaceballs: no such index', str(context.exception))
631-
631+
632632
else:
633633

634634
# Creating a client with one index
@@ -808,6 +808,36 @@ def testPhoneticMatcher(self):
808808
self.assertEqual(2, len(res.docs))
809809
self.assertEqual(['John', 'Jon'], sorted([d.name for d in res.docs]))
810810

811+
def testScorer(self):
812+
# Creating a client with a given index name
813+
client = self.getCleanClient('idx')
814+
815+
client.create_index((TextField('description'),))
816+
817+
client.add_document('doc1', description='The quick brown fox jumps over the lazy dog')
818+
client.add_document('doc2', description='Quick alice was beginning to get very tired of sitting by her quick sister on the bank, and of having nothing to do.')
819+
820+
# default scorer is TFIDF
821+
res = client.search(Query('quick').with_scores())
822+
self.assertEqual(1.0, res.docs[0].score)
823+
res = client.search(Query('quick').scorer('TFIDF').with_scores())
824+
self.assertEqual(1.0, res.docs[0].score)
825+
826+
res = client.search(Query('quick').scorer('TFIDF.DOCNORM').with_scores())
827+
self.assertEqual(0.1111111111111111, res.docs[0].score)
828+
829+
res = client.search(Query('quick').scorer('BM25').with_scores())
830+
self.assertEqual(0.17699114465425977, res.docs[0].score)
831+
832+
res = client.search(Query('quick').scorer('DISMAX').with_scores())
833+
self.assertEqual(2.0, res.docs[0].score)
834+
835+
res = client.search(Query('quick').scorer('DOCSCORE').with_scores())
836+
self.assertEqual(1.0, res.docs[0].score)
837+
838+
res = client.search(Query('quick').scorer('HAMMING').with_scores())
839+
self.assertEqual(0.0, res.docs[0].score)
840+
811841
def testGet(self):
812842
client = self.getCleanClient('idx')
813843
client.create_index((TextField('f1'), TextField('f2')))

0 commit comments

Comments
 (0)