Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 621be8d

Browse files
committedMar 26, 2023
AWS API Gateway with Amazon Lambda integrations support
1 parent 0898d87 commit 621be8d

18 files changed

+976
-165
lines changed
 

‎docs/integrations.rst

+53
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,59 @@ Integrations
33

44
Openapi-core integrates with your popular libraries and frameworks. Each integration offers different levels of integration that help validate and unmarshal your request and response data.
55

6+
Amazon API Gateway
7+
------------------
8+
9+
This section describes integration with `Amazon API Gateway <https://aws.amazon.com/api-gateway/>`__.
10+
11+
It is useful for:
12+
* `AWS Lambda integrations <https://docs.aws.amazon.com/apigateway/latest/developerguide/http-api-develop-integrations-lambda.html>`__ where Lambda functions handle events from API Gateway (Amazon API Gateway event format version 1.0 and 2.0).
13+
* `AWS Lambda function URLs <https://docs.aws.amazon.com/lambda/latest/dg/lambda-urls.html>` where Lambda functions handle events from dedicated HTTP(S) endpoint (Amazon API Gateway event format version 2.0).
14+
15+
Low level
16+
~~~~~~~~~
17+
18+
You can use ``APIGatewayEventV2OpenAPIRequest`` as an API Gateway event (format version 2.0) request factory:
19+
20+
.. code-block:: python
21+
22+
from openapi_core import unmarshal_request
23+
from openapi_core.contrib.aws import APIGatewayEventV2OpenAPIRequest
24+
25+
openapi_request = APIGatewayEventV2OpenAPIRequest(event)
26+
result = unmarshal_request(openapi_request, spec=spec)
27+
28+
If you use format version 1.0, then import and use ``APIGatewayEventOpenAPIRequest`` as an API Gateway event (format version 1.0) request factory.
29+
30+
You can use ``APIGatewayEventV2ResponseOpenAPIResponse`` as an API Gateway event (format version 2.0) response factory:
31+
32+
.. code-block:: python
33+
34+
from openapi_core import unmarshal_response
35+
from openapi_core.contrib.aws import APIGatewayEventV2ResponseOpenAPIResponse
36+
37+
openapi_response = APIGatewayEventV2ResponseOpenAPIResponse(response)
38+
result = unmarshal_response(openapi_request, openapi_response, spec=spec)
39+
40+
If you use format version 1.0, then import and use ``APIGatewayEventResponseOpenAPIResponse`` as an API Gateway event (format version 1.0) response factory.
41+
42+
ANY method
43+
~~~~~~~~~~
44+
45+
API Gateway have special ``ANY`` method that catches all HTTP methods. It's specified as `x-amazon-apigateway-any-method <https://docs.aws.amazon.com/apigateway/latest/developerguide/api-gateway-swagger-extensions-any-method.html>`__ OpenAPI extension. If you use the extension, you want to define ``path_finder_cls`` to be ``APIGatewayPathFinder``:
46+
47+
.. code-block:: python
48+
49+
from openapi_core.contrib.aws import APIGatewayPathFinder
50+
51+
result = unmarshal_response(
52+
openapi_request,
53+
openapi_response,
54+
spec=spec,
55+
path_finder_cls=APIGatewayPathFinder,
56+
)
57+
58+
659
Bottle
760
------
861

‎openapi_core/contrib/aws/__init__.py

+18
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
"""OpenAPI core contrib aws module"""
2+
from openapi_core.contrib.aws.finders import APIGatewayPathFinder
3+
from openapi_core.contrib.aws.requests import APIGatewayEventOpenAPIRequest
4+
from openapi_core.contrib.aws.requests import APIGatewayEventV2OpenAPIRequest
5+
from openapi_core.contrib.aws.responses import (
6+
APIGatewayEventResponseOpenAPIResponse,
7+
)
8+
from openapi_core.contrib.aws.responses import (
9+
APIGatewayEventV2ResponseOpenAPIResponse,
10+
)
11+
12+
__all__ = [
13+
"APIGatewayEventOpenAPIRequest",
14+
"APIGatewayEventV2OpenAPIRequest",
15+
"APIGatewayEventResponseOpenAPIResponse",
16+
"APIGatewayEventV2ResponseOpenAPIResponse",
17+
"APIGatewayPathFinder",
18+
]

‎openapi_core/contrib/aws/datatypes.py

+74
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
from typing import Dict
2+
from typing import List
3+
from typing import Optional
4+
5+
from pydantic import Field
6+
from pydantic.dataclasses import dataclass
7+
8+
9+
class APIGatewayEventConfig:
10+
extra = "allow"
11+
12+
13+
@dataclass(config=APIGatewayEventConfig, frozen=True)
14+
class APIGatewayEvent:
15+
"""AWS API Gateway event"""
16+
17+
headers: Dict[str, str]
18+
19+
path: str
20+
httpMethod: str
21+
resource: str
22+
23+
queryStringParameters: Optional[Dict[str, str]] = None
24+
isBase64Encoded: Optional[bool] = None
25+
body: Optional[str] = None
26+
pathParameters: Optional[Dict[str, str]] = None
27+
stageVariables: Optional[Dict[str, str]] = None
28+
29+
multiValueHeaders: Optional[Dict[str, List[str]]] = None
30+
version: Optional[str] = "1.0"
31+
multiValueQueryStringParameters: Optional[Dict[str, List[str]]] = None
32+
33+
34+
@dataclass(config=APIGatewayEventConfig, frozen=True)
35+
class APIGatewayEventV2:
36+
"""AWS API Gateway event v2"""
37+
38+
headers: Dict[str, str]
39+
40+
version: str
41+
routeKey: str
42+
rawPath: str
43+
rawQueryString: str
44+
45+
queryStringParameters: Optional[Dict[str, str]] = None
46+
isBase64Encoded: Optional[bool] = None
47+
body: Optional[str] = None
48+
pathParameters: Optional[Dict[str, str]] = None
49+
stageVariables: Optional[Dict[str, str]] = None
50+
51+
cookies: Optional[List[str]] = None
52+
53+
54+
@dataclass(config=APIGatewayEventConfig, frozen=True)
55+
class APIGatewayEventResponse:
56+
"""AWS API Gateway event response"""
57+
58+
body: str
59+
isBase64Encoded: bool
60+
statusCode: int
61+
headers: Dict[str, str]
62+
multiValueHeaders: Dict[str, List[str]]
63+
64+
65+
@dataclass(config=APIGatewayEventConfig, frozen=True)
66+
class APIGatewayEventV2Response:
67+
"""AWS API Gateway event v2 response"""
68+
69+
body: str
70+
isBase64Encoded: bool = False
71+
statusCode: int = 200
72+
headers: Dict[str, str] = Field(
73+
default_factory=lambda: {"content-type": "application/json"}
74+
)

‎openapi_core/contrib/aws/finders.py

+8
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
from openapi_core.templating.paths.finders import APICallPathFinder
2+
from openapi_core.templating.paths.iterators import AnyMethodOperationsIterator
3+
4+
5+
class APIGatewayPathFinder(APICallPathFinder):
6+
operations_iterator = AnyMethodOperationsIterator(
7+
any_method="x-amazon-apigateway-any-method",
8+
)

‎openapi_core/contrib/aws/requests.py

+109
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
from typing import Dict
2+
from typing import Optional
3+
4+
from werkzeug.datastructures import Headers
5+
from werkzeug.datastructures import ImmutableMultiDict
6+
7+
from openapi_core.contrib.aws.datatypes import APIGatewayEvent
8+
from openapi_core.contrib.aws.datatypes import APIGatewayEventV2
9+
from openapi_core.contrib.aws.types import APIGatewayEventPayload
10+
from openapi_core.datatypes import RequestParameters
11+
12+
13+
class APIGatewayEventOpenAPIRequest:
14+
"""
15+
Converts an API Gateway event payload to an OpenAPI request
16+
"""
17+
18+
def __init__(self, payload: APIGatewayEventPayload):
19+
self.event = APIGatewayEvent(**payload)
20+
21+
self.parameters = RequestParameters(
22+
query=ImmutableMultiDict(self.query_params),
23+
header=Headers(self.event.headers),
24+
cookie=ImmutableMultiDict(),
25+
)
26+
27+
@property
28+
def query_params(self) -> Dict[str, str]:
29+
params = self.event.queryStringParameters
30+
if params is None:
31+
return {}
32+
return params
33+
34+
@property
35+
def proto(self) -> str:
36+
return self.event.headers.get("X-Forwarded-Proto", "https")
37+
38+
@property
39+
def host(self) -> str:
40+
return self.event.headers["Host"]
41+
42+
@property
43+
def host_url(self) -> str:
44+
return "://".join([self.proto, self.host])
45+
46+
@property
47+
def path(self) -> str:
48+
return self.event.resource
49+
50+
@property
51+
def method(self) -> str:
52+
return self.event.httpMethod.lower()
53+
54+
@property
55+
def body(self) -> Optional[str]:
56+
return self.event.body
57+
58+
@property
59+
def mimetype(self) -> str:
60+
return self.event.headers.get("Content-Type", "")
61+
62+
63+
class APIGatewayEventV2OpenAPIRequest:
64+
"""
65+
Converts an API Gateway event v2 payload to an OpenAPI request
66+
"""
67+
68+
def __init__(self, payload: APIGatewayEventPayload):
69+
self.event = APIGatewayEventV2(**payload)
70+
71+
self.parameters = RequestParameters(
72+
query=ImmutableMultiDict(self.query_params),
73+
header=Headers(self.event.headers),
74+
cookie=ImmutableMultiDict(),
75+
)
76+
77+
@property
78+
def query_params(self) -> Dict[str, str]:
79+
if self.event.queryStringParameters is None:
80+
return {}
81+
return self.event.queryStringParameters
82+
83+
@property
84+
def proto(self) -> str:
85+
return self.event.headers.get("x-forwarded-proto", "https")
86+
87+
@property
88+
def host(self) -> str:
89+
return self.event.headers["host"]
90+
91+
@property
92+
def host_url(self) -> str:
93+
return "://".join([self.proto, self.host])
94+
95+
@property
96+
def path(self) -> str:
97+
return self.event.rawPath
98+
99+
@property
100+
def method(self) -> str:
101+
return self.event.routeKey.lower()
102+
103+
@property
104+
def body(self) -> Optional[str]:
105+
return self.event.body
106+
107+
@property
108+
def mimetype(self) -> str:
109+
return self.event.headers.get("content-type", "")

‎openapi_core/contrib/aws/responses.py

+83
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
from json import dumps
2+
from typing import Union
3+
4+
from werkzeug.datastructures import Headers
5+
6+
from openapi_core.contrib.aws.datatypes import APIGatewayEventResponse
7+
from openapi_core.contrib.aws.datatypes import APIGatewayEventV2Response
8+
from openapi_core.contrib.aws.types import APIGatewayEventResponsePayload
9+
10+
APIGatewayEventV2ResponseType = Union[APIGatewayEventV2Response, dict, str]
11+
12+
13+
class APIGatewayEventResponseOpenAPIResponse:
14+
"""
15+
Converts an API Gateway event response payload to an OpenAPI request
16+
"""
17+
18+
def __init__(self, payload: APIGatewayEventResponsePayload):
19+
self.response = APIGatewayEventResponse(**payload)
20+
21+
@property
22+
def data(self) -> str:
23+
return self.response.body
24+
25+
@property
26+
def status_code(self) -> int:
27+
return self.response.statusCode
28+
29+
@property
30+
def headers(self) -> Headers:
31+
return Headers(self.response.headers)
32+
33+
@property
34+
def mimetype(self) -> str:
35+
content_type = self.response.headers.get("Content-Type", "")
36+
assert isinstance(content_type, str)
37+
return content_type
38+
39+
40+
class APIGatewayEventV2ResponseOpenAPIResponse:
41+
"""
42+
Converts an API Gateway event v2 response payload to an OpenAPI request
43+
"""
44+
45+
def __init__(self, payload: Union[APIGatewayEventResponsePayload, str]):
46+
if not isinstance(payload, dict):
47+
payload = self._construct_payload(payload)
48+
elif "statusCode" not in payload:
49+
body = dumps(payload)
50+
payload = self._construct_payload(body)
51+
52+
self.response = APIGatewayEventV2Response(**payload)
53+
54+
@staticmethod
55+
def _construct_payload(body: str) -> APIGatewayEventResponsePayload:
56+
return {
57+
"isBase64Encoded": False,
58+
"statusCode": 200,
59+
"headers": {
60+
"content-type": "application/json",
61+
},
62+
"body": body,
63+
}
64+
65+
@property
66+
def data(self) -> str:
67+
return self.response.body
68+
69+
@property
70+
def status_code(self) -> int:
71+
return self.response.statusCode
72+
73+
@property
74+
def headers(self) -> Headers:
75+
return Headers(self.response.headers)
76+
77+
@property
78+
def mimetype(self) -> str:
79+
content_type = self.response.headers.get(
80+
"content-type", "application/json"
81+
)
82+
assert isinstance(content_type, str)
83+
return content_type

‎openapi_core/contrib/aws/types.py

+5
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
from typing import Any
2+
from typing import Dict
3+
4+
APIGatewayEventPayload = Dict[str, Any]
5+
APIGatewayEventResponsePayload = Dict[str, Any]
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
from openapi_core.templating.paths.finders import APICallPathFinder
2+
from openapi_core.templating.paths.finders import WebhookPathFinder
3+
4+
__all__ = [
5+
"APICallPathFinder",
6+
"WebhookPathFinder",
7+
]

‎openapi_core/templating/paths/finders.py

+33-115
Original file line numberDiff line numberDiff line change
@@ -17,148 +17,66 @@
1717
from openapi_core.templating.paths.exceptions import PathNotFound
1818
from openapi_core.templating.paths.exceptions import PathsNotFound
1919
from openapi_core.templating.paths.exceptions import ServerNotFound
20+
from openapi_core.templating.paths.iterators import SimpleOperationsIterator
21+
from openapi_core.templating.paths.iterators import SimplePathsIterator
22+
from openapi_core.templating.paths.iterators import SimpleServersIterator
23+
from openapi_core.templating.paths.iterators import TemplatePathsIterator
24+
from openapi_core.templating.paths.iterators import TemplateServersIterator
25+
from openapi_core.templating.paths.protocols import OperationsIterator
26+
from openapi_core.templating.paths.protocols import PathsIterator
27+
from openapi_core.templating.paths.protocols import ServersIterator
2028
from openapi_core.templating.paths.util import template_path_len
2129
from openapi_core.templating.util import parse
2230
from openapi_core.templating.util import search
2331

2432

25-
class BasePathFinder:
33+
class PathFinder:
34+
paths_iterator: PathsIterator = NotImplemented
35+
operations_iterator: OperationsIterator = NotImplemented
36+
servers_iterator: ServersIterator = NotImplemented
37+
2638
def __init__(self, spec: Spec, base_url: Optional[str] = None):
2739
self.spec = spec
2840
self.base_url = base_url
2941

3042
def find(self, method: str, name: str) -> PathOperationServer:
31-
paths_iter = self._get_paths_iter(name)
43+
paths_iter = self.paths_iterator(
44+
name,
45+
self.spec,
46+
base_url=self.base_url,
47+
)
3248
paths_iter_peek = peekable(paths_iter)
3349

3450
if not paths_iter_peek:
3551
raise PathNotFound(name)
3652

37-
operations_iter = self._get_operations_iter(method, paths_iter_peek)
53+
operations_iter = self.operations_iterator(
54+
method,
55+
paths_iter_peek,
56+
self.spec,
57+
base_url=self.base_url,
58+
)
3859
operations_iter_peek = peekable(operations_iter)
3960

4061
if not operations_iter_peek:
4162
raise OperationNotFound(name, method)
4263

43-
servers_iter = self._get_servers_iter(
44-
name,
45-
operations_iter_peek,
64+
servers_iter = self.servers_iterator(
65+
name, operations_iter_peek, self.spec, base_url=self.base_url
4666
)
4767

4868
try:
4969
return next(servers_iter)
5070
except StopIteration:
5171
raise ServerNotFound(name)
5272

53-
def _get_paths_iter(self, name: str) -> Iterator[Path]:
54-
raise NotImplementedError
55-
56-
def _get_operations_iter(
57-
self, method: str, paths_iter: Iterator[Path]
58-
) -> Iterator[PathOperation]:
59-
for path, path_result in paths_iter:
60-
if method not in path:
61-
continue
62-
operation = path / method
63-
yield PathOperation(path, operation, path_result)
64-
65-
def _get_servers_iter(
66-
self, name: str, operations_iter: Iterator[PathOperation]
67-
) -> Iterator[PathOperationServer]:
68-
raise NotImplementedError
69-
70-
71-
class APICallPathFinder(BasePathFinder):
72-
def __init__(self, spec: Spec, base_url: Optional[str] = None):
73-
self.spec = spec
74-
self.base_url = base_url
75-
76-
def _get_paths_iter(self, name: str) -> Iterator[Path]:
77-
paths = self.spec / "paths"
78-
if not paths.exists():
79-
raise PathsNotFound(paths.uri())
80-
template_paths: List[Path] = []
81-
for path_pattern, path in list(paths.items()):
82-
# simple path.
83-
# Return right away since it is always the most concrete
84-
if name.endswith(path_pattern):
85-
path_result = TemplateResult(path_pattern, {})
86-
yield Path(path, path_result)
87-
# template path
88-
else:
89-
result = search(path_pattern, name)
90-
if result:
91-
path_result = TemplateResult(path_pattern, result.named)
92-
template_paths.append(Path(path, path_result))
93-
94-
# Fewer variables -> more concrete path
95-
yield from sorted(template_paths, key=template_path_len)
96-
97-
def _get_servers_iter(
98-
self, name: str, operations_iter: Iterator[PathOperation]
99-
) -> Iterator[PathOperationServer]:
100-
for path, operation, path_result in operations_iter:
101-
servers = (
102-
path.get("servers", None)
103-
or operation.get("servers", None)
104-
or self.spec.get("servers", [{"url": "/"}])
105-
)
106-
for server in servers:
107-
server_url_pattern = name.rsplit(path_result.resolved, 1)[0]
108-
server_url = server["url"]
109-
if not is_absolute(server_url):
110-
# relative to absolute url
111-
if self.base_url is not None:
112-
server_url = urljoin(self.base_url, server["url"])
113-
# if no base url check only path part
114-
else:
115-
server_url_pattern = urlparse(server_url_pattern).path
116-
if server_url.endswith("/"):
117-
server_url = server_url[:-1]
118-
# simple path
119-
if server_url_pattern == server_url:
120-
server_result = TemplateResult(server["url"], {})
121-
yield PathOperationServer(
122-
path,
123-
operation,
124-
server,
125-
path_result,
126-
server_result,
127-
)
128-
# template path
129-
else:
130-
result = parse(server["url"], server_url_pattern)
131-
if result:
132-
server_result = TemplateResult(
133-
server["url"], result.named
134-
)
135-
yield PathOperationServer(
136-
path,
137-
operation,
138-
server,
139-
path_result,
140-
server_result,
141-
)
14273

74+
class APICallPathFinder(PathFinder):
75+
paths_iterator: PathsIterator = TemplatePathsIterator("paths")
76+
operations_iterator: OperationsIterator = SimpleOperationsIterator()
77+
servers_iterator: ServersIterator = TemplateServersIterator()
14378

144-
class WebhookPathFinder(BasePathFinder):
145-
def _get_paths_iter(self, name: str) -> Iterator[Path]:
146-
webhooks = self.spec / "webhooks"
147-
if not webhooks.exists():
148-
raise PathsNotFound(webhooks.uri())
149-
for webhook_name, path in list(webhooks.items()):
150-
if name == webhook_name:
151-
path_result = TemplateResult(webhook_name, {})
152-
yield Path(path, path_result)
15379

154-
def _get_servers_iter(
155-
self, name: str, operations_iter: Iterator[PathOperation]
156-
) -> Iterator[PathOperationServer]:
157-
for path, operation, path_result in operations_iter:
158-
yield PathOperationServer(
159-
path,
160-
operation,
161-
None,
162-
path_result,
163-
{},
164-
)
80+
class WebhookPathFinder(APICallPathFinder):
81+
paths_iterator = SimplePathsIterator("webhooks")
82+
servers_iterator = SimpleServersIterator()
+186
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,186 @@
1+
from itertools import tee
2+
from typing import Iterator
3+
from typing import List
4+
from typing import Optional
5+
from urllib.parse import urljoin
6+
from urllib.parse import urlparse
7+
8+
from more_itertools import peekable
9+
10+
from openapi_core.schema.servers import is_absolute
11+
from openapi_core.spec import Spec
12+
from openapi_core.templating.datatypes import TemplateResult
13+
from openapi_core.templating.paths.datatypes import Path
14+
from openapi_core.templating.paths.datatypes import PathOperation
15+
from openapi_core.templating.paths.datatypes import PathOperationServer
16+
from openapi_core.templating.paths.exceptions import OperationNotFound
17+
from openapi_core.templating.paths.exceptions import PathNotFound
18+
from openapi_core.templating.paths.exceptions import PathsNotFound
19+
from openapi_core.templating.paths.exceptions import ServerNotFound
20+
from openapi_core.templating.paths.util import template_path_len
21+
from openapi_core.templating.util import parse
22+
from openapi_core.templating.util import search
23+
24+
25+
class SimplePathsIterator:
26+
def __init__(self, paths_part: str):
27+
self.paths_part = paths_part
28+
29+
def __call__(
30+
self, name: str, spec: Spec, base_url: Optional[str] = None
31+
) -> Iterator[Path]:
32+
paths = spec / self.paths_part
33+
if not paths.exists():
34+
raise PathsNotFound(paths.uri())
35+
for path_name, path in list(paths.items()):
36+
if name == path_name:
37+
path_result = TemplateResult(path_name, {})
38+
yield Path(path, path_result)
39+
40+
41+
class TemplatePathsIterator:
42+
def __init__(self, paths_part: str):
43+
self.paths_part = paths_part
44+
45+
def __call__(
46+
self, name: str, spec: Spec, base_url: Optional[str] = None
47+
) -> Iterator[Path]:
48+
paths = spec / self.paths_part
49+
if not paths.exists():
50+
raise PathsNotFound(paths.uri())
51+
template_paths: List[Path] = []
52+
for path_pattern, path in list(paths.items()):
53+
# simple path.
54+
# Return right away since it is always the most concrete
55+
if name.endswith(path_pattern):
56+
path_result = TemplateResult(path_pattern, {})
57+
yield Path(path, path_result)
58+
# template path
59+
else:
60+
result = search(path_pattern, name)
61+
if result:
62+
path_result = TemplateResult(path_pattern, result.named)
63+
template_paths.append(Path(path, path_result))
64+
65+
# Fewer variables -> more concrete path
66+
yield from sorted(template_paths, key=template_path_len)
67+
68+
69+
class SimpleOperationsIterator:
70+
def __call__(
71+
self,
72+
method: str,
73+
paths_iter: Iterator[Path],
74+
spec: Spec,
75+
base_url: Optional[str] = None,
76+
) -> Iterator[PathOperation]:
77+
for path, path_result in paths_iter:
78+
if method not in path:
79+
continue
80+
operation = path / method
81+
yield PathOperation(path, operation, path_result)
82+
83+
84+
class AnyMethodOperationsIterator(SimpleOperationsIterator):
85+
def __init__(self, any_method: str):
86+
self.any_method = any_method
87+
88+
def __call__(
89+
self,
90+
method: str,
91+
paths_iter: Iterator[Path],
92+
spec: Spec,
93+
base_url: Optional[str] = None,
94+
) -> Iterator[PathOperation]:
95+
paths_iter_1, paths_iter_2 = tee(paths_iter, 2)
96+
yield from super().__call__(
97+
method, paths_iter_1, spec, base_url=base_url
98+
)
99+
yield from super().__call__(
100+
self.any_method, paths_iter_2, spec, base_url=base_url
101+
)
102+
103+
104+
class SimpleServersIterator:
105+
def __call__(
106+
self,
107+
name: str,
108+
operations_iter: Iterator[PathOperation],
109+
spec: Spec,
110+
base_url: Optional[str] = None,
111+
) -> Iterator[PathOperationServer]:
112+
for path, operation, path_result in operations_iter:
113+
yield PathOperationServer(
114+
path,
115+
operation,
116+
None,
117+
path_result,
118+
{},
119+
)
120+
121+
122+
class TemplateServersIterator:
123+
def __call__(
124+
self,
125+
name: str,
126+
operations_iter: Iterator[PathOperation],
127+
spec: Spec,
128+
base_url: Optional[str] = None,
129+
) -> Iterator[PathOperationServer]:
130+
for path, operation, path_result in operations_iter:
131+
servers = (
132+
path.get("servers", None)
133+
or operation.get("servers", None)
134+
or spec.get("servers", [{"url": "/"}])
135+
)
136+
for server in servers:
137+
server_url_pattern = name.rsplit(path_result.resolved, 1)[0]
138+
server_url = server["url"]
139+
if not is_absolute(server_url):
140+
# relative to absolute url
141+
if base_url is not None:
142+
server_url = urljoin(base_url, server["url"])
143+
# if no base url check only path part
144+
else:
145+
server_url_pattern = urlparse(server_url_pattern).path
146+
if server_url.endswith("/"):
147+
server_url = server_url[:-1]
148+
# simple path
149+
if server_url_pattern == server_url:
150+
server_result = TemplateResult(server["url"], {})
151+
yield PathOperationServer(
152+
path,
153+
operation,
154+
server,
155+
path_result,
156+
server_result,
157+
)
158+
# template path
159+
else:
160+
result = parse(server["url"], server_url_pattern)
161+
if result:
162+
server_result = TemplateResult(
163+
server["url"], result.named
164+
)
165+
yield PathOperationServer(
166+
path,
167+
operation,
168+
server,
169+
path_result,
170+
server_result,
171+
)
172+
# servers should'n end with tailing slash
173+
# but let's search for this too
174+
server_url_pattern += "/"
175+
result = parse(server["url"], server_url_pattern)
176+
if result:
177+
server_result = TemplateResult(
178+
server["url"], result.named
179+
)
180+
yield PathOperationServer(
181+
path,
182+
operation,
183+
server,
184+
path_result,
185+
server_result,
186+
)
+47
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
import sys
2+
from typing import Iterator
3+
from typing import Optional
4+
5+
if sys.version_info >= (3, 8):
6+
from typing import Protocol
7+
from typing import runtime_checkable
8+
else:
9+
from typing_extensions import Protocol
10+
from typing_extensions import runtime_checkable
11+
12+
from openapi_core.spec import Spec
13+
from openapi_core.templating.paths.datatypes import Path
14+
from openapi_core.templating.paths.datatypes import PathOperation
15+
from openapi_core.templating.paths.datatypes import PathOperationServer
16+
17+
18+
@runtime_checkable
19+
class PathsIterator(Protocol):
20+
def __call__(
21+
self, name: str, spec: Spec, base_url: Optional[str] = None
22+
) -> Iterator[Path]:
23+
...
24+
25+
26+
@runtime_checkable
27+
class OperationsIterator(Protocol):
28+
def __call__(
29+
self,
30+
method: str,
31+
paths_iter: Iterator[Path],
32+
spec: Spec,
33+
base_url: Optional[str] = None,
34+
) -> Iterator[PathOperation]:
35+
...
36+
37+
38+
@runtime_checkable
39+
class ServersIterator(Protocol):
40+
def __call__(
41+
self,
42+
name: str,
43+
operations_iter: Iterator[PathOperation],
44+
spec: Spec,
45+
base_url: Optional[str] = None,
46+
) -> Iterator[PathOperationServer]:
47+
...
+5
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
from typing import Type
2+
3+
from openapi_core.templating.paths.finders import PathFinder
4+
5+
PathFinderType = Type[PathFinder]

‎openapi_core/unmarshalling/request/unmarshallers.py

+4
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from openapi_core.security.factories import SecurityProviderFactory
2626
from openapi_core.spec import Spec
2727
from openapi_core.templating.paths.exceptions import PathError
28+
from openapi_core.templating.paths.types import PathFinderType
2829
from openapi_core.unmarshalling.request.datatypes import RequestUnmarshalResult
2930
from openapi_core.unmarshalling.request.proxies import (
3031
SpecRequestValidatorProxy,
@@ -92,6 +93,7 @@ def __init__(
9293
schema_casters_factory: SchemaCastersFactory = schema_casters_factory,
9394
parameter_deserializers_factory: ParameterDeserializersFactory = parameter_deserializers_factory,
9495
media_type_deserializers_factory: MediaTypeDeserializersFactory = media_type_deserializers_factory,
96+
path_finder_cls: Optional[PathFinderType] = None,
9597
schema_validators_factory: Optional[SchemaValidatorsFactory] = None,
9698
format_validators: Optional[FormatValidatorsDict] = None,
9799
extra_format_validators: Optional[FormatValidatorsDict] = None,
@@ -112,6 +114,7 @@ def __init__(
112114
schema_casters_factory=schema_casters_factory,
113115
parameter_deserializers_factory=parameter_deserializers_factory,
114116
media_type_deserializers_factory=media_type_deserializers_factory,
117+
path_finder_cls=path_finder_cls,
115118
schema_validators_factory=schema_validators_factory,
116119
format_validators=format_validators,
117120
extra_format_validators=extra_format_validators,
@@ -127,6 +130,7 @@ def __init__(
127130
schema_casters_factory=schema_casters_factory,
128131
parameter_deserializers_factory=parameter_deserializers_factory,
129132
media_type_deserializers_factory=media_type_deserializers_factory,
133+
path_finder_cls=path_finder_cls,
130134
schema_validators_factory=schema_validators_factory,
131135
format_validators=format_validators,
132136
extra_format_validators=extra_format_validators,

‎openapi_core/unmarshalling/unmarshallers.py

+3
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
ParameterDeserializersFactory,
2222
)
2323
from openapi_core.spec import Spec
24+
from openapi_core.templating.paths.types import PathFinderType
2425
from openapi_core.unmarshalling.schemas.datatypes import (
2526
FormatUnmarshallersDict,
2627
)
@@ -42,6 +43,7 @@ def __init__(
4243
schema_casters_factory: SchemaCastersFactory = schema_casters_factory,
4344
parameter_deserializers_factory: ParameterDeserializersFactory = parameter_deserializers_factory,
4445
media_type_deserializers_factory: MediaTypeDeserializersFactory = media_type_deserializers_factory,
46+
path_finder_cls: Optional[PathFinderType] = None,
4547
schema_validators_factory: Optional[SchemaValidatorsFactory] = None,
4648
format_validators: Optional[FormatValidatorsDict] = None,
4749
extra_format_validators: Optional[FormatValidatorsDict] = None,
@@ -64,6 +66,7 @@ def __init__(
6466
schema_casters_factory=schema_casters_factory,
6567
parameter_deserializers_factory=parameter_deserializers_factory,
6668
media_type_deserializers_factory=media_type_deserializers_factory,
69+
path_finder_cls=path_finder_cls,
6770
schema_validators_factory=schema_validators_factory,
6871
format_validators=format_validators,
6972
extra_format_validators=extra_format_validators,

‎openapi_core/validation/request/validators.py

+3-5
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
from openapi_core.spec.paths import Spec
3434
from openapi_core.templating.paths.exceptions import PathError
3535
from openapi_core.templating.paths.finders import WebhookPathFinder
36+
from openapi_core.templating.paths.types import PathFinderType
3637
from openapi_core.templating.security.exceptions import SecurityNotFound
3738
from openapi_core.util import chainiters
3839
from openapi_core.validation.decorators import ValidationErrorWrapper
@@ -70,6 +71,7 @@ def __init__(
7071
schema_casters_factory: SchemaCastersFactory = schema_casters_factory,
7172
parameter_deserializers_factory: ParameterDeserializersFactory = parameter_deserializers_factory,
7273
media_type_deserializers_factory: MediaTypeDeserializersFactory = media_type_deserializers_factory,
74+
path_finder_cls: Optional[PathFinderType] = None,
7375
schema_validators_factory: Optional[SchemaValidatorsFactory] = None,
7476
format_validators: Optional[FormatValidatorsDict] = None,
7577
extra_format_validators: Optional[FormatValidatorsDict] = None,
@@ -84,6 +86,7 @@ def __init__(
8486
schema_casters_factory=schema_casters_factory,
8587
parameter_deserializers_factory=parameter_deserializers_factory,
8688
media_type_deserializers_factory=media_type_deserializers_factory,
89+
path_finder_cls=path_finder_cls,
8790
schema_validators_factory=schema_validators_factory,
8891
format_validators=format_validators,
8992
extra_format_validators=extra_format_validators,
@@ -414,24 +417,19 @@ class V31RequestSecurityValidator(APICallRequestSecurityValidator):
414417

415418
class V31RequestValidator(APICallRequestValidator):
416419
schema_validators_factory = oas31_schema_validators_factory
417-
path_finder_cls = WebhookPathFinder
418420

419421

420422
class V31WebhookRequestBodyValidator(WebhookRequestBodyValidator):
421423
schema_validators_factory = oas31_schema_validators_factory
422-
path_finder_cls = WebhookPathFinder
423424

424425

425426
class V31WebhookRequestParametersValidator(WebhookRequestParametersValidator):
426427
schema_validators_factory = oas31_schema_validators_factory
427-
path_finder_cls = WebhookPathFinder
428428

429429

430430
class V31WebhookRequestSecurityValidator(WebhookRequestSecurityValidator):
431431
schema_validators_factory = oas31_schema_validators_factory
432-
path_finder_cls = WebhookPathFinder
433432

434433

435434
class V31WebhookRequestValidator(WebhookRequestValidator):
436435
schema_validators_factory = oas31_schema_validators_factory
437-
path_finder_cls = WebhookPathFinder

‎openapi_core/validation/validators.py

+14-7
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from typing import Mapping
55
from typing import Optional
66
from typing import Tuple
7+
from typing import Type
78
from urllib.parse import urljoin
89

910
if sys.version_info >= (3, 8):
@@ -34,13 +35,15 @@
3435
from openapi_core.templating.media_types.datatypes import MediaType
3536
from openapi_core.templating.paths.datatypes import PathOperationServer
3637
from openapi_core.templating.paths.finders import APICallPathFinder
37-
from openapi_core.templating.paths.finders import BasePathFinder
38+
from openapi_core.templating.paths.finders import PathFinder
3839
from openapi_core.templating.paths.finders import WebhookPathFinder
40+
from openapi_core.templating.paths.types import PathFinderType
3941
from openapi_core.validation.schemas.datatypes import FormatValidatorsDict
4042
from openapi_core.validation.schemas.factories import SchemaValidatorsFactory
4143

4244

4345
class BaseValidator:
46+
path_finder_cls: PathFinderType = NotImplemented
4447
schema_validators_factory: SchemaValidatorsFactory = NotImplemented
4548

4649
def __init__(
@@ -50,6 +53,7 @@ def __init__(
5053
schema_casters_factory: SchemaCastersFactory = schema_casters_factory,
5154
parameter_deserializers_factory: ParameterDeserializersFactory = parameter_deserializers_factory,
5255
media_type_deserializers_factory: MediaTypeDeserializersFactory = media_type_deserializers_factory,
56+
path_finder_cls: Optional[PathFinderType] = None,
5357
schema_validators_factory: Optional[SchemaValidatorsFactory] = None,
5458
format_validators: Optional[FormatValidatorsDict] = None,
5559
extra_format_validators: Optional[FormatValidatorsDict] = None,
@@ -65,6 +69,9 @@ def __init__(
6569
self.media_type_deserializers_factory = (
6670
media_type_deserializers_factory
6771
)
72+
self.path_finder_cls = path_finder_cls or self.path_finder_cls
73+
if self.path_finder_cls is NotImplemented: # type: ignore[comparison-overlap]
74+
raise NotImplementedError("path_finder_cls is not assigned")
6875
self.schema_validators_factory = (
6976
schema_validators_factory or self.schema_validators_factory
7077
)
@@ -76,6 +83,10 @@ def __init__(
7683
self.extra_format_validators = extra_format_validators
7784
self.extra_media_type_deserializers = extra_media_type_deserializers
7885

86+
@cached_property
87+
def path_finder(self) -> PathFinder:
88+
return self.path_finder_cls(self.spec, base_url=self.base_url)
89+
7990
def _get_media_type(self, content: Spec, mimetype: str) -> MediaType:
8091
from openapi_core.templating.media_types.finders import MediaTypeFinder
8192

@@ -176,9 +187,7 @@ def _get_content_value_and_schema(
176187

177188

178189
class BaseAPICallValidator(BaseValidator):
179-
@cached_property
180-
def path_finder(self) -> BasePathFinder:
181-
return APICallPathFinder(self.spec, base_url=self.base_url)
190+
path_finder_cls = APICallPathFinder
182191

183192
def _find_path(self, request: Request) -> PathOperationServer:
184193
path_pattern = getattr(request, "path_pattern", None) or request.path
@@ -187,9 +196,7 @@ def _find_path(self, request: Request) -> PathOperationServer:
187196

188197

189198
class BaseWebhookValidator(BaseValidator):
190-
@cached_property
191-
def path_finder(self) -> BasePathFinder:
192-
return WebhookPathFinder(self.spec, base_url=self.base_url)
199+
path_finder_cls = WebhookPathFinder
193200

194201
def _find_path(self, request: WebhookRequest) -> PathOperationServer:
195202
return self.path_finder.find(request.method, request.name)

‎poetry.lock

+318-38
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

‎pyproject.toml

+6
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,9 @@ output = "reports/coverage.xml"
1111

1212
[tool.mypy]
1313
files = "openapi_core"
14+
plugins = [
15+
"pydantic.mypy"
16+
]
1417
strict = true
1518

1619
[[tool.mypy.overrides]]
@@ -72,6 +75,7 @@ jsonschema-spec = "^0.1.1"
7275
backports-cached-property = {version = "^1.0.2", python = "<3.8" }
7376
sphinx = {version = "^5.3.0", optional = true}
7477
sphinx-immaterial = {version = "^0.11.0", optional = true}
78+
pydantic = "^1.10.7"
7579

7680
[tool.poetry.extras]
7781
docs = ["sphinx", "sphinx-immaterial"]
@@ -83,11 +87,13 @@ starlette = ["starlette", "httpx"]
8387

8488
[tool.poetry.dev-dependencies]
8589
black = "^23.1.0"
90+
boto3 = "^1.26.96"
8691
django = ">=3.0"
8792
djangorestframework = "^3.11.2"
8893
falcon = ">=3.0"
8994
flask = "*"
9095
isort = "^5.11.5"
96+
moto = "^4.1.5"
9197
pre-commit = "*"
9298
pytest = "^7"
9399
pytest-flake8 = "*"

0 commit comments

Comments
 (0)
Please sign in to comment.