Skip to content

Commit f6fa886

Browse files
authored
Filter an existing django queryset (#168)
* Add queryset filtering * Remove _get_queryset * Add get_queryset --------- Co-authored-by: Andrey Laguta <[email protected]>
1 parent 7e959bb commit f6fa886

File tree

2 files changed

+44
-8
lines changed

2 files changed

+44
-8
lines changed

django_elasticsearch_dsl/search.py

Lines changed: 26 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,17 @@ def _clone(self):
1414
s._model = self._model
1515
return s
1616

17-
def to_queryset(self, keep_order=True):
17+
def filter_queryset(self, queryset, keep_search_order=True):
1818
"""
19-
This method return a django queryset from the an elasticsearch result.
20-
It cost a query to the sql db.
19+
Filter an existing django queryset using the elasticsearch result.
20+
It costs a query to the sql db.
2121
"""
2222
s = self
23+
if s._model is not queryset.model:
24+
raise TypeError(
25+
'Unexpected queryset model '
26+
'(should be: %s, got: %s)' % (s._model, queryset.model)
27+
)
2328

2429
# Do not query again if the es result is already cached
2530
if not hasattr(self, '_response'):
@@ -28,14 +33,27 @@ def to_queryset(self, keep_order=True):
2833
s = s.execute()
2934

3035
pks = [result.meta.id for result in s]
36+
queryset = queryset.filter(pk__in=pks)
3137

32-
qs = self._model.objects.filter(pk__in=pks)
33-
34-
if keep_order:
38+
if keep_search_order:
3539
preserved_order = Case(
3640
*[When(pk=pk, then=pos) for pos, pk in enumerate(pks)],
3741
output_field=IntegerField()
3842
)
39-
qs = qs.order_by(preserved_order)
43+
queryset = queryset.order_by(preserved_order)
44+
45+
return queryset
46+
47+
def _get_queryset(self):
48+
"""
49+
Return a django queryset that will be filtered by to_queryset method.
50+
"""
51+
return self._model._default_manager.all()
4052

41-
return qs
53+
def to_queryset(self, keep_order=True):
54+
"""
55+
Return a django queryset from the elasticsearch result.
56+
It costs a query to the sql db.
57+
"""
58+
qs = self._get_queryset()
59+
return self.filter_queryset(qs, keep_order)

tests/test_integration.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -331,6 +331,24 @@ def test_rebuild_command(self):
331331
result = AdDocument().search().execute()
332332
self.assertEqual(len(result), 3)
333333

334+
def test_filter_queryset(self):
335+
Ad(title="Nothing that match", car=self.car1).save()
336+
337+
qs = AdDocument().search().query(
338+
'match', title="Ad number 2").filter_queryset(Ad.objects)
339+
self.assertEqual(qs.count(), 2)
340+
self.assertEqual(list(qs), [self.ad2, self.ad1])
341+
342+
qs = AdDocument().search().query(
343+
'match', title="Ad number 2"
344+
).filter_queryset(Ad.objects.filter(url="www.ad2.com"))
345+
self.assertEqual(qs.count(), 1)
346+
self.assertEqual(list(qs), [self.ad2])
347+
348+
with self.assertRaisesMessage(TypeError, 'Unexpected queryset model'):
349+
AdDocument().search().query(
350+
'match', title="Ad number 2").filter_queryset(Category.objects)
351+
334352
def test_to_queryset(self):
335353
Ad(title="Nothing that match", car=self.car1).save()
336354
qs = AdDocument().search().query(

0 commit comments

Comments
 (0)