From fea380b81953b28eef8e2439efec37de39b3610e Mon Sep 17 00:00:00 2001 From: Edwin Onuonga Date: Fri, 27 Dec 2024 20:54:35 +0000 Subject: [PATCH] add `HMMClassifier.fit` multiprocesing --- docs/source/sections/models/index.rst | 4 ++-- sequentia/model_selection/_search.py | 5 +++-- sequentia/models/hmm/classifier.py | 20 +++++++++++++++++-- tests/unit/test_models/hmm/test_classifier.py | 5 ++++- 4 files changed, 27 insertions(+), 7 deletions(-) diff --git a/docs/source/sections/models/index.rst b/docs/source/sections/models/index.rst index 2b9708e..ba03888 100644 --- a/docs/source/sections/models/index.rst +++ b/docs/source/sections/models/index.rst @@ -16,9 +16,9 @@ The following models provided by Sequentia all support variable length sequences | | | | +----------+------------+ | | | | | Training | Prediction | +=========================+==============================+================+===============+==============+==========+============+ -| :class:`.HMMClassifier` | :class:`.GaussianMixtureHMM` | Classification | Real | ✔ | ✗ | ✔ | +| :class:`.HMMClassifier` | :class:`.GaussianMixtureHMM` | Classification | Real | ✔ | ✔ | ✔ | | +------------------------------+----------------+---------------+--------------+----------+------------+ -| | :class:`.CategoricalHMM` | Classification | Categorical | ✗ | ✗ | ✔ | +| | :class:`.CategoricalHMM` | Classification | Categorical | ✗ | ✔ | ✔ | +-------------------------+------------------------------+----------------+---------------+--------------+----------+------------+ | :class:`.KNNRegressor` | Regression | Real | ✔ | N/A | ✔ | +--------------------------------------------------------+----------------+---------------+--------------+----------+------------+ diff --git a/sequentia/model_selection/_search.py b/sequentia/model_selection/_search.py index e8be5ff..d0b5ef5 100644 --- a/sequentia/model_selection/_search.py +++ b/sequentia/model_selection/_search.py @@ -79,9 +79,10 @@ def param_grid(**kwargs: list[t.Any]) -> list[dict[str, t.Any]]: settings for :class:`.GaussianMixtureHMM`, which is a nested model specified in the constructor of a :class:`.HMMClassifier`. :: - from sklearn.preprocessing import Pipeline, minmax_scale + from sklearn.pipeline import Pipeline + from sklearn.preprocessing import minmax_scale - from sequenta.enums import PriorMode, CovarianceMode, TopologyMode + from sequentia.enums import PriorMode, CovarianceMode, TopologyMode from sequentia.models import HMMClassifier, GaussianMixtureHMM from sequentia.preprocessing import IndependentFunctionTransformer from sequentia.model_selection import GridSearchCV, StratifiedKFold diff --git a/sequentia/models/hmm/classifier.py b/sequentia/models/hmm/classifier.py index 0ccf032..5cfe488 100644 --- a/sequentia/models/hmm/classifier.py +++ b/sequentia/models/hmm/classifier.py @@ -356,8 +356,24 @@ def fit( lengths=lengths, classes=self.classes_, ) - for X_c, lengths_c, c in dataset.iter_by_class(): - self.models[c].fit(X_c, lengths=lengths_c) + + # get number of jobs + n_jobs = _multiprocessing.effective_n_jobs( + self.n_jobs, x=self.classes_ + ) + + # fit models in parallel + self.models = dict( + zip( + self.classes_, + joblib.Parallel(n_jobs=n_jobs, max_nbytes=None)( + joblib.delayed(self.models[c].fit)( + X_c, lengths=lengths_c + ) + for X_c, lengths_c, c in dataset.iter_by_class() + ), + ) + ) # Set class priors models: t.Iterable[int, variants.BaseHMM] = self.models.items() diff --git a/tests/unit/test_models/hmm/test_classifier.py b/tests/unit/test_models/hmm/test_classifier.py index 31dbedc..e7d180d 100644 --- a/tests/unit/test_models/hmm/test_classifier.py +++ b/tests/unit/test_models/hmm/test_classifier.py @@ -121,6 +121,7 @@ def assert_fit(clf: BaseHMM): ], ) @pytest.mark.parametrize("fit_mode", list(FitMode)) +@pytest.mark.parametrize("n_jobs", [1, -1]) def test_classifier_e2e( request: SubRequest, helpers: t.Any, @@ -128,9 +129,10 @@ def test_classifier_e2e( dataset: SequentialDataset, prior: enums.PriorMode | dict[int, float], fit_mode: FitMode, + n_jobs: int, random_state: np.random.RandomState, ) -> None: - clf = HMMClassifier(prior=prior) + clf = HMMClassifier(prior=prior, n_jobs=n_jobs) clf.add_models({i: copy.deepcopy(model) for i in range(n_classes)}) assert clf.prior == prior @@ -156,6 +158,7 @@ def test_classifier_e2e( variant=type(model), model_kwargs=model.get_params(), prior=prior, + n_jobs=n_jobs, ) clf.fit(**train.X_y_lengths)