From a30a677ab0522c7f38d49bbccf8a9a0147269cd9 Mon Sep 17 00:00:00 2001 From: Michael Schuster Date: Mon, 2 Dec 2024 15:47:55 +0100 Subject: [PATCH] Disable client-side caching for some orchestrators --- .../flavors/kubernetes_orchestrator_flavor.py | 11 +++++++++++ .../flavors/lightning_orchestrator_flavor.py | 11 +++++++++++ .../flavors/skypilot_orchestrator_base_vm_config.py | 12 ++++++++++++ src/zenml/orchestrators/base_orchestrator.py | 10 ++++++++++ .../service_connectors/service_connector_utils.py | 12 +++--------- 5 files changed, 47 insertions(+), 9 deletions(-) diff --git a/src/zenml/integrations/kubernetes/flavors/kubernetes_orchestrator_flavor.py b/src/zenml/integrations/kubernetes/flavors/kubernetes_orchestrator_flavor.py index 6401c11ab75..4f5902efc21 100644 --- a/src/zenml/integrations/kubernetes/flavors/kubernetes_orchestrator_flavor.py +++ b/src/zenml/integrations/kubernetes/flavors/kubernetes_orchestrator_flavor.py @@ -134,6 +134,17 @@ def is_schedulable(self) -> bool: """ return True + @property + def supports_client_side_caching(self) -> bool: + """Whether the orchestrator supports client side caching. + + Returns: + Whether the orchestrator supports client side caching. + """ + # The Kubernetes orchestrator starts step pods from a pipeline pod. + # This is currently not supported when using client-side caching. + return False + class KubernetesOrchestratorFlavor(BaseOrchestratorFlavor): """Kubernetes orchestrator flavor.""" diff --git a/src/zenml/integrations/lightning/flavors/lightning_orchestrator_flavor.py b/src/zenml/integrations/lightning/flavors/lightning_orchestrator_flavor.py index 77cc2b08959..bf66791123e 100644 --- a/src/zenml/integrations/lightning/flavors/lightning_orchestrator_flavor.py +++ b/src/zenml/integrations/lightning/flavors/lightning_orchestrator_flavor.py @@ -94,6 +94,17 @@ def is_schedulable(self) -> bool: """ return False + @property + def supports_client_side_caching(self) -> bool: + """Whether the orchestrator supports client side caching. + + Returns: + Whether the orchestrator supports client side caching. + """ + # The Lightning orchestrator starts step studios from a pipeline studio. + # This is currently not supported when using client-side caching. + return False + class LightningOrchestratorFlavor(BaseOrchestratorFlavor): """Lightning orchestrator flavor.""" diff --git a/src/zenml/integrations/skypilot/flavors/skypilot_orchestrator_base_vm_config.py b/src/zenml/integrations/skypilot/flavors/skypilot_orchestrator_base_vm_config.py index 91322a4dcc7..c80926620ac 100644 --- a/src/zenml/integrations/skypilot/flavors/skypilot_orchestrator_base_vm_config.py +++ b/src/zenml/integrations/skypilot/flavors/skypilot_orchestrator_base_vm_config.py @@ -144,3 +144,15 @@ def is_local(self) -> bool: True if this config is for a local component, False otherwise. """ return False + + @property + def supports_client_side_caching(self) -> bool: + """Whether the orchestrator supports client side caching. + + Returns: + Whether the orchestrator supports client side caching. + """ + # The Skypilot orchestrator runs the entire pipeline in a single VM, or + # starts additional VMs from the root VM. Both of those cases are + # currently not supported when using client-side caching. + return False diff --git a/src/zenml/orchestrators/base_orchestrator.py b/src/zenml/orchestrators/base_orchestrator.py index b0bd86647b7..fc4b0e5962f 100644 --- a/src/zenml/orchestrators/base_orchestrator.py +++ b/src/zenml/orchestrators/base_orchestrator.py @@ -84,6 +84,15 @@ def is_schedulable(self) -> bool: """ return False + @property + def supports_client_side_caching(self) -> bool: + """Whether the orchestrator supports client side caching. + + Returns: + Whether the orchestrator supports client side caching. + """ + return True + class BaseOrchestrator(StackComponent, ABC): """Base class for all orchestrators. @@ -205,6 +214,7 @@ def run( if ( placeholder_run + and self.config.supports_client_side_caching and not deployment.schedule and not prevent_client_side_caching ): diff --git a/src/zenml/service_connectors/service_connector_utils.py b/src/zenml/service_connectors/service_connector_utils.py index d97f097faf5..a20f4847ee3 100644 --- a/src/zenml/service_connectors/service_connector_utils.py +++ b/src/zenml/service_connectors/service_connector_utils.py @@ -60,15 +60,9 @@ def _raise_specific_cloud_exception_if_needed( orchestrators: List[ResourcesInfo], container_registries: List[ResourcesInfo], ) -> None: - AWS_DOCS = ( - "https://docs.zenml.io/how-to/infrastructure-deployment/auth-management/aws-service-connector" - ) - GCP_DOCS = ( - "https://docs.zenml.io/how-to/infrastructure-deployment/auth-management/gcp-service-connector" - ) - AZURE_DOCS = ( - "https://docs.zenml.io/how-to/infrastructure-deployment/auth-management/azure-service-connector" - ) + AWS_DOCS = "https://docs.zenml.io/how-to/infrastructure-deployment/auth-management/aws-service-connector" + GCP_DOCS = "https://docs.zenml.io/how-to/infrastructure-deployment/auth-management/gcp-service-connector" + AZURE_DOCS = "https://docs.zenml.io/how-to/infrastructure-deployment/auth-management/azure-service-connector" if not artifact_stores: error_msg = (