Skip to content

Commit

Permalink
ruff typing
Browse files Browse the repository at this point in the history
  • Loading branch information
Arina Danilina committed Jan 18, 2024
1 parent ab89a42 commit 57c2649
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 28 deletions.
32 changes: 16 additions & 16 deletions src/moscot/base/problems/_mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,8 @@ class AnalysisMixinProtocol(Protocol[K, B]):

adata: AnnData
_policy: SubsetPolicy[K]
solutions: Dict[Tuple[K, K], BaseSolverOutput]
problems: Dict[Tuple[K, K], B]
solutions: dict[tuple[K, K], BaseSolverOutput]
problems: dict[tuple[K, K], B]

def _apply(
self,
Expand All @@ -66,14 +66,14 @@ def _apply(

def _interpolate_transport(
self: AnalysisMixinProtocol[K, B],
path: Sequence[Tuple[K, K]],
path: Sequence[tuple[K, K]],
scale_by_marginals: bool = True,
) -> LinearOperator:
...

def _flatten(
self: AnalysisMixinProtocol[K, B],
data: Dict[K, ArrayLike],
data: dict[K, ArrayLike],
*,
key: Optional[str],
) -> ArrayLike:
Expand Down Expand Up @@ -345,7 +345,7 @@ def _annotation_mapping(
out_len = self.solutions[(source, target)].shape[1]
batch_size = batch_size if batch_size is not None else out_len
for batch in range(0, out_len, batch_size):
tm_batch: ArrayLike = self.push( # type: ignore[attr-defined]
tm_batch: ArrayLike = self.push( # type: ignore[no-redef]
source=source,
target=target,
data=None,
Expand All @@ -369,7 +369,7 @@ def _annotation_mapping(
out_len = self.solutions[(source, target)].shape[0]
batch_size = batch_size if batch_size is not None else out_len
for batch in range(0, out_len, batch_size):
tm_batch: ArrayLike = self.pull( # type: ignore[attr-defined]
tm_batch: ArrayLike = self.pull( # type: ignore[no-redef]
source=source,
target=target,
data=None,
Expand Down Expand Up @@ -397,7 +397,7 @@ def _sample_from_tmap(
account_for_unbalancedness: bool = False,
interpolation_parameter: Optional[Numeric_t] = None,
seed: Optional[int] = None,
) -> Tuple[List[Any], List[ArrayLike]]:
) -> tuple[list[Any], list[ArrayLike]]:
rng = np.random.RandomState(seed)
if account_for_unbalancedness and interpolation_parameter is None:
raise ValueError("When accounting for unbalancedness, interpolation parameter must be provided.")
Expand Down Expand Up @@ -434,7 +434,7 @@ def _sample_from_tmap(

rows_sampled = rng.choice(source_dim, p=row_probability / row_probability.sum(), size=n_samples)
rows, counts = np.unique(rows_sampled, return_counts=True)
all_cols_sampled: List[str] = []
all_cols_sampled: list[str] = []
for batch in range(0, len(rows), batch_size):
rows_batch = rows[batch : batch + batch_size]
counts_batch = counts[batch : batch + batch_size]
Expand Down Expand Up @@ -467,7 +467,7 @@ def _sample_from_tmap(
def _interpolate_transport(
self: AnalysisMixinProtocol[K, B],
# TODO(@giovp): rename this to 'explicit_steps', pass to policy.plan() and reintroduce (source_key, target_key)
path: Sequence[Tuple[K, K]],
path: Sequence[tuple[K, K]],
scale_by_marginals: bool = True,
**_: Any,
) -> LinearOperator:
Expand All @@ -478,7 +478,7 @@ def _interpolate_transport(
fst, *rest = path
return self.solutions[fst].chain([self.solutions[r] for r in rest], scale_by_marginals=scale_by_marginals)

def _flatten(self: AnalysisMixinProtocol[K, B], data: Dict[K, ArrayLike], *, key: Optional[str]) -> ArrayLike:
def _flatten(self: AnalysisMixinProtocol[K, B], data: dict[K, ArrayLike], *, key: Optional[str]) -> ArrayLike:
tmp = np.full(len(self.adata), np.nan)
for k, v in data.items():
mask = self.adata.obs[key] == k
Expand All @@ -490,8 +490,8 @@ def _annotation_aggregation_transition(
source: K,
target: K,
annotation_key: str,
annotations_1: List[Any],
annotations_2: List[Any],
annotations_1: list[Any],
annotations_2: list[Any],
df: pd.DataFrame,
tm: pd.DataFrame,
forward: bool,
Expand Down Expand Up @@ -526,8 +526,8 @@ def _cell_aggregation_transition(
target: str,
annotation_key: str,
# TODO(MUCDK): unused variables, del below
annotations_1: List[Any],
annotations_2: List[Any],
annotations_1: list[Any],
annotations_2: list[Any],
df_1: pd.DataFrame,
df_2: pd.DataFrame,
tm: pd.DataFrame,
Expand Down Expand Up @@ -563,9 +563,9 @@ def compute_feature_correlation(
obs_key: str,
corr_method: Literal["pearson", "spearman"] = "pearson",
significance_method: Literal["fisher", "perm_test"] = "fisher",
annotation: Optional[Dict[str, Iterable[str]]] = None,
annotation: Optional[dict[str, Iterable[str]]] = None,
layer: Optional[str] = None,
features: Optional[Union[List[str], Literal["human", "mouse", "drosophila"]]] = None,
features: Optional[Union[list[str], Literal["human", "mouse", "drosophila"]]] = None,
confidence_level: float = 0.95,
n_perms: int = 1000,
seed: Optional[int] = None,
Expand Down
2 changes: 1 addition & 1 deletion src/moscot/problems/space/_mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -599,7 +599,7 @@ def annotation_mapping( # type: ignore[misc]
mapping_mode: Literal["sum", "max"],
annotation_label: str,
source: K,
target: K | str = "tgt",
target: Union[K, str] = "tgt",
forward: bool = False,
scale_by_marginals: bool = True,
cell_transition_kwargs: Mapping[str, Any] = types.MappingProxyType({}),
Expand Down
22 changes: 11 additions & 11 deletions src/moscot/problems/time/_mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@

class TemporalMixinProtocol(AnalysisMixinProtocol[K, B], Protocol[K, B]): # type: ignore[misc]
adata: AnnData
problems: Dict[Tuple[K, K], BirthDeathProblem]
problems: dict[tuple[K, K], BirthDeathProblem]
temporal_key: Optional[str]
_temporal_key: Optional[str]

Expand Down Expand Up @@ -87,7 +87,7 @@ def _sample_from_tmap(
account_for_unbalancedness: bool = False,
interpolation_parameter: Optional[float] = None,
seed: Optional[int] = None,
) -> Tuple[List[Any], List[ArrayLike]]:
) -> tuple[list[Any], list[ArrayLike]]:
...

def _compute_wasserstein_distance(
Expand Down Expand Up @@ -123,7 +123,7 @@ def _get_data(
posterior_marginals: bool = True,
*,
only_start: bool = False,
) -> Union[Tuple[ArrayLike, AnnData], Tuple[ArrayLike, ArrayLike, ArrayLike, AnnData, ArrayLike]]:
) -> Union[tuple[ArrayLike, AnnData], tuple[ArrayLike, ArrayLike, ArrayLike, AnnData, ArrayLike]]:
...

def _interpolate_gex_randomly(
Expand All @@ -139,7 +139,7 @@ def _interpolate_gex_randomly(

def _plot_temporal(
self: TemporalMixinProtocol[K, B],
data: Dict[K, ArrayLike],
data: dict[K, ArrayLike],
source: K,
target: K,
time_points: Optional[Iterable[K]] = None,
Expand All @@ -157,7 +157,7 @@ def _get_interp_param(
) -> float:
...

def __iter__(self) -> Iterator[Tuple[K, K]]:
def __iter__(self) -> Iterator[tuple[K, K]]:
...


Expand Down Expand Up @@ -278,7 +278,7 @@ def sankey(
order_annotations: Optional[Sequence[str]] = None,
key_added: Optional[str] = _constants.SANKEY,
**kwargs: Any,
) -> Optional[List[pd.DataFrame]]:
) -> Optional[list[pd.DataFrame]]:
"""Compute a `Sankey diagram <https://en.wikipedia.org/wiki/Sankey_diagram>`_ between cells across time points.
.. seealso::
Expand Down Expand Up @@ -392,7 +392,7 @@ def push(
source: K,
target: K,
data: Optional[Union[str, ArrayLike]] = None,
subset: Optional[Union[str, List[str], Tuple[int, int]]] = None,
subset: Optional[Union[str, list[str], tuple[int, int]]] = None,
scale_by_marginals: bool = True,
key_added: Optional[str] = _constants.PUSH,
return_all: bool = False,
Expand Down Expand Up @@ -459,7 +459,7 @@ def pull(
source: K,
target: K,
data: Optional[Union[str, ArrayLike]] = None,
subset: Optional[Union[str, List[str], Tuple[int, int]]] = None,
subset: Optional[Union[str, list[str], tuple[int, int]]] = None,
scale_by_marginals: bool = True,
key_added: Optional[str] = _constants.PULL,
return_all: bool = False,
Expand Down Expand Up @@ -614,7 +614,7 @@ def _get_data(
posterior_marginals: bool = True,
*,
only_start: bool = False,
) -> Union[Tuple[ArrayLike, AnnData], Tuple[ArrayLike, ArrayLike, ArrayLike, AnnData, ArrayLike]]:
) -> Union[tuple[ArrayLike, AnnData], tuple[ArrayLike, ArrayLike, ArrayLike, AnnData, ArrayLike]]:
# TODO: use .items()
for src, tgt in self.problems:
tag = self.problems[src, tgt].xy.tag # type: ignore[union-attr]
Expand Down Expand Up @@ -821,7 +821,7 @@ def compute_time_point_distances(
posterior_marginals: bool = True,
backend: Literal["ott"] = "ott",
**kwargs: Any,
) -> Tuple[float, float]:
) -> tuple[float, float]:
"""Compute `Wasserstein distance <https://en.wikipedia.org/wiki/Wasserstein_metric>`_ between time points.
.. seealso::
Expand Down Expand Up @@ -904,7 +904,7 @@ def compute_batch_distances(
if len(data) != len(adata):
raise ValueError(f"Expected the `data` to have length `{len(adata)}`, found `{len(data)}`.")

dist: List[float] = []
dist: list[float] = []
for batch_1, batch_2 in itertools.combinations(adata.obs[batch_key].unique(), 2):
dist.append(
self._compute_wasserstein_distance(
Expand Down

0 comments on commit 57c2649

Please sign in to comment.