Skip to content

Commit 76ddbd8

Browse files
author
sprenger
committed
[MLIO] flatten ML_block dictionary structure
1 parent 5073a7b commit 76ddbd8

File tree

1 file changed

+71
-88
lines changed

1 file changed

+71
-88
lines changed

neo/rawio/monkeylogicrawio.py

Lines changed: 71 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -141,12 +141,20 @@ def read_data(self, f, recursive=False):
141141
else:
142142
raise ValueError(f'unknown variable type {self.var_type}')
143143

144-
# Sanity check: Blocks can only have children or contain data
145-
if self.data is not None and len(self.keys()):
146-
raise ValueError(f'Block {self.var_name} has {len(self)} children and data: {self.data}')
147-
144+
self.flatten()
148145

146+
def flatten(self):
147+
'''
148+
Reassigning data objects to be children of parent dict
149+
block1.block2.data -> block1.data as block2 anyway does not contain keys
150+
'''
151+
for k, v in self.items():
152+
# Sanity check: Blocks can either have children or contain data
153+
if v.data is not None and len(v.keys()):
154+
raise ValueError(f'Block {k} has {len(k)} children and data: {v.data}')
149155

156+
if v.data is not None:
157+
self[k] = v.data
150158

151159

152160
class MonkeyLogicRawIO(BaseRawIO):
@@ -161,6 +169,16 @@ def __init__(self, filename=''):
161169
def _source_name(self):
162170
return self.filename
163171

172+
def _data_sanity_checks(self):
173+
for trial_id in self.trial_ids:
174+
events = self.ml_blocks[f'Trial{trial_id}']['BehavioralCodes']
175+
176+
# sanity check: last event == trial end
177+
first_event_code = events['CodeNumbers'][0]
178+
last_event_code = events['CodeNumbers'][-1]
179+
assert first_event_code == 9 # 9 denotes sending of trial start event
180+
assert last_event_code == 18 # 18 denotes sending of trial end event
181+
164182
def _parse_header(self):
165183
self.ml_blocks = {}
166184

@@ -170,7 +188,9 @@ def _parse_header(self):
170188
self.ml_blocks[bl.var_name] = bl
171189

172190
trial_rec = self.ml_blocks['TrialRecord']
173-
self.trial_ids = np.arange(1, int(trial_rec['CurrentTrialNumber'].data))
191+
self.trial_ids = np.arange(1, int(trial_rec['CurrentTrialNumber']))
192+
193+
self._data_sanity_checks()
174194

175195
exclude_signals = ['SampleInterval']
176196

@@ -185,49 +205,48 @@ def _parse_header(self):
185205

186206
ana_block = self.ml_blocks['Trial1']['AnalogData']
187207

188-
def _register_signal(sig_block, prefix=''):
208+
def _register_signal(sig_block, name):
189209
nonlocal stream_id
190210
nonlocal chan_id
191-
if sig_data.data is not None and any(sig_data.data.shape):
192-
signal_streams.append((prefix + sig_data.var_name, stream_id))
211+
if not isinstance(sig_data, dict) and any(sig_data.shape):
212+
signal_streams.append((name, stream_id))
193213

194-
ch_name = sig_data.var_name
195214
sr = 1 # TODO: Where to get the sampling rate info?
196-
dtype = type(sig_data.data)
215+
dtype = type(sig_data)
197216
units = '' # TODO: Where to find the unit info?
198217
gain = 1 # TODO: Where to find the gain info?
199218
offset = 0 # TODO: Can signals have an offset in ML?
200219
stream_id = 0 # all analog data belong to same stream
201220

202-
if sig_block.data.shape[1] == 1:
203-
signal_channels.append((prefix + ch_name, chan_id, sr, dtype, units, gain, offset,
221+
if sig_block.shape[1] == 1:
222+
signal_channels.append((name, chan_id, sr, dtype, units, gain, offset,
204223
stream_id))
205224
chan_id += 1
206225
else:
207-
for sub_chan_id in range(sig_block.data.shape[1]):
226+
for sub_chan_id in range(sig_block.shape[1]):
208227
signal_channels.append(
209-
(prefix + ch_name, chan_id, sr, dtype, units, gain, offset,
228+
(name, chan_id, sr, dtype, units, gain, offset,
210229
stream_id))
211230
chan_id += 1
212231

213232

214233

215-
# 1st level signals ('Trial1'/'AnalogData'/<signal>')
234+
216235
for sig_name, sig_data in ana_block.items():
217236
if sig_name in exclude_signals:
218237
continue
219238

220-
# 1st level signals
221-
if sig_data.data is not None and any(sig_data.data.shape):
222-
_register_signal(sig_data)
239+
# 1st level signals ('Trial1'/'AnalogData'/<signal>')
240+
if not isinstance(sig_data, dict) and any(sig_data.shape):
241+
_register_signal(sig_data, name=sig_name)
223242

224-
# 2nd level signals
225-
elif sig_data.keys():
243+
# 2nd level signals ('Trial1'/'AnalogData'/<signal_group>/<signal>')
244+
elif isinstance(sig_data, dict):
226245
for sig_sub_name, sig_sub_data in sig_data.items():
227-
if sig_sub_data.data is not None:
228-
chan_names.append(f'{sig_name}/{sig_sub_name}')
229-
_register_signal(sig_sub_data, prefix=f'{sig_name}/')
230-
246+
if not isinstance(sig_sub_data, dict):
247+
name = f'{sig_name}/{sig_sub_name}'
248+
chan_names.append(name)
249+
_register_signal(sig_sub_data, name=name)
231250

232251
spike_channels = []
233252

@@ -244,7 +263,7 @@ def _register_signal(sig_block, prefix=''):
244263

245264
self.header = {}
246265
self.header['nb_block'] = 1
247-
self.header['nb_segment'] = [1]
266+
self.header['nb_segment'] = [len(self.trial_ids)]
248267
self.header['signal_streams'] = signal_streams
249268
self.header['signal_channels'] = signal_channels
250269
self.header['spike_channels'] = spike_channels
@@ -258,7 +277,7 @@ def _register_signal(sig_block, prefix=''):
258277
array_annotation_keys = []
259278

260279
ml_anno = {k: v for k, v in sorted(self.ml_blocks.items()) if not k.startswith('Trial')}
261-
bl_ann = self.raw_annotations['block'][0]
280+
bl_ann = self.raw_annotations['blocks'][0]
262281
bl_ann.update(ml_anno)
263282

264283
# TODO annotate segments according to trial properties
@@ -314,18 +333,17 @@ def _register_signal(sig_block, prefix=''):
314333
# return
315334

316335
def _segment_t_start(self, block_index, seg_index):
317-
if 'Trial1' in self.ml_blocks:
318-
t_start = self.ml_blocks['Trial1']['AbsoluteTrialStartTime'].data[0][0]
319-
else:
320-
t_start = 0
336+
assert block_index == 0
337+
338+
t_start = self.ml_blocks[f'Trial{seg_index+1}']['AbsoluteTrialStartTime'][0][0]
321339
return t_start
322340

323341
def _segment_t_stop(self, block_index, seg_index):
324-
last_trial = self.ml_blocks[f'Trial{self.trial_ids[-1]}']
342+
t_start = self._segment_t_start(block_index, seg_index)
343+
# using stream 0 as all analogsignal stream should have the same duration
344+
duration = self._get_signal_size(block_index, seg_index, 0)
325345

326-
t_start = last_trial['AbsoluteTrialStartTime'].data[0][0]
327-
t_stop = t_start + 10 # TODO: Find sampling rates to determine trial end
328-
return t_stop
346+
return t_start + duration
329347

330348
def _get_signal_size(self, block_index, seg_index, stream_index):
331349
stream_name, stream_id = self.header['signal_streams'][stream_index]
@@ -334,7 +352,7 @@ def _get_signal_size(self, block_index, seg_index, stream_index):
334352
for sn in stream_name.split('/'): # dealing with 1st and 2nd level signals
335353
block = block[sn]
336354

337-
size = block.data.shape[0]
355+
size = block.shape[0]
338356
return size # size is per signal, not the sum of all channel_indexes
339357

340358
def _get_signal_t_start(self, block_index, seg_index, stream_index):
@@ -356,65 +374,28 @@ def _get_analogsignal_chunk(self, block_index, seg_index, i_start, i_stop, strea
356374
block = block[sn]
357375

358376
if channel_indexes is None:
359-
raw_signals = block.data
377+
raw_signals = block
360378
else:
361-
raw_signals = block.data[channel_indexes]
379+
raw_signals = block[channel_indexes]
362380

363381
raw_signals = raw_signals[i_start:i_stop]
364382
return raw_signals
365383

366384
def _spike_count(self, block_index, seg_index, unit_index):
367-
count = 0
368-
head_id = self.header['spike_channels'][unit_index][1]
369-
for mt in self.file.blocks[block_index].groups[seg_index].multi_tags:
370-
for src in mt.sources:
371-
if mt.type == 'neo.spiketrain' and [src.type == "neo.unit"]:
372-
if head_id == src.id:
373-
return len(mt.positions)
374-
return count
375-
376-
def _get_spike_timestamps(self, block_index, seg_index, unit_index,
377-
t_start, t_stop):
378-
block = self.unit_list['blocks'][block_index]
379-
segment = block['segments'][seg_index]
380-
spike_dict = segment['spiketrains']
381-
spike_timestamps = spike_dict[unit_index]
382-
spike_timestamps = np.transpose(spike_timestamps)
383-
384-
if t_start is not None or t_stop is not None:
385-
lim0 = t_start
386-
lim1 = t_stop
387-
mask = (spike_timestamps >= lim0) & (spike_timestamps <= lim1)
388-
spike_timestamps = spike_timestamps[mask]
389-
return spike_timestamps
385+
return None
386+
387+
def _get_spike_timestamps(self, block_index, seg_index, unit_index, t_start, t_stop):
388+
return None
390389

391390
def _rescale_spike_timestamp(self, spike_timestamps, dtype):
392-
spike_times = spike_timestamps.astype(dtype)
393-
return spike_times
394-
395-
def _get_spike_raw_waveforms(self, block_index, seg_index, unit_index,
396-
t_start, t_stop):
397-
# this must return a 3D numpy array (nb_spike, nb_channel, nb_sample)
398-
seg = self.unit_list['blocks'][block_index]['segments'][seg_index]
399-
waveforms = seg['spiketrains_unit'][unit_index]['waveforms']
400-
if not waveforms:
401-
return None
402-
raw_waveforms = np.array(waveforms)
391+
return None
403392

404-
if t_start is not None:
405-
lim0 = t_start
406-
mask = (raw_waveforms >= lim0)
407-
# use nan to keep the shape
408-
raw_waveforms = np.where(mask, raw_waveforms, np.nan)
409-
if t_stop is not None:
410-
lim1 = t_stop
411-
mask = (raw_waveforms <= lim1)
412-
raw_waveforms = np.where(mask, raw_waveforms, np.nan)
413-
return raw_waveforms
393+
def _get_spike_raw_waveforms(self, block_index, seg_index, unit_index, t_start, t_stop):
394+
return None
414395

415396
def _event_count(self, block_index, seg_index, event_channel_index):
416397
assert event_channel_index == 0
417-
times = self.ml_blocks[f'Trial{seg_index+1}']['BehavioralCodes']['CodeTimes'].data
398+
times = self.ml_blocks[f'Trial{seg_index+1}']['BehavioralCodes']['CodeTimes']
418399

419400
return len(times)
420401

@@ -424,9 +405,9 @@ def _get_event_timestamps(self, block_index, seg_index, event_channel_index, t_s
424405
assert block_index == 0
425406
assert event_channel_index == 0
426407

427-
timestamp = self.ml_blocks[f'Trial{seg_index+1}']['BehavioralCodes']['CodeTimes'].data
408+
timestamp = self.ml_blocks[f'Trial{seg_index+1}']['BehavioralCodes']['CodeTimes']
428409
timestamp = timestamp.flatten()
429-
labels = self.ml_blocks[f'Trial{seg_index+1}']['BehavioralCodes']['CodeNumbers'].data
410+
labels = self.ml_blocks[f'Trial{seg_index+1}']['BehavioralCodes']['CodeNumbers']
430411
labels = labels.flatten()
431412

432413
if t_start is not None:
@@ -440,12 +421,14 @@ def _get_event_timestamps(self, block_index, seg_index, event_channel_index, t_s
440421
return timestamp, durations, labels
441422

442423
def _rescale_event_timestamp(self, event_timestamps, dtype, event_channel_index):
443-
# TODO: Figure out unit and scaling of event timestamps
444-
event_timestamps /= 1000 # assume this is in milliseconds
424+
# times are stored in millisecond, see
425+
# shttps://monkeylogic.nimh.nih.gov/docs_GettingStarted.html#FormatsSupported
426+
event_timestamps /= 1000
445427
return event_timestamps.astype(dtype) # return in seconds
446428

447429
def _rescale_epoch_duration(self, raw_duration, dtype, event_channel_index):
448-
# TODO: Figure out unit and scaling of event timestamps
449-
raw_duration /= 1000 # assume this is in milliseconds
430+
# times are stored in millisecond, see
431+
# shttps://monkeylogic.nimh.nih.gov/docs_GettingStarted.html#FormatsSupported
432+
raw_duration /= 1000
450433
return raw_duration.astype(dtype) # return in seconds
451434

0 commit comments

Comments
 (0)