Skip to content

Commit ead8c15

Browse files
author
sprenger
committed
[MLIO] add block and trial/segment annotations
1 parent 76ddbd8 commit ead8c15

File tree

1 file changed

+41
-66
lines changed

1 file changed

+41
-66
lines changed

neo/rawio/monkeylogicrawio.py

+41-66
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from .baserawio import (BaseRawIO, _signal_channel_dtype, _signal_stream_dtype,
1616
_spike_channel_dtype, _event_channel_dtype)
1717

18+
1819
class MLBLock(dict):
1920
n_byte_dtype = {'logical': (1, '?'),
2021
'char': (1, 'c'),
@@ -125,7 +126,7 @@ def read_data(self, f, recursive=False):
125126
n_fields = f.read(8)
126127
n_fields = struct.unpack('Q', n_fields)[0]
127128

128-
for field in range(n_fields*np.prod(self.var_size)):
129+
for field in range(n_fields * np.prod(self.var_size)):
129130
bl = MLBLock.generate_block(f)
130131
if recursive:
131132
self[bl.var_name] = bl
@@ -158,7 +159,6 @@ def flatten(self):
158159

159160

160161
class MonkeyLogicRawIO(BaseRawIO):
161-
162162
extensions = ['bhv2']
163163
rawmode = 'one-file'
164164

@@ -220,7 +220,7 @@ def _register_signal(sig_block, name):
220220

221221
if sig_block.shape[1] == 1:
222222
signal_channels.append((name, chan_id, sr, dtype, units, gain, offset,
223-
stream_id))
223+
stream_id))
224224
chan_id += 1
225225
else:
226226
for sub_chan_id in range(sig_block.shape[1]):
@@ -229,9 +229,6 @@ def _register_signal(sig_block, name):
229229
stream_id))
230230
chan_id += 1
231231

232-
233-
234-
235232
for sig_name, sig_data in ana_block.items():
236233
if sig_name in exclude_signals:
237234
continue
@@ -248,6 +245,7 @@ def _register_signal(sig_block, name):
248245
chan_names.append(name)
249246
_register_signal(sig_sub_data, name=name)
250247

248+
# ML does not record spike information
251249
spike_channels = []
252250

253251
signal_channels = np.array(signal_channels, dtype=_signal_channel_dtype)
@@ -259,8 +257,6 @@ def _register_signal(sig_block, name):
259257
# event_channels.append(('ML Trials', 1, 'epoch')) # no epochs supported yet
260258
event_channels = np.array(event_channels, dtype=_event_channel_dtype)
261259

262-
263-
264260
self.header = {}
265261
self.header['nb_block'] = 1
266262
self.header['nb_segment'] = [len(self.trial_ids)]
@@ -273,69 +269,49 @@ def _register_signal(sig_block, name):
273269

274270
# adding custom annotations and array annotations
275271

276-
ignore_annotations = ['AnalogData', 'AbsoluteTrialStartTime']
277-
array_annotation_keys = []
272+
ignore_annotations = [
273+
# data blocks
274+
'AnalogData', 'AbsoluteTrialStartTime', 'BehavioralCodes', 'CodeNumbers',
275+
# ML temporary variables
276+
'ConditionsThisBlock',
277+
'CurrentBlock', 'CurrentBlockCount', 'CurrentBlockCondition',
278+
'CurrentBlockInfo', 'CurrentBlockStimulusInfo', 'CurrentTrialNumber',
279+
'CurrentTrialWithinBlock', 'LastTrialAnalogData', 'LastTrialCodes',
280+
'NextBlock', 'NextCondition']
281+
282+
def _filter_keys(full_dict, ignore_keys, simplify=True):
283+
res = {}
284+
for k, v in full_dict.items():
285+
if k in ignore_keys:
286+
continue
287+
288+
if isinstance(v, dict):
289+
res[k] = _filter_keys(v, ignore_keys)
290+
else:
291+
if simplify and isinstance(v, np.ndarray) and np.prod(v.shape) == 1:
292+
v = v.flat[0]
293+
res[k] = v
294+
return res
278295

279-
ml_anno = {k: v for k, v in sorted(self.ml_blocks.items()) if not k.startswith('Trial')}
296+
ml_ann = {k: v for k, v in self.ml_blocks.items() if k in ['MLConfig', 'TrialRecord']}
297+
ml_ann = _filter_keys(ml_ann, ignore_annotations)
280298
bl_ann = self.raw_annotations['blocks'][0]
281-
bl_ann.update(ml_anno)
299+
bl_ann.update(ml_ann)
300+
301+
for trial_id in self.trial_ids:
302+
ml_trial = self.ml_blocks[f'Trial{trial_id}']
303+
assert ml_trial['Trial'] == trial_id
282304

283-
# TODO annotate segments according to trial properties
284-
seg_ann = self.raw_annotations['blocks'][0]['segments'][0]
285-
seg_ann.update(ml_anno)
305+
seg_ann = self.raw_annotations['blocks'][0]['segments'][trial_id-1]
306+
seg_ann.update(_filter_keys(ml_trial, ignore_annotations))
286307

287308
event_ann = seg_ann['events'][0] # 0 is event
288309
# epoch_ann = seg_ann['events'][1] # 1 is epoch
289310

290-
# TODO: add annotations for AnalogSignals
291-
# TODO: add array_annotations for AnalogSignals
292-
293-
# ml_anno = {k: v for k, v in sorted(self.ml_blocks.items()) if k.startswith('Trial')}
294-
#
295-
# raise NotImplementedError()
296-
#
297-
# # extract array annotations
298-
# event_ann.update(self._filter_properties(props, 'ep'))
299-
# ev_idx += 1
300-
#
301-
# # adding array annotations to analogsignals
302-
# annotated_anasigs = []
303-
# sig_ann = seg_ann['signals']
304-
# # this implementation relies on analogsignals always being
305-
# # stored in the same stream order across segments
306-
# stream_id = 0
307-
# for da_idx, da in enumerate(group.data_arrays):
308-
# if da.type != "neo.analogsignal":
309-
# continue
310-
# anasig_id = da.name.split('.')[-2]
311-
# # skip already annotated signals as each channel already
312-
# # contains the complete set of annotations and
313-
# # array_annotations
314-
# if anasig_id in annotated_anasigs:
315-
# continue
316-
# annotated_anasigs.append(anasig_id)
317-
#
318-
# # collect annotation properties
319-
# props = [p for p in da.metadata.props
320-
# if p.type != 'ARRAYANNOTATION']
321-
# props_dict = self._filter_properties(props, "analogsignal")
322-
# sig_ann[stream_id].update(props_dict)
323-
#
324-
# # collect array annotation properties
325-
# props = [p for p in da.metadata.props
326-
# if p.type == 'ARRAYANNOTATION']
327-
# props_dict = self._filter_properties(props, "analogsignal")
328-
# sig_ann[stream_id]['__array_annotations__'].update(
329-
# props_dict)
330-
#
331-
# stream_id += 1
332-
#
333-
# return
334-
335311
def _segment_t_start(self, block_index, seg_index):
336312
assert block_index == 0
337313

338-
t_start = self.ml_blocks[f'Trial{seg_index+1}']['AbsoluteTrialStartTime'][0][0]
314+
t_start = self.ml_blocks[f'Trial{seg_index + 1}']['AbsoluteTrialStartTime'][0][0]
339315
return t_start
340316

341317
def _segment_t_stop(self, block_index, seg_index):
@@ -369,7 +345,7 @@ def _get_analogsignal_chunk(self, block_index, seg_index, i_start, i_stop, strea
369345
i_stop = self.get_signal_size(block_index, seg_index, stream_index)
370346

371347
raw_signals_list = []
372-
block = self.ml_blocks[f'Trial{seg_index+1}']['AnalogData']
348+
block = self.ml_blocks[f'Trial{seg_index + 1}']['AnalogData']
373349
for sn in stream_name.split('/'):
374350
block = block[sn]
375351

@@ -395,7 +371,7 @@ def _get_spike_raw_waveforms(self, block_index, seg_index, unit_index, t_start,
395371

396372
def _event_count(self, block_index, seg_index, event_channel_index):
397373
assert event_channel_index == 0
398-
times = self.ml_blocks[f'Trial{seg_index+1}']['BehavioralCodes']['CodeTimes']
374+
times = self.ml_blocks[f'Trial{seg_index + 1}']['BehavioralCodes']['CodeTimes']
399375

400376
return len(times)
401377

@@ -405,9 +381,9 @@ def _get_event_timestamps(self, block_index, seg_index, event_channel_index, t_s
405381
assert block_index == 0
406382
assert event_channel_index == 0
407383

408-
timestamp = self.ml_blocks[f'Trial{seg_index+1}']['BehavioralCodes']['CodeTimes']
384+
timestamp = self.ml_blocks[f'Trial{seg_index + 1}']['BehavioralCodes']['CodeTimes']
409385
timestamp = timestamp.flatten()
410-
labels = self.ml_blocks[f'Trial{seg_index+1}']['BehavioralCodes']['CodeNumbers']
386+
labels = self.ml_blocks[f'Trial{seg_index + 1}']['BehavioralCodes']['CodeNumbers']
411387
labels = labels.flatten()
412388

413389
if t_start is not None:
@@ -431,4 +407,3 @@ def _rescale_epoch_duration(self, raw_duration, dtype, event_channel_index):
431407
# shttps://monkeylogic.nimh.nih.gov/docs_GettingStarted.html#FormatsSupported
432408
raw_duration /= 1000
433409
return raw_duration.astype(dtype) # return in seconds
434-

0 commit comments

Comments
 (0)