diff --git a/src/moscot/base/problems/_mixins.py b/src/moscot/base/problems/_mixins.py index a38fae2e9..ba52e3eb7 100644 --- a/src/moscot/base/problems/_mixins.py +++ b/src/moscot/base/problems/_mixins.py @@ -666,6 +666,7 @@ def compute_entropy( forward: bool = True, key_added: Optional[str] = "conditional_entropy", batch_size: Optional[int] = None, + c: float = 0.0, ) -> Optional[pd.DataFrame]: """Compute the conditional entropy per cell. @@ -685,6 +686,8 @@ def compute_entropy( Key in :attr:`~anndata.AnnData.obs` where the entropy is stored. batch_size Batch size for the computation of the entropy. If :obj:`None`, the entire dataset is used. + c + Constant added to each row of the transport matrix to avoid numerical instability. Returns ------- @@ -710,7 +713,7 @@ def compute_entropy( split_mass=True, key_added=None, ) - df.iloc[range(batch, min(batch + batch_size, len(df))), 0] = _compute_conditional_entropy(cond_dists) # type: ignore[arg-type] + df.iloc[range(batch, min(batch + batch_size, len(df))), 0] = _compute_conditional_entropy(cond_dists + c) # type: ignore[arg-type, operator] if key_added is not None: self.adata.obs[key_added] = df return df if key_added is None else None diff --git a/src/moscot/base/problems/_utils.py b/src/moscot/base/problems/_utils.py index 397248d3e..6b52ec3b2 100644 --- a/src/moscot/base/problems/_utils.py +++ b/src/moscot/base/problems/_utils.py @@ -721,5 +721,5 @@ def _get_n_cores(n_cores: Optional[int], n_jobs: Optional[int]) -> int: return n_cores -def _compute_conditional_entropy(p_xy: ArrayLike) -> ArrayLike: +def _compute_conditional_entropy(p_xy: ArrayLike, c: float = 0.0) -> ArrayLike: return -np.sum(p_xy * np.log(p_xy / p_xy.sum(axis=0)), axis=0) diff --git a/tests/problems/generic/test_mixins.py b/tests/problems/generic/test_mixins.py index 65a44f835..062dbcd3d 100644 --- a/tests/problems/generic/test_mixins.py +++ b/tests/problems/generic/test_mixins.py @@ -281,8 +281,9 @@ def test_compute_feature_correlation_transcription_factors( @pytest.mark.parametrize("forward", [True, False]) @pytest.mark.parametrize("key_added", [None, "test"]) @pytest.mark.parametrize("batch_size", [None, 2]) + @pytest.mark.parametrize("c", [0.0, 0.1]) def test_compute_entropy_pipeline( - self, adata_time: AnnData, forward: bool, key_added: Optional[str], batch_size: int + self, adata_time: AnnData, forward: bool, key_added: Optional[str], batch_size: int, c: float ): rng = np.random.RandomState(42) adata_time = adata_time[adata_time.obs["time"].isin((0, 1))].copy() @@ -295,7 +296,9 @@ def test_compute_entropy_pipeline( problem = problem.prepare(key="time", xy_callback="local-pca", policy="sequential") problem[0, 1]._solution = MockSolverOutput(tmap) - out = problem.compute_entropy(source=0, target=1, forward=forward, key_added=key_added, batch_size=batch_size) + out = problem.compute_entropy( + source=0, target=1, forward=forward, key_added=key_added, batch_size=batch_size, c=c + ) if key_added is None: assert isinstance(out, pd.DataFrame) assert len(out) == n0