Skip to content

Refactor code base to add ability for multiple concurrency models / executors #19

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 18 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion graphql_subscriptions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,3 @@
from .subscription_transport_ws import SubscriptionServer

__all__ = ['RedisPubsub', 'SubscriptionManager', 'SubscriptionServer']

Empty file.
82 changes: 82 additions & 0 deletions graphql_subscriptions/executors/asyncio.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
from __future__ import absolute_import

import asyncio
from websockets import ConnectionClosed

try:
from asyncio import ensure_future
except ImportError:
# ensure_future is only implemented in Python 3.4.4+
# Reference: https://github.com/graphql-python/graphql-core/blob/master/graphql/execution/executors/asyncio.py
def ensure_future(coro_or_future, loop=None):
"""Wrap a coroutine or an awaitable in a future.
If the argument is a Future, it is returned directly.
"""
if isinstance(coro_or_future, asyncio.Future):
if loop is not None and loop is not coro_or_future._loop:
raise ValueError('loop argument must agree with Future')
return coro_or_future
elif asyncio.iscoroutine(coro_or_future):
if loop is None:
loop = asyncio.get_event_loop()
task = loop.create_task(coro_or_future)
if task._source_traceback:
del task._source_traceback[-1]
return task
else:
raise TypeError(
'A Future, a coroutine or an awaitable is required')


class AsyncioExecutor(object):
error = ConnectionClosed
task_cancel_error = asyncio.CancelledError

def __init__(self, loop=None):
if loop is None:
loop = asyncio.get_event_loop()
self.loop = loop
self.futures = []

def ws_close(self, code):
return self.ws.close(code)

def ws_protocol(self):
return self.ws.subprotocol

def ws_isopen(self):
if self.ws.open:
return True
else:
return False

def ws_send(self, msg):
return self.ws.send(msg)

def ws_recv(self):
return self.ws.recv()

def sleep(self, time):
if self.loop.is_running():
return asyncio.sleep(time)
return self.loop.run_until_complete(asyncio.sleep(time))

@staticmethod
def kill(future):
future.cancel()

def join(self, future=None, timeout=None):
if not isinstance(future, asyncio.Future):
return
if self.loop.is_running():
return asyncio.wait_for(future, timeout=timeout)
return self.loop.run_until_complete(
asyncio.wait_for(future, timeout=timeout))

def execute(self, fn, *args, **kwargs):
result = fn(*args, **kwargs)
if isinstance(result, asyncio.Future) or asyncio.iscoroutine(result):
future = ensure_future(result, loop=self.loop)
self.futures.append(future)
return future
return result
Empty file.
52 changes: 52 additions & 0 deletions graphql_subscriptions/executors/gevent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
from __future__ import absolute_import

from geventwebsocket.exceptions import WebSocketError
import gevent


class GeventExecutor(object):
# used to patch socket library so it doesn't block
socket = gevent.socket
error = WebSocketError

def __init__(self):
self.greenlets = []

def ws_close(self, code):
self.ws.close(code)

def ws_protocol(self):
return self.ws.protocol

def ws_isopen(self):
if self.ws.closed:
return False
else:
return True

def ws_send(self, msg, **kwargs):
self.ws.send(msg, **kwargs)

def ws_recv(self):
return self.ws.receive()

@staticmethod
def sleep(time):
gevent.sleep(time)

@staticmethod
def kill(greenlet):
gevent.kill(greenlet)

@staticmethod
def join(greenlet, timeout=None):
greenlet.join(timeout)

def join_all(self):
gevent.joinall(self.greenlets)
self.greenlets = []

def execute(self, fn, *args, **kwargs):
greenlet = gevent.spawn(fn, *args, **kwargs)
self.greenlets.append(greenlet)
return greenlet
4 changes: 4 additions & 0 deletions graphql_subscriptions/subscription_manager/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from .manager import SubscriptionManager
from .pubsub import RedisPubsub

__all__ = ['SubscriptionManager', 'RedisPubsub']
Original file line number Diff line number Diff line change
Expand Up @@ -2,71 +2,15 @@
standard_library.install_aliases()
from builtins import object
from types import FunctionType
import pickle

from graphql import parse, validate, specified_rules, value_from_ast, execute
from graphql.language.ast import OperationDefinition
from promise import Promise
import gevent
import redis

from .utils import to_snake_case
from .validation import SubscriptionHasSingleRootField


class RedisPubsub(object):
def __init__(self, host='localhost', port=6379, *args, **kwargs):
redis.connection.socket = gevent.socket
self.redis = redis.StrictRedis(host, port, *args, **kwargs)
self.pubsub = self.redis.pubsub()
self.subscriptions = {}
self.sub_id_counter = 0
self.greenlet = None

def publish(self, trigger_name, message):
self.redis.publish(trigger_name, pickle.dumps(message))
return True

def subscribe(self, trigger_name, on_message_handler, options):
self.sub_id_counter += 1
try:
if trigger_name not in list(self.subscriptions.values())[0]:
self.pubsub.subscribe(trigger_name)
except IndexError:
self.pubsub.subscribe(trigger_name)
self.subscriptions[self.sub_id_counter] = [
trigger_name, on_message_handler
]
if not self.greenlet:
self.greenlet = gevent.spawn(self.wait_and_get_message)
return Promise.resolve(self.sub_id_counter)

def unsubscribe(self, sub_id):
trigger_name, on_message_handler = self.subscriptions[sub_id]
del self.subscriptions[sub_id]
try:
if trigger_name not in list(self.subscriptions.values())[0]:
self.pubsub.unsubscribe(trigger_name)
except IndexError:
self.pubsub.unsubscribe(trigger_name)
if not self.subscriptions:
self.greenlet = self.greenlet.kill()

def wait_and_get_message(self):
while True:
message = self.pubsub.get_message(ignore_subscribe_messages=True)
if message:
self.handle_message(message)
gevent.sleep(.001)

def handle_message(self, message):
if isinstance(message['channel'], bytes):
channel = message['channel'].decode()
for sub_id, trigger_map in self.subscriptions.items():
if trigger_map[0] == channel:
trigger_map[1](pickle.loads(message['data']))


class ValidationError(Exception):
def __init__(self, errors):
self.errors = errors
Expand All @@ -79,7 +23,7 @@ def __init__(self, schema, pubsub, setup_funcs={}):
self.pubsub = pubsub
self.setup_funcs = setup_funcs
self.subscriptions = {}
self.max_subscription_id = 0
self.max_subscription_id = 1

def publish(self, trigger_name, payload):
self.pubsub.publish(trigger_name, payload)
Expand Down Expand Up @@ -145,11 +89,6 @@ def subscribe(self, query, operation_name, callback, variables, context,
except AttributeError:
channel_options = {}

# TODO: Think about this some more...the Apollo library
# let's all messages through by default, even if
# the users incorrectly uses the setup_funcs (does not
# use 'filter' or 'channel_options' keys); I think it
# would be better to raise an exception here
def filter(arg1, arg2):
return True

Expand Down Expand Up @@ -181,7 +120,8 @@ def context_do_execute_handler(result):
subscription_promises.append(
self.pubsub.
subscribe(trigger_name, on_message, channel_options).then(
lambda id: self.subscriptions[external_subscription_id].append(id)
lambda id: self.subscriptions[external_subscription_id].
append(id)
))

return Promise.all(subscription_promises).then(
Expand Down
113 changes: 113 additions & 0 deletions graphql_subscriptions/subscription_manager/pubsub.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
from future import standard_library
standard_library.install_aliases()
from builtins import object
import pickle
import sys

from promise import Promise
import redis

from ..executors.gevent import GeventExecutor
from ..executors.asyncio import AsyncioExecutor

PY3 = sys.version_info[0] == 3


class RedisPubsub(object):
def __init__(self,
host='localhost',
port=6379,
executor=GeventExecutor,
*args,
**kwargs):

if executor == AsyncioExecutor:
try:
import aredis
except:
raise ImportError(
'You need the redis client "aredis" installed for use w/ '
'asyncio')

redis_client = aredis
else:
redis_client = redis

# patch redis socket library so it doesn't block if using gevent
if executor == GeventExecutor:
redis_client.connection.socket = executor.socket

self.redis = redis_client.StrictRedis(host, port, *args, **kwargs)
self.pubsub = self.redis.pubsub(ignore_subscribe_messages=True)

self.executor = executor()
self.backgrd_task = None

self.subscriptions = {}
self.sub_id_counter = 0

def publish(self, trigger_name, message):
self.executor.execute(self.redis.publish, trigger_name,
pickle.dumps(message))
return True

def subscribe(self, trigger_name, on_message_handler, options):
self.sub_id_counter += 1

self.subscriptions[self.sub_id_counter] = [
trigger_name, on_message_handler]

if PY3:
trigger_name = trigger_name.encode()

if trigger_name not in list(self.pubsub.channels.keys()):
self.executor.join(self.executor.execute(self.pubsub.subscribe,
trigger_name))
if not self.backgrd_task:
self.backgrd_task = self.executor.execute(
self.wait_and_get_message)

return Promise.resolve(self.sub_id_counter)

def unsubscribe(self, sub_id):
trigger_name, on_message_handler = self.subscriptions[sub_id]
del self.subscriptions[sub_id]

if PY3:
trigger_name = trigger_name.encode()

if trigger_name not in list(self.pubsub.channels.keys()):
self.executor.execute(self.pubsub.unsubscribe, trigger_name)

if not self.subscriptions:
self.backgrd_task = self.executor.kill(self.backgrd_task)

async def _wait_and_get_message_async(self):
try:
while True:
message = await self.pubsub.get_message()
if message:
self.handle_message(message)
await self.executor.sleep(.001)
except self.executor.task_cancel_error:
return

def _wait_and_get_message_sync(self):
while True:
message = self.pubsub.get_message()
if message:
self.handle_message(message)
self.executor.sleep(.001)

def wait_and_get_message(self):
if hasattr(self.executor, 'loop'):
return self._wait_and_get_message_async()
return self._wait_and_get_message_sync()

def handle_message(self, message):

channel = message['channel'].decode() if PY3 else message['channel']

for sub_id, trigger_map in self.subscriptions.items():
if trigger_map[0] == channel:
trigger_map[1](pickle.loads(message['data']))
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@

FIELD = 'Field'

# XXX from Apollo pacakge: Temporarily use this validation
# rule to make our life a bit easier.
# Temporarily use this validation rule to make our life a bit easier.


class SubscriptionHasSingleRootField(ValidationRule):
Expand All @@ -27,8 +26,8 @@ def enter_OperationDefinition(self, node, key, parent, path, ancestors):
else:
self.context.report_error(
GraphQLError(
'Apollo subscriptions do not support fragments on\
the root field', [node]))
'Subscriptions do not support fragments on '
'the root field', [node]))
if num_fields > 1:
self.context.report_error(
GraphQLError(
Expand All @@ -38,5 +37,5 @@ def enter_OperationDefinition(self, node, key, parent, path, ancestors):

@staticmethod
def too_many_subscription_fields_error(subscription_name):
return 'Subscription "{0}" must have only one\
field.'.format(subscription_name)
return ('Subscription "{0}" must have only one '
'field.'.format(subscription_name))
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .server import SubscriptionServer
Loading