diff --git a/asyncio/base_events.py b/asyncio/base_events.py index 648b9b9b..640fdc6f 100644 --- a/asyncio/base_events.py +++ b/asyncio/base_events.py @@ -87,6 +87,12 @@ def _set_reuseport(sock): 'SO_REUSEPORT defined but not implemented.') +def _copy_and_detach_socket(sock): + fd = sock.detach() + new_sock = socket.socket(sock.family, sock.type, sock.proto, fd) + return new_sock + + # Linux's sock.type is a bitmask that can include extra info about socket. _SOCKET_TYPE_MASK = 0 if hasattr(socket, 'SOCK_NONBLOCK'): @@ -768,9 +774,11 @@ def create_connection(self, protocol_factory, host=None, port=None, *, raise OSError('Multiple exceptions: {}'.format( ', '.join(str(exc) for exc in exceptions))) - elif sock is None: - raise ValueError( - 'host and port was not specified and no sock specified') + else: + if sock is None: + raise ValueError( + 'host and port was not specified and no sock specified') + sock = _copy_and_detach_socket(sock) transport, protocol = yield from self._create_connection_transport( sock, protocol_factory, ssl, server_hostname) @@ -828,6 +836,7 @@ def create_datagram_endpoint(self, protocol_factory, 'socket modifier keyword arguments can not be used ' 'when sock is specified. ({})'.format(problems)) sock.setblocking(False) + sock = _copy_and_detach_socket(sock) r_addr = None else: if not (local_addr or remote_addr): @@ -1024,6 +1033,7 @@ def create_server(self, protocol_factory, host=None, port=None, else: if sock is None: raise ValueError('Neither host/port nor sock were specified') + sock = _copy_and_detach_socket(sock) sockets = [sock] server = Server(self, sockets) @@ -1045,6 +1055,7 @@ def connect_accepted_socket(self, protocol_factory, sock, *, ssl=None): This method is a coroutine. When completed, the coroutine returns a (transport, protocol) pair. """ + transport, protocol = yield from self._create_connection_transport( sock, protocol_factory, ssl, '', server_side=True) if self._debug: diff --git a/tests/test_events.py b/tests/test_events.py index 7df926f1..40d71102 100644 --- a/tests/test_events.py +++ b/tests/test_events.py @@ -1240,11 +1240,13 @@ def connection_made(self, transport): sock_ob = socket.socket(type=socket.SOCK_STREAM) sock_ob.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) sock_ob.bind(('0.0.0.0', 0)) + sock_ob_fd = sock_ob.fileno() f = self.loop.create_server(TestMyProto, sock=sock_ob) server = self.loop.run_until_complete(f) sock = server.sockets[0] - self.assertIs(sock, sock_ob) + self.assertEqual(sock.fileno(), sock_ob_fd) + self.assertEqual(sock_ob.fileno(), -1) host, port = sock.getsockname() self.assertEqual(host, '0.0.0.0') diff --git a/tests/test_streams.py b/tests/test_streams.py index 35557c3c..f45d47b7 100644 --- a/tests/test_streams.py +++ b/tests/test_streams.py @@ -582,7 +582,7 @@ def start(self): asyncio.start_server(self.handle_client, sock=sock, loop=self.loop)) - return sock.getsockname() + return self.server.sockets[0].getsockname() def handle_client_callback(self, client_reader, client_writer): self.loop.create_task(self.handle_client(client_reader,