@@ -141,12 +141,20 @@ def read_data(self, f, recursive=False):
141141 else :
142142 raise ValueError (f'unknown variable type { self .var_type } ' )
143143
144- # Sanity check: Blocks can only have children or contain data
145- if self .data is not None and len (self .keys ()):
146- raise ValueError (f'Block { self .var_name } has { len (self )} children and data: { self .data } ' )
147-
144+ self .flatten ()
148145
146+ def flatten (self ):
147+ '''
148+ Reassigning data objects to be children of parent dict
149+ block1.block2.data -> block1.data as block2 anyway does not contain keys
150+ '''
151+ for k , v in self .items ():
152+ # Sanity check: Blocks can either have children or contain data
153+ if v .data is not None and len (v .keys ()):
154+ raise ValueError (f'Block { k } has { len (k )} children and data: { v .data } ' )
149155
156+ if v .data is not None :
157+ self [k ] = v .data
150158
151159
152160class MonkeyLogicRawIO (BaseRawIO ):
@@ -161,6 +169,16 @@ def __init__(self, filename=''):
161169 def _source_name (self ):
162170 return self .filename
163171
172+ def _data_sanity_checks (self ):
173+ for trial_id in self .trial_ids :
174+ events = self .ml_blocks [f'Trial{ trial_id } ' ]['BehavioralCodes' ]
175+
176+ # sanity check: last event == trial end
177+ first_event_code = events ['CodeNumbers' ][0 ]
178+ last_event_code = events ['CodeNumbers' ][- 1 ]
179+ assert first_event_code == 9 # 9 denotes sending of trial start event
180+ assert last_event_code == 18 # 18 denotes sending of trial end event
181+
164182 def _parse_header (self ):
165183 self .ml_blocks = {}
166184
@@ -170,7 +188,9 @@ def _parse_header(self):
170188 self .ml_blocks [bl .var_name ] = bl
171189
172190 trial_rec = self .ml_blocks ['TrialRecord' ]
173- self .trial_ids = np .arange (1 , int (trial_rec ['CurrentTrialNumber' ].data ))
191+ self .trial_ids = np .arange (1 , int (trial_rec ['CurrentTrialNumber' ]))
192+
193+ self ._data_sanity_checks ()
174194
175195 exclude_signals = ['SampleInterval' ]
176196
@@ -185,49 +205,48 @@ def _parse_header(self):
185205
186206 ana_block = self .ml_blocks ['Trial1' ]['AnalogData' ]
187207
188- def _register_signal (sig_block , prefix = '' ):
208+ def _register_signal (sig_block , name ):
189209 nonlocal stream_id
190210 nonlocal chan_id
191- if sig_data . data is not None and any (sig_data . data .shape ):
192- signal_streams .append ((prefix + sig_data . var_name , stream_id ))
211+ if not isinstance ( sig_data , dict ) and any (sig_data .shape ):
212+ signal_streams .append ((name , stream_id ))
193213
194- ch_name = sig_data .var_name
195214 sr = 1 # TODO: Where to get the sampling rate info?
196- dtype = type (sig_data . data )
215+ dtype = type (sig_data )
197216 units = '' # TODO: Where to find the unit info?
198217 gain = 1 # TODO: Where to find the gain info?
199218 offset = 0 # TODO: Can signals have an offset in ML?
200219 stream_id = 0 # all analog data belong to same stream
201220
202- if sig_block .data . shape [1 ] == 1 :
203- signal_channels .append ((prefix + ch_name , chan_id , sr , dtype , units , gain , offset ,
221+ if sig_block .shape [1 ] == 1 :
222+ signal_channels .append ((name , chan_id , sr , dtype , units , gain , offset ,
204223 stream_id ))
205224 chan_id += 1
206225 else :
207- for sub_chan_id in range (sig_block .data . shape [1 ]):
226+ for sub_chan_id in range (sig_block .shape [1 ]):
208227 signal_channels .append (
209- (prefix + ch_name , chan_id , sr , dtype , units , gain , offset ,
228+ (name , chan_id , sr , dtype , units , gain , offset ,
210229 stream_id ))
211230 chan_id += 1
212231
213232
214233
215- # 1st level signals ('Trial1'/'AnalogData'/<signal>')
234+
216235 for sig_name , sig_data in ana_block .items ():
217236 if sig_name in exclude_signals :
218237 continue
219238
220- # 1st level signals
221- if sig_data . data is not None and any (sig_data . data .shape ):
222- _register_signal (sig_data )
239+ # 1st level signals ('Trial1'/'AnalogData'/<signal>')
240+ if not isinstance ( sig_data , dict ) and any (sig_data .shape ):
241+ _register_signal (sig_data , name = sig_name )
223242
224- # 2nd level signals
225- elif sig_data . keys ( ):
243+ # 2nd level signals ('Trial1'/'AnalogData'/<signal_group>/<signal>')
244+ elif isinstance ( sig_data , dict ):
226245 for sig_sub_name , sig_sub_data in sig_data .items ():
227- if sig_sub_data . data is not None :
228- chan_names . append ( f'{ sig_name } /{ sig_sub_name } ' )
229- _register_signal ( sig_sub_data , prefix = f' { sig_name } /' )
230-
246+ if not isinstance ( sig_sub_data , dict ) :
247+ name = f'{ sig_name } /{ sig_sub_name } '
248+ chan_names . append ( name )
249+ _register_signal ( sig_sub_data , name = name )
231250
232251 spike_channels = []
233252
@@ -244,7 +263,7 @@ def _register_signal(sig_block, prefix=''):
244263
245264 self .header = {}
246265 self .header ['nb_block' ] = 1
247- self .header ['nb_segment' ] = [1 ]
266+ self .header ['nb_segment' ] = [len ( self . trial_ids ) ]
248267 self .header ['signal_streams' ] = signal_streams
249268 self .header ['signal_channels' ] = signal_channels
250269 self .header ['spike_channels' ] = spike_channels
@@ -258,7 +277,7 @@ def _register_signal(sig_block, prefix=''):
258277 array_annotation_keys = []
259278
260279 ml_anno = {k : v for k , v in sorted (self .ml_blocks .items ()) if not k .startswith ('Trial' )}
261- bl_ann = self .raw_annotations ['block ' ][0 ]
280+ bl_ann = self .raw_annotations ['blocks ' ][0 ]
262281 bl_ann .update (ml_anno )
263282
264283 # TODO annotate segments according to trial properties
@@ -314,18 +333,17 @@ def _register_signal(sig_block, prefix=''):
314333 # return
315334
316335 def _segment_t_start (self , block_index , seg_index ):
317- if 'Trial1' in self .ml_blocks :
318- t_start = self .ml_blocks ['Trial1' ]['AbsoluteTrialStartTime' ].data [0 ][0 ]
319- else :
320- t_start = 0
336+ assert block_index == 0
337+
338+ t_start = self .ml_blocks [f'Trial{ seg_index + 1 } ' ]['AbsoluteTrialStartTime' ][0 ][0 ]
321339 return t_start
322340
323341 def _segment_t_stop (self , block_index , seg_index ):
324- last_trial = self .ml_blocks [f'Trial{ self .trial_ids [- 1 ]} ' ]
342+ t_start = self ._segment_t_start (block_index , seg_index )
343+ # using stream 0 as all analogsignal stream should have the same duration
344+ duration = self ._get_signal_size (block_index , seg_index , 0 )
325345
326- t_start = last_trial ['AbsoluteTrialStartTime' ].data [0 ][0 ]
327- t_stop = t_start + 10 # TODO: Find sampling rates to determine trial end
328- return t_stop
346+ return t_start + duration
329347
330348 def _get_signal_size (self , block_index , seg_index , stream_index ):
331349 stream_name , stream_id = self .header ['signal_streams' ][stream_index ]
@@ -334,7 +352,7 @@ def _get_signal_size(self, block_index, seg_index, stream_index):
334352 for sn in stream_name .split ('/' ): # dealing with 1st and 2nd level signals
335353 block = block [sn ]
336354
337- size = block .data . shape [0 ]
355+ size = block .shape [0 ]
338356 return size # size is per signal, not the sum of all channel_indexes
339357
340358 def _get_signal_t_start (self , block_index , seg_index , stream_index ):
@@ -356,65 +374,28 @@ def _get_analogsignal_chunk(self, block_index, seg_index, i_start, i_stop, strea
356374 block = block [sn ]
357375
358376 if channel_indexes is None :
359- raw_signals = block . data
377+ raw_signals = block
360378 else :
361- raw_signals = block . data [channel_indexes ]
379+ raw_signals = block [channel_indexes ]
362380
363381 raw_signals = raw_signals [i_start :i_stop ]
364382 return raw_signals
365383
366384 def _spike_count (self , block_index , seg_index , unit_index ):
367- count = 0
368- head_id = self .header ['spike_channels' ][unit_index ][1 ]
369- for mt in self .file .blocks [block_index ].groups [seg_index ].multi_tags :
370- for src in mt .sources :
371- if mt .type == 'neo.spiketrain' and [src .type == "neo.unit" ]:
372- if head_id == src .id :
373- return len (mt .positions )
374- return count
375-
376- def _get_spike_timestamps (self , block_index , seg_index , unit_index ,
377- t_start , t_stop ):
378- block = self .unit_list ['blocks' ][block_index ]
379- segment = block ['segments' ][seg_index ]
380- spike_dict = segment ['spiketrains' ]
381- spike_timestamps = spike_dict [unit_index ]
382- spike_timestamps = np .transpose (spike_timestamps )
383-
384- if t_start is not None or t_stop is not None :
385- lim0 = t_start
386- lim1 = t_stop
387- mask = (spike_timestamps >= lim0 ) & (spike_timestamps <= lim1 )
388- spike_timestamps = spike_timestamps [mask ]
389- return spike_timestamps
385+ return None
386+
387+ def _get_spike_timestamps (self , block_index , seg_index , unit_index , t_start , t_stop ):
388+ return None
390389
391390 def _rescale_spike_timestamp (self , spike_timestamps , dtype ):
392- spike_times = spike_timestamps .astype (dtype )
393- return spike_times
394-
395- def _get_spike_raw_waveforms (self , block_index , seg_index , unit_index ,
396- t_start , t_stop ):
397- # this must return a 3D numpy array (nb_spike, nb_channel, nb_sample)
398- seg = self .unit_list ['blocks' ][block_index ]['segments' ][seg_index ]
399- waveforms = seg ['spiketrains_unit' ][unit_index ]['waveforms' ]
400- if not waveforms :
401- return None
402- raw_waveforms = np .array (waveforms )
391+ return None
403392
404- if t_start is not None :
405- lim0 = t_start
406- mask = (raw_waveforms >= lim0 )
407- # use nan to keep the shape
408- raw_waveforms = np .where (mask , raw_waveforms , np .nan )
409- if t_stop is not None :
410- lim1 = t_stop
411- mask = (raw_waveforms <= lim1 )
412- raw_waveforms = np .where (mask , raw_waveforms , np .nan )
413- return raw_waveforms
393+ def _get_spike_raw_waveforms (self , block_index , seg_index , unit_index , t_start , t_stop ):
394+ return None
414395
415396 def _event_count (self , block_index , seg_index , event_channel_index ):
416397 assert event_channel_index == 0
417- times = self .ml_blocks [f'Trial{ seg_index + 1 } ' ]['BehavioralCodes' ]['CodeTimes' ]. data
398+ times = self .ml_blocks [f'Trial{ seg_index + 1 } ' ]['BehavioralCodes' ]['CodeTimes' ]
418399
419400 return len (times )
420401
@@ -424,9 +405,9 @@ def _get_event_timestamps(self, block_index, seg_index, event_channel_index, t_s
424405 assert block_index == 0
425406 assert event_channel_index == 0
426407
427- timestamp = self .ml_blocks [f'Trial{ seg_index + 1 } ' ]['BehavioralCodes' ]['CodeTimes' ]. data
408+ timestamp = self .ml_blocks [f'Trial{ seg_index + 1 } ' ]['BehavioralCodes' ]['CodeTimes' ]
428409 timestamp = timestamp .flatten ()
429- labels = self .ml_blocks [f'Trial{ seg_index + 1 } ' ]['BehavioralCodes' ]['CodeNumbers' ]. data
410+ labels = self .ml_blocks [f'Trial{ seg_index + 1 } ' ]['BehavioralCodes' ]['CodeNumbers' ]
430411 labels = labels .flatten ()
431412
432413 if t_start is not None :
@@ -440,12 +421,14 @@ def _get_event_timestamps(self, block_index, seg_index, event_channel_index, t_s
440421 return timestamp , durations , labels
441422
442423 def _rescale_event_timestamp (self , event_timestamps , dtype , event_channel_index ):
443- # TODO: Figure out unit and scaling of event timestamps
444- event_timestamps /= 1000 # assume this is in milliseconds
424+ # times are stored in millisecond, see
425+ # shttps://monkeylogic.nimh.nih.gov/docs_GettingStarted.html#FormatsSupported
426+ event_timestamps /= 1000
445427 return event_timestamps .astype (dtype ) # return in seconds
446428
447429 def _rescale_epoch_duration (self , raw_duration , dtype , event_channel_index ):
448- # TODO: Figure out unit and scaling of event timestamps
449- raw_duration /= 1000 # assume this is in milliseconds
430+ # times are stored in millisecond, see
431+ # shttps://monkeylogic.nimh.nih.gov/docs_GettingStarted.html#FormatsSupported
432+ raw_duration /= 1000
450433 return raw_duration .astype (dtype ) # return in seconds
451434
0 commit comments