Skip to content
138 changes: 81 additions & 57 deletions neo/rawio/blackrockrawio.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,19 +323,21 @@ def _parse_header(self):
self.__nsx_data_header = {}

for nsx_nb in self._avail_nsx:
spec = self.__nsx_spec[nsx_nb] = self.__extract_nsx_file_spec(nsx_nb)
spec_version = self.__nsx_spec[nsx_nb] = self.__extract_nsx_file_spec(nsx_nb)
# read nsx headers
self.__nsx_basic_header[nsx_nb], self.__nsx_ext_header[nsx_nb] = self.__nsx_header_reader[spec](nsx_nb)
nsx_header_reader = self.__nsx_header_reader[spec_version]
self.__nsx_basic_header[nsx_nb], self.__nsx_ext_header[nsx_nb] = nsx_header_reader(nsx_nb)

# The only way to know if it is the PTP-variant of file spec 3.0
# The only way to know if it is the peak-to-peak-variant of file spec 3.0
# is to check for nanosecond timestamp resolution.
if (
is_ptp_variant = (
"timestamp_resolution" in self.__nsx_basic_header[nsx_nb].dtype.names
and self.__nsx_basic_header[nsx_nb]["timestamp_resolution"] == 1_000_000_000
):
)
if is_ptp_variant:
nsx_dataheader_reader = self.__nsx_dataheader_reader["3.0-ptp"]
else:
nsx_dataheader_reader = self.__nsx_dataheader_reader[spec]
nsx_dataheader_reader = self.__nsx_dataheader_reader[spec_version]
# for nsxdef get_analogsignal_shape(self, block_index, seg_index):
self.__nsx_data_header[nsx_nb] = nsx_dataheader_reader(nsx_nb)

Expand All @@ -355,8 +357,12 @@ def _parse_header(self):
else:
raise (ValueError("nsx_to_load is wrong"))

if not all(nsx_nb in self._avail_nsx for nsx_nb in self.nsx_to_load):
raise FileNotFoundError(f"nsx_to_load does not match available nsx list")
missing_nsx_files = [nsx_nb for nsx_nb in self.nsx_to_load if nsx_nb not in self._avail_nsx]
if missing_nsx_files:
missing_list = ", ".join(f"ns{nsx_nb}" for nsx_nb in missing_nsx_files)
raise FileNotFoundError(
f"Requested NSX file(s) not found: {missing_list}. Available NSX files: {self._avail_nsx}"
)

# check that all files come from the same specification
all_spec = [self.__nsx_spec[nsx_nb] for nsx_nb in self.nsx_to_load]
Expand All @@ -381,27 +387,29 @@ def _parse_header(self):
self.sig_sampling_rates = {}
if len(self.nsx_to_load) > 0:
for nsx_nb in self.nsx_to_load:
spec = self.__nsx_spec[nsx_nb]
# The only way to know if it is the PTP-variant of file spec 3.0
basic_header = self.__nsx_basic_header[nsx_nb]
spec_version = self.__nsx_spec[nsx_nb]
# The only way to know if it is the peak-to-peak-variant of file spec 3.0
# is to check for nanosecond timestamp resolution.
if (
"timestamp_resolution" in self.__nsx_basic_header[nsx_nb].dtype.names
and self.__nsx_basic_header[nsx_nb]["timestamp_resolution"] == 1_000_000_000
):
is_ptp_variant = (
"timestamp_resolution" in basic_header.dtype.names
and basic_header["timestamp_resolution"] == 1_000_000_000
)
if is_ptp_variant:
_data_reader_fun = self.__nsx_data_reader["3.0-ptp"]
else:
_data_reader_fun = self.__nsx_data_reader[spec]
_data_reader_fun = self.__nsx_data_reader[spec_version]
self.nsx_datas[nsx_nb] = _data_reader_fun(nsx_nb)

sr = float(self.main_sampling_rate / self.__nsx_basic_header[nsx_nb]["period"])
sr = float(self.main_sampling_rate / basic_header["period"])
self.sig_sampling_rates[nsx_nb] = sr

if spec in ["2.2", "2.3", "3.0"]:
if spec_version in ["2.2", "2.3", "3.0"]:
ext_header = self.__nsx_ext_header[nsx_nb]
elif spec == "2.1":
elif spec_version == "2.1":
ext_header = []
keys = ["labels", "units", "min_analog_val", "max_analog_val", "min_digital_val", "max_digital_val"]
params = self.__nsx_params[spec](nsx_nb)
params = self.__nsx_params[spec_version](nsx_nb)
for i in range(len(params["labels"])):
d = {}
for key in keys:
Expand All @@ -415,11 +423,11 @@ def _parse_header(self):
signal_buffers.append((stream_name, buffer_id))
signal_streams.append((stream_name, stream_id, buffer_id))
for i, chan in enumerate(ext_header):
if spec in ["2.2", "2.3", "3.0"]:
if spec_version in ["2.2", "2.3", "3.0"]:
ch_name = chan["electrode_label"].decode()
ch_id = str(chan["electrode_id"])
units = chan["units"].decode()
elif spec == "2.1":
elif spec_version == "2.1":
ch_name = chan["labels"]
ch_id = str(self.__nsx_ext_header[nsx_nb][i]["electrode_id"])
units = chan["units"]
Expand Down Expand Up @@ -809,7 +817,7 @@ def __extract_nsx_file_spec(self, nsx_nb):
"""
Extract file specification from an .nsx file.
"""
filename = ".".join([self._filenames["nsx"], f"ns{nsx_nb}"])
filename = f"{self._filenames['nsx']}.ns{nsx_nb}"

# Header structure of files specification 2.2 and higher. For files 2.1
# and lower, the entries ver_major and ver_minor are not supported.
Expand All @@ -829,7 +837,7 @@ def __extract_nev_file_spec(self):
"""
Extract file specification from an .nev file
"""
filename = ".".join([self._filenames["nev"], "nev"])
filename = f"{self._filenames['nev']}.nev"
# Header structure of files specification 2.2 and higher. For files 2.1
# and lower, the entries ver_major and ver_minor are not supported.
dt0 = [("file_id", "S8"), ("ver_major", "uint8"), ("ver_minor", "uint8")]
Expand Down Expand Up @@ -879,7 +887,7 @@ def __read_nsx_header_variant_b(self, nsx_nb):
"""
Extract nsx header information from a 2.2 or 2.3 .nsx file
"""
filename = ".".join([self._filenames["nsx"], f"ns{nsx_nb}"])
filename = f"{self._filenames['nsx']}.ns{nsx_nb}"

# basic header (file_id: NEURALCD)
dt0 = [
Expand Down Expand Up @@ -911,7 +919,6 @@ def __read_nsx_header_variant_b(self, nsx_nb):

# extended header (type: CC)
offset_dt0 = np.dtype(dt0).itemsize
shape = nsx_basic_header["channel_count"]
dt1 = [
("type", "S2"),
("electrode_id", "uint16"),
Expand All @@ -930,28 +937,32 @@ def __read_nsx_header_variant_b(self, nsx_nb):
# filter settings used to create nsx from source signal
("hi_freq_corner", "uint32"),
("hi_freq_order", "uint32"),
("hi_freq_type", "uint16"), # 0=None, 1=Butterworth
("hi_freq_type", "uint16"), # 0=None, 1=Butterworth, 2=Chebyshev
("lo_freq_corner", "uint32"),
("lo_freq_order", "uint32"),
("lo_freq_type", "uint16"),
] # 0=None, 1=Butterworth
] # 0=None, 1=Butterworth, 2=Chebyshev

nsx_ext_header = np.memmap(filename, shape=shape, offset=offset_dt0, dtype=dt1, mode="r")
channel_count = int(nsx_basic_header["channel_count"])
nsx_ext_header = np.memmap(filename, shape=channel_count, offset=offset_dt0, dtype=dt1, mode="r")

return nsx_basic_header, nsx_ext_header

def __read_nsx_dataheader(self, nsx_nb, offset):
"""
Reads data header following the given offset of an nsx file.
"""
filename = ".".join([self._filenames["nsx"], f"ns{nsx_nb}"])
filename = f"{self._filenames['nsx']}.ns{nsx_nb}"

ts_size = "uint64" if self.__nsx_basic_header[nsx_nb]["ver_major"] >= 3 else "uint32"
major_version = self.__nsx_basic_header[nsx_nb]["ver_major"]
ts_size = "uint64" if major_version >= 3 else "uint32"

# dtypes data header, the header flag is always set to 1
dt2 = [("header_flag", "uint8"), ("timestamp", ts_size), ("nb_data_points", "uint32")]

# dtypes data header
dt2 = [("header", "uint8"), ("timestamp", ts_size), ("nb_data_points", "uint32")]
packet_header = np.memmap(filename, dtype=dt2, shape=1, offset=offset, mode="r")[0]

return np.memmap(filename, dtype=dt2, shape=1, offset=offset, mode="r")[0]
return packet_header

def __read_nsx_dataheader_variant_a(self, nsx_nb, filesize=None, offset=None):
"""
Expand All @@ -971,32 +982,44 @@ def __read_nsx_dataheader_variant_b(
Reads the nsx data header for each data block following the offset of
file spec 2.2, 2.3, and 3.0.
"""
filename = ".".join([self._filenames["nsx"], f"ns{nsx_nb}"])
filename = f"{self._filenames['nsx']}.ns{nsx_nb}"

filesize = self.__get_file_size(filename)

data_header = {}
index = 0

if offset is None:
offset = self.__nsx_basic_header[nsx_nb]["bytes_in_headers"]
offset_to_first_data_block = offset or int(self.__nsx_basic_header[nsx_nb]["bytes_in_headers"])

channel_count = int(self.__nsx_basic_header[nsx_nb]["channel_count"])
offset = offset_to_first_data_block
data_block_index = 0
while offset < filesize:
dh = self.__read_nsx_dataheader(nsx_nb, offset)
data_header[index] = {
"header": dh["header"],
"timestamp": dh["timestamp"],
"nb_data_points": dh["nb_data_points"],
"offset_to_data_block": offset + dh.dtype.itemsize,
packet_header = self.__read_nsx_dataheader(nsx_nb, offset)
header_flag = packet_header["header_flag"]
# NSX data blocks must have header_flag = 1, other values indicate file corruption
if header_flag != 1:
raise ValueError(
f"Invalid NSX data block header at offset {offset:#x} in ns{nsx_nb} file. "
f"Expected header_flag=1, got {header_flag}. "
f"This may indicate file corruption or unsupported NSX format variant. "
f"Block index: {data_block_index}, File size: {filesize} bytes"
)
timestamp = packet_header["timestamp"]
offset_to_data_block_start = offset + packet_header.dtype.itemsize
num_data_points = int(packet_header["nb_data_points"])

data_header[data_block_index] = {
"header": header_flag,
"timestamp": timestamp,
"nb_data_points": num_data_points,
"offset_to_data_block": offset_to_data_block_start,
}

# data size = number of data points * (2bytes * number of channels)
# use of `int` avoids overflow problem
data_size = int(dh["nb_data_points"]) * int(self.__nsx_basic_header[nsx_nb]["channel_count"]) * 2
# define new offset (to possible next data block)
offset = int(data_header[index]["offset_to_data_block"]) + data_size
# Jump to the next data block
data_block_size = num_data_points * channel_count * np.dtype("int16").itemsize
offset = offset_to_data_block_start + data_block_size

index += 1
data_block_index += 1

return data_header

Expand Down Expand Up @@ -1082,19 +1105,20 @@ def __read_nsx_data_variant_b(self, nsx_nb):
Extract nsx data (blocks) from a 2.2, 2.3, or 3.0 .nsx file.
Blocks can arise if the recording was paused by the user.
"""
filename = ".".join([self._filenames["nsx"], f"ns{nsx_nb}"])
filename = f"{self._filenames['nsx']}.ns{nsx_nb}"

data = {}
for data_bl in self.__nsx_data_header[nsx_nb].keys():
data_header = self.__nsx_data_header[nsx_nb]
number_of_channels = int(self.__nsx_basic_header[nsx_nb]["channel_count"])

for data_block in data_header.keys():
# get shape and offset of data
shape = (
int(self.__nsx_data_header[nsx_nb][data_bl]["nb_data_points"]),
int(self.__nsx_basic_header[nsx_nb]["channel_count"]),
)
offset = int(self.__nsx_data_header[nsx_nb][data_bl]["offset_to_data_block"])
number_of_samples = int(data_header[data_block]["nb_data_points"])
shape = (number_of_samples, number_of_channels)
offset = int(data_header[data_block]["offset_to_data_block"])

# read data
data[data_bl] = np.memmap(filename, dtype="int16", shape=shape, offset=offset, mode="r")
data[data_block] = np.memmap(filename, dtype="int16", shape=shape, offset=offset, mode="r")

return data

Expand Down
Loading