Skip to content

Commit e0a5d1c

Browse files
Support "contains" and "overlap" filtering (v2) (#1100)
* Fix project setup * Support contains/overlap filters * Add Python 2.7 support * Adjust docstrings * Remove unused fixtures
1 parent 66c8901 commit e0a5d1c

File tree

6 files changed

+307
-12
lines changed

6 files changed

+307
-12
lines changed

graphene_django/compat.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,16 @@ class MissingType(object):
66
# Postgres fields are only available in Django with psycopg2 installed
77
# and we cannot have psycopg2 on PyPy
88
from django.contrib.postgres.fields import (
9+
IntegerRangeField,
910
ArrayField,
1011
HStoreField,
1112
JSONField as PGJSONField,
1213
RangeField,
1314
)
1415
except ImportError:
15-
ArrayField, HStoreField, PGJSONField, RangeField = (MissingType,) * 4
16+
IntegerRangeField, ArrayField, HStoreField, PGJSONField, RangeField = (
17+
MissingType,
18+
) * 5
1619

1720
try:
1821
# JSONField is only available from Django 3.1
+128
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
from mock import MagicMock
2+
import pytest
3+
4+
from django.db import models
5+
from django.db.models.query import QuerySet
6+
from django_filters import filters
7+
from django_filters import FilterSet
8+
import graphene
9+
from graphene.relay import Node
10+
from graphene_django import DjangoObjectType
11+
from graphene_django.utils import DJANGO_FILTER_INSTALLED
12+
13+
from ...compat import ArrayField
14+
15+
pytestmark = []
16+
17+
if DJANGO_FILTER_INSTALLED:
18+
from graphene_django.filter import DjangoFilterConnectionField
19+
else:
20+
pytestmark.append(
21+
pytest.mark.skipif(
22+
True, reason="django_filters not installed or not compatible"
23+
)
24+
)
25+
26+
27+
STORE = {"events": []}
28+
29+
30+
@pytest.fixture
31+
def Event():
32+
class Event(models.Model):
33+
name = models.CharField(max_length=50)
34+
tags = ArrayField(models.CharField(max_length=50))
35+
36+
return Event
37+
38+
39+
@pytest.fixture
40+
def EventFilterSet(Event):
41+
42+
from django.contrib.postgres.forms import SimpleArrayField
43+
44+
class ArrayFilter(filters.Filter):
45+
base_field_class = SimpleArrayField
46+
47+
class EventFilterSet(FilterSet):
48+
class Meta:
49+
model = Event
50+
fields = {
51+
"name": ["exact"],
52+
}
53+
54+
tags__contains = ArrayFilter(field_name="tags", lookup_expr="contains")
55+
tags__overlap = ArrayFilter(field_name="tags", lookup_expr="overlap")
56+
57+
return EventFilterSet
58+
59+
60+
@pytest.fixture
61+
def EventType(Event, EventFilterSet):
62+
class EventType(DjangoObjectType):
63+
class Meta:
64+
model = Event
65+
interfaces = (Node,)
66+
filterset_class = EventFilterSet
67+
68+
return EventType
69+
70+
71+
@pytest.fixture
72+
def Query(Event, EventType):
73+
class Query(graphene.ObjectType):
74+
events = DjangoFilterConnectionField(EventType)
75+
76+
def resolve_events(self, info, **kwargs):
77+
78+
events = [
79+
Event(name="Live Show", tags=["concert", "music", "rock"],),
80+
Event(name="Musical", tags=["movie", "music"],),
81+
Event(name="Ballet", tags=["concert", "dance"],),
82+
]
83+
84+
STORE["events"] = events
85+
86+
m_queryset = MagicMock(spec=QuerySet)
87+
m_queryset.model = Event
88+
89+
def filter_events(**kwargs):
90+
if "tags__contains" in kwargs:
91+
STORE["events"] = list(
92+
filter(
93+
lambda e: set(kwargs["tags__contains"]).issubset(
94+
set(e.tags)
95+
),
96+
STORE["events"],
97+
)
98+
)
99+
if "tags__overlap" in kwargs:
100+
STORE["events"] = list(
101+
filter(
102+
lambda e: not set(kwargs["tags__overlap"]).isdisjoint(
103+
set(e.tags)
104+
),
105+
STORE["events"],
106+
)
107+
)
108+
109+
def mock_queryset_filter(*args, **kwargs):
110+
filter_events(**kwargs)
111+
return m_queryset
112+
113+
def mock_queryset_none(*args, **kwargs):
114+
STORE["events"] = []
115+
return m_queryset
116+
117+
def mock_queryset_count(*args, **kwargs):
118+
return len(STORE["events"])
119+
120+
m_queryset.all.return_value = m_queryset
121+
m_queryset.filter.side_effect = mock_queryset_filter
122+
m_queryset.none.side_effect = mock_queryset_none
123+
m_queryset.count.side_effect = mock_queryset_count
124+
m_queryset.__getitem__.side_effect = STORE["events"].__getitem__
125+
126+
return m_queryset
127+
128+
return Query
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
import pytest
2+
3+
from graphene import Schema
4+
5+
from ...compat import ArrayField, MissingType
6+
7+
8+
@pytest.mark.skipif(ArrayField is MissingType, reason="ArrayField should exist")
9+
def test_string_contains_multiple(Query):
10+
"""
11+
Test contains filter on a string field.
12+
"""
13+
14+
schema = Schema(query=Query)
15+
16+
query = """
17+
query {
18+
events (tags_Contains: ["concert", "music"]) {
19+
edges {
20+
node {
21+
name
22+
}
23+
}
24+
}
25+
}
26+
"""
27+
result = schema.execute(query)
28+
assert not result.errors
29+
assert result.data["events"]["edges"] == [
30+
{"node": {"name": "Live Show"}},
31+
]
32+
33+
34+
@pytest.mark.skipif(ArrayField is MissingType, reason="ArrayField should exist")
35+
def test_string_contains_one(Query):
36+
"""
37+
Test contains filter on a string field.
38+
"""
39+
40+
schema = Schema(query=Query)
41+
42+
query = """
43+
query {
44+
events (tags_Contains: ["music"]) {
45+
edges {
46+
node {
47+
name
48+
}
49+
}
50+
}
51+
}
52+
"""
53+
result = schema.execute(query)
54+
assert not result.errors
55+
assert result.data["events"]["edges"] == [
56+
{"node": {"name": "Live Show"}},
57+
{"node": {"name": "Musical"}},
58+
]
59+
60+
61+
@pytest.mark.skipif(ArrayField is MissingType, reason="ArrayField should exist")
62+
def test_string_contains_none(Query):
63+
"""
64+
Test contains filter on a string field.
65+
"""
66+
67+
schema = Schema(query=Query)
68+
69+
query = """
70+
query {
71+
events (tags_Contains: []) {
72+
edges {
73+
node {
74+
name
75+
}
76+
}
77+
}
78+
}
79+
"""
80+
result = schema.execute(query)
81+
assert not result.errors
82+
assert result.data["events"]["edges"] == []
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
import pytest
2+
3+
from graphene import Schema
4+
5+
from ...compat import ArrayField, MissingType
6+
7+
8+
@pytest.mark.skipif(ArrayField is MissingType, reason="ArrayField should exist")
9+
def test_string_overlap_multiple(Query):
10+
"""
11+
Test overlap filter on a string field.
12+
"""
13+
14+
schema = Schema(query=Query)
15+
16+
query = """
17+
query {
18+
events (tags_Overlap: ["concert", "music"]) {
19+
edges {
20+
node {
21+
name
22+
}
23+
}
24+
}
25+
}
26+
"""
27+
result = schema.execute(query)
28+
assert not result.errors
29+
assert result.data["events"]["edges"] == [
30+
{"node": {"name": "Live Show"}},
31+
{"node": {"name": "Musical"}},
32+
{"node": {"name": "Ballet"}},
33+
]
34+
35+
36+
@pytest.mark.skipif(ArrayField is MissingType, reason="ArrayField should exist")
37+
def test_string_overlap_one(Query):
38+
"""
39+
Test overlap filter on a string field.
40+
"""
41+
42+
schema = Schema(query=Query)
43+
44+
query = """
45+
query {
46+
events (tags_Overlap: ["music"]) {
47+
edges {
48+
node {
49+
name
50+
}
51+
}
52+
}
53+
}
54+
"""
55+
result = schema.execute(query)
56+
assert not result.errors
57+
assert result.data["events"]["edges"] == [
58+
{"node": {"name": "Live Show"}},
59+
{"node": {"name": "Musical"}},
60+
]
61+
62+
63+
@pytest.mark.skipif(ArrayField is MissingType, reason="ArrayField should exist")
64+
def test_string_overlap_none(Query):
65+
"""
66+
Test overlap filter on a string field.
67+
"""
68+
69+
schema = Schema(query=Query)
70+
71+
query = """
72+
query {
73+
events (tags_Overlap: []) {
74+
edges {
75+
node {
76+
name
77+
}
78+
}
79+
}
80+
}
81+
"""
82+
result = schema.execute(query)
83+
assert not result.errors
84+
assert result.data["events"]["edges"] == []

graphene_django/filter/utils.py

+7-9
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import six
22

3-
from graphene import List
3+
import graphene
44

55
from django_filters.utils import get_model_field
66
from django_filters.filters import Filter, BaseCSVFilter
@@ -41,11 +41,11 @@ def get_filtering_args_from_filterset(filterset_class, type):
4141

4242
field = convert_form_field(form_field)
4343

44-
if filter_type in ["in", "range"]:
45-
# Replace CSV filters (`in`, `range`) argument type to be a list of
44+
if filter_type in {"in", "range", "contains", "overlap"}:
45+
# Replace CSV filters (`in`, `range`, `contains`, `overlap`) argument type to be a list of
4646
# the same type as the field. See comments in
4747
# `replace_csv_filters` method for more details.
48-
field = List(field.get_type())
48+
field = graphene.List(field.get_type())
4949

5050
field_type = field.Argument()
5151
field_type.description = filter_field.label
@@ -71,7 +71,7 @@ def get_filterset_class(filterset_class, **meta):
7171

7272
def replace_csv_filters(filterset_class):
7373
"""
74-
Replace the "in" and "range" filters (that are not explicitly declared) to not be BaseCSVFilter (BaseInFilter, BaseRangeFilter) objects anymore
74+
Replace the "in", "contains", "overlap" and "range" filters (that are not explicitly declared) to not be BaseCSVFilter (BaseInFilter, BaseRangeFilter) objects anymore
7575
but regular Filter objects that simply use the input value as filter argument on the queryset.
7676
7777
This is because those BaseCSVFilter are expecting a string as input with comma separated value but with GraphQl we
@@ -81,8 +81,7 @@ def replace_csv_filters(filterset_class):
8181
"""
8282
for name, filter_field in six.iteritems(filterset_class.base_filters):
8383
filter_type = filter_field.lookup_expr
84-
if filter_type == "in":
85-
assert isinstance(filter_field, BaseCSVFilter)
84+
if filter_type in {"in", "contains", "overlap"}:
8685
filterset_class.base_filters[name] = InFilter(
8786
field_name=filter_field.field_name,
8887
lookup_expr=filter_field.lookup_expr,
@@ -92,8 +91,7 @@ def replace_csv_filters(filterset_class):
9291
**filter_field.extra
9392
)
9493

95-
if filter_type == "range":
96-
assert isinstance(filter_field, BaseCSVFilter)
94+
elif filter_type == "range":
9795
filterset_class.base_filters[name] = RangeFilter(
9896
field_name=filter_field.field_name,
9997
lookup_expr=filter_field.lookup_expr,

graphene_django/tests/test_query.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import graphene
1212
from graphene.relay import Node
1313

14-
from ..compat import JSONField, MissingType
14+
from ..compat import IntegerRangeField, MissingType
1515
from ..fields import DjangoConnectionField
1616
from ..types import DjangoObjectType
1717
from ..utils import DJANGO_FILTER_INSTALLED
@@ -113,7 +113,7 @@ def resolve_reporter(self, info):
113113
assert result.data == expected
114114

115115

116-
@pytest.mark.skipif(JSONField is MissingType, reason="RangeField should exist")
116+
@pytest.mark.skipif(IntegerRangeField is MissingType, reason="RangeField should exist")
117117
def test_should_query_postgres_fields():
118118
from django.contrib.postgres.fields import (
119119
IntegerRangeField,

0 commit comments

Comments
 (0)