From 5cc6984102f3075a2c2d4d187e1e2d434bab0e52 Mon Sep 17 00:00:00 2001 From: zethson Date: Sat, 6 Jan 2024 18:06:18 +0100 Subject: [PATCH] Add de_res_to_anndata Signed-off-by: zethson --- pertpy/tools/_differential_gene_expression.py | 73 +++++++++++++++++++ .../test_differential_gene_expression.py | 26 ++++++- 2 files changed, 98 insertions(+), 1 deletion(-) diff --git a/pertpy/tools/_differential_gene_expression.py b/pertpy/tools/_differential_gene_expression.py index 43af17d2..1613716d 100644 --- a/pertpy/tools/_differential_gene_expression.py +++ b/pertpy/tools/_differential_gene_expression.py @@ -7,6 +7,7 @@ import numpy.typing as npt import pandas as pd from scipy.stats import kendalltau, pearsonr, spearmanr +from statsmodels.stats.multitest import fdrcorrection if TYPE_CHECKING: from anndata import AnnData @@ -216,6 +217,78 @@ def calculate_cohens_d(self, de_res_1: pd.DataFrame, de_res_2: pd.DataFrame) -> return cohens_d + def de_res_to_anndata( + self, + adata: AnnData, + de_res: pd.DataFrame, + *, + groupby: str, + gene_id_col: str = "gene_symbols", + score_col: str = "scores", + pval_col: str = "pvals", + pval_adj_col: str | None = "pvals_adj", + lfc_col: str = "logfoldchanges", + key_added: str = "rank_genes_groups", + ) -> None: + """Add tabular differential expression result to AnnData as if it was produced by `scanpy.tl.rank_genes_groups`. + + Args: + adata: + Annotated data matrix + de_res: + Tablular de result + groupby: + Column in `de_res` that indicates the group. This column must also exist in `adata.obs`. + gene_id_col: + Column in `de_res` that holds the gene identifiers + score_col: + Column in `de_res` that holds the score (results will be ordered by score). + pval_col: + Column in `de_res` that holds the unadjusted pvalue + pval_adj_col: + Column in `de_res` that holds the adjusted pvalue. + If not specified, the unadjusted pvalues will be FDR-adjusted. + lfc_col: + Column in `de_res` that holds the log fold change + key_added: + Key under which the results will be stored in `adata.uns` + """ + if groupby not in adata.obs.columns or groupby not in de_res.columns: + raise ValueError("groupby column must exist in both adata and de_res.") + res_dict = { + "params": { + "groupby": groupby, + "reference": "rest", + "method": "other", + "use_raw": True, + "layer": None, + "corr_method": "other", + }, + "names": [], + "scores": [], + "pvals": [], + "pvals_adj": [], + "logfoldchanges": [], + } + df_groupby = de_res.groupby(groupby) + for _, tmp_df in df_groupby: + tmp_df = tmp_df.sort_values(score_col, ascending=False) + res_dict["names"].append(tmp_df[gene_id_col].values) # type: ignore + res_dict["scores"].append(tmp_df[score_col].values) # type: ignore + res_dict["pvals"].append(tmp_df[pval_col].values) # type: ignore + if pval_adj_col is not None: + res_dict["pvals_adj"].append(tmp_df[pval_adj_col].values) # type: ignore + else: + res_dict["pvals_adj"].append(fdrcorrection(tmp_df[pval_col].values)[1]) # type: ignore + res_dict["logfoldchanges"].append(tmp_df[lfc_col].values) # type: ignore + + for key in ["names", "scores", "pvals", "pvals_adj", "logfoldchanges"]: + res_dict[key] = pd.DataFrame( + np.vstack(res_dict[key]).T, + columns=list(df_groupby.groups.keys()), + ).to_records(index=False, column_dtypes="O") + adata.uns[key_added] = res_dict + def de_analysis( self, adata: AnnData, diff --git a/tests/tools/test_differential_gene_expression.py b/tests/tools/test_differential_gene_expression.py index b304d75a..177d0e6a 100644 --- a/tests/tools/test_differential_gene_expression.py +++ b/tests/tools/test_differential_gene_expression.py @@ -1,11 +1,20 @@ +import numpy as np import pandas as pd import pertpy as pt import pytest +from anndata import AnnData @pytest.fixture def dummy_de_results(): - data1 = {"pvals": [0.1, 0.2, 0.3, 0.4], "pvals_adj": [0.1, 0.25, 0.35, 0.45], "logfoldchanges": [1, 2, 3, 4]} + data1 = { + "pvals": [0.1, 0.2, 0.3, 0.4], + "pvals_adj": [0.1, 0.25, 0.35, 0.45], + "logfoldchanges": [1, 2, 3, 4], + "group": ["group_1", "group_1", "group_2", "group_2"], + "scores": [5, 10, 4, 20], + "gene_symbols": ["BRCA1", "TP53", "EGFR", "MYC"], + } data2 = {"pvals": [0.1, 0.2, 0.3, 0.4], "pvals_adj": [0.15, 0.2, 0.35, 0.5], "logfoldchanges": [2, 3, 4, 5]} de_res_1 = pd.DataFrame(data1) de_res_2 = pd.DataFrame(data2) @@ -55,3 +64,18 @@ def test_calculate_cohens_d(dummy_de_results, pt_de): cohens_d = pt_de.calculate_cohens_d(de_res_1, de_res_2) assert isinstance(cohens_d, float) + + +def test_de_res_to_anndata(dummy_de_results, pt_de): + de_res_1, de_res_2 = dummy_de_results + + rng = np.random.default_rng() + X = rng.random((4, 5)) + adata = AnnData(X) + adata.obs["group"] = ["group_1"] * 2 + ["group_2"] * 2 + + pt_de.de_res_to_anndata(adata, de_res_1, groupby="group") + assert "rank_genes_groups" in adata.uns + assert all( + col in adata.uns["rank_genes_groups"] for col in ["names", "scores", "pvals", "pvals_adj", "logfoldchanges"] + )