Skip to content

Commit d90de4a

Browse files
authored
Fix N+1 problem for one-to-many and many-to-many relationships (#254)
This optimization batches what used to be multiple SQL statements into a single SQL statement. For now, you'll have to enable the optimization via the `SQLAlchemyObjectType.Meta.connection_field_factory` (see `test_batching.py`).
1 parent 98e6fe7 commit d90de4a

File tree

8 files changed

+458
-120
lines changed

8 files changed

+458
-120
lines changed

Diff for: .gitignore

+2
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ var/
2626
*.egg-info/
2727
.installed.cfg
2828
*.egg
29+
.python-version
2930

3031
# PyInstaller
3132
# Usually these files are written by a python script from a template
@@ -47,6 +48,7 @@ nosetests.xml
4748
coverage.xml
4849
*,cover
4950
.pytest_cache/
51+
.benchmarks/
5052

5153
# Translations
5254
*.mo

Diff for: graphene_sqlalchemy/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from .fields import SQLAlchemyConnectionField
33
from .utils import get_query, get_session
44

5-
__version__ = "2.2.2"
5+
__version__ = "2.3.0.dev0"
66

77
__all__ = [
88
"__version__",

Diff for: graphene_sqlalchemy/batching.py

+69
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
import sqlalchemy
2+
from promise import dataloader, promise
3+
from sqlalchemy.orm import Session, strategies
4+
from sqlalchemy.orm.query import QueryContext
5+
6+
7+
def get_batch_resolver(relationship_prop):
8+
class RelationshipLoader(dataloader.DataLoader):
9+
cache = False
10+
11+
def batch_load_fn(self, parents): # pylint: disable=method-hidden
12+
"""
13+
Batch loads the relationships of all the parents as one SQL statement.
14+
15+
There is no way to do this out-of-the-box with SQLAlchemy but
16+
we can piggyback on some internal APIs of the `selectin`
17+
eager loading strategy. It's a bit hacky but it's preferable
18+
than re-implementing and maintainnig a big chunk of the `selectin`
19+
loader logic ourselves.
20+
21+
The approach here is to build a regular query that
22+
selects the parent and `selectin` load the relationship.
23+
But instead of having the query emits 2 `SELECT` statements
24+
when callling `all()`, we skip the first `SELECT` statement
25+
and jump right before the `selectin` loader is called.
26+
To accomplish this, we have to construct objects that are
27+
normally built in the first part of the query in order
28+
to call directly `SelectInLoader._load_for_path`.
29+
30+
TODO Move this logic to a util in the SQLAlchemy repo as per
31+
SQLAlchemy's main maitainer suggestion.
32+
See https://git.io/JewQ7
33+
"""
34+
child_mapper = relationship_prop.mapper
35+
parent_mapper = relationship_prop.parent
36+
session = Session.object_session(parents[0])
37+
38+
# These issues are very unlikely to happen in practice...
39+
for parent in parents:
40+
# assert parent.__mapper__ is parent_mapper
41+
# All instances must share the same session
42+
assert session is Session.object_session(parent)
43+
# The behavior of `selectin` is undefined if the parent is dirty
44+
assert parent not in session.dirty
45+
46+
loader = strategies.SelectInLoader(relationship_prop, (('lazy', 'selectin'),))
47+
48+
# Should the boolean be set to False? Does it matter for our purposes?
49+
states = [(sqlalchemy.inspect(parent), True) for parent in parents]
50+
51+
# For our purposes, the query_context will only used to get the session
52+
query_context = QueryContext(session.query(parent_mapper.entity))
53+
54+
loader._load_for_path(
55+
query_context,
56+
parent_mapper._path_registry,
57+
states,
58+
None,
59+
child_mapper,
60+
)
61+
62+
return promise.Promise.resolve([getattr(parent, relationship_prop.key) for parent in parents])
63+
64+
loader = RelationshipLoader()
65+
66+
def resolve(root, info, **args):
67+
return loader.load(root)
68+
69+
return resolve

Diff for: graphene_sqlalchemy/fields.py

+30-8
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from graphene.relay.connection import PageInfo
1010
from graphql_relay.connection.arrayconnection import connection_from_list_slice
1111

12+
from .batching import get_batch_resolver
1213
from .utils import get_query
1314

1415

@@ -33,14 +34,8 @@ def model(self):
3334
return self.type._meta.node._meta.model
3435

3536
@classmethod
36-
def get_query(cls, model, info, sort=None, **args):
37-
query = get_query(model, info.context)
38-
if sort is not None:
39-
if isinstance(sort, six.string_types):
40-
query = query.order_by(sort.value)
41-
else:
42-
query = query.order_by(*(col.value for col in sort))
43-
return query
37+
def get_query(cls, model, info, **args):
38+
return get_query(model, info.context)
4439

4540
@classmethod
4641
def resolve_connection(cls, connection_type, model, info, args, resolved):
@@ -78,6 +73,7 @@ def get_resolver(self, parent_resolver):
7873
return partial(self.connection_resolver, parent_resolver, self.type, self.model)
7974

8075

76+
# TODO Rename this to SortableSQLAlchemyConnectionField
8177
class SQLAlchemyConnectionField(UnsortedSQLAlchemyConnectionField):
8278
def __init__(self, type, *args, **kwargs):
8379
if "sort" not in kwargs and issubclass(type, Connection):
@@ -95,6 +91,32 @@ def __init__(self, type, *args, **kwargs):
9591
del kwargs["sort"]
9692
super(SQLAlchemyConnectionField, self).__init__(type, *args, **kwargs)
9793

94+
@classmethod
95+
def get_query(cls, model, info, sort=None, **args):
96+
query = get_query(model, info.context)
97+
if sort is not None:
98+
if isinstance(sort, six.string_types):
99+
query = query.order_by(sort.value)
100+
else:
101+
query = query.order_by(*(col.value for col in sort))
102+
return query
103+
104+
105+
class BatchSQLAlchemyConnectionField(UnsortedSQLAlchemyConnectionField):
106+
"""
107+
This is currently experimental.
108+
The API and behavior may change in future versions.
109+
Use at your own risk.
110+
"""
111+
def get_resolver(self, parent_resolver):
112+
return partial(self.connection_resolver, self.resolver, self.type, self.model)
113+
114+
@classmethod
115+
def from_relationship(cls, relationship, registry, **field_kwargs):
116+
model = relationship.mapper.entity
117+
model_type = registry.get_type_for_model(model)
118+
return cls(model_type._meta.connection, resolver=get_batch_resolver(relationship), **field_kwargs)
119+
98120

99121
def default_connection_field_factory(relationship, registry, **field_kwargs):
100122
model = relationship.mapper.entity

Diff for: graphene_sqlalchemy/tests/models.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ class Reporter(Base):
6161
last_name = Column(String(30), doc="Last name")
6262
email = Column(String(), doc="Email")
6363
favorite_pet_kind = Column(PetKind)
64-
pets = relationship("Pet", secondary=association_table, backref="reporters")
64+
pets = relationship("Pet", secondary=association_table, backref="reporters", order_by="Pet.id")
6565
articles = relationship("Article", backref="reporter")
6666
favorite_article = relationship("Article", uselist=False)
6767

0 commit comments

Comments
 (0)