3
3
"""
4
4
Inference module utilities
5
5
"""
6
-
6
+ from __future__ import annotations
7
7
from collections .abc import Iterable
8
+ from typing import TYPE_CHECKING , Optional
9
+ import warnings
8
10
import numpy as np
9
11
from scipy import stats
10
12
from scipy .stats import rankdata , wilcoxon
13
15
from rsatoolbox .rdm import RDMs
14
16
from .matrix import pairwise_contrast
15
17
from .rdm_utils import batch_to_matrices
18
+ if TYPE_CHECKING :
19
+ from numpy .typing import NDArray
16
20
17
21
18
22
def input_check_model (models , theta = None , fitter = None , N = 1 ):
@@ -68,7 +72,7 @@ def input_check_model(models, theta=None, fitter=None, N=1):
68
72
return models , evaluations , theta , fitter
69
73
70
74
71
- def pool_rdm (rdms , method = 'cosine' ):
75
+ def pool_rdm (rdms , method : str = 'cosine' ):
72
76
"""pools multiple RDMs into the one with maximal performance under a given
73
77
evaluation metric
74
78
rdm_descriptors of the generated rdms are empty
@@ -114,11 +118,11 @@ def pool_rdm(rdms, method='cosine'):
114
118
rdm_vec = np .array ([_nan_rank_data (v ) for v in rdm_vec ])
115
119
rdm_vec = _nan_mean (rdm_vec )
116
120
elif method in ('kendall' , 'tau-b' ):
117
- Warning ('Noise ceiling for tau based on averaged ranks!' )
121
+ warnings . warn ('Noise ceiling for tau based on averaged ranks!' )
118
122
rdm_vec = np .array ([_nan_rank_data (v ) for v in rdm_vec ])
119
123
rdm_vec = _nan_mean (rdm_vec )
120
124
elif method == 'tau-a' :
121
- Warning ('Noise ceiling for tau based on averaged ranks!' )
125
+ warnings . warn ('Noise ceiling for tau based on averaged ranks!' )
122
126
rdm_vec = np .array ([_nan_rank_data (v ) for v in rdm_vec ])
123
127
rdm_vec = _nan_mean (rdm_vec )
124
128
else :
@@ -130,7 +134,7 @@ def pool_rdm(rdms, method='cosine'):
130
134
pattern_descriptors = rdms .pattern_descriptors )
131
135
132
136
133
- def _nan_mean (rdm_vector ) :
137
+ def _nan_mean (rdm_vector : NDArray ) -> NDArray :
134
138
""" takes the average over a rdm_vector with nans for masked entries
135
139
without a warning
136
140
@@ -149,7 +153,7 @@ def _nan_mean(rdm_vector):
149
153
return rdm_mean
150
154
151
155
152
- def _nan_rank_data (rdm_vector ) :
156
+ def _nan_rank_data (rdm_vector : NDArray ) -> NDArray :
153
157
""" rank_data for vectors with nan entries
154
158
155
159
Args:
@@ -166,9 +170,14 @@ def _nan_rank_data(rdm_vector):
166
170
return ranks
167
171
168
172
169
- def all_tests (evaluations , noise_ceil , test_type = 't-test' ,
170
- model_var = None , diff_var = None , noise_ceil_var = None ,
171
- dof = 1 ):
173
+ def all_tests (
174
+ evaluations : NDArray ,
175
+ noise_ceil : NDArray ,
176
+ test_type : str = 't-test' ,
177
+ model_var : Optional [NDArray ] = None ,
178
+ diff_var : Optional [NDArray ] = None ,
179
+ noise_ceil_var : Optional [NDArray ] = None ,
180
+ dof : int = 1 ):
172
181
"""wrapper running all tests necessary for the model plot
173
182
-> pairwise tests, tests against 0 and against noise ceiling
174
183
@@ -218,7 +227,11 @@ def all_tests(evaluations, noise_ceil, test_type='t-test',
218
227
return p_pairwise , p_zero , p_noise
219
228
220
229
221
- def pair_tests (evaluations , test_type = 't-test' , diff_var = None , dof = 1 ):
230
+ def pair_tests (
231
+ evaluations : NDArray ,
232
+ test_type : str = 't-test' ,
233
+ diff_var : Optional [NDArray ] = None ,
234
+ dof : int = 1 ):
222
235
"""wrapper running pair tests
223
236
224
237
Args:
@@ -393,7 +406,7 @@ def bootstrap_pair_tests(evaluations):
393
406
proportions = np .zeros ((evaluations .shape [1 ], evaluations .shape [1 ]))
394
407
while len (evaluations .shape ) > 2 :
395
408
evaluations = np .mean (evaluations , axis = - 1 )
396
- for i_model in range (evaluations .shape [1 ]- 1 ):
409
+ for i_model in range (evaluations .shape [1 ] - 1 ):
397
410
for j_model in range (i_model + 1 , evaluations .shape [1 ]):
398
411
proportions [i_model , j_model ] = np .sum (
399
412
evaluations [:, i_model ] < evaluations [:, j_model ]) \
@@ -499,7 +512,11 @@ def t_test_nc(evaluations, variances, noise_ceil, dof=1):
499
512
return p
500
513
501
514
502
- def extract_variances (variance , nc_included = True ):
515
+ def extract_variances (
516
+ variance ,
517
+ nc_included : bool = True ,
518
+ n_rdm : Optional [int ] = None ,
519
+ n_pattern : Optional [int ] = None ):
503
520
""" extracts the variances for the individual model evaluations,
504
521
differences between model evaluations and for the comparison to
505
522
the noise ceiling
@@ -516,6 +533,12 @@ def extract_variances(variance, nc_included=True):
516
533
to the noise ceiling results
517
534
518
535
nc_included=False assumes that the noise ceiling is fixed instead.
536
+
537
+ To get the more accurate estimates that take into account
538
+ the number of subjects and/or the numbers of stimuli
539
+ can be passed as n_rdm and n_pattern respectively.
540
+ This function corrects for all ns that are passed. If you bootstrapped
541
+ only one factor only pass the N for that factor!
519
542
"""
520
543
if variance .ndim == 0 :
521
544
variance = np .array ([variance ])
@@ -532,6 +555,9 @@ def extract_variances(variance, nc_included=True):
532
555
model_variances = variance
533
556
nc_variances = np .array ([variance , variance ]).T
534
557
diff_variances = np .diag (C @ np .diag (variance ) @ C .T )
558
+ model_variances = _correct_1d (model_variances , n_pattern , n_rdm )
559
+ nc_variances = _correct_1d (nc_variances , n_pattern , n_rdm )
560
+ diff_variances = _correct_1d (diff_variances , n_pattern , n_rdm )
535
561
elif variance .ndim == 2 :
536
562
# a single covariance matrix
537
563
if nc_included :
@@ -546,6 +572,9 @@ def extract_variances(variance, nc_included=True):
546
572
model_variances = np .diag (variance )
547
573
nc_variances = np .array ([model_variances , model_variances ]).T
548
574
diff_variances = np .diag (C @ variance @ C .T )
575
+ model_variances = _correct_1d (model_variances , n_pattern , n_rdm )
576
+ nc_variances = _correct_1d (nc_variances , n_pattern , n_rdm )
577
+ diff_variances = _correct_1d (diff_variances , n_pattern , n_rdm )
549
578
elif variance .ndim == 3 :
550
579
# general transform for multiple covariance matrices
551
580
if nc_included :
@@ -565,12 +594,30 @@ def extract_variances(variance, nc_included=True):
565
594
).transpose (1 , 2 , 0 )
566
595
diff_variances = np .einsum ('ij,kjl,il->ki' , C , variance , C )
567
596
# dual bootstrap variance estimate from 3 covariance matrices
568
- model_variances = _dual_bootstrap (model_variances )
569
- nc_variances = _dual_bootstrap (nc_variances )
570
- diff_variances = _dual_bootstrap (diff_variances )
597
+ model_variances = _dual_bootstrap (model_variances , n_rdm , n_pattern )
598
+ nc_variances = _dual_bootstrap (nc_variances , n_rdm , n_pattern )
599
+ diff_variances = _dual_bootstrap (diff_variances , n_rdm , n_pattern )
571
600
return model_variances , diff_variances , nc_variances
572
601
573
602
603
+ def _correct_1d (
604
+ variance : NDArray ,
605
+ n_pattern : Optional [int ] = None ,
606
+ n_rdm : Optional [int ] = None ):
607
+ if (n_pattern is not None ) and (n_rdm is not None ):
608
+ # uncorrected dual bootstrap?
609
+ n = min (n_rdm , n_pattern )
610
+ elif n_pattern is not None :
611
+ n = n_pattern
612
+ elif n_rdm is not None :
613
+ n = n_rdm
614
+ else :
615
+ n = None
616
+ if n is not None :
617
+ variance = (n / (n - 1 )) * variance
618
+ return variance
619
+
620
+
574
621
def get_errorbars (model_var , evaluations , dof , error_bars = 'sem' ,
575
622
test_type = 't-test' ):
576
623
""" computes errorbars for the model-evaluations from a results object
@@ -617,31 +664,47 @@ def get_errorbars(model_var, evaluations, dof, error_bars='sem',
617
664
errorbar_high = std_eval \
618
665
* tdist .ppf (prop_cut , dof )
619
666
else :
620
- raise Exception ('computing errorbars: Argument ' +
621
- 'error_bars is incorrectly defined as '
622
- + str (error_bars ) + '.' )
667
+ raise ValueError ('computing errorbars: Argument ' +
668
+ 'error_bars is incorrectly defined as '
669
+ + str (error_bars ) + '.' )
623
670
limits = np .stack ((errorbar_low , errorbar_high ))
624
671
if np .isnan (limits ).any () or (abs (limits ) == np .inf ).any ():
625
- raise Exception (
672
+ raise ValueError (
626
673
'computing errorbars: Too few bootstrap samples for the ' +
627
674
'requested confidence interval: ' + error_bars + '.' )
628
675
return limits
629
676
630
677
631
- def _dual_bootstrap (variances ):
678
+ def _dual_bootstrap (variances , n_rdm = None , n_pattern = None ):
632
679
""" helper function to perform the dual bootstrap
633
680
634
681
Takes a 3x... array of variances and computes the corrections assuming:
635
682
variances[0] are the variances in the double bootstrap
636
683
variances[1] are the variances in the rdm bootstrap
637
684
variances[2] are the variances in the pattern bootstrap
685
+
686
+ If both n_rdm and n_pattern are given this uses
687
+ the more accurate small sample formula.
638
688
"""
639
- variance = 2 * (variances [1 ] + variances [2 ]) \
640
- - variances [0 ]
641
- variance = np .maximum (np .maximum (
642
- variance , variances [1 ]), variances [2 ])
643
- variance = np .minimum (
644
- variance , variances [0 ])
689
+ if n_rdm is None or n_pattern is None :
690
+ variance = 2 * (variances [1 ] + variances [2 ]) \
691
+ - variances [0 ]
692
+ variance = np .maximum (np .maximum (
693
+ variance , variances [1 ]), variances [2 ])
694
+ variance = np .minimum (
695
+ variance , variances [0 ])
696
+ else :
697
+ variance = (
698
+ (n_rdm / (n_rdm - 1 )) * variances [1 ]
699
+ + (n_pattern / (n_pattern - 1 )) * variances [2 ]
700
+ - ((n_pattern * n_rdm / (n_pattern - 1 ) / (n_rdm - 1 ))
701
+ * (variances [0 ] - variances [1 ] - variances [2 ])))
702
+ variance = np .maximum (np .maximum (
703
+ variance ,
704
+ (n_rdm / (n_rdm - 1 )) * variances [1 ]),
705
+ (n_pattern / (n_pattern - 1 )) * variances [2 ])
706
+ variance = np .minimum (
707
+ variance , variances [0 ])
645
708
return variance
646
709
647
710
0 commit comments