1
1
from math import cos , sin
2
- from typing import Literal , Optional , Tuple
2
+ from typing import Literal , Optional , Tuple , Union
3
3
4
4
import pytest
5
5
@@ -211,15 +211,22 @@ def adata_translation_split(adata_translation) -> Tuple[AnnData, AnnData]:
211
211
@pytest .fixture ()
212
212
def adata_anno (
213
213
problem_kind : Literal ["temporal" , "cross_modality" , "alignment" , "mapping" ],
214
- # forward: bool
215
- ) -> AnnData | Tuple [AnnData , AnnData ]:
214
+ ) -> Union [AnnData , Tuple [AnnData , AnnData ]]:
216
215
rng = np .random .RandomState (31 )
217
216
adata_src = AnnData (X = csr_matrix (rng .normal (size = (10 , 60 ))))
218
- adata_src .obs ["celltype" ] = _gt_source_annotation
219
- adata_src .obs ["celltype" ] = adata_src .obs ["celltype" ].astype ("category" )
220
- adata_src .uns ["expected_max" ] = _gt_target_max_annotation
221
- adata_src .uns ["expected_sum" ] = _gt_target_sum_annotation
217
+ rng_src = rng .choice (["A" , "B" , "C" ], size = 5 ).tolist ()
218
+ adata_src .obs ["celltype1" ] = ["C" , "C" , "A" , "B" , "B" ] + rng_src
219
+ adata_src .obs ["celltype1" ] = adata_src .obs ["celltype1" ].astype ("category" )
220
+ adata_src .uns ["expected_max1" ] = ["C" , "C" , "A" , "B" , "B" ] + rng_src + rng_src
221
+ adata_src .uns ["expected_sum1" ] = ["C" , "C" , "B" , "B" , "B" ] + rng_src + rng_src
222
+
222
223
adata_tgt = AnnData (X = csr_matrix (rng .normal (size = (15 , 60 ))))
224
+ rng_tgt = rng .choice (["A" , "B" , "C" ], size = 5 ).tolist ()
225
+ adata_tgt .obs ["celltype2" ] = ["C" , "C" , "A" , "B" , "B" ] + rng_tgt + rng_tgt
226
+ adata_tgt .obs ["celltype2" ] = adata_tgt .obs ["celltype2" ].astype ("category" )
227
+ adata_tgt .uns ["expected_max2" ] = ["C" , "C" , "A" , "B" , "B" ] + rng_tgt
228
+ adata_tgt .uns ["expected_sum2" ] = ["C" , "C" , "B" , "B" , "B" ] + rng_tgt
229
+
223
230
if problem_kind == "cross_modality" :
224
231
adata_src .obs ["batch" ] = "0"
225
232
adata_tgt .obs ["batch" ] = "1"
@@ -228,32 +235,33 @@ def adata_anno(
228
235
sc .pp .pca (adata_src )
229
236
sc .pp .pca (adata_tgt )
230
237
return adata_src , adata_tgt
231
- if problem_kind in ["alignment" , "mapping" ]:
238
+ if problem_kind == "mapping" :
239
+ adata_src .obs ["batch" ] = "0"
240
+ adata_tgt .obs ["batch" ] = "1"
241
+ sc .pp .pca (adata_src )
242
+ sc .pp .pca (adata_tgt )
243
+ adata_tgt .obsm ["spatial" ] = rng .normal (size = (adata_tgt .n_obs , 2 ))
244
+ return adata_src , adata_tgt
245
+ if problem_kind == "alignment" :
232
246
adata_src .obsm ["spatial" ] = rng .normal (size = (adata_src .n_obs , 2 ))
233
247
adata_tgt .obsm ["spatial" ] = rng .normal (size = (adata_tgt .n_obs , 2 ))
234
248
key = "day" if problem_kind == "temporal" else "batch"
235
- adatas = [adata_src , adata_tgt ] # if forward else [adata_tgt, adata_src]
249
+ adatas = [adata_src , adata_tgt ]
236
250
adata = ad .concat (adatas , join = "outer" , label = key , index_unique = "-" , uns_merge = "unique" )
237
251
adata .obs [key ] = (pd .to_numeric (adata .obs [key ]) if key == "day" else adata .obs [key ]).astype ("category" )
238
252
adata .layers ["counts" ] = adata .X .A
239
253
sc .pp .pca (adata )
240
254
return adata
241
255
242
256
243
- _gt_source_annotation = np .array (["A" , "A" , "B" , "A" , "B" , "C" , "A" , "A" , "A" , "A" ], dtype = "U1" )
244
-
245
- _gt_target_max_annotation = np .array (["A" , "A" , "B" , "A" , "B" , "C" , "A" , "A" , "A" , "A" , "B" , "B" , "B" , "B" , "B" ])
246
-
247
- _gt_target_sum_annotation = np .array (["A" , "A" , "B" , "A" , "B" , "C" , "A" , "A" , "A" , "A" , "A" , "A" , "A" , "A" , "A" ])
248
-
249
-
250
257
@pytest .fixture ()
251
258
def gt_tm_annotation () -> np .ndarray :
252
259
tm = np .zeros ((10 , 15 ))
253
260
for i in range (10 ):
254
261
tm [i ][i ] = 1
255
262
for i in range (10 , 15 ):
256
- tm [0 ][i ] = 0.3
257
- tm [1 ][i ] = 0.3
258
- tm [2 ][i ] = 0.4
263
+ tm [i - 5 ][i ] = 1
264
+ for j in range (2 ,5 ):
265
+ for i in range (2 ,5 ):
266
+ tm [i ][j ] = 0.3 if i != j else 0.4
259
267
return tm
0 commit comments