diff --git a/end_to_end_tests/golden-record/my_test_api_client/api/tests/defaults_tests_defaults_post.py b/end_to_end_tests/golden-record/my_test_api_client/api/tests/defaults_tests_defaults_post.py index 88e4421f4..e671e60ab 100644 --- a/end_to_end_tests/golden-record/my_test_api_client/api/tests/defaults_tests_defaults_post.py +++ b/end_to_end_tests/golden-record/my_test_api_client/api/tests/defaults_tests_defaults_post.py @@ -131,7 +131,7 @@ def _get_kwargs( } -def _parse_response(*, response: httpx.Response) -> Optional[Union[None, HTTPValidationError]]: +def _parse_response(*, response: httpx.Response) -> Optional[Union[HTTPValidationError, None]]: if response.status_code == 200: response_200 = None @@ -143,7 +143,7 @@ def _parse_response(*, response: httpx.Response) -> Optional[Union[None, HTTPVal return None -def _build_response(*, response: httpx.Response) -> Response[Union[None, HTTPValidationError]]: +def _build_response(*, response: httpx.Response) -> Response[Union[HTTPValidationError, None]]: return Response( status_code=response.status_code, content=response.content, @@ -172,7 +172,7 @@ def sync_detailed( required_model_prop: ModelWithUnionProperty, nullable_model_prop: Union[Unset, None, ModelWithUnionProperty] = UNSET, nullable_required_model_prop: Optional[ModelWithUnionProperty], -) -> Response[Union[None, HTTPValidationError]]: +) -> Response[Union[HTTPValidationError, None]]: kwargs = _get_kwargs( client=client, string_prop=string_prop, @@ -221,7 +221,7 @@ def sync( required_model_prop: ModelWithUnionProperty, nullable_model_prop: Union[Unset, None, ModelWithUnionProperty] = UNSET, nullable_required_model_prop: Optional[ModelWithUnionProperty], -) -> Optional[Union[None, HTTPValidationError]]: +) -> Optional[Union[HTTPValidationError, None]]: """ """ return sync_detailed( @@ -266,7 +266,7 @@ async def asyncio_detailed( required_model_prop: ModelWithUnionProperty, nullable_model_prop: Union[Unset, None, ModelWithUnionProperty] = UNSET, nullable_required_model_prop: Optional[ModelWithUnionProperty], -) -> Response[Union[None, HTTPValidationError]]: +) -> Response[Union[HTTPValidationError, None]]: kwargs = _get_kwargs( client=client, string_prop=string_prop, @@ -314,7 +314,7 @@ async def asyncio( required_model_prop: ModelWithUnionProperty, nullable_model_prop: Union[Unset, None, ModelWithUnionProperty] = UNSET, nullable_required_model_prop: Optional[ModelWithUnionProperty], -) -> Optional[Union[None, HTTPValidationError]]: +) -> Optional[Union[HTTPValidationError, None]]: """ """ return ( diff --git a/end_to_end_tests/golden-record/my_test_api_client/api/tests/get_user_list.py b/end_to_end_tests/golden-record/my_test_api_client/api/tests/get_user_list.py index ec2216810..095f4be9d 100644 --- a/end_to_end_tests/golden-record/my_test_api_client/api/tests/get_user_list.py +++ b/end_to_end_tests/golden-record/my_test_api_client/api/tests/get_user_list.py @@ -47,7 +47,7 @@ def _get_kwargs( } -def _parse_response(*, response: httpx.Response) -> Optional[Union[List[AModel], HTTPValidationError]]: +def _parse_response(*, response: httpx.Response) -> Optional[Union[HTTPValidationError, List[AModel]]]: if response.status_code == 200: response_200 = [] _response_200 = response.json() @@ -61,10 +61,14 @@ def _parse_response(*, response: httpx.Response) -> Optional[Union[List[AModel], response_422 = HTTPValidationError.from_dict(response.json()) return response_422 + if response.status_code == 423: + response_423 = HTTPValidationError.from_dict(response.json()) + + return response_423 return None -def _build_response(*, response: httpx.Response) -> Response[Union[List[AModel], HTTPValidationError]]: +def _build_response(*, response: httpx.Response) -> Response[Union[HTTPValidationError, List[AModel]]]: return Response( status_code=response.status_code, content=response.content, @@ -78,7 +82,7 @@ def sync_detailed( client: Client, an_enum_value: List[AnEnum], some_date: Union[datetime.date, datetime.datetime], -) -> Response[Union[List[AModel], HTTPValidationError]]: +) -> Response[Union[HTTPValidationError, List[AModel]]]: kwargs = _get_kwargs( client=client, an_enum_value=an_enum_value, @@ -97,7 +101,7 @@ def sync( client: Client, an_enum_value: List[AnEnum], some_date: Union[datetime.date, datetime.datetime], -) -> Optional[Union[List[AModel], HTTPValidationError]]: +) -> Optional[Union[HTTPValidationError, List[AModel]]]: """ Get a list of things """ return sync_detailed( @@ -112,7 +116,7 @@ async def asyncio_detailed( client: Client, an_enum_value: List[AnEnum], some_date: Union[datetime.date, datetime.datetime], -) -> Response[Union[List[AModel], HTTPValidationError]]: +) -> Response[Union[HTTPValidationError, List[AModel]]]: kwargs = _get_kwargs( client=client, an_enum_value=an_enum_value, @@ -130,7 +134,7 @@ async def asyncio( client: Client, an_enum_value: List[AnEnum], some_date: Union[datetime.date, datetime.datetime], -) -> Optional[Union[List[AModel], HTTPValidationError]]: +) -> Optional[Union[HTTPValidationError, List[AModel]]]: """ Get a list of things """ return ( diff --git a/end_to_end_tests/golden-record/my_test_api_client/api/tests/int_enum_tests_int_enum_post.py b/end_to_end_tests/golden-record/my_test_api_client/api/tests/int_enum_tests_int_enum_post.py index 7d14632c4..6691aecf0 100644 --- a/end_to_end_tests/golden-record/my_test_api_client/api/tests/int_enum_tests_int_enum_post.py +++ b/end_to_end_tests/golden-record/my_test_api_client/api/tests/int_enum_tests_int_enum_post.py @@ -34,7 +34,7 @@ def _get_kwargs( } -def _parse_response(*, response: httpx.Response) -> Optional[Union[None, HTTPValidationError]]: +def _parse_response(*, response: httpx.Response) -> Optional[Union[HTTPValidationError, None]]: if response.status_code == 200: response_200 = None @@ -46,7 +46,7 @@ def _parse_response(*, response: httpx.Response) -> Optional[Union[None, HTTPVal return None -def _build_response(*, response: httpx.Response) -> Response[Union[None, HTTPValidationError]]: +def _build_response(*, response: httpx.Response) -> Response[Union[HTTPValidationError, None]]: return Response( status_code=response.status_code, content=response.content, @@ -59,7 +59,7 @@ def sync_detailed( *, client: Client, int_enum: AnIntEnum, -) -> Response[Union[None, HTTPValidationError]]: +) -> Response[Union[HTTPValidationError, None]]: kwargs = _get_kwargs( client=client, int_enum=int_enum, @@ -76,7 +76,7 @@ def sync( *, client: Client, int_enum: AnIntEnum, -) -> Optional[Union[None, HTTPValidationError]]: +) -> Optional[Union[HTTPValidationError, None]]: """ """ return sync_detailed( @@ -89,7 +89,7 @@ async def asyncio_detailed( *, client: Client, int_enum: AnIntEnum, -) -> Response[Union[None, HTTPValidationError]]: +) -> Response[Union[HTTPValidationError, None]]: kwargs = _get_kwargs( client=client, int_enum=int_enum, @@ -105,7 +105,7 @@ async def asyncio( *, client: Client, int_enum: AnIntEnum, -) -> Optional[Union[None, HTTPValidationError]]: +) -> Optional[Union[HTTPValidationError, None]]: """ """ return ( diff --git a/end_to_end_tests/golden-record/my_test_api_client/api/tests/json_body_tests_json_body_post.py b/end_to_end_tests/golden-record/my_test_api_client/api/tests/json_body_tests_json_body_post.py index 074ab9d89..2c21e806e 100644 --- a/end_to_end_tests/golden-record/my_test_api_client/api/tests/json_body_tests_json_body_post.py +++ b/end_to_end_tests/golden-record/my_test_api_client/api/tests/json_body_tests_json_body_post.py @@ -29,7 +29,7 @@ def _get_kwargs( } -def _parse_response(*, response: httpx.Response) -> Optional[Union[None, HTTPValidationError]]: +def _parse_response(*, response: httpx.Response) -> Optional[Union[HTTPValidationError, None]]: if response.status_code == 200: response_200 = None @@ -41,7 +41,7 @@ def _parse_response(*, response: httpx.Response) -> Optional[Union[None, HTTPVal return None -def _build_response(*, response: httpx.Response) -> Response[Union[None, HTTPValidationError]]: +def _build_response(*, response: httpx.Response) -> Response[Union[HTTPValidationError, None]]: return Response( status_code=response.status_code, content=response.content, @@ -54,7 +54,7 @@ def sync_detailed( *, client: Client, json_body: AModel, -) -> Response[Union[None, HTTPValidationError]]: +) -> Response[Union[HTTPValidationError, None]]: kwargs = _get_kwargs( client=client, json_body=json_body, @@ -71,7 +71,7 @@ def sync( *, client: Client, json_body: AModel, -) -> Optional[Union[None, HTTPValidationError]]: +) -> Optional[Union[HTTPValidationError, None]]: """ Try sending a JSON body """ return sync_detailed( @@ -84,7 +84,7 @@ async def asyncio_detailed( *, client: Client, json_body: AModel, -) -> Response[Union[None, HTTPValidationError]]: +) -> Response[Union[HTTPValidationError, None]]: kwargs = _get_kwargs( client=client, json_body=json_body, @@ -100,7 +100,7 @@ async def asyncio( *, client: Client, json_body: AModel, -) -> Optional[Union[None, HTTPValidationError]]: +) -> Optional[Union[HTTPValidationError, None]]: """ Try sending a JSON body """ return ( diff --git a/end_to_end_tests/golden-record/my_test_api_client/api/tests/optional_value_tests_optional_query_param.py b/end_to_end_tests/golden-record/my_test_api_client/api/tests/optional_value_tests_optional_query_param.py index 64431ba2f..0b60a77a2 100644 --- a/end_to_end_tests/golden-record/my_test_api_client/api/tests/optional_value_tests_optional_query_param.py +++ b/end_to_end_tests/golden-record/my_test_api_client/api/tests/optional_value_tests_optional_query_param.py @@ -35,7 +35,7 @@ def _get_kwargs( } -def _parse_response(*, response: httpx.Response) -> Optional[Union[None, HTTPValidationError]]: +def _parse_response(*, response: httpx.Response) -> Optional[Union[HTTPValidationError, None]]: if response.status_code == 200: response_200 = None @@ -47,7 +47,7 @@ def _parse_response(*, response: httpx.Response) -> Optional[Union[None, HTTPVal return None -def _build_response(*, response: httpx.Response) -> Response[Union[None, HTTPValidationError]]: +def _build_response(*, response: httpx.Response) -> Response[Union[HTTPValidationError, None]]: return Response( status_code=response.status_code, content=response.content, @@ -60,7 +60,7 @@ def sync_detailed( *, client: Client, query_param: Union[Unset, List[str]] = UNSET, -) -> Response[Union[None, HTTPValidationError]]: +) -> Response[Union[HTTPValidationError, None]]: kwargs = _get_kwargs( client=client, query_param=query_param, @@ -77,7 +77,7 @@ def sync( *, client: Client, query_param: Union[Unset, List[str]] = UNSET, -) -> Optional[Union[None, HTTPValidationError]]: +) -> Optional[Union[HTTPValidationError, None]]: """ Test optional query parameters """ return sync_detailed( @@ -90,7 +90,7 @@ async def asyncio_detailed( *, client: Client, query_param: Union[Unset, List[str]] = UNSET, -) -> Response[Union[None, HTTPValidationError]]: +) -> Response[Union[HTTPValidationError, None]]: kwargs = _get_kwargs( client=client, query_param=query_param, @@ -106,7 +106,7 @@ async def asyncio( *, client: Client, query_param: Union[Unset, List[str]] = UNSET, -) -> Optional[Union[None, HTTPValidationError]]: +) -> Optional[Union[HTTPValidationError, None]]: """ Test optional query parameters """ return ( diff --git a/end_to_end_tests/golden-record/my_test_api_client/api/tests/token_with_cookie_auth_token_with_cookie_get.py b/end_to_end_tests/golden-record/my_test_api_client/api/tests/token_with_cookie_auth_token_with_cookie_get.py index c497a0b3a..1ca44278f 100644 --- a/end_to_end_tests/golden-record/my_test_api_client/api/tests/token_with_cookie_auth_token_with_cookie_get.py +++ b/end_to_end_tests/golden-record/my_test_api_client/api/tests/token_with_cookie_auth_token_with_cookie_get.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, Optional, Union +from typing import Any, Dict import httpx @@ -26,24 +26,12 @@ def _get_kwargs( } -def _parse_response(*, response: httpx.Response) -> Optional[Union[None, None]]: - if response.status_code == 200: - response_200 = None - - return response_200 - if response.status_code == 401: - response_401 = None - - return response_401 - return None - - -def _build_response(*, response: httpx.Response) -> Response[Union[None, None]]: +def _build_response(*, response: httpx.Response) -> Response[None]: return Response( status_code=response.status_code, content=response.content, headers=response.headers, - parsed=_parse_response(response=response), + parsed=None, ) @@ -51,7 +39,7 @@ def sync_detailed( *, client: Client, my_token: str, -) -> Response[Union[None, None]]: +) -> Response[None]: kwargs = _get_kwargs( client=client, my_token=my_token, @@ -64,24 +52,11 @@ def sync_detailed( return _build_response(response=response) -def sync( - *, - client: Client, - my_token: str, -) -> Optional[Union[None, None]]: - """ Test optional cookie parameters """ - - return sync_detailed( - client=client, - my_token=my_token, - ).parsed - - async def asyncio_detailed( *, client: Client, my_token: str, -) -> Response[Union[None, None]]: +) -> Response[None]: kwargs = _get_kwargs( client=client, my_token=my_token, @@ -91,18 +66,3 @@ async def asyncio_detailed( response = await _client.get(**kwargs) return _build_response(response=response) - - -async def asyncio( - *, - client: Client, - my_token: str, -) -> Optional[Union[None, None]]: - """ Test optional cookie parameters """ - - return ( - await asyncio_detailed( - client=client, - my_token=my_token, - ) - ).parsed diff --git a/end_to_end_tests/golden-record/my_test_api_client/api/tests/upload_file_tests_upload_post.py b/end_to_end_tests/golden-record/my_test_api_client/api/tests/upload_file_tests_upload_post.py index 705443d95..fb18e9f61 100644 --- a/end_to_end_tests/golden-record/my_test_api_client/api/tests/upload_file_tests_upload_post.py +++ b/end_to_end_tests/golden-record/my_test_api_client/api/tests/upload_file_tests_upload_post.py @@ -40,7 +40,7 @@ def _get_kwargs( } -def _parse_response(*, response: httpx.Response) -> Optional[Union[None, HTTPValidationError]]: +def _parse_response(*, response: httpx.Response) -> Optional[Union[HTTPValidationError, None]]: if response.status_code == 200: response_200 = None @@ -52,7 +52,7 @@ def _parse_response(*, response: httpx.Response) -> Optional[Union[None, HTTPVal return None -def _build_response(*, response: httpx.Response) -> Response[Union[None, HTTPValidationError]]: +def _build_response(*, response: httpx.Response) -> Response[Union[HTTPValidationError, None]]: return Response( status_code=response.status_code, content=response.content, @@ -66,7 +66,7 @@ def sync_detailed( client: Client, multipart_data: BodyUploadFileTestsUploadPost, keep_alive: Union[Unset, bool] = UNSET, -) -> Response[Union[None, HTTPValidationError]]: +) -> Response[Union[HTTPValidationError, None]]: kwargs = _get_kwargs( client=client, multipart_data=multipart_data, @@ -85,7 +85,7 @@ def sync( client: Client, multipart_data: BodyUploadFileTestsUploadPost, keep_alive: Union[Unset, bool] = UNSET, -) -> Optional[Union[None, HTTPValidationError]]: +) -> Optional[Union[HTTPValidationError, None]]: """ Upload a file """ return sync_detailed( @@ -100,7 +100,7 @@ async def asyncio_detailed( client: Client, multipart_data: BodyUploadFileTestsUploadPost, keep_alive: Union[Unset, bool] = UNSET, -) -> Response[Union[None, HTTPValidationError]]: +) -> Response[Union[HTTPValidationError, None]]: kwargs = _get_kwargs( client=client, multipart_data=multipart_data, @@ -118,7 +118,7 @@ async def asyncio( client: Client, multipart_data: BodyUploadFileTestsUploadPost, keep_alive: Union[Unset, bool] = UNSET, -) -> Optional[Union[None, HTTPValidationError]]: +) -> Optional[Union[HTTPValidationError, None]]: """ Upload a file """ return ( diff --git a/end_to_end_tests/openapi.json b/end_to_end_tests/openapi.json index 359419473..ed0de8918 100644 --- a/end_to_end_tests/openapi.json +++ b/end_to_end_tests/openapi.json @@ -70,6 +70,16 @@ } } } + }, + "423": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + } } } } diff --git a/openapi_python_client/parser/openapi.py b/openapi_python_client/parser/openapi.py index 2361c3c66..283dd1948 100644 --- a/openapi_python_client/parser/openapi.py +++ b/openapi_python_client/parser/openapi.py @@ -265,6 +265,15 @@ def from_data( return result, schemas + def response_type(self) -> str: + """ Get the Python type of any response from this endpoint """ + types = sorted({response.prop.get_type_string() for response in self.responses}) + if len(types) == 0: + return "None" + if len(types) == 1: + return self.responses[0].prop.get_type_string() + return f"Union[{', '.join(types)}]" + @dataclass class GeneratorData: diff --git a/openapi_python_client/templates/endpoint_macros.py.jinja b/openapi_python_client/templates/endpoint_macros.py.jinja index 705985aab..e822ff320 100644 --- a/openapi_python_client/templates/endpoint_macros.py.jinja +++ b/openapi_python_client/templates/endpoint_macros.py.jinja @@ -71,20 +71,6 @@ params = {k: v for k, v in params.items() if v is not UNSET and v is not None} {% endif %} {% endmacro %} -{% macro return_type(endpoint) %} -{% if endpoint.responses | length == 0 %} -None -{%- elif endpoint.responses | length == 1 %} -{{ endpoint.responses[0].prop.get_type_string() }} -{%- else %} -Union[ - {% for response in endpoint.responses %} - {{ response.prop.get_type_string() }}{{ "," if not loop.last }} - {% endfor %} -] -{%- endif %} -{% endmacro %} - {# The all the kwargs passed into an endpoint (and variants thereof)) #} {% macro arguments(endpoint) %} *, diff --git a/openapi_python_client/templates/endpoint_module.py.jinja b/openapi_python_client/templates/endpoint_module.py.jinja index 5b1c434bc..174d67b9d 100644 --- a/openapi_python_client/templates/endpoint_module.py.jinja +++ b/openapi_python_client/templates/endpoint_module.py.jinja @@ -10,9 +10,9 @@ from ...types import Response, UNSET{% if endpoint.multipart_body_reference %}, {{ relative }} {% endfor %} -{% from "endpoint_macros.py.jinja" import header_params, cookie_params, query_params, json_body, return_type, arguments, client, kwargs, parse_response %} +{% from "endpoint_macros.py.jinja" import header_params, cookie_params, query_params, json_body, arguments, client, kwargs, parse_response %} -{% set return_string = return_type(endpoint) %} +{% set return_string = endpoint.response_type() %} {% set parsed_responses = (endpoint.responses | length > 0) and return_string != "None" %} def _get_kwargs( diff --git a/tests/test_parser/test_openapi.py b/tests/test_parser/test_openapi.py index 1bbaa3e6e..5f0ecdc90 100644 --- a/tests/test_parser/test_openapi.py +++ b/tests/test_parser/test_openapi.py @@ -1,3 +1,7 @@ +from unittest.mock import MagicMock + +import pytest + import openapi_python_client.schema as oai from openapi_python_client import GeneratorError from openapi_python_client.parser.errors import ParseError @@ -114,6 +118,19 @@ def test_from_dict_invalid_version(self, mocker): class TestEndpoint: + def make_endpoint(self): + from openapi_python_client.parser.openapi import Endpoint + + return Endpoint( + path="path", + method="method", + description=None, + name="name", + requires_security=False, + tag="tag", + relative_imports={"import_3"}, + ) + def test_parse_request_form_body(self, mocker): ref = mocker.MagicMock() body = oai.RequestBody.construct( @@ -195,15 +212,7 @@ def test_add_body_no_data(self, mocker): from openapi_python_client.parser.openapi import Endpoint, Schemas parse_request_form_body = mocker.patch.object(Endpoint, "parse_request_form_body") - endpoint = Endpoint( - path="path", - method="method", - description=None, - name="name", - requires_security=False, - tag="tag", - relative_imports={"import_3"}, - ) + endpoint = self.make_endpoint() schemas = Schemas() Endpoint._add_body(endpoint=endpoint, data=oai.Operation.construct(), schemas=schemas) @@ -217,15 +226,7 @@ def test_add_body_bad_data(self, mocker): parse_error = ParseError(data=mocker.MagicMock()) other_schemas = mocker.MagicMock() mocker.patch.object(Endpoint, "parse_request_json_body", return_value=(parse_error, other_schemas)) - endpoint = Endpoint( - path="path", - method="method", - description=None, - name="name", - requires_security=False, - tag="tag", - relative_imports={"import_3"}, - ) + endpoint = self.make_endpoint() request_body = mocker.MagicMock() schemas = Schemas() @@ -263,15 +264,7 @@ def test_add_body_happy(self, mocker): f"{MODULE_NAME}.import_string_from_reference", side_effect=["import_1", "import_2"] ) - endpoint = Endpoint( - path="path", - method="method", - description=None, - name="name", - requires_security=False, - tag="tag", - relative_imports={"import_3"}, - ) + endpoint = self.make_endpoint() initial_schemas = mocker.MagicMock() (endpoint, response_schemas) = Endpoint._add_body( @@ -303,15 +296,7 @@ def test__add_responses_status_code_error(self, mocker): data = { "not_a_number": response_1_data, } - endpoint = Endpoint( - path="path", - method="method", - description=None, - name="name", - requires_security=False, - tag="tag", - relative_imports={"import_3"}, - ) + endpoint = self.make_endpoint() parse_error = ParseError(data=mocker.MagicMock()) response_from_data = mocker.patch(f"{MODULE_NAME}.response_from_data", return_value=(parse_error, schemas)) @@ -333,15 +318,7 @@ def test__add_responses_error(self, mocker): "200": response_1_data, "404": response_2_data, } - endpoint = Endpoint( - path="path", - method="method", - description=None, - name="name", - requires_security=False, - tag="tag", - relative_imports={"import_3"}, - ) + endpoint = self.make_endpoint() parse_error = ParseError(data=mocker.MagicMock()) response_from_data = mocker.patch(f"{MODULE_NAME}.response_from_data", return_value=(parse_error, schemas)) @@ -374,15 +351,7 @@ def test__add_responses(self, mocker): "200": response_1_data, "404": response_2_data, } - endpoint = Endpoint( - path="path", - method="method", - description=None, - name="name", - requires_security=False, - tag="tag", - relative_imports={"import_3"}, - ) + endpoint = self.make_endpoint() schemas = mocker.MagicMock() schemas_1 = mocker.MagicMock() schemas_2 = mocker.MagicMock() @@ -420,14 +389,7 @@ def test__add_responses(self, mocker): def test__add_parameters_handles_no_params(self): from openapi_python_client.parser.openapi import Endpoint, Schemas - endpoint = Endpoint( - path="path", - method="method", - description=None, - name="name", - requires_security=False, - tag="tag", - ) + endpoint = self.make_endpoint() schemas = Schemas() # Just checking there's no exception here assert Endpoint._add_parameters(endpoint=endpoint, data=oai.Operation.construct(), schemas=schemas) == ( @@ -438,14 +400,7 @@ def test__add_parameters_handles_no_params(self): def test__add_parameters_parse_error(self, mocker): from openapi_python_client.parser.openapi import Endpoint - endpoint = Endpoint( - path="path", - method="method", - description=None, - name="name", - requires_security=False, - tag="tag", - ) + endpoint = self.make_endpoint() initial_schemas = mocker.MagicMock() parse_error = ParseError(data=mocker.MagicMock()) property_schemas = mocker.MagicMock() @@ -463,14 +418,7 @@ def test__add_parameters_parse_error(self, mocker): def test__add_parameters_fail_loudly_when_location_not_supported(self, mocker): from openapi_python_client.parser.openapi import Endpoint, Schemas - endpoint = Endpoint( - path="path", - method="method", - description=None, - name="name", - requires_security=False, - tag="tag", - ) + endpoint = self.make_endpoint() parsed_schemas = mocker.MagicMock() mocker.patch(f"{MODULE_NAME}.property_from_data", return_value=(mocker.MagicMock(), parsed_schemas)) param = oai.Parameter.construct( @@ -487,15 +435,7 @@ def test__add_parameters_happy(self, mocker): from openapi_python_client.parser.openapi import Endpoint from openapi_python_client.parser.properties import Property - endpoint = Endpoint( - path="path", - method="method", - description=None, - name="name", - requires_security=False, - tag="tag", - relative_imports={"import_3"}, - ) + endpoint = self.make_endpoint() path_prop = mocker.MagicMock(autospec=Property) path_prop_import = mocker.MagicMock() path_prop.get_imports = mocker.MagicMock(return_value={path_prop_import}) @@ -731,6 +671,19 @@ def test_from_data_no_security(self, mocker): endpoint=_add_responses.return_value[0], data=data, schemas=_add_responses.return_value[1] ) + @pytest.mark.parametrize( + "response_types, expected", + (([], "None"), (["Something"], "Something"), (["First", "Second", "Second"], "Union[First, Second]")), + ) + def test_response_type(self, response_types, expected): + endpoint = self.make_endpoint() + for response_type in response_types: + mock_response = MagicMock() + mock_response.prop.get_type_string.return_value = response_type + endpoint.responses.append(mock_response) + + assert endpoint.response_type() == expected + class TestImportStringFromReference: def test_import_string_from_reference_no_prefix(self, mocker):