diff --git a/hyper/contrib.py b/hyper/contrib.py index 5a580f29..ff4f8ff8 100644 --- a/hyper/contrib.py +++ b/hyper/contrib.py @@ -196,3 +196,8 @@ def getheaders(self, name): orig.msg = FakeOriginalResponse(resp.headers.iter_raw()) return response + + def close(self): + for connection in self.connections.values(): + connection.close() + self.connections.clear() diff --git a/test/test_integration.py b/test/test_integration.py index bde7d393..6a9ece42 100644 --- a/test/test_integration.py +++ b/test/test_integration.py @@ -1717,3 +1717,103 @@ def socket_handler(listener): timeout=(10, 0.5)) self.tear_down() + + def test_adapter_close(self): + self.set_up(secure=False) + + def socket_handler(listener): + sock = listener.accept()[0] + + # We should get the initial request. + data = b'' + while not data.endswith(b'\r\n\r\n'): + data += sock.recv(65535) + + # We need to send back a response. + resp = ( + b'HTTP/1.1 201 No Content\r\n' + b'Server: socket-level-server\r\n' + b'Content-Length: 0\r\n' + b'Connection: close\r\n' + b'\r\n' + ) + sock.send(resp) + sock.close() + + self._start_server(socket_handler) + + a = HTTP20Adapter() + s = requests.Session() + s.mount('http://', a) + r = s.get('http://%s:%s' % (self.host, self.port)) + connections_before_close = list(a.connections.values()) + + # ensure that we have at least 1 connection + assert connections_before_close + + s.close() + + # check that connections cache is empty + assert not a.connections + + # check that all connections are actually closed + assert all(conn._sock is None for conn in connections_before_close) + + assert r.status_code == 201 + assert len(r.headers) == 3 + assert r.headers['server'] == 'socket-level-server' + assert r.headers['content-length'] == '0' + assert r.headers['connection'] == 'close' + + assert r.content == b'' + + self.tear_down() + + def test_adapter_close_context_manager(self): + self.set_up(secure=False) + + def socket_handler(listener): + sock = listener.accept()[0] + + # We should get the initial request. + data = b'' + while not data.endswith(b'\r\n\r\n'): + data += sock.recv(65535) + + # We need to send back a response. + resp = ( + b'HTTP/1.1 201 No Content\r\n' + b'Server: socket-level-server\r\n' + b'Content-Length: 0\r\n' + b'Connection: close\r\n' + b'\r\n' + ) + sock.send(resp) + sock.close() + + self._start_server(socket_handler) + + with requests.Session() as s: + a = HTTP20Adapter() + s.mount('http://', a) + r = s.get('http://%s:%s' % (self.host, self.port)) + connections_before_close = list(a.connections.values()) + + # ensure that we have at least 1 connection + assert connections_before_close + + # check that connections cache is empty + assert not a.connections + + # check that all connections are actually closed + assert all(conn._sock is None for conn in connections_before_close) + + assert r.status_code == 201 + assert len(r.headers) == 3 + assert r.headers['server'] == 'socket-level-server' + assert r.headers['content-length'] == '0' + assert r.headers['connection'] == 'close' + + assert r.content == b'' + + self.tear_down()