Skip to content

Commit ec0a963

Browse files
added DBSCAN clusters to autodraw on plot
1 parent 1790bf6 commit ec0a963

File tree

2 files changed

+78
-6
lines changed

2 files changed

+78
-6
lines changed

rastermap/gui.py

+49-5
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,10 @@
88
from pyqtgraph import GraphicsScene
99
from scipy.stats import zscore
1010
from matplotlib import cm
11-
from rastermap.roi import gROI
11+
from rastermap.roi import gROI, dbROI
1212
import rastermap.run
1313
from rastermap import Rastermap
14+
from sklearn.cluster import DBSCAN
1415

1516
def triangle_area(p):
1617
area = 0.5 * np.abs(p[0,0] * p[1,1] - p[0,0] * p[2,1] +
@@ -231,7 +232,7 @@ def __init__(self):
231232
self.makegrid.clicked.connect(self.make_grid)
232233
self.makegrid.setStyleSheet(self.styleInactive)
233234
self.makegrid.setEnabled(False)
234-
self.makegrid.setFixedWidth(100)
235+
self.makegrid.setFixedWidth(200)
235236
self.l0.addWidget(self.makegrid, rs+7, 0, 1, 1)
236237
self.gridsize = QtGui.QLineEdit(self)
237238
self.gridsize.setValidator(QtGui.QIntValidator(2, 20))
@@ -241,6 +242,21 @@ def __init__(self):
241242
self.gridsize.returnPressed.connect(self.make_grid)
242243
self.l0.addWidget(self.gridsize, rs+7, 1, 1, 1)
243244

245+
self.dbbutton = QtGui.QPushButton("DBSCAN clusters, ms=")
246+
self.dbbutton.setFont(QtGui.QFont("Arial", 8, QtGui.QFont.Bold))
247+
self.dbbutton.clicked.connect(self.dbscan)
248+
self.dbbutton.setStyleSheet(self.styleInactive)
249+
self.dbbutton.setEnabled(False)
250+
self.dbbutton.setFixedWidth(200)
251+
self.l0.addWidget(self.dbbutton, rs+8, 0, 1, 1)
252+
self.min_samples = QtGui.QLineEdit(self)
253+
self.min_samples.setValidator(QtGui.QIntValidator(5, 200))
254+
self.min_samples.setText("50")
255+
self.min_samples.setFixedWidth(45)
256+
self.min_samples.setAlignment(QtCore.Qt.AlignRight)
257+
self.min_samples.returnPressed.connect(self.dbscan)
258+
self.l0.addWidget(self.min_samples, rs+8, 1, 1, 1)
259+
244260
ysm = QtGui.QLabel("<font color='white'>y-binning</font>")
245261
ysm.setFixedWidth(100)
246262
self.l0.addWidget(ysm, rs+6, 0, 1, 1)
@@ -295,11 +311,13 @@ def __init__(self):
295311
# self.load_behavior('C:/Users/carse/github/TX4/beh.npy')
296312
self.file_iscell = None
297313
#self.fname = '/media/carsen/DATA2/grive/rastermap/DATA/embedding.npy'
298-
#self.load_proc(self.fname)
314+
self.fname = 'D:/grive/cshl_suite2p/TX39/embedding.npy'
315+
self.load_proc(self.fname)
299316

300317
self.show()
301318
self.win.show()
302319

320+
303321
def add_imgROI(self):
304322
if hasattr(self, 'imgROI'):
305323
self.pfull.removeItem(self.imgROI)
@@ -509,6 +527,22 @@ def update_selected(self, ineur):
509527
ineur = self.selected[ineur]
510528
self.xp.setData(pos=self.embedding[ineur,:][np.newaxis,:])
511529

530+
def dbscan(self):
531+
ms = int(self.min_samples.text())
532+
# remove previous ROIs
533+
if len(self.ROIs) > 0:
534+
for n in range(len(self.ROIs)):
535+
self.ROI_delete()
536+
537+
db = DBSCAN(eps=0.8, min_samples=ms).fit(self.embedding)
538+
ilabels = np.unique(db.labels_)
539+
ilabels = ilabels[ilabels>=0]
540+
print(ilabels)
541+
#ilabels = ilabels[:1]
542+
for i in ilabels:
543+
self.dbROI_add((db.labels_==i).nonzero()[0])
544+
self.ROI_selection()
545+
512546
def make_grid(self):
513547
ng = int(self.gridsize.text())
514548
if len(self.ROIs) > 0:
@@ -531,6 +565,15 @@ def make_grid(self):
531565
self.ROI_add(pos, prect, color=jet[j+k*ng]*255.0)
532566
self.ROI_selection()
533567

568+
def dbROI_add(self, selected, color=None):
569+
if color is None:
570+
color = np.random.randint(255,size=(3,))
571+
self.ROIs.append(dbROI(selected, color, self))
572+
self.Rselected.append(self.ROIs[-1].selected)
573+
self.Rcolors.append(np.reshape(np.tile(self.ROIs[-1].color, 10 * self.Rselected[-1].size),
574+
(self.Rselected[-1].size, 10, 3)))
575+
self.ROIorder.append(len(self.ROIs)-1)
576+
534577
def ROI_add(self, pos, prect, color=None):
535578
if color is None:
536579
color = np.random.randint(255,size=(3,))
@@ -586,9 +629,12 @@ def enable_embedded(self):
586629
self.updateROI.setEnabled(True)
587630
self.saveROI.setEnabled(True)
588631
self.makegrid.setEnabled(True)
632+
self.dbbutton.setEnabled(True)
633+
589634
self.updateROI.setStyleSheet(self.styleUnpressed)
590635
self.saveROI.setStyleSheet(self.styleUnpressed)
591636
self.makegrid.setStyleSheet(self.styleUnpressed)
637+
self.dbbutton.setStyleSheet(self.styleUnpressed)
592638

593639
def disable_embedded(self):
594640
self.updateROI.setEnabled(False)
@@ -649,8 +695,6 @@ def mouse_moved_bar(self, pos):
649695
ineur = min(self.colormat.shape[0]-1, max(0, int(np.floor(y))))
650696
self.update_selected(ineur)
651697

652-
653-
654698
def plot_clicked(self, event):
655699
"""left-click chooses a cell, right-click flips cell to other view"""
656700
flip = False

rastermap/roi.py

+29-1
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,34 @@ def triangle_area(p0, p1, p2):
1111
p2[:,0] * p0[1] - p2[:,0] * p1[1])
1212
return area
1313

14+
class dbROI():
15+
"""
16+
ROI that is premade from DBSCAN clustering
17+
"""
18+
def __init__(self, selected, color, parent=None):
19+
self.color = color
20+
self.selected = selected
21+
self.positions = parent.embedding[self.selected, :]
22+
self.pen = pg.mkPen(pg.mkColor(self.color),
23+
width=1,
24+
style=QtCore.Qt.SolidLine)
25+
self.ROIplot = pg.ScatterPlotItem(pos=self.positions, pen=self.pen, symbol='o', size=2)
26+
parent.p0.addItem(self.ROIplot)
27+
parent.p0.removeItem(parent.xp)
28+
parent.p0.addItem(parent.xp)
29+
30+
def inROI(self, Y):
31+
dists = np.zeros((Y.shape[0],))
32+
for k,y in enumerate(Y):
33+
dists[k] = (((self.positions - self.y)**2).sum(axis=1)**0.5).min()
34+
inroi = dists < 0.5
35+
36+
return Y[inroi], dists[inroi]
37+
38+
def remove(self, parent):
39+
'''remove ROI'''
40+
parent.p0.removeItem(self.ROIplot)
41+
1442
class gROI():
1543
'''
1644
draw a line segment which is the gradient over which to plot the points
@@ -21,7 +49,7 @@ def __init__(self, pos, prect, color, parent=None):
2149
self.d = ((prect[0][0,:] - prect[0][1,:])**2).sum()**0.5 / 2
2250
#self.slope = (pos[1,1] - pos[0,1]) / (pos[1,0] - pos[0,0])
2351
#self.yint = pos[1,0] - self.slope * pos[0,0]
24-
np.save('groi.npy', {'prect': self.prect, 'pos': self.pos})
52+
#np.save('groi.npy', {'prect': self.prect, 'pos': self.pos})
2553
self.color = color
2654
self.pen = pg.mkPen(pg.mkColor(self.color),
2755
width=3,

0 commit comments

Comments
 (0)