Skip to content
This repository was archived by the owner on Mar 5, 2020. It is now read-only.

Commit 7f3b80a

Browse files
committed
Rewrite IPC layer to just use two shared memory segments
1 parent f1f4fbf commit 7f3b80a

File tree

2 files changed

+65
-76
lines changed

2 files changed

+65
-76
lines changed

asgi_ipc.py

Lines changed: 41 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515

1616
__version__ = pkg_resources.require('asgi_ipc')[0].version
17+
MB = 1024 * 1024
1718

1819

1920
class IPCChannelLayer(BaseChannelLayer):
@@ -39,9 +40,9 @@ def __init__(self, prefix="asgi", expiry=60, group_expiry=86400, capacity=10, ch
3940
)
4041
self.thread_lock = threading.Lock()
4142
self.prefix = prefix
42-
self.channel_set = MemorySet("/%s-channelset" % self.prefix)
43+
self.channel_store = MemoryDict("/%s-channel-dict" % self.prefix, size=100 * MB)
4344
# Set containing all groups to flush
44-
self.group_set = MemorySet("/%s-groupset" % self.prefix)
45+
self.group_store = MemoryDict("/%s-group-dict" % self.prefix, size=20 * MB)
4546

4647
### ASGI API ###
4748

@@ -52,12 +53,13 @@ def send(self, channel, message):
5253
assert isinstance(message, dict), "message is not a dict"
5354
assert self.valid_channel_name(channel), "channel name not valid"
5455
# Write message into the correct message queue
55-
channel_list = self._channel_list(channel)
5656
with self.thread_lock:
57+
channel_list = self.channel_store.get(channel, [])
5758
if len(channel_list) >= self.get_capacity(channel):
5859
raise self.ChannelFull
5960
else:
6061
channel_list.append([message, time.time() + self.expiry])
62+
self.channel_store[channel] = channel_list
6163

6264
def receive_many(self, channels, block=False):
6365
if not channels:
@@ -68,16 +70,22 @@ def receive_many(self, channels, block=False):
6870
# Try to pop off all of the named channels
6971
with self.thread_lock:
7072
for channel in channels:
71-
channel_list = self._channel_list(channel)
73+
channel_list = self.channel_store.get(channel, [])
7274
# Keep looping on the channel until we hit no messages or an unexpired one
7375
while True:
7476
try:
75-
message, expires = channel_list.popleft()
77+
# Popleft equivalent
78+
message, expires = channel_list[0]
79+
channel_list = channel_list[1:]
80+
self.channel_store[channel] = channel_list
7681
if expires <= time.time():
7782
continue
7883
return channel, message
7984
except IndexError:
8085
break
86+
# If the channel is now empty, delete its key
87+
if not channel_list and channel in self.channel_store:
88+
del self.channel_store[channel]
8189
return None, None
8290

8391
def new_channel(self, pattern):
@@ -88,10 +96,11 @@ def new_channel(self, pattern):
8896
assert pattern.endswith("!") or pattern.endswith("?")
8997
new_name = pattern + random_string
9098
# To see if it's present we open the queue without O_CREAT
91-
if not MemoryList.exists(self._channel_path(new_name)):
92-
return new_name
93-
else:
94-
continue
99+
with self.thread_lock:
100+
if new_name not in self.channel_store:
101+
return new_name
102+
else:
103+
continue
95104

96105
### Groups extension ###
97106

@@ -100,32 +109,38 @@ def group_add(self, group, channel):
100109
Adds the channel to the named group
101110
"""
102111
assert self.valid_group_name(group), "Invalid group name"
103-
group_dict = self._group_dict(group)
104112
with self.thread_lock:
113+
group_dict = self.group_store.get(group, {})
105114
group_dict[channel] = time.time() + self.group_expiry
115+
self.group_store[group] = group_dict
106116

107117
def group_discard(self, group, channel):
108118
"""
109119
Removes the channel from the named group if it is in the group;
110120
does nothing otherwise (does not error)
111121
"""
112122
assert self.valid_group_name(group), "Invalid group name"
113-
group_dict = self._group_dict(group)
114123
with self.thread_lock:
115-
group_dict.discard(channel)
124+
group_dict = self.group_store.get(group, {})
125+
if channel in group_dict:
126+
del group_dict[channel]
127+
if not group_dict:
128+
del self.group_store[group]
129+
else:
130+
self.group_store[group] = group_dict
116131

117132
def send_group(self, group, message):
118133
"""
119134
Sends a message to the entire group.
120135
"""
121136
assert self.valid_group_name(group), "Invalid group name"
122-
group_dict = self._group_dict(group)
123137
with self.thread_lock:
124-
items = list(group_dict.items())
125-
for channel, expires in items:
138+
group_dict = self.group_store.get(group, {})
139+
for channel, expires in list(group_dict.items()):
126140
if expires <= time.time():
141+
del group_dict[channel]
127142
with self.thread_lock:
128-
group_dict.discard(channel)
143+
self.group_store[group] = group_dict
129144
else:
130145
try:
131146
self.send(channel, message)
@@ -139,34 +154,8 @@ def flush(self):
139154
Deletes all messages and groups.
140155
"""
141156
with self.thread_lock:
142-
for path in self.channel_set:
143-
MemoryList(path).flush()
144-
for path in self.group_set:
145-
MemoryDict(path).flush()
146-
147-
### Internal functions ###
148-
149-
def _channel_path(self, channel):
150-
assert isinstance(channel, six.text_type)
151-
return "/%s-channel-%s" % (self.prefix, channel.encode("ascii"))
152-
153-
def _group_path(self, group):
154-
assert isinstance(group, six.text_type)
155-
return "/%s-group-%s" % (self.prefix, group.encode("ascii"))
156-
157-
def _channel_list(self, channel):
158-
"""
159-
Returns a MemoryList object for the channel
160-
"""
161-
self.channel_set.add(self._channel_path(channel))
162-
return MemoryList(self._channel_path(channel), size=1024*1024*self.capacity)
163-
164-
def _group_dict(self, group):
165-
"""
166-
Returns a MemoryDict object for the named group
167-
"""
168-
self.group_set.add(self._group_path(group))
169-
return MemoryDict(self._group_path(group), size=1024*1024*10)
157+
self.channel_store.flush()
158+
self.group_store.flush()
170159

171160
def __str__(self):
172161
return "%s(hosts=%s)" % (self.__class__.__name__, self.hosts)
@@ -309,6 +298,11 @@ def __setitem__(self, key, value):
309298
d[key] = value
310299
self._set_value(d)
311300

301+
def __delitem__(self, key):
302+
d = self._get_value()
303+
del d[key]
304+
self._set_value(d)
305+
312306
def __len__(self):
313307
return len(self._get_value())
314308

@@ -329,39 +323,11 @@ def keys(self):
329323
def values(self):
330324
return self._get_value().values()
331325

326+
def get(self, key, default):
327+
return self._get_value().get(key, default)
328+
332329
def discard(self, item):
333330
value = self._get_value()
334331
if item in value:
335332
del value[item]
336333
self._set_value(value)
337-
338-
339-
class MemorySet(MemoryDict):
340-
"""
341-
Like MemoryDict but just presents a set interface (using dict keys)
342-
"""
343-
344-
def add(self, item):
345-
value = self._get_value()
346-
value[item] = None
347-
self._set_value(value)
348-
349-
350-
class MemoryList(MemoryDatastructure):
351-
"""
352-
Memory-backed list. Used for channels.
353-
"""
354-
355-
signature = b"ASGL0001"
356-
357-
datatype = list
358-
359-
def append(self, item):
360-
value = self._get_value()
361-
value.append(item)
362-
self._set_value(value)
363-
364-
def popleft(self):
365-
value = self._get_value()
366-
self._set_value(value[1:])
367-
return value[0]

test_asgi_ipc.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import unicode_literals
22

3-
from asgi_ipc import IPCChannelLayer
3+
import unittest
4+
from asgi_ipc import IPCChannelLayer, MemoryDict
45
from asgiref.conformance import ConformanceTestCase
56

67

@@ -10,3 +11,25 @@ class IPCLayerTests(ConformanceTestCase):
1011
channel_layer = IPCChannelLayer(expiry=1, group_expiry=2, capacity=5)
1112
expiry_delay = 1.1
1213
capacity_limit = 5
14+
15+
16+
# MemoryDict unit tests
17+
class MemoryDictTests(unittest.TestCase):
18+
19+
def setUp(self):
20+
self.instance = MemoryDict("/test-md")
21+
self.instance.flush()
22+
23+
def test_item_access(self):
24+
# Make sure the key is not there to start
25+
with self.assertRaises(KeyError):
26+
self.instance["test"]
27+
# Set it and check it twice
28+
self.instance["test"] = "foo"
29+
self.assertEqual(self.instance["test"], "foo")
30+
self.instance["test"] = "bar"
31+
self.assertEqual(self.instance["test"], "bar")
32+
# Delete it and make sure it's gone
33+
del self.instance["test"]
34+
with self.assertRaises(KeyError):
35+
self.instance["test"]

0 commit comments

Comments
 (0)