Skip to content

fix: Properly replace reserved words in class and module names [#475, #476]. Thanks @mtovts! #477

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Aug 16, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from .parameters import ParametersEndpoints
from .tag1 import Tag1Endpoints
from .tests import TestsEndpoints
from .true_ import True_Endpoints


class MyTestApiClientApi:
Expand All @@ -29,3 +30,7 @@ def tag1(cls) -> Type[Tag1Endpoints]:
@classmethod
def location(cls) -> Type[LocationEndpoints]:
return LocationEndpoints

@classmethod
def true_(cls) -> Type[True_Endpoints]:
return True_Endpoints
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
""" Contains methods for accessing the API Endpoints """

import types

from . import false_


class True_Endpoints:
@classmethod
def false_(cls) -> types.ModuleType:
return false_
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
from typing import Any, Dict

import httpx

from ...client import Client
from ...types import UNSET, Response


def _get_kwargs(
*,
client: Client,
import_: str,
) -> Dict[str, Any]:
url = "{}/naming/keywords".format(client.base_url)

headers: Dict[str, Any] = client.get_headers()
cookies: Dict[str, Any] = client.get_cookies()

params: Dict[str, Any] = {
"import": import_,
}
params = {k: v for k, v in params.items() if v is not UNSET and v is not None}

return {
"url": url,
"headers": headers,
"cookies": cookies,
"timeout": client.get_timeout(),
"params": params,
}


def _build_response(*, response: httpx.Response) -> Response[Any]:
return Response(
status_code=response.status_code,
content=response.content,
headers=response.headers,
parsed=None,
)


def sync_detailed(
*,
client: Client,
import_: str,
) -> Response[Any]:
kwargs = _get_kwargs(
client=client,
import_=import_,
)

response = httpx.get(
**kwargs,
)

return _build_response(response=response)


async def asyncio_detailed(
*,
client: Client,
import_: str,
) -> Response[Any]:
kwargs = _get_kwargs(
client=client,
import_=import_,
)

async with httpx.AsyncClient() as _client:
response = await _client.get(**kwargs)

return _build_response(response=response)
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from .different_enum import DifferentEnum
from .free_form_model import FreeFormModel
from .http_validation_error import HTTPValidationError
from .import_ import Import
from .model_from_all_of import ModelFromAllOf
from .model_name import ModelName
from .model_with_additional_properties_inlined import ModelWithAdditionalPropertiesInlined
Expand All @@ -35,6 +36,7 @@
from .model_with_union_property_inlined import ModelWithUnionPropertyInlined
from .model_with_union_property_inlined_fruit_type_0 import ModelWithUnionPropertyInlinedFruitType0
from .model_with_union_property_inlined_fruit_type_1 import ModelWithUnionPropertyInlinedFruitType1
from .none import None_
from .test_inline_objects_json_body import TestInlineObjectsJsonBody
from .test_inline_objects_response_200 import TestInlineObjectsResponse200
from .validation_error import ValidationError
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
from typing import Any, Dict, List, Type, TypeVar

import attr

T = TypeVar("T", bound="Import")


@attr.s(auto_attribs=True)
class Import:
""" """

additional_properties: Dict[str, Any] = attr.ib(init=False, factory=dict)

def to_dict(self) -> Dict[str, Any]:

field_dict: Dict[str, Any] = {}
field_dict.update(self.additional_properties)
field_dict.update({})

return field_dict

@classmethod
def from_dict(cls: Type[T], src_dict: Dict[str, Any]) -> T:
d = src_dict.copy()
import_ = cls()

import_.additional_properties = d
return import_

@property
def additional_keys(self) -> List[str]:
return list(self.additional_properties.keys())

def __getitem__(self, key: str) -> Any:
return self.additional_properties[key]

def __setitem__(self, key: str, value: Any) -> None:
self.additional_properties[key] = value

def __delitem__(self, key: str) -> None:
del self.additional_properties[key]

def __contains__(self, key: str) -> bool:
return key in self.additional_properties
44 changes: 44 additions & 0 deletions end_to_end_tests/golden-record/my_test_api_client/models/none.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
from typing import Any, Dict, List, Type, TypeVar

import attr

T = TypeVar("T", bound="None_")


@attr.s(auto_attribs=True)
class None_:
""" """

additional_properties: Dict[str, Any] = attr.ib(init=False, factory=dict)

def to_dict(self) -> Dict[str, Any]:

field_dict: Dict[str, Any] = {}
field_dict.update(self.additional_properties)
field_dict.update({})

return field_dict

@classmethod
def from_dict(cls: Type[T], src_dict: Dict[str, Any]) -> T:
d = src_dict.copy()
none = cls()

none.additional_properties = d
return none

@property
def additional_keys(self) -> List[str]:
return list(self.additional_properties.keys())

def __getitem__(self, key: str) -> Any:
return self.additional_properties[key]

def __setitem__(self, key: str, value: Any) -> None:
self.additional_properties[key] = value

def __delitem__(self, key: str) -> None:
del self.additional_properties[key]

def __contains__(self, key: str) -> bool:
return key in self.additional_properties
27 changes: 27 additions & 0 deletions end_to_end_tests/openapi.json
Original file line number Diff line number Diff line change
Expand Up @@ -933,6 +933,27 @@
],
"responses": {}
}
},
"/naming/keywords": {
"description": "Ensure that Python keywords are renamed properly.",
"get": {
"tags": [
"true"
],
"operationId": "false",
"parameters": [
{
"name": "import",
"required": true,
"schema": {
"type": "string",
"nullable": false
},
"in": "query"
}
],
"responses": {}
}
}
},
"components": {
Expand Down Expand Up @@ -1761,6 +1782,12 @@
"AByteStream": {
"type": "string",
"format": "byte"
},
"import": {
"type": "object"
},
"None": {
"type": "object"
}
}
}
Expand Down
8 changes: 4 additions & 4 deletions end_to_end_tests/test_custom_templates/api_init.py.jinja
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@

from typing import Type
{% for tag in endpoint_collections_by_tag.keys() %}
from .{{ tag }} import {{ utils.pascal_case(tag) }}Endpoints
from .{{ tag }} import {{ class_name(tag) }}Endpoints
{% endfor %}

class {{ utils.pascal_case(package_name) }}Api:
class {{ class_name(package_name) }}Api:
{% for tag in endpoint_collections_by_tag.keys() %}
@classmethod
def {{ tag }}(cls) -> Type[{{ utils.pascal_case(tag) }}Endpoints]:
return {{ utils.pascal_case(tag) }}Endpoints
def {{ tag }}(cls) -> Type[{{ class_name(tag) }}Endpoints]:
return {{ class_name(tag) }}Endpoints
{% endfor %}
8 changes: 4 additions & 4 deletions end_to_end_tests/test_custom_templates/endpoint_init.py.jinja
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,15 @@

import types
{% for endpoint in endpoint_collection.endpoints %}
from . import {{ utils.snake_case(endpoint.name) }}
from . import {{ python_identifier(endpoint.name) }}
{% endfor %}

class {{ utils.pascal_case(endpoint_collection.tag) }}Endpoints:
class {{ class_name(endpoint_collection.tag) }}Endpoints:

{% for endpoint in endpoint_collection.endpoints %}

@classmethod
def {{ utils.snake_case(endpoint.name) }}(cls) -> types.ModuleType:
def {{ python_identifier(endpoint.name) }}(cls) -> types.ModuleType:
{% if endpoint.description %}
"""
{{ endpoint.description }}
Expand All @@ -20,5 +20,5 @@ class {{ utils.pascal_case(endpoint_collection.tag) }}Endpoints:
{{ endpoint.summary }}
"""
{% endif %}
return {{ utils.snake_case(endpoint.name) }}
return {{ python_identifier(endpoint.name) }}
{% endfor %}
6 changes: 4 additions & 2 deletions openapi_python_client/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
from .config import Config
from .parser import GeneratorData, import_string_from_class
from .parser.errors import GeneratorError
from .utils import snake_case

if sys.version_info.minor < 8: # version did not exist before 3.8, need to use a backport
from importlib_metadata import version
Expand Down Expand Up @@ -58,6 +57,7 @@ def __init__(
self.openapi: GeneratorData = openapi
self.meta: MetaType = meta
self.file_encoding = file_encoding
self.config = config

package_loader = PackageLoader(__package__)
loader: BaseLoader
Expand Down Expand Up @@ -87,6 +87,8 @@ def __init__(
self.env.filters.update(TEMPLATE_FILTERS)
self.env.globals.update(
utils=utils,
python_identifier=lambda x: utils.PythonIdentifier(x, config.field_prefix),
class_name=lambda x: utils.ClassName(x, config.field_prefix),
package_name=self.package_name,
package_dir=self.package_dir,
package_description=self.package_description,
Expand Down Expand Up @@ -267,7 +269,7 @@ def _build_api(self) -> None:
)

for endpoint in collection.endpoints:
module_path = tag_dir / f"{snake_case(endpoint.name)}.py"
module_path = tag_dir / f"{utils.PythonIdentifier(endpoint.name, self.config.field_prefix)}.py"
module_path.write_text(endpoint_template.render(endpoint=endpoint), encoding=self.file_encoding)


Expand Down
18 changes: 9 additions & 9 deletions openapi_python_client/parser/properties/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from ... import Config
from ... import schema as oai
from ... import utils
from ...utils import ClassName, PythonIdentifier
from ..errors import ParseError, PropertyError

if TYPE_CHECKING: # pragma: no cover
Expand All @@ -17,7 +17,6 @@


_ReferencePath = NewType("_ReferencePath", str)
_ClassName = NewType("_ClassName", str)


def parse_reference_path(ref_path_raw: str) -> Union[_ReferencePath, ParseError]:
Expand All @@ -38,33 +37,34 @@ def parse_reference_path(ref_path_raw: str) -> Union[_ReferencePath, ParseError]
class Class:
"""Represents Python class which will be generated from an OpenAPI schema"""

name: _ClassName
module_name: str
name: ClassName
module_name: PythonIdentifier

@staticmethod
def from_string(*, string: str, config: Config) -> "Class":
"""Get a Class from an arbitrary string"""
class_name = string.split("/")[-1] # Get rid of ref path stuff
class_name = utils.pascal_case(class_name)
class_name = ClassName(class_name, config.field_prefix)
override = config.class_overrides.get(class_name)

if override is not None and override.class_name is not None:
class_name = override.class_name
class_name = ClassName(override.class_name, config.field_prefix)

if override is not None and override.module_name is not None:
module_name = override.module_name
else:
module_name = utils.snake_case(class_name)
module_name = class_name
module_name = PythonIdentifier(module_name, config.field_prefix)

return Class(name=cast(_ClassName, class_name), module_name=module_name)
return Class(name=class_name, module_name=module_name)


@attr.s(auto_attribs=True, frozen=True)
class Schemas:
"""Structure for containing all defined, shareable, and reusable schemas (attr classes and Enums)"""

classes_by_reference: Dict[_ReferencePath, Property] = attr.ib(factory=dict)
classes_by_name: Dict[_ClassName, Property] = attr.ib(factory=dict)
classes_by_name: Dict[ClassName, Property] = attr.ib(factory=dict)
errors: List[ParseError] = attr.ib(factory=list)


Expand Down
Loading