-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathget_data.py
185 lines (151 loc) · 6.11 KB
/
get_data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
"""
A set of helper functions for downloading and preprocessing hippocampal data.
@author: Gautam Agarwal
@author: Ben Toker
"""
from scipy.signal import decimate
import numpy as np
import math
import os
from scipy import io
from scipy.stats import mode
import h5py
def get_behav(mat_file, fs=25):
"""
Organizes information about rat behavior into a matrix
Input:
mat_file: file containing behavioral data
fs = sampling frequency
Output:
lapID = Data stored in the format listed below
lapID format:
Column 0: Trial Number
Column 1: Maze arm (-1/0/1/2) (-1 = not in maze arm)
Column 2: Correct (0/1)
Column 3: other/first approach/port/last departure (0/1/2/3)
Column 4: x position in mm
Column 5: y position in mm
"""
dec = int(1250 / fs) # decimation factor
mat = io.loadmat(mat_file, variable_names=['Track'])
lapID = np.array([np.squeeze(mat['Track']["lapID"][0][0])[::dec]], dtype='float32') - 1
lapID = np.append(lapID, [np.squeeze(mat['Track']["mazeSect"][0][0])[::dec]], axis=0)
lapID = np.append(lapID, [np.squeeze(mat['Track']["corrChoice"][0][0])[::dec]], axis=0)
lapID = np.append(lapID, np.zeros((1, len(lapID[0]))), axis=0)
lapID = np.append(lapID, decimate(mat['Track']["xMM"][0][0].T, dec), axis=0)
lapID = np.append(lapID, decimate(mat['Track']["yMM"][0][0].T, dec), axis=0)
lapID = lapID.T
# Filter values and construct column 3
in_arm = np.in1d(lapID[:, 1], np.array(range(4, 10))) # rat is occupying a maze arm
in_end = np.in1d(lapID[:, 1], np.array(range(7, 10)))
# lapID[np.in1d(lapID[:,1], np.array(range(4, 10)), invert = True), 1] = -1
lapID[in_arm, 1] = (lapID[in_arm, 1] - 1) % 3
lapID[~in_arm, 1] = -1
# lapID[lapID[:, 1] == 0, :] = 0
for i in range(int(np.max(lapID[:, 0]))):
r = np.logical_and(lapID[:, 0] == i, in_end) # lapID[:, 3] == 2)
inds = np.where(np.logical_and(lapID[:, 0] == i, in_arm))[0]
all_end = np.where(r)[0]
# if all_end.size > 0: #valid trial where rat goes to end of arm
lapID[inds[inds < all_end[0]], 3] = 1
lapID[inds[inds > all_end[-1]], 3] = 3
lapID[longest_stretch(r), 3] = 2
# Return structured data
return lapID
def longest_stretch(bool_array):
"""
Finds longest contiguous stretch of True values
Input:
bool_array = boolean vector
Output:
bool_most_common = boolean vector, True only for longest stretch of 'True' in bool_array
"""
bool_array_diff = np.append(bool_array[0], bool_array)
bool_array_diff = np.cumsum(np.abs(np.diff(bool_array_diff)))
bool_most_common = bool_array_diff == mode(bool_array_diff[bool_array])[0]
return bool_most_common
def get_spikes(mat_file, fs=25):
"""
Counts spikes for each neuron in each time bin
Input:
mat_file = file containing spike data
fs = sampling rate
Output:
sp = binned spikes
"""
mat = io.loadmat(mat_file, variable_names=['Spike', 'Clu', 'xml'])
n_channels = mat['xml']['nChannels'][0][0][0][0]
dec = int(1250 / fs)
max_spike_res = np.ceil(np.max(mat['Spike']['res'][0][0]) / dec) + 1
max_spike_clu = np.max(mat['Spike']['totclu'][0][0]) + 1 # Precompute the bins
bins_res = np.arange(max_spike_res)
bins_clu = np.arange(max_spike_clu)
spike_res = np.squeeze(mat['Spike']['res'][0][0]) // dec
spike_clu = np.squeeze(mat['Spike']['totclu'][0][0]) - 1
# Bin both dimensions using histogram2d.
sp, _, _ = np.histogram2d(spike_res, spike_clu, bins=(bins_res, bins_clu))
sp = sp.astype(np.uint8)
mask = mat['Clu']['shank'][0][0][0] <= math.ceil(n_channels / 8)
sp = sp[:, mask]
return sp
def get_LFP(lfp_file, n_channels, init_fs, fs=25):
"""
Decimates LFPs to desired sampling rate
Input:
lfp_file = raw lfp data file of type .lfp
init_fs = inital sampling rate of the data
fs = desired sampling rate (to decimate to)
Output:
X = formatted lfp data
"""
dec = int(init_fs / fs)
file_size = os.path.getsize(lfp_file)
data_size = np.dtype('int16').itemsize
total_elements = file_size // data_size
n_samples = total_elements // n_channels
# Clip the rows to remove electrodes implanted in mPFC.
if n_channels > 256: # sessions 1 and 2
n_keep = 255
else: # sessions 3 and 4
n_keep = 192
# Load and decimate the data (takes more memory!)
# slice_data = np.memmap(lfp_file, dtype='int16', mode='r', shape=(n_samples, n_channels))
# X = decimate(slice_data[:, :n_keep], dec, axis=0)
# Process each channel individually and store in the pre-allocated array (takes less memory)
final_length = math.ceil(n_samples / dec)
X = np.zeros((final_length, n_keep), dtype=np.float32)
for channel in range(n_keep):
# Load the channel data using memmap
channel_data = np.memmap(lfp_file, dtype='int16', mode='r', shape=(n_samples, n_channels))[:, channel]
# Decimate the channel data
X[:, channel] = decimate(channel_data, dec, axis=0)
print(channel)
return X
def get_LFP_from_mat(lfp_data, n_channels, init_fs, fs=25):
"""
Decimates LFPs to desired sampling rate from a MATLAB file
Input:
lfp_data = LFP data array from MATLAB file
n_channels = number of channels in the data
init_fs = inital sampling rate of the data
fs = desired sampling rate (to decimate to)
Output:
X = formatted LFP data
"""
dec = int(init_fs / fs)
n_samples = lfp_data.shape[1]
# Clip the rows to remove electrodes implanted in mPFC.
if n_channels > 256: # sessions 1 and 2
n_keep = 255
else: # sessions 3 and 4
n_keep = 192
# Process each channel individually and store in the pre-allocated array
final_length = math.ceil(n_samples / dec)
X = np.zeros((final_length, n_keep), dtype=np.float32)
for channel in range(n_keep):
# Load the channel data directly from the lfp_data array
channel_data = lfp_data[channel, :]
# Decimate the channel data
X[:, channel] = decimate(channel_data, dec, axis=0)
print(channel)
return X