Skip to content

Commit

Permalink
Add volume and volume mounts arguments to TrainingClient.create_job A…
Browse files Browse the repository at this point in the history
…PI (#2449)

Signed-off-by: Antonin Stefanutti <[email protected]>
  • Loading branch information
astefanutti authored Feb 25, 2025
1 parent 078ec30 commit 3860d3d
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 2 deletions.
12 changes: 12 additions & 0 deletions sdk/python/kubeflow/training/api/training_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,6 +354,8 @@ def create_job(
env_vars: Optional[
Union[Dict[str, str], List[Union[models.V1EnvVar, models.V1EnvVar]]]
] = None,
volumes: Optional[List[models.V1Volume]] = None,
volume_mounts: Optional[List[models.V1VolumeMount]] = None,
):
"""Create the Training Job.
Job can be created using one of the following options:
Expand Down Expand Up @@ -418,6 +420,8 @@ def create_job(
https://github.com/kubernetes-client/python/blob/master/kubernetes/docs/V1EnvVar.md)
or a kubernetes.client.models.V1EnvFromSource (documented here:
https://github.com/kubernetes-client/python/blob/master/kubernetes/docs/V1EnvFromSource.md)
volumes: Volume(s) to be attached to the replicas.
volume_mounts: VolumeMount(s) specifying where to mount the volume(s) into the replicas.
Raises:
ValueError: Invalid input parameters.
Expand Down Expand Up @@ -448,6 +452,12 @@ def create_job(
f"Job kind must be one of these: {constants.JOB_PARAMETERS.keys()}"
)

if len(volumes or []) != len(volume_mounts or []):
raise ValueError(
"Volumes and VolumeMounts must be the same length: "
f"{len(volumes or [])} vs. {len(volume_mounts or [])}"
)

# If Training function or base image is set, configure Job template.
if job is None and (train_func is not None or base_image is not None):
# Job name must be set to configure Job template.
Expand Down Expand Up @@ -496,11 +506,13 @@ def create_job(
args=args,
resources=resources_per_worker,
env_vars=env_vars,
volume_mounts=volume_mounts,
)

# Get Pod template spec using the above container.
pod_template_spec = utils.get_pod_template_spec(
containers=[container_spec],
volumes=volumes,
)

# Configure template for different Jobs.
Expand Down
44 changes: 42 additions & 2 deletions sdk/python/kubeflow/training/api/training_client_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
V1ObjectMeta,
V1PodSpec,
V1PodTemplateSpec,
V1Volume,
V1VolumeMount,
)

TEST_NAME = "test"
Expand Down Expand Up @@ -142,6 +144,8 @@ def create_job(
args=None,
num_workers=2,
env_vars=None,
volumes=None,
volume_mounts=None,
):
# Handle env_vars as either a dict or a list
if env_vars:
Expand All @@ -158,6 +162,7 @@ def create_job(
command=command,
args=args,
env=env_vars,
volume_mounts=volume_mounts,
)

master = KubeflowOrgV1ReplicaSpec(
Expand All @@ -166,7 +171,10 @@ def create_job(
metadata=V1ObjectMeta(
annotations={constants.ISTIO_SIDECAR_INJECTION: "false"}
),
spec=V1PodSpec(containers=[container]),
spec=V1PodSpec(
containers=[container],
volumes=volumes,
),
),
)

Expand All @@ -180,7 +188,10 @@ def create_job(
metadata=V1ObjectMeta(
annotations={constants.ISTIO_SIDECAR_INJECTION: "false"}
),
spec=V1PodSpec(containers=[container]),
spec=V1PodSpec(
containers=[container],
volumes=volumes,
),
),
)

Expand Down Expand Up @@ -530,6 +541,35 @@ def __init__(self):
env_vars=[V1EnvVar(name="ENV_VAR", value="env_value")], num_workers=2
),
),
(
"create job with a volume and a volume mount",
{
"name": TEST_NAME,
"namespace": TEST_NAME,
"base_image": TEST_IMAGE,
"num_workers": 1,
"volumes": [V1Volume(name="vol")],
"volume_mounts": [V1VolumeMount(name="vol", mount_path="/mnt")],
},
SUCCESS,
create_job(
num_workers=1,
volumes=[V1Volume(name="vol")],
volume_mounts=[V1VolumeMount(name="vol", mount_path="/mnt")],
),
),
(
"invalid number of volume mount",
{
"name": TEST_NAME,
"namespace": TEST_NAME,
"base_image": TEST_IMAGE,
"num_workers": 1,
"volumes": [V1Volume(name="vol")],
},
ValueError,
None,
),
]

test_data_get_job_pods = [
Expand Down

0 comments on commit 3860d3d

Please sign in to comment.