From 3825d67248b0ff5600fec9d558733d51bcd4397f Mon Sep 17 00:00:00 2001 From: Jelle Teijema Date: Thu, 6 Feb 2025 15:52:27 +0100 Subject: [PATCH] Update metric argument cleaning --- asreviewcontrib/insights/metrics.py | 21 +++++++++++++-------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/asreviewcontrib/insights/metrics.py b/asreviewcontrib/insights/metrics.py index ecee421..09f8ce3 100644 --- a/asreviewcontrib/insights/metrics.py +++ b/asreviewcontrib/insights/metrics.py @@ -196,14 +196,19 @@ def get_metrics( y_absolute=False, version=None, ): - recall = ( - [recall] - if recall and not isinstance(recall, list) - else [0.1, 0.25, 0.5, 0.75, 0.9] - ) - wss = [wss] if wss and not isinstance(wss, list) else [0.95] - erf = [erf] if erf and not isinstance(erf, list) else [0.10] - cm = [cm] if cm and not isinstance(cm, list) else [0.1, 0.25, 0.5, 0.75, 0.9] + def ensure_list_of_floats(value, default): + if value is None: + return default + if isinstance(value, float): + return [value] + if isinstance(value, list) and all(isinstance(i, float) for i in value): + return value + raise ValueError(f"Invalid input: {value}. Must be a float or a list of floats.") + + recall = ensure_list_of_floats(recall, [0.1, 0.25, 0.5, 0.75, 0.9]) + wss = ensure_list_of_floats(wss, [0.95]) + erf = ensure_list_of_floats(erf, [0.10]) + cm = ensure_list_of_floats(cm, [0.1, 0.25, 0.5, 0.75, 0.9]) labels = _pad_simulation_labels(state_obj, priors=priors)