From ae719dd21cfcf1c4930762d7c8616d897db09c87 Mon Sep 17 00:00:00 2001 From: sprenger Date: Wed, 22 Sep 2021 16:03:02 +0200 Subject: [PATCH 1/7] [MLIO] add first version of ML rawio --- neo/rawio/monkeylogicrawio.py | 451 ++++++++++++++++++++ neo/test/rawiotest/test_monkeylogicrawio.py | 43 ++ 2 files changed, 494 insertions(+) create mode 100644 neo/rawio/monkeylogicrawio.py create mode 100644 neo/test/rawiotest/test_monkeylogicrawio.py diff --git a/neo/rawio/monkeylogicrawio.py b/neo/rawio/monkeylogicrawio.py new file mode 100644 index 000000000..10c50b82a --- /dev/null +++ b/neo/rawio/monkeylogicrawio.py @@ -0,0 +1,451 @@ +""" +RawIO Class for MonkeyLogic files + +The RawIO assumes all segments and all blocks have the same structure. +It supports all kinds of NEO objects. +This IO does not support lazy loading. +Reading of bhv2 files based on https://monkeylogic.nimh.nih.gov/docs_BHV2BinaryStructure.html + +Author: Julia Sprenger +""" + +import numpy as np +import struct + +from .baserawio import (BaseRawIO, _signal_channel_dtype, _signal_stream_dtype, + _spike_channel_dtype, _event_channel_dtype) + +class MLBLock(dict): + n_byte_dtype = {'logical': (1, '?'), + 'char': (1, 'c'), + 'integers': (8, 'Q'), + 'uint8': (1, 'B'), + 'single': (4, 'f'), + 'double': (8, 'd')} + + @staticmethod + def generate_block(f): + """ + Generate a new ML block object + :param f: file handle to read to create new block + :return: + """ + LN = f.read(8) + # No MLBlock, e.g. due to EOF + if not LN: + return None + LN = struct.unpack('Q', LN)[0] + # print(f'\nLN: {LN}') + var_name = f.read(LN) + # print(var_name) + + LT = f.read(8) + LT = struct.unpack('Q', LT)[0] + # print(f'LT: {LT}') + var_type = f.read(LT) + # print(var_type) + + DV = f.read(8) + DV = struct.unpack('Q', DV)[0] + # print(f'DV: {DV}') + var_size = f.read(DV * 8) + var_size = struct.unpack(f'{DV}Q', var_size) + # print(var_size) + + return MLBLock(LN, var_name, LT, var_type, DV, var_size) + + def __bool__(self): + if any((self.LN, self.LT)): + return True + else: + return False + + def __init__(self, LN, var_name, LT, var_type, DV, var_size): + self.LN = LN + self.var_name = var_name.decode() + self.LT = LT + self.var_type = var_type.decode() + self.DV = DV + self.var_size = var_size + + self.children = self + self.data = None + + def __repr__(self): + if self.data is None: + shape = 0 + dt = '' + else: + shape = getattr(self.data, 'shape', len(self.data)) + dt = f' dtype: {self.var_type}' + + return f'MLBLock [{shape}|{len(self)}] "{self.var_name}"{dt}' + + def read_data(self, f, recursive=False): + """ + Read data based on the file handle f + + Parameters + ---------- + f file handle + recursive + + Returns + ------- + + """ + + # reading basic data types + if self.var_type in self.n_byte_dtype: + n_byte, format = self.n_byte_dtype[self.var_type] + + data = np.empty(shape=np.prod(self.var_size), dtype=format) + + for i in range(np.prod(self.var_size)): + d = f.read(n_byte) + d = struct.unpack(format, d)[0] + data[i] = d + + data = data.reshape(self.var_size) + + # decoding characters + if self.var_type == 'char': + data = np.char.decode(data) + + # handling convert array to string when only single dimension + if np.prod(self.var_size) == np.max(self.var_size): + data = ''.join(c for c in data.flatten()) + + # print(f'data: {data}') + + self.data = data + + # reading potentially nested data types + elif self.var_type == 'struct': + n_fields = f.read(8) + n_fields = struct.unpack('Q', n_fields)[0] + + for field in range(n_fields*np.prod(self.var_size)): + bl = MLBLock.generate_block(f) + if recursive: + self[bl.var_name] = bl + bl.read_data(f, recursive=recursive) + + elif self.var_type == 'cell': + for field in range(np.prod(self.var_size)): + bl = MLBLock.generate_block(f) + if recursive: + self[bl.var_name] = bl + bl.read_data(f, recursive=recursive) + + else: + raise ValueError(f'unknown variable type {self.var_type}') + + # Sanity check: Blocks can only have children or contain data + if self.data is not None and len(self.keys()): + raise ValueError(f'Block {self.var_name} has {len(self)} children and data: {self.data}') + + + + + +class MonkeyLogicRawIO(BaseRawIO): + + extensions = ['bhv2'] + rawmode = 'one-file' + + def __init__(self, filename=''): + BaseRawIO.__init__(self) + self.filename = str(filename) + + def _source_name(self): + return self.filename + + def _parse_header(self): + self.ml_blocks = {} + + with open(self.filename, 'rb') as f: + while bl := MLBLock.generate_block(f): + bl.read_data(f, recursive=True) + self.ml_blocks[bl.var_name] = bl + + trial_rec = self.ml_blocks['TrialRecord'] + self.trial_ids = np.arange(1, int(trial_rec['CurrentTrialNumber'].data)) + + exclude_signals = ['SampleInterval'] + + # rawio configuration + signal_streams = [] + signal_channels = [] + + if 'Trial1' in self.ml_blocks: + chan_id = 0 + stream_id = 0 + chan_names = [] + + ana_block = self.ml_blocks['Trial1']['AnalogData'] + + def _register_signal(sig_block, prefix=''): + nonlocal stream_id + nonlocal chan_id + if sig_data.data is not None and any(sig_data.data.shape): + signal_streams.append((prefix + sig_data.var_name, stream_id)) + + ch_name = sig_data.var_name + sr = 1 # TODO: Where to get the sampling rate info? + dtype = type(sig_data.data) + units = '' # TODO: Where to find the unit info? + gain = 1 # TODO: Where to find the gain info? + offset = 0 # TODO: Can signals have an offset in ML? + stream_id = 0 # all analog data belong to same stream + + if sig_block.data.shape[1] == 1: + signal_channels.append((prefix + ch_name, chan_id, sr, dtype, units, gain, offset, + stream_id)) + chan_id += 1 + else: + for sub_chan_id in range(sig_block.data.shape[1]): + signal_channels.append( + (prefix + ch_name, chan_id, sr, dtype, units, gain, offset, + stream_id)) + chan_id += 1 + + + + # 1st level signals ('Trial1'/'AnalogData'/') + for sig_name, sig_data in ana_block.items(): + if sig_name in exclude_signals: + continue + + # 1st level signals + if sig_data.data is not None and any(sig_data.data.shape): + _register_signal(sig_data) + + # 2nd level signals + elif sig_data.keys(): + for sig_sub_name, sig_sub_data in sig_data.items(): + if sig_sub_data.data is not None: + chan_names.append(f'{sig_name}/{sig_sub_name}') + _register_signal(sig_sub_data, prefix=f'{sig_name}/') + + + spike_channels = [] + + signal_channels = np.array(signal_channels, dtype=_signal_channel_dtype) + signal_streams = np.array(signal_streams, dtype=_signal_stream_dtype) + spike_channels = np.array(spike_channels, dtype=_spike_channel_dtype) + + event_channels = [] + event_channels.append(('ML Trials', 0, 'event')) + # event_channels.append(('ML Trials', 1, 'epoch')) # no epochs supported yet + event_channels = np.array(event_channels, dtype=_event_channel_dtype) + + + + self.header = {} + self.header['nb_block'] = 1 + self.header['nb_segment'] = [1] + self.header['signal_streams'] = signal_streams + self.header['signal_channels'] = signal_channels + self.header['spike_channels'] = spike_channels + self.header['event_channels'] = event_channels + + self._generate_minimal_annotations() + + # adding custom annotations and array annotations + + ignore_annotations = ['AnalogData', 'AbsoluteTrialStartTime'] + array_annotation_keys = [] + + ml_anno = {k: v for k, v in sorted(self.ml_blocks.items()) if not k.startswith('Trial')} + bl_ann = self.raw_annotations['block'][0] + bl_ann.update(ml_anno) + + # TODO annotate segments according to trial properties + seg_ann = self.raw_annotations['blocks'][0]['segments'][0] + seg_ann.update(ml_anno) + + event_ann = seg_ann['events'][0] # 0 is event + # epoch_ann = seg_ann['events'][1] # 1 is epoch + + # TODO: add annotations for AnalogSignals + # TODO: add array_annotations for AnalogSignals + + # ml_anno = {k: v for k, v in sorted(self.ml_blocks.items()) if k.startswith('Trial')} + # + # raise NotImplementedError() + # + # # extract array annotations + # event_ann.update(self._filter_properties(props, 'ep')) + # ev_idx += 1 + # + # # adding array annotations to analogsignals + # annotated_anasigs = [] + # sig_ann = seg_ann['signals'] + # # this implementation relies on analogsignals always being + # # stored in the same stream order across segments + # stream_id = 0 + # for da_idx, da in enumerate(group.data_arrays): + # if da.type != "neo.analogsignal": + # continue + # anasig_id = da.name.split('.')[-2] + # # skip already annotated signals as each channel already + # # contains the complete set of annotations and + # # array_annotations + # if anasig_id in annotated_anasigs: + # continue + # annotated_anasigs.append(anasig_id) + # + # # collect annotation properties + # props = [p for p in da.metadata.props + # if p.type != 'ARRAYANNOTATION'] + # props_dict = self._filter_properties(props, "analogsignal") + # sig_ann[stream_id].update(props_dict) + # + # # collect array annotation properties + # props = [p for p in da.metadata.props + # if p.type == 'ARRAYANNOTATION'] + # props_dict = self._filter_properties(props, "analogsignal") + # sig_ann[stream_id]['__array_annotations__'].update( + # props_dict) + # + # stream_id += 1 + # + # return + + def _segment_t_start(self, block_index, seg_index): + if 'Trial1' in self.ml_blocks: + t_start = self.ml_blocks['Trial1']['AbsoluteTrialStartTime'].data[0][0] + else: + t_start = 0 + return t_start + + def _segment_t_stop(self, block_index, seg_index): + last_trial = self.ml_blocks[f'Trial{self.trial_ids[-1]}'] + + t_start = last_trial['AbsoluteTrialStartTime'].data[0][0] + t_stop = t_start + 10 # TODO: Find sampling rates to determine trial end + return t_stop + + def _get_signal_size(self, block_index, seg_index, stream_index): + stream_name, stream_id = self.header['signal_streams'][stream_index] + + block = self.ml_blocks[f'Trial{seg_index + 1}']['AnalogData'] + for sn in stream_name.split('/'): # dealing with 1st and 2nd level signals + block = block[sn] + + size = block.data.shape[0] + return size # size is per signal, not the sum of all channel_indexes + + def _get_signal_t_start(self, block_index, seg_index, stream_index): + sig_t_start = self._segment_t_start(block_index, seg_index) + return sig_t_start + + def _get_analogsignal_chunk(self, block_index, seg_index, i_start, i_stop, stream_index, + channel_indexes): + stream_name, stream_id = self.header['signal_streams'][stream_index] + + if i_start is None: + i_start = 0 + if i_stop is None: + i_stop = self.get_signal_size(block_index, seg_index, stream_index) + + raw_signals_list = [] + block = self.ml_blocks[f'Trial{seg_index+1}']['AnalogData'] + for sn in stream_name.split('/'): + block = block[sn] + + if channel_indexes is None: + raw_signals = block.data + else: + raw_signals = block.data[channel_indexes] + + raw_signals = raw_signals[i_start:i_stop] + return raw_signals + + def _spike_count(self, block_index, seg_index, unit_index): + count = 0 + head_id = self.header['spike_channels'][unit_index][1] + for mt in self.file.blocks[block_index].groups[seg_index].multi_tags: + for src in mt.sources: + if mt.type == 'neo.spiketrain' and [src.type == "neo.unit"]: + if head_id == src.id: + return len(mt.positions) + return count + + def _get_spike_timestamps(self, block_index, seg_index, unit_index, + t_start, t_stop): + block = self.unit_list['blocks'][block_index] + segment = block['segments'][seg_index] + spike_dict = segment['spiketrains'] + spike_timestamps = spike_dict[unit_index] + spike_timestamps = np.transpose(spike_timestamps) + + if t_start is not None or t_stop is not None: + lim0 = t_start + lim1 = t_stop + mask = (spike_timestamps >= lim0) & (spike_timestamps <= lim1) + spike_timestamps = spike_timestamps[mask] + return spike_timestamps + + def _rescale_spike_timestamp(self, spike_timestamps, dtype): + spike_times = spike_timestamps.astype(dtype) + return spike_times + + def _get_spike_raw_waveforms(self, block_index, seg_index, unit_index, + t_start, t_stop): + # this must return a 3D numpy array (nb_spike, nb_channel, nb_sample) + seg = self.unit_list['blocks'][block_index]['segments'][seg_index] + waveforms = seg['spiketrains_unit'][unit_index]['waveforms'] + if not waveforms: + return None + raw_waveforms = np.array(waveforms) + + if t_start is not None: + lim0 = t_start + mask = (raw_waveforms >= lim0) + # use nan to keep the shape + raw_waveforms = np.where(mask, raw_waveforms, np.nan) + if t_stop is not None: + lim1 = t_stop + mask = (raw_waveforms <= lim1) + raw_waveforms = np.where(mask, raw_waveforms, np.nan) + return raw_waveforms + + def _event_count(self, block_index, seg_index, event_channel_index): + assert event_channel_index == 0 + times = self.ml_blocks[f'Trial{seg_index+1}']['BehavioralCodes']['CodeTimes'].data + + return len(times) + + def _get_event_timestamps(self, block_index, seg_index, event_channel_index, t_start, t_stop): + + durations = None + assert block_index == 0 + assert event_channel_index == 0 + + timestamp = self.ml_blocks[f'Trial{seg_index+1}']['BehavioralCodes']['CodeTimes'].data + timestamp = timestamp.flatten() + labels = self.ml_blocks[f'Trial{seg_index+1}']['BehavioralCodes']['CodeNumbers'].data + labels = labels.flatten() + + if t_start is not None: + keep = timestamp >= t_start + timestamp, labels = timestamp[keep], labels[keep] + + if t_stop is not None: + keep = timestamp <= t_stop + timestamp, labels = timestamp[keep], labels[keep] + + return timestamp, durations, labels + + def _rescale_event_timestamp(self, event_timestamps, dtype, event_channel_index): + # TODO: Figure out unit and scaling of event timestamps + event_timestamps /= 1000 # assume this is in milliseconds + return event_timestamps.astype(dtype) # return in seconds + + def _rescale_epoch_duration(self, raw_duration, dtype, event_channel_index): + # TODO: Figure out unit and scaling of event timestamps + raw_duration /= 1000 # assume this is in milliseconds + return raw_duration.astype(dtype) # return in seconds + diff --git a/neo/test/rawiotest/test_monkeylogicrawio.py b/neo/test/rawiotest/test_monkeylogicrawio.py new file mode 100644 index 000000000..e97da165e --- /dev/null +++ b/neo/test/rawiotest/test_monkeylogicrawio.py @@ -0,0 +1,43 @@ +import unittest + +from neo.rawio.monkeylogicrawio import MonkeyLogicRawIO +from neo.test.rawiotest.common_rawio_test import BaseTestRawIO + +import logging + +logging.getLogger().setLevel(logging.INFO) + + +class TestMonkeyLogicRawIO(BaseTestRawIO, unittest.TestCase, ): + rawioclass = MonkeyLogicRawIO + entities_to_download = [ + 'monkeylogic' + ] + entities_to_test = [] + + def setUp(self): + # TODO update this + filename = '/home/sprengerj/projects/monkey_logic/210909_TSCM_5cj_5cl_Riesling.bhv2' + filename = '/home/sprengerj/projects/monkey_logic/sabrina/210810__learndms_userloop.bhv2' + # filename = '/home/sprengerj/projects/monkey_logic/sabrina/210916__learndms_userloop.bhv2' + # filename = '/home/sprengerj/projects/monkey_logic/sabrina/210917__learndms_userloop.bhv2' + + + self.rawio = MonkeyLogicRawIO(filename) + + def test_scan_ncs_files(self): + + # Test BML style of Ncs files, similar to PRE4 but with fractional frequency + # in the header and fractional microsPerSamp, which is then rounded as appropriate + # in each record. + + rawio = self.rawio + self.rawio.parse_header() + + # test values here from direct inspection of .ncs files + # self.assertEqual(rawio._nb_segment, 1) + # self.assertListEqual(rawio._timestamp_limits, [(0, 192000)]) + # self.assertEqual(rawio._sigs_length[0], 4608) + # self.assertEqual(rawio._sigs_t_start[0], 0) + # self.assertEqual(rawio._sigs_t_stop[0], 0.192) + # self.assertEqual(len(rawio._sigs_memmaps), 1) From d41dba1716f05ed4f9bfb0dbca5ab087aa3c6c36 Mon Sep 17 00:00:00 2001 From: sprenger Date: Wed, 22 Sep 2021 16:03:47 +0200 Subject: [PATCH 2/7] [MLIO] add main MLIO --- neo/io/__init__.py | 7 ++++++ neo/io/monkeylogicio.py | 21 ++++++++++++++++++ neo/test/iotest/test_monkeylogicio.py | 32 +++++++++++++++++++++++++++ 3 files changed, 60 insertions(+) create mode 100644 neo/io/monkeylogicio.py create mode 100644 neo/test/iotest/test_monkeylogicio.py diff --git a/neo/io/__init__.py b/neo/io/__init__.py index 6d10278ff..2291da6bb 100644 --- a/neo/io/__init__.py +++ b/neo/io/__init__.py @@ -37,6 +37,7 @@ * :attr:`KwikIO` * :attr:`MaxwellIO` * :attr:`MicromedIO` +* :attr:`MonkeyLogicIO` * :attr:`NeoMatlabIO` * :attr:`NestIO` * :attr:`NeuralynxIO` @@ -158,6 +159,10 @@ .. autoattribute:: extensions +.. autoclass:: neo.io.MonkeyLogicIO + + .. autoattribute:: extensions + .. autoclass:: neo.io.NeoMatlabIO .. autoattribute:: extensions @@ -295,6 +300,7 @@ from neo.io.mearecio import MEArecIO from neo.io.maxwellio import MaxwellIO from neo.io.micromedio import MicromedIO +from neo.io.monkeylogicio import MonkeyLogicIO from neo.io.neomatlabio import NeoMatlabIO from neo.io.nestio import NestIO from neo.io.neuralynxio import NeuralynxIO @@ -345,6 +351,7 @@ MEArecIO, MaxwellIO, MicromedIO, + MonkeyLogicIO, NixIO, # place NixIO before other IOs that use HDF5 to make it the default for .h5 files NeoMatlabIO, NestIO, diff --git a/neo/io/monkeylogicio.py b/neo/io/monkeylogicio.py new file mode 100644 index 000000000..1f6297858 --- /dev/null +++ b/neo/io/monkeylogicio.py @@ -0,0 +1,21 @@ +from neo.io.basefromrawio import BaseFromRaw +from neo.rawio.monkeylogicrawio import MonkeyLogicRawIO + + +class MonkeyLogicIO(MonkeyLogicRawIO, BaseFromRaw): + + name = 'MonkeyLogicIO' + + _prefered_signal_group_mode = 'group-by-same-units' + _prefered_units_group_mode = 'all-in-one' + + def __init__(self, filename): + MonkeyLogicRawIO.__init__(self, filename) + BaseFromRaw.__init__(self, filename) + + def __enter__(self): + return self + + def __exit__(self, *args): + self.header = None + self.file.close() diff --git a/neo/test/iotest/test_monkeylogicio.py b/neo/test/iotest/test_monkeylogicio.py new file mode 100644 index 000000000..b62e89cf0 --- /dev/null +++ b/neo/test/iotest/test_monkeylogicio.py @@ -0,0 +1,32 @@ +""" +Tests of neo.io.monkeylogicio +""" + +import unittest + +from neo.io import MonkeyLogicIO +from neo.test.iotest.common_io_test import BaseTestIO + +# class TestMonkeyLogicIO(BaseTestIO, unittest.TestCase): +# entities_to_download = [ +# 'monkeylogic' +# ] +# entities_to_test = [ +# 'monkeylogic/mearec_test_10s.h5' +# ] +# ioclass = MonkeyLogicIO + + +class TestMonkeyLogicIO(unittest.TestCase): + # TODO: Adjust this once ML files are on GIN + + def test_read(self): + filename = '/home/sprengerj/projects/monkey_logic/210909_TSCM_5cj_5cl_Riesling.bhv2' + filename = '/home/sprengerj/projects/monkey_logic/sabrina/210810__learndms_userloop.bhv2' + # filename = '/home/sprengerj/projects/monkey_logic/sabrina/210916__learndms_userloop.bhv2' + # filename = '/home/sprengerj/projects/monkey_logic/sabrina/210917__learndms_userloop.bhv2' + io = MonkeyLogicIO(filename) + io.read_block() + +if __name__ == "__main__": + unittest.main() From 5073a7b28d96470d54f8f3463f833222ef4cb965 Mon Sep 17 00:00:00 2001 From: sprenger Date: Wed, 22 Sep 2021 17:54:21 +0200 Subject: [PATCH 3/7] [MLIO] add non-lazy warning --- neo/io/monkeylogicio.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/neo/io/monkeylogicio.py b/neo/io/monkeylogicio.py index 1f6297858..88b1c03be 100644 --- a/neo/io/monkeylogicio.py +++ b/neo/io/monkeylogicio.py @@ -1,6 +1,6 @@ from neo.io.basefromrawio import BaseFromRaw from neo.rawio.monkeylogicrawio import MonkeyLogicRawIO - +import warnings class MonkeyLogicIO(MonkeyLogicRawIO, BaseFromRaw): @@ -13,9 +13,15 @@ def __init__(self, filename): MonkeyLogicRawIO.__init__(self, filename) BaseFromRaw.__init__(self, filename) - def __enter__(self): - return self + def read_block(self, block_index=0, lazy=False, + create_group_across_segment=None, + signal_group_mode=None, load_waveforms=False): + + if lazy: + warnings.warn('Lazy loading is not supported by MonkeyLogicIO. ' + 'Ignoring `lazy=True` parameter.') - def __exit__(self, *args): - self.header = None - self.file.close() + return BaseFromRaw.read_block(self, block_index=block_index, lazy=False, + create_group_across_segment=create_group_across_segment, + signal_group_mode=signal_group_mode, + load_waveforms=load_waveforms) From 76ddbd83141ce301bada4b4783f7ab0ad19a59cc Mon Sep 17 00:00:00 2001 From: sprenger Date: Wed, 22 Sep 2021 17:54:59 +0200 Subject: [PATCH 4/7] [MLIO] flatten ML_block dictionary structure --- neo/rawio/monkeylogicrawio.py | 159 +++++++++++++++------------------- 1 file changed, 71 insertions(+), 88 deletions(-) diff --git a/neo/rawio/monkeylogicrawio.py b/neo/rawio/monkeylogicrawio.py index 10c50b82a..c39b0f913 100644 --- a/neo/rawio/monkeylogicrawio.py +++ b/neo/rawio/monkeylogicrawio.py @@ -141,12 +141,20 @@ def read_data(self, f, recursive=False): else: raise ValueError(f'unknown variable type {self.var_type}') - # Sanity check: Blocks can only have children or contain data - if self.data is not None and len(self.keys()): - raise ValueError(f'Block {self.var_name} has {len(self)} children and data: {self.data}') - + self.flatten() + def flatten(self): + ''' + Reassigning data objects to be children of parent dict + block1.block2.data -> block1.data as block2 anyway does not contain keys + ''' + for k, v in self.items(): + # Sanity check: Blocks can either have children or contain data + if v.data is not None and len(v.keys()): + raise ValueError(f'Block {k} has {len(k)} children and data: {v.data}') + if v.data is not None: + self[k] = v.data class MonkeyLogicRawIO(BaseRawIO): @@ -161,6 +169,16 @@ def __init__(self, filename=''): def _source_name(self): return self.filename + def _data_sanity_checks(self): + for trial_id in self.trial_ids: + events = self.ml_blocks[f'Trial{trial_id}']['BehavioralCodes'] + + # sanity check: last event == trial end + first_event_code = events['CodeNumbers'][0] + last_event_code = events['CodeNumbers'][-1] + assert first_event_code == 9 # 9 denotes sending of trial start event + assert last_event_code == 18 # 18 denotes sending of trial end event + def _parse_header(self): self.ml_blocks = {} @@ -170,7 +188,9 @@ def _parse_header(self): self.ml_blocks[bl.var_name] = bl trial_rec = self.ml_blocks['TrialRecord'] - self.trial_ids = np.arange(1, int(trial_rec['CurrentTrialNumber'].data)) + self.trial_ids = np.arange(1, int(trial_rec['CurrentTrialNumber'])) + + self._data_sanity_checks() exclude_signals = ['SampleInterval'] @@ -185,49 +205,48 @@ def _parse_header(self): ana_block = self.ml_blocks['Trial1']['AnalogData'] - def _register_signal(sig_block, prefix=''): + def _register_signal(sig_block, name): nonlocal stream_id nonlocal chan_id - if sig_data.data is not None and any(sig_data.data.shape): - signal_streams.append((prefix + sig_data.var_name, stream_id)) + if not isinstance(sig_data, dict) and any(sig_data.shape): + signal_streams.append((name, stream_id)) - ch_name = sig_data.var_name sr = 1 # TODO: Where to get the sampling rate info? - dtype = type(sig_data.data) + dtype = type(sig_data) units = '' # TODO: Where to find the unit info? gain = 1 # TODO: Where to find the gain info? offset = 0 # TODO: Can signals have an offset in ML? stream_id = 0 # all analog data belong to same stream - if sig_block.data.shape[1] == 1: - signal_channels.append((prefix + ch_name, chan_id, sr, dtype, units, gain, offset, + if sig_block.shape[1] == 1: + signal_channels.append((name, chan_id, sr, dtype, units, gain, offset, stream_id)) chan_id += 1 else: - for sub_chan_id in range(sig_block.data.shape[1]): + for sub_chan_id in range(sig_block.shape[1]): signal_channels.append( - (prefix + ch_name, chan_id, sr, dtype, units, gain, offset, + (name, chan_id, sr, dtype, units, gain, offset, stream_id)) chan_id += 1 - # 1st level signals ('Trial1'/'AnalogData'/') + for sig_name, sig_data in ana_block.items(): if sig_name in exclude_signals: continue - # 1st level signals - if sig_data.data is not None and any(sig_data.data.shape): - _register_signal(sig_data) + # 1st level signals ('Trial1'/'AnalogData'/') + if not isinstance(sig_data, dict) and any(sig_data.shape): + _register_signal(sig_data, name=sig_name) - # 2nd level signals - elif sig_data.keys(): + # 2nd level signals ('Trial1'/'AnalogData'//') + elif isinstance(sig_data, dict): for sig_sub_name, sig_sub_data in sig_data.items(): - if sig_sub_data.data is not None: - chan_names.append(f'{sig_name}/{sig_sub_name}') - _register_signal(sig_sub_data, prefix=f'{sig_name}/') - + if not isinstance(sig_sub_data, dict): + name = f'{sig_name}/{sig_sub_name}' + chan_names.append(name) + _register_signal(sig_sub_data, name=name) spike_channels = [] @@ -244,7 +263,7 @@ def _register_signal(sig_block, prefix=''): self.header = {} self.header['nb_block'] = 1 - self.header['nb_segment'] = [1] + self.header['nb_segment'] = [len(self.trial_ids)] self.header['signal_streams'] = signal_streams self.header['signal_channels'] = signal_channels self.header['spike_channels'] = spike_channels @@ -258,7 +277,7 @@ def _register_signal(sig_block, prefix=''): array_annotation_keys = [] ml_anno = {k: v for k, v in sorted(self.ml_blocks.items()) if not k.startswith('Trial')} - bl_ann = self.raw_annotations['block'][0] + bl_ann = self.raw_annotations['blocks'][0] bl_ann.update(ml_anno) # TODO annotate segments according to trial properties @@ -314,18 +333,17 @@ def _register_signal(sig_block, prefix=''): # return def _segment_t_start(self, block_index, seg_index): - if 'Trial1' in self.ml_blocks: - t_start = self.ml_blocks['Trial1']['AbsoluteTrialStartTime'].data[0][0] - else: - t_start = 0 + assert block_index == 0 + + t_start = self.ml_blocks[f'Trial{seg_index+1}']['AbsoluteTrialStartTime'][0][0] return t_start def _segment_t_stop(self, block_index, seg_index): - last_trial = self.ml_blocks[f'Trial{self.trial_ids[-1]}'] + t_start = self._segment_t_start(block_index, seg_index) + # using stream 0 as all analogsignal stream should have the same duration + duration = self._get_signal_size(block_index, seg_index, 0) - t_start = last_trial['AbsoluteTrialStartTime'].data[0][0] - t_stop = t_start + 10 # TODO: Find sampling rates to determine trial end - return t_stop + return t_start + duration def _get_signal_size(self, block_index, seg_index, stream_index): stream_name, stream_id = self.header['signal_streams'][stream_index] @@ -334,7 +352,7 @@ def _get_signal_size(self, block_index, seg_index, stream_index): for sn in stream_name.split('/'): # dealing with 1st and 2nd level signals block = block[sn] - size = block.data.shape[0] + size = block.shape[0] return size # size is per signal, not the sum of all channel_indexes def _get_signal_t_start(self, block_index, seg_index, stream_index): @@ -356,65 +374,28 @@ def _get_analogsignal_chunk(self, block_index, seg_index, i_start, i_stop, strea block = block[sn] if channel_indexes is None: - raw_signals = block.data + raw_signals = block else: - raw_signals = block.data[channel_indexes] + raw_signals = block[channel_indexes] raw_signals = raw_signals[i_start:i_stop] return raw_signals def _spike_count(self, block_index, seg_index, unit_index): - count = 0 - head_id = self.header['spike_channels'][unit_index][1] - for mt in self.file.blocks[block_index].groups[seg_index].multi_tags: - for src in mt.sources: - if mt.type == 'neo.spiketrain' and [src.type == "neo.unit"]: - if head_id == src.id: - return len(mt.positions) - return count - - def _get_spike_timestamps(self, block_index, seg_index, unit_index, - t_start, t_stop): - block = self.unit_list['blocks'][block_index] - segment = block['segments'][seg_index] - spike_dict = segment['spiketrains'] - spike_timestamps = spike_dict[unit_index] - spike_timestamps = np.transpose(spike_timestamps) - - if t_start is not None or t_stop is not None: - lim0 = t_start - lim1 = t_stop - mask = (spike_timestamps >= lim0) & (spike_timestamps <= lim1) - spike_timestamps = spike_timestamps[mask] - return spike_timestamps + return None + + def _get_spike_timestamps(self, block_index, seg_index, unit_index, t_start, t_stop): + return None def _rescale_spike_timestamp(self, spike_timestamps, dtype): - spike_times = spike_timestamps.astype(dtype) - return spike_times - - def _get_spike_raw_waveforms(self, block_index, seg_index, unit_index, - t_start, t_stop): - # this must return a 3D numpy array (nb_spike, nb_channel, nb_sample) - seg = self.unit_list['blocks'][block_index]['segments'][seg_index] - waveforms = seg['spiketrains_unit'][unit_index]['waveforms'] - if not waveforms: - return None - raw_waveforms = np.array(waveforms) + return None - if t_start is not None: - lim0 = t_start - mask = (raw_waveforms >= lim0) - # use nan to keep the shape - raw_waveforms = np.where(mask, raw_waveforms, np.nan) - if t_stop is not None: - lim1 = t_stop - mask = (raw_waveforms <= lim1) - raw_waveforms = np.where(mask, raw_waveforms, np.nan) - return raw_waveforms + def _get_spike_raw_waveforms(self, block_index, seg_index, unit_index, t_start, t_stop): + return None def _event_count(self, block_index, seg_index, event_channel_index): assert event_channel_index == 0 - times = self.ml_blocks[f'Trial{seg_index+1}']['BehavioralCodes']['CodeTimes'].data + times = self.ml_blocks[f'Trial{seg_index+1}']['BehavioralCodes']['CodeTimes'] return len(times) @@ -424,9 +405,9 @@ def _get_event_timestamps(self, block_index, seg_index, event_channel_index, t_s assert block_index == 0 assert event_channel_index == 0 - timestamp = self.ml_blocks[f'Trial{seg_index+1}']['BehavioralCodes']['CodeTimes'].data + timestamp = self.ml_blocks[f'Trial{seg_index+1}']['BehavioralCodes']['CodeTimes'] timestamp = timestamp.flatten() - labels = self.ml_blocks[f'Trial{seg_index+1}']['BehavioralCodes']['CodeNumbers'].data + labels = self.ml_blocks[f'Trial{seg_index+1}']['BehavioralCodes']['CodeNumbers'] labels = labels.flatten() if t_start is not None: @@ -440,12 +421,14 @@ def _get_event_timestamps(self, block_index, seg_index, event_channel_index, t_s return timestamp, durations, labels def _rescale_event_timestamp(self, event_timestamps, dtype, event_channel_index): - # TODO: Figure out unit and scaling of event timestamps - event_timestamps /= 1000 # assume this is in milliseconds + # times are stored in millisecond, see + # shttps://monkeylogic.nimh.nih.gov/docs_GettingStarted.html#FormatsSupported + event_timestamps /= 1000 return event_timestamps.astype(dtype) # return in seconds def _rescale_epoch_duration(self, raw_duration, dtype, event_channel_index): - # TODO: Figure out unit and scaling of event timestamps - raw_duration /= 1000 # assume this is in milliseconds + # times are stored in millisecond, see + # shttps://monkeylogic.nimh.nih.gov/docs_GettingStarted.html#FormatsSupported + raw_duration /= 1000 return raw_duration.astype(dtype) # return in seconds From ead8c15d70bcc25b05e164795f0c357cbcbee4d7 Mon Sep 17 00:00:00 2001 From: sprenger Date: Thu, 23 Sep 2021 13:12:36 +0200 Subject: [PATCH 5/7] [MLIO] add block and trial/segment annotations --- neo/rawio/monkeylogicrawio.py | 107 +++++++++++++--------------------- 1 file changed, 41 insertions(+), 66 deletions(-) diff --git a/neo/rawio/monkeylogicrawio.py b/neo/rawio/monkeylogicrawio.py index c39b0f913..4c24ab190 100644 --- a/neo/rawio/monkeylogicrawio.py +++ b/neo/rawio/monkeylogicrawio.py @@ -15,6 +15,7 @@ from .baserawio import (BaseRawIO, _signal_channel_dtype, _signal_stream_dtype, _spike_channel_dtype, _event_channel_dtype) + class MLBLock(dict): n_byte_dtype = {'logical': (1, '?'), 'char': (1, 'c'), @@ -125,7 +126,7 @@ def read_data(self, f, recursive=False): n_fields = f.read(8) n_fields = struct.unpack('Q', n_fields)[0] - for field in range(n_fields*np.prod(self.var_size)): + for field in range(n_fields * np.prod(self.var_size)): bl = MLBLock.generate_block(f) if recursive: self[bl.var_name] = bl @@ -158,7 +159,6 @@ def flatten(self): class MonkeyLogicRawIO(BaseRawIO): - extensions = ['bhv2'] rawmode = 'one-file' @@ -220,7 +220,7 @@ def _register_signal(sig_block, name): if sig_block.shape[1] == 1: signal_channels.append((name, chan_id, sr, dtype, units, gain, offset, - stream_id)) + stream_id)) chan_id += 1 else: for sub_chan_id in range(sig_block.shape[1]): @@ -229,9 +229,6 @@ def _register_signal(sig_block, name): stream_id)) chan_id += 1 - - - for sig_name, sig_data in ana_block.items(): if sig_name in exclude_signals: continue @@ -248,6 +245,7 @@ def _register_signal(sig_block, name): chan_names.append(name) _register_signal(sig_sub_data, name=name) + # ML does not record spike information spike_channels = [] signal_channels = np.array(signal_channels, dtype=_signal_channel_dtype) @@ -259,8 +257,6 @@ def _register_signal(sig_block, name): # event_channels.append(('ML Trials', 1, 'epoch')) # no epochs supported yet event_channels = np.array(event_channels, dtype=_event_channel_dtype) - - self.header = {} self.header['nb_block'] = 1 self.header['nb_segment'] = [len(self.trial_ids)] @@ -273,69 +269,49 @@ def _register_signal(sig_block, name): # adding custom annotations and array annotations - ignore_annotations = ['AnalogData', 'AbsoluteTrialStartTime'] - array_annotation_keys = [] + ignore_annotations = [ + # data blocks + 'AnalogData', 'AbsoluteTrialStartTime', 'BehavioralCodes', 'CodeNumbers', + # ML temporary variables + 'ConditionsThisBlock', + 'CurrentBlock', 'CurrentBlockCount', 'CurrentBlockCondition', + 'CurrentBlockInfo', 'CurrentBlockStimulusInfo', 'CurrentTrialNumber', + 'CurrentTrialWithinBlock', 'LastTrialAnalogData', 'LastTrialCodes', + 'NextBlock', 'NextCondition'] + + def _filter_keys(full_dict, ignore_keys, simplify=True): + res = {} + for k, v in full_dict.items(): + if k in ignore_keys: + continue + + if isinstance(v, dict): + res[k] = _filter_keys(v, ignore_keys) + else: + if simplify and isinstance(v, np.ndarray) and np.prod(v.shape) == 1: + v = v.flat[0] + res[k] = v + return res - ml_anno = {k: v for k, v in sorted(self.ml_blocks.items()) if not k.startswith('Trial')} + ml_ann = {k: v for k, v in self.ml_blocks.items() if k in ['MLConfig', 'TrialRecord']} + ml_ann = _filter_keys(ml_ann, ignore_annotations) bl_ann = self.raw_annotations['blocks'][0] - bl_ann.update(ml_anno) + bl_ann.update(ml_ann) + + for trial_id in self.trial_ids: + ml_trial = self.ml_blocks[f'Trial{trial_id}'] + assert ml_trial['Trial'] == trial_id - # TODO annotate segments according to trial properties - seg_ann = self.raw_annotations['blocks'][0]['segments'][0] - seg_ann.update(ml_anno) + seg_ann = self.raw_annotations['blocks'][0]['segments'][trial_id-1] + seg_ann.update(_filter_keys(ml_trial, ignore_annotations)) event_ann = seg_ann['events'][0] # 0 is event # epoch_ann = seg_ann['events'][1] # 1 is epoch - # TODO: add annotations for AnalogSignals - # TODO: add array_annotations for AnalogSignals - - # ml_anno = {k: v for k, v in sorted(self.ml_blocks.items()) if k.startswith('Trial')} - # - # raise NotImplementedError() - # - # # extract array annotations - # event_ann.update(self._filter_properties(props, 'ep')) - # ev_idx += 1 - # - # # adding array annotations to analogsignals - # annotated_anasigs = [] - # sig_ann = seg_ann['signals'] - # # this implementation relies on analogsignals always being - # # stored in the same stream order across segments - # stream_id = 0 - # for da_idx, da in enumerate(group.data_arrays): - # if da.type != "neo.analogsignal": - # continue - # anasig_id = da.name.split('.')[-2] - # # skip already annotated signals as each channel already - # # contains the complete set of annotations and - # # array_annotations - # if anasig_id in annotated_anasigs: - # continue - # annotated_anasigs.append(anasig_id) - # - # # collect annotation properties - # props = [p for p in da.metadata.props - # if p.type != 'ARRAYANNOTATION'] - # props_dict = self._filter_properties(props, "analogsignal") - # sig_ann[stream_id].update(props_dict) - # - # # collect array annotation properties - # props = [p for p in da.metadata.props - # if p.type == 'ARRAYANNOTATION'] - # props_dict = self._filter_properties(props, "analogsignal") - # sig_ann[stream_id]['__array_annotations__'].update( - # props_dict) - # - # stream_id += 1 - # - # return - def _segment_t_start(self, block_index, seg_index): assert block_index == 0 - t_start = self.ml_blocks[f'Trial{seg_index+1}']['AbsoluteTrialStartTime'][0][0] + t_start = self.ml_blocks[f'Trial{seg_index + 1}']['AbsoluteTrialStartTime'][0][0] return t_start def _segment_t_stop(self, block_index, seg_index): @@ -369,7 +345,7 @@ def _get_analogsignal_chunk(self, block_index, seg_index, i_start, i_stop, strea i_stop = self.get_signal_size(block_index, seg_index, stream_index) raw_signals_list = [] - block = self.ml_blocks[f'Trial{seg_index+1}']['AnalogData'] + block = self.ml_blocks[f'Trial{seg_index + 1}']['AnalogData'] for sn in stream_name.split('/'): block = block[sn] @@ -395,7 +371,7 @@ def _get_spike_raw_waveforms(self, block_index, seg_index, unit_index, t_start, def _event_count(self, block_index, seg_index, event_channel_index): assert event_channel_index == 0 - times = self.ml_blocks[f'Trial{seg_index+1}']['BehavioralCodes']['CodeTimes'] + times = self.ml_blocks[f'Trial{seg_index + 1}']['BehavioralCodes']['CodeTimes'] return len(times) @@ -405,9 +381,9 @@ def _get_event_timestamps(self, block_index, seg_index, event_channel_index, t_s assert block_index == 0 assert event_channel_index == 0 - timestamp = self.ml_blocks[f'Trial{seg_index+1}']['BehavioralCodes']['CodeTimes'] + timestamp = self.ml_blocks[f'Trial{seg_index + 1}']['BehavioralCodes']['CodeTimes'] timestamp = timestamp.flatten() - labels = self.ml_blocks[f'Trial{seg_index+1}']['BehavioralCodes']['CodeNumbers'] + labels = self.ml_blocks[f'Trial{seg_index + 1}']['BehavioralCodes']['CodeNumbers'] labels = labels.flatten() if t_start is not None: @@ -431,4 +407,3 @@ def _rescale_epoch_duration(self, raw_duration, dtype, event_channel_index): # shttps://monkeylogic.nimh.nih.gov/docs_GettingStarted.html#FormatsSupported raw_duration /= 1000 return raw_duration.astype(dtype) # return in seconds - From 602a84991de5eccc5ed547a5d3c87c27546e846e Mon Sep 17 00:00:00 2001 From: sprenger Date: Sat, 25 Sep 2021 13:15:46 +0200 Subject: [PATCH 6/7] [MLIO] improve cell handling --- neo/rawio/monkeylogicrawio.py | 38 +++++++++++++++++++++++++++-------- 1 file changed, 30 insertions(+), 8 deletions(-) diff --git a/neo/rawio/monkeylogicrawio.py b/neo/rawio/monkeylogicrawio.py index 4c24ab190..4d79eae6d 100644 --- a/neo/rawio/monkeylogicrawio.py +++ b/neo/rawio/monkeylogicrawio.py @@ -16,7 +16,7 @@ _spike_channel_dtype, _event_channel_dtype) -class MLBLock(dict): +class MLBlock(dict): n_byte_dtype = {'logical': (1, '?'), 'char': (1, 'c'), 'integers': (8, 'Q'), @@ -53,7 +53,7 @@ def generate_block(f): var_size = struct.unpack(f'{DV}Q', var_size) # print(var_size) - return MLBLock(LN, var_name, LT, var_type, DV, var_size) + return MLBlock(LN, var_name, LT, var_type, DV, var_size) def __bool__(self): if any((self.LN, self.LT)): @@ -107,7 +107,11 @@ def read_data(self, f, recursive=False): d = struct.unpack(format, d)[0] data[i] = d - data = data.reshape(self.var_size) + # convert to simple / expected data shape + if self.var_size == (1, 1): + data = data[0] + else: + data = data.reshape(self.var_size) # decoding characters if self.var_type == 'char': @@ -127,17 +131,23 @@ def read_data(self, f, recursive=False): n_fields = struct.unpack('Q', n_fields)[0] for field in range(n_fields * np.prod(self.var_size)): - bl = MLBLock.generate_block(f) + bl = MLBlock.generate_block(f) if recursive: self[bl.var_name] = bl bl.read_data(f, recursive=recursive) elif self.var_type == 'cell': + # cells are always 2D + assert len(self.var_size) == 2, 'Unexpected dimensions of cells' + data = np.empty(shape=np.prod(self.var_size), dtype=object) for field in range(np.prod(self.var_size)): - bl = MLBLock.generate_block(f) + bl = MLBlock.generate_block(f) if recursive: - self[bl.var_name] = bl + data[field] = bl + bl.read_data(f, recursive=recursive) + data = data.reshape(self.var_size) + self.data = data else: raise ValueError(f'unknown variable type {self.var_type}') @@ -146,9 +156,12 @@ def read_data(self, f, recursive=False): def flatten(self): ''' - Reassigning data objects to be children of parent dict + Flatten structure by + 1) Reassigning data objects to be children of parent dict block1.block2.data -> block1.data as block2 anyway does not contain keys + 2) converting data arrays items from blocks to data objects ''' + for k, v in self.items(): # Sanity check: Blocks can either have children or contain data if v.data is not None and len(v.keys()): @@ -157,6 +170,15 @@ def flatten(self): if v.data is not None: self[k] = v.data + # converting arrays of MLBlocks (cells) to (nested) list of objects + if isinstance(self[k], np.ndarray) and all([isinstance(b, MLBlock) for b in self[k].flat]): + assert len(self[k].shape) == 2 + for i in range(self[k].shape[0]): + for j in range(self[k].shape[1]): + self[k][i, j] = self[k][i, j].data + self[k] = self[k].tolist() + + class MonkeyLogicRawIO(BaseRawIO): extensions = ['bhv2'] @@ -183,7 +205,7 @@ def _parse_header(self): self.ml_blocks = {} with open(self.filename, 'rb') as f: - while bl := MLBLock.generate_block(f): + while bl := MLBlock.generate_block(f): bl.read_data(f, recursive=True) self.ml_blocks[bl.var_name] = bl From cdfc3a1a3d6149e242afa87e74e737c73407e4a8 Mon Sep 17 00:00:00 2001 From: sprenger Date: Sat, 25 Sep 2021 13:16:39 +0200 Subject: [PATCH 7/7] [MLIO] update test --- neo/test/iotest/test_monkeylogicio.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/neo/test/iotest/test_monkeylogicio.py b/neo/test/iotest/test_monkeylogicio.py index b62e89cf0..7bf29ae8e 100644 --- a/neo/test/iotest/test_monkeylogicio.py +++ b/neo/test/iotest/test_monkeylogicio.py @@ -21,12 +21,18 @@ class TestMonkeyLogicIO(unittest.TestCase): # TODO: Adjust this once ML files are on GIN def test_read(self): - filename = '/home/sprengerj/projects/monkey_logic/210909_TSCM_5cj_5cl_Riesling.bhv2' - filename = '/home/sprengerj/projects/monkey_logic/sabrina/210810__learndms_userloop.bhv2' + filename = '/home/sprengerj/projects/monkey_logic/guilhem/210909_TSCM_5cj_5cl_Riesling.bhv2' + # filename = '/home/sprengerj/projects/monkey_logic/sabrina/210810__learndms_userloop.bhv2' # filename = '/home/sprengerj/projects/monkey_logic/sabrina/210916__learndms_userloop.bhv2' # filename = '/home/sprengerj/projects/monkey_logic/sabrina/210917__learndms_userloop.bhv2' io = MonkeyLogicIO(filename) - io.read_block() + bl = io.read_block() + + assert len(bl.segments) == len(io.trial_ids) + assert 'Trial' in bl.segments[0].annotations + assert len(bl.segments[0].events) == 1 + print(bl.segments[0].events[0].times) + if __name__ == "__main__": unittest.main()