@@ -22,6 +22,10 @@ def default(self, o):
22
22
23
23
24
24
class GraphQLSubscriptionConsumer (AsyncJsonWebsocketConsumer ):
25
+ def __init__ (self , * args , ** kwargs ):
26
+ super ().__init__ (* args , ** kwargs )
27
+ self .futures = []
28
+
25
29
async def connect (self ):
26
30
self .connection_context = None
27
31
if WS_PROTOCOL in self .scope ["subprotocols" ]:
@@ -33,14 +37,22 @@ async def connect(self):
33
37
await self .close ()
34
38
35
39
async def disconnect (self , code ):
40
+ for future in self .futures :
41
+ # Ensure any running message tasks are cancelled.
42
+ future .cancel ()
36
43
if self .connection_context :
37
44
self .connection_context .socket_closed = True
38
- await subscription_server .on_close (self .connection_context )
45
+ close_future = subscription_server .on_close (self .connection_context )
46
+ await asyncio .gather (close_future , * self .futures )
39
47
40
48
async def receive_json (self , content ):
41
- asyncio .ensure_future (
42
- subscription_server .on_message (self .connection_context , content )
49
+ self .futures .append (
50
+ asyncio .ensure_future (
51
+ subscription_server .on_message (self .connection_context , content )
52
+ )
43
53
)
54
+ # Clean up any completed futures.
55
+ self .futures = [future for future in self .futures if not future .done ()]
44
56
45
57
@classmethod
46
58
async def encode_json (cls , content ):
0 commit comments