Skip to content

Commit 93e5db6

Browse files
committed
Refactoring DIALOGUE
Signed-off-by: zethson <[email protected]>
1 parent 9021d84 commit 93e5db6

File tree

1 file changed

+25
-31
lines changed

1 file changed

+25
-31
lines changed

pertpy/tools/_dialogue.py

Lines changed: 25 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -57,15 +57,14 @@ def _get_pseudobulks(
5757
5858
Copied from `https://github.com/schillerlab/sc-toolbox/blob/397e80dc5e8fb8017b75f6c3fa634a1e1213d484/sc_toolbox/tools/__init__.py#L458`
5959
60-
# TODO: Replace with decoupler's implementation
61-
6260
Args:
6361
groupby: The key to groupby for pseudobulks
6462
strategy: The pseudobulking strategy. One of "median" or "mean"
6563
6664
Returns:
6765
A Pandas DataFrame of pseudobulk counts
6866
"""
67+
# TODO: Replace with decoupler's implementation
6968
pseudobulk = {"Genes": adata.var_names.values}
7069

7170
for category in adata.obs.loc[:, groupby].cat.categories:
@@ -105,18 +104,16 @@ def _pseudobulk_pca(self, adata: AnnData, groupby: str, n_components: int = 50)
105104
def _scale_data(self, pseudobulks: pd.DataFrame, normalize: bool = True) -> np.ndarray:
106105
"""Row-wise mean center and scale by the standard deviation.
107106
108-
TODO: the `scale` function we implemented to match the R `scale` fn should already contain this functionality.
109-
110107
Args:
111108
pseudobulks: The pseudobulk PCA components.
112109
normalize: Whether to mimic DIALOGUE behavior or not.
113110
114111
Returns:
115112
The scaled count matrix.
116113
"""
114+
# TODO: the `scale` function we implemented to match the R `scale` fn should already contain this functionality.
117115
# DIALOGUE doesn't scale the data before passing to multicca, unlike what is recommended by sparsecca.
118116
# However, performing this scaling _does_ increase overall correlation of the end result
119-
# WHEN SAMPLE ORDER AND DIALOGUE2+3 PROCESSING IS IGNORED.
120117
if normalize:
121118
return pseudobulks.to_numpy()
122119
else:
@@ -371,7 +368,7 @@ def _get_residuals(self, X: np.ndarray, y: np.ndarray) -> np.ndarray:
371368
return np.array(resid)
372369

373370
def _iterative_nnls(self, A_orig: np.ndarray, y_orig: np.ndarray, feature_ranks: list[int], n_iter: int = 1000):
374-
"""Solves non-negative least squares separately for different feature categories.
371+
"""Solves non-negative least-squares separately for different feature categories.
375372
376373
Mimics DLG.iterative.nnls.
377374
Variables are notated according to:
@@ -628,9 +625,8 @@ def calculate_multifactor_PMD(
628625
>>> dl = pt.tl.Dialogue(sample_id = "clinical.status", celltype_key = "cell.subtypes", n_counts_key = "nCount_RNA", n_mpcs = 3)
629626
>>> adata, mcps, ws, ct_subs = dl.calculate_multifactor_PMD(adata, normalize=True)
630627
"""
631-
# IMPORTANT NOTE: the order in which matrices are passed to multicca matters. As such,
632-
# it is important here that to obtain the same result as in R, we pass the matrices in
633-
# in the same order.
628+
# IMPORTANT NOTE: the order in which matrices are passed to multicca matters.
629+
# As such, it is important here that to obtain the same result as in R, we pass the matrices in the same order.
634630
if ct_order is not None:
635631
cell_types = ct_order
636632
else:
@@ -798,7 +794,7 @@ def multilevel_modeling(
798794
for mcp in mcps:
799795
mixed_model_progress.update(mm_task, description=f"[bold blue]Determining mixed effects for {mcp}")
800796

801-
# TODO Check that the genes in result{sig_genes_1] are different and if so note that somewhere and explain why
797+
# TODO Check whether the genes in result{sig_genes_1] are different and if so note that somewhere and explain why
802798
result = {}
803799
result["HLM_result_1"], result["sig_genes_1"] = self._apply_HLM_per_MCP_for_one_pair(
804800
mcp_name=mcp,
@@ -868,22 +864,19 @@ def test_association(
868864
sample_label = self.sample_id
869865
n_mcps = self.n_mcps
870866

871-
# create conditions_compare if not supplied
872867
if conditions_compare is None:
873868
conditions_compare = list(adata.obs["path_str"].cat.categories) # type: ignore
874869
if len(conditions_compare) != 2:
875870
raise ValueError("Please specify conditions to compare or supply an object with only 2 conditions")
876871

877-
# create data frames to store results
878872
pvals = pd.DataFrame(1, adata.obs[celltype_label].unique(), ["mcp_" + str(n) for n in range(0, n_mcps)])
879873
tstats = pd.DataFrame(1, adata.obs[celltype_label].unique(), ["mcp_" + str(n) for n in range(0, n_mcps)])
880874
pvals_adj = pd.DataFrame(1, adata.obs[celltype_label].unique(), ["mcp_" + str(n) for n in range(0, n_mcps)])
881875

882876
response = adata.obs.groupby(sample_label)[condition_label].agg(pd.Series.mode)
883877
for celltype in adata.obs[celltype_label].unique():
884-
# subset data to cell type
885878
df = adata.obs[adata.obs[celltype_label] == celltype]
886-
# run t-test for each MCP
879+
887880
for mcpnum in ["mcp_" + str(n) for n in range(0, n_mcps)]:
888881
mns = df.groupby(sample_label)[mcpnum].mean()
889882
mns = pd.concat([mns, response], axis=1)
@@ -893,11 +886,10 @@ def test_association(
893886
)
894887
pvals.loc[celltype, mcpnum] = res[1]
895888
tstats.loc[celltype, mcpnum] = res[0]
896-
# return(res)
897889

898-
# benjamini-hochberg correction for number of cell types (use BH because correlated MCPs)
899890
for mcpnum in ["mcp_" + str(n) for n in range(0, n_mcps)]:
900891
pvals_adj[mcpnum] = multipletests(pvals[mcpnum], method="fdr_bh")[1]
892+
901893
return {"pvals": pvals, "tstats": tstats, "pvals_adj": pvals_adj}
902894

903895
def get_mlm_mcp_genes(
@@ -914,7 +906,7 @@ def get_mlm_mcp_genes(
914906
celltype: Cell type of interest.
915907
results: dl.MultilevelModeling result object.
916908
MCP: MCP key of the result object.
917-
threshhold: Number between [0,1]. The fraction of cell types compared against which must have the associated MCP gene.
909+
threshold: Number between [0,1]. The fraction of cell types compared against which must have the associated MCP gene.
918910
Defaults to 0.70.
919911
focal_celltypes: None (compare against all cell types) or a list of other cell types which you want to compare against.
920912
Defaults to None.
@@ -938,7 +930,6 @@ def get_mlm_mcp_genes(
938930
# REMOVE THIS BLOCK ONCE MLM OUTPUT MATCHES STANDARD
939931
if MCP.startswith("mcp_"):
940932
MCP = MCP.replace("mcp_", "MCP")
941-
# convert from MCPx to MCPx+1
942933
MCP = "MCP" + str(int(MCP[3:]) - 1)
943934

944935
# Extract all comparison keys from the results object
@@ -1007,17 +998,16 @@ def _get_extrema_MCP_genes_single(self, ct_subs: dict, mcp: str = "mcp_0", fract
1007998
objects containing the results of gene ranking analysis.
1008999
10091000
Examples:
1010-
ct_subs = {
1011-
"subpop1": anndata_obj1,
1012-
"subpop2": anndata_obj2,
1013-
# ... more subpopulations ...
1014-
}
1015-
genes_results = _get_extrema_MCP_genes_single(ct_subs, mcp="mcp_4", fraction=0.2)
1001+
>>> ct_subs = {
1002+
>>> "subpop1": anndata_obj1,
1003+
>>> "subpop2": anndata_obj2,
1004+
>>> # ... more subpopulations ...
1005+
>>> }
1006+
>>> genes_results = _get_extrema_MCP_genes_single(ct_subs, mcp="mcp_4", fraction=0.2)
10161007
"""
10171008
genes = {}
10181009
for ct in ct_subs.keys():
10191010
mini = ct_subs[ct]
1020-
mini.obs[mcp]
10211011
mini.obs["extrema"] = pd.qcut(
10221012
mini.obs[mcp],
10231013
[0, 0 + fraction, 1 - fraction, 1.0],
@@ -1027,6 +1017,7 @@ def _get_extrema_MCP_genes_single(self, ct_subs: dict, mcp: str = "mcp_0", fract
10271017
mini, "extrema", groups=["high" + mcp + " " + ct], reference="low " + mcp + " " + ct
10281018
)
10291019
genes[ct] = mini # .uns['rank_genes_groups']
1020+
10301021
return genes
10311022

10321023
def get_extrema_MCP_genes(self, ct_subs: dict, fraction: float = 0.1):
@@ -1064,11 +1055,12 @@ def get_extrema_MCP_genes(self, ct_subs: dict, fraction: float = 0.1):
10641055
rank_dfs[mcp] = {}
10651056
ct_ranked = self._get_extrema_MCP_genes_single(ct_subs, mcp=mcp, fraction=fraction)
10661057
for celltype in ct_ranked.keys():
1067-
rank_dfs[mcp][celltype] = sc.get.rank_genes_groups_df(ct_ranked[celltype], group=None)
1058+
rank_dfs[mcp][celltype] = sc.get.rank_genes_groups_df(ct_ranked[celltype])
10681059

10691060
return rank_dfs
10701061

10711062
def plot_split_violins(
1063+
self,
10721064
adata: AnnData,
10731065
split_key: str,
10741066
celltype_key=str,
@@ -1111,18 +1103,20 @@ def plot_split_violins(
11111103

11121104
return ax
11131105

1114-
def plot_pairplot(adata: AnnData, celltype_key: str, color: str, sample_id: str, mcp: str = "mcp_0") -> PairGrid:
1106+
def plot_pairplot(
1107+
self, adata: AnnData, celltype_key: str, color: str, sample_id: str, mcp: str = "mcp_0"
1108+
) -> PairGrid:
11151109
"""Generate a pairplot visualization for multi-cell perturbation (MCP) data.
11161110
11171111
Computes the mean of a specified MCP feature (mcp) for each combination of sample and cell type,
11181112
then creates a pairplot to visualize the relationships between these mean MCP values.
11191113
11201114
Args:
11211115
adata: Annotated data object.
1122-
celltype_key: Key in adata.obs containing cell type annotations.
1123-
color: Key in adata.obs for color annotations. This parameter is used as the hue
1124-
sample_id: Key in adata.obs for the sample annotations.
1125-
mcp: Key in adata.obs for MCP feature values. Defaults to "mcp_0".
1116+
celltype_key: Key in `adata.obs` containing cell type annotations.
1117+
color: Key in `adata.obs` for color annotations. This parameter is used as the hue
1118+
sample_id: Key in `adata.obs` for the sample annotations.
1119+
mcp: Key in `adata.obs` for MCP feature values. Defaults to `"mcp_0"`.
11261120
11271121
Returns:
11281122
Seaborn Pairgrid object.

0 commit comments

Comments
 (0)