3
3
"""
4
4
Inference module utilities
5
5
"""
6
-
6
+ from __future__ import annotations
7
7
from collections .abc import Iterable
8
8
import numpy as np
9
9
from scipy import stats
13
13
from rsatoolbox .rdm import RDMs
14
14
from .matrix import pairwise_contrast
15
15
from .rdm_utils import batch_to_matrices
16
+ from typing import TYPE_CHECKING , Optional
17
+ if TYPE_CHECKING :
18
+ from numpy .typing import NDArray
16
19
17
20
18
21
def input_check_model (models , theta = None , fitter = None , N = 1 ):
@@ -68,7 +71,7 @@ def input_check_model(models, theta=None, fitter=None, N=1):
68
71
return models , evaluations , theta , fitter
69
72
70
73
71
- def pool_rdm (rdms , method = 'cosine' ):
74
+ def pool_rdm (rdms , method : str = 'cosine' ):
72
75
"""pools multiple RDMs into the one with maximal performance under a given
73
76
evaluation metric
74
77
rdm_descriptors of the generated rdms are empty
@@ -130,7 +133,7 @@ def pool_rdm(rdms, method='cosine'):
130
133
pattern_descriptors = rdms .pattern_descriptors )
131
134
132
135
133
- def _nan_mean (rdm_vector ) :
136
+ def _nan_mean (rdm_vector : NDArray ) -> NDArray :
134
137
""" takes the average over a rdm_vector with nans for masked entries
135
138
without a warning
136
139
@@ -149,7 +152,7 @@ def _nan_mean(rdm_vector):
149
152
return rdm_mean
150
153
151
154
152
- def _nan_rank_data (rdm_vector ) :
155
+ def _nan_rank_data (rdm_vector : NDArray ) -> NDArray :
153
156
""" rank_data for vectors with nan entries
154
157
155
158
Args:
@@ -166,9 +169,14 @@ def _nan_rank_data(rdm_vector):
166
169
return ranks
167
170
168
171
169
- def all_tests (evaluations , noise_ceil , test_type = 't-test' ,
170
- model_var = None , diff_var = None , noise_ceil_var = None ,
171
- dof = 1 ):
172
+ def all_tests (
173
+ evaluations : NDArray ,
174
+ noise_ceil : NDArray ,
175
+ test_type : str = 't-test' ,
176
+ model_var : Optional [NDArray ] = None ,
177
+ diff_var : Optional [NDArray ] = None ,
178
+ noise_ceil_var : Optional [NDArray ] = None ,
179
+ dof : int = 1 ):
172
180
"""wrapper running all tests necessary for the model plot
173
181
-> pairwise tests, tests against 0 and against noise ceiling
174
182
@@ -218,7 +226,11 @@ def all_tests(evaluations, noise_ceil, test_type='t-test',
218
226
return p_pairwise , p_zero , p_noise
219
227
220
228
221
- def pair_tests (evaluations , test_type = 't-test' , diff_var = None , dof = 1 ):
229
+ def pair_tests (
230
+ evaluations : NDArray ,
231
+ test_type : str = 't-test' ,
232
+ diff_var : Optional [NDArray ] = None ,
233
+ dof : int = 1 ):
222
234
"""wrapper running pair tests
223
235
224
236
Args:
@@ -499,7 +511,11 @@ def t_test_nc(evaluations, variances, noise_ceil, dof=1):
499
511
return p
500
512
501
513
502
- def extract_variances (variance , nc_included = True ):
514
+ def extract_variances (
515
+ variance ,
516
+ nc_included : bool = True ,
517
+ n_rdm : Optional [int ] = None ,
518
+ n_pattern : Optional [int ] = None ):
503
519
""" extracts the variances for the individual model evaluations,
504
520
differences between model evaluations and for the comparison to
505
521
the noise ceiling
@@ -516,6 +532,12 @@ def extract_variances(variance, nc_included=True):
516
532
to the noise ceiling results
517
533
518
534
nc_included=False assumes that the noise ceiling is fixed instead.
535
+
536
+ To get the more accurate estimates that take into account
537
+ the number of subjects and/or the numbers of stimuli
538
+ can be passed as n_rdm and n_pattern respectively.
539
+ This function corrects for all ns that are passed. If you bootstrapped
540
+ only one factor only pass the N for that factor!
519
541
"""
520
542
if variance .ndim == 0 :
521
543
variance = np .array ([variance ])
@@ -532,6 +554,9 @@ def extract_variances(variance, nc_included=True):
532
554
model_variances = variance
533
555
nc_variances = np .array ([variance , variance ]).T
534
556
diff_variances = np .diag (C @ np .diag (variance ) @ C .T )
557
+ model_variances = _correct_1d (model_variances , n_pattern , n_rdm )
558
+ nc_variances = _correct_1d (nc_variances , n_pattern , n_rdm )
559
+ diff_variances = _correct_1d (diff_variances , n_pattern , n_rdm )
535
560
elif variance .ndim == 2 :
536
561
# a single covariance matrix
537
562
if nc_included :
@@ -546,6 +571,9 @@ def extract_variances(variance, nc_included=True):
546
571
model_variances = np .diag (variance )
547
572
nc_variances = np .array ([model_variances , model_variances ]).T
548
573
diff_variances = np .diag (C @ variance @ C .T )
574
+ model_variances = _correct_1d (model_variances , n_pattern , n_rdm )
575
+ nc_variances = _correct_1d (nc_variances , n_pattern , n_rdm )
576
+ diff_variances = _correct_1d (diff_variances , n_pattern , n_rdm )
549
577
elif variance .ndim == 3 :
550
578
# general transform for multiple covariance matrices
551
579
if nc_included :
@@ -565,12 +593,30 @@ def extract_variances(variance, nc_included=True):
565
593
).transpose (1 , 2 , 0 )
566
594
diff_variances = np .einsum ('ij,kjl,il->ki' , C , variance , C )
567
595
# dual bootstrap variance estimate from 3 covariance matrices
568
- model_variances = _dual_bootstrap (model_variances )
569
- nc_variances = _dual_bootstrap (nc_variances )
570
- diff_variances = _dual_bootstrap (diff_variances )
596
+ model_variances = _dual_bootstrap (model_variances , n_rdm , n_pattern )
597
+ nc_variances = _dual_bootstrap (nc_variances , n_rdm , n_pattern )
598
+ diff_variances = _dual_bootstrap (diff_variances , n_rdm , n_pattern )
571
599
return model_variances , diff_variances , nc_variances
572
600
573
601
602
+ def _correct_1d (
603
+ variance : NDArray ,
604
+ n_pattern : Optional [int ] = None ,
605
+ n_rdm : Optional [int ] = None ):
606
+ if (n_pattern is not None ) and (n_rdm is not None ):
607
+ # uncorrected dual bootstrap?
608
+ n = min (n_rdm , n_pattern )
609
+ elif n_pattern is not None :
610
+ n = n_pattern
611
+ elif n_rdm is not None :
612
+ n = n_rdm
613
+ else :
614
+ n = None
615
+ if n is not None :
616
+ variance = (n / (n - 1 )) * variance
617
+ return variance
618
+
619
+
574
620
def get_errorbars (model_var , evaluations , dof , error_bars = 'sem' ,
575
621
test_type = 't-test' ):
576
622
""" computes errorbars for the model-evaluations from a results object
@@ -628,20 +674,36 @@ def get_errorbars(model_var, evaluations, dof, error_bars='sem',
628
674
return limits
629
675
630
676
631
- def _dual_bootstrap (variances ):
677
+ def _dual_bootstrap (variances , n_rdm = None , n_pattern = None ):
632
678
""" helper function to perform the dual bootstrap
633
679
634
680
Takes a 3x... array of variances and computes the corrections assuming:
635
681
variances[0] are the variances in the double bootstrap
636
682
variances[1] are the variances in the rdm bootstrap
637
683
variances[2] are the variances in the pattern bootstrap
684
+
685
+ If both n_rdm and n_pattern are given this uses
686
+ the more accurate small sample formula.
638
687
"""
639
- variance = 2 * (variances [1 ] + variances [2 ]) \
640
- - variances [0 ]
641
- variance = np .maximum (np .maximum (
642
- variance , variances [1 ]), variances [2 ])
643
- variance = np .minimum (
644
- variance , variances [0 ])
688
+ if n_rdm is None or n_pattern is None :
689
+ variance = 2 * (variances [1 ] + variances [2 ]) \
690
+ - variances [0 ]
691
+ variance = np .maximum (np .maximum (
692
+ variance , variances [1 ]), variances [2 ])
693
+ variance = np .minimum (
694
+ variance , variances [0 ])
695
+ else :
696
+ variance = (
697
+ (n_rdm / (n_rdm - 1 )) * variances [1 ]
698
+ + (n_pattern / (n_pattern - 1 )) * variances [2 ]
699
+ - ((n_pattern * n_rdm / (n_pattern - 1 ) / (n_rdm - 1 ))
700
+ * (variances [0 ] - variances [1 ] - variances [2 ])))
701
+ variance = np .maximum (np .maximum (
702
+ variance ,
703
+ (n_rdm / (n_rdm - 1 )) * variances [1 ]),
704
+ (n_pattern / (n_pattern - 1 )) * variances [2 ])
705
+ variance = np .minimum (
706
+ variance , variances [0 ])
645
707
return variance
646
708
647
709
0 commit comments