diff --git a/src/spikeinterface/core/base.py b/src/spikeinterface/core/base.py index 107c2ec4c0..48a6989c69 100644 --- a/src/spikeinterface/core/base.py +++ b/src/spikeinterface/core/base.py @@ -520,6 +520,7 @@ def to_dict( dump_dict["annotations"] = {k: self._annotations.get(k, None) for k in self._main_annotations} if include_properties: + print(self._properties.keys()) dump_dict["properties"] = self._properties else: # include only main properties @@ -688,7 +689,14 @@ def _get_file_path(file_path: Union[str, Path], extensions: Sequence) -> Path: ) return file_path - def dump(self, file_path: Union[str, Path], relative_to=None, folder_metadata=None) -> None: + def dump( + self, + file_path: Union[str, Path, None] = None, + relative_to: Union[str, Path, bool, None] = None, + include_annotations: bool = True, + include_properties: bool = True, + folder_metadata: Union[str, Path, None] = None, + ) -> None: """ Dumps extractor to json or pickle @@ -699,11 +707,28 @@ def dump(self, file_path: Union[str, Path], relative_to=None, folder_metadata=No relative_to: str, Path, True or None If not None, files and folders are serialized relative to this path. If True, the relative folder is the parent folder. This means that file and folder paths in extractor objects kwargs are changed to be relative rather than absolute. + include_annotations: bool, default: True + If True, all annotations are dumped + include_properties: bool, default: True + If True, all properties are dumped + folder_metadata: str, Path, or None + Folder with files containing additional information (e.g. probe in BaseRecording) and properties """ if str(file_path).endswith(".json"): - self.dump_to_json(file_path, relative_to=relative_to, folder_metadata=folder_metadata) + self.dump_to_json( + file_path, + relative_to=relative_to, + include_annotations=include_annotations, + include_properties=include_properties, + folder_metadata=folder_metadata, + ) elif str(file_path).endswith(".pkl") or str(file_path).endswith(".pickle"): - self.dump_to_pickle(file_path, folder_metadata=folder_metadata) + self.dump_to_pickle( + file_path, + include_annotations=include_annotations, + include_properties=include_properties, + folder_metadata=folder_metadata, + ) else: raise ValueError("Dump: file must .json or .pkl") @@ -711,6 +736,8 @@ def dump_to_json( self, file_path: Union[str, Path, None] = None, relative_to: Union[str, Path, bool, None] = None, + include_annotations: bool = True, + include_properties: bool = True, folder_metadata: Union[str, Path, None] = None, ) -> None: """ @@ -724,6 +751,10 @@ def dump_to_json( relative_to: str, Path, True or None If not None, files and folders are serialized relative to this path. If True, the relative folder is the parent folder. This means that file and folder paths in extractor objects kwargs are changed to be relative rather than absolute. + include_annotations: bool, default: True + If True, all annotations are dumped + include_properties: bool, default: True + If True, all properties are dumped folder_metadata: str, Path, or None Folder with files containing additional information (e.g. probe in BaseRecording) and properties """ @@ -735,8 +766,8 @@ def dump_to_json( relative_to = relative_to.resolve().absolute() dump_dict = self.to_dict( - include_annotations=True, - include_properties=False, + include_annotations=include_annotations, + include_properties=include_properties, relative_to=relative_to, folder_metadata=folder_metadata, recursive=True, @@ -753,6 +784,7 @@ def dump_to_pickle( file_path: Union[str, Path, None] = None, relative_to: Union[str, Path, bool, None] = None, include_properties: bool = True, + include_annotations: bool = True, folder_metadata: Union[str, Path, None] = None, ): """ @@ -766,8 +798,10 @@ def dump_to_pickle( relative_to: str, Path, True or None If not None, files and folders are serialized relative to this path. If True, the relative folder is the parent folder. This means that file and folder paths in extractor objects kwargs are changed to be relative rather than absolute. - include_properties: bool + include_properties: bool, default: True If True, all properties are dumped + include_annotations: bool, default: True + If True, all annotations are dumped folder_metadata: str, Path, or None Folder with files containing additional information (e.g. probe in BaseRecording) and properties. """ @@ -783,7 +817,7 @@ def dump_to_pickle( recursive = False dump_dict = self.to_dict( - include_annotations=True, + include_annotations=include_annotations, include_properties=include_properties, folder_metadata=folder_metadata, relative_to=relative_to, @@ -1128,6 +1162,7 @@ def _load_extractor_from_dict(dic) -> BaseExtractor: extractor._annotations.update(dic["annotations"]) for k, v in dic["properties"].items(): + print(f"Loading property {k}") extractor.set_property(k, v) return extractor diff --git a/src/spikeinterface/core/baserecording.py b/src/spikeinterface/core/baserecording.py index 8d4c81ecd5..c41daa91f2 100644 --- a/src/spikeinterface/core/baserecording.py +++ b/src/spikeinterface/core/baserecording.py @@ -644,7 +644,7 @@ def _save(self, format="binary", verbose: bool = False, **save_kwargs): else: raise ValueError(f"format {format} not supported") - if self.get_property("contact_vector") is not None: + if self.has_probe(): probegroup = self.get_probegroup() cached.set_probegroup(probegroup) @@ -672,7 +672,7 @@ def _extra_metadata_from_folder(self, folder): def _extra_metadata_to_folder(self, folder): # save probe - if self.get_property("contact_vector") is not None: + if self.has_probe(): probegroup = self.get_probegroup() write_probeinterface(folder / "probe.json", probegroup) diff --git a/src/spikeinterface/core/baserecordingsnippets.py b/src/spikeinterface/core/baserecordingsnippets.py index 8933561669..e0999e96a9 100644 --- a/src/spikeinterface/core/baserecordingsnippets.py +++ b/src/spikeinterface/core/baserecordingsnippets.py @@ -10,6 +10,8 @@ from warnings import warn +_minimal_probe_properties = ["_probe_x", "_probe_y"] + class BaseRecordingSnippets(BaseExtractor): """ @@ -22,6 +24,8 @@ def __init__(self, sampling_frequency: float, channel_ids: list[str, int], dtype BaseExtractor.__init__(self, channel_ids) self._sampling_frequency = float(sampling_frequency) self._dtype = np.dtype(dtype) + if self.get_property("contact_vector") is not None: + self._set_properties_from_contact_vector(self.get_property("contact_vector")) @property def channel_ids(self): @@ -62,7 +66,7 @@ def has_scaled(self): return self.has_scaleable_traces() def has_probe(self) -> bool: - return "contact_vector" in self.get_property_keys() + return all(prop in self.get_property_keys() for prop in _minimal_probe_properties) def has_channel_location(self) -> bool: return self.has_probe() or "location" in self.get_property_keys() @@ -196,7 +200,7 @@ def _set_probes(self, probe_or_probegroup, group_mode="by_probe", in_place=False sub_recording = self.select_channels(new_channel_ids) # create a vector that handle all contacts in property - sub_recording.set_property("contact_vector", probe_as_numpy_array, ids=None) + sub_recording._set_probe_numpy_array_to_properties(probe_as_numpy_array) # planar_contour is saved in annotations for probe_index, probe in enumerate(probegroup.probes): @@ -261,7 +265,7 @@ def get_probes(self): return probegroup.probes def get_probegroup(self): - arr = self.get_property("contact_vector") + arr = self._get_probe_numpy_array_from_properties() if arr is None: positions = self.get_property("location") if positions is None: @@ -280,6 +284,35 @@ def get_probegroup(self): probe.set_planar_contour(contour) return probegroup + def _set_probe_numpy_array_to_properties(self, probe_as_numpy_array): + for key in probe_as_numpy_array.dtype.names: + self.set_property(f"_probe_{key}", probe_as_numpy_array[key], ids=None) + + def _get_probe_numpy_array_from_properties(self): + # first construct numpy structured dtype from properties + structured_dtype = [] + for key in self.get_property_keys(): + if key.startswith("_probe_"): + values = self.get_property(key) + if values.dtype.kind == "SU": + dtype = "U64" + elif values.dtype.kind == "f": + dtype = "float64" + elif values.dtype.kind == "i": + dtype = "int64" + else: + dtype = values.dtype + structured_dtype.append((key[len("_probe_") :], dtype)) + arr = np.zeros(self.get_num_channels(), dtype=structured_dtype) + for key in arr.dtype.names: + arr[key] = self.get_property(f"_probe_{key}") + return arr + + def _set_properties_from_contact_vector(self, contact_vector): + for key in contact_vector.dtype.names: + self.set_property(f"_probe_{key}", contact_vector[key], ids=None) + self.delete_property("contact_vector") + def _extra_metadata_from_folder(self, folder): # load probe folder = Path(folder) @@ -289,7 +322,7 @@ def _extra_metadata_from_folder(self, folder): def _extra_metadata_to_folder(self, folder): # save probe - if self.get_property("contact_vector") is not None: + if self.has_probe(): probegroup = self.get_probegroup() write_probeinterface(folder / "probe.json", probegroup) @@ -346,7 +379,7 @@ def set_dummy_probe_from_locations(self, locations, shape="circle", shape_params self.set_probe(probe, in_place=True) def set_channel_locations(self, locations, channel_ids=None): - if self.get_property("contact_vector") is not None: + if self.has_probe(): raise ValueError("set_channel_locations(..) destroys the probe description, prefer _set_probes(..)") self.set_property("location", locations, ids=channel_ids) @@ -354,9 +387,9 @@ def get_channel_locations(self, channel_ids=None, axes: str = "xy") -> np.ndarra if channel_ids is None: channel_ids = self.get_channel_ids() channel_indices = self.ids_to_indices(channel_ids) - contact_vector = self.get_property("contact_vector") - if contact_vector is not None: + if self.has_probe(): # here we bypass the probe reconstruction so this works both for probe and probegroup + contact_vector = self._get_probe_numpy_array_from_properties() ndim = len(axes) all_positions = np.zeros((contact_vector.size, ndim), dtype="float64") for i, dim in enumerate(axes): diff --git a/src/spikeinterface/core/basesnippets.py b/src/spikeinterface/core/basesnippets.py index 869842779d..cd24036413 100644 --- a/src/spikeinterface/core/basesnippets.py +++ b/src/spikeinterface/core/basesnippets.py @@ -214,7 +214,7 @@ def _save(self, format="npy", **save_kwargs): else: raise ValueError(f"format {format} not supported") - if self.get_property("contact_vector") is not None: + if self.has_probe(): probegroup = self.get_probegroup() cached.set_probegroup(probegroup) diff --git a/src/spikeinterface/core/channelsaggregationrecording.py b/src/spikeinterface/core/channelsaggregationrecording.py index 0797313793..613db38445 100644 --- a/src/spikeinterface/core/channelsaggregationrecording.py +++ b/src/spikeinterface/core/channelsaggregationrecording.py @@ -93,9 +93,9 @@ def __init__(self, recording_list_or_dict=None, renamed_channel_ids=None, record break for prop_name, prop_values in property_dict.items(): - if prop_name == "contact_vector": - # remap device channel indices correctly - prop_values["device_channel_indices"] = np.arange(self.get_num_channels()) + # remap device channel indices correctly + if prop_name == "_probe_device_channel_indices": + prop_values = np.arange(self.get_num_channels()) self.set_property(key=prop_name, values=prop_values) # if locations are present, check that they are all different! diff --git a/src/spikeinterface/core/channelslice.py b/src/spikeinterface/core/channelslice.py index 8a4f29e86c..f7013f5390 100644 --- a/src/spikeinterface/core/channelslice.py +++ b/src/spikeinterface/core/channelslice.py @@ -65,10 +65,8 @@ def __init__(self, parent_recording, channel_ids=None, renamed_channel_ids=None) self._parent = parent_recording # change the wiring of the probe - contact_vector = self.get_property("contact_vector") - if contact_vector is not None: - contact_vector["device_channel_indices"] = np.arange(len(channel_ids), dtype="int64") - self.set_property("contact_vector", contact_vector) + if parent_recording.has_probe(): + self.set_property("_probe_device_channel_indices", np.arange(len(channel_ids), dtype="int64")) # update dump dict self._kwargs = { @@ -155,10 +153,8 @@ def __init__(self, parent_snippets, channel_ids=None, renamed_channel_ids=None): parent_snippets.copy_metadata(self, only_main=False, ids=self._channel_ids) # change the wiring of the probe - contact_vector = self.get_property("contact_vector") - if contact_vector is not None: - contact_vector["device_channel_indices"] = np.arange(len(channel_ids), dtype="int64") - self.set_property("contact_vector", contact_vector) + if parent_snippets.has_probe(): + self.set_property("_probe_device_channel_indices", np.arange(len(channel_ids), dtype="int64")) # update dump dict self._kwargs = { diff --git a/src/spikeinterface/core/recording_tools.py b/src/spikeinterface/core/recording_tools.py index f7a2bce6a7..28ce2f73ef 100644 --- a/src/spikeinterface/core/recording_tools.py +++ b/src/spikeinterface/core/recording_tools.py @@ -1046,8 +1046,6 @@ def get_rec_attributes(recording): The rec_attributes dictionary """ properties_to_attrs = deepcopy(recording._properties) - if "contact_vector" in properties_to_attrs: - del properties_to_attrs["contact_vector"] rec_attributes = dict( channel_ids=recording.channel_ids, sampling_frequency=recording.get_sampling_frequency(), diff --git a/src/spikeinterface/core/zarrextractors.py b/src/spikeinterface/core/zarrextractors.py index e3c407b2f6..f1d87054ea 100644 --- a/src/spikeinterface/core/zarrextractors.py +++ b/src/spikeinterface/core/zarrextractors.py @@ -485,7 +485,7 @@ def add_recording_to_zarr_group( ) # save probe - if recording.get_property("contact_vector") is not None: + if recording.has_probe(): probegroup = recording.get_probegroup() zarr_group.attrs["probe"] = check_json(probegroup.to_dict(array_as_list=True)) diff --git a/src/spikeinterface/extractors/neoextractors/biocam.py b/src/spikeinterface/extractors/neoextractors/biocam.py index 6713bbbc2c..57211cceb9 100644 --- a/src/spikeinterface/extractors/neoextractors/biocam.py +++ b/src/spikeinterface/extractors/neoextractors/biocam.py @@ -73,8 +73,8 @@ def __init__( probe_kwargs["electrode_width"] = electrode_width probe = probeinterface.read_3brain(file_path, **probe_kwargs) self.set_probe(probe, in_place=True) - self.set_property("row", self.get_property("contact_vector")["row"]) - self.set_property("col", self.get_property("contact_vector")["col"]) + self.set_property("row", self.get_property("_probe_row")) + self.set_property("col", self.get_property("_probe_col")) self._kwargs.update( { diff --git a/src/spikeinterface/extractors/neoextractors/maxwell.py b/src/spikeinterface/extractors/neoextractors/maxwell.py index 04e41433e1..f90bdc51c1 100644 --- a/src/spikeinterface/extractors/neoextractors/maxwell.py +++ b/src/spikeinterface/extractors/neoextractors/maxwell.py @@ -77,7 +77,7 @@ def __init__( rec_name = self.neo_reader.rec_name probe = probeinterface.read_maxwell(file_path, well_name=well_name, rec_name=rec_name) self.set_probe(probe, in_place=True) - self.set_property("electrode", self.get_property("contact_vector")["electrode"]) + self.set_property("electrode", self.get_property("_probe_electrode")) self._kwargs.update(dict(file_path=str(Path(file_path).absolute()), rec_name=rec_name)) @classmethod diff --git a/src/spikeinterface/extractors/tests/test_iblextractors.py b/src/spikeinterface/extractors/tests/test_iblextractors.py index 334ca63d6a..0d1d2786dd 100644 --- a/src/spikeinterface/extractors/tests/test_iblextractors.py +++ b/src/spikeinterface/extractors/tests/test_iblextractors.py @@ -83,7 +83,6 @@ def test_property_keys(self): expected_property_keys = [ "gain_to_uV", "offset_to_uV", - "contact_vector", "location", "group", "shank", @@ -94,7 +93,8 @@ def test_property_keys(self): "adc", "index_on_probe", ] - self.assertCountEqual(first=self.recording.get_property_keys(), second=expected_property_keys) + recording_properties = [k for k in self.recording.get_property_keys() if not k.startswith("_probe")] + self.assertCountEqual(first=recording_properties, second=expected_property_keys) def test_trace_shape(self): expected_shape = (21, 384) diff --git a/src/spikeinterface/preprocessing/tests/test_interpolate_bad_channels.py b/src/spikeinterface/preprocessing/tests/test_interpolate_bad_channels.py index 06bde4e3d1..1b56694b6e 100644 --- a/src/spikeinterface/preprocessing/tests/test_interpolate_bad_channels.py +++ b/src/spikeinterface/preprocessing/tests/test_interpolate_bad_channels.py @@ -98,8 +98,7 @@ def test_compare_input_argument_ranges_against_ibl(shanks, p, sigma_um, num_chan # distribute default probe locations across 4 shanks if set x = np.random.choice(shanks, num_channels) - for idx, __ in enumerate(recording._properties["contact_vector"]): - recording._properties["contact_vector"][idx][1] = x[idx] + recording._properties["_probe_x"] = x # generate random bad channel locations bad_channel_indexes = np.random.choice(num_channels, np.random.randint(1, int(num_channels / 5)), replace=False) @@ -137,14 +136,15 @@ def test_output_values(): bad_channel_indexes = np.array([0]) bad_channel_ids = recording.channel_ids[bad_channel_indexes] - new_probe_locs = [ - [5, 7, 3, 5, 5], # 5 channels, a in the center ('bad channel', zero index) - [5, 5, 5, 7, 3], - ] # all others equal distance away. + new_probe_locs = np.array( + [ + [5, 7, 3, 5, 5], # 5 channels, a in the center ('bad channel', zero index) + [5, 5, 5, 7, 3], + ] + ).T # all others equal distance away. # Overwrite the probe information with the new locations - for idx, (x, y) in enumerate(zip(*new_probe_locs)): - recording._properties["contact_vector"][idx][1] = x - recording._properties["contact_vector"][idx][2] = y + recording._properties["_probe_x"] = new_probe_locs[:, 0] + recording._properties["_probe_y"] = new_probe_locs[:, 1] # Run interpolation in SI and check the interpolated channel # 0 is a linear combination of other channels @@ -158,8 +158,8 @@ def test_output_values(): # Shift the last channel position so that it is 4 units, rather than 2 # away. Setting sigma_um = p = 1 allows easy calculation of the expected # weights. - recording._properties["contact_vector"][-1][1] = 5 - recording._properties["contact_vector"][-1][2] = 9 + recording._properties["_probe_x"][-1] = 5 + recording._properties["_probe_y"][-1] = 9 expected_weights = np.r_[np.tile(np.exp(-2), 3), np.exp(-4)] expected_weights /= np.sum(expected_weights) diff --git a/src/spikeinterface/preprocessing/zero_channel_pad.py b/src/spikeinterface/preprocessing/zero_channel_pad.py index c06baf525a..718cb1f27f 100644 --- a/src/spikeinterface/preprocessing/zero_channel_pad.py +++ b/src/spikeinterface/preprocessing/zero_channel_pad.py @@ -161,7 +161,7 @@ def __init__(self, recording: BaseRecording, num_channels: int, channel_mapping: "The new mapping cannot exceed total number of channels " "in the zero-chanenl-padded recording." ) else: - if "locations" in recording.get_property_keys() or "contact_vector" in recording.get_property_keys(): + if recording.has_probe(): self.channel_mapping = np.argsort(recording.get_channel_locations()[:, 1]) else: self.channel_mapping = np.arange(recording.get_num_channels()) diff --git a/src/spikeinterface/sortingcomponents/motion/motion_interpolation.py b/src/spikeinterface/sortingcomponents/motion/motion_interpolation.py index 2a64f9f7ea..34ce41105d 100644 --- a/src/spikeinterface/sortingcomponents/motion/motion_interpolation.py +++ b/src/spikeinterface/sortingcomponents/motion/motion_interpolation.py @@ -408,10 +408,7 @@ def __init__( if border_mode == "remove_channels": # change the wiring of the probe # TODO this is also done in ChannelSliceRecording, this should be done in a common place - contact_vector = self.get_property("contact_vector") - if contact_vector is not None: - contact_vector["device_channel_indices"] = np.arange(len(channel_ids), dtype="int64") - self.set_property("contact_vector", contact_vector) + self.set_property("_probe_device_channel_indices", np.arange(len(channel_ids), dtype="int64")) # handle manual interpolation_time_bin_centers_s # the case where interpolation_time_bin_size_s is set is handled per-segment below