Skip to content

Commit 927c399

Browse files
committed
Added very basic security support
Closes #4
1 parent 8857e2a commit 927c399

File tree

5 files changed

+32
-8
lines changed

5 files changed

+32
-8
lines changed

openapi_python_client/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,8 @@ def _build_project(openapi: OpenAPI):
4141
package_dir.mkdir()
4242
package_init = package_dir / "__init__.py"
4343
package_description = f"A client library for accessing {openapi.title}"
44-
package_init.write_text(f'""" {package_description} """')
44+
package_init_template = env.get_template("package_init.pyi")
45+
package_init.write_text(package_init_template.render(description=package_description))
4546

4647
# Create a pyproject.toml file
4748
pyproject_template = env.get_template("pyproject.toml")

openapi_python_client/models/openapi.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ def from_dict(d: Dict[str, Dict[str, Dict]], /) -> Dict[str, EndpointCollection]
7272
parameters=parameters,
7373
responses=responses,
7474
form_body_reference=form_body_reference,
75+
requires_security=method_data.get("security"),
7576
)
7677

7778
collection = endpoints_by_tag.setdefault(tag, EndpointCollection(tag=tag))
@@ -95,6 +96,7 @@ class Endpoint:
9596
name: str
9697
parameters: List[Parameter]
9798
responses: List[Response]
99+
requires_security: bool
98100
form_body_reference: Optional[Reference]
99101

100102
@staticmethod
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,21 @@
11
from dataclasses import dataclass
2+
from typing import Dict
23

34
@dataclass
45
class Client:
56
""" A class for keeping track of data related to the API """
67
base_url: str
8+
9+
def get_headers(self) -> Dict[str, str]:
10+
""" Get headers to be used in all endpoints """
11+
return {}
12+
13+
14+
@dataclass
15+
class AuthenticatedClient(Client):
16+
""" A Client which has been authenticated for use on secured endpoints """
17+
token: str
18+
19+
def get_headers(self) -> Dict[str, str]:
20+
""" Get headers to be used in authenticated endpoints """
21+
return {"Authorization": f"Bearer {self.token}"}

openapi_python_client/templates/endpoint_module.pyi

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,20 @@ from dataclasses import asdict
22

33
import requests
44

5-
from ..client import Client
5+
from ..client import AuthenticatedClient, Client
66
{% for relative in collection.relative_imports %}
77
{{ relative }}
88
{% endfor %}
9-
109
{% for endpoint in collection.endpoints %}
10+
11+
1112
def {{ endpoint.name }}(
1213
*,
14+
{% if endpoint.requires_security %}
15+
client: AuthenticatedClient,
16+
{% else %}
1317
client: Client,
18+
{% endif %}
1419
{% if endpoint.form_body_reference %}
1520
form_data: {{ endpoint.form_body_reference.class_name }},
1621
{% endif %}
@@ -19,31 +24,30 @@ def {{ endpoint.name }}(
1924
url = client.base_url + "{{ endpoint.path }}"
2025

2126
{% if endpoint.method == "get" %}
22-
return requests.get(url=url)
27+
return requests.get(url=url, headers=client.get_headers())
2328
{% elif endpoint.method == "post" %}
2429
return requests.post(
2530
url=url,
31+
headers=client.get_headers(),
2632
{% if endpoint.form_body_reference %}
2733
data=asdict(form_data),
2834
{% endif %}
2935
)
3036
{% elif endpoint.method == "patch" %}
3137
return requests.patch(
3238
url=url,
39+
headers=client.get_headers()
3340
{% if endpoint.form_body_reference %}
3441
data=asdict(form_data),
3542
{% endif %}
3643
)
3744
{% elif endpoint.method == "put" %}
3845
return requests.put(
3946
url=url,
47+
headers=client.get_headers()
4048
{% if endpoint.form_body_reference %}
4149
data=asdict(form_data),
4250
{% endif %}
4351
)
4452
{% endif %}
45-
46-
4753
{% endfor %}
48-
49-
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
""" {{ description }} """
2+
from .client import Client, AuthenticatedClient

0 commit comments

Comments
 (0)