Skip to content

Commit 11d4641

Browse files
Merge branch 'master' of https://github.com/MouseLand/rastermap
2 parents 6cedd0c + ec0a963 commit 11d4641

File tree

2 files changed

+133
-44
lines changed

2 files changed

+133
-44
lines changed

Diff for: rastermap/gui.py

+104-43
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,10 @@
88
from pyqtgraph import GraphicsScene
99
from scipy.stats import zscore
1010
from matplotlib import cm
11-
from rastermap.roi import gROI
11+
from rastermap.roi import gROI, dbROI
1212
import rastermap.run
1313
from rastermap import Rastermap
14+
from sklearn.cluster import DBSCAN
1415

1516
def triangle_area(p):
1617
area = 0.5 * np.abs(p[0,0] * p[1,1] - p[0,0] * p[2,1] +
@@ -180,7 +181,7 @@ def __init__(self):
180181
self.p1 = self.win.addPlot(row=1, col=2, colspan=3,
181182
rowspan=3, invertY=True, padding=0)
182183
self.p1.setMouseEnabled(x=False, y=False)
183-
self.img = pg.ImageItem(autoDownsample=False)
184+
self.img = pg.ImageItem(autoDownsample=True)
184185
self.p1.hideAxis('left')
185186
self.p1.setMenuEnabled(False)
186187
self.p1.scene().contextMenuItem = self.p1
@@ -231,16 +232,31 @@ def __init__(self):
231232
self.makegrid.clicked.connect(self.make_grid)
232233
self.makegrid.setStyleSheet(self.styleInactive)
233234
self.makegrid.setEnabled(False)
234-
self.makegrid.setFixedWidth(100)
235+
self.makegrid.setFixedWidth(200)
235236
self.l0.addWidget(self.makegrid, rs+7, 0, 1, 1)
236237
self.gridsize = QtGui.QLineEdit(self)
237-
self.gridsize.setValidator(QtGui.QIntValidator(0, 500))
238-
self.gridsize.setText("10")
238+
self.gridsize.setValidator(QtGui.QIntValidator(2, 20))
239+
self.gridsize.setText("5")
239240
self.gridsize.setFixedWidth(45)
240241
self.gridsize.setAlignment(QtCore.Qt.AlignRight)
241242
self.gridsize.returnPressed.connect(self.make_grid)
242243
self.l0.addWidget(self.gridsize, rs+7, 1, 1, 1)
243244

245+
self.dbbutton = QtGui.QPushButton("DBSCAN clusters, ms=")
246+
self.dbbutton.setFont(QtGui.QFont("Arial", 8, QtGui.QFont.Bold))
247+
self.dbbutton.clicked.connect(self.dbscan)
248+
self.dbbutton.setStyleSheet(self.styleInactive)
249+
self.dbbutton.setEnabled(False)
250+
self.dbbutton.setFixedWidth(200)
251+
self.l0.addWidget(self.dbbutton, rs+8, 0, 1, 1)
252+
self.min_samples = QtGui.QLineEdit(self)
253+
self.min_samples.setValidator(QtGui.QIntValidator(5, 200))
254+
self.min_samples.setText("50")
255+
self.min_samples.setFixedWidth(45)
256+
self.min_samples.setAlignment(QtCore.Qt.AlignRight)
257+
self.min_samples.returnPressed.connect(self.dbscan)
258+
self.l0.addWidget(self.min_samples, rs+8, 1, 1, 1)
259+
244260
ysm = QtGui.QLabel("<font color='white'>y-binning</font>")
245261
ysm.setFixedWidth(100)
246262
self.l0.addWidget(ysm, rs+6, 0, 1, 1)
@@ -295,11 +311,13 @@ def __init__(self):
295311
# self.load_behavior('C:/Users/carse/github/TX4/beh.npy')
296312
self.file_iscell = None
297313
#self.fname = '/media/carsen/DATA2/grive/rastermap/DATA/embedding.npy'
298-
#self.load_proc(self.fname)
314+
self.fname = 'D:/grive/cshl_suite2p/TX39/embedding.npy'
315+
self.load_proc(self.fname)
299316

300317
self.show()
301318
self.win.show()
302319

320+
303321
def add_imgROI(self):
304322
if hasattr(self, 'imgROI'):
305323
self.pfull.removeItem(self.imgROI)
@@ -343,23 +361,24 @@ def smooth_activity(self):
343361
self.sp_smoothed /= 12
344362

345363
def plot_activity(self):
346-
self.smooth_activity()
347-
nn = self.sp_smoothed.shape[0]
348-
nt = self.sp_smoothed.shape[1]
349-
self.imgfull.setImage(self.sp_smoothed)
350-
self.imgfull.setLevels([self.sat[0],self.sat[1]])
351-
self.img.setImage(self.sp_smoothed)
352-
self.img.setLevels([self.sat[0],self.sat[1]])
353-
self.p1.setXRange(0, nt, padding=0)
354-
self.p1.setYRange(0, nn, padding=0)
355-
self.p1.setLimits(xMin=0,xMax=nt,yMin=0,yMax=nn)
356-
self.pfull.setXRange(0, nt, padding=0)
357-
self.pfull.setYRange(0, nn, padding=0)
358-
self.pfull.setLimits(xMin=-1,xMax=nt+1,yMin=-1,yMax=nn+1)
359-
self.imgROI.setPos(-.5,-.5)
360-
self.imgROI.setSize([nt+.5,nn+.5])
361-
self.imgROI.maxBounds = QtCore.QRectF(-1.,-1.,nt+1,nn+1)
362-
if 0:
364+
if self.embedded:
365+
self.smooth_activity()
366+
nn = self.sp_smoothed.shape[0]
367+
nt = self.sp_smoothed.shape[1]
368+
self.imgfull.setImage(self.sp_smoothed)
369+
self.imgfull.setLevels([self.sat[0],self.sat[1]])
370+
self.img.setImage(self.sp_smoothed)
371+
self.img.setLevels([self.sat[0],self.sat[1]])
372+
self.p1.setXRange(0, nt, padding=0)
373+
self.p1.setYRange(0, nn, padding=0)
374+
self.p1.setLimits(xMin=0,xMax=nt,yMin=0,yMax=nn)
375+
self.pfull.setXRange(0, nt, padding=0)
376+
self.pfull.setYRange(0, nn, padding=0)
377+
self.pfull.setLimits(xMin=-1,xMax=nt+1,yMin=-1,yMax=nn+1)
378+
self.imgROI.setPos(-.5,-.5)
379+
self.imgROI.setSize([nt+.5,nn+.5])
380+
self.imgROI.maxBounds = QtCore.QRectF(-1.,-1.,nt+1,nn+1)
381+
else:
363382
nn = self.sp.shape[0]
364383
nt = self.sp.shape[1]
365384
self.imgfull.setImage(self.sp)
@@ -382,8 +401,10 @@ def plot_activity(self):
382401
def plot_colorbar(self):
383402
nneur = self.colormat_plot.shape[0]
384403
self.colorimg.setImage(self.colormat_plot)
385-
N = int(self.smooth.text())
386-
NN = self.sp_smoothed.shape[0]*N
404+
if self.embedded:
405+
N = int(self.smooth.text())
406+
else:
407+
N = 1
387408
self.p3.setYRange(self.yrange[0]*N, self.yrange[-1]*N)
388409
self.p3.setXRange(0,10)
389410
self.p3.setLimits(yMin=self.yrange[0]*N,yMax=self.yrange[-1]*N,xMin=0,xMax=10)
@@ -415,17 +436,17 @@ def imgROI_range(self):
415436
yrange = (np.arange(0,int(sizey)) + np.floor(posy)).astype(np.int32)
416437
xrange = xrange[xrange>=0]
417438
yrange = yrange[yrange>=0]
418-
#if self.embedded:
419-
xrange = xrange[xrange<self.sp_smoothed.shape[1]]
420-
yrange = yrange[yrange<self.sp_smoothed.shape[0]]
421-
#else:
422-
# xrange = xrange[xrange<self.sp.shape[1]]
423-
# yrange = yrange[yrange<self.sp.shape[0]]
439+
if self.embedded:
440+
xrange = xrange[xrange<self.sp_smoothed.shape[1]]
441+
yrange = yrange[yrange<self.sp_smoothed.shape[0]]
442+
else:
443+
xrange = xrange[xrange<self.sp.shape[1]]
444+
yrange = yrange[yrange<self.sp.shape[0]]
424445
return xrange,yrange
425446

426447
def imgROI_position(self):
427448
xrange,yrange = self.imgROI_range()
428-
if 1:
449+
if self.embedded:
429450
self.img.setImage(self.sp_smoothed[np.ix_(yrange,xrange)])
430451
else:
431452
self.img.setImage(self.sp[np.ix_(yrange,xrange)])
@@ -437,7 +458,10 @@ def imgROI_position(self):
437458
axy = self.p3.getAxis('left')
438459
axx = self.p1.getAxis('bottom')
439460
self.plot_colorbar()
440-
N = int(self.smooth.text())
461+
if self.embedded:
462+
N = int(self.smooth.text())
463+
else:
464+
N = 1
441465
axy.setTicks([[(0,str(self.yrange[0])),(self.yrange[-1]*N,str(self.yrange[-1]*N))]])
442466
axx.setTicks([[(0.0,str(xrange[0])),(float(xrange.size),str(xrange[-1]))]])
443467

@@ -481,7 +505,9 @@ def ROI_selection(self, loaded=False):
481505

482506
self.colormat_plot = self.colormat.copy()
483507
self.plot_activity()
508+
print('plotted activity')
484509
self.plot_colorbar()
510+
print('plotted colorbar')
485511
self.win.show()
486512

487513
def update_selected(self, ineur):
@@ -501,19 +527,52 @@ def update_selected(self, ineur):
501527
ineur = self.selected[ineur]
502528
self.xp.setData(pos=self.embedding[ineur,:][np.newaxis,:])
503529

530+
def dbscan(self):
531+
ms = int(self.min_samples.text())
532+
# remove previous ROIs
533+
if len(self.ROIs) > 0:
534+
for n in range(len(self.ROIs)):
535+
self.ROI_delete()
536+
537+
db = DBSCAN(eps=0.8, min_samples=ms).fit(self.embedding)
538+
ilabels = np.unique(db.labels_)
539+
ilabels = ilabels[ilabels>=0]
540+
print(ilabels)
541+
#ilabels = ilabels[:1]
542+
for i in ilabels:
543+
self.dbROI_add((db.labels_==i).nonzero()[0])
544+
self.ROI_selection()
545+
504546
def make_grid(self):
505547
ng = int(self.gridsize.text())
506548
if len(self.ROIs) > 0:
507549
for n in range(len(self.ROIs)):
508550
self.ROI_delete()
509-
sz = self.embedding.max() / ng
510-
print(sz)
511-
corners = np.array([j*sz for j in range(0,ng)])
512-
print(corners)
513-
#for j in range(ng):
514-
# for k in range(ng):
515-
# prect = np.array([[corners[j],corners[k]],
516-
# [corners[j],corners[k]],
551+
sz = (self.embedding.max() - self.embedding.min()) / ng
552+
corners = np.linspace(self.embedding.min(), self.embedding.max(), ng+1)
553+
jet = cm.get_cmap('jet')
554+
jet = jet(np.linspace(0,1,ng**2))
555+
jet = jet[:,:3]
556+
for j in range(ng):
557+
for k in range(ng):
558+
prect = [np.array([[corners[j],corners[k]],
559+
[corners[j+1],corners[k]],
560+
[corners[j+1],corners[k+1]],
561+
[corners[j],corners[k+1]],
562+
[corners[j],corners[k]]])]
563+
pos = [np.array([[corners[j+1],corners[k]+sz/2],
564+
[corners[j],corners[k]+sz/2]])]
565+
self.ROI_add(pos, prect, color=jet[j+k*ng]*255.0)
566+
self.ROI_selection()
567+
568+
def dbROI_add(self, selected, color=None):
569+
if color is None:
570+
color = np.random.randint(255,size=(3,))
571+
self.ROIs.append(dbROI(selected, color, self))
572+
self.Rselected.append(self.ROIs[-1].selected)
573+
self.Rcolors.append(np.reshape(np.tile(self.ROIs[-1].color, 10 * self.Rselected[-1].size),
574+
(self.Rselected[-1].size, 10, 3)))
575+
self.ROIorder.append(len(self.ROIs)-1)
517576

518577
def ROI_add(self, pos, prect, color=None):
519578
if color is None:
@@ -570,9 +629,12 @@ def enable_embedded(self):
570629
self.updateROI.setEnabled(True)
571630
self.saveROI.setEnabled(True)
572631
self.makegrid.setEnabled(True)
632+
self.dbbutton.setEnabled(True)
633+
573634
self.updateROI.setStyleSheet(self.styleUnpressed)
574635
self.saveROI.setStyleSheet(self.styleUnpressed)
575636
self.makegrid.setStyleSheet(self.styleUnpressed)
637+
self.dbbutton.setStyleSheet(self.styleUnpressed)
576638

577639
def disable_embedded(self):
578640
self.updateROI.setEnabled(False)
@@ -633,8 +695,6 @@ def mouse_moved_bar(self, pos):
633695
ineur = min(self.colormat.shape[0]-1, max(0, int(np.floor(y))))
634696
self.update_selected(ineur)
635697

636-
637-
638698
def plot_clicked(self, event):
639699
"""left-click chooses a cell, right-click flips cell to other view"""
640700
flip = False
@@ -778,6 +838,7 @@ def load_mat(self, name=None):
778838
self.ROI_selection()
779839
self.enable_loaded()
780840
self.show()
841+
print('done loading')
781842
self.loaded = True
782843

783844
def load_iscell(self):

Diff for: rastermap/roi.py

+29-1
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,34 @@ def triangle_area(p0, p1, p2):
1111
p2[:,0] * p0[1] - p2[:,0] * p1[1])
1212
return area
1313

14+
class dbROI():
15+
"""
16+
ROI that is premade from DBSCAN clustering
17+
"""
18+
def __init__(self, selected, color, parent=None):
19+
self.color = color
20+
self.selected = selected
21+
self.positions = parent.embedding[self.selected, :]
22+
self.pen = pg.mkPen(pg.mkColor(self.color),
23+
width=1,
24+
style=QtCore.Qt.SolidLine)
25+
self.ROIplot = pg.ScatterPlotItem(pos=self.positions, pen=self.pen, symbol='o', size=2)
26+
parent.p0.addItem(self.ROIplot)
27+
parent.p0.removeItem(parent.xp)
28+
parent.p0.addItem(parent.xp)
29+
30+
def inROI(self, Y):
31+
dists = np.zeros((Y.shape[0],))
32+
for k,y in enumerate(Y):
33+
dists[k] = (((self.positions - self.y)**2).sum(axis=1)**0.5).min()
34+
inroi = dists < 0.5
35+
36+
return Y[inroi], dists[inroi]
37+
38+
def remove(self, parent):
39+
'''remove ROI'''
40+
parent.p0.removeItem(self.ROIplot)
41+
1442
class gROI():
1543
'''
1644
draw a line segment which is the gradient over which to plot the points
@@ -21,7 +49,7 @@ def __init__(self, pos, prect, color, parent=None):
2149
self.d = ((prect[0][0,:] - prect[0][1,:])**2).sum()**0.5 / 2
2250
#self.slope = (pos[1,1] - pos[0,1]) / (pos[1,0] - pos[0,0])
2351
#self.yint = pos[1,0] - self.slope * pos[0,0]
24-
np.save('groi.npy', {'prect': self.prect, 'pos': self.pos})
52+
#np.save('groi.npy', {'prect': self.prect, 'pos': self.pos})
2553
self.color = color
2654
self.pen = pg.mkPen(pg.mkColor(self.color),
2755
width=3,

0 commit comments

Comments
 (0)