Skip to content
This repository was archived by the owner on Nov 23, 2017. It is now read-only.

Commit eb4c3ff

Browse files
committed
Make loop methods reject socket kinds they do not support.
More specifically: * loop.create_connection() and loop.create_server() can accept AF_INET or AF_INET6 SOCK_STREAM sockets; * loop.create_datagram_endpoint() can accept only SOCK_DGRAM sockets; * loop.connect_accepted_socket() can accept only SOCK_STREAM sockets; * fixed a bug in create_unix_server() and create_unix_connection() to properly check for SOCK_STREAM sockets on Linux; * fixed static DNS resolution to decline socket types that aren't strictly equal to SOCK_STREAM or SOCK_DGRAM. On Linux socket type can be a bit mask, and we should let system getaddrinfo() to deal with it.
1 parent a3bb643 commit eb4c3ff

File tree

5 files changed

+137
-24
lines changed

5 files changed

+137
-24
lines changed

asyncio/base_events.py

+45-11
Original file line numberDiff line numberDiff line change
@@ -84,12 +84,26 @@ def _set_reuseport(sock):
8484
'SO_REUSEPORT defined but not implemented.')
8585

8686

87-
# Linux's sock.type is a bitmask that can include extra info about socket.
88-
_SOCKET_TYPE_MASK = 0
89-
if hasattr(socket, 'SOCK_NONBLOCK'):
90-
_SOCKET_TYPE_MASK |= socket.SOCK_NONBLOCK
91-
if hasattr(socket, 'SOCK_CLOEXEC'):
92-
_SOCKET_TYPE_MASK |= socket.SOCK_CLOEXEC
87+
def _is_stream_socket(sock):
88+
# Linux's socket.type is a bitmask that can include extra info
89+
# about socket, therefore we can't do simple
90+
# `sock_type == socket.SOCK_STREAM`.
91+
return (sock.type & socket.SOCK_STREAM) == socket.SOCK_STREAM
92+
93+
94+
def _is_dgram_socket(sock):
95+
# Linux's socket.type is a bitmask that can include extra info
96+
# about socket, therefore we can't do simple
97+
# `sock_type == socket.SOCK_DGRAM`.
98+
return (sock.type & socket.SOCK_DGRAM) == socket.SOCK_DGRAM
99+
100+
101+
def _is_ip_socket(sock):
102+
if sock.family == socket.AF_INET:
103+
return True
104+
if hasattr(socket, 'AF_INET6') and sock.family == socket.AF_INET6:
105+
return True
106+
return False
93107

94108

95109
def _ipaddr_info(host, port, family, type, proto):
@@ -102,8 +116,12 @@ def _ipaddr_info(host, port, family, type, proto):
102116
host is None:
103117
return None
104118

105-
type &= ~_SOCKET_TYPE_MASK
106119
if type == socket.SOCK_STREAM:
120+
# Linux only:
121+
# getaddrinfo() can raise when socket.type is a bit mask.
122+
# So if socket.type is a bit mask of SOCK_STREAM, and say
123+
# SOCK_NONBLOCK, we simply return None, which will trigger
124+
# a call to getaddrinfo() letting it process this request.
107125
proto = socket.IPPROTO_TCP
108126
elif type == socket.SOCK_DGRAM:
109127
proto = socket.IPPROTO_UDP
@@ -124,7 +142,9 @@ def _ipaddr_info(host, port, family, type, proto):
124142
return None
125143

126144
if family == socket.AF_UNSPEC:
127-
afs = [socket.AF_INET, socket.AF_INET6]
145+
afs = [socket.AF_INET]
146+
if hasattr(socket, 'AF_INET6'):
147+
afs.append(socket.AF_INET6)
128148
else:
129149
afs = [family]
130150

@@ -771,9 +791,13 @@ def create_connection(self, protocol_factory, host=None, port=None, *,
771791
raise OSError('Multiple exceptions: {}'.format(
772792
', '.join(str(exc) for exc in exceptions)))
773793

774-
elif sock is None:
775-
raise ValueError(
776-
'host and port was not specified and no sock specified')
794+
else:
795+
if sock is None:
796+
raise ValueError(
797+
'host and port was not specified and no sock specified')
798+
if not _is_stream_socket(sock) or not _is_ip_socket(sock):
799+
raise ValueError(
800+
'A TCP Stream Socket was expected, got {!r}'.format(sock))
777801

778802
transport, protocol = yield from self._create_connection_transport(
779803
sock, protocol_factory, ssl, server_hostname)
@@ -817,6 +841,9 @@ def create_datagram_endpoint(self, protocol_factory,
817841
allow_broadcast=None, sock=None):
818842
"""Create datagram connection."""
819843
if sock is not None:
844+
if not _is_dgram_socket(sock):
845+
raise ValueError(
846+
'A UDP Socket was expected, got {!r}'.format(sock))
820847
if (local_addr or remote_addr or
821848
family or proto or flags or
822849
reuse_address or reuse_port or allow_broadcast):
@@ -1027,6 +1054,9 @@ def create_server(self, protocol_factory, host=None, port=None,
10271054
else:
10281055
if sock is None:
10291056
raise ValueError('Neither host/port nor sock were specified')
1057+
if not _is_stream_socket(sock) or not _is_ip_socket(sock):
1058+
raise ValueError(
1059+
'A TCP Stream Socket was expected, got {!r}'.format(sock))
10301060
sockets = [sock]
10311061

10321062
server = Server(self, sockets)
@@ -1048,6 +1078,10 @@ def connect_accepted_socket(self, protocol_factory, sock, *, ssl=None):
10481078
This method is a coroutine. When completed, the coroutine
10491079
returns a (transport, protocol) pair.
10501080
"""
1081+
if not _is_stream_socket(sock):
1082+
raise ValueError(
1083+
'A Stream Socket was expected, got {!r}'.format(sock))
1084+
10511085
transport, protocol = yield from self._create_connection_transport(
10521086
sock, protocol_factory, ssl, '', server_side=True)
10531087
if self._debug:

asyncio/unix_events.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -235,7 +235,7 @@ def create_unix_connection(self, protocol_factory, path, *,
235235
if sock is None:
236236
raise ValueError('no path and sock were specified')
237237
if (sock.family != socket.AF_UNIX or
238-
sock.type != socket.SOCK_STREAM):
238+
not base_events._is_stream_socket(sock)):
239239
raise ValueError(
240240
'A UNIX Domain Stream Socket was expected, got {!r}'
241241
.format(sock))
@@ -289,7 +289,7 @@ def create_unix_server(self, protocol_factory, path=None, *,
289289
'path was not specified, and no sock specified')
290290

291291
if (sock.family != socket.AF_UNIX or
292-
sock.type != socket.SOCK_STREAM):
292+
not base_events._is_stream_socket(sock)):
293293
raise ValueError(
294294
'A UNIX Domain Stream Socket was expected, got {!r}'
295295
.format(sock))

tests/test_base_events.py

+55-8
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,13 @@ def test_ipaddr_info(self):
116116
self.assertIsNone(
117117
base_events._ipaddr_info('::3%lo0', 1, INET6, STREAM, TCP))
118118

119+
if hasattr(socket, 'SOCK_NONBLOCK'):
120+
self.assertEqual(
121+
None,
122+
base_events._ipaddr_info(
123+
'1.2.3.4', 1, INET, STREAM | socket.SOCK_NONBLOCK, TCP))
124+
125+
119126
def test_port_parameter_types(self):
120127
# Test obscure kinds of arguments for "port".
121128
INET = socket.AF_INET
@@ -1040,6 +1047,43 @@ def test_create_connection_host_port_sock(self):
10401047
MyProto, 'example.com', 80, sock=object())
10411048
self.assertRaises(ValueError, self.loop.run_until_complete, coro)
10421049

1050+
@unittest.skipUnless(hasattr(socket, 'AF_UNIX'), 'no Unix sockets')
1051+
def test_create_connection_wrong_sock(self):
1052+
sock = socket.socket(socket.AF_UNIX)
1053+
with sock:
1054+
coro = self.loop.create_connection(MyProto, sock=sock)
1055+
with self.assertRaisesRegex(ValueError,
1056+
'A TCP Stream Socket was expected'):
1057+
self.loop.run_until_complete(coro)
1058+
1059+
@unittest.skipUnless(hasattr(socket, 'AF_UNIX'), 'no Unix sockets')
1060+
def test_create_server_wrong_sock(self):
1061+
sock = socket.socket(socket.AF_UNIX)
1062+
with sock:
1063+
coro = self.loop.create_server(MyProto, sock=sock)
1064+
with self.assertRaisesRegex(ValueError,
1065+
'A TCP Stream Socket was expected'):
1066+
self.loop.run_until_complete(coro)
1067+
1068+
@unittest.skipUnless(hasattr(socket, 'SOCK_NONBLOCK'),
1069+
'no socket.SOCK_NONBLOCK (linux only)')
1070+
def test_create_server_stream_bittype(self):
1071+
sock = socket.socket(
1072+
socket.AF_INET, socket.SOCK_STREAM | socket.SOCK_NONBLOCK)
1073+
with sock:
1074+
coro = self.loop.create_server(lambda: None, sock=sock)
1075+
srv = self.loop.run_until_complete(coro)
1076+
srv.close()
1077+
self.loop.run_until_complete(srv.wait_closed())
1078+
1079+
def test_create_datagram_endpoint_wrong_sock(self):
1080+
sock = socket.socket(socket.AF_INET)
1081+
with sock:
1082+
coro = self.loop.create_datagram_endpoint(MyProto, sock=sock)
1083+
with self.assertRaisesRegex(ValueError,
1084+
'A UDP Socket was expected'):
1085+
self.loop.run_until_complete(coro)
1086+
10431087
def test_create_connection_no_host_port_sock(self):
10441088
coro = self.loop.create_connection(MyProto)
10451089
self.assertRaises(ValueError, self.loop.run_until_complete, coro)
@@ -1487,36 +1531,39 @@ def test_create_datagram_endpoint_sock(self):
14871531
self.assertEqual('CLOSED', protocol.state)
14881532

14891533
def test_create_datagram_endpoint_sock_sockopts(self):
1534+
class FakeSock:
1535+
type = socket.SOCK_DGRAM
1536+
14901537
fut = self.loop.create_datagram_endpoint(
1491-
MyDatagramProto, local_addr=('127.0.0.1', 0), sock=object())
1538+
MyDatagramProto, local_addr=('127.0.0.1', 0), sock=FakeSock())
14921539
self.assertRaises(ValueError, self.loop.run_until_complete, fut)
14931540

14941541
fut = self.loop.create_datagram_endpoint(
1495-
MyDatagramProto, remote_addr=('127.0.0.1', 0), sock=object())
1542+
MyDatagramProto, remote_addr=('127.0.0.1', 0), sock=FakeSock())
14961543
self.assertRaises(ValueError, self.loop.run_until_complete, fut)
14971544

14981545
fut = self.loop.create_datagram_endpoint(
1499-
MyDatagramProto, family=1, sock=object())
1546+
MyDatagramProto, family=1, sock=FakeSock())
15001547
self.assertRaises(ValueError, self.loop.run_until_complete, fut)
15011548

15021549
fut = self.loop.create_datagram_endpoint(
1503-
MyDatagramProto, proto=1, sock=object())
1550+
MyDatagramProto, proto=1, sock=FakeSock())
15041551
self.assertRaises(ValueError, self.loop.run_until_complete, fut)
15051552

15061553
fut = self.loop.create_datagram_endpoint(
1507-
MyDatagramProto, flags=1, sock=object())
1554+
MyDatagramProto, flags=1, sock=FakeSock())
15081555
self.assertRaises(ValueError, self.loop.run_until_complete, fut)
15091556

15101557
fut = self.loop.create_datagram_endpoint(
1511-
MyDatagramProto, reuse_address=True, sock=object())
1558+
MyDatagramProto, reuse_address=True, sock=FakeSock())
15121559
self.assertRaises(ValueError, self.loop.run_until_complete, fut)
15131560

15141561
fut = self.loop.create_datagram_endpoint(
1515-
MyDatagramProto, reuse_port=True, sock=object())
1562+
MyDatagramProto, reuse_port=True, sock=FakeSock())
15161563
self.assertRaises(ValueError, self.loop.run_until_complete, fut)
15171564

15181565
fut = self.loop.create_datagram_endpoint(
1519-
MyDatagramProto, allow_broadcast=True, sock=object())
1566+
MyDatagramProto, allow_broadcast=True, sock=FakeSock())
15201567
self.assertRaises(ValueError, self.loop.run_until_complete, fut)
15211568

15221569
def test_create_datagram_endpoint_sockopts(self):

tests/test_events.py

+8-3
Original file line numberDiff line numberDiff line change
@@ -791,9 +791,9 @@ def client():
791791
conn, _ = lsock.accept()
792792
proto = MyProto(loop=loop)
793793
proto.loop = loop
794-
f = loop.create_task(
794+
loop.run_until_complete(
795795
loop.connect_accepted_socket(
796-
(lambda : proto), conn, ssl=server_ssl))
796+
(lambda: proto), conn, ssl=server_ssl))
797797
loop.run_forever()
798798
proto.transport.close()
799799
lsock.close()
@@ -1377,6 +1377,11 @@ def datagram_received(self, data, addr):
13771377
server.transport.close()
13781378

13791379
def test_create_datagram_endpoint_sock(self):
1380+
if (sys.platform == 'win32' and
1381+
isinstance(self.loop, proactor_events.BaseProactorEventLoop)):
1382+
raise unittest.SkipTest(
1383+
'UDP is not supported with proactor event loops')
1384+
13801385
sock = None
13811386
local_address = ('127.0.0.1', 0)
13821387
infos = self.loop.run_until_complete(
@@ -1394,7 +1399,7 @@ def test_create_datagram_endpoint_sock(self):
13941399
else:
13951400
assert False, 'Can not create socket.'
13961401

1397-
f = self.loop.create_connection(
1402+
f = self.loop.create_datagram_endpoint(
13981403
lambda: MyDatagramProto(loop=self.loop), sock=sock)
13991404
tr, pr = self.loop.run_until_complete(f)
14001405
self.assertIsInstance(tr, asyncio.Transport)

tests/test_unix_events.py

+27
Original file line numberDiff line numberDiff line change
@@ -280,6 +280,33 @@ def test_create_unix_server_path_inetsock(self):
280280
'A UNIX Domain Stream.*was expected'):
281281
self.loop.run_until_complete(coro)
282282

283+
def test_create_unix_server_path_dgram(self):
284+
sock = socket.socket(socket.AF_UNIX, socket.SOCK_DGRAM)
285+
with sock:
286+
coro = self.loop.create_unix_server(lambda: None, path=None,
287+
sock=sock)
288+
with self.assertRaisesRegex(ValueError,
289+
'A UNIX Domain Stream.*was expected'):
290+
self.loop.run_until_complete(coro)
291+
292+
@unittest.skipUnless(hasattr(socket, 'SOCK_NONBLOCK'),
293+
'no socket.SOCK_NONBLOCK (linux only)')
294+
def test_create_unix_server_path_stream_bittype(self):
295+
sock = socket.socket(
296+
socket.AF_UNIX, socket.SOCK_STREAM | socket.SOCK_NONBLOCK)
297+
with tempfile.NamedTemporaryFile() as file:
298+
fn = file.name
299+
try:
300+
with sock:
301+
sock.bind(fn)
302+
coro = self.loop.create_unix_server(lambda: None, path=None,
303+
sock=sock)
304+
srv = self.loop.run_until_complete(coro)
305+
srv.close()
306+
self.loop.run_until_complete(srv.wait_closed())
307+
finally:
308+
os.unlink(fn)
309+
283310
def test_create_unix_connection_path_inetsock(self):
284311
sock = socket.socket()
285312
with sock:

0 commit comments

Comments
 (0)