Skip to content

Commit 4af9e0b

Browse files
committed
Pass resolver to validators
1 parent e9e6d51 commit 4af9e0b

File tree

4 files changed

+148
-88
lines changed

4 files changed

+148
-88
lines changed

src/hypothesis_jsonschema/_canonicalise.py

+58-43
Original file line numberDiff line numberDiff line change
@@ -79,9 +79,20 @@ def _get_validator_class(schema: Schema) -> JSONSchemaValidator:
7979
return validator
8080

8181

82-
def make_validator(schema: Schema) -> JSONSchemaValidator:
82+
class LocalResolver(jsonschema.RefResolver):
83+
def resolve_remote(self, uri: str) -> NoReturn:
84+
raise HypothesisRefResolutionError(
85+
f"hypothesis-jsonschema does not fetch remote references (uri={uri!r})"
86+
)
87+
88+
89+
def make_validator(
90+
schema: Schema, resolver: LocalResolver = None
91+
) -> JSONSchemaValidator:
92+
if resolver is None:
93+
resolver = LocalResolver.from_schema(schema)
8394
validator = _get_validator_class(schema)
84-
return validator(schema)
95+
return validator(schema, resolver=resolver)
8596

8697

8798
class HypothesisRefResolutionError(jsonschema.exceptions.RefResolutionError):
@@ -203,7 +214,7 @@ def get_integer_bounds(schema: Schema) -> Tuple[Optional[int], Optional[int]]:
203214
return lower, upper
204215

205216

206-
def canonicalish(schema: JSONType) -> Dict[str, Any]:
217+
def canonicalish(schema: JSONType, resolver: LocalResolver = None) -> Dict[str, Any]:
207218
"""Convert a schema into a more-canonical form.
208219
209220
This is obviously incomplete, but improves best-effort recognition of
@@ -225,12 +236,15 @@ def canonicalish(schema: JSONType) -> Dict[str, Any]:
225236
"but expected a dict."
226237
)
227238

239+
if resolver is None:
240+
resolver = LocalResolver.from_schema(schema)
241+
228242
if "const" in schema:
229-
if not make_validator(schema).is_valid(schema["const"]):
243+
if not make_validator(schema, resolver=resolver).is_valid(schema["const"]):
230244
return FALSEY
231245
return {"const": schema["const"]}
232246
if "enum" in schema:
233-
validator = make_validator(schema)
247+
validator = make_validator(schema, resolver=resolver)
234248
enum_ = sorted(
235249
(v for v in schema["enum"] if validator.is_valid(v)), key=sort_key
236250
)
@@ -254,15 +268,15 @@ def canonicalish(schema: JSONType) -> Dict[str, Any]:
254268
# Recurse into the value of each keyword with a schema (or list of them) as a value
255269
for key in SCHEMA_KEYS:
256270
if isinstance(schema.get(key), list):
257-
schema[key] = [canonicalish(v) for v in schema[key]]
271+
schema[key] = [canonicalish(v, resolver=resolver) for v in schema[key]]
258272
elif isinstance(schema.get(key), (bool, dict)):
259-
schema[key] = canonicalish(schema[key])
273+
schema[key] = canonicalish(schema[key], resolver=resolver)
260274
else:
261275
assert key not in schema, (key, schema[key])
262276
for key in SCHEMA_OBJECT_KEYS:
263277
if key in schema:
264278
schema[key] = {
265-
k: v if isinstance(v, list) else canonicalish(v)
279+
k: v if isinstance(v, list) else canonicalish(v, resolver=resolver)
266280
for k, v in schema[key].items()
267281
}
268282

@@ -308,7 +322,9 @@ def canonicalish(schema: JSONType) -> Dict[str, Any]:
308322

309323
if "array" in type_ and "contains" in schema:
310324
if isinstance(schema.get("items"), dict):
311-
contains_items = merged([schema["contains"], schema["items"]])
325+
contains_items = merged(
326+
[schema["contains"], schema["items"]], resolver=resolver
327+
)
312328
if contains_items is not None:
313329
schema["contains"] = contains_items
314330

@@ -462,9 +478,9 @@ def canonicalish(schema: JSONType) -> Dict[str, Any]:
462478
type_.remove(t)
463479
if t not in ("integer", "number"):
464480
not_["type"].remove(t)
465-
not_ = canonicalish(not_)
481+
not_ = canonicalish(not_, resolver=resolver)
466482

467-
m = merged([not_, {**schema, "type": type_}])
483+
m = merged([not_, {**schema, "type": type_}], resolver=resolver)
468484
if m is not None:
469485
not_ = m
470486
if not_ != FALSEY:
@@ -543,7 +559,7 @@ def canonicalish(schema: JSONType) -> Dict[str, Any]:
543559
else:
544560
tmp = schema.copy()
545561
ao = tmp.pop("allOf")
546-
out = merged([tmp] + ao)
562+
out = merged([tmp] + ao, resolver=resolver)
547563
if isinstance(out, dict): # pragma: no branch
548564
schema = out
549565
# TODO: this assertion is soley because mypy 0.750 doesn't know
@@ -555,7 +571,7 @@ def canonicalish(schema: JSONType) -> Dict[str, Any]:
555571
one_of = sorted(one_of, key=encode_canonical_json)
556572
one_of = [s for s in one_of if s != FALSEY]
557573
if len(one_of) == 1:
558-
m = merged([schema, one_of[0]])
574+
m = merged([schema, one_of[0]], resolver=resolver)
559575
if m is not None: # pragma: no branch
560576
return m
561577
if (not one_of) or one_of.count(TRUTHY) > 1:
@@ -570,13 +586,6 @@ def canonicalish(schema: JSONType) -> Dict[str, Any]:
570586
FALSEY = canonicalish(False)
571587

572588

573-
class LocalResolver(jsonschema.RefResolver):
574-
def resolve_remote(self, uri: str) -> NoReturn:
575-
raise HypothesisRefResolutionError(
576-
f"hypothesis-jsonschema does not fetch remote references (uri={uri!r})"
577-
)
578-
579-
580589
def is_recursive_reference(reference: str, resolver: LocalResolver) -> bool:
581590
"""Detect if the given reference is recursive."""
582591
# Special case: a reference to the schema's root is always recursive
@@ -593,7 +602,7 @@ def is_recursive_reference(reference: str, resolver: LocalResolver) -> bool:
593602

594603

595604
def resolve_all_refs(
596-
schema: Union[bool, Schema], *, resolver: LocalResolver = None
605+
schema: Union[bool, Schema], *, resolver: LocalResolver
597606
) -> Tuple[Schema, bool]:
598607
"""Resolve all non-recursive references in the given schema.
599608
@@ -602,8 +611,6 @@ def resolve_all_refs(
602611
if isinstance(schema, bool):
603612
return canonicalish(schema), False
604613
assert isinstance(schema, dict), schema
605-
if resolver is None:
606-
resolver = LocalResolver.from_schema(deepcopy(schema))
607614
if not isinstance(resolver, jsonschema.RefResolver):
608615
raise InvalidArgument(
609616
f"resolver={resolver} (type {type(resolver).__name__}) is not a RefResolver"
@@ -617,7 +624,7 @@ def resolve_all_refs(
617624
with resolver.resolving(ref) as got:
618625
if s == {}:
619626
return resolve_all_refs(deepcopy(got), resolver=resolver)
620-
m = merged([s, got])
627+
m = merged([s, got], resolver=resolver)
621628
if m is None: # pragma: no cover
622629
msg = f"$ref:{ref!r} had incompatible base schema {s!r}"
623630
raise HypothesisRefResolutionError(msg)
@@ -671,7 +678,7 @@ def resolve_all_refs(
671678
return schema, False
672679

673680

674-
def merged(schemas: List[Any]) -> Optional[Schema]:
681+
def merged(schemas: List[Any], resolver: LocalResolver = None) -> Optional[Schema]:
675682
"""Merge *n* schemas into a single schema, or None if result is invalid.
676683
677684
Takes the logical intersection, so any object that validates against the returned
@@ -684,7 +691,9 @@ def merged(schemas: List[Any]) -> Optional[Schema]:
684691
It's currently also used for keys that could be merged but aren't yet.
685692
"""
686693
assert schemas, "internal error: must pass at least one schema to merge"
687-
schemas = sorted((canonicalish(s) for s in schemas), key=upper_bound_instances)
694+
schemas = sorted(
695+
(canonicalish(s, resolver=resolver) for s in schemas), key=upper_bound_instances
696+
)
688697
if any(s == FALSEY for s in schemas):
689698
return FALSEY
690699
out = schemas[0]
@@ -693,11 +702,11 @@ def merged(schemas: List[Any]) -> Optional[Schema]:
693702
continue
694703
# If we have a const or enum, this is fairly easy by filtering:
695704
if "const" in out:
696-
if make_validator(s).is_valid(out["const"]):
705+
if make_validator(s, resolver=resolver).is_valid(out["const"]):
697706
continue
698707
return FALSEY
699708
if "enum" in out:
700-
validator = make_validator(s)
709+
validator = make_validator(s, resolver=resolver)
701710
enum_ = [v for v in out["enum"] if validator.is_valid(v)]
702711
if not enum_:
703712
return FALSEY
@@ -748,36 +757,41 @@ def merged(schemas: List[Any]) -> Optional[Schema]:
748757
else:
749758
out_combined = merged(
750759
[s for p, s in out_pat.items() if re.search(p, prop_name)]
751-
or [out_add]
760+
or [out_add],
761+
resolver=resolver,
752762
)
753763
if prop_name in s_props:
754764
s_combined = s_props[prop_name]
755765
else:
756766
s_combined = merged(
757767
[s for p, s in s_pat.items() if re.search(p, prop_name)]
758-
or [s_add]
768+
or [s_add],
769+
resolver=resolver,
759770
)
760771
if out_combined is None or s_combined is None: # pragma: no cover
761772
# Note that this can only be the case if we were actually going to
762773
# use the schema which we attempted to merge, i.e. prop_name was
763774
# not in the schema and there were unmergable pattern schemas.
764775
return None
765-
m = merged([out_combined, s_combined])
776+
m = merged([out_combined, s_combined], resolver=resolver)
766777
if m is None:
767778
return None
768779
out_props[prop_name] = m
769780
# With all the property names done, it's time to handle the patterns. This is
770781
# simpler as we merge with either an identical pattern, or additionalProperties.
771782
if out_pat or s_pat:
772783
for pattern in set(out_pat) | set(s_pat):
773-
m = merged([out_pat.get(pattern, out_add), s_pat.get(pattern, s_add)])
784+
m = merged(
785+
[out_pat.get(pattern, out_add), s_pat.get(pattern, s_add)],
786+
resolver=resolver,
787+
)
774788
if m is None: # pragma: no cover
775789
return None
776790
out_pat[pattern] = m
777791
out["patternProperties"] = out_pat
778792
# Finally, we merge togther the additionalProperties schemas.
779793
if out_add or s_add:
780-
m = merged([out_add, s_add])
794+
m = merged([out_add, s_add], resolver=resolver)
781795
if m is None: # pragma: no cover
782796
return None
783797
out["additionalProperties"] = m
@@ -811,7 +825,7 @@ def merged(schemas: List[Any]) -> Optional[Schema]:
811825
return None
812826
if "contains" in out and "contains" in s and out["contains"] != s["contains"]:
813827
# If one `contains` schema is a subset of the other, we can discard it.
814-
m = merged([out["contains"], s["contains"]])
828+
m = merged([out["contains"], s["contains"]], resolver=resolver)
815829
if m == out["contains"] or m == s["contains"]:
816830
out["contains"] = m
817831
s.pop("contains")
@@ -841,7 +855,7 @@ def merged(schemas: List[Any]) -> Optional[Schema]:
841855
v = {"required": v}
842856
elif isinstance(sval, list):
843857
sval = {"required": sval}
844-
m = merged([v, sval])
858+
m = merged([v, sval], resolver=resolver)
845859
if m is None:
846860
return None
847861
odeps[k] = m
@@ -855,26 +869,27 @@ def merged(schemas: List[Any]) -> Optional[Schema]:
855869
[
856870
out.get("additionalItems", TRUTHY),
857871
s.get("additionalItems", TRUTHY),
858-
]
872+
],
873+
resolver=resolver,
859874
)
860875
for a, b in itertools.zip_longest(oitems, sitems):
861876
if a is None:
862877
a = out.get("additionalItems", TRUTHY)
863878
elif b is None:
864879
b = s.get("additionalItems", TRUTHY)
865-
out["items"].append(merged([a, b]))
880+
out["items"].append(merged([a, b], resolver=resolver))
866881
elif isinstance(oitems, list):
867-
out["items"] = [merged([x, sitems]) for x in oitems]
882+
out["items"] = [merged([x, sitems], resolver=resolver) for x in oitems]
868883
out["additionalItems"] = merged(
869-
[out.get("additionalItems", TRUTHY), sitems]
884+
[out.get("additionalItems", TRUTHY), sitems], resolver=resolver
870885
)
871886
elif isinstance(sitems, list):
872-
out["items"] = [merged([x, oitems]) for x in sitems]
887+
out["items"] = [merged([x, oitems], resolver=resolver) for x in sitems]
873888
out["additionalItems"] = merged(
874-
[s.get("additionalItems", TRUTHY), oitems]
889+
[s.get("additionalItems", TRUTHY), oitems], resolver=resolver
875890
)
876891
else:
877-
out["items"] = merged([oitems, sitems])
892+
out["items"] = merged([oitems, sitems], resolver=resolver)
878893
if out["items"] is None:
879894
return None
880895
if isinstance(out["items"], list) and None in out["items"]:
@@ -898,7 +913,7 @@ def merged(schemas: List[Any]) -> Optional[Schema]:
898913
# If non-validation keys like `title` or `description` don't match,
899914
# that doesn't really matter and we'll just go with first we saw.
900915
return None
901-
out = canonicalish(out)
916+
out = canonicalish(out, resolver=resolver)
902917
if out == FALSEY:
903918
return FALSEY
904919
assert isinstance(out, dict)

0 commit comments

Comments
 (0)