-
Notifications
You must be signed in to change notification settings - Fork 227
/
Copy pathfields.py
292 lines (248 loc) · 10.3 KB
/
fields.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
import enum
import warnings
from functools import partial
from promise import Promise, is_thenable
from sqlalchemy.orm.query import Query
from graphene.relay import Connection, ConnectionField
from graphene.relay.connection import connection_adapter, page_info_adapter
from graphql_relay import connection_from_array_slice
from .batching import get_batch_resolver
from .filters import BaseTypeFilter
from .utils import (
SQL_VERSION_HIGHER_EQUAL_THAN_1_4,
EnumValue,
get_nullable_type,
get_query,
get_session,
)
if SQL_VERSION_HIGHER_EQUAL_THAN_1_4:
from sqlalchemy.ext.asyncio import AsyncSession
class SQLAlchemyConnectionField(ConnectionField):
@property
def type(self):
from .types import SQLAlchemyBase
type_ = super(ConnectionField, self).type
nullable_type = get_nullable_type(type_)
if issubclass(nullable_type, Connection):
return type_
assert issubclass(nullable_type, SQLAlchemyBase), (
"SQLALchemyConnectionField only accepts SQLAlchemyBase types, not {}"
).format(nullable_type.__name__)
assert nullable_type.connection, "The type {} doesn't have a connection".format(
nullable_type.__name__
)
assert type_ == nullable_type, (
"Passing a SQLAlchemyBase instance is deprecated. "
"Pass the connection type instead accessible via SQLAlchemyBase.connection"
)
return nullable_type.connection
def __init__(self, type_, *args, **kwargs):
nullable_type = get_nullable_type(type_)
# Handle Sorting and Filtering
if (
"sort" not in kwargs
and nullable_type
and issubclass(nullable_type, Connection)
):
# Let super class raise if type is not a Connection
try:
kwargs.setdefault("sort", nullable_type.Edge.node._type.sort_argument())
except (AttributeError, TypeError):
raise TypeError(
'Cannot create sort argument for {}. A model is required. Set the "sort" argument'
" to None to disabling the creation of the sort query argument".format(
nullable_type.__name__
)
)
elif "sort" in kwargs and kwargs["sort"] is None:
del kwargs["sort"]
if (
"filter" not in kwargs
and nullable_type
and issubclass(nullable_type, Connection)
):
# Only add filtering if a filter argument exists on the object type
filter_argument = nullable_type.Edge.node._type.get_filter_argument()
if filter_argument:
kwargs.setdefault("filter", filter_argument)
elif "filter" in kwargs and kwargs["filter"] is None:
del kwargs["filter"]
super(SQLAlchemyConnectionField, self).__init__(type_, *args, **kwargs)
@property
def model(self):
return get_nullable_type(self.type)._meta.node._meta.model
@classmethod
def get_query(cls, model, info, sort=None, filter=None, **args):
query = get_query(model, info.context)
if sort is not None:
if not isinstance(sort, list):
sort = [sort]
sort_args = []
# ensure consistent handling of graphene Enums, enum values and
# plain strings
for item in sort:
if isinstance(item, enum.Enum):
sort_args.append(item.value.value)
elif isinstance(item, EnumValue):
sort_args.append(item.value)
else:
sort_args.append(item)
query = query.order_by(*sort_args)
if filter is not None:
assert isinstance(filter, dict)
filter_type: BaseTypeFilter = type(filter)
query, clauses = filter_type.execute_filters(query, filter)
query = query.filter(*clauses)
return query
@classmethod
def resolve_connection(cls, connection_type, model, info, args, resolved):
session = get_session(info.context)
if resolved is None:
if SQL_VERSION_HIGHER_EQUAL_THAN_1_4 and isinstance(session, AsyncSession):
async def get_result():
return await cls.resolve_connection_async(
connection_type, model, info, args, resolved
)
return get_result()
else:
resolved = cls.get_query(model, info, **args)
if isinstance(resolved, Query):
_len = resolved.count()
else:
_len = len(resolved)
def adjusted_connection_adapter(edges, pageInfo):
return connection_adapter(connection_type, edges, pageInfo)
connection = connection_from_array_slice(
array_slice=resolved,
args=args,
slice_start=0,
array_length=_len,
array_slice_length=_len,
connection_type=adjusted_connection_adapter,
edge_type=connection_type.Edge,
page_info_type=page_info_adapter,
)
connection.iterable = resolved
connection.length = _len
return connection
@classmethod
async def resolve_connection_async(
cls, connection_type, model, info, args, resolved
):
session = get_session(info.context)
if resolved is None:
query = cls.get_query(model, info, **args)
resolved = (await session.scalars(query)).all()
if isinstance(resolved, Query):
_len = resolved.count()
else:
_len = len(resolved)
def adjusted_connection_adapter(edges, pageInfo):
return connection_adapter(connection_type, edges, pageInfo)
connection = connection_from_array_slice(
array_slice=resolved,
args=args,
slice_start=0,
array_length=_len,
array_slice_length=_len,
connection_type=adjusted_connection_adapter,
edge_type=connection_type.Edge,
page_info_type=page_info_adapter,
)
connection.iterable = resolved
connection.length = _len
return connection
@classmethod
def connection_resolver(cls, resolver, connection_type, model, root, info, **args):
resolved = resolver(root, info, **args)
on_resolve = partial(cls.resolve_connection, connection_type, model, info, args)
if is_thenable(resolved):
return Promise.resolve(resolved).then(on_resolve)
return on_resolve(resolved)
def wrap_resolve(self, parent_resolver):
return partial(
self.connection_resolver,
parent_resolver,
get_nullable_type(self.type),
self.model,
)
# TODO Remove in next major version
class UnsortedSQLAlchemyConnectionField(SQLAlchemyConnectionField):
def __init__(self, type_, *args, **kwargs):
if "sort" in kwargs and kwargs["sort"] is not None:
warnings.warn(
"UnsortedSQLAlchemyConnectionField does not support sorting. "
"All sorting arguments will be ignored."
)
kwargs["sort"] = None
warnings.warn(
"UnsortedSQLAlchemyConnectionField is deprecated and will be removed in the next "
"major version. Use SQLAlchemyConnectionField instead and either don't "
"provide the `sort` argument or set it to None if you do not want sorting.",
DeprecationWarning,
)
super(UnsortedSQLAlchemyConnectionField, self).__init__(type_, *args, **kwargs)
class BatchSQLAlchemyConnectionField(SQLAlchemyConnectionField):
"""
This is currently experimental.
The API and behavior may change in future versions.
Use at your own risk.
"""
@classmethod
def connection_resolver(cls, resolver, connection_type, model, root, info, **args):
if root is None:
resolved = resolver(root, info, **args)
on_resolve = partial(
cls.resolve_connection, connection_type, model, info, args
)
else:
relationship_prop = None
for relationship in root.__class__.__mapper__.relationships:
if relationship.mapper.class_ == model:
relationship_prop = relationship
break
resolved = get_batch_resolver(relationship_prop)(root, info, **args)
on_resolve = partial(
cls.resolve_connection, connection_type, root, info, args
)
if is_thenable(resolved):
return Promise.resolve(resolved).then(on_resolve)
return on_resolve(resolved)
@classmethod
def from_relationship(cls, relationship, registry, **field_kwargs):
model = relationship.mapper.entity
model_type = registry.get_type_for_model(model)
return cls(
model_type.connection,
resolver=get_batch_resolver(relationship),
**field_kwargs,
)
def default_connection_field_factory(relationship, registry, **field_kwargs):
model = relationship.mapper.entity
model_type = registry.get_type_for_model(model)
return __connectionFactory(model_type, **field_kwargs)
# TODO Remove in next major version
__connectionFactory = UnsortedSQLAlchemyConnectionField
def createConnectionField(type_, **field_kwargs):
warnings.warn(
"createConnectionField is deprecated and will be removed in the next "
"major version. Use SQLAlchemyBase.Meta.connection_field_factory instead.",
DeprecationWarning,
)
return __connectionFactory(type_, **field_kwargs)
def registerConnectionFieldFactory(factoryMethod):
warnings.warn(
"registerConnectionFieldFactory is deprecated and will be removed in the next "
"major version. Use SQLAlchemyBase.Meta.connection_field_factory instead.",
DeprecationWarning,
)
global __connectionFactory
__connectionFactory = factoryMethod
def unregisterConnectionFieldFactory():
warnings.warn(
"registerConnectionFieldFactory is deprecated and will be removed in the next "
"major version. Use SQLAlchemyBase.Meta.connection_field_factory instead.",
DeprecationWarning,
)
global __connectionFactory
__connectionFactory = UnsortedSQLAlchemyConnectionField