From e1608b0667f9d4dc2a03fc2905ad70844d23f264 Mon Sep 17 00:00:00 2001 From: Arina Danilina Date: Thu, 18 Jan 2024 14:54:04 +0100 Subject: [PATCH] some mypy fixes --- pyproject.toml | 2 ++ src/moscot/base/problems/_mixins.py | 10 +++++----- src/moscot/problems/cross_modality/_mixins.py | 10 ++++++---- src/moscot/problems/space/_mixins.py | 14 +++++++------- src/moscot/problems/time/_mixins.py | 7 +++++++ 5 files changed, 27 insertions(+), 16 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index ef1fc1ac5..10102d26b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -122,6 +122,8 @@ ignore = [ "D107", # Missing docstring in magic method "D105", + # Use `X | Y` for type annotations + "UP007", ] line-length = 120 select = [ diff --git a/src/moscot/base/problems/_mixins.py b/src/moscot/base/problems/_mixins.py index 36863551a..43ee8f248 100644 --- a/src/moscot/base/problems/_mixins.py +++ b/src/moscot/base/problems/_mixins.py @@ -122,7 +122,7 @@ def _annotation_mapping( forward: bool, source: K, target: K, - key: str, + key: str | None = None, other_adata: Optional[str] = None, scale_by_marginals: bool = True, cell_transition_kwargs: Mapping[str, Any] = types.MappingProxyType({}), @@ -342,10 +342,10 @@ def _annotation_mapping( filter_key=key, filter_value=source, ) - out_len = self[(source, target)].solution.shape[1] + out_len = self.solutions[(source, target)].shape[1] batch_size = batch_size if batch_size is not None else out_len for batch in range(0, out_len, batch_size): - tm_batch = self.push( + tm_batch : ArrayLike = self.push( # type: ignore[attr-defined] source=source, target=target, data=None, @@ -366,10 +366,10 @@ def _annotation_mapping( filter_key=key, filter_value=target, ) - out_len = self[(source, target)].solution.shape[0] + out_len = self.solutions[(source, target)].shape[0] batch_size = batch_size if batch_size is not None else out_len for batch in range(0, out_len, batch_size): - tm_batch = self.pull( + tm_batch : ArrayLike = self.pull( # type: ignore[attr-defined] source=source, target=target, data=None, diff --git a/src/moscot/problems/cross_modality/_mixins.py b/src/moscot/problems/cross_modality/_mixins.py index a42ecb14b..46367c9e0 100644 --- a/src/moscot/problems/cross_modality/_mixins.py +++ b/src/moscot/problems/cross_modality/_mixins.py @@ -25,6 +25,8 @@ class CrossModalityTranslationMixinProtocol(AnalysisMixinProtocol[K, B]): def _cell_transition(self: AnalysisMixinProtocol[K, B], *args: Any, **kwargs: Any) -> pd.DataFrame: ... + def _annotation_mapping(self: AnalysisMixinProtocol[K, B], *args: Any, **kwargs: Any) -> pd.DataFrame: + ... class CrossModalityTranslationMixin(AnalysisMixin[K, B]): """Cross modality translation analysis mixin class.""" @@ -184,13 +186,13 @@ def cell_transition( # type: ignore[misc] key_added=key_added, ) - def annotation_mapping( + def annotation_mapping( # type: ignore[misc] self: CrossModalityTranslationMixinProtocol[K, B], mapping_mode: Literal["sum", "max"], annotation_label: str, forward: bool, - source: K = "src", - target: K = "tgt", + source: str = "src", + target: str = "tgt", scale_by_marginals: bool = True, other_adata: Optional[str] = None, cell_transition_kwargs: Mapping[str, Any] = types.MappingProxyType({}), @@ -202,7 +204,7 @@ def annotation_mapping( target=target, key=self.batch_key, forward=forward, - other_adata=self.adata_tgt, + other_adata=self.adata_tgt if other_adata is None else other_adata, scale_by_marginals=scale_by_marginals, cell_transition_kwargs=cell_transition_kwargs, ) diff --git a/src/moscot/problems/space/_mixins.py b/src/moscot/problems/space/_mixins.py index 3ee462985..cd101c19d 100644 --- a/src/moscot/problems/space/_mixins.py +++ b/src/moscot/problems/space/_mixins.py @@ -284,8 +284,8 @@ def cell_transition( # type: ignore[misc] key_added=key_added, ) - def annotation_mapping( - self: AnalysisMixinProtocol[K, B], + def annotation_mapping( # type: ignore[misc] + self: SpatialAlignmentMixinProtocol[K, B], mapping_mode: Literal["sum", "max"], annotation_label: str, forward: bool, @@ -299,7 +299,7 @@ def annotation_mapping( annotation_label=annotation_label, source=source, target=target, - key=self._batch_key, + key=self.batch_key, forward=forward, scale_by_marginals=scale_by_marginals, cell_transition_kwargs=cell_transition_kwargs, @@ -594,12 +594,12 @@ def cell_transition( # type: ignore[misc] key_added=key_added, ) - def annotation_mapping( - self: AnalysisMixinProtocol[K, B], + def annotation_mapping( # type: ignore[misc] + self: SpatialMappingMixinProtocol[K, B], mapping_mode: Literal["sum", "max"], annotation_label: str, - source: str, - target: str = "tgt", + source: K, + target: K | str = "tgt", forward: bool = False, scale_by_marginals: bool = True, cell_transition_kwargs: Mapping[str, Any] = types.MappingProxyType({}), diff --git a/src/moscot/problems/time/_mixins.py b/src/moscot/problems/time/_mixins.py index 0c6f1de2c..d43f49005 100644 --- a/src/moscot/problems/time/_mixins.py +++ b/src/moscot/problems/time/_mixins.py @@ -69,6 +69,13 @@ def _cell_transition( ) -> pd.DataFrame: ... + def _annotation_mapping( + self: AnalysisMixinProtocol[K, B], + *args: Any, + **kwargs: Any, + ) -> pd.DataFrame: + ... + def _sample_from_tmap( self: TemporalMixinProtocol[K, B], source: K,