diff --git a/.builder/actions/aws_crt_python.py b/.builder/actions/aws_crt_python.py index 51af165a6..4959da139 100644 --- a/.builder/actions/aws_crt_python.py +++ b/.builder/actions/aws_crt_python.py @@ -29,6 +29,7 @@ def run(self, env): # Enable S3 tests env.shell.setenv('AWS_TEST_S3', '1') + env.shell.setenv('AWS_TEST_LOCALHOST', '1') actions = [ [self.python, '--version'], diff --git a/awscrt/http.py b/awscrt/http.py index b3a89b868..1050210c5 100644 --- a/awscrt/http.py +++ b/awscrt/http.py @@ -112,6 +112,26 @@ def new(cls, If successful, the Future will contain a new :class:`HttpClientConnection`. Otherwise, it will contain an exception. """ + return HttpClientConnection._new_common( + host_name, + port, + bootstrap, + socket_options, + tls_connection_options, + proxy_options) + + @staticmethod + def _new_common( + host_name, + port, + bootstrap=None, + socket_options=None, + tls_connection_options=None, + proxy_options=None, + expected_version=None): + """ + Initialize the generic part of the HttpClientConnection class. + """ assert isinstance(bootstrap, ClientBootstrap) or bootstrap is None assert isinstance(host_name, str) assert isinstance(port, int) @@ -120,6 +140,7 @@ def new(cls, assert isinstance(proxy_options, HttpProxyOptions) or proxy_options is None future = Future() + try: if not socket_options: socket_options = SocketOptions() @@ -127,37 +148,22 @@ def new(cls, if not bootstrap: bootstrap = ClientBootstrap.get_or_create_static_default() - connection = cls() - connection._host_name = host_name - connection._port = port - - def on_connection_setup(binding, error_code, http_version): - if error_code == 0: - connection._binding = binding - connection._version = HttpVersion(http_version) - future.set_result(connection) - else: - future.set_exception(awscrt.exceptions.from_code(error_code)) - - # on_shutdown MUST NOT reference the connection itself, just the shutdown_future within it. - # Otherwise we create a circular reference that prevents the connection from getting GC'd. - shutdown_future = connection.shutdown_future - - def on_shutdown(error_code): - if error_code: - shutdown_future.set_exception(awscrt.exceptions.from_code(error_code)) - else: - shutdown_future.set_result(None) + connection_core = _HttpClientConnectionCore( + host_name, + port, + bootstrap=bootstrap, + tls_connection_options=tls_connection_options, + connect_future=future, + expected_version=expected_version) _awscrt.http_client_connection_new( bootstrap, - on_connection_setup, - on_shutdown, host_name, port, socket_options, tls_connection_options, - proxy_options) + proxy_options, + connection_core) except Exception as e: future.set_exception(e) @@ -219,6 +225,33 @@ def request(self, request, on_response=None, on_body=None): return HttpClientStream(self, request, on_response, on_body) +class Http2ClientConnection(HttpClientConnection): + """ + HTTP/2 client connection. + + This class extends HttpClientConnection with HTTP/2 specific functionality. + """ + @classmethod + def new(cls, + host_name, + port, + bootstrap=None, + socket_options=None, + tls_connection_options=None, + proxy_options=None): + return HttpClientConnection._new_common( + host_name, + port, + bootstrap, + socket_options, + tls_connection_options, + proxy_options, + HttpVersion.Http2) + + def request(self, request, on_response=None, on_body=None, manual_write=False): + return Http2ClientStream(self, request, on_response, on_body, manual_write) + + class HttpStreamBase(NativeResource): """Base for HTTP stream classes""" __slots__ = ('_connection', '_completion_future', '_on_body_cb') @@ -258,9 +291,12 @@ class HttpClientStream(HttpStreamBase): completes. If the exchange fails to complete, the Future will contain an exception indicating why it failed. """ - __slots__ = ('_response_status_code', '_on_response_cb', '_on_body_cb', '_request') + __slots__ = ('_response_status_code', '_on_response_cb', '_on_body_cb', '_request', '_version') def __init__(self, connection, request, on_response=None, on_body=None): + self._generic_init(connection, request, on_response, on_body) + + def _generic_init(self, connection, request, on_response=None, on_body=None, http2_manual_write=False): assert isinstance(connection, HttpClientConnection) assert isinstance(request, HttpRequest) assert callable(on_response) or on_response is None @@ -273,8 +309,14 @@ def __init__(self, connection, request, on_response=None, on_body=None): # keep HttpRequest alive until stream completes self._request = request + self._version = connection.version - self._binding = _awscrt.http_client_stream_new(self, connection, request) + self._binding = _awscrt.http_client_stream_new(self, connection, request, http2_manual_write) + + @property + def version(self): + """HttpVersion: Protocol used by this stream""" + return self._version @property def response_status_code(self): @@ -307,6 +349,24 @@ def _on_complete(self, error_code): self._completion_future.set_exception(awscrt.exceptions.from_code(error_code)) +class Http2ClientStream(HttpClientStream): + def __init__(self, connection, request, on_response=None, on_body=None, manual_write=False): + super()._generic_init(connection, request, on_response, on_body, manual_write) + + def write_data(self, data_stream, end_stream=False): + future = Future() + body_stream = InputStream.wrap(data_stream, allow_none=True) + + def on_write_complete(error_code): + if error_code: + future.set_exception(awscrt.exceptions.from_code(error_code)) + else: + future.set_result(None) + + _awscrt.http2_client_stream_write_data(self, body_stream, end_stream, on_write_complete) + return future + + class HttpMessageBase(NativeResource): """ Base for HttpRequest and HttpResponse classes. @@ -625,3 +685,58 @@ def __init__(self, self.auth_username = auth_username self.auth_password = auth_password self.connection_type = connection_type + + +class _HttpClientConnectionCore: + ''' + Private class to keep all the related Python object alive until C land clean up for HttpClientConnection + ''' + + def __init__( + self, + host_name, + port, + bootstrap=None, + tls_connection_options=None, + connect_future=None, + expected_version=None): + self._shutdown_future = None + self._host_name = host_name + self._port = port + self._bootstrap = bootstrap + self._tls_connection_options = tls_connection_options + self._connect_future = connect_future + self._expected_version = expected_version + + def _on_connection_setup(self, binding, error_code, http_version): + if self._expected_version and self._expected_version != http_version: + # unexpected protocol version + # AWS_ERROR_HTTP_UNSUPPORTED_PROTOCOL + self._connect_future.set_exception(awscrt.exceptions.from_code(2060)) + return + if error_code != 0: + self._connect_future.set_exception(awscrt.exceptions.from_code(error_code)) + return + if http_version == HttpVersion.Http2: + connection = Http2ClientConnection() + else: + connection = HttpClientConnection() + + connection._host_name = self._host_name + connection._port = self._port + + connection._binding = binding + connection._version = HttpVersion(http_version) + self._shutdown_future = connection.shutdown_future + self._connect_future.set_result(connection) + # release reference to the future, as it points to connection which creates a cycle reference. + self._connect_future = None + + def _on_shutdown(self, error_code): + if self._shutdown_future is None: + # connection failed, ignore shutdown + return + if error_code: + self._shutdown_future.set_exception(awscrt.exceptions.from_code(error_code)) + else: + self._shutdown_future.set_result(None) diff --git a/crt/aws-c-cal b/crt/aws-c-cal index ff8801488..fa108de52 160000 --- a/crt/aws-c-cal +++ b/crt/aws-c-cal @@ -1 +1 @@ -Subproject commit ff8801488d588067d021d131193681b591699477 +Subproject commit fa108de5280afd71018e0a0534edb36b33f030f6 diff --git a/crt/aws-c-http b/crt/aws-c-http index e526ac338..6586c80ed 160000 --- a/crt/aws-c-http +++ b/crt/aws-c-http @@ -1 +1 @@ -Subproject commit e526ac338ca414c01d3fc037da1c418c935808bc +Subproject commit 6586c80edc09a07d3e6db6bf82c4b53aefdfe895 diff --git a/crt/aws-checksums b/crt/aws-checksums index 66b447c07..9978ba2c3 160000 --- a/crt/aws-checksums +++ b/crt/aws-checksums @@ -1 +1 @@ -Subproject commit 66b447c0765a2caff2d806111e6ec1db2383e4d2 +Subproject commit 9978ba2c33a7a259c1a6bd0f62abe26827d03b85 diff --git a/pyproject.toml b/pyproject.toml index 2ad2358a5..14cf1d81f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,4 +35,5 @@ dev = [ "build>=1.2.2", # for building wheels "sphinx>=7.2.6,<7.3; python_version >= '3.9'", # for building docs "websockets>=13.1", # for tests + "h2", # for tests ] diff --git a/source/http.h b/source/http.h index 549706eb3..2cfbf91e2 100644 --- a/source/http.h +++ b/source/http.h @@ -37,9 +37,10 @@ PyObject *aws_py_http_connection_is_open(PyObject *self, PyObject *args); PyObject *aws_py_http_client_connection_new(PyObject *self, PyObject *args); PyObject *aws_py_http_client_stream_new(PyObject *self, PyObject *args); - PyObject *aws_py_http_client_stream_activate(PyObject *self, PyObject *args); +PyObject *aws_py_http2_client_stream_write_data(PyObject *self, PyObject *args); + /* Create capsule around new request-style aws_http_message struct */ PyObject *aws_py_http_message_new_request(PyObject *self, PyObject *args); diff --git a/source/http_connection.c b/source/http_connection.c index 73ebac254..bf27047f5 100644 --- a/source/http_connection.c +++ b/source/http_connection.c @@ -22,26 +22,15 @@ static const char *s_capsule_name_http_connection = "aws_http_connection"; */ struct http_connection_binding { struct aws_http_connection *native; + /* Reference to python object that reference to other related python object to keep it alive */ + PyObject *py_core; bool release_called; bool shutdown_called; - - /* Setup callback, reference cleared after invoking */ - PyObject *on_setup; - - /* Shutdown callback, reference cleared after setting result */ - PyObject *on_shutdown; - - /* Dependencies that must outlive this */ - PyObject *bootstrap; - PyObject *tls_ctx; }; static void s_connection_destroy(struct http_connection_binding *connection) { - Py_XDECREF(connection->on_setup); - Py_XDECREF(connection->on_shutdown); - Py_XDECREF(connection->bootstrap); - Py_XDECREF(connection->tls_ctx); + Py_XDECREF(connection->py_core); aws_mem_release(aws_py_get_allocator(), connection); } @@ -84,14 +73,14 @@ static void s_on_connection_shutdown(struct aws_http_connection *native_connecti bool destroy_after_shutdown = connection->release_called; /* Invoke on_shutdown, then clear our reference to it */ - PyObject *result = PyObject_CallFunction(connection->on_shutdown, "(i)", error_code); + PyObject *result = PyObject_CallMethod(connection->py_core, "_on_shutdown", "(i)", error_code); + if (result) { Py_DECREF(result); } else { /* Callback might fail during application shutdown */ PyErr_WriteUnraisable(PyErr_Occurred()); } - Py_CLEAR(connection->on_shutdown); if (destroy_after_shutdown) { s_connection_destroy(connection); @@ -107,7 +96,6 @@ static void s_on_client_connection_setup( struct http_connection_binding *connection = user_data; AWS_FATAL_ASSERT((native_connection != NULL) ^ error_code); - AWS_FATAL_ASSERT(connection->on_setup); connection->native = native_connection; @@ -126,9 +114,8 @@ static void s_on_client_connection_setup( http_version = aws_http_connection_get_version(native_connection); } - /* Invoke on_setup, then clear our reference to it */ - PyObject *result = - PyObject_CallFunction(connection->on_setup, "(Oii)", capsule ? capsule : Py_None, error_code, http_version); + PyObject *result = PyObject_CallMethod( + connection->py_core, "_on_connection_setup", "(Oii)", capsule ? capsule : Py_None, error_code, http_version); if (result) { Py_DECREF(result); @@ -137,8 +124,6 @@ static void s_on_client_connection_setup( PyErr_WriteUnraisable(PyErr_Occurred()); } - Py_CLEAR(connection->on_setup); - if (native_connection) { /* Connection exists, but failed to create capsule. Release connection, which eventually destroys binding */ if (!capsule) { @@ -159,27 +144,25 @@ PyObject *aws_py_http_client_connection_new(PyObject *self, PyObject *args) { struct aws_allocator *allocator = aws_py_get_allocator(); PyObject *bootstrap_py; - PyObject *on_connection_setup_py; - PyObject *on_shutdown_py; const char *host_name; Py_ssize_t host_name_len; uint32_t port_number; PyObject *socket_options_py; PyObject *tls_options_py; PyObject *proxy_options_py; + PyObject *py_core; if (!PyArg_ParseTuple( args, - "OOOs#IOOO", + "Os#IOOOO", &bootstrap_py, - &on_connection_setup_py, - &on_shutdown_py, &host_name, &host_name_len, &port_number, &socket_options_py, &tls_options_py, - &proxy_options_py)) { + &proxy_options_py, + &py_core)) { return NULL; } @@ -201,12 +184,6 @@ PyObject *aws_py_http_client_connection_new(PyObject *self, PyObject *args) { if (!tls_options) { goto error; } - - connection->tls_ctx = PyObject_GetAttrString(tls_options_py, "tls_ctx"); /* Creates new reference */ - if (!connection->tls_ctx || connection->tls_ctx == Py_None) { - PyErr_SetString(PyExc_TypeError, "tls_connection_options.tls_ctx is invalid"); - goto error; - } } struct aws_socket_options socket_options; @@ -239,12 +216,8 @@ PyObject *aws_py_http_client_connection_new(PyObject *self, PyObject *args) { .on_shutdown = s_on_connection_shutdown, }; - connection->on_setup = on_connection_setup_py; - Py_INCREF(connection->on_setup); - connection->on_shutdown = on_shutdown_py; - Py_INCREF(connection->on_shutdown); - connection->bootstrap = bootstrap_py; - Py_INCREF(connection->bootstrap); + connection->py_core = py_core; + Py_INCREF(connection->py_core); if (aws_http_client_connect(&http_options)) { PyErr_SetAwsLastError(); diff --git a/source/http_stream.c b/source/http_stream.c index 6843e0ea8..cb93af9b0 100644 --- a/source/http_stream.c +++ b/source/http_stream.c @@ -3,6 +3,7 @@ * SPDX-License-Identifier: Apache-2.0. */ #include "http.h" +#include "io.h" #include @@ -235,7 +236,8 @@ PyObject *aws_py_http_client_stream_new(PyObject *self, PyObject *args) { PyObject *py_stream = NULL; PyObject *py_connection = NULL; PyObject *py_request = NULL; - if (!PyArg_ParseTuple(args, "OOO", &py_stream, &py_connection, &py_request)) { + int http2_manual_write = 0; + if (!PyArg_ParseTuple(args, "OOOp", &py_stream, &py_connection, &py_request, &http2_manual_write)) { return NULL; } @@ -282,6 +284,7 @@ PyObject *aws_py_http_client_stream_new(PyObject *self, PyObject *args) { .on_response_body = s_on_incoming_body, .on_complete = s_on_stream_complete, .user_data = stream, + .http2_use_manual_data_writes = http2_manual_write, }; stream->native = aws_http_connection_make_request(native_connection, &request_options); @@ -323,3 +326,68 @@ PyObject *aws_py_http_client_stream_activate(PyObject *self, PyObject *args) { Py_RETURN_NONE; } + +static void s_on_http2_write_data_complete(struct aws_http_stream *stream, int error_code, void *user_data) { + (void)stream; + PyObject *py_on_write_complete = (PyObject *)user_data; + AWS_FATAL_ASSERT(py_on_write_complete); + PyGILState_STATE state; + if (aws_py_gilstate_ensure(&state)) { + return; /* Python has shut down. Nothing matters anymore, but don't crash */ + } + + /* Invoke on_setup, then clear our reference to it */ + PyObject *result = PyObject_CallFunction(py_on_write_complete, "(i)", error_code); + if (result) { + Py_DECREF(result); + } else { + /* Callback might fail during application shutdown */ + PyErr_WriteUnraisable(PyErr_Occurred()); + } + Py_DECREF(py_on_write_complete); + PyGILState_Release(state); +} + +PyObject *aws_py_http2_client_stream_write_data(PyObject *self, PyObject *args) { + (void)self; + + PyObject *py_stream = NULL; + PyObject *py_body_stream = NULL; + int end_stream = false; + PyObject *py_on_write_complete = NULL; + if (!PyArg_ParseTuple(args, "OOpO", &py_stream, &py_body_stream, &end_stream, &py_on_write_complete)) { + return NULL; + } + + struct aws_http_stream *http_stream = aws_py_get_http_stream(py_stream); + if (!http_stream) { + return NULL; + } + + struct aws_input_stream *body_stream = NULL; + // Write an empty stream is allowed. + if (py_body_stream != Py_None) { + /* The py_body_stream has the same lifetime as the C stream, no need to keep it alive from this binding. */ + body_stream = aws_py_get_input_stream(py_body_stream); + if (!body_stream) { + return PyErr_AwsLastError(); + } + } + + /* Make sure the python callback live long enough for C to call. */ + Py_INCREF(py_on_write_complete); + + struct aws_http2_stream_write_data_options write_options = { + .data = body_stream, + .end_stream = end_stream, + .on_complete = s_on_http2_write_data_complete, + .user_data = py_on_write_complete, + }; + + int error = aws_http2_stream_write_data(http_stream, &write_options); + if (error) { + Py_DECREF(py_on_write_complete); + return PyErr_AwsLastError(); + } + Py_RETURN_NONE; +} diff --git a/source/module.c b/source/module.c index e60fc0449..4ad2bd002 100644 --- a/source/module.c +++ b/source/module.c @@ -487,7 +487,8 @@ int aws_py_translate_py_error(void) { } /* Print standard traceback to sys.stderr and clear the error indicator. */ - PyErr_Print(); + /* Handles the exception in C, do not set the last vars for python. */ + PyErr_PrintEx(0 /*set_sys_last_vars*/); fprintf(stderr, "Treating Python exception as error %d(%s)\n", aws_error_code, aws_error_name(aws_error_code)); return aws_error_code; @@ -820,6 +821,7 @@ static PyMethodDef s_module_methods[] = { AWS_PY_METHOD_DEF(http_client_connection_new, METH_VARARGS), AWS_PY_METHOD_DEF(http_client_stream_new, METH_VARARGS), AWS_PY_METHOD_DEF(http_client_stream_activate, METH_VARARGS), + AWS_PY_METHOD_DEF(http2_client_stream_write_data, METH_VARARGS), AWS_PY_METHOD_DEF(http_message_new_request, METH_VARARGS), AWS_PY_METHOD_DEF(http_message_get_request_method, METH_VARARGS), AWS_PY_METHOD_DEF(http_message_set_request_method, METH_VARARGS), diff --git a/test/test_http_client.py b/test/test_http_client.py index dd95d85aa..4bfd8f256 100644 --- a/test/test_http_client.py +++ b/test/test_http_client.py @@ -2,9 +2,9 @@ # SPDX-License-Identifier: Apache-2.0. import awscrt.exceptions -from awscrt.http import HttpClientConnection, HttpClientStream, HttpHeaders, HttpProxyOptions, HttpRequest, HttpVersion +from awscrt.http import HttpClientConnection, HttpClientStream, HttpHeaders, HttpProxyOptions, HttpRequest, HttpVersion, Http2ClientConnection from awscrt.io import ClientBootstrap, ClientTlsContext, DefaultHostResolver, EventLoopGroup, TlsConnectionOptions, TlsContextOptions, TlsCipherPref -from concurrent.futures import Future +from concurrent.futures import Future, thread from http.server import HTTPServer, SimpleHTTPRequestHandler from io import BytesIO import os @@ -13,6 +13,10 @@ import threading import unittest from urllib.parse import urlparse +import subprocess +import sys +import socket +import time class Response: @@ -45,7 +49,7 @@ def do_PUT(self): class TestClient(NativeResourceTest): hostname = 'localhost' - timeout = 10 # seconds + timeout = 5 # seconds def _start_server(self, secure, http_1_0=False): # HTTP/1.0 closes the connection at the end of each request @@ -332,18 +336,21 @@ def _new_h2_client_connection(self, url): host_resolver = DefaultHostResolver(event_loop_group) bootstrap = ClientBootstrap(event_loop_group, host_resolver) - port = 443 - scheme = 'https' + port = url.port + # only test https + if port is None: + port = 443 tls_ctx_options = TlsContextOptions() + tls_ctx_options.verify_peer = False # allow localhost tls_ctx = ClientTlsContext(tls_ctx_options) tls_conn_opt = tls_ctx.new_connection_options() tls_conn_opt.set_server_name(url.hostname) tls_conn_opt.set_alpn_list(["h2"]) - connection_future = HttpClientConnection.new(host_name=url.hostname, - port=port, - bootstrap=bootstrap, - tls_connection_options=tls_conn_opt) + connection_future = Http2ClientConnection.new(host_name=url.hostname, + port=port, + bootstrap=bootstrap, + tls_connection_options=tls_conn_opt) return connection_future.result(self.timeout) def test_h2_client(self): @@ -368,10 +375,202 @@ def test_h2_client(self): self.assertEqual(None, connection.close().exception(self.timeout)) + def test_h2_manual_write_exception(self): + url = urlparse("https://d1cz66xoahf9cl.cloudfront.net/http_test_doc.txt") + connection = self._new_h2_client_connection(url) + # check we set an h2 connection + self.assertEqual(connection.version, HttpVersion.Http2) + + request = HttpRequest('GET', url.path) + request.headers.add('host', url.hostname) + response = Response() + stream = connection.request(request, response.on_response, response.on_body) + stream.activate() + exception = None + try: + # If the stream is not configured to allow manual writes, this should throw an exception directly + stream.write_data(BytesIO(b'hello'), False) + except RuntimeError as e: + exception = e + self.assertIsNotNone(exception) + + self.assertEqual(None, connection.close().exception(self.timeout)) + @unittest.skipIf(not TlsCipherPref.PQ_DEFAULT.is_supported(), "Cipher pref not supported") def test_connect_pq_default(self): self._test_connect(secure=True, cipher_pref=TlsCipherPref.PQ_DEFAULT) +@unittest.skipUnless(os.environ.get('AWS_TEST_LOCALHOST'), 'set env var to run test: AWS_TEST_LOCALHOST') +class TestClientMockServer(NativeResourceTest): + + timeout = 5 # seconds + p_server = None + mock_server_url = None + + def setUp(self): + super().setUp() + # Start the mock server from the aws-c-http. + server_path = os.path.join( + os.path.dirname(__file__), + '..', + 'crt', + 'aws-c-http', + 'tests', + 'py_localhost', + 'server.py') + python_path = sys.executable + self.mock_server_url = urlparse("https://localhost:3443/upload_test") + self.p_server = subprocess.Popen([python_path, server_path]) + # Wait for server to be ready + self._wait_for_server_ready() + + def _wait_for_server_ready(self): + """Wait until server is accepting connections.""" + max_attempts = 20 + + for attempt in range(max_attempts): + try: + with socket.create_connection(("127.0.0.1", self.mock_server_url.port), timeout=1): + return # Server is ready + except (ConnectionRefusedError, socket.timeout): + time.sleep(0.5) + + # If we get here, server failed to start + stdout, stderr = self.p_server.communicate(timeout=0.5) + raise RuntimeError(f"Server failed to start after {max_attempts} attempts.\n" + f"STDOUT: {stdout.decode()}\nSTDERR: {stderr.decode()}") + + def tearDown(self): + self.p_server.terminate() + try: + self.p_server.wait(timeout=5) + except subprocess.TimeoutExpired: + self.p_server.kill() + super().tearDown() + + def _new_mock_connection(self): + + event_loop_group = EventLoopGroup() + host_resolver = DefaultHostResolver(event_loop_group) + bootstrap = ClientBootstrap(event_loop_group, host_resolver) + + port = self.mock_server_url.port + # only test https + if port is None: + port = 443 + tls_ctx_options = TlsContextOptions() + tls_ctx_options.verify_peer = False # allow localhost + tls_ctx = ClientTlsContext(tls_ctx_options) + tls_conn_opt = tls_ctx.new_connection_options() + tls_conn_opt.set_server_name(self.mock_server_url.hostname) + tls_conn_opt.set_alpn_list(["h2"]) + + connection_future = Http2ClientConnection.new(host_name=self.mock_server_url.hostname, + port=port, + bootstrap=bootstrap, + tls_connection_options=tls_conn_opt) + return connection_future.result(self.timeout) + + def test_h2_mock_server_manual_write(self): + connection = self._new_mock_connection() + # check we set an h2 connection + self.assertEqual(connection.version, HttpVersion.Http2) + + request = HttpRequest('POST', self.mock_server_url.path) + request.headers.add('host', self.mock_server_url.hostname) + response = Response() + stream = connection.request(request, response.on_response, response.on_body, manual_write=True) + stream.activate() + exception = None + try: + # If the stream is not configured to allow manual writes, this should throw an exception directly + f = stream.write_data(BytesIO(b'hello'), False) + f.result(self.timeout) + stream.write_data(BytesIO(b'he123123'), False) + stream.write_data(None, False) + stream.write_data(BytesIO(b'hello'), True) + except RuntimeError as e: + exception = e + self.assertIsNone(exception) + stream_completion_result = stream.completion_future.result(80) + # check result + self.assertEqual(200, response.status_code) + self.assertEqual(200, stream_completion_result) + print(response.body) + + self.assertEqual(None, connection.close().exception(self.timeout)) + + class DelayStream: + def __init__(self, bad_read=False): + self._read = False + self.bad_read = bad_read + + def read(self, _len): + if self.bad_read: + # simulate a bad read that raises an exception + # this will cause the stream to fail + raise RuntimeError("bad read exception") + if self._read: + # return empty as EOS + return b'' + else: + self._read = True + return b'hello' + + def test_h2_mock_server_manual_write_read_exception(self): + connection = self._new_mock_connection() + # check we set an h2 connection + self.assertEqual(connection.version, HttpVersion.Http2) + + request = HttpRequest('POST', self.mock_server_url.path) + request.headers.add('host', self.mock_server_url.hostname) + response = Response() + stream = connection.request(request, response.on_response, response.on_body, manual_write=True) + stream.activate() + exception = None + data = self.DelayStream(bad_read=True) + try: + f = stream.write_data(data, False) + f.result(self.timeout) + except Exception as e: + # future will raise the exception from the write_data call. + exception = e + self.assertIsNotNone(exception) + # stream will complete with same exception. + stream_completion_exception = stream.completion_future.exception() + self.assertIsNotNone(stream_completion_exception) + # assert that the exception is the same as the one we got from write_data. + self.assertEqual(str(exception), str(stream_completion_exception)) + self.assertEqual(None, connection.close().exception(self.timeout)) + + def test_h2_mock_server_manual_write_lifetime(self): + connection = self._new_mock_connection() + # check we set an h2 connection + self.assertEqual(connection.version, HttpVersion.Http2) + + request = HttpRequest('POST', self.mock_server_url.path) + request.headers.add('host', self.mock_server_url.hostname) + response = Response() + stream = connection.request(request, response.on_response, response.on_body, manual_write=True) + stream.activate() + exception = None + data = self.DelayStream(bad_read=False) + try: + f = stream.write_data(data, False) + # make sure when the python object was dropped, things are still ok + del data + f.result(self.timeout) + f = stream.write_data(None, True) + f.result(self.timeout) + except Exception as e: + # future will raise the exception from the write_data call. + exception = e + self.assertIsNone(exception) + # stream will complete with another exception. + stream.completion_future.result() + self.assertEqual(None, connection.close().exception(self.timeout)) + + if __name__ == '__main__': unittest.main() diff --git a/test/test_s3.py b/test/test_s3.py index 5550645e0..4e8785eed 100644 --- a/test/test_s3.py +++ b/test/test_s3.py @@ -691,7 +691,7 @@ def test_fork_workaround(self): mp.set_start_method('fork', force=True) process = Process(target=self.fork_s3_client) process.start() - process.join(10) + process.join() self.assertEqual(0, process.exitcode) self.upload_with_global_client() del CRT_S3_CLIENT