Skip to content

Commit b802f17

Browse files
Merge pull request #317 from rsagroup/fix_print_result
fixing print for crossvalidation
2 parents 42f1d79 + e0f6ed9 commit b802f17

File tree

5 files changed

+80
-197
lines changed

5 files changed

+80
-197
lines changed

demos/demo_bootstrap.ipynb

+31-171
Large diffs are not rendered by default.

src/rsatoolbox/inference/result.py

+16-6
Original file line numberDiff line numberDiff line change
@@ -96,8 +96,16 @@ def summary(self, test_type='t-test'):
9696
name_length = max([max(len(m.name) for m in self.models) + 1, 6])
9797
means = self.get_means()
9898
sems = self.get_sem()
99-
p_zero = self.test_zero(test_type=test_type)
100-
p_noise = self.test_noise(test_type=test_type)
99+
if means is None:
100+
means = np.nan * np.ones(self.n_model)
101+
if sems is None:
102+
sems = np.nan * np.ones(self.n_model)
103+
try:
104+
p_zero = self.test_zero(test_type=test_type)
105+
p_noise = self.test_noise(test_type=test_type)
106+
except ValueError:
107+
p_zero = np.nan * np.ones(self.n_model)
108+
p_noise = np.nan * np.ones(self.n_model)
101109
# header of the results table
102110
summary += 'Model' + (' ' * (name_length - 5))
103111
summary += '| Eval \u00B1 SEM |'
@@ -118,7 +126,9 @@ def summary(self, test_type='t-test'):
118126
summary += f'{p_noise[i]:>14.3f} |'
119127
summary += '\n'
120128
summary += '\n'
121-
if test_type == 't-test':
129+
if self.cv_method == 'crossvalidation':
130+
summary += 'No p-values available as crossvalidation provides no variance estimate'
131+
elif test_type == 't-test':
122132
summary += 'p-values are based on uncorrected t-tests'
123133
elif test_type == 'bootstrap':
124134
summary += 'p-values are based on percentiles of the bootstrap samples'
@@ -250,11 +260,11 @@ def get_errorbars(self, eb_type='sem', test_type='t-test'):
250260
ci_percent = float(eb_type[2:]) / 100
251261
ci = self.get_ci(ci_percent, test_type)
252262
means = self.get_means()
253-
errorbar_low = -(ci[0] - means)
254-
errorbar_high = (ci[1] - means)
263+
errorbar_low = means - ci[0]
264+
errorbar_high = ci[1] - means
255265
limits = np.concatenate((errorbar_low, errorbar_high))
256266
if np.isnan(limits).any() or (abs(limits) == np.inf).any():
257-
raise Exception(
267+
raise ValueError(
258268
'plot_model_comparison: Too few bootstrap samples for ' +
259269
'the requested confidence interval: ' + eb_type + '.')
260270
return (errorbar_low, errorbar_high)

src/rsatoolbox/vis/model_plot.py

+17-11
Original file line numberDiff line numberDiff line change
@@ -264,9 +264,10 @@ def plot_model_comparison(result, sort=False, colors=None,
264264
models = [models[i] for i in idx]
265265
if not ('descend' in sort.lower() or
266266
'ascend' in sort.lower()):
267-
raise Exception('plot_model_comparison: Argument ' +
268-
'sort is incorrectly defined as '
269-
+ sort + '.')
267+
raise ValueError(
268+
'plot_model_comparison: Argument ' +
269+
'sort is incorrectly defined as ' +
270+
sort + '.')
270271

271272
# run tests
272273
if any([test_pair_comparisons,
@@ -290,9 +291,10 @@ def plot_model_comparison(result, sort=False, colors=None,
290291
elif 'nili' in test_pair_comparisons.lower():
291292
h_pair_tests = 0.4
292293
else:
293-
raise Exception('plot_model_comparison: Argument ' +
294-
'test_pair_comparisons is incorrectly defined as '
295-
+ test_pair_comparisons + '.')
294+
raise ValueError(
295+
'plot_model_comparison: Argument ' +
296+
'test_pair_comparisons is incorrectly defined as ' +
297+
test_pair_comparisons + '.')
296298
ax = plt.axes((l, b, w, h*(1-h_pair_tests)))
297299
axbar = plt.axes((l, b + h * (1 - h_pair_tests), w,
298300
h * h_pair_tests * 0.7))
@@ -360,7 +362,7 @@ def plot_model_comparison(result, sort=False, colors=None,
360362
marker=10, markersize=half_sym_size,
361363
linewidth=0)
362364
else:
363-
raise Exception(
365+
raise ValueError(
364366
'plot_model_comparison: Argument test_above_0' +
365367
' is incorrectly defined as ' + test_above_0 + '.')
366368

@@ -397,7 +399,7 @@ def plot_model_comparison(result, sort=False, colors=None,
397399
markerfacecolor=noise_ceil_col,
398400
markeredgecolor='none', linewidth=0)
399401
else:
400-
raise Exception(
402+
raise ValueError(
401403
'plot_model_comparison: Argument ' +
402404
'test_below_noise_ceil is incorrectly defined as ' +
403405
test_below_noise_ceil + '.')
@@ -429,7 +431,7 @@ def plot_model_comparison(result, sort=False, colors=None,
429431
significant = p_pairwise < crit
430432
else:
431433
if 'uncorrected' not in multiple_pair_testing.lower():
432-
raise Exception(
434+
raise ValueError(
433435
'plot_model_comparison: Argument ' +
434436
'multiple_pair_testing is incorrectly defined as ' +
435437
multiple_pair_testing + '.')
@@ -793,7 +795,11 @@ def plot_arrows(axbar, significant):
793795
k += 1
794796
axbar.plot((i, j), (k, k), 'k-', linewidth=2)
795797
occupied[k-1, i*3+2:j*3+1] = 1
796-
h = occupied.sum(axis=1).nonzero()[0].max()+1
798+
h = occupied.sum(axis=1)
799+
if np.any(h > 0):
800+
h = h.nonzero()[0].max()+1
801+
else:
802+
h = 1
797803
axbar.set_ylim((0, max(expected_n_lines, h)))
798804

799805

@@ -905,7 +911,7 @@ def _get_model_comp_descr(test_type, n_models, multiple_pair_testing, alpha,
905911
' model-pair comparisons)')
906912
else:
907913
if 'uncorrected' not in multiple_pair_testing.lower():
908-
raise Exception(
914+
raise ValueError(
909915
'plot_model_comparison: Argument ' +
910916
'multiple_pair_testing is incorrectly defined as ' +
911917
multiple_pair_testing + '.')

tests/test_crossval.py

+10-8
Original file line numberDiff line numberDiff line change
@@ -37,17 +37,19 @@ def test_crossval(self):
3737
from rsatoolbox.inference import crossval
3838
rdms = self.rdms
3939
m = self.m
40-
train_set = [(rdms.subset_pattern('type', [0, 1]), np.array([0, 1])),
41-
(rdms.subset_pattern('type', [0, 4]), np.array([0, 4])),
40+
train_set = [(rdms.subset_pattern('type', [0, 1, 2]), np.array([0, 1, 2])),
41+
(rdms.subset_pattern('type', [3, 4, 5]), np.array([3, 4, 5])),
4242
]
43-
test_set = [(rdms.subset_pattern('type', [2, 4]), np.array([2, 4])),
44-
(rdms.subset_pattern('type', [1, 2]), np.array([1, 2])),
43+
test_set = [(rdms.subset_pattern('type', [3, 4, 5]), np.array([3, 4, 5])),
44+
(rdms.subset_pattern('type', [0, 1, 2]), np.array([0, 1, 2])),
4545
]
46-
ceil_set = [(rdms.subset_pattern('type', [2, 4]), np.array([2, 4])),
47-
(rdms.subset_pattern('type', [1, 2]), np.array([1, 2])),
46+
ceil_set = [(rdms.subset_pattern('type', [3, 4, 5]), np.array([3, 4, 5])),
47+
(rdms.subset_pattern('type', [0, 1, 2]), np.array([0, 1, 2])),
4848
]
49-
crossval(m, rdms, train_set, test_set, ceil_set,
50-
pattern_descriptor='type')
49+
res = crossval(
50+
m, rdms, train_set, test_set, ceil_set,
51+
pattern_descriptor='type')
52+
print(res)
5153

5254
def test_bootstrap_crossval(self):
5355
from rsatoolbox.inference import bootstrap_crossval

tests/test_demo.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -292,7 +292,12 @@ def test_exercise_all(self):
292292
models_flex, rdms_data, train_set, test_set,
293293
ceil_set=ceil_set, method='corr')
294294
# plot results
295-
rsatoolbox.vis.plot_model_comparison(results_3_cv)
295+
rsatoolbox.vis.plot_model_comparison(
296+
results_3_cv,
297+
error_bars=False,
298+
test_pair_comparisons=False,
299+
test_above_0=False,
300+
test_below_noise_ceil=False)
296301

297302
results_3_full = rsatoolbox.inference.bootstrap_crossval(
298303
models_flex, rdms_data, k_pattern=4, k_rdm=2, method='corr', N=5)

0 commit comments

Comments
 (0)