diff --git a/q2_sample_classifier/classify.py b/q2_sample_classifier/classify.py index 0e35436..ca449e2 100644 --- a/q2_sample_classifier/classify.py +++ b/q2_sample_classifier/classify.py @@ -15,6 +15,7 @@ from sklearn.pipeline import Pipeline import qiime2 +from qiime2.plugin import get_available_cores import pandas as pd import biom import skbio @@ -107,6 +108,9 @@ def _fit_predict_knn_cv( x: pd.DataFrame, y: pd.Series, k: int, cv: int, random_state: int, n_jobs: int ) -> (pd.Series, pd.Series): + if n_jobs == 0: + n_jobs = get_available_cores() + kf = KFold(n_splits=cv, shuffle=True, random_state=random_state) # train and test with CV @@ -291,6 +295,9 @@ def fit_regressor(table: biom.Table, def predict_base(table, sample_estimator, n_jobs): + if n_jobs == 0: + n_jobs = get_available_cores() + # extract feature data from biom feature_data = _extract_features(table) index = table.ids() diff --git a/q2_sample_classifier/plugin_setup.py b/q2_sample_classifier/plugin_setup.py index 58a3ad5..6f6c098 100644 --- a/q2_sample_classifier/plugin_setup.py +++ b/q2_sample_classifier/plugin_setup.py @@ -10,7 +10,7 @@ from qiime2.plugin import ( Int, Str, Float, Range, Bool, Plugin, Metadata, Choices, MetadataColumn, - Numeric, Categorical, Citations, Visualization, TypeMatch) + Numeric, Categorical, Citations, Visualization, TypeMatch, Threads) from q2_types.feature_table import ( FeatureTable, Frequency, RelativeFrequency, PresenceAbsence, Balance, PercentileNormalized, Design, Composition) @@ -100,7 +100,7 @@ parameters = { 'base': { 'random_state': Int, - 'n_jobs': Int, + 'n_jobs': Threads, 'n_estimators': Int % Range(1, None), 'missing_samples': Str % Choices(['error', 'ignore'])}, 'splitter': { diff --git a/q2_sample_classifier/utilities.py b/q2_sample_classifier/utilities.py index 334044c..b224676 100644 --- a/q2_sample_classifier/utilities.py +++ b/q2_sample_classifier/utilities.py @@ -27,6 +27,7 @@ ) from sklearn.pipeline import Pipeline +from qiime2.plugin import get_available_cores import q2templates import pandas as pd import numpy as np @@ -264,6 +265,9 @@ def nested_cross_validation(table, metadata, cv, random_state, n_jobs, n_estimators, estimator, stratify, parameter_tuning, classification, scoring, missing_samples='error'): + if n_jobs == 0: + n_jobs = get_available_cores() + # extract column name from NumericMetadataColumn column = metadata.name @@ -301,6 +305,9 @@ def _fit_estimator(features, targets, estimator, n_estimators=100, step=0.05, cv=5, random_state=None, n_jobs=1, optimize_feature_selection=False, parameter_tuning=False, missing_samples='error', classification=True): + if n_jobs == 0: + n_jobs = get_available_cores() + # extract column name from CategoricalMetadataColumn column = targets.to_series().name