Skip to content

Commit 83a0106

Browse files
committed
fixed use of predict_proba
1 parent ab70ac9 commit 83a0106

File tree

1 file changed

+6
-5
lines changed

1 file changed

+6
-5
lines changed

src/autora/experimentalist/falsification/__init__.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -240,13 +240,14 @@ def sample(
240240

241241
# get target pattern for popper net
242242
model_predict = getattr(model, "predict_proba", None)
243-
if callable(model_predict) is False:
244-
model_predict = getattr(model, "predict", None)
245-
246243
if callable(model_predict) is False or model_predict is None:
247-
raise Exception("Model must have `predict` or `predict_proba` method.")
244+
model_predict = getattr(model, "predict", None)
245+
if callable(model_predict) is False or model_predict is None:
246+
raise Exception("Model must have `predict` or `predict_proba` method.")
247+
else:
248+
predicted_observations = model.predict(reference_conditions)
248249
else:
249-
predicted_observations = model.predict(reference_conditions)
250+
predicted_observations = model.predict_proba(reference_conditions)
250251

251252
if isinstance(predicted_observations, np.ndarray) is False:
252253
try:

0 commit comments

Comments
 (0)