Skip to content

Commit 4d3c8db

Browse files
author
Arina Danilina
committed
fully passing tests
1 parent 93b2e8d commit 4d3c8db

File tree

7 files changed

+125
-82
lines changed

7 files changed

+125
-82
lines changed

src/moscot/base/problems/_mixins.py

+43-12
Original file line numberDiff line numberDiff line change
@@ -310,8 +310,9 @@ def _annotation_mapping(
310310
target: K,
311311
key: str | None = None,
312312
forward: bool = True,
313-
other_adata: Optional[str] = None,
313+
other_adata: str | None = None,
314314
scale_by_marginals: bool = True,
315+
batch_size: int | None = None,
315316
cell_transition_kwargs: Mapping[str, Any] = types.MappingProxyType({}),
316317
) -> pd.DataFrame:
317318
if mapping_mode == "sum":
@@ -321,37 +322,67 @@ def _annotation_mapping(
321322
cell_transition_kwargs.setdefault("source", source)
322323
cell_transition_kwargs.setdefault("target", target)
323324
cell_transition_kwargs.setdefault("other_adata", other_adata)
324-
cell_transition_kwargs.setdefault("forward", forward)
325+
cell_transition_kwargs.setdefault("forward", not forward)
325326
if forward:
326-
cell_transition_kwargs.setdefault("source_groups", None)
327-
cell_transition_kwargs.setdefault("target_groups", annotation_label)
328-
axis = 1 # columns
329-
else:
330327
cell_transition_kwargs.setdefault("source_groups", annotation_label)
331328
cell_transition_kwargs.setdefault("target_groups", None)
332329
axis = 0 # rows
330+
else:
331+
cell_transition_kwargs.setdefault("source_groups", None)
332+
cell_transition_kwargs.setdefault("target_groups", annotation_label)
333+
axis = 1 # columns
333334
out: pd.DataFrame = self._cell_transition(**cell_transition_kwargs)
334335
return out.idxmax(axis=axis).to_frame(name=annotation_label)
335336
if mapping_mode == "max":
337+
out = []
336338
if forward:
337339
source_df = _get_df_cell_transition(
338340
self.adata,
339341
annotation_keys=[annotation_label],
340342
filter_key=key,
341343
filter_value=source,
342344
)
343-
dummy = pd.get_dummies(source_df, prefix="", prefix_sep="")
344-
out: ArrayLike = self[(source, target)].push(dummy, scale_by_marginals=scale_by_marginals)
345+
out_len = self[(source, target)].solution.shape[1]
346+
batch_size = batch_size if batch_size is not None else out_len
347+
for batch in range(0, out_len, batch_size):
348+
tm_batch = self.push(
349+
source=source,
350+
target=target,
351+
data=None,
352+
subset=(batch, batch_size),
353+
normalize=True,
354+
return_all=False,
355+
scale_by_marginals=scale_by_marginals,
356+
split_mass=True,
357+
key_added=None,
358+
)
359+
v = np.array(tm_batch.argmax(1))
360+
out.extend(source_df[annotation_label][v[i]] for i in range(len(v)))
361+
345362
else:
346363
target_df = _get_df_cell_transition(
347364
self.adata if other_adata is None else other_adata,
348365
annotation_keys=[annotation_label],
349366
filter_key=key,
350367
filter_value=target,
351368
)
352-
dummy = pd.get_dummies(target_df, prefix="", prefix_sep="")
353-
out: ArrayLike = self[(source, target)].pull(dummy, scale_by_marginals=scale_by_marginals)
354-
categories = pd.Categorical([dummy.columns[i] for i in np.array(out.argmax(1))])
369+
out_len = self[(source, target)].solution.shape[0]
370+
batch_size = batch_size if batch_size is not None else out_len
371+
for batch in range(0, out_len, batch_size):
372+
tm_batch = self.pull(
373+
source=source,
374+
target=target,
375+
data=None,
376+
subset=(batch, batch_size),
377+
normalize=True,
378+
return_all=False,
379+
scale_by_marginals=scale_by_marginals,
380+
split_mass=True,
381+
key_added=None,
382+
)
383+
v = np.array(tm_batch.argmax(1))
384+
out.extend(target_df[annotation_label][v[i]] for i in range(len(v)))
385+
categories = pd.Categorical(out)
355386
return pd.DataFrame(categories, columns=[annotation_label])
356387
raise NotImplementedError(f"Mapping mode `{mapping_mode!r}` is not yet implemented.")
357388

@@ -507,7 +538,7 @@ def _cell_aggregation_transition(
507538
if batch_size is None:
508539
batch_size = len(df_2)
509540
for batch in range(0, len(df_2), batch_size):
510-
result = func( # TODO(@MUCDK) check how to make compatiAnalysisMixinProtocolcelltyble with all policies
541+
result = func( # TODO(@MUCDK) check how to make compatible with all policies
511542
source=source,
512543
target=target,
513544
data=None,

src/moscot/problems/cross_modality/_mixins.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,7 @@ def annotation_mapping(
202202
target=target,
203203
key=self.batch_key,
204204
forward=forward,
205-
other_adata=self.adata_tgt if forward else self.adata_src,
205+
other_adata=self.adata_tgt,
206206
scale_by_marginals=scale_by_marginals,
207207
cell_transition_kwargs=cell_transition_kwargs,
208208
)

src/moscot/problems/space/_mixins.py

+1-15
Original file line numberDiff line numberDiff line change
@@ -604,26 +604,12 @@ def annotation_mapping(
604604
scale_by_marginals: bool = True,
605605
cell_transition_kwargs: Mapping[str, Any] = types.MappingProxyType({}),
606606
) -> pd.DataFrame:
607-
"""
608-
609-
Notes
610-
-----
611-
If forward is True, it means that the annotation columns (annotation label) needs to be in the target adata,
612-
If forward is False, it means that the annotation column (annotation label) needs to be in the source adata.
613-
"""
614-
cell_transition_kwargs = dict(cell_transition_kwargs)
615-
if forward:
616-
cell_transition_kwargs.setdefault("source_groups", annotation_label)
617-
cell_transition_kwargs.setdefault("target_groups", None)
618-
else:
619-
cell_transition_kwargs.setdefault("source_groups", None)
620-
cell_transition_kwargs.setdefault("target_groups", annotation_label)
621607
return self._annotation_mapping(
622608
mapping_mode=mapping_mode,
623609
annotation_label=annotation_label,
624610
source=source,
625611
target=target,
626-
forward=not forward if mapping_mode == "sum" else forward,
612+
forward=forward,
627613
key=self.batch_key,
628614
other_adata=self.adata_sc,
629615
scale_by_marginals=scale_by_marginals,

tests/conftest.py

+27-19
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from math import cos, sin
2-
from typing import Literal, Optional, Tuple
2+
from typing import Literal, Optional, Tuple, Union
33

44
import pytest
55

@@ -211,15 +211,22 @@ def adata_translation_split(adata_translation) -> Tuple[AnnData, AnnData]:
211211
@pytest.fixture()
212212
def adata_anno(
213213
problem_kind: Literal["temporal", "cross_modality", "alignment", "mapping"],
214-
# forward: bool
215-
) -> AnnData | Tuple[AnnData, AnnData]:
214+
) -> Union[AnnData, Tuple[AnnData, AnnData]]:
216215
rng = np.random.RandomState(31)
217216
adata_src = AnnData(X=csr_matrix(rng.normal(size=(10, 60))))
218-
adata_src.obs["celltype"] = _gt_source_annotation
219-
adata_src.obs["celltype"] = adata_src.obs["celltype"].astype("category")
220-
adata_src.uns["expected_max"] = _gt_target_max_annotation
221-
adata_src.uns["expected_sum"] = _gt_target_sum_annotation
217+
rng_src = rng.choice(["A", "B", "C"], size=5).tolist()
218+
adata_src.obs["celltype1"] = ["C", "C", "A", "B", "B"] + rng_src
219+
adata_src.obs["celltype1"] = adata_src.obs["celltype1"].astype("category")
220+
adata_src.uns["expected_max1"] = ["C", "C", "A", "B", "B"] + rng_src + rng_src
221+
adata_src.uns["expected_sum1"] = ["C", "C", "B", "B", "B"] + rng_src + rng_src
222+
222223
adata_tgt = AnnData(X=csr_matrix(rng.normal(size=(15, 60))))
224+
rng_tgt = rng.choice(["A", "B", "C"], size=5).tolist()
225+
adata_tgt.obs["celltype2"] = ["C", "C", "A", "B", "B"] + rng_tgt + rng_tgt
226+
adata_tgt.obs["celltype2"] = adata_tgt.obs["celltype2"].astype("category")
227+
adata_tgt.uns["expected_max2"] = ["C", "C", "A", "B", "B"] + rng_tgt
228+
adata_tgt.uns["expected_sum2"] = ["C", "C", "B", "B", "B"] + rng_tgt
229+
223230
if problem_kind == "cross_modality":
224231
adata_src.obs["batch"] = "0"
225232
adata_tgt.obs["batch"] = "1"
@@ -228,32 +235,33 @@ def adata_anno(
228235
sc.pp.pca(adata_src)
229236
sc.pp.pca(adata_tgt)
230237
return adata_src, adata_tgt
231-
if problem_kind in ["alignment", "mapping"]:
238+
if problem_kind == "mapping":
239+
adata_src.obs["batch"] = "0"
240+
adata_tgt.obs["batch"] = "1"
241+
sc.pp.pca(adata_src)
242+
sc.pp.pca(adata_tgt)
243+
adata_tgt.obsm["spatial"] = rng.normal(size=(adata_tgt.n_obs, 2))
244+
return adata_src, adata_tgt
245+
if problem_kind == "alignment":
232246
adata_src.obsm["spatial"] = rng.normal(size=(adata_src.n_obs, 2))
233247
adata_tgt.obsm["spatial"] = rng.normal(size=(adata_tgt.n_obs, 2))
234248
key = "day" if problem_kind == "temporal" else "batch"
235-
adatas = [adata_src, adata_tgt] # if forward else [adata_tgt, adata_src]
249+
adatas = [adata_src, adata_tgt]
236250
adata = ad.concat(adatas, join="outer", label=key, index_unique="-", uns_merge="unique")
237251
adata.obs[key] = (pd.to_numeric(adata.obs[key]) if key == "day" else adata.obs[key]).astype("category")
238252
adata.layers["counts"] = adata.X.A
239253
sc.pp.pca(adata)
240254
return adata
241255

242256

243-
_gt_source_annotation = np.array(["A", "A", "B", "A", "B", "C", "A", "A", "A", "A"], dtype="U1")
244-
245-
_gt_target_max_annotation = np.array(["A", "A", "B", "A", "B", "C", "A", "A", "A", "A", "B", "B", "B", "B", "B"])
246-
247-
_gt_target_sum_annotation = np.array(["A", "A", "B", "A", "B", "C", "A", "A", "A", "A", "A", "A", "A", "A", "A"])
248-
249-
250257
@pytest.fixture()
251258
def gt_tm_annotation() -> np.ndarray:
252259
tm = np.zeros((10, 15))
253260
for i in range(10):
254261
tm[i][i] = 1
255262
for i in range(10, 15):
256-
tm[0][i] = 0.3
257-
tm[1][i] = 0.3
258-
tm[2][i] = 0.4
263+
tm[i-5][i] = 1
264+
for j in range(2,5):
265+
for i in range(2,5):
266+
tm[i][j] = 0.3 if i != j else 0.4
259267
return tm

tests/problems/cross_modality/test_mixins.py

+11-12
Original file line numberDiff line numberDiff line change
@@ -108,29 +108,28 @@ def test_cell_transition_pipeline(
108108
pd.testing.assert_frame_equal(result1, result2)
109109

110110
@pytest.mark.fast()
111-
@pytest.mark.parametrize("forward", [True]) # , False])
112-
@pytest.mark.parametrize(
113-
"mapping_mode",
114-
[
115-
"max",
116-
],
117-
) # "sum"])
111+
@pytest.mark.parametrize("forward", [True, False])
112+
@pytest.mark.parametrize("mapping_mode",["max", "sum"])
118113
@pytest.mark.parametrize("problem_kind", ["cross_modality"])
119114
def test_annotation_mapping(
120115
self, adata_anno: Tuple[AnnData, AnnData], forward: bool, mapping_mode, gt_tm_annotation
121116
):
122-
rng = np.random.RandomState(0)
123117
adata_src, adata_tgt = adata_anno
124118
tp = TranslationProblem(adata_src, adata_tgt)
125119
tp = tp.prepare(src_attr="emb_src", tgt_attr="emb_tgt")
126120
problem_keys = ("src", "tgt")
127121
assert set(tp.problems.keys()) == {problem_keys}
128122
tp[problem_keys].set_solution(MockSolverOutput(gt_tm_annotation), overwrite=True)
129-
123+
annotation_label = "celltype1" if forward else "celltype2"
130124
result = tp.annotation_mapping(
131125
mapping_mode=mapping_mode,
132-
annotation_label="celltype",
126+
annotation_label=annotation_label,
133127
forward=forward,
128+
source="src",
129+
target="tgt"
134130
)
135-
expected_result = adata_src.uns["expected_max"] if mapping_mode == "max" else adata_src.uns["expected_sum"]
136-
assert (result["celltype"] == expected_result).all()
131+
if forward:
132+
expected_result = adata_src.uns["expected_max1"] if mapping_mode == "max" else adata_src.uns["expected_sum1"]
133+
else:
134+
expected_result = adata_tgt.uns["expected_max2"] if mapping_mode == "max" else adata_tgt.uns["expected_sum2"]
135+
assert (result[annotation_label] == expected_result).all()

tests/problems/space/test_mixins.py

+32-12
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,29 @@ def test_cell_transition_pipeline(self, adata_space_rotate: AnnData, forward: bo
9393
assert isinstance(result, pd.DataFrame)
9494
assert result.shape == (3, 3)
9595

96+
@pytest.mark.fast()
97+
@pytest.mark.parametrize("forward", [True, False])
98+
@pytest.mark.parametrize("mapping_mode", ["max", "sum"])
99+
@pytest.mark.parametrize("problem_kind", ["alignment"])
100+
def test_annotation_mapping(self, adata_anno: AnnData, forward: bool, mapping_mode, gt_tm_annotation):
101+
ap = AlignmentProblem(adata=adata_anno)
102+
ap = ap.prepare(batch_key="batch", joint_attr={"attr": "X"})
103+
problem_keys = ("0", "1")
104+
assert set(ap.problems.keys()) == {problem_keys}
105+
ap[problem_keys].set_solution(MockSolverOutput(gt_tm_annotation))
106+
annotation_label = "celltype1" if forward else "celltype2"
107+
result = ap.annotation_mapping(
108+
mapping_mode=mapping_mode,
109+
annotation_label=annotation_label,
110+
source="0",
111+
target="1",
112+
forward=forward,
113+
)
114+
if forward:
115+
expected_result = adata_anno.uns["expected_max1"] if mapping_mode == "max" else adata_anno.uns["expected_sum1"]
116+
else:
117+
expected_result = adata_anno.uns["expected_max2"] if mapping_mode == "max" else adata_anno.uns["expected_sum2"]
118+
assert (result[annotation_label] == expected_result).all()
96119

97120
class TestSpatialMappingAnalysisMixin:
98121
@pytest.mark.parametrize("sc_attr", [{"attr": "X"}, {"attr": "obsm", "key": "X_pca"}])
@@ -177,28 +200,25 @@ def test_cell_transition_pipeline(self, adata_mapping: AnnData, forward: bool, n
177200
assert result.shape == (3, 4)
178201

179202
@pytest.mark.fast()
180-
@pytest.mark.parametrize(
181-
"forward",
182-
[
183-
False,
184-
],
185-
) # True])
203+
@pytest.mark.parametrize("forward", [True, False])
186204
@pytest.mark.parametrize("mapping_mode", ["max", "sum"])
187205
@pytest.mark.parametrize("problem_kind", ["mapping"])
188206
def test_annotation_mapping(self, adata_anno: AnnData, forward: bool, mapping_mode, gt_tm_annotation):
189-
rng = np.random.RandomState(0)
190-
adataref, adatasp = _adata_spatial_split(adata_anno)
207+
adataref, adatasp = adata_anno
191208
mp = MappingProblem(adataref, adatasp)
192209
mp = mp.prepare(sc_attr={"attr": "obsm", "key": "X_pca"}, joint_attr={"attr": "X"})
193210
problem_keys = ("src", "tgt")
194211
assert set(mp.problems.keys()) == {problem_keys}
195212
mp[problem_keys].set_solution(MockSolverOutput(gt_tm_annotation.T))
196-
213+
annotation_label = "celltype1" if not forward else "celltype2"
197214
result = mp.annotation_mapping(
198215
mapping_mode=mapping_mode,
199-
annotation_label="celltype",
216+
annotation_label=annotation_label,
200217
source="src",
201218
forward=forward,
202219
)
203-
expected_result = adataref.uns["expected_max"] if mapping_mode == "max" else adataref.uns["expected_sum"]
204-
assert (result["celltype"] == expected_result).all()
220+
if not forward:
221+
expected_result = adataref.uns["expected_max1"] if mapping_mode == "max" else adataref.uns["expected_sum1"]
222+
else:
223+
expected_result = adatasp.uns["expected_max2"] if mapping_mode == "max" else adatasp.uns["expected_sum2"]
224+
assert (result[annotation_label] == expected_result).all()

tests/problems/time/test_mixins.py

+10-11
Original file line numberDiff line numberDiff line change
@@ -51,25 +51,24 @@ def test_cell_transition_full_pipeline(self, gt_temporal_adata: AnnData, forward
5151
np.testing.assert_allclose(present_cell_type_marginal, 1.0)
5252

5353
@pytest.mark.fast()
54-
@pytest.mark.parametrize(
55-
"forward",
56-
[
57-
True,
58-
],
59-
) # False])
60-
@pytest.mark.parametrize("mapping_mode", ["max"]) # , "sum"])
54+
@pytest.mark.parametrize("forward",[True, False])
55+
@pytest.mark.parametrize("mapping_mode", ["max", "sum"])
6156
@pytest.mark.parametrize("problem_kind", ["temporal"])
6257
def test_annotation_mapping(self, adata_anno: AnnData, forward: bool, mapping_mode, gt_tm_annotation):
6358
problem = TemporalProblem(adata_anno)
6459
problem_keys = (0, 1)
6560
problem = problem.prepare(time_key="day", joint_attr="X_pca")
6661
assert set(problem.problems.keys()) == {problem_keys}
6762
problem[problem_keys]._solution = MockSolverOutput(gt_tm_annotation)
63+
annotation_label = "celltype1" if forward else "celltype2"
6864
result = problem.annotation_mapping(
69-
mapping_mode=mapping_mode, annotation_label="celltype", forward=forward, source=0, target=1
70-
)
71-
expected_result = adata_anno.uns["expected_max"] if mapping_mode == "max" else adata_anno.uns["expected_sum"]
72-
assert (result["celltype"] == expected_result).all()
65+
mapping_mode=mapping_mode, annotation_label=annotation_label, forward=forward, source=0, target=1
66+
)
67+
if forward:
68+
expected_result = adata_anno.uns["expected_max1"] if mapping_mode == "max" else adata_anno.uns["expected_sum1"]
69+
else:
70+
expected_result = adata_anno.uns["expected_max2"] if mapping_mode == "max" else adata_anno.uns["expected_sum2"]
71+
assert (result[annotation_label] == expected_result).all()
7372

7473
@pytest.mark.fast()
7574
@pytest.mark.parametrize("forward", [True, False])

0 commit comments

Comments
 (0)