diff --git a/pyproject.toml b/pyproject.toml index ef1fc1ac5..10102d26b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -122,6 +122,8 @@ ignore = [ "D107", # Missing docstring in magic method "D105", + # Use `X | Y` for type annotations + "UP007", ] line-length = 120 select = [ diff --git a/src/moscot/base/problems/_mixins.py b/src/moscot/base/problems/_mixins.py index 8e1bb648c..cd09cd9b9 100644 --- a/src/moscot/base/problems/_mixins.py +++ b/src/moscot/base/problems/_mixins.py @@ -1,15 +1,16 @@ +from __future__ import annotations + +import types from typing import ( TYPE_CHECKING, Any, - Dict, Generic, Iterable, - List, Literal, + Mapping, Optional, Protocol, Sequence, - Tuple, Union, ) @@ -45,8 +46,8 @@ class AnalysisMixinProtocol(Protocol[K, B]): adata: AnnData _policy: SubsetPolicy[K] - solutions: Dict[Tuple[K, K], BaseSolverOutput] - problems: Dict[Tuple[K, K], B] + solutions: dict[tuple[K, K], BaseSolverOutput] + problems: dict[tuple[K, K], B] def _apply( self, @@ -61,15 +62,15 @@ def _apply( ... def _interpolate_transport( - self: "AnalysisMixinProtocol[K, B]", - path: Sequence[Tuple[K, K]], + self: AnalysisMixinProtocol[K, B], + path: Sequence[tuple[K, K]], scale_by_marginals: bool = True, ) -> LinearOperator: ... def _flatten( - self: "AnalysisMixinProtocol[K, B]", - data: Dict[K, ArrayLike], + self: AnalysisMixinProtocol[K, B], + data: dict[K, ArrayLike], *, key: Optional[str], ) -> ArrayLike: @@ -83,8 +84,20 @@ def pull(self, *args: Any, **kwargs: Any) -> Optional[ApplyOutput_t[K]]: """Pull distribution.""" ... + def _cell_transition( + self: AnalysisMixinProtocol[K, B], + source: K, + target: K, + source_groups: Str_Dict_t, + target_groups: Str_Dict_t, + aggregation_mode: Literal["annotation", "cell"] = "annotation", + key_added: Optional[str] = _constants.CELL_TRANSITION, + **kwargs: Any, + ) -> pd.DataFrame: + ... + def _cell_transition_online( - self: "AnalysisMixinProtocol[K, B]", + self: AnalysisMixinProtocol[K, B], key: Optional[str], source: K, target: K, @@ -99,6 +112,20 @@ def _cell_transition_online( ) -> pd.DataFrame: ... + def _annotation_mapping( + self: AnalysisMixinProtocol[K, B], + mapping_mode: Literal["sum", "max"], + annotation_label: str, + forward: bool, + source: K, + target: K, + key: str | None = None, + other_adata: Optional[str] = None, + scale_by_marginals: bool = True, + cell_transition_kwargs: Mapping[str, Any] = types.MappingProxyType({}), + ) -> pd.DataFrame: + ... + class AnalysisMixin(Generic[K, B]): """Base Analysis Mixin.""" @@ -122,7 +149,6 @@ def _cell_transition( ) if aggregation_mode == "cell" and source_groups is None and target_groups is None: raise ValueError("At least one of `source_groups` and `target_group` must be specified.") - _check_argument_compatibility_cell_transition( source_annotation=source_groups, target_annotation=target_groups, @@ -179,13 +205,13 @@ def _cell_transition_online( ) df_source = _get_df_cell_transition( self.adata, - [source_annotation_key, target_annotation_key], + [source_annotation_key] if aggregation_mode == "cell" else [source_annotation_key, target_annotation_key], key, source, ) df_target = _get_df_cell_transition( self.adata if other_adata is None else other_adata, - [source_annotation_key, target_annotation_key], + [target_annotation_key] if aggregation_mode == "cell" else [source_annotation_key, target_annotation_key], key if other_adata is None else other_key, target, ) @@ -273,6 +299,90 @@ def _cell_transition_online( forward=forward, ) + def _annotation_mapping( + self: AnalysisMixinProtocol[K, B], + mapping_mode: Literal["sum", "max"], + annotation_label: str, + source: K, + target: K, + key: str | None = None, + forward: bool = True, + other_adata: str | None = None, + scale_by_marginals: bool = True, + batch_size: int | None = None, + cell_transition_kwargs: Mapping[str, Any] = types.MappingProxyType({}), + ) -> pd.DataFrame: + if mapping_mode == "sum": + cell_transition_kwargs = dict(cell_transition_kwargs) + cell_transition_kwargs.setdefault("aggregation_mode", "cell") # aggregation mode should be set to cell + cell_transition_kwargs.setdefault("key", key) + cell_transition_kwargs.setdefault("source", source) + cell_transition_kwargs.setdefault("target", target) + cell_transition_kwargs.setdefault("other_adata", other_adata) + cell_transition_kwargs.setdefault("forward", not forward) + if forward: + cell_transition_kwargs.setdefault("source_groups", annotation_label) + cell_transition_kwargs.setdefault("target_groups", None) + axis = 0 # rows + else: + cell_transition_kwargs.setdefault("source_groups", None) + cell_transition_kwargs.setdefault("target_groups", annotation_label) + axis = 1 # columns + out: pd.DataFrame = self._cell_transition(**cell_transition_kwargs) + return out.idxmax(axis=axis).to_frame(name=annotation_label) + if mapping_mode == "max": + out = [] + if forward: + source_df = _get_df_cell_transition( + self.adata, + annotation_keys=[annotation_label], + filter_key=key, + filter_value=source, + ) + out_len = self.solutions[(source, target)].shape[1] + batch_size = batch_size if batch_size is not None else out_len + for batch in range(0, out_len, batch_size): + tm_batch: ArrayLike = self.push( + source=source, + target=target, + data=None, + subset=(batch, batch_size), + normalize=True, + return_all=False, + scale_by_marginals=scale_by_marginals, + split_mass=True, + key_added=None, + ) + v = np.array(tm_batch.argmax(1)) + out.extend(source_df[annotation_label][v[i]] for i in range(len(v))) + + else: + target_df = _get_df_cell_transition( + self.adata if other_adata is None else other_adata, + annotation_keys=[annotation_label], + filter_key=key, + filter_value=target, + ) + out_len = self.solutions[(source, target)].shape[0] + batch_size = batch_size if batch_size is not None else out_len + for batch in range(0, out_len, batch_size): + tm_batch: ArrayLike = self.pull( # type: ignore[no-redef] + source=source, + target=target, + data=None, + subset=(batch, batch_size), + normalize=True, + return_all=False, + scale_by_marginals=scale_by_marginals, + split_mass=True, + key_added=None, + ) + v = np.array(tm_batch.argmax(1)) + out.extend(target_df[annotation_label][v[i]] for i in range(len(v))) + categories = pd.Categorical(out) + return pd.DataFrame(categories, columns=[annotation_label]) + raise NotImplementedError(f"Mapping mode `{mapping_mode!r}` is not yet implemented.") + def _sample_from_tmap( self: AnalysisMixinProtocol[K, B], source: K, @@ -284,7 +394,7 @@ def _sample_from_tmap( account_for_unbalancedness: bool = False, interpolation_parameter: Optional[Numeric_t] = None, seed: Optional[int] = None, - ) -> Tuple[List[Any], List[ArrayLike]]: + ) -> tuple[list[Any], list[ArrayLike]]: rng = np.random.RandomState(seed) if account_for_unbalancedness and interpolation_parameter is None: raise ValueError("When accounting for unbalancedness, interpolation parameter must be provided.") @@ -321,7 +431,7 @@ def _sample_from_tmap( rows_sampled = rng.choice(source_dim, p=row_probability / row_probability.sum(), size=n_samples) rows, counts = np.unique(rows_sampled, return_counts=True) - all_cols_sampled: List[str] = [] + all_cols_sampled: list[str] = [] for batch in range(0, len(rows), batch_size): rows_batch = rows[batch : batch + batch_size] counts_batch = counts[batch : batch + batch_size] @@ -354,7 +464,7 @@ def _sample_from_tmap( def _interpolate_transport( self: AnalysisMixinProtocol[K, B], # TODO(@giovp): rename this to 'explicit_steps', pass to policy.plan() and reintroduce (source_key, target_key) - path: Sequence[Tuple[K, K]], + path: Sequence[tuple[K, K]], scale_by_marginals: bool = True, **_: Any, ) -> LinearOperator: @@ -365,7 +475,7 @@ def _interpolate_transport( fst, *rest = path return self.solutions[fst].chain([self.solutions[r] for r in rest], scale_by_marginals=scale_by_marginals) - def _flatten(self: AnalysisMixinProtocol[K, B], data: Dict[K, ArrayLike], *, key: Optional[str]) -> ArrayLike: + def _flatten(self: AnalysisMixinProtocol[K, B], data: dict[K, ArrayLike], *, key: Optional[str]) -> ArrayLike: tmp = np.full(len(self.adata), np.nan) for k, v in data.items(): mask = self.adata.obs[key] == k @@ -377,8 +487,8 @@ def _annotation_aggregation_transition( source: K, target: K, annotation_key: str, - annotations_1: List[Any], - annotations_2: List[Any], + annotations_1: list[Any], + annotations_2: list[Any], df: pd.DataFrame, tm: pd.DataFrame, forward: bool, @@ -413,8 +523,8 @@ def _cell_aggregation_transition( target: str, annotation_key: str, # TODO(MUCDK): unused variables, del below - annotations_1: List[Any], - annotations_2: List[Any], + annotations_1: list[Any], + annotations_2: list[Any], df_1: pd.DataFrame, df_2: pd.DataFrame, tm: pd.DataFrame, @@ -450,9 +560,9 @@ def compute_feature_correlation( obs_key: str, corr_method: Literal["pearson", "spearman"] = "pearson", significance_method: Literal["fisher", "perm_test"] = "fisher", - annotation: Optional[Dict[str, Iterable[str]]] = None, + annotation: Optional[dict[str, Iterable[str]]] = None, layer: Optional[str] = None, - features: Optional[Union[List[str], Literal["human", "mouse", "drosophila"]]] = None, + features: Optional[Union[list[str], Literal["human", "mouse", "drosophila"]]] = None, confidence_level: float = 0.95, n_perms: int = 1000, seed: Optional[int] = None, diff --git a/src/moscot/base/problems/_utils.py b/src/moscot/base/problems/_utils.py index a2c09674c..397248d3e 100644 --- a/src/moscot/base/problems/_utils.py +++ b/src/moscot/base/problems/_utils.py @@ -114,7 +114,7 @@ def _check_argument_compatibility_cell_transition( raise ValueError("Unable to infer distributions, missing `adata` and `key`.") if forward and target_annotation is None: raise ValueError("No target annotation provided.") - if not forward and source_annotation is None: + if aggregation_mode == "annotation" and (not forward and source_annotation is None): raise ValueError("No source annotation provided.") if (aggregation_mode == "annotation") and (source_annotation is None or target_annotation is None): raise ValueError( diff --git a/src/moscot/problems/cross_modality/_mixins.py b/src/moscot/problems/cross_modality/_mixins.py index 8235a7161..ce58f84a4 100644 --- a/src/moscot/problems/cross_modality/_mixins.py +++ b/src/moscot/problems/cross_modality/_mixins.py @@ -1,4 +1,5 @@ -from typing import TYPE_CHECKING, Any, Dict, Literal, Optional +import types +from typing import TYPE_CHECKING, Any, Dict, Literal, Mapping, Optional import pandas as pd @@ -24,6 +25,9 @@ class CrossModalityTranslationMixinProtocol(AnalysisMixinProtocol[K, B]): def _cell_transition(self: AnalysisMixinProtocol[K, B], *args: Any, **kwargs: Any) -> pd.DataFrame: ... + def _annotation_mapping(self: AnalysisMixinProtocol[K, B], *args: Any, **kwargs: Any) -> pd.DataFrame: + ... + class CrossModalityTranslationMixin(AnalysisMixin[K, B]): """Cross modality translation analysis mixin class.""" @@ -183,6 +187,53 @@ def cell_transition( # type: ignore[misc] key_added=key_added, ) + def annotation_mapping( # type: ignore[misc] + self: CrossModalityTranslationMixinProtocol[K, B], + mapping_mode: Literal["sum", "max"], + annotation_label: str, + forward: bool, + source: str = "src", + target: str = "tgt", + cell_transition_kwargs: Mapping[str, Any] = types.MappingProxyType({}), + ) -> pd.DataFrame: + """Transfer annotations between distributions. + + This function transfers annotations (e.g. cell type labels) between distributions of cells. + + Parameters + ---------- + mapping_mode + How to decide which label to transfer. Valid options are: + + - ``'max'`` - pick the label of the annotated cell with the highest matching probability. + - ``'sum'`` - aggregate the annotated cells by label then + pick the label with the highest total matching probability. + annotation_label + Key in :attr:`~anndata.AnnData.obs` where the annotation is stored. + forward + If :obj:`True`, transfer the annotations from ``source`` to ``target``. + source + Key identifying the source distribution. + target + Key identifying the target distribution. + cell_transition_kwargs + Keyword arguments for :meth:`cell_transition`, used only if ``mapping_mode = 'sum'``. + + Returns + ------- + :class:`~pandas.DataFrame` - Returns the DataFrame of transferred annotations. + """ + return self._annotation_mapping( + mapping_mode=mapping_mode, + annotation_label=annotation_label, + source=source, + target=target, + key=self.batch_key, + forward=forward, + other_adata=self.adata_tgt, + cell_transition_kwargs=cell_transition_kwargs, + ) + @property def batch_key(self) -> Optional[str]: """Batch key in :attr:`~anndata.AnnData.obs`.""" diff --git a/src/moscot/problems/space/_mixins.py b/src/moscot/problems/space/_mixins.py index 429f5d0d0..0aa84c326 100644 --- a/src/moscot/problems/space/_mixins.py +++ b/src/moscot/problems/space/_mixins.py @@ -1,4 +1,5 @@ import itertools +import types from typing import ( TYPE_CHECKING, Any, @@ -63,6 +64,13 @@ def _cell_transition( ) -> pd.DataFrame: ... + def _annotation_mapping( + self: AnalysisMixinProtocol[K, B], + *args: Any, + **kwargs: Any, + ) -> pd.DataFrame: + ... + class SpatialMappingMixinProtocol(AnalysisMixinProtocol[K, B]): """Protocol class.""" @@ -82,6 +90,9 @@ def _filter_vars( def _cell_transition(self: AnalysisMixinProtocol[K, B], *args: Any, **kwargs: Any) -> pd.DataFrame: ... + def _annotation_mapping(self: AnalysisMixinProtocol[K, B], *args: Any, **kwargs: Any) -> pd.DataFrame: + ... + class SpatialAlignmentMixin(AnalysisMixin[K, B]): """Spatial alignment mixin class.""" @@ -273,6 +284,52 @@ def cell_transition( # type: ignore[misc] key_added=key_added, ) + def annotation_mapping( # type: ignore[misc] + self: SpatialAlignmentMixinProtocol[K, B], + mapping_mode: Literal["sum", "max"], + annotation_label: str, + forward: bool, + source: str = "src", + target: str = "tgt", + cell_transition_kwargs: Mapping[str, Any] = types.MappingProxyType({}), + ) -> pd.DataFrame: + """Transfer annotations between distributions. + + This function transfers annotations (e.g. cell type labels) between distributions of cells. + + Parameters + ---------- + mapping_mode + How to decide which label to transfer. Valid options are: + + - ``'max'`` - pick the label of the annotated cell with the highest matching probability. + - ``'sum'`` - aggregate the annotated cells by label then + pick the label with the highest total matching probability. + annotation_label + Key in :attr:`~anndata.AnnData.obs` where the annotation is stored. + forward + If :obj:`True`, transfer the annotations from ``source`` to ``target``. + source + Key identifying the source distribution. + target + Key identifying the target distribution. + cell_transition_kwargs + Keyword arguments for :meth:`cell_transition`, used only if ``mapping_mode = 'sum'``. + + Returns + ------- + :class:`~pandas.DataFrame` - Returns the DataFrame of transferred annotations. + """ + return self._annotation_mapping( + mapping_mode=mapping_mode, + annotation_label=annotation_label, + source=source, + target=target, + key=self.batch_key, + forward=forward, + cell_transition_kwargs=cell_transition_kwargs, + ) + @property def spatial_key(self) -> Optional[str]: """Spatial key in :attr:`~anndata.AnnData.obsm`.""" @@ -562,6 +619,53 @@ def cell_transition( # type: ignore[misc] key_added=key_added, ) + def annotation_mapping( # type: ignore[misc] + self: SpatialMappingMixinProtocol[K, B], + mapping_mode: Literal["sum", "max"], + annotation_label: str, + source: K, + target: Union[K, str] = "tgt", + forward: bool = False, + cell_transition_kwargs: Mapping[str, Any] = types.MappingProxyType({}), + ) -> pd.DataFrame: + """Transfer annotations between distributions. + + This function transfers annotations (e.g. cell type labels) between distributions of cells. + + Parameters + ---------- + mapping_mode + How to decide which label to transfer. Valid options are: + + - ``'max'`` - pick the label of the annotated cell with the highest matching probability. + - ``'sum'`` - aggregate the annotated cells by label then + pick the label with the highest total matching probability. + annotation_label + Key in :attr:`~anndata.AnnData.obs` where the annotation is stored. + forward + If :obj:`True`, transfer the annotations from ``source`` to ``target``. + source + Key identifying the source distribution. + target + Key identifying the target distribution. + cell_transition_kwargs + Keyword arguments for :meth:`cell_transition`, used only if ``mapping_mode = 'sum'``. + + Returns + ------- + :class:`~pandas.DataFrame` - Returns the DataFrame of transferred annotations. + """ + return self._annotation_mapping( + mapping_mode=mapping_mode, + annotation_label=annotation_label, + source=source, + target=target, + forward=forward, + key=self.batch_key, + other_adata=self.adata_sc, + cell_transition_kwargs=cell_transition_kwargs, + ) + @property def batch_key(self) -> Optional[str]: """Batch key in :attr:`~anndata.AnnData.obs`.""" diff --git a/src/moscot/problems/time/_mixins.py b/src/moscot/problems/time/_mixins.py index 55b94c890..597757901 100644 --- a/src/moscot/problems/time/_mixins.py +++ b/src/moscot/problems/time/_mixins.py @@ -1,17 +1,18 @@ +from __future__ import annotations + import itertools import pathlib +import types from typing import ( TYPE_CHECKING, Any, - Dict, Iterable, Iterator, - List, Literal, + Mapping, Optional, Protocol, Sequence, - Tuple, Union, ) @@ -34,12 +35,12 @@ class TemporalMixinProtocol(AnalysisMixinProtocol[K, B], Protocol[K, B]): # type: ignore[misc] adata: AnnData - problems: Dict[Tuple[K, K], BirthDeathProblem] + problems: dict[tuple[K, K], BirthDeathProblem] temporal_key: Optional[str] _temporal_key: Optional[str] def cell_transition( # noqa: D102 - self: "TemporalMixinProtocol[K, B]", + self: TemporalMixinProtocol[K, B], source: K, target: K, source_groups: Str_Dict_t, @@ -65,8 +66,15 @@ def _cell_transition( ) -> pd.DataFrame: ... + def _annotation_mapping( + self: AnalysisMixinProtocol[K, B], + *args: Any, + **kwargs: Any, + ) -> pd.DataFrame: + ... + def _sample_from_tmap( - self: "TemporalMixinProtocol[K, B]", + self: TemporalMixinProtocol[K, B], source: K, target: K, n_samples: int, @@ -76,11 +84,11 @@ def _sample_from_tmap( account_for_unbalancedness: bool = False, interpolation_parameter: Optional[float] = None, seed: Optional[int] = None, - ) -> Tuple[List[Any], List[ArrayLike]]: + ) -> tuple[list[Any], list[ArrayLike]]: ... def _compute_wasserstein_distance( - self: "TemporalMixinProtocol[K, B]", + self: TemporalMixinProtocol[K, B], point_cloud_1: ArrayLike, point_cloud_2: ArrayLike, a: Optional[ArrayLike] = None, @@ -91,7 +99,7 @@ def _compute_wasserstein_distance( ... def _interpolate_gex_with_ot( - self: "TemporalMixinProtocol[K, B]", + self: TemporalMixinProtocol[K, B], number_cells: int, source_data: ArrayLike, target_data: ArrayLike, @@ -105,18 +113,18 @@ def _interpolate_gex_with_ot( ... def _get_data( - self: "TemporalMixinProtocol[K, B]", + self: TemporalMixinProtocol[K, B], source: K, intermediate: Optional[K] = None, target: Optional[K] = None, posterior_marginals: bool = True, *, only_start: bool = False, - ) -> Union[Tuple[ArrayLike, AnnData], Tuple[ArrayLike, ArrayLike, ArrayLike, AnnData, ArrayLike]]: + ) -> Union[tuple[ArrayLike, AnnData], tuple[ArrayLike, ArrayLike, ArrayLike, AnnData, ArrayLike]]: ... def _interpolate_gex_randomly( - self: "TemporalMixinProtocol[K, B]", + self: TemporalMixinProtocol[K, B], number_cells: int, source_data: ArrayLike, target_data: ArrayLike, @@ -127,8 +135,8 @@ def _interpolate_gex_randomly( ... def _plot_temporal( - self: "TemporalMixinProtocol[K, B]", - data: Dict[K, ArrayLike], + self: TemporalMixinProtocol[K, B], + data: dict[K, ArrayLike], source: K, target: K, time_points: Optional[Iterable[K]] = None, @@ -146,7 +154,7 @@ def _get_interp_param( ) -> float: ... - def __iter__(self) -> Iterator[Tuple[K, K]]: + def __iter__(self) -> Iterator[tuple[K, K]]: ... @@ -232,8 +240,55 @@ def cell_transition( key_added=key_added, ) + def annotation_mapping( + self: TemporalMixinProtocol[K, B], + mapping_mode: Literal["sum", "max"], + annotation_label: str, + forward: bool, + source: K, + target: K, + cell_transition_kwargs: Mapping[str, Any] = types.MappingProxyType({}), + ) -> pd.DataFrame: + """Transfer annotations between distributions. + + This function transfers annotations (e.g. cell type labels) between distributions of cells. + + Parameters + ---------- + mapping_mode + How to decide which label to transfer. Valid options are: + + - ``'max'`` - pick the label of the annotated cell with the highest matching probability. + - ``'sum'`` - aggregate the annotated cells by label then + pick the label with the highest total matching probability. + annotation_label + Key in :attr:`~anndata.AnnData.obs` where the annotation is stored. + forward + If :obj:`True`, transfer the annotations from ``source`` to ``target``. + source + Key identifying the source distribution. + target + Key identifying the target distribution. + cell_transition_kwargs + Keyword arguments for :meth:`cell_transition`, used only if ``mapping_mode = 'sum'``. + + Returns + ------- + :class:`~pandas.DataFrame` - Returns the DataFrame of transferred annotations. + """ + return self._annotation_mapping( + mapping_mode=mapping_mode, + annotation_label=annotation_label, + source=source, + target=target, + key=self._temporal_key, + forward=forward, + other_adata=None, + cell_transition_kwargs=cell_transition_kwargs, + ) + def sankey( - self: "TemporalMixinProtocol[K, B]", + self: TemporalMixinProtocol[K, B], source: K, target: K, source_groups: Str_Dict_t, @@ -245,7 +300,7 @@ def sankey( order_annotations: Optional[Sequence[str]] = None, key_added: Optional[str] = _constants.SANKEY, **kwargs: Any, - ) -> Optional[List[pd.DataFrame]]: + ) -> Optional[list[pd.DataFrame]]: """Compute a `Sankey diagram `_ between cells across time points. .. seealso:: @@ -359,7 +414,7 @@ def push( source: K, target: K, data: Optional[Union[str, ArrayLike]] = None, - subset: Optional[Union[str, List[str], Tuple[int, int]]] = None, + subset: Optional[Union[str, list[str], tuple[int, int]]] = None, scale_by_marginals: bool = True, key_added: Optional[str] = _constants.PUSH, return_all: bool = False, @@ -426,7 +481,7 @@ def pull( source: K, target: K, data: Optional[Union[str, ArrayLike]] = None, - subset: Optional[Union[str, List[str], Tuple[int, int]]] = None, + subset: Optional[Union[str, list[str], tuple[int, int]]] = None, scale_by_marginals: bool = True, key_added: Optional[str] = _constants.PULL, return_all: bool = False, @@ -581,7 +636,7 @@ def _get_data( posterior_marginals: bool = True, *, only_start: bool = False, - ) -> Union[Tuple[ArrayLike, AnnData], Tuple[ArrayLike, ArrayLike, ArrayLike, AnnData, ArrayLike]]: + ) -> Union[tuple[ArrayLike, AnnData], tuple[ArrayLike, ArrayLike, ArrayLike, AnnData, ArrayLike]]: # TODO: use .items() for src, tgt in self.problems: tag = self.problems[src, tgt].xy.tag # type: ignore[union-attr] @@ -788,7 +843,7 @@ def compute_time_point_distances( posterior_marginals: bool = True, backend: Literal["ott"] = "ott", **kwargs: Any, - ) -> Tuple[float, float]: + ) -> tuple[float, float]: """Compute `Wasserstein distance `_ between time points. .. seealso:: @@ -871,7 +926,7 @@ def compute_batch_distances( if len(data) != len(adata): raise ValueError(f"Expected the `data` to have length `{len(adata)}`, found `{len(data)}`.") - dist: List[float] = [] + dist: list[float] = [] for batch_1, batch_2 in itertools.combinations(adata.obs[batch_key].unique(), 2): dist.append( self._compute_wasserstein_distance( diff --git a/src/moscot/utils/subset_policy.py b/src/moscot/utils/subset_policy.py index 8552cf0fc..a5375de84 100644 --- a/src/moscot/utils/subset_policy.py +++ b/src/moscot/utils/subset_policy.py @@ -83,7 +83,10 @@ def __init__( self._subset_key: Optional[str] = key if verify_integrity and len(self._cat) < 2: - raise ValueError(f"Policy must contain at least `2` different values, found `{len(self._cat)}`.") + raise ValueError( + f"Policy must contain at least `2` different values, found `{len(self._cat)}`.\n" + "Is it possible that there is only one `batch` in `batch_key`?" + ) @abc.abstractmethod def _create_graph(self, **kwargs: Any) -> Set[Tuple[K, K]]: diff --git a/tests/conftest.py b/tests/conftest.py index 8d23e9260..f86b26d3b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,5 +1,5 @@ from math import cos, sin -from typing import Optional, Tuple +from typing import Literal, Optional, Tuple, Union import pytest @@ -206,3 +206,62 @@ def adata_translation_split(adata_translation) -> Tuple[AnnData, AnnData]: adata_src.obsm["emb_src"] = rng.normal(size=(adata_src.shape[0], 5)) adata_tgt.obsm["emb_tgt"] = rng.normal(size=(adata_tgt.shape[0], 15)) return adata_src, adata_tgt + + +@pytest.fixture() +def adata_anno( + problem_kind: Literal["temporal", "cross_modality", "alignment", "mapping"], +) -> Union[AnnData, Tuple[AnnData, AnnData]]: + rng = np.random.RandomState(31) + adata_src = AnnData(X=csr_matrix(rng.normal(size=(10, 60)))) + rng_src = rng.choice(["A", "B", "C"], size=5).tolist() + adata_src.obs["celltype1"] = ["C", "C", "A", "B", "B"] + rng_src + adata_src.obs["celltype1"] = adata_src.obs["celltype1"].astype("category") + adata_src.uns["expected_max1"] = ["C", "C", "A", "B", "B"] + rng_src + rng_src + adata_src.uns["expected_sum1"] = ["C", "C", "B", "B", "B"] + rng_src + rng_src + + adata_tgt = AnnData(X=csr_matrix(rng.normal(size=(15, 60)))) + rng_tgt = rng.choice(["A", "B", "C"], size=5).tolist() + adata_tgt.obs["celltype2"] = ["C", "C", "A", "B", "B"] + rng_tgt + rng_tgt + adata_tgt.obs["celltype2"] = adata_tgt.obs["celltype2"].astype("category") + adata_tgt.uns["expected_max2"] = ["C", "C", "A", "B", "B"] + rng_tgt + adata_tgt.uns["expected_sum2"] = ["C", "C", "B", "B", "B"] + rng_tgt + + if problem_kind == "cross_modality": + adata_src.obs["batch"] = "0" + adata_tgt.obs["batch"] = "1" + adata_src.obsm["emb_src"] = rng.normal(size=(adata_src.shape[0], 5)) + adata_tgt.obsm["emb_tgt"] = rng.normal(size=(adata_tgt.shape[0], 15)) + sc.pp.pca(adata_src) + sc.pp.pca(adata_tgt) + return adata_src, adata_tgt + if problem_kind == "mapping": + adata_src.obs["batch"] = "0" + adata_tgt.obs["batch"] = "1" + sc.pp.pca(adata_src) + sc.pp.pca(adata_tgt) + adata_tgt.obsm["spatial"] = rng.normal(size=(adata_tgt.n_obs, 2)) + return adata_src, adata_tgt + if problem_kind == "alignment": + adata_src.obsm["spatial"] = rng.normal(size=(adata_src.n_obs, 2)) + adata_tgt.obsm["spatial"] = rng.normal(size=(adata_tgt.n_obs, 2)) + key = "day" if problem_kind == "temporal" else "batch" + adatas = [adata_src, adata_tgt] + adata = ad.concat(adatas, join="outer", label=key, index_unique="-", uns_merge="unique") + adata.obs[key] = (pd.to_numeric(adata.obs[key]) if key == "day" else adata.obs[key]).astype("category") + adata.layers["counts"] = adata.X.A + sc.pp.pca(adata) + return adata + + +@pytest.fixture() +def gt_tm_annotation() -> np.ndarray: + tm = np.zeros((10, 15)) + for i in range(10): + tm[i][i] = 1 + for i in range(10, 15): + tm[i - 5][i] = 1 + for j in range(2, 5): + for i in range(2, 5): + tm[i][j] = 0.3 if i != j else 0.4 + return tm diff --git a/tests/problems/cross_modality/test_mixins.py b/tests/problems/cross_modality/test_mixins.py index 33c3cb918..079e153a4 100644 --- a/tests/problems/cross_modality/test_mixins.py +++ b/tests/problems/cross_modality/test_mixins.py @@ -106,3 +106,30 @@ def test_cell_transition_pipeline( assert result2.shape == (3, 3) with pytest.raises(AssertionError): pd.testing.assert_frame_equal(result1, result2) + + @pytest.mark.fast() + @pytest.mark.parametrize("forward", [True, False]) + @pytest.mark.parametrize("mapping_mode", ["max", "sum"]) + @pytest.mark.parametrize("problem_kind", ["cross_modality"]) + def test_annotation_mapping( + self, adata_anno: Tuple[AnnData, AnnData], forward: bool, mapping_mode, gt_tm_annotation + ): + adata_src, adata_tgt = adata_anno + tp = TranslationProblem(adata_src, adata_tgt) + tp = tp.prepare(src_attr="emb_src", tgt_attr="emb_tgt") + problem_keys = ("src", "tgt") + assert set(tp.problems.keys()) == {problem_keys} + tp[problem_keys].set_solution(MockSolverOutput(gt_tm_annotation), overwrite=True) + annotation_label = "celltype1" if forward else "celltype2" + result = tp.annotation_mapping( + mapping_mode=mapping_mode, annotation_label=annotation_label, forward=forward, source="src", target="tgt" + ) + if forward: + expected_result = ( + adata_src.uns["expected_max1"] if mapping_mode == "max" else adata_src.uns["expected_sum1"] + ) + else: + expected_result = ( + adata_tgt.uns["expected_max2"] if mapping_mode == "max" else adata_tgt.uns["expected_sum2"] + ) + assert (result[annotation_label] == expected_result).all() diff --git a/tests/problems/space/test_mixins.py b/tests/problems/space/test_mixins.py index e208d7763..a6b70031c 100644 --- a/tests/problems/space/test_mixins.py +++ b/tests/problems/space/test_mixins.py @@ -93,6 +93,34 @@ def test_cell_transition_pipeline(self, adata_space_rotate: AnnData, forward: bo assert isinstance(result, pd.DataFrame) assert result.shape == (3, 3) + @pytest.mark.fast() + @pytest.mark.parametrize("forward", [True, False]) + @pytest.mark.parametrize("mapping_mode", ["max", "sum"]) + @pytest.mark.parametrize("problem_kind", ["alignment"]) + def test_annotation_mapping(self, adata_anno: AnnData, forward: bool, mapping_mode, gt_tm_annotation): + ap = AlignmentProblem(adata=adata_anno) + ap = ap.prepare(batch_key="batch", joint_attr={"attr": "X"}) + problem_keys = ("0", "1") + assert set(ap.problems.keys()) == {problem_keys} + ap[problem_keys].set_solution(MockSolverOutput(gt_tm_annotation)) + annotation_label = "celltype1" if forward else "celltype2" + result = ap.annotation_mapping( + mapping_mode=mapping_mode, + annotation_label=annotation_label, + source="0", + target="1", + forward=forward, + ) + if forward: + expected_result = ( + adata_anno.uns["expected_max1"] if mapping_mode == "max" else adata_anno.uns["expected_sum1"] + ) + else: + expected_result = ( + adata_anno.uns["expected_max2"] if mapping_mode == "max" else adata_anno.uns["expected_sum2"] + ) + assert (result[annotation_label] == expected_result).all() + class TestSpatialMappingAnalysisMixin: @pytest.mark.parametrize("sc_attr", [{"attr": "X"}, {"attr": "obsm", "key": "X_pca"}]) @@ -175,3 +203,27 @@ def test_cell_transition_pipeline(self, adata_mapping: AnnData, forward: bool, n assert isinstance(result, pd.DataFrame) assert result.shape == (3, 4) + + @pytest.mark.fast() + @pytest.mark.parametrize("forward", [True, False]) + @pytest.mark.parametrize("mapping_mode", ["max", "sum"]) + @pytest.mark.parametrize("problem_kind", ["mapping"]) + def test_annotation_mapping(self, adata_anno: AnnData, forward: bool, mapping_mode, gt_tm_annotation): + adataref, adatasp = adata_anno + mp = MappingProblem(adataref, adatasp) + mp = mp.prepare(sc_attr={"attr": "obsm", "key": "X_pca"}, joint_attr={"attr": "X"}) + problem_keys = ("src", "tgt") + assert set(mp.problems.keys()) == {problem_keys} + mp[problem_keys].set_solution(MockSolverOutput(gt_tm_annotation.T)) + annotation_label = "celltype1" if not forward else "celltype2" + result = mp.annotation_mapping( + mapping_mode=mapping_mode, + annotation_label=annotation_label, + source="src", + forward=forward, + ) + if not forward: + expected_result = adataref.uns["expected_max1"] if mapping_mode == "max" else adataref.uns["expected_sum1"] + else: + expected_result = adatasp.uns["expected_max2"] if mapping_mode == "max" else adatasp.uns["expected_sum2"] + assert (result[annotation_label] == expected_result).all() diff --git a/tests/problems/time/test_mixins.py b/tests/problems/time/test_mixins.py index 612211879..cb2d9ea2a 100644 --- a/tests/problems/time/test_mixins.py +++ b/tests/problems/time/test_mixins.py @@ -50,6 +50,30 @@ def test_cell_transition_full_pipeline(self, gt_temporal_adata: AnnData, forward present_cell_type_marginal = marginal[marginal > 0] np.testing.assert_allclose(present_cell_type_marginal, 1.0) + @pytest.mark.fast() + @pytest.mark.parametrize("forward", [True, False]) + @pytest.mark.parametrize("mapping_mode", ["max", "sum"]) + @pytest.mark.parametrize("problem_kind", ["temporal"]) + def test_annotation_mapping(self, adata_anno: AnnData, forward: bool, mapping_mode, gt_tm_annotation): + problem = TemporalProblem(adata_anno) + problem_keys = (0, 1) + problem = problem.prepare(time_key="day", joint_attr="X_pca") + assert set(problem.problems.keys()) == {problem_keys} + problem[problem_keys]._solution = MockSolverOutput(gt_tm_annotation) + annotation_label = "celltype1" if forward else "celltype2" + result = problem.annotation_mapping( + mapping_mode=mapping_mode, annotation_label=annotation_label, forward=forward, source=0, target=1 + ) + if forward: + expected_result = ( + adata_anno.uns["expected_max1"] if mapping_mode == "max" else adata_anno.uns["expected_sum1"] + ) + else: + expected_result = ( + adata_anno.uns["expected_max2"] if mapping_mode == "max" else adata_anno.uns["expected_sum2"] + ) + assert (result[annotation_label] == expected_result).all() + @pytest.mark.fast() @pytest.mark.parametrize("forward", [True, False]) def test_cell_transition_different_groups(self, gt_temporal_adata: AnnData, forward: bool):