Skip to content

Commit b7d872d

Browse files
committed
add sklearn API
1 parent eb02269 commit b7d872d

File tree

1 file changed

+46
-3
lines changed

1 file changed

+46
-3
lines changed

src/FlowSOM/main.py

Lines changed: 46 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,16 +11,26 @@
1111
from scipy.sparse import issparse
1212
from scipy.spatial.distance import cdist, pdist, squareform
1313
from scipy.stats import median_abs_deviation
14+
from sklearn.base import BaseEstimator, ClusterMixin
1415
from sklearn.cluster import AgglomerativeClustering
1516

1617
from .io import read_csv, read_FCS
1718
from .tl import SOM, ConsensusCluster, get_channels, get_markers, map_data_to_codes
1819

1920

20-
class FlowSOM:
21+
class FlowSOM(BaseEstimator, ClusterMixin):
2122
"""A class that contains all the FlowSOM data using MuData objects."""
2223

23-
def __init__(self, inp=None, cols_to_use: np.ndarray | None = None, n_clus=10, seed: int | None = None, **kwargs):
24+
def __init__(
25+
self,
26+
inp=None,
27+
cols_to_use: np.ndarray | None = None,
28+
n_clus=10,
29+
n_clusters=None,
30+
seed: int | None = None,
31+
random_state=None,
32+
**kwargs,
33+
):
2434
"""Initialize the FlowSOM AnnData object.
2535
2636
:param inp: A file path to an FCS file or a AnnData FCS file to cluster
@@ -34,13 +44,46 @@ def __init__(self, inp=None, cols_to_use: np.ndarray | None = None, n_clus=10, s
3444
"""
3545
if seed is not None:
3646
random.seed(seed)
47+
if n_clus is not None and n_clusters is None:
48+
n_clusters = n_clus
49+
self.n_clusters = n_clusters
50+
self.mudata = None
3751
if inp is not None:
3852
self.mudata = self.read_input(inp)
3953
self.build_SOM(cols_to_use, **kwargs)
4054
self.build_MST()
41-
self.metacluster(n_clus)
55+
self.metacluster(self.n_clusters)
4256
self._update_derived_values()
4357

58+
@property
59+
def labels_(self):
60+
"""Get the labels."""
61+
if "cell_data" in self.mudata.mod.keys():
62+
if "clustering" in self.mudata["cell_data"].obs_keys():
63+
return self.mudata["cell_data"].obs["clustering"]
64+
return None
65+
66+
@labels_.setter
67+
def labels_(self, value):
68+
"""Set the labels."""
69+
if "cell_data" in self.mudata.mod.keys():
70+
self.mudata["cell_data"].obs["clustering"] = value
71+
else:
72+
raise ValueError("No cell data found in the MuData object.")
73+
74+
def fit(self, X, y=None):
75+
"""Fit the model."""
76+
self.build_SOM(X)
77+
self.build_MST()
78+
self.metacluster()
79+
self._update_derived_values()
80+
return self
81+
82+
def predict(self, X, y=None):
83+
"""Predict the model."""
84+
new_fsom = self.new_data(X)
85+
return new_fsom
86+
4487
def read_input(self, inp):
4588
"""Converts input to a FlowSOM AnnData object.
4689

0 commit comments

Comments
 (0)