Skip to content

Refactor contact_vector to _probe_* properties #3629

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 42 additions & 7 deletions src/spikeinterface/core/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -520,6 +520,7 @@ def to_dict(
dump_dict["annotations"] = {k: self._annotations.get(k, None) for k in self._main_annotations}

if include_properties:
print(self._properties.keys())
dump_dict["properties"] = self._properties
else:
# include only main properties
Expand Down Expand Up @@ -688,7 +689,14 @@ def _get_file_path(file_path: Union[str, Path], extensions: Sequence) -> Path:
)
return file_path

def dump(self, file_path: Union[str, Path], relative_to=None, folder_metadata=None) -> None:
def dump(
self,
file_path: Union[str, Path, None] = None,
relative_to: Union[str, Path, bool, None] = None,
include_annotations: bool = True,
include_properties: bool = True,
folder_metadata: Union[str, Path, None] = None,
) -> None:
"""
Dumps extractor to json or pickle

Expand All @@ -699,18 +707,37 @@ def dump(self, file_path: Union[str, Path], relative_to=None, folder_metadata=No
relative_to: str, Path, True or None
If not None, files and folders are serialized relative to this path. If True, the relative folder is the parent folder.
This means that file and folder paths in extractor objects kwargs are changed to be relative rather than absolute.
include_annotations: bool, default: True
If True, all annotations are dumped
include_properties: bool, default: True
If True, all properties are dumped
folder_metadata: str, Path, or None
Folder with files containing additional information (e.g. probe in BaseRecording) and properties
"""
if str(file_path).endswith(".json"):
self.dump_to_json(file_path, relative_to=relative_to, folder_metadata=folder_metadata)
self.dump_to_json(
file_path,
relative_to=relative_to,
include_annotations=include_annotations,
include_properties=include_properties,
folder_metadata=folder_metadata,
)
elif str(file_path).endswith(".pkl") or str(file_path).endswith(".pickle"):
self.dump_to_pickle(file_path, folder_metadata=folder_metadata)
self.dump_to_pickle(
file_path,
include_annotations=include_annotations,
include_properties=include_properties,
folder_metadata=folder_metadata,
)
else:
raise ValueError("Dump: file must .json or .pkl")

def dump_to_json(
self,
file_path: Union[str, Path, None] = None,
relative_to: Union[str, Path, bool, None] = None,
include_annotations: bool = True,
include_properties: bool = True,
folder_metadata: Union[str, Path, None] = None,
) -> None:
"""
Expand All @@ -724,6 +751,10 @@ def dump_to_json(
relative_to: str, Path, True or None
If not None, files and folders are serialized relative to this path. If True, the relative folder is the parent folder.
This means that file and folder paths in extractor objects kwargs are changed to be relative rather than absolute.
include_annotations: bool, default: True
If True, all annotations are dumped
include_properties: bool, default: True
If True, all properties are dumped
folder_metadata: str, Path, or None
Folder with files containing additional information (e.g. probe in BaseRecording) and properties
"""
Expand All @@ -735,8 +766,8 @@ def dump_to_json(
relative_to = relative_to.resolve().absolute()

dump_dict = self.to_dict(
include_annotations=True,
include_properties=False,
include_annotations=include_annotations,
include_properties=include_properties,
relative_to=relative_to,
folder_metadata=folder_metadata,
recursive=True,
Expand All @@ -753,6 +784,7 @@ def dump_to_pickle(
file_path: Union[str, Path, None] = None,
relative_to: Union[str, Path, bool, None] = None,
include_properties: bool = True,
include_annotations: bool = True,
folder_metadata: Union[str, Path, None] = None,
):
"""
Expand All @@ -766,8 +798,10 @@ def dump_to_pickle(
relative_to: str, Path, True or None
If not None, files and folders are serialized relative to this path. If True, the relative folder is the parent folder.
This means that file and folder paths in extractor objects kwargs are changed to be relative rather than absolute.
include_properties: bool
include_properties: bool, default: True
If True, all properties are dumped
include_annotations: bool, default: True
If True, all annotations are dumped
folder_metadata: str, Path, or None
Folder with files containing additional information (e.g. probe in BaseRecording) and properties.
"""
Expand All @@ -783,7 +817,7 @@ def dump_to_pickle(
recursive = False

dump_dict = self.to_dict(
include_annotations=True,
include_annotations=include_annotations,
include_properties=include_properties,
folder_metadata=folder_metadata,
relative_to=relative_to,
Expand Down Expand Up @@ -1128,6 +1162,7 @@ def _load_extractor_from_dict(dic) -> BaseExtractor:

extractor._annotations.update(dic["annotations"])
for k, v in dic["properties"].items():
print(f"Loading property {k}")
extractor.set_property(k, v)

return extractor
Expand Down
4 changes: 2 additions & 2 deletions src/spikeinterface/core/baserecording.py
Original file line number Diff line number Diff line change
Expand Up @@ -644,7 +644,7 @@ def _save(self, format="binary", verbose: bool = False, **save_kwargs):
else:
raise ValueError(f"format {format} not supported")

if self.get_property("contact_vector") is not None:
if self.has_probe():
probegroup = self.get_probegroup()
cached.set_probegroup(probegroup)

Expand Down Expand Up @@ -672,7 +672,7 @@ def _extra_metadata_from_folder(self, folder):

def _extra_metadata_to_folder(self, folder):
# save probe
if self.get_property("contact_vector") is not None:
if self.has_probe():
probegroup = self.get_probegroup()
write_probeinterface(folder / "probe.json", probegroup)

Expand Down
47 changes: 40 additions & 7 deletions src/spikeinterface/core/baserecordingsnippets.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@

from warnings import warn

_minimal_probe_properties = ["_probe_x", "_probe_y"]


class BaseRecordingSnippets(BaseExtractor):
"""
Expand All @@ -22,6 +24,8 @@ def __init__(self, sampling_frequency: float, channel_ids: list[str, int], dtype
BaseExtractor.__init__(self, channel_ids)
self._sampling_frequency = float(sampling_frequency)
self._dtype = np.dtype(dtype)
if self.get_property("contact_vector") is not None:
self._set_properties_from_contact_vector(self.get_property("contact_vector"))

@property
def channel_ids(self):
Expand Down Expand Up @@ -62,7 +66,7 @@ def has_scaled(self):
return self.has_scaleable_traces()

def has_probe(self) -> bool:
return "contact_vector" in self.get_property_keys()
return all(prop in self.get_property_keys() for prop in _minimal_probe_properties)

def has_channel_location(self) -> bool:
return self.has_probe() or "location" in self.get_property_keys()
Expand Down Expand Up @@ -196,7 +200,7 @@ def _set_probes(self, probe_or_probegroup, group_mode="by_probe", in_place=False
sub_recording = self.select_channels(new_channel_ids)

# create a vector that handle all contacts in property
sub_recording.set_property("contact_vector", probe_as_numpy_array, ids=None)
sub_recording._set_probe_numpy_array_to_properties(probe_as_numpy_array)

# planar_contour is saved in annotations
for probe_index, probe in enumerate(probegroup.probes):
Expand Down Expand Up @@ -261,7 +265,7 @@ def get_probes(self):
return probegroup.probes

def get_probegroup(self):
arr = self.get_property("contact_vector")
arr = self._get_probe_numpy_array_from_properties()
if arr is None:
positions = self.get_property("location")
if positions is None:
Expand All @@ -280,6 +284,35 @@ def get_probegroup(self):
probe.set_planar_contour(contour)
return probegroup

def _set_probe_numpy_array_to_properties(self, probe_as_numpy_array):
for key in probe_as_numpy_array.dtype.names:
self.set_property(f"_probe_{key}", probe_as_numpy_array[key], ids=None)

def _get_probe_numpy_array_from_properties(self):
# first construct numpy structured dtype from properties
structured_dtype = []
for key in self.get_property_keys():
if key.startswith("_probe_"):
values = self.get_property(key)
if values.dtype.kind == "SU":
dtype = "U64"
elif values.dtype.kind == "f":
dtype = "float64"
elif values.dtype.kind == "i":
dtype = "int64"
else:
dtype = values.dtype
structured_dtype.append((key[len("_probe_") :], dtype))
arr = np.zeros(self.get_num_channels(), dtype=structured_dtype)
for key in arr.dtype.names:
arr[key] = self.get_property(f"_probe_{key}")
return arr

def _set_properties_from_contact_vector(self, contact_vector):
for key in contact_vector.dtype.names:
self.set_property(f"_probe_{key}", contact_vector[key], ids=None)
self.delete_property("contact_vector")

def _extra_metadata_from_folder(self, folder):
# load probe
folder = Path(folder)
Expand All @@ -289,7 +322,7 @@ def _extra_metadata_from_folder(self, folder):

def _extra_metadata_to_folder(self, folder):
# save probe
if self.get_property("contact_vector") is not None:
if self.has_probe():
probegroup = self.get_probegroup()
write_probeinterface(folder / "probe.json", probegroup)

Expand Down Expand Up @@ -346,17 +379,17 @@ def set_dummy_probe_from_locations(self, locations, shape="circle", shape_params
self.set_probe(probe, in_place=True)

def set_channel_locations(self, locations, channel_ids=None):
if self.get_property("contact_vector") is not None:
if self.has_probe():
raise ValueError("set_channel_locations(..) destroys the probe description, prefer _set_probes(..)")
self.set_property("location", locations, ids=channel_ids)

def get_channel_locations(self, channel_ids=None, axes: str = "xy") -> np.ndarray:
if channel_ids is None:
channel_ids = self.get_channel_ids()
channel_indices = self.ids_to_indices(channel_ids)
contact_vector = self.get_property("contact_vector")
if contact_vector is not None:
if self.has_probe():
# here we bypass the probe reconstruction so this works both for probe and probegroup
contact_vector = self._get_probe_numpy_array_from_properties()
ndim = len(axes)
all_positions = np.zeros((contact_vector.size, ndim), dtype="float64")
for i, dim in enumerate(axes):
Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/core/basesnippets.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ def _save(self, format="npy", **save_kwargs):
else:
raise ValueError(f"format {format} not supported")

if self.get_property("contact_vector") is not None:
if self.has_probe():
probegroup = self.get_probegroup()
cached.set_probegroup(probegroup)

Expand Down
6 changes: 3 additions & 3 deletions src/spikeinterface/core/channelsaggregationrecording.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,9 +93,9 @@ def __init__(self, recording_list_or_dict=None, renamed_channel_ids=None, record
break

for prop_name, prop_values in property_dict.items():
if prop_name == "contact_vector":
# remap device channel indices correctly
prop_values["device_channel_indices"] = np.arange(self.get_num_channels())
# remap device channel indices correctly
if prop_name == "_probe_device_channel_indices":
prop_values = np.arange(self.get_num_channels())
self.set_property(key=prop_name, values=prop_values)

# if locations are present, check that they are all different!
Expand Down
12 changes: 4 additions & 8 deletions src/spikeinterface/core/channelslice.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,10 +65,8 @@ def __init__(self, parent_recording, channel_ids=None, renamed_channel_ids=None)
self._parent = parent_recording

# change the wiring of the probe
contact_vector = self.get_property("contact_vector")
if contact_vector is not None:
contact_vector["device_channel_indices"] = np.arange(len(channel_ids), dtype="int64")
self.set_property("contact_vector", contact_vector)
if parent_recording.has_probe():
self.set_property("_probe_device_channel_indices", np.arange(len(channel_ids), dtype="int64"))

# update dump dict
self._kwargs = {
Expand Down Expand Up @@ -155,10 +153,8 @@ def __init__(self, parent_snippets, channel_ids=None, renamed_channel_ids=None):
parent_snippets.copy_metadata(self, only_main=False, ids=self._channel_ids)

# change the wiring of the probe
contact_vector = self.get_property("contact_vector")
if contact_vector is not None:
contact_vector["device_channel_indices"] = np.arange(len(channel_ids), dtype="int64")
self.set_property("contact_vector", contact_vector)
if parent_snippets.has_probe():
self.set_property("_probe_device_channel_indices", np.arange(len(channel_ids), dtype="int64"))

# update dump dict
self._kwargs = {
Expand Down
2 changes: 0 additions & 2 deletions src/spikeinterface/core/recording_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -1046,8 +1046,6 @@ def get_rec_attributes(recording):
The rec_attributes dictionary
"""
properties_to_attrs = deepcopy(recording._properties)
if "contact_vector" in properties_to_attrs:
del properties_to_attrs["contact_vector"]
rec_attributes = dict(
channel_ids=recording.channel_ids,
sampling_frequency=recording.get_sampling_frequency(),
Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/core/zarrextractors.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,7 +485,7 @@ def add_recording_to_zarr_group(
)

# save probe
if recording.get_property("contact_vector") is not None:
if recording.has_probe():
probegroup = recording.get_probegroup()
zarr_group.attrs["probe"] = check_json(probegroup.to_dict(array_as_list=True))

Expand Down
4 changes: 2 additions & 2 deletions src/spikeinterface/extractors/neoextractors/biocam.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,8 @@ def __init__(
probe_kwargs["electrode_width"] = electrode_width
probe = probeinterface.read_3brain(file_path, **probe_kwargs)
self.set_probe(probe, in_place=True)
self.set_property("row", self.get_property("contact_vector")["row"])
self.set_property("col", self.get_property("contact_vector")["col"])
self.set_property("row", self.get_property("_probe_row"))
self.set_property("col", self.get_property("_probe_col"))

self._kwargs.update(
{
Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/extractors/neoextractors/maxwell.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def __init__(
rec_name = self.neo_reader.rec_name
probe = probeinterface.read_maxwell(file_path, well_name=well_name, rec_name=rec_name)
self.set_probe(probe, in_place=True)
self.set_property("electrode", self.get_property("contact_vector")["electrode"])
self.set_property("electrode", self.get_property("_probe_electrode"))
self._kwargs.update(dict(file_path=str(Path(file_path).absolute()), rec_name=rec_name))

@classmethod
Expand Down
4 changes: 2 additions & 2 deletions src/spikeinterface/extractors/tests/test_iblextractors.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,6 @@ def test_property_keys(self):
expected_property_keys = [
"gain_to_uV",
"offset_to_uV",
"contact_vector",
"location",
"group",
"shank",
Expand All @@ -94,7 +93,8 @@ def test_property_keys(self):
"adc",
"index_on_probe",
]
self.assertCountEqual(first=self.recording.get_property_keys(), second=expected_property_keys)
recording_properties = [k for k in self.recording.get_property_keys() if not k.startswith("_probe")]
self.assertCountEqual(first=recording_properties, second=expected_property_keys)

def test_trace_shape(self):
expected_shape = (21, 384)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,7 @@ def test_compare_input_argument_ranges_against_ibl(shanks, p, sigma_um, num_chan

# distribute default probe locations across 4 shanks if set
x = np.random.choice(shanks, num_channels)
for idx, __ in enumerate(recording._properties["contact_vector"]):
recording._properties["contact_vector"][idx][1] = x[idx]
recording._properties["_probe_x"] = x

# generate random bad channel locations
bad_channel_indexes = np.random.choice(num_channels, np.random.randint(1, int(num_channels / 5)), replace=False)
Expand Down Expand Up @@ -137,14 +136,15 @@ def test_output_values():
bad_channel_indexes = np.array([0])
bad_channel_ids = recording.channel_ids[bad_channel_indexes]

new_probe_locs = [
[5, 7, 3, 5, 5], # 5 channels, a in the center ('bad channel', zero index)
[5, 5, 5, 7, 3],
] # all others equal distance away.
new_probe_locs = np.array(
[
[5, 7, 3, 5, 5], # 5 channels, a in the center ('bad channel', zero index)
[5, 5, 5, 7, 3],
]
).T # all others equal distance away.
# Overwrite the probe information with the new locations
for idx, (x, y) in enumerate(zip(*new_probe_locs)):
recording._properties["contact_vector"][idx][1] = x
recording._properties["contact_vector"][idx][2] = y
recording._properties["_probe_x"] = new_probe_locs[:, 0]
recording._properties["_probe_y"] = new_probe_locs[:, 1]

# Run interpolation in SI and check the interpolated channel
# 0 is a linear combination of other channels
Expand All @@ -158,8 +158,8 @@ def test_output_values():
# Shift the last channel position so that it is 4 units, rather than 2
# away. Setting sigma_um = p = 1 allows easy calculation of the expected
# weights.
recording._properties["contact_vector"][-1][1] = 5
recording._properties["contact_vector"][-1][2] = 9
recording._properties["_probe_x"][-1] = 5
recording._properties["_probe_y"][-1] = 9
expected_weights = np.r_[np.tile(np.exp(-2), 3), np.exp(-4)]
expected_weights /= np.sum(expected_weights)

Expand Down
Loading
Loading