Skip to content

Commit e3c9208

Browse files
gbruninppdebreuck
andauthored
Random state in feature selection (ppdebreuck#168)
* Upgraded pymatgen and matminer requirements * backward compatibility warning * Possibility to remove all NaNs features or not after featurization. * Arg in featurize. * Arg in preset because there are clean_df there as well. * Easier setting of drop_allnan. * Let this for another PR. * Possibility to tune random_state in feature selection. Useful when segfaults appear with very small datasets (testing). * update doscstring --------- Co-authored-by: ppdebreuck <[email protected]>
1 parent f49b1ca commit e3c9208

File tree

1 file changed

+8
-1
lines changed

1 file changed

+8
-1
lines changed

modnet/preprocessing.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -805,6 +805,7 @@ def feature_selection(
805805
drop_thr: float = 0.2,
806806
n_jobs: int = None,
807807
ignore_names: Optional[List] = [],
808+
random_state: int = None,
808809
):
809810
"""Compute the mutual information between features and targets,
810811
then apply relevance-redundancy rankings to choose the top `n`
@@ -823,6 +824,7 @@ def feature_selection(
823824
n_jobs: max. number of processes to use when calculating cross NMI.
824825
ignore_names (List): Optional list of property names to ignore during feature selection.
825826
Feature selection will be performed w.r.t. all properties except the ones in ignore_names.
827+
random_state (int): Seed used to compute the NMI.
826828
827829
"""
828830
if getattr(self, "df_featurized", None) is None:
@@ -867,7 +869,11 @@ def feature_selection(
867869
else:
868870
df = self.df_featurized.copy()
869871
self.cross_nmi, self.feature_entropy = get_cross_nmi(
870-
df, return_entropy=True, drop_thr=drop_thr, n_jobs=n_jobs
872+
df,
873+
return_entropy=True,
874+
drop_thr=drop_thr,
875+
n_jobs=n_jobs,
876+
random_state=random_state,
871877
)
872878

873879
if self.cross_nmi.isna().sum().sum() > 0:
@@ -897,6 +903,7 @@ def feature_selection(
897903
df,
898904
df_target,
899905
task_type,
906+
random_state=random_state,
900907
)[name]
901908

902909
LOG.info("Computing optimal features...")

0 commit comments

Comments
 (0)