Skip to content

Commit 98e6fe7

Browse files
authored
Fix N+1 problem for one-to-one and many-to-one relationships (#253)
1 parent 89c3726 commit 98e6fe7

File tree

8 files changed

+379
-47
lines changed

8 files changed

+379
-47
lines changed

graphene_sqlalchemy/resolver.py

Whitespace-only changes.

graphene_sqlalchemy/tests/conftest.py

+15-17
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import pytest
22
from sqlalchemy import create_engine
3-
from sqlalchemy.orm import scoped_session, sessionmaker
3+
from sqlalchemy.orm import sessionmaker
44

55
import graphene
66

@@ -23,19 +23,17 @@ def convert_composite_class(composite, registry):
2323

2424

2525
@pytest.yield_fixture(scope="function")
26-
def session():
27-
db = create_engine(test_db_url)
28-
connection = db.engine.connect()
29-
transaction = connection.begin()
30-
Base.metadata.create_all(connection)
31-
32-
# options = dict(bind=connection, binds={})
33-
session_factory = sessionmaker(bind=connection)
34-
session = scoped_session(session_factory)
35-
36-
yield session
37-
38-
# Finalize test here
39-
transaction.rollback()
40-
connection.close()
41-
session.remove()
26+
def session_factory():
27+
engine = create_engine(test_db_url)
28+
Base.metadata.create_all(engine)
29+
30+
yield sessionmaker(bind=engine)
31+
32+
# SQLite in-memory db is deleted when its connection is closed.
33+
# https://www.sqlite.org/inmemorydb.html
34+
engine.dispose()
35+
36+
37+
@pytest.fixture(scope="function")
38+
def session(session_factory):
39+
return session_factory()
+228
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,228 @@
1+
import contextlib
2+
import logging
3+
4+
import pkg_resources
5+
import pytest
6+
7+
import graphene
8+
9+
from ..types import SQLAlchemyObjectType
10+
from .models import Article, Reporter
11+
from .utils import to_std_dicts
12+
13+
14+
class MockLoggingHandler(logging.Handler):
15+
"""Intercept and store log messages in a list."""
16+
def __init__(self, *args, **kwargs):
17+
self.messages = []
18+
logging.Handler.__init__(self, *args, **kwargs)
19+
20+
def emit(self, record):
21+
self.messages.append(record.getMessage())
22+
23+
24+
@contextlib.contextmanager
25+
def mock_sqlalchemy_logging_handler():
26+
logging.basicConfig()
27+
sql_logger = logging.getLogger('sqlalchemy.engine')
28+
previous_level = sql_logger.level
29+
30+
sql_logger.setLevel(logging.INFO)
31+
mock_logging_handler = MockLoggingHandler()
32+
mock_logging_handler.setLevel(logging.INFO)
33+
sql_logger.addHandler(mock_logging_handler)
34+
35+
yield mock_logging_handler
36+
37+
sql_logger.setLevel(previous_level)
38+
39+
40+
def make_fixture(session):
41+
reporter_1 = Reporter(
42+
first_name='Reporter_1',
43+
)
44+
session.add(reporter_1)
45+
reporter_2 = Reporter(
46+
first_name='Reporter_2',
47+
)
48+
session.add(reporter_2)
49+
50+
article_1 = Article(headline='Article_1')
51+
article_1.reporter = reporter_1
52+
session.add(article_1)
53+
54+
article_2 = Article(headline='Article_2')
55+
article_2.reporter = reporter_2
56+
session.add(article_2)
57+
58+
session.commit()
59+
session.close()
60+
61+
62+
def get_schema(session):
63+
class ReporterType(SQLAlchemyObjectType):
64+
class Meta:
65+
model = Reporter
66+
67+
class ArticleType(SQLAlchemyObjectType):
68+
class Meta:
69+
model = Article
70+
71+
class Query(graphene.ObjectType):
72+
articles = graphene.Field(graphene.List(ArticleType))
73+
reporters = graphene.Field(graphene.List(ReporterType))
74+
75+
def resolve_articles(self, _info):
76+
return session.query(Article).all()
77+
78+
def resolve_reporters(self, _info):
79+
return session.query(Reporter).all()
80+
81+
return graphene.Schema(query=Query)
82+
83+
84+
def is_sqlalchemy_version_less_than(version_string):
85+
return pkg_resources.get_distribution('SQLAlchemy').parsed_version < pkg_resources.parse_version(version_string)
86+
87+
88+
if is_sqlalchemy_version_less_than('1.2'):
89+
pytest.skip('SQL batching only works for SQLAlchemy 1.2+', allow_module_level=True)
90+
91+
92+
def test_many_to_one(session_factory):
93+
session = session_factory()
94+
make_fixture(session)
95+
schema = get_schema(session)
96+
97+
with mock_sqlalchemy_logging_handler() as sqlalchemy_logging_handler:
98+
# Starts new session to fully reset the engine / connection logging level
99+
session = session_factory()
100+
result = schema.execute("""
101+
query {
102+
articles {
103+
headline
104+
reporter {
105+
firstName
106+
}
107+
}
108+
}
109+
""", context_value={"session": session})
110+
messages = sqlalchemy_logging_handler.messages
111+
112+
assert len(messages) == 5
113+
114+
if is_sqlalchemy_version_less_than('1.3'):
115+
# The batched SQL statement generated is different in 1.2.x
116+
# SQLAlchemy 1.3+ optimizes out a JOIN statement in `selectin`
117+
# See https://git.io/JewQu
118+
return
119+
120+
assert messages == [
121+
'BEGIN (implicit)',
122+
123+
'SELECT articles.id AS articles_id, '
124+
'articles.headline AS articles_headline, '
125+
'articles.pub_date AS articles_pub_date, '
126+
'articles.reporter_id AS articles_reporter_id \n'
127+
'FROM articles',
128+
'()',
129+
130+
'SELECT reporters.id AS reporters_id, '
131+
'(SELECT CAST(count(reporters.id) AS INTEGER) AS anon_2 \nFROM reporters) AS anon_1, '
132+
'reporters.first_name AS reporters_first_name, '
133+
'reporters.last_name AS reporters_last_name, '
134+
'reporters.email AS reporters_email, '
135+
'reporters.favorite_pet_kind AS reporters_favorite_pet_kind \n'
136+
'FROM reporters \n'
137+
'WHERE reporters.id IN (?, ?)',
138+
'(1, 2)',
139+
]
140+
141+
assert not result.errors
142+
result = to_std_dicts(result.data)
143+
assert result == {
144+
"articles": [
145+
{
146+
"headline": "Article_1",
147+
"reporter": {
148+
"firstName": "Reporter_1",
149+
},
150+
},
151+
{
152+
"headline": "Article_2",
153+
"reporter": {
154+
"firstName": "Reporter_2",
155+
},
156+
},
157+
],
158+
}
159+
160+
161+
def test_one_to_one(session_factory):
162+
session = session_factory()
163+
make_fixture(session)
164+
schema = get_schema(session)
165+
166+
with mock_sqlalchemy_logging_handler() as sqlalchemy_logging_handler:
167+
# Starts new session to fully reset the engine / connection logging level
168+
session = session_factory()
169+
result = schema.execute("""
170+
query {
171+
reporters {
172+
firstName
173+
favoriteArticle {
174+
headline
175+
}
176+
}
177+
}
178+
""", context_value={"session": session})
179+
messages = sqlalchemy_logging_handler.messages
180+
181+
assert len(messages) == 5
182+
183+
if is_sqlalchemy_version_less_than('1.3'):
184+
# The batched SQL statement generated is different in 1.2.x
185+
# SQLAlchemy 1.3+ optimizes out a JOIN statement in `selectin`
186+
# See https://git.io/JewQu
187+
return
188+
189+
assert messages == [
190+
'BEGIN (implicit)',
191+
192+
'SELECT (SELECT CAST(count(reporters.id) AS INTEGER) AS anon_2 \nFROM reporters) AS anon_1, '
193+
'reporters.id AS reporters_id, '
194+
'reporters.first_name AS reporters_first_name, '
195+
'reporters.last_name AS reporters_last_name, '
196+
'reporters.email AS reporters_email, '
197+
'reporters.favorite_pet_kind AS reporters_favorite_pet_kind \n'
198+
'FROM reporters',
199+
'()',
200+
201+
'SELECT articles.reporter_id AS articles_reporter_id, '
202+
'articles.id AS articles_id, '
203+
'articles.headline AS articles_headline, '
204+
'articles.pub_date AS articles_pub_date \n'
205+
'FROM articles \n'
206+
'WHERE articles.reporter_id IN (?, ?) '
207+
'ORDER BY articles.reporter_id',
208+
'(1, 2)'
209+
]
210+
211+
assert not result.errors
212+
result = to_std_dicts(result.data)
213+
assert result == {
214+
"reporters": [
215+
{
216+
"firstName": "Reporter_1",
217+
"favoriteArticle": {
218+
"headline": "Article_1",
219+
},
220+
},
221+
{
222+
"firstName": "Reporter_2",
223+
"favoriteArticle": {
224+
"headline": "Article_2",
225+
},
226+
},
227+
],
228+
}

graphene_sqlalchemy/tests/test_query.py

+1-10
Original file line numberDiff line numberDiff line change
@@ -5,16 +5,7 @@
55
from ..fields import SQLAlchemyConnectionField
66
from ..types import ORMField, SQLAlchemyObjectType
77
from .models import Article, CompositeFullName, Editor, HairKind, Pet, Reporter
8-
9-
10-
def to_std_dicts(value):
11-
"""Convert nested ordered dicts to normal dicts for better comparison."""
12-
if isinstance(value, dict):
13-
return {k: to_std_dicts(v) for k, v in value.items()}
14-
elif isinstance(value, list):
15-
return [to_std_dicts(v) for v in value]
16-
else:
17-
return value
8+
from .utils import to_std_dicts
189

1910

2011
def add_test_data(session):

graphene_sqlalchemy/tests/utils.py

+8
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
def to_std_dicts(value):
2+
"""Convert nested ordered dicts to normal dicts for better comparison."""
3+
if isinstance(value, dict):
4+
return {k: to_std_dicts(v) for k, v in value.items()}
5+
elif isinstance(value, list):
6+
return [to_std_dicts(v) for v in value]
7+
else:
8+
return value

0 commit comments

Comments
 (0)