Skip to content

Commit 17d535e

Browse files
authored
Add batching params (#260)
Add parameters to toggle batching on or off. This can be configured at 2 levels: - we can configure all the fields of a type at once via SQLAlchemyObjectType.meta.batching - or we can specify it for a specific field via ORMfield.batching. This trumps SQLAlchemyObjectType.meta.batching.
1 parent 7a48d3d commit 17d535e

File tree

6 files changed

+325
-120
lines changed

6 files changed

+325
-120
lines changed

graphene_sqlalchemy/converter.py

+76-20
Original file line numberDiff line numberDiff line change
@@ -3,21 +3,28 @@
33
from singledispatch import singledispatch
44
from sqlalchemy import types
55
from sqlalchemy.dialects import postgresql
6-
from sqlalchemy.orm import interfaces
6+
from sqlalchemy.orm import interfaces, strategies
77

88
from graphene import (ID, Boolean, Dynamic, Enum, Field, Float, Int, List,
99
String)
1010
from graphene.types.json import JSONString
1111

12+
from .batching import get_batch_resolver
1213
from .enums import enum_for_sa_enum
14+
from .fields import (BatchSQLAlchemyConnectionField,
15+
default_connection_field_factory)
1316
from .registry import get_global_registry
17+
from .resolvers import get_attr_resolver, get_custom_resolver
1418

1519
try:
1620
from sqlalchemy_utils import ChoiceType, JSONType, ScalarListType, TSVectorType
1721
except ImportError:
1822
ChoiceType = JSONType = ScalarListType = TSVectorType = object
1923

2024

25+
is_selectin_available = getattr(strategies, 'SelectInLoader', None)
26+
27+
2128
def get_column_doc(column):
2229
return getattr(column, "doc", None)
2330

@@ -26,33 +33,82 @@ def is_column_nullable(column):
2633
return bool(getattr(column, "nullable", True))
2734

2835

29-
def convert_sqlalchemy_relationship(relationship_prop, registry, connection_field_factory, resolver, **field_kwargs):
30-
direction = relationship_prop.direction
31-
model = relationship_prop.mapper.entity
32-
36+
def convert_sqlalchemy_relationship(relationship_prop, obj_type, connection_field_factory, batching,
37+
orm_field_name, **field_kwargs):
38+
"""
39+
:param sqlalchemy.RelationshipProperty relationship_prop:
40+
:param SQLAlchemyObjectType obj_type:
41+
:param function|None connection_field_factory:
42+
:param bool batching:
43+
:param str orm_field_name:
44+
:param dict field_kwargs:
45+
:rtype: Dynamic
46+
"""
3347
def dynamic_type():
34-
_type = registry.get_type_for_model(model)
48+
""":rtype: Field|None"""
49+
direction = relationship_prop.direction
50+
child_type = obj_type._meta.registry.get_type_for_model(relationship_prop.mapper.entity)
51+
batching_ = batching if is_selectin_available else False
3552

36-
if not _type:
53+
if not child_type:
3754
return None
55+
3856
if direction == interfaces.MANYTOONE or not relationship_prop.uselist:
39-
return Field(
40-
_type,
41-
resolver=resolver,
42-
**field_kwargs
43-
)
44-
elif direction in (interfaces.ONETOMANY, interfaces.MANYTOMANY):
45-
if _type.connection:
46-
# TODO Add a way to override connection_field_factory
47-
return connection_field_factory(relationship_prop, registry, **field_kwargs)
48-
return Field(
49-
List(_type),
50-
**field_kwargs
51-
)
57+
return _convert_o2o_or_m2o_relationship(relationship_prop, obj_type, batching_, orm_field_name,
58+
**field_kwargs)
59+
60+
if direction in (interfaces.ONETOMANY, interfaces.MANYTOMANY):
61+
return _convert_o2m_or_m2m_relationship(relationship_prop, obj_type, batching_,
62+
connection_field_factory, **field_kwargs)
5263

5364
return Dynamic(dynamic_type)
5465

5566

67+
def _convert_o2o_or_m2o_relationship(relationship_prop, obj_type, batching, orm_field_name, **field_kwargs):
68+
"""
69+
Convert one-to-one or many-to-one relationshsip. Return an object field.
70+
71+
:param sqlalchemy.RelationshipProperty relationship_prop:
72+
:param SQLAlchemyObjectType obj_type:
73+
:param bool batching:
74+
:param str orm_field_name:
75+
:param dict field_kwargs:
76+
:rtype: Field
77+
"""
78+
child_type = obj_type._meta.registry.get_type_for_model(relationship_prop.mapper.entity)
79+
80+
resolver = get_custom_resolver(obj_type, orm_field_name)
81+
if resolver is None:
82+
resolver = get_batch_resolver(relationship_prop) if batching else \
83+
get_attr_resolver(obj_type, relationship_prop.key)
84+
85+
return Field(child_type, resolver=resolver, **field_kwargs)
86+
87+
88+
def _convert_o2m_or_m2m_relationship(relationship_prop, obj_type, batching, connection_field_factory, **field_kwargs):
89+
"""
90+
Convert one-to-many or many-to-many relationshsip. Return a list field or a connection field.
91+
92+
:param sqlalchemy.RelationshipProperty relationship_prop:
93+
:param SQLAlchemyObjectType obj_type:
94+
:param bool batching:
95+
:param function|None connection_field_factory:
96+
:param dict field_kwargs:
97+
:rtype: Field
98+
"""
99+
child_type = obj_type._meta.registry.get_type_for_model(relationship_prop.mapper.entity)
100+
101+
if not child_type._meta.connection:
102+
return Field(List(child_type), **field_kwargs)
103+
104+
# TODO Allow override of connection_field_factory and resolver via ORMField
105+
if connection_field_factory is None:
106+
connection_field_factory = BatchSQLAlchemyConnectionField.from_relationship if batching else \
107+
default_connection_field_factory
108+
109+
return connection_field_factory(relationship_prop, obj_type._meta.registry, **field_kwargs)
110+
111+
56112
def convert_sqlalchemy_hybrid_method(hybrid_prop, resolver, **field_kwargs):
57113
if 'type' not in field_kwargs:
58114
# TODO The default type should be dependent on the type of the property propety.

graphene_sqlalchemy/resolver.py

Whitespace-only changes.

graphene_sqlalchemy/resolvers.py

+26
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
from graphene.utils.get_unbound_function import get_unbound_function
2+
3+
4+
def get_custom_resolver(obj_type, orm_field_name):
5+
"""
6+
Since `graphene` will call `resolve_<field_name>` on a field only if it
7+
does not have a `resolver`, we need to re-implement that logic here so
8+
users are able to override the default resolvers that we provide.
9+
"""
10+
resolver = getattr(obj_type, 'resolve_{}'.format(orm_field_name), None)
11+
if resolver:
12+
return get_unbound_function(resolver)
13+
14+
return None
15+
16+
17+
def get_attr_resolver(obj_type, model_attr):
18+
"""
19+
In order to support field renaming via `ORMField.model_attr`,
20+
we need to define resolver functions for each field.
21+
22+
:param SQLAlchemyObjectType obj_type:
23+
:param str model_attr: the name of the SQLAlchemy attribute
24+
:rtype: Callable
25+
"""
26+
return lambda root, _info: getattr(root, model_attr, None)

graphene_sqlalchemy/tests/test_batching.py

+190-5
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,9 @@
66
import graphene
77
from graphene import relay
88

9-
from ..fields import BatchSQLAlchemyConnectionField
10-
from ..types import SQLAlchemyObjectType
9+
from ..fields import (BatchSQLAlchemyConnectionField,
10+
default_connection_field_factory)
11+
from ..types import ORMField, SQLAlchemyObjectType
1112
from .models import Article, HairKind, Pet, Reporter
1213
from .utils import is_sqlalchemy_version_less_than, to_std_dicts
1314

@@ -43,19 +44,19 @@ class ReporterType(SQLAlchemyObjectType):
4344
class Meta:
4445
model = Reporter
4546
interfaces = (relay.Node,)
46-
connection_field_factory = BatchSQLAlchemyConnectionField.from_relationship
47+
batching = True
4748

4849
class ArticleType(SQLAlchemyObjectType):
4950
class Meta:
5051
model = Article
5152
interfaces = (relay.Node,)
52-
connection_field_factory = BatchSQLAlchemyConnectionField.from_relationship
53+
batching = True
5354

5455
class PetType(SQLAlchemyObjectType):
5556
class Meta:
5657
model = Pet
5758
interfaces = (relay.Node,)
58-
connection_field_factory = BatchSQLAlchemyConnectionField.from_relationship
59+
batching = True
5960

6061
class Query(graphene.ObjectType):
6162
articles = graphene.Field(graphene.List(ArticleType))
@@ -513,3 +514,187 @@ def test_many_to_many(session_factory):
513514
},
514515
],
515516
}
517+
518+
519+
def test_disable_batching_via_ormfield(session_factory):
520+
session = session_factory()
521+
reporter_1 = Reporter(first_name='Reporter_1')
522+
session.add(reporter_1)
523+
reporter_2 = Reporter(first_name='Reporter_2')
524+
session.add(reporter_2)
525+
session.commit()
526+
session.close()
527+
528+
class ReporterType(SQLAlchemyObjectType):
529+
class Meta:
530+
model = Reporter
531+
interfaces = (relay.Node,)
532+
batching = True
533+
534+
favorite_article = ORMField(batching=False)
535+
articles = ORMField(batching=False)
536+
537+
class ArticleType(SQLAlchemyObjectType):
538+
class Meta:
539+
model = Article
540+
interfaces = (relay.Node,)
541+
542+
class Query(graphene.ObjectType):
543+
reporters = graphene.Field(graphene.List(ReporterType))
544+
545+
def resolve_reporters(self, info):
546+
return info.context.get('session').query(Reporter).all()
547+
548+
schema = graphene.Schema(query=Query)
549+
550+
# Test one-to-one and many-to-one relationships
551+
with mock_sqlalchemy_logging_handler() as sqlalchemy_logging_handler:
552+
# Starts new session to fully reset the engine / connection logging level
553+
session = session_factory()
554+
schema.execute("""
555+
query {
556+
reporters {
557+
favoriteArticle {
558+
headline
559+
}
560+
}
561+
}
562+
""", context_value={"session": session})
563+
messages = sqlalchemy_logging_handler.messages
564+
565+
select_statements = [message for message in messages if 'SELECT' in message and 'FROM articles' in message]
566+
assert len(select_statements) == 2
567+
568+
# Test one-to-many and many-to-many relationships
569+
with mock_sqlalchemy_logging_handler() as sqlalchemy_logging_handler:
570+
# Starts new session to fully reset the engine / connection logging level
571+
session = session_factory()
572+
schema.execute("""
573+
query {
574+
reporters {
575+
articles {
576+
edges {
577+
node {
578+
headline
579+
}
580+
}
581+
}
582+
}
583+
}
584+
""", context_value={"session": session})
585+
messages = sqlalchemy_logging_handler.messages
586+
587+
select_statements = [message for message in messages if 'SELECT' in message and 'FROM articles' in message]
588+
assert len(select_statements) == 2
589+
590+
591+
def test_connection_factory_field_overrides_batching_is_false(session_factory):
592+
session = session_factory()
593+
reporter_1 = Reporter(first_name='Reporter_1')
594+
session.add(reporter_1)
595+
reporter_2 = Reporter(first_name='Reporter_2')
596+
session.add(reporter_2)
597+
session.commit()
598+
session.close()
599+
600+
class ReporterType(SQLAlchemyObjectType):
601+
class Meta:
602+
model = Reporter
603+
interfaces = (relay.Node,)
604+
batching = False
605+
connection_field_factory = BatchSQLAlchemyConnectionField.from_relationship
606+
607+
articles = ORMField(batching=False)
608+
609+
class ArticleType(SQLAlchemyObjectType):
610+
class Meta:
611+
model = Article
612+
interfaces = (relay.Node,)
613+
614+
class Query(graphene.ObjectType):
615+
reporters = graphene.Field(graphene.List(ReporterType))
616+
617+
def resolve_reporters(self, info):
618+
return info.context.get('session').query(Reporter).all()
619+
620+
schema = graphene.Schema(query=Query)
621+
622+
with mock_sqlalchemy_logging_handler() as sqlalchemy_logging_handler:
623+
# Starts new session to fully reset the engine / connection logging level
624+
session = session_factory()
625+
schema.execute("""
626+
query {
627+
reporters {
628+
articles {
629+
edges {
630+
node {
631+
headline
632+
}
633+
}
634+
}
635+
}
636+
}
637+
""", context_value={"session": session})
638+
messages = sqlalchemy_logging_handler.messages
639+
640+
if is_sqlalchemy_version_less_than('1.3'):
641+
# The batched SQL statement generated is different in 1.2.x
642+
# SQLAlchemy 1.3+ optimizes out a JOIN statement in `selectin`
643+
# See https://git.io/JewQu
644+
select_statements = [message for message in messages if 'SELECT' in message and 'JOIN articles' in message]
645+
else:
646+
select_statements = [message for message in messages if 'SELECT' in message and 'FROM articles' in message]
647+
assert len(select_statements) == 1
648+
649+
650+
def test_connection_factory_field_overrides_batching_is_true(session_factory):
651+
session = session_factory()
652+
reporter_1 = Reporter(first_name='Reporter_1')
653+
session.add(reporter_1)
654+
reporter_2 = Reporter(first_name='Reporter_2')
655+
session.add(reporter_2)
656+
session.commit()
657+
session.close()
658+
659+
class ReporterType(SQLAlchemyObjectType):
660+
class Meta:
661+
model = Reporter
662+
interfaces = (relay.Node,)
663+
batching = True
664+
connection_field_factory = default_connection_field_factory
665+
666+
articles = ORMField(batching=True)
667+
668+
class ArticleType(SQLAlchemyObjectType):
669+
class Meta:
670+
model = Article
671+
interfaces = (relay.Node,)
672+
673+
class Query(graphene.ObjectType):
674+
reporters = graphene.Field(graphene.List(ReporterType))
675+
676+
def resolve_reporters(self, info):
677+
return info.context.get('session').query(Reporter).all()
678+
679+
schema = graphene.Schema(query=Query)
680+
681+
with mock_sqlalchemy_logging_handler() as sqlalchemy_logging_handler:
682+
# Starts new session to fully reset the engine / connection logging level
683+
session = session_factory()
684+
schema.execute("""
685+
query {
686+
reporters {
687+
articles {
688+
edges {
689+
node {
690+
headline
691+
}
692+
}
693+
}
694+
}
695+
}
696+
""", context_value={"session": session})
697+
messages = sqlalchemy_logging_handler.messages
698+
699+
select_statements = [message for message in messages if 'SELECT' in message and 'FROM articles' in message]
700+
assert len(select_statements) == 2

0 commit comments

Comments
 (0)