1515from .baserawio import (BaseRawIO , _signal_channel_dtype , _signal_stream_dtype ,
1616 _spike_channel_dtype , _event_channel_dtype )
1717
18+
1819class MLBLock (dict ):
1920 n_byte_dtype = {'logical' : (1 , '?' ),
2021 'char' : (1 , 'c' ),
@@ -125,7 +126,7 @@ def read_data(self, f, recursive=False):
125126 n_fields = f .read (8 )
126127 n_fields = struct .unpack ('Q' , n_fields )[0 ]
127128
128- for field in range (n_fields * np .prod (self .var_size )):
129+ for field in range (n_fields * np .prod (self .var_size )):
129130 bl = MLBLock .generate_block (f )
130131 if recursive :
131132 self [bl .var_name ] = bl
@@ -158,7 +159,6 @@ def flatten(self):
158159
159160
160161class MonkeyLogicRawIO (BaseRawIO ):
161-
162162 extensions = ['bhv2' ]
163163 rawmode = 'one-file'
164164
@@ -220,7 +220,7 @@ def _register_signal(sig_block, name):
220220
221221 if sig_block .shape [1 ] == 1 :
222222 signal_channels .append ((name , chan_id , sr , dtype , units , gain , offset ,
223- stream_id ))
223+ stream_id ))
224224 chan_id += 1
225225 else :
226226 for sub_chan_id in range (sig_block .shape [1 ]):
@@ -229,9 +229,6 @@ def _register_signal(sig_block, name):
229229 stream_id ))
230230 chan_id += 1
231231
232-
233-
234-
235232 for sig_name , sig_data in ana_block .items ():
236233 if sig_name in exclude_signals :
237234 continue
@@ -248,6 +245,7 @@ def _register_signal(sig_block, name):
248245 chan_names .append (name )
249246 _register_signal (sig_sub_data , name = name )
250247
248+ # ML does not record spike information
251249 spike_channels = []
252250
253251 signal_channels = np .array (signal_channels , dtype = _signal_channel_dtype )
@@ -259,8 +257,6 @@ def _register_signal(sig_block, name):
259257 # event_channels.append(('ML Trials', 1, 'epoch')) # no epochs supported yet
260258 event_channels = np .array (event_channels , dtype = _event_channel_dtype )
261259
262-
263-
264260 self .header = {}
265261 self .header ['nb_block' ] = 1
266262 self .header ['nb_segment' ] = [len (self .trial_ids )]
@@ -273,69 +269,49 @@ def _register_signal(sig_block, name):
273269
274270 # adding custom annotations and array annotations
275271
276- ignore_annotations = ['AnalogData' , 'AbsoluteTrialStartTime' ]
277- array_annotation_keys = []
272+ ignore_annotations = [
273+ # data blocks
274+ 'AnalogData' , 'AbsoluteTrialStartTime' , 'BehavioralCodes' , 'CodeNumbers' ,
275+ # ML temporary variables
276+ 'ConditionsThisBlock' ,
277+ 'CurrentBlock' , 'CurrentBlockCount' , 'CurrentBlockCondition' ,
278+ 'CurrentBlockInfo' , 'CurrentBlockStimulusInfo' , 'CurrentTrialNumber' ,
279+ 'CurrentTrialWithinBlock' , 'LastTrialAnalogData' , 'LastTrialCodes' ,
280+ 'NextBlock' , 'NextCondition' ]
281+
282+ def _filter_keys (full_dict , ignore_keys , simplify = True ):
283+ res = {}
284+ for k , v in full_dict .items ():
285+ if k in ignore_keys :
286+ continue
287+
288+ if isinstance (v , dict ):
289+ res [k ] = _filter_keys (v , ignore_keys )
290+ else :
291+ if simplify and isinstance (v , np .ndarray ) and np .prod (v .shape ) == 1 :
292+ v = v .flat [0 ]
293+ res [k ] = v
294+ return res
278295
279- ml_anno = {k : v for k , v in sorted (self .ml_blocks .items ()) if not k .startswith ('Trial' )}
296+ ml_ann = {k : v for k , v in self .ml_blocks .items () if k in ['MLConfig' , 'TrialRecord' ]}
297+ ml_ann = _filter_keys (ml_ann , ignore_annotations )
280298 bl_ann = self .raw_annotations ['blocks' ][0 ]
281- bl_ann .update (ml_anno )
299+ bl_ann .update (ml_ann )
300+
301+ for trial_id in self .trial_ids :
302+ ml_trial = self .ml_blocks [f'Trial{ trial_id } ' ]
303+ assert ml_trial ['Trial' ] == trial_id
282304
283- # TODO annotate segments according to trial properties
284- seg_ann = self .raw_annotations ['blocks' ][0 ]['segments' ][0 ]
285- seg_ann .update (ml_anno )
305+ seg_ann = self .raw_annotations ['blocks' ][0 ]['segments' ][trial_id - 1 ]
306+ seg_ann .update (_filter_keys (ml_trial , ignore_annotations ))
286307
287308 event_ann = seg_ann ['events' ][0 ] # 0 is event
288309 # epoch_ann = seg_ann['events'][1] # 1 is epoch
289310
290- # TODO: add annotations for AnalogSignals
291- # TODO: add array_annotations for AnalogSignals
292-
293- # ml_anno = {k: v for k, v in sorted(self.ml_blocks.items()) if k.startswith('Trial')}
294- #
295- # raise NotImplementedError()
296- #
297- # # extract array annotations
298- # event_ann.update(self._filter_properties(props, 'ep'))
299- # ev_idx += 1
300- #
301- # # adding array annotations to analogsignals
302- # annotated_anasigs = []
303- # sig_ann = seg_ann['signals']
304- # # this implementation relies on analogsignals always being
305- # # stored in the same stream order across segments
306- # stream_id = 0
307- # for da_idx, da in enumerate(group.data_arrays):
308- # if da.type != "neo.analogsignal":
309- # continue
310- # anasig_id = da.name.split('.')[-2]
311- # # skip already annotated signals as each channel already
312- # # contains the complete set of annotations and
313- # # array_annotations
314- # if anasig_id in annotated_anasigs:
315- # continue
316- # annotated_anasigs.append(anasig_id)
317- #
318- # # collect annotation properties
319- # props = [p for p in da.metadata.props
320- # if p.type != 'ARRAYANNOTATION']
321- # props_dict = self._filter_properties(props, "analogsignal")
322- # sig_ann[stream_id].update(props_dict)
323- #
324- # # collect array annotation properties
325- # props = [p for p in da.metadata.props
326- # if p.type == 'ARRAYANNOTATION']
327- # props_dict = self._filter_properties(props, "analogsignal")
328- # sig_ann[stream_id]['__array_annotations__'].update(
329- # props_dict)
330- #
331- # stream_id += 1
332- #
333- # return
334-
335311 def _segment_t_start (self , block_index , seg_index ):
336312 assert block_index == 0
337313
338- t_start = self .ml_blocks [f'Trial{ seg_index + 1 } ' ]['AbsoluteTrialStartTime' ][0 ][0 ]
314+ t_start = self .ml_blocks [f'Trial{ seg_index + 1 } ' ]['AbsoluteTrialStartTime' ][0 ][0 ]
339315 return t_start
340316
341317 def _segment_t_stop (self , block_index , seg_index ):
@@ -369,7 +345,7 @@ def _get_analogsignal_chunk(self, block_index, seg_index, i_start, i_stop, strea
369345 i_stop = self .get_signal_size (block_index , seg_index , stream_index )
370346
371347 raw_signals_list = []
372- block = self .ml_blocks [f'Trial{ seg_index + 1 } ' ]['AnalogData' ]
348+ block = self .ml_blocks [f'Trial{ seg_index + 1 } ' ]['AnalogData' ]
373349 for sn in stream_name .split ('/' ):
374350 block = block [sn ]
375351
@@ -395,7 +371,7 @@ def _get_spike_raw_waveforms(self, block_index, seg_index, unit_index, t_start,
395371
396372 def _event_count (self , block_index , seg_index , event_channel_index ):
397373 assert event_channel_index == 0
398- times = self .ml_blocks [f'Trial{ seg_index + 1 } ' ]['BehavioralCodes' ]['CodeTimes' ]
374+ times = self .ml_blocks [f'Trial{ seg_index + 1 } ' ]['BehavioralCodes' ]['CodeTimes' ]
399375
400376 return len (times )
401377
@@ -405,9 +381,9 @@ def _get_event_timestamps(self, block_index, seg_index, event_channel_index, t_s
405381 assert block_index == 0
406382 assert event_channel_index == 0
407383
408- timestamp = self .ml_blocks [f'Trial{ seg_index + 1 } ' ]['BehavioralCodes' ]['CodeTimes' ]
384+ timestamp = self .ml_blocks [f'Trial{ seg_index + 1 } ' ]['BehavioralCodes' ]['CodeTimes' ]
409385 timestamp = timestamp .flatten ()
410- labels = self .ml_blocks [f'Trial{ seg_index + 1 } ' ]['BehavioralCodes' ]['CodeNumbers' ]
386+ labels = self .ml_blocks [f'Trial{ seg_index + 1 } ' ]['BehavioralCodes' ]['CodeNumbers' ]
411387 labels = labels .flatten ()
412388
413389 if t_start is not None :
@@ -431,4 +407,3 @@ def _rescale_epoch_duration(self, raw_duration, dtype, event_channel_index):
431407 # shttps://monkeylogic.nimh.nih.gov/docs_GettingStarted.html#FormatsSupported
432408 raw_duration /= 1000
433409 return raw_duration .astype (dtype ) # return in seconds
434-
0 commit comments