From 57c2649da3e1cd8cfd66dca8a6139a2953d9321f Mon Sep 17 00:00:00 2001 From: Arina Danilina Date: Thu, 18 Jan 2024 15:11:43 +0100 Subject: [PATCH] ruff typing --- src/moscot/base/problems/_mixins.py | 32 ++++++++++++++-------------- src/moscot/problems/space/_mixins.py | 2 +- src/moscot/problems/time/_mixins.py | 22 +++++++++---------- 3 files changed, 28 insertions(+), 28 deletions(-) diff --git a/src/moscot/base/problems/_mixins.py b/src/moscot/base/problems/_mixins.py index ae5855ff2..75570df12 100644 --- a/src/moscot/base/problems/_mixins.py +++ b/src/moscot/base/problems/_mixins.py @@ -49,8 +49,8 @@ class AnalysisMixinProtocol(Protocol[K, B]): adata: AnnData _policy: SubsetPolicy[K] - solutions: Dict[Tuple[K, K], BaseSolverOutput] - problems: Dict[Tuple[K, K], B] + solutions: dict[tuple[K, K], BaseSolverOutput] + problems: dict[tuple[K, K], B] def _apply( self, @@ -66,14 +66,14 @@ def _apply( def _interpolate_transport( self: AnalysisMixinProtocol[K, B], - path: Sequence[Tuple[K, K]], + path: Sequence[tuple[K, K]], scale_by_marginals: bool = True, ) -> LinearOperator: ... def _flatten( self: AnalysisMixinProtocol[K, B], - data: Dict[K, ArrayLike], + data: dict[K, ArrayLike], *, key: Optional[str], ) -> ArrayLike: @@ -345,7 +345,7 @@ def _annotation_mapping( out_len = self.solutions[(source, target)].shape[1] batch_size = batch_size if batch_size is not None else out_len for batch in range(0, out_len, batch_size): - tm_batch: ArrayLike = self.push( # type: ignore[attr-defined] + tm_batch: ArrayLike = self.push( # type: ignore[no-redef] source=source, target=target, data=None, @@ -369,7 +369,7 @@ def _annotation_mapping( out_len = self.solutions[(source, target)].shape[0] batch_size = batch_size if batch_size is not None else out_len for batch in range(0, out_len, batch_size): - tm_batch: ArrayLike = self.pull( # type: ignore[attr-defined] + tm_batch: ArrayLike = self.pull( # type: ignore[no-redef] source=source, target=target, data=None, @@ -397,7 +397,7 @@ def _sample_from_tmap( account_for_unbalancedness: bool = False, interpolation_parameter: Optional[Numeric_t] = None, seed: Optional[int] = None, - ) -> Tuple[List[Any], List[ArrayLike]]: + ) -> tuple[list[Any], list[ArrayLike]]: rng = np.random.RandomState(seed) if account_for_unbalancedness and interpolation_parameter is None: raise ValueError("When accounting for unbalancedness, interpolation parameter must be provided.") @@ -434,7 +434,7 @@ def _sample_from_tmap( rows_sampled = rng.choice(source_dim, p=row_probability / row_probability.sum(), size=n_samples) rows, counts = np.unique(rows_sampled, return_counts=True) - all_cols_sampled: List[str] = [] + all_cols_sampled: list[str] = [] for batch in range(0, len(rows), batch_size): rows_batch = rows[batch : batch + batch_size] counts_batch = counts[batch : batch + batch_size] @@ -467,7 +467,7 @@ def _sample_from_tmap( def _interpolate_transport( self: AnalysisMixinProtocol[K, B], # TODO(@giovp): rename this to 'explicit_steps', pass to policy.plan() and reintroduce (source_key, target_key) - path: Sequence[Tuple[K, K]], + path: Sequence[tuple[K, K]], scale_by_marginals: bool = True, **_: Any, ) -> LinearOperator: @@ -478,7 +478,7 @@ def _interpolate_transport( fst, *rest = path return self.solutions[fst].chain([self.solutions[r] for r in rest], scale_by_marginals=scale_by_marginals) - def _flatten(self: AnalysisMixinProtocol[K, B], data: Dict[K, ArrayLike], *, key: Optional[str]) -> ArrayLike: + def _flatten(self: AnalysisMixinProtocol[K, B], data: dict[K, ArrayLike], *, key: Optional[str]) -> ArrayLike: tmp = np.full(len(self.adata), np.nan) for k, v in data.items(): mask = self.adata.obs[key] == k @@ -490,8 +490,8 @@ def _annotation_aggregation_transition( source: K, target: K, annotation_key: str, - annotations_1: List[Any], - annotations_2: List[Any], + annotations_1: list[Any], + annotations_2: list[Any], df: pd.DataFrame, tm: pd.DataFrame, forward: bool, @@ -526,8 +526,8 @@ def _cell_aggregation_transition( target: str, annotation_key: str, # TODO(MUCDK): unused variables, del below - annotations_1: List[Any], - annotations_2: List[Any], + annotations_1: list[Any], + annotations_2: list[Any], df_1: pd.DataFrame, df_2: pd.DataFrame, tm: pd.DataFrame, @@ -563,9 +563,9 @@ def compute_feature_correlation( obs_key: str, corr_method: Literal["pearson", "spearman"] = "pearson", significance_method: Literal["fisher", "perm_test"] = "fisher", - annotation: Optional[Dict[str, Iterable[str]]] = None, + annotation: Optional[dict[str, Iterable[str]]] = None, layer: Optional[str] = None, - features: Optional[Union[List[str], Literal["human", "mouse", "drosophila"]]] = None, + features: Optional[Union[list[str], Literal["human", "mouse", "drosophila"]]] = None, confidence_level: float = 0.95, n_perms: int = 1000, seed: Optional[int] = None, diff --git a/src/moscot/problems/space/_mixins.py b/src/moscot/problems/space/_mixins.py index cd101c19d..f4604bb8c 100644 --- a/src/moscot/problems/space/_mixins.py +++ b/src/moscot/problems/space/_mixins.py @@ -599,7 +599,7 @@ def annotation_mapping( # type: ignore[misc] mapping_mode: Literal["sum", "max"], annotation_label: str, source: K, - target: K | str = "tgt", + target: Union[K, str] = "tgt", forward: bool = False, scale_by_marginals: bool = True, cell_transition_kwargs: Mapping[str, Any] = types.MappingProxyType({}), diff --git a/src/moscot/problems/time/_mixins.py b/src/moscot/problems/time/_mixins.py index d43f49005..f2ee7068f 100644 --- a/src/moscot/problems/time/_mixins.py +++ b/src/moscot/problems/time/_mixins.py @@ -38,7 +38,7 @@ class TemporalMixinProtocol(AnalysisMixinProtocol[K, B], Protocol[K, B]): # type: ignore[misc] adata: AnnData - problems: Dict[Tuple[K, K], BirthDeathProblem] + problems: dict[tuple[K, K], BirthDeathProblem] temporal_key: Optional[str] _temporal_key: Optional[str] @@ -87,7 +87,7 @@ def _sample_from_tmap( account_for_unbalancedness: bool = False, interpolation_parameter: Optional[float] = None, seed: Optional[int] = None, - ) -> Tuple[List[Any], List[ArrayLike]]: + ) -> tuple[list[Any], list[ArrayLike]]: ... def _compute_wasserstein_distance( @@ -123,7 +123,7 @@ def _get_data( posterior_marginals: bool = True, *, only_start: bool = False, - ) -> Union[Tuple[ArrayLike, AnnData], Tuple[ArrayLike, ArrayLike, ArrayLike, AnnData, ArrayLike]]: + ) -> Union[tuple[ArrayLike, AnnData], tuple[ArrayLike, ArrayLike, ArrayLike, AnnData, ArrayLike]]: ... def _interpolate_gex_randomly( @@ -139,7 +139,7 @@ def _interpolate_gex_randomly( def _plot_temporal( self: TemporalMixinProtocol[K, B], - data: Dict[K, ArrayLike], + data: dict[K, ArrayLike], source: K, target: K, time_points: Optional[Iterable[K]] = None, @@ -157,7 +157,7 @@ def _get_interp_param( ) -> float: ... - def __iter__(self) -> Iterator[Tuple[K, K]]: + def __iter__(self) -> Iterator[tuple[K, K]]: ... @@ -278,7 +278,7 @@ def sankey( order_annotations: Optional[Sequence[str]] = None, key_added: Optional[str] = _constants.SANKEY, **kwargs: Any, - ) -> Optional[List[pd.DataFrame]]: + ) -> Optional[list[pd.DataFrame]]: """Compute a `Sankey diagram `_ between cells across time points. .. seealso:: @@ -392,7 +392,7 @@ def push( source: K, target: K, data: Optional[Union[str, ArrayLike]] = None, - subset: Optional[Union[str, List[str], Tuple[int, int]]] = None, + subset: Optional[Union[str, list[str], tuple[int, int]]] = None, scale_by_marginals: bool = True, key_added: Optional[str] = _constants.PUSH, return_all: bool = False, @@ -459,7 +459,7 @@ def pull( source: K, target: K, data: Optional[Union[str, ArrayLike]] = None, - subset: Optional[Union[str, List[str], Tuple[int, int]]] = None, + subset: Optional[Union[str, list[str], tuple[int, int]]] = None, scale_by_marginals: bool = True, key_added: Optional[str] = _constants.PULL, return_all: bool = False, @@ -614,7 +614,7 @@ def _get_data( posterior_marginals: bool = True, *, only_start: bool = False, - ) -> Union[Tuple[ArrayLike, AnnData], Tuple[ArrayLike, ArrayLike, ArrayLike, AnnData, ArrayLike]]: + ) -> Union[tuple[ArrayLike, AnnData], tuple[ArrayLike, ArrayLike, ArrayLike, AnnData, ArrayLike]]: # TODO: use .items() for src, tgt in self.problems: tag = self.problems[src, tgt].xy.tag # type: ignore[union-attr] @@ -821,7 +821,7 @@ def compute_time_point_distances( posterior_marginals: bool = True, backend: Literal["ott"] = "ott", **kwargs: Any, - ) -> Tuple[float, float]: + ) -> tuple[float, float]: """Compute `Wasserstein distance `_ between time points. .. seealso:: @@ -904,7 +904,7 @@ def compute_batch_distances( if len(data) != len(adata): raise ValueError(f"Expected the `data` to have length `{len(adata)}`, found `{len(data)}`.") - dist: List[float] = [] + dist: list[float] = [] for batch_1, batch_2 in itertools.combinations(adata.obs[batch_key].unique(), 2): dist.append( self._compute_wasserstein_distance(