Skip to content

Commit 722ef8c

Browse files
committed
update to include small sample two factor bootstrap correction
1 parent 3723036 commit 722ef8c

File tree

3 files changed

+103
-29
lines changed

3 files changed

+103
-29
lines changed

src/rsatoolbox/inference/evaluate.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,8 @@ def eval_dual_bootstrap(
173173
/ (matrix.shape[0] - 1)
174174
result = Result(models, evaluations, method=method,
175175
cv_method=cv_method, noise_ceiling=noise_ceil,
176-
variances=variances, dof=dof)
176+
variances=variances, dof=dof, n_rdm=data.n_rdm,
177+
n_pattern=data.n_cond)
177178
return result
178179

179180

@@ -201,15 +202,17 @@ def eval_fixed(models, data, theta=None, method='cosine'):
201202
noise_ceil = boot_noise_ceiling(
202203
data, method=method, rdm_descriptor='index')
203204
if data.n_rdm > 1:
204-
variances = np.cov(evaluations[0], ddof=1) \
205+
variances = np.cov(evaluations[0], ddof=0) \
205206
/ evaluations.shape[-1]
206207
dof = evaluations.shape[-1] - 1
207208
else:
208209
variances = None
209210
dof = 0
210211
result = Result(models, evaluations, method=method,
211212
cv_method='fixed', noise_ceiling=noise_ceil,
212-
variances=variances, dof=dof)
213+
variances=variances, dof=dof, n_rdm=data.n_rdm,
214+
n_pattern=None)
215+
result.n_pattern = data.n_cond
213216
return result
214217

215218

@@ -269,7 +272,8 @@ def eval_bootstrap(models, data, theta=None, method='cosine', N=1000,
269272
dof = min(data.n_rdm, data.n_cond) - 1
270273
result = Result(models, evaluations, method=method,
271274
cv_method='bootstrap', noise_ceiling=noise_ceil,
272-
variances=variances, dof=dof)
275+
variances=variances, dof=dof, n_rdm=data.n_rdm,
276+
n_pattern=data.n_cond)
273277
return result
274278

275279

@@ -329,7 +333,9 @@ def eval_bootstrap_pattern(models, data, theta=None, method='cosine', N=1000,
329333
dof = data.n_cond - 1
330334
result = Result(models, evaluations, method=method,
331335
cv_method='bootstrap_pattern', noise_ceiling=noise_ceil,
332-
variances=variances, dof=dof)
336+
variances=variances, dof=dof, n_rdm=None,
337+
n_pattern=data.n_cond)
338+
result.n_rdm = data.n_rdm
333339
return result
334340

335341

@@ -378,7 +384,9 @@ def eval_bootstrap_rdm(models, data, theta=None, method='cosine', N=1000,
378384
variances = np.cov(evaluations.T)
379385
result = Result(models, evaluations, method=method,
380386
cv_method='bootstrap_rdm', noise_ceiling=noise_ceil,
381-
variances=variances, dof=dof)
387+
variances=variances, dof=dof, n_rdm=data.n_rdm,
388+
n_pattern=None)
389+
result.n_pattern = data.n_cond
382390
return result
383391

384392

@@ -590,7 +598,8 @@ def bootstrap_crossval(models, data, method='cosine', fitter=None,
590598
variances = np.cov(np.concatenate([evals_nonan.T, noise_ceil_nonan]))
591599
result = Result(models, evaluations, method=method,
592600
cv_method=cv_method, noise_ceiling=noise_ceil,
593-
variances=variances, dof=dof)
601+
variances=variances, dof=dof, n_rdm=data.n_rdm,
602+
n_pattern=data.n_cond)
594603
return result
595604

596605

@@ -735,7 +744,8 @@ def eval_dual_bootstrap_random(
735744
variances = np.cov(np.concatenate([evals_nonan.T, noise_ceil_nonan]))
736745
result = Result(models, evaluations, method=method,
737746
cv_method=cv_method, noise_ceiling=noise_ceil,
738-
variances=variances, dof=dof)
747+
variances=variances, dof=dof, n_rdm=data.n_rdm,
748+
n_pattern=data.n_cond)
739749
return result
740750

741751

src/rsatoolbox/inference/result.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ class Result:
4040
"""
4141

4242
def __init__(self, models, evaluations, method, cv_method, noise_ceiling,
43-
variances=None, dof=1, fitter=None):
43+
variances=None, dof=1, fitter=None, n_rdm=None, n_pattern=None):
4444
if isinstance(models, rsatoolbox.model.Model):
4545
models = [models]
4646
assert len(models) == evaluations.shape[1], 'evaluations shape does' \
@@ -55,6 +55,8 @@ def __init__(self, models, evaluations, method, cv_method, noise_ceiling,
5555
self.dof = dof
5656
self.fitter = fitter
5757
self.n_bootstraps = evaluations.shape[0]
58+
self.n_rdm = n_rdm
59+
self.n_pattern = n_pattern
5860
if variances is not None:
5961
# if the variances only refer to the models this should have the
6062
# same number of entries as the models list.
@@ -63,7 +65,7 @@ def __init__(self, models, evaluations, method, cv_method, noise_ceiling,
6365
else:
6466
nc_included = variances.shape[-1] != len(models)
6567
self.model_var, self.diff_var, self.noise_ceil_var = \
66-
extract_variances(variances, nc_included)
68+
extract_variances(variances, nc_included, n_rdm, n_pattern)
6769
else:
6870
self.model_var = None
6971
self.diff_var = None

src/rsatoolbox/util/inference_util.py

Lines changed: 81 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
"""
44
Inference module utilities
55
"""
6-
6+
from __future__ import annotations
77
from collections.abc import Iterable
88
import numpy as np
99
from scipy import stats
@@ -13,6 +13,9 @@
1313
from rsatoolbox.rdm import RDMs
1414
from .matrix import pairwise_contrast
1515
from .rdm_utils import batch_to_matrices
16+
from typing import TYPE_CHECKING, Optional
17+
if TYPE_CHECKING:
18+
from numpy.typing import NDArray
1619

1720

1821
def input_check_model(models, theta=None, fitter=None, N=1):
@@ -68,7 +71,7 @@ def input_check_model(models, theta=None, fitter=None, N=1):
6871
return models, evaluations, theta, fitter
6972

7073

71-
def pool_rdm(rdms, method='cosine'):
74+
def pool_rdm(rdms, method: str = 'cosine'):
7275
"""pools multiple RDMs into the one with maximal performance under a given
7376
evaluation metric
7477
rdm_descriptors of the generated rdms are empty
@@ -130,7 +133,7 @@ def pool_rdm(rdms, method='cosine'):
130133
pattern_descriptors=rdms.pattern_descriptors)
131134

132135

133-
def _nan_mean(rdm_vector):
136+
def _nan_mean(rdm_vector: NDArray) -> NDArray:
134137
""" takes the average over a rdm_vector with nans for masked entries
135138
without a warning
136139
@@ -149,7 +152,7 @@ def _nan_mean(rdm_vector):
149152
return rdm_mean
150153

151154

152-
def _nan_rank_data(rdm_vector):
155+
def _nan_rank_data(rdm_vector: NDArray) -> NDArray:
153156
""" rank_data for vectors with nan entries
154157
155158
Args:
@@ -166,9 +169,14 @@ def _nan_rank_data(rdm_vector):
166169
return ranks
167170

168171

169-
def all_tests(evaluations, noise_ceil, test_type='t-test',
170-
model_var=None, diff_var=None, noise_ceil_var=None,
171-
dof=1):
172+
def all_tests(
173+
evaluations: NDArray,
174+
noise_ceil: NDArray,
175+
test_type: str = 't-test',
176+
model_var: Optional[NDArray] = None,
177+
diff_var: Optional[NDArray] = None,
178+
noise_ceil_var: Optional[NDArray] = None,
179+
dof: int = 1):
172180
"""wrapper running all tests necessary for the model plot
173181
-> pairwise tests, tests against 0 and against noise ceiling
174182
@@ -218,7 +226,11 @@ def all_tests(evaluations, noise_ceil, test_type='t-test',
218226
return p_pairwise, p_zero, p_noise
219227

220228

221-
def pair_tests(evaluations, test_type='t-test', diff_var=None, dof=1):
229+
def pair_tests(
230+
evaluations: NDArray,
231+
test_type: str = 't-test',
232+
diff_var: Optional[NDArray] = None,
233+
dof: int = 1):
222234
"""wrapper running pair tests
223235
224236
Args:
@@ -499,7 +511,11 @@ def t_test_nc(evaluations, variances, noise_ceil, dof=1):
499511
return p
500512

501513

502-
def extract_variances(variance, nc_included=True):
514+
def extract_variances(
515+
variance,
516+
nc_included: bool = True,
517+
n_rdm: Optional[int] = None,
518+
n_pattern: Optional[int] = None):
503519
""" extracts the variances for the individual model evaluations,
504520
differences between model evaluations and for the comparison to
505521
the noise ceiling
@@ -516,6 +532,12 @@ def extract_variances(variance, nc_included=True):
516532
to the noise ceiling results
517533
518534
nc_included=False assumes that the noise ceiling is fixed instead.
535+
536+
To get the more accurate estimates that take into account
537+
the number of subjects and/or the numbers of stimuli
538+
can be passed as n_rdm and n_pattern respectively.
539+
This function corrects for all ns that are passed. If you bootstrapped
540+
only one factor only pass the N for that factor!
519541
"""
520542
if variance.ndim == 0:
521543
variance = np.array([variance])
@@ -532,6 +554,9 @@ def extract_variances(variance, nc_included=True):
532554
model_variances = variance
533555
nc_variances = np.array([variance, variance]).T
534556
diff_variances = np.diag(C @ np.diag(variance) @ C.T)
557+
model_variances = _correct_1d(model_variances, n_pattern, n_rdm)
558+
nc_variances = _correct_1d(nc_variances, n_pattern, n_rdm)
559+
diff_variances = _correct_1d(diff_variances, n_pattern, n_rdm)
535560
elif variance.ndim == 2:
536561
# a single covariance matrix
537562
if nc_included:
@@ -546,6 +571,9 @@ def extract_variances(variance, nc_included=True):
546571
model_variances = np.diag(variance)
547572
nc_variances = np.array([model_variances, model_variances]).T
548573
diff_variances = np.diag(C @ variance @ C.T)
574+
model_variances = _correct_1d(model_variances, n_pattern, n_rdm)
575+
nc_variances = _correct_1d(nc_variances, n_pattern, n_rdm)
576+
diff_variances = _correct_1d(diff_variances, n_pattern, n_rdm)
549577
elif variance.ndim == 3:
550578
# general transform for multiple covariance matrices
551579
if nc_included:
@@ -565,12 +593,30 @@ def extract_variances(variance, nc_included=True):
565593
).transpose(1, 2, 0)
566594
diff_variances = np.einsum('ij,kjl,il->ki', C, variance, C)
567595
# dual bootstrap variance estimate from 3 covariance matrices
568-
model_variances = _dual_bootstrap(model_variances)
569-
nc_variances = _dual_bootstrap(nc_variances)
570-
diff_variances = _dual_bootstrap(diff_variances)
596+
model_variances = _dual_bootstrap(model_variances, n_rdm, n_pattern)
597+
nc_variances = _dual_bootstrap(nc_variances, n_rdm, n_pattern)
598+
diff_variances = _dual_bootstrap(diff_variances, n_rdm, n_pattern)
571599
return model_variances, diff_variances, nc_variances
572600

573601

602+
def _correct_1d(
603+
variance: NDArray,
604+
n_pattern: Optional[int] = None,
605+
n_rdm: Optional[int] = None):
606+
if (n_pattern is not None) and (n_rdm is not None):
607+
# uncorrected dual bootstrap?
608+
n = min(n_rdm, n_pattern)
609+
elif n_pattern is not None:
610+
n = n_pattern
611+
elif n_rdm is not None:
612+
n = n_rdm
613+
else:
614+
n = None
615+
if n is not None:
616+
variance = (n / (n - 1)) * variance
617+
return variance
618+
619+
574620
def get_errorbars(model_var, evaluations, dof, error_bars='sem',
575621
test_type='t-test'):
576622
""" computes errorbars for the model-evaluations from a results object
@@ -628,20 +674,36 @@ def get_errorbars(model_var, evaluations, dof, error_bars='sem',
628674
return limits
629675

630676

631-
def _dual_bootstrap(variances):
677+
def _dual_bootstrap(variances, n_rdm=None, n_pattern=None):
632678
""" helper function to perform the dual bootstrap
633679
634680
Takes a 3x... array of variances and computes the corrections assuming:
635681
variances[0] are the variances in the double bootstrap
636682
variances[1] are the variances in the rdm bootstrap
637683
variances[2] are the variances in the pattern bootstrap
684+
685+
If both n_rdm and n_pattern are given this uses
686+
the more accurate small sample formula.
638687
"""
639-
variance = 2 * (variances[1] + variances[2]) \
640-
- variances[0]
641-
variance = np.maximum(np.maximum(
642-
variance, variances[1]), variances[2])
643-
variance = np.minimum(
644-
variance, variances[0])
688+
if n_rdm is None or n_pattern is None:
689+
variance = 2 * (variances[1] + variances[2]) \
690+
- variances[0]
691+
variance = np.maximum(np.maximum(
692+
variance, variances[1]), variances[2])
693+
variance = np.minimum(
694+
variance, variances[0])
695+
else:
696+
variance = (
697+
(n_rdm / (n_rdm - 1)) * variances[1]
698+
+ (n_pattern / (n_pattern - 1)) * variances[2]
699+
- ((n_pattern*n_rdm / (n_pattern - 1) / (n_rdm - 1))
700+
* (variances[0] - variances[1] - variances[2])))
701+
variance = np.maximum(np.maximum(
702+
variance,
703+
(n_rdm / (n_rdm - 1)) * variances[1]),
704+
(n_pattern / (n_pattern - 1)) * variances[2])
705+
variance = np.minimum(
706+
variance, variances[0])
645707
return variance
646708

647709

0 commit comments

Comments
 (0)