diff --git a/src/zenml/integrations/mlflow/__init__.py b/src/zenml/integrations/mlflow/__init__.py index 3a7bfdafc5a..2c067596ceb 100644 --- a/src/zenml/integrations/mlflow/__init__.py +++ b/src/zenml/integrations/mlflow/__init__.py @@ -35,7 +35,7 @@ class MlflowIntegration(Integration): # does not pin it. They fixed this in a later version, so we can probably # remove this once we update the mlflow version. REQUIREMENTS = [ - "mlflow>=2.1.1,<=2.6.0", + "mlflow>=2.1.1,<=2.9.2", "mlserver>=1.3.3", "mlserver-mlflow>=1.3.3", ] diff --git a/src/zenml/integrations/mlflow/experiment_trackers/mlflow_experiment_tracker.py b/src/zenml/integrations/mlflow/experiment_trackers/mlflow_experiment_tracker.py index c0f8f3c9cb7..3ce4416a6cc 100644 --- a/src/zenml/integrations/mlflow/experiment_trackers/mlflow_experiment_tracker.py +++ b/src/zenml/integrations/mlflow/experiment_trackers/mlflow_experiment_tracker.py @@ -13,6 +13,7 @@ # permissions and limitations under the License. """Implementation of the MLflow experiment tracker for ZenML.""" +import importlib import os from typing import TYPE_CHECKING, Any, Dict, Optional, Type, cast @@ -219,34 +220,37 @@ def get_step_run_metadata( } def disable_autologging(self) -> None: - """Disables MLflow autologging.""" - from mlflow import ( - fastai, - gluon, - lightgbm, - pytorch, - sklearn, - spark, - statsmodels, - tensorflow, - xgboost, - ) - - # There is no way to disable auto-logging for all frameworks at once. - # If auto-logging is explicitly enabled for a framework by calling its - # autolog() method, it cannot be disabled by calling - # `mlflow.autolog(disable=True)`. Therefore, we need to disable - # auto-logging for all frameworks explicitly. - - tensorflow.autolog(disable=True) - gluon.autolog(disable=True) - xgboost.autolog(disable=True) - lightgbm.autolog(disable=True) - statsmodels.autolog(disable=True) - spark.autolog(disable=True) - sklearn.autolog(disable=True) - fastai.autolog(disable=True) - pytorch.autolog(disable=True) + """Disables MLflow autologging for all supported frameworks.""" + frameworks = [ + "tensorflow", + "gluon", + "xgboost", + "lightgbm", + "statsmodels", + "spark", + "sklearn", + "fastai", + "pytorch", + ] + + failed_frameworks = [] + + for framework in frameworks: + try: + # Correctly prefix the module name with 'mlflow.' + module_name = f"mlflow.{framework}" + # Dynamically import the module corresponding to the framework + module = importlib.import_module(module_name) + # Call the autolog function with disable=True + module.autolog(disable=True) + except Exception: + failed_frameworks.append(framework) + + if len(failed_frameworks) > 0: + logger.warning( + f"Failed to disable MLflow autologging for the following frameworks: " + f"{failed_frameworks}." + ) def cleanup_step_run( self,