Skip to content

Commit e323e2b

Browse files
tcleonardThomas Leonard
and
Thomas Leonard
authored
Add enum support to filters and fix filter typing (v2) (#1114)
* - Add filtering support for choice fields converted to graphql Enum (or not) - Fix type of various filters (used to default to String) - Fix bug with contains introduced in previous PR - Fix bug with declared filters being overridden (see PR #1108) - Fix support for ArrayField and add documentation * Fix tests Co-authored-by: Thomas Leonard <[email protected]>
1 parent e0a5d1c commit e323e2b

15 files changed

+834
-110
lines changed

docs/filtering.rst

+43
Original file line numberDiff line numberDiff line change
@@ -228,3 +228,46 @@ with this set up, you can now order the users under group:
228228
}
229229
}
230230
}
231+
232+
233+
PostgreSQL `ArrayField`
234+
-----------------------
235+
236+
Graphene provides an easy to implement filters on `ArrayField` as they are not natively supported by django_filters:
237+
238+
.. code:: python
239+
240+
from django.db import models
241+
from django_filters import FilterSet, OrderingFilter
242+
from graphene_django.filter import ArrayFilter
243+
244+
class Event(models.Model):
245+
name = models.CharField(max_length=50)
246+
tags = ArrayField(models.CharField(max_length=50))
247+
248+
class EventFilterSet(FilterSet):
249+
class Meta:
250+
model = Event
251+
fields = {
252+
"name": ["exact", "contains"],
253+
}
254+
255+
tags__contains = ArrayFilter(field_name="tags", lookup_expr="contains")
256+
tags__overlap = ArrayFilter(field_name="tags", lookup_expr="overlap")
257+
tags = ArrayFilter(field_name="tags", lookup_expr="exact")
258+
259+
class EventType(DjangoObjectType):
260+
class Meta:
261+
model = Event
262+
interfaces = (Node,)
263+
filterset_class = EventFilterSet
264+
265+
with this set up, you can now filter events by tags:
266+
267+
.. code::
268+
269+
query {
270+
events(tags_Overlap: ["concert", "festival"]) {
271+
name
272+
}
273+
}

graphene_django/filter/__init__.py

+10-1
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,19 @@
99
)
1010
else:
1111
from .fields import DjangoFilterConnectionField
12-
from .filters import GlobalIDFilter, GlobalIDMultipleChoiceFilter
12+
from .filters import (
13+
ArrayFilter,
14+
GlobalIDFilter,
15+
GlobalIDMultipleChoiceFilter,
16+
ListFilter,
17+
RangeFilter,
18+
)
1319

1420
__all__ = [
1521
"DjangoFilterConnectionField",
1622
"GlobalIDFilter",
1723
"GlobalIDMultipleChoiceFilter",
24+
"ArrayFilter",
25+
"ListFilter",
26+
"RangeFilter",
1827
]

graphene_django/filter/fields.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,8 @@ def filterset_class(self):
4343
if self._extra_filter_meta:
4444
meta.update(self._extra_filter_meta)
4545

46-
filterset_class = self._provided_filterset_class or (
47-
self.node_type._meta.filterset_class
46+
filterset_class = (
47+
self._provided_filterset_class or self.node_type._meta.filterset_class
4848
)
4949
self._filterset_class = get_filterset_class(filterset_class, **meta)
5050

graphene_django/filter/filters.py

+30-4
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from django.forms import Field
33

44
from django_filters import Filter, MultipleChoiceFilter
5+
from django_filters.constants import EMPTY_VALUES
56

67
from graphql_relay.node.node import from_global_id
78

@@ -31,14 +32,15 @@ def filter(self, qs, value):
3132
return super(GlobalIDMultipleChoiceFilter, self).filter(qs, gids)
3233

3334

34-
class InFilter(Filter):
35+
class ListFilter(Filter):
3536
"""
36-
Filter for a list of value using the `__in` Django filter.
37+
Filter that takes a list of value as input.
38+
It is for example used for `__in` filters.
3739
"""
3840

3941
def filter(self, qs, value):
4042
"""
41-
Override the default filter class to check first weather the list is
43+
Override the default filter class to check first whether the list is
4244
empty or not.
4345
This needs to be done as in this case we expect to get an empty output
4446
(if not an exclude filter) but django_filter consider an empty list
@@ -52,7 +54,7 @@ def filter(self, qs, value):
5254
else:
5355
return qs.none()
5456
else:
55-
return super(InFilter, self).filter(qs, value)
57+
return super(ListFilter, self).filter(qs, value)
5658

5759

5860
def validate_range(value):
@@ -73,3 +75,27 @@ class RangeField(Field):
7375

7476
class RangeFilter(Filter):
7577
field_class = RangeField
78+
79+
80+
class ArrayFilter(Filter):
81+
"""
82+
Filter made for PostgreSQL ArrayField.
83+
"""
84+
85+
def filter(self, qs, value):
86+
"""
87+
Override the default filter class to check first whether the list is
88+
empty or not.
89+
This needs to be done as in this case we expect to get the filter applied with
90+
an empty list since it's a valid value but django_filter consider an empty list
91+
to be an empty input value (see `EMPTY_VALUES`) meaning that
92+
the filter does not need to be applied (hence returning the original
93+
queryset).
94+
"""
95+
if value in EMPTY_VALUES and value != []:
96+
return qs
97+
if self.distinct:
98+
qs = qs.distinct()
99+
lookup = "%s__%s" % (self.field_name, self.lookup_expr)
100+
qs = self.get_method(qs)(**{lookup: value})
101+
return qs

graphene_django/filter/tests/conftest.py

+34-8
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from graphene.relay import Node
1010
from graphene_django import DjangoObjectType
1111
from graphene_django.utils import DJANGO_FILTER_INSTALLED
12+
from graphene_django.filter import ArrayFilter, ListFilter
1213

1314
from ...compat import ArrayField
1415

@@ -32,27 +33,37 @@ def Event():
3233
class Event(models.Model):
3334
name = models.CharField(max_length=50)
3435
tags = ArrayField(models.CharField(max_length=50))
36+
tag_ids = ArrayField(models.IntegerField())
37+
random_field = ArrayField(models.BooleanField())
3538

3639
return Event
3740

3841

3942
@pytest.fixture
4043
def EventFilterSet(Event):
41-
42-
from django.contrib.postgres.forms import SimpleArrayField
43-
44-
class ArrayFilter(filters.Filter):
45-
base_field_class = SimpleArrayField
46-
4744
class EventFilterSet(FilterSet):
4845
class Meta:
4946
model = Event
5047
fields = {
51-
"name": ["exact"],
48+
"name": ["exact", "contains"],
5249
}
5350

51+
# Those are actually usable with our Query fixture bellow
5452
tags__contains = ArrayFilter(field_name="tags", lookup_expr="contains")
5553
tags__overlap = ArrayFilter(field_name="tags", lookup_expr="overlap")
54+
tags = ArrayFilter(field_name="tags", lookup_expr="exact")
55+
56+
# Those are actually not usable and only to check type declarations
57+
tags_ids__contains = ArrayFilter(field_name="tag_ids", lookup_expr="contains")
58+
tags_ids__overlap = ArrayFilter(field_name="tag_ids", lookup_expr="overlap")
59+
tags_ids = ArrayFilter(field_name="tag_ids", lookup_expr="exact")
60+
random_field__contains = ArrayFilter(
61+
field_name="random_field", lookup_expr="contains"
62+
)
63+
random_field__overlap = ArrayFilter(
64+
field_name="random_field", lookup_expr="overlap"
65+
)
66+
random_field = ArrayFilter(field_name="random_field", lookup_expr="exact")
5667

5768
return EventFilterSet
5869

@@ -70,6 +81,11 @@ class Meta:
7081

7182
@pytest.fixture
7283
def Query(Event, EventType):
84+
"""
85+
Note that we have to use a custom resolver to replicate the arrayfield filter behavior as
86+
we are running unit tests in sqlite which does not have ArrayFields.
87+
"""
88+
7389
class Query(graphene.ObjectType):
7490
events = DjangoFilterConnectionField(EventType)
7591

@@ -79,6 +95,7 @@ def resolve_events(self, info, **kwargs):
7995
Event(name="Live Show", tags=["concert", "music", "rock"],),
8096
Event(name="Musical", tags=["movie", "music"],),
8197
Event(name="Ballet", tags=["concert", "dance"],),
98+
Event(name="Speech", tags=[],),
8299
]
83100

84101
STORE["events"] = events
@@ -105,6 +122,13 @@ def filter_events(**kwargs):
105122
STORE["events"],
106123
)
107124
)
125+
if "tags__exact" in kwargs:
126+
STORE["events"] = list(
127+
filter(
128+
lambda e: set(kwargs["tags__exact"]) == set(e.tags),
129+
STORE["events"],
130+
)
131+
)
108132

109133
def mock_queryset_filter(*args, **kwargs):
110134
filter_events(**kwargs)
@@ -121,7 +145,9 @@ def mock_queryset_count(*args, **kwargs):
121145
m_queryset.filter.side_effect = mock_queryset_filter
122146
m_queryset.none.side_effect = mock_queryset_none
123147
m_queryset.count.side_effect = mock_queryset_count
124-
m_queryset.__getitem__.side_effect = STORE["events"].__getitem__
148+
m_queryset.__getitem__.side_effect = lambda index: STORE[
149+
"events"
150+
].__getitem__(index)
125151

126152
return m_queryset
127153

graphene_django/filter/tests/filters.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ class Meta:
1010
fields = {
1111
"headline": ["exact", "icontains"],
1212
"pub_date": ["gt", "lt", "exact"],
13-
"reporter": ["exact"],
13+
"reporter": ["exact", "in"],
1414
}
1515

1616
order_by = OrderingFilter(fields=("pub_date",))

graphene_django/filter/tests/test_contains_filter.py renamed to graphene_django/filter/tests/test_array_field_contains_filter.py

+12-7
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,9 @@
66

77

88
@pytest.mark.skipif(ArrayField is MissingType, reason="ArrayField should exist")
9-
def test_string_contains_multiple(Query):
9+
def test_array_field_contains_multiple(Query):
1010
"""
11-
Test contains filter on a string field.
11+
Test contains filter on a array field of string.
1212
"""
1313

1414
schema = Schema(query=Query)
@@ -32,9 +32,9 @@ def test_string_contains_multiple(Query):
3232

3333

3434
@pytest.mark.skipif(ArrayField is MissingType, reason="ArrayField should exist")
35-
def test_string_contains_one(Query):
35+
def test_array_field_contains_one(Query):
3636
"""
37-
Test contains filter on a string field.
37+
Test contains filter on a array field of string.
3838
"""
3939

4040
schema = Schema(query=Query)
@@ -59,9 +59,9 @@ def test_string_contains_one(Query):
5959

6060

6161
@pytest.mark.skipif(ArrayField is MissingType, reason="ArrayField should exist")
62-
def test_string_contains_none(Query):
62+
def test_array_field_contains_empty_list(Query):
6363
"""
64-
Test contains filter on a string field.
64+
Test contains filter on a array field of string.
6565
"""
6666

6767
schema = Schema(query=Query)
@@ -79,4 +79,9 @@ def test_string_contains_none(Query):
7979
"""
8080
result = schema.execute(query)
8181
assert not result.errors
82-
assert result.data["events"]["edges"] == []
82+
assert result.data["events"]["edges"] == [
83+
{"node": {"name": "Live Show"}},
84+
{"node": {"name": "Musical"}},
85+
{"node": {"name": "Ballet"}},
86+
{"node": {"name": "Speech"}},
87+
]

0 commit comments

Comments
 (0)