diff --git a/tests/test_metrics.py b/tests/test_metrics.py index aba6a01..40e92a8 100644 --- a/tests/test_metrics.py +++ b/tests/test_metrics.py @@ -162,3 +162,35 @@ def test_single_value_formats(): assert isinstance(_wss([1,1,0,0], 0.5), float) assert isinstance(_loss_value([1,1,0,0]), float) assert isinstance(_erf([1,1,0,0], 0.5), float) + +def test_get_metrics(): + with open_state( + Path(TEST_ASREVIEW_FILES, "sim_van_de_schoot_2017_stop_if_min.asreview") + ) as s: + metrics = get_metrics(s, wss=[0.75, 0.85, 0.95], erf=[0.75, 0.85, 0.95]) + + wss_data = next( + (item["value"] for item in metrics["data"]["items"] if item["id"] == "wss"), + None + ) + assert wss_data is not None, "WSS key missing in metrics" + + erf_data = next( + (item["value"] for item in metrics["data"]["items"] if item["id"] == "erf"), + None + ) + assert erf_data is not None, "ERF key missing in metrics" + + wss_values = {val[0]: val[1] for val in wss_data} + for value in [0.75, 0.85, 0.95]: + assert value in wss_values, f"WSS value {value} missing in output" + + for wss_score in wss_values.values(): + assert 0 <= wss_score <= 1, f"WSS value {wss_score} out of expected range" + + erf_values = {val[0]: val[1] for val in wss_data} + for value in [0.75, 0.85, 0.95]: + assert value in erf_values, f"ERF value {value} missing in output" + + for erf_score in erf_values.values(): + assert 0 <= erf_score <= 1, f"ERF value {wss_score} out of expected range" \ No newline at end of file