Skip to content

Commit

Permalink
some mypy fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
Arina Danilina committed Jan 18, 2024
1 parent e2b9c5b commit e1608b0
Show file tree
Hide file tree
Showing 5 changed files with 27 additions and 16 deletions.
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,8 @@ ignore = [
"D107",
# Missing docstring in magic method
"D105",
# Use `X | Y` for type annotations
"UP007",
]
line-length = 120
select = [
Expand Down
10 changes: 5 additions & 5 deletions src/moscot/base/problems/_mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def _annotation_mapping(
forward: bool,
source: K,
target: K,
key: str,
key: str | None = None,
other_adata: Optional[str] = None,
scale_by_marginals: bool = True,
cell_transition_kwargs: Mapping[str, Any] = types.MappingProxyType({}),
Expand Down Expand Up @@ -342,10 +342,10 @@ def _annotation_mapping(
filter_key=key,
filter_value=source,
)
out_len = self[(source, target)].solution.shape[1]
out_len = self.solutions[(source, target)].shape[1]
batch_size = batch_size if batch_size is not None else out_len
for batch in range(0, out_len, batch_size):
tm_batch = self.push(
tm_batch : ArrayLike = self.push( # type: ignore[attr-defined]
source=source,
target=target,
data=None,
Expand All @@ -366,10 +366,10 @@ def _annotation_mapping(
filter_key=key,
filter_value=target,
)
out_len = self[(source, target)].solution.shape[0]
out_len = self.solutions[(source, target)].shape[0]
batch_size = batch_size if batch_size is not None else out_len
for batch in range(0, out_len, batch_size):
tm_batch = self.pull(
tm_batch : ArrayLike = self.pull( # type: ignore[attr-defined]
source=source,
target=target,
data=None,
Expand Down
10 changes: 6 additions & 4 deletions src/moscot/problems/cross_modality/_mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ class CrossModalityTranslationMixinProtocol(AnalysisMixinProtocol[K, B]):
def _cell_transition(self: AnalysisMixinProtocol[K, B], *args: Any, **kwargs: Any) -> pd.DataFrame:
...

def _annotation_mapping(self: AnalysisMixinProtocol[K, B], *args: Any, **kwargs: Any) -> pd.DataFrame:
...

class CrossModalityTranslationMixin(AnalysisMixin[K, B]):
"""Cross modality translation analysis mixin class."""
Expand Down Expand Up @@ -184,13 +186,13 @@ def cell_transition( # type: ignore[misc]
key_added=key_added,
)

def annotation_mapping(
def annotation_mapping( # type: ignore[misc]
self: CrossModalityTranslationMixinProtocol[K, B],
mapping_mode: Literal["sum", "max"],
annotation_label: str,
forward: bool,
source: K = "src",
target: K = "tgt",
source: str = "src",
target: str = "tgt",
scale_by_marginals: bool = True,
other_adata: Optional[str] = None,
cell_transition_kwargs: Mapping[str, Any] = types.MappingProxyType({}),
Expand All @@ -202,7 +204,7 @@ def annotation_mapping(
target=target,
key=self.batch_key,
forward=forward,
other_adata=self.adata_tgt,
other_adata=self.adata_tgt if other_adata is None else other_adata,
scale_by_marginals=scale_by_marginals,
cell_transition_kwargs=cell_transition_kwargs,
)
Expand Down
14 changes: 7 additions & 7 deletions src/moscot/problems/space/_mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,8 +284,8 @@ def cell_transition( # type: ignore[misc]
key_added=key_added,
)

def annotation_mapping(
self: AnalysisMixinProtocol[K, B],
def annotation_mapping( # type: ignore[misc]
self: SpatialAlignmentMixinProtocol[K, B],
mapping_mode: Literal["sum", "max"],
annotation_label: str,
forward: bool,
Expand All @@ -299,7 +299,7 @@ def annotation_mapping(
annotation_label=annotation_label,
source=source,
target=target,
key=self._batch_key,
key=self.batch_key,
forward=forward,
scale_by_marginals=scale_by_marginals,
cell_transition_kwargs=cell_transition_kwargs,
Expand Down Expand Up @@ -594,12 +594,12 @@ def cell_transition( # type: ignore[misc]
key_added=key_added,
)

def annotation_mapping(
self: AnalysisMixinProtocol[K, B],
def annotation_mapping( # type: ignore[misc]
self: SpatialMappingMixinProtocol[K, B],
mapping_mode: Literal["sum", "max"],
annotation_label: str,
source: str,
target: str = "tgt",
source: K,
target: K | str = "tgt",
forward: bool = False,
scale_by_marginals: bool = True,
cell_transition_kwargs: Mapping[str, Any] = types.MappingProxyType({}),
Expand Down
7 changes: 7 additions & 0 deletions src/moscot/problems/time/_mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,13 @@ def _cell_transition(
) -> pd.DataFrame:
...

def _annotation_mapping(
self: AnalysisMixinProtocol[K, B],
*args: Any,
**kwargs: Any,
) -> pd.DataFrame:
...

def _sample_from_tmap(
self: TemporalMixinProtocol[K, B],
source: K,
Expand Down

0 comments on commit e1608b0

Please sign in to comment.