diff --git a/src/moscot/plotting/_plotting.py b/src/moscot/plotting/_plotting.py index afff73c71..363f2cd1a 100644 --- a/src/moscot/plotting/_plotting.py +++ b/src/moscot/plotting/_plotting.py @@ -23,7 +23,7 @@ _create_col_colors, _heatmap, _input_to_adatas, - _plot_temporal, + _plot_scatter, _sankey, get_plotting_vars, ) @@ -292,9 +292,9 @@ def push( if data["data"] is not None and data["subset"] is not None and cmap is None: cmap = _create_col_colors(adata, data["data"], data["subset"]) - fig = _plot_temporal( + fig = _plot_scatter( adata=adata, - temporal_key=data["temporal_key"], + generic_key=data["key"], key_stored=key, source=data["source"], target=data["target"], @@ -400,9 +400,9 @@ def pull( if data["data"] is not None and data["subset"] is not None and cmap is None: cmap = _create_col_colors(adata, data["data"], data["subset"]) - fig = _plot_temporal( + fig = _plot_scatter( adata=adata, - temporal_key=data["temporal_key"], + generic_key=data["key"], key_stored=key, source=data["source"], target=data["target"], diff --git a/src/moscot/plotting/_utils.py b/src/moscot/plotting/_utils.py index e4aaf7595..358ae5fa2 100644 --- a/src/moscot/plotting/_utils.py +++ b/src/moscot/plotting/_utils.py @@ -376,9 +376,9 @@ def _input_to_adatas( raise ValueError(f"Unable to interpret input of type `{type(inp)}`.") -def _plot_temporal( +def _plot_scatter( adata: AnnData, - temporal_key: str, + generic_key: str, key_stored: str, source: float, target: float, @@ -430,7 +430,8 @@ def _plot_temporal( titles.extend([f"{name} at time {time_points[i]}" for i in range(1, len(time_points))]) else: titles = [ - f"{categories if categories is not None else 'Cells'} at time {source if push else target} and {name}" + f"{'Push' if push else 'Pull'} {categories if categories is not None else 'cells'} " + + f"from {source if push else target} to {target if push else source}" ] for i, ax in enumerate(axs): # we need to create adata_view because otherwise the view of the adata is copied in the next step i+1 @@ -446,7 +447,7 @@ def _plot_temporal( adata_view = adata else: tmp = np.full(len(adata), constant_fill_value) - mask = adata.obs[temporal_key] == time_points[i] + mask = adata.obs[generic_key] == time_points[i] tmp[mask] = adata[mask].obs[key_stored] if scale: diff --git a/src/moscot/problems/generic/_mixins.py b/src/moscot/problems/generic/_mixins.py index 5acf7f08e..6580c2ad5 100644 --- a/src/moscot/problems/generic/_mixins.py +++ b/src/moscot/problems/generic/_mixins.py @@ -169,7 +169,11 @@ def push( if TYPE_CHECKING: assert isinstance(key_added, str) plot_vars = { - "distribution_key": self.batch_key, + "source": source, + "target": target, + "key": self.batch_key, + "data": data if isinstance(data, str) else None, + "subset": subset, } self.adata.obs[key_added] = self._flatten(result, key=self.batch_key) set_plotting_vars(self.adata, _constants.PUSH, key=key_added, value=plot_vars) @@ -232,6 +236,10 @@ def pull( if key_added is not None: plot_vars = { "key": self.batch_key, + "data": data if isinstance(data, str) else None, + "subset": subset, + "source": source, + "target": target, } self.adata.obs[key_added] = self._flatten(result, key=self.batch_key) set_plotting_vars(self.adata, _constants.PULL, key=key_added, value=plot_vars) diff --git a/src/moscot/problems/time/_mixins.py b/src/moscot/problems/time/_mixins.py index 883efc8f8..8b9401a57 100644 --- a/src/moscot/problems/time/_mixins.py +++ b/src/moscot/problems/time/_mixins.py @@ -462,7 +462,7 @@ def push( plot_vars = { "source": source, "target": target, - "temporal_key": self.temporal_key, + "key": self.temporal_key, "data": data if isinstance(data, str) else None, "subset": subset, } @@ -526,7 +526,7 @@ def pull( if key_added is not None: plot_vars = { - "temporal_key": self.temporal_key, + "key": self.temporal_key, "data": data if isinstance(data, str) else None, "subset": subset, "source": source, diff --git a/tests/plotting/conftest.py b/tests/plotting/conftest.py index 017992e29..50ba864bf 100644 --- a/tests/plotting/conftest.py +++ b/tests/plotting/conftest.py @@ -41,7 +41,7 @@ def adata_pl_cell_transition(gt_temporal_adata: AnnData) -> AnnData: @pytest.fixture() def adata_pl_push(adata_time: AnnData) -> AnnData: rng = np.random.RandomState(0) - plot_vars = {"temporal_key": "time", "data": "celltype", "subset": "A", "source": 0, "target": 1} + plot_vars = {"key": "time", "data": "celltype", "subset": "A", "source": 0, "target": 1} adata_time.uns["celltype_colors"] = ["#cc1b1b", "#2ccc1b", "#cc1bcc"] adata_time.obs["celltype"] = adata_time.obs["celltype"].astype("category") set_plotting_vars(adata_time, _constants.PUSH, key=_constants.PUSH, value=plot_vars) @@ -60,7 +60,7 @@ def adata_pl_push(adata_time: AnnData) -> AnnData: @pytest.fixture() def adata_pl_pull(adata_time: AnnData) -> AnnData: rng = np.random.RandomState(0) - plot_vars = {"temporal_key": "time", "data": "celltype", "subset": "A", "source": 0, "target": 1} + plot_vars = {"key": "time", "data": "celltype", "subset": "A", "source": 0, "target": 1} adata_time.uns["celltype_colors"] = ["#cc1b1b", "#2ccc1b", "#cc1bcc"] adata_time.obs["celltype"] = adata_time.obs["celltype"].astype("category") set_plotting_vars(adata_time, _constants.PULL, key=_constants.PULL, value=plot_vars)