14
14
15
15
16
16
__version__ = pkg_resources .require ('asgi_ipc' )[0 ].version
17
+ MB = 1024 * 1024
17
18
18
19
19
20
class IPCChannelLayer (BaseChannelLayer ):
@@ -39,9 +40,9 @@ def __init__(self, prefix="asgi", expiry=60, group_expiry=86400, capacity=10, ch
39
40
)
40
41
self .thread_lock = threading .Lock ()
41
42
self .prefix = prefix
42
- self .channel_set = MemorySet ("/%s-channelset " % self .prefix )
43
+ self .channel_store = MemoryDict ("/%s-channel-dict " % self .prefix , size = 100 * MB )
43
44
# Set containing all groups to flush
44
- self .group_set = MemorySet ("/%s-groupset " % self .prefix )
45
+ self .group_store = MemoryDict ("/%s-group-dict " % self .prefix , size = 20 * MB )
45
46
46
47
### ASGI API ###
47
48
@@ -52,12 +53,13 @@ def send(self, channel, message):
52
53
assert isinstance (message , dict ), "message is not a dict"
53
54
assert self .valid_channel_name (channel ), "channel name not valid"
54
55
# Write message into the correct message queue
55
- channel_list = self ._channel_list (channel )
56
56
with self .thread_lock :
57
+ channel_list = self .channel_store .get (channel , [])
57
58
if len (channel_list ) >= self .get_capacity (channel ):
58
59
raise self .ChannelFull
59
60
else :
60
61
channel_list .append ([message , time .time () + self .expiry ])
62
+ self .channel_store [channel ] = channel_list
61
63
62
64
def receive_many (self , channels , block = False ):
63
65
if not channels :
@@ -68,16 +70,22 @@ def receive_many(self, channels, block=False):
68
70
# Try to pop off all of the named channels
69
71
with self .thread_lock :
70
72
for channel in channels :
71
- channel_list = self ._channel_list (channel )
73
+ channel_list = self .channel_store . get (channel , [] )
72
74
# Keep looping on the channel until we hit no messages or an unexpired one
73
75
while True :
74
76
try :
75
- message , expires = channel_list .popleft ()
77
+ # Popleft equivalent
78
+ message , expires = channel_list [0 ]
79
+ channel_list = channel_list [1 :]
80
+ self .channel_store [channel ] = channel_list
76
81
if expires <= time .time ():
77
82
continue
78
83
return channel , message
79
84
except IndexError :
80
85
break
86
+ # If the channel is now empty, delete its key
87
+ if not channel_list and channel in self .channel_store :
88
+ del self .channel_store [channel ]
81
89
return None , None
82
90
83
91
def new_channel (self , pattern ):
@@ -88,10 +96,11 @@ def new_channel(self, pattern):
88
96
assert pattern .endswith ("!" ) or pattern .endswith ("?" )
89
97
new_name = pattern + random_string
90
98
# To see if it's present we open the queue without O_CREAT
91
- if not MemoryList .exists (self ._channel_path (new_name )):
92
- return new_name
93
- else :
94
- continue
99
+ with self .thread_lock :
100
+ if new_name not in self .channel_store :
101
+ return new_name
102
+ else :
103
+ continue
95
104
96
105
### Groups extension ###
97
106
@@ -100,32 +109,38 @@ def group_add(self, group, channel):
100
109
Adds the channel to the named group
101
110
"""
102
111
assert self .valid_group_name (group ), "Invalid group name"
103
- group_dict = self ._group_dict (group )
104
112
with self .thread_lock :
113
+ group_dict = self .group_store .get (group , {})
105
114
group_dict [channel ] = time .time () + self .group_expiry
115
+ self .group_store [group ] = group_dict
106
116
107
117
def group_discard (self , group , channel ):
108
118
"""
109
119
Removes the channel from the named group if it is in the group;
110
120
does nothing otherwise (does not error)
111
121
"""
112
122
assert self .valid_group_name (group ), "Invalid group name"
113
- group_dict = self ._group_dict (group )
114
123
with self .thread_lock :
115
- group_dict .discard (channel )
124
+ group_dict = self .group_store .get (group , {})
125
+ if channel in group_dict :
126
+ del group_dict [channel ]
127
+ if not group_dict :
128
+ del self .group_store [group ]
129
+ else :
130
+ self .group_store [group ] = group_dict
116
131
117
132
def send_group (self , group , message ):
118
133
"""
119
134
Sends a message to the entire group.
120
135
"""
121
136
assert self .valid_group_name (group ), "Invalid group name"
122
- group_dict = self ._group_dict (group )
123
137
with self .thread_lock :
124
- items = list ( group_dict . items () )
125
- for channel , expires in items :
138
+ group_dict = self . group_store . get ( group , {} )
139
+ for channel , expires in list ( group_dict . items ()) :
126
140
if expires <= time .time ():
141
+ del group_dict [channel ]
127
142
with self .thread_lock :
128
- group_dict . discard ( channel )
143
+ self . group_store [ group ] = group_dict
129
144
else :
130
145
try :
131
146
self .send (channel , message )
@@ -139,34 +154,8 @@ def flush(self):
139
154
Deletes all messages and groups.
140
155
"""
141
156
with self .thread_lock :
142
- for path in self .channel_set :
143
- MemoryList (path ).flush ()
144
- for path in self .group_set :
145
- MemoryDict (path ).flush ()
146
-
147
- ### Internal functions ###
148
-
149
- def _channel_path (self , channel ):
150
- assert isinstance (channel , six .text_type )
151
- return "/%s-channel-%s" % (self .prefix , channel .encode ("ascii" ))
152
-
153
- def _group_path (self , group ):
154
- assert isinstance (group , six .text_type )
155
- return "/%s-group-%s" % (self .prefix , group .encode ("ascii" ))
156
-
157
- def _channel_list (self , channel ):
158
- """
159
- Returns a MemoryList object for the channel
160
- """
161
- self .channel_set .add (self ._channel_path (channel ))
162
- return MemoryList (self ._channel_path (channel ), size = 1024 * 1024 * self .capacity )
163
-
164
- def _group_dict (self , group ):
165
- """
166
- Returns a MemoryDict object for the named group
167
- """
168
- self .group_set .add (self ._group_path (group ))
169
- return MemoryDict (self ._group_path (group ), size = 1024 * 1024 * 10 )
157
+ self .channel_store .flush ()
158
+ self .group_store .flush ()
170
159
171
160
def __str__ (self ):
172
161
return "%s(hosts=%s)" % (self .__class__ .__name__ , self .hosts )
@@ -309,6 +298,11 @@ def __setitem__(self, key, value):
309
298
d [key ] = value
310
299
self ._set_value (d )
311
300
301
+ def __delitem__ (self , key ):
302
+ d = self ._get_value ()
303
+ del d [key ]
304
+ self ._set_value (d )
305
+
312
306
def __len__ (self ):
313
307
return len (self ._get_value ())
314
308
@@ -329,39 +323,11 @@ def keys(self):
329
323
def values (self ):
330
324
return self ._get_value ().values ()
331
325
326
+ def get (self , key , default ):
327
+ return self ._get_value ().get (key , default )
328
+
332
329
def discard (self , item ):
333
330
value = self ._get_value ()
334
331
if item in value :
335
332
del value [item ]
336
333
self ._set_value (value )
337
-
338
-
339
- class MemorySet (MemoryDict ):
340
- """
341
- Like MemoryDict but just presents a set interface (using dict keys)
342
- """
343
-
344
- def add (self , item ):
345
- value = self ._get_value ()
346
- value [item ] = None
347
- self ._set_value (value )
348
-
349
-
350
- class MemoryList (MemoryDatastructure ):
351
- """
352
- Memory-backed list. Used for channels.
353
- """
354
-
355
- signature = b"ASGL0001"
356
-
357
- datatype = list
358
-
359
- def append (self , item ):
360
- value = self ._get_value ()
361
- value .append (item )
362
- self ._set_value (value )
363
-
364
- def popleft (self ):
365
- value = self ._get_value ()
366
- self ._set_value (value [1 :])
367
- return value [0 ]
0 commit comments