@@ -667,33 +667,18 @@ def _check_start_shape(model, start: PointType):
667
667
The complete dictionary mapping (transformed) variable names to numeric initial values.
668
668
"""
669
669
e = ""
670
- for var in model .basic_RVs :
671
- try :
672
- var_shape = model .fastfn (var .shape )(start )
673
- if var .name in start .keys ():
674
- start_var_shape = np .shape (start [var .name ])
675
- if start_var_shape :
676
- if not np .array_equal (var_shape , start_var_shape ):
677
- e += "\n Expected shape {} for var '{}', got: {}" .format (
678
- tuple (var_shape ), var .name , start_var_shape
679
- )
680
- # if start var has no shape
681
- else :
682
- # if model var has a specified shape
683
- if var_shape .size > 0 :
684
- e += "\n Expected shape {} for var " "'{}', got scalar {}" .format (
685
- tuple (var_shape ), var .name , start [var .name ]
686
- )
687
- except NotImplementedError as ex :
688
- if ex .args [0 ].startswith ("Cannot sample" ):
689
- _log .warning (
690
- f"Unable to check start shape of { var } because the RV does not implement random sampling."
691
- )
692
- else :
693
- raise
694
-
670
+ try :
671
+ actual_shapes = model .eval_rv_shapes ()
672
+ except NotImplementedError as ex :
673
+ warnings .warn (f"Unable to validate shapes: { ex .args [0 ]} " , UserWarning )
674
+ return
675
+ for name , sval in start .items ():
676
+ ashape = actual_shapes .get (name )
677
+ sshape = np .shape (sval )
678
+ if ashape != tuple (sshape ):
679
+ e += f"\n Expected shape { ashape } for var '{ name } ', got: { sshape } "
695
680
if e != "" :
696
- raise ValueError (f"Bad shape for start argument :{ e } " )
681
+ raise ValueError (f"Bad shape in start point :{ e } " )
697
682
698
683
699
684
def _sample_many (
0 commit comments