Skip to content

add support for graphql-core 3 to graphql_ws.aiohttp #43

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 4 additions & 5 deletions .travis.yml
Original file line number Diff line number Diff line change
@@ -18,15 +18,14 @@ deploy:
MnY1TzdSMDJ0d2xyS3pXWlB6SGxLWHVLc0dwemxCS3RLZ0RVQW02aEFBaFZwdDBVbUhTWnVtQ0I1
cnFkWndJWFdTcEQ4SU8rRHMzTTdwbGMzMThPT2ZkOFo2MXU1dVlRZkFlUklURkRpNjVLUHp4Y1U9
on:
python: 2.7
python: 3.6
repo: graphql-python/graphql-ws
tags: true
install: pip install -U tox-travis
install: pip install -U tox-travis codecov
language: python
python:
- 3.8
- 3.7
- 3.6
- 3.5
- 2.7
script: tox
script: tox -- --cov-branch --cov-report=term-missing --cov=graphql_ws
after_success: codecov
10 changes: 7 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
@@ -7,6 +7,10 @@ Currently supports:
* [Gevent](https://github.com/graphql-python/graphql-ws#gevent)
* Sanic (uses [websockets](https://github.com/aaugustin/websockets/) library)

[![PyPI version](https://badge.fury.io/py/graphql-ws.svg)](https://badge.fury.io/py/graphql-ws)
[![TravisCI Build Status](https://travis-ci.org/graphql-python/graphql-ws.svg?branch=master)](https://travis-ci.org/graphql-python/graphql-ws)
[![codecov](https://codecov.io/gh/graphql-python/graphql-ws/branch/master/graph/badge.svg)](https://codecov.io/gh/graphql-python/graphql-ws)

# Installation instructions

For instaling graphql-ws, just run this command in your shell
@@ -167,8 +171,8 @@ class Subscription(graphene.ObjectType):


def resolve_count_seconds(
root,
info,
root,
info,
up_to=5
):
return Observable.interval(1000)\
@@ -202,4 +206,4 @@ from graphql_ws.django_channels import GraphQLSubscriptionConsumer
channel_routing = [
route_class(GraphQLSubscriptionConsumer, path=r"^/subscriptions"),
]
```
```
49 changes: 18 additions & 31 deletions graphql_ws/aiohttp.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,15 @@
from inspect import isawaitable
from asyncio import ensure_future, wait, shield
from asyncio import ensure_future, shield, wait

from aiohttp import WSMsgType
from graphql.execution.executors.asyncio import AsyncioExecutor
from graphql import subscribe
from graphql.language import parse

from .base import (
ConnectionClosedException, BaseConnectionContext, BaseSubscriptionServer)
from .observable_aiter import setup_observable_extension

from .constants import (
GQL_CONNECTION_ACK,
GQL_CONNECTION_ERROR,
GQL_COMPLETE
BaseConnectionContext,
BaseSubscriptionServer,
ConnectionClosedException,
)

setup_observable_extension()
from .constants import GQL_COMPLETE, GQL_CONNECTION_ACK, GQL_CONNECTION_ERROR


class AiohttpConnectionContext(BaseConnectionContext):
@@ -47,12 +42,6 @@ def __init__(self, schema, keep_alive=True, loop=None):
self.loop = loop
super().__init__(schema, keep_alive)

def get_graphql_params(self, *args, **kwargs):
params = super(AiohttpSubscriptionServer,
self).get_graphql_params(*args, **kwargs)
return dict(params, return_promise=True,
executor=AsyncioExecutor(loop=self.loop))

async def _handle(self, ws, request_context=None):
connection_context = AiohttpConnectionContext(ws, request_context)
await self.on_open(connection_context)
@@ -69,7 +58,8 @@ async def _handle(self, ws, request_context=None):
(_, pending) = await wait(pending, timeout=0, loop=self.loop)

task = ensure_future(
self.on_message(connection_context, message), loop=self.loop)
self.on_message(connection_context, message), loop=self.loop
)
pending.add(task)

self.on_close(connection_context)
@@ -99,23 +89,20 @@ async def on_connection_init(self, connection_context, op_id, payload):
await connection_context.close(1011)

async def on_start(self, connection_context, op_id, params):
execution_result = self.execute(
connection_context.request_context, params)

if isawaitable(execution_result):
execution_result = await execution_result
request_string = params.pop("request_string")
query = parse(request_string)
result = await subscribe(self.schema, query, **params)

if not hasattr(execution_result, '__aiter__'):
await self.send_execution_result(
connection_context, op_id, execution_result)
if not hasattr(result, "__aiter__"):
await self.send_execution_result(connection_context, op_id, result)
else:
iterator = await execution_result.__aiter__()
connection_context.register_operation(op_id, iterator)
async for single_result in iterator:
connection_context.register_operation(op_id, result)
async for single_result in result:
if not connection_context.has_operation(op_id):
break
await self.send_execution_result(
connection_context, op_id, single_result)
connection_context, op_id, single_result
)
await self.send_message(connection_context, op_id, GQL_COMPLETE)

async def on_stop(self, connection_context, op_id):
41 changes: 18 additions & 23 deletions graphql_ws/base.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
import json
from collections import OrderedDict

from graphql import graphql, format_error
from graphql import format_error, graphql

from .constants import (
GQL_CONNECTION_ERROR,
GQL_CONNECTION_INIT,
GQL_CONNECTION_TERMINATE,
GQL_DATA,
GQL_ERROR,
GQL_START,
GQL_STOP,
GQL_ERROR,
GQL_CONNECTION_ERROR,
GQL_DATA
)


@@ -51,7 +51,6 @@ def close(self, code):


class BaseSubscriptionServer(object):

def __init__(self, schema, keep_alive=True):
self.schema = schema
self.keep_alive = keep_alive
@@ -92,7 +91,8 @@ def process_message(self, connection_context, parsed_message):
if not isinstance(params, dict):
error = Exception(
"Invalid params returned from get_graphql_params!"
" Return values must be a dict.")
" Return values must be a dict."
)
return self.send_error(connection_context, op_id, error)

# If we already have a subscription with this id, unsubscribe from
@@ -106,8 +106,11 @@ def process_message(self, connection_context, parsed_message):
return self.on_stop(connection_context, op_id)

else:
return self.send_error(connection_context, op_id, Exception(
"Invalid message type: {}.".format(op_type)))
return self.send_error(
connection_context,
op_id,
Exception("Invalid message type: {}.".format(op_type)),
)

def send_execution_result(self, connection_context, op_id, execution_result):
result = self.execution_result_to_dict(execution_result)
@@ -118,8 +121,9 @@ def execution_result_to_dict(self, execution_result):
if execution_result.data:
result['data'] = execution_result.data
if execution_result.errors:
result['errors'] = [format_error(error)
for error in execution_result.errors]
result['errors'] = [
format_error(error) for error in execution_result.errors
]
return result

def send_message(self, connection_context, op_id=None, op_type=None, payload=None):
@@ -144,16 +148,9 @@ def send_error(self, connection_context, op_id, error, error_type=None):
' GQL_CONNECTION_ERROR or GQL_ERROR'
)

error_payload = {
'message': str(error)
}
error_payload = {'message': str(error)}

return self.send_message(
connection_context,
op_id,
error_type,
error_payload
)
return self.send_message(connection_context, op_id, error_type, error_payload)

def unsubscribe(self, connection_context, op_id):
if connection_context.has_operation(op_id):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the connection_context.get_operation(op_id).dispose() seems to depend on the old rxpy-dependency that has been removed from graphql-core and is effectively removed in this branch by removing setup_observable_extension.

this causes an unhandled exception on unsubscribe

@@ -170,8 +167,7 @@ def on_connection_terminate(self, connection_context, op_id):
return connection_context.close(1011)

def execute(self, request_context, params):
return graphql(
self.schema, **dict(params, allow_subscriptions=True))
return graphql(self.schema, params)

def handle(self, ws, request_context=None):
raise NotImplementedError("handle method not implemented")
@@ -180,8 +176,7 @@ def on_message(self, connection_context, message):
try:
if not isinstance(message, dict):
parsed_message = json.loads(message)
assert isinstance(
parsed_message, dict), "Payload must be an object."
assert isinstance(parsed_message, dict), "Payload must be an object."
else:
parsed_message = message
except Exception as e:
5 changes: 4 additions & 1 deletion requirements_dev.txt
Original file line number Diff line number Diff line change
@@ -6,9 +6,12 @@ tox>=3,<4
coverage>=5.0,<6
Sphinx>=1.8,<2
PyYAML>=5.3,<6
pytest==3.2.5
pytest<=3.6,<6
pytest-runner>=5.2,<6
gevent>=1.1,<2
graphene>=2.0,<3
django>=1.5,<3
channels>=1.0,<2
pytest-aiohttp
asyncmock; python_version<"3.8"
pytest-cov
9 changes: 5 additions & 4 deletions setup.py
Original file line number Diff line number Diff line change
@@ -3,7 +3,7 @@

"""The setup script."""

from setuptools import setup, find_packages
from setuptools import find_packages, setup

with open("README.rst") as readme_file:
readme = readme_file.read()
@@ -14,7 +14,7 @@
history = history_file.read()

requirements = [
"graphql-core>=2.0,<3",
"graphql-core>=3.0.0",
# TODO: put package requirements here
]

@@ -26,7 +26,8 @@

test_requirements = [
"pytest",
# TODO: put package test requirements here
"pytest-aiohttp",
'asyncmock; python_version<"3.8"',
]

setup(
@@ -57,7 +58,7 @@
"Programming Language :: Python :: 3.5",
"Programming Language :: Python :: 3.6",
"Programming Language :: Python :: 3.7",
"Programming Language :: Python :: 3.8"
"Programming Language :: Python :: 3.8",
],
test_suite="tests",
tests_require=test_requirements,
135 changes: 135 additions & 0 deletions tests/test_aiohttp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
import asyncio
import sys
from typing import Awaitable, Callable

import pytest
from aiohttp import WSMsgType
from aiohttp.client import ClientWebSocketResponse
from aiohttp.test_utils import TestClient
from aiohttp.web import Application, WebSocketResponse
from graphql import GraphQLSchema, build_schema
from graphql_ws.aiohttp import AiohttpSubscriptionServer

if sys.version_info >= (3, 8):
from unittest.mock import AsyncMock
else:
from asyncmock import AsyncMock


AiohttpClientFactory = Callable[[Application], Awaitable[TestClient]]


def schema() -> GraphQLSchema:
spec = """
type Query {
dummy: String
}
type Subscription {
messages: String
error: String
}
schema {
query: Query
subscription: Subscription
}
"""

async def messages_subscribe(root, _info):
await asyncio.sleep(0.1)
yield "foo"
await asyncio.sleep(0.1)
yield "bar"

async def error_subscribe(root, _info):
raise RuntimeError("baz")

schema = build_schema(spec)
schema.subscription_type.fields["messages"].subscribe = messages_subscribe
schema.subscription_type.fields["messages"].resolve = lambda evt, _info: evt
schema.subscription_type.fields["error"].subscribe = error_subscribe
schema.subscription_type.fields["error"].resolve = lambda evt, _info: evt
return schema


@pytest.fixture
def client(
loop: asyncio.AbstractEventLoop, aiohttp_client: AiohttpClientFactory
) -> TestClient:
subscription_server = AiohttpSubscriptionServer(schema())

async def subscriptions(request):
conn = WebSocketResponse(protocols=('graphql-ws',))
await conn.prepare(request)
await subscription_server.handle(conn)
return conn

app = Application()
app["subscription_server"] = subscription_server
app.router.add_get('/subscriptions', subscriptions)
return loop.run_until_complete(aiohttp_client(app))


@pytest.fixture
async def connection(client: TestClient) -> ClientWebSocketResponse:
conn = await client.ws_connect("/subscriptions")
yield conn
await conn.close()


async def test_connection_closed_on_error(connection: ClientWebSocketResponse):
connection._writer.transport.write(b'0' * 500)
response = await connection.receive()
assert response.type == WSMsgType.CLOSE


async def test_connection_init(connection: ClientWebSocketResponse):
await connection.send_str('{"type":"connection_init","payload":{}}')
response = await connection.receive()
assert response.type == WSMsgType.TEXT
assert response.data == '{"type": "connection_ack"}'


async def test_connection_init_rejected_on_error(
monkeypatch, client: TestClient, connection: ClientWebSocketResponse
):
# raise exception in AiohttpSubscriptionServer.on_connect
monkeypatch.setattr(
client.app["subscription_server"],
"on_connect",
AsyncMock(side_effect=RuntimeError()),
)
await connection.send_str('{"type":"connection_init", "payload": {}}')
response = await connection.receive()
assert response.type == WSMsgType.TEXT
assert response.json()['type'] == 'connection_error'


async def test_messages_subscription(connection: ClientWebSocketResponse):
await connection.send_str('{"type":"connection_init","payload":{}}')
await connection.receive()
await connection.send_str(
'{"id":"1","type":"start","payload":{"query":"subscription MySub { messages }"}}'
)
first = await connection.receive_str()
assert (
first == '{"id": "1", "type": "data", "payload": {"data": {"messages": "foo"}}}'
)
second = await connection.receive_str()
assert (
second
== '{"id": "1", "type": "data", "payload": {"data": {"messages": "bar"}}}'
)
resolve_message = await connection.receive_str()
assert resolve_message == '{"id": "1", "type": "complete"}'


async def test_subscription_resolve_error(connection: ClientWebSocketResponse):
await connection.send_str('{"type":"connection_init","payload":{}}')
await connection.receive()
await connection.send_str(
'{"id":"2","type":"start","payload":{"query":"subscription MySub { error }"}}'
)
error = await connection.receive_json()
assert error["payload"]["errors"][0]["message"] == "baz"
11 changes: 2 additions & 9 deletions tox.ini
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
[tox]
envlist = py27, py35, py36, py37, py38, flake8
envlist = py36, py37, py38, flake8

[travis]
python =
3.8: py38
3.7: py37
3.6: py36
3.5: py35
2.7: py27

[testenv:flake8]
basepython=python
@@ -21,12 +19,7 @@ deps =
-r{toxinidir}/requirements_dev.txt
commands =
pip install -U pip
pytest --basetemp={envtmpdir}

[testenv:py35]
deps =
-r{toxinidir}/requirements_dev.txt
aiohttp>=3.6,<4
pytest --basetemp={envtmpdir} {posargs}

[testenv:py36]
deps =