From c1d61185d61a4829adb6d48ddd9a92a74ea1df05 Mon Sep 17 00:00:00 2001 From: Arina Danilina Date: Sat, 20 Jan 2024 14:32:31 +0100 Subject: [PATCH 1/8] passing batch_size fix --- src/moscot/base/problems/_mixins.py | 8 ++++---- src/moscot/problems/cross_modality/_mixins.py | 2 ++ src/moscot/problems/space/_mixins.py | 4 ++++ src/moscot/problems/time/_mixins.py | 2 ++ tests/problems/cross_modality/test_mixins.py | 2 +- tests/problems/space/test_mixins.py | 2 ++ tests/problems/time/test_mixins.py | 2 +- 7 files changed, 16 insertions(+), 6 deletions(-) diff --git a/src/moscot/base/problems/_mixins.py b/src/moscot/base/problems/_mixins.py index cd09cd9b9..70e26f461 100644 --- a/src/moscot/base/problems/_mixins.py +++ b/src/moscot/base/problems/_mixins.py @@ -342,7 +342,7 @@ def _annotation_mapping( out_len = self.solutions[(source, target)].shape[1] batch_size = batch_size if batch_size is not None else out_len for batch in range(0, out_len, batch_size): - tm_batch: ArrayLike = self.push( + tm_batch: ArrayLike = self.pull( source=source, target=target, data=None, @@ -353,7 +353,7 @@ def _annotation_mapping( split_mass=True, key_added=None, ) - v = np.array(tm_batch.argmax(1)) + v = (tm_batch.argmax(0)) out.extend(source_df[annotation_label][v[i]] for i in range(len(v))) else: @@ -366,7 +366,7 @@ def _annotation_mapping( out_len = self.solutions[(source, target)].shape[0] batch_size = batch_size if batch_size is not None else out_len for batch in range(0, out_len, batch_size): - tm_batch: ArrayLike = self.pull( # type: ignore[no-redef] + tm_batch: ArrayLike = self.push( # type: ignore[no-redef] source=source, target=target, data=None, @@ -377,7 +377,7 @@ def _annotation_mapping( split_mass=True, key_added=None, ) - v = np.array(tm_batch.argmax(1)) + v = (tm_batch.argmax(0)) out.extend(target_df[annotation_label][v[i]] for i in range(len(v))) categories = pd.Categorical(out) return pd.DataFrame(categories, columns=[annotation_label]) diff --git a/src/moscot/problems/cross_modality/_mixins.py b/src/moscot/problems/cross_modality/_mixins.py index ce58f84a4..0e1ead762 100644 --- a/src/moscot/problems/cross_modality/_mixins.py +++ b/src/moscot/problems/cross_modality/_mixins.py @@ -195,6 +195,7 @@ def annotation_mapping( # type: ignore[misc] source: str = "src", target: str = "tgt", cell_transition_kwargs: Mapping[str, Any] = types.MappingProxyType({}), + **kwargs: Mapping[str, Any] ) -> pd.DataFrame: """Transfer annotations between distributions. @@ -232,6 +233,7 @@ def annotation_mapping( # type: ignore[misc] forward=forward, other_adata=self.adata_tgt, cell_transition_kwargs=cell_transition_kwargs, + **kwargs ) @property diff --git a/src/moscot/problems/space/_mixins.py b/src/moscot/problems/space/_mixins.py index 0aa84c326..ac601a51e 100644 --- a/src/moscot/problems/space/_mixins.py +++ b/src/moscot/problems/space/_mixins.py @@ -292,6 +292,7 @@ def annotation_mapping( # type: ignore[misc] source: str = "src", target: str = "tgt", cell_transition_kwargs: Mapping[str, Any] = types.MappingProxyType({}), + **kwargs: Mapping[str, Any] ) -> pd.DataFrame: """Transfer annotations between distributions. @@ -328,6 +329,7 @@ def annotation_mapping( # type: ignore[misc] key=self.batch_key, forward=forward, cell_transition_kwargs=cell_transition_kwargs, + **kwargs ) @property @@ -627,6 +629,7 @@ def annotation_mapping( # type: ignore[misc] target: Union[K, str] = "tgt", forward: bool = False, cell_transition_kwargs: Mapping[str, Any] = types.MappingProxyType({}), + **kwargs: Mapping[str, Any] ) -> pd.DataFrame: """Transfer annotations between distributions. @@ -664,6 +667,7 @@ def annotation_mapping( # type: ignore[misc] key=self.batch_key, other_adata=self.adata_sc, cell_transition_kwargs=cell_transition_kwargs, + **kwargs ) @property diff --git a/src/moscot/problems/time/_mixins.py b/src/moscot/problems/time/_mixins.py index 597757901..fdc6c92b8 100644 --- a/src/moscot/problems/time/_mixins.py +++ b/src/moscot/problems/time/_mixins.py @@ -248,6 +248,7 @@ def annotation_mapping( source: K, target: K, cell_transition_kwargs: Mapping[str, Any] = types.MappingProxyType({}), + **kwargs: Mapping[str, Any] ) -> pd.DataFrame: """Transfer annotations between distributions. @@ -285,6 +286,7 @@ def annotation_mapping( forward=forward, other_adata=None, cell_transition_kwargs=cell_transition_kwargs, + **kwargs ) def sankey( diff --git a/tests/problems/cross_modality/test_mixins.py b/tests/problems/cross_modality/test_mixins.py index 079e153a4..940ee79e4 100644 --- a/tests/problems/cross_modality/test_mixins.py +++ b/tests/problems/cross_modality/test_mixins.py @@ -122,7 +122,7 @@ def test_annotation_mapping( tp[problem_keys].set_solution(MockSolverOutput(gt_tm_annotation), overwrite=True) annotation_label = "celltype1" if forward else "celltype2" result = tp.annotation_mapping( - mapping_mode=mapping_mode, annotation_label=annotation_label, forward=forward, source="src", target="tgt" + mapping_mode=mapping_mode, annotation_label=annotation_label, forward=forward, source="src", target="tgt", batch_size=7, ) if forward: expected_result = ( diff --git a/tests/problems/space/test_mixins.py b/tests/problems/space/test_mixins.py index a6b70031c..d81eaab20 100644 --- a/tests/problems/space/test_mixins.py +++ b/tests/problems/space/test_mixins.py @@ -110,6 +110,7 @@ def test_annotation_mapping(self, adata_anno: AnnData, forward: bool, mapping_mo source="0", target="1", forward=forward, + batch_size=7, ) if forward: expected_result = ( @@ -221,6 +222,7 @@ def test_annotation_mapping(self, adata_anno: AnnData, forward: bool, mapping_mo annotation_label=annotation_label, source="src", forward=forward, + batch_size=7, ) if not forward: expected_result = adataref.uns["expected_max1"] if mapping_mode == "max" else adataref.uns["expected_sum1"] diff --git a/tests/problems/time/test_mixins.py b/tests/problems/time/test_mixins.py index cb2d9ea2a..f8eaa604f 100644 --- a/tests/problems/time/test_mixins.py +++ b/tests/problems/time/test_mixins.py @@ -62,7 +62,7 @@ def test_annotation_mapping(self, adata_anno: AnnData, forward: bool, mapping_mo problem[problem_keys]._solution = MockSolverOutput(gt_tm_annotation) annotation_label = "celltype1" if forward else "celltype2" result = problem.annotation_mapping( - mapping_mode=mapping_mode, annotation_label=annotation_label, forward=forward, source=0, target=1 + mapping_mode=mapping_mode, annotation_label=annotation_label, forward=forward, source=0, target=1, batch_size=7, ) if forward: expected_result = ( From f94d838a3ed01140965c05240bde157d12e9eb62 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 20 Jan 2024 13:33:35 +0000 Subject: [PATCH 2/8] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/moscot/base/problems/_mixins.py | 4 ++-- src/moscot/problems/cross_modality/_mixins.py | 4 ++-- src/moscot/problems/space/_mixins.py | 8 ++++---- src/moscot/problems/time/_mixins.py | 4 ++-- tests/problems/cross_modality/test_mixins.py | 7 ++++++- tests/problems/time/test_mixins.py | 7 ++++++- 6 files changed, 22 insertions(+), 12 deletions(-) diff --git a/src/moscot/base/problems/_mixins.py b/src/moscot/base/problems/_mixins.py index 70e26f461..56267dba1 100644 --- a/src/moscot/base/problems/_mixins.py +++ b/src/moscot/base/problems/_mixins.py @@ -353,7 +353,7 @@ def _annotation_mapping( split_mass=True, key_added=None, ) - v = (tm_batch.argmax(0)) + v = tm_batch.argmax(0) out.extend(source_df[annotation_label][v[i]] for i in range(len(v))) else: @@ -377,7 +377,7 @@ def _annotation_mapping( split_mass=True, key_added=None, ) - v = (tm_batch.argmax(0)) + v = tm_batch.argmax(0) out.extend(target_df[annotation_label][v[i]] for i in range(len(v))) categories = pd.Categorical(out) return pd.DataFrame(categories, columns=[annotation_label]) diff --git a/src/moscot/problems/cross_modality/_mixins.py b/src/moscot/problems/cross_modality/_mixins.py index 0e1ead762..23e763de6 100644 --- a/src/moscot/problems/cross_modality/_mixins.py +++ b/src/moscot/problems/cross_modality/_mixins.py @@ -195,7 +195,7 @@ def annotation_mapping( # type: ignore[misc] source: str = "src", target: str = "tgt", cell_transition_kwargs: Mapping[str, Any] = types.MappingProxyType({}), - **kwargs: Mapping[str, Any] + **kwargs: Mapping[str, Any], ) -> pd.DataFrame: """Transfer annotations between distributions. @@ -233,7 +233,7 @@ def annotation_mapping( # type: ignore[misc] forward=forward, other_adata=self.adata_tgt, cell_transition_kwargs=cell_transition_kwargs, - **kwargs + **kwargs, ) @property diff --git a/src/moscot/problems/space/_mixins.py b/src/moscot/problems/space/_mixins.py index ac601a51e..9c53eb6fd 100644 --- a/src/moscot/problems/space/_mixins.py +++ b/src/moscot/problems/space/_mixins.py @@ -292,7 +292,7 @@ def annotation_mapping( # type: ignore[misc] source: str = "src", target: str = "tgt", cell_transition_kwargs: Mapping[str, Any] = types.MappingProxyType({}), - **kwargs: Mapping[str, Any] + **kwargs: Mapping[str, Any], ) -> pd.DataFrame: """Transfer annotations between distributions. @@ -329,7 +329,7 @@ def annotation_mapping( # type: ignore[misc] key=self.batch_key, forward=forward, cell_transition_kwargs=cell_transition_kwargs, - **kwargs + **kwargs, ) @property @@ -629,7 +629,7 @@ def annotation_mapping( # type: ignore[misc] target: Union[K, str] = "tgt", forward: bool = False, cell_transition_kwargs: Mapping[str, Any] = types.MappingProxyType({}), - **kwargs: Mapping[str, Any] + **kwargs: Mapping[str, Any], ) -> pd.DataFrame: """Transfer annotations between distributions. @@ -667,7 +667,7 @@ def annotation_mapping( # type: ignore[misc] key=self.batch_key, other_adata=self.adata_sc, cell_transition_kwargs=cell_transition_kwargs, - **kwargs + **kwargs, ) @property diff --git a/src/moscot/problems/time/_mixins.py b/src/moscot/problems/time/_mixins.py index fdc6c92b8..a631cab78 100644 --- a/src/moscot/problems/time/_mixins.py +++ b/src/moscot/problems/time/_mixins.py @@ -248,7 +248,7 @@ def annotation_mapping( source: K, target: K, cell_transition_kwargs: Mapping[str, Any] = types.MappingProxyType({}), - **kwargs: Mapping[str, Any] + **kwargs: Mapping[str, Any], ) -> pd.DataFrame: """Transfer annotations between distributions. @@ -286,7 +286,7 @@ def annotation_mapping( forward=forward, other_adata=None, cell_transition_kwargs=cell_transition_kwargs, - **kwargs + **kwargs, ) def sankey( diff --git a/tests/problems/cross_modality/test_mixins.py b/tests/problems/cross_modality/test_mixins.py index 940ee79e4..97b498f1e 100644 --- a/tests/problems/cross_modality/test_mixins.py +++ b/tests/problems/cross_modality/test_mixins.py @@ -122,7 +122,12 @@ def test_annotation_mapping( tp[problem_keys].set_solution(MockSolverOutput(gt_tm_annotation), overwrite=True) annotation_label = "celltype1" if forward else "celltype2" result = tp.annotation_mapping( - mapping_mode=mapping_mode, annotation_label=annotation_label, forward=forward, source="src", target="tgt", batch_size=7, + mapping_mode=mapping_mode, + annotation_label=annotation_label, + forward=forward, + source="src", + target="tgt", + batch_size=7, ) if forward: expected_result = ( diff --git a/tests/problems/time/test_mixins.py b/tests/problems/time/test_mixins.py index f8eaa604f..a6860a5dc 100644 --- a/tests/problems/time/test_mixins.py +++ b/tests/problems/time/test_mixins.py @@ -62,7 +62,12 @@ def test_annotation_mapping(self, adata_anno: AnnData, forward: bool, mapping_mo problem[problem_keys]._solution = MockSolverOutput(gt_tm_annotation) annotation_label = "celltype1" if forward else "celltype2" result = problem.annotation_mapping( - mapping_mode=mapping_mode, annotation_label=annotation_label, forward=forward, source=0, target=1, batch_size=7, + mapping_mode=mapping_mode, + annotation_label=annotation_label, + forward=forward, + source=0, + target=1, + batch_size=7, ) if forward: expected_result = ( From e5a1f83de866cc04d339bf240f97565abae39420 Mon Sep 17 00:00:00 2001 From: Arina Danilina Date: Sat, 20 Jan 2024 16:10:34 +0100 Subject: [PATCH 3/8] index not a scalar fix --- src/moscot/base/problems/_mixins.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/moscot/base/problems/_mixins.py b/src/moscot/base/problems/_mixins.py index 56267dba1..31efd57c6 100644 --- a/src/moscot/base/problems/_mixins.py +++ b/src/moscot/base/problems/_mixins.py @@ -353,7 +353,7 @@ def _annotation_mapping( split_mass=True, key_added=None, ) - v = tm_batch.argmax(0) + v = np.array(tm_batch.argmax(0)) out.extend(source_df[annotation_label][v[i]] for i in range(len(v))) else: @@ -377,7 +377,7 @@ def _annotation_mapping( split_mass=True, key_added=None, ) - v = tm_batch.argmax(0) + v = np.array(tm_batch.argmax(0)) out.extend(target_df[annotation_label][v[i]] for i in range(len(v))) categories = pd.Categorical(out) return pd.DataFrame(categories, columns=[annotation_label]) From 98bf16a719eb2b72c6548dfcb419200bee075a01 Mon Sep 17 00:00:00 2001 From: Arina Danilina Date: Sat, 20 Jan 2024 19:18:57 +0100 Subject: [PATCH 4/8] expose batch_size, parametrize it in tests --- src/moscot/base/problems/_mixins.py | 1 + src/moscot/problems/cross_modality/_mixins.py | 5 +++++ src/moscot/problems/space/_mixins.py | 10 ++++++++++ src/moscot/problems/time/_mixins.py | 5 +++++ tests/problems/cross_modality/test_mixins.py | 5 +++-- tests/problems/space/test_mixins.py | 10 ++++++---- tests/problems/time/test_mixins.py | 5 +++-- 7 files changed, 33 insertions(+), 8 deletions(-) diff --git a/src/moscot/base/problems/_mixins.py b/src/moscot/base/problems/_mixins.py index 31efd57c6..dc76889fd 100644 --- a/src/moscot/base/problems/_mixins.py +++ b/src/moscot/base/problems/_mixins.py @@ -320,6 +320,7 @@ def _annotation_mapping( cell_transition_kwargs.setdefault("target", target) cell_transition_kwargs.setdefault("other_adata", other_adata) cell_transition_kwargs.setdefault("forward", not forward) + cell_transition_kwargs.setdefault("batch_size", batch_size) if forward: cell_transition_kwargs.setdefault("source_groups", annotation_label) cell_transition_kwargs.setdefault("target_groups", None) diff --git a/src/moscot/problems/cross_modality/_mixins.py b/src/moscot/problems/cross_modality/_mixins.py index 23e763de6..e07260e74 100644 --- a/src/moscot/problems/cross_modality/_mixins.py +++ b/src/moscot/problems/cross_modality/_mixins.py @@ -194,6 +194,7 @@ def annotation_mapping( # type: ignore[misc] forward: bool, source: str = "src", target: str = "tgt", + batch_size: int | None = None, cell_transition_kwargs: Mapping[str, Any] = types.MappingProxyType({}), **kwargs: Mapping[str, Any], ) -> pd.DataFrame: @@ -217,6 +218,9 @@ def annotation_mapping( # type: ignore[misc] Key identifying the source distribution. target Key identifying the target distribution. + batch_size + Number of rows/columns of the cost matrix to materialize during :meth:`push` or :meth:`pull`. + Larger value will require more memory. cell_transition_kwargs Keyword arguments for :meth:`cell_transition`, used only if ``mapping_mode = 'sum'``. @@ -232,6 +236,7 @@ def annotation_mapping( # type: ignore[misc] key=self.batch_key, forward=forward, other_adata=self.adata_tgt, + batch_size=batch_size, cell_transition_kwargs=cell_transition_kwargs, **kwargs, ) diff --git a/src/moscot/problems/space/_mixins.py b/src/moscot/problems/space/_mixins.py index 9c53eb6fd..ad8cf9186 100644 --- a/src/moscot/problems/space/_mixins.py +++ b/src/moscot/problems/space/_mixins.py @@ -291,6 +291,7 @@ def annotation_mapping( # type: ignore[misc] forward: bool, source: str = "src", target: str = "tgt", + batch_size: int | None = None, cell_transition_kwargs: Mapping[str, Any] = types.MappingProxyType({}), **kwargs: Mapping[str, Any], ) -> pd.DataFrame: @@ -314,6 +315,9 @@ def annotation_mapping( # type: ignore[misc] Key identifying the source distribution. target Key identifying the target distribution. + batch_size + Number of rows/columns of the cost matrix to materialize during :meth:`push` or :meth:`pull`. + Larger value will require more memory. cell_transition_kwargs Keyword arguments for :meth:`cell_transition`, used only if ``mapping_mode = 'sum'``. @@ -328,6 +332,7 @@ def annotation_mapping( # type: ignore[misc] target=target, key=self.batch_key, forward=forward, + batch_size=batch_size, cell_transition_kwargs=cell_transition_kwargs, **kwargs, ) @@ -628,6 +633,7 @@ def annotation_mapping( # type: ignore[misc] source: K, target: Union[K, str] = "tgt", forward: bool = False, + batch_size: int | None = None, cell_transition_kwargs: Mapping[str, Any] = types.MappingProxyType({}), **kwargs: Mapping[str, Any], ) -> pd.DataFrame: @@ -651,6 +657,9 @@ def annotation_mapping( # type: ignore[misc] Key identifying the source distribution. target Key identifying the target distribution. + batch_size + Number of rows/columns of the cost matrix to materialize during :meth:`push` or :meth:`pull`. + Larger value will require more memory. cell_transition_kwargs Keyword arguments for :meth:`cell_transition`, used only if ``mapping_mode = 'sum'``. @@ -666,6 +675,7 @@ def annotation_mapping( # type: ignore[misc] forward=forward, key=self.batch_key, other_adata=self.adata_sc, + batch_size=batch_size, cell_transition_kwargs=cell_transition_kwargs, **kwargs, ) diff --git a/src/moscot/problems/time/_mixins.py b/src/moscot/problems/time/_mixins.py index a631cab78..cc47df43a 100644 --- a/src/moscot/problems/time/_mixins.py +++ b/src/moscot/problems/time/_mixins.py @@ -247,6 +247,7 @@ def annotation_mapping( forward: bool, source: K, target: K, + batch_size: int | None = None, cell_transition_kwargs: Mapping[str, Any] = types.MappingProxyType({}), **kwargs: Mapping[str, Any], ) -> pd.DataFrame: @@ -270,6 +271,9 @@ def annotation_mapping( Key identifying the source distribution. target Key identifying the target distribution. + batch_size + Number of rows/columns of the cost matrix to materialize during :meth:`push` or :meth:`pull`. + Larger value will require more memory. cell_transition_kwargs Keyword arguments for :meth:`cell_transition`, used only if ``mapping_mode = 'sum'``. @@ -285,6 +289,7 @@ def annotation_mapping( key=self._temporal_key, forward=forward, other_adata=None, + batch_size=batch_size, cell_transition_kwargs=cell_transition_kwargs, **kwargs, ) diff --git a/tests/problems/cross_modality/test_mixins.py b/tests/problems/cross_modality/test_mixins.py index 97b498f1e..781c7c8d1 100644 --- a/tests/problems/cross_modality/test_mixins.py +++ b/tests/problems/cross_modality/test_mixins.py @@ -110,9 +110,10 @@ def test_cell_transition_pipeline( @pytest.mark.fast() @pytest.mark.parametrize("forward", [True, False]) @pytest.mark.parametrize("mapping_mode", ["max", "sum"]) + @pytest.mark.parametrize("batch_size", [3, 7, None]) @pytest.mark.parametrize("problem_kind", ["cross_modality"]) def test_annotation_mapping( - self, adata_anno: Tuple[AnnData, AnnData], forward: bool, mapping_mode, gt_tm_annotation + self, adata_anno: Tuple[AnnData, AnnData], forward: bool, mapping_mode, batch_size, gt_tm_annotation ): adata_src, adata_tgt = adata_anno tp = TranslationProblem(adata_src, adata_tgt) @@ -127,7 +128,7 @@ def test_annotation_mapping( forward=forward, source="src", target="tgt", - batch_size=7, + batch_size=batch_size, ) if forward: expected_result = ( diff --git a/tests/problems/space/test_mixins.py b/tests/problems/space/test_mixins.py index d81eaab20..c105683a9 100644 --- a/tests/problems/space/test_mixins.py +++ b/tests/problems/space/test_mixins.py @@ -96,8 +96,9 @@ def test_cell_transition_pipeline(self, adata_space_rotate: AnnData, forward: bo @pytest.mark.fast() @pytest.mark.parametrize("forward", [True, False]) @pytest.mark.parametrize("mapping_mode", ["max", "sum"]) + @pytest.mark.parametrize("batch_size", [3, 7, None]) @pytest.mark.parametrize("problem_kind", ["alignment"]) - def test_annotation_mapping(self, adata_anno: AnnData, forward: bool, mapping_mode, gt_tm_annotation): + def test_annotation_mapping(self, adata_anno: AnnData, forward: bool, mapping_mode, batch_size, gt_tm_annotation): ap = AlignmentProblem(adata=adata_anno) ap = ap.prepare(batch_key="batch", joint_attr={"attr": "X"}) problem_keys = ("0", "1") @@ -110,7 +111,7 @@ def test_annotation_mapping(self, adata_anno: AnnData, forward: bool, mapping_mo source="0", target="1", forward=forward, - batch_size=7, + batch_size=batch_size, ) if forward: expected_result = ( @@ -208,8 +209,9 @@ def test_cell_transition_pipeline(self, adata_mapping: AnnData, forward: bool, n @pytest.mark.fast() @pytest.mark.parametrize("forward", [True, False]) @pytest.mark.parametrize("mapping_mode", ["max", "sum"]) + @pytest.mark.parametrize("batch_size", [3, 7, None]) @pytest.mark.parametrize("problem_kind", ["mapping"]) - def test_annotation_mapping(self, adata_anno: AnnData, forward: bool, mapping_mode, gt_tm_annotation): + def test_annotation_mapping(self, adata_anno: AnnData, forward: bool, mapping_mode, batch_size, gt_tm_annotation): adataref, adatasp = adata_anno mp = MappingProblem(adataref, adatasp) mp = mp.prepare(sc_attr={"attr": "obsm", "key": "X_pca"}, joint_attr={"attr": "X"}) @@ -222,7 +224,7 @@ def test_annotation_mapping(self, adata_anno: AnnData, forward: bool, mapping_mo annotation_label=annotation_label, source="src", forward=forward, - batch_size=7, + batch_size=batch_size, ) if not forward: expected_result = adataref.uns["expected_max1"] if mapping_mode == "max" else adataref.uns["expected_sum1"] diff --git a/tests/problems/time/test_mixins.py b/tests/problems/time/test_mixins.py index a6860a5dc..e5c9e36f8 100644 --- a/tests/problems/time/test_mixins.py +++ b/tests/problems/time/test_mixins.py @@ -53,8 +53,9 @@ def test_cell_transition_full_pipeline(self, gt_temporal_adata: AnnData, forward @pytest.mark.fast() @pytest.mark.parametrize("forward", [True, False]) @pytest.mark.parametrize("mapping_mode", ["max", "sum"]) + @pytest.mark.parametrize("batch_size", [3, 7, None]) @pytest.mark.parametrize("problem_kind", ["temporal"]) - def test_annotation_mapping(self, adata_anno: AnnData, forward: bool, mapping_mode, gt_tm_annotation): + def test_annotation_mapping(self, adata_anno: AnnData, forward: bool, mapping_mode, batch_size, gt_tm_annotation): problem = TemporalProblem(adata_anno) problem_keys = (0, 1) problem = problem.prepare(time_key="day", joint_attr="X_pca") @@ -67,7 +68,7 @@ def test_annotation_mapping(self, adata_anno: AnnData, forward: bool, mapping_mo forward=forward, source=0, target=1, - batch_size=7, + batch_size=batch_size, ) if forward: expected_result = ( From 3da07cff861197f127d51b5f89f380363428c7cf Mon Sep 17 00:00:00 2001 From: Arina Danilina Date: Sat, 20 Jan 2024 19:23:23 +0100 Subject: [PATCH 5/8] Optional instead of | None --- src/moscot/problems/cross_modality/_mixins.py | 2 +- src/moscot/problems/space/_mixins.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/moscot/problems/cross_modality/_mixins.py b/src/moscot/problems/cross_modality/_mixins.py index e07260e74..23a066a14 100644 --- a/src/moscot/problems/cross_modality/_mixins.py +++ b/src/moscot/problems/cross_modality/_mixins.py @@ -194,7 +194,7 @@ def annotation_mapping( # type: ignore[misc] forward: bool, source: str = "src", target: str = "tgt", - batch_size: int | None = None, + batch_size: Optional[int] = None, cell_transition_kwargs: Mapping[str, Any] = types.MappingProxyType({}), **kwargs: Mapping[str, Any], ) -> pd.DataFrame: diff --git a/src/moscot/problems/space/_mixins.py b/src/moscot/problems/space/_mixins.py index ad8cf9186..dd0e8a28b 100644 --- a/src/moscot/problems/space/_mixins.py +++ b/src/moscot/problems/space/_mixins.py @@ -291,7 +291,7 @@ def annotation_mapping( # type: ignore[misc] forward: bool, source: str = "src", target: str = "tgt", - batch_size: int | None = None, + batch_size: Optional[int] = None, cell_transition_kwargs: Mapping[str, Any] = types.MappingProxyType({}), **kwargs: Mapping[str, Any], ) -> pd.DataFrame: @@ -633,7 +633,7 @@ def annotation_mapping( # type: ignore[misc] source: K, target: Union[K, str] = "tgt", forward: bool = False, - batch_size: int | None = None, + batch_size: Optional[int] = None, cell_transition_kwargs: Mapping[str, Any] = types.MappingProxyType({}), **kwargs: Mapping[str, Any], ) -> pd.DataFrame: From 281bae6a664fe6892c837aee011185b7eaa9a4a1 Mon Sep 17 00:00:00 2001 From: Arina Danilina Date: Mon, 22 Jan 2024 12:35:39 +0100 Subject: [PATCH 6/8] if batch_size is None --- src/moscot/problems/cross_modality/_mixins.py | 1 + src/moscot/problems/space/_mixins.py | 2 ++ src/moscot/problems/time/_mixins.py | 3 ++- 3 files changed, 5 insertions(+), 1 deletion(-) diff --git a/src/moscot/problems/cross_modality/_mixins.py b/src/moscot/problems/cross_modality/_mixins.py index 23a066a14..1ad0436b4 100644 --- a/src/moscot/problems/cross_modality/_mixins.py +++ b/src/moscot/problems/cross_modality/_mixins.py @@ -220,6 +220,7 @@ def annotation_mapping( # type: ignore[misc] Key identifying the target distribution. batch_size Number of rows/columns of the cost matrix to materialize during :meth:`push` or :meth:`pull`. + If :obj:`None` the entire cost matrix will be materialized. Larger value will require more memory. cell_transition_kwargs Keyword arguments for :meth:`cell_transition`, used only if ``mapping_mode = 'sum'``. diff --git a/src/moscot/problems/space/_mixins.py b/src/moscot/problems/space/_mixins.py index dd0e8a28b..80f60d53b 100644 --- a/src/moscot/problems/space/_mixins.py +++ b/src/moscot/problems/space/_mixins.py @@ -317,6 +317,7 @@ def annotation_mapping( # type: ignore[misc] Key identifying the target distribution. batch_size Number of rows/columns of the cost matrix to materialize during :meth:`push` or :meth:`pull`. + If :obj:`None`, the entire cost matrix will be materialized. Larger value will require more memory. cell_transition_kwargs Keyword arguments for :meth:`cell_transition`, used only if ``mapping_mode = 'sum'``. @@ -659,6 +660,7 @@ def annotation_mapping( # type: ignore[misc] Key identifying the target distribution. batch_size Number of rows/columns of the cost matrix to materialize during :meth:`push` or :meth:`pull`. + If :obj:`None`, the entire cost matrix will be materialized. Larger value will require more memory. cell_transition_kwargs Keyword arguments for :meth:`cell_transition`, used only if ``mapping_mode = 'sum'``. diff --git a/src/moscot/problems/time/_mixins.py b/src/moscot/problems/time/_mixins.py index cc47df43a..117bec531 100644 --- a/src/moscot/problems/time/_mixins.py +++ b/src/moscot/problems/time/_mixins.py @@ -247,7 +247,7 @@ def annotation_mapping( forward: bool, source: K, target: K, - batch_size: int | None = None, + batch_size: Optional[int] = None, cell_transition_kwargs: Mapping[str, Any] = types.MappingProxyType({}), **kwargs: Mapping[str, Any], ) -> pd.DataFrame: @@ -273,6 +273,7 @@ def annotation_mapping( Key identifying the target distribution. batch_size Number of rows/columns of the cost matrix to materialize during :meth:`push` or :meth:`pull`. + If :obj:`None`, the entire cost matrix will be materialized. Larger value will require more memory. cell_transition_kwargs Keyword arguments for :meth:`cell_transition`, used only if ``mapping_mode = 'sum'``. From 88e321a29d36367f1a83ae0b249f835e427c3f19 Mon Sep 17 00:00:00 2001 From: Arina Danilina Date: Mon, 22 Jan 2024 12:36:26 +0100 Subject: [PATCH 7/8] if batch_size is None --- src/moscot/problems/cross_modality/_mixins.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/moscot/problems/cross_modality/_mixins.py b/src/moscot/problems/cross_modality/_mixins.py index 1ad0436b4..86678ea49 100644 --- a/src/moscot/problems/cross_modality/_mixins.py +++ b/src/moscot/problems/cross_modality/_mixins.py @@ -220,7 +220,7 @@ def annotation_mapping( # type: ignore[misc] Key identifying the target distribution. batch_size Number of rows/columns of the cost matrix to materialize during :meth:`push` or :meth:`pull`. - If :obj:`None` the entire cost matrix will be materialized. + If :obj:`None`, the entire cost matrix will be materialized. Larger value will require more memory. cell_transition_kwargs Keyword arguments for :meth:`cell_transition`, used only if ``mapping_mode = 'sum'``. From cd0e30111c8cc9a2cf49b2273114ba12ed932452 Mon Sep 17 00:00:00 2001 From: Arina Danilina Date: Tue, 23 Jan 2024 10:45:43 +0100 Subject: [PATCH 8/8] line order --- src/moscot/problems/cross_modality/_mixins.py | 2 +- src/moscot/problems/space/_mixins.py | 4 ++-- src/moscot/problems/time/_mixins.py | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/moscot/problems/cross_modality/_mixins.py b/src/moscot/problems/cross_modality/_mixins.py index 86678ea49..5299e1dab 100644 --- a/src/moscot/problems/cross_modality/_mixins.py +++ b/src/moscot/problems/cross_modality/_mixins.py @@ -220,8 +220,8 @@ def annotation_mapping( # type: ignore[misc] Key identifying the target distribution. batch_size Number of rows/columns of the cost matrix to materialize during :meth:`push` or :meth:`pull`. - If :obj:`None`, the entire cost matrix will be materialized. Larger value will require more memory. + If :obj:`None`, the entire cost matrix will be materialized. cell_transition_kwargs Keyword arguments for :meth:`cell_transition`, used only if ``mapping_mode = 'sum'``. diff --git a/src/moscot/problems/space/_mixins.py b/src/moscot/problems/space/_mixins.py index 80f60d53b..d351eb821 100644 --- a/src/moscot/problems/space/_mixins.py +++ b/src/moscot/problems/space/_mixins.py @@ -317,8 +317,8 @@ def annotation_mapping( # type: ignore[misc] Key identifying the target distribution. batch_size Number of rows/columns of the cost matrix to materialize during :meth:`push` or :meth:`pull`. - If :obj:`None`, the entire cost matrix will be materialized. Larger value will require more memory. + If :obj:`None`, the entire cost matrix will be materialized. cell_transition_kwargs Keyword arguments for :meth:`cell_transition`, used only if ``mapping_mode = 'sum'``. @@ -660,8 +660,8 @@ def annotation_mapping( # type: ignore[misc] Key identifying the target distribution. batch_size Number of rows/columns of the cost matrix to materialize during :meth:`push` or :meth:`pull`. - If :obj:`None`, the entire cost matrix will be materialized. Larger value will require more memory. + If :obj:`None`, the entire cost matrix will be materialized. cell_transition_kwargs Keyword arguments for :meth:`cell_transition`, used only if ``mapping_mode = 'sum'``. diff --git a/src/moscot/problems/time/_mixins.py b/src/moscot/problems/time/_mixins.py index 117bec531..c2e940e79 100644 --- a/src/moscot/problems/time/_mixins.py +++ b/src/moscot/problems/time/_mixins.py @@ -273,8 +273,8 @@ def annotation_mapping( Key identifying the target distribution. batch_size Number of rows/columns of the cost matrix to materialize during :meth:`push` or :meth:`pull`. - If :obj:`None`, the entire cost matrix will be materialized. Larger value will require more memory. + If :obj:`None`, the entire cost matrix will be materialized. cell_transition_kwargs Keyword arguments for :meth:`cell_transition`, used only if ``mapping_mode = 'sum'``.