@@ -60,7 +60,7 @@ def test_model(
60
60
)
61
61
62
62
63
- def _validate_input_shape (shape : Tuple [int , ...], shape_spec ) -> bool :
63
+ def check_input_shape (shape : Tuple [int , ...], shape_spec ) -> bool :
64
64
if isinstance (shape_spec , list ):
65
65
if shape != tuple (shape_spec ):
66
66
return False
@@ -81,7 +81,7 @@ def _validate_input_shape(shape: Tuple[int, ...], shape_spec) -> bool:
81
81
return True
82
82
83
83
84
- def _validate_output_shape (shape : Tuple [int , ...], shape_spec , input_shapes ) -> bool :
84
+ def check_output_shape (shape : Tuple [int , ...], shape_spec , input_shapes ) -> bool :
85
85
if isinstance (shape_spec , list ):
86
86
return shape == tuple (shape_spec )
87
87
elif isinstance (shape_spec , ImplicitOutputShape ):
@@ -129,7 +129,7 @@ def test_resource(
129
129
assert len (inputs ) == len (model .inputs ) # should be checked by validation
130
130
input_shapes = {}
131
131
for idx , (ipt , ipt_spec ) in enumerate (zip (inputs , model .inputs )):
132
- if not _validate_input_shape (tuple (ipt .shape ), ipt_spec .shape ):
132
+ if not check_input_shape (tuple (ipt .shape ), ipt_spec .shape ):
133
133
raise ValidationError (
134
134
f"Shape { tuple (ipt .shape )} of test input { idx } '{ ipt_spec .name } ' does not match "
135
135
f"input shape description: { ipt_spec .shape } ."
@@ -138,7 +138,7 @@ def test_resource(
138
138
139
139
assert len (expected ) == len (model .outputs ) # should be checked by validation
140
140
for idx , (out , out_spec ) in enumerate (zip (expected , model .outputs )):
141
- if not _validate_output_shape (tuple (out .shape ), out_spec .shape , input_shapes ):
141
+ if not check_output_shape (tuple (out .shape ), out_spec .shape , input_shapes ):
142
142
error = (error or "" ) + (
143
143
f"Shape { tuple (out .shape )} of test output { idx } '{ out_spec .name } ' does not match "
144
144
f"output shape description: { out_spec .shape } ."
0 commit comments