Skip to content

Minimal recursive references support #69

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 4 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
190 changes: 122 additions & 68 deletions src/hypothesis_jsonschema/_canonicalise.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import re
from copy import deepcopy
from typing import Any, Dict, List, NoReturn, Optional, Tuple, Union
from urllib.parse import urljoin

import jsonschema
from hypothesis.errors import InvalidArgument
Expand Down Expand Up @@ -78,9 +79,20 @@ def _get_validator_class(schema: Schema) -> JSONSchemaValidator:
return validator


def make_validator(schema: Schema) -> JSONSchemaValidator:
class LocalResolver(jsonschema.RefResolver):
def resolve_remote(self, uri: str) -> NoReturn:
raise HypothesisRefResolutionError(
f"hypothesis-jsonschema does not fetch remote references (uri={uri!r})"
)


def make_validator(
schema: Schema, resolver: LocalResolver = None
) -> JSONSchemaValidator:
if resolver is None:
resolver = LocalResolver.from_schema(schema)
validator = _get_validator_class(schema)
return validator(schema)
return validator(schema, resolver=resolver)


class HypothesisRefResolutionError(jsonschema.exceptions.RefResolutionError):
Expand Down Expand Up @@ -202,7 +214,7 @@ def get_integer_bounds(schema: Schema) -> Tuple[Optional[int], Optional[int]]:
return lower, upper


def canonicalish(schema: JSONType) -> Dict[str, Any]:
def canonicalish(schema: JSONType, resolver: LocalResolver = None) -> Dict[str, Any]:
"""Convert a schema into a more-canonical form.

This is obviously incomplete, but improves best-effort recognition of
Expand All @@ -224,12 +236,15 @@ def canonicalish(schema: JSONType) -> Dict[str, Any]:
"but expected a dict."
)

if resolver is None:
resolver = LocalResolver.from_schema(schema)

if "const" in schema:
if not make_validator(schema).is_valid(schema["const"]):
if not make_validator(schema, resolver=resolver).is_valid(schema["const"]):
return FALSEY
return {"const": schema["const"]}
if "enum" in schema:
validator = make_validator(schema)
validator = make_validator(schema, resolver=resolver)
enum_ = sorted(
(v for v in schema["enum"] if validator.is_valid(v)), key=sort_key
)
Expand All @@ -253,15 +268,15 @@ def canonicalish(schema: JSONType) -> Dict[str, Any]:
# Recurse into the value of each keyword with a schema (or list of them) as a value
for key in SCHEMA_KEYS:
if isinstance(schema.get(key), list):
schema[key] = [canonicalish(v) for v in schema[key]]
schema[key] = [canonicalish(v, resolver=resolver) for v in schema[key]]
elif isinstance(schema.get(key), (bool, dict)):
schema[key] = canonicalish(schema[key])
schema[key] = canonicalish(schema[key], resolver=resolver)
else:
assert key not in schema, (key, schema[key])
for key in SCHEMA_OBJECT_KEYS:
if key in schema:
schema[key] = {
k: v if isinstance(v, list) else canonicalish(v)
k: v if isinstance(v, list) else canonicalish(v, resolver=resolver)
for k, v in schema[key].items()
}

Expand Down Expand Up @@ -307,7 +322,9 @@ def canonicalish(schema: JSONType) -> Dict[str, Any]:

if "array" in type_ and "contains" in schema:
if isinstance(schema.get("items"), dict):
contains_items = merged([schema["contains"], schema["items"]])
contains_items = merged(
[schema["contains"], schema["items"]], resolver=resolver
)
if contains_items is not None:
schema["contains"] = contains_items

Expand Down Expand Up @@ -461,9 +478,9 @@ def canonicalish(schema: JSONType) -> Dict[str, Any]:
type_.remove(t)
if t not in ("integer", "number"):
not_["type"].remove(t)
not_ = canonicalish(not_)
not_ = canonicalish(not_, resolver=resolver)

m = merged([not_, {**schema, "type": type_}])
m = merged([not_, {**schema, "type": type_}], resolver=resolver)
if m is not None:
not_ = m
if not_ != FALSEY:
Expand Down Expand Up @@ -542,7 +559,7 @@ def canonicalish(schema: JSONType) -> Dict[str, Any]:
else:
tmp = schema.copy()
ao = tmp.pop("allOf")
out = merged([tmp] + ao)
out = merged([tmp] + ao, resolver=resolver)
if isinstance(out, dict): # pragma: no branch
schema = out
# TODO: this assertion is soley because mypy 0.750 doesn't know
Expand All @@ -554,7 +571,7 @@ def canonicalish(schema: JSONType) -> Dict[str, Any]:
one_of = sorted(one_of, key=encode_canonical_json)
one_of = [s for s in one_of if s != FALSEY]
if len(one_of) == 1:
m = merged([schema, one_of[0]])
m = merged([schema, one_of[0]], resolver=resolver)
if m is not None: # pragma: no branch
return m
if (not one_of) or one_of.count(TRUTHY) > 1:
Expand All @@ -569,70 +586,99 @@ def canonicalish(schema: JSONType) -> Dict[str, Any]:
FALSEY = canonicalish(False)


class LocalResolver(jsonschema.RefResolver):
def resolve_remote(self, uri: str) -> NoReturn:
raise HypothesisRefResolutionError(
f"hypothesis-jsonschema does not fetch remote references (uri={uri!r})"
)
def is_recursive_reference(reference: str, resolver: LocalResolver) -> bool:
"""Detect if the given reference is recursive."""
# Special case: a reference to the schema's root is always recursive
if reference == "#":
return True
# During reference resolving the scope might go to external schemas. `hypothesis-jsonschema` does not support
# schemas behind remote references, but the underlying `jsonschema` library includes meta schemas for
# different JSON Schema drafts that are available transparently, and they count as external schemas in this context.
# For this reason we need to check the reference relatively to the base uri.
full_reference = urljoin(resolver.base_uri, reference)
# If a fully-qualified reference is in the resolution stack, then we encounter it for the second time.
# Therefore it is a recursive reference.
return full_reference in resolver._scopes_stack


def resolve_all_refs(
schema: Union[bool, Schema], *, resolver: LocalResolver = None
) -> Schema:
"""
Resolve all references in the given schema.
schema: Union[bool, Schema], *, resolver: LocalResolver
) -> Tuple[Schema, bool]:
"""Resolve all non-recursive references in the given schema.

This handles nested definitions, but not recursive definitions.
The latter require special handling to convert to strategies and are much
less common, so we just ignore them (and error out) for now.
When a recursive reference is detected, it stops traversing the currently resolving branch and leaves it as is.
"""
if isinstance(schema, bool):
return canonicalish(schema)
return canonicalish(schema), False
assert isinstance(schema, dict), schema
if resolver is None:
resolver = LocalResolver.from_schema(deepcopy(schema))
if not isinstance(resolver, jsonschema.RefResolver):
raise InvalidArgument(
f"resolver={resolver} (type {type(resolver).__name__}) is not a RefResolver"
)

if "$ref" in schema:
s = dict(schema)
ref = s.pop("$ref")
with resolver.resolving(ref) as got:
if s == {}:
return resolve_all_refs(got, resolver=resolver)
m = merged([s, got])
if m is None: # pragma: no cover
msg = f"$ref:{ref!r} had incompatible base schema {s!r}"
raise HypothesisRefResolutionError(msg)
return resolve_all_refs(m, resolver=resolver)
assert "$ref" not in schema
# Recursive references are skipped to avoid infinite recursion.
if not is_recursive_reference(schema["$ref"], resolver):
s = dict(schema)
ref = s.pop("$ref")
with resolver.resolving(ref) as got:
if s == {}:
return resolve_all_refs(deepcopy(got), resolver=resolver)
m = merged([s, got], resolver=resolver)
if m is None: # pragma: no cover
msg = f"$ref:{ref!r} had incompatible base schema {s!r}"
raise HypothesisRefResolutionError(msg)
# `deepcopy` is not needed, because, the schemas are copied inside the `merged` call above
return resolve_all_refs(m, resolver=resolver)
else:
return schema, True

for key in SCHEMA_KEYS:
val = schema.get(key, False)
if isinstance(val, list):
schema[key] = [
resolve_all_refs(v, resolver=resolver) if isinstance(v, dict) else v
for v in val
]
value = []
for v in val:
if isinstance(v, dict):
resolved, is_recursive = resolve_all_refs(
deepcopy(v), resolver=resolver
)
if is_recursive:
return schema, True
else:
value.append(resolved)
else:
value.append(v)
schema[key] = value
elif isinstance(val, dict):
schema[key] = resolve_all_refs(val, resolver=resolver)
resolved, is_recursive = resolve_all_refs(deepcopy(val), resolver=resolver)
if is_recursive:
return schema, True
else:
schema[key] = resolved
else:
assert isinstance(val, bool)
for key in SCHEMA_OBJECT_KEYS: # values are keys-to-schema-dicts, not schemas
if key in schema:
subschema = schema[key]
assert isinstance(subschema, dict)
schema[key] = {
k: resolve_all_refs(v, resolver=resolver) if isinstance(v, dict) else v
for k, v in subschema.items()
}
value = {}
for k, v in subschema.items():
if isinstance(v, dict):
resolved, is_recursive = resolve_all_refs(
deepcopy(v), resolver=resolver
)
if is_recursive:
return schema, True
else:
value[k] = resolved
else:
value[k] = v
schema[key] = value
assert isinstance(schema, dict)
return schema
return schema, False


def merged(schemas: List[Any]) -> Optional[Schema]:
def merged(schemas: List[Any], resolver: LocalResolver = None) -> Optional[Schema]:
"""Merge *n* schemas into a single schema, or None if result is invalid.

Takes the logical intersection, so any object that validates against the returned
Expand All @@ -645,7 +691,9 @@ def merged(schemas: List[Any]) -> Optional[Schema]:
It's currently also used for keys that could be merged but aren't yet.
"""
assert schemas, "internal error: must pass at least one schema to merge"
schemas = sorted((canonicalish(s) for s in schemas), key=upper_bound_instances)
schemas = sorted(
(canonicalish(s, resolver=resolver) for s in schemas), key=upper_bound_instances
)
if any(s == FALSEY for s in schemas):
return FALSEY
out = schemas[0]
Expand All @@ -654,11 +702,11 @@ def merged(schemas: List[Any]) -> Optional[Schema]:
continue
# If we have a const or enum, this is fairly easy by filtering:
if "const" in out:
if make_validator(s).is_valid(out["const"]):
if make_validator(s, resolver=resolver).is_valid(out["const"]):
continue
return FALSEY
if "enum" in out:
validator = make_validator(s)
validator = make_validator(s, resolver=resolver)
enum_ = [v for v in out["enum"] if validator.is_valid(v)]
if not enum_:
return FALSEY
Expand Down Expand Up @@ -709,36 +757,41 @@ def merged(schemas: List[Any]) -> Optional[Schema]:
else:
out_combined = merged(
[s for p, s in out_pat.items() if re.search(p, prop_name)]
or [out_add]
or [out_add],
resolver=resolver,
)
if prop_name in s_props:
s_combined = s_props[prop_name]
else:
s_combined = merged(
[s for p, s in s_pat.items() if re.search(p, prop_name)]
or [s_add]
or [s_add],
resolver=resolver,
)
if out_combined is None or s_combined is None: # pragma: no cover
# Note that this can only be the case if we were actually going to
# use the schema which we attempted to merge, i.e. prop_name was
# not in the schema and there were unmergable pattern schemas.
return None
m = merged([out_combined, s_combined])
m = merged([out_combined, s_combined], resolver=resolver)
if m is None:
return None
out_props[prop_name] = m
# With all the property names done, it's time to handle the patterns. This is
# simpler as we merge with either an identical pattern, or additionalProperties.
if out_pat or s_pat:
for pattern in set(out_pat) | set(s_pat):
m = merged([out_pat.get(pattern, out_add), s_pat.get(pattern, s_add)])
m = merged(
[out_pat.get(pattern, out_add), s_pat.get(pattern, s_add)],
resolver=resolver,
)
if m is None: # pragma: no cover
return None
out_pat[pattern] = m
out["patternProperties"] = out_pat
# Finally, we merge togther the additionalProperties schemas.
if out_add or s_add:
m = merged([out_add, s_add])
m = merged([out_add, s_add], resolver=resolver)
if m is None: # pragma: no cover
return None
out["additionalProperties"] = m
Expand Down Expand Up @@ -772,7 +825,7 @@ def merged(schemas: List[Any]) -> Optional[Schema]:
return None
if "contains" in out and "contains" in s and out["contains"] != s["contains"]:
# If one `contains` schema is a subset of the other, we can discard it.
m = merged([out["contains"], s["contains"]])
m = merged([out["contains"], s["contains"]], resolver=resolver)
if m == out["contains"] or m == s["contains"]:
out["contains"] = m
s.pop("contains")
Expand Down Expand Up @@ -802,7 +855,7 @@ def merged(schemas: List[Any]) -> Optional[Schema]:
v = {"required": v}
elif isinstance(sval, list):
sval = {"required": sval}
m = merged([v, sval])
m = merged([v, sval], resolver=resolver)
if m is None:
return None
odeps[k] = m
Expand All @@ -816,26 +869,27 @@ def merged(schemas: List[Any]) -> Optional[Schema]:
[
out.get("additionalItems", TRUTHY),
s.get("additionalItems", TRUTHY),
]
],
resolver=resolver,
)
for a, b in itertools.zip_longest(oitems, sitems):
if a is None:
a = out.get("additionalItems", TRUTHY)
elif b is None:
b = s.get("additionalItems", TRUTHY)
out["items"].append(merged([a, b]))
out["items"].append(merged([a, b], resolver=resolver))
elif isinstance(oitems, list):
out["items"] = [merged([x, sitems]) for x in oitems]
out["items"] = [merged([x, sitems], resolver=resolver) for x in oitems]
out["additionalItems"] = merged(
[out.get("additionalItems", TRUTHY), sitems]
[out.get("additionalItems", TRUTHY), sitems], resolver=resolver
)
elif isinstance(sitems, list):
out["items"] = [merged([x, oitems]) for x in sitems]
out["items"] = [merged([x, oitems], resolver=resolver) for x in sitems]
out["additionalItems"] = merged(
[s.get("additionalItems", TRUTHY), oitems]
[s.get("additionalItems", TRUTHY), oitems], resolver=resolver
)
else:
out["items"] = merged([oitems, sitems])
out["items"] = merged([oitems, sitems], resolver=resolver)
if out["items"] is None:
return None
if isinstance(out["items"], list) and None in out["items"]:
Expand All @@ -859,7 +913,7 @@ def merged(schemas: List[Any]) -> Optional[Schema]:
# If non-validation keys like `title` or `description` don't match,
# that doesn't really matter and we'll just go with first we saw.
return None
out = canonicalish(out)
out = canonicalish(out, resolver=resolver)
if out == FALSEY:
return FALSEY
assert isinstance(out, dict)
Expand Down
Loading