diff --git a/channels_graphql_ws/graphql_ws_consumer.py b/channels_graphql_ws/graphql_ws_consumer.py index 81bd817..6210154 100644 --- a/channels_graphql_ws/graphql_ws_consumer.py +++ b/channels_graphql_ws/graphql_ws_consumer.py @@ -1,43 +1,3 @@ -# Copyright (C) DATADVANCE, 2010-2021 -# -# Permission is hereby granted, free of charge, to any person obtaining -# a copy of this software and associated documentation files (the -# "Software"), to deal in the Software without restriction, including -# without limitation the rights to use, copy, modify, merge, publish, -# distribute, sublicense, and/or sell copies of the Software, and to -# permit persons to whom the Software is furnished to do so, subject to -# the following conditions: -# -# The above copyright notice and this permission notice shall be -# included in all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, -# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF -# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. -# IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY -# CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, -# TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE -# SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. - -"""Channels consumer which implements GraphQL WebSocket protocol. - -The `GraphqlWsConsumer` is a Channels WebSocket consumer which maintains -WebSocket connection with the client. - -Implementation assumes that client uses the protocol implemented by the -library `subscription-transport-ws` (which is used by Apollo). - -NOTE: Links based on which this functionality is implemented: -- Protocol description: - https://github.com/apollographql/subscriptions-transport-ws/blob/master/PROTOCOL.md - https://github.com/apollographql/subscriptions-transport-ws/blob/master/src/message-types.ts -- ASGI specification for WebSockets: - https://github.com/django/asgiref/blob/master/specs/www.rst#websocket -- GitHubGist with the root of inspiration: - https://gist.github.com/tricoder42/af3d0337c1b33d82c1b32d12bd0265ec -""" - - import asyncio import concurrent import dataclasses @@ -47,7 +7,6 @@ import types import weakref from typing import Any, Callable, Dict, List, Optional, Sequence - import asgiref.sync import channels.generic.websocket as ch_websocket import django.core.serializers @@ -64,9 +23,9 @@ # Module logger. LOG = logging.getLogger(__name__) -# WebSocket subprotocol used for the GraphQL. +# WebSocket subprotocols used for the GraphQL. GRAPHQL_WS_SUBPROTOCOL = "graphql-ws" - +TRANSPORT_WS_SUBPROTOCOL = "graphql-transport-ws" class GraphqlWsConsumer(ch_websocket.AsyncJsonWebsocketConsumer): """Channels consumer for the WebSocket GraphQL backend. @@ -219,23 +178,20 @@ async def connect(self): """Handle new WebSocket connection.""" # Assert we run in a proper thread. self._assert_thread() - + self.connection_context = None + found_protocol = None # Check the subprotocol told by the client. # # NOTE: In Python 3.6 `scope["subprotocols"]` was a string, but # starting with Python 3.7 it is a bytes. This can be a proper # change or just a bug in the Channels to be fixed. So let's # accept both variants until it becomes clear. - assert GRAPHQL_WS_SUBPROTOCOL in ( - (sp.decode() if isinstance(sp, bytes) else sp) - for sp in self.scope["subprotocols"] - ), ( - f"WebSocket client does not request for the subprotocol " - f"{GRAPHQL_WS_SUBPROTOCOL}!" - ) + for protocol in [GRAPHQL_WS_SUBPROTOCOL, TRANSPORT_WS_SUBPROTOCOL]: + if protocol in self.scope["subprotocols"]: + found_protocol = protocol + break - # Accept connection with the GraphQL-specific subprotocol. - await self.accept(subprotocol=GRAPHQL_WS_SUBPROTOCOL) + await self.accept(subprotocol=found_protocol) async def disconnect(self, code): """Handle WebSocket disconnect. @@ -284,6 +240,7 @@ async def disconnect(self, code): self._background_tasks.clear() async def receive_json(self, content): # pylint: disable=arguments-differ + print(content['type']) """Process WebSocket message received from the client. NOTE: We force 'STOP' message processing to wait until 'START' @@ -306,11 +263,36 @@ async def receive_json(self, content): # pylint: disable=arguments-differ msg_type = content["type"].upper() if msg_type == "CONNECTION_INIT": - task = self._on_gql_connection_init(payload=content["payload"]) + + task = self._on_gql_connection_init(payload={}) elif msg_type == "CONNECTION_TERMINATE": task = self._on_gql_connection_terminate() + elif msg_type == "SUBSCRIBE": + op_id = content["id"] + # Create and lock a mutex for this particular operation id, + # so STOP processing for the same operation id will wail + # until START processing finishes. Locks are stored in a + # weak collection so we do not have to manually clean it up. + if op_id in self._operation_locks: + raise graphql.error.GraphQLError( + f"Operation with msg_id={op_id} is already running!" + ) + op_lock = asyncio.Lock() + self._operation_locks[op_id] = op_lock + await op_lock.acquire() + + async def on_start(): + try: + await self._on_gql_start( + operation_id=op_id, payload=content["payload"] + ) + finally: + op_lock.release() + + task = on_start() + elif msg_type == "START": op_id = content["id"] # Create and lock a mutex for this particular operation id, @@ -344,7 +326,15 @@ async def on_stop(): await self._on_gql_stop(operation_id=op_id) task = on_stop() + elif msg_type == "COMPLETE": + op_id = content["id"] + async def on_stop(): + # Will until START message processing finishes, if any. + async with self._operation_locks.setdefault(op_id, asyncio.Lock()): + await self._on_gql_stop(operation_id=op_id) + + task = on_stop() else: error_msg = f"Message of unknown type '{msg_type}' received!" task = self._send_gql_error( @@ -363,6 +353,7 @@ async def on_stop(): self._spawn_background_task(task) async def broadcast(self, message): + print(message) """The broadcast message handler. Method is called when new `broadcast` message received from the @@ -464,7 +455,7 @@ async def unsubscribe(self, message): # ---------------------------------------------------------- GRAPHQL PROTOCOL EVENTS async def _on_gql_connection_init(self, payload): - """Process the CONNECTION_INIT message. + """Process the CONNECTION_INIT or SUBSCRIBE message. Start sending keepalive messages if `send_keepalive_every` set. Respond with either CONNECTION_ACK or CONNECTION_ERROR message. @@ -505,13 +496,14 @@ async def keepalive_sender(): await asyncio.sleep(self.send_keepalive_every) await self._send_gql_connection_keep_alive() - self._keepalive_task = asyncio.ensure_future(keepalive_sender()) + self._keepalive_task = asyncio.ensure_future( + keepalive_sender()) # Immediately send keepalive message cause it is # required by the protocol description. await self._send_gql_connection_keep_alive() async def _on_gql_connection_terminate(self): - """Process the CONNECTION_TERMINATE message. + """Process the CONNECTION_TERMINATE OR message. NOTE: Depending on the value of the `strict_ordering` setting this method is either awaited directly or offloaded to an async @@ -754,7 +746,8 @@ async def notifier(): waitlist = [] for group in groups: self._sids_by_group.setdefault(group, []).append(operation_id) - waitlist.append(self._channel_layer.group_add(group, self.channel_name)) + waitlist.append(self._channel_layer.group_add( + group, self.channel_name)) notifier_task = self._spawn_background_task(notifier()) self._subscriptions[operation_id] = self._SubInf( groups=groups, @@ -866,7 +859,7 @@ async def _send_gql_data( await self.send_json( { - "type": "data", + "type": "next", "id": operation_id, "payload": { "data": data, @@ -900,7 +893,8 @@ async def _send_gql_error(self, operation_id, error: str): self._assert_thread() LOG.error("GraphQL query processing error: %s", error) await self.send_json( - {"type": "error", "id": operation_id, "payload": {"errors": [error]}} + {"type": "error", "id": operation_id, + "payload": {"errors": [error]}} ) async def _send_gql_complete(self, operation_id): diff --git a/channels_graphql_ws/templates/graphene/graphiql.html b/channels_graphql_ws/templates/graphene/graphiql.html new file mode 100644 index 0000000..d8066fb --- /dev/null +++ b/channels_graphql_ws/templates/graphene/graphiql.html @@ -0,0 +1,135 @@ + + + +
+ + + + + + + + + + + + + diff --git a/channels_graphql_ws/transport.py b/channels_graphql_ws/transport.py index ff031d9..0ebc704 100644 --- a/channels_graphql_ws/transport.py +++ b/channels_graphql_ws/transport.py @@ -169,13 +169,13 @@ async def _process_messages(self, connected, timeout): async with session as session: connection = session.ws_connect( self._url, - protocols=[graphql_ws_consumer.GRAPHQL_WS_SUBPROTOCOL], + protocols=[graphql_ws_consumer.GRAPHQL_WS_SUBPROTOCOL,graphql_ws_consumer.TRANSPORT_WS_SUBPROTOCOL], timeout=timeout, ) async with connection as self._connection: if ( self._connection.protocol - != graphql_ws_consumer.GRAPHQL_WS_SUBPROTOCOL + not in [graphql_ws_consumer.GRAPHQL_WS_SUBPROTOCOL,graphql_ws_consumer.TRANSPORT_WS_SUBPROTOCOL] ): raise RuntimeError( f"Server uses wrong subprotocol: {self._connection.protocol}!" diff --git a/poetry.lock b/poetry.lock index c584627..340bdef 100644 --- a/poetry.lock +++ b/poetry.lock @@ -367,11 +367,11 @@ python-versions = "*" [[package]] name = "django" -version = "3.2.6" +version = "4.0" description = "A high-level Python Web framework that encourages rapid development and clean, pragmatic design." category = "main" optional = false -python-versions = ">=3.6" +python-versions = ">=3.10" [package.dependencies] asgiref = ">=3.3.2,<4"