Skip to content

Commit 46b87a5

Browse files
authored
Support precomputed correlation matrix for calculating variance inflation term in Stouffers (#121)
* Support precomputed correlation matrix for calculating variance inflation term in Stouffers * Update test_results.py * Switch to macos-12 per actions/setup-python#850 * Update test_results.py
1 parent 624968e commit 46b87a5

File tree

4 files changed

+49
-9
lines changed

4 files changed

+49
-9
lines changed

.github/workflows/testing.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ jobs:
3838
strategy:
3939
fail-fast: false
4040
matrix:
41-
os: ["ubuntu-latest", "macos-latest"]
41+
os: ["ubuntu-latest", "macos-12"]
4242
python-version: ["3.8", "3.9", "3.10", "3.11"]
4343

4444
name: ${{ matrix.os }} with Python ${{ matrix.python-version }}

pymare/estimators/combination.py

+21-7
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ class StoufferCombinationTest(CombinationTest):
113113
# Maps Dataset attributes onto fit() args; see BaseEstimator for details.
114114
_dataset_attr_map = {"z": "y", "w": "n", "g": "v"}
115115

116-
def _inflation_term(self, z, w, g):
116+
def _inflation_term(self, z, w, g, corr=None):
117117
"""Calculate the variance inflation term for each group.
118118
119119
This term is used to adjust the variance of the combined z-score when
@@ -127,6 +127,8 @@ def _inflation_term(self, z, w, g):
127127
Array of weights.
128128
g : :obj:`numpy.ndarray` of shape (n, d)
129129
Array of group labels.
130+
corr : :obj:`numpy.ndarray` of shape (n, n), optional
131+
The correlation matrix of the z-values. If None, it will be calculated.
130132
131133
Returns
132134
-------
@@ -157,26 +159,38 @@ def _inflation_term(self, z, w, g):
157159
continue
158160

159161
# Calculate the within group correlation matrix and sum the non-diagonal elements
160-
corr = np.corrcoef(group_z, rowvar=True)
162+
if corr is None:
163+
if z.shape[1] < 2:
164+
raise ValueError("The number of features must be greater than 1.")
165+
group_corr = np.corrcoef(group_z, rowvar=True)
166+
else:
167+
group_corr = corr[group_indices][:, group_indices]
168+
161169
upper_indices = np.triu_indices(n_samples, k=1)
162-
non_diag_corr = corr[upper_indices]
170+
non_diag_corr = group_corr[upper_indices]
163171
w_i, w_j = weights[upper_indices[0]], weights[upper_indices[1]]
164172

165173
sigma += (2 * w_i * w_j * non_diag_corr).sum()
166174

167175
return sigma
168176

169-
def fit(self, z, w=None, g=None):
177+
def fit(self, z, w=None, g=None, corr=None):
170178
"""Fit the estimator to z-values, optionally with weights and groups."""
171-
return super().fit(z, w=w, g=g)
179+
return super().fit(z, w=w, g=g, corr=corr)
172180

173-
def p_value(self, z, w=None, g=None):
181+
def p_value(self, z, w=None, g=None, corr=None):
174182
"""Calculate p-values."""
175183
if w is None:
176184
w = np.ones_like(z)
177185

186+
if g is None and corr is not None:
187+
warnings.warn("Correlation matrix provided without groups. Ignoring.")
188+
189+
if g is not None and corr is not None and g.shape[0] != corr.shape[0]:
190+
raise ValueError("Group labels must have the same length as the correlation matrix.")
191+
178192
# Calculate the variance inflation term, sum of non-diagonal elements of sigma.
179-
sigma = self._inflation_term(z, w, g) if g is not None else 0
193+
sigma = self._inflation_term(z, w, g, corr=corr) if g is not None else 0
180194

181195
# The sum of diagonal elements of sigma is given by (w**2).sum(0).
182196
variance = (w**2).sum(0) + sigma

pymare/tests/test_combination_tests.py

+25
Original file line numberDiff line numberDiff line change
@@ -79,3 +79,28 @@ def test_stouffer_adjusted():
7979
sigma_l1 = n_maps_l1 * (n_maps_l1 - 1) # Expected inflation term
8080
z_expected_l1 = n_maps_l1 * common_sample / np.sqrt(n_maps_l1 + sigma_l1)
8181
assert np.allclose(z_l1, z_expected_l1, atol=1e-5)
82+
83+
# Test with correlation matrix and groups.
84+
data_corr = data - data.mean(0)
85+
corr = np.corrcoef(data_corr, rowvar=True)
86+
results_corr = (
87+
StoufferCombinationTest("directed").fit(z=data, w=weights, g=groups, corr=corr).params_
88+
)
89+
z_corr = ss.norm.isf(results_corr["p"])
90+
91+
z_corr_expected = np.array([5.00088912, 3.70356943, 4.05465924, 5.4633001, 5.18927878])
92+
assert np.allclose(z_corr, z_corr_expected, atol=1e-5)
93+
94+
# Test with no correlation matrix and groups, but only one feature.
95+
with pytest.raises(ValueError):
96+
StoufferCombinationTest("directed").fit(z=data[:, :1], w=weights[:, :1], g=groups)
97+
98+
# Test with correlation matrix and groups of different shapes.
99+
with pytest.raises(ValueError):
100+
StoufferCombinationTest("directed").fit(z=data, w=weights, g=groups, corr=corr[:-2, :-2])
101+
102+
# Test with correlation matrix and no groups.
103+
results1 = StoufferCombinationTest("directed").fit(z=_z1, corr=corr).params_
104+
z1 = ss.norm.isf(results1["p"])
105+
106+
assert np.allclose(z1, [4.69574], atol=1e-5)

pymare/tests/test_results.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -87,8 +87,9 @@ def test_combination_test_results_from_arrays(dataset):
8787

8888
# fit overwrites dataset_ attribute with None
8989
assert fitted_estimator.dataset_ is None
90+
9091
# fit_dataset overwrites it with the Dataset
91-
fitted_estimator.fit_dataset(dataset)
92+
fitted_estimator.fit_dataset(Dataset(dataset.y))
9293
assert isinstance(fitted_estimator.dataset_, Dataset)
9394
# fit sets it back to None
9495
fitted_estimator.fit(z=dataset.y)

0 commit comments

Comments
 (0)