@@ -141,12 +141,20 @@ def read_data(self, f, recursive=False):
141
141
else :
142
142
raise ValueError (f'unknown variable type { self .var_type } ' )
143
143
144
- # Sanity check: Blocks can only have children or contain data
145
- if self .data is not None and len (self .keys ()):
146
- raise ValueError (f'Block { self .var_name } has { len (self )} children and data: { self .data } ' )
147
-
144
+ self .flatten ()
148
145
146
+ def flatten (self ):
147
+ '''
148
+ Reassigning data objects to be children of parent dict
149
+ block1.block2.data -> block1.data as block2 anyway does not contain keys
150
+ '''
151
+ for k , v in self .items ():
152
+ # Sanity check: Blocks can either have children or contain data
153
+ if v .data is not None and len (v .keys ()):
154
+ raise ValueError (f'Block { k } has { len (k )} children and data: { v .data } ' )
149
155
156
+ if v .data is not None :
157
+ self [k ] = v .data
150
158
151
159
152
160
class MonkeyLogicRawIO (BaseRawIO ):
@@ -161,6 +169,16 @@ def __init__(self, filename=''):
161
169
def _source_name (self ):
162
170
return self .filename
163
171
172
+ def _data_sanity_checks (self ):
173
+ for trial_id in self .trial_ids :
174
+ events = self .ml_blocks [f'Trial{ trial_id } ' ]['BehavioralCodes' ]
175
+
176
+ # sanity check: last event == trial end
177
+ first_event_code = events ['CodeNumbers' ][0 ]
178
+ last_event_code = events ['CodeNumbers' ][- 1 ]
179
+ assert first_event_code == 9 # 9 denotes sending of trial start event
180
+ assert last_event_code == 18 # 18 denotes sending of trial end event
181
+
164
182
def _parse_header (self ):
165
183
self .ml_blocks = {}
166
184
@@ -170,7 +188,9 @@ def _parse_header(self):
170
188
self .ml_blocks [bl .var_name ] = bl
171
189
172
190
trial_rec = self .ml_blocks ['TrialRecord' ]
173
- self .trial_ids = np .arange (1 , int (trial_rec ['CurrentTrialNumber' ].data ))
191
+ self .trial_ids = np .arange (1 , int (trial_rec ['CurrentTrialNumber' ]))
192
+
193
+ self ._data_sanity_checks ()
174
194
175
195
exclude_signals = ['SampleInterval' ]
176
196
@@ -185,49 +205,48 @@ def _parse_header(self):
185
205
186
206
ana_block = self .ml_blocks ['Trial1' ]['AnalogData' ]
187
207
188
- def _register_signal (sig_block , prefix = '' ):
208
+ def _register_signal (sig_block , name ):
189
209
nonlocal stream_id
190
210
nonlocal chan_id
191
- if sig_data . data is not None and any (sig_data . data .shape ):
192
- signal_streams .append ((prefix + sig_data . var_name , stream_id ))
211
+ if not isinstance ( sig_data , dict ) and any (sig_data .shape ):
212
+ signal_streams .append ((name , stream_id ))
193
213
194
- ch_name = sig_data .var_name
195
214
sr = 1 # TODO: Where to get the sampling rate info?
196
- dtype = type (sig_data . data )
215
+ dtype = type (sig_data )
197
216
units = '' # TODO: Where to find the unit info?
198
217
gain = 1 # TODO: Where to find the gain info?
199
218
offset = 0 # TODO: Can signals have an offset in ML?
200
219
stream_id = 0 # all analog data belong to same stream
201
220
202
- if sig_block .data . shape [1 ] == 1 :
203
- signal_channels .append ((prefix + ch_name , chan_id , sr , dtype , units , gain , offset ,
221
+ if sig_block .shape [1 ] == 1 :
222
+ signal_channels .append ((name , chan_id , sr , dtype , units , gain , offset ,
204
223
stream_id ))
205
224
chan_id += 1
206
225
else :
207
- for sub_chan_id in range (sig_block .data . shape [1 ]):
226
+ for sub_chan_id in range (sig_block .shape [1 ]):
208
227
signal_channels .append (
209
- (prefix + ch_name , chan_id , sr , dtype , units , gain , offset ,
228
+ (name , chan_id , sr , dtype , units , gain , offset ,
210
229
stream_id ))
211
230
chan_id += 1
212
231
213
232
214
233
215
- # 1st level signals ('Trial1'/'AnalogData'/<signal>')
234
+
216
235
for sig_name , sig_data in ana_block .items ():
217
236
if sig_name in exclude_signals :
218
237
continue
219
238
220
- # 1st level signals
221
- if sig_data . data is not None and any (sig_data . data .shape ):
222
- _register_signal (sig_data )
239
+ # 1st level signals ('Trial1'/'AnalogData'/<signal>')
240
+ if not isinstance ( sig_data , dict ) and any (sig_data .shape ):
241
+ _register_signal (sig_data , name = sig_name )
223
242
224
- # 2nd level signals
225
- elif sig_data . keys ( ):
243
+ # 2nd level signals ('Trial1'/'AnalogData'/<signal_group>/<signal>')
244
+ elif isinstance ( sig_data , dict ):
226
245
for sig_sub_name , sig_sub_data in sig_data .items ():
227
- if sig_sub_data . data is not None :
228
- chan_names . append ( f'{ sig_name } /{ sig_sub_name } ' )
229
- _register_signal ( sig_sub_data , prefix = f' { sig_name } /' )
230
-
246
+ if not isinstance ( sig_sub_data , dict ) :
247
+ name = f'{ sig_name } /{ sig_sub_name } '
248
+ chan_names . append ( name )
249
+ _register_signal ( sig_sub_data , name = name )
231
250
232
251
spike_channels = []
233
252
@@ -244,7 +263,7 @@ def _register_signal(sig_block, prefix=''):
244
263
245
264
self .header = {}
246
265
self .header ['nb_block' ] = 1
247
- self .header ['nb_segment' ] = [1 ]
266
+ self .header ['nb_segment' ] = [len ( self . trial_ids ) ]
248
267
self .header ['signal_streams' ] = signal_streams
249
268
self .header ['signal_channels' ] = signal_channels
250
269
self .header ['spike_channels' ] = spike_channels
@@ -258,7 +277,7 @@ def _register_signal(sig_block, prefix=''):
258
277
array_annotation_keys = []
259
278
260
279
ml_anno = {k : v for k , v in sorted (self .ml_blocks .items ()) if not k .startswith ('Trial' )}
261
- bl_ann = self .raw_annotations ['block ' ][0 ]
280
+ bl_ann = self .raw_annotations ['blocks ' ][0 ]
262
281
bl_ann .update (ml_anno )
263
282
264
283
# TODO annotate segments according to trial properties
@@ -314,18 +333,17 @@ def _register_signal(sig_block, prefix=''):
314
333
# return
315
334
316
335
def _segment_t_start (self , block_index , seg_index ):
317
- if 'Trial1' in self .ml_blocks :
318
- t_start = self .ml_blocks ['Trial1' ]['AbsoluteTrialStartTime' ].data [0 ][0 ]
319
- else :
320
- t_start = 0
336
+ assert block_index == 0
337
+
338
+ t_start = self .ml_blocks [f'Trial{ seg_index + 1 } ' ]['AbsoluteTrialStartTime' ][0 ][0 ]
321
339
return t_start
322
340
323
341
def _segment_t_stop (self , block_index , seg_index ):
324
- last_trial = self .ml_blocks [f'Trial{ self .trial_ids [- 1 ]} ' ]
342
+ t_start = self ._segment_t_start (block_index , seg_index )
343
+ # using stream 0 as all analogsignal stream should have the same duration
344
+ duration = self ._get_signal_size (block_index , seg_index , 0 )
325
345
326
- t_start = last_trial ['AbsoluteTrialStartTime' ].data [0 ][0 ]
327
- t_stop = t_start + 10 # TODO: Find sampling rates to determine trial end
328
- return t_stop
346
+ return t_start + duration
329
347
330
348
def _get_signal_size (self , block_index , seg_index , stream_index ):
331
349
stream_name , stream_id = self .header ['signal_streams' ][stream_index ]
@@ -334,7 +352,7 @@ def _get_signal_size(self, block_index, seg_index, stream_index):
334
352
for sn in stream_name .split ('/' ): # dealing with 1st and 2nd level signals
335
353
block = block [sn ]
336
354
337
- size = block .data . shape [0 ]
355
+ size = block .shape [0 ]
338
356
return size # size is per signal, not the sum of all channel_indexes
339
357
340
358
def _get_signal_t_start (self , block_index , seg_index , stream_index ):
@@ -356,65 +374,28 @@ def _get_analogsignal_chunk(self, block_index, seg_index, i_start, i_stop, strea
356
374
block = block [sn ]
357
375
358
376
if channel_indexes is None :
359
- raw_signals = block . data
377
+ raw_signals = block
360
378
else :
361
- raw_signals = block . data [channel_indexes ]
379
+ raw_signals = block [channel_indexes ]
362
380
363
381
raw_signals = raw_signals [i_start :i_stop ]
364
382
return raw_signals
365
383
366
384
def _spike_count (self , block_index , seg_index , unit_index ):
367
- count = 0
368
- head_id = self .header ['spike_channels' ][unit_index ][1 ]
369
- for mt in self .file .blocks [block_index ].groups [seg_index ].multi_tags :
370
- for src in mt .sources :
371
- if mt .type == 'neo.spiketrain' and [src .type == "neo.unit" ]:
372
- if head_id == src .id :
373
- return len (mt .positions )
374
- return count
375
-
376
- def _get_spike_timestamps (self , block_index , seg_index , unit_index ,
377
- t_start , t_stop ):
378
- block = self .unit_list ['blocks' ][block_index ]
379
- segment = block ['segments' ][seg_index ]
380
- spike_dict = segment ['spiketrains' ]
381
- spike_timestamps = spike_dict [unit_index ]
382
- spike_timestamps = np .transpose (spike_timestamps )
383
-
384
- if t_start is not None or t_stop is not None :
385
- lim0 = t_start
386
- lim1 = t_stop
387
- mask = (spike_timestamps >= lim0 ) & (spike_timestamps <= lim1 )
388
- spike_timestamps = spike_timestamps [mask ]
389
- return spike_timestamps
385
+ return None
386
+
387
+ def _get_spike_timestamps (self , block_index , seg_index , unit_index , t_start , t_stop ):
388
+ return None
390
389
391
390
def _rescale_spike_timestamp (self , spike_timestamps , dtype ):
392
- spike_times = spike_timestamps .astype (dtype )
393
- return spike_times
394
-
395
- def _get_spike_raw_waveforms (self , block_index , seg_index , unit_index ,
396
- t_start , t_stop ):
397
- # this must return a 3D numpy array (nb_spike, nb_channel, nb_sample)
398
- seg = self .unit_list ['blocks' ][block_index ]['segments' ][seg_index ]
399
- waveforms = seg ['spiketrains_unit' ][unit_index ]['waveforms' ]
400
- if not waveforms :
401
- return None
402
- raw_waveforms = np .array (waveforms )
391
+ return None
403
392
404
- if t_start is not None :
405
- lim0 = t_start
406
- mask = (raw_waveforms >= lim0 )
407
- # use nan to keep the shape
408
- raw_waveforms = np .where (mask , raw_waveforms , np .nan )
409
- if t_stop is not None :
410
- lim1 = t_stop
411
- mask = (raw_waveforms <= lim1 )
412
- raw_waveforms = np .where (mask , raw_waveforms , np .nan )
413
- return raw_waveforms
393
+ def _get_spike_raw_waveforms (self , block_index , seg_index , unit_index , t_start , t_stop ):
394
+ return None
414
395
415
396
def _event_count (self , block_index , seg_index , event_channel_index ):
416
397
assert event_channel_index == 0
417
- times = self .ml_blocks [f'Trial{ seg_index + 1 } ' ]['BehavioralCodes' ]['CodeTimes' ]. data
398
+ times = self .ml_blocks [f'Trial{ seg_index + 1 } ' ]['BehavioralCodes' ]['CodeTimes' ]
418
399
419
400
return len (times )
420
401
@@ -424,9 +405,9 @@ def _get_event_timestamps(self, block_index, seg_index, event_channel_index, t_s
424
405
assert block_index == 0
425
406
assert event_channel_index == 0
426
407
427
- timestamp = self .ml_blocks [f'Trial{ seg_index + 1 } ' ]['BehavioralCodes' ]['CodeTimes' ]. data
408
+ timestamp = self .ml_blocks [f'Trial{ seg_index + 1 } ' ]['BehavioralCodes' ]['CodeTimes' ]
428
409
timestamp = timestamp .flatten ()
429
- labels = self .ml_blocks [f'Trial{ seg_index + 1 } ' ]['BehavioralCodes' ]['CodeNumbers' ]. data
410
+ labels = self .ml_blocks [f'Trial{ seg_index + 1 } ' ]['BehavioralCodes' ]['CodeNumbers' ]
430
411
labels = labels .flatten ()
431
412
432
413
if t_start is not None :
@@ -440,12 +421,14 @@ def _get_event_timestamps(self, block_index, seg_index, event_channel_index, t_s
440
421
return timestamp , durations , labels
441
422
442
423
def _rescale_event_timestamp (self , event_timestamps , dtype , event_channel_index ):
443
- # TODO: Figure out unit and scaling of event timestamps
444
- event_timestamps /= 1000 # assume this is in milliseconds
424
+ # times are stored in millisecond, see
425
+ # shttps://monkeylogic.nimh.nih.gov/docs_GettingStarted.html#FormatsSupported
426
+ event_timestamps /= 1000
445
427
return event_timestamps .astype (dtype ) # return in seconds
446
428
447
429
def _rescale_epoch_duration (self , raw_duration , dtype , event_channel_index ):
448
- # TODO: Figure out unit and scaling of event timestamps
449
- raw_duration /= 1000 # assume this is in milliseconds
430
+ # times are stored in millisecond, see
431
+ # shttps://monkeylogic.nimh.nih.gov/docs_GettingStarted.html#FormatsSupported
432
+ raw_duration /= 1000
450
433
return raw_duration .astype (dtype ) # return in seconds
451
434
0 commit comments