diff --git a/baybe/utils/validation.py b/baybe/utils/validation.py index 84b47ad2f3..84a2108c1a 100644 --- a/baybe/utils/validation.py +++ b/baybe/utils/validation.py @@ -92,28 +92,26 @@ def validate_target_input(data: pd.DataFrame, targets: Iterable[Target]) -> None from baybe.targets import BinaryTarget, NumericalTarget if len(data) < 1: - raise ValueError("The provided input dataframe cannot be emtpy.") + raise ValueError("The provided input dataframe cannot be empty.") for t in targets: if data[t.name].isna().any(): raise ValueError( - f"The target '{t.name}' has missing values or NaNs in the provided " - f"dataframe. Missing target values are not supported." + f"The target '{t.name}' has missing values in the provided dataframe." ) if isinstance(t, NumericalTarget): if data[t.name].dtype.kind not in "iufb": raise TypeError( f"The numerical target '{t.name}' has non-numeric entries in the " - f"provided dataframe. Non-numeric target values are not supported." + f"provided dataframe." ) elif isinstance(t, BinaryTarget): - if not ( - data[t.name].isin(allowed := [t.failure_value, t.success_value]).all() - ): + allowed = {t.failure_value, t.success_value} + if invalid := set(data[t.name].unique()) - allowed: raise ValueError( - f"The binary target '{t.name}' has nvalid entries in the provided " - f"dataframe. Allowed values are: {allowed}." + f"The binary target '{t.name}' has invalid entries {invalid} " + f"in the provided dataframe. Allowed values are: {allowed}." ) @@ -137,13 +135,13 @@ def validate_parameter_input( TypeError: If a parameter contains non-numeric values. """ if len(data) < 1: - raise ValueError("The provided input dataframe cannot be emtpy.") + raise ValueError("The provided input dataframe cannot be empty.") for p in parameters: if data[p.name].isna().any(): raise ValueError( - f"The parameter '{p.name}' has missing values or NaNs in the provided " - f"dataframe. Missing parameter values are not supported." + f"The parameter '{p.name}' has missing values in the provided " + f"dataframe." ) if p.is_numerical and (data[p.name].dtype.kind not in "iufb"): raise TypeError(