@@ -667,33 +667,18 @@ def _check_start_shape(model, start: PointType):
667667 The complete dictionary mapping (transformed) variable names to numeric initial values.
668668 """
669669 e = ""
670- for var in model .basic_RVs :
671- try :
672- var_shape = model .fastfn (var .shape )(start )
673- if var .name in start .keys ():
674- start_var_shape = np .shape (start [var .name ])
675- if start_var_shape :
676- if not np .array_equal (var_shape , start_var_shape ):
677- e += "\n Expected shape {} for var '{}', got: {}" .format (
678- tuple (var_shape ), var .name , start_var_shape
679- )
680- # if start var has no shape
681- else :
682- # if model var has a specified shape
683- if var_shape .size > 0 :
684- e += "\n Expected shape {} for var " "'{}', got scalar {}" .format (
685- tuple (var_shape ), var .name , start [var .name ]
686- )
687- except NotImplementedError as ex :
688- if ex .args [0 ].startswith ("Cannot sample" ):
689- _log .warning (
690- f"Unable to check start shape of { var } because the RV does not implement random sampling."
691- )
692- else :
693- raise
694-
670+ try :
671+ actual_shapes = model .eval_rv_shapes ()
672+ except NotImplementedError as ex :
673+ warnings .warn (f"Unable to validate shapes: { ex .args [0 ]} " , UserWarning )
674+ return
675+ for name , sval in start .items ():
676+ ashape = actual_shapes .get (name )
677+ sshape = np .shape (sval )
678+ if ashape != tuple (sshape ):
679+ e += f"\n Expected shape { ashape } for var '{ name } ', got: { sshape } "
695680 if e != "" :
696- raise ValueError (f"Bad shape for start argument :{ e } " )
681+ raise ValueError (f"Bad shape in start point :{ e } " )
697682
698683
699684def _sample_many (
0 commit comments