From 9ccf0cddc5e8dcb94a43fd67db0bf8141268e23b Mon Sep 17 00:00:00 2001 From: jteijema Date: Wed, 30 Oct 2024 14:24:23 +0100 Subject: [PATCH] Add new tests --- tests/test_metrics.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/tests/test_metrics.py b/tests/test_metrics.py index 9a844bd..c106892 100644 --- a/tests/test_metrics.py +++ b/tests/test_metrics.py @@ -119,7 +119,7 @@ def test_loss(): Path(TEST_ASREVIEW_FILES, "sim_van_de_schoot_2017_stop_if_min.asreview") ) as s: loss_value = loss(s) - assert_almost_equal(loss_value, 0.011590940352087164, decimal=6) + assert_almost_equal(loss_value, 0.011592855205548452, decimal=6) def test_loss_value_function(): labels = [1, 0] @@ -132,11 +132,15 @@ def test_loss_value_function(): labels = [1, 1, 0, 0, 0] loss_value = _loss_value(labels) - assert_almost_equal(loss_value, 0, decimal=6) + assert_almost_equal(loss_value, 0, decimal=6), f"{loss_value} is {int(loss_value)}" labels = [0, 0, 0, 1, 1] loss_value = _loss_value(labels) - assert_almost_equal(loss_value, 1, decimal=6) + assert_almost_equal(loss_value, 1, decimal=6) + + labels = [1, 0, 1] + loss_value = _loss_value(labels) + assert_almost_equal(loss_value, 0.5, decimal=6) import random for i in range(100):