diff --git a/pyproject.toml b/pyproject.toml index ca60bee..e82d5d8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "geneweaver-core" -version = "0.9.0a6" +version = "0.9.0a7" description = "The core of the Jax-Geneweaver Python library" authors = ["Jax Computational Sciences "] readme = "README.md" diff --git a/src/geneweaver/core/threshold.py b/src/geneweaver/core/threshold.py new file mode 100644 index 0000000..d243801 --- /dev/null +++ b/src/geneweaver/core/threshold.py @@ -0,0 +1,42 @@ +"""A module for functions dealing with thresholding.""" + +from typing import List + +from geneweaver.core.schema.score import GenesetScoreType + + +def check_threshold(geneset_score: GenesetScoreType, value: float) -> bool: + """Check to see if a value falls within the score threshold. + + :param geneset_score: The geneset score type and threshold arguments. + :param value: The value to check against. + :return: A boolean that indicates if the value falls within the range specified by + the threshold. + """ + score_type = int(geneset_score.score_type) + value = float(value) + + # P and Q values + if score_type == 1 or score_type == 2: + return value <= geneset_score.threshold + + # Correlation and effect scores + elif score_type == 4 or score_type == 5: + return geneset_score.threshold_low <= value <= geneset_score.threshold + + # Binary + else: + return value >= geneset_score.threshold + + +def check_threshold_list( + geneset_score: GenesetScoreType, values: List[float] +) -> List[bool]: + """Check to see if a list of values falls within the score threshold. + + :param geneset_score: The geneset score type and threshold arguments. + :param values: The list of values to check against. + :return: A list of booleans that indicates if the value falls within the range + specified by the threshold. + """ + return [check_threshold(geneset_score, value) for value in values] diff --git a/tests/unit/threshold/__init__.py b/tests/unit/threshold/__init__.py new file mode 100644 index 0000000..b7e93f7 --- /dev/null +++ b/tests/unit/threshold/__init__.py @@ -0,0 +1 @@ +"""Tests for threshold functions.""" diff --git a/tests/unit/threshold/test_check_threshold.py b/tests/unit/threshold/test_check_threshold.py new file mode 100644 index 0000000..8ea30e3 --- /dev/null +++ b/tests/unit/threshold/test_check_threshold.py @@ -0,0 +1,35 @@ +"""Test the check_threshold function.""" + +import pytest +from geneweaver.core.schema.score import GenesetScoreType +from geneweaver.core.threshold import check_threshold + + +@pytest.mark.parametrize( + ("geneset_score", "value", "expected"), + [ + (GenesetScoreType(score_type=1, threshold=0.05), 0.01, True), + (GenesetScoreType(score_type=1, threshold=0.05), 0.06, False), + (GenesetScoreType(score_type=2, threshold=0.05), 0.01, True), + (GenesetScoreType(score_type=2, threshold=0.05), 0.06, False), + (GenesetScoreType(score_type=3, threshold=0.05), 0.01, False), + ( + GenesetScoreType(score_type=4, threshold=0.05, threshold_low=0.01), + 0.03, + True, + ), + ( + GenesetScoreType(score_type=4, threshold=0.05, threshold_low=0.01), + 0.06, + False, + ), + ( + GenesetScoreType(score_type=5, threshold=0.05, threshold_low=0.01), + 0.03, + True, + ), + ], +) +def test_check_threshold(geneset_score, value, expected): + """Check each parametrized case with valid arguments.""" + assert check_threshold(geneset_score, value) == expected diff --git a/tests/unit/threshold/test_check_threshold_list.py b/tests/unit/threshold/test_check_threshold_list.py new file mode 100644 index 0000000..e789588 --- /dev/null +++ b/tests/unit/threshold/test_check_threshold_list.py @@ -0,0 +1,55 @@ +"""Test the check_threshold_list function.""" + +import pytest +from geneweaver.core.schema.score import GenesetScoreType +from geneweaver.core.threshold import check_threshold_list + + +@pytest.mark.parametrize( + ("geneset_scores", "values", "expected"), + [ + (GenesetScoreType(score_type=1, threshold=0.05), [0.01, 0.06], [True, False]), + (GenesetScoreType(score_type=2, threshold=0.05), [0.01, 0.06], [True, False]), + (GenesetScoreType(score_type=3, threshold=0.05), [0.01, 0.06], [False, True]), + ( + GenesetScoreType(score_type=4, threshold=0.05, threshold_low=0.01), + [0.03, 0.06], + [True, False], + ), + ( + GenesetScoreType(score_type=5, threshold=0.05, threshold_low=0.01), + [0.03, 0.06], + [True, False], + ), + ( + GenesetScoreType(score_type="p-value", threshold=0.05), + [0.01, 0.06], + [True, False], + ), + ( + GenesetScoreType(score_type="q-value", threshold=0.05), + [0.01, 0.06], + [True, False], + ), + ( + GenesetScoreType(score_type="binary", threshold=0.05), + [0.01, 0.06], + [False, True], + ), + ( + GenesetScoreType( + score_type="correlation", threshold=0.05, threshold_low=0.01 + ), + [0.03, 0.06], + [True, False], + ), + ( + GenesetScoreType(score_type="effect", threshold=0.05, threshold_low=0.01), + [0.03, 0.06], + [True, False], + ), + ], +) +def test_check_threshold_list(geneset_scores, values, expected): + """Check each parametrized case with valid arguments.""" + assert check_threshold_list(geneset_scores, values) == expected