-
Notifications
You must be signed in to change notification settings - Fork 209
Add multi-segment capability to BaseRasterWidget and children #3805
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
unit_ids : array-like | None, default: None | ||
List of unit_ids to plot | ||
segment_index : int | list | None, default: None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would simplify this by using segment_indices
and only allow list | None
. Two advantages: 1) as a user it's not obvious that segment_index
can take lists since it's singular. 2) it simplifies the code for checking that the segments exist lower down.
# Multiple segments specified | ||
for idx in segment_index: | ||
if idx not in available_segments: | ||
raise ValueError(f"segment_index {idx} not found in data") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would be good to display the available segments. Something like
raise ValueError(f"segment_index {idx} not found in available segments {available_segments}")
# Calculate total duration across all segments for x-axis limits | ||
total_duration = 0 | ||
for idx in segment_indices: | ||
duration = sorting_analyzer.get_num_samples(idx) / sorting_analyzer.sampling_frequency |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd suggest we make a durations
list here, and then update BaseRasterWidget
to take a list of durations.
unit_ids : array-like | None, default: None | ||
List of unit_ids to plot | ||
segment_index : int | list | None, default: None | ||
For multi-segment data, specifies which segment(s) to plot. If None, uses all available segments. | ||
For single-segment data, this parameter is ignored. | ||
total_duration : int | None, default: None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Suggest replace total_duration
with durations
which takes a list of durations per segment
all_units.update(spike_train_data[seg_idx].keys()) | ||
unit_ids = list(all_units) | ||
|
||
# Calculate segment durations and boundaries |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we used durations
in args, we can delete this computation
cumulative_durations.append(cumulative_durations[-1] + duration) | ||
|
||
# Segment boundaries for visualization (only internal boundaries) | ||
segment_boundaries = cumulative_durations[1:] if len(segments_to_use) > 1 else None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There's a numpy funciton for this. Something like
segment_boundaries = np.cumsum(segment_durations)
cumulative_durations = np.concatenate([[0], segment_boundaries])
should do the job!
spike_times = spike_train_data[seg_idx][unit_id] | ||
|
||
# Adjust spike times by adding cumulative duration of previous segments | ||
if offset > 0: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is the only case where offset <= 0
when it is equal to 0
? If so, it's ok to do adjusted_times = spike_times + 0
. Then we can replace these four lines with just
adjusted_times = spike_times + offset
y_values = y_axis_data[seg_idx][unit_id] | ||
|
||
# Concatenate with any existing data | ||
if len(concatenated_spike_trains[unit_id]) > 0: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it's ok to concatenate a full array to an empty list. So I think the stuff under else
is redundant, and you can remove the if
statement
concatenated_spike_trains = {unit_id: [] for unit_id in unit_ids} | ||
concatenated_y_axis = {unit_id: [] for unit_id in unit_ids} | ||
|
||
for i, seg_idx in enumerate(segments_to_use): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can simplify some of the gross indexing in the for loop by using some zip
ping here. Try a for
loop that looks like:
for offset, spike_train_segment, y_axis_segment in zip(cumulative_durations, spike_train_data.values(), y_axis_data.values()):
for unit_id, spike_times in spike_train_segment.items():
y_values = y_axis_segment[unit_id]
Should make the code more readable!
|
||
# Update spike train and y-axis data with concatenated values | ||
processed_spike_train_data = concatenated_spike_trains | ||
processed_y_axis_data = concatenated_y_axis |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you can pass concatenated_spike_trains
and concatenated_y_axis
directly to plot_data
Adds the option to pass a list of segment indices to the
AmplitudesWidget
,DriftRasterMapWidget
, andRasterWidget
to plot across multiple segments, by updating how the base widget handles plotting data. Maintains current default behaviour and SortingView capability. resolves #3801