Skip to content

Commit e9a710d

Browse files
authored
Merge pull request #1630 from NeuralEnsemble/black-formatting
Black formatting
2 parents a75cd42 + e485357 commit e9a710d

8 files changed

+54
-86
lines changed

neo/core/spiketrain.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,7 @@ def normalize_times_array(times, units=None, dtype=None, copy=None):
199199
"In order to facilitate the deprecation copy can be set to None but will raise an "
200200
"error if set to True/False since this will silently do nothing. This argument will be completely "
201201
"removed in Neo 0.15.0. Please update your code base as necessary."
202-
)
202+
)
203203

204204
if dtype is None:
205205
if not hasattr(times, "dtype"):

neo/rawio/blackrockrawio.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -681,8 +681,7 @@ def _get_timestamp_slice(self, timestamp, seg_index, t_start, t_stop):
681681
if t_start is None:
682682
t_start = self._seg_t_starts[seg_index]
683683
if t_stop is None:
684-
t_stop = self._seg_t_stops[seg_index] + 1 / float(
685-
self.__nev_basic_header['timestamp_resolution'])
684+
t_stop = self._seg_t_stops[seg_index] + 1 / float(self.__nev_basic_header["timestamp_resolution"])
686685

687686
if t_start is None:
688687
ind_start = None
@@ -715,15 +714,16 @@ def _get_spike_raw_waveforms(self, block_index, seg_index, unit_index, t_start,
715714
)
716715
unit_spikes = all_spikes[mask]
717716

718-
wf_dtype = self.__nev_params('waveform_dtypes')[channel_id]
719-
wf_size = self.__nev_params('waveform_size')[channel_id]
717+
wf_dtype = self.__nev_params("waveform_dtypes")[channel_id]
718+
wf_size = self.__nev_params("waveform_size")[channel_id]
720719
wf_byte_size = np.dtype(wf_dtype).itemsize * wf_size
721720

722721
dt1 = [
723-
('extra', 'S{}'.format(unit_spikes['waveform'].dtype.itemsize - wf_byte_size)),
724-
('ch_waveform', 'S{}'.format(wf_byte_size))]
722+
("extra", "S{}".format(unit_spikes["waveform"].dtype.itemsize - wf_byte_size)),
723+
("ch_waveform", "S{}".format(wf_byte_size)),
724+
]
725725

726-
waveforms = unit_spikes['waveform'].view(dt1)['ch_waveform'].flatten().view(wf_dtype)
726+
waveforms = unit_spikes["waveform"].view(dt1)["ch_waveform"].flatten().view(wf_dtype)
727727

728728
waveforms = waveforms.reshape(int(unit_spikes.size), 1, int(wf_size))
729729

@@ -1365,7 +1365,9 @@ def __match_nsx_and_nev_segment_ids(self, nsx_nb):
13651365

13661366
# Show warning if spikes do not fit any segment (+- 1 sampling 'tick')
13671367
# Spike should belong to segment before
1368-
mask_outside = (ev_ids == i) & (data["timestamp"] < int(seg["timestamp"]) - int(nsx_offset) - int(nsx_period))
1368+
mask_outside = (ev_ids == i) & (
1369+
data["timestamp"] < int(seg["timestamp"]) - int(nsx_offset) - int(nsx_period)
1370+
)
13691371

13701372
if len(data[mask_outside]) > 0:
13711373
warnings.warn(f"Spikes outside any segment. Detected on segment #{i}")
@@ -1995,7 +1997,6 @@ def __get_nsx_param_variant_a(self, nsx_nb):
19951997
else:
19961998
units = "uV"
19971999

1998-
19992000
nsx_parameters = {
20002001
"nb_data_points": int(
20012002
(self.__get_file_size(filename) - bytes_in_headers)

neo/rawio/intanrawio.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ class IntanRawIO(BaseRawIO):
9292
one long vector, which must be post-processed to extract individual digital channel information.
9393
See the intantech website for more information on performing this post-processing.
9494
95-
95+
9696
Examples
9797
--------
9898
>>> import neo.rawio

neo/rawio/micromedrawio.py

Lines changed: 16 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,6 @@ def __init__(self, filename=""):
5252

5353
def _parse_header(self):
5454

55-
5655
with open(self.filename, "rb") as fid:
5756
f = StructFile(fid)
5857

@@ -99,7 +98,6 @@ def _parse_header(self):
9998
if zname != zname2.decode("ascii").strip(" "):
10099
raise NeoReadWriteError("expected the zone name to match")
101100

102-
103101
# "TRONCA" zone define segments
104102
zname2, pos, length = zones["TRONCA"]
105103
f.seek(pos)
@@ -114,7 +112,7 @@ def _parse_header(self):
114112
break
115113
else:
116114
self.info_segments.append((seg_start, trace_offset))
117-
115+
118116
if len(self.info_segments) == 0:
119117
# one unique segment = general case
120118
self.info_segments.append((0, 0))
@@ -152,8 +150,9 @@ def _parse_header(self):
152150
(sampling_rate,) = f.read_f("H")
153151
sampling_rate *= Rate_Min
154152
chan_id = str(c)
155-
signal_channels.append((chan_name, chan_id, sampling_rate, sig_dtype, units, gain, offset, stream_id, buffer_id))
156-
153+
signal_channels.append(
154+
(chan_name, chan_id, sampling_rate, sig_dtype, units, gain, offset, stream_id, buffer_id)
155+
)
157156

158157
signal_channels = np.array(signal_channels, dtype=_signal_channel_dtype)
159158

@@ -166,31 +165,31 @@ def _parse_header(self):
166165
self._sampling_rate = float(np.unique(signal_channels["sampling_rate"])[0])
167166

168167
# memmap traces buffer
169-
full_signal_shape = get_memmap_shape(self.filename, sig_dtype, num_channels=Num_Chan, offset=Data_Start_Offset)
168+
full_signal_shape = get_memmap_shape(
169+
self.filename, sig_dtype, num_channels=Num_Chan, offset=Data_Start_Offset
170+
)
170171
seg_limits = [trace_offset for seg_start, trace_offset in self.info_segments] + [full_signal_shape[0]]
171172
self._t_starts = []
172-
self._buffer_descriptions = {0 :{}}
173+
self._buffer_descriptions = {0: {}}
173174
for seg_index in range(nb_segment):
174175
seg_start, trace_offset = self.info_segments[seg_index]
175176
self._t_starts.append(seg_start / self._sampling_rate)
176177

177178
start = seg_limits[seg_index]
178179
stop = seg_limits[seg_index + 1]
179-
180+
180181
shape = (stop - start, Num_Chan)
181-
file_offset = Data_Start_Offset + ( start * np.dtype(sig_dtype).itemsize * Num_Chan)
182+
file_offset = Data_Start_Offset + (start * np.dtype(sig_dtype).itemsize * Num_Chan)
182183
self._buffer_descriptions[0][seg_index] = {}
183184
self._buffer_descriptions[0][seg_index][buffer_id] = {
184-
"type" : "raw",
185-
"file_path" : str(self.filename),
186-
"dtype" : sig_dtype,
185+
"type": "raw",
186+
"file_path": str(self.filename),
187+
"dtype": sig_dtype,
187188
"order": "C",
188-
"file_offset" : file_offset,
189-
"shape" : shape,
189+
"file_offset": file_offset,
190+
"shape": shape,
190191
}
191192

192-
193-
194193
# Event channels
195194
event_channels = []
196195
event_channels.append(("Trigger", "", "event"))
@@ -217,14 +216,9 @@ def _parse_header(self):
217216
for seg_index in range(nb_segment):
218217
left_lim = seg_limits[seg_index]
219218
right_lim = seg_limits[seg_index + 1]
220-
keep = (
221-
(rawevent["start"] >= left_lim)
222-
& (rawevent["start"] < right_lim)
223-
& (rawevent["start"] != 0)
224-
)
219+
keep = (rawevent["start"] >= left_lim) & (rawevent["start"] < right_lim) & (rawevent["start"] != 0)
225220
self._raw_events[-1].append(rawevent[keep])
226221

227-
228222
# No spikes
229223
spike_channels = []
230224
spike_channels = np.array(spike_channels, dtype=_spike_channel_dtype)

neo/rawio/neuronexusrawio.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ def __init__(self, filename: str | Path = ""):
7373
* The *.xdat.json metadata file
7474
* The *_data.xdat binary file of all raw data
7575
* The *_timestamps.xdat binary file of the timestamp data
76-
76+
7777
From the metadata the other two files are located within the same directory
7878
and loaded.
7979

neo/rawio/spikeglxrawio.py

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ class SpikeGLXRawIO(BaseRawWithBufferApiIO):
8585
* This IO reads the entire folder and subfolders locating the `.bin` and `.meta` files
8686
* Handles gates and triggers as segments (based on the `_gt0`, `_gt1`, `_t0` , `_t1` in filenames)
8787
* Handles all signals coming from different acquisition cards ("imec0", "imec1", etc) in a typical
88-
PXIe chassis setup and also external signal like "nidq".
88+
PXIe chassis setup and also external signal like "nidq".
8989
* For imec devices both "ap" and "lf" are extracted so even a one device setup will have several "streams"
9090
9191
Examples
@@ -227,22 +227,19 @@ def _parse_header(self):
227227

228228
self._t_starts = {stream_name: {} for stream_name in stream_names}
229229
self._t_stops = {seg_index: 0.0 for seg_index in range(nb_segment)}
230-
230+
231231
for stream_name in stream_names:
232232
for seg_index in range(nb_segment):
233233
info = self.signals_info_dict[seg_index, stream_name]
234234

235235
frame_start = float(info["meta"]["firstSample"])
236236
sampling_frequency = info["sampling_rate"]
237237
t_start = frame_start / sampling_frequency
238-
239-
self._t_starts[stream_name][seg_index] = t_start
238+
239+
self._t_starts[stream_name][seg_index] = t_start
240240
t_stop = info["sample_length"] / info["sampling_rate"]
241241
self._t_stops[seg_index] = max(self._t_stops[seg_index], t_stop)
242242

243-
244-
245-
246243
# fille into header dict
247244
self.header = {}
248245
self.header["nb_block"] = 1
@@ -361,24 +358,23 @@ def scan_files(dirname):
361358
raise FileNotFoundError(f"No appropriate combination of .meta and .bin files were detected in {dirname}")
362359

363360
# This sets non-integers values before integers
364-
normalize = lambda x: x if isinstance(x, int) else -1
361+
normalize = lambda x: x if isinstance(x, int) else -1
365362

366363
# Segment index is determined by the gate_num and trigger_num in that order
367364
def get_segment_tuple(info):
368365
# Create a key from the normalized gate_num and trigger_num
369366
gate_num = normalize(info.get("gate_num"))
370367
trigger_num = normalize(info.get("trigger_num"))
371368
return (gate_num, trigger_num)
372-
369+
373370
unique_segment_tuples = {get_segment_tuple(info) for info in info_list}
374371
sorted_keys = sorted(unique_segment_tuples)
375372

376373
# Map each unique key to a corresponding index
377374
segment_tuple_to_segment_index = {key: idx for idx, key in enumerate(sorted_keys)}
378375

379376
for info in info_list:
380-
info["seg_index"] = segment_tuple_to_segment_index[get_segment_tuple(info)]
381-
377+
info["seg_index"] = segment_tuple_to_segment_index[get_segment_tuple(info)]
382378

383379
# Probe index calculation
384380
# The calculation is ordered by slot, port, dock in that order, this is the number that appears in the filename
@@ -409,7 +405,7 @@ def get_probe_tuple(info):
409405
stream_name = f"{device_kind}{device_index}{stream_kind}"
410406

411407
info["stream_name"] = stream_name
412-
408+
413409
return info_list
414410

415411

neo/test/rawiotest/test_micromedrawio.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
import numpy as np
1212

13+
1314
class TestMicromedRawIO(
1415
BaseTestRawIO,
1516
unittest.TestCase,
@@ -25,15 +26,15 @@ class TestMicromedRawIO(
2526
def test_micromed_multi_segments(self):
2627
file_full = self.get_local_path("micromed/File_mircomed2.TRC")
2728
file_splitted = self.get_local_path("micromed/File_mircomed2_2segments.TRC")
28-
29+
2930
# the second file contains 2 pieces of the first file
3031
# so it is 2 segments with the same traces but reduced
3132
# note that traces in the splited can differ at the very end of the cut
3233

3334
reader1 = MicromedRawIO(file_full)
3435
reader1.parse_header()
3536
assert reader1.segment_count(block_index=0) == 1
36-
assert reader1.get_signal_t_start(block_index=0, seg_index=0, stream_index=0) == 0.
37+
assert reader1.get_signal_t_start(block_index=0, seg_index=0, stream_index=0) == 0.0
3738
traces1 = reader1.get_analogsignal_chunk(stream_index=0)
3839

3940
reader2 = MicromedRawIO(file_splitted)
@@ -48,11 +49,10 @@ def test_micromed_multi_segments(self):
4849
sr = reader2.get_signal_sampling_rate(stream_index=0)
4950
ind_start = int(t_start * sr)
5051
traces2 = reader2.get_analogsignal_chunk(block_index=0, seg_index=seg_index, stream_index=0)
51-
traces1_chunk = traces1[ind_start: ind_start+traces2.shape[0]]
52+
traces1_chunk = traces1[ind_start : ind_start + traces2.shape[0]]
5253
# we remove the last 100 sample because tools that cut traces is truncating the last buffer
5354
assert np.array_equal(traces2[:-100], traces1_chunk[:-100])
5455

5556

56-
5757
if __name__ == "__main__":
5858
unittest.main()

neo/test/rawiotest/test_spikeglxrawio.py

Lines changed: 13 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -114,63 +114,40 @@ def test_nidq_digital_channel(self):
114114

115115
def test_t_start_reading(self):
116116
"""Test that t_start values are correctly read for all streams and segments."""
117-
117+
118118
# Expected t_start values for each stream and segment
119119
expected_t_starts = {
120-
'imec0.ap': {
121-
0: 15.319535472007237,
122-
1: 15.339535431281986,
123-
2: 21.284723325294053,
124-
3: 21.3047232845688
125-
},
126-
'imec1.ap': {
127-
0: 15.319554693264516,
128-
1: 15.339521518106308,
129-
2: 21.284735282142822,
130-
3: 21.304702106984614
131-
},
132-
'imec0.lf': {
133-
0: 15.3191688060872,
134-
1: 15.339168765361949,
135-
2: 21.284356659374016,
136-
3: 21.304356618648765
137-
},
138-
'imec1.lf': {
139-
0: 15.319321358082725,
140-
1: 15.339321516521915,
141-
2: 21.284568614155827,
142-
3: 21.30456877259502
143-
}
120+
"imec0.ap": {0: 15.319535472007237, 1: 15.339535431281986, 2: 21.284723325294053, 3: 21.3047232845688},
121+
"imec1.ap": {0: 15.319554693264516, 1: 15.339521518106308, 2: 21.284735282142822, 3: 21.304702106984614},
122+
"imec0.lf": {0: 15.3191688060872, 1: 15.339168765361949, 2: 21.284356659374016, 3: 21.304356618648765},
123+
"imec1.lf": {0: 15.319321358082725, 1: 15.339321516521915, 2: 21.284568614155827, 3: 21.30456877259502},
144124
}
145-
125+
146126
# Initialize the RawIO
147127
rawio = SpikeGLXRawIO(self.get_local_path("spikeglx/multi_trigger_multi_gate/SpikeGLX/5-19-2022-CI4"))
148128
rawio.parse_header()
149-
129+
150130
# Get list of stream names
151131
stream_names = rawio.header["signal_streams"]["name"]
152-
132+
153133
# Test t_start for each stream and segment
154134
for stream_name, expected_values in expected_t_starts.items():
155135
# Get stream index
156136
stream_index = list(stream_names).index(stream_name)
157-
137+
158138
# Check each segment
159139
for seg_index, expected_t_start in expected_values.items():
160-
actual_t_start = rawio.get_signal_t_start(
161-
block_index=0,
162-
seg_index=seg_index,
163-
stream_index=stream_index
164-
)
165-
140+
actual_t_start = rawio.get_signal_t_start(block_index=0, seg_index=seg_index, stream_index=stream_index)
141+
166142
# Use numpy.testing for proper float comparison
167143
np.testing.assert_allclose(
168144
actual_t_start,
169145
expected_t_start,
170146
rtol=1e-9,
171147
atol=1e-9,
172-
err_msg=f"Mismatch in t_start for stream '{stream_name}', segment {seg_index}"
148+
err_msg=f"Mismatch in t_start for stream '{stream_name}', segment {seg_index}",
173149
)
174150

151+
175152
if __name__ == "__main__":
176153
unittest.main()

0 commit comments

Comments
 (0)