Skip to content

Commit 8b9ad60

Browse files
Merge pull request #113 from theislab/dev
fixed bug in .mean of pairwise test
2 parents f8635bb + e929694 commit 8b9ad60

File tree

1 file changed

+11
-7
lines changed

1 file changed

+11
-7
lines changed

diffxpy/testing/det.py

+11-7
Original file line numberDiff line numberDiff line change
@@ -1555,13 +1555,13 @@ def __init__(
15551555
mean_x0 = np.asarray(np.mean(x0, axis=0)).flatten().astype(dtype=np.float)
15561556
mean_x1 = np.asarray(np.mean(x1, axis=0)).flatten().astype(dtype=np.float)
15571557
# Avoid unnecessary mean computation:
1558-
self._mean = np.average(
1558+
self._mean = np.asarray(np.average(
15591559
a=np.vstack([mean_x0, mean_x1]),
15601560
weights=np.array([x0.shape[0] / (x0.shape[0] + x1.shape[0]),
15611561
x1.shape[0] / (x0.shape[0] + x1.shape[0])]),
15621562
axis=0,
15631563
returned=False
1564-
)
1564+
)).flatten()
15651565
self._ave_nonzero = self._mean != 0 # omit all-zero features
15661566
if isinstance(x0, scipy.sparse.csr_matrix):
15671567
# Efficient analytic expression of variance without densification.
@@ -1603,6 +1603,8 @@ def __init__(
16031603
if is_logged:
16041604
self._logfc = mean_x1 - mean_x0
16051605
else:
1606+
mean_x0 = np.nextafter(0, np.inf, out=mean_x0, where=mean_x0 < np.nextafter(0, np.inf))
1607+
mean_x1 = np.nextafter(0, np.inf, out=mean_x1, where=mean_x1 < np.nextafter(0, np.inf))
16061608
self._logfc = np.log(mean_x1) - np.log(mean_x0)
16071609

16081610
@property
@@ -1679,13 +1681,13 @@ def __init__(
16791681
mean_x0 = np.asarray(np.mean(x0, axis=0)).flatten().astype(dtype=np.float)
16801682
mean_x1 = np.asarray(np.mean(x1, axis=0)).flatten().astype(dtype=np.float)
16811683
# Avoid unnecessary mean computation:
1682-
self._mean = np.average(
1684+
self._mean = np.asarray(np.average(
16831685
a=np.vstack([mean_x0, mean_x1]),
16841686
weights=np.array([x0.shape[0] / (x0.shape[0] + x1.shape[0]),
16851687
x1.shape[0] / (x0.shape[0] + x1.shape[0])]),
16861688
axis=0,
16871689
returned=False
1688-
)
1690+
)).flatten()
16891691
if isinstance(x0, scipy.sparse.csr_matrix):
16901692
# Efficient analytic expression of variance without densification.
16911693
var_x0 = np.asarray(np.mean(x0.power(2), axis=0)).flatten().astype(dtype=np.float) - np.square(mean_x0)
@@ -1724,6 +1726,8 @@ def __init__(
17241726
if is_logged:
17251727
self._logfc = mean_x1 - mean_x0
17261728
else:
1729+
mean_x0 = np.nextafter(0, np.inf, out=mean_x0, where=mean_x0 < np.nextafter(0, np.inf))
1730+
mean_x1 = np.nextafter(0, np.inf, out=mean_x1, where=mean_x1 < np.nextafter(0, np.inf))
17271731
self._logfc = np.log(mean_x1) - np.log(mean_x0)
17281732

17291733
@property
@@ -1881,7 +1885,7 @@ def __init__(self, gene_ids, pval, logfc, ave, groups, tests, correction_type: s
18811885
self._gene_ids = np.asarray(gene_ids)
18821886
self._logfc = logfc
18831887
self._pval = pval
1884-
self._mean = ave
1888+
self._mean = np.asarray(ave).flatten()
18851889
self.groups = list(np.asarray(groups))
18861890
self._tests = tests
18871891

@@ -2673,7 +2677,7 @@ def __init__(
26732677
self._gene_ids = np.asarray(gene_ids)
26742678
self._pval = pval
26752679
self._logfc = logfc
2676-
self._mean = ave
2680+
self._mean = np.asarray(ave).flatten()
26772681
self.groups = list(np.asarray(groups))
26782682
self._tests = tests
26792683

@@ -2794,7 +2798,7 @@ def __init__(self, partitions, tests, ave, correction_type: str = "by_test"):
27942798
self._gene_ids = tests[0].gene_ids
27952799
self._pval = np.expand_dims(np.vstack([x.pval for x in tests]), axis=0)
27962800
self._logfc = np.expand_dims(np.vstack([x.log_fold_change() for x in tests]), axis=0)
2797-
self._mean = ave
2801+
self._mean = np.asarray(ave).flatten()
27982802

27992803
_ = self.qval
28002804

0 commit comments

Comments
 (0)