Skip to content

Commit 5cc6984

Browse files
committed
Add de_res_to_anndata
Signed-off-by: zethson <[email protected]>
1 parent 5a53238 commit 5cc6984

File tree

2 files changed

+98
-1
lines changed

2 files changed

+98
-1
lines changed

pertpy/tools/_differential_gene_expression.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import numpy.typing as npt
88
import pandas as pd
99
from scipy.stats import kendalltau, pearsonr, spearmanr
10+
from statsmodels.stats.multitest import fdrcorrection
1011

1112
if TYPE_CHECKING:
1213
from anndata import AnnData
@@ -216,6 +217,78 @@ def calculate_cohens_d(self, de_res_1: pd.DataFrame, de_res_2: pd.DataFrame) ->
216217

217218
return cohens_d
218219

220+
def de_res_to_anndata(
221+
self,
222+
adata: AnnData,
223+
de_res: pd.DataFrame,
224+
*,
225+
groupby: str,
226+
gene_id_col: str = "gene_symbols",
227+
score_col: str = "scores",
228+
pval_col: str = "pvals",
229+
pval_adj_col: str | None = "pvals_adj",
230+
lfc_col: str = "logfoldchanges",
231+
key_added: str = "rank_genes_groups",
232+
) -> None:
233+
"""Add tabular differential expression result to AnnData as if it was produced by `scanpy.tl.rank_genes_groups`.
234+
235+
Args:
236+
adata:
237+
Annotated data matrix
238+
de_res:
239+
Tablular de result
240+
groupby:
241+
Column in `de_res` that indicates the group. This column must also exist in `adata.obs`.
242+
gene_id_col:
243+
Column in `de_res` that holds the gene identifiers
244+
score_col:
245+
Column in `de_res` that holds the score (results will be ordered by score).
246+
pval_col:
247+
Column in `de_res` that holds the unadjusted pvalue
248+
pval_adj_col:
249+
Column in `de_res` that holds the adjusted pvalue.
250+
If not specified, the unadjusted pvalues will be FDR-adjusted.
251+
lfc_col:
252+
Column in `de_res` that holds the log fold change
253+
key_added:
254+
Key under which the results will be stored in `adata.uns`
255+
"""
256+
if groupby not in adata.obs.columns or groupby not in de_res.columns:
257+
raise ValueError("groupby column must exist in both adata and de_res.")
258+
res_dict = {
259+
"params": {
260+
"groupby": groupby,
261+
"reference": "rest",
262+
"method": "other",
263+
"use_raw": True,
264+
"layer": None,
265+
"corr_method": "other",
266+
},
267+
"names": [],
268+
"scores": [],
269+
"pvals": [],
270+
"pvals_adj": [],
271+
"logfoldchanges": [],
272+
}
273+
df_groupby = de_res.groupby(groupby)
274+
for _, tmp_df in df_groupby:
275+
tmp_df = tmp_df.sort_values(score_col, ascending=False)
276+
res_dict["names"].append(tmp_df[gene_id_col].values) # type: ignore
277+
res_dict["scores"].append(tmp_df[score_col].values) # type: ignore
278+
res_dict["pvals"].append(tmp_df[pval_col].values) # type: ignore
279+
if pval_adj_col is not None:
280+
res_dict["pvals_adj"].append(tmp_df[pval_adj_col].values) # type: ignore
281+
else:
282+
res_dict["pvals_adj"].append(fdrcorrection(tmp_df[pval_col].values)[1]) # type: ignore
283+
res_dict["logfoldchanges"].append(tmp_df[lfc_col].values) # type: ignore
284+
285+
for key in ["names", "scores", "pvals", "pvals_adj", "logfoldchanges"]:
286+
res_dict[key] = pd.DataFrame(
287+
np.vstack(res_dict[key]).T,
288+
columns=list(df_groupby.groups.keys()),
289+
).to_records(index=False, column_dtypes="O")
290+
adata.uns[key_added] = res_dict
291+
219292
def de_analysis(
220293
self,
221294
adata: AnnData,

tests/tools/test_differential_gene_expression.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,20 @@
1+
import numpy as np
12
import pandas as pd
23
import pertpy as pt
34
import pytest
5+
from anndata import AnnData
46

57

68
@pytest.fixture
79
def dummy_de_results():
8-
data1 = {"pvals": [0.1, 0.2, 0.3, 0.4], "pvals_adj": [0.1, 0.25, 0.35, 0.45], "logfoldchanges": [1, 2, 3, 4]}
10+
data1 = {
11+
"pvals": [0.1, 0.2, 0.3, 0.4],
12+
"pvals_adj": [0.1, 0.25, 0.35, 0.45],
13+
"logfoldchanges": [1, 2, 3, 4],
14+
"group": ["group_1", "group_1", "group_2", "group_2"],
15+
"scores": [5, 10, 4, 20],
16+
"gene_symbols": ["BRCA1", "TP53", "EGFR", "MYC"],
17+
}
918
data2 = {"pvals": [0.1, 0.2, 0.3, 0.4], "pvals_adj": [0.15, 0.2, 0.35, 0.5], "logfoldchanges": [2, 3, 4, 5]}
1019
de_res_1 = pd.DataFrame(data1)
1120
de_res_2 = pd.DataFrame(data2)
@@ -55,3 +64,18 @@ def test_calculate_cohens_d(dummy_de_results, pt_de):
5564

5665
cohens_d = pt_de.calculate_cohens_d(de_res_1, de_res_2)
5766
assert isinstance(cohens_d, float)
67+
68+
69+
def test_de_res_to_anndata(dummy_de_results, pt_de):
70+
de_res_1, de_res_2 = dummy_de_results
71+
72+
rng = np.random.default_rng()
73+
X = rng.random((4, 5))
74+
adata = AnnData(X)
75+
adata.obs["group"] = ["group_1"] * 2 + ["group_2"] * 2
76+
77+
pt_de.de_res_to_anndata(adata, de_res_1, groupby="group")
78+
assert "rank_genes_groups" in adata.uns
79+
assert all(
80+
col in adata.uns["rank_genes_groups"] for col in ["names", "scores", "pvals", "pvals_adj", "logfoldchanges"]
81+
)

0 commit comments

Comments
 (0)