8
8
from pyqtgraph import GraphicsScene
9
9
from scipy .stats import zscore
10
10
from matplotlib import cm
11
- from rastermap .roi import gROI
11
+ from rastermap .roi import gROI , dbROI
12
12
import rastermap .run
13
13
from rastermap import Rastermap
14
+ from sklearn .cluster import DBSCAN
14
15
15
16
def triangle_area (p ):
16
17
area = 0.5 * np .abs (p [0 ,0 ] * p [1 ,1 ] - p [0 ,0 ] * p [2 ,1 ] +
@@ -231,7 +232,7 @@ def __init__(self):
231
232
self .makegrid .clicked .connect (self .make_grid )
232
233
self .makegrid .setStyleSheet (self .styleInactive )
233
234
self .makegrid .setEnabled (False )
234
- self .makegrid .setFixedWidth (100 )
235
+ self .makegrid .setFixedWidth (200 )
235
236
self .l0 .addWidget (self .makegrid , rs + 7 , 0 , 1 , 1 )
236
237
self .gridsize = QtGui .QLineEdit (self )
237
238
self .gridsize .setValidator (QtGui .QIntValidator (2 , 20 ))
@@ -241,6 +242,21 @@ def __init__(self):
241
242
self .gridsize .returnPressed .connect (self .make_grid )
242
243
self .l0 .addWidget (self .gridsize , rs + 7 , 1 , 1 , 1 )
243
244
245
+ self .dbbutton = QtGui .QPushButton ("DBSCAN clusters, ms=" )
246
+ self .dbbutton .setFont (QtGui .QFont ("Arial" , 8 , QtGui .QFont .Bold ))
247
+ self .dbbutton .clicked .connect (self .dbscan )
248
+ self .dbbutton .setStyleSheet (self .styleInactive )
249
+ self .dbbutton .setEnabled (False )
250
+ self .dbbutton .setFixedWidth (200 )
251
+ self .l0 .addWidget (self .dbbutton , rs + 8 , 0 , 1 , 1 )
252
+ self .min_samples = QtGui .QLineEdit (self )
253
+ self .min_samples .setValidator (QtGui .QIntValidator (5 , 200 ))
254
+ self .min_samples .setText ("50" )
255
+ self .min_samples .setFixedWidth (45 )
256
+ self .min_samples .setAlignment (QtCore .Qt .AlignRight )
257
+ self .min_samples .returnPressed .connect (self .dbscan )
258
+ self .l0 .addWidget (self .min_samples , rs + 8 , 1 , 1 , 1 )
259
+
244
260
ysm = QtGui .QLabel ("<font color='white'>y-binning</font>" )
245
261
ysm .setFixedWidth (100 )
246
262
self .l0 .addWidget (ysm , rs + 6 , 0 , 1 , 1 )
@@ -295,11 +311,13 @@ def __init__(self):
295
311
# self.load_behavior('C:/Users/carse/github/TX4/beh.npy')
296
312
self .file_iscell = None
297
313
#self.fname = '/media/carsen/DATA2/grive/rastermap/DATA/embedding.npy'
298
- #self.load_proc(self.fname)
314
+ self .fname = 'D:/grive/cshl_suite2p/TX39/embedding.npy'
315
+ self .load_proc (self .fname )
299
316
300
317
self .show ()
301
318
self .win .show ()
302
319
320
+
303
321
def add_imgROI (self ):
304
322
if hasattr (self , 'imgROI' ):
305
323
self .pfull .removeItem (self .imgROI )
@@ -509,6 +527,22 @@ def update_selected(self, ineur):
509
527
ineur = self .selected [ineur ]
510
528
self .xp .setData (pos = self .embedding [ineur ,:][np .newaxis ,:])
511
529
530
+ def dbscan (self ):
531
+ ms = int (self .min_samples .text ())
532
+ # remove previous ROIs
533
+ if len (self .ROIs ) > 0 :
534
+ for n in range (len (self .ROIs )):
535
+ self .ROI_delete ()
536
+
537
+ db = DBSCAN (eps = 0.8 , min_samples = ms ).fit (self .embedding )
538
+ ilabels = np .unique (db .labels_ )
539
+ ilabels = ilabels [ilabels >= 0 ]
540
+ print (ilabels )
541
+ #ilabels = ilabels[:1]
542
+ for i in ilabels :
543
+ self .dbROI_add ((db .labels_ == i ).nonzero ()[0 ])
544
+ self .ROI_selection ()
545
+
512
546
def make_grid (self ):
513
547
ng = int (self .gridsize .text ())
514
548
if len (self .ROIs ) > 0 :
@@ -531,6 +565,15 @@ def make_grid(self):
531
565
self .ROI_add (pos , prect , color = jet [j + k * ng ]* 255.0 )
532
566
self .ROI_selection ()
533
567
568
+ def dbROI_add (self , selected , color = None ):
569
+ if color is None :
570
+ color = np .random .randint (255 ,size = (3 ,))
571
+ self .ROIs .append (dbROI (selected , color , self ))
572
+ self .Rselected .append (self .ROIs [- 1 ].selected )
573
+ self .Rcolors .append (np .reshape (np .tile (self .ROIs [- 1 ].color , 10 * self .Rselected [- 1 ].size ),
574
+ (self .Rselected [- 1 ].size , 10 , 3 )))
575
+ self .ROIorder .append (len (self .ROIs )- 1 )
576
+
534
577
def ROI_add (self , pos , prect , color = None ):
535
578
if color is None :
536
579
color = np .random .randint (255 ,size = (3 ,))
@@ -586,9 +629,12 @@ def enable_embedded(self):
586
629
self .updateROI .setEnabled (True )
587
630
self .saveROI .setEnabled (True )
588
631
self .makegrid .setEnabled (True )
632
+ self .dbbutton .setEnabled (True )
633
+
589
634
self .updateROI .setStyleSheet (self .styleUnpressed )
590
635
self .saveROI .setStyleSheet (self .styleUnpressed )
591
636
self .makegrid .setStyleSheet (self .styleUnpressed )
637
+ self .dbbutton .setStyleSheet (self .styleUnpressed )
592
638
593
639
def disable_embedded (self ):
594
640
self .updateROI .setEnabled (False )
@@ -649,8 +695,6 @@ def mouse_moved_bar(self, pos):
649
695
ineur = min (self .colormat .shape [0 ]- 1 , max (0 , int (np .floor (y ))))
650
696
self .update_selected (ineur )
651
697
652
-
653
-
654
698
def plot_clicked (self , event ):
655
699
"""left-click chooses a cell, right-click flips cell to other view"""
656
700
flip = False
0 commit comments