Skip to content

Commit 849217a

Browse files
authored
Add support for Non-Null SQLAlchemyConnectionField (#261)
* Add support for Non-Null SQLAlchemyConnectionField * Remove implicit ORDER BY clause to fix tests with SQLAlchemy 1.3.16
1 parent 421f8e4 commit 849217a

File tree

3 files changed

+56
-18
lines changed

3 files changed

+56
-18
lines changed

graphene_sqlalchemy/fields.py

+38-12
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from promise import Promise, is_thenable
66
from sqlalchemy.orm.query import Query
77

8+
from graphene import NonNull
89
from graphene.relay import Connection, ConnectionField
910
from graphene.relay.connection import PageInfo
1011
from graphql_relay.connection.arrayconnection import connection_from_list_slice
@@ -19,19 +20,26 @@ def type(self):
1920
from .types import SQLAlchemyObjectType
2021

2122
_type = super(ConnectionField, self).type
22-
if issubclass(_type, Connection):
23+
nullable_type = get_nullable_type(_type)
24+
if issubclass(nullable_type, Connection):
2325
return _type
24-
assert issubclass(_type, SQLAlchemyObjectType), (
26+
assert issubclass(nullable_type, SQLAlchemyObjectType), (
2527
"SQLALchemyConnectionField only accepts SQLAlchemyObjectType types, not {}"
26-
).format(_type.__name__)
27-
assert _type.connection, "The type {} doesn't have a connection".format(
28-
_type.__name__
28+
).format(nullable_type.__name__)
29+
assert (
30+
nullable_type.connection
31+
), "The type {} doesn't have a connection".format(
32+
nullable_type.__name__
2933
)
30-
return _type.connection
34+
assert _type == nullable_type, (
35+
"Passing a SQLAlchemyObjectType instance is deprecated. "
36+
"Pass the connection type instead accessible via SQLAlchemyObjectType.connection"
37+
)
38+
return nullable_type.connection
3139

3240
@property
3341
def model(self):
34-
return self.type._meta.node._meta.model
42+
return get_nullable_type(self.type)._meta.node._meta.model
3543

3644
@classmethod
3745
def get_query(cls, model, info, **args):
@@ -70,21 +78,27 @@ def connection_resolver(cls, resolver, connection_type, model, root, info, **arg
7078
return on_resolve(resolved)
7179

7280
def get_resolver(self, parent_resolver):
73-
return partial(self.connection_resolver, parent_resolver, self.type, self.model)
81+
return partial(
82+
self.connection_resolver,
83+
parent_resolver,
84+
get_nullable_type(self.type),
85+
self.model,
86+
)
7487

7588

7689
# TODO Rename this to SortableSQLAlchemyConnectionField
7790
class SQLAlchemyConnectionField(UnsortedSQLAlchemyConnectionField):
7891
def __init__(self, type, *args, **kwargs):
79-
if "sort" not in kwargs and issubclass(type, Connection):
92+
nullable_type = get_nullable_type(type)
93+
if "sort" not in kwargs and issubclass(nullable_type, Connection):
8094
# Let super class raise if type is not a Connection
8195
try:
82-
kwargs.setdefault("sort", type.Edge.node._type.sort_argument())
96+
kwargs.setdefault("sort", nullable_type.Edge.node._type.sort_argument())
8397
except (AttributeError, TypeError):
8498
raise TypeError(
8599
'Cannot create sort argument for {}. A model is required. Set the "sort" argument'
86100
" to None to disabling the creation of the sort query argument".format(
87-
type.__name__
101+
nullable_type.__name__
88102
)
89103
)
90104
elif "sort" in kwargs and kwargs["sort"] is None:
@@ -108,8 +122,14 @@ class BatchSQLAlchemyConnectionField(UnsortedSQLAlchemyConnectionField):
108122
The API and behavior may change in future versions.
109123
Use at your own risk.
110124
"""
125+
111126
def get_resolver(self, parent_resolver):
112-
return partial(self.connection_resolver, self.resolver, self.type, self.model)
127+
return partial(
128+
self.connection_resolver,
129+
self.resolver,
130+
get_nullable_type(self.type),
131+
self.model,
132+
)
113133

114134
@classmethod
115135
def from_relationship(cls, relationship, registry, **field_kwargs):
@@ -155,3 +175,9 @@ def unregisterConnectionFieldFactory():
155175
)
156176
global __connectionFactory
157177
__connectionFactory = UnsortedSQLAlchemyConnectionField
178+
179+
180+
def get_nullable_type(_type):
181+
if isinstance(_type, NonNull):
182+
return _type.of_type
183+
return _type

graphene_sqlalchemy/tests/test_batching.py

+3-5
Original file line numberDiff line numberDiff line change
@@ -233,8 +233,7 @@ def test_one_to_one(session_factory):
233233
'articles.headline AS articles_headline, '
234234
'articles.pub_date AS articles_pub_date \n'
235235
'FROM articles \n'
236-
'WHERE articles.reporter_id IN (?, ?) '
237-
'ORDER BY articles.reporter_id',
236+
'WHERE articles.reporter_id IN (?, ?)',
238237
'(1, 2)'
239238
]
240239

@@ -337,8 +336,7 @@ def test_one_to_many(session_factory):
337336
'articles.headline AS articles_headline, '
338337
'articles.pub_date AS articles_pub_date \n'
339338
'FROM articles \n'
340-
'WHERE articles.reporter_id IN (?, ?) '
341-
'ORDER BY articles.reporter_id',
339+
'WHERE articles.reporter_id IN (?, ?)',
342340
'(1, 2)'
343341
]
344342

@@ -470,7 +468,7 @@ def test_many_to_many(session_factory):
470468
'JOIN association AS association_1 ON reporters_1.id = association_1.reporter_id '
471469
'JOIN pets ON pets.id = association_1.pet_id \n'
472470
'WHERE reporters_1.id IN (?, ?) '
473-
'ORDER BY reporters_1.id, pets.id',
471+
'ORDER BY pets.id',
474472
'(1, 2)'
475473
]
476474

graphene_sqlalchemy/tests/test_fields.py

+15-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import pytest
22
from promise import Promise
33

4-
from graphene import ObjectType
4+
from graphene import NonNull, ObjectType
55
from graphene.relay import Connection, Node
66

77
from ..fields import (SQLAlchemyConnectionField,
@@ -26,6 +26,20 @@ class Meta:
2626
##
2727

2828

29+
def test_nonnull_sqlalachemy_connection():
30+
field = SQLAlchemyConnectionField(NonNull(Pet.connection))
31+
assert isinstance(field.type, NonNull)
32+
assert issubclass(field.type.of_type, Connection)
33+
assert field.type.of_type._meta.node is Pet
34+
35+
36+
def test_required_sqlalachemy_connection():
37+
field = SQLAlchemyConnectionField(Pet.connection, required=True)
38+
assert isinstance(field.type, NonNull)
39+
assert issubclass(field.type.of_type, Connection)
40+
assert field.type.of_type._meta.node is Pet
41+
42+
2943
def test_promise_connection_resolver():
3044
def resolver(_obj, _info):
3145
return Promise.resolve([])

0 commit comments

Comments
 (0)