Skip to content

Commit fb74aa2

Browse files
authored
Added support for ADDSCORES modifier (#3329)
* Added support for ADDSCORES modifier * Fixed codestyle issues * More codestyle fixes * Updated test cases and testing image to represent latest * Codestyle issues * Added handling for dict responses
1 parent fd0b0d3 commit fb74aa2

File tree

4 files changed

+64
-1
lines changed

4 files changed

+64
-1
lines changed

.github/workflows/integration.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ env:
2828
# this speeds up coverage with Python 3.12: https://github.com/nedbat/coveragepy/issues/1665
2929
COVERAGE_CORE: sysmon
3030
REDIS_IMAGE: redis:7.4-rc2
31-
REDIS_STACK_IMAGE: redis/redis-stack-server:7.4.0-rc2
31+
REDIS_STACK_IMAGE: redis/redis-stack-server:latest
3232

3333
jobs:
3434
dependency-audit:

redis/commands/search/aggregation.py

+11
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,7 @@ def __init__(self, query: str = "*") -> None:
111111
self._verbatim = False
112112
self._cursor = []
113113
self._dialect = None
114+
self._add_scores = False
114115

115116
def load(self, *fields: List[str]) -> "AggregateRequest":
116117
"""
@@ -292,6 +293,13 @@ def with_schema(self) -> "AggregateRequest":
292293
self._with_schema = True
293294
return self
294295

296+
def add_scores(self) -> "AggregateRequest":
297+
"""
298+
If set, includes the score as an ordinary field of the row.
299+
"""
300+
self._add_scores = True
301+
return self
302+
295303
def verbatim(self) -> "AggregateRequest":
296304
self._verbatim = True
297305
return self
@@ -315,6 +323,9 @@ def build_args(self) -> List[str]:
315323
if self._verbatim:
316324
ret.append("VERBATIM")
317325

326+
if self._add_scores:
327+
ret.append("ADDSCORES")
328+
318329
if self._cursor:
319330
ret += self._cursor
320331

tests/test_asyncio/test_search.py

+26
Original file line numberDiff line numberDiff line change
@@ -1530,6 +1530,32 @@ async def test_withsuffixtrie(decoded_r: redis.Redis):
15301530
assert "WITHSUFFIXTRIE" in info["attributes"][0]["flags"]
15311531

15321532

1533+
@pytest.mark.redismod
1534+
@skip_ifmodversion_lt("2.10.05", "search")
1535+
async def test_aggregations_add_scores(decoded_r: redis.Redis):
1536+
assert await decoded_r.ft().create_index(
1537+
(
1538+
TextField("name", sortable=True, weight=5.0),
1539+
NumericField("age", sortable=True),
1540+
)
1541+
)
1542+
1543+
assert await decoded_r.hset("doc1", mapping={"name": "bar", "age": "25"})
1544+
assert await decoded_r.hset("doc2", mapping={"name": "foo", "age": "19"})
1545+
1546+
req = aggregations.AggregateRequest("*").add_scores()
1547+
res = await decoded_r.ft().aggregate(req)
1548+
1549+
if isinstance(res, dict):
1550+
assert len(res["results"]) == 2
1551+
assert res["results"][0]["extra_attributes"] == {"__score": "0.2"}
1552+
assert res["results"][1]["extra_attributes"] == {"__score": "0.2"}
1553+
else:
1554+
assert len(res.rows) == 2
1555+
assert res.rows[0] == ["__score", "0.2"]
1556+
assert res.rows[1] == ["__score", "0.2"]
1557+
1558+
15331559
@pytest.mark.redismod
15341560
@skip_if_redis_enterprise()
15351561
async def test_search_commands_in_pipeline(decoded_r: redis.Redis):

tests/test_search.py

+26
Original file line numberDiff line numberDiff line change
@@ -1440,6 +1440,32 @@ def test_aggregations_filter(client):
14401440
assert res["results"][1]["extra_attributes"] == {"age": "25"}
14411441

14421442

1443+
@pytest.mark.redismod
1444+
@skip_ifmodversion_lt("2.10.05", "search")
1445+
def test_aggregations_add_scores(client):
1446+
client.ft().create_index(
1447+
(
1448+
TextField("name", sortable=True, weight=5.0),
1449+
NumericField("age", sortable=True),
1450+
)
1451+
)
1452+
1453+
client.hset("doc1", mapping={"name": "bar", "age": "25"})
1454+
client.hset("doc2", mapping={"name": "foo", "age": "19"})
1455+
1456+
req = aggregations.AggregateRequest("*").add_scores()
1457+
res = client.ft().aggregate(req)
1458+
1459+
if isinstance(res, dict):
1460+
assert len(res["results"]) == 2
1461+
assert res["results"][0]["extra_attributes"] == {"__score": "0.2"}
1462+
assert res["results"][1]["extra_attributes"] == {"__score": "0.2"}
1463+
else:
1464+
assert len(res.rows) == 2
1465+
assert res.rows[0] == ["__score", "0.2"]
1466+
assert res.rows[1] == ["__score", "0.2"]
1467+
1468+
14431469
@pytest.mark.redismod
14441470
@skip_ifmodversion_lt("2.0.0", "search")
14451471
def test_index_definition(client):

0 commit comments

Comments
 (0)