From 6b6b4acb6395380a94f29a5c6e28ddb8d85a936a Mon Sep 17 00:00:00 2001 From: Arina Danilina Date: Mon, 29 Jan 2024 16:00:50 +0100 Subject: [PATCH 1/6] mtp push/pull for all problems --- src/moscot/plotting/_plotting.py | 6 ++++-- src/moscot/plotting/_utils.py | 4 ++++ src/moscot/problems/generic/_mixins.py | 10 +++++++++- 3 files changed, 17 insertions(+), 3 deletions(-) diff --git a/src/moscot/plotting/_plotting.py b/src/moscot/plotting/_plotting.py index 897040309..115fbb919 100644 --- a/src/moscot/plotting/_plotting.py +++ b/src/moscot/plotting/_plotting.py @@ -294,7 +294,8 @@ def push( fig = _plot_temporal( adata=adata, - temporal_key=data["temporal_key"], + temporal_key=data["temporal_key"] if "temporal_key" in data else None, + generic_key=data["key"] if "key" in data else None, key_stored=key, source=data["source"], target=data["target"], @@ -402,7 +403,8 @@ def pull( fig = _plot_temporal( adata=adata, - temporal_key=data["temporal_key"], + temporal_key=data["temporal_key"] if "temporal_key" in data else None, + generic_key=data["key"] if "key" in data else None, key_stored=key, source=data["source"], target=data["target"], diff --git a/src/moscot/plotting/_utils.py b/src/moscot/plotting/_utils.py index e4aaf7595..1fd188f69 100644 --- a/src/moscot/plotting/_utils.py +++ b/src/moscot/plotting/_utils.py @@ -379,6 +379,7 @@ def _input_to_adatas( def _plot_temporal( adata: AnnData, temporal_key: str, + generic_key: str, key_stored: str, source: float, target: float, @@ -431,6 +432,9 @@ def _plot_temporal( else: titles = [ f"{categories if categories is not None else 'Cells'} at time {source if push else target} and {name}" + ] if generic_key is None else [ + f"{'Push' if push else 'Pull'} {categories if categories is not None else 'cells'} " + + f"from {source if push else target} to {target if push else source}" ] for i, ax in enumerate(axs): # we need to create adata_view because otherwise the view of the adata is copied in the next step i+1 diff --git a/src/moscot/problems/generic/_mixins.py b/src/moscot/problems/generic/_mixins.py index 4b20d7537..a9070e2ac 100644 --- a/src/moscot/problems/generic/_mixins.py +++ b/src/moscot/problems/generic/_mixins.py @@ -170,7 +170,11 @@ def push( if TYPE_CHECKING: assert isinstance(key_added, str) plot_vars = { - "distribution_key": self.batch_key, + "source": source, + "target": target, + "key": self.batch_key, + "data": data if isinstance(data, str) else None, + "subset": subset, } self.adata.obs[key_added] = self._flatten(result, key=self.batch_key) set_plotting_vars(self.adata, _constants.PUSH, key=key_added, value=plot_vars) @@ -233,6 +237,10 @@ def pull( if key_added is not None: plot_vars = { "key": self.batch_key, + "data": data if isinstance(data, str) else None, + "subset": subset, + "source": source, + "target": target, } self.adata.obs[key_added] = self._flatten(result, key=self.batch_key) set_plotting_vars(self.adata, _constants.PULL, key=key_added, value=plot_vars) From 88c06f50b6fce05c853f35104712dd96cee9541d Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 29 Jan 2024 15:03:02 +0000 Subject: [PATCH 2/6] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/moscot/plotting/_utils.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/src/moscot/plotting/_utils.py b/src/moscot/plotting/_utils.py index 1fd188f69..7e9a150b8 100644 --- a/src/moscot/plotting/_utils.py +++ b/src/moscot/plotting/_utils.py @@ -430,12 +430,14 @@ def _plot_temporal( titles = [f"{categories if categories is not None else 'Cells'} at time {source if push else target}"] titles.extend([f"{name} at time {time_points[i]}" for i in range(1, len(time_points))]) else: - titles = [ - f"{categories if categories is not None else 'Cells'} at time {source if push else target} and {name}" - ] if generic_key is None else [ - f"{'Push' if push else 'Pull'} {categories if categories is not None else 'cells'} " + - f"from {source if push else target} to {target if push else source}" - ] + titles = ( + [f"{categories if categories is not None else 'Cells'} at time {source if push else target} and {name}"] + if generic_key is None + else [ + f"{'Push' if push else 'Pull'} {categories if categories is not None else 'cells'} " + + f"from {source if push else target} to {target if push else source}" + ] + ) for i, ax in enumerate(axs): # we need to create adata_view because otherwise the view of the adata is copied in the next step i+1 with RandomKeys(adata, n=2, where="obs") as keys: From da532d9c473e9d67ecf1009dbb177a4dbd8381c5 Mon Sep 17 00:00:00 2001 From: Arina Danilina Date: Tue, 6 Feb 2024 08:45:12 +0100 Subject: [PATCH 3/6] renaming --- src/moscot/plotting/_plotting.py | 10 ++++------ src/moscot/plotting/_utils.py | 5 ++--- src/moscot/problems/time/_mixins.py | 4 ++-- 3 files changed, 8 insertions(+), 11 deletions(-) diff --git a/src/moscot/plotting/_plotting.py b/src/moscot/plotting/_plotting.py index 23a5c8132..a649e464a 100644 --- a/src/moscot/plotting/_plotting.py +++ b/src/moscot/plotting/_plotting.py @@ -23,7 +23,7 @@ _create_col_colors, _heatmap, _input_to_adatas, - _plot_temporal, + _plot_scatter, _sankey, get_plotting_vars, ) @@ -292,10 +292,9 @@ def push( if data["data"] is not None and data["subset"] is not None and cmap is None: cmap = _create_col_colors(adata, data["data"], data["subset"]) - fig = _plot_temporal( + fig = _plot_scatter( adata=adata, - temporal_key=data["temporal_key"] if "temporal_key" in data else None, - generic_key=data["key"] if "key" in data else None, + generic_key=data["key"], key_stored=key, source=data["source"], target=data["target"], @@ -401,9 +400,8 @@ def pull( if data["data"] is not None and data["subset"] is not None and cmap is None: cmap = _create_col_colors(adata, data["data"], data["subset"]) - fig = _plot_temporal( + fig = _plot_scatter( adata=adata, - temporal_key=data["temporal_key"] if "temporal_key" in data else None, generic_key=data["key"] if "key" in data else None, key_stored=key, source=data["source"], diff --git a/src/moscot/plotting/_utils.py b/src/moscot/plotting/_utils.py index 7e9a150b8..01a3fbb10 100644 --- a/src/moscot/plotting/_utils.py +++ b/src/moscot/plotting/_utils.py @@ -376,9 +376,8 @@ def _input_to_adatas( raise ValueError(f"Unable to interpret input of type `{type(inp)}`.") -def _plot_temporal( +def _plot_scatter( adata: AnnData, - temporal_key: str, generic_key: str, key_stored: str, source: float, @@ -452,7 +451,7 @@ def _plot_temporal( adata_view = adata else: tmp = np.full(len(adata), constant_fill_value) - mask = adata.obs[temporal_key] == time_points[i] + mask = adata.obs[generic_key] == time_points[i] tmp[mask] = adata[mask].obs[key_stored] if scale: diff --git a/src/moscot/problems/time/_mixins.py b/src/moscot/problems/time/_mixins.py index 883efc8f8..8b9401a57 100644 --- a/src/moscot/problems/time/_mixins.py +++ b/src/moscot/problems/time/_mixins.py @@ -462,7 +462,7 @@ def push( plot_vars = { "source": source, "target": target, - "temporal_key": self.temporal_key, + "key": self.temporal_key, "data": data if isinstance(data, str) else None, "subset": subset, } @@ -526,7 +526,7 @@ def pull( if key_added is not None: plot_vars = { - "temporal_key": self.temporal_key, + "key": self.temporal_key, "data": data if isinstance(data, str) else None, "subset": subset, "source": source, From eb930fb3588925a3af6f42dee70e0a1c7a8c02c0 Mon Sep 17 00:00:00 2001 From: Arina Danilina Date: Tue, 20 Feb 2024 08:03:48 +0100 Subject: [PATCH 4/6] universal title and ruff fix --- src/moscot/plotting/_plotting.py | 2 +- src/moscot/plotting/_utils.py | 8 ++------ 2 files changed, 3 insertions(+), 7 deletions(-) diff --git a/src/moscot/plotting/_plotting.py b/src/moscot/plotting/_plotting.py index a649e464a..363f2cd1a 100644 --- a/src/moscot/plotting/_plotting.py +++ b/src/moscot/plotting/_plotting.py @@ -402,7 +402,7 @@ def pull( fig = _plot_scatter( adata=adata, - generic_key=data["key"] if "key" in data else None, + generic_key=data["key"], key_stored=key, source=data["source"], target=data["target"], diff --git a/src/moscot/plotting/_utils.py b/src/moscot/plotting/_utils.py index 01a3fbb10..277ad621f 100644 --- a/src/moscot/plotting/_utils.py +++ b/src/moscot/plotting/_utils.py @@ -429,14 +429,10 @@ def _plot_scatter( titles = [f"{categories if categories is not None else 'Cells'} at time {source if push else target}"] titles.extend([f"{name} at time {time_points[i]}" for i in range(1, len(time_points))]) else: - titles = ( - [f"{categories if categories is not None else 'Cells'} at time {source if push else target} and {name}"] - if generic_key is None - else [ + titles = ([ f"{'Push' if push else 'Pull'} {categories if categories is not None else 'cells'} " + f"from {source if push else target} to {target if push else source}" - ] - ) + ]) for i, ax in enumerate(axs): # we need to create adata_view because otherwise the view of the adata is copied in the next step i+1 with RandomKeys(adata, n=2, where="obs") as keys: From 6f37cb1ee1ed714898f1c1088d0c7a063fe7166d Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 20 Feb 2024 07:04:28 +0000 Subject: [PATCH 5/6] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/moscot/plotting/_utils.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/moscot/plotting/_utils.py b/src/moscot/plotting/_utils.py index 277ad621f..358ae5fa2 100644 --- a/src/moscot/plotting/_utils.py +++ b/src/moscot/plotting/_utils.py @@ -429,10 +429,10 @@ def _plot_scatter( titles = [f"{categories if categories is not None else 'Cells'} at time {source if push else target}"] titles.extend([f"{name} at time {time_points[i]}" for i in range(1, len(time_points))]) else: - titles = ([ - f"{'Push' if push else 'Pull'} {categories if categories is not None else 'cells'} " - + f"from {source if push else target} to {target if push else source}" - ]) + titles = [ + f"{'Push' if push else 'Pull'} {categories if categories is not None else 'cells'} " + + f"from {source if push else target} to {target if push else source}" + ] for i, ax in enumerate(axs): # we need to create adata_view because otherwise the view of the adata is copied in the next step i+1 with RandomKeys(adata, n=2, where="obs") as keys: From 9b05edadd778aa4c55d7c5f7bfa07fa27eebb049 Mon Sep 17 00:00:00 2001 From: Arina Danilina Date: Tue, 20 Feb 2024 08:36:10 +0100 Subject: [PATCH 6/6] key name in tests --- tests/plotting/conftest.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/plotting/conftest.py b/tests/plotting/conftest.py index 017992e29..50ba864bf 100644 --- a/tests/plotting/conftest.py +++ b/tests/plotting/conftest.py @@ -41,7 +41,7 @@ def adata_pl_cell_transition(gt_temporal_adata: AnnData) -> AnnData: @pytest.fixture() def adata_pl_push(adata_time: AnnData) -> AnnData: rng = np.random.RandomState(0) - plot_vars = {"temporal_key": "time", "data": "celltype", "subset": "A", "source": 0, "target": 1} + plot_vars = {"key": "time", "data": "celltype", "subset": "A", "source": 0, "target": 1} adata_time.uns["celltype_colors"] = ["#cc1b1b", "#2ccc1b", "#cc1bcc"] adata_time.obs["celltype"] = adata_time.obs["celltype"].astype("category") set_plotting_vars(adata_time, _constants.PUSH, key=_constants.PUSH, value=plot_vars) @@ -60,7 +60,7 @@ def adata_pl_push(adata_time: AnnData) -> AnnData: @pytest.fixture() def adata_pl_pull(adata_time: AnnData) -> AnnData: rng = np.random.RandomState(0) - plot_vars = {"temporal_key": "time", "data": "celltype", "subset": "A", "source": 0, "target": 1} + plot_vars = {"key": "time", "data": "celltype", "subset": "A", "source": 0, "target": 1} adata_time.uns["celltype_colors"] = ["#cc1b1b", "#2ccc1b", "#cc1bcc"] adata_time.obs["celltype"] = adata_time.obs["celltype"].astype("category") set_plotting_vars(adata_time, _constants.PULL, key=_constants.PULL, value=plot_vars)