15
15
from .baserawio import (BaseRawIO , _signal_channel_dtype , _signal_stream_dtype ,
16
16
_spike_channel_dtype , _event_channel_dtype )
17
17
18
+
18
19
class MLBLock (dict ):
19
20
n_byte_dtype = {'logical' : (1 , '?' ),
20
21
'char' : (1 , 'c' ),
@@ -125,7 +126,7 @@ def read_data(self, f, recursive=False):
125
126
n_fields = f .read (8 )
126
127
n_fields = struct .unpack ('Q' , n_fields )[0 ]
127
128
128
- for field in range (n_fields * np .prod (self .var_size )):
129
+ for field in range (n_fields * np .prod (self .var_size )):
129
130
bl = MLBLock .generate_block (f )
130
131
if recursive :
131
132
self [bl .var_name ] = bl
@@ -158,7 +159,6 @@ def flatten(self):
158
159
159
160
160
161
class MonkeyLogicRawIO (BaseRawIO ):
161
-
162
162
extensions = ['bhv2' ]
163
163
rawmode = 'one-file'
164
164
@@ -220,7 +220,7 @@ def _register_signal(sig_block, name):
220
220
221
221
if sig_block .shape [1 ] == 1 :
222
222
signal_channels .append ((name , chan_id , sr , dtype , units , gain , offset ,
223
- stream_id ))
223
+ stream_id ))
224
224
chan_id += 1
225
225
else :
226
226
for sub_chan_id in range (sig_block .shape [1 ]):
@@ -229,9 +229,6 @@ def _register_signal(sig_block, name):
229
229
stream_id ))
230
230
chan_id += 1
231
231
232
-
233
-
234
-
235
232
for sig_name , sig_data in ana_block .items ():
236
233
if sig_name in exclude_signals :
237
234
continue
@@ -248,6 +245,7 @@ def _register_signal(sig_block, name):
248
245
chan_names .append (name )
249
246
_register_signal (sig_sub_data , name = name )
250
247
248
+ # ML does not record spike information
251
249
spike_channels = []
252
250
253
251
signal_channels = np .array (signal_channels , dtype = _signal_channel_dtype )
@@ -259,8 +257,6 @@ def _register_signal(sig_block, name):
259
257
# event_channels.append(('ML Trials', 1, 'epoch')) # no epochs supported yet
260
258
event_channels = np .array (event_channels , dtype = _event_channel_dtype )
261
259
262
-
263
-
264
260
self .header = {}
265
261
self .header ['nb_block' ] = 1
266
262
self .header ['nb_segment' ] = [len (self .trial_ids )]
@@ -273,69 +269,49 @@ def _register_signal(sig_block, name):
273
269
274
270
# adding custom annotations and array annotations
275
271
276
- ignore_annotations = ['AnalogData' , 'AbsoluteTrialStartTime' ]
277
- array_annotation_keys = []
272
+ ignore_annotations = [
273
+ # data blocks
274
+ 'AnalogData' , 'AbsoluteTrialStartTime' , 'BehavioralCodes' , 'CodeNumbers' ,
275
+ # ML temporary variables
276
+ 'ConditionsThisBlock' ,
277
+ 'CurrentBlock' , 'CurrentBlockCount' , 'CurrentBlockCondition' ,
278
+ 'CurrentBlockInfo' , 'CurrentBlockStimulusInfo' , 'CurrentTrialNumber' ,
279
+ 'CurrentTrialWithinBlock' , 'LastTrialAnalogData' , 'LastTrialCodes' ,
280
+ 'NextBlock' , 'NextCondition' ]
281
+
282
+ def _filter_keys (full_dict , ignore_keys , simplify = True ):
283
+ res = {}
284
+ for k , v in full_dict .items ():
285
+ if k in ignore_keys :
286
+ continue
287
+
288
+ if isinstance (v , dict ):
289
+ res [k ] = _filter_keys (v , ignore_keys )
290
+ else :
291
+ if simplify and isinstance (v , np .ndarray ) and np .prod (v .shape ) == 1 :
292
+ v = v .flat [0 ]
293
+ res [k ] = v
294
+ return res
278
295
279
- ml_anno = {k : v for k , v in sorted (self .ml_blocks .items ()) if not k .startswith ('Trial' )}
296
+ ml_ann = {k : v for k , v in self .ml_blocks .items () if k in ['MLConfig' , 'TrialRecord' ]}
297
+ ml_ann = _filter_keys (ml_ann , ignore_annotations )
280
298
bl_ann = self .raw_annotations ['blocks' ][0 ]
281
- bl_ann .update (ml_anno )
299
+ bl_ann .update (ml_ann )
300
+
301
+ for trial_id in self .trial_ids :
302
+ ml_trial = self .ml_blocks [f'Trial{ trial_id } ' ]
303
+ assert ml_trial ['Trial' ] == trial_id
282
304
283
- # TODO annotate segments according to trial properties
284
- seg_ann = self .raw_annotations ['blocks' ][0 ]['segments' ][0 ]
285
- seg_ann .update (ml_anno )
305
+ seg_ann = self .raw_annotations ['blocks' ][0 ]['segments' ][trial_id - 1 ]
306
+ seg_ann .update (_filter_keys (ml_trial , ignore_annotations ))
286
307
287
308
event_ann = seg_ann ['events' ][0 ] # 0 is event
288
309
# epoch_ann = seg_ann['events'][1] # 1 is epoch
289
310
290
- # TODO: add annotations for AnalogSignals
291
- # TODO: add array_annotations for AnalogSignals
292
-
293
- # ml_anno = {k: v for k, v in sorted(self.ml_blocks.items()) if k.startswith('Trial')}
294
- #
295
- # raise NotImplementedError()
296
- #
297
- # # extract array annotations
298
- # event_ann.update(self._filter_properties(props, 'ep'))
299
- # ev_idx += 1
300
- #
301
- # # adding array annotations to analogsignals
302
- # annotated_anasigs = []
303
- # sig_ann = seg_ann['signals']
304
- # # this implementation relies on analogsignals always being
305
- # # stored in the same stream order across segments
306
- # stream_id = 0
307
- # for da_idx, da in enumerate(group.data_arrays):
308
- # if da.type != "neo.analogsignal":
309
- # continue
310
- # anasig_id = da.name.split('.')[-2]
311
- # # skip already annotated signals as each channel already
312
- # # contains the complete set of annotations and
313
- # # array_annotations
314
- # if anasig_id in annotated_anasigs:
315
- # continue
316
- # annotated_anasigs.append(anasig_id)
317
- #
318
- # # collect annotation properties
319
- # props = [p for p in da.metadata.props
320
- # if p.type != 'ARRAYANNOTATION']
321
- # props_dict = self._filter_properties(props, "analogsignal")
322
- # sig_ann[stream_id].update(props_dict)
323
- #
324
- # # collect array annotation properties
325
- # props = [p for p in da.metadata.props
326
- # if p.type == 'ARRAYANNOTATION']
327
- # props_dict = self._filter_properties(props, "analogsignal")
328
- # sig_ann[stream_id]['__array_annotations__'].update(
329
- # props_dict)
330
- #
331
- # stream_id += 1
332
- #
333
- # return
334
-
335
311
def _segment_t_start (self , block_index , seg_index ):
336
312
assert block_index == 0
337
313
338
- t_start = self .ml_blocks [f'Trial{ seg_index + 1 } ' ]['AbsoluteTrialStartTime' ][0 ][0 ]
314
+ t_start = self .ml_blocks [f'Trial{ seg_index + 1 } ' ]['AbsoluteTrialStartTime' ][0 ][0 ]
339
315
return t_start
340
316
341
317
def _segment_t_stop (self , block_index , seg_index ):
@@ -369,7 +345,7 @@ def _get_analogsignal_chunk(self, block_index, seg_index, i_start, i_stop, strea
369
345
i_stop = self .get_signal_size (block_index , seg_index , stream_index )
370
346
371
347
raw_signals_list = []
372
- block = self .ml_blocks [f'Trial{ seg_index + 1 } ' ]['AnalogData' ]
348
+ block = self .ml_blocks [f'Trial{ seg_index + 1 } ' ]['AnalogData' ]
373
349
for sn in stream_name .split ('/' ):
374
350
block = block [sn ]
375
351
@@ -395,7 +371,7 @@ def _get_spike_raw_waveforms(self, block_index, seg_index, unit_index, t_start,
395
371
396
372
def _event_count (self , block_index , seg_index , event_channel_index ):
397
373
assert event_channel_index == 0
398
- times = self .ml_blocks [f'Trial{ seg_index + 1 } ' ]['BehavioralCodes' ]['CodeTimes' ]
374
+ times = self .ml_blocks [f'Trial{ seg_index + 1 } ' ]['BehavioralCodes' ]['CodeTimes' ]
399
375
400
376
return len (times )
401
377
@@ -405,9 +381,9 @@ def _get_event_timestamps(self, block_index, seg_index, event_channel_index, t_s
405
381
assert block_index == 0
406
382
assert event_channel_index == 0
407
383
408
- timestamp = self .ml_blocks [f'Trial{ seg_index + 1 } ' ]['BehavioralCodes' ]['CodeTimes' ]
384
+ timestamp = self .ml_blocks [f'Trial{ seg_index + 1 } ' ]['BehavioralCodes' ]['CodeTimes' ]
409
385
timestamp = timestamp .flatten ()
410
- labels = self .ml_blocks [f'Trial{ seg_index + 1 } ' ]['BehavioralCodes' ]['CodeNumbers' ]
386
+ labels = self .ml_blocks [f'Trial{ seg_index + 1 } ' ]['BehavioralCodes' ]['CodeNumbers' ]
411
387
labels = labels .flatten ()
412
388
413
389
if t_start is not None :
@@ -431,4 +407,3 @@ def _rescale_epoch_duration(self, raw_duration, dtype, event_channel_index):
431
407
# shttps://monkeylogic.nimh.nih.gov/docs_GettingStarted.html#FormatsSupported
432
408
raw_duration /= 1000
433
409
return raw_duration .astype (dtype ) # return in seconds
434
-
0 commit comments