diff --git a/flytekit/models/task.py b/flytekit/models/task.py index 88430aa28a..1bd5eda2d6 100644 --- a/flytekit/models/task.py +++ b/flytekit/models/task.py @@ -1053,6 +1053,7 @@ def to_flyte_idl(self) -> _core_task.K8sPod: metadata=self._metadata.to_flyte_idl() if self.metadata else None, pod_spec=_json_format.Parse(_json.dumps(self.pod_spec), _struct.Struct()) if self.pod_spec else None, data_config=self.data_config.to_flyte_idl() if self.data_config else None, + primary_container_name=self.primary_container_name, ) @classmethod @@ -1081,6 +1082,7 @@ def from_pod_template(cls, pod_template: "PodTemplate") -> "K8sPod": return cls( metadata=K8sObjectMetadata(labels=pod_template.labels, annotations=pod_template.annotations), pod_spec=ApiClient().sanitize_for_serialization(pod_template.pod_spec), + primary_container_name=pod_template.primary_container_name, ) diff --git a/plugins/flytekit-spark/flytekitplugins/spark/models.py b/plugins/flytekit-spark/flytekitplugins/spark/models.py index e74a9fbe3f..1f185609f4 100644 --- a/plugins/flytekit-spark/flytekitplugins/spark/models.py +++ b/plugins/flytekit-spark/flytekitplugins/spark/models.py @@ -7,6 +7,7 @@ from flytekit.exceptions import user as _user_exceptions from flytekit.models import common as _common +from flytekit.models.task import K8sPod class SparkType(enum.Enum): @@ -27,6 +28,8 @@ def __init__( executor_path: str, databricks_conf: Optional[Dict[str, Dict[str, Dict]]] = None, databricks_instance: Optional[str] = None, + driver_pod: Optional[K8sPod] = None, + executor_pod: Optional[K8sPod] = None, ): """ This defines a SparkJob target. It will execute the appropriate SparkJob. @@ -47,6 +50,8 @@ def __init__( databricks_conf = {} self._databricks_conf = databricks_conf self._databricks_instance = databricks_instance + self._driver_pod = driver_pod + self._executor_pod = executor_pod def with_overrides( self, @@ -71,6 +76,8 @@ def with_overrides( hadoop_conf=new_hadoop_conf, databricks_conf=new_databricks_conf, databricks_instance=self.databricks_instance, + driver_pod=self.driver_pod, + executor_pod=self.executor_pod, executor_path=self.executor_path, ) @@ -139,6 +146,22 @@ def databricks_instance(self) -> str: """ return self._databricks_instance + @property + def driver_pod(self) -> K8sPod: + """ + Additional pod specs for driver pod. + :rtype: K8sPod + """ + return self._driver_pod + + @property + def executor_pod(self) -> K8sPod: + """ + Additional pod specs for the worker node pods. + :rtype: K8sPod + """ + return self._executor_pod + def to_flyte_idl(self): """ :rtype: flyteidl.plugins.spark_pb2.SparkJob @@ -167,6 +190,8 @@ def to_flyte_idl(self): hadoopConf=self.hadoop_conf, databricksConf=databricks_conf, databricksInstance=self.databricks_instance, + driverPod=self.driver_pod.to_flyte_idl() if self.driver_pod else None, + executorPod=self.executor_pod.to_flyte_idl() if self.executor_pod else None, ) @classmethod @@ -193,4 +218,6 @@ def from_flyte_idl(cls, pb2_object): executor_path=pb2_object.executorPath, databricks_conf=json_format.MessageToDict(pb2_object.databricksConf), databricks_instance=pb2_object.databricksInstance, + driver_pod=pb2_object.driverPod, + executor_pod=pb2_object.executorPod, ) diff --git a/plugins/flytekit-spark/flytekitplugins/spark/task.py b/plugins/flytekit-spark/flytekitplugins/spark/task.py index 7d2f718617..e1914ca772 100644 --- a/plugins/flytekit-spark/flytekitplugins/spark/task.py +++ b/plugins/flytekit-spark/flytekitplugins/spark/task.py @@ -10,9 +10,11 @@ from flytekit import FlyteContextManager, PythonFunctionTask, lazy_module, logger from flytekit.configuration import DefaultImages, SerializationSettings from flytekit.core.context_manager import ExecutionParameters +from flytekit.core.pod_template import PRIMARY_CONTAINER_DEFAULT_NAME, PodTemplate from flytekit.extend import ExecutionState, TaskPlugins from flytekit.extend.backend.base_agent import AsyncAgentExecutorMixin from flytekit.image_spec import ImageSpec +from flytekit.models.task import K8sPod from .models import SparkJob, SparkType @@ -26,17 +28,21 @@ class Spark(object): Use this to configure a SparkContext for a your task. Task's marked with this will automatically execute natively onto K8s as a distributed execution of spark - Args: - spark_conf: Dictionary of spark config. The variables should match what spark expects - hadoop_conf: Dictionary of hadoop conf. The variables should match a typical hadoop configuration for spark - executor_path: Python binary executable to use for PySpark in driver and executor. - applications_path: MainFile is the path to a bundled JAR, Python, or R file of the application to execute. + Attributes: + spark_conf (Optional[Dict[str, str]]): Spark configuration dictionary. + hadoop_conf (Optional[Dict[str, str]]): Hadoop configuration dictionary. + executor_path (Optional[str]): Path to the Python binary for PySpark execution. + applications_path (Optional[str]): Path to the main application file. + driver_pod (Optional[PodTemplate]): The pod template for the Spark driver pod. + executor_pod (Optional[PodTemplate]): The pod template for the Spark executor pod. """ spark_conf: Optional[Dict[str, str]] = None hadoop_conf: Optional[Dict[str, str]] = None executor_path: Optional[str] = None applications_path: Optional[str] = None + driver_pod: Optional[PodTemplate] = None + executor_pod: Optional[PodTemplate] = None def __post_init__(self): if self.spark_conf is None: @@ -168,6 +174,8 @@ def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]: executor_path=self._default_executor_path or settings.python_interpreter, main_class="", spark_type=SparkType.PYTHON, + driver_pod=self.to_k8s_pod(self.task_config.driver_pod), + executor_pod=self.to_k8s_pod(self.task_config.executor_pod), ) if isinstance(self.task_config, (Databricks, DatabricksV2)): cfg = cast(DatabricksV2, self.task_config) @@ -176,6 +184,27 @@ def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]: return MessageToDict(job.to_flyte_idl()) + def to_k8s_pod(self, pod_template: Optional[PodTemplate] = None) -> Optional[K8sPod]: + """ + Convert the podTemplate to K8sPod + """ + if pod_template is None: + return None + + task_primary_container_name = ( + self.pod_template.primary_container_name if self.pod_template else PRIMARY_CONTAINER_DEFAULT_NAME + ) + + if pod_template.primary_container_name != task_primary_container_name: + logger.warning( + "Primary container name ('%s') set in spark differs from the one in @task ('%s'). " + "The primary container name in @task will be overridden.", + pod_template.primary_container_name, + task_primary_container_name, + ) + + return K8sPod.from_pod_template(pod_template) + def pre_execute(self, user_params: ExecutionParameters) -> ExecutionParameters: import pyspark as _pyspark diff --git a/plugins/flytekit-spark/tests/test_spark_task.py b/plugins/flytekit-spark/tests/test_spark_task.py index 7ce5f14ebf..a8426e2896 100644 --- a/plugins/flytekit-spark/tests/test_spark_task.py +++ b/plugins/flytekit-spark/tests/test_spark_task.py @@ -5,23 +5,58 @@ import pyspark import pytest +from google.protobuf.json_format import MessageToDict +from flytekit import PodTemplate from flytekit.core import context_manager from flytekitplugins.spark import Spark from flytekitplugins.spark.task import Databricks, new_spark_session from pyspark.sql import SparkSession import flytekit -from flytekit import StructuredDataset, StructuredDatasetTransformerEngine, task, ImageSpec -from flytekit.configuration import Image, ImageConfig, SerializationSettings, FastSerializationSettings, DefaultImages -from flytekit.core.context_manager import ExecutionParameters, FlyteContextManager, ExecutionState +from flytekit import ( + StructuredDataset, + StructuredDatasetTransformerEngine, + task, + ImageSpec, +) +from flytekit.configuration import ( + Image, + ImageConfig, + SerializationSettings, + FastSerializationSettings, + DefaultImages, +) +from flytekit.core.context_manager import ( + ExecutionParameters, + FlyteContextManager, + ExecutionState, +) +from flytekit.models.task import K8sObjectMetadata, K8sPod +from kubernetes.client.models import ( + V1Container, + V1PodSpec, + V1Toleration, + V1EnvVar, +) + + +# @pytest.fixture(scope="function") +# def reset_spark_session() -> None: +# pyspark.sql.SparkSession.builder.getOrCreate().stop() +# yield +# pyspark.sql.SparkSession.builder.getOrCreate().stop() + @pytest.fixture(scope="function") def reset_spark_session() -> None: - pyspark.sql.SparkSession.builder.getOrCreate().stop() + if SparkSession._instantiatedSession: + SparkSession.builder.getOrCreate().stop() + SparkSession._instantiatedSession = None yield - pyspark.sql.SparkSession.builder.getOrCreate().stop() - + if SparkSession._instantiatedSession: + SparkSession.builder.getOrCreate().stop() + SparkSession._instantiatedSession = None def test_spark_task(reset_spark_session): databricks_conf = { @@ -68,7 +103,10 @@ def my_spark(a: str) -> int: retrieved_settings = my_spark.get_custom(settings) assert retrieved_settings["sparkConf"] == {"spark": "1"} assert retrieved_settings["executorPath"] == "/usr/bin/python3" - assert retrieved_settings["mainApplicationFile"] == "local:///usr/local/bin/entrypoint.py" + assert ( + retrieved_settings["mainApplicationFile"] + == "local:///usr/local/bin/entrypoint.py" + ) pb = ExecutionParameters.new_builder() pb.working_dir = "/tmp" @@ -121,11 +159,13 @@ def test_to_html(): df = spark.createDataFrame([("Bob", 10)], ["name", "age"]) sd = StructuredDataset(dataframe=df) tf = StructuredDatasetTransformerEngine() - output = tf.to_html(FlyteContextManager.current_context(), sd, pyspark.sql.DataFrame) + output = tf.to_html( + FlyteContextManager.current_context(), sd, pyspark.sql.DataFrame + ) assert pd.DataFrame(df.schema, columns=["StructField"]).to_html() == output -@mock.patch('pyspark.context.SparkContext.addPyFile') +@mock.patch("pyspark.context.SparkContext.addPyFile") def test_spark_addPyFile(mock_add_pyfile): @task( task_config=Spark( @@ -151,8 +191,11 @@ def my_spark(a: int) -> int: ctx = context_manager.FlyteContextManager.current_context() with context_manager.FlyteContextManager.with_context( - ctx.with_execution_state( - ctx.new_execution_state().with_params(mode=ExecutionState.Mode.TASK_EXECUTION)).with_serialization_settings(serialization_settings) + ctx.with_execution_state( + ctx.new_execution_state().with_params( + mode=ExecutionState.Mode.TASK_EXECUTION + ) + ).with_serialization_settings(serialization_settings) ) as new_ctx: my_spark.pre_execute(new_ctx.user_space_params) mock_add_pyfile.assert_called_once() @@ -173,7 +216,10 @@ def spark1(partitions: int) -> float: print("Starting Spark with Partitions: {}".format(partitions)) return 1.0 - assert spark1.container_image.base_image == f"cr.flyte.org/flyteorg/flytekit:spark-{DefaultImages.get_version_suffix()}" + assert ( + spark1.container_image.base_image + == f"cr.flyte.org/flyteorg/flytekit:spark-{DefaultImages.get_version_suffix()}" + ) assert spark1._default_executor_path == "/usr/bin/python3" assert spark1._default_applications_path == "local:///usr/local/bin/entrypoint.py" @@ -185,6 +231,229 @@ def spark2(partitions: int) -> float: print("Starting Spark with Partitions: {}".format(partitions)) return 1.0 - assert spark2.container_image.base_image == f"cr.flyte.org/flyteorg/flytekit:spark-{DefaultImages.get_version_suffix()}" + assert ( + spark2.container_image.base_image + == f"cr.flyte.org/flyteorg/flytekit:spark-{DefaultImages.get_version_suffix()}" + ) assert spark2._default_executor_path == "/usr/bin/python3" assert spark2._default_applications_path == "local:///usr/local/bin/entrypoint.py" + + +def clean_dict(d): + """ + Recursively remove keys with None values from dictionaries and lists. + """ + if isinstance(d, dict): + return {k: clean_dict(v) for k, v in d.items() if v is not None} + elif isinstance(d, list): + return [clean_dict(item) for item in d if item is not None] + else: + return d + + +def test_spark_driver_executor_podSpec(reset_spark_session): + custom_image = ImageSpec( + registry="ghcr.io/flyteorg", + packages=["flytekitplugins-spark"], + ) + + driver_pod_spec = V1PodSpec( + containers=[ + V1Container( + name="driver-primary", + image="ghcr.io/flyteorg", + command=["echo"], + args=["wow"], + env=[V1EnvVar(name="x/custom-driver", value="driver")], + ), + V1Container( + name="not-primary", + image="ghcr.io/flyteorg", + command=["echo"], + args=["not_primary"], + ), + ], + tolerations=[ + V1Toleration( + key="x/custom-driver", + operator="Equal", + value="foo-driver", + effect="NoSchedule", + ), + ], + ) + + executor_pod_spec = V1PodSpec( + containers=[ + V1Container( + name="executor-primary", + image="ghcr.io/flyteorg", + command=["echo"], + args=["wow"], + env=[V1EnvVar(name="x/custom-executor", value="executor")], + ), + V1Container( + name="not-primary", + image="ghcr.io/flyteorg", + command=["echo"], + args=["not_primary"], + ), + ], + tolerations=[ + V1Toleration( + key="x/custom-executor", + operator="Equal", + value="foo-executor", + effect="NoSchedule", + ), + ], + ) + + driver_pod = PodTemplate( + labels={"lKeyA_d": "lValA", "lKeyB_d": "lValB"}, + annotations={"aKeyA_d": "aValA", "aKeyB_d": "aValB"}, + primary_container_name="driver-primary", + pod_spec=driver_pod_spec, + ) + + executor_pod = PodTemplate( + labels={"lKeyA_e": "lValA", "lKeyB_e": "lValB"}, + annotations={"aKeyA_e": "aValA", "aKeyB_e": "aValB"}, + primary_container_name="executor-primary", + pod_spec=executor_pod_spec, + ) + + expect_driver_pod_spec = V1PodSpec( + containers=[ + V1Container( + name="driver-primary", + image="ghcr.io/flyteorg", + command=["echo"], + args=["wow"], + env=[ + V1EnvVar(name="x/custom-driver", value="driver"), + ], + ), + V1Container( + name="not-primary", + image="ghcr.io/flyteorg", + command=["echo"], + args=["not_primary"], + ), + ], + tolerations=[ + V1Toleration( + key="x/custom-driver", + operator="Equal", + value="foo-driver", + effect="NoSchedule", + ), + ], + ) + + expect_executor_pod_spec = V1PodSpec( + containers=[ + V1Container( + name="executor-primary", + image="ghcr.io/flyteorg", + command=["echo"], + args=["wow"], + env=[ + V1EnvVar(name="x/custom-executor", value="executor"), + ], + ), + V1Container( + name="not-primary", + image="ghcr.io/flyteorg", + command=["echo"], + args=["not_primary"], + ), + ], + tolerations=[ + V1Toleration( + key="x/custom-executor", + operator="Equal", + value="foo-executor", + effect="NoSchedule", + ), + ], + ) + + driver_pod_spec_dict_remove_None = expect_driver_pod_spec.to_dict() + executor_pod_spec_dict_remove_None = expect_executor_pod_spec.to_dict() + + driver_pod_spec_dict_remove_None = clean_dict(driver_pod_spec_dict_remove_None) + executor_pod_spec_dict_remove_None = clean_dict(executor_pod_spec_dict_remove_None) + + target_driver_k8sPod = K8sPod( + metadata=K8sObjectMetadata( + labels={"lKeyA_d": "lValA", "lKeyB_d": "lValB"}, + annotations={"aKeyA_d": "aValA", "aKeyB_d": "aValB"}, + ), + pod_spec=driver_pod_spec_dict_remove_None, # type: ignore + primary_container_name="driver-primary" + ) + + target_executor_k8sPod = K8sPod( + metadata=K8sObjectMetadata( + labels={"lKeyA_e": "lValA", "lKeyB_e": "lValB"}, + annotations={"aKeyA_e": "aValA", "aKeyB_e": "aValB"}, + ), + pod_spec=executor_pod_spec_dict_remove_None, # type: ignore + primary_container_name="executor-primary" + ) + + @task( + task_config=Spark( + spark_conf={"spark.driver.memory": "1000M"}, + driver_pod=driver_pod, + executor_pod=executor_pod, + ), + container_image=custom_image, + pod_template=PodTemplate(primary_container_name="primary"), + ) + def my_spark(a: str) -> int: + session = flytekit.current_context().spark_session + configs = session.sparkContext.getConf().getAll() + assert ("spark.driver.memory", "1000M") in configs + assert session.sparkContext.appName == "FlyteSpark: ex:local:local:local" + return 10 + + assert my_spark.task_config is not None + assert my_spark.task_config.spark_conf == {"spark.driver.memory": "1000M"} + default_img = Image(name="default", fqn="test", tag="tag") + + settings = SerializationSettings( + project="project", + domain="domain", + version="version", + env={"FOO": "baz"}, + image_config=ImageConfig(default_image=default_img, images=[default_img]), + ) + + retrieved_settings = my_spark.get_custom(settings) + assert retrieved_settings["sparkConf"] == {"spark.driver.memory": "1000M"} + assert retrieved_settings["executorPath"] == "/usr/bin/python3" + assert ( + retrieved_settings["mainApplicationFile"] + == "local:///usr/local/bin/entrypoint.py" + ) + assert retrieved_settings["driverPod"] == MessageToDict( + target_driver_k8sPod.to_flyte_idl() + ) + assert retrieved_settings["executorPod"] == MessageToDict( + target_executor_k8sPod.to_flyte_idl() + ) + + pb = ExecutionParameters.new_builder() + pb.working_dir = "/tmp" + pb.execution_id = "ex:local:local:local" + p = pb.build() + new_p = my_spark.pre_execute(p) + assert new_p is not None + assert new_p.has_attr("SPARK_SESSION") + + assert my_spark.sess is not None + configs = my_spark.sess.sparkContext.getConf().getAll() + assert ("spark.driver.memory", "1000M") in configs + assert ("spark.app.name", "FlyteSpark: ex:local:local:local") in configs diff --git a/pydoclint-errors-baseline.txt b/pydoclint-errors-baseline.txt index 043fcce00b..808a93731e 100644 --- a/pydoclint-errors-baseline.txt +++ b/pydoclint-errors-baseline.txt @@ -602,8 +602,6 @@ plugins/flytekit-spark/flytekitplugins/spark/models.py DOC301: Class `SparkJob`: __init__() should not have a docstring; please combine it with the docstring of the class -------------------- plugins/flytekit-spark/flytekitplugins/spark/task.py - DOC601: Class `Spark`: Class docstring contains fewer class attributes than actual class attributes. (Please read https://jsh9.github.io/pydoclint/checking_class_attributes.html on how to correctly document class attributes.) - DOC603: Class `Spark`: Class docstring attributes are different from actual class attributes. (Or could be other formatting issues: https://jsh9.github.io/pydoclint/violation_codes.html#notes-on-doc103 ). Attributes in the class definition but not in the docstring: [applications_path: Optional[str], executor_path: Optional[str], hadoop_conf: Optional[Dict[str, str]], spark_conf: Optional[Dict[str, str]]]. (Please read https://jsh9.github.io/pydoclint/checking_class_attributes.html on how to correctly document class attributes.) DOC601: Class `DatabricksV2`: Class docstring contains fewer class attributes than actual class attributes. (Please read https://jsh9.github.io/pydoclint/checking_class_attributes.html on how to correctly document class attributes.) DOC603: Class `DatabricksV2`: Class docstring attributes are different from actual class attributes. (Or could be other formatting issues: https://jsh9.github.io/pydoclint/violation_codes.html#notes-on-doc103 ). Attributes in the class definition but not in the docstring: [databricks_conf: Optional[Dict[str, Union[str, dict]]], databricks_instance: Optional[str]]. (Please read https://jsh9.github.io/pydoclint/checking_class_attributes.html on how to correctly document class attributes.) --------------------