Skip to content

Commit

Permalink
Merge branch 'ig/neural_solvers' of github.com:theislab/moscot into i…
Browse files Browse the repository at this point in the history
…g/neural_solvers
  • Loading branch information
ilan-gold committed Feb 26, 2024
2 parents a0f94a4 + 25844d7 commit f017c9e
Show file tree
Hide file tree
Showing 16 changed files with 90 additions and 114 deletions.
6 changes: 3 additions & 3 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ repos:
additional_dependencies: [numpy>=1.25.0]
files: ^src
- repo: https://github.com/psf/black
rev: 23.12.1
rev: 24.2.0
hooks:
- id: black
additional_dependencies: [toml]
Expand Down Expand Up @@ -42,7 +42,7 @@ repos:
- id: check-yaml
- id: check-toml
- repo: https://github.com/asottile/pyupgrade
rev: v3.15.0
rev: v3.15.1
hooks:
- id: pyupgrade
args: [--py3-plus, --py38-plus, --keep-runtime-typing]
Expand All @@ -63,7 +63,7 @@ repos:
- id: doc8
- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version.
rev: v0.1.14
rev: v0.2.2
hooks:
- id: ruff
args: [--fix, --exit-non-zero-on-fix]
2 changes: 1 addition & 1 deletion src/moscot/_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
try:
from numpy.typing import DTypeLike, NDArray

ArrayLike = NDArray[np.float_]
ArrayLike = NDArray[np.float64]
except (ImportError, TypeError):
ArrayLike = np.ndarray # type: ignore[misc]
DTypeLike = np.dtype # type: ignore[misc]
Expand Down
30 changes: 15 additions & 15 deletions src/moscot/base/problems/_mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
from moscot.base.output import BaseDiscreteSolverOutput
from moscot.base.problems._utils import (
_check_argument_compatibility_cell_transition,
_compute_conditional_entropy,
_correlation_test,
_get_df_cell_transition,
_order_transition_matrix,
Expand Down Expand Up @@ -58,23 +57,20 @@ def _apply(
return_all: bool = False,
scale_by_marginals: bool = False,
**kwargs: Any,
) -> ApplyOutput_t[K]:
...
) -> ApplyOutput_t[K]: ...

def _interpolate_transport(
self: AnalysisMixinProtocol[K, B],
path: Sequence[tuple[K, K]],
scale_by_marginals: bool = True,
) -> LinearOperator:
...
) -> LinearOperator: ...

def _flatten(
self: AnalysisMixinProtocol[K, B],
data: dict[K, ArrayLike],
*,
key: Optional[str],
) -> ArrayLike:
...
) -> ArrayLike: ...

def push(self, *args: Any, **kwargs: Any) -> Optional[ApplyOutput_t[K]]:
"""Push distribution."""
Expand All @@ -93,8 +89,7 @@ def _cell_transition(
aggregation_mode: Literal["annotation", "cell"] = "annotation",
key_added: Optional[str] = _constants.CELL_TRANSITION,
**kwargs: Any,
) -> pd.DataFrame:
...
) -> pd.DataFrame: ...

def _cell_transition_online(
self: AnalysisMixinProtocol[K, B],
Expand All @@ -109,8 +104,7 @@ def _cell_transition_online(
other_adata: Optional[str] = None,
batch_size: Optional[int] = None,
normalize: bool = True,
) -> pd.DataFrame:
...
) -> pd.DataFrame: ...

def _annotation_mapping(
self: AnalysisMixinProtocol[K, B],
Expand All @@ -123,8 +117,7 @@ def _annotation_mapping(
other_adata: Optional[str] = None,
scale_by_marginals: bool = True,
cell_transition_kwargs: Mapping[str, Any] = types.MappingProxyType({}),
) -> pd.DataFrame:
...
) -> pd.DataFrame: ...


class AnalysisMixin(Generic[K, B]):
Expand Down Expand Up @@ -672,12 +665,13 @@ def compute_entropy(
forward: bool = True,
key_added: Optional[str] = "conditional_entropy",
batch_size: Optional[int] = None,
c: float = 1e-10,
**kwargs: Any,
) -> Optional[pd.DataFrame]:
"""Compute the conditional entropy per cell.
The conditional entropy reflects the uncertainty of the mapping of a single cell.
Parameters
----------
source
Expand All @@ -691,12 +685,18 @@ def compute_entropy(
Key in :attr:`~anndata.AnnData.obs` where the entropy is stored.
batch_size
Batch size for the computation of the entropy. If :obj:`None`, the entire dataset is used.
c
Constant added to each row of the transport matrix to avoid numerical instability.
kwargs
Kwargs for :func:`~scipy.stats.entropy`.
Returns
-------
:obj:`None` if ``key_added`` is not None. Otherwise, returns a data frame of shape ``(n_cells, 1)`` containing
the conditional entropy per cell.
"""
from scipy import stats

filter_value = source if forward else target
df = pd.DataFrame(
index=self.adata[self.adata.obs[self._policy.key] == filter_value, :].obs_names,
Expand All @@ -716,7 +716,7 @@ def compute_entropy(
split_mass=True,
key_added=None,
)
df.iloc[range(batch, min(batch + batch_size, len(df))), 0] = _compute_conditional_entropy(cond_dists) # type: ignore[arg-type]
df.iloc[range(batch, min(batch + batch_size, len(df))), 0] = stats.entropy(cond_dists + c, **kwargs) # type: ignore[operator]
if key_added is not None:
self.adata.obs[key_added] = df
return df if key_added is None else None
4 changes: 0 additions & 4 deletions src/moscot/base/problems/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -719,7 +719,3 @@ def _get_n_cores(n_cores: Optional[int], n_jobs: Optional[int]) -> int:
return multiprocessing.cpu_count() + 1 + n_cores

return n_cores


def _compute_conditional_entropy(p_xy: ArrayLike) -> ArrayLike:
return -np.sum(p_xy * np.log(p_xy / p_xy.sum(axis=0)), axis=0)
12 changes: 8 additions & 4 deletions src/moscot/base/problems/birth_death.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,7 @@ def score_genes_for_marginals( # noqa: D102
proliferation_key: str = "proliferation",
apoptosis_key: str = "apoptosis",
**kwargs: Any,
) -> "BirthDeathProtocol":
...
) -> "BirthDeathProtocol": ...


class BirthDeathProblemProtocol(BirthDeathProtocol, Protocol): # noqa: D101
Expand Down Expand Up @@ -166,6 +165,7 @@ def estimate_marginals(
source: bool,
proliferation_key: Optional[str] = None,
apoptosis_key: Optional[str] = None,
scaling: Optional[float] = None,
**kwargs: Any,
) -> ArrayLike:
"""Estimate the source or target :term:`marginals` based on marker genes, either with the
Expand All @@ -184,6 +184,11 @@ def estimate_marginals(
Key in :attr:`~anndata.AnnData.obs` where proliferation scores are stored.
apoptosis_key
Key in :attr:`~anndata.AnnData.obs` where apoptosis scores are stored.
scaling
A parameter for prior growth rate estimation.
If :obj:`float` is passed, it will be used as a scaling parameter in an exponential kernel
with proliferation and apoptosis scores.
If :obj:`None`, parameters corresponding to the birth and death processes will be used.
kwargs
Keyword arguments for :func:`~moscot.base.problems.birth_death.beta` and
:func:`~moscot.base.problems.birth_death.delta`.
Expand Down Expand Up @@ -217,9 +222,8 @@ def estimate(key: Optional[str], *, fn: Callable[..., ArrayLike], **kwargs: Any)
self.proliferation_key = proliferation_key
self.apoptosis_key = apoptosis_key

if "scaling" in kwargs:
if scaling:
beta_fn = delta_fn = lambda x, *_, **__: x
scaling = kwargs["scaling"]
else:
beta_fn, delta_fn = beta, delta
scaling = 1.0
Expand Down
22 changes: 11 additions & 11 deletions src/moscot/plotting/_plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
_create_col_colors,
_heatmap,
_input_to_adatas,
_plot_temporal,
_plot_scatter,
_sankey,
get_plotting_vars,
)
Expand Down Expand Up @@ -104,12 +104,12 @@ def cell_transition(
row_adata=adata1,
col_adata=adata2,
transition_matrix=data["transition_matrix"],
row_annotation=data["source_groups"]
if isinstance(data["source_groups"], str)
else next(iter(data["source_groups"])),
col_annotation=data["target_groups"]
if isinstance(data["target_groups"], str)
else next(iter(data["target_groups"])),
row_annotation=(
data["source_groups"] if isinstance(data["source_groups"], str) else next(iter(data["source_groups"]))
),
col_annotation=(
data["target_groups"] if isinstance(data["target_groups"], str) else next(iter(data["target_groups"]))
),
row_annotation_label=data["source"] if row_label is None else row_label,
col_annotation_label=data["target"] if col_label is None else col_label,
cont_cmap=cmap,
Expand Down Expand Up @@ -292,9 +292,9 @@ def push(
if data["data"] is not None and data["subset"] is not None and cmap is None:
cmap = _create_col_colors(adata, data["data"], data["subset"])

fig = _plot_temporal(
fig = _plot_scatter(
adata=adata,
temporal_key=data["temporal_key"],
generic_key=data["key"],
key_stored=key,
source=data["source"],
target=data["target"],
Expand Down Expand Up @@ -400,9 +400,9 @@ def pull(
if data["data"] is not None and data["subset"] is not None and cmap is None:
cmap = _create_col_colors(adata, data["data"], data["subset"])

fig = _plot_temporal(
fig = _plot_scatter(
adata=adata,
temporal_key=data["temporal_key"],
generic_key=data["key"],
key_stored=key,
source=data["source"],
target=data["target"],
Expand Down
9 changes: 5 additions & 4 deletions src/moscot/plotting/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,9 +376,9 @@ def _input_to_adatas(
raise ValueError(f"Unable to interpret input of type `{type(inp)}`.")


def _plot_temporal(
def _plot_scatter(
adata: AnnData,
temporal_key: str,
generic_key: str,
key_stored: str,
source: float,
target: float,
Expand Down Expand Up @@ -430,7 +430,8 @@ def _plot_temporal(
titles.extend([f"{name} at time {time_points[i]}" for i in range(1, len(time_points))])
else:
titles = [
f"{categories if categories is not None else 'Cells'} at time {source if push else target} and {name}"
f"{'Push' if push else 'Pull'} {categories if categories is not None else 'cells'} "
+ f"from {source if push else target} to {target if push else source}"
]
for i, ax in enumerate(axs):
# we need to create adata_view because otherwise the view of the adata is copied in the next step i+1
Expand All @@ -446,7 +447,7 @@ def _plot_temporal(
adata_view = adata
else:
tmp = np.full(len(adata), constant_fill_value)
mask = adata.obs[temporal_key] == time_points[i]
mask = adata.obs[generic_key] == time_points[i]

tmp[mask] = adata[mask].obs[key_stored]
if scale:
Expand Down
4 changes: 1 addition & 3 deletions src/moscot/problems/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,7 @@ def handle_joint_attr(
return xy, kwargs

# if this is True we have custom cost matrix or moscot cost - in this case we have a custom cost matrix
if joint_attr.get("tag", None) == "cost_matrix" and (
len(joint_attr) == 2 or kwargs.get("attr", None) == "obsp"
):
if joint_attr.get("tag", None) == "cost_matrix" and (len(joint_attr) == 2 or kwargs.get("attr") == "obsp"):
joint_attr.setdefault("cost", "custom")
joint_attr.setdefault("attr", "obsp")
kwargs["xy_callback"] = "cost-matrix"
Expand Down
8 changes: 3 additions & 5 deletions src/moscot/problems/cross_modality/_mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,9 @@ class CrossModalityTranslationMixinProtocol(AnalysisMixinProtocol[K, B]):
_tgt_attr: Optional[Dict[str, Any]]
batch_key: Optional[str]

def _cell_transition(self: AnalysisMixinProtocol[K, B], *args: Any, **kwargs: Any) -> pd.DataFrame:
...
def _cell_transition(self: AnalysisMixinProtocol[K, B], *args: Any, **kwargs: Any) -> pd.DataFrame: ...

def _annotation_mapping(self: AnalysisMixinProtocol[K, B], *args: Any, **kwargs: Any) -> pd.DataFrame:
...
def _annotation_mapping(self: AnalysisMixinProtocol[K, B], *args: Any, **kwargs: Any) -> pd.DataFrame: ...


class CrossModalityTranslationMixin(AnalysisMixin[K, B]):
Expand Down Expand Up @@ -82,7 +80,7 @@ def _get_features(
attr: Dict[str, Any],
) -> ArrayLike:
data = getattr(adata, attr["attr"])
key = attr.get("key", None)
key = attr.get("key")
return data if key is None else data[key]

if self._src_attr is None:
Expand Down
13 changes: 10 additions & 3 deletions src/moscot/problems/generic/_mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,7 @@ def _cell_transition(
self: AnalysisMixinProtocol[K, B],
*args: Any,
**kwargs: Any,
) -> pd.DataFrame:
...
) -> pd.DataFrame: ...


class GenericAnalysisMixin(AnalysisMixin[K, B]):
Expand Down Expand Up @@ -170,7 +169,11 @@ def push(
if TYPE_CHECKING:
assert isinstance(key_added, str)
plot_vars = {
"distribution_key": self.batch_key,
"source": source,
"target": target,
"key": self.batch_key,
"data": data if isinstance(data, str) else None,
"subset": subset,
}
self.adata.obs[key_added] = self._flatten(result, key=self.batch_key)
set_plotting_vars(self.adata, _constants.PUSH, key=key_added, value=plot_vars)
Expand Down Expand Up @@ -233,6 +236,10 @@ def pull(
if key_added is not None:
plot_vars = {
"key": self.batch_key,
"data": data if isinstance(data, str) else None,
"subset": subset,
"source": source,
"target": target,
}
self.adata.obs[key_added] = self._flatten(result, key=self.batch_key)
set_plotting_vars(self.adata, _constants.PULL, key=key_added, value=plot_vars)
Expand Down
21 changes: 7 additions & 14 deletions src/moscot/problems/space/_mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,30 +46,26 @@ def _subset_spatial( # type:ignore[empty-body]
self: "SpatialAlignmentMixinProtocol[K, B]",
k: K,
spatial_key: str,
) -> ArrayLike:
...
) -> ArrayLike: ...

def _interpolate_scheme( # type:ignore[empty-body]
self: "SpatialAlignmentMixinProtocol[K, B]",
reference: K,
mode: Literal["warp", "affine"],
spatial_key: str,
) -> Tuple[Dict[K, ArrayLike], Optional[Dict[K, Optional[ArrayLike]]]]:
...
) -> Tuple[Dict[K, ArrayLike], Optional[Dict[K, Optional[ArrayLike]]]]: ...

def _cell_transition(
self: AnalysisMixinProtocol[K, B],
*args: Any,
**kwargs: Any,
) -> pd.DataFrame:
...
) -> pd.DataFrame: ...

def _annotation_mapping(
self: AnalysisMixinProtocol[K, B],
*args: Any,
**kwargs: Any,
) -> pd.DataFrame:
...
) -> pd.DataFrame: ...


class SpatialMappingMixinProtocol(AnalysisMixinProtocol[K, B]):
Expand All @@ -84,14 +80,11 @@ class SpatialMappingMixinProtocol(AnalysisMixinProtocol[K, B]):
def _filter_vars(
self: "SpatialMappingMixinProtocol[K, B]",
var_names: Optional[Sequence[str]] = None,
) -> Optional[List[str]]:
...
) -> Optional[List[str]]: ...

def _cell_transition(self: AnalysisMixinProtocol[K, B], *args: Any, **kwargs: Any) -> pd.DataFrame:
...
def _cell_transition(self: AnalysisMixinProtocol[K, B], *args: Any, **kwargs: Any) -> pd.DataFrame: ...

def _annotation_mapping(self: AnalysisMixinProtocol[K, B], *args: Any, **kwargs: Any) -> pd.DataFrame:
...
def _annotation_mapping(self: AnalysisMixinProtocol[K, B], *args: Any, **kwargs: Any) -> pd.DataFrame: ...


class SpatialAlignmentMixin(AnalysisMixin[K, B]):
Expand Down
Loading

0 comments on commit f017c9e

Please sign in to comment.