Skip to content

Add multi-segment capability to BaseRasterWidget and children #3805

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: main
Choose a base branch
from

Conversation

jakeswann1
Copy link
Contributor

@jakeswann1 jakeswann1 commented Mar 25, 2025

Adds the option to pass a list of segment indices to the AmplitudesWidget, DriftRasterMapWidget, and RasterWidget to plot across multiple segments, by updating how the base widget handles plotting data. Maintains current default behaviour and SortingView capability. resolves #3801

@jakeswann1
Copy link
Contributor Author

Multi-segment plots would look like this:

image
image
image

@zm711 zm711 added the widgets Related to widgets module label Mar 26, 2025
unit_ids : array-like | None, default: None
List of unit_ids to plot
segment_index : int | list | None, default: None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would simplify this by using segment_indices and only allow list | None. Two advantages: 1) as a user it's not obvious that segment_index can take lists since it's singular. 2) it simplifies the code for checking that the segments exist lower down.

# Multiple segments specified
for idx in segment_index:
if idx not in available_segments:
raise ValueError(f"segment_index {idx} not found in data")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would be good to display the available segments. Something like

raise ValueError(f"segment_index {idx} not found in available segments {available_segments}")

# Calculate total duration across all segments for x-axis limits
total_duration = 0
for idx in segment_indices:
duration = sorting_analyzer.get_num_samples(idx) / sorting_analyzer.sampling_frequency
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd suggest we make a durations list here, and then update BaseRasterWidget to take a list of durations.

unit_ids : array-like | None, default: None
List of unit_ids to plot
segment_index : int | list | None, default: None
For multi-segment data, specifies which segment(s) to plot. If None, uses all available segments.
For single-segment data, this parameter is ignored.
total_duration : int | None, default: None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggest replace total_duration with durations which takes a list of durations per segment

all_units.update(spike_train_data[seg_idx].keys())
unit_ids = list(all_units)

# Calculate segment durations and boundaries
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we used durations in args, we can delete this computation

cumulative_durations.append(cumulative_durations[-1] + duration)

# Segment boundaries for visualization (only internal boundaries)
segment_boundaries = cumulative_durations[1:] if len(segments_to_use) > 1 else None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's a numpy funciton for this. Something like

segment_boundaries = np.cumsum(segment_durations)
cumulative_durations = np.concatenate([[0], segment_boundaries])

should do the job!

spike_times = spike_train_data[seg_idx][unit_id]

# Adjust spike times by adding cumulative duration of previous segments
if offset > 0:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the only case where offset <= 0 when it is equal to 0? If so, it's ok to do adjusted_times = spike_times + 0. Then we can replace these four lines with just

adjusted_times = spike_times + offset

y_values = y_axis_data[seg_idx][unit_id]

# Concatenate with any existing data
if len(concatenated_spike_trains[unit_id]) > 0:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's ok to concatenate a full array to an empty list. So I think the stuff under else is redundant, and you can remove the if statement

concatenated_spike_trains = {unit_id: [] for unit_id in unit_ids}
concatenated_y_axis = {unit_id: [] for unit_id in unit_ids}

for i, seg_idx in enumerate(segments_to_use):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can simplify some of the gross indexing in the for loop by using some zipping here. Try a for loop that looks like:

for offset, spike_train_segment, y_axis_segment in zip(cumulative_durations, spike_train_data.values(), y_axis_data.values()):
    for unit_id, spike_times in spike_train_segment.items():
        y_values = y_axis_segment[unit_id]

Should make the code more readable!


# Update spike train and y-axis data with concatenated values
processed_spike_train_data = concatenated_spike_trains
processed_y_axis_data = concatenated_y_axis
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you can pass concatenated_spike_trains and concatenated_y_axis directly to plot_data

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
widgets Related to widgets module
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Multi-segment support for AmplitudesWidget
3 participants