Skip to content

Commit 9556ec9

Browse files
committed
fix: Properly replace reserved words in class and module names [#475, #476]. Thanks @mtovts!
1 parent faef048 commit 9556ec9

File tree

11 files changed

+230
-12
lines changed

11 files changed

+230
-12
lines changed

end_to_end_tests/custom-templates-golden-record/my_test_api_client/api/__init__.py

+5
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
from .default import DefaultEndpoints
66
from .location import LocationEndpoints
7+
from .naming import NamingEndpoints
78
from .parameters import ParametersEndpoints
89
from .tag1 import Tag1Endpoints
910
from .tests import TestsEndpoints
@@ -29,3 +30,7 @@ def tag1(cls) -> Type[Tag1Endpoints]:
2930
@classmethod
3031
def location(cls) -> Type[LocationEndpoints]:
3132
return LocationEndpoints
33+
34+
@classmethod
35+
def naming(cls) -> Type[NamingEndpoints]:
36+
return NamingEndpoints
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
""" Contains methods for accessing the API Endpoints """
2+
3+
import types
4+
5+
from . import get_naming_keywords
6+
7+
8+
class NamingEndpoints:
9+
@classmethod
10+
def get_naming_keywords(cls) -> types.ModuleType:
11+
return get_naming_keywords

end_to_end_tests/golden-record/my_test_api_client/api/naming/__init__.py

Whitespace-only changes.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
from typing import Any, Dict
2+
3+
import httpx
4+
5+
from ...client import Client
6+
from ...types import UNSET, Response
7+
8+
9+
def _get_kwargs(
10+
*,
11+
client: Client,
12+
import_: str,
13+
) -> Dict[str, Any]:
14+
url = "{}/naming/keywords".format(client.base_url)
15+
16+
headers: Dict[str, Any] = client.get_headers()
17+
cookies: Dict[str, Any] = client.get_cookies()
18+
19+
params: Dict[str, Any] = {
20+
"import": import_,
21+
}
22+
params = {k: v for k, v in params.items() if v is not UNSET and v is not None}
23+
24+
return {
25+
"url": url,
26+
"headers": headers,
27+
"cookies": cookies,
28+
"timeout": client.get_timeout(),
29+
"params": params,
30+
}
31+
32+
33+
def _build_response(*, response: httpx.Response) -> Response[Any]:
34+
return Response(
35+
status_code=response.status_code,
36+
content=response.content,
37+
headers=response.headers,
38+
parsed=None,
39+
)
40+
41+
42+
def sync_detailed(
43+
*,
44+
client: Client,
45+
import_: str,
46+
) -> Response[Any]:
47+
kwargs = _get_kwargs(
48+
client=client,
49+
import_=import_,
50+
)
51+
52+
response = httpx.get(
53+
**kwargs,
54+
)
55+
56+
return _build_response(response=response)
57+
58+
59+
async def asyncio_detailed(
60+
*,
61+
client: Client,
62+
import_: str,
63+
) -> Response[Any]:
64+
kwargs = _get_kwargs(
65+
client=client,
66+
import_=import_,
67+
)
68+
69+
async with httpx.AsyncClient() as _client:
70+
response = await _client.get(**kwargs)
71+
72+
return _build_response(response=response)

end_to_end_tests/golden-record/my_test_api_client/models/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from .different_enum import DifferentEnum
2020
from .free_form_model import FreeFormModel
2121
from .http_validation_error import HTTPValidationError
22+
from .import_ import Import
2223
from .model_from_all_of import ModelFromAllOf
2324
from .model_name import ModelName
2425
from .model_with_additional_properties_inlined import ModelWithAdditionalPropertiesInlined
@@ -35,6 +36,7 @@
3536
from .model_with_union_property_inlined import ModelWithUnionPropertyInlined
3637
from .model_with_union_property_inlined_fruit_type_0 import ModelWithUnionPropertyInlinedFruitType0
3738
from .model_with_union_property_inlined_fruit_type_1 import ModelWithUnionPropertyInlinedFruitType1
39+
from .none import None_
3840
from .test_inline_objects_json_body import TestInlineObjectsJsonBody
3941
from .test_inline_objects_response_200 import TestInlineObjectsResponse200
4042
from .validation_error import ValidationError
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
from typing import Any, Dict, List, Type, TypeVar
2+
3+
import attr
4+
5+
T = TypeVar("T", bound="Import")
6+
7+
8+
@attr.s(auto_attribs=True)
9+
class Import:
10+
""" """
11+
12+
additional_properties: Dict[str, Any] = attr.ib(init=False, factory=dict)
13+
14+
def to_dict(self) -> Dict[str, Any]:
15+
16+
field_dict: Dict[str, Any] = {}
17+
field_dict.update(self.additional_properties)
18+
field_dict.update({})
19+
20+
return field_dict
21+
22+
@classmethod
23+
def from_dict(cls: Type[T], src_dict: Dict[str, Any]) -> T:
24+
d = src_dict.copy()
25+
import_ = cls()
26+
27+
import_.additional_properties = d
28+
return import_
29+
30+
@property
31+
def additional_keys(self) -> List[str]:
32+
return list(self.additional_properties.keys())
33+
34+
def __getitem__(self, key: str) -> Any:
35+
return self.additional_properties[key]
36+
37+
def __setitem__(self, key: str, value: Any) -> None:
38+
self.additional_properties[key] = value
39+
40+
def __delitem__(self, key: str) -> None:
41+
del self.additional_properties[key]
42+
43+
def __contains__(self, key: str) -> bool:
44+
return key in self.additional_properties
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
from typing import Any, Dict, List, Type, TypeVar
2+
3+
import attr
4+
5+
T = TypeVar("T", bound="None_")
6+
7+
8+
@attr.s(auto_attribs=True)
9+
class None_:
10+
""" """
11+
12+
additional_properties: Dict[str, Any] = attr.ib(init=False, factory=dict)
13+
14+
def to_dict(self) -> Dict[str, Any]:
15+
16+
field_dict: Dict[str, Any] = {}
17+
field_dict.update(self.additional_properties)
18+
field_dict.update({})
19+
20+
return field_dict
21+
22+
@classmethod
23+
def from_dict(cls: Type[T], src_dict: Dict[str, Any]) -> T:
24+
d = src_dict.copy()
25+
none = cls()
26+
27+
none.additional_properties = d
28+
return none
29+
30+
@property
31+
def additional_keys(self) -> List[str]:
32+
return list(self.additional_properties.keys())
33+
34+
def __getitem__(self, key: str) -> Any:
35+
return self.additional_properties[key]
36+
37+
def __setitem__(self, key: str, value: Any) -> None:
38+
self.additional_properties[key] = value
39+
40+
def __delitem__(self, key: str) -> None:
41+
del self.additional_properties[key]
42+
43+
def __contains__(self, key: str) -> bool:
44+
return key in self.additional_properties

end_to_end_tests/openapi.json

+26
Original file line numberDiff line numberDiff line change
@@ -933,6 +933,26 @@
933933
],
934934
"responses": {}
935935
}
936+
},
937+
"/naming/keywords": {
938+
"description": "Ensure that Python keywords are renamed properly.",
939+
"get": {
940+
"tags": [
941+
"naming"
942+
],
943+
"parameters": [
944+
{
945+
"name": "import",
946+
"required": true,
947+
"schema": {
948+
"type": "string",
949+
"nullable": false
950+
},
951+
"in": "query"
952+
}
953+
],
954+
"responses": {}
955+
}
936956
}
937957
},
938958
"components": {
@@ -1761,6 +1781,12 @@
17611781
"AByteStream": {
17621782
"type": "string",
17631783
"format": "byte"
1784+
},
1785+
"import": {
1786+
"type": "object"
1787+
},
1788+
"None": {
1789+
"type": "object"
17641790
}
17651791
}
17661792
}

openapi_python_client/parser/properties/schemas.py

+9-9
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
from ... import Config
99
from ... import schema as oai
10-
from ... import utils
10+
from ...utils import ClassName, PythonIdentifier
1111
from ..errors import ParseError, PropertyError
1212

1313
if TYPE_CHECKING: # pragma: no cover
@@ -17,7 +17,6 @@
1717

1818

1919
_ReferencePath = NewType("_ReferencePath", str)
20-
_ClassName = NewType("_ClassName", str)
2120

2221

2322
def parse_reference_path(ref_path_raw: str) -> Union[_ReferencePath, ParseError]:
@@ -38,33 +37,34 @@ def parse_reference_path(ref_path_raw: str) -> Union[_ReferencePath, ParseError]
3837
class Class:
3938
"""Represents Python class which will be generated from an OpenAPI schema"""
4039

41-
name: _ClassName
42-
module_name: str
40+
name: ClassName
41+
module_name: PythonIdentifier
4342

4443
@staticmethod
4544
def from_string(*, string: str, config: Config) -> "Class":
4645
"""Get a Class from an arbitrary string"""
4746
class_name = string.split("/")[-1] # Get rid of ref path stuff
48-
class_name = utils.pascal_case(class_name)
47+
class_name = ClassName(class_name, config.field_prefix)
4948
override = config.class_overrides.get(class_name)
5049

5150
if override is not None and override.class_name is not None:
52-
class_name = override.class_name
51+
class_name = ClassName(override.class_name, config.field_prefix)
5352

5453
if override is not None and override.module_name is not None:
5554
module_name = override.module_name
5655
else:
57-
module_name = utils.snake_case(class_name)
56+
module_name = class_name
57+
module_name = PythonIdentifier(module_name, config.field_prefix)
5858

59-
return Class(name=cast(_ClassName, class_name), module_name=module_name)
59+
return Class(name=class_name, module_name=module_name)
6060

6161

6262
@attr.s(auto_attribs=True, frozen=True)
6363
class Schemas:
6464
"""Structure for containing all defined, shareable, and reusable schemas (attr classes and Enums)"""
6565

6666
classes_by_reference: Dict[_ReferencePath, Property] = attr.ib(factory=dict)
67-
classes_by_name: Dict[_ClassName, Property] = attr.ib(factory=dict)
67+
classes_by_name: Dict[ClassName, Property] = attr.ib(factory=dict)
6868
errors: List[ParseError] = attr.ib(factory=list)
6969

7070

openapi_python_client/utils.py

+15-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88

99
class PythonIdentifier(str):
10-
"""A string which has been validated / transformed into a valid identifier for Python"""
10+
"""A snake_case string which has been validated / transformed into a valid identifier for Python"""
1111

1212
def __new__(cls, value: str, prefix: str) -> "PythonIdentifier":
1313
new_value = fix_reserved_words(snake_case(sanitize(value)))
@@ -20,6 +20,20 @@ def __deepcopy__(self, _: Any) -> "PythonIdentifier":
2020
return self
2121

2222

23+
class ClassName(str):
24+
"""A PascalCase string which has been validated / transformed into a valid class name for Python"""
25+
26+
def __new__(cls, value: str, prefix: str) -> "ClassName":
27+
new_value = fix_reserved_words(pascal_case(sanitize(value)))
28+
29+
if not new_value.isidentifier():
30+
new_value = f"{prefix}{new_value}"
31+
return str.__new__(cls, new_value)
32+
33+
def __deepcopy__(self, _: Any) -> "ClassName":
34+
return self
35+
36+
2337
def sanitize(value: str) -> str:
2438
"""Removes every character that isn't 0-9, A-Z, a-z, or a known delimiter"""
2539
return re.sub(rf"[^\w{DELIMITERS}]+", "", value)

tests/test_parser/test_properties/test_model_property.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ def test_additional_schemas(self, additional_properties_schema, expected_additio
7373
)
7474

7575
model, _ = build_model_property(
76-
data=data, name="prop", schemas=Schemas(), required=True, parent_name="parent", config=MagicMock()
76+
data=data, name="prop", schemas=Schemas(), required=True, parent_name="parent", config=Config()
7777
)
7878

7979
assert model.additional_properties == expected_additional_properties
@@ -151,7 +151,7 @@ def test_bad_props_return_error(self):
151151
schemas = Schemas()
152152

153153
err, new_schemas = build_model_property(
154-
data=data, name="prop", schemas=schemas, required=True, parent_name=None, config=MagicMock()
154+
data=data, name="prop", schemas=schemas, required=True, parent_name=None, config=Config()
155155
)
156156

157157
assert new_schemas == schemas

0 commit comments

Comments
 (0)