diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index ee787023f..098c19035 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -13,7 +13,7 @@ repos: additional_dependencies: [numpy>=1.25.0] files: ^src - repo: https://github.com/psf/black - rev: 23.12.1 + rev: 24.2.0 hooks: - id: black additional_dependencies: [toml] @@ -42,7 +42,7 @@ repos: - id: check-yaml - id: check-toml - repo: https://github.com/asottile/pyupgrade - rev: v3.15.0 + rev: v3.15.1 hooks: - id: pyupgrade args: [--py3-plus, --py38-plus, --keep-runtime-typing] @@ -63,7 +63,7 @@ repos: - id: doc8 - repo: https://github.com/astral-sh/ruff-pre-commit # Ruff version. - rev: v0.1.14 + rev: v0.2.2 hooks: - id: ruff args: [--fix, --exit-non-zero-on-fix] diff --git a/docs/notebooks b/docs/notebooks index 5e1e67141..a14755a3f 160000 --- a/docs/notebooks +++ b/docs/notebooks @@ -1 +1 @@ -Subproject commit 5e1e671418156edd35deaac6de95c7d80809dc56 +Subproject commit a14755a3fa26b94c823c78d6eec9208e18cb3d02 diff --git a/src/moscot/_types.py b/src/moscot/_types.py index 3d463c851..024471886 100644 --- a/src/moscot/_types.py +++ b/src/moscot/_types.py @@ -8,7 +8,7 @@ try: from numpy.typing import DTypeLike, NDArray - ArrayLike = NDArray[np.float_] + ArrayLike = NDArray[np.float64] except (ImportError, TypeError): ArrayLike = np.ndarray # type: ignore[misc] DTypeLike = np.dtype # type: ignore[misc] diff --git a/src/moscot/base/problems/_mixins.py b/src/moscot/base/problems/_mixins.py index b17a330e4..9c460eac2 100644 --- a/src/moscot/base/problems/_mixins.py +++ b/src/moscot/base/problems/_mixins.py @@ -26,7 +26,6 @@ from moscot.base.output import BaseDiscreteSolverOutput from moscot.base.problems._utils import ( _check_argument_compatibility_cell_transition, - _compute_conditional_entropy, _correlation_test, _get_df_cell_transition, _order_transition_matrix, @@ -58,23 +57,20 @@ def _apply( return_all: bool = False, scale_by_marginals: bool = False, **kwargs: Any, - ) -> ApplyOutput_t[K]: - ... + ) -> ApplyOutput_t[K]: ... def _interpolate_transport( self: AnalysisMixinProtocol[K, B], path: Sequence[tuple[K, K]], scale_by_marginals: bool = True, - ) -> LinearOperator: - ... + ) -> LinearOperator: ... def _flatten( self: AnalysisMixinProtocol[K, B], data: dict[K, ArrayLike], *, key: Optional[str], - ) -> ArrayLike: - ... + ) -> ArrayLike: ... def push(self, *args: Any, **kwargs: Any) -> Optional[ApplyOutput_t[K]]: """Push distribution.""" @@ -93,8 +89,7 @@ def _cell_transition( aggregation_mode: Literal["annotation", "cell"] = "annotation", key_added: Optional[str] = _constants.CELL_TRANSITION, **kwargs: Any, - ) -> pd.DataFrame: - ... + ) -> pd.DataFrame: ... def _cell_transition_online( self: AnalysisMixinProtocol[K, B], @@ -109,8 +104,7 @@ def _cell_transition_online( other_adata: Optional[str] = None, batch_size: Optional[int] = None, normalize: bool = True, - ) -> pd.DataFrame: - ... + ) -> pd.DataFrame: ... def _annotation_mapping( self: AnalysisMixinProtocol[K, B], @@ -123,8 +117,7 @@ def _annotation_mapping( other_adata: Optional[str] = None, scale_by_marginals: bool = True, cell_transition_kwargs: Mapping[str, Any] = types.MappingProxyType({}), - ) -> pd.DataFrame: - ... + ) -> pd.DataFrame: ... class AnalysisMixin(Generic[K, B]): @@ -672,12 +665,13 @@ def compute_entropy( forward: bool = True, key_added: Optional[str] = "conditional_entropy", batch_size: Optional[int] = None, + c: float = 1e-10, + **kwargs: Any, ) -> Optional[pd.DataFrame]: """Compute the conditional entropy per cell. The conditional entropy reflects the uncertainty of the mapping of a single cell. - Parameters ---------- source @@ -691,12 +685,18 @@ def compute_entropy( Key in :attr:`~anndata.AnnData.obs` where the entropy is stored. batch_size Batch size for the computation of the entropy. If :obj:`None`, the entire dataset is used. + c + Constant added to each row of the transport matrix to avoid numerical instability. + kwargs + Kwargs for :func:`~scipy.stats.entropy`. Returns ------- :obj:`None` if ``key_added`` is not None. Otherwise, returns a data frame of shape ``(n_cells, 1)`` containing the conditional entropy per cell. """ + from scipy import stats + filter_value = source if forward else target df = pd.DataFrame( index=self.adata[self.adata.obs[self._policy.key] == filter_value, :].obs_names, @@ -716,7 +716,7 @@ def compute_entropy( split_mass=True, key_added=None, ) - df.iloc[range(batch, min(batch + batch_size, len(df))), 0] = _compute_conditional_entropy(cond_dists) # type: ignore[arg-type] + df.iloc[range(batch, min(batch + batch_size, len(df))), 0] = stats.entropy(cond_dists + c, **kwargs) # type: ignore[operator] if key_added is not None: self.adata.obs[key_added] = df return df if key_added is None else None diff --git a/src/moscot/base/problems/_utils.py b/src/moscot/base/problems/_utils.py index 397248d3e..7219b58e3 100644 --- a/src/moscot/base/problems/_utils.py +++ b/src/moscot/base/problems/_utils.py @@ -719,7 +719,3 @@ def _get_n_cores(n_cores: Optional[int], n_jobs: Optional[int]) -> int: return multiprocessing.cpu_count() + 1 + n_cores return n_cores - - -def _compute_conditional_entropy(p_xy: ArrayLike) -> ArrayLike: - return -np.sum(p_xy * np.log(p_xy / p_xy.sum(axis=0)), axis=0) diff --git a/src/moscot/base/problems/birth_death.py b/src/moscot/base/problems/birth_death.py index 2121e2d1d..6f6004e6c 100644 --- a/src/moscot/base/problems/birth_death.py +++ b/src/moscot/base/problems/birth_death.py @@ -38,8 +38,7 @@ def score_genes_for_marginals( # noqa: D102 proliferation_key: str = "proliferation", apoptosis_key: str = "apoptosis", **kwargs: Any, - ) -> "BirthDeathProtocol": - ... + ) -> "BirthDeathProtocol": ... class BirthDeathProblemProtocol(BirthDeathProtocol, Protocol): # noqa: D101 @@ -166,6 +165,7 @@ def estimate_marginals( source: bool, proliferation_key: Optional[str] = None, apoptosis_key: Optional[str] = None, + scaling: Optional[float] = None, **kwargs: Any, ) -> ArrayLike: """Estimate the source or target :term:`marginals` based on marker genes, either with the @@ -184,6 +184,11 @@ def estimate_marginals( Key in :attr:`~anndata.AnnData.obs` where proliferation scores are stored. apoptosis_key Key in :attr:`~anndata.AnnData.obs` where apoptosis scores are stored. + scaling + A parameter for prior growth rate estimation. + If :obj:`float` is passed, it will be used as a scaling parameter in an exponential kernel + with proliferation and apoptosis scores. + If :obj:`None`, parameters corresponding to the birth and death processes will be used. kwargs Keyword arguments for :func:`~moscot.base.problems.birth_death.beta` and :func:`~moscot.base.problems.birth_death.delta`. @@ -217,9 +222,8 @@ def estimate(key: Optional[str], *, fn: Callable[..., ArrayLike], **kwargs: Any) self.proliferation_key = proliferation_key self.apoptosis_key = apoptosis_key - if "scaling" in kwargs: + if scaling: beta_fn = delta_fn = lambda x, *_, **__: x - scaling = kwargs["scaling"] else: beta_fn, delta_fn = beta, delta scaling = 1.0 diff --git a/src/moscot/plotting/_plotting.py b/src/moscot/plotting/_plotting.py index 897040309..363f2cd1a 100644 --- a/src/moscot/plotting/_plotting.py +++ b/src/moscot/plotting/_plotting.py @@ -23,7 +23,7 @@ _create_col_colors, _heatmap, _input_to_adatas, - _plot_temporal, + _plot_scatter, _sankey, get_plotting_vars, ) @@ -104,12 +104,12 @@ def cell_transition( row_adata=adata1, col_adata=adata2, transition_matrix=data["transition_matrix"], - row_annotation=data["source_groups"] - if isinstance(data["source_groups"], str) - else next(iter(data["source_groups"])), - col_annotation=data["target_groups"] - if isinstance(data["target_groups"], str) - else next(iter(data["target_groups"])), + row_annotation=( + data["source_groups"] if isinstance(data["source_groups"], str) else next(iter(data["source_groups"])) + ), + col_annotation=( + data["target_groups"] if isinstance(data["target_groups"], str) else next(iter(data["target_groups"])) + ), row_annotation_label=data["source"] if row_label is None else row_label, col_annotation_label=data["target"] if col_label is None else col_label, cont_cmap=cmap, @@ -292,9 +292,9 @@ def push( if data["data"] is not None and data["subset"] is not None and cmap is None: cmap = _create_col_colors(adata, data["data"], data["subset"]) - fig = _plot_temporal( + fig = _plot_scatter( adata=adata, - temporal_key=data["temporal_key"], + generic_key=data["key"], key_stored=key, source=data["source"], target=data["target"], @@ -400,9 +400,9 @@ def pull( if data["data"] is not None and data["subset"] is not None and cmap is None: cmap = _create_col_colors(adata, data["data"], data["subset"]) - fig = _plot_temporal( + fig = _plot_scatter( adata=adata, - temporal_key=data["temporal_key"], + generic_key=data["key"], key_stored=key, source=data["source"], target=data["target"], diff --git a/src/moscot/plotting/_utils.py b/src/moscot/plotting/_utils.py index e4aaf7595..358ae5fa2 100644 --- a/src/moscot/plotting/_utils.py +++ b/src/moscot/plotting/_utils.py @@ -376,9 +376,9 @@ def _input_to_adatas( raise ValueError(f"Unable to interpret input of type `{type(inp)}`.") -def _plot_temporal( +def _plot_scatter( adata: AnnData, - temporal_key: str, + generic_key: str, key_stored: str, source: float, target: float, @@ -430,7 +430,8 @@ def _plot_temporal( titles.extend([f"{name} at time {time_points[i]}" for i in range(1, len(time_points))]) else: titles = [ - f"{categories if categories is not None else 'Cells'} at time {source if push else target} and {name}" + f"{'Push' if push else 'Pull'} {categories if categories is not None else 'cells'} " + + f"from {source if push else target} to {target if push else source}" ] for i, ax in enumerate(axs): # we need to create adata_view because otherwise the view of the adata is copied in the next step i+1 @@ -446,7 +447,7 @@ def _plot_temporal( adata_view = adata else: tmp = np.full(len(adata), constant_fill_value) - mask = adata.obs[temporal_key] == time_points[i] + mask = adata.obs[generic_key] == time_points[i] tmp[mask] = adata[mask].obs[key_stored] if scale: diff --git a/src/moscot/problems/_utils.py b/src/moscot/problems/_utils.py index 9d154a0c7..7f7efc9fe 100644 --- a/src/moscot/problems/_utils.py +++ b/src/moscot/problems/_utils.py @@ -36,9 +36,7 @@ def handle_joint_attr( return xy, kwargs # if this is True we have custom cost matrix or moscot cost - in this case we have a custom cost matrix - if joint_attr.get("tag", None) == "cost_matrix" and ( - len(joint_attr) == 2 or kwargs.get("attr", None) == "obsp" - ): + if joint_attr.get("tag", None) == "cost_matrix" and (len(joint_attr) == 2 or kwargs.get("attr") == "obsp"): joint_attr.setdefault("cost", "custom") joint_attr.setdefault("attr", "obsp") kwargs["xy_callback"] = "cost-matrix" diff --git a/src/moscot/problems/cross_modality/_mixins.py b/src/moscot/problems/cross_modality/_mixins.py index 5299e1dab..5a3967e1f 100644 --- a/src/moscot/problems/cross_modality/_mixins.py +++ b/src/moscot/problems/cross_modality/_mixins.py @@ -22,11 +22,9 @@ class CrossModalityTranslationMixinProtocol(AnalysisMixinProtocol[K, B]): _tgt_attr: Optional[Dict[str, Any]] batch_key: Optional[str] - def _cell_transition(self: AnalysisMixinProtocol[K, B], *args: Any, **kwargs: Any) -> pd.DataFrame: - ... + def _cell_transition(self: AnalysisMixinProtocol[K, B], *args: Any, **kwargs: Any) -> pd.DataFrame: ... - def _annotation_mapping(self: AnalysisMixinProtocol[K, B], *args: Any, **kwargs: Any) -> pd.DataFrame: - ... + def _annotation_mapping(self: AnalysisMixinProtocol[K, B], *args: Any, **kwargs: Any) -> pd.DataFrame: ... class CrossModalityTranslationMixin(AnalysisMixin[K, B]): @@ -82,7 +80,7 @@ def _get_features( attr: Dict[str, Any], ) -> ArrayLike: data = getattr(adata, attr["attr"]) - key = attr.get("key", None) + key = attr.get("key") return data if key is None else data[key] if self._src_attr is None: diff --git a/src/moscot/problems/generic/_mixins.py b/src/moscot/problems/generic/_mixins.py index 4b20d7537..6580c2ad5 100644 --- a/src/moscot/problems/generic/_mixins.py +++ b/src/moscot/problems/generic/_mixins.py @@ -24,8 +24,7 @@ def _cell_transition( self: AnalysisMixinProtocol[K, B], *args: Any, **kwargs: Any, - ) -> pd.DataFrame: - ... + ) -> pd.DataFrame: ... class GenericAnalysisMixin(AnalysisMixin[K, B]): @@ -170,7 +169,11 @@ def push( if TYPE_CHECKING: assert isinstance(key_added, str) plot_vars = { - "distribution_key": self.batch_key, + "source": source, + "target": target, + "key": self.batch_key, + "data": data if isinstance(data, str) else None, + "subset": subset, } self.adata.obs[key_added] = self._flatten(result, key=self.batch_key) set_plotting_vars(self.adata, _constants.PUSH, key=key_added, value=plot_vars) @@ -233,6 +236,10 @@ def pull( if key_added is not None: plot_vars = { "key": self.batch_key, + "data": data if isinstance(data, str) else None, + "subset": subset, + "source": source, + "target": target, } self.adata.obs[key_added] = self._flatten(result, key=self.batch_key) set_plotting_vars(self.adata, _constants.PULL, key=key_added, value=plot_vars) diff --git a/src/moscot/problems/space/_mixins.py b/src/moscot/problems/space/_mixins.py index 667a53962..1a94b5cda 100644 --- a/src/moscot/problems/space/_mixins.py +++ b/src/moscot/problems/space/_mixins.py @@ -46,30 +46,26 @@ def _subset_spatial( # type:ignore[empty-body] self: "SpatialAlignmentMixinProtocol[K, B]", k: K, spatial_key: str, - ) -> ArrayLike: - ... + ) -> ArrayLike: ... def _interpolate_scheme( # type:ignore[empty-body] self: "SpatialAlignmentMixinProtocol[K, B]", reference: K, mode: Literal["warp", "affine"], spatial_key: str, - ) -> Tuple[Dict[K, ArrayLike], Optional[Dict[K, Optional[ArrayLike]]]]: - ... + ) -> Tuple[Dict[K, ArrayLike], Optional[Dict[K, Optional[ArrayLike]]]]: ... def _cell_transition( self: AnalysisMixinProtocol[K, B], *args: Any, **kwargs: Any, - ) -> pd.DataFrame: - ... + ) -> pd.DataFrame: ... def _annotation_mapping( self: AnalysisMixinProtocol[K, B], *args: Any, **kwargs: Any, - ) -> pd.DataFrame: - ... + ) -> pd.DataFrame: ... class SpatialMappingMixinProtocol(AnalysisMixinProtocol[K, B]): @@ -84,14 +80,11 @@ class SpatialMappingMixinProtocol(AnalysisMixinProtocol[K, B]): def _filter_vars( self: "SpatialMappingMixinProtocol[K, B]", var_names: Optional[Sequence[str]] = None, - ) -> Optional[List[str]]: - ... + ) -> Optional[List[str]]: ... - def _cell_transition(self: AnalysisMixinProtocol[K, B], *args: Any, **kwargs: Any) -> pd.DataFrame: - ... + def _cell_transition(self: AnalysisMixinProtocol[K, B], *args: Any, **kwargs: Any) -> pd.DataFrame: ... - def _annotation_mapping(self: AnalysisMixinProtocol[K, B], *args: Any, **kwargs: Any) -> pd.DataFrame: - ... + def _annotation_mapping(self: AnalysisMixinProtocol[K, B], *args: Any, **kwargs: Any) -> pd.DataFrame: ... class SpatialAlignmentMixin(AnalysisMixin[K, B]): diff --git a/src/moscot/problems/time/_mixins.py b/src/moscot/problems/time/_mixins.py index c2e940e79..8b9401a57 100644 --- a/src/moscot/problems/time/_mixins.py +++ b/src/moscot/problems/time/_mixins.py @@ -50,28 +50,23 @@ def cell_transition( # noqa: D102 batch_size: Optional[int] = None, normalize: bool = True, key_added: Optional[str] = _constants.CELL_TRANSITION, - ) -> pd.DataFrame: - ... + ) -> pd.DataFrame: ... - def push(self, *args: Any, **kwargs: Any) -> Optional[ApplyOutput_t[K]]: - ... + def push(self, *args: Any, **kwargs: Any) -> Optional[ApplyOutput_t[K]]: ... - def pull(self, *args: Any, **kwargs: Any) -> Optional[ApplyOutput_t[K]]: - ... + def pull(self, *args: Any, **kwargs: Any) -> Optional[ApplyOutput_t[K]]: ... def _cell_transition( self: AnalysisMixinProtocol[K, B], *args: Any, **kwargs: Any, - ) -> pd.DataFrame: - ... + ) -> pd.DataFrame: ... def _annotation_mapping( self: AnalysisMixinProtocol[K, B], *args: Any, **kwargs: Any, - ) -> pd.DataFrame: - ... + ) -> pd.DataFrame: ... def _sample_from_tmap( self: TemporalMixinProtocol[K, B], @@ -84,8 +79,7 @@ def _sample_from_tmap( account_for_unbalancedness: bool = False, interpolation_parameter: Optional[float] = None, seed: Optional[int] = None, - ) -> tuple[list[Any], list[ArrayLike]]: - ... + ) -> tuple[list[Any], list[ArrayLike]]: ... def _compute_wasserstein_distance( self: TemporalMixinProtocol[K, B], @@ -95,8 +89,7 @@ def _compute_wasserstein_distance( b: Optional[ArrayLike] = None, backend: Literal["ott"] = "ott", **kwargs: Any, - ) -> float: - ... + ) -> float: ... def _interpolate_gex_with_ot( self: TemporalMixinProtocol[K, B], @@ -109,8 +102,7 @@ def _interpolate_gex_with_ot( account_for_unbalancedness: bool = True, batch_size: int = 256, seed: Optional[int] = None, - ) -> ArrayLike: - ... + ) -> ArrayLike: ... def _get_data( self: TemporalMixinProtocol[K, B], @@ -120,8 +112,7 @@ def _get_data( posterior_marginals: bool = True, *, only_start: bool = False, - ) -> Union[tuple[ArrayLike, AnnData], tuple[ArrayLike, ArrayLike, ArrayLike, AnnData, ArrayLike]]: - ... + ) -> Union[tuple[ArrayLike, AnnData], tuple[ArrayLike, ArrayLike, ArrayLike, AnnData, ArrayLike]]: ... def _interpolate_gex_randomly( self: TemporalMixinProtocol[K, B], @@ -131,8 +122,7 @@ def _interpolate_gex_randomly( interpolation_parameter: float, growth_rates: Optional[ArrayLike] = None, seed: Optional[int] = None, - ) -> ArrayLike: - ... + ) -> ArrayLike: ... def _plot_temporal( self: TemporalMixinProtocol[K, B], @@ -145,17 +135,14 @@ def _plot_temporal( fill_value: float = 0.0, save: Optional[Union[str, pathlib.Path]] = None, **kwargs: Any, - ) -> None: - ... + ) -> None: ... @staticmethod def _get_interp_param( source: K, intermediate: K, target: K, interpolation_parameter: Optional[float] = None - ) -> float: - ... + ) -> float: ... - def __iter__(self) -> Iterator[tuple[K, K]]: - ... + def __iter__(self) -> Iterator[tuple[K, K]]: ... class TemporalMixin(AnalysisMixin[K, B]): @@ -475,7 +462,7 @@ def push( plot_vars = { "source": source, "target": target, - "temporal_key": self.temporal_key, + "key": self.temporal_key, "data": data if isinstance(data, str) else None, "subset": subset, } @@ -539,7 +526,7 @@ def pull( if key_added is not None: plot_vars = { - "temporal_key": self.temporal_key, + "key": self.temporal_key, "data": data if isinstance(data, str) else None, "subset": subset, "source": source, diff --git a/src/moscot/utils/_data/mouse_proliferation.txt b/src/moscot/utils/_data/mouse_proliferation.txt index 6376e6a54..1a5fdb3d8 100644 --- a/src/moscot/utils/_data/mouse_proliferation.txt +++ b/src/moscot/utils/_data/mouse_proliferation.txt @@ -76,7 +76,6 @@ Tpx2 Aurka Anln Chaf1b -Hjurp Tacc3 Mcm5 Anp32e diff --git a/tests/plotting/conftest.py b/tests/plotting/conftest.py index 017992e29..50ba864bf 100644 --- a/tests/plotting/conftest.py +++ b/tests/plotting/conftest.py @@ -41,7 +41,7 @@ def adata_pl_cell_transition(gt_temporal_adata: AnnData) -> AnnData: @pytest.fixture() def adata_pl_push(adata_time: AnnData) -> AnnData: rng = np.random.RandomState(0) - plot_vars = {"temporal_key": "time", "data": "celltype", "subset": "A", "source": 0, "target": 1} + plot_vars = {"key": "time", "data": "celltype", "subset": "A", "source": 0, "target": 1} adata_time.uns["celltype_colors"] = ["#cc1b1b", "#2ccc1b", "#cc1bcc"] adata_time.obs["celltype"] = adata_time.obs["celltype"].astype("category") set_plotting_vars(adata_time, _constants.PUSH, key=_constants.PUSH, value=plot_vars) @@ -60,7 +60,7 @@ def adata_pl_push(adata_time: AnnData) -> AnnData: @pytest.fixture() def adata_pl_pull(adata_time: AnnData) -> AnnData: rng = np.random.RandomState(0) - plot_vars = {"temporal_key": "time", "data": "celltype", "subset": "A", "source": 0, "target": 1} + plot_vars = {"key": "time", "data": "celltype", "subset": "A", "source": 0, "target": 1} adata_time.uns["celltype_colors"] = ["#cc1b1b", "#2ccc1b", "#cc1bcc"] adata_time.obs["celltype"] = adata_time.obs["celltype"].astype("category") set_plotting_vars(adata_time, _constants.PULL, key=_constants.PULL, value=plot_vars) diff --git a/tests/problems/generic/test_mixins.py b/tests/problems/generic/test_mixins.py index 65a44f835..85c702c63 100644 --- a/tests/problems/generic/test_mixins.py +++ b/tests/problems/generic/test_mixins.py @@ -281,8 +281,9 @@ def test_compute_feature_correlation_transcription_factors( @pytest.mark.parametrize("forward", [True, False]) @pytest.mark.parametrize("key_added", [None, "test"]) @pytest.mark.parametrize("batch_size", [None, 2]) + @pytest.mark.parametrize("c", [0.0, 0.1]) def test_compute_entropy_pipeline( - self, adata_time: AnnData, forward: bool, key_added: Optional[str], batch_size: int + self, adata_time: AnnData, forward: bool, key_added: Optional[str], batch_size: int, c: float ): rng = np.random.RandomState(42) adata_time = adata_time[adata_time.obs["time"].isin((0, 1))].copy() @@ -295,7 +296,9 @@ def test_compute_entropy_pipeline( problem = problem.prepare(key="time", xy_callback="local-pca", policy="sequential") problem[0, 1]._solution = MockSolverOutput(tmap) - out = problem.compute_entropy(source=0, target=1, forward=forward, key_added=key_added, batch_size=batch_size) + out = problem.compute_entropy( + source=0, target=1, forward=forward, key_added=key_added, batch_size=batch_size, c=c + ) if key_added is None: assert isinstance(out, pd.DataFrame) assert len(out) == n0 @@ -312,18 +315,7 @@ def test_compute_entropy_pipeline( @pytest.mark.parametrize("forward", [True, False]) @pytest.mark.parametrize("batch_size", [None, 2, 15]) def test_compute_entropy_regression(self, adata_time: AnnData, forward: bool, batch_size: Optional[int]): - def gt_conditional_entropy(matrix): - px = np.sum(matrix, axis=1) - - # Initialize conditional entropy vector - h_y_given_x = np.zeros((len(px), 1)) - - # Compute conditional entropy for each value of x - for i in range(matrix.shape[0]): - for j in range(matrix.shape[1]): - h_y_given_x[i] -= matrix[i, j] * np.log(matrix[i, j] / px[i]) - - return h_y_given_x + from scipy import stats rng = np.random.RandomState(42) adata_time = adata_time[adata_time.obs["time"].isin((0, 1))].copy() @@ -337,7 +329,8 @@ def gt_conditional_entropy(matrix): problem[0, 1]._solution = MockSolverOutput(tmap) moscot_out = problem.compute_entropy(source=0, target=1, forward=forward, batch_size=batch_size, key_added=None) - gt_out = gt_conditional_entropy(tmap) if forward else gt_conditional_entropy(tmap.T) + gt_out = stats.entropy(tmap + 1e-10, axis=1 if forward else 0) + gt_out = np.expand_dims(gt_out, axis=1) if forward else np.expand_dims(gt_out, axis=0).T np.testing.assert_allclose( np.array(moscot_out, dtype=float), np.array(gt_out, dtype=float), rtol=RTOL, atol=ATOL