Skip to content

Commit e461977

Browse files
committed
Add support for query parameters
Closes #5
1 parent 927c399 commit e461977

File tree

9 files changed

+113
-98
lines changed

9 files changed

+113
-98
lines changed

openapi_python_client/__init__.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,9 @@
55

66
import orjson
77
import requests
8-
import stringcase
98
from jinja2 import Environment, PackageLoader
109

11-
from .models import OpenAPI
10+
from .models import OpenAPI, import_string_from_reference
1211

1312

1413
def main():
@@ -57,17 +56,23 @@ def _build_project(openapi: OpenAPI):
5756
models_dir = package_dir / "models"
5857
models_dir.mkdir()
5958
models_init = models_dir / "__init__.py"
60-
models_init.write_text('""" Contains all the data models used in inputs/outputs """')
59+
imports = []
60+
6161
model_template = env.get_template("model.pyi")
6262
for schema in openapi.schemas.values():
63-
module_path = models_dir / f"{stringcase.snakecase(schema.title)}.py"
63+
module_path = models_dir / f"{schema.reference.module_name}.py"
6464
module_path.write_text(model_template.render(schema=schema))
65+
imports.append(import_string_from_reference(schema.reference))
6566

6667
# Generate enums
6768
enum_template = env.get_template("enum.pyi")
6869
for enum in openapi.enums.values():
6970
module_path = models_dir / f"{enum.name}.py"
7071
module_path.write_text(enum_template.render(enum=enum))
72+
imports.append(import_string_from_reference(enum.reference))
73+
74+
models_init_template = env.get_template("models_init.pyi")
75+
models_init.write_text(models_init_template.render(imports=imports))
7176

7277
# Generate Client
7378
client_path = package_dir / "client.py"
Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
11
""" Classes representing the data in the OpenAPI schema """
22

3-
from .api_info import APIInfo
4-
from .openapi import OpenAPI
3+
from .openapi import OpenAPI, import_string_from_reference

openapi_python_client/models/api_info.py

Lines changed: 0 additions & 8 deletions
This file was deleted.

openapi_python_client/models/openapi.py

Lines changed: 60 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -2,39 +2,22 @@
22

33
from dataclasses import dataclass, field
44
from enum import Enum
5-
from typing import Dict, List, Optional, Set
6-
7-
import stringcase
5+
from typing import Dict, List, Optional, Set, Iterable, Generator
86

97
from .properties import Property, property_from_dict, ListProperty, RefProperty, EnumProperty
10-
from .responses import Response, response_from_dict
118
from .reference import Reference
9+
from .responses import Response, response_from_dict
1210

1311

14-
class ParameterLocation(Enum):
12+
class ParameterLocation(str, Enum):
1513
""" The places Parameters can be put when calling an Endpoint """
1614

1715
QUERY = "query"
1816
PATH = "path"
1917

2018

21-
@dataclass
22-
class Parameter:
23-
""" A parameter in an Endpoint """
24-
25-
location: ParameterLocation
26-
property: Property
27-
28-
@staticmethod
29-
def from_dict(d: Dict, /) -> Parameter:
30-
""" Construct a parameter from it's OpenAPI dict form """
31-
return Parameter(
32-
location=ParameterLocation(d["in"]),
33-
property=property_from_dict(name=d["name"], required=d["required"], data=d["schema"]),
34-
)
35-
36-
37-
def _import_string_from_reference(reference: Reference, prefix: str = "") -> str:
19+
def import_string_from_reference(reference: Reference, prefix: str = "") -> str:
20+
""" Create a string which is used to import a reference """
3821
return f"from {prefix}.{reference.module_name} import {reference.class_name}"
3922

4023

@@ -52,11 +35,24 @@ def from_dict(d: Dict[str, Dict[str, Dict]], /) -> Dict[str, EndpointCollection]
5235
endpoints_by_tag: Dict[str, EndpointCollection] = {}
5336
for path, path_data in d.items():
5437
for method, method_data in path_data.items():
55-
parameters: List[Parameter] = []
38+
query_parameters: List[Property] = []
39+
path_parameters: List[Property] = []
5640
responses: List[Response] = []
57-
for param_dict in method_data.get("parameters", []):
58-
parameters.append(Parameter.from_dict(param_dict))
5941
tag = method_data.get("tags", ["default"])[0]
42+
collection = endpoints_by_tag.setdefault(tag, EndpointCollection(tag=tag))
43+
for param_dict in method_data.get("parameters", []):
44+
prop = property_from_dict(
45+
name=param_dict["name"], required=param_dict["required"], data=param_dict["schema"]
46+
)
47+
if isinstance(prop, (ListProperty, RefProperty, EnumProperty)) and prop.reference:
48+
collection.relative_imports.add(import_string_from_reference(prop.reference, prefix="..models"))
49+
if param_dict["in"] == ParameterLocation.QUERY:
50+
query_parameters.append(prop)
51+
elif param_dict["in"] == ParameterLocation.PATH:
52+
path_parameters.append(prop)
53+
else:
54+
raise ValueError(f"Don't know where to put this parameter: {param_dict}")
55+
6056
for code, response_dict in method_data["responses"].items():
6157
response = response_from_dict(status_code=int(code), data=response_dict)
6258
responses.append(response)
@@ -69,17 +65,17 @@ def from_dict(d: Dict[str, Dict[str, Dict]], /) -> Dict[str, EndpointCollection]
6965
method=method,
7066
description=method_data.get("description"),
7167
name=method_data["operationId"],
72-
parameters=parameters,
68+
query_parameters=query_parameters,
69+
path_parameters=path_parameters,
7370
responses=responses,
7471
form_body_reference=form_body_reference,
7572
requires_security=method_data.get("security"),
7673
)
7774

78-
collection = endpoints_by_tag.setdefault(tag, EndpointCollection(tag=tag))
7975
collection.endpoints.append(endpoint)
8076
if form_body_reference:
8177
collection.relative_imports.add(
82-
_import_string_from_reference(form_body_reference, prefix="..models")
78+
import_string_from_reference(form_body_reference, prefix="..models")
8379
)
8480
return endpoints_by_tag
8581

@@ -94,7 +90,8 @@ class Endpoint:
9490
method: str
9591
description: Optional[str]
9692
name: str
97-
parameters: List[Parameter]
93+
query_parameters: List[Property]
94+
path_parameters: List[Property]
9895
responses: List[Response]
9996
requires_security: bool
10097
form_body_reference: Optional[Reference]
@@ -118,7 +115,7 @@ class Schema:
118115
These will all be converted to dataclasses in the client
119116
"""
120117

121-
title: str
118+
reference: Reference
122119
required_properties: List[Property]
123120
optional_properties: List[Property]
124121
description: str
@@ -139,10 +136,10 @@ def from_dict(d: Dict, /) -> Schema:
139136
required_properties.append(p)
140137
else:
141138
optional_properties.append(p)
142-
if isinstance(p, (ListProperty, RefProperty)) and p.reference:
143-
relative_imports.add(_import_string_from_reference(p.reference))
139+
if isinstance(p, (ListProperty, RefProperty, EnumProperty)) and p.reference:
140+
relative_imports.add(import_string_from_reference(p.reference))
144141
schema = Schema(
145-
title=stringcase.pascalcase(d["title"]),
142+
reference=Reference(d["title"]),
146143
required_properties=required_properties,
147144
optional_properties=optional_properties,
148145
relative_imports=relative_imports,
@@ -156,7 +153,7 @@ def dict(d: Dict, /) -> Dict[str, Schema]:
156153
result = {}
157154
for data in d.values():
158155
s = Schema.from_dict(data)
159-
result[s.title] = s
156+
result[s.reference.class_name] = s
160157
return result
161158

162159

@@ -172,29 +169,44 @@ class OpenAPI:
172169
endpoint_collections_by_tag: Dict[str, EndpointCollection]
173170
enums: Dict[str, EnumProperty]
174171

172+
@staticmethod
173+
def check_enums(schemas: Iterable[Schema], collections: Iterable[EndpointCollection]) -> Dict[str, EnumProperty]:
174+
enums: Dict[str, EnumProperty] = {}
175+
176+
def _iterate_properties() -> Generator[Property]:
177+
for schema in schemas:
178+
yield from schema.required_properties
179+
yield from schema.optional_properties
180+
for collection in collections:
181+
for endpoint in collection.endpoints:
182+
yield from endpoint.path_parameters
183+
yield from endpoint.query_parameters
184+
185+
for prop in _iterate_properties():
186+
if not isinstance(prop, EnumProperty):
187+
continue
188+
189+
if prop.reference.class_name in enums:
190+
# We already have an enum with this name, make sure the values match
191+
assert (
192+
prop.values == enums[prop.reference.class_name].values
193+
), f"Encountered conflicting enum named {prop.reference.class_name}"
194+
195+
enums[prop.reference.class_name] = prop
196+
return enums
197+
175198
@staticmethod
176199
def from_dict(d: Dict, /) -> OpenAPI:
177200
""" Create an OpenAPI from dict """
178201
schemas = Schema.dict(d["components"]["schemas"])
179-
enums: Dict[str, EnumProperty] = {}
180-
for schema in schemas.values():
181-
for prop in schema.required_properties + schema.optional_properties:
182-
if not isinstance(prop, EnumProperty):
183-
continue
184-
schema.relative_imports.add(f"from .{prop.name} import {prop.class_name}")
185-
if prop.class_name in enums:
186-
# We already have an enum with this name, make sure the values match
187-
assert (
188-
prop.values == enums[prop.class_name].values
189-
), f"Encountered conflicting enum named {prop.class_name}"
190-
191-
enums[prop.class_name] = prop
202+
endpoint_collections_by_tag = EndpointCollection.from_dict(d["paths"])
203+
enums = OpenAPI.check_enums(schemas.values(), endpoint_collections_by_tag.values())
192204

193205
return OpenAPI(
194206
title=d["info"]["title"],
195207
description=d["info"]["description"],
196208
version=d["info"]["version"],
197-
endpoint_collections_by_tag=EndpointCollection.from_dict(d["paths"]),
209+
endpoint_collections_by_tag=endpoint_collections_by_tag,
198210
schemas=schemas,
199211
security_schemes=d["components"]["securitySchemes"],
200212
enums=enums,

openapi_python_client/models/properties.py

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
from dataclasses import dataclass, field
22
from typing import Optional, List, Dict, Union, ClassVar
33

4-
import stringcase
5-
64
from .reference import Reference
75

86

@@ -36,6 +34,10 @@ def to_string(self) -> str:
3634
else:
3735
return f"{self.name}: {self.get_type_string()}"
3836

37+
def transform(self) -> str:
38+
""" What it takes to turn this object into a native python type """
39+
return self.name
40+
3941

4042
@dataclass
4143
class StringProperty(Property):
@@ -100,17 +102,25 @@ class EnumProperty(Property):
100102
""" A property that should use an enum """
101103

102104
values: Dict[str, str]
103-
class_name: str = field(init=False)
105+
inverse_values: Dict[str, str] = field(init=False)
106+
reference: Reference = field(init=False)
104107

105108
def __post_init__(self):
106-
self.class_name = stringcase.pascalcase(self.name)
109+
self.reference = Reference(self.name)
110+
self.inverse_values = {v: k for k, v in self.values.items()}
111+
if self.default is not None:
112+
self.default = f"{self.reference.class_name}.{self.inverse_values[self.default]}"
107113

108114
def get_type_string(self):
109115
""" Get a string representation of type that should be used when declaring this property """
110116

111117
if self.required:
112-
return self.class_name
113-
return f"Optional[{self.class_name}]"
118+
return self.reference.class_name
119+
return f"Optional[{self.reference.class_name}]"
120+
121+
def transform(self) -> str:
122+
""" Output to the template, convert this Enum into a JSONable value """
123+
return f"{self.name}.value"
114124

115125
@staticmethod
116126
def values_from_list(l: List[str], /) -> Dict[str, str]:
@@ -140,21 +150,13 @@ def get_type_string(self):
140150
return self.reference.class_name
141151
return f"Optional[{self.reference.class_name}]"
142152

143-
def to_string(self) -> str:
144-
""" How this should be declared in a dataclass """
145-
return f"{self.name}: {self.get_type_string()}"
146-
147153

148154
@dataclass
149155
class DictProperty(Property):
150156
""" Property that is a general Dict """
151157

152158
_type_string: ClassVar[str] = "Dict"
153159

154-
def to_string(self) -> str:
155-
""" How this should be declared in a dataclass """
156-
return f"{self.name}: {self.get_type_string()}"
157-
158160

159161
_openapi_types_to_python_type_strings = {
160162
"string": "str",
Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from dataclasses import asdict
2+
from typing import Optional
23

34
import requests
45

@@ -11,43 +12,42 @@ from ..client import AuthenticatedClient, Client
1112

1213
def {{ endpoint.name }}(
1314
*,
15+
{# Proper client based on whether or not the endpoint requires authentication #}
1416
{% if endpoint.requires_security %}
1517
client: AuthenticatedClient,
1618
{% else %}
1719
client: Client,
1820
{% endif %}
21+
{# Form data if any #}
1922
{% if endpoint.form_body_reference %}
2023
form_data: {{ endpoint.form_body_reference.class_name }},
2124
{% endif %}
25+
{# query parameters #}
26+
{% if endpoint.query_parameters %}
27+
{% for parameter in endpoint.query_parameters %}
28+
{{ parameter.to_string() }},
29+
{% endfor %}
30+
{% endif %}
2231
):
2332
""" {{ endpoint.description }} """
2433
url = client.base_url + "{{ endpoint.path }}"
2534

26-
{% if endpoint.method == "get" %}
27-
return requests.get(url=url, headers=client.get_headers())
28-
{% elif endpoint.method == "post" %}
29-
return requests.post(
35+
{% if endpoint.query_parameters %}
36+
params = {
37+
{% for parameter in endpoint.query_parameters %}
38+
"{{ parameter.name }}": {{ parameter.transform() }},
39+
{% endfor %}
40+
}
41+
{% endif %}
42+
43+
return requests.{{ endpoint.method }}(
3044
url=url,
3145
headers=client.get_headers(),
3246
{% if endpoint.form_body_reference %}
3347
data=asdict(form_data),
3448
{% endif %}
35-
)
36-
{% elif endpoint.method == "patch" %}
37-
return requests.patch(
38-
url=url,
39-
headers=client.get_headers()
40-
{% if endpoint.form_body_reference %}
41-
data=asdict(form_data),
49+
{% if endpoint.query_parameters %}
50+
params=params,
4251
{% endif %}
4352
)
44-
{% elif endpoint.method == "put" %}
45-
return requests.put(
46-
url=url,
47-
headers=client.get_headers()
48-
{% if endpoint.form_body_reference %}
49-
data=asdict(form_data),
50-
{% endif %}
51-
)
52-
{% endif %}
5353
{% endfor %}
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from enum import Enum
22

3-
class {{ enum.class_name }}(Enum):
3+
class {{ enum.reference.class_name }}(Enum):
44
{% for key, value in enum.values.items() %}
55
{{ key }} = "{{ value }}"
66
{% endfor %}

0 commit comments

Comments
 (0)