Skip to content

Commit 79beab9

Browse files
timecourse plot: style fixes
1 parent e2703a8 commit 79beab9

File tree

1 file changed

+81
-51
lines changed

1 file changed

+81
-51
lines changed

src/rsatoolbox/vis/timecourse.py

+81-51
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
"""Lineplot of dissimilarity over time
2+
3+
See demo_meg_mne for an example.
4+
"""
15
from __future__ import annotations
26
from typing import TYPE_CHECKING, Tuple, List, Optional
37
import matplotlib.pyplot as plt
@@ -6,93 +10,110 @@
610
from rsatoolbox.rdm.rdms import RDMs
711
from matplotlib.axes._axes import Axes
812
from matplotlib.figure import Figure
13+
from numpy.typing import NDArray
914

1015

1116
def plot_timecourse(
12-
rdms_data: RDMs,
13-
descriptor: str,
14-
n_t_display:int = 20, #
15-
fig_width: Optional[int] = None,
16-
timecourse_plot_rel_height: Optional[int] = None,
17+
rdms_data: RDMs,
18+
descriptor: str,
19+
n_t_display:int = 20,
20+
fig_width: Optional[int] = None,
21+
timecourse_plot_rel_height: Optional[int] = None,
1722
time_formatted: Optional[List[str]] = None,
1823
colored_conditions: Optional[list] = None,
1924
plot_individual_dissimilarities: Optional[bool] = None,
20-
) -> Tuple[Figure, List[Axes]]:
25+
) -> Tuple[Figure, List[Axes]]:
2126
""" plots the RDM movie for a given descriptor
2227
2328
Args:
2429
rdms_data (rsatoolbox.rdm.RDMs): rdm movie
2530
descriptor (str): name of the descriptor that created the rdm movie
2631
n_t_display (int, optional): number of RDM time points to display. Defaults to 20.
2732
fig_width (int, optional): width of the figure (in inches). Defaults to None.
28-
timecourse_plot_rel_height (int, optional): height of the timecourse plot (relative to the rdm movie row).
29-
time_formatted (List[str], optional): time points formatted as strings.
30-
Defaults to None (i.e., rdms_data.time_descriptors['time'] is considered to be in seconds)
31-
colored_condiitons (list, optional): vector of pattern condition names to dissimilarities according to a categorical model on colored_conditions Defaults to None.
32-
plot_individual_dissimilarities (bool, optional): whether to plot the individual dissimilarities. Defaults to None (i.e., False if colored_conditions is notNone, True otherwise).
33+
timecourse_plot_rel_height (int, optional): height of the timecourse plot (relative to
34+
the rdm movie row).
35+
time_formatted (List[str], optional): time points formatted as strings.
36+
Defaults to None (i.e., rdms_data.time_descriptors['time'] is considered to
37+
be in seconds)
38+
colored_condiitons (list, optional): vector of pattern condition names to dissimilarities
39+
according to a categorical model on colored_conditions Defaults to None.
40+
plot_individual_dissimilarities (bool, optional): whether to plot the individual
41+
dissimilarities. Defaults to None (i.e., False if colored_conditions is not
42+
None, True otherwise).
3343
3444
Returns:
3545
Tuple[matplotlib.figure.Figure, npt.ArrayLike, collections.defaultdict]:
36-
46+
3747
Tuple of
3848
- Handle to created figure
3949
- Subplot axis handles from plt.subplots.
4050
"""
4151
# create labels
4252
time = rdms_data.rdm_descriptors['time']
4353
unique_time = np.unique(time)
44-
time_formatted = time_formatted or ['%0.0f ms' % (np.round(x*1000,2)) for x in unique_time]
45-
46-
n_dissimilarity_elements = rdms_data.dissimilarities.shape[1]
47-
54+
time_formatted = time_formatted or [f'{np.round(x*1000,2):0.0f} ms' for x in unique_time]
55+
56+
n_dissimilarity_elements = rdms_data.dissimilarities.shape[1]
57+
4858
# color mapping from colored conditions
49-
unsquareform = lambda a: a[np.nonzero(np.triu(a, k=1))]
50-
if colored_conditions is not None:
51-
plot_individual_dissimilarities = False if plot_individual_dissimilarities is None else plot_individual_dissimilarities
52-
unsquare_idx = np.triu_indices(n_dissimilarity_elements, k=1)
53-
pairwise_conds = unsquareform(np.array([[{c1, c2} for c1 in colored_conditions] for c2 in colored_conditions]))
59+
if colored_conditions is not None:
60+
if plot_individual_dissimilarities is None:
61+
plot_individual_dissimilarities = False
62+
sf_conds = [[{c1, c2} for c1 in colored_conditions] for c2 in colored_conditions]
63+
pairwise_conds = unsquareform(np.array(sf_conds))
5464
pairwise_conds_unique = np.unique(pairwise_conds)
55-
cnames = np.unique(colored_conditions)
56-
color_index = {f'{list(x)[0]} vs {list(x)[1]}' if len(list(x))==2 else f'{list(x)[0]} vs {list(x)[0]}': pairwise_conds==x for x in pairwise_conds_unique}
65+
color_index = {}
66+
for x in pairwise_conds_unique:
67+
if len(list(x))==2:
68+
key = f'{list(x)[0]} vs {list(x)[1]}'
69+
else:
70+
key = f'{list(x)[0]} vs {list(x)[0]}'
71+
color_index[key] = pairwise_conds==x
5772
else:
5873
color_index = {'': np.array([True]*n_dissimilarity_elements)}
5974
plot_individual_dissimilarities = True
60-
75+
6176
colors = plt.get_cmap('turbo')(np.linspace(0, 1, len(color_index)+1))
62-
77+
6378
# how many rdms to display
64-
t_display_idx = (np.round(np.linspace(0, len(unique_time)-1, min(len(unique_time), n_t_display)))).astype(int)
79+
n_times = len(unique_time)
80+
t_display_idx = (np.round(np.linspace(0, n_times-1, min(n_times, n_t_display)))).astype(int)
6581
t_display_idx = np.unique(t_display_idx)
6682
n_t_display = len(t_display_idx)
67-
83+
6884
# auto determine relative sizes of axis
6985
timecourse_plot_rel_height = timecourse_plot_rel_height or n_t_display // 3
70-
base_size = 40 / n_t_display if fig_width is None else fig_width / n_t_display
71-
86+
base_size = 40 / n_t_display if fig_width is None else fig_width / n_t_display
87+
7288
# figure layout
73-
fig = plt.figure(constrained_layout=True, figsize=(base_size*n_t_display,base_size*timecourse_plot_rel_height))
89+
fig = plt.figure(
90+
constrained_layout=True,
91+
figsize=(base_size*n_t_display,base_size*timecourse_plot_rel_height)
92+
)
7493
gs = fig.add_gridspec(timecourse_plot_rel_height+1, n_t_display)
7594
tc_ax = fig.add_subplot(gs[:-1,:])
7695
rdm_axes = [fig.add_subplot(gs[-1,i]) for i in range(n_t_display)]
7796

78-
# plot dissimilarity timecourses
79-
97+
# plot dissimilarity timecourses
8098
dissimilarities_mean = np.zeros((rdms_data.dissimilarities.shape[1], len(unique_time)))
81-
8299
for i, t in enumerate(unique_time):
83100
dissimilarities_mean[:, i] = np.mean(rdms_data.dissimilarities[t == time, :], axis=0)
84-
85-
def _plot_mean_dissimilarities(labels=False):
101+
102+
def _plot_mean_dissimilarities(labels=False):
86103
for i, (pairwise_name, idx) in enumerate(color_index.items()):
87104
mn = np.mean(dissimilarities_mean[idx, :],axis=0)
88-
se = np.std(dissimilarities_mean[idx, :],axis=0)/ np.sqrt(dissimilarities_mean.shape[0]) # se is over dissimilarities, not over subjects
105+
n = np.sqrt(dissimilarities_mean.shape[0])
106+
# se is over dissimilarities, not over subjects
107+
se = np.std(dissimilarities_mean[idx, :],axis=0)/n
89108
tc_ax.fill_between(unique_time, mn-se, mn+se, color=colors[i], alpha=.3)
90-
tc_ax.plot(unique_time, mn, color=colors[i], linewidth=2, label=pairwise_name if labels else None)
91-
109+
label = pairwise_name if labels else None
110+
tc_ax.plot(unique_time, mn, color=colors[i], linewidth=2, label=label)
111+
92112
def _plot_individual_dissimilarities():
93-
for i, (pairwise_name, idx) in enumerate(color_index.items()):
94-
tc_ax.plot(unique_time, dissimilarities_mean[idx, :].T, color=colors[i], alpha=max(1/255., 1/n_dissimilarity_elements))
95-
113+
for i, (_, idx) in enumerate(color_index.items()):
114+
a = max(1/255., 1/n_dissimilarity_elements)
115+
tc_ax.plot(unique_time, dissimilarities_mean[idx, :].T, color=colors[i], alpha=a)
116+
96117
if plot_individual_dissimilarities:
97118
if colored_conditions is not None:
98119
_plot_mean_dissimilarities()
@@ -101,29 +122,38 @@ def _plot_individual_dissimilarities():
101122
tc_ax.set_ylim(yl)
102123
else:
103124
_plot_individual_dissimilarities()
104-
125+
105126
if colored_conditions is not None:
106127
_plot_mean_dissimilarities(True)
107-
128+
108129
yl = tc_ax.get_ylim()
109130
for t in unique_time[t_display_idx]:
110131
tc_ax.plot([t,t], yl, linestyle=':', color='b', alpha=0.3)
111132
tc_ax.set_ylabel(f'Dissimilarity\n({rdms_data.dissimilarity_measure})')
112133
tc_ax.set_xticks(unique_time)
113-
tc_ax.set_xticklabels([time_formatted[idx] if idx in t_display_idx else '' for idx in range(len(unique_time))])
134+
tc_ax.set_xticklabels([
135+
time_formatted[idx] if idx in t_display_idx else '' for idx in range(len(unique_time))
136+
])
114137
dt = np.diff(unique_time[t_display_idx])[0]
115138
tc_ax.set_xlim(unique_time[t_display_idx[0]]-dt/2, unique_time[t_display_idx[-1]]+dt/2)
116139

117140
tc_ax.legend()
118-
141+
119142
# display (selected) rdms
120143
vmax = np.std(rdms_data.dissimilarities) * 2
121-
for i, (tidx, a) in enumerate(zip(t_display_idx, rdm_axes)):
122-
a.imshow(np.mean(rdms_data.subset('time', unique_time[tidx]).get_matrices(),axis=0), vmin=0, vmax=vmax);
123-
a.set_title('%0.0f ms' % (np.round(unique_time[tidx]*1000,2)))
144+
for i, (tidx, a) in enumerate(zip(t_display_idx, rdm_axes)):
145+
mean_dissim = np.mean(rdms_data.subset('time', unique_time[tidx]).get_matrices(),axis=0)
146+
a.imshow(mean_dissim, vmin=0, vmax=vmax)
147+
a.set_title(f'{np.round(unique_time[tidx]*1000,2):0.0f} ms')
124148
a.set_yticklabels([])
125-
a.set_yticks([])
149+
a.set_yticks([])
126150
a.set_xticklabels([])
127-
a.set_xticks([])
128-
151+
a.set_xticks([])
152+
129153
return fig, [tc_ax] + rdm_axes
154+
155+
156+
def unsquareform(a: NDArray) -> NDArray:
157+
"""Helper function; convert squareform to vector
158+
"""
159+
return a[np.nonzero(np.triu(a, k=1))]

0 commit comments

Comments
 (0)