Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEAT] add driver/executor pod in Spark #3016

Open
wants to merge 16 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions flytekit/models/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -1053,6 +1053,7 @@ def to_flyte_idl(self) -> _core_task.K8sPod:
metadata=self._metadata.to_flyte_idl() if self.metadata else None,
pod_spec=_json_format.Parse(_json.dumps(self.pod_spec), _struct.Struct()) if self.pod_spec else None,
data_config=self.data_config.to_flyte_idl() if self.data_config else None,
primary_container_name=self.primary_container_name,
)

@classmethod
Expand Down Expand Up @@ -1081,6 +1082,7 @@ def from_pod_template(cls, pod_template: "PodTemplate") -> "K8sPod":
return cls(
metadata=K8sObjectMetadata(labels=pod_template.labels, annotations=pod_template.annotations),
pod_spec=ApiClient().sanitize_for_serialization(pod_template.pod_spec),
primary_container_name=pod_template.primary_container_name,
)


Expand Down
27 changes: 27 additions & 0 deletions plugins/flytekit-spark/flytekitplugins/spark/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from flytekit.exceptions import user as _user_exceptions
from flytekit.models import common as _common
from flytekit.models.task import K8sPod


class SparkType(enum.Enum):
Expand All @@ -27,6 +28,8 @@ def __init__(
executor_path: str,
databricks_conf: Optional[Dict[str, Dict[str, Dict]]] = None,
databricks_instance: Optional[str] = None,
driver_pod: Optional[K8sPod] = None,
executor_pod: Optional[K8sPod] = None,
):
"""
This defines a SparkJob target. It will execute the appropriate SparkJob.
Expand All @@ -47,6 +50,8 @@ def __init__(
databricks_conf = {}
self._databricks_conf = databricks_conf
self._databricks_instance = databricks_instance
self._driver_pod = driver_pod
self._executor_pod = executor_pod

def with_overrides(
self,
Expand All @@ -71,6 +76,8 @@ def with_overrides(
hadoop_conf=new_hadoop_conf,
databricks_conf=new_databricks_conf,
databricks_instance=self.databricks_instance,
driver_pod=self.driver_pod,
executor_pod=self.executor_pod,
executor_path=self.executor_path,
)

Expand Down Expand Up @@ -139,6 +146,22 @@ def databricks_instance(self) -> str:
"""
return self._databricks_instance

@property
def driver_pod(self) -> K8sPod:
"""
Additional pod specs for driver pod.
:rtype: K8sPod
"""
return self._driver_pod

@property
def executor_pod(self) -> K8sPod:
"""
Additional pod specs for the worker node pods.
:rtype: K8sPod
"""
return self._executor_pod

def to_flyte_idl(self):
"""
:rtype: flyteidl.plugins.spark_pb2.SparkJob
Expand Down Expand Up @@ -167,6 +190,8 @@ def to_flyte_idl(self):
hadoopConf=self.hadoop_conf,
databricksConf=databricks_conf,
databricksInstance=self.databricks_instance,
driverPod=self.driver_pod.to_flyte_idl() if self.driver_pod else None,
executorPod=self.executor_pod.to_flyte_idl() if self.executor_pod else None,
)

@classmethod
Expand All @@ -193,4 +218,6 @@ def from_flyte_idl(cls, pb2_object):
executor_path=pb2_object.executorPath,
databricks_conf=json_format.MessageToDict(pb2_object.databricksConf),
databricks_instance=pb2_object.databricksInstance,
driver_pod=pb2_object.driverPod,
executor_pod=pb2_object.executorPod,
)
39 changes: 34 additions & 5 deletions plugins/flytekit-spark/flytekitplugins/spark/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,11 @@
from flytekit import FlyteContextManager, PythonFunctionTask, lazy_module, logger
from flytekit.configuration import DefaultImages, SerializationSettings
from flytekit.core.context_manager import ExecutionParameters
from flytekit.core.pod_template import PRIMARY_CONTAINER_DEFAULT_NAME, PodTemplate
from flytekit.extend import ExecutionState, TaskPlugins
from flytekit.extend.backend.base_agent import AsyncAgentExecutorMixin
from flytekit.image_spec import ImageSpec
from flytekit.models.task import K8sPod

from .models import SparkJob, SparkType

Expand All @@ -26,17 +28,21 @@ class Spark(object):
Use this to configure a SparkContext for a your task. Task's marked with this will automatically execute
natively onto K8s as a distributed execution of spark

Args:
spark_conf: Dictionary of spark config. The variables should match what spark expects
hadoop_conf: Dictionary of hadoop conf. The variables should match a typical hadoop configuration for spark
executor_path: Python binary executable to use for PySpark in driver and executor.
applications_path: MainFile is the path to a bundled JAR, Python, or R file of the application to execute.
Attributes:
spark_conf (Optional[Dict[str, str]]): Spark configuration dictionary.
hadoop_conf (Optional[Dict[str, str]]): Hadoop configuration dictionary.
executor_path (Optional[str]): Path to the Python binary for PySpark execution.
applications_path (Optional[str]): Path to the main application file.
driver_pod (Optional[PodTemplate]): The pod template for the Spark driver pod.
executor_pod (Optional[PodTemplate]): The pod template for the Spark executor pod.
"""

spark_conf: Optional[Dict[str, str]] = None
hadoop_conf: Optional[Dict[str, str]] = None
executor_path: Optional[str] = None
applications_path: Optional[str] = None
driver_pod: Optional[PodTemplate] = None
executor_pod: Optional[PodTemplate] = None

def __post_init__(self):
if self.spark_conf is None:
Expand Down Expand Up @@ -168,6 +174,8 @@ def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]:
executor_path=self._default_executor_path or settings.python_interpreter,
main_class="",
spark_type=SparkType.PYTHON,
driver_pod=self.to_k8s_pod(self.task_config.driver_pod),
executor_pod=self.to_k8s_pod(self.task_config.executor_pod),
)
if isinstance(self.task_config, (Databricks, DatabricksV2)):
cfg = cast(DatabricksV2, self.task_config)
Expand All @@ -176,6 +184,27 @@ def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]:

return MessageToDict(job.to_flyte_idl())

def to_k8s_pod(self, pod_template: Optional[PodTemplate] = None) -> Optional[K8sPod]:
"""
Convert the podTemplate to K8sPod
"""
if pod_template is None:
return None

task_primary_container_name = (
self.pod_template.primary_container_name if self.pod_template else PRIMARY_CONTAINER_DEFAULT_NAME
)

if pod_template.primary_container_name != task_primary_container_name:
logger.warning(
"Primary container name ('%s') set in spark differs from the one in @task ('%s'). "
"The primary container name in @task will be overridden.",
pod_template.primary_container_name,
task_primary_container_name,
)

return K8sPod.from_pod_template(pod_template)

def pre_execute(self, user_params: ExecutionParameters) -> ExecutionParameters:
import pyspark as _pyspark

Expand Down
Loading
Loading