Skip to content

Commit 595e164

Browse files
michaelosthegetwiecki
authored andcommitted
Simplify _check_start_shape
Closes #5031
1 parent 5f248fc commit 595e164

File tree

2 files changed

+13
-29
lines changed

2 files changed

+13
-29
lines changed

pymc/sampling.py

+11-26
Original file line numberDiff line numberDiff line change
@@ -667,33 +667,18 @@ def _check_start_shape(model, start: PointType):
667667
The complete dictionary mapping (transformed) variable names to numeric initial values.
668668
"""
669669
e = ""
670-
for var in model.basic_RVs:
671-
try:
672-
var_shape = model.fastfn(var.shape)(start)
673-
if var.name in start.keys():
674-
start_var_shape = np.shape(start[var.name])
675-
if start_var_shape:
676-
if not np.array_equal(var_shape, start_var_shape):
677-
e += "\nExpected shape {} for var '{}', got: {}".format(
678-
tuple(var_shape), var.name, start_var_shape
679-
)
680-
# if start var has no shape
681-
else:
682-
# if model var has a specified shape
683-
if var_shape.size > 0:
684-
e += "\nExpected shape {} for var " "'{}', got scalar {}".format(
685-
tuple(var_shape), var.name, start[var.name]
686-
)
687-
except NotImplementedError as ex:
688-
if ex.args[0].startswith("Cannot sample"):
689-
_log.warning(
690-
f"Unable to check start shape of {var} because the RV does not implement random sampling."
691-
)
692-
else:
693-
raise
694-
670+
try:
671+
actual_shapes = model.eval_rv_shapes()
672+
except NotImplementedError as ex:
673+
warnings.warn(f"Unable to validate shapes: {ex.args[0]}", UserWarning)
674+
return
675+
for name, sval in start.items():
676+
ashape = actual_shapes.get(name)
677+
sshape = np.shape(sval)
678+
if ashape != tuple(sshape):
679+
e += f"\nExpected shape {ashape} for var '{name}', got: {sshape}"
695680
if e != "":
696-
raise ValueError(f"Bad shape for start argument:{e}")
681+
raise ValueError(f"Bad shape in start point:{e}")
697682

698683

699684
def _sample_many(

pymc/tests/test_sampling.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -235,10 +235,9 @@ def test_sampler_stat_tune(self, cores):
235235
@pytest.mark.parametrize(
236236
"start, error",
237237
[
238-
([1, 2], TypeError),
239-
({"x": 1}, TypeError),
238+
({"x": 1}, ValueError),
240239
({"x": [1, 2, 3]}, ValueError),
241-
({"x": np.array([[1, 1], [1, 1]])}, TypeError),
240+
({"x": np.array([[1, 1], [1, 1]])}, ValueError),
242241
],
243242
)
244243
def test_sample_start_bad_shape(self, start, error):

0 commit comments

Comments
 (0)