diff --git a/pertpy/plot/_coda.py b/pertpy/plot/_coda.py index 6a5b597d..420c6738 100644 --- a/pertpy/plot/_coda.py +++ b/pertpy/plot/_coda.py @@ -79,7 +79,7 @@ def __stackbar( # pragma: no cover if show_legend: ax.legend(loc="upper left", bbox_to_anchor=(1, 1), ncol=1) ax.set_xticks(r) - ax.set_xticklabels(level_names, rotation=45) + ax.set_xticklabels(level_names, rotation=45, ha="right") ax.set_ylabel("Proportion") return ax diff --git a/pertpy/tools/_base_coda.py b/pertpy/tools/_base_coda.py index 5973cffc..29bfda79 100644 --- a/pertpy/tools/_base_coda.py +++ b/pertpy/tools/_base_coda.py @@ -1350,7 +1350,7 @@ def import_tree( def from_scanpy( adata: AnnData, cell_type_identifier: str, - sample_identifier: str, + sample_identifier: str | list[str], covariate_uns: str | None = None, covariate_obs: list[str] | None = None, covariate_df: pd.DataFrame | None = None, @@ -1358,16 +1358,16 @@ def from_scanpy( """ Creates a compositional analysis dataset from a single anndata object, as it is produced by e.g. scanpy. - The anndata object needs to have a column in adata.obs that contains the cell type assignment, - and one column that specifies the grouping into samples. - Covariates can either be specified via a key in adata.uns, or as a separate DataFrame. + The anndata object needs to have a column in adata.obs that contains the cell type assignment. + Further, it must contain one column or a set of columns (e.g. subject id, treatment, disease status) that uniquely identify each (statistical) sample. + Further covariates (e.g. subject age) can either be specified via addidional column names in adata.obs, a key in adata.uns, or as a separate DataFrame. NOTE: The order of samples in the returned dataset is determined by the first occurence of cells from each sample in `adata` Args: adata: An anndata object from scanpy cell_type_identifier: column name in adata.obs that specifies the cell types - sample_identifier: column name in adata.obs that specifies the sample + sample_identifier: column name or list of column names in adata.obs that uniquely identify each sample covariate_uns: key for adata.uns, where covariate values are stored covariate_obs: list of column names in adata.obs, where covariate values are stored. Note: If covariate values are not unique for a value of sample_identifier, this covaariate will be skipped. covariate_df: DataFrame with covariates @@ -1377,6 +1377,19 @@ def from_scanpy( """ + if type(sample_identifier) == str: + sample_identifier = [sample_identifier] + + if covariate_obs: + covariate_obs += sample_identifier + else: + covariate_obs = sample_identifier # type: ignore + + # join sample identifiers + if type(sample_identifier) == list: + adata.obs["scCODA_sample_id"] = adata.obs[sample_identifier].agg("-".join, axis=1) + sample_identifier = "scCODA_sample_id" + # get cell type counts groups = adata.obs.value_counts([sample_identifier, cell_type_identifier]) count_data = groups.unstack(level=cell_type_identifier) diff --git a/pertpy/tools/_mixscape.py b/pertpy/tools/_mixscape.py index 91259bfa..356b8496 100644 --- a/pertpy/tools/_mixscape.py +++ b/pertpy/tools/_mixscape.py @@ -383,8 +383,17 @@ def lda( return adata def _get_perturbation_markers( - self, adata, split_masks, categories, labels, control, layer, pval_cutoff, min_de_genes, logfc_threshold - ): + self, + adata: AnnData, + split_masks: list[np.ndarray], + categories: list[str], + labels: str, + control: str, + layer: str, + pval_cutoff: float, + min_de_genes: float, + logfc_threshold: float, + ) -> dict[tuple, np.ndarray]: """determine gene sets across all splits/groups through differential gene expression Args: