Skip to content

Commit 3e040d3

Browse files
committed
Added some form data support
Work toward #11, not complete yet.
1 parent e332665 commit 3e040d3

File tree

3 files changed

+83
-27
lines changed

3 files changed

+83
-27
lines changed

openapi_python_client/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,8 +79,8 @@ def _build_project(openapi: OpenAPI):
7979
api_init = api_dir / "__init__.py"
8080
api_init.write_text('""" Contains all methods for accessing the API """')
8181
endpoint_template = env.get_template("endpoint_module.pyi")
82-
for tag, endpoints in openapi.endpoints_by_tag.items():
82+
for tag, collection in openapi.endpoint_collections_by_tag.items():
8383
module_path = api_dir / f"{tag}.py"
84-
module_path.write_text(endpoint_template.render(endpoints=endpoints))
84+
module_path.write_text(endpoint_template.render(collection=collection))
8585

8686

openapi_python_client/models/openapi.py

Lines changed: 50 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from __future__ import annotations
22

3-
from collections import defaultdict
43
from dataclasses import dataclass, field
54
from enum import Enum
65
from typing import Dict, List, Optional, Set
@@ -34,23 +33,21 @@ def from_dict(d: Dict, /) -> Parameter:
3433
)
3534

3635

37-
@dataclass
38-
class Endpoint:
39-
"""
40-
Describes a single endpoint on the server
41-
"""
36+
def _import_string_from_ref(ref: str, prefix: str = "") -> str:
37+
return f"from {prefix}.{stringcase.snakecase(ref)} import {ref}"
4238

43-
path: str
44-
method: str
45-
description: Optional[str]
46-
name: str
47-
parameters: List[Parameter]
48-
responses: List[Response]
39+
40+
@dataclass
41+
class EndpointCollection:
42+
""" A bunch of endpoints grouped under a tag that will become a module """
43+
tag: str
44+
endpoints: List[Endpoint] = field(default_factory=list)
45+
relative_imports: Set[str] = field(default_factory=set)
4946

5047
@staticmethod
51-
def get_by_tags_from_dict(d: Dict[str, Dict[str, Dict]], /) -> Dict[str, List[Endpoint]]:
52-
""" Parse the openapi paths data to get a list of endpoints """
53-
endpoints_by_tag: Dict[str, List[Endpoint]] = defaultdict(list)
48+
def from_dict(d: Dict[str, Dict[str, Dict]], /) -> Dict[str, EndpointCollection]:
49+
""" Parse the openapi paths data to get EndpointCollections by tag """
50+
endpoints_by_tag: Dict[str, EndpointCollection] = {}
5451
for path, path_data in d.items():
5552
for method, method_data in path_data.items():
5653
parameters: List[Parameter] = []
@@ -64,18 +61,52 @@ def get_by_tags_from_dict(d: Dict[str, Dict[str, Dict]], /) -> Dict[str, List[En
6461
data=response_dict,
6562
)
6663
responses.append(response)
64+
form_body_ref = None
65+
if "requestBody" in method_data:
66+
form_body_ref = Endpoint.parse_request_body(method_data["requestBody"])
67+
6768
endpoint = Endpoint(
6869
path=path,
6970
method=method,
7071
description=method_data.get("description"),
7172
name=method_data["operationId"],
7273
parameters=parameters,
7374
responses=responses,
75+
form_body_ref=form_body_ref,
7476
)
75-
endpoints_by_tag[tag].append(endpoint)
77+
78+
collection = endpoints_by_tag.setdefault(tag, EndpointCollection(tag=tag))
79+
collection.endpoints.append(endpoint)
80+
if form_body_ref:
81+
collection.relative_imports.add(_import_string_from_ref(form_body_ref, prefix="..models"))
7682
return endpoints_by_tag
7783

7884

85+
@dataclass
86+
class Endpoint:
87+
"""
88+
Describes a single endpoint on the server
89+
"""
90+
91+
path: str
92+
method: str
93+
description: Optional[str]
94+
name: str
95+
parameters: List[Parameter]
96+
responses: List[Response]
97+
form_body_ref: Optional[str]
98+
99+
@staticmethod
100+
def parse_request_body(body: Dict, /) -> Optional[str]:
101+
""" Return form_body_ref """
102+
form_body_ref = None
103+
body_content = body["content"]
104+
form_body = body_content.get("application/x-www-form-urlencoded")
105+
if form_body:
106+
form_body_ref = form_body["schema"]["$ref"].split("/")[-1]
107+
return form_body_ref
108+
109+
79110
@dataclass
80111
class Schema:
81112
"""
@@ -99,7 +130,7 @@ def from_dict(d: Dict, /) -> Schema:
99130
p = property_from_dict(name=key, required=key in required, data=value)
100131
properties.append(p)
101132
if isinstance(p, (ListProperty, RefProperty)) and p.ref:
102-
schema.relative_imports.add(f"from .{stringcase.snakecase(p.ref)} import {p.ref}")
133+
schema.relative_imports.add(_import_string_from_ref(p.ref))
103134
return schema
104135

105136
@staticmethod
@@ -121,7 +152,7 @@ class OpenAPI:
121152
version: str
122153
security_schemes: Dict
123154
schemas: Dict[str, Schema]
124-
endpoints_by_tag: Dict[str, List[Endpoint]]
155+
endpoint_collections_by_tag: Dict[str, EndpointCollection]
125156
enums: Dict[str, EnumProperty]
126157

127158
@staticmethod
@@ -146,7 +177,7 @@ def from_dict(d: Dict, /) -> OpenAPI:
146177
title=d["info"]["title"],
147178
description=d["info"]["description"],
148179
version=d["info"]["version"],
149-
endpoints_by_tag=Endpoint.get_by_tags_from_dict(d["paths"]),
180+
endpoint_collections_by_tag=EndpointCollection.from_dict(d["paths"]),
150181
schemas=schemas,
151182
security_schemes=d["components"]["securitySchemes"],
152183
enums=enums,

openapi_python_client/templates/endpoint_module.pyi

Lines changed: 31 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,46 @@
1+
from dataclasses import asdict
2+
13
import requests
24

35
from ..client import Client
6+
{% for relative in collection.relative_imports %}
7+
{{ relative }}
8+
{% endfor %}
49

5-
6-
{% for endpoint in endpoints %}
7-
def {{ endpoint.name }}(client: Client):
10+
{% for endpoint in collection.endpoints %}
11+
def {{ endpoint.name }}(
12+
*,
13+
client: Client,
14+
{% if endpoint.form_body_ref %}
15+
form_data: {{ endpoint.form_body_ref }},
16+
{% endif %}
17+
):
818
""" {{ endpoint.description }} """
919
url = client.base_url + "{{ endpoint.path }}"
1020

1121
{% if endpoint.method == "get" %}
1222
return requests.get(url=url)
1323
{% elif endpoint.method == "post" %}
14-
return requests.post(url=url)
24+
return requests.post(
25+
url=url,
26+
{% if endpoint.form_body_ref %}
27+
data=asdict(form_data),
28+
{% endif %}
29+
)
1530
{% elif endpoint.method == "patch" %}
16-
return requests.patch(url=url)
31+
return requests.patch(
32+
url=url,
33+
{% if endpoint.form_body_ref %}
34+
data=asdict(form_data),
35+
{% endif %}
36+
)
1737
{% elif endpoint.method == "put" %}
18-
return requests.put(url=url)
38+
return requests.put(
39+
url=url,
40+
{% if endpoint.form_body_ref %}
41+
data=asdict(form_data),
42+
{% endif %}
43+
)
1944
{% endif %}
2045

2146

0 commit comments

Comments
 (0)