From 2719731ab717e3af30ee29850e54cbce7d3030f1 Mon Sep 17 00:00:00 2001 From: p1c2u Date: Fri, 23 Feb 2024 21:22:04 +0000 Subject: [PATCH] Path finder cls configuration --- openapi_core/app.py | 41 ++-- openapi_core/templating/paths/__init__.py | 7 + openapi_core/templating/paths/finders.py | 160 +++------------ openapi_core/templating/paths/iterators.py | 185 ++++++++++++++++++ openapi_core/templating/paths/protocols.py | 39 ++++ openapi_core/templating/paths/types.py | 5 + .../unmarshalling/request/protocols.py | 3 + .../unmarshalling/request/unmarshallers.py | 4 + .../unmarshalling/response/protocols.py | 3 + openapi_core/unmarshalling/unmarshallers.py | 3 + openapi_core/validation/configurations.py | 7 + openapi_core/validation/request/protocols.py | 3 + openapi_core/validation/request/validators.py | 9 +- openapi_core/validation/response/protocols.py | 3 + openapi_core/validation/validators.py | 18 +- 15 files changed, 334 insertions(+), 156 deletions(-) create mode 100644 openapi_core/templating/paths/iterators.py create mode 100644 openapi_core/templating/paths/protocols.py create mode 100644 openapi_core/templating/paths/types.py diff --git a/openapi_core/app.py b/openapi_core/app.py index 5a2c5588..50c73904 100644 --- a/openapi_core/app.py +++ b/openapi_core/app.py @@ -1,5 +1,6 @@ """OpenAPI core app module""" +from functools import cached_property from pathlib import Path from typing import Optional @@ -142,19 +143,19 @@ def check_spec(self) -> None: def version(self) -> SpecVersion: return self._get_version() - @property + @cached_property def request_validator_cls(self) -> Optional[RequestValidatorType]: if not isinstance(self.config.request_validator_cls, Unset): return self.config.request_validator_cls return REQUEST_VALIDATORS.get(self.version) - @property + @cached_property def response_validator_cls(self) -> Optional[ResponseValidatorType]: if not isinstance(self.config.response_validator_cls, Unset): return self.config.response_validator_cls return RESPONSE_VALIDATORS.get(self.version) - @property + @cached_property def webhook_request_validator_cls( self, ) -> Optional[WebhookRequestValidatorType]: @@ -162,7 +163,7 @@ def webhook_request_validator_cls( return self.config.webhook_request_validator_cls return WEBHOOK_REQUEST_VALIDATORS.get(self.version) - @property + @cached_property def webhook_response_validator_cls( self, ) -> Optional[WebhookResponseValidatorType]: @@ -170,19 +171,19 @@ def webhook_response_validator_cls( return self.config.webhook_response_validator_cls return WEBHOOK_RESPONSE_VALIDATORS.get(self.version) - @property + @cached_property def request_unmarshaller_cls(self) -> Optional[RequestUnmarshallerType]: if not isinstance(self.config.request_unmarshaller_cls, Unset): return self.config.request_unmarshaller_cls return REQUEST_UNMARSHALLERS.get(self.version) - @property + @cached_property def response_unmarshaller_cls(self) -> Optional[ResponseUnmarshallerType]: if not isinstance(self.config.response_unmarshaller_cls, Unset): return self.config.response_unmarshaller_cls return RESPONSE_UNMARSHALLERS.get(self.version) - @property + @cached_property def webhook_request_unmarshaller_cls( self, ) -> Optional[WebhookRequestUnmarshallerType]: @@ -190,7 +191,7 @@ def webhook_request_unmarshaller_cls( return self.config.webhook_request_unmarshaller_cls return WEBHOOK_REQUEST_UNMARSHALLERS.get(self.version) - @property + @cached_property def webhook_response_unmarshaller_cls( self, ) -> Optional[WebhookResponseUnmarshallerType]: @@ -200,7 +201,7 @@ def webhook_response_unmarshaller_cls( return self.config.webhook_response_unmarshaller_cls return WEBHOOK_RESPONSE_UNMARSHALLERS.get(self.version) - @property + @cached_property def request_validator(self) -> RequestValidator: if self.request_validator_cls is None: raise SpecError("Validator class not found") @@ -211,13 +212,14 @@ def request_validator(self) -> RequestValidator: media_type_deserializers_factory=self.config.media_type_deserializers_factory, schema_casters_factory=self.config.schema_casters_factory, schema_validators_factory=self.config.schema_validators_factory, + path_finder_cls=self.config.path_finder_cls, spec_validator_cls=self.config.spec_validator_cls, extra_format_validators=self.config.extra_format_validators, extra_media_type_deserializers=self.config.extra_media_type_deserializers, security_provider_factory=self.config.security_provider_factory, ) - @property + @cached_property def response_validator(self) -> ResponseValidator: if self.response_validator_cls is None: raise SpecError("Validator class not found") @@ -228,12 +230,13 @@ def response_validator(self) -> ResponseValidator: media_type_deserializers_factory=self.config.media_type_deserializers_factory, schema_casters_factory=self.config.schema_casters_factory, schema_validators_factory=self.config.schema_validators_factory, + path_finder_cls=self.config.path_finder_cls, spec_validator_cls=self.config.spec_validator_cls, extra_format_validators=self.config.extra_format_validators, extra_media_type_deserializers=self.config.extra_media_type_deserializers, ) - @property + @cached_property def webhook_request_validator(self) -> WebhookRequestValidator: if self.webhook_request_validator_cls is None: raise SpecError("Validator class not found") @@ -244,13 +247,14 @@ def webhook_request_validator(self) -> WebhookRequestValidator: media_type_deserializers_factory=self.config.media_type_deserializers_factory, schema_casters_factory=self.config.schema_casters_factory, schema_validators_factory=self.config.schema_validators_factory, + path_finder_cls=self.config.webhook_path_finder_cls, spec_validator_cls=self.config.spec_validator_cls, extra_format_validators=self.config.extra_format_validators, extra_media_type_deserializers=self.config.extra_media_type_deserializers, security_provider_factory=self.config.security_provider_factory, ) - @property + @cached_property def webhook_response_validator(self) -> WebhookResponseValidator: if self.webhook_response_validator_cls is None: raise SpecError("Validator class not found") @@ -261,12 +265,13 @@ def webhook_response_validator(self) -> WebhookResponseValidator: media_type_deserializers_factory=self.config.media_type_deserializers_factory, schema_casters_factory=self.config.schema_casters_factory, schema_validators_factory=self.config.schema_validators_factory, + path_finder_cls=self.config.webhook_path_finder_cls, spec_validator_cls=self.config.spec_validator_cls, extra_format_validators=self.config.extra_format_validators, extra_media_type_deserializers=self.config.extra_media_type_deserializers, ) - @property + @cached_property def request_unmarshaller(self) -> RequestUnmarshaller: if self.request_unmarshaller_cls is None: raise SpecError("Unmarshaller class not found") @@ -277,6 +282,7 @@ def request_unmarshaller(self) -> RequestUnmarshaller: media_type_deserializers_factory=self.config.media_type_deserializers_factory, schema_casters_factory=self.config.schema_casters_factory, schema_validators_factory=self.config.schema_validators_factory, + path_finder_cls=self.config.path_finder_cls, spec_validator_cls=self.config.spec_validator_cls, extra_format_validators=self.config.extra_format_validators, extra_media_type_deserializers=self.config.extra_media_type_deserializers, @@ -285,7 +291,7 @@ def request_unmarshaller(self) -> RequestUnmarshaller: extra_format_unmarshallers=self.config.extra_format_unmarshallers, ) - @property + @cached_property def response_unmarshaller(self) -> ResponseUnmarshaller: if self.response_unmarshaller_cls is None: raise SpecError("Unmarshaller class not found") @@ -296,6 +302,7 @@ def response_unmarshaller(self) -> ResponseUnmarshaller: media_type_deserializers_factory=self.config.media_type_deserializers_factory, schema_casters_factory=self.config.schema_casters_factory, schema_validators_factory=self.config.schema_validators_factory, + path_finder_cls=self.config.path_finder_cls, spec_validator_cls=self.config.spec_validator_cls, extra_format_validators=self.config.extra_format_validators, extra_media_type_deserializers=self.config.extra_media_type_deserializers, @@ -303,7 +310,7 @@ def response_unmarshaller(self) -> ResponseUnmarshaller: extra_format_unmarshallers=self.config.extra_format_unmarshallers, ) - @property + @cached_property def webhook_request_unmarshaller(self) -> WebhookRequestUnmarshaller: if self.webhook_request_unmarshaller_cls is None: raise SpecError("Unmarshaller class not found") @@ -314,6 +321,7 @@ def webhook_request_unmarshaller(self) -> WebhookRequestUnmarshaller: media_type_deserializers_factory=self.config.media_type_deserializers_factory, schema_casters_factory=self.config.schema_casters_factory, schema_validators_factory=self.config.schema_validators_factory, + path_finder_cls=self.config.webhook_path_finder_cls, spec_validator_cls=self.config.spec_validator_cls, extra_format_validators=self.config.extra_format_validators, extra_media_type_deserializers=self.config.extra_media_type_deserializers, @@ -322,7 +330,7 @@ def webhook_request_unmarshaller(self) -> WebhookRequestUnmarshaller: extra_format_unmarshallers=self.config.extra_format_unmarshallers, ) - @property + @cached_property def webhook_response_unmarshaller(self) -> WebhookResponseUnmarshaller: if self.webhook_response_unmarshaller_cls is None: raise SpecError("Unmarshaller class not found") @@ -333,6 +341,7 @@ def webhook_response_unmarshaller(self) -> WebhookResponseUnmarshaller: media_type_deserializers_factory=self.config.media_type_deserializers_factory, schema_casters_factory=self.config.schema_casters_factory, schema_validators_factory=self.config.schema_validators_factory, + path_finder_cls=self.config.webhook_path_finder_cls, spec_validator_cls=self.config.spec_validator_cls, extra_format_validators=self.config.extra_format_validators, extra_media_type_deserializers=self.config.extra_media_type_deserializers, diff --git a/openapi_core/templating/paths/__init__.py b/openapi_core/templating/paths/__init__.py index e69de29b..93e94f74 100644 --- a/openapi_core/templating/paths/__init__.py +++ b/openapi_core/templating/paths/__init__.py @@ -0,0 +1,7 @@ +from openapi_core.templating.paths.finders import APICallPathFinder +from openapi_core.templating.paths.finders import WebhookPathFinder + +__all__ = [ + "APICallPathFinder", + "WebhookPathFinder", +] diff --git a/openapi_core/templating/paths/finders.py b/openapi_core/templating/paths/finders.py index 4c0c04d0..bd4dc033 100644 --- a/openapi_core/templating/paths/finders.py +++ b/openapi_core/templating/paths/finders.py @@ -1,49 +1,57 @@ """OpenAPI core templating paths finders module""" -from typing import Iterator -from typing import List from typing import Optional -from urllib.parse import urljoin -from urllib.parse import urlparse from jsonschema_path import SchemaPath from more_itertools import peekable -from openapi_core.schema.servers import is_absolute -from openapi_core.templating.datatypes import TemplateResult -from openapi_core.templating.paths.datatypes import Path -from openapi_core.templating.paths.datatypes import PathOperation from openapi_core.templating.paths.datatypes import PathOperationServer from openapi_core.templating.paths.exceptions import OperationNotFound from openapi_core.templating.paths.exceptions import PathNotFound -from openapi_core.templating.paths.exceptions import PathsNotFound from openapi_core.templating.paths.exceptions import ServerNotFound -from openapi_core.templating.paths.util import template_path_len -from openapi_core.templating.util import parse -from openapi_core.templating.util import search +from openapi_core.templating.paths.iterators import SimpleOperationsIterator +from openapi_core.templating.paths.iterators import SimplePathsIterator +from openapi_core.templating.paths.iterators import SimpleServersIterator +from openapi_core.templating.paths.iterators import TemplatePathsIterator +from openapi_core.templating.paths.iterators import TemplateServersIterator +from openapi_core.templating.paths.protocols import OperationsIterator +from openapi_core.templating.paths.protocols import PathsIterator +from openapi_core.templating.paths.protocols import ServersIterator class BasePathFinder: + paths_iterator: PathsIterator = NotImplemented + operations_iterator: OperationsIterator = NotImplemented + servers_iterator: ServersIterator = NotImplemented + def __init__(self, spec: SchemaPath, base_url: Optional[str] = None): self.spec = spec self.base_url = base_url def find(self, method: str, name: str) -> PathOperationServer: - paths_iter = self._get_paths_iter(name) + paths_iter = self.paths_iterator( + name, + self.spec, + base_url=self.base_url, + ) paths_iter_peek = peekable(paths_iter) if not paths_iter_peek: raise PathNotFound(name) - operations_iter = self._get_operations_iter(method, paths_iter_peek) + operations_iter = self.operations_iterator( + method, + paths_iter_peek, + self.spec, + base_url=self.base_url, + ) operations_iter_peek = peekable(operations_iter) if not operations_iter_peek: raise OperationNotFound(name, method) - servers_iter = self._get_servers_iter( - name, - operations_iter_peek, + servers_iter = self.servers_iterator( + name, operations_iter_peek, self.spec, base_url=self.base_url ) try: @@ -51,117 +59,13 @@ def find(self, method: str, name: str) -> PathOperationServer: except StopIteration: raise ServerNotFound(name) - def _get_paths_iter(self, name: str) -> Iterator[Path]: - raise NotImplementedError - - def _get_operations_iter( - self, method: str, paths_iter: Iterator[Path] - ) -> Iterator[PathOperation]: - for path, path_result in paths_iter: - if method not in path: - continue - operation = path / method - yield PathOperation(path, operation, path_result) - - def _get_servers_iter( - self, name: str, operations_iter: Iterator[PathOperation] - ) -> Iterator[PathOperationServer]: - raise NotImplementedError - class APICallPathFinder(BasePathFinder): - def __init__(self, spec: SchemaPath, base_url: Optional[str] = None): - self.spec = spec - self.base_url = base_url + paths_iterator: PathsIterator = TemplatePathsIterator("paths") + operations_iterator: OperationsIterator = SimpleOperationsIterator() + servers_iterator: ServersIterator = TemplateServersIterator() + - def _get_paths_iter(self, name: str) -> Iterator[Path]: - paths = self.spec / "paths" - if not paths.exists(): - raise PathsNotFound(paths.as_uri()) - template_paths: List[Path] = [] - for path_pattern, path in list(paths.items()): - # simple path. - # Return right away since it is always the most concrete - if name.endswith(path_pattern): - path_result = TemplateResult(path_pattern, {}) - yield Path(path, path_result) - # template path - else: - result = search(path_pattern, name) - if result: - path_result = TemplateResult(path_pattern, result.named) - template_paths.append(Path(path, path_result)) - - # Fewer variables -> more concrete path - yield from sorted(template_paths, key=template_path_len) - - def _get_servers_iter( - self, name: str, operations_iter: Iterator[PathOperation] - ) -> Iterator[PathOperationServer]: - for path, operation, path_result in operations_iter: - servers = ( - path.get("servers", None) - or operation.get("servers", None) - or self.spec.get("servers", None) - ) - if not servers: - servers = [SchemaPath.from_dict({"url": "/"})] - for server in servers: - server_url_pattern = name.rsplit(path_result.resolved, 1)[0] - server_url = server["url"] - if not is_absolute(server_url): - # relative to absolute url - if self.base_url is not None: - server_url = urljoin(self.base_url, server["url"]) - # if no base url check only path part - else: - server_url_pattern = urlparse(server_url_pattern).path - if server_url.endswith("/"): - server_url = server_url[:-1] - # simple path - if server_url_pattern == server_url: - server_result = TemplateResult(server["url"], {}) - yield PathOperationServer( - path, - operation, - server, - path_result, - server_result, - ) - # template path - else: - result = parse(server["url"], server_url_pattern) - if result: - server_result = TemplateResult( - server["url"], result.named - ) - yield PathOperationServer( - path, - operation, - server, - path_result, - server_result, - ) - - -class WebhookPathFinder(BasePathFinder): - def _get_paths_iter(self, name: str) -> Iterator[Path]: - webhooks = self.spec / "webhooks" - if not webhooks.exists(): - raise PathsNotFound(webhooks.as_uri()) - for webhook_name, path in list(webhooks.items()): - if name == webhook_name: - path_result = TemplateResult(webhook_name, {}) - yield Path(path, path_result) - - def _get_servers_iter( - self, name: str, operations_iter: Iterator[PathOperation] - ) -> Iterator[PathOperationServer]: - for path, operation, path_result in operations_iter: - yield PathOperationServer( - path, - operation, - None, - path_result, - {}, - ) +class WebhookPathFinder(APICallPathFinder): + paths_iterator = SimplePathsIterator("webhooks") + servers_iterator = SimpleServersIterator() diff --git a/openapi_core/templating/paths/iterators.py b/openapi_core/templating/paths/iterators.py new file mode 100644 index 00000000..f78d3342 --- /dev/null +++ b/openapi_core/templating/paths/iterators.py @@ -0,0 +1,185 @@ +from typing import Iterator +from typing import List +from typing import Optional +from urllib.parse import urljoin +from urllib.parse import urlparse + +from jsonschema_path import SchemaPath + +from openapi_core.schema.servers import is_absolute +from openapi_core.templating.datatypes import TemplateResult +from openapi_core.templating.paths.datatypes import Path +from openapi_core.templating.paths.datatypes import PathOperation +from openapi_core.templating.paths.datatypes import PathOperationServer +from openapi_core.templating.paths.exceptions import PathsNotFound +from openapi_core.templating.paths.util import template_path_len +from openapi_core.templating.util import parse +from openapi_core.templating.util import search + + +class SimplePathsIterator: + def __init__(self, paths_part: str): + self.paths_part = paths_part + + def __call__( + self, name: str, spec: SchemaPath, base_url: Optional[str] = None + ) -> Iterator[Path]: + paths = spec / self.paths_part + if not paths.exists(): + raise PathsNotFound(paths.as_uri()) + for path_name, path in list(paths.items()): + if name == path_name: + path_result = TemplateResult(path_name, {}) + yield Path(path, path_result) + + +class TemplatePathsIterator: + def __init__(self, paths_part: str): + self.paths_part = paths_part + + def __call__( + self, name: str, spec: SchemaPath, base_url: Optional[str] = None + ) -> Iterator[Path]: + paths = spec / self.paths_part + if not paths.exists(): + raise PathsNotFound(paths.as_uri()) + template_paths: List[Path] = [] + for path_pattern, path in list(paths.items()): + # simple path. + # Return right away since it is always the most concrete + if name.endswith(path_pattern): + path_result = TemplateResult(path_pattern, {}) + yield Path(path, path_result) + # template path + else: + result = search(path_pattern, name) + if result: + path_result = TemplateResult(path_pattern, result.named) + template_paths.append(Path(path, path_result)) + + # Fewer variables -> more concrete path + yield from sorted(template_paths, key=template_path_len) + + +class SimpleOperationsIterator: + def __call__( + self, + method: str, + paths_iter: Iterator[Path], + spec: SchemaPath, + base_url: Optional[str] = None, + ) -> Iterator[PathOperation]: + for path, path_result in paths_iter: + if method not in path: + continue + operation = path / method + yield PathOperation(path, operation, path_result) + + +class CatchAllMethodOperationsIterator(SimpleOperationsIterator): + def __init__(self, ca_method_name: str, ca_operation_name: str): + self.ca_method_name = ca_method_name + self.ca_operation_name = ca_operation_name + + def __call__( + self, + method: str, + paths_iter: Iterator[Path], + spec: SchemaPath, + base_url: Optional[str] = None, + ) -> Iterator[PathOperation]: + if method == self.ca_method_name: + yield from super().__call__( + self.ca_operation_name, paths_iter, spec, base_url=base_url + ) + else: + yield from super().__call__( + method, paths_iter, spec, base_url=base_url + ) + + +class SimpleServersIterator: + def __call__( + self, + name: str, + operations_iter: Iterator[PathOperation], + spec: SchemaPath, + base_url: Optional[str] = None, + ) -> Iterator[PathOperationServer]: + for path, operation, path_result in operations_iter: + yield PathOperationServer( + path, + operation, + None, + path_result, + {}, + ) + + +class TemplateServersIterator: + def __call__( + self, + name: str, + operations_iter: Iterator[PathOperation], + spec: SchemaPath, + base_url: Optional[str] = None, + ) -> Iterator[PathOperationServer]: + for path, operation, path_result in operations_iter: + servers = ( + path.get("servers", None) + or operation.get("servers", None) + or spec.get("servers", None) + ) + if not servers: + servers = [SchemaPath.from_dict({"url": "/"})] + for server in servers: + server_url_pattern = name.rsplit(path_result.resolved, 1)[0] + server_url = server["url"] + if not is_absolute(server_url): + # relative to absolute url + if base_url is not None: + server_url = urljoin(base_url, server["url"]) + # if no base url check only path part + else: + server_url_pattern = urlparse(server_url_pattern).path + if server_url.endswith("/"): + server_url = server_url[:-1] + # simple path + if server_url_pattern == server_url: + server_result = TemplateResult(server["url"], {}) + yield PathOperationServer( + path, + operation, + server, + path_result, + server_result, + ) + # template path + else: + result = parse(server["url"], server_url_pattern) + if result: + server_result = TemplateResult( + server["url"], result.named + ) + yield PathOperationServer( + path, + operation, + server, + path_result, + server_result, + ) + # servers should'n end with tailing slash + # but let's search for this too + server_url_pattern += "/" + result = parse(server["url"], server_url_pattern) + if result: + server_result = TemplateResult( + server["url"], result.named + ) + yield PathOperationServer( + path, + operation, + server, + path_result, + server_result, + ) diff --git a/openapi_core/templating/paths/protocols.py b/openapi_core/templating/paths/protocols.py new file mode 100644 index 00000000..e73c690c --- /dev/null +++ b/openapi_core/templating/paths/protocols.py @@ -0,0 +1,39 @@ +from typing import Iterator +from typing import Optional +from typing import Protocol +from typing import runtime_checkable + +from jsonschema_path import SchemaPath + +from openapi_core.templating.paths.datatypes import Path +from openapi_core.templating.paths.datatypes import PathOperation +from openapi_core.templating.paths.datatypes import PathOperationServer + + +@runtime_checkable +class PathsIterator(Protocol): + def __call__( + self, name: str, spec: SchemaPath, base_url: Optional[str] = None + ) -> Iterator[Path]: ... + + +@runtime_checkable +class OperationsIterator(Protocol): + def __call__( + self, + method: str, + paths_iter: Iterator[Path], + spec: SchemaPath, + base_url: Optional[str] = None, + ) -> Iterator[PathOperation]: ... + + +@runtime_checkable +class ServersIterator(Protocol): + def __call__( + self, + name: str, + operations_iter: Iterator[PathOperation], + spec: SchemaPath, + base_url: Optional[str] = None, + ) -> Iterator[PathOperationServer]: ... diff --git a/openapi_core/templating/paths/types.py b/openapi_core/templating/paths/types.py new file mode 100644 index 00000000..6067a18a --- /dev/null +++ b/openapi_core/templating/paths/types.py @@ -0,0 +1,5 @@ +from typing import Type + +from openapi_core.templating.paths.finders import BasePathFinder + +PathFinderType = Type[BasePathFinder] diff --git a/openapi_core/unmarshalling/request/protocols.py b/openapi_core/unmarshalling/request/protocols.py index 0c725191..43a18cbe 100644 --- a/openapi_core/unmarshalling/request/protocols.py +++ b/openapi_core/unmarshalling/request/protocols.py @@ -25,6 +25,7 @@ from openapi_core.protocols import WebhookRequest from openapi_core.security import security_provider_factory from openapi_core.security.factories import SecurityProviderFactory +from openapi_core.templating.paths.types import PathFinderType from openapi_core.unmarshalling.request.datatypes import RequestUnmarshalResult from openapi_core.unmarshalling.schemas.datatypes import ( FormatUnmarshallersDict, @@ -46,6 +47,7 @@ def __init__( media_type_deserializers_factory: MediaTypeDeserializersFactory = media_type_deserializers_factory, schema_casters_factory: Optional[SchemaCastersFactory] = None, schema_validators_factory: Optional[SchemaValidatorsFactory] = None, + path_finder_cls: Optional[PathFinderType] = None, spec_validator_cls: Optional[SpecValidatorType] = None, format_validators: Optional[FormatValidatorsDict] = None, extra_format_validators: Optional[FormatValidatorsDict] = None, @@ -76,6 +78,7 @@ def __init__( media_type_deserializers_factory: MediaTypeDeserializersFactory = media_type_deserializers_factory, schema_casters_factory: Optional[SchemaCastersFactory] = None, schema_validators_factory: Optional[SchemaValidatorsFactory] = None, + path_finder_cls: Optional[PathFinderType] = None, spec_validator_cls: Optional[SpecValidatorType] = None, format_validators: Optional[FormatValidatorsDict] = None, extra_format_validators: Optional[FormatValidatorsDict] = None, diff --git a/openapi_core/unmarshalling/request/unmarshallers.py b/openapi_core/unmarshalling/request/unmarshallers.py index 10f69b69..efd45930 100644 --- a/openapi_core/unmarshalling/request/unmarshallers.py +++ b/openapi_core/unmarshalling/request/unmarshallers.py @@ -23,6 +23,7 @@ from openapi_core.security import security_provider_factory from openapi_core.security.factories import SecurityProviderFactory from openapi_core.templating.paths.exceptions import PathError +from openapi_core.templating.paths.types import PathFinderType from openapi_core.unmarshalling.request.datatypes import RequestUnmarshalResult from openapi_core.unmarshalling.schemas import ( oas30_write_schema_unmarshallers_factory, @@ -88,6 +89,7 @@ def __init__( media_type_deserializers_factory: MediaTypeDeserializersFactory = media_type_deserializers_factory, schema_casters_factory: Optional[SchemaCastersFactory] = None, schema_validators_factory: Optional[SchemaValidatorsFactory] = None, + path_finder_cls: Optional[PathFinderType] = None, spec_validator_cls: Optional[SpecValidatorType] = None, format_validators: Optional[FormatValidatorsDict] = None, extra_format_validators: Optional[FormatValidatorsDict] = None, @@ -109,6 +111,7 @@ def __init__( media_type_deserializers_factory=media_type_deserializers_factory, schema_casters_factory=schema_casters_factory, schema_validators_factory=schema_validators_factory, + path_finder_cls=path_finder_cls, spec_validator_cls=spec_validator_cls, format_validators=format_validators, extra_format_validators=extra_format_validators, @@ -125,6 +128,7 @@ def __init__( media_type_deserializers_factory=media_type_deserializers_factory, schema_casters_factory=schema_casters_factory, schema_validators_factory=schema_validators_factory, + path_finder_cls=path_finder_cls, spec_validator_cls=spec_validator_cls, format_validators=format_validators, extra_format_validators=extra_format_validators, diff --git a/openapi_core/unmarshalling/response/protocols.py b/openapi_core/unmarshalling/response/protocols.py index edb6fde4..de90c58d 100644 --- a/openapi_core/unmarshalling/response/protocols.py +++ b/openapi_core/unmarshalling/response/protocols.py @@ -24,6 +24,7 @@ from openapi_core.protocols import Request from openapi_core.protocols import Response from openapi_core.protocols import WebhookRequest +from openapi_core.templating.paths.types import PathFinderType from openapi_core.unmarshalling.response.datatypes import ( ResponseUnmarshalResult, ) @@ -47,6 +48,7 @@ def __init__( media_type_deserializers_factory: MediaTypeDeserializersFactory = media_type_deserializers_factory, schema_casters_factory: Optional[SchemaCastersFactory] = None, schema_validators_factory: Optional[SchemaValidatorsFactory] = None, + path_finder_cls: Optional[PathFinderType] = None, spec_validator_cls: Optional[SpecValidatorType] = None, format_validators: Optional[FormatValidatorsDict] = None, extra_format_validators: Optional[FormatValidatorsDict] = None, @@ -77,6 +79,7 @@ def __init__( media_type_deserializers_factory: MediaTypeDeserializersFactory = media_type_deserializers_factory, schema_casters_factory: Optional[SchemaCastersFactory] = None, schema_validators_factory: Optional[SchemaValidatorsFactory] = None, + path_finder_cls: Optional[PathFinderType] = None, spec_validator_cls: Optional[SpecValidatorType] = None, format_validators: Optional[FormatValidatorsDict] = None, extra_format_validators: Optional[FormatValidatorsDict] = None, diff --git a/openapi_core/unmarshalling/unmarshallers.py b/openapi_core/unmarshalling/unmarshallers.py index 9869b9c7..ddc8b891 100644 --- a/openapi_core/unmarshalling/unmarshallers.py +++ b/openapi_core/unmarshalling/unmarshallers.py @@ -20,6 +20,7 @@ from openapi_core.deserializing.styles.factories import ( StyleDeserializersFactory, ) +from openapi_core.templating.paths.types import PathFinderType from openapi_core.unmarshalling.schemas.datatypes import ( FormatUnmarshallersDict, ) @@ -42,6 +43,7 @@ def __init__( media_type_deserializers_factory: MediaTypeDeserializersFactory = media_type_deserializers_factory, schema_casters_factory: Optional[SchemaCastersFactory] = None, schema_validators_factory: Optional[SchemaValidatorsFactory] = None, + path_finder_cls: Optional[PathFinderType] = None, spec_validator_cls: Optional[SpecValidatorType] = None, format_validators: Optional[FormatValidatorsDict] = None, extra_format_validators: Optional[FormatValidatorsDict] = None, @@ -65,6 +67,7 @@ def __init__( media_type_deserializers_factory=media_type_deserializers_factory, schema_casters_factory=schema_casters_factory, schema_validators_factory=schema_validators_factory, + path_finder_cls=path_finder_cls, spec_validator_cls=spec_validator_cls, format_validators=format_validators, extra_format_validators=extra_format_validators, diff --git a/openapi_core/validation/configurations.py b/openapi_core/validation/configurations.py index 17149428..ebc32fc4 100644 --- a/openapi_core/validation/configurations.py +++ b/openapi_core/validation/configurations.py @@ -17,6 +17,7 @@ ) from openapi_core.security import security_provider_factory from openapi_core.security.factories import SecurityProviderFactory +from openapi_core.templating.paths.types import PathFinderType from openapi_core.validation.schemas.datatypes import FormatValidatorsDict from openapi_core.validation.schemas.factories import SchemaValidatorsFactory @@ -28,6 +29,10 @@ class ValidatorConfig: Attributes: server_base_url Server base URI. + path_finder_cls + Path finder class. + webhook_path_finder_cls + Webhook path finder class. style_deserializers_factory Style deserializers factory. media_type_deserializers_factory @@ -45,6 +50,8 @@ class ValidatorConfig: """ server_base_url: Optional[str] = None + path_finder_cls: Optional[PathFinderType] = None + webhook_path_finder_cls: Optional[PathFinderType] = None style_deserializers_factory: StyleDeserializersFactory = ( style_deserializers_factory diff --git a/openapi_core/validation/request/protocols.py b/openapi_core/validation/request/protocols.py index 2554e59e..983864e2 100644 --- a/openapi_core/validation/request/protocols.py +++ b/openapi_core/validation/request/protocols.py @@ -26,6 +26,7 @@ from openapi_core.protocols import WebhookRequest from openapi_core.security import security_provider_factory from openapi_core.security.factories import SecurityProviderFactory +from openapi_core.templating.paths.types import PathFinderType from openapi_core.validation.schemas.datatypes import FormatValidatorsDict from openapi_core.validation.schemas.factories import SchemaValidatorsFactory @@ -40,6 +41,7 @@ def __init__( media_type_deserializers_factory: MediaTypeDeserializersFactory = media_type_deserializers_factory, schema_casters_factory: Optional[SchemaCastersFactory] = None, schema_validators_factory: Optional[SchemaValidatorsFactory] = None, + path_finder_cls: Optional[PathFinderType] = None, spec_validator_cls: Optional[SpecValidatorType] = None, format_validators: Optional[FormatValidatorsDict] = None, extra_format_validators: Optional[FormatValidatorsDict] = None, @@ -70,6 +72,7 @@ def __init__( media_type_deserializers_factory: MediaTypeDeserializersFactory = media_type_deserializers_factory, schema_casters_factory: Optional[SchemaCastersFactory] = None, schema_validators_factory: Optional[SchemaValidatorsFactory] = None, + path_finder_cls: Optional[PathFinderType] = None, spec_validator_cls: Optional[SpecValidatorType] = None, format_validators: Optional[FormatValidatorsDict] = None, extra_format_validators: Optional[FormatValidatorsDict] = None, diff --git a/openapi_core/validation/request/validators.py b/openapi_core/validation/request/validators.py index 4d205416..34e23ecd 100644 --- a/openapi_core/validation/request/validators.py +++ b/openapi_core/validation/request/validators.py @@ -36,7 +36,7 @@ from openapi_core.security.exceptions import SecurityProviderError from openapi_core.security.factories import SecurityProviderFactory from openapi_core.templating.paths.exceptions import PathError -from openapi_core.templating.paths.finders import WebhookPathFinder +from openapi_core.templating.paths.types import PathFinderType from openapi_core.templating.security.exceptions import SecurityNotFound from openapi_core.util import chainiters from openapi_core.validation.decorators import ValidationErrorWrapper @@ -75,6 +75,7 @@ def __init__( media_type_deserializers_factory: MediaTypeDeserializersFactory = media_type_deserializers_factory, schema_casters_factory: Optional[SchemaCastersFactory] = None, schema_validators_factory: Optional[SchemaValidatorsFactory] = None, + path_finder_cls: Optional[PathFinderType] = None, spec_validator_cls: Optional[SpecValidatorType] = None, format_validators: Optional[FormatValidatorsDict] = None, extra_format_validators: Optional[FormatValidatorsDict] = None, @@ -90,6 +91,7 @@ def __init__( media_type_deserializers_factory=media_type_deserializers_factory, schema_casters_factory=schema_casters_factory, schema_validators_factory=schema_validators_factory, + path_finder_cls=path_finder_cls, spec_validator_cls=spec_validator_cls, format_validators=format_validators, extra_format_validators=extra_format_validators, @@ -444,32 +446,27 @@ class V31RequestValidator(APICallRequestValidator): spec_validator_cls = OpenAPIV31SpecValidator schema_casters_factory = oas31_schema_casters_factory schema_validators_factory = oas31_schema_validators_factory - path_finder_cls = WebhookPathFinder class V31WebhookRequestBodyValidator(WebhookRequestBodyValidator): spec_validator_cls = OpenAPIV31SpecValidator schema_casters_factory = oas31_schema_casters_factory schema_validators_factory = oas31_schema_validators_factory - path_finder_cls = WebhookPathFinder class V31WebhookRequestParametersValidator(WebhookRequestParametersValidator): spec_validator_cls = OpenAPIV31SpecValidator schema_casters_factory = oas31_schema_casters_factory schema_validators_factory = oas31_schema_validators_factory - path_finder_cls = WebhookPathFinder class V31WebhookRequestSecurityValidator(WebhookRequestSecurityValidator): spec_validator_cls = OpenAPIV31SpecValidator schema_casters_factory = oas31_schema_casters_factory schema_validators_factory = oas31_schema_validators_factory - path_finder_cls = WebhookPathFinder class V31WebhookRequestValidator(WebhookRequestValidator): spec_validator_cls = OpenAPIV31SpecValidator schema_casters_factory = oas31_schema_casters_factory schema_validators_factory = oas31_schema_validators_factory - path_finder_cls = WebhookPathFinder diff --git a/openapi_core/validation/response/protocols.py b/openapi_core/validation/response/protocols.py index 168c6483..f0f33dc6 100644 --- a/openapi_core/validation/response/protocols.py +++ b/openapi_core/validation/response/protocols.py @@ -25,6 +25,7 @@ from openapi_core.protocols import Request from openapi_core.protocols import Response from openapi_core.protocols import WebhookRequest +from openapi_core.templating.paths.types import PathFinderType from openapi_core.validation.schemas.datatypes import FormatValidatorsDict from openapi_core.validation.schemas.factories import SchemaValidatorsFactory @@ -39,6 +40,7 @@ def __init__( media_type_deserializers_factory: MediaTypeDeserializersFactory = media_type_deserializers_factory, schema_casters_factory: Optional[SchemaCastersFactory] = None, schema_validators_factory: Optional[SchemaValidatorsFactory] = None, + path_finder_cls: Optional[PathFinderType] = None, spec_validator_cls: Optional[SpecValidatorType] = None, format_validators: Optional[FormatValidatorsDict] = None, extra_format_validators: Optional[FormatValidatorsDict] = None, @@ -70,6 +72,7 @@ def __init__( media_type_deserializers_factory: MediaTypeDeserializersFactory = media_type_deserializers_factory, schema_casters_factory: Optional[SchemaCastersFactory] = None, schema_validators_factory: Optional[SchemaValidatorsFactory] = None, + path_finder_cls: Optional[PathFinderType] = None, spec_validator_cls: Optional[SpecValidatorType] = None, format_validators: Optional[FormatValidatorsDict] = None, extra_format_validators: Optional[FormatValidatorsDict] = None, diff --git a/openapi_core/validation/validators.py b/openapi_core/validation/validators.py index 4389e118..09275368 100644 --- a/openapi_core/validation/validators.py +++ b/openapi_core/validation/validators.py @@ -36,6 +36,7 @@ from openapi_core.templating.paths.finders import APICallPathFinder from openapi_core.templating.paths.finders import BasePathFinder from openapi_core.templating.paths.finders import WebhookPathFinder +from openapi_core.templating.paths.types import PathFinderType from openapi_core.validation.schemas.datatypes import FormatValidatorsDict from openapi_core.validation.schemas.factories import SchemaValidatorsFactory @@ -43,6 +44,7 @@ class BaseValidator: schema_casters_factory: SchemaCastersFactory = NotImplemented schema_validators_factory: SchemaValidatorsFactory = NotImplemented + path_finder_cls: PathFinderType = NotImplemented spec_validator_cls: Optional[SpecValidatorType] = None def __init__( @@ -53,6 +55,7 @@ def __init__( media_type_deserializers_factory: MediaTypeDeserializersFactory = media_type_deserializers_factory, schema_casters_factory: Optional[SchemaCastersFactory] = None, schema_validators_factory: Optional[SchemaValidatorsFactory] = None, + path_finder_cls: Optional[PathFinderType] = None, spec_validator_cls: Optional[SpecValidatorType] = None, format_validators: Optional[FormatValidatorsDict] = None, extra_format_validators: Optional[FormatValidatorsDict] = None, @@ -79,11 +82,18 @@ def __init__( raise NotImplementedError( "schema_validators_factory is not assigned" ) + self.path_finder_cls = path_finder_cls or self.path_finder_cls + if self.path_finder_cls is NotImplemented: # type: ignore[comparison-overlap] + raise NotImplementedError("path_finder_cls is not assigned") self.spec_validator_cls = spec_validator_cls or self.spec_validator_cls self.format_validators = format_validators self.extra_format_validators = extra_format_validators self.extra_media_type_deserializers = extra_media_type_deserializers + @cached_property + def path_finder(self) -> BasePathFinder: + return self.path_finder_cls(self.spec, base_url=self.base_url) + def check_spec(self, spec: SchemaPath) -> None: if self.spec_validator_cls is None: return @@ -267,9 +277,7 @@ def _get_media_type_value( class BaseAPICallValidator(BaseValidator): - @cached_property - def path_finder(self) -> BasePathFinder: - return APICallPathFinder(self.spec, base_url=self.base_url) + path_finder_cls = APICallPathFinder def _find_path(self, request: Request) -> PathOperationServer: path_pattern = getattr(request, "path_pattern", None) or request.path @@ -278,9 +286,7 @@ def _find_path(self, request: Request) -> PathOperationServer: class BaseWebhookValidator(BaseValidator): - @cached_property - def path_finder(self) -> BasePathFinder: - return WebhookPathFinder(self.spec, base_url=self.base_url) + path_finder_cls = WebhookPathFinder def _find_path(self, request: WebhookRequest) -> PathOperationServer: return self.path_finder.find(request.method, request.name)