Skip to content

Commit 6456782

Browse files
committed
Cancel remaining fields on exceptions
gather() returns when the first exception is raised, but does not cancel any remaining tasks. These continue to run which is inefficient, and can also cause problems if they access shared resources like database connections. Fixes: #236
1 parent 0107e30 commit 6456782

File tree

2 files changed

+62
-7
lines changed

2 files changed

+62
-7
lines changed

src/graphql/execution/execute.py

+24-7
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,16 @@
22

33
from __future__ import annotations
44

5-
from asyncio import ensure_future, gather, shield, wait_for
5+
from asyncio import (
6+
FIRST_EXCEPTION,
7+
CancelledError,
8+
create_task,
9+
ensure_future,
10+
gather,
11+
shield,
12+
wait,
13+
wait_for,
14+
)
615
from contextlib import suppress
716
from copy import copy
817
from typing import (
@@ -459,12 +468,16 @@ async def get_results() -> dict[str, Any]:
459468
field = awaitable_fields[0]
460469
results[field] = await results[field]
461470
else:
462-
results.update(
463-
zip(
464-
awaitable_fields,
465-
await gather(*(results[field] for field in awaitable_fields)),
466-
)
467-
)
471+
tasks = {}
472+
for field in awaitable_fields:
473+
tasks[create_task(results[field])] = field # type: ignore[arg-type]
474+
475+
done, pending = await wait(tasks, return_when=FIRST_EXCEPTION)
476+
for task in pending:
477+
task.cancel()
478+
479+
results.update((tasks[task], task.result()) for task in done)
480+
468481
return results
469482

470483
return get_results()
@@ -538,6 +551,10 @@ async def await_completed() -> Any:
538551
try:
539552
return await completed
540553
except Exception as raw_error:
554+
# Before Python 3.8 CancelledError inherits Exception and
555+
# so gets caught here.
556+
if isinstance(raw_error, CancelledError):
557+
raise
541558
self.handle_field_error(
542559
raw_error,
543560
return_type,

tests/execution/test_parallel.py

+38
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
GraphQLInt,
1212
GraphQLInterfaceType,
1313
GraphQLList,
14+
GraphQLNonNull,
1415
GraphQLObjectType,
1516
GraphQLSchema,
1617
GraphQLString,
@@ -193,3 +194,40 @@ async def is_type_of_baz(obj, *_args):
193194
{"foo": [{"foo": "bar", "foobar": 1}, {"foo": "baz", "foobaz": 2}]},
194195
None,
195196
)
197+
198+
@pytest.mark.asyncio
199+
async def cancel_on_exception():
200+
barrier = Barrier(2)
201+
completed = False
202+
203+
async def succeed(*_args):
204+
nonlocal completed
205+
await barrier.wait()
206+
completed = True
207+
208+
async def fail(*_args):
209+
raise Exception
210+
211+
schema = GraphQLSchema(
212+
GraphQLObjectType(
213+
"Query",
214+
{
215+
"foo": GraphQLField(GraphQLNonNull(GraphQLBoolean), resolve=fail),
216+
"bar": GraphQLField(GraphQLBoolean, resolve=succeed),
217+
},
218+
)
219+
)
220+
221+
ast = parse("{foo, bar}")
222+
223+
awaitable_result = execute(schema, ast)
224+
assert isinstance(awaitable_result, Awaitable)
225+
result = await asyncio.wait_for(awaitable_result, 1.0)
226+
227+
assert result.errors
228+
assert not result.data
229+
230+
# Unblock succeed() and check that it does not complete
231+
await barrier.wait()
232+
await asyncio.sleep(0)
233+
assert not completed

0 commit comments

Comments
 (0)