Skip to content

Commit 602a849

Browse files
author
sprenger
committed
[MLIO] improve cell handling
1 parent ead8c15 commit 602a849

File tree

1 file changed

+30
-8
lines changed

1 file changed

+30
-8
lines changed

neo/rawio/monkeylogicrawio.py

+30-8
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
_spike_channel_dtype, _event_channel_dtype)
1717

1818

19-
class MLBLock(dict):
19+
class MLBlock(dict):
2020
n_byte_dtype = {'logical': (1, '?'),
2121
'char': (1, 'c'),
2222
'integers': (8, 'Q'),
@@ -53,7 +53,7 @@ def generate_block(f):
5353
var_size = struct.unpack(f'{DV}Q', var_size)
5454
# print(var_size)
5555

56-
return MLBLock(LN, var_name, LT, var_type, DV, var_size)
56+
return MLBlock(LN, var_name, LT, var_type, DV, var_size)
5757

5858
def __bool__(self):
5959
if any((self.LN, self.LT)):
@@ -107,7 +107,11 @@ def read_data(self, f, recursive=False):
107107
d = struct.unpack(format, d)[0]
108108
data[i] = d
109109

110-
data = data.reshape(self.var_size)
110+
# convert to simple / expected data shape
111+
if self.var_size == (1, 1):
112+
data = data[0]
113+
else:
114+
data = data.reshape(self.var_size)
111115

112116
# decoding characters
113117
if self.var_type == 'char':
@@ -127,17 +131,23 @@ def read_data(self, f, recursive=False):
127131
n_fields = struct.unpack('Q', n_fields)[0]
128132

129133
for field in range(n_fields * np.prod(self.var_size)):
130-
bl = MLBLock.generate_block(f)
134+
bl = MLBlock.generate_block(f)
131135
if recursive:
132136
self[bl.var_name] = bl
133137
bl.read_data(f, recursive=recursive)
134138

135139
elif self.var_type == 'cell':
140+
# cells are always 2D
141+
assert len(self.var_size) == 2, 'Unexpected dimensions of cells'
142+
data = np.empty(shape=np.prod(self.var_size), dtype=object)
136143
for field in range(np.prod(self.var_size)):
137-
bl = MLBLock.generate_block(f)
144+
bl = MLBlock.generate_block(f)
138145
if recursive:
139-
self[bl.var_name] = bl
146+
data[field] = bl
147+
140148
bl.read_data(f, recursive=recursive)
149+
data = data.reshape(self.var_size)
150+
self.data = data
141151

142152
else:
143153
raise ValueError(f'unknown variable type {self.var_type}')
@@ -146,9 +156,12 @@ def read_data(self, f, recursive=False):
146156

147157
def flatten(self):
148158
'''
149-
Reassigning data objects to be children of parent dict
159+
Flatten structure by
160+
1) Reassigning data objects to be children of parent dict
150161
block1.block2.data -> block1.data as block2 anyway does not contain keys
162+
2) converting data arrays items from blocks to data objects
151163
'''
164+
152165
for k, v in self.items():
153166
# Sanity check: Blocks can either have children or contain data
154167
if v.data is not None and len(v.keys()):
@@ -157,6 +170,15 @@ def flatten(self):
157170
if v.data is not None:
158171
self[k] = v.data
159172

173+
# converting arrays of MLBlocks (cells) to (nested) list of objects
174+
if isinstance(self[k], np.ndarray) and all([isinstance(b, MLBlock) for b in self[k].flat]):
175+
assert len(self[k].shape) == 2
176+
for i in range(self[k].shape[0]):
177+
for j in range(self[k].shape[1]):
178+
self[k][i, j] = self[k][i, j].data
179+
self[k] = self[k].tolist()
180+
181+
160182

161183
class MonkeyLogicRawIO(BaseRawIO):
162184
extensions = ['bhv2']
@@ -183,7 +205,7 @@ def _parse_header(self):
183205
self.ml_blocks = {}
184206

185207
with open(self.filename, 'rb') as f:
186-
while bl := MLBLock.generate_block(f):
208+
while bl := MLBlock.generate_block(f):
187209
bl.read_data(f, recursive=True)
188210
self.ml_blocks[bl.var_name] = bl
189211

0 commit comments

Comments
 (0)