5
5
from promise import Promise , is_thenable
6
6
from sqlalchemy .orm .query import Query
7
7
8
+ from graphene import NonNull
8
9
from graphene .relay import Connection , ConnectionField
9
10
from graphene .relay .connection import PageInfo
10
11
from graphql_relay .connection .arrayconnection import connection_from_list_slice
@@ -19,19 +20,26 @@ def type(self):
19
20
from .types import SQLAlchemyObjectType
20
21
21
22
_type = super (ConnectionField , self ).type
22
- if issubclass (_type , Connection ):
23
+ nullable_type = get_nullable_type (_type )
24
+ if issubclass (nullable_type , Connection ):
23
25
return _type
24
- assert issubclass (_type , SQLAlchemyObjectType ), (
26
+ assert issubclass (nullable_type , SQLAlchemyObjectType ), (
25
27
"SQLALchemyConnectionField only accepts SQLAlchemyObjectType types, not {}"
26
- ).format (_type .__name__ )
27
- assert _type .connection , "The type {} doesn't have a connection" .format (
28
- _type .__name__
28
+ ).format (nullable_type .__name__ )
29
+ assert (
30
+ nullable_type .connection
31
+ ), "The type {} doesn't have a connection" .format (
32
+ nullable_type .__name__
29
33
)
30
- return _type .connection
34
+ assert _type == nullable_type , (
35
+ "Passing a SQLAlchemyObjectType instance is deprecated. "
36
+ "Pass the connection type instead accessible via SQLAlchemyObjectType.connection"
37
+ )
38
+ return nullable_type .connection
31
39
32
40
@property
33
41
def model (self ):
34
- return self .type ._meta .node ._meta .model
42
+ return get_nullable_type ( self .type ) ._meta .node ._meta .model
35
43
36
44
@classmethod
37
45
def get_query (cls , model , info , ** args ):
@@ -70,21 +78,27 @@ def connection_resolver(cls, resolver, connection_type, model, root, info, **arg
70
78
return on_resolve (resolved )
71
79
72
80
def get_resolver (self , parent_resolver ):
73
- return partial (self .connection_resolver , parent_resolver , self .type , self .model )
81
+ return partial (
82
+ self .connection_resolver ,
83
+ parent_resolver ,
84
+ get_nullable_type (self .type ),
85
+ self .model ,
86
+ )
74
87
75
88
76
89
# TODO Rename this to SortableSQLAlchemyConnectionField
77
90
class SQLAlchemyConnectionField (UnsortedSQLAlchemyConnectionField ):
78
91
def __init__ (self , type , * args , ** kwargs ):
79
- if "sort" not in kwargs and issubclass (type , Connection ):
92
+ nullable_type = get_nullable_type (type )
93
+ if "sort" not in kwargs and issubclass (nullable_type , Connection ):
80
94
# Let super class raise if type is not a Connection
81
95
try :
82
- kwargs .setdefault ("sort" , type .Edge .node ._type .sort_argument ())
96
+ kwargs .setdefault ("sort" , nullable_type .Edge .node ._type .sort_argument ())
83
97
except (AttributeError , TypeError ):
84
98
raise TypeError (
85
99
'Cannot create sort argument for {}. A model is required. Set the "sort" argument'
86
100
" to None to disabling the creation of the sort query argument" .format (
87
- type .__name__
101
+ nullable_type .__name__
88
102
)
89
103
)
90
104
elif "sort" in kwargs and kwargs ["sort" ] is None :
@@ -108,8 +122,14 @@ class BatchSQLAlchemyConnectionField(UnsortedSQLAlchemyConnectionField):
108
122
The API and behavior may change in future versions.
109
123
Use at your own risk.
110
124
"""
125
+
111
126
def get_resolver (self , parent_resolver ):
112
- return partial (self .connection_resolver , self .resolver , self .type , self .model )
127
+ return partial (
128
+ self .connection_resolver ,
129
+ self .resolver ,
130
+ get_nullable_type (self .type ),
131
+ self .model ,
132
+ )
113
133
114
134
@classmethod
115
135
def from_relationship (cls , relationship , registry , ** field_kwargs ):
@@ -155,3 +175,9 @@ def unregisterConnectionFieldFactory():
155
175
)
156
176
global __connectionFactory
157
177
__connectionFactory = UnsortedSQLAlchemyConnectionField
178
+
179
+
180
+ def get_nullable_type (_type ):
181
+ if isinstance (_type , NonNull ):
182
+ return _type .of_type
183
+ return _type
0 commit comments