diff --git a/pycurl_requests/adapters/pycurl.py b/pycurl_requests/adapters/pycurl.py index 8257572..694c35d 100644 --- a/pycurl_requests/adapters/pycurl.py +++ b/pycurl_requests/adapters/pycurl.py @@ -83,6 +83,21 @@ def send(self, request, stream=False, timeout=None, verify=True, cert=None, prox return pycurl_request.send() +class ChunkIterableReader: + def __init__(self, iterator: Iterator[bytes]): + self._iterator = iterator + + def read(self, ignored) -> bytes: + return bytes(next(self._iterator, b"")) + + def close(self): # TODO + try: + self._iterator.close() + except AttributeError: + pass + + + class PyCurlRequest: def __init__(self, prepared, *, curl=None, timeout=None, allow_redirects=True, max_redirects=-1): self.prepared = prepared @@ -167,13 +182,31 @@ def send(self): if self.prepared.body is not None: if isinstance(self.prepared.body, str): body = io.BytesIO(self.prepared.body.encode('iso-8859-1')) + self.curl.setopt(pycurl.READDATA, body) elif isinstance(self.prepared.body, bytes): body = io.BytesIO(self.prepared.body) + self.curl.setopt(pycurl.READDATA, body) + elif isinstance(self.prepared.body, (io.RawIOBase, io.BufferedIOBase)): + self.curl.setopt(pycurl.READFUNCTION, self.prepared.body.read) + self.curl.setopt(pycurl.TRANSFER_ENCODING, 1) + elif hasattr(self.prepared.body, "__iter__"): # TODO: call iter instead of checking (e.g. to support delegates) + try: + n_bytes = len(self.prepared.body) + except TypeError: + # "(Since 7.66.0, libcurl will automatically use chunked encoding for POSTs if the size is unknown.)" + self.curl.setopt(pycurl.TRANSFER_ENCODING, 1) + else: + self.curl.setopt(pycurl.TRANSFER_ENCODING, 0) + self.curl.setopt(pycurl.INFILESIZE_LARGE, n_bytes) + reader = ChunkIterableReader(iter(self.prepared.body)) + self.curl.setopt(pycurl.READFUNCTION, reader.read) + # TODO: throw exceptions to the iterator (requests doesn't do this but would facilitate error handling) else: body = self.prepared.body + self.curl.setopt(pycurl.READDATA, body) self.curl.setopt(pycurl.UPLOAD, 1) - self.curl.setopt(pycurl.READDATA, body) + content_length = self.prepared.headers.get('Content-Length') if content_length is not None: diff --git a/pycurl_requests/tests/test_streaming.py b/pycurl_requests/tests/test_streaming.py new file mode 100644 index 0000000..8417405 --- /dev/null +++ b/pycurl_requests/tests/test_streaming.py @@ -0,0 +1,38 @@ +import io +from pycurl_requests import requests +from pycurl_requests.tests.utils import * + + +def test_streaming_upload_from_file(http_server): + f = io.BytesIO(test_data) + response = requests.post(http_server.base_url + '/stream', data=f) + assert response.status_code == 200 + + +def data_generator(data: bytes, chunk_size: int): + i = 0 + while True: + chunk = data[chunk_size * i: chunk_size * (i + 1)] + if len(chunk) == 0: + break + yield chunk + i += 1 + + +def test_streaming_upload_form_iterable(http_server): + response = requests.post(http_server.base_url + '/stream', data=data_generator(test_data, 123)) + assert response.status_code == 200 + + +def test_streaming_upload_form_iterable_with_known_length(http_server): + class FixedLengthIterable: + data = test_data + + def __len__(self): + return len(self.data) + + def __iter__(self): + return data_generator(data=self.data, chunk_size=123) + + response = requests.post(http_server.base_url + '/stream_no_chunked', data=FixedLengthIterable()) + assert response.status_code == 200 \ No newline at end of file diff --git a/pycurl_requests/tests/utils.py b/pycurl_requests/tests/utils.py index 73bec3d..c8203a4 100644 --- a/pycurl_requests/tests/utils.py +++ b/pycurl_requests/tests/utils.py @@ -3,6 +3,7 @@ """ import json +import random import threading import time from http import cookies @@ -13,13 +14,16 @@ from pycurl_requests import requests -__all__ = ['IS_PYCURL_REQUESTS', 'http_server'] +__all__ = ['IS_PYCURL_REQUESTS', 'http_server', 'test_data'] #: Is this _really_ PyCurl-Requests? #: Should be used when testing for PyCurl-Requests extensions. IS_PYCURL_REQUESTS = requests.__name__ == 'pycurl_requests' +test_data = bytes(random.getrandbits(8) for _ in range(123456)) + + @pytest.fixture(scope='module') def http_server(): httpd = HTTPServer(('127.0.0.1', 0), HTTPRequestHandler) @@ -93,6 +97,44 @@ def do_GET_response_headers(self): def do_HTTP_404(self): self.send_error(404, 'Not Found') + def do_POST(self): + path = self.url.path[1:].replace('/', '_') + getattr(self, f'do_POST_{path}', self.do_HTTP_404)() + + def do_POST_stream(self): + self.POST_stream_helper(allow_chunked=True) + + def do_POST_stream_no_chunked(self): + self.POST_stream_helper(allow_chunked=False) + + def POST_stream_helper(self, allow_chunked: bool): + if "Content-Length" in self.headers: + content_length = int(self.headers["Content-Length"]) + body = self.rfile.read(content_length) + elif "Transfer-Encoding" in self.headers and "chunked" in self.headers["Transfer-Encoding"]: + if not allow_chunked: + self.response('This endpoint has chunked transfer deactivated.', status=(400, "Bad Request")) + return + body = b"" + while True: + line = self.rfile.readline() + chunk_length = int(line, 16) + if chunk_length != 0: + chunk = self.rfile.read(chunk_length) + body += chunk + self.rfile.readline() + if chunk_length == 0: + break + else: + self.response('Missing Content-Length or Transfer-Encoding header.', status=(400, "Bad Request")) + return + + if body == test_data: + self.response('Upload succeeded.') + else: + self.response('Upload failed.', status=(400, "Bad Request")) + + @property def url(self): if not hasattr(self, '_url'):