From 61dee1e60ace80da419d052072d740272a4b0d77 Mon Sep 17 00:00:00 2001 From: Dylan Anthony Date: Sat, 20 Mar 2021 12:14:35 -0600 Subject: [PATCH 1/2] fix: Allow passing data with files in multipart. (Fixes #351) --- .../tests/upload_file_tests_upload_post.py | 13 ++++++++++-- .../body_upload_file_tests_upload_post.py | 12 +++++++++-- end_to_end_tests/openapi.json | 5 +++++ .../templates/endpoint_module.py.jinja | 21 +++++++++++++------ 4 files changed, 41 insertions(+), 10 deletions(-) diff --git a/end_to_end_tests/golden-record/my_test_api_client/api/tests/upload_file_tests_upload_post.py b/end_to_end_tests/golden-record/my_test_api_client/api/tests/upload_file_tests_upload_post.py index 2ef1278bc..705443d95 100644 --- a/end_to_end_tests/golden-record/my_test_api_client/api/tests/upload_file_tests_upload_post.py +++ b/end_to_end_tests/golden-record/my_test_api_client/api/tests/upload_file_tests_upload_post.py @@ -5,7 +5,7 @@ from ...client import Client from ...models.body_upload_file_tests_upload_post import BodyUploadFileTestsUploadPost from ...models.http_validation_error import HTTPValidationError -from ...types import UNSET, Response, Unset +from ...types import UNSET, File, Response, Unset def _get_kwargs( @@ -22,12 +22,21 @@ def _get_kwargs( if keep_alive is not UNSET: headers["keep-alive"] = keep_alive + files = {} + data = {} + for key, value in multipart_data.to_dict().items(): + if isinstance(value, File): + files[key] = value + else: + data[key] = value + return { "url": url, "headers": headers, "cookies": cookies, "timeout": client.get_timeout(), - "files": multipart_data.to_dict(), + "files": files, + "data": data, } diff --git a/end_to_end_tests/golden-record/my_test_api_client/models/body_upload_file_tests_upload_post.py b/end_to_end_tests/golden-record/my_test_api_client/models/body_upload_file_tests_upload_post.py index 97db03356..a250dbb37 100644 --- a/end_to_end_tests/golden-record/my_test_api_client/models/body_upload_file_tests_upload_post.py +++ b/end_to_end_tests/golden-record/my_test_api_client/models/body_upload_file_tests_upload_post.py @@ -1,9 +1,9 @@ from io import BytesIO -from typing import Any, Dict, Type, TypeVar +from typing import Any, Dict, Type, TypeVar, Union import attr -from ..types import File +from ..types import UNSET, File, Unset T = TypeVar("T", bound="BodyUploadFileTestsUploadPost") @@ -13,16 +13,21 @@ class BodyUploadFileTestsUploadPost: """ """ some_file: File + some_string: Union[Unset, str] = "some_default_string" def to_dict(self) -> Dict[str, Any]: some_file = self.some_file.to_tuple() + some_string = self.some_string + field_dict: Dict[str, Any] = {} field_dict.update( { "some_file": some_file, } ) + if some_string is not UNSET: + field_dict["some_string"] = some_string return field_dict @@ -31,8 +36,11 @@ def from_dict(cls: Type[T], src_dict: Dict[str, Any]) -> T: d = src_dict.copy() some_file = File(payload=BytesIO(d.pop("some_file"))) + some_string = d.pop("some_string", UNSET) + body_upload_file_tests_upload_post = cls( some_file=some_file, + some_string=some_string, ) return body_upload_file_tests_upload_post diff --git a/end_to_end_tests/openapi.json b/end_to_end_tests/openapi.json index 3cc49b58d..5bd42221a 100644 --- a/end_to_end_tests/openapi.json +++ b/end_to_end_tests/openapi.json @@ -923,6 +923,11 @@ "title": "Some File", "type": "string", "format": "binary" + }, + "some_string": { + "title": "Some String", + "type": "string", + "default": "some_default_string" } }, "additionalProperties": false diff --git a/openapi_python_client/templates/endpoint_module.py.jinja b/openapi_python_client/templates/endpoint_module.py.jinja index bd738073e..5b1c434bc 100644 --- a/openapi_python_client/templates/endpoint_module.py.jinja +++ b/openapi_python_client/templates/endpoint_module.py.jinja @@ -4,7 +4,7 @@ import httpx from attr import asdict from ...client import AuthenticatedClient, Client -from ...types import Response, UNSET +from ...types import Response, UNSET{% if endpoint.multipart_body_reference %}, File {% endif %} {% for relative in endpoint.relative_imports %} {{ relative }} @@ -36,6 +36,16 @@ def _get_kwargs( {{ json_body(endpoint) | indent(4) }} + {% if endpoint.multipart_body_reference %} + files = {} + data = {} + for key, value in multipart_data.to_dict().items(): + if isinstance(value, File): + files[key] = value + else: + data[key] = value + {% endif %} + return { "url": url, "headers": headers, @@ -43,11 +53,10 @@ def _get_kwargs( "timeout": client.get_timeout(), {% if endpoint.form_body_reference %} "data": asdict(form_data), - {% endif %} - {% if endpoint.multipart_body_reference %} - "files": multipart_data.to_dict(), - {% endif %} - {% if endpoint.json_body %} + {% elif endpoint.multipart_body_reference %} + "files": files, + "data": data, + {% elif endpoint.json_body %} "json": {{ "json_" + endpoint.json_body.python_name }}, {% endif %} {% if endpoint.query_parameters %} From 78238b3501b2297dfa8e969597cf0099e7c45aa8 Mon Sep 17 00:00:00 2001 From: Dylan Anthony Date: Sat, 20 Mar 2021 12:32:14 -0600 Subject: [PATCH 2/2] ci: Remove hard-to-use unit test --- tests/test_templates/endpoint_module.py | 129 ------------------- tests/test_templates/test_endpoint_module.py | 56 -------- 2 files changed, 185 deletions(-) delete mode 100644 tests/test_templates/endpoint_module.py delete mode 100644 tests/test_templates/test_endpoint_module.py diff --git a/tests/test_templates/endpoint_module.py b/tests/test_templates/endpoint_module.py deleted file mode 100644 index d0b6a9b03..000000000 --- a/tests/test_templates/endpoint_module.py +++ /dev/null @@ -1,129 +0,0 @@ -from typing import Any, Dict, List, Optional, Union, cast - -import httpx -from attr import asdict - -from ...client import AuthenticatedClient, Client -from ...types import Response, UNSET - -import this -from __future__ import braces - - -def _get_kwargs( - *, - client: AuthenticatedClient, - form_data: FormBody, - multipart_data: MultiPartBody, - json_body: Json, -) -> Dict[str, Any]: - url = "{}/post/".format(client.base_url) - - headers: Dict[str, Any] = client.get_headers() - cookies: Dict[str, Any] = client.get_cookies() - - return { - "url": url, - "headers": headers, - "cookies": cookies, - "timeout": client.get_timeout(), - "data": asdict(form_data), - "files": multipart_data.to_dict(), - "json": json_json_body, - } - - -def _parse_response(*, response: httpx.Response) -> Optional[Union[str, int]]: - if response.status_code == 200: - response_one = response.json() - return response_one - if response.status_code == 201: - response_one = response.json() - return response_one - return None - - -def _build_response(*, response: httpx.Response) -> Response[Union[str, int]]: - return Response( - status_code=response.status_code, - content=response.content, - headers=response.headers, - parsed=_parse_response(response=response), - ) - - -def sync_detailed( - *, - client: AuthenticatedClient, - form_data: FormBody, - multipart_data: MultiPartBody, - json_body: Json, -) -> Response[Union[str, int]]: - kwargs = _get_kwargs( - client=client, - form_data=form_data, - multipart_data=multipart_data, - json_body=json_body, - ) - - response = httpx.post( - **kwargs, - ) - - return _build_response(response=response) - - -def sync( - *, - client: AuthenticatedClient, - form_data: FormBody, - multipart_data: MultiPartBody, - json_body: Json, -) -> Optional[Union[str, int]]: - """ POST endpoint """ - - return sync_detailed( - client=client, - form_data=form_data, - multipart_data=multipart_data, - json_body=json_body, - ).parsed - - -async def asyncio_detailed( - *, - client: AuthenticatedClient, - form_data: FormBody, - multipart_data: MultiPartBody, - json_body: Json, -) -> Response[Union[str, int]]: - kwargs = _get_kwargs( - client=client, - form_data=form_data, - multipart_data=multipart_data, - json_body=json_body, - ) - - async with httpx.AsyncClient() as _client: - response = await _client.post(**kwargs) - - return _build_response(response=response) - - -async def asyncio( - *, - client: AuthenticatedClient, - form_data: FormBody, - multipart_data: MultiPartBody, - json_body: Json, -) -> Optional[Union[str, int]]: - """ POST endpoint """ - - return ( - await asyncio_detailed( - client=client, - form_data=form_data, - multipart_data=multipart_data, - json_body=json_body, - ) - ).parsed diff --git a/tests/test_templates/test_endpoint_module.py b/tests/test_templates/test_endpoint_module.py deleted file mode 100644 index 868c51c58..000000000 --- a/tests/test_templates/test_endpoint_module.py +++ /dev/null @@ -1,56 +0,0 @@ -from pathlib import Path - -import pytest -from jinja2 import Template - - -@pytest.fixture(scope="session") -def template(env) -> Template: - return env.get_template("endpoint_module.py.jinja") - - -def test_async_module(template, mocker): - path_param = mocker.MagicMock(python_name="path_param_1") - path_param.name = "pathParam1" - path_param.to_string.return_value = "path_param_1: str" - query_param = mocker.MagicMock(template=None, python_name="query_param_1") - query_param.name = "queryParam" - query_param.to_string.return_value = "query_param_1: str" - header_param = mocker.MagicMock(template=None, python_name="header_param_1") - header_param.name = "headerParam" - header_param.to_string.return_value = "header_param_1: str" - - form_body_reference = mocker.MagicMock(class_name="FormBody") - multipart_body_reference = mocker.MagicMock(class_name="MultiPartBody") - json_body = mocker.MagicMock(template=None, python_name="json_body") - json_body.get_type_string.return_value = "Json" - post_response_1 = mocker.MagicMock( - status_code=200, source="response.json()", prop=mocker.MagicMock(template=None, python_name="response_one") - ) - post_response_1.prop.get_type_string.return_value = "str" - post_response_2 = mocker.MagicMock( - status_code=201, source="response.json()", prop=mocker.MagicMock(template=None, python_name="response_one") - ) - post_response_2.prop.get_type_string.return_value = "int" - post_endpoint = mocker.MagicMock( - name="camelCase", - requires_security=True, - path_parameters=[], - query_parameters=[], - form_body_reference=form_body_reference, - multipart_body_reference=multipart_body_reference, - json_body=json_body, - responses=[post_response_1, post_response_2], - description="POST endpoint", - path="/post/", - method="post", - relative_imports=["import this", "from __future__ import braces"], - ) - post_endpoint.name = "camelCase" - - result = template.render(endpoint=post_endpoint) - - import black - - expected = (Path(__file__).parent / "endpoint_module.py").read_text() - black.assert_equivalent(result, expected)