|
| 1 | +import inspect |
1 | 2 | from collections import OrderedDict
|
2 |
| -from functools import singledispatch, wraps |
| 3 | +from functools import partial, singledispatch, wraps |
3 | 4 |
|
4 | 5 | from django.db import models
|
5 | 6 | from django.utils.encoding import force_str
|
|
25 | 26 | )
|
26 | 27 | from graphene.types.json import JSONString
|
27 | 28 | from graphene.types.scalars import BigInt
|
| 29 | +from graphene.types.resolver import get_default_resolver |
28 | 30 | from graphene.utils.str_converters import to_camel_case
|
29 | 31 | from graphql import GraphQLError
|
30 | 32 |
|
@@ -258,14 +260,62 @@ def convert_time_to_string(field, registry=None):
|
258 | 260 |
|
259 | 261 | @convert_django_field.register(models.OneToOneRel)
|
260 | 262 | def convert_onetoone_field_to_djangomodel(field, registry=None):
|
| 263 | + from graphene.utils.str_converters import to_snake_case |
| 264 | + from .types import DjangoObjectType |
| 265 | + |
261 | 266 | model = field.related_model
|
262 | 267 |
|
263 | 268 | def dynamic_type():
|
264 | 269 | _type = registry.get_type_for_model(model)
|
265 | 270 | if not _type:
|
266 | 271 | return
|
267 | 272 |
|
268 |
| - return Field(_type, required=not field.null) |
| 273 | + class CustomField(Field): |
| 274 | + def wrap_resolve(self, parent_resolver): |
| 275 | + """ |
| 276 | + Implements a custom resolver which goes through the `get_node` method to ensure that |
| 277 | + it goes through the `get_queryset` method of the DjangoObjectType. |
| 278 | + """ |
| 279 | + resolver = super().wrap_resolve(parent_resolver) |
| 280 | + |
| 281 | + # If `get_queryset` was not overridden in the DjangoObjectType |
| 282 | + # or if we explicitly bypass the `get_queryset` method, |
| 283 | + # we can just return the default resolver. |
| 284 | + if ( |
| 285 | + _type.get_queryset.__func__ |
| 286 | + is DjangoObjectType.get_queryset.__func__ |
| 287 | + or getattr(resolver, "_bypass_get_queryset", False) |
| 288 | + ): |
| 289 | + return resolver |
| 290 | + |
| 291 | + def custom_resolver(root, info, **args): |
| 292 | + # Note: this function is used to resolve 1:1 relation fields |
| 293 | + |
| 294 | + is_resolver_awaitable = inspect.iscoroutinefunction(resolver) |
| 295 | + |
| 296 | + if is_resolver_awaitable: |
| 297 | + fk_obj = resolver(root, info, **args) |
| 298 | + # In case the resolver is a custom awaitable resolver that overwrites |
| 299 | + # the default Django resolver |
| 300 | + return fk_obj |
| 301 | + |
| 302 | + field_name = to_snake_case(info.field_name) |
| 303 | + reversed_field_name = root.__class__._meta.get_field( |
| 304 | + field_name |
| 305 | + ).remote_field.name |
| 306 | + return _type.get_queryset( |
| 307 | + _type._meta.model.objects.filter( |
| 308 | + **{reversed_field_name: root.pk} |
| 309 | + ), |
| 310 | + info, |
| 311 | + ).get() |
| 312 | + |
| 313 | + return custom_resolver |
| 314 | + |
| 315 | + return CustomField( |
| 316 | + _type, |
| 317 | + required=not field.null, |
| 318 | + ) |
269 | 319 |
|
270 | 320 | return Dynamic(dynamic_type)
|
271 | 321 |
|
@@ -313,14 +363,89 @@ def dynamic_type():
|
313 | 363 | @convert_django_field.register(models.OneToOneField)
|
314 | 364 | @convert_django_field.register(models.ForeignKey)
|
315 | 365 | def convert_field_to_djangomodel(field, registry=None):
|
| 366 | + from graphene.utils.str_converters import to_snake_case |
| 367 | + from .types import DjangoObjectType |
| 368 | + |
316 | 369 | model = field.related_model
|
317 | 370 |
|
318 | 371 | def dynamic_type():
|
319 | 372 | _type = registry.get_type_for_model(model)
|
320 | 373 | if not _type:
|
321 | 374 | return
|
322 | 375 |
|
323 |
| - return Field( |
| 376 | + class CustomField(Field): |
| 377 | + def wrap_resolve(self, parent_resolver): |
| 378 | + """ |
| 379 | + Implements a custom resolver which goes through the `get_node` method to ensure that |
| 380 | + it goes through the `get_queryset` method of the DjangoObjectType. |
| 381 | + """ |
| 382 | + resolver = super().wrap_resolve(parent_resolver) |
| 383 | + |
| 384 | + # If `get_queryset` was not overridden in the DjangoObjectType |
| 385 | + # or if we explicitly bypass the `get_queryset` method, |
| 386 | + # we can just return the default resolver. |
| 387 | + if ( |
| 388 | + _type.get_queryset.__func__ |
| 389 | + is DjangoObjectType.get_queryset.__func__ |
| 390 | + or getattr(resolver, "_bypass_get_queryset", False) |
| 391 | + ): |
| 392 | + return resolver |
| 393 | + |
| 394 | + def custom_resolver(root, info, **args): |
| 395 | + # Note: this function is used to resolve FK or 1:1 fields |
| 396 | + # it does not differentiate between custom-resolved fields |
| 397 | + # and default resolved fields. |
| 398 | + |
| 399 | + # because this is a django foreign key or one-to-one field, the primary-key for |
| 400 | + # this node can be accessed from the root node. |
| 401 | + # ex: article.reporter_id |
| 402 | + |
| 403 | + # get the name of the id field from the root's model |
| 404 | + field_name = to_snake_case(info.field_name) |
| 405 | + db_field_key = root.__class__._meta.get_field(field_name).attname |
| 406 | + if hasattr(root, db_field_key): |
| 407 | + # get the object's primary-key from root |
| 408 | + object_pk = getattr(root, db_field_key) |
| 409 | + else: |
| 410 | + return None |
| 411 | + |
| 412 | + is_resolver_awaitable = inspect.iscoroutinefunction(resolver) |
| 413 | + |
| 414 | + if is_resolver_awaitable: |
| 415 | + fk_obj = resolver(root, info, **args) |
| 416 | + # In case the resolver is a custom awaitable resolver that overwrites |
| 417 | + # the default Django resolver |
| 418 | + return fk_obj |
| 419 | + |
| 420 | + instance_from_get_node = _type.get_node(info, object_pk) |
| 421 | + |
| 422 | + if instance_from_get_node is None: |
| 423 | + # no instance to return |
| 424 | + return |
| 425 | + elif ( |
| 426 | + isinstance(resolver, partial) |
| 427 | + and resolver.func is get_default_resolver() |
| 428 | + ): |
| 429 | + return instance_from_get_node |
| 430 | + elif resolver is not get_default_resolver(): |
| 431 | + # Default resolver is overridden |
| 432 | + # For optimization, add the instance to the resolver |
| 433 | + setattr(root, field_name, instance_from_get_node) |
| 434 | + # Explanation: |
| 435 | + # previously, _type.get_node` is called which results in at least one hit to the database. |
| 436 | + # But, if we did not pass the instance to the root, calling the resolver will result in |
| 437 | + # another call to get the instance which results in at least two database queries in total |
| 438 | + # to resolve this node only. |
| 439 | + # That's why the value of the object is set in the root so when the object is accessed |
| 440 | + # in the resolver (root.field_name) it does not access the database unless queried explicitly. |
| 441 | + fk_obj = resolver(root, info, **args) |
| 442 | + return fk_obj |
| 443 | + else: |
| 444 | + return instance_from_get_node |
| 445 | + |
| 446 | + return custom_resolver |
| 447 | + |
| 448 | + return CustomField( |
324 | 449 | _type,
|
325 | 450 | description=get_django_field_description(field),
|
326 | 451 | required=not field.null,
|
|
0 commit comments