11
11
from scipy .sparse import issparse
12
12
from scipy .spatial .distance import cdist , pdist , squareform
13
13
from scipy .stats import median_abs_deviation
14
+ from sklearn .base import BaseEstimator , ClusterMixin
14
15
from sklearn .cluster import AgglomerativeClustering
15
16
16
17
from .io import read_csv , read_FCS
17
18
from .tl import SOM , ConsensusCluster , get_channels , get_markers , map_data_to_codes
18
19
19
20
20
- class FlowSOM :
21
+ class FlowSOM ( BaseEstimator , ClusterMixin ) :
21
22
"""A class that contains all the FlowSOM data using MuData objects."""
22
23
23
- def __init__ (self , inp = None , cols_to_use : np .ndarray | None = None , n_clus = 10 , seed : int | None = None , ** kwargs ):
24
+ def __init__ (
25
+ self ,
26
+ inp = None ,
27
+ cols_to_use : np .ndarray | None = None ,
28
+ n_clus = 10 ,
29
+ n_clusters = None ,
30
+ seed : int | None = None ,
31
+ random_state = None ,
32
+ ** kwargs ,
33
+ ):
24
34
"""Initialize the FlowSOM AnnData object.
25
35
26
36
:param inp: A file path to an FCS file or a AnnData FCS file to cluster
@@ -34,13 +44,46 @@ def __init__(self, inp=None, cols_to_use: np.ndarray | None = None, n_clus=10, s
34
44
"""
35
45
if seed is not None :
36
46
random .seed (seed )
47
+ if n_clus is not None and n_clusters is None :
48
+ n_clusters = n_clus
49
+ self .n_clusters = n_clusters
50
+ self .mudata = None
37
51
if inp is not None :
38
52
self .mudata = self .read_input (inp )
39
53
self .build_SOM (cols_to_use , ** kwargs )
40
54
self .build_MST ()
41
- self .metacluster (n_clus )
55
+ self .metacluster (self . n_clusters )
42
56
self ._update_derived_values ()
43
57
58
+ @property
59
+ def labels_ (self ):
60
+ """Get the labels."""
61
+ if "cell_data" in self .mudata .mod .keys ():
62
+ if "clustering" in self .mudata ["cell_data" ].obs_keys ():
63
+ return self .mudata ["cell_data" ].obs ["clustering" ]
64
+ return None
65
+
66
+ @labels_ .setter
67
+ def labels_ (self , value ):
68
+ """Set the labels."""
69
+ if "cell_data" in self .mudata .mod .keys ():
70
+ self .mudata ["cell_data" ].obs ["clustering" ] = value
71
+ else :
72
+ raise ValueError ("No cell data found in the MuData object." )
73
+
74
+ def fit (self , X , y = None ):
75
+ """Fit the model."""
76
+ self .build_SOM (X )
77
+ self .build_MST ()
78
+ self .metacluster ()
79
+ self ._update_derived_values ()
80
+ return self
81
+
82
+ def predict (self , X , y = None ):
83
+ """Predict the model."""
84
+ new_fsom = self .new_data (X )
85
+ return new_fsom
86
+
44
87
def read_input (self , inp ):
45
88
"""Converts input to a FlowSOM AnnData object.
46
89
0 commit comments