@@ -50,7 +50,7 @@ def __init__(self, prefix="asgi", expiry=60, group_expiry=86400, capacity=10, ch
50
50
def send (self , channel , message ):
51
51
# Typecheck
52
52
assert isinstance (message , dict ), "message is not a dict"
53
- assert isinstance (channel , six . text_type ), "%s is not unicode" % channel
53
+ assert self . valid_channel_name (channel ), "channel name not valid"
54
54
# Write message into the correct message queue
55
55
channel_list = self ._channel_list (channel )
56
56
with self .thread_lock :
@@ -63,7 +63,7 @@ def receive_many(self, channels, block=False):
63
63
if not channels :
64
64
return None , None
65
65
channels = list (channels )
66
- assert all (isinstance (channel , six . text_type ) for channel in channels )
66
+ assert all (self . valid_channel_name (channel ) for channel in channels ), "one or more channel names invalid"
67
67
random .shuffle (channels )
68
68
# Try to pop off all of the named channels
69
69
with self .thread_lock :
@@ -85,7 +85,7 @@ def new_channel(self, pattern):
85
85
# Keep making channel names till one isn't present.
86
86
while True :
87
87
random_string = "" .join (random .choice (string .ascii_letters ) for i in range (12 ))
88
- assert pattern .endswith ("!" )
88
+ assert pattern .endswith ("!" ) or pattern . endswith ( "?" )
89
89
new_name = pattern + random_string
90
90
# To see if it's present we open the queue without O_CREAT
91
91
if not MemoryList .exists (self ._channel_path (new_name )):
@@ -99,6 +99,7 @@ def group_add(self, group, channel):
99
99
"""
100
100
Adds the channel to the named group
101
101
"""
102
+ assert self .valid_group_name (group ), "Invalid group name"
102
103
group_dict = self ._group_dict (group )
103
104
with self .thread_lock :
104
105
group_dict [channel ] = time .time () + self .group_expiry
@@ -108,6 +109,7 @@ def group_discard(self, group, channel):
108
109
Removes the channel from the named group if it is in the group;
109
110
does nothing otherwise (does not error)
110
111
"""
112
+ assert self .valid_group_name (group ), "Invalid group name"
111
113
group_dict = self ._group_dict (group )
112
114
with self .thread_lock :
113
115
group_dict .discard (channel )
@@ -116,6 +118,7 @@ def send_group(self, group, message):
116
118
"""
117
119
Sends a message to the entire group.
118
120
"""
121
+ assert self .valid_group_name (group ), "Invalid group name"
119
122
group_dict = self ._group_dict (group )
120
123
with self .thread_lock :
121
124
items = list (group_dict .items ())
0 commit comments