Skip to content

Commit

Permalink
Add de_res_to_anndata
Browse files Browse the repository at this point in the history
Signed-off-by: zethson <[email protected]>
  • Loading branch information
Zethson committed Jan 6, 2024
1 parent 5a53238 commit 5cc6984
Show file tree
Hide file tree
Showing 2 changed files with 98 additions and 1 deletion.
73 changes: 73 additions & 0 deletions pertpy/tools/_differential_gene_expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import numpy.typing as npt
import pandas as pd
from scipy.stats import kendalltau, pearsonr, spearmanr
from statsmodels.stats.multitest import fdrcorrection

if TYPE_CHECKING:
from anndata import AnnData
Expand Down Expand Up @@ -216,6 +217,78 @@ def calculate_cohens_d(self, de_res_1: pd.DataFrame, de_res_2: pd.DataFrame) ->

return cohens_d

def de_res_to_anndata(
self,
adata: AnnData,
de_res: pd.DataFrame,
*,
groupby: str,
gene_id_col: str = "gene_symbols",
score_col: str = "scores",
pval_col: str = "pvals",
pval_adj_col: str | None = "pvals_adj",
lfc_col: str = "logfoldchanges",
key_added: str = "rank_genes_groups",
) -> None:
"""Add tabular differential expression result to AnnData as if it was produced by `scanpy.tl.rank_genes_groups`.
Args:
adata:
Annotated data matrix
de_res:
Tablular de result
groupby:
Column in `de_res` that indicates the group. This column must also exist in `adata.obs`.
gene_id_col:
Column in `de_res` that holds the gene identifiers
score_col:
Column in `de_res` that holds the score (results will be ordered by score).
pval_col:
Column in `de_res` that holds the unadjusted pvalue
pval_adj_col:
Column in `de_res` that holds the adjusted pvalue.
If not specified, the unadjusted pvalues will be FDR-adjusted.
lfc_col:
Column in `de_res` that holds the log fold change
key_added:
Key under which the results will be stored in `adata.uns`
"""
if groupby not in adata.obs.columns or groupby not in de_res.columns:
raise ValueError("groupby column must exist in both adata and de_res.")
res_dict = {
"params": {
"groupby": groupby,
"reference": "rest",
"method": "other",
"use_raw": True,
"layer": None,
"corr_method": "other",
},
"names": [],
"scores": [],
"pvals": [],
"pvals_adj": [],
"logfoldchanges": [],
}
df_groupby = de_res.groupby(groupby)
for _, tmp_df in df_groupby:
tmp_df = tmp_df.sort_values(score_col, ascending=False)
res_dict["names"].append(tmp_df[gene_id_col].values) # type: ignore
res_dict["scores"].append(tmp_df[score_col].values) # type: ignore
res_dict["pvals"].append(tmp_df[pval_col].values) # type: ignore
if pval_adj_col is not None:
res_dict["pvals_adj"].append(tmp_df[pval_adj_col].values) # type: ignore
else:
res_dict["pvals_adj"].append(fdrcorrection(tmp_df[pval_col].values)[1]) # type: ignore
res_dict["logfoldchanges"].append(tmp_df[lfc_col].values) # type: ignore

for key in ["names", "scores", "pvals", "pvals_adj", "logfoldchanges"]:
res_dict[key] = pd.DataFrame(
np.vstack(res_dict[key]).T,
columns=list(df_groupby.groups.keys()),
).to_records(index=False, column_dtypes="O")
adata.uns[key_added] = res_dict

def de_analysis(
self,
adata: AnnData,
Expand Down
26 changes: 25 additions & 1 deletion tests/tools/test_differential_gene_expression.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,20 @@
import numpy as np
import pandas as pd
import pertpy as pt
import pytest
from anndata import AnnData


@pytest.fixture
def dummy_de_results():
data1 = {"pvals": [0.1, 0.2, 0.3, 0.4], "pvals_adj": [0.1, 0.25, 0.35, 0.45], "logfoldchanges": [1, 2, 3, 4]}
data1 = {
"pvals": [0.1, 0.2, 0.3, 0.4],
"pvals_adj": [0.1, 0.25, 0.35, 0.45],
"logfoldchanges": [1, 2, 3, 4],
"group": ["group_1", "group_1", "group_2", "group_2"],
"scores": [5, 10, 4, 20],
"gene_symbols": ["BRCA1", "TP53", "EGFR", "MYC"],
}
data2 = {"pvals": [0.1, 0.2, 0.3, 0.4], "pvals_adj": [0.15, 0.2, 0.35, 0.5], "logfoldchanges": [2, 3, 4, 5]}
de_res_1 = pd.DataFrame(data1)
de_res_2 = pd.DataFrame(data2)
Expand Down Expand Up @@ -55,3 +64,18 @@ def test_calculate_cohens_d(dummy_de_results, pt_de):

cohens_d = pt_de.calculate_cohens_d(de_res_1, de_res_2)
assert isinstance(cohens_d, float)


def test_de_res_to_anndata(dummy_de_results, pt_de):
de_res_1, de_res_2 = dummy_de_results

rng = np.random.default_rng()
X = rng.random((4, 5))
adata = AnnData(X)
adata.obs["group"] = ["group_1"] * 2 + ["group_2"] * 2

pt_de.de_res_to_anndata(adata, de_res_1, groupby="group")
assert "rank_genes_groups" in adata.uns
assert all(
col in adata.uns["rank_genes_groups"] for col in ["names", "scores", "pvals", "pvals_adj", "logfoldchanges"]
)

0 comments on commit 5cc6984

Please sign in to comment.