@@ -113,7 +113,7 @@ class StoufferCombinationTest(CombinationTest):
113
113
# Maps Dataset attributes onto fit() args; see BaseEstimator for details.
114
114
_dataset_attr_map = {"z" : "y" , "w" : "n" , "g" : "v" }
115
115
116
- def _inflation_term (self , z , w , g ):
116
+ def _inflation_term (self , z , w , g , corr = None ):
117
117
"""Calculate the variance inflation term for each group.
118
118
119
119
This term is used to adjust the variance of the combined z-score when
@@ -127,6 +127,8 @@ def _inflation_term(self, z, w, g):
127
127
Array of weights.
128
128
g : :obj:`numpy.ndarray` of shape (n, d)
129
129
Array of group labels.
130
+ corr : :obj:`numpy.ndarray` of shape (n, n), optional
131
+ The correlation matrix of the z-values. If None, it will be calculated.
130
132
131
133
Returns
132
134
-------
@@ -157,26 +159,38 @@ def _inflation_term(self, z, w, g):
157
159
continue
158
160
159
161
# Calculate the within group correlation matrix and sum the non-diagonal elements
160
- corr = np .corrcoef (group_z , rowvar = True )
162
+ if corr is None :
163
+ if z .shape [1 ] < 2 :
164
+ raise ValueError ("The number of features must be greater than 1." )
165
+ group_corr = np .corrcoef (group_z , rowvar = True )
166
+ else :
167
+ group_corr = corr [group_indices ][:, group_indices ]
168
+
161
169
upper_indices = np .triu_indices (n_samples , k = 1 )
162
- non_diag_corr = corr [upper_indices ]
170
+ non_diag_corr = group_corr [upper_indices ]
163
171
w_i , w_j = weights [upper_indices [0 ]], weights [upper_indices [1 ]]
164
172
165
173
sigma += (2 * w_i * w_j * non_diag_corr ).sum ()
166
174
167
175
return sigma
168
176
169
- def fit (self , z , w = None , g = None ):
177
+ def fit (self , z , w = None , g = None , corr = None ):
170
178
"""Fit the estimator to z-values, optionally with weights and groups."""
171
- return super ().fit (z , w = w , g = g )
179
+ return super ().fit (z , w = w , g = g , corr = corr )
172
180
173
- def p_value (self , z , w = None , g = None ):
181
+ def p_value (self , z , w = None , g = None , corr = None ):
174
182
"""Calculate p-values."""
175
183
if w is None :
176
184
w = np .ones_like (z )
177
185
186
+ if g is None and corr is not None :
187
+ warnings .warn ("Correlation matrix provided without groups. Ignoring." )
188
+
189
+ if g is not None and corr is not None and g .shape [0 ] != corr .shape [0 ]:
190
+ raise ValueError ("Group labels must have the same length as the correlation matrix." )
191
+
178
192
# Calculate the variance inflation term, sum of non-diagonal elements of sigma.
179
- sigma = self ._inflation_term (z , w , g ) if g is not None else 0
193
+ sigma = self ._inflation_term (z , w , g , corr = corr ) if g is not None else 0
180
194
181
195
# The sum of diagonal elements of sigma is given by (w**2).sum(0).
182
196
variance = (w ** 2 ).sum (0 ) + sigma
0 commit comments